From 1cc3e3c82872f306eee5745ccff510f44eb72823 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Villemot?= <sebastien@dynare.org> Date: Tue, 5 Jul 2022 15:04:36 +0200 Subject: [PATCH] Fix interaction of temporary terms with steady_state operator When the same complex expression appears outside and inside a steady_state() operator, the same temporary term would be used for both cases, which was obviously wrong. The fix consists in never substituting temporary terms for expressions inside the steady_state operator(). Incidentally, this implies that external functions can no longer be used inside steady_state operators (since their computed values are stored inside temporary terms). (manually cherry picked from commit c27342cfeb7fee793cab4ed58dfd4b9f72b6b30a) --- src/ExprNode.cc | 49 ++++++++++++++++++++++++++++++++++++++++++++----- src/ExprNode.hh | 9 +++++++++ 2 files changed, 53 insertions(+), 5 deletions(-) diff --git a/src/ExprNode.cc b/src/ExprNode.cc index 4de5a55d..2b998783 100644 --- a/src/ExprNode.cc +++ b/src/ExprNode.cc @@ -96,6 +96,13 @@ ExprNode::checkIfTemporaryTermThenWrite(ostream &output, ExprNodeOutputType outp if (auto it = temporary_terms.find(const_cast<ExprNode *>(this)); it == temporary_terms.end()) return false; + /* If we are inside a steady_state() operator, the temporary terms do not + apply, since those refer to the dynamic model (assuming that writeOutput() + was initially not called with a steady state output type, which is + typically the case). */ + if (isSteadyStateOperatorOutput(output_type)) + return false; + auto it2 = temporary_terms_idxs.find(const_cast<ExprNode *>(this)); // It is the responsibility of the caller to ensure that all temporary terms have their index assert(it2 != temporary_terms_idxs.end()); @@ -2266,7 +2273,8 @@ UnaryOpNode::computeTemporaryTerms(const pair<int, int> &derivOrder, it == reference_count.end()) { reference_count[this2] = { 1, derivOrder }; - arg->computeTemporaryTerms(derivOrder, temp_terms_map, reference_count, is_matlab); + if (op_code != UnaryOpcode::steadyState) // See comment in checkIfTemporaryTermThenWrite{,Bytecode}() + arg->computeTemporaryTerms(derivOrder, temp_terms_map, reference_count, is_matlab); } else { @@ -2286,7 +2294,8 @@ UnaryOpNode::computeBlockTemporaryTerms(int blk, int eq, vector<vector<temporary it == reference_count.end()) { reference_count[this2] = { 1, blk, eq }; - arg->computeBlockTemporaryTerms(blk, eq, blocks_temporary_terms, reference_count); + if (op_code != UnaryOpcode::steadyState) // See comment in checkIfTemporaryTermThenWrite{,Bytecode}() + arg->computeBlockTemporaryTerms(blk, eq, blocks_temporary_terms, reference_count); } else { @@ -2902,7 +2911,7 @@ UnaryOpNode::compile(ostream &CompileCode, unsigned int &instruction_number, const deriv_node_temp_terms_t &tef_terms) const { if (auto this2 = const_cast<UnaryOpNode *>(this); - temporary_terms.find(this2) != temporary_terms.end()) + temporary_terms.find(this2) != temporary_terms.end() && !steady_dynamic) { if (dynamic) { @@ -4144,7 +4153,7 @@ BinaryOpNode::compile(ostream &CompileCode, unsigned int &instruction_number, { // If current node is a temporary term if (auto this2 = const_cast<BinaryOpNode *>(this); - temporary_terms.find(this2) != temporary_terms.end()) + temporary_terms.find(this2) != temporary_terms.end() && !steady_dynamic) { if (dynamic) { @@ -5792,7 +5801,7 @@ TrinaryOpNode::compile(ostream &CompileCode, unsigned int &instruction_number, { // If current node is a temporary term if (auto this2 = const_cast<TrinaryOpNode *>(this); - temporary_terms.find(this2) != temporary_terms.end()) + temporary_terms.find(this2) != temporary_terms.end() && !steady_dynamic) { if (dynamic) { @@ -6946,6 +6955,12 @@ ExternalFunctionNode::compile(ostream &CompileCode, unsigned int &instruction_nu const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic, const deriv_node_temp_terms_t &tef_terms) const { + if (steady_dynamic) + { + cerr << "ERROR: The expression inside a steady_state operator cannot contain external functions" << endl; + exit(EXIT_FAILURE); + } + if (auto this2 = const_cast<ExternalFunctionNode *>(this); temporary_terms.find(this2) != temporary_terms.end()) { @@ -7079,6 +7094,12 @@ ExternalFunctionNode::writeOutput(ostream &output, ExprNodeOutputType output_typ return; } + if (isSteadyStateOperatorOutput(output_type)) + { + cerr << "ERROR: The expression inside a steady_state operator cannot contain external functions" << endl; + exit(EXIT_FAILURE); + } + if (checkIfTemporaryTermThenWrite(output, output_type, temporary_terms, temporary_terms_idxs)) return; @@ -7311,6 +7332,12 @@ FirstDerivExternalFunctionNode::writeOutput(ostream &output, ExprNodeOutputType return; } + if (isSteadyStateOperatorOutput(output_type)) + { + cerr << "ERROR: The expression inside a steady_state operator cannot contain external functions" << endl; + exit(EXIT_FAILURE); + } + if (checkIfTemporaryTermThenWrite(output, output_type, temporary_terms, temporary_terms_idxs)) return; @@ -7339,6 +7366,12 @@ FirstDerivExternalFunctionNode::compile(ostream &CompileCode, unsigned int &inst const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic, const deriv_node_temp_terms_t &tef_terms) const { + if (steady_dynamic) + { + cerr << "ERROR: The expression inside a steady_state operator cannot contain external functions" << endl; + exit(EXIT_FAILURE); + } + if (auto this2 = const_cast<FirstDerivExternalFunctionNode *>(this); temporary_terms.find(this2) != temporary_terms.end()) { @@ -7675,6 +7708,12 @@ SecondDerivExternalFunctionNode::writeOutput(ostream &output, ExprNodeOutputType return; } + if (isSteadyStateOperatorOutput(output_type)) + { + cerr << "ERROR: The expression inside a steady_state operator cannot contain external functions" << endl; + exit(EXIT_FAILURE); + } + if (checkIfTemporaryTermThenWrite(output, output_type, temporary_terms, temporary_terms_idxs)) return; diff --git a/src/ExprNode.hh b/src/ExprNode.hh index 14dd67d1..cc4aba7b 100644 --- a/src/ExprNode.hh +++ b/src/ExprNode.hh @@ -139,6 +139,15 @@ isLatexOutput(ExprNodeOutputType output_type) || output_type == ExprNodeOutputType::latexDynamicSteadyStateOperator; } +inline bool +isSteadyStateOperatorOutput(ExprNodeOutputType output_type) +{ + return output_type == ExprNodeOutputType::latexDynamicSteadyStateOperator + || output_type == ExprNodeOutputType::matlabDynamicSteadyStateOperator + || output_type == ExprNodeOutputType::CDynamicSteadyStateOperator + || output_type == ExprNodeOutputType::juliaDynamicSteadyStateOperator; +} + /* Equal to 1 for Matlab langage or Julia, or to 0 for C language. Not defined for LaTeX. In Matlab and Julia, array indexes begin at 1, while they begin at 0 in C */ inline int -- GitLab