From f5df7e75671d2a8236d95d0d5b2f231761bff513 Mon Sep 17 00:00:00 2001 From: Houtan Bastani <houtan@dynare.org> Date: Mon, 28 Jan 2019 14:57:30 +0100 Subject: [PATCH] when an equation is of the form `X` = `constant`, replace all occurrences of `X` in other equations with `constant` --- src/ExprNode.cc | 127 +++++++++++++++++++++++++++++++++++++++++++++++ src/ExprNode.hh | 22 ++++++++ src/ModFile.cc | 2 + src/ModelTree.cc | 25 +++++++++- src/ModelTree.hh | 8 ++- 5 files changed, 182 insertions(+), 2 deletions(-) diff --git a/src/ExprNode.cc b/src/ExprNode.cc index f6e52e48..b4e88c01 100644 --- a/src/ExprNode.cc +++ b/src/ExprNode.cc @@ -719,6 +719,18 @@ NumConstNode::fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, c { } +void +NumConstNode::findConstantEquations(map<expr_t, expr_t> &table) const +{ + return; +} + +expr_t +NumConstNode::replaceVarsInEquation(map<expr_t, expr_t> &table) const +{ + return const_cast<NumConstNode *>(this); +} + VariableNode::VariableNode(DataTree &datatree_arg, int idx_arg, int symb_id_arg, int lag_arg) : ExprNode{datatree_arg, idx_arg}, symb_id{symb_id_arg}, @@ -2023,6 +2035,21 @@ VariableNode::fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, c { } +void +VariableNode::findConstantEquations(map<expr_t, expr_t> &table) const +{ + return; +} + +expr_t +VariableNode::replaceVarsInEquation(map<expr_t, expr_t> &table) const +{ + for (auto & it : table) + if (dynamic_cast<VariableNode *>(it.first)->symb_id == symb_id) + return dynamic_cast<NumConstNode *>(it.second); + return const_cast<VariableNode *>(this); +} + UnaryOpNode::UnaryOpNode(DataTree &datatree_arg, int idx_arg, UnaryOpcode op_code_arg, const expr_t arg_arg, int expectation_information_set_arg, int param1_symb_id_arg, int param2_symb_id_arg, string adl_param_name_arg, vector<int> adl_lags_arg) : ExprNode{datatree_arg, idx_arg}, arg{arg_arg}, @@ -3847,6 +3874,19 @@ UnaryOpNode::fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, co arg->fillErrorCorrectionRow(eqn, nontrend_lhs, trend_lhs, EC); } +void +UnaryOpNode::findConstantEquations(map<expr_t, expr_t> &table) const +{ + arg->findConstantEquations(table); +} + +expr_t +UnaryOpNode::replaceVarsInEquation(map<expr_t, expr_t> &table) const +{ + expr_t argsubst = arg->replaceVarsInEquation(table); + return buildSimilarUnaryOpNode(argsubst, datatree); +} + BinaryOpNode::BinaryOpNode(DataTree &datatree_arg, int idx_arg, const expr_t arg1_arg, BinaryOpcode op_code_arg, const expr_t arg2_arg, int powerDerivOrder_arg) : ExprNode{datatree_arg, idx_arg}, @@ -5813,6 +5853,36 @@ BinaryOpNode::fillErrorCorrectionRowHelper(expr_t arg1, expr_t arg2, EC[make_tuple(eqn, -max_lag, colidx)] = arg1; } +void +BinaryOpNode::findConstantEquations(map<expr_t, expr_t> &table) const +{ + if (op_code == BinaryOpcode::equal) + if (dynamic_cast<VariableNode *>(arg1) != nullptr && dynamic_cast<NumConstNode *>(arg2) != nullptr) + table[arg1] = arg2; + else if (dynamic_cast<VariableNode *>(arg2) != nullptr && dynamic_cast<NumConstNode *>(arg1) != nullptr) + table[arg2] = arg1; + else + { + arg1->findConstantEquations(table); + arg2->findConstantEquations(table); + } +} + +expr_t +BinaryOpNode::replaceVarsInEquation(map<expr_t, expr_t> &table) const +{ + if (op_code == BinaryOpcode::equal) + for (auto & it : table) + if ((dynamic_cast<VariableNode *>(it.first) == arg1 + && dynamic_cast<NumConstNode *>(it.second) == arg2) + || (dynamic_cast<VariableNode *>(it.first) == arg2 + && dynamic_cast<NumConstNode *>(it.second) == arg1)) + return const_cast<BinaryOpNode *>(this); + expr_t arg1subst = arg1->replaceVarsInEquation(table); + expr_t arg2subst = arg2->replaceVarsInEquation(table); + return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree); +} + void BinaryOpNode::fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs, map<tuple<int, int, int>, expr_t> &EC) const { @@ -6817,6 +6887,23 @@ TrinaryOpNode::fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, arg3->fillErrorCorrectionRow(eqn, nontrend_lhs, trend_lhs, EC); } +void +TrinaryOpNode::findConstantEquations(map<expr_t, expr_t> &table) const +{ + arg1->findConstantEquations(table); + arg2->findConstantEquations(table); + arg3->findConstantEquations(table); +} + +expr_t +TrinaryOpNode::replaceVarsInEquation(map<expr_t, expr_t> &table) const +{ + expr_t arg1subst = arg1->replaceVarsInEquation(table); + expr_t arg2subst = arg2->replaceVarsInEquation(table); + expr_t arg3subst = arg3->replaceVarsInEquation(table); + return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree); +} + AbstractExternalFunctionNode::AbstractExternalFunctionNode(DataTree &datatree_arg, int idx_arg, int symb_id_arg, @@ -7462,6 +7549,22 @@ AbstractExternalFunctionNode::fillErrorCorrectionRow(int eqn, const vector<int> exit(EXIT_FAILURE); } +void +AbstractExternalFunctionNode::findConstantEquations(map<expr_t, expr_t> &table) const +{ + for (auto argument : arguments) + argument->findConstantEquations(table); +} + +expr_t +AbstractExternalFunctionNode::replaceVarsInEquation(map<expr_t, expr_t> &table) const +{ + vector<expr_t> arguments_subst; + for (auto argument : arguments) + arguments_subst.push_back(argument->replaceVarsInEquation(table)); + return buildSimilarExternalFunctionNode(arguments_subst, datatree); +} + ExternalFunctionNode::ExternalFunctionNode(DataTree &datatree_arg, int idx_arg, int symb_id_arg, @@ -8982,6 +9085,18 @@ VarExpectationNode::fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_ exit(EXIT_FAILURE); } +void +VarExpectationNode::findConstantEquations(map<expr_t, expr_t> &table) const +{ + return; +} + +expr_t +VarExpectationNode::replaceVarsInEquation(map<expr_t, expr_t> &table) const +{ + return const_cast<VarExpectationNode *>(this); +} + void VarExpectationNode::writeJsonAST(ostream &output) const { @@ -9496,6 +9611,18 @@ PacExpectationNode::fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_ exit(EXIT_FAILURE); } +void +PacExpectationNode::findConstantEquations(map<expr_t, expr_t> &table) const +{ + return; +} + +expr_t +PacExpectationNode::replaceVarsInEquation(map<expr_t, expr_t> &table) const +{ + return const_cast<PacExpectationNode *>(this); +} + void PacExpectationNode::writeJsonAST(ostream &output) const { diff --git a/src/ExprNode.hh b/src/ExprNode.hh index 71b1d34a..ef7bdfd8 100644 --- a/src/ExprNode.hh +++ b/src/ExprNode.hh @@ -615,6 +615,12 @@ class ExprNode virtual void fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs, map<tuple<int, int, int>, expr_t> &EC) const = 0; + //! Finds equations where a variable is equal to a constant + virtual void findConstantEquations(map<expr_t, expr_t> &table) const = 0; + + //! Replaces variables found in findConstantEquations() with their constant values + virtual expr_t replaceVarsInEquation(map<expr_t, expr_t> &table) const = 0; + //! Returns true if PacExpectationNode encountered virtual bool containsPacExpectation(const string &pac_model_name = "") const = 0; @@ -730,6 +736,8 @@ public: void fillPacExpectationVarInfo(string &model_name_arg, vector<int> &lhs_arg, int max_lag_arg, int pac_max_lag_arg, vector<bool> &nonstationary_arg, int growth_symb_id_arg, int growth_lag, int equation_number_arg) override; void fillAutoregressiveRow(int eqn, const vector<int> &lhs, map<tuple<int, int, int>, expr_t> &AR) const override; void fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs, map<tuple<int, int, int>, expr_t> &EC) const override; + void findConstantEquations(map<expr_t, expr_t> &table) const override; + expr_t replaceVarsInEquation(map<expr_t, expr_t> &table) const override; bool containsPacExpectation(const string &pac_model_name = "") const override; void getPacOptimizingPart(int lhs_orig_symb_id, pair<int, pair<vector<int>, vector<bool>>> &ec_params_and_vars, set<pair<int, pair<int, int>>> ¶ms_and_vars) const override; @@ -820,6 +828,8 @@ public: void fillPacExpectationVarInfo(string &model_name_arg, vector<int> &lhs_arg, int max_lag_arg, int pac_max_lag_arg, vector<bool> &nonstationary_arg, int growth_symb_id_arg, int growth_lag, int equation_number_arg) override; void fillAutoregressiveRow(int eqn, const vector<int> &lhs, map<tuple<int, int, int>, expr_t> &AR) const override; void fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs, map<tuple<int, int, int>, expr_t> &EC) const override; + void findConstantEquations(map<expr_t, expr_t> &table) const override; + expr_t replaceVarsInEquation(map<expr_t, expr_t> &table) const override; bool containsPacExpectation(const string &pac_model_name = "") const override; void getPacOptimizingPart(int lhs_orig_symb_id, pair<int, pair<vector<int>, vector<bool>>> &ec_params_and_vars, set<pair<int, pair<int, int>>> ¶ms_and_vars) const override; @@ -938,6 +948,8 @@ public: void fillPacExpectationVarInfo(string &model_name_arg, vector<int> &lhs_arg, int max_lag_arg, int pac_max_lag_arg, vector<bool> &nonstationary_arg, int growth_symb_id_arg, int growth_lag, int equation_number_arg) override; void fillAutoregressiveRow(int eqn, const vector<int> &lhs, map<tuple<int, int, int>, expr_t> &AR) const override; void fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs, map<tuple<int, int, int>, expr_t> &EC) const override; + void findConstantEquations(map<expr_t, expr_t> &table) const override; + expr_t replaceVarsInEquation(map<expr_t, expr_t> &table) const override; bool containsPacExpectation(const string &pac_model_name = "") const override; void getPacOptimizingPart(int lhs_orig_symb_id, pair<int, pair<vector<int>, vector<bool>>> &ec_params_and_vars, set<pair<int, pair<int, int>>> ¶ms_and_vars) const override; @@ -1074,6 +1086,8 @@ public: int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs, map<tuple<int, int, int>, expr_t> &AR) const; void fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs, map<tuple<int, int, int>, expr_t> &EC) const override; + void findConstantEquations(map<expr_t, expr_t> &table) const override; + expr_t replaceVarsInEquation(map<expr_t, expr_t> &table) const override; bool containsPacExpectation(const string &pac_model_name = "") const override; void getPacOptimizingPart(int lhs_orig_symb_id, pair<int, pair<vector<int>, vector<bool>>> &ec_params_and_vars, set<pair<int, pair<int, int>>> ¶ms_and_vars) const override; @@ -1189,6 +1203,8 @@ public: void fillPacExpectationVarInfo(string &model_name_arg, vector<int> &lhs_arg, int max_lag_arg, int pac_max_lag_arg, vector<bool> &nonstationary_arg, int growth_symb_id_arg, int growth_lag, int equation_number_arg) override; void fillAutoregressiveRow(int eqn, const vector<int> &lhs, map<tuple<int, int, int>, expr_t> &AR) const override; void fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs, map<tuple<int, int, int>, expr_t> &EC) const override; + void findConstantEquations(map<expr_t, expr_t> &table) const override; + expr_t replaceVarsInEquation(map<expr_t, expr_t> &table) const override; bool containsPacExpectation(const string &pac_model_name = "") const override; void getPacOptimizingPart(int lhs_orig_symb_id, pair<int, pair<vector<int>, vector<bool>>> &ec_params_and_vars, set<pair<int, pair<int, int>>> ¶ms_and_vars) const override; @@ -1316,6 +1332,8 @@ public: void fillPacExpectationVarInfo(string &model_name_arg, vector<int> &lhs_arg, int max_lag_arg, int pac_max_lag_arg, vector<bool> &nonstationary_arg, int growth_symb_id_arg, int growth_lag, int equation_number_arg) override; void fillAutoregressiveRow(int eqn, const vector<int> &lhs, map<tuple<int, int, int>, expr_t> &AR) const override; void fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs, map<tuple<int, int, int>, expr_t> &EC) const override; + void findConstantEquations(map<expr_t, expr_t> &table) const override; + expr_t replaceVarsInEquation(map<expr_t, expr_t> &table) const override; bool containsPacExpectation(const string &pac_model_name = "") const override; void getPacOptimizingPart(int lhs_orig_symb_id, pair<int, pair<vector<int>, vector<bool>>> &ec_params_and_vars, set<pair<int, pair<int, int>>> ¶ms_and_vars) const override; @@ -1531,6 +1549,8 @@ public: void fillPacExpectationVarInfo(string &model_name_arg, vector<int> &lhs_arg, int max_lag_arg, int pac_max_lag_arg, vector<bool> &nonstationary_arg, int growth_symb_id_arg, int growth_lag, int equation_number_arg) override; void fillAutoregressiveRow(int eqn, const vector<int> &lhs, map<tuple<int, int, int>, expr_t> &AR) const override; void fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs, map<tuple<int, int, int>, expr_t> &EC) const override; + void findConstantEquations(map<expr_t, expr_t> &table) const override; + expr_t replaceVarsInEquation(map<expr_t, expr_t> &table) const override; bool containsPacExpectation(const string &pac_model_name = "") const override; void getPacOptimizingPart(int lhs_orig_symb_id, pair<int, pair<vector<int>, vector<bool>>> &ec_params_and_vars, set<pair<int, pair<int, int>>> ¶ms_and_vars) const override; @@ -1632,6 +1652,8 @@ public: void fillPacExpectationVarInfo(string &model_name_arg, vector<int> &lhs_arg, int max_lag_arg, int pac_max_lag_arg, vector<bool> &nonstationary_arg, int growth_symb_id_arg, int growth_lag, int equation_number_arg) override; void fillAutoregressiveRow(int eqn, const vector<int> &lhs, map<tuple<int, int, int>, expr_t> &AR) const override; void fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs, map<tuple<int, int, int>, expr_t> &EC) const override; + void findConstantEquations(map<expr_t, expr_t> &table) const override; + expr_t replaceVarsInEquation(map<expr_t, expr_t> &table) const override; bool containsPacExpectation(const string &pac_model_name = "") const override; void getPacOptimizingPart(int lhs_orig_symb_id, pair<int, pair<vector<int>, vector<bool>>> &ec_params_and_vars, set<pair<int, pair<int, int>>> ¶ms_and_vars) const override; diff --git a/src/ModFile.cc b/src/ModFile.cc index 8557d682..81f49912 100644 --- a/src/ModFile.cc +++ b/src/ModFile.cc @@ -370,6 +370,8 @@ ModFile::transformPass(bool nostrict, bool stochastic, bool compute_xrefs, const dynamic_model.setLeadsLagsOrig(); original_model = dynamic_model; + dynamic_model.simplifyEquations(); + if (nostrict) { set<int> unusedEndogs = dynamic_model.findUnusedEndogenous(); diff --git a/src/ModelTree.cc b/src/ModelTree.cc index 71e1a285..29486fa3 100644 --- a/src/ModelTree.cc +++ b/src/ModelTree.cc @@ -1,5 +1,5 @@ /* - * Copyright (C) 2003-2018 Dynare Team + * Copyright (C) 2003-2019 Dynare Team * * This file is part of Dynare. * @@ -1921,6 +1921,29 @@ ModelTree::addEquation(expr_t eq, int lineno) equations_lineno.push_back(lineno); } +void +ModelTree::simplifyEquations() +{ + size_t last_subst_table_size = 0; + map<expr_t, expr_t> subst_table; + findConstantEquations(subst_table); + while (subst_table.size() != last_subst_table_size) + { + last_subst_table_size = subst_table.size(); + for (auto & equation : equations) + equation = dynamic_cast<BinaryOpNode *>(equation->replaceVarsInEquation(subst_table)); + subst_table.clear(); + findConstantEquations(subst_table); + } +} + +void +ModelTree::findConstantEquations(map<expr_t, expr_t> &subst_table) const +{ + for (auto & equation : equations) + equation->findConstantEquations(subst_table); +} + void ModelTree::addEquation(expr_t eq, int lineno, const vector<pair<string, string>> &eq_tags) { diff --git a/src/ModelTree.hh b/src/ModelTree.hh index 2eb3c2e0..95af4faa 100644 --- a/src/ModelTree.hh +++ b/src/ModelTree.hh @@ -1,5 +1,5 @@ /* - * Copyright (C) 2003-2018 Dynare Team + * Copyright (C) 2003-2019 Dynare Team * * This file is part of Dynare. * @@ -348,6 +348,12 @@ public: void set_cutoff_to_zero(); //! Helper for writing the Jacobian elements in MATLAB and C /*! Writes either (i+1,j+1) or [i+j*no_eq] */ + + //! Simplify model equations: if a variable is equal to a constant, replace that variable elsewhere in the model + void simplifyEquations(); + //! Find equations where variable is equal to a constant + void findConstantEquations(map<expr_t, expr_t> &subst_table) const; + void jacobianHelper(ostream &output, int eq_nb, int col_nb, ExprNodeOutputType output_type) const; //! Helper for writing the sparse Hessian or third derivatives in MATLAB and C /*! If order=2, writes either v2(i+1,j+1) or v2[i+j*NNZDerivatives[2]] -- GitLab