From 1cc3e3c82872f306eee5745ccff510f44eb72823 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?S=C3=A9bastien=20Villemot?= <sebastien@dynare.org>
Date: Tue, 5 Jul 2022 15:04:36 +0200
Subject: [PATCH] Fix interaction of temporary terms with steady_state operator

When the same complex expression appears outside and inside a steady_state()
operator, the same temporary term would be used for both cases, which was
obviously wrong.

The fix consists in never substituting temporary terms for expressions inside
the steady_state operator().

Incidentally, this implies that external functions can no longer be used inside
steady_state operators (since their computed values are stored inside temporary
terms).

(manually cherry picked from commit c27342cfeb7fee793cab4ed58dfd4b9f72b6b30a)
---
 src/ExprNode.cc | 49 ++++++++++++++++++++++++++++++++++++++++++++-----
 src/ExprNode.hh |  9 +++++++++
 2 files changed, 53 insertions(+), 5 deletions(-)

diff --git a/src/ExprNode.cc b/src/ExprNode.cc
index 4de5a55d..2b998783 100644
--- a/src/ExprNode.cc
+++ b/src/ExprNode.cc
@@ -96,6 +96,13 @@ ExprNode::checkIfTemporaryTermThenWrite(ostream &output, ExprNodeOutputType outp
   if (auto it = temporary_terms.find(const_cast<ExprNode *>(this)); it == temporary_terms.end())
     return false;
 
+  /* If we are inside a steady_state() operator, the temporary terms do not
+     apply, since those refer to the dynamic model (assuming that writeOutput()
+     was initially not called with a steady state output type, which is
+     typically the case). */
+  if (isSteadyStateOperatorOutput(output_type))
+    return false;
+
   auto it2 = temporary_terms_idxs.find(const_cast<ExprNode *>(this));
   // It is the responsibility of the caller to ensure that all temporary terms have their index
   assert(it2 != temporary_terms_idxs.end());
@@ -2266,7 +2273,8 @@ UnaryOpNode::computeTemporaryTerms(const pair<int, int> &derivOrder,
       it == reference_count.end())
     {
       reference_count[this2] = { 1, derivOrder };
-      arg->computeTemporaryTerms(derivOrder, temp_terms_map, reference_count, is_matlab);
+      if (op_code != UnaryOpcode::steadyState) // See comment in checkIfTemporaryTermThenWrite{,Bytecode}()
+        arg->computeTemporaryTerms(derivOrder, temp_terms_map, reference_count, is_matlab);
     }
   else
     {
@@ -2286,7 +2294,8 @@ UnaryOpNode::computeBlockTemporaryTerms(int blk, int eq, vector<vector<temporary
       it == reference_count.end())
     {
       reference_count[this2] = { 1, blk, eq };
-      arg->computeBlockTemporaryTerms(blk, eq, blocks_temporary_terms, reference_count);
+      if (op_code != UnaryOpcode::steadyState) // See comment in checkIfTemporaryTermThenWrite{,Bytecode}()
+        arg->computeBlockTemporaryTerms(blk, eq, blocks_temporary_terms, reference_count);
     }
   else
     {
@@ -2902,7 +2911,7 @@ UnaryOpNode::compile(ostream &CompileCode, unsigned int &instruction_number,
                      const deriv_node_temp_terms_t &tef_terms) const
 {
   if (auto this2 = const_cast<UnaryOpNode *>(this);
-      temporary_terms.find(this2) != temporary_terms.end())
+      temporary_terms.find(this2) != temporary_terms.end() && !steady_dynamic)
     {
       if (dynamic)
         {
@@ -4144,7 +4153,7 @@ BinaryOpNode::compile(ostream &CompileCode, unsigned int &instruction_number,
 {
   // If current node is a temporary term
   if (auto this2 = const_cast<BinaryOpNode *>(this);
-      temporary_terms.find(this2) != temporary_terms.end())
+      temporary_terms.find(this2) != temporary_terms.end() && !steady_dynamic)
     {
       if (dynamic)
         {
@@ -5792,7 +5801,7 @@ TrinaryOpNode::compile(ostream &CompileCode, unsigned int &instruction_number,
 {
   // If current node is a temporary term
   if (auto this2 = const_cast<TrinaryOpNode *>(this);
-      temporary_terms.find(this2) != temporary_terms.end())
+      temporary_terms.find(this2) != temporary_terms.end() && !steady_dynamic)
     {
       if (dynamic)
         {
@@ -6946,6 +6955,12 @@ ExternalFunctionNode::compile(ostream &CompileCode, unsigned int &instruction_nu
                               const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic,
                               const deriv_node_temp_terms_t &tef_terms) const
 {
+  if (steady_dynamic)
+    {
+      cerr << "ERROR: The expression inside a steady_state operator cannot contain external functions" << endl;
+      exit(EXIT_FAILURE);
+    }
+
   if (auto this2 = const_cast<ExternalFunctionNode *>(this);
       temporary_terms.find(this2) != temporary_terms.end())
     {
@@ -7079,6 +7094,12 @@ ExternalFunctionNode::writeOutput(ostream &output, ExprNodeOutputType output_typ
       return;
     }
 
+  if (isSteadyStateOperatorOutput(output_type))
+    {
+      cerr << "ERROR: The expression inside a steady_state operator cannot contain external functions" << endl;
+      exit(EXIT_FAILURE);
+    }
+
   if (checkIfTemporaryTermThenWrite(output, output_type, temporary_terms, temporary_terms_idxs))
     return;
 
@@ -7311,6 +7332,12 @@ FirstDerivExternalFunctionNode::writeOutput(ostream &output, ExprNodeOutputType
       return;
     }
 
+  if (isSteadyStateOperatorOutput(output_type))
+    {
+      cerr << "ERROR: The expression inside a steady_state operator cannot contain external functions" << endl;
+      exit(EXIT_FAILURE);
+    }
+
   if (checkIfTemporaryTermThenWrite(output, output_type, temporary_terms, temporary_terms_idxs))
     return;
 
@@ -7339,6 +7366,12 @@ FirstDerivExternalFunctionNode::compile(ostream &CompileCode, unsigned int &inst
                                         const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic,
                                         const deriv_node_temp_terms_t &tef_terms) const
 {
+  if (steady_dynamic)
+    {
+      cerr << "ERROR: The expression inside a steady_state operator cannot contain external functions" << endl;
+      exit(EXIT_FAILURE);
+    }
+
   if (auto this2 = const_cast<FirstDerivExternalFunctionNode *>(this);
       temporary_terms.find(this2) != temporary_terms.end())
     {
@@ -7675,6 +7708,12 @@ SecondDerivExternalFunctionNode::writeOutput(ostream &output, ExprNodeOutputType
       return;
     }
 
+  if (isSteadyStateOperatorOutput(output_type))
+    {
+      cerr << "ERROR: The expression inside a steady_state operator cannot contain external functions" << endl;
+      exit(EXIT_FAILURE);
+    }
+
   if (checkIfTemporaryTermThenWrite(output, output_type, temporary_terms, temporary_terms_idxs))
     return;
 
diff --git a/src/ExprNode.hh b/src/ExprNode.hh
index 14dd67d1..cc4aba7b 100644
--- a/src/ExprNode.hh
+++ b/src/ExprNode.hh
@@ -139,6 +139,15 @@ isLatexOutput(ExprNodeOutputType output_type)
     || output_type == ExprNodeOutputType::latexDynamicSteadyStateOperator;
 }
 
+inline bool
+isSteadyStateOperatorOutput(ExprNodeOutputType output_type)
+{
+  return output_type == ExprNodeOutputType::latexDynamicSteadyStateOperator
+    || output_type == ExprNodeOutputType::matlabDynamicSteadyStateOperator
+    || output_type == ExprNodeOutputType::CDynamicSteadyStateOperator
+    || output_type == ExprNodeOutputType::juliaDynamicSteadyStateOperator;
+}
+
 /* Equal to 1 for Matlab langage or Julia, or to 0 for C language. Not defined for LaTeX.
    In Matlab and Julia, array indexes begin at 1, while they begin at 0 in C */
 inline int
-- 
GitLab