From fe3f18947e49c8f9eb901c90121ed93a08422aa2 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?S=C3=A9bastien=20Villemot?= <sebastien@dynare.org>
Date: Thu, 23 Mar 2023 18:28:13 +0100
Subject: [PATCH] No longer replace all auxiliary variables by their definition
 in the static model
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This is effectively a revert of commits 1b4f68f93433cf1fd826789f32a567c34f4c1e7d,
32fb90d5f3c729f0744ad97ffc7c51e31cb3731b and f6f4ea70fbce9fd584c72c15a2932fa02a854d0b.

This transformation had been introduced in order to fix the computation of the
Ramsey steady state in the case where Lagrange multipliers appeared with a lead
or lag ⩾ 2 (and where thus part of the definition of an auxiliary variable).

But this transformation had introduced bugs in the handling of external
functions which were difficult to tackle.

Moreover, it seems preferable to keep the strict correspondence between the
dynamic and static model, in order to make reasoning about the preprocessor
internals easier (in particular, for this reason this transformation was not
implemented in ModFile::transformPass() but in ModFile::computingPass(), which
was a bit confusing).

A better solution for the Ramsey steady state issue will is implemented in the
descendent of the present commit.

Ref. dynare#633, dynare#1119, dynare#1133
---
 src/ExprNode.cc    | 64 ----------------------------------------------
 src/ExprNode.hh    | 17 ------------
 src/StaticModel.cc | 22 +++-------------
 src/SymbolTable.cc | 14 ----------
 src/SymbolTable.hh |  2 --
 5 files changed, 3 insertions(+), 116 deletions(-)

diff --git a/src/ExprNode.cc b/src/ExprNode.cc
index 98ac170a..faeb7d70 100644
--- a/src/ExprNode.cc
+++ b/src/ExprNode.cc
@@ -841,12 +841,6 @@ NumConstNode::isParamTimesEndogExpr() const
   return false;
 }
 
-expr_t
-NumConstNode::substituteStaticAuxiliaryVariable() const
-{
-  return const_cast<NumConstNode *>(this);
-}
-
 expr_t
 NumConstNode::replaceVarsInEquation([[maybe_unused]] map<VariableNode *, NumConstNode *> &table) const
 {
@@ -1382,20 +1376,6 @@ VariableNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
     }
 }
 
-expr_t
-VariableNode::substituteStaticAuxiliaryVariable() const
-{
-  if (get_type() == SymbolType::endogenous)
-    try
-      {
-        return datatree.symbol_table.getAuxiliaryVarsExprNode(symb_id)->substituteStaticAuxiliaryVariable();
-      }
-    catch (SymbolTable::SearchFailedException &e)
-      {
-      }
-  return const_cast<VariableNode *>(this);
-}
-
 double
 VariableNode::eval(const eval_context_t &eval_context) const noexcept(false)
 {
@@ -3997,19 +3977,6 @@ UnaryOpNode::isParamTimesEndogExpr() const
   return arg->isParamTimesEndogExpr();
 }
 
-expr_t
-UnaryOpNode::substituteStaticAuxiliaryVariable() const
-{
-  if (op_code == UnaryOpcode::diff)
-    return datatree.Zero;
-
-  expr_t argsubst = arg->substituteStaticAuxiliaryVariable();
-  if (op_code == UnaryOpcode::expectation)
-    return argsubst;
-  else
-    return buildSimilarUnaryOpNode(argsubst, datatree);
-}
-
 expr_t
 UnaryOpNode::replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const
 {
@@ -5853,19 +5820,6 @@ BinaryOpNode::replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table)
   return recurseTransform(&ExprNode::replaceVarsInEquation, table);
 }
 
-expr_t
-BinaryOpNode::substituteStaticAuxiliaryVariable() const
-{
-  return recurseTransform(&ExprNode::substituteStaticAuxiliaryVariable);
-}
-
-expr_t
-BinaryOpNode::substituteStaticAuxiliaryDefinition() const
-{
-  expr_t arg2subst = arg2->substituteStaticAuxiliaryVariable();
-  return buildSimilarBinaryOpNode(arg1, arg2subst, datatree);
-}
-
 void
 BinaryOpNode::matchMatchedMoment(vector<int> &symb_ids, vector<int> &lags, vector<int> &powers) const
 {
@@ -6721,12 +6675,6 @@ TrinaryOpNode::isParamTimesEndogExpr() const
     || arg3->isParamTimesEndogExpr();
 }
 
-expr_t
-TrinaryOpNode::substituteStaticAuxiliaryVariable() const
-{
-  return recurseTransform(&ExprNode::substituteStaticAuxiliaryVariable);
-}
-
 expr_t
 TrinaryOpNode::replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const
 {
@@ -7252,12 +7200,6 @@ AbstractExternalFunctionNode::containsExternalFunction() const
   return true;
 }
 
-expr_t
-AbstractExternalFunctionNode::substituteStaticAuxiliaryVariable() const
-{
-  return recurseTransform(&ExprNode::substituteStaticAuxiliaryVariable);
-}
-
 expr_t
 AbstractExternalFunctionNode::replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const
 {
@@ -8557,12 +8499,6 @@ SubModelNode::isParamTimesEndogExpr() const
   return false;
 }
 
-expr_t
-SubModelNode::substituteStaticAuxiliaryVariable() const
-{
-  return const_cast<SubModelNode *>(this);
-}
-
 expr_t
 SubModelNode::replaceVarsInEquation([[maybe_unused]] map<VariableNode *, NumConstNode *> &table) const
 {
diff --git a/src/ExprNode.hh b/src/ExprNode.hh
index 4c371569..875e27d9 100644
--- a/src/ExprNode.hh
+++ b/src/ExprNode.hh
@@ -713,9 +713,6 @@ public:
   //! Returns true if the expression is in static form (no lead, no lag, no expectation, no STEADY_STATE)
   virtual bool isInStaticForm() const = 0;
 
-  //! Substitute auxiliary variables by their expression in static model
-  virtual expr_t substituteStaticAuxiliaryVariable() const = 0;
-
   //! Matches a linear combination of variables (endo or exo), where scalars can be constant*parameter
   /*! Returns a list of (variable_id, lag, param_id, constant)
     corresponding to the terms in the expression. When there is no
@@ -916,7 +913,6 @@ public:
   bool containsPacExpectation(const string &pac_model_name = "") const override;
   bool containsPacTargetNonstationary(const string &pac_model_name = "") const override;
   bool isParamTimesEndogExpr() const override;
-  expr_t substituteStaticAuxiliaryVariable() const override;
   expr_t substituteLogTransform(int orig_symb_id, int aux_symb_id) const override;
 };
 
@@ -990,8 +986,6 @@ public:
   bool containsPacExpectation(const string &pac_model_name = "") const override;
   bool containsPacTargetNonstationary(const string &pac_model_name = "") const override;
   bool isParamTimesEndogExpr() const override;
-  //! Substitute auxiliary variables by their expression in static model
-  expr_t substituteStaticAuxiliaryVariable() const override;
   void matchMatchedMoment(vector<int> &symb_ids, vector<int> &lags, vector<int> &powers) const override;
   pair<int, expr_t> matchEndogenousTimesConstant() const override;
   expr_t substituteLogTransform(int orig_symb_id, int aux_symb_id) const override;
@@ -1105,8 +1099,6 @@ public:
   bool containsPacExpectation(const string &pac_model_name = "") const override;
   bool containsPacTargetNonstationary(const string &pac_model_name = "") const override;
   bool isParamTimesEndogExpr() const override;
-  //! Substitute auxiliary variables by their expression in static model
-  expr_t substituteStaticAuxiliaryVariable() const override;
   void decomposeAdditiveTerms(vector<pair<expr_t, int>> &terms, int current_sign) const override;
   expr_t substituteLogTransform(int orig_symb_id, int aux_symb_id) const override;
 };
@@ -1256,10 +1248,6 @@ public:
   pair<optional<int>, expr_t> getPacOptimizingShareAndExprNodesHelper(int lhs_orig_symb_id) const;
   expr_t getPacNonOptimizingPart(int optim_share_symb_id) const;
   bool isParamTimesEndogExpr() const override;
-  //! Substitute auxiliary variables by their expression in static model
-  expr_t substituteStaticAuxiliaryVariable() const override;
-  //! Substitute auxiliary variables by their expression in static model auxiliary variable definition
-  expr_t substituteStaticAuxiliaryDefinition() const;
   void decomposeAdditiveTerms(vector<pair<expr_t, int>> &terms, int current_sign) const override;
   void decomposeMultiplicativeFactors(vector<pair<expr_t, int>> &factors, int current_exponent = 1) const override;
   void matchMatchedMoment(vector<int> &symb_ids, vector<int> &lags, vector<int> &powers) const override;
@@ -1372,8 +1360,6 @@ public:
   bool containsPacExpectation(const string &pac_model_name = "") const override;
   bool containsPacTargetNonstationary(const string &pac_model_name = "") const override;
   bool isParamTimesEndogExpr() const override;
-  //! Substitute auxiliary variables by their expression in static model
-  expr_t substituteStaticAuxiliaryVariable() const override;
   expr_t substituteLogTransform(int orig_symb_id, int aux_symb_id) const override;
 };
 
@@ -1498,8 +1484,6 @@ public:
   bool containsPacExpectation(const string &pac_model_name = "") const override;
   bool containsPacTargetNonstationary(const string &pac_model_name = "") const override;
   bool isParamTimesEndogExpr() const override;
-  //! Substitute auxiliary variables by their expression in static model
-  expr_t substituteStaticAuxiliaryVariable() const override;
   expr_t substituteLogTransform(int orig_symb_id, int aux_symb_id) const override;
 };
 
@@ -1665,7 +1649,6 @@ public:
   bool isInStaticForm() const override;
   expr_t replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const override;
   bool isParamTimesEndogExpr() const override;
-  expr_t substituteStaticAuxiliaryVariable() const override;
   expr_t differentiateForwardVars(const vector<string> &subset, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
   expr_t decreaseLeadsLagsPredeterminedVariables() const override;
   expr_t replaceTrendVar() const override;
diff --git a/src/StaticModel.cc b/src/StaticModel.cc
index 3244ecf2..038ea375 100644
--- a/src/StaticModel.cc
+++ b/src/StaticModel.cc
@@ -174,22 +174,6 @@ StaticModel::computingPass(int derivsOrder, int paramsDerivsOrder, const eval_co
 {
   initializeVariablesAndEquations();
 
-  vector<BinaryOpNode *> neweqs;
-  for (int eq = 0; eq < static_cast<int>(equations.size() - aux_equations.size()); eq++)
-    {
-      expr_t eq_tmp = equations[eq]->substituteStaticAuxiliaryVariable();
-      neweqs.push_back(dynamic_cast<BinaryOpNode *>(eq_tmp->toStatic(*this)));
-    }
-
-  for (auto &aux_equation : aux_equations)
-    {
-      expr_t eq_tmp = aux_equation->substituteStaticAuxiliaryDefinition();
-      neweqs.push_back(dynamic_cast<BinaryOpNode *>(eq_tmp->toStatic(*this)));
-    }
-
-  equations.clear();
-  copy(neweqs.begin(), neweqs.end(), back_inserter(equations));
-
   /* In both MATLAB and Julia, tensors for higher-order derivatives are stored
      in matrices whose columns correspond to variable multi-indices. Since we
      currently are limited to 32-bit signed integers (hence 31 bits) for matrix
@@ -769,7 +753,7 @@ StaticModel::writeAuxVarRecursiveDefinitions(ostream &output, ExprNodeOutputType
       dynamic_cast<ExprNode *>(aux_equation)->writeExternalFunctionOutput(output, output_type, {}, {}, tef_terms);
   for (auto aux_equation : aux_equations)
     {
-      dynamic_cast<ExprNode *>(aux_equation->substituteStaticAuxiliaryDefinition())->writeOutput(output, output_type, {}, {}, tef_terms);
+      aux_equation->writeOutput(output, output_type, {}, {}, tef_terms);
       output << ";" << endl;
     }
 }
@@ -787,7 +771,7 @@ StaticModel::writeLatexAuxVarRecursiveDefinitions(ostream &output) const
   for (auto aux_equation : aux_equations)
     {
       output << R"(\begin{dmath})" << endl;
-      dynamic_cast<ExprNode *>(aux_equation->substituteStaticAuxiliaryDefinition())->writeOutput(output, ExprNodeOutputType::latexStaticModel);
+      dynamic_cast<ExprNode *>(aux_equation)->writeOutput(output, ExprNodeOutputType::latexStaticModel);
       output << endl << R"(\end{dmath})" << endl;
     }
 }
@@ -820,7 +804,7 @@ StaticModel::writeJsonAuxVarRecursiveDefinitions(ostream &output) const
       output << R"(, {"lhs": ")";
       aux_equation->arg1->writeJsonOutput(output, temporary_terms, tef_terms, false);
       output << R"(", "rhs": ")";
-      dynamic_cast<BinaryOpNode *>(aux_equation->substituteStaticAuxiliaryDefinition())->arg2->writeJsonOutput(output, temporary_terms, tef_terms, false);
+      aux_equation->arg2->writeJsonOutput(output, temporary_terms, tef_terms, false);
       output << R"("})";
     }
 }
diff --git a/src/SymbolTable.cc b/src/SymbolTable.cc
index 8e9db624..37e3841d 100644
--- a/src/SymbolTable.cc
+++ b/src/SymbolTable.cc
@@ -706,20 +706,6 @@ SymbolTable::unrollDiffLeadLagChain(int symb_id, int lag) const noexcept(false)
   return { symb_id, lag };
 }
 
-expr_t
-SymbolTable::getAuxiliaryVarsExprNode(int symb_id) const noexcept(false)
-// throw exception if it is a Lagrange multiplier
-{
-  for (const auto &aux_var : aux_vars)
-    if (aux_var.symb_id == symb_id)
-      if (expr_t expr_node = aux_var.expr_node;
-          expr_node)
-        return expr_node;
-      else
-        throw SearchFailedException(symb_id);
-  throw SearchFailedException(symb_id);
-}
-
 void
 SymbolTable::markPredetermined(int symb_id) noexcept(false)
 {
diff --git a/src/SymbolTable.hh b/src/SymbolTable.hh
index 4de52425..e8e787be 100644
--- a/src/SymbolTable.hh
+++ b/src/SymbolTable.hh
@@ -309,8 +309,6 @@ public:
   {
     return aux_vars.size();
   };
-  //! Retruns expr_node for an auxiliary variable
-  expr_t getAuxiliaryVarsExprNode(int symb_id) const noexcept(false);
   //! Tests if symbol already exists
   inline bool exists(const string &name) const;
   //! Get symbol name (by ID)
-- 
GitLab