From 84d792bced340c52d90cd962e51e22eec605cea0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Villemot?= <sebastien@dynare.org> Date: Wed, 5 Jun 2024 15:26:34 +0200 Subject: [PATCH] Allow model-local variables with leads or lags Ref. dynare#1929 --- src/DataTree.hh | 6 +-- src/ExprNode.cc | 108 ++++++++++++++++++++++--------------------- src/ParsingDriver.cc | 4 -- 3 files changed, 58 insertions(+), 60 deletions(-) diff --git a/src/DataTree.hh b/src/DataTree.hh index 8502f3a8..3a02b7eb 100644 --- a/src/DataTree.hh +++ b/src/DataTree.hh @@ -1,5 +1,5 @@ /* - * Copyright © 2003-2023 Dynare Team + * Copyright © 2003-2024 Dynare Team * * This file is part of Dynare. * @@ -360,13 +360,13 @@ public: }; [[nodiscard]] expr_t - getLocalVariable(int symb_id) const + getLocalVariable(int symb_id, int lead_lag) const { auto it = local_variables_table.find(symb_id); if (it == local_variables_table.end()) throw UnknownLocalVariableException {symb_id}; - return it->second; + return it->second->decreaseLeadsLags(-lead_lag); } static void diff --git a/src/ExprNode.cc b/src/ExprNode.cc index d475e5cf..13557d0a 100644 --- a/src/ExprNode.cc +++ b/src/ExprNode.cc @@ -880,9 +880,7 @@ VariableNode::VariableNode(DataTree& datatree_arg, int idx_arg, int symb_id_arg, // It makes sense to allow a lead/lag on parameters: during steady state calibration, endogenous // and parameters can be swapped assert(get_type() != SymbolType::externalFunction - && (lag == 0 - || (get_type() != SymbolType::modelLocalVariable - && get_type() != SymbolType::modFileLocalVariable))); + && (lag == 0 || get_type() != SymbolType::modFileLocalVariable)); } void @@ -909,9 +907,9 @@ VariableNode::prepareForDerivation() non_null_derivatives.insert(datatree.getDerivID(symb_id, lag)); break; case SymbolType::modelLocalVariable: - datatree.getLocalVariable(symb_id)->prepareForDerivation(); + datatree.getLocalVariable(symb_id, lag)->prepareForDerivation(); // Non null derivatives are those of the value of the local parameter - non_null_derivatives = datatree.getLocalVariable(symb_id)->non_null_derivatives; + non_null_derivatives = datatree.getLocalVariable(symb_id, lag)->non_null_derivatives; break; case SymbolType::modFileLocalVariable: case SymbolType::statementDeclaredVariable: @@ -967,7 +965,7 @@ VariableNode::prepareForChainRuleDerivation( break; case SymbolType::modelLocalVariable: { - expr_t def {datatree.getLocalVariable(symb_id)}; + expr_t def {datatree.getLocalVariable(symb_id, lag)}; // Non null derivatives are those of the value of the model local variable def->prepareForChainRuleDerivation(recursive_variables, non_null_chain_rule_derivatives); non_null_chain_rule_derivatives.emplace(const_cast<VariableNode*>(this), @@ -1002,7 +1000,7 @@ VariableNode::computeDerivative(int deriv_id) else return datatree.Zero; case SymbolType::modelLocalVariable: - return datatree.getLocalVariable(symb_id)->getDerivative(deriv_id); + return datatree.getLocalVariable(symb_id, lag)->getDerivative(deriv_id); case SymbolType::modFileLocalVariable: cerr << "modFileLocalVariable is not derivable" << endl; exit(EXIT_FAILURE); @@ -1025,7 +1023,7 @@ bool VariableNode::containsExternalFunction() const { if (get_type() == SymbolType::modelLocalVariable) - return datatree.getLocalVariable(symb_id)->containsExternalFunction(); + return datatree.getLocalVariable(symb_id, lag)->containsExternalFunction(); return false; } @@ -1158,8 +1156,8 @@ VariableNode::writeOutput(ostream& output, ExprNodeOutputType output_type, || output_type == ExprNodeOutputType::CDynamicSteadyStateOperator) { output << "("; - datatree.getLocalVariable(symb_id)->writeOutput(output, output_type, temporary_terms, - temporary_terms_idxs, tef_terms); + datatree.getLocalVariable(symb_id, lag) + ->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms); output << ")"; } else @@ -1422,7 +1420,7 @@ double VariableNode::eval(const eval_context_t& eval_context) const noexcept(false) { if (get_type() == SymbolType::modelLocalVariable) - return datatree.getLocalVariable(symb_id)->eval(eval_context); + return datatree.getLocalVariable(symb_id, lag)->eval(eval_context); auto it = eval_context.find(symb_id); if (it == eval_context.end()) @@ -1443,8 +1441,9 @@ VariableNode::writeBytecodeOutput(Bytecode::Writer& code_file, return; if (auto type = get_type(); type == SymbolType::modelLocalVariable) - datatree.getLocalVariable(symb_id)->writeBytecodeOutput(code_file, output_type, temporary_terms, - temporary_terms_idxs, tef_terms); + datatree.getLocalVariable(symb_id, lag) + ->writeBytecodeOutput(code_file, output_type, temporary_terms, temporary_terms_idxs, + tef_terms); else { int tsid = datatree.symbol_table.getTypeSpecificID(symb_id); @@ -1487,7 +1486,7 @@ VariableNode::collectDynamicVariables(SymbolType type_arg, set<pair<int, int>>& if (get_type() == type_arg) result.emplace(symb_id, lag); if (get_type() == SymbolType::modelLocalVariable) - datatree.getLocalVariable(symb_id)->collectDynamicVariables(type_arg, result); + datatree.getLocalVariable(symb_id, lag)->collectDynamicVariables(type_arg, result); } void @@ -1497,8 +1496,8 @@ VariableNode::computeSubExprContainingVariable(int symb_id_arg, int lag_arg, if (symb_id == symb_id_arg && lag == lag_arg) contain_var.insert(const_cast<VariableNode*>(this)); if (get_type() == SymbolType::modelLocalVariable) - datatree.getLocalVariable(symb_id)->computeSubExprContainingVariable(symb_id_arg, lag_arg, - contain_var); + datatree.getLocalVariable(symb_id, lag) + ->computeSubExprContainingVariable(symb_id_arg, lag_arg, contain_var); } BinaryOpNode* @@ -1507,7 +1506,7 @@ VariableNode::normalizeEquationHelper(const set<expr_t>& contain_var, expr_t rhs assert(contain_var.contains(const_cast<VariableNode*>(this))); if (get_type() == SymbolType::modelLocalVariable) - return datatree.getLocalVariable(symb_id)->normalizeEquationHelper(contain_var, rhs); + return datatree.getLocalVariable(symb_id, lag)->normalizeEquationHelper(contain_var, rhs); // This the LHS variable: we have finished the normalization return datatree.AddEqual(const_cast<VariableNode*>(this), rhs); @@ -1541,8 +1540,9 @@ VariableNode::computeChainRuleDerivative( return datatree.Zero; case SymbolType::modelLocalVariable: - return datatree.getLocalVariable(symb_id)->getChainRuleDerivative( - deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache); + return datatree.getLocalVariable(symb_id, lag) + ->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, + cache); case SymbolType::modFileLocalVariable: cerr << "modFileLocalVariable is not derivable" << endl; exit(EXIT_FAILURE); @@ -1585,7 +1585,7 @@ VariableNode::computeXrefs(EquationInfo& ei) const ei.param.emplace(symb_id, 0); break; case SymbolType::modelLocalVariable: - datatree.getLocalVariable(symb_id)->computeXrefs(ei); + datatree.getLocalVariable(symb_id, lag)->computeXrefs(ei); break; case SymbolType::trend: case SymbolType::logTrend: @@ -1619,7 +1619,7 @@ VariableNode::maxEndoLead() const case SymbolType::endogenous: return max(lag, 0); case SymbolType::modelLocalVariable: - return datatree.getLocalVariable(symb_id)->maxEndoLead(); + return datatree.getLocalVariable(symb_id, lag)->maxEndoLead(); default: return 0; } @@ -1633,7 +1633,7 @@ VariableNode::maxExoLead() const case SymbolType::exogenous: return max(lag, 0); case SymbolType::modelLocalVariable: - return datatree.getLocalVariable(symb_id)->maxExoLead(); + return datatree.getLocalVariable(symb_id, lag)->maxExoLead(); default: return 0; } @@ -1647,7 +1647,7 @@ VariableNode::maxEndoLag() const case SymbolType::endogenous: return max(-lag, 0); case SymbolType::modelLocalVariable: - return datatree.getLocalVariable(symb_id)->maxEndoLag(); + return datatree.getLocalVariable(symb_id, lag)->maxEndoLag(); default: return 0; } @@ -1661,7 +1661,7 @@ VariableNode::maxExoLag() const case SymbolType::exogenous: return max(-lag, 0); case SymbolType::modelLocalVariable: - return datatree.getLocalVariable(symb_id)->maxExoLag(); + return datatree.getLocalVariable(symb_id, lag)->maxExoLag(); default: return 0; } @@ -1677,7 +1677,7 @@ VariableNode::maxLead() const case SymbolType::exogenousDet: return lag; case SymbolType::modelLocalVariable: - return datatree.getLocalVariable(symb_id)->maxLead(); + return datatree.getLocalVariable(symb_id, lag)->maxLead(); default: return 0; } @@ -1693,7 +1693,7 @@ VariableNode::maxLag() const case SymbolType::exogenousDet: return -lag; case SymbolType::modelLocalVariable: - return datatree.getLocalVariable(symb_id)->maxLag(); + return datatree.getLocalVariable(symb_id, lag)->maxLag(); default: return 0; } @@ -1710,7 +1710,7 @@ VariableNode::maxLagWithDiffsExpanded() const case SymbolType::epilogue: return -lag; case SymbolType::modelLocalVariable: - return datatree.getLocalVariable(symb_id)->maxLagWithDiffsExpanded(); + return datatree.getLocalVariable(symb_id, lag)->maxLagWithDiffsExpanded(); default: return 0; } @@ -1744,7 +1744,7 @@ expr_t VariableNode::substituteModelLocalVariables() const { if (get_type() == SymbolType::modelLocalVariable) - return datatree.getLocalVariable(symb_id); + return datatree.getLocalVariable(symb_id, lag); return const_cast<VariableNode*>(this); } @@ -1753,7 +1753,7 @@ expr_t VariableNode::substituteVarExpectation(const map<string, expr_t>& subst_table) const { if (get_type() == SymbolType::modelLocalVariable) - return datatree.getLocalVariable(symb_id)->substituteVarExpectation(subst_table); + return datatree.getLocalVariable(symb_id, lag)->substituteVarExpectation(subst_table); return const_cast<VariableNode*>(this); } @@ -1762,21 +1762,21 @@ void VariableNode::findDiffNodes(lag_equivalence_table_t& nodes) const { if (get_type() == SymbolType::modelLocalVariable) - datatree.getLocalVariable(symb_id)->findDiffNodes(nodes); + datatree.getLocalVariable(symb_id, lag)->findDiffNodes(nodes); } void VariableNode::findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t& nodes) const { if (get_type() == SymbolType::modelLocalVariable) - datatree.getLocalVariable(symb_id)->findUnaryOpNodesForAuxVarCreation(nodes); + datatree.getLocalVariable(symb_id, lag)->findUnaryOpNodesForAuxVarCreation(nodes); } optional<int> VariableNode::findTargetVariable(int lhs_symb_id) const { if (get_type() == SymbolType::modelLocalVariable) - return datatree.getLocalVariable(symb_id)->findTargetVariable(lhs_symb_id); + return datatree.getLocalVariable(symb_id, lag)->findTargetVariable(lhs_symb_id); return nullopt; } @@ -1786,7 +1786,7 @@ VariableNode::substituteDiff(const lag_equivalence_table_t& nodes, subst_table_t vector<BinaryOpNode*>& neweqs) const { if (get_type() == SymbolType::modelLocalVariable) - return datatree.getLocalVariable(symb_id)->substituteDiff(nodes, subst_table, neweqs); + return datatree.getLocalVariable(symb_id, lag)->substituteDiff(nodes, subst_table, neweqs); return const_cast<VariableNode*>(this); } @@ -1797,7 +1797,8 @@ VariableNode::substituteUnaryOpNodes(const lag_equivalence_table_t& nodes, vector<BinaryOpNode*>& neweqs) const { if (get_type() == SymbolType::modelLocalVariable) - return datatree.getLocalVariable(symb_id)->substituteUnaryOpNodes(nodes, subst_table, neweqs); + return datatree.getLocalVariable(symb_id, lag) + ->substituteUnaryOpNodes(nodes, subst_table, neweqs); return const_cast<VariableNode*>(this); } @@ -1806,7 +1807,7 @@ expr_t VariableNode::substitutePacExpectation(const string& name, expr_t subexpr) { if (get_type() == SymbolType::modelLocalVariable) - return datatree.getLocalVariable(symb_id)->substitutePacExpectation(name, subexpr); + return datatree.getLocalVariable(symb_id, lag)->substitutePacExpectation(name, subexpr); return const_cast<VariableNode*>(this); } @@ -1815,7 +1816,7 @@ expr_t VariableNode::substitutePacTargetNonstationary(const string& name, expr_t subexpr) { if (get_type() == SymbolType::modelLocalVariable) - return datatree.getLocalVariable(symb_id)->substitutePacTargetNonstationary(name, subexpr); + return datatree.getLocalVariable(symb_id, lag)->substitutePacTargetNonstationary(name, subexpr); return const_cast<VariableNode*>(this); } @@ -1832,7 +1833,7 @@ VariableNode::decreaseLeadsLags(int n) const case SymbolType::logTrend: return datatree.AddVariable(symb_id, lag - n); case SymbolType::modelLocalVariable: - return datatree.getLocalVariable(symb_id)->decreaseLeadsLags(n); + return datatree.getLocalVariable(symb_id, lag)->decreaseLeadsLags(n); default: return const_cast<VariableNode*>(this); } @@ -1863,7 +1864,7 @@ VariableNode::substituteEndoLeadGreaterThanTwo(subst_table_t& subst_table, else return createEndoLeadAuxiliaryVarForMyself(subst_table, neweqs); case SymbolType::modelLocalVariable: - if (expr_t value = datatree.getLocalVariable(symb_id); value->maxEndoLead() <= 1) + if (expr_t value = datatree.getLocalVariable(symb_id, lag); value->maxEndoLead() <= 1) return const_cast<VariableNode*>(this); else return value->substituteEndoLeadGreaterThanTwo(subst_table, neweqs, deterministic_model); @@ -1912,7 +1913,7 @@ VariableNode::substituteEndoLagGreaterThanTwo(subst_table_t& subst_table, return substexpr; case SymbolType::modelLocalVariable: - if (expr_t value = datatree.getLocalVariable(symb_id); value->maxEndoLag() <= 1) + if (expr_t value = datatree.getLocalVariable(symb_id, lag); value->maxEndoLag() <= 1) return const_cast<VariableNode*>(this); else return value->substituteEndoLagGreaterThanTwo(subst_table, neweqs); @@ -1933,7 +1934,7 @@ VariableNode::substituteExoLead(subst_table_t& subst_table, vector<BinaryOpNode* else return createExoLeadAuxiliaryVarForMyself(subst_table, neweqs); case SymbolType::modelLocalVariable: - if (expr_t value = datatree.getLocalVariable(symb_id); value->maxExoLead() == 0) + if (expr_t value = datatree.getLocalVariable(symb_id, lag); value->maxExoLead() == 0) return const_cast<VariableNode*>(this); else return value->substituteExoLead(subst_table, neweqs, deterministic_model); @@ -1981,7 +1982,7 @@ VariableNode::substituteExoLag(subst_table_t& subst_table, vector<BinaryOpNode*> return substexpr; case SymbolType::modelLocalVariable: - if (expr_t value = datatree.getLocalVariable(symb_id); value->maxExoLag() == 0) + if (expr_t value = datatree.getLocalVariable(symb_id, lag); value->maxExoLag() == 0) return const_cast<VariableNode*>(this); else return value->substituteExoLag(subst_table, neweqs); @@ -1995,8 +1996,8 @@ VariableNode::substituteExpectation(subst_table_t& subst_table, vector<BinaryOpN bool partial_information_model) const { if (get_type() == SymbolType::modelLocalVariable) - return datatree.getLocalVariable(symb_id)->substituteExpectation(subst_table, neweqs, - partial_information_model); + return datatree.getLocalVariable(symb_id, lag) + ->substituteExpectation(subst_table, neweqs, partial_information_model); return const_cast<VariableNode*>(this); } @@ -2032,7 +2033,7 @@ VariableNode::differentiateForwardVars(const vector<string>& subset, subst_table return datatree.AddPlus(datatree.AddVariable(symb_id, 0), diffvar); } case SymbolType::modelLocalVariable: - if (expr_t value = datatree.getLocalVariable(symb_id); value->maxEndoLead() <= 0) + if (expr_t value = datatree.getLocalVariable(symb_id, lag); value->maxEndoLead() <= 0) return const_cast<VariableNode*>(this); else return value->differentiateForwardVars(subset, subst_table, neweqs); @@ -2061,7 +2062,7 @@ bool VariableNode::containsPacExpectation(const string& pac_model_name) const { if (get_type() == SymbolType::modelLocalVariable) - return datatree.getLocalVariable(symb_id)->containsPacExpectation(pac_model_name); + return datatree.getLocalVariable(symb_id, lag)->containsPacExpectation(pac_model_name); return false; } @@ -2070,7 +2071,7 @@ bool VariableNode::containsPacTargetNonstationary(const string& pac_model_name) const { if (get_type() == SymbolType::modelLocalVariable) - return datatree.getLocalVariable(symb_id)->containsPacTargetNonstationary(pac_model_name); + return datatree.getLocalVariable(symb_id, lag)->containsPacTargetNonstationary(pac_model_name); return false; } @@ -2079,7 +2080,7 @@ expr_t VariableNode::replaceTrendVar() const { if (get_type() == SymbolType::modelLocalVariable) - return datatree.getLocalVariable(symb_id)->replaceTrendVar(); + return datatree.getLocalVariable(symb_id, lag)->replaceTrendVar(); if (get_type() == SymbolType::trend) return datatree.One; @@ -2093,7 +2094,7 @@ expr_t VariableNode::detrend(int symb_id, bool log_trend, expr_t trend) const { if (get_type() == SymbolType::modelLocalVariable) - return datatree.getLocalVariable(symb_id)->detrend(symb_id, log_trend, trend); + return datatree.getLocalVariable(symb_id, lag)->detrend(symb_id, log_trend, trend); if (this->symb_id != symb_id) return const_cast<VariableNode*>(this); @@ -2118,7 +2119,7 @@ int VariableNode::countDiffs() const { if (get_type() == SymbolType::modelLocalVariable) - return datatree.getLocalVariable(symb_id)->countDiffs(); + return datatree.getLocalVariable(symb_id, lag)->countDiffs(); return 0; } @@ -2127,7 +2128,7 @@ expr_t VariableNode::removeTrendLeadLag(const map<int, expr_t>& trend_symbols_map) const { if (get_type() == SymbolType::modelLocalVariable) - return datatree.getLocalVariable(symb_id)->removeTrendLeadLag(trend_symbols_map); + return datatree.getLocalVariable(symb_id, lag)->removeTrendLeadLag(trend_symbols_map); if ((get_type() != SymbolType::trend && get_type() != SymbolType::logTrend) || lag == 0) return const_cast<VariableNode*>(this); @@ -2179,7 +2180,7 @@ bool VariableNode::isInStaticForm() const { if (get_type() == SymbolType::modelLocalVariable) - return datatree.getLocalVariable(symb_id)->isInStaticForm(); + return datatree.getLocalVariable(symb_id, lag)->isInStaticForm(); return lag == 0; } @@ -2188,7 +2189,7 @@ bool VariableNode::isParamTimesEndogExpr() const { if (get_type() == SymbolType::modelLocalVariable) - return datatree.getLocalVariable(symb_id)->isParamTimesEndogExpr(); + return datatree.getLocalVariable(symb_id, lag)->isParamTimesEndogExpr(); return false; } @@ -2225,7 +2226,8 @@ expr_t VariableNode::substituteLogTransform(int orig_symb_id, int aux_symb_id) const { if (get_type() == SymbolType::modelLocalVariable) - return datatree.getLocalVariable(symb_id)->substituteLogTransform(orig_symb_id, aux_symb_id); + return datatree.getLocalVariable(symb_id, lag) + ->substituteLogTransform(orig_symb_id, aux_symb_id); if (symb_id == orig_symb_id) return datatree.AddExp(datatree.AddVariable(aux_symb_id, lag)); diff --git a/src/ParsingDriver.cc b/src/ParsingDriver.cc index 953535a7..b3f43827 100644 --- a/src/ParsingDriver.cc +++ b/src/ParsingDriver.cc @@ -418,10 +418,6 @@ ParsingDriver::add_model_variable(int symb_id, int lag) error("Exogenous deterministic variable " + mod_file->symbol_table.getName(symb_id) + " cannot be given a lead or a lag."); - if (type == SymbolType::modelLocalVariable && lag != 0) - error("Model local variable " + mod_file->symbol_table.getName(symb_id) - + " cannot be given a lead or a lag."); - if (data_tree == planner_objective.get()) { if (lag != 0) -- GitLab