From 7ba77751c375f415b8a5b9329b74dea1bc46fc4a Mon Sep 17 00:00:00 2001
From: Houtan Bastani <houtan@dynare.org>
Date: Tue, 5 May 2020 14:55:09 -0400
Subject: [PATCH] simplify external function C output

---
 src/ExprNode.cc | 201 ++++++++++++++++++------------------------------
 src/ExprNode.hh |   2 +-
 2 files changed, 77 insertions(+), 126 deletions(-)

diff --git a/src/ExprNode.cc b/src/ExprNode.cc
index 156f201d..10e9eaa3 100644
--- a/src/ExprNode.cc
+++ b/src/ExprNode.cc
@@ -7051,13 +7051,12 @@ void
 AbstractExternalFunctionNode::writePrhs(ostream &output, ExprNodeOutputType output_type,
                                         const temporary_terms_t &temporary_terms,
                                         const temporary_terms_idxs_t &temporary_terms_idxs,
-                                        const deriv_node_temp_terms_t &tef_terms, const string &ending) const
+                                        const deriv_node_temp_terms_t &tef_terms) const
 {
-  output << "mxArray *prhs"<< ending << "[nrhs"<< ending << "];" << endl;
   int i = 0;
   for (auto argument : arguments)
     {
-      output << "prhs" << ending << "[" << i++ << "] = mxCreateDoubleScalar("; // All external_function arguments are scalars
+      output << "  prhs[" << i++ << "] = mxCreateDoubleScalar("; // All external_function arguments are scalars
       argument->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
       output << ");" << endl;
     }
@@ -7283,44 +7282,36 @@ ExternalFunctionNode::writeExternalFunctionOutput(ostream &output, ExprNodeOutpu
 
       if (isCOutput(output_type))
         {
-          stringstream ending;
-          ending << "_tef_" << getIndxInTefTerms(symb_id, tef_terms);
-          if (symb_id == first_deriv_symb_id
-              && symb_id == second_deriv_symb_id)
-            output << "int nlhs" << ending.str() << " = 3;" << endl
-                   << "double *TEF_" << indx << ", "
-                   << "*TEFD_" << indx << ", "
-                   << "*TEFDD_" << indx << ";" << endl;
-          else if (symb_id == first_deriv_symb_id)
-            output << "int nlhs" << ending.str() << " = 2;" << endl
-                   << "double *TEF_" << indx << ", "
-                   << "*TEFD_" << indx << "; " << endl;
-          else
-            output << "int nlhs" << ending.str() << " = 1;" << endl
-                   << "double *TEF_" << indx << ";" << endl;
-
-          output << "mxArray *plhs" << ending.str()<< "[nlhs"<< ending.str() << "];" << endl;
-          output << "int nrhs" << ending.str()<< " = " << arguments.size() << ";" << endl;
-          writePrhs(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms, ending.str());
-
-          output << "mexCallMATLAB("
-                 << "nlhs" << ending.str() << ", "
-                 << "plhs" << ending.str() << ", "
-                 << "nrhs" << ending.str() << ", "
-                 << "prhs" << ending.str() << R"(, ")"
+          output << "double *TEF_" << indx;
+          if (symb_id == first_deriv_symb_id)
+            output << ", *TEFD_" << indx;
+          if (symb_id == second_deriv_symb_id)
+            output << ", *TEFDD_" << indx;
+          output << ";" << endl;
+
+          if (symb_id == first_deriv_symb_id && symb_id == second_deriv_symb_id)
+            output << "int TEFDD_" << indx << "_nrows;" << endl;
+
+          int nlhs =
+            symb_id == first_deriv_symb_id && symb_id == second_deriv_symb_id ? 3
+            : symb_id == first_deriv_symb_id ? 2 : 1;
+          output << "{" << endl
+                 << "  mxArray *plhs[" << nlhs << "], *prhs[" << arguments.size() << "];" << endl;
+
+          writePrhs(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
+
+          output << "  mexCallMATLAB(" << nlhs << ", plhs, " << arguments.size() << ", prhs, " << R"(")"
                  << datatree.symbol_table.getName(symb_id) << R"(");)" << endl;
 
-          if (symb_id == first_deriv_symb_id
-              && symb_id == second_deriv_symb_id)
-            output << "TEF_" << indx << " = mxGetPr(plhs" << ending.str() << "[0]);" << endl
-                   << "TEFD_" << indx << " = mxGetPr(plhs" << ending.str() << "[1]);" << endl
-                   << "TEFDD_" << indx << " = mxGetPr(plhs" << ending.str() << "[2]);" << endl
-                   << "int TEFDD_" << indx << "_nrows = (int)mxGetM(plhs" << ending.str()<< "[2]);" << endl;
-          else if (symb_id == first_deriv_symb_id)
-            output << "TEF_" << indx << " = mxGetPr(plhs" << ending.str() << "[0]);" << endl
-                   << "TEFD_" << indx << " = mxGetPr(plhs" << ending.str() << "[1]);" << endl;
-          else
-            output << "TEF_" << indx << " = mxGetPr(plhs" << ending.str() << "[0]);" << endl;
+          output << "  TEF_" << indx << " = mxGetPr(plhs[0]);" << endl;
+          if (symb_id == first_deriv_symb_id)
+            {
+              output << "  TEFD_" << indx << " = mxGetPr(plhs[1]);" << endl;
+              if (symb_id == second_deriv_symb_id)
+                output << "  TEFDD_" << indx << " = mxGetPr(plhs[2]);" << endl
+                       << "  TEFDD_" << indx << "_nrows = (int)mxGetM(plhs[2]);" << endl;
+            }
+          output << "}" << endl;
         }
       else
         {
@@ -7577,62 +7568,42 @@ FirstDerivExternalFunctionNode::writeExternalFunctionOutput(ostream &output, Exp
   if (isCOutput(output_type))
     if (first_deriv_symb_id == ExternalFunctionsTable::IDNotSet)
       {
-        stringstream ending;
-        ending << "_tefd_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex;
-        output << "int nlhs" << ending.str() << " = 1;" << endl
-               << "double *TEFD_fdd_" <<  getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex << ";" << endl
-               << "mxArray *plhs" << ending.str() << "[nlhs"<< ending.str() << "];" << endl
-               << "int nrhs" << ending.str() << " = 3;" << endl
-               << "mxArray *prhs" << ending.str() << "[nrhs"<< ending.str() << "];" << endl
-               << "mwSize dims" << ending.str() << "[2];" << endl;
-
-        output << "dims" << ending.str() << "[0] = 1;" << endl
-               << "dims" << ending.str() << "[1] = " << arguments.size() << ";" << endl;
-
-        output << "prhs" << ending.str() << R"([0] = mxCreateString(")" << datatree.symbol_table.getName(symb_id) << R"(");)" << endl
-               << "prhs" << ending.str() << "[1] = mxCreateDoubleScalar(" << inputIndex << ");"<< endl
-               << "prhs" << ending.str() << "[2] = mxCreateCellArray(2, dims" << ending.str() << ");"<< endl;
+        output << "double *TEFD_fdd_" <<  getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex << ";" << endl
+               << "{" << endl
+               << "  const mwSize dims[2] = {1, " << arguments.size() << "};" << endl
+               << "  mxArray *plhs[1], *prhs[3];" << endl
+               << R"(  prhs[0] = mxCreateString(")" << datatree.symbol_table.getName(symb_id) << R"(");)" << endl
+               << "  prhs[1] = mxCreateDoubleScalar(" << inputIndex << ");"<< endl
+               << "  prhs[2] = mxCreateCellArray(2, dims);"<< endl;
 
         int i = 0;
         for (auto argument : arguments)
           {
-            output << "mxSetCell(prhs" << ending.str() << "[2], "
-                   << i++ << ", "
+            output << "  mxSetCell(prhs[2], " << i++ << ", "
                    << "mxCreateDoubleScalar("; // All external_function arguments are scalars
             argument->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
             output << "));" << endl;
           }
 
-        output << "mexCallMATLAB("
-               << "nlhs" << ending.str() << ", "
-               << "plhs" << ending.str() << ", "
-               << "nrhs" << ending.str() << ", "
-               << "prhs" << ending.str() << R"(, ")"
-               << R"(jacob_element");)" << endl;
-
-        output << "TEFD_fdd_" <<  getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex
-               << " = mxGetPr(plhs" << ending.str() << "[0]);" << endl;
+        output << "  mexCallMATLAB(1, plhs, 3, prhs," << R"("jacob_element");)" << endl
+               << "  TEFD_fdd_" <<  getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex
+               << " = mxGetPr(plhs[0]);" << endl
+               << "}" << endl;
       }
     else
       {
         tef_terms[{ first_deriv_symb_id, arguments }] = static_cast<int>(tef_terms.size());
         int indx = getIndxInTefTerms(first_deriv_symb_id, tef_terms);
-        stringstream ending;
-        ending << "_tefd_def_" << indx;
-        output << "int nlhs" << ending.str() << " = 1;" << endl
-               << "double *TEFD_def_" << indx << ";" << endl
-               << "mxArray *plhs" << ending.str() << "[nlhs"<< ending.str() << "];" << endl
-               << "int nrhs" << ending.str() << " = " << arguments.size() << ";" << endl;
-        writePrhs(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms, ending.str());
-
-        output << "mexCallMATLAB("
-               << "nlhs" << ending.str() << ", "
-               << "plhs" << ending.str() << ", "
-               << "nrhs" << ending.str() << ", "
-               << "prhs" << ending.str() << R"(, ")"
-               << datatree.symbol_table.getName(first_deriv_symb_id) << R"(");)" << endl;
-
-        output << "TEFD_def_" << indx << " = mxGetPr(plhs" << ending.str() << "[0]);" << endl;
+        output << "double *TEFD_def_" << indx << ";" << endl
+               << "{" << endl
+               << "  mxArray *plhs[1], *prhs[" << arguments.size() << "];" << endl;
+
+        writePrhs(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
+
+        output << "  mexCallMATLAB(1, plhs, " << arguments.size() << ", prhs," << R"(")"
+               << datatree.symbol_table.getName(first_deriv_symb_id) << R"(");)" << endl
+               << "  TEFD_def_" << indx << " = mxGetPr(plhs[0]);" << endl
+               << "}" << endl;
       }
   else
     {
@@ -7926,63 +7897,43 @@ SecondDerivExternalFunctionNode::writeExternalFunctionOutput(ostream &output, Ex
     if (second_deriv_symb_id == ExternalFunctionsTable::IDNotSet)
       {
         stringstream ending;
-        ending << "_tefdd_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex1 << "_" << inputIndex2;
-        output << "int nlhs" << ending.str() << " = 1;" << endl
-               << "double *TEFDD_fdd_" <<  getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex1 << "_" << inputIndex2 << ";" << endl
-               << "mxArray *plhs" << ending.str() << "[nlhs"<< ending.str() << "];" << endl
-               << "int nrhs" << ending.str() << " = 4;" << endl
-               << "mxArray *prhs" << ending.str() << "[nrhs"<< ending.str() << "];" << endl
-               << "mwSize dims" << ending.str() << "[2];" << endl;
-
-        output << "dims" << ending.str() << "[0] = 1;" << endl
-               << "dims" << ending.str() << "[1] = " << arguments.size() << ";" << endl;
-
-        output << "prhs" << ending.str() << R"([0] = mxCreateString(")" << datatree.symbol_table.getName(symb_id) << R"(");)" << endl
-               << "prhs" << ending.str() << "[1] = mxCreateDoubleScalar(" << inputIndex1 << ");"<< endl
-               << "prhs" << ending.str() << "[2] = mxCreateDoubleScalar(" << inputIndex2 << ");"<< endl
-               << "prhs" << ending.str() << "[3] = mxCreateCellArray(2, dims" << ending.str() << ");"<< endl;
+        output << "double *TEFDD_fdd_" <<  getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex1 << "_" << inputIndex2 << ";" << endl
+               << "{" << endl
+               << "  const mwSize dims[2]= {1, " << arguments.size() << "};" << endl
+               << "  mxArray *plhs[1], *prhs[4];" << endl
+               << R"(  prhs[0] = mxCreateString(")" << datatree.symbol_table.getName(symb_id) << R"(");)" << endl
+               << "  prhs[1] = mxCreateDoubleScalar(" << inputIndex1 << ");"<< endl
+               << "  prhs[2] = mxCreateDoubleScalar(" << inputIndex2 << ");"<< endl
+               << "  prhs[3] = mxCreateCellArray(2, dims);"<< endl;
 
         int i = 0;
         for (auto argument : arguments)
           {
-            output << "mxSetCell(prhs" << ending.str() << "[3], "
-                   << i++ << ", "
-                   << "mxCreateDoubleScalar("; // All external_function arguments are scalars
+            output << "  mxSetCell(prhs[3], " << i++ << ", "
+                   << "  mxCreateDoubleScalar("; // All external_function arguments are scalars
             argument->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
             output << "));" << endl;
           }
 
-        output << "mexCallMATLAB("
-               << "nlhs" << ending.str() << ", "
-               << "plhs" << ending.str() << ", "
-               << "nrhs" << ending.str() << ", "
-               << "prhs" << ending.str() << R"(, ")"
-               << R"(hess_element");)" << endl;
-
-        output << "TEFDD_fdd_" <<  getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex1 << "_" << inputIndex2
-               << " = mxGetPr(plhs" << ending.str() << "[0]);" << endl;
+        output << "  mexCallMATLAB(1, plhs, 4, prhs, " << R"("hess_element");)" << endl
+               << "  TEFDD_fdd_" <<  getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex1 << "_" << inputIndex2
+               << " = mxGetPr(plhs[0]);" << endl
+               << "}" << endl;
       }
     else
       {
         tef_terms[{ second_deriv_symb_id, arguments }] = static_cast<int>(tef_terms.size());
         int indx = getIndxInTefTerms(second_deriv_symb_id, tef_terms);
-        stringstream ending;
-        ending << "_tefdd_def_" << indx;
-
-        output << "int nlhs" << ending.str() << " = 1;" << endl
-               << "double *TEFDD_def_" << indx << ";" << endl
-               << "mxArray *plhs" << ending.str() << "[nlhs"<< ending.str() << "];" << endl
-               << "int nrhs" << ending.str() << " = " << arguments.size() << ";" << endl;
-        writePrhs(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms, ending.str());
-
-        output << "mexCallMATLAB("
-               << "nlhs" << ending.str() << ", "
-               << "plhs" << ending.str() << ", "
-               << "nrhs" << ending.str() << ", "
-               << "prhs" << ending.str() << R"(, ")"
-               << datatree.symbol_table.getName(second_deriv_symb_id) << R"(");)" << endl;
-
-        output << "TEFDD_def_" << indx << " = mxGetPr(plhs" << ending.str() << "[0]);" << endl;
+        output << "double *TEFDD_def_" << indx << ";" << endl
+               << "{" << endl
+               << "  mxArray *plhs[1], *prhs[" << arguments.size() << "];" << endl;
+
+        writePrhs(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
+
+        output << "  mexCallMATLAB(1, plhs, " << arguments.size() << ", prhs, " << R"(")"
+               << datatree.symbol_table.getName(second_deriv_symb_id) << R"(");)" << endl
+               << "  TEFDD_def_" << indx << " = mxGetPr(plhs[0]);" << endl
+               << "}" << endl;
       }
   else
     {
diff --git a/src/ExprNode.hh b/src/ExprNode.hh
index 4c6f0dfa..677abd6e 100644
--- a/src/ExprNode.hh
+++ b/src/ExprNode.hh
@@ -1336,7 +1336,7 @@ public:
   bool containsExogenous() const override;
   int countDiffs() const override;
   bool isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const override;
-  void writePrhs(ostream &output, ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms, const string &ending) const;
+  void writePrhs(ostream &output, ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const;
   expr_t replaceTrendVar() const override;
   expr_t detrend(int symb_id, bool log_trend, expr_t trend) const override;
   expr_t clone(DataTree &datatree) const override = 0;
-- 
GitLab