diff --git a/src/DynamicModel.cc b/src/DynamicModel.cc index 98c0e910e4d8f4cbea2edae1e50249599a866d70..5282909aaa62a4e27c7615be6667f5bf115ed6a2 100644 --- a/src/DynamicModel.cc +++ b/src/DynamicModel.cc @@ -3466,23 +3466,46 @@ DynamicModel::updateVarAndTrendModelRhs() const else if (i == 1) eqnums = trend_component_model_table.getEqNums(); + map<string, vector<int>> trend_varr; map<string, vector<set<pair<int, int>>>> rhsr; for (const auto & it : eqnums) { + vector<int> lhs; + vector<int> trend_var; vector<set<pair<int, int>>> rhs; + int lhs_idx = 0; + if (i == 1) + lhs = trend_component_model_table.getLhs(it.first); for (auto eqn : it.second) { set<pair<int, int>> rhs_set; equations[eqn]->get_arg2()->collectDynamicVariables(SymbolType::endogenous, rhs_set); rhs.push_back(rhs_set); + if (i == 1) + { + int lhs_symb_id = lhs[lhs_idx++]; + if (symbol_table.isAuxiliaryVariable(lhs_symb_id)) + try + { + lhs_symb_id = symbol_table.getOrigSymbIdForAuxVar(lhs_symb_id); + } + catch (...) + { + } + trend_var.push_back(equations[eqn]->get_arg2()->findTrendVariable(lhs_symb_id)); + } } rhsr[it.first] = rhs; + trend_varr[it.first] = trend_var; } if (i == 0) var_model_table.setRhs(rhsr); else if (i == 1) - trend_component_model_table.setRhs(rhsr); + { + trend_component_model_table.setRhs(rhsr); + trend_component_model_table.setTrendVar(trend_varr); + } } } diff --git a/src/ExprNode.cc b/src/ExprNode.cc index be3e9387e71bb4cc02432fe417b500404213d0ce..c5c47f247a20e0779109650585d0a3487f0fabe3 100644 --- a/src/ExprNode.cc +++ b/src/ExprNode.cc @@ -563,6 +563,12 @@ NumConstNode::findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_ { } +int +NumConstNode::findTrendVariable(int lhs_symb_id) const +{ + return -1; +} + expr_t NumConstNode::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const { @@ -1458,6 +1464,12 @@ VariableNode::findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_ { } +int +VariableNode::findTrendVariable(int lhs_symb_id) const +{ + return -1; +} + expr_t VariableNode::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const @@ -3157,6 +3169,12 @@ UnaryOpNode::findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) diff_table[sthis][arg_max_lag] = const_cast<UnaryOpNode *>(this); } +int +UnaryOpNode::findTrendVariable(int lhs_symb_id) const +{ + return arg->findTrendVariable(lhs_symb_id); +} + expr_t UnaryOpNode::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const @@ -5078,6 +5096,65 @@ BinaryOpNode::isInStaticForm() const return arg1->isInStaticForm() && arg2->isInStaticForm(); } +bool +BinaryOpNode::findTrendVariableHelper1(int lhs_symb_id, int rhs_symb_id) const +{ + if (lhs_symb_id == rhs_symb_id) + return true; + + try + { + if (datatree.symbol_table.isAuxiliaryVariable(rhs_symb_id) + && lhs_symb_id == datatree.symbol_table.getOrigSymbIdForAuxVar(rhs_symb_id)) + return true; + } + catch (...) + { + } + return false; +} + +int +BinaryOpNode::findTrendVariableHelper(const expr_t arg1, const expr_t arg2, + int lhs_symb_id) const +{ + set<int> params; + arg1->collectVariables(SymbolType::parameter, params); + if (params.size() != 1) + return -1; + + set<pair<int, int>> endogs; + arg2->collectDynamicVariables(SymbolType::endogenous, endogs); + if (endogs.size() == 2) + { + auto *testarg2 = dynamic_cast<BinaryOpNode *>(arg2); + if (testarg2 != nullptr && testarg2->get_op_code() == BinaryOpcode::minus) + { + auto *test_arg1 = dynamic_cast<VariableNode *>(testarg2->get_arg1()); + auto *test_arg2 = dynamic_cast<VariableNode *>(testarg2->get_arg2()); + if (test_arg1 != nullptr && test_arg2 != nullptr ) + if (findTrendVariableHelper1(lhs_symb_id, endogs.begin()->first)) + return endogs.rbegin()->first; + else if (findTrendVariableHelper1(lhs_symb_id, endogs.rbegin()->first)) + return endogs.begin()->first; + } + } + return -1; +} + +int +BinaryOpNode::findTrendVariable(int lhs_symb_id) const +{ + int retval = findTrendVariableHelper(arg1, arg2, lhs_symb_id); + if (retval < 0) + retval = findTrendVariableHelper(arg2, arg1, lhs_symb_id); + if (retval < 0) + retval = arg1->findTrendVariable(lhs_symb_id); + if (retval < 0) + retval = arg2->findTrendVariable(lhs_symb_id); + return retval; +} + void BinaryOpNode::getPacOptimizingPartHelper(const expr_t arg1, const expr_t arg2, pair<int, vector<int>> &ec_params_and_vars, @@ -6073,6 +6150,17 @@ TrinaryOpNode::findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff arg3->findUnaryOpNodesForAuxVarCreation(static_datatree, nodes); } +int +TrinaryOpNode::findTrendVariable(int lhs_symb_id) const +{ + int retval = arg1->findTrendVariable(lhs_symb_id); + if (retval < 0) + retval = arg2->findTrendVariable(lhs_symb_id); + if (retval < 0) + retval = arg3->findTrendVariable(lhs_symb_id); + return retval; +} + expr_t TrinaryOpNode::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const @@ -6535,6 +6623,18 @@ AbstractExternalFunctionNode::findUnaryOpNodesForAuxVarCreation(DataTree &static argument->findUnaryOpNodesForAuxVarCreation(static_datatree, nodes); } +int +AbstractExternalFunctionNode::findTrendVariable(int lhs_symb_id) const +{ + for (auto argument : arguments) + { + int retval = argument->findTrendVariable(lhs_symb_id); + if (retval >= 0) + return retval; + } + return -1; +} + expr_t AbstractExternalFunctionNode::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const @@ -8168,6 +8268,12 @@ VarExpectationNode::findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, { } +int +VarExpectationNode::findTrendVariable(int lhs_symb_id) const +{ + return -1; +} + expr_t VarExpectationNode::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const @@ -8689,6 +8795,12 @@ PacExpectationNode::findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, { } +int +PacExpectationNode::findTrendVariable(int lhs_symb_id) const +{ + return -1; +} + expr_t PacExpectationNode::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const diff --git a/src/ExprNode.hh b/src/ExprNode.hh index 7346250a8f5395ff0598f612d75ff10ff885719e..2e7363e203d700644fe0afebb7b94b327a409c51 100644 --- a/src/ExprNode.hh +++ b/src/ExprNode.hh @@ -505,6 +505,7 @@ class ExprNode //! Substitute diff operator virtual void findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const = 0; virtual void findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const = 0; + virtual int findTrendVariable(int lhs_symb_id) const = 0; virtual expr_t substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const = 0; virtual expr_t substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const = 0; @@ -617,6 +618,7 @@ public: expr_t substituteVarExpectation(const map<string, expr_t> &subst_table) const override; void findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const override; void findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const override; + int findTrendVariable(int lhs_symb_id) const override; expr_t substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override; expr_t substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override; expr_t substitutePacExpectation(map<const PacExpectationNode *, const BinaryOpNode *> &subst_table) override; @@ -715,6 +717,7 @@ public: expr_t substituteVarExpectation(const map<string, expr_t> &subst_table) const override; void findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const override; void findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const override; + int findTrendVariable(int lhs_symb_id) const override; expr_t substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override; expr_t substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override; expr_t substitutePacExpectation(map<const PacExpectationNode *, const BinaryOpNode *> &subst_table) override; @@ -837,6 +840,7 @@ public: void findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const override; bool createAuxVarForUnaryOpNode() const; void findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const override; + int findTrendVariable(int lhs_symb_id) const override; expr_t substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override; expr_t substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override; expr_t substitutePacExpectation(map<const PacExpectationNode *, const BinaryOpNode *> &subst_table) override; @@ -980,6 +984,9 @@ public: expr_t substituteVarExpectation(const map<string, expr_t> &subst_table) const override; void findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const override; void findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const override; + bool findTrendVariableHelper1(int lhs_symb_id, int rhs_symb_id) const; + int findTrendVariableHelper(const expr_t arg1, const expr_t arg2, int lhs_symb_id) const; + int findTrendVariable(int lhs_symb_id) const override; expr_t substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override; expr_t substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override; expr_t substitutePacExpectation(map<const PacExpectationNode *, const BinaryOpNode *> &subst_table) override; @@ -1092,6 +1099,7 @@ public: expr_t substituteVarExpectation(const map<string, expr_t> &subst_table) const override; void findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const override; void findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const override; + int findTrendVariable(int lhs_symb_id) const override; expr_t substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override; expr_t substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override; expr_t substitutePacExpectation(map<const PacExpectationNode *, const BinaryOpNode *> &subst_table) override; @@ -1210,6 +1218,7 @@ public: expr_t substituteVarExpectation(const map<string, expr_t> &subst_table) const override; void findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const override; void findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const override; + int findTrendVariable(int lhs_symb_id) const override; expr_t substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override; expr_t substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override; expr_t substitutePacExpectation(map<const PacExpectationNode *, const BinaryOpNode *> &subst_table) override; @@ -1412,6 +1421,7 @@ public: expr_t substituteVarExpectation(const map<string, expr_t> &subst_table) const override; void findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const override; void findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const override; + int findTrendVariable(int lhs_symb_id) const override; expr_t substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override; expr_t substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override; expr_t substitutePacExpectation(map<const PacExpectationNode *, const BinaryOpNode *> &subst_table) override; @@ -1507,6 +1517,7 @@ public: expr_t substituteVarExpectation(const map<string, expr_t> &subst_table) const override; void findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const override; void findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const override; + int findTrendVariable(int lhs_symb_id) const override; expr_t substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override; expr_t substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override; expr_t substitutePacExpectation(map<const PacExpectationNode *, const BinaryOpNode *> &subst_table) override; diff --git a/src/SubModel.cc b/src/SubModel.cc index e4786116235ada65dad5fe2caf34e246876447d2..62613b39f8d4d0268ddb26416bd5a37841c167ea 100644 --- a/src/SubModel.cc +++ b/src/SubModel.cc @@ -78,6 +78,12 @@ TrendComponentModelTable::setNonstationary(map<string, vector<bool>> nonstationa nonstationary = move(nonstationary_arg); } +void +TrendComponentModelTable::setTrendVar(map<string, vector<int>> trend_vars_arg) +{ + trend_vars = move(trend_vars_arg); +} + void TrendComponentModelTable::setLhs(map<string, vector<int>> lhs_arg) { @@ -275,6 +281,10 @@ TrendComponentModelTable::writeOutput(ostream &output) const i++; } + output << "M_.trend_component." << name << ".trend_vars = ["; + for (auto it : trend_vars.at(name)) + output << (it >= 0 ? symbol_table.getTypeSpecificID(it) + 1 : -1) << " "; + output << "];" << endl; } } diff --git a/src/SubModel.hh b/src/SubModel.hh index d8b691f0202b03dbb2f99145a448744f77c7b84f..e5cba518d4b5142b30ea5a89ff1431aae226c6b2 100644 --- a/src/SubModel.hh +++ b/src/SubModel.hh @@ -44,6 +44,7 @@ private: map<string, vector<set<pair<int, int>>>> rhs; map<string, vector<bool>> diff, nonstationary; map<string, vector<expr_t>> lhs_expr_t; + map<string, vector<int>> trend_vars; public: TrendComponentModelTable(SymbolTable &symbol_table_arg); @@ -77,6 +78,7 @@ public: void setDiff(map<string, vector<bool>> diff_arg); void setOrigDiffVar(map<string, vector<int>> orig_diff_var_arg); void setNonstationary(map<string, vector<bool>> nonstationary_arg); + void setTrendVar(map<string, vector<int>> trend_vars_arg); //! Write output of this class void writeOutput(ostream &output) const; diff --git a/src/SymbolTable.cc b/src/SymbolTable.cc index e5b3ed765b7f7dbba8aa19d527c19e743cd53b0b..722edfdb36010428b55b4cfe0a546ffe64fc3b00 100644 --- a/src/SymbolTable.cc +++ b/src/SymbolTable.cc @@ -865,7 +865,10 @@ int SymbolTable::getOrigSymbIdForAuxVar(int aux_var_symb_id) const noexcept(false) { for (const auto & aux_var : aux_vars) - if ((aux_var.get_type() == AuxVarType::endoLag || aux_var.get_type() == AuxVarType::exoLag || aux_var.get_type() == AuxVarType::diff) + if ((aux_var.get_type() == AuxVarType::endoLag + || aux_var.get_type() == AuxVarType::exoLag + || aux_var.get_type() == AuxVarType::diff + || aux_var.get_type() == AuxVarType::diffLag) && aux_var.get_symb_id() == aux_var_symb_id) return aux_var.get_orig_symb_id(); throw UnknownSymbolIDException(aux_var_symb_id);