diff --git a/mex/build/sparse_hessian_times_B_kronecker_C.am b/mex/build/sparse_hessian_times_B_kronecker_C.am index ab8537b6173b1d96d8bc0f91bf78a1371fa55ff8..1b296e06ff5d0f9db43e37b3770acb344f57fb17 100644 --- a/mex/build/sparse_hessian_times_B_kronecker_C.am +++ b/mex/build/sparse_hessian_times_B_kronecker_C.am @@ -5,7 +5,7 @@ mex_PROGRAMS = sparse_hessian_times_B_kronecker_C -nodist_sparse_hessian_times_B_kronecker_C_SOURCES = sparse_hessian_times_B_kronecker_C.f08 matlab_mex.F08 +nodist_sparse_hessian_times_B_kronecker_C_SOURCES = sparse_hessian_times_B_kronecker_C.f08 matlab_mex.F08 blas_lapack.F08 AM_FCFLAGS += -fopenmp AM_LDFLAGS += $(OPENMP_LDFLAGS) @@ -13,7 +13,7 @@ AM_LDFLAGS += $(OPENMP_LDFLAGS) BUILT_SOURCES = $(nodist_sparse_hessian_times_B_kronecker_C_SOURCES) CLEANFILES = $(nodist_sparse_hessian_times_B_kronecker_C_SOURCES) -sparse_hessian_times_B_kronecker_C.o : matlab_mex.mod +sparse_hessian_times_B_kronecker_C.o : matlab_mex.mod blas.mod %.f08: $(top_srcdir)/../../sources/kronecker/%.f08 $(LN_S) -f $< $@ diff --git a/mex/sources/blas_lapack.F08 b/mex/sources/blas_lapack.F08 index 8df746a17734744deab6fc495044168483a9706f..8042df0a3032c29895dd3500d07aa85ef2348ed7 100644 --- a/mex/sources/blas_lapack.F08 +++ b/mex/sources/blas_lapack.F08 @@ -48,6 +48,16 @@ module blas real(real64), dimension(*), intent(inout) :: y end subroutine dgemv end interface + + interface + subroutine dger(m, n, alpha, x, incx, y, incy, a, lda) + import :: blint, real64 + integer(blint), intent(in) :: m, n, incx, incy, lda + real(real64), intent(in) :: alpha + real(real64), dimension(*), intent(in) :: x, y + real(real64), dimension(*), intent(inout) :: a + end subroutine dger + end interface end module blas module lapack diff --git a/mex/sources/kronecker/sparse_hessian_times_B_kronecker_C.f08 b/mex/sources/kronecker/sparse_hessian_times_B_kronecker_C.f08 index acd7532bee774b4ed445132f24d7eac0a1309a7e..6978470344f20c1325efc5e443ba768fda2c226f 100644 --- a/mex/sources/kronecker/sparse_hessian_times_B_kronecker_C.f08 +++ b/mex/sources/kronecker/sparse_hessian_times_B_kronecker_C.f08 @@ -24,6 +24,8 @@ subroutine mexFunction(nlhs, plhs, nrhs, prhs) bind(c, name='mexFunction') use iso_fortran_env use iso_c_binding use matlab_mex + use blas + use omp_lib implicit none type(c_ptr), dimension(*), intent(in), target :: prhs @@ -96,32 +98,43 @@ subroutine mexFunction(nlhs, plhs, nrhs, prhs) bind(c, name='mexFunction') end if contains subroutine sparse_hessian_times_B_kronecker_C - integer(c_size_t) :: ii,jj,jB,jC,iB,iC - integer(mwIndex) :: k1,k2,k,kk - real(real64) :: bc - - ! Loop over the columns of B⊗C (or of the result matrix D) - !$omp parallel do num_threads(numthreads) default(none) shared(nA,nB,nC,mC,A_jc,A_ir,A_v,B,C,D) & - !$omp private(ii,jj,jB,jC,iB,iC,k1,k2,k,kk,bc) - do jj = 1,nB*nC ! column of B⊗C index - jB = (jj-1)/nC+1 - jC = mod(jj-1, nC)+1 - ! Loop over the rows of B⊗C (column jj) - do ii=1,nA - k1 = A_jc(ii) - k2 = A_jc(ii+1) - if (k1 < k2) then ! Otherwise column ii of A does not have non zero elements (and there is nothing to compute) - iC = mod(ii-1, mC)+1 - iB = (ii-1)/mC+1 - bc = B(iB,jB)*C(iC,jC) - ! Loop over the non zero entries of A(:,ii) - do k=k1+1,k2 - kk = A_ir(k)+1 - D(kk,jj) = D(kk,jj) + A_v(k)*bc - end do - end if - end do + integer(c_size_t) :: ii,k,kk,k1,k2,iB,iC + real(real64), dimension(:,:), allocatable :: Dt ! Transpose of D + integer(omp_lock_kind), dimension(:), allocatable :: locks + + allocate(Dt(nB*nC, mA)) + Dt = 0._real64 + + allocate(locks(mA)) + do ii=1,mA + call omp_init_lock(locks(ii)) end do + + ! Loop over columns of A (and therefore rows of B⊗C) + ! Scheduling is made dynamic because the number of non-zero elements is not constant + + !$omp parallel do num_threads(numthreads) default(none) shared(nA,nB,nC,mC,A_jc,A_ir,A_v,B,C,Dt,locks) & + !$omp private(ii,k,kk,k1,k2,iB,iC) schedule(dynamic) + do ii=1,nA + k1 = A_jc(ii) + k2 = A_jc(ii+1) + if (k1 < k2) then ! Otherwise column ii of A does not have non zero elements (and there is nothing to compute) + iC = mod(ii-1, mC)+1 + iB = (ii-1)/mC+1 + ! Loop over the non-zero entries of A(:,ii) + do k=k1+1,k2 + kk = A_ir(k)+1 + ! D(kk,:) += A(kk,ii)·vec(C(iC,:)·B(iB,:)ᵀ) + ! NB: This call creates temporary contiguous copies of B(iB,:) and C(iC,:), hence incx=incy=1 + call omp_set_lock(locks(kk)) + call dger(int(nC, blint), int(nB, blint), A_v(k), C(iC,:), 1_blint, & + B(iB,:), 1_blint, Dt(:,kk), int(nC, blint)) + call omp_unset_lock(locks(kk)) + end do + end if + end do + + D = transpose(Dt) end subroutine sparse_hessian_times_B_kronecker_C subroutine sparse_hessian_times_B_kronecker_B