From f5df7e75671d2a8236d95d0d5b2f231761bff513 Mon Sep 17 00:00:00 2001
From: Houtan Bastani <houtan@dynare.org>
Date: Mon, 28 Jan 2019 14:57:30 +0100
Subject: [PATCH] when an equation is of the form `X` = `constant`, replace all
 occurrences of `X` in other equations with `constant`

---
 src/ExprNode.cc  | 127 +++++++++++++++++++++++++++++++++++++++++++++++
 src/ExprNode.hh  |  22 ++++++++
 src/ModFile.cc   |   2 +
 src/ModelTree.cc |  25 +++++++++-
 src/ModelTree.hh |   8 ++-
 5 files changed, 182 insertions(+), 2 deletions(-)

diff --git a/src/ExprNode.cc b/src/ExprNode.cc
index f6e52e48..b4e88c01 100644
--- a/src/ExprNode.cc
+++ b/src/ExprNode.cc
@@ -719,6 +719,18 @@ NumConstNode::fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, c
 {
 }
 
+void
+NumConstNode::findConstantEquations(map<expr_t, expr_t> &table) const
+{
+  return;
+}
+
+expr_t
+NumConstNode::replaceVarsInEquation(map<expr_t, expr_t> &table) const
+{
+  return const_cast<NumConstNode *>(this);
+}
+
 VariableNode::VariableNode(DataTree &datatree_arg, int idx_arg, int symb_id_arg, int lag_arg) :
   ExprNode{datatree_arg, idx_arg},
   symb_id{symb_id_arg},
@@ -2023,6 +2035,21 @@ VariableNode::fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, c
 {
 }
 
+void
+VariableNode::findConstantEquations(map<expr_t, expr_t> &table) const
+{
+  return;
+}
+
+expr_t
+VariableNode::replaceVarsInEquation(map<expr_t, expr_t> &table) const
+{
+  for (auto & it : table)
+    if (dynamic_cast<VariableNode *>(it.first)->symb_id == symb_id)
+      return dynamic_cast<NumConstNode *>(it.second);
+  return const_cast<VariableNode *>(this);
+}
+
 UnaryOpNode::UnaryOpNode(DataTree &datatree_arg, int idx_arg, UnaryOpcode op_code_arg, const expr_t arg_arg, int expectation_information_set_arg, int param1_symb_id_arg, int param2_symb_id_arg, string adl_param_name_arg, vector<int> adl_lags_arg) :
   ExprNode{datatree_arg, idx_arg},
   arg{arg_arg},
@@ -3847,6 +3874,19 @@ UnaryOpNode::fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, co
   arg->fillErrorCorrectionRow(eqn, nontrend_lhs, trend_lhs, EC);
 }
 
+void
+UnaryOpNode::findConstantEquations(map<expr_t, expr_t> &table) const
+{
+  arg->findConstantEquations(table);
+}
+
+expr_t
+UnaryOpNode::replaceVarsInEquation(map<expr_t, expr_t> &table) const
+{
+  expr_t argsubst =  arg->replaceVarsInEquation(table);
+  return buildSimilarUnaryOpNode(argsubst, datatree);
+}
+
 BinaryOpNode::BinaryOpNode(DataTree &datatree_arg, int idx_arg, const expr_t arg1_arg,
                            BinaryOpcode op_code_arg, const expr_t arg2_arg, int powerDerivOrder_arg) :
   ExprNode{datatree_arg, idx_arg},
@@ -5813,6 +5853,36 @@ BinaryOpNode::fillErrorCorrectionRowHelper(expr_t arg1, expr_t arg2,
   EC[make_tuple(eqn, -max_lag, colidx)] = arg1;
 }
 
+void
+BinaryOpNode::findConstantEquations(map<expr_t, expr_t> &table) const
+{
+  if (op_code == BinaryOpcode::equal)
+    if (dynamic_cast<VariableNode *>(arg1) != nullptr && dynamic_cast<NumConstNode *>(arg2) != nullptr)
+      table[arg1] = arg2;
+    else if (dynamic_cast<VariableNode *>(arg2) != nullptr && dynamic_cast<NumConstNode *>(arg1) != nullptr)
+      table[arg2] = arg1;
+  else
+    {
+      arg1->findConstantEquations(table);
+      arg2->findConstantEquations(table);
+    }
+}
+
+expr_t
+BinaryOpNode::replaceVarsInEquation(map<expr_t, expr_t> &table) const
+{
+  if (op_code == BinaryOpcode::equal)
+    for (auto & it : table)
+      if ((dynamic_cast<VariableNode *>(it.first) == arg1
+           && dynamic_cast<NumConstNode *>(it.second) == arg2)
+          || (dynamic_cast<VariableNode *>(it.first) == arg2
+              && dynamic_cast<NumConstNode *>(it.second) == arg1))
+        return const_cast<BinaryOpNode *>(this);
+  expr_t arg1subst = arg1->replaceVarsInEquation(table);
+  expr_t arg2subst = arg2->replaceVarsInEquation(table);
+  return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
+}
+
 void
 BinaryOpNode::fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs, map<tuple<int, int, int>, expr_t> &EC) const
 {
@@ -6817,6 +6887,23 @@ TrinaryOpNode::fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs,
   arg3->fillErrorCorrectionRow(eqn, nontrend_lhs, trend_lhs, EC);
 }
 
+void
+TrinaryOpNode::findConstantEquations(map<expr_t, expr_t> &table) const
+{
+  arg1->findConstantEquations(table);
+  arg2->findConstantEquations(table);
+  arg3->findConstantEquations(table);
+}
+
+expr_t
+TrinaryOpNode::replaceVarsInEquation(map<expr_t, expr_t> &table) const
+{
+  expr_t arg1subst = arg1->replaceVarsInEquation(table);
+  expr_t arg2subst = arg2->replaceVarsInEquation(table);
+  expr_t arg3subst = arg3->replaceVarsInEquation(table);
+  return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree);
+}
+
 AbstractExternalFunctionNode::AbstractExternalFunctionNode(DataTree &datatree_arg,
                                                            int idx_arg,
                                                            int symb_id_arg,
@@ -7462,6 +7549,22 @@ AbstractExternalFunctionNode::fillErrorCorrectionRow(int eqn, const vector<int>
   exit(EXIT_FAILURE);
 }
 
+void
+AbstractExternalFunctionNode::findConstantEquations(map<expr_t, expr_t> &table) const
+{
+  for (auto argument : arguments)
+    argument->findConstantEquations(table);
+}
+
+expr_t
+AbstractExternalFunctionNode::replaceVarsInEquation(map<expr_t, expr_t> &table) const
+{
+  vector<expr_t> arguments_subst;
+  for (auto argument : arguments)
+    arguments_subst.push_back(argument->replaceVarsInEquation(table));
+  return buildSimilarExternalFunctionNode(arguments_subst, datatree);
+}
+
 ExternalFunctionNode::ExternalFunctionNode(DataTree &datatree_arg,
                                            int idx_arg,
                                            int symb_id_arg,
@@ -8982,6 +9085,18 @@ VarExpectationNode::fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_
   exit(EXIT_FAILURE);
 }
 
+void
+VarExpectationNode::findConstantEquations(map<expr_t, expr_t> &table) const
+{
+  return;
+}
+
+expr_t
+VarExpectationNode::replaceVarsInEquation(map<expr_t, expr_t> &table) const
+{
+  return const_cast<VarExpectationNode *>(this);
+}
+
 void
 VarExpectationNode::writeJsonAST(ostream &output) const
 {
@@ -9496,6 +9611,18 @@ PacExpectationNode::fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_
   exit(EXIT_FAILURE);
 }
 
+void
+PacExpectationNode::findConstantEquations(map<expr_t, expr_t> &table) const
+{
+  return;
+}
+
+expr_t
+PacExpectationNode::replaceVarsInEquation(map<expr_t, expr_t> &table) const
+{
+  return const_cast<PacExpectationNode *>(this);
+}
+
 void
 PacExpectationNode::writeJsonAST(ostream &output) const
 {
diff --git a/src/ExprNode.hh b/src/ExprNode.hh
index 71b1d34a..ef7bdfd8 100644
--- a/src/ExprNode.hh
+++ b/src/ExprNode.hh
@@ -615,6 +615,12 @@ class ExprNode
       virtual void fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs,
                                           map<tuple<int, int, int>, expr_t> &EC) const = 0;
 
+      //! Finds equations where a variable is equal to a constant
+      virtual void findConstantEquations(map<expr_t, expr_t> &table) const = 0;
+
+      //! Replaces variables found in findConstantEquations() with their constant values
+      virtual expr_t replaceVarsInEquation(map<expr_t, expr_t> &table) const = 0;
+
       //! Returns true if PacExpectationNode encountered
       virtual bool containsPacExpectation(const string &pac_model_name = "") const = 0;
 
@@ -730,6 +736,8 @@ public:
   void fillPacExpectationVarInfo(string &model_name_arg, vector<int> &lhs_arg, int max_lag_arg, int pac_max_lag_arg, vector<bool> &nonstationary_arg, int growth_symb_id_arg, int growth_lag, int equation_number_arg) override;
   void fillAutoregressiveRow(int eqn, const vector<int> &lhs, map<tuple<int, int, int>, expr_t> &AR) const override;
   void fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs, map<tuple<int, int, int>, expr_t> &EC) const override;
+  void findConstantEquations(map<expr_t, expr_t> &table) const override;
+  expr_t replaceVarsInEquation(map<expr_t, expr_t> &table) const override;
   bool containsPacExpectation(const string &pac_model_name = "") const override;
   void getPacOptimizingPart(int lhs_orig_symb_id, pair<int, pair<vector<int>, vector<bool>>> &ec_params_and_vars,
                             set<pair<int, pair<int, int>>> &params_and_vars) const override;
@@ -820,6 +828,8 @@ public:
   void fillPacExpectationVarInfo(string &model_name_arg, vector<int> &lhs_arg, int max_lag_arg, int pac_max_lag_arg, vector<bool> &nonstationary_arg, int growth_symb_id_arg, int growth_lag, int equation_number_arg) override;
   void fillAutoregressiveRow(int eqn, const vector<int> &lhs, map<tuple<int, int, int>, expr_t> &AR) const override;
   void fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs, map<tuple<int, int, int>, expr_t> &EC) const override;
+  void findConstantEquations(map<expr_t, expr_t> &table) const override;
+  expr_t replaceVarsInEquation(map<expr_t, expr_t> &table) const override;
   bool containsPacExpectation(const string &pac_model_name = "") const override;
   void getPacOptimizingPart(int lhs_orig_symb_id, pair<int, pair<vector<int>, vector<bool>>> &ec_params_and_vars,
                             set<pair<int, pair<int, int>>> &params_and_vars) const override;
@@ -938,6 +948,8 @@ public:
   void fillPacExpectationVarInfo(string &model_name_arg, vector<int> &lhs_arg, int max_lag_arg, int pac_max_lag_arg, vector<bool> &nonstationary_arg, int growth_symb_id_arg, int growth_lag, int equation_number_arg) override;
   void fillAutoregressiveRow(int eqn, const vector<int> &lhs, map<tuple<int, int, int>, expr_t> &AR) const override;
   void fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs, map<tuple<int, int, int>, expr_t> &EC) const override;
+  void findConstantEquations(map<expr_t, expr_t> &table) const override;
+  expr_t replaceVarsInEquation(map<expr_t, expr_t> &table) const override;
   bool containsPacExpectation(const string &pac_model_name = "") const override;
   void getPacOptimizingPart(int lhs_orig_symb_id, pair<int, pair<vector<int>, vector<bool>>> &ec_params_and_vars,
                             set<pair<int, pair<int, int>>> &params_and_vars) const override;
@@ -1074,6 +1086,8 @@ public:
                                     int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs,
                                     map<tuple<int, int, int>, expr_t> &AR) const;
   void fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs, map<tuple<int, int, int>, expr_t> &EC) const override;
+  void findConstantEquations(map<expr_t, expr_t> &table) const override;
+  expr_t replaceVarsInEquation(map<expr_t, expr_t> &table) const override;
   bool containsPacExpectation(const string &pac_model_name = "") const override;
   void getPacOptimizingPart(int lhs_orig_symb_id, pair<int, pair<vector<int>, vector<bool>>> &ec_params_and_vars,
                             set<pair<int, pair<int, int>>> &params_and_vars) const override;
@@ -1189,6 +1203,8 @@ public:
   void fillPacExpectationVarInfo(string &model_name_arg, vector<int> &lhs_arg, int max_lag_arg, int pac_max_lag_arg, vector<bool> &nonstationary_arg, int growth_symb_id_arg, int growth_lag, int equation_number_arg) override;
   void fillAutoregressiveRow(int eqn, const vector<int> &lhs, map<tuple<int, int, int>, expr_t> &AR) const override;
   void fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs, map<tuple<int, int, int>, expr_t> &EC) const override;
+  void findConstantEquations(map<expr_t, expr_t> &table) const override;
+  expr_t replaceVarsInEquation(map<expr_t, expr_t> &table) const override;
   bool containsPacExpectation(const string &pac_model_name = "") const override;
   void getPacOptimizingPart(int lhs_orig_symb_id, pair<int, pair<vector<int>, vector<bool>>> &ec_params_and_vars,
                             set<pair<int, pair<int, int>>> &params_and_vars) const override;
@@ -1316,6 +1332,8 @@ public:
   void fillPacExpectationVarInfo(string &model_name_arg, vector<int> &lhs_arg, int max_lag_arg, int pac_max_lag_arg, vector<bool> &nonstationary_arg, int growth_symb_id_arg, int growth_lag, int equation_number_arg) override;
   void fillAutoregressiveRow(int eqn, const vector<int> &lhs, map<tuple<int, int, int>, expr_t> &AR) const override;
   void fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs, map<tuple<int, int, int>, expr_t> &EC) const override;
+  void findConstantEquations(map<expr_t, expr_t> &table) const override;
+  expr_t replaceVarsInEquation(map<expr_t, expr_t> &table) const override;
   bool containsPacExpectation(const string &pac_model_name = "") const override;
   void getPacOptimizingPart(int lhs_orig_symb_id, pair<int, pair<vector<int>, vector<bool>>> &ec_params_and_vars,
                             set<pair<int, pair<int, int>>> &params_and_vars) const override;
@@ -1531,6 +1549,8 @@ public:
   void fillPacExpectationVarInfo(string &model_name_arg, vector<int> &lhs_arg, int max_lag_arg, int pac_max_lag_arg, vector<bool> &nonstationary_arg, int growth_symb_id_arg, int growth_lag, int equation_number_arg) override;
   void fillAutoregressiveRow(int eqn, const vector<int> &lhs, map<tuple<int, int, int>, expr_t> &AR) const override;
   void fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs, map<tuple<int, int, int>, expr_t> &EC) const override;
+  void findConstantEquations(map<expr_t, expr_t> &table) const override;
+  expr_t replaceVarsInEquation(map<expr_t, expr_t> &table) const override;
   bool containsPacExpectation(const string &pac_model_name = "") const override;
   void getPacOptimizingPart(int lhs_orig_symb_id, pair<int, pair<vector<int>, vector<bool>>> &ec_params_and_vars,
                             set<pair<int, pair<int, int>>> &params_and_vars) const override;
@@ -1632,6 +1652,8 @@ public:
   void fillPacExpectationVarInfo(string &model_name_arg, vector<int> &lhs_arg, int max_lag_arg, int pac_max_lag_arg, vector<bool> &nonstationary_arg, int growth_symb_id_arg, int growth_lag, int equation_number_arg) override;
   void fillAutoregressiveRow(int eqn, const vector<int> &lhs, map<tuple<int, int, int>, expr_t> &AR) const override;
   void fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs, map<tuple<int, int, int>, expr_t> &EC) const override;
+  void findConstantEquations(map<expr_t, expr_t> &table) const override;
+  expr_t replaceVarsInEquation(map<expr_t, expr_t> &table) const override;
   bool containsPacExpectation(const string &pac_model_name = "") const override;
   void getPacOptimizingPart(int lhs_orig_symb_id, pair<int, pair<vector<int>, vector<bool>>> &ec_params_and_vars,
                             set<pair<int, pair<int, int>>> &params_and_vars) const override;
diff --git a/src/ModFile.cc b/src/ModFile.cc
index 8557d682..81f49912 100644
--- a/src/ModFile.cc
+++ b/src/ModFile.cc
@@ -370,6 +370,8 @@ ModFile::transformPass(bool nostrict, bool stochastic, bool compute_xrefs, const
   dynamic_model.setLeadsLagsOrig();
   original_model = dynamic_model;
 
+  dynamic_model.simplifyEquations();
+
   if (nostrict)
     {
       set<int> unusedEndogs = dynamic_model.findUnusedEndogenous();
diff --git a/src/ModelTree.cc b/src/ModelTree.cc
index 71e1a285..29486fa3 100644
--- a/src/ModelTree.cc
+++ b/src/ModelTree.cc
@@ -1,5 +1,5 @@
 /*
- * Copyright (C) 2003-2018 Dynare Team
+ * Copyright (C) 2003-2019 Dynare Team
  *
  * This file is part of Dynare.
  *
@@ -1921,6 +1921,29 @@ ModelTree::addEquation(expr_t eq, int lineno)
   equations_lineno.push_back(lineno);
 }
 
+void
+ModelTree::simplifyEquations()
+{
+  size_t last_subst_table_size = 0;
+  map<expr_t, expr_t> subst_table;
+  findConstantEquations(subst_table);
+  while (subst_table.size() != last_subst_table_size)
+    {
+      last_subst_table_size = subst_table.size();
+      for (auto & equation : equations)
+        equation = dynamic_cast<BinaryOpNode *>(equation->replaceVarsInEquation(subst_table));
+      subst_table.clear();
+      findConstantEquations(subst_table);
+    }
+}
+
+void
+ModelTree::findConstantEquations(map<expr_t, expr_t> &subst_table) const
+{
+  for (auto & equation : equations)
+    equation->findConstantEquations(subst_table);
+}
+
 void
 ModelTree::addEquation(expr_t eq, int lineno, const vector<pair<string, string>> &eq_tags)
 {
diff --git a/src/ModelTree.hh b/src/ModelTree.hh
index 2eb3c2e0..95af4faa 100644
--- a/src/ModelTree.hh
+++ b/src/ModelTree.hh
@@ -1,5 +1,5 @@
 /*
- * Copyright (C) 2003-2018 Dynare Team
+ * Copyright (C) 2003-2019 Dynare Team
  *
  * This file is part of Dynare.
  *
@@ -348,6 +348,12 @@ public:
   void set_cutoff_to_zero();
   //! Helper for writing the Jacobian elements in MATLAB and C
   /*! Writes either (i+1,j+1) or [i+j*no_eq] */
+
+  //! Simplify model equations: if a variable is equal to a constant, replace that variable elsewhere in the model
+  void simplifyEquations();
+  //! Find equations where variable is equal to a constant
+  void findConstantEquations(map<expr_t, expr_t> &subst_table) const;
+
   void jacobianHelper(ostream &output, int eq_nb, int col_nb, ExprNodeOutputType output_type) const;
   //! Helper for writing the sparse Hessian or third derivatives in MATLAB and C
   /*! If order=2, writes either v2(i+1,j+1) or v2[i+j*NNZDerivatives[2]]
-- 
GitLab