From cfb41d291cdeb54e96efca7a1a7f297b7902d049 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?S=C3=A9bastien=20Villemot?= <sebastien@dynare.org>
Date: Tue, 10 Nov 2020 18:04:05 +0100
Subject: [PATCH] Substitute out model-local variables early in the model
 transform pass
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Incidentally, this makes it possible to use model-local variables with either
the “block” or “bytecode” option. (Ref: dynare#1243)
---
 src/DynamicModel.cc | 23 ++++++++++-------
 src/DynamicModel.hh |  5 ++--
 src/ExprNode.cc     | 60 +++++++++++++++++++++++++++++++++++++++++++++
 src/ExprNode.hh     | 11 +++++++++
 src/ModFile.cc      |  9 +++----
 5 files changed, 91 insertions(+), 17 deletions(-)

diff --git a/src/DynamicModel.cc b/src/DynamicModel.cc
index ab730efc..56e74514 100644
--- a/src/DynamicModel.cc
+++ b/src/DynamicModel.cc
@@ -5632,6 +5632,20 @@ DynamicModel::substituteAdl()
     equation = dynamic_cast<BinaryOpNode *>(equation->substituteAdl());
 }
 
+void
+DynamicModel::substituteModelLocalVariables()
+{
+  for (auto &equation : equations)
+    equation = dynamic_cast<BinaryOpNode *>(equation->substituteModelLocalVariables());
+
+  /* We can’t clear local_variables_table at this point, because in case of
+     ramsey_policy, the original model is saved via DynamicModel::operator=()
+     before computing the FOC. But since DataTree::operator=() clones all
+     nodes, it will try to clone nodes containing model-local variables, and
+     this will fail at the point where DataTree methods try to evaluate those
+     nodes to a numerical value. */
+}
+
 set<int>
 DynamicModel::getEquationNumbersFromTags(const set<string> &eqtags) const
 {
@@ -5907,15 +5921,6 @@ DynamicModel::fillEvalContext(eval_context_t &eval_context) const
     eval_context[trendVar] = 2; //not <= 0 bc of log, not 1 bc of powers
 }
 
-bool
-DynamicModel::isModelLocalVariableUsed() const
-{
-  set<int> used_local_vars;
-  for (size_t i = 0; i < equations.size() && used_local_vars.empty(); i++)
-    equations[i]->collectVariables(SymbolType::modelLocalVariable, used_local_vars);
-  return !used_local_vars.empty();
-}
-
 void
 DynamicModel::addStaticOnlyEquation(expr_t eq, int lineno, const map<string, string> &eq_tags)
 {
diff --git a/src/DynamicModel.hh b/src/DynamicModel.hh
index d020b879..7752964f 100644
--- a/src/DynamicModel.hh
+++ b/src/DynamicModel.hh
@@ -486,6 +486,9 @@ public:
   //! Substitutes adl operator
   void substituteAdl();
 
+  //! Substitutes out all model-local variables
+  void substituteModelLocalVariables();
+
   //! Creates aux vars for all unary operators
   pair<lag_equivalence_table_t, ExprNode::subst_table_t> substituteUnaryOps();
 
@@ -573,8 +576,6 @@ public:
     return tuple(static_only_equations, static_only_equations_lineno, static_only_equations_equation_tags);
   };
 
-  bool isModelLocalVariableUsed() const;
-
   //! Returns true if a parameter was used in the model block with a lead or lag
   bool ParamUsedWithLeadLag() const;
 
diff --git a/src/ExprNode.cc b/src/ExprNode.cc
index a9df2293..5d4180ca 100644
--- a/src/ExprNode.cc
+++ b/src/ExprNode.cc
@@ -613,6 +613,12 @@ NumConstNode::substituteAdl() const
   return const_cast<NumConstNode *>(this);
 }
 
+expr_t
+NumConstNode::substituteModelLocalVariables() const
+{
+  return const_cast<NumConstNode *>(this);
+}
+
 expr_t
 NumConstNode::substituteVarExpectation(const map<string, expr_t> &subst_table) const
 {
@@ -1541,6 +1547,15 @@ VariableNode::substituteAdl() const
   return const_cast<VariableNode *>(this);
 }
 
+expr_t
+VariableNode::substituteModelLocalVariables() const
+{
+  if (get_type() == SymbolType::modelLocalVariable)
+    return datatree.getLocalVariable(symb_id);
+
+  return const_cast<VariableNode *>(this);
+}
+
 expr_t
 VariableNode::substituteVarExpectation(const map<string, expr_t> &subst_table) const
 {
@@ -3259,6 +3274,13 @@ UnaryOpNode::substituteAdl() const
   return retval;
 }
 
+expr_t
+UnaryOpNode::substituteModelLocalVariables() const
+{
+  expr_t argsubst = arg->substituteModelLocalVariables();
+  return buildSimilarUnaryOpNode(argsubst, datatree);
+}
+
 expr_t
 UnaryOpNode::substituteVarExpectation(const map<string, expr_t> &subst_table) const
 {
@@ -5043,6 +5065,14 @@ BinaryOpNode::substituteAdl() const
   return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
 }
 
+expr_t
+BinaryOpNode::substituteModelLocalVariables() const
+{
+  expr_t arg1subst = arg1->substituteModelLocalVariables();
+  expr_t arg2subst = arg2->substituteModelLocalVariables();
+  return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
+}
+
 expr_t
 BinaryOpNode::substituteVarExpectation(const map<string, expr_t> &subst_table) const
 {
@@ -6267,6 +6297,15 @@ TrinaryOpNode::substituteAdl() const
   return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree);
 }
 
+expr_t
+TrinaryOpNode::substituteModelLocalVariables() const
+{
+  expr_t arg1subst = arg1->substituteModelLocalVariables();
+  expr_t arg2subst = arg2->substituteModelLocalVariables();
+  expr_t arg3subst = arg3->substituteModelLocalVariables();
+  return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree);
+}
+
 expr_t
 TrinaryOpNode::substituteVarExpectation(const map<string, expr_t> &subst_table) const
 {
@@ -6683,6 +6722,15 @@ AbstractExternalFunctionNode::substituteAdl() const
   return buildSimilarExternalFunctionNode(arguments_subst, datatree);
 }
 
+expr_t
+AbstractExternalFunctionNode::substituteModelLocalVariables() const
+{
+  vector<expr_t> arguments_subst;
+  for (auto argument : arguments)
+    arguments_subst.push_back(argument->substituteModelLocalVariables());
+  return buildSimilarExternalFunctionNode(arguments_subst, datatree);
+}
+
 expr_t
 AbstractExternalFunctionNode::substituteVarExpectation(const map<string, expr_t> &subst_table) const
 {
@@ -8244,6 +8292,12 @@ VarExpectationNode::substituteAdl() const
   return const_cast<VarExpectationNode *>(this);
 }
 
+expr_t
+VarExpectationNode::substituteModelLocalVariables() const
+{
+  return const_cast<VarExpectationNode *>(this);
+}
+
 expr_t
 VarExpectationNode::substituteVarExpectation(const map<string, expr_t> &subst_table) const
 {
@@ -8630,6 +8684,12 @@ PacExpectationNode::substituteAdl() const
   return const_cast<PacExpectationNode *>(this);
 }
 
+expr_t
+PacExpectationNode::substituteModelLocalVariables() const
+{
+  return const_cast<PacExpectationNode *>(this);
+}
+
 expr_t
 PacExpectationNode::substituteVarExpectation(const map<string, expr_t> &subst_table) const
 {
diff --git a/src/ExprNode.hh b/src/ExprNode.hh
index 9f8bdd2f..3d9c3da9 100644
--- a/src/ExprNode.hh
+++ b/src/ExprNode.hh
@@ -588,6 +588,9 @@ public:
   //! Substitute adl operator
   virtual expr_t substituteAdl() const = 0;
 
+  //! Substitute out model-local variables
+  virtual expr_t substituteModelLocalVariables() const = 0;
+
   //! Substitute VarExpectation nodes
   virtual expr_t substituteVarExpectation(const map<string, expr_t> &subst_table) const = 0;
 
@@ -768,6 +771,7 @@ public:
   expr_t substituteExoLag(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
   expr_t substituteExpectation(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool partial_information_model) const override;
   expr_t substituteAdl() const override;
+  expr_t substituteModelLocalVariables() const override;
   expr_t substituteVarExpectation(const map<string, expr_t> &subst_table) const override;
   void findDiffNodes(lag_equivalence_table_t &nodes) const override;
   void findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const override;
@@ -840,6 +844,7 @@ public:
   expr_t substituteExoLag(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
   expr_t substituteExpectation(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool partial_information_model) const override;
   expr_t substituteAdl() const override;
+  expr_t substituteModelLocalVariables() const override;
   expr_t substituteVarExpectation(const map<string, expr_t> &subst_table) const override;
   void findDiffNodes(lag_equivalence_table_t &nodes) const override;
   void findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const override;
@@ -942,6 +947,7 @@ public:
   expr_t substituteExoLag(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
   expr_t substituteExpectation(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool partial_information_model) const override;
   expr_t substituteAdl() const override;
+  expr_t substituteModelLocalVariables() const override;
   expr_t substituteVarExpectation(const map<string, expr_t> &subst_table) const override;
   void findDiffNodes(lag_equivalence_table_t &nodes) const override;
   bool createAuxVarForUnaryOpNode() const;
@@ -1048,6 +1054,7 @@ public:
   expr_t substituteExoLag(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
   expr_t substituteExpectation(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool partial_information_model) const override;
   expr_t substituteAdl() const override;
+  expr_t substituteModelLocalVariables() const override;
   expr_t substituteVarExpectation(const map<string, expr_t> &subst_table) const override;
   void findDiffNodes(lag_equivalence_table_t &nodes) const override;
   void findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const override;
@@ -1183,6 +1190,7 @@ public:
   expr_t substituteExoLag(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
   expr_t substituteExpectation(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool partial_information_model) const override;
   expr_t substituteAdl() const override;
+  expr_t substituteModelLocalVariables() const override;
   expr_t substituteVarExpectation(const map<string, expr_t> &subst_table) const override;
   void findDiffNodes(lag_equivalence_table_t &nodes) const override;
   void findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const override;
@@ -1292,6 +1300,7 @@ public:
   expr_t substituteExoLag(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
   expr_t substituteExpectation(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool partial_information_model) const override;
   expr_t substituteAdl() const override;
+  expr_t substituteModelLocalVariables() const override;
   expr_t substituteVarExpectation(const map<string, expr_t> &subst_table) const override;
   void findDiffNodes(lag_equivalence_table_t &nodes) const override;
   void findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const override;
@@ -1468,6 +1477,7 @@ public:
   expr_t substituteExoLag(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
   expr_t substituteExpectation(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool partial_information_model) const override;
   expr_t substituteAdl() const override;
+  expr_t substituteModelLocalVariables() const override;
   expr_t substituteVarExpectation(const map<string, expr_t> &subst_table) const override;
   void findDiffNodes(lag_equivalence_table_t &nodes) const override;
   void findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const override;
@@ -1540,6 +1550,7 @@ public:
   expr_t substituteExoLag(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
   expr_t substituteExpectation(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool partial_information_model) const override;
   expr_t substituteAdl() const override;
+  expr_t substituteModelLocalVariables() const override;
   expr_t substituteVarExpectation(const map<string, expr_t> &subst_table) const override;
   void findDiffNodes(lag_equivalence_table_t &nodes) const override;
   void findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const override;
diff --git a/src/ModFile.cc b/src/ModFile.cc
index 4b3a6764..18e379e2 100644
--- a/src/ModFile.cc
+++ b/src/ModFile.cc
@@ -205,12 +205,6 @@ ModFile::checkPass(bool nostrict, bool stochastic)
       exit(EXIT_FAILURE);
     }
 
-  if ((block || bytecode) && dynamic_model.isModelLocalVariableUsed())
-    {
-      cerr << "ERROR: In 'model' block, 'block' or 'bytecode' options are not yet compatible with pound expressions" << endl;
-      exit(EXIT_FAILURE);
-    }
-
   if ((stochastic_statement_present || mod_file_struct.check_present || mod_file_struct.steady_present) && no_static)
     {
       cerr << "ERROR: no_static option is incompatible with stoch_simul, estimation, osr, ramsey_policy, discretionary_policy, steady and check commands" << endl;
@@ -402,6 +396,9 @@ ModFile::transformPass(bool nostrict, bool stochastic, bool compute_xrefs, bool
   original_model = dynamic_model;
   dynamic_model.expandEqTags();
 
+  // Replace all model-local variables by their expression
+  dynamic_model.substituteModelLocalVariables();
+
   // Check that all declared endogenous are used in equations
   set<int> unusedEndogs = dynamic_model.findUnusedEndogenous();
   bool unusedEndogsIsErr = !nostrict && !mod_file_struct.bvar_present && unusedEndogs.size();
-- 
GitLab