From 7ba77751c375f415b8a5b9329b74dea1bc46fc4a Mon Sep 17 00:00:00 2001 From: Houtan Bastani <houtan@dynare.org> Date: Tue, 5 May 2020 14:55:09 -0400 Subject: [PATCH] simplify external function C output --- src/ExprNode.cc | 201 ++++++++++++++++++------------------------------ src/ExprNode.hh | 2 +- 2 files changed, 77 insertions(+), 126 deletions(-) diff --git a/src/ExprNode.cc b/src/ExprNode.cc index 156f201d..10e9eaa3 100644 --- a/src/ExprNode.cc +++ b/src/ExprNode.cc @@ -7051,13 +7051,12 @@ void AbstractExternalFunctionNode::writePrhs(ostream &output, ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, - const deriv_node_temp_terms_t &tef_terms, const string &ending) const + const deriv_node_temp_terms_t &tef_terms) const { - output << "mxArray *prhs"<< ending << "[nrhs"<< ending << "];" << endl; int i = 0; for (auto argument : arguments) { - output << "prhs" << ending << "[" << i++ << "] = mxCreateDoubleScalar("; // All external_function arguments are scalars + output << " prhs[" << i++ << "] = mxCreateDoubleScalar("; // All external_function arguments are scalars argument->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms); output << ");" << endl; } @@ -7283,44 +7282,36 @@ ExternalFunctionNode::writeExternalFunctionOutput(ostream &output, ExprNodeOutpu if (isCOutput(output_type)) { - stringstream ending; - ending << "_tef_" << getIndxInTefTerms(symb_id, tef_terms); - if (symb_id == first_deriv_symb_id - && symb_id == second_deriv_symb_id) - output << "int nlhs" << ending.str() << " = 3;" << endl - << "double *TEF_" << indx << ", " - << "*TEFD_" << indx << ", " - << "*TEFDD_" << indx << ";" << endl; - else if (symb_id == first_deriv_symb_id) - output << "int nlhs" << ending.str() << " = 2;" << endl - << "double *TEF_" << indx << ", " - << "*TEFD_" << indx << "; " << endl; - else - output << "int nlhs" << ending.str() << " = 1;" << endl - << "double *TEF_" << indx << ";" << endl; - - output << "mxArray *plhs" << ending.str()<< "[nlhs"<< ending.str() << "];" << endl; - output << "int nrhs" << ending.str()<< " = " << arguments.size() << ";" << endl; - writePrhs(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms, ending.str()); - - output << "mexCallMATLAB(" - << "nlhs" << ending.str() << ", " - << "plhs" << ending.str() << ", " - << "nrhs" << ending.str() << ", " - << "prhs" << ending.str() << R"(, ")" + output << "double *TEF_" << indx; + if (symb_id == first_deriv_symb_id) + output << ", *TEFD_" << indx; + if (symb_id == second_deriv_symb_id) + output << ", *TEFDD_" << indx; + output << ";" << endl; + + if (symb_id == first_deriv_symb_id && symb_id == second_deriv_symb_id) + output << "int TEFDD_" << indx << "_nrows;" << endl; + + int nlhs = + symb_id == first_deriv_symb_id && symb_id == second_deriv_symb_id ? 3 + : symb_id == first_deriv_symb_id ? 2 : 1; + output << "{" << endl + << " mxArray *plhs[" << nlhs << "], *prhs[" << arguments.size() << "];" << endl; + + writePrhs(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms); + + output << " mexCallMATLAB(" << nlhs << ", plhs, " << arguments.size() << ", prhs, " << R"(")" << datatree.symbol_table.getName(symb_id) << R"(");)" << endl; - if (symb_id == first_deriv_symb_id - && symb_id == second_deriv_symb_id) - output << "TEF_" << indx << " = mxGetPr(plhs" << ending.str() << "[0]);" << endl - << "TEFD_" << indx << " = mxGetPr(plhs" << ending.str() << "[1]);" << endl - << "TEFDD_" << indx << " = mxGetPr(plhs" << ending.str() << "[2]);" << endl - << "int TEFDD_" << indx << "_nrows = (int)mxGetM(plhs" << ending.str()<< "[2]);" << endl; - else if (symb_id == first_deriv_symb_id) - output << "TEF_" << indx << " = mxGetPr(plhs" << ending.str() << "[0]);" << endl - << "TEFD_" << indx << " = mxGetPr(plhs" << ending.str() << "[1]);" << endl; - else - output << "TEF_" << indx << " = mxGetPr(plhs" << ending.str() << "[0]);" << endl; + output << " TEF_" << indx << " = mxGetPr(plhs[0]);" << endl; + if (symb_id == first_deriv_symb_id) + { + output << " TEFD_" << indx << " = mxGetPr(plhs[1]);" << endl; + if (symb_id == second_deriv_symb_id) + output << " TEFDD_" << indx << " = mxGetPr(plhs[2]);" << endl + << " TEFDD_" << indx << "_nrows = (int)mxGetM(plhs[2]);" << endl; + } + output << "}" << endl; } else { @@ -7577,62 +7568,42 @@ FirstDerivExternalFunctionNode::writeExternalFunctionOutput(ostream &output, Exp if (isCOutput(output_type)) if (first_deriv_symb_id == ExternalFunctionsTable::IDNotSet) { - stringstream ending; - ending << "_tefd_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex; - output << "int nlhs" << ending.str() << " = 1;" << endl - << "double *TEFD_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex << ";" << endl - << "mxArray *plhs" << ending.str() << "[nlhs"<< ending.str() << "];" << endl - << "int nrhs" << ending.str() << " = 3;" << endl - << "mxArray *prhs" << ending.str() << "[nrhs"<< ending.str() << "];" << endl - << "mwSize dims" << ending.str() << "[2];" << endl; - - output << "dims" << ending.str() << "[0] = 1;" << endl - << "dims" << ending.str() << "[1] = " << arguments.size() << ";" << endl; - - output << "prhs" << ending.str() << R"([0] = mxCreateString(")" << datatree.symbol_table.getName(symb_id) << R"(");)" << endl - << "prhs" << ending.str() << "[1] = mxCreateDoubleScalar(" << inputIndex << ");"<< endl - << "prhs" << ending.str() << "[2] = mxCreateCellArray(2, dims" << ending.str() << ");"<< endl; + output << "double *TEFD_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex << ";" << endl + << "{" << endl + << " const mwSize dims[2] = {1, " << arguments.size() << "};" << endl + << " mxArray *plhs[1], *prhs[3];" << endl + << R"( prhs[0] = mxCreateString(")" << datatree.symbol_table.getName(symb_id) << R"(");)" << endl + << " prhs[1] = mxCreateDoubleScalar(" << inputIndex << ");"<< endl + << " prhs[2] = mxCreateCellArray(2, dims);"<< endl; int i = 0; for (auto argument : arguments) { - output << "mxSetCell(prhs" << ending.str() << "[2], " - << i++ << ", " + output << " mxSetCell(prhs[2], " << i++ << ", " << "mxCreateDoubleScalar("; // All external_function arguments are scalars argument->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms); output << "));" << endl; } - output << "mexCallMATLAB(" - << "nlhs" << ending.str() << ", " - << "plhs" << ending.str() << ", " - << "nrhs" << ending.str() << ", " - << "prhs" << ending.str() << R"(, ")" - << R"(jacob_element");)" << endl; - - output << "TEFD_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex - << " = mxGetPr(plhs" << ending.str() << "[0]);" << endl; + output << " mexCallMATLAB(1, plhs, 3, prhs," << R"("jacob_element");)" << endl + << " TEFD_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex + << " = mxGetPr(plhs[0]);" << endl + << "}" << endl; } else { tef_terms[{ first_deriv_symb_id, arguments }] = static_cast<int>(tef_terms.size()); int indx = getIndxInTefTerms(first_deriv_symb_id, tef_terms); - stringstream ending; - ending << "_tefd_def_" << indx; - output << "int nlhs" << ending.str() << " = 1;" << endl - << "double *TEFD_def_" << indx << ";" << endl - << "mxArray *plhs" << ending.str() << "[nlhs"<< ending.str() << "];" << endl - << "int nrhs" << ending.str() << " = " << arguments.size() << ";" << endl; - writePrhs(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms, ending.str()); - - output << "mexCallMATLAB(" - << "nlhs" << ending.str() << ", " - << "plhs" << ending.str() << ", " - << "nrhs" << ending.str() << ", " - << "prhs" << ending.str() << R"(, ")" - << datatree.symbol_table.getName(first_deriv_symb_id) << R"(");)" << endl; - - output << "TEFD_def_" << indx << " = mxGetPr(plhs" << ending.str() << "[0]);" << endl; + output << "double *TEFD_def_" << indx << ";" << endl + << "{" << endl + << " mxArray *plhs[1], *prhs[" << arguments.size() << "];" << endl; + + writePrhs(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms); + + output << " mexCallMATLAB(1, plhs, " << arguments.size() << ", prhs," << R"(")" + << datatree.symbol_table.getName(first_deriv_symb_id) << R"(");)" << endl + << " TEFD_def_" << indx << " = mxGetPr(plhs[0]);" << endl + << "}" << endl; } else { @@ -7926,63 +7897,43 @@ SecondDerivExternalFunctionNode::writeExternalFunctionOutput(ostream &output, Ex if (second_deriv_symb_id == ExternalFunctionsTable::IDNotSet) { stringstream ending; - ending << "_tefdd_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex1 << "_" << inputIndex2; - output << "int nlhs" << ending.str() << " = 1;" << endl - << "double *TEFDD_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex1 << "_" << inputIndex2 << ";" << endl - << "mxArray *plhs" << ending.str() << "[nlhs"<< ending.str() << "];" << endl - << "int nrhs" << ending.str() << " = 4;" << endl - << "mxArray *prhs" << ending.str() << "[nrhs"<< ending.str() << "];" << endl - << "mwSize dims" << ending.str() << "[2];" << endl; - - output << "dims" << ending.str() << "[0] = 1;" << endl - << "dims" << ending.str() << "[1] = " << arguments.size() << ";" << endl; - - output << "prhs" << ending.str() << R"([0] = mxCreateString(")" << datatree.symbol_table.getName(symb_id) << R"(");)" << endl - << "prhs" << ending.str() << "[1] = mxCreateDoubleScalar(" << inputIndex1 << ");"<< endl - << "prhs" << ending.str() << "[2] = mxCreateDoubleScalar(" << inputIndex2 << ");"<< endl - << "prhs" << ending.str() << "[3] = mxCreateCellArray(2, dims" << ending.str() << ");"<< endl; + output << "double *TEFDD_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex1 << "_" << inputIndex2 << ";" << endl + << "{" << endl + << " const mwSize dims[2]= {1, " << arguments.size() << "};" << endl + << " mxArray *plhs[1], *prhs[4];" << endl + << R"( prhs[0] = mxCreateString(")" << datatree.symbol_table.getName(symb_id) << R"(");)" << endl + << " prhs[1] = mxCreateDoubleScalar(" << inputIndex1 << ");"<< endl + << " prhs[2] = mxCreateDoubleScalar(" << inputIndex2 << ");"<< endl + << " prhs[3] = mxCreateCellArray(2, dims);"<< endl; int i = 0; for (auto argument : arguments) { - output << "mxSetCell(prhs" << ending.str() << "[3], " - << i++ << ", " - << "mxCreateDoubleScalar("; // All external_function arguments are scalars + output << " mxSetCell(prhs[3], " << i++ << ", " + << " mxCreateDoubleScalar("; // All external_function arguments are scalars argument->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms); output << "));" << endl; } - output << "mexCallMATLAB(" - << "nlhs" << ending.str() << ", " - << "plhs" << ending.str() << ", " - << "nrhs" << ending.str() << ", " - << "prhs" << ending.str() << R"(, ")" - << R"(hess_element");)" << endl; - - output << "TEFDD_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex1 << "_" << inputIndex2 - << " = mxGetPr(plhs" << ending.str() << "[0]);" << endl; + output << " mexCallMATLAB(1, plhs, 4, prhs, " << R"("hess_element");)" << endl + << " TEFDD_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex1 << "_" << inputIndex2 + << " = mxGetPr(plhs[0]);" << endl + << "}" << endl; } else { tef_terms[{ second_deriv_symb_id, arguments }] = static_cast<int>(tef_terms.size()); int indx = getIndxInTefTerms(second_deriv_symb_id, tef_terms); - stringstream ending; - ending << "_tefdd_def_" << indx; - - output << "int nlhs" << ending.str() << " = 1;" << endl - << "double *TEFDD_def_" << indx << ";" << endl - << "mxArray *plhs" << ending.str() << "[nlhs"<< ending.str() << "];" << endl - << "int nrhs" << ending.str() << " = " << arguments.size() << ";" << endl; - writePrhs(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms, ending.str()); - - output << "mexCallMATLAB(" - << "nlhs" << ending.str() << ", " - << "plhs" << ending.str() << ", " - << "nrhs" << ending.str() << ", " - << "prhs" << ending.str() << R"(, ")" - << datatree.symbol_table.getName(second_deriv_symb_id) << R"(");)" << endl; - - output << "TEFDD_def_" << indx << " = mxGetPr(plhs" << ending.str() << "[0]);" << endl; + output << "double *TEFDD_def_" << indx << ";" << endl + << "{" << endl + << " mxArray *plhs[1], *prhs[" << arguments.size() << "];" << endl; + + writePrhs(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms); + + output << " mexCallMATLAB(1, plhs, " << arguments.size() << ", prhs, " << R"(")" + << datatree.symbol_table.getName(second_deriv_symb_id) << R"(");)" << endl + << " TEFDD_def_" << indx << " = mxGetPr(plhs[0]);" << endl + << "}" << endl; } else { diff --git a/src/ExprNode.hh b/src/ExprNode.hh index 4c6f0dfa..677abd6e 100644 --- a/src/ExprNode.hh +++ b/src/ExprNode.hh @@ -1336,7 +1336,7 @@ public: bool containsExogenous() const override; int countDiffs() const override; bool isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const override; - void writePrhs(ostream &output, ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms, const string &ending) const; + void writePrhs(ostream &output, ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const; expr_t replaceTrendVar() const override; expr_t detrend(int symb_id, bool log_trend, expr_t trend) const override; expr_t clone(DataTree &datatree) const override = 0; -- GitLab