diff --git a/src/DataTree.cc b/src/DataTree.cc index df054eb50b89741469b04d5f7cb44c9461c91325..5649aa0ea2076f3de8ff4632b0c4c8f890bc237f 100644 --- a/src/DataTree.cc +++ b/src/DataTree.cc @@ -98,6 +98,11 @@ DataTree::operator=(const DataTree &d) // Constants must be initialized first because they are used in some Add* methods initConstants(); + /* Model local variables must be next, because they can be evaluated in Add* + methods when the model equations are added */ + for (const auto &it : d.local_variables_table) + local_variables_table[it.first] = it.second->clone(*this); + for (const auto &it : d.node_list) it->clone(*this); @@ -105,9 +110,6 @@ DataTree::operator=(const DataTree &d) local_variables_vector = d.local_variables_vector; - for (const auto &it : d.local_variables_table) - local_variables_table[it.first] = it.second->clone(*this); - return *this; } diff --git a/src/ExprNode.cc b/src/ExprNode.cc index 7d2814be28de5b735c9129cd1aa52d5c50062b66..1c0725243b881367e78c0c83cf321bd774f76e28 100644 --- a/src/ExprNode.cc +++ b/src/ExprNode.cc @@ -859,6 +859,9 @@ VariableNode::computeDerivative(int deriv_id) bool VariableNode::containsExternalFunction() const { + if (get_type() == SymbolType::modelLocalVariable) + return datatree.getLocalVariable(symb_id)->containsExternalFunction(); + return false; } @@ -1203,6 +1206,9 @@ VariableNode::substituteStaticAuxiliaryVariable() const double VariableNode::eval(const eval_context_t &eval_context) const noexcept(false) { + if (get_type() == SymbolType::modelLocalVariable) + return datatree.getLocalVariable(symb_id)->eval(eval_context); + auto it = eval_context.find(symb_id); if (it == eval_context.end()) throw EvalException(); @@ -1311,6 +1317,8 @@ VariableNode::computeSubExprContainingVariable(int symb_id_arg, int lag_arg, set { if (symb_id == symb_id_arg && lag == lag_arg) contain_var.insert(const_cast<VariableNode*>(this)); + if (get_type() == SymbolType::modelLocalVariable) + datatree.getLocalVariable(symb_id)->computeSubExprContainingVariable(symb_id_arg, lag_arg, contain_var); } BinaryOpNode * @@ -1318,6 +1326,9 @@ VariableNode::normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs { assert(contain_var.count(const_cast<VariableNode *>(this)) > 0); + if (get_type() == SymbolType::modelLocalVariable) + return datatree.getLocalVariable(symb_id)->normalizeEquationHelper(contain_var, rhs); + // This the LHS variable: we have finished the normalization return datatree.AddEqual(const_cast<VariableNode *>(this), rhs); } @@ -1391,10 +1402,12 @@ VariableNode::computeXrefs(EquationInfo &ei) const case SymbolType::parameter: ei.param.emplace(symb_id, 0); break; + case SymbolType::modFileLocalVariable: + datatree.getLocalVariable(symb_id)->computeXrefs(ei); + break; case SymbolType::trend: case SymbolType::logTrend: case SymbolType::modelLocalVariable: - case SymbolType::modFileLocalVariable: case SymbolType::statementDeclaredVariable: case SymbolType::unusedEndogenous: case SymbolType::externalFunction: @@ -1559,6 +1572,9 @@ VariableNode::VarMaxLag(const set<expr_t> &lhs_lag_equiv) const int VariableNode::PacMaxLag(int lhs_symb_id) const { + if (get_type() == SymbolType::modelLocalVariable) + return datatree.getLocalVariable(symb_id)->PacMaxLag(lhs_symb_id); + if (lhs_symb_id == symb_id) return -lag; return 0; @@ -1567,28 +1583,41 @@ VariableNode::PacMaxLag(int lhs_symb_id) const expr_t VariableNode::substituteAdl() const { + if (get_type() == SymbolType::modelLocalVariable) + return datatree.getLocalVariable(symb_id)->substituteAdl(); + return const_cast<VariableNode *>(this); } expr_t VariableNode::substituteVarExpectation(const map<string, expr_t> &subst_table) const { + if (get_type() == SymbolType::modelLocalVariable) + return datatree.getLocalVariable(symb_id)->substituteVarExpectation(subst_table); + return const_cast<VariableNode *>(this); } void VariableNode::findDiffNodes(lag_equivalence_table_t &nodes) const { + if (get_type() == SymbolType::modelLocalVariable) + datatree.getLocalVariable(symb_id)->findDiffNodes(nodes); } void VariableNode::findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const { + if (get_type() == SymbolType::modelLocalVariable) + datatree.getLocalVariable(symb_id)->findUnaryOpNodesForAuxVarCreation(nodes); } int VariableNode::findTargetVariable(int lhs_symb_id) const { + if (get_type() == SymbolType::modelLocalVariable) + return datatree.getLocalVariable(symb_id)->findTargetVariable(lhs_symb_id); + return -1; } @@ -1596,18 +1625,27 @@ expr_t VariableNode::substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const { + if (get_type() == SymbolType::modelLocalVariable) + return datatree.getLocalVariable(symb_id)->substituteDiff(nodes, subst_table, neweqs); + return const_cast<VariableNode *>(this); } expr_t VariableNode::substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const { + if (get_type() == SymbolType::modelLocalVariable) + return datatree.getLocalVariable(symb_id)->substituteUnaryOpNodes(nodes, subst_table, neweqs); + return const_cast<VariableNode *>(this); } expr_t VariableNode::substitutePacExpectation(const string &name, expr_t subexpr) { + if (get_type() == SymbolType::modelLocalVariable) + return datatree.getLocalVariable(symb_id)->substitutePacExpectation(name, subexpr); + return const_cast<VariableNode *>(this); } @@ -1632,6 +1670,9 @@ VariableNode::decreaseLeadsLags(int n) const expr_t VariableNode::decreaseLeadsLagsPredeterminedVariables() const { + if (get_type() == SymbolType::modelLocalVariable) + return datatree.getLocalVariable(symb_id)->decreaseLeadsLagsPredeterminedVariables(); + if (datatree.symbol_table.isPredetermined(symb_id)) return decreaseLeadsLags(1); else @@ -1773,6 +1814,9 @@ VariableNode::substituteExoLag(subst_table_t &subst_table, vector<BinaryOpNode * expr_t VariableNode::substituteExpectation(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool partial_information_model) const { + if (get_type() == SymbolType::modelLocalVariable) + return datatree.getLocalVariable(symb_id)->substituteExpectation(subst_table, neweqs, partial_information_model); + return const_cast<VariableNode *>(this); } @@ -1831,12 +1875,18 @@ VariableNode::isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int la bool VariableNode::containsPacExpectation(const string &pac_model_name) const { + if (get_type() == SymbolType::modelLocalVariable) + return datatree.getLocalVariable(symb_id)->containsPacExpectation(pac_model_name); + return false; } bool VariableNode::containsEndogenous() const { + if (get_type() == SymbolType::modelLocalVariable) + return datatree.getLocalVariable(symb_id)->containsEndogenous(); + if (get_type() == SymbolType::endogenous) return true; else @@ -1846,12 +1896,18 @@ VariableNode::containsEndogenous() const bool VariableNode::containsExogenous() const { + if (get_type() == SymbolType::modelLocalVariable) + return datatree.getLocalVariable(symb_id)->containsExogenous(); + return get_type() == SymbolType::exogenous || get_type() == SymbolType::exogenousDet; } expr_t VariableNode::replaceTrendVar() const { + if (get_type() == SymbolType::modelLocalVariable) + return datatree.getLocalVariable(symb_id)->replaceTrendVar(); + if (get_type() == SymbolType::trend) return datatree.One; else if (get_type() == SymbolType::logTrend) @@ -1863,6 +1919,9 @@ VariableNode::replaceTrendVar() const expr_t VariableNode::detrend(int symb_id, bool log_trend, expr_t trend) const { + if (get_type() == SymbolType::modelLocalVariable) + return datatree.getLocalVariable(symb_id)->detrend(symb_id, log_trend, trend); + if (this->symb_id != symb_id) return const_cast<VariableNode *>(this); @@ -1885,12 +1944,18 @@ VariableNode::detrend(int symb_id, bool log_trend, expr_t trend) const int VariableNode::countDiffs() const { + if (get_type() == SymbolType::modelLocalVariable) + return datatree.getLocalVariable(symb_id)->countDiffs(); + return 0; } expr_t VariableNode::removeTrendLeadLag(const map<int, expr_t> &trend_symbols_map) const { + if (get_type() == SymbolType::modelLocalVariable) + return datatree.getLocalVariable(symb_id)->removeTrendLeadLag(trend_symbols_map); + if ((get_type() != SymbolType::trend && get_type() != SymbolType::logTrend) || lag == 0) return const_cast<VariableNode *>(this); @@ -1936,24 +2001,36 @@ VariableNode::removeTrendLeadLag(const map<int, expr_t> &trend_symbols_map) cons bool VariableNode::isInStaticForm() const { + if (get_type() == SymbolType::modelLocalVariable) + return datatree.getLocalVariable(symb_id)->isInStaticForm(); + return lag == 0; } bool VariableNode::isParamTimesEndogExpr() const { + if (get_type() == SymbolType::modelLocalVariable) + return datatree.getLocalVariable(symb_id)->isParamTimesEndogExpr(); + return false; } bool VariableNode::isVarModelReferenced(const string &model_info_name) const { + if (get_type() == SymbolType::modelLocalVariable) + return datatree.getLocalVariable(symb_id)->isVarModelReferenced(model_info_name); + return false; } void VariableNode::getEndosAndMaxLags(map<string, int> &model_endos_and_lags) const { + if (get_type() == SymbolType::modelLocalVariable) + return datatree.getLocalVariable(symb_id)->getEndosAndMaxLags(model_endos_and_lags); + if (get_type() == SymbolType::endogenous) if (string varname = datatree.symbol_table.getName(symb_id); model_endos_and_lags.find(varname) == model_endos_and_lags.end()) @@ -1971,6 +2048,9 @@ VariableNode::findConstantEquations(map<VariableNode *, NumConstNode *> &table) expr_t VariableNode::replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const { + if (get_type() == SymbolType::modelLocalVariable) + return datatree.getLocalVariable(symb_id)->replaceVarsInEquation(table); + for (auto &it : table) if (it.first->symb_id == symb_id) return it.second; @@ -1980,6 +2060,9 @@ VariableNode::replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) void VariableNode::matchMatchedMoment(vector<int> &symb_ids, vector<int> &lags, vector<int> &powers) const { + /* Used for simple expression outside model block, so no need to special-case + model local variables */ + if (get_type() != SymbolType::endogenous) throw MatchFailureException{"Variable " + datatree.symbol_table.getName(symb_id) + " is not an endogenous"};