diff --git a/src/DataTree.cc b/src/DataTree.cc index e2f6ed1ef0a82afd02d679bf7492f170c46e54b1..d479ecdbf0ea8d6da13c5f50acf1bf13f8876f26 100644 --- a/src/DataTree.cc +++ b/src/DataTree.cc @@ -98,6 +98,11 @@ DataTree::operator=(const DataTree &d) // Constants must be initialized first because they are used in some Add* methods initConstants(); + /* Model local variables must be next, because they can be evaluated in Add* + methods when the model equations are added */ + for (const auto &it : d.local_variables_table) + local_variables_table[it.first] = it.second->clone(*this); + for (const auto &it : d.node_list) it->clone(*this); @@ -105,9 +110,6 @@ DataTree::operator=(const DataTree &d) local_variables_vector = d.local_variables_vector; - for (const auto &it : d.local_variables_table) - local_variables_table[it.first] = it.second->clone(*this); - return *this; } diff --git a/src/ExprNode.cc b/src/ExprNode.cc index 4d6db015955ee2aaa6873c6d5bdfcc2bc023c752..3358e73a8dd182b754292127bfb240477615c97b 100644 --- a/src/ExprNode.cc +++ b/src/ExprNode.cc @@ -887,6 +887,9 @@ VariableNode::collectTemporary_terms(const temporary_terms_t &temporary_terms, t bool VariableNode::containsExternalFunction() const { + if (get_type() == SymbolType::modelLocalVariable) + return datatree.getLocalVariable(symb_id)->containsExternalFunction(); + return false; } @@ -1245,6 +1248,9 @@ VariableNode::substituteStaticAuxiliaryVariable() const double VariableNode::eval(const eval_context_t &eval_context) const noexcept(false) { + if (get_type() == SymbolType::modelLocalVariable) + return datatree.getLocalVariable(symb_id)->eval(eval_context); + auto it = eval_context.find(symb_id); if (it == eval_context.end()) throw EvalException(); @@ -1461,10 +1467,12 @@ VariableNode::computeXrefs(EquationInfo &ei) const case SymbolType::parameter: ei.param.emplace(symb_id, 0); break; + case SymbolType::modFileLocalVariable: + datatree.getLocalVariable(symb_id)->computeXrefs(ei); + break; case SymbolType::trend: case SymbolType::logTrend: case SymbolType::modelLocalVariable: - case SymbolType::modFileLocalVariable: case SymbolType::statementDeclaredVariable: case SymbolType::unusedEndogenous: case SymbolType::externalFunction: @@ -1629,6 +1637,9 @@ VariableNode::VarMaxLag(const set<expr_t> &lhs_lag_equiv) const int VariableNode::PacMaxLag(int lhs_symb_id) const { + if (get_type() == SymbolType::modelLocalVariable) + return datatree.getLocalVariable(symb_id)->PacMaxLag(lhs_symb_id); + if (lhs_symb_id == symb_id) return -lag; return 0; @@ -1643,28 +1654,41 @@ VariableNode::getPacTargetSymbId(int lhs_symb_id, int undiff_lhs_symb_id) const expr_t VariableNode::substituteAdl() const { + if (get_type() == SymbolType::modelLocalVariable) + return datatree.getLocalVariable(symb_id)->substituteAdl(); + return const_cast<VariableNode *>(this); } expr_t VariableNode::substituteVarExpectation(const map<string, expr_t> &subst_table) const { + if (get_type() == SymbolType::modelLocalVariable) + return datatree.getLocalVariable(symb_id)->substituteVarExpectation(subst_table); + return const_cast<VariableNode *>(this); } void VariableNode::findDiffNodes(lag_equivalence_table_t &nodes) const { + if (get_type() == SymbolType::modelLocalVariable) + datatree.getLocalVariable(symb_id)->findDiffNodes(nodes); } void VariableNode::findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const { + if (get_type() == SymbolType::modelLocalVariable) + datatree.getLocalVariable(symb_id)->findUnaryOpNodesForAuxVarCreation(nodes); } int VariableNode::findTargetVariable(int lhs_symb_id) const { + if (get_type() == SymbolType::modelLocalVariable) + return datatree.getLocalVariable(symb_id)->findTargetVariable(lhs_symb_id); + return -1; } @@ -1672,18 +1696,27 @@ expr_t VariableNode::substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const { + if (get_type() == SymbolType::modelLocalVariable) + return datatree.getLocalVariable(symb_id)->substituteDiff(nodes, subst_table, neweqs); + return const_cast<VariableNode *>(this); } expr_t VariableNode::substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const { + if (get_type() == SymbolType::modelLocalVariable) + return datatree.getLocalVariable(symb_id)->substituteUnaryOpNodes(nodes, subst_table, neweqs); + return const_cast<VariableNode *>(this); } expr_t VariableNode::substitutePacExpectation(const string &name, expr_t subexpr) { + if (get_type() == SymbolType::modelLocalVariable) + return datatree.getLocalVariable(symb_id)->substitutePacExpectation(name, subexpr); + return const_cast<VariableNode *>(this); } @@ -1708,6 +1741,9 @@ VariableNode::decreaseLeadsLags(int n) const expr_t VariableNode::decreaseLeadsLagsPredeterminedVariables() const { + if (get_type() == SymbolType::modelLocalVariable) + return datatree.getLocalVariable(symb_id)->decreaseLeadsLagsPredeterminedVariables(); + if (datatree.symbol_table.isPredetermined(symb_id)) return decreaseLeadsLags(1); else @@ -1849,6 +1885,9 @@ VariableNode::substituteExoLag(subst_table_t &subst_table, vector<BinaryOpNode * expr_t VariableNode::substituteExpectation(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool partial_information_model) const { + if (get_type() == SymbolType::modelLocalVariable) + return datatree.getLocalVariable(symb_id)->substituteExpectation(subst_table, neweqs, partial_information_model); + return const_cast<VariableNode *>(this); } @@ -1907,12 +1946,18 @@ VariableNode::isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int la bool VariableNode::containsPacExpectation(const string &pac_model_name) const { + if (get_type() == SymbolType::modelLocalVariable) + return datatree.getLocalVariable(symb_id)->containsPacExpectation(pac_model_name); + return false; } bool VariableNode::containsEndogenous() const { + if (get_type() == SymbolType::modelLocalVariable) + return datatree.getLocalVariable(symb_id)->containsEndogenous(); + if (get_type() == SymbolType::endogenous) return true; else @@ -1922,12 +1967,18 @@ VariableNode::containsEndogenous() const bool VariableNode::containsExogenous() const { + if (get_type() == SymbolType::modelLocalVariable) + return datatree.getLocalVariable(symb_id)->containsExogenous(); + return get_type() == SymbolType::exogenous || get_type() == SymbolType::exogenousDet; } expr_t VariableNode::replaceTrendVar() const { + if (get_type() == SymbolType::modelLocalVariable) + return datatree.getLocalVariable(symb_id)->replaceTrendVar(); + if (get_type() == SymbolType::trend) return datatree.One; else if (get_type() == SymbolType::logTrend) @@ -1939,6 +1990,9 @@ VariableNode::replaceTrendVar() const expr_t VariableNode::detrend(int symb_id, bool log_trend, expr_t trend) const { + if (get_type() == SymbolType::modelLocalVariable) + return datatree.getLocalVariable(symb_id)->detrend(symb_id, log_trend, trend); + if (this->symb_id != symb_id) return const_cast<VariableNode *>(this); @@ -1961,12 +2015,18 @@ VariableNode::detrend(int symb_id, bool log_trend, expr_t trend) const int VariableNode::countDiffs() const { + if (get_type() == SymbolType::modelLocalVariable) + return datatree.getLocalVariable(symb_id)->countDiffs(); + return 0; } expr_t VariableNode::removeTrendLeadLag(const map<int, expr_t> &trend_symbols_map) const { + if (get_type() == SymbolType::modelLocalVariable) + return datatree.getLocalVariable(symb_id)->removeTrendLeadLag(trend_symbols_map); + if ((get_type() != SymbolType::trend && get_type() != SymbolType::logTrend) || lag == 0) return const_cast<VariableNode *>(this); @@ -2012,24 +2072,36 @@ VariableNode::removeTrendLeadLag(const map<int, expr_t> &trend_symbols_map) cons bool VariableNode::isInStaticForm() const { + if (get_type() == SymbolType::modelLocalVariable) + return datatree.getLocalVariable(symb_id)->isInStaticForm(); + return lag == 0; } bool VariableNode::isParamTimesEndogExpr() const { + if (get_type() == SymbolType::modelLocalVariable) + return datatree.getLocalVariable(symb_id)->isParamTimesEndogExpr(); + return false; } bool VariableNode::isVarModelReferenced(const string &model_info_name) const { + if (get_type() == SymbolType::modelLocalVariable) + return datatree.getLocalVariable(symb_id)->isVarModelReferenced(model_info_name); + return false; } void VariableNode::getEndosAndMaxLags(map<string, int> &model_endos_and_lags) const { + if (get_type() == SymbolType::modelLocalVariable) + return datatree.getLocalVariable(symb_id)->getEndosAndMaxLags(model_endos_and_lags); + if (get_type() == SymbolType::endogenous) if (string varname = datatree.symbol_table.getName(symb_id); model_endos_and_lags.find(varname) == model_endos_and_lags.end()) @@ -2047,6 +2119,9 @@ VariableNode::findConstantEquations(map<VariableNode *, NumConstNode *> &table) expr_t VariableNode::replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const { + if (get_type() == SymbolType::modelLocalVariable) + return datatree.getLocalVariable(symb_id)->replaceVarsInEquation(table); + for (auto &it : table) if (it.first->symb_id == symb_id) return it.second;