Skip to content
GitLab
Menu
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
Johannes Pfeifer
dynare
Commits
3bd3c78e
Unverified
Commit
3bd3c78e
authored
Jun 04, 2021
by
Sébastien Villemot
Browse files
A_times_B_kronecker_C MEX: rewrite in Fortran
parent
c4ca0ef0
Changes
3
Hide whitespace changes
Inline
Side-by-side
mex/build/kronecker.am
View file @
3bd3c78e
mex_PROGRAMS = sparse_hessian_times_B_kronecker_C A_times_B_kronecker_C
nodist_sparse_hessian_times_B_kronecker_C_SOURCES = sparse_hessian_times_B_kronecker_C.cc
nodist_A_times_B_kronecker_C_SOURCES = A_times_B_kronecker_C.
cc
nodist_A_times_B_kronecker_C_SOURCES = A_times_B_kronecker_C.
f08 matlab_mex.F08 blas_lapack.F08
sparse_hessian_times_B_kronecker_C_CXXFLAGS = $(AM_CXXFLAGS) -fopenmp
sparse_hessian_times_B_kronecker_C_LDFLAGS = $(AM_LDFLAGS) $(OPENMP_LDFLAGS)
...
...
@@ -11,3 +11,8 @@ CLEANFILES = $(nodist_sparse_hessian_times_B_kronecker_C_SOURCES) $(nodist_A_tim
%.cc: $(top_srcdir)/../../sources/kronecker/%.cc
$(LN_S) -f $< $@
A_times_B_kronecker_C.o : matlab_mex.mod lapack.mod
%.f08: $(top_srcdir)/../../sources/kronecker/%.f08
$(LN_S) -f $< $@
mex/sources/kronecker/A_times_B_kronecker_C.cc
deleted
100644 → 0
View file @
c4ca0ef0
/*
* Copyright © 2007-2020 Dynare Team
*
* This file is part of Dynare.
*
* Dynare is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* Dynare is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with Dynare. If not, see <http://www.gnu.org/licenses/>.
*/
/*
* This mex file computes A·(B⊗C) or A·(B⊗B) without explicitly building B⊗C or B⊗B, so that
* one can consider large matrices B and/or C.
*/
#include
<dynmex.h>
#include
<dynblas.h>
void
full_A_times_kronecker_B_C
(
const
double
*
A
,
const
double
*
B
,
const
double
*
C
,
double
*
D
,
blas_int
mA
,
blas_int
nA
,
blas_int
mB
,
blas_int
nB
,
blas_int
mC
,
blas_int
nC
)
{
const
blas_int
shiftA
=
mA
*
mC
;
const
blas_int
shiftD
=
mA
*
nC
;
blas_int
kd
=
0
,
ka
=
0
;
double
one
=
1.0
;
for
(
blas_int
col
=
0
;
col
<
nB
;
col
++
)
{
ka
=
0
;
for
(
blas_int
row
=
0
;
row
<
mB
;
row
++
)
{
dgemm
(
"N"
,
"N"
,
&
mA
,
&
nC
,
&
mC
,
&
B
[
mB
*
col
+
row
],
&
A
[
ka
],
&
mA
,
C
,
&
mC
,
&
one
,
&
D
[
kd
],
&
mA
);
ka
+=
shiftA
;
}
kd
+=
shiftD
;
}
}
void
full_A_times_kronecker_B_B
(
const
double
*
A
,
const
double
*
B
,
double
*
D
,
blas_int
mA
,
blas_int
nA
,
blas_int
mB
,
blas_int
nB
)
{
const
blas_int
shiftA
=
mA
*
mB
;
const
blas_int
shiftD
=
mA
*
nB
;
blas_int
kd
=
0
,
ka
=
0
;
double
one
=
1.0
;
for
(
blas_int
col
=
0
;
col
<
nB
;
col
++
)
{
ka
=
0
;
for
(
blas_int
row
=
0
;
row
<
mB
;
row
++
)
{
dgemm
(
"N"
,
"N"
,
&
mA
,
&
nB
,
&
mB
,
&
B
[
mB
*
col
+
row
],
&
A
[
ka
],
&
mA
,
B
,
&
mB
,
&
one
,
&
D
[
kd
],
&
mA
);
ka
+=
shiftA
;
}
kd
+=
shiftD
;
}
}
void
mexFunction
(
int
nlhs
,
mxArray
*
plhs
[],
int
nrhs
,
const
mxArray
*
prhs
[])
{
// Check input and output:
if
(
nrhs
>
3
||
nrhs
<
2
||
nlhs
!=
1
)
{
mexErrMsgTxt
(
"A_times_B_kronecker_C takes 2 or 3 input arguments and provides 1 output argument."
);
return
;
// Needed to shut up some GCC warnings
}
// Get & Check dimensions (columns and rows):
size_t
mA
=
mxGetM
(
prhs
[
0
]);
size_t
nA
=
mxGetN
(
prhs
[
0
]);
size_t
mB
=
mxGetM
(
prhs
[
1
]);
size_t
nB
=
mxGetN
(
prhs
[
1
]);
size_t
mC
,
nC
;
if
(
nrhs
==
3
)
// A·(B⊗C) is to be computed.
{
mC
=
mxGetM
(
prhs
[
2
]);
nC
=
mxGetN
(
prhs
[
2
]);
if
(
mB
*
mC
!=
nA
)
mexErrMsgTxt
(
"Input dimension error!"
);
}
else
// A·(B⊗B) is to be computed.
{
if
(
mB
*
mB
!=
nA
)
mexErrMsgTxt
(
"Input dimension error!"
);
}
// Get input matrices:
const
double
*
A
=
mxGetPr
(
prhs
[
0
]);
const
double
*
B
=
mxGetPr
(
prhs
[
1
]);
const
double
*
C
{
nullptr
};
if
(
nrhs
==
3
)
C
=
mxGetPr
(
prhs
[
2
]);
// Initialization of the ouput:
if
(
nrhs
==
3
)
plhs
[
0
]
=
mxCreateDoubleMatrix
(
mA
,
nB
*
nC
,
mxREAL
);
else
plhs
[
0
]
=
mxCreateDoubleMatrix
(
mA
,
nB
*
nB
,
mxREAL
);
double
*
D
=
mxGetPr
(
plhs
[
0
]);
// Computational part:
if
(
nrhs
==
2
)
full_A_times_kronecker_B_B
(
A
,
B
,
D
,
mA
,
nA
,
mB
,
nB
);
else
full_A_times_kronecker_B_C
(
A
,
B
,
C
,
D
,
mA
,
nA
,
mB
,
nB
,
mC
,
nC
);
}
mex/sources/kronecker/A_times_B_kronecker_C.f08
0 → 100644
View file @
3bd3c78e
! This MEX file computes A·(B⊗C) or A·(B⊗B) without explicitly building B⊗C or
! B⊗B, so that one can consider large matrices B and/or C.
! Copyright © 2007-2021 Dynare Team
!
! This file is part of Dynare.
!
! Dynare is free software: you can redistribute it and/or modify
! it under the terms of the GNU General Public License as published by
! the Free Software Foundation, either version 3 of the License, or
! (at your option) any later version.
!
! Dynare is distributed in the hope that it will be useful,
! but WITHOUT ANY WARRANTY; without even the implied warranty of
! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
! GNU General Public License for more details.
!
! You should have received a copy of the GNU General Public License
! along with Dynare. If not, see <http://www.gnu.org/licenses/>.
subroutine
mexFunction
(
nlhs
,
plhs
,
nrhs
,
prhs
)
bind
(
c
,
name
=
'mexFunction'
)
use
iso_fortran_env
,
only
:
real64
use
iso_c_binding
,
only
:
c_int
use
matlab_mex
use
blas
implicit
none
type
(
c_ptr
),
dimension
(
*
),
intent
(
in
),
target
::
prhs
type
(
c_ptr
),
dimension
(
*
),
intent
(
out
)
::
plhs
integer
(
c_int
),
intent
(
in
),
value
::
nlhs
,
nrhs
integer
(
c_size_t
)
::
mA
,
nA
,
mB
,
nB
,
mC
,
nC
real
(
real64
),
dimension
(:,
:),
pointer
,
contiguous
::
A
,
B
,
C
,
D
if
(
nrhs
>
3
.or.
nrhs
<
2
.or.
nlhs
/
=
1
)
then
call
mexErrMsgTxt
(
"A_times_B_kronecker_C takes 2 or 3 input arguments and provides 1 output argument"
)
end
if
if
(
.not.
mxIsDouble
(
prhs
(
1
))
.or.
mxIsComplex
(
prhs
(
1
))
&
.or.
.not.
mxIsDouble
(
prhs
(
2
))
.or.
mxIsComplex
(
prhs
(
2
)))
then
call
mexErrMsgTxt
(
"A_times_B_kronecker_C: first two arguments should be real matrices"
)
end
if
mA
=
mxGetM
(
prhs
(
1
))
nA
=
mxGetN
(
prhs
(
1
))
mB
=
mxGetM
(
prhs
(
2
))
nB
=
mxGetN
(
prhs
(
2
))
A
(
1
:
mA
,
1
:
nA
)
=>
mxGetPr
(
prhs
(
1
))
B
(
1
:
mB
,
1
:
nB
)
=>
mxGetPr
(
prhs
(
2
))
if
(
nrhs
==
3
)
then
! A·(B⊗C) is to be computed.
if
(
.not.
mxIsDouble
(
prhs
(
3
))
.or.
mxIsComplex
(
prhs
(
3
)))
then
call
mexErrMsgTxt
(
"A_times_B_kronecker_C: third argument should be a real matrix"
)
end
if
mC
=
mxGetM
(
prhs
(
3
))
nC
=
mxGetN
(
prhs
(
3
))
if
(
mB
*
mC
/
=
nA
)
then
call
mexErrMsgTxt
(
"Input dimension error!"
)
end
if
C
(
1
:
mC
,
1
:
nC
)
=>
mxGetPr
(
prhs
(
3
))
plhs
(
1
)
=
mxCreateDoubleMatrix
(
mA
,
nB
*
nC
,
mxREAL
)
D
(
1
:
mA
,
1
:
nB
*
nC
)
=>
mxGetPr
(
plhs
(
1
))
call
full_A_times_kronecker_B_C
else
! A·(B⊗B) is to be computed.
if
(
mB
*
mB
/
=
nA
)
then
call
mexErrMsgTxt
(
"Input dimension error!"
)
end
if
plhs
(
1
)
=
mxCreateDoubleMatrix
(
mA
,
nB
*
nB
,
mxREAL
)
D
(
1
:
mA
,
1
:
nB
*
nB
)
=>
mxGetPr
(
plhs
(
1
))
call
full_A_times_kronecker_B_B
end
if
contains
! Computes D=A·(B⊗C)
subroutine
full_A_times_kronecker_B_C
integer
(
c_size_t
)
::
i
,
j
,
ka
,
kd
kd
=
1
do
j
=
1
,
nB
ka
=
1
do
i
=
1
,
mB
! D(:,kd:kd+nC) += B(i,j)·A(:,ka:ka+mC)·C
call
dgemm
(
"N"
,
"N"
,
int
(
mA
,
blint
),
int
(
nC
,
blint
),
int
(
mC
,
blint
),
B
(
i
,
j
),
&
A
(:,
ka
:
ka
+
mC
),
int
(
mA
,
blint
),
C
,
int
(
mC
,
blint
),
1._real64
,
&
D
(:,
kd
:
kd
+
nC
),
int
(
mA
,
blint
))
ka
=
ka
+
mC
end
do
kd
=
kd
+
nC
end
do
end
subroutine
full_A_times_kronecker_B_C
! Computes D=A·(B⊗B)
subroutine
full_A_times_kronecker_B_B
integer
(
c_size_t
)
::
i
,
j
,
ka
,
kd
kd
=
1
do
j
=
1
,
nB
ka
=
1
do
i
=
1
,
mB
! D(:,kd:kd+nB) += B(i,j)·A(:,ka:ka+mB)·B
call
dgemm
(
"N"
,
"N"
,
int
(
mA
,
blint
),
int
(
nB
,
blint
),
int
(
mB
,
blint
),
B
(
i
,
j
),
&
A
(:,
ka
:
ka
+
mB
),
int
(
mA
,
blint
),
B
,
int
(
mB
,
blint
),
1._real64
,
&
D
(:,
kd
:
kd
+
nB
),
int
(
mA
,
blint
))
ka
=
ka
+
mB
end
do
kd
=
kd
+
nB
end
do
end
subroutine
full_A_times_kronecker_B_B
end
subroutine
mexFunction
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment