From 49fcdd9e80c4d5d88daad531bae0bed4a69a53e3 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?S=C3=A9bastien=20Villemot?= <sebastien@dynare.org>
Date: Mon, 28 Feb 2022 12:18:58 +0100
Subject: [PATCH] Bytecode: fix bug in sparse matrix multiplication routines
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Since those routines compute Aᵀ·B, the number of rows of the output is equal to
the number of columns of A.

(cherry picked from commit a7cc4662852061d5b35bc62309ec5318e30e0a2f)
---
 mex/sources/bytecode/SparseMatrix.cc | 9 +++------
 1 file changed, 3 insertions(+), 6 deletions(-)

diff --git a/mex/sources/bytecode/SparseMatrix.cc b/mex/sources/bytecode/SparseMatrix.cc
index 9fc01320b0..1c8da9fab4 100644
--- a/mex/sources/bytecode/SparseMatrix.cc
+++ b/mex/sources/bytecode/SparseMatrix.cc
@@ -1796,13 +1796,12 @@ mxArray *
 dynSparseMatrix::mult_SAT_B(const mxArray *A_m, const mxArray *B_m)
 {
   size_t n_A = mxGetN(A_m);
-  size_t m_A = mxGetM(A_m);
   mwIndex *A_i = mxGetIr(A_m);
   mwIndex *A_j = mxGetJc(A_m);
   double *A_d = mxGetPr(A_m);
   size_t n_B = mxGetN(B_m);
   double *B_d = mxGetPr(B_m);
-  mxArray *C_m = mxCreateDoubleMatrix(m_A, n_B, mxREAL);
+  mxArray *C_m = mxCreateDoubleMatrix(n_A, n_B, mxREAL);
   double *C_d = mxGetPr(C_m);
   for (int j = 0; j < static_cast<int>(n_B); j++)
     for (unsigned int i = 0; i < n_A; i++)
@@ -1823,14 +1822,13 @@ mxArray *
 dynSparseMatrix::Sparse_mult_SAT_B(const mxArray *A_m, const mxArray *B_m)
 {
   size_t n_A = mxGetN(A_m);
-  size_t m_A = mxGetM(A_m);
   mwIndex *A_i = mxGetIr(A_m);
   mwIndex *A_j = mxGetJc(A_m);
   double *A_d = mxGetPr(A_m);
   size_t n_B = mxGetN(B_m);
   size_t m_B = mxGetM(B_m);
   double *B_d = mxGetPr(B_m);
-  mxArray *C_m = mxCreateSparse(m_A, n_B, m_A*n_B, mxREAL);
+  mxArray *C_m = mxCreateSparse(n_A, n_B, n_A*n_B, mxREAL);
   mwIndex *C_i = mxGetIr(C_m);
   mwIndex *C_j = mxGetJc(C_m);
   double *C_d = mxGetPr(C_m);
@@ -1868,7 +1866,6 @@ mxArray *
 dynSparseMatrix::Sparse_mult_SAT_SB(const mxArray *A_m, const mxArray *B_m)
 {
   size_t n_A = mxGetN(A_m);
-  size_t m_A = mxGetM(A_m);
   mwIndex *A_i = mxGetIr(A_m);
   mwIndex *A_j = mxGetJc(A_m);
   double *A_d = mxGetPr(A_m);
@@ -1876,7 +1873,7 @@ dynSparseMatrix::Sparse_mult_SAT_SB(const mxArray *A_m, const mxArray *B_m)
   mwIndex *B_i = mxGetIr(B_m);
   mwIndex *B_j = mxGetJc(B_m);
   double *B_d = mxGetPr(B_m);
-  mxArray *C_m = mxCreateSparse(m_A, n_B, m_A*n_B, mxREAL);
+  mxArray *C_m = mxCreateSparse(n_A, n_B, n_A*n_B, mxREAL);
   mwIndex *C_i = mxGetIr(C_m);
   mwIndex *C_j = mxGetJc(C_m);
   double *C_d = mxGetPr(C_m);
-- 
GitLab