Verified Commit 82cef48e authored by Sébastien Villemot's avatar Sébastien Villemot
Browse files

A_times_B_kronecker_C MEX: remove the OpenMP codepath

Testing shows that it is slower than the BLAS path.
parent c6a09a65
......@@ -69,7 +69,6 @@ options_.mode_check.nolik = false;
options_.huge_number = 1e7;
% Default number of threads for parallelized mex files.
options_.threads.kronecker.A_times_B_kronecker_C = 1;
options_.threads.kronecker.sparse_hessian_times_B_kronecker_C = 1;
options_.threads.local_state_space_iteration_2 = 1;
......
......@@ -183,7 +183,6 @@ else
print_info(info,options.noprint,options);
end
dr = dyn_second_order_solver(d1,d2,dr,M,...
options.threads.kronecker.A_times_B_kronecker_C,...
options.threads.kronecker.sparse_hessian_times_B_kronecker_C);
end
......@@ -274,7 +273,6 @@ if info
end
dr_np = dyn_second_order_solver(d1_np,d2_np,dr_np,pm.M_np,...
options.threads.kronecker.A_times_B_kronecker_C,...
options.threads.kronecker.sparse_hessian_times_B_kronecker_C);
end
......@@ -312,7 +310,6 @@ if isfield(options,'portfolio') && options.portfolio == 1
end
dr_np = dyn_second_order_solver(d1_np,d2_np,dr_np,pm.M_np,...
options.threads.kronecker.A_times_B_kronecker_C,...
options.threads.kronecker.sparse_hessian_times_B_kronecker_C);
end
......
function dr = dyn_second_order_solver(jacobia,hessian_mat,dr,M_,threads_ABC,threads_BC)
function dr = dyn_second_order_solver(jacobia,hessian_mat,dr,M_,threads_BC)
%@info:
%! @deftypefn {Function File} {@var{dr} =} dyn_second_order_solver (@var{jacobia},@var{hessian_mat},@var{dr},@var{M_},@var{threads_ABC},@var{threads_BC})
%! @deftypefn {Function File} {@var{dr} =} dyn_second_order_solver (@var{jacobia},@var{hessian_mat},@var{dr},@var{M_},@var{threads_BC})
%! @anchor{dyn_second_order_solver}
%! @sp 1
%! Computes the second order reduced form of the DSGE model
......@@ -17,8 +17,6 @@ function dr = dyn_second_order_solver(jacobia,hessian_mat,dr,M_,threads_ABC,thre
%! Matlab's structure describing the reduced form solution of the model.
%! @item M_
%! Matlab's structure describing the model (initialized by @code{dynare}).
%! @item threads_ABC
%! Integer controlling number of threads in A_times_B_kronecker_C
%! @item threads_BC
%! Integer controlling number of threads in sparse_hessian_times_B_kronecker_C
%! @end table
......@@ -125,7 +123,7 @@ hu1 = [hu;zeros(np-nspred,M_.exo_nbr)];
[nrhx,nchx] = size(Gy);
[nrhu1,nchu1] = size(hu1);
[abcOut,err] = A_times_B_kronecker_C(dr.ghxx,Gy,hu1,threads_ABC);
[abcOut,err] = A_times_B_kronecker_C(dr.ghxx,Gy,hu1);
mexErrCheck('A_times_B_kronecker_C', err);
B1 = B*abcOut;
rhs = -[rhs; zeros(n-M_.endo_nbr,size(rhs,2))]-B1;
......@@ -139,7 +137,7 @@ dr.ghxu = A\rhs;
[rhs, err] = sparse_hessian_times_B_kronecker_C(hessian_mat,zu,threads_BC);
mexErrCheck('sparse_hessian_times_B_kronecker_C', err);
[B1, err] = A_times_B_kronecker_C(B*dr.ghxx,hu1,threads_ABC);
[B1, err] = A_times_B_kronecker_C(B*dr.ghxx,hu1);
mexErrCheck('A_times_B_kronecker_C', err);
rhs = -[rhs; zeros(n-M_.endo_nbr,size(rhs,2))]-B1;
......
......@@ -54,20 +54,20 @@ ys = oo.dr.ys;
%second order terms
Uyy = full(Uyy);
[Uyygygy, err] = A_times_B_kronecker_C(Uyy,gy,gy,options.threads.kronecker.A_times_B_kronecker_C);
[Uyygygy, err] = A_times_B_kronecker_C(Uyy,gy,gy);
mexErrCheck('A_times_B_kronecker_C', err);
[Uyygugu, err] = A_times_B_kronecker_C(Uyy,gu,gu,options.threads.kronecker.A_times_B_kronecker_C);
[Uyygugu, err] = A_times_B_kronecker_C(Uyy,gu,gu);
mexErrCheck('A_times_B_kronecker_C', err);
[Uyygygu, err] = A_times_B_kronecker_C(Uyy,gy,gu,options.threads.kronecker.A_times_B_kronecker_C);
[Uyygygu, err] = A_times_B_kronecker_C(Uyy,gy,gu);
mexErrCheck('A_times_B_kronecker_C', err);
Wbar =U/(1-beta); %steady state welfare
Wy = Uy*gy/(eye(nspred)-beta*Gy);
Wu = Uy*gu+beta*Wy*Gu;
Wyy = Uyygygy/(eye(nspred*nspred)-beta*kron(Gy,Gy));
[Wyygugu, err] = A_times_B_kronecker_C(Wyy,Gu,Gu,options.threads.kronecker.A_times_B_kronecker_C);
[Wyygugu, err] = A_times_B_kronecker_C(Wyy,Gu,Gu);
mexErrCheck('A_times_B_kronecker_C', err);
[Wyygygu,err] = A_times_B_kronecker_C(Wyy,Gy,Gu,options.threads.kronecker.A_times_B_kronecker_C);
[Wyygygu,err] = A_times_B_kronecker_C(Wyy,Gy,Gu);
mexErrCheck('A_times_B_kronecker_C', err);
Wuu = Uyygugu+beta*Wyygugu;
Wyu = Uyygygu+beta*Wyygygu;
......@@ -90,19 +90,19 @@ end
yhat1 = yhat1(dr.order_var(nstatic+(1:nspred)),1)-dr.ys(dr.order_var(nstatic+(1:nspred)));
u = oo.exo_simul(1,:)';
[Wyyyhatyhat1, err] = A_times_B_kronecker_C(Wyy,yhat1,yhat1,options.threads.kronecker.A_times_B_kronecker_C);
[Wyyyhatyhat1, err] = A_times_B_kronecker_C(Wyy,yhat1,yhat1);
mexErrCheck('A_times_B_kronecker_C', err);
[Wuuuu, err] = A_times_B_kronecker_C(Wuu,u,u,options.threads.kronecker.A_times_B_kronecker_C);
[Wuuuu, err] = A_times_B_kronecker_C(Wuu,u,u);
mexErrCheck('A_times_B_kronecker_C', err);
[Wyuyhatu1, err] = A_times_B_kronecker_C(Wyu,yhat1,u,options.threads.kronecker.A_times_B_kronecker_C);
[Wyuyhatu1, err] = A_times_B_kronecker_C(Wyu,yhat1,u);
mexErrCheck('A_times_B_kronecker_C', err);
planner_objective_value(1) = Wbar+Wy*yhat1+Wu*u+Wyuyhatu1 ...
+ 0.5*(Wyyyhatyhat1 + Wuuuu+Wss);
if options.ramsey_policy
yhat2 = yhat2(dr.order_var(nstatic+(1:nspred)),1)-dr.ys(dr.order_var(nstatic+(1:nspred)));
[Wyyyhatyhat2, err] = A_times_B_kronecker_C(Wyy,yhat2,yhat2,options.threads.kronecker.A_times_B_kronecker_C);
[Wyyyhatyhat2, err] = A_times_B_kronecker_C(Wyy,yhat2,yhat2);
mexErrCheck('A_times_B_kronecker_C', err);
[Wyuyhatu2, err] = A_times_B_kronecker_C(Wyu,yhat2,u,options.threads.kronecker.A_times_B_kronecker_C);
[Wyuyhatu2, err] = A_times_B_kronecker_C(Wyu,yhat2,u);
mexErrCheck('A_times_B_kronecker_C', err);
planner_objective_value(2) = Wbar+Wy*yhat2+Wu*u+Wyuyhatu2 ...
+ 0.5*(Wyyyhatyhat2 + Wuuuu+Wss);
......
function [D, err] = A_times_B_kronecker_C(A,B,C,fake)
function [D, err] = A_times_B_kronecker_C(A,B,C)
%@info:
%! @deftypefn {Function File} {[@var{D}, @var{err}] =} A_times_B_kronecker_C (@var{A},@var{B},@var{C},@var{fake})
%! @deftypefn {Function File} {[@var{D}, @var{err}] =} A_times_B_kronecker_C (@var{A},@var{B},@var{C})
%! @anchor{kronecker/A_times_B_kronecker_C}
%! @sp 1
%! Computes A*kron(B,C).
......@@ -15,8 +15,6 @@ function [D, err] = A_times_B_kronecker_C(A,B,C,fake)
%! mB*nB matrix of doubles.
%! @item C
%! mC*nC matrix of doubles.
%! @item fake
%! Anything you want, just a fake parameter (because the mex version admits a last argument specifying the number of threads to be used in parallel mode).
%! @end table
%! @sp 2
%! @strong{Outputs}
......@@ -111,4 +109,4 @@ else
D = A * kron(B,B);
end
end
err = 0;
\ No newline at end of file
err = 0;
......@@ -79,12 +79,12 @@ function [y,y_] = local_state_space_iteration_2(yhat,epsilon,ghx,ghu,constant,gh
% frederic DOT karame AT univ DASH evry DOT fr
if nargin==9
pruning = 0; numthreads = a;
pruning = 0;
if nargout>1
error('local_state_space_iteration_2:: Numbers of input and output argument are inconsistent!')
end
elseif nargin==11
pruning = 1; numthreads = c; yhat_ = a; ss = b;
pruning = 1; yhat_ = a; ss = b;
if nargout~=2
error('local_state_space_iteration_2:: Numbers of input and output argument are inconsistent!')
end
......@@ -92,22 +92,20 @@ else
error('local_state_space_iteration_2:: Wrong number of input arguments!')
end
number_of_threads = numthreads;
switch pruning
case 0
for i =1:size(yhat,2)
y(:,i) = constant + ghx*yhat(:,i) + ghu*epsilon(:,i) ...
+ A_times_B_kronecker_C(.5*ghxx,yhat(:,i),number_of_threads) ...
+ A_times_B_kronecker_C(.5*ghuu,epsilon(:,i),number_of_threads) ...
+ A_times_B_kronecker_C(ghxu,yhat(:,i),epsilon(:,i),number_of_threads);
+ A_times_B_kronecker_C(.5*ghxx,yhat(:,i)) ...
+ A_times_B_kronecker_C(.5*ghuu,epsilon(:,i)) ...
+ A_times_B_kronecker_C(ghxu,yhat(:,i),epsilon(:,i));
end
case 1
for i =1:size(yhat,2)
y(:,i) = constant + ghx*yhat(:,i) + ghu*epsilon(:,i) ...
+ A_times_B_kronecker_C(.5*ghxx,yhat_(:,i),number_of_threads) ...
+ A_times_B_kronecker_C(.5*ghuu,epsilon(:,i),number_of_threads) ...
+ A_times_B_kronecker_C(ghxu,yhat_(:,i),epsilon(:,i),number_of_threads);
+ A_times_B_kronecker_C(.5*ghxx,yhat_(:,i)) ...
+ A_times_B_kronecker_C(.5*ghuu,epsilon(:,i)) ...
+ A_times_B_kronecker_C(ghxu,yhat_(:,i),epsilon(:,i));
end
y_ = ghx*yhat_+ghu*epsilon;
y_ = bsxfun(@plus,y_,ss);
......
......@@ -37,8 +37,6 @@ if ~isint(n)
end
switch mexname
case 'A_times_B_kronecker_C'
options_.threads.kronecker.A_times_B_kronecker_C = n;
case 'sparse_hessian_times_B_kronecker_C'
options_.threads.kronecker.sparse_hessian_times_B_kronecker_C = n;
case 'local_state_space_iteration_2'
......@@ -47,4 +45,4 @@ switch mexname
message = [ mexname ' is not a known parallel mex file.' ];
message_id = 'Dynare:Threads:UnknownParallelMex';
warning(message_id,message);
end
\ No newline at end of file
end
......@@ -98,11 +98,11 @@ else
yhat1 = y__(order_var(k2))-dr.ys(order_var(k2));
yhat2 = y_(order_var(k2),i-1)-dr.ys(order_var(k2));
epsilon = ex_(i-1,:)';
[abcOut1, err] = A_times_B_kronecker_C(.5*dr.ghxx,yhat1,options_.threads.kronecker.A_times_B_kronecker_C);
[abcOut1, err] = A_times_B_kronecker_C(.5*dr.ghxx,yhat1);
mexErrCheck('A_times_B_kronecker_C', err);
[abcOut2, err] = A_times_B_kronecker_C(.5*dr.ghuu,epsilon,options_.threads.kronecker.A_times_B_kronecker_C);
[abcOut2, err] = A_times_B_kronecker_C(.5*dr.ghuu,epsilon);
mexErrCheck('A_times_B_kronecker_C', err);
[abcOut3, err] = A_times_B_kronecker_C(dr.ghxu,yhat1,epsilon,options_.threads.kronecker.A_times_B_kronecker_C);
[abcOut3, err] = A_times_B_kronecker_C(dr.ghxu,yhat1,epsilon);
mexErrCheck('A_times_B_kronecker_C', err);
y_(order_var,i) = constant + dr.ghx*yhat2 + dr.ghu*epsilon ...
+ abcOut1 + abcOut2 + abcOut3;
......@@ -112,11 +112,11 @@ else
for i = 2:iter+M_.maximum_lag
yhat = y_(order_var(k2),i-1)-dr.ys(order_var(k2));
epsilon = ex_(i-1,:)';
[abcOut1, err] = A_times_B_kronecker_C(.5*dr.ghxx,yhat,options_.threads.kronecker.A_times_B_kronecker_C);
[abcOut1, err] = A_times_B_kronecker_C(.5*dr.ghxx,yhat);
mexErrCheck('A_times_B_kronecker_C', err);
[abcOut2, err] = A_times_B_kronecker_C(.5*dr.ghuu,epsilon,options_.threads.kronecker.A_times_B_kronecker_C);
[abcOut2, err] = A_times_B_kronecker_C(.5*dr.ghuu,epsilon);
mexErrCheck('A_times_B_kronecker_C', err);
[abcOut3, err] = A_times_B_kronecker_C(dr.ghxu,yhat,epsilon,options_.threads.kronecker.A_times_B_kronecker_C);
[abcOut3, err] = A_times_B_kronecker_C(dr.ghxu,yhat,epsilon);
mexErrCheck('A_times_B_kronecker_C', err);
y_(dr.order_var,i) = constant + dr.ghx*yhat + dr.ghu*epsilon ...
+ abcOut1 + abcOut2 + abcOut3;
......@@ -138,7 +138,6 @@ else
ghuuu = dr.ghuuu;
ghxss = dr.ghxss;
ghuss = dr.ghuss;
threads = options_.threads.kronecker.A_times_B_kronecker_C;
nspred = M_.nspred;
ipred = M_.nstatic+(1:nspred);
%construction follows Andreasen et al (2013), Technical
......@@ -151,29 +150,29 @@ else
u = ex_(i-1,:)';
%construct terms of order 2 from second order part, based
%on linear component yhat1
[gyy, err] = A_times_B_kronecker_C(ghxx,yhat1,threads);
[gyy, err] = A_times_B_kronecker_C(ghxx,yhat1);
mexErrCheck('A_times_B_kronecker_C', err);
[guu, err] = A_times_B_kronecker_C(ghuu,u,threads);
[guu, err] = A_times_B_kronecker_C(ghuu,u);
mexErrCheck('A_times_B_kronecker_C', err);
[gyu, err] = A_times_B_kronecker_C(ghxu,yhat1,u,threads);
[gyu, err] = A_times_B_kronecker_C(ghxu,yhat1,u);
mexErrCheck('A_times_B_kronecker_C', err);
%construct terms of order 3 from second order part, based
%on order 2 component yhat2
[gyy12, err] = A_times_B_kronecker_C(ghxx,yhat1,yhat2,threads);
[gyy12, err] = A_times_B_kronecker_C(ghxx,yhat1,yhat2);
mexErrCheck('A_times_B_kronecker_C', err);
[gy2u, err] = A_times_B_kronecker_C(ghxu,yhat2,u,threads);
[gy2u, err] = A_times_B_kronecker_C(ghxu,yhat2,u);
mexErrCheck('A_times_B_kronecker_C', err);
%construct terms of order 3, all based on first order component yhat1
y2a = kron(yhat1,yhat1);
[gyyy, err] = A_times_B_kronecker_C(ghxxx,y2a,yhat1,threads);
[gyyy, err] = A_times_B_kronecker_C(ghxxx,y2a,yhat1);
mexErrCheck('A_times_B_kronecker_C', err);
u2a = kron(u,u);
[guuu, err] = A_times_B_kronecker_C(ghuuu,u2a,u,threads);
[guuu, err] = A_times_B_kronecker_C(ghuuu,u2a,u);
mexErrCheck('A_times_B_kronecker_C', err);
yu = kron(yhat1,u);
[gyyu, err] = A_times_B_kronecker_C(ghxxu,yhat1,yu,threads);
[gyyu, err] = A_times_B_kronecker_C(ghxxu,yhat1,yu);
mexErrCheck('A_times_B_kronecker_C', err);
[gyuu, err] = A_times_B_kronecker_C(ghxuu,yu,u,threads);
[gyuu, err] = A_times_B_kronecker_C(ghxuu,yu,u);
mexErrCheck('A_times_B_kronecker_C', err);
%add all terms of order 3, linear component based on third
%order yhat3
......
......@@ -268,7 +268,6 @@ else
if local_order > 1
% Second order
dr = dyn_second_order_solver(jacobia_,hessian1,dr,M_,...
options_.threads.kronecker.A_times_B_kronecker_C,...
options_.threads.kronecker.sparse_hessian_times_B_kronecker_C);
% reordering second order derivatives, used for deterministic
......
......@@ -25,37 +25,10 @@
#include <dynmex.h>
#include <dynblas.h>
#ifdef USE_OMP
# include <omp.h>
#endif
#define DEBUG_OMP 0
void
full_A_times_kronecker_B_C(const double *A, const double *B, const double *C, double *D,
blas_int mA, blas_int nA, blas_int mB, blas_int nB, blas_int mC, blas_int nC, int number_of_threads)
blas_int mA, blas_int nA, blas_int mB, blas_int nB, blas_int mC, blas_int nC)
{
#ifdef USE_OMP
# pragma omp parallel for num_threads(number_of_threads)
for (blas_int colD = 0; colD < nB*nC; colD++)
{
# if DEBUG_OMP
mexPrintf("%d thread number is %d (%d).\n", colD, omp_get_thread_num(), omp_get_num_threads());
# endif
blas_int colB = colD/nC;
blas_int colC = colD%nC;
for (blas_int colA = 0; colA < nA; colA++)
{
blas_int rowB = colA/mC;
blas_int rowC = colA%mC;
blas_int idxA = colA*mA;
blas_int idxD = colD*mA;
double BC = B[colB*mB+rowB]*C[colC*mC+rowC];
for (blas_int rowD = 0; rowD < mA; rowD++)
D[idxD+rowD] += A[idxA+rowD]*BC;
}
}
#else
const blas_int shiftA = mA*mC;
const blas_int shiftD = mA*nC;
blas_int kd = 0, ka = 0;
......@@ -70,33 +43,11 @@ full_A_times_kronecker_B_C(const double *A, const double *B, const double *C, do
}
kd += shiftD;
}
#endif
}
void
full_A_times_kronecker_B_B(const double *A, const double *B, double *D, blas_int mA, blas_int nA, blas_int mB, blas_int nB, int number_of_threads)
full_A_times_kronecker_B_B(const double *A, const double *B, double *D, blas_int mA, blas_int nA, blas_int mB, blas_int nB)
{
#ifdef USE_OMP
# pragma omp parallel for num_threads(number_of_threads)
for (blas_int colD = 0; colD < nB*nB; colD++)
{
# if DEBUG_OMP
mexPrintf("%d thread number is %d (%d).\n", colD, omp_get_thread_num(), omp_get_num_threads());
# endif
blas_int j1B = colD/nB;
blas_int j2B = colD%nB;
for (blas_int colA = 0; colA < nA; colA++)
{
blas_int i1B = colA/mB;
blas_int i2B = colA%mB;
blas_int idxA = colA*mA;
blas_int idxD = colD*mA;
double BB = B[j1B*mB+i1B]*B[j2B*mB+i2B];
for (blas_int rowD = 0; rowD < mA; rowD++)
D[idxD+rowD] += A[idxA+rowD]*BB;
}
}
#else
const blas_int shiftA = mA*mB;
const blas_int shiftD = mA*nB;
blas_int kd = 0, ka = 0;
......@@ -111,15 +62,14 @@ full_A_times_kronecker_B_B(const double *A, const double *B, double *D, blas_int
}
kd += shiftD;
}
#endif
}
void
mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
// Check input and output:
if (nrhs > 4 || nrhs < 3)
DYN_MEX_FUNC_ERR_MSG_TXT("A_times_B_kronecker_C takes 3 or 4 input arguments and provides 2 output arguments.");
if (nrhs > 3 || nrhs < 2)
DYN_MEX_FUNC_ERR_MSG_TXT("A_times_B_kronecker_C takes 2 or 3 input arguments and provides 2 output arguments.");
// Get & Check dimensions (columns and rows):
size_t mA = mxGetM(prhs[0]);
......@@ -127,7 +77,7 @@ mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
size_t mB = mxGetM(prhs[1]);
size_t nB = mxGetN(prhs[1]);
size_t mC, nC;
if (nrhs == 4) // A·(B⊗C) is to be computed.
if (nrhs == 3) // A·(B⊗C) is to be computed.
{
mC = mxGetM(prhs[2]);
nC = mxGetN(prhs[2]);
......@@ -140,30 +90,24 @@ mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
DYN_MEX_FUNC_ERR_MSG_TXT("Input dimension error!");
}
// Get input matrices:
int numthreads;
const double *A = mxGetPr(prhs[0]);
const double *B = mxGetPr(prhs[1]);
const double *C;
if (nrhs == 4)
{
C = mxGetPr(prhs[2]);
numthreads = static_cast<int>(mxGetScalar(prhs[3]));
}
else
numthreads = static_cast<int>(mxGetScalar(prhs[2]));
const double *C{nullptr};
if (nrhs == 3)
C = mxGetPr(prhs[2]);
// Initialization of the ouput:
if (nrhs == 4)
if (nrhs == 3)
plhs[0] = mxCreateDoubleMatrix(mA, nB*nC, mxREAL);
else
plhs[0] = mxCreateDoubleMatrix(mA, nB*nB, mxREAL);
double *D = mxGetPr(plhs[0]);
// Computational part:
if (nrhs == 3)
full_A_times_kronecker_B_B(A, B, D, mA, nA, mB, nB, numthreads);
if (nrhs == 2)
full_A_times_kronecker_B_B(A, B, D, mA, nA, mB, nB);
else
full_A_times_kronecker_B_C(A, B, C, D, mA, nA, mB, nB, mC, nC, numthreads);
full_A_times_kronecker_B_C(A, B, C, D, mA, nA, mB, nB, mC, nC);
plhs[1] = mxCreateDoubleScalar(0);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment