diff --git a/src/DynamicModel.cc b/src/DynamicModel.cc index 68b8a5c1ba25443d22e16a23d447be6ce9c6defb..175ad2a6be779d144c629b461d691a6a60b07bc7 100644 --- a/src/DynamicModel.cc +++ b/src/DynamicModel.cc @@ -189,11 +189,11 @@ DynamicModel::operator=(const DynamicModel &m) } void -DynamicModel::compileDerivative(ofstream &code_file, unsigned int &instruction_number, int eq, int symb_id, int lag, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs) const +DynamicModel::compileDerivative(ofstream &code_file, unsigned int &instruction_number, int eq, int symb_id, int lag, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const { if (auto it = derivatives[1].find({ eq, getDerivID(symbol_table.getID(SymbolType::endogenous, symb_id), lag) }); it != derivatives[1].end()) - it->second->compile(code_file, instruction_number, false, temporary_terms, temporary_terms_idxs, true, false); + it->second->compile(code_file, instruction_number, false, temporary_terms, temporary_terms_idxs, true, false, tef_terms); else { FLDZ_ fldz; @@ -202,11 +202,11 @@ DynamicModel::compileDerivative(ofstream &code_file, unsigned int &instruction_n } void -DynamicModel::compileChainRuleDerivative(ofstream &code_file, unsigned int &instruction_number, int blk, int eq, int var, int lag, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs) const +DynamicModel::compileChainRuleDerivative(ofstream &code_file, unsigned int &instruction_number, int blk, int eq, int var, int lag, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const { if (auto it = blocks_derivatives[blk].find({ eq, var, lag }); it != blocks_derivatives[blk].end()) - it->second->compile(code_file, instruction_number, false, temporary_terms, temporary_terms_idxs, true, false); + it->second->compile(code_file, instruction_number, false, temporary_terms, temporary_terms_idxs, true, false, tef_terms); else { FLDZ_ fldz; @@ -907,9 +907,11 @@ DynamicModel::writeDynamicBytecode(const string &basename) const fbeginblock.write(code_file, instruction_number); temporary_terms_t temporary_terms_union; - compileTemporaryTerms(code_file, instruction_number, true, false, temporary_terms_union, temporary_terms_idxs); + deriv_node_temp_terms_t tef_terms; + + compileTemporaryTerms(code_file, instruction_number, true, false, temporary_terms_union, temporary_terms_idxs, tef_terms); - compileModelEquations(code_file, instruction_number, true, false, temporary_terms_union, temporary_terms_idxs); + compileModelEquations(code_file, instruction_number, true, false, temporary_terms_union, temporary_terms_idxs, tef_terms); FENDEQU_ fendequ; fendequ.write(code_file, instruction_number); @@ -936,7 +938,7 @@ DynamicModel::writeDynamicBytecode(const string &basename) const if (!my_derivatives[eq].size()) my_derivatives[eq].clear(); my_derivatives[eq].emplace_back(var, lag, count_u); - d1->compile(code_file, instruction_number, false, temporary_terms_union, temporary_terms_idxs, true, false); + d1->compile(code_file, instruction_number, false, temporary_terms_union, temporary_terms_idxs, true, false, tef_terms); FSTPU_ fstpu(count_u); fstpu.write(code_file, instruction_number); @@ -998,7 +1000,7 @@ DynamicModel::writeDynamicBytecode(const string &basename) const prev_lag = lag; count_col_endo++; } - d1->compile(code_file, instruction_number, false, temporary_terms_union, temporary_terms_idxs, true, false); + d1->compile(code_file, instruction_number, false, temporary_terms_union, temporary_terms_idxs, true, false, tef_terms); FSTPG3_ fstpg3(eq, var, lag, count_col_endo-1); fstpg3.write(code_file, instruction_number); } @@ -1017,7 +1019,7 @@ DynamicModel::writeDynamicBytecode(const string &basename) const prev_lag = lag; count_col_exo++; } - d1->compile(code_file, instruction_number, false, temporary_terms_union, temporary_terms_idxs, true, false); + d1->compile(code_file, instruction_number, false, temporary_terms_union, temporary_terms_idxs, true, false, tef_terms); FSTPG3_ fstpg3(eq, var, lag, count_col_exo-1); fstpg3.write(code_file, instruction_number); } @@ -1173,16 +1175,16 @@ DynamicModel::writeDynamicBlockBytecode(const string &basename) const eq_node = getBlockEquationExpr(block, i); lhs = eq_node->arg1; rhs = eq_node->arg2; - rhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, true, false); - lhs->compile(code_file, instruction_number, true, temporary_terms_union, blocks_temporary_terms_idxs, true, false); + rhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, true, false, tef_terms); + lhs->compile(code_file, instruction_number, true, temporary_terms_union, blocks_temporary_terms_idxs, true, false, tef_terms); } else if (equ_type == EquationType::evaluateRenormalized) { eq_node = getBlockEquationRenormalizedExpr(block, i); lhs = eq_node->arg1; rhs = eq_node->arg2; - rhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, true, false); - lhs->compile(code_file, instruction_number, true, temporary_terms_union, blocks_temporary_terms_idxs, true, false); + rhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, true, false, tef_terms); + lhs->compile(code_file, instruction_number, true, temporary_terms_union, blocks_temporary_terms_idxs, true, false, tef_terms); } break; case BlockSimulationType::solveBackwardComplete: @@ -1203,8 +1205,8 @@ DynamicModel::writeDynamicBlockBytecode(const string &basename) const eq_node = getBlockEquationExpr(block, i); lhs = eq_node->arg1; rhs = eq_node->arg2; - lhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, true, false); - rhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, true, false); + lhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, true, false, tef_terms); + rhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, true, false, tef_terms); FBINARY_ fbinary{static_cast<int>(BinaryOpcode::minus)}; fbinary.write(code_file, instruction_number); @@ -1242,7 +1244,7 @@ DynamicModel::writeDynamicBlockBytecode(const string &basename) const FNUMEXPR_ fnumexpr(ExpressionType::FirstEndoDerivative, getBlockEquationID(block, 0), getBlockVariableID(block, 0), 0); fnumexpr.write(code_file, instruction_number); } - compileDerivative(code_file, instruction_number, getBlockEquationID(block, 0), getBlockVariableID(block, 0), 0, temporary_terms_union, blocks_temporary_terms_idxs); + compileDerivative(code_file, instruction_number, getBlockEquationID(block, 0), getBlockVariableID(block, 0), 0, temporary_terms_union, blocks_temporary_terms_idxs, tef_terms); { FSTPG_ fstpg(0); fstpg.write(code_file, instruction_number); @@ -1281,7 +1283,7 @@ DynamicModel::writeDynamicBlockBytecode(const string &basename) const Uf[eqr].Ufl->lag = lag; FNUMEXPR_ fnumexpr(ExpressionType::FirstEndoDerivative, eqr, varr, lag); fnumexpr.write(code_file, instruction_number); - compileChainRuleDerivative(code_file, instruction_number, block, eq, var, lag, temporary_terms_union, blocks_temporary_terms_idxs); + compileChainRuleDerivative(code_file, instruction_number, block, eq, var, lag, temporary_terms_union, blocks_temporary_terms_idxs, tef_terms); FSTPU_ fstpu(count_u); fstpu.write(code_file, instruction_number); count_u++; @@ -1358,7 +1360,7 @@ DynamicModel::writeDynamicBlockBytecode(const string &basename) const int varr = getBlockVariableID(block, var); FNUMEXPR_ fnumexpr(ExpressionType::FirstEndoDerivative, eqr, varr, lag); fnumexpr.write(code_file, instruction_number); - compileDerivative(code_file, instruction_number, eqr, varr, lag, temporary_terms_union, blocks_temporary_terms_idxs); + compileDerivative(code_file, instruction_number, eqr, varr, lag, temporary_terms_union, blocks_temporary_terms_idxs, tef_terms); FSTPG3_ fstpg3(eq, var, lag, blocks_jacob_cols_endo[block].at({ var, lag })); fstpg3.write(code_file, instruction_number); } @@ -1369,7 +1371,7 @@ DynamicModel::writeDynamicBlockBytecode(const string &basename) const int varr = 0; // Dummy value, actually unused by the bytecode MEX FNUMEXPR_ fnumexpr(ExpressionType::FirstExoDerivative, eqr, varr, lag); fnumexpr.write(code_file, instruction_number); - d->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, true, false); + d->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, true, false, tef_terms); FSTPG3_ fstpg3(eq, var, lag, blocks_jacob_cols_exo[block].at({ var, lag })); fstpg3.write(code_file, instruction_number); } @@ -1380,7 +1382,7 @@ DynamicModel::writeDynamicBlockBytecode(const string &basename) const int varr = 0; // Dummy value, actually unused by the bytecode MEX FNUMEXPR_ fnumexpr(ExpressionType::FirstExodetDerivative, eqr, varr, lag); fnumexpr.write(code_file, instruction_number); - d->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, true, false); + d->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, true, false, tef_terms); FSTPG3_ fstpg3(eq, var, lag, blocks_jacob_cols_exo_det[block].at({ var, lag })); fstpg3.write(code_file, instruction_number); } @@ -1391,7 +1393,7 @@ DynamicModel::writeDynamicBlockBytecode(const string &basename) const int varr = 0; // Dummy value, actually unused by the bytecode MEX FNUMEXPR_ fnumexpr(ExpressionType::FirstOtherEndoDerivative, eqr, varr, lag); fnumexpr.write(code_file, instruction_number); - d->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, true, false); + d->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, true, false, tef_terms); FSTPG3_ fstpg3(eq, var, lag, blocks_jacob_cols_other_endo[block].at({ var, lag })); fstpg3.write(code_file, instruction_number); } diff --git a/src/DynamicModel.hh b/src/DynamicModel.hh index ee08a698d7d51d8ded6d1814362af52834695fed..17b74b674db94351740d9053d9eeb3bb0af559ba 100644 --- a/src/DynamicModel.hh +++ b/src/DynamicModel.hh @@ -214,9 +214,9 @@ private: map<expr_t, tuple<int, int, int>> &reference_count) const override; //! Write derivative code of an equation w.r. to a variable - void compileDerivative(ofstream &code_file, unsigned int &instruction_number, int eq, int symb_id, int lag, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs) const; + void compileDerivative(ofstream &code_file, unsigned int &instruction_number, int eq, int symb_id, int lag, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const; //! Write chain rule derivative code of an equation w.r. to a variable - void compileChainRuleDerivative(ofstream &code_file, unsigned int &instruction_number, int blk, int eq, int var, int lag, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs) const; + void compileChainRuleDerivative(ofstream &code_file, unsigned int &instruction_number, int blk, int eq, int var, int lag, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const; //! Get the type corresponding to a derivation ID SymbolType getTypeByDerivID(int deriv_id) const noexcept(false) override; diff --git a/src/ExprNode.cc b/src/ExprNode.cc index 4ce8ae0b331a7d591d6884d88ce7db7268fb9d46..4bf7d466f2974e124eae251426d211807bdc684f 100644 --- a/src/ExprNode.cc +++ b/src/ExprNode.cc @@ -169,14 +169,6 @@ ExprNode::writeOutput(ostream &output, ExprNodeOutputType output_type, const tem writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, {}); } -void -ExprNode::compile(ostream &CompileCode, unsigned int &instruction_number, - bool lhs_rhs, const temporary_terms_t &temporary_terms, - const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic) const -{ - compile(CompileCode, instruction_number, lhs_rhs, temporary_terms, temporary_terms_idxs, dynamic, steady_dynamic, {}); -} - void ExprNode::writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms, diff --git a/src/ExprNode.hh b/src/ExprNode.hh index 6cb7a2471a79d2a70b90381c53f562422feabfff..14dd67d120125e761c6573ed589baeed4cd67617 100644 --- a/src/ExprNode.hh +++ b/src/ExprNode.hh @@ -409,7 +409,7 @@ public: virtual double eval(const eval_context_t &eval_context) const noexcept(false) = 0; virtual void compile(ostream &CompileCode, unsigned int &instruction_number, bool lhs_rhs, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic, const deriv_node_temp_terms_t &tef_terms) const = 0; - void compile(ostream &CompileCode, unsigned int &instruction_number, bool lhs_rhs, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic) const; + //! Creates a static version of this node /*! This method duplicates the current node by creating a similar node from which all leads/lags have been stripped, diff --git a/src/ModelTree.cc b/src/ModelTree.cc index 5c0db750c8ba19234ccbc37e758d8c837cb31da1..b79ee6b5ff1bbafc0e1abfd300c96fdfdb6c4c39 100644 --- a/src/ModelTree.cc +++ b/src/ModelTree.cc @@ -1263,10 +1263,9 @@ ModelTree::testNestedParenthesis(const string &str) const } void -ModelTree::compileTemporaryTerms(ostream &code_file, unsigned int &instruction_number, bool dynamic, bool steady_dynamic, temporary_terms_t &temporary_terms_union, const temporary_terms_idxs_t &temporary_terms_idxs) const +ModelTree::compileTemporaryTerms(ostream &code_file, unsigned int &instruction_number, bool dynamic, bool steady_dynamic, temporary_terms_t &temporary_terms_union, const temporary_terms_idxs_t &temporary_terms_idxs, deriv_node_temp_terms_t &tef_terms) const { // To store the functions that have already been written in the form TEF* = ext_fun(); - deriv_node_temp_terms_t tef_terms; for (auto [tt, idx] : temporary_terms_idxs) { if (dynamic_cast<AbstractExternalFunctionNode *>(tt)) @@ -1401,7 +1400,7 @@ ModelTree::writeModelEquations(ostream &output, ExprNodeOutputType output_type, } void -ModelTree::compileModelEquations(ostream &code_file, unsigned int &instruction_number, bool dynamic, bool steady_dynamic, const temporary_terms_t &temporary_terms_union, const temporary_terms_idxs_t &temporary_terms_idxs) const +ModelTree::compileModelEquations(ostream &code_file, unsigned int &instruction_number, bool dynamic, bool steady_dynamic, const temporary_terms_t &temporary_terms_union, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const { for (int eq = 0; eq < static_cast<int>(equations.size()); eq++) { @@ -1422,8 +1421,8 @@ ModelTree::compileModelEquations(ostream &code_file, unsigned int &instruction_n if (vrhs != 0) // The right hand side of the equation is not empty ==> residual=lhs-rhs; { - lhs->compile(code_file, instruction_number, false, temporary_terms_union, temporary_terms_idxs, dynamic, steady_dynamic); - rhs->compile(code_file, instruction_number, false, temporary_terms_union, temporary_terms_idxs, dynamic, steady_dynamic); + lhs->compile(code_file, instruction_number, false, temporary_terms_union, temporary_terms_idxs, dynamic, steady_dynamic, tef_terms); + rhs->compile(code_file, instruction_number, false, temporary_terms_union, temporary_terms_idxs, dynamic, steady_dynamic, tef_terms); FBINARY_ fbinary{static_cast<int>(BinaryOpcode::minus)}; fbinary.write(code_file, instruction_number); @@ -1433,7 +1432,7 @@ ModelTree::compileModelEquations(ostream &code_file, unsigned int &instruction_n } else // The right hand side of the equation is empty ==> residual=lhs; { - lhs->compile(code_file, instruction_number, false, temporary_terms_union, temporary_terms_idxs, dynamic, steady_dynamic); + lhs->compile(code_file, instruction_number, false, temporary_terms_union, temporary_terms_idxs, dynamic, steady_dynamic, tef_terms); FSTPR_ fstpr(eq); fstpr.write(code_file, instruction_number); } diff --git a/src/ModelTree.hh b/src/ModelTree.hh index 32f1932d870dcb23f98f258809a3006994a01f84..f736160ff5815a6d7afbacd8894b026cdc1dfe2a 100644 --- a/src/ModelTree.hh +++ b/src/ModelTree.hh @@ -235,7 +235,7 @@ protected: void writeTemporaryTerms(const temporary_terms_t &tt, temporary_terms_t &temp_term_union, const temporary_terms_idxs_t &tt_idxs, ostream &output, ExprNodeOutputType output_type, deriv_node_temp_terms_t &tef_terms) const; void writeJsonTemporaryTerms(const temporary_terms_t &tt, temporary_terms_t &temp_term_union, ostream &output, deriv_node_temp_terms_t &tef_terms, const string &concat) const; //! Compiles temporary terms - void compileTemporaryTerms(ostream &code_file, unsigned int &instruction_number, bool dynamic, bool steady_dynamic, temporary_terms_t &temporary_terms_union, const temporary_terms_idxs_t &temporary_terms_idxs) const; + void compileTemporaryTerms(ostream &code_file, unsigned int &instruction_number, bool dynamic, bool steady_dynamic, temporary_terms_t &temporary_terms_union, const temporary_terms_idxs_t &temporary_terms_idxs, deriv_node_temp_terms_t &tef_terms) const; //! Adds information for (non-block) bytecode simulation in a separate .bin file void writeBytecodeBinFile(const string &filename, int &u_count_int, bool &file_open, bool is_two_boundaries) const; //! Fixes output when there are more than 32 nested parens, Issue #1201 @@ -258,7 +258,7 @@ protected: Optionally put the external function variable calls into TEF terms */ void writeJsonModelLocalVariables(ostream &output, bool write_tef_terms, deriv_node_temp_terms_t &tef_terms) const; //! Compiles model equations - void compileModelEquations(ostream &code_file, unsigned int &instruction_number, bool dynamic, bool steady_dynamic, const temporary_terms_t &temporary_terms_union, const temporary_terms_idxs_t &temporary_terms_idxs) const; + void compileModelEquations(ostream &code_file, unsigned int &instruction_number, bool dynamic, bool steady_dynamic, const temporary_terms_t &temporary_terms_union, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const; //! Writes LaTeX model file void writeLatexModelFile(const string &mod_basename, const string &latex_basename, ExprNodeOutputType output_type, bool write_equation_tags) const; diff --git a/src/StaticModel.cc b/src/StaticModel.cc index 0a06c9d1f5a7b34b81b688c1c1fdd57224a3ad1e..964ed5b70259c4e5d85fb42e0f91b8bbed0302a6 100644 --- a/src/StaticModel.cc +++ b/src/StaticModel.cc @@ -103,11 +103,11 @@ StaticModel::StaticModel(const DynamicModel &m) : } void -StaticModel::compileDerivative(ofstream &code_file, unsigned int &instruction_number, int eq, int symb_id, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs) const +StaticModel::compileDerivative(ofstream &code_file, unsigned int &instruction_number, int eq, int symb_id, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const { if (auto it = derivatives[1].find({ eq, getDerivID(symbol_table.getID(SymbolType::endogenous, symb_id), 0) }); it != derivatives[1].end()) - it->second->compile(code_file, instruction_number, false, temporary_terms, temporary_terms_idxs, false, false); + it->second->compile(code_file, instruction_number, false, temporary_terms, temporary_terms_idxs, false, false, tef_terms); else { FLDZ_ fldz; @@ -116,11 +116,11 @@ StaticModel::compileDerivative(ofstream &code_file, unsigned int &instruction_nu } void -StaticModel::compileChainRuleDerivative(ofstream &code_file, unsigned int &instruction_number, int blk, int eq, int var, int lag, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs) const +StaticModel::compileChainRuleDerivative(ofstream &code_file, unsigned int &instruction_number, int blk, int eq, int var, int lag, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const { if (auto it = blocks_derivatives[blk].find({ eq, var, lag }); it != blocks_derivatives[blk].end()) - it->second->compile(code_file, instruction_number, false, temporary_terms, temporary_terms_idxs, false, false); + it->second->compile(code_file, instruction_number, false, temporary_terms, temporary_terms_idxs, false, false, tef_terms); else { FLDZ_ fldz; @@ -420,9 +420,11 @@ StaticModel::writeStaticBytecode(const string &basename) const fbeginblock.write(code_file, instruction_number); temporary_terms_t temporary_terms_union; - compileTemporaryTerms(code_file, instruction_number, false, false, temporary_terms_union, temporary_terms_idxs); + deriv_node_temp_terms_t tef_terms; + + compileTemporaryTerms(code_file, instruction_number, false, false, temporary_terms_union, temporary_terms_idxs, tef_terms); - compileModelEquations(code_file, instruction_number, false, false, temporary_terms_union, temporary_terms_idxs); + compileModelEquations(code_file, instruction_number, false, false, temporary_terms_union, temporary_terms_idxs, tef_terms); FENDEQU_ fendequ; fendequ.write(code_file, instruction_number); @@ -449,7 +451,7 @@ StaticModel::writeStaticBytecode(const string &basename) const my_derivatives[eq].clear(); my_derivatives[eq].emplace_back(var, count_u); - d1->compile(code_file, instruction_number, false, temporary_terms_union, temporary_terms_idxs, false, false); + d1->compile(code_file, instruction_number, false, temporary_terms_union, temporary_terms_idxs, false, false, tef_terms); FSTPSU_ fstpsu(count_u); fstpsu.write(code_file, instruction_number); @@ -511,7 +513,7 @@ StaticModel::writeStaticBytecode(const string &basename) const my_derivatives[eq].clear(); my_derivatives[eq].emplace_back(var, count_u); - d1->compile(code_file, instruction_number, false, temporary_terms_union, temporary_terms_idxs, false, false); + d1->compile(code_file, instruction_number, false, temporary_terms_union, temporary_terms_idxs, false, false, tef_terms); FSTPG2_ fstpg2(eq, var); fstpg2.write(code_file, instruction_number); } @@ -660,16 +662,16 @@ StaticModel::writeStaticBlockBytecode(const string &basename) const eq_node = getBlockEquationExpr(block, i); lhs = eq_node->arg1; rhs = eq_node->arg2; - rhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, false, false); - lhs->compile(code_file, instruction_number, true, temporary_terms_union, blocks_temporary_terms_idxs, false, false); + rhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, false, false, tef_terms); + lhs->compile(code_file, instruction_number, true, temporary_terms_union, blocks_temporary_terms_idxs, false, false, tef_terms); } else if (equ_type == EquationType::evaluateRenormalized) { eq_node = getBlockEquationRenormalizedExpr(block, i); lhs = eq_node->arg1; rhs = eq_node->arg2; - rhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, false, false); - lhs->compile(code_file, instruction_number, true, temporary_terms_union, blocks_temporary_terms_idxs, false, false); + rhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, false, false, tef_terms); + lhs->compile(code_file, instruction_number, true, temporary_terms_union, blocks_temporary_terms_idxs, false, false, tef_terms); } break; case BlockSimulationType::solveBackwardComplete: @@ -688,8 +690,8 @@ StaticModel::writeStaticBlockBytecode(const string &basename) const eq_node = getBlockEquationExpr(block, i); lhs = eq_node->arg1; rhs = eq_node->arg2; - lhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, false, false); - rhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, false, false); + lhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, false, false, tef_terms); + rhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, false, false, tef_terms); FBINARY_ fbinary{static_cast<int>(BinaryOpcode::minus)}; fbinary.write(code_file, instruction_number); @@ -716,7 +718,7 @@ StaticModel::writeStaticBlockBytecode(const string &basename) const FNUMEXPR_ fnumexpr(ExpressionType::FirstEndoDerivative, 0, 0); fnumexpr.write(code_file, instruction_number); } - compileDerivative(code_file, instruction_number, getBlockEquationID(block, 0), getBlockVariableID(block, 0), temporary_terms_union, blocks_temporary_terms_idxs); + compileDerivative(code_file, instruction_number, getBlockEquationID(block, 0), getBlockVariableID(block, 0), temporary_terms_union, blocks_temporary_terms_idxs, tef_terms); { FSTPG_ fstpg(0); fstpg.write(code_file, instruction_number); @@ -748,7 +750,7 @@ StaticModel::writeStaticBlockBytecode(const string &basename) const Uf[eqr].Ufl->var = varr; FNUMEXPR_ fnumexpr(ExpressionType::FirstEndoDerivative, eqr, varr); fnumexpr.write(code_file, instruction_number); - compileChainRuleDerivative(code_file, instruction_number, block, eq, var, 0, temporary_terms_union, blocks_temporary_terms_idxs); + compileChainRuleDerivative(code_file, instruction_number, block, eq, var, 0, temporary_terms_union, blocks_temporary_terms_idxs, tef_terms); FSTPSU_ fstpsu(count_u); fstpsu.write(code_file, instruction_number); count_u++; @@ -836,16 +838,16 @@ StaticModel::writeStaticBlockBytecode(const string &basename) const eq_node = getBlockEquationExpr(block, i); lhs = eq_node->arg1; rhs = eq_node->arg2; - rhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, false, false); - lhs->compile(code_file, instruction_number, true, temporary_terms_union, blocks_temporary_terms_idxs, false, false); + rhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, false, false, tef_terms); + lhs->compile(code_file, instruction_number, true, temporary_terms_union, blocks_temporary_terms_idxs, false, false, tef_terms); } else if (equ_type == EquationType::evaluateRenormalized) { eq_node = getBlockEquationRenormalizedExpr(block, i); lhs = eq_node->arg1; rhs = eq_node->arg2; - rhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, false, false); - lhs->compile(code_file, instruction_number, true, temporary_terms_union, blocks_temporary_terms_idxs, false, false); + rhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, false, false, tef_terms); + lhs->compile(code_file, instruction_number, true, temporary_terms_union, blocks_temporary_terms_idxs, false, false, tef_terms); } break; case BlockSimulationType::solveBackwardComplete: @@ -864,8 +866,8 @@ StaticModel::writeStaticBlockBytecode(const string &basename) const eq_node = getBlockEquationExpr(block, i); lhs = eq_node->arg1; rhs = eq_node->arg2; - lhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, false, false); - rhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, false, false); + lhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, false, false, tef_terms); + rhs->compile(code_file, instruction_number, false, temporary_terms_union, blocks_temporary_terms_idxs, false, false, tef_terms); FBINARY_ fbinary{static_cast<int>(BinaryOpcode::minus)}; fbinary.write(code_file, instruction_number); @@ -890,7 +892,7 @@ StaticModel::writeStaticBlockBytecode(const string &basename) const FNUMEXPR_ fnumexpr(ExpressionType::FirstEndoDerivative, 0, 0); fnumexpr.write(code_file, instruction_number); } - compileDerivative(code_file, instruction_number, getBlockEquationID(block, 0), getBlockVariableID(block, 0), temporary_terms_union, blocks_temporary_terms_idxs); + compileDerivative(code_file, instruction_number, getBlockEquationID(block, 0), getBlockVariableID(block, 0), temporary_terms_union, blocks_temporary_terms_idxs, tef_terms); { FSTPG2_ fstpg2(0, 0); fstpg2.write(code_file, instruction_number); @@ -909,7 +911,7 @@ StaticModel::writeStaticBlockBytecode(const string &basename) const FNUMEXPR_ fnumexpr(ExpressionType::FirstEndoDerivative, eqr, varr, 0); fnumexpr.write(code_file, instruction_number); - compileChainRuleDerivative(code_file, instruction_number, block, eq, var, 0, temporary_terms_union, blocks_temporary_terms_idxs); + compileChainRuleDerivative(code_file, instruction_number, block, eq, var, 0, temporary_terms_union, blocks_temporary_terms_idxs, tef_terms); FSTPG2_ fstpg2(eq, var); fstpg2.write(code_file, instruction_number); diff --git a/src/StaticModel.hh b/src/StaticModel.hh index c2fc1d6f110739119b212eb87ba1a7278997409b..5af7147179392c594a8b4dd534290ed6d131edb5 100644 --- a/src/StaticModel.hh +++ b/src/StaticModel.hh @@ -78,9 +78,9 @@ private: void evaluateJacobian(const eval_context_t &eval_context, jacob_map_t *j_m, bool dynamic); //! Write derivative code of an equation w.r. to a variable - void compileDerivative(ofstream &code_file, unsigned int &instruction_number, int eq, int symb_id, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs) const; + void compileDerivative(ofstream &code_file, unsigned int &instruction_number, int eq, int symb_id, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const; //! Write chain rule derivative code of an equation w.r. to a variable - void compileChainRuleDerivative(ofstream &code_file, unsigned int &instruction_number, int blk, int eq, int var, int lag, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs) const; + void compileChainRuleDerivative(ofstream &code_file, unsigned int &instruction_number, int blk, int eq, int var, int lag, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const; //! Get the type corresponding to a derivation ID SymbolType getTypeByDerivID(int deriv_id) const noexcept(false) override;