From 02ae0af3e51c9d14cdc9e0eb9d7ee92c8bc71391 Mon Sep 17 00:00:00 2001
From: Houtan Bastani <houtan@dynare.org>
Date: Tue, 29 Jan 2019 17:29:24 +0100
Subject: [PATCH] change map type for readability

---
 src/ExprNode.cc  | 40 ++++++++++++++++++++--------------------
 src/ExprNode.hh  | 37 +++++++++++++++++++------------------
 src/ModelTree.cc |  4 ++--
 src/ModelTree.hh |  2 +-
 4 files changed, 42 insertions(+), 41 deletions(-)

diff --git a/src/ExprNode.cc b/src/ExprNode.cc
index 556f0efd..870be77f 100644
--- a/src/ExprNode.cc
+++ b/src/ExprNode.cc
@@ -714,13 +714,13 @@ NumConstNode::fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, c
 }
 
 void
-NumConstNode::findConstantEquations(map<expr_t, expr_t> &table) const
+NumConstNode::findConstantEquations(map<VariableNode *, NumConstNode *> &table) const
 {
   return;
 }
 
 expr_t
-NumConstNode::replaceVarsInEquation(map<expr_t, expr_t> &table) const
+NumConstNode::replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const
 {
   return const_cast<NumConstNode *>(this);
 }
@@ -2024,17 +2024,17 @@ VariableNode::fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, c
 }
 
 void
-VariableNode::findConstantEquations(map<expr_t, expr_t> &table) const
+VariableNode::findConstantEquations(map<VariableNode *, NumConstNode *> &table) const
 {
   return;
 }
 
 expr_t
-VariableNode::replaceVarsInEquation(map<expr_t, expr_t> &table) const
+VariableNode::replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const
 {
   for (auto & it : table)
-    if (dynamic_cast<VariableNode *>(it.first)->symb_id == symb_id)
-      return dynamic_cast<NumConstNode *>(it.second);
+    if (it.first->symb_id == symb_id)
+      return it.second;
   return const_cast<VariableNode *>(this);
 }
 
@@ -3857,13 +3857,13 @@ UnaryOpNode::fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, co
 }
 
 void
-UnaryOpNode::findConstantEquations(map<expr_t, expr_t> &table) const
+UnaryOpNode::findConstantEquations(map<VariableNode *, NumConstNode *> &table) const
 {
   arg->findConstantEquations(table);
 }
 
 expr_t
-UnaryOpNode::replaceVarsInEquation(map<expr_t, expr_t> &table) const
+UnaryOpNode::replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const
 {
   expr_t argsubst =  arg->replaceVarsInEquation(table);
   return buildSimilarUnaryOpNode(argsubst, datatree);
@@ -5830,13 +5830,13 @@ BinaryOpNode::fillErrorCorrectionRowHelper(expr_t arg1, expr_t arg2,
 }
 
 void
-BinaryOpNode::findConstantEquations(map<expr_t, expr_t> &table) const
+BinaryOpNode::findConstantEquations(map<VariableNode *, NumConstNode *> &table) const
 {
   if (op_code == BinaryOpcode::equal)
     if (dynamic_cast<VariableNode *>(arg1) != nullptr && dynamic_cast<NumConstNode *>(arg2) != nullptr)
-      table[arg1] = arg2;
+      table[dynamic_cast<VariableNode *>(arg1)] = dynamic_cast<NumConstNode *>(arg2);
     else if (dynamic_cast<VariableNode *>(arg2) != nullptr && dynamic_cast<NumConstNode *>(arg1) != nullptr)
-      table[arg2] = arg1;
+      table[dynamic_cast<VariableNode *>(arg2)] = dynamic_cast<NumConstNode *>(arg1);
   else
     {
       arg1->findConstantEquations(table);
@@ -5845,7 +5845,7 @@ BinaryOpNode::findConstantEquations(map<expr_t, expr_t> &table) const
 }
 
 expr_t
-BinaryOpNode::replaceVarsInEquation(map<expr_t, expr_t> &table) const
+BinaryOpNode::replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const
 {
   if (op_code == BinaryOpcode::equal)
     for (auto & it : table)
@@ -6858,7 +6858,7 @@ TrinaryOpNode::fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs,
 }
 
 void
-TrinaryOpNode::findConstantEquations(map<expr_t, expr_t> &table) const
+TrinaryOpNode::findConstantEquations(map<VariableNode *, NumConstNode *> &table) const
 {
   arg1->findConstantEquations(table);
   arg2->findConstantEquations(table);
@@ -6866,7 +6866,7 @@ TrinaryOpNode::findConstantEquations(map<expr_t, expr_t> &table) const
 }
 
 expr_t
-TrinaryOpNode::replaceVarsInEquation(map<expr_t, expr_t> &table) const
+TrinaryOpNode::replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const
 {
   expr_t arg1subst = arg1->replaceVarsInEquation(table);
   expr_t arg2subst = arg2->replaceVarsInEquation(table);
@@ -7511,14 +7511,14 @@ AbstractExternalFunctionNode::fillErrorCorrectionRow(int eqn, const vector<int>
 }
 
 void
-AbstractExternalFunctionNode::findConstantEquations(map<expr_t, expr_t> &table) const
+AbstractExternalFunctionNode::findConstantEquations(map<VariableNode *, NumConstNode *> &table) const
 {
   for (auto argument : arguments)
     argument->findConstantEquations(table);
 }
 
 expr_t
-AbstractExternalFunctionNode::replaceVarsInEquation(map<expr_t, expr_t> &table) const
+AbstractExternalFunctionNode::replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const
 {
   vector<expr_t> arguments_subst;
   for (auto argument : arguments)
@@ -9040,13 +9040,13 @@ VarExpectationNode::fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_
 }
 
 void
-VarExpectationNode::findConstantEquations(map<expr_t, expr_t> &table) const
+VarExpectationNode::findConstantEquations(map<VariableNode *, NumConstNode *> &table) const
 {
   return;
 }
 
 expr_t
-VarExpectationNode::replaceVarsInEquation(map<expr_t, expr_t> &table) const
+VarExpectationNode::replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const
 {
   return const_cast<VarExpectationNode *>(this);
 }
@@ -9560,13 +9560,13 @@ PacExpectationNode::fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_
 }
 
 void
-PacExpectationNode::findConstantEquations(map<expr_t, expr_t> &table) const
+PacExpectationNode::findConstantEquations(map<VariableNode *, NumConstNode *> &table) const
 {
   return;
 }
 
 expr_t
-PacExpectationNode::replaceVarsInEquation(map<expr_t, expr_t> &table) const
+PacExpectationNode::replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const
 {
   return const_cast<PacExpectationNode *>(this);
 }
diff --git a/src/ExprNode.hh b/src/ExprNode.hh
index df57b822..ff84420c 100644
--- a/src/ExprNode.hh
+++ b/src/ExprNode.hh
@@ -34,6 +34,7 @@ using namespace std;
 #include "SymbolList.hh"
 
 class DataTree;
+class NumConstNode;
 class VariableNode;
 class UnaryOpNode;
 class BinaryOpNode;
@@ -613,10 +614,10 @@ class ExprNode
                                           map<tuple<int, int, int>, expr_t> &EC) const = 0;
 
       //! Finds equations where a variable is equal to a constant
-      virtual void findConstantEquations(map<expr_t, expr_t> &table) const = 0;
+      virtual void findConstantEquations(map<VariableNode *, NumConstNode *> &table) const = 0;
 
       //! Replaces variables found in findConstantEquations() with their constant values
-      virtual expr_t replaceVarsInEquation(map<expr_t, expr_t> &table) const = 0;
+      virtual expr_t replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const = 0;
 
       //! Returns true if PacExpectationNode encountered
       virtual bool containsPacExpectation(const string &pac_model_name = "") const = 0;
@@ -732,8 +733,8 @@ public:
   void fillPacExpectationVarInfo(string &model_name_arg, vector<int> &lhs_arg, int max_lag_arg, int pac_max_lag_arg, vector<bool> &nonstationary_arg, int growth_symb_id_arg, int growth_lag, int equation_number_arg) override;
   void fillAutoregressiveRow(int eqn, const vector<int> &lhs, map<tuple<int, int, int>, expr_t> &AR) const override;
   void fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs, map<tuple<int, int, int>, expr_t> &EC) const override;
-  void findConstantEquations(map<expr_t, expr_t> &table) const override;
-  expr_t replaceVarsInEquation(map<expr_t, expr_t> &table) const override;
+  void findConstantEquations(map<VariableNode *, NumConstNode *> &table) const override;
+  expr_t replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const override;
   bool containsPacExpectation(const string &pac_model_name = "") const override;
   void getPacOptimizingPart(int lhs_orig_symb_id, pair<int, pair<vector<int>, vector<bool>>> &ec_params_and_vars,
                             set<pair<int, pair<int, int>>> &params_and_vars) const override;
@@ -823,8 +824,8 @@ public:
   void fillPacExpectationVarInfo(string &model_name_arg, vector<int> &lhs_arg, int max_lag_arg, int pac_max_lag_arg, vector<bool> &nonstationary_arg, int growth_symb_id_arg, int growth_lag, int equation_number_arg) override;
   void fillAutoregressiveRow(int eqn, const vector<int> &lhs, map<tuple<int, int, int>, expr_t> &AR) const override;
   void fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs, map<tuple<int, int, int>, expr_t> &EC) const override;
-  void findConstantEquations(map<expr_t, expr_t> &table) const override;
-  expr_t replaceVarsInEquation(map<expr_t, expr_t> &table) const override;
+  void findConstantEquations(map<VariableNode *, NumConstNode *> &table) const override;
+  expr_t replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const override;
   bool containsPacExpectation(const string &pac_model_name = "") const override;
   void getPacOptimizingPart(int lhs_orig_symb_id, pair<int, pair<vector<int>, vector<bool>>> &ec_params_and_vars,
                             set<pair<int, pair<int, int>>> &params_and_vars) const override;
@@ -942,8 +943,8 @@ public:
   void fillPacExpectationVarInfo(string &model_name_arg, vector<int> &lhs_arg, int max_lag_arg, int pac_max_lag_arg, vector<bool> &nonstationary_arg, int growth_symb_id_arg, int growth_lag, int equation_number_arg) override;
   void fillAutoregressiveRow(int eqn, const vector<int> &lhs, map<tuple<int, int, int>, expr_t> &AR) const override;
   void fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs, map<tuple<int, int, int>, expr_t> &EC) const override;
-  void findConstantEquations(map<expr_t, expr_t> &table) const override;
-  expr_t replaceVarsInEquation(map<expr_t, expr_t> &table) const override;
+  void findConstantEquations(map<VariableNode *, NumConstNode *> &table) const override;
+  expr_t replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const override;
   bool containsPacExpectation(const string &pac_model_name = "") const override;
   void getPacOptimizingPart(int lhs_orig_symb_id, pair<int, pair<vector<int>, vector<bool>>> &ec_params_and_vars,
                             set<pair<int, pair<int, int>>> &params_and_vars) const override;
@@ -1079,8 +1080,8 @@ public:
                                     int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs,
                                     map<tuple<int, int, int>, expr_t> &AR) const;
   void fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs, map<tuple<int, int, int>, expr_t> &EC) const override;
-  void findConstantEquations(map<expr_t, expr_t> &table) const override;
-  expr_t replaceVarsInEquation(map<expr_t, expr_t> &table) const override;
+  void findConstantEquations(map<VariableNode *, NumConstNode *> &table) const override;
+  expr_t replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const override;
   bool containsPacExpectation(const string &pac_model_name = "") const override;
   void getPacOptimizingPart(int lhs_orig_symb_id, pair<int, pair<vector<int>, vector<bool>>> &ec_params_and_vars,
                             set<pair<int, pair<int, int>>> &params_and_vars) const override;
@@ -1195,8 +1196,8 @@ public:
   void fillPacExpectationVarInfo(string &model_name_arg, vector<int> &lhs_arg, int max_lag_arg, int pac_max_lag_arg, vector<bool> &nonstationary_arg, int growth_symb_id_arg, int growth_lag, int equation_number_arg) override;
   void fillAutoregressiveRow(int eqn, const vector<int> &lhs, map<tuple<int, int, int>, expr_t> &AR) const override;
   void fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs, map<tuple<int, int, int>, expr_t> &EC) const override;
-  void findConstantEquations(map<expr_t, expr_t> &table) const override;
-  expr_t replaceVarsInEquation(map<expr_t, expr_t> &table) const override;
+  void findConstantEquations(map<VariableNode *, NumConstNode *> &table) const override;
+  expr_t replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const override;
   bool containsPacExpectation(const string &pac_model_name = "") const override;
   void getPacOptimizingPart(int lhs_orig_symb_id, pair<int, pair<vector<int>, vector<bool>>> &ec_params_and_vars,
                             set<pair<int, pair<int, int>>> &params_and_vars) const override;
@@ -1323,8 +1324,8 @@ public:
   void fillPacExpectationVarInfo(string &model_name_arg, vector<int> &lhs_arg, int max_lag_arg, int pac_max_lag_arg, vector<bool> &nonstationary_arg, int growth_symb_id_arg, int growth_lag, int equation_number_arg) override;
   void fillAutoregressiveRow(int eqn, const vector<int> &lhs, map<tuple<int, int, int>, expr_t> &AR) const override;
   void fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs, map<tuple<int, int, int>, expr_t> &EC) const override;
-  void findConstantEquations(map<expr_t, expr_t> &table) const override;
-  expr_t replaceVarsInEquation(map<expr_t, expr_t> &table) const override;
+  void findConstantEquations(map<VariableNode *, NumConstNode *> &table) const override;
+  expr_t replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const override;
   bool containsPacExpectation(const string &pac_model_name = "") const override;
   void getPacOptimizingPart(int lhs_orig_symb_id, pair<int, pair<vector<int>, vector<bool>>> &ec_params_and_vars,
                             set<pair<int, pair<int, int>>> &params_and_vars) const override;
@@ -1539,8 +1540,8 @@ public:
   void fillPacExpectationVarInfo(string &model_name_arg, vector<int> &lhs_arg, int max_lag_arg, int pac_max_lag_arg, vector<bool> &nonstationary_arg, int growth_symb_id_arg, int growth_lag, int equation_number_arg) override;
   void fillAutoregressiveRow(int eqn, const vector<int> &lhs, map<tuple<int, int, int>, expr_t> &AR) const override;
   void fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs, map<tuple<int, int, int>, expr_t> &EC) const override;
-  void findConstantEquations(map<expr_t, expr_t> &table) const override;
-  expr_t replaceVarsInEquation(map<expr_t, expr_t> &table) const override;
+  void findConstantEquations(map<VariableNode *, NumConstNode *> &table) const override;
+  expr_t replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const override;
   bool containsPacExpectation(const string &pac_model_name = "") const override;
   void getPacOptimizingPart(int lhs_orig_symb_id, pair<int, pair<vector<int>, vector<bool>>> &ec_params_and_vars,
                             set<pair<int, pair<int, int>>> &params_and_vars) const override;
@@ -1641,8 +1642,8 @@ public:
   void fillPacExpectationVarInfo(string &model_name_arg, vector<int> &lhs_arg, int max_lag_arg, int pac_max_lag_arg, vector<bool> &nonstationary_arg, int growth_symb_id_arg, int growth_lag, int equation_number_arg) override;
   void fillAutoregressiveRow(int eqn, const vector<int> &lhs, map<tuple<int, int, int>, expr_t> &AR) const override;
   void fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs, map<tuple<int, int, int>, expr_t> &EC) const override;
-  void findConstantEquations(map<expr_t, expr_t> &table) const override;
-  expr_t replaceVarsInEquation(map<expr_t, expr_t> &table) const override;
+  void findConstantEquations(map<VariableNode *, NumConstNode *> &table) const override;
+  expr_t replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const override;
   bool containsPacExpectation(const string &pac_model_name = "") const override;
   void getPacOptimizingPart(int lhs_orig_symb_id, pair<int, pair<vector<int>, vector<bool>>> &ec_params_and_vars,
                             set<pair<int, pair<int, int>>> &params_and_vars) const override;
diff --git a/src/ModelTree.cc b/src/ModelTree.cc
index 29486fa3..136b39bb 100644
--- a/src/ModelTree.cc
+++ b/src/ModelTree.cc
@@ -1925,7 +1925,7 @@ void
 ModelTree::simplifyEquations()
 {
   size_t last_subst_table_size = 0;
-  map<expr_t, expr_t> subst_table;
+  map<VariableNode *, NumConstNode *> subst_table;
   findConstantEquations(subst_table);
   while (subst_table.size() != last_subst_table_size)
     {
@@ -1938,7 +1938,7 @@ ModelTree::simplifyEquations()
 }
 
 void
-ModelTree::findConstantEquations(map<expr_t, expr_t> &subst_table) const
+ModelTree::findConstantEquations(map<VariableNode *, NumConstNode *> &subst_table) const
 {
   for (auto & equation : equations)
     equation->findConstantEquations(subst_table);
diff --git a/src/ModelTree.hh b/src/ModelTree.hh
index 95af4faa..f607fb3e 100644
--- a/src/ModelTree.hh
+++ b/src/ModelTree.hh
@@ -352,7 +352,7 @@ public:
   //! Simplify model equations: if a variable is equal to a constant, replace that variable elsewhere in the model
   void simplifyEquations();
   //! Find equations where variable is equal to a constant
-  void findConstantEquations(map<expr_t, expr_t> &subst_table) const;
+  void findConstantEquations(map<VariableNode *, NumConstNode *> &subst_table) const;
 
   void jacobianHelper(ostream &output, int eq_nb, int col_nb, ExprNodeOutputType output_type) const;
   //! Helper for writing the sparse Hessian or third derivatives in MATLAB and C
-- 
GitLab