From a93e264c2cddf01baeae0b203992e74ba2833e47 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Villemot?= <sebastien@dynare.org> Date: Fri, 28 Jan 2022 15:36:04 +0100 Subject: [PATCH] =?UTF-8?q?Harmonize=20=E2=80=9CdiffForward=E2=80=9D=20aux?= =?UTF-8?q?var=20with=20=E2=80=9Cdiff=E2=80=9D=20auxvar=20by=20giving=20it?= =?UTF-8?q?=20an=20orig=5Flead=5Flag=20as=20well?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit By the way: – Fix and improve the explanation of the purpose of the orig_symb_id and orig_lead_lag fields for auxvars – Factorize the code that prints those fields in MATLAB and JSON output --- src/ExprNode.cc | 8 +++---- src/SymbolTable.cc | 54 ++++++++++++++++++---------------------------- src/SymbolTable.hh | 21 +++++++++++------- 3 files changed, 38 insertions(+), 45 deletions(-) diff --git a/src/ExprNode.cc b/src/ExprNode.cc index 579fc5a3..77a9e38c 100644 --- a/src/ExprNode.cc +++ b/src/ExprNode.cc @@ -1791,10 +1791,10 @@ VariableNode::differentiateForwardVars(const vector<string> &subset, subst_table diffvar = const_cast<VariableNode *>(it->second); else { - int aux_symb_id = datatree.symbol_table.addDiffForwardAuxiliaryVar(symb_id, datatree.AddMinus(datatree.AddVariable(symb_id, 0), - datatree.AddVariable(symb_id, -1))); - neweqs.push_back(datatree.AddEqual(datatree.AddVariable(aux_symb_id, 0), datatree.AddMinus(datatree.AddVariable(symb_id, 0), - datatree.AddVariable(symb_id, -1)))); + expr_t substexpr = datatree.AddMinus(datatree.AddVariable(symb_id, 0), + datatree.AddVariable(symb_id, -1)); + int aux_symb_id = datatree.symbol_table.addDiffForwardAuxiliaryVar(symb_id, 0, substexpr); + neweqs.push_back(datatree.AddEqual(datatree.AddVariable(aux_symb_id, 0), substexpr)); diffvar = datatree.AddVariable(aux_symb_id, 1); subst_table[this] = diffvar; } diff --git a/src/SymbolTable.cc b/src/SymbolTable.cc index 87b7daaf..d76a7d9c 100644 --- a/src/SymbolTable.cc +++ b/src/SymbolTable.cc @@ -357,35 +357,29 @@ SymbolTable::writeOutput(ostream &output) const noexcept(false) { case AuxVarType::endoLead: case AuxVarType::exoLead: + case AuxVarType::expectation: + case AuxVarType::pacExpectation: + case AuxVarType::pacTargetNonstationary: break; case AuxVarType::endoLag: case AuxVarType::exoLag: + case AuxVarType::diffLag: + case AuxVarType::diffLead: + case AuxVarType::diffForward: output << "M_.aux_vars(" << i+1 << ").orig_index = " << getTypeSpecificID(aux_vars[i].get_orig_symb_id())+1 << ";" << endl << "M_.aux_vars(" << i+1 << ").orig_lead_lag = " << aux_vars[i].get_orig_lead_lag() << ";" << endl; break; case AuxVarType::unaryOp: + output << "M_.aux_vars(" << i+1 << ").unary_op = '" << aux_vars[i].get_unary_op() << "';" << endl; + // NB: Fallback! + case AuxVarType::diff: if (aux_vars[i].get_orig_symb_id() >= 0) output << "M_.aux_vars(" << i+1 << ").orig_index = " << getTypeSpecificID(aux_vars[i].get_orig_symb_id())+1 << ";" << endl << "M_.aux_vars(" << i+1 << ").orig_lead_lag = " << aux_vars[i].get_orig_lead_lag() << ";" << endl; - output << "M_.aux_vars(" << i+1 << ").unary_op = '" << aux_vars[i].get_unary_op() << "';" << endl; break; case AuxVarType::multiplier: output << "M_.aux_vars(" << i+1 << ").eq_nbr = " << aux_vars[i].get_equation_number_for_multiplier() + 1 << ";" << endl; break; - case AuxVarType::diffForward: - output << "M_.aux_vars(" << i+1 << ").orig_index = " << getTypeSpecificID(aux_vars[i].get_orig_symb_id())+1 << ";" << endl; - break; - case AuxVarType::expectation: - case AuxVarType::pacExpectation: - case AuxVarType::pacTargetNonstationary: - break; - case AuxVarType::diff: - case AuxVarType::diffLag: - case AuxVarType::diffLead: - if (aux_vars[i].get_orig_symb_id() >= 0) - output << "M_.aux_vars(" << i+1 << ").orig_index = " << getTypeSpecificID(aux_vars[i].get_orig_symb_id())+1 << ";" << endl - << "M_.aux_vars(" << i+1 << ").orig_lead_lag = " << aux_vars[i].get_orig_lead_lag() << ";" << endl; - break; } if (expr_t orig_expr = aux_vars[i].get_expr_node(); @@ -650,7 +644,7 @@ SymbolTable::addMultiplierAuxiliaryVar(int index) noexcept(false) } int -SymbolTable::addDiffForwardAuxiliaryVar(int orig_symb_id, expr_t expr_arg) noexcept(false) +SymbolTable::addDiffForwardAuxiliaryVar(int orig_symb_id, int orig_lead_lag, expr_t expr_arg) noexcept(false) { ostringstream varname; int symb_id; @@ -666,7 +660,7 @@ SymbolTable::addDiffForwardAuxiliaryVar(int orig_symb_id, expr_t expr_arg) noexc exit(EXIT_FAILURE); } - aux_vars.emplace_back(symb_id, AuxVarType::diffForward, orig_symb_id, 0, 0, 0, expr_arg, ""); + aux_vars.emplace_back(symb_id, AuxVarType::diffForward, orig_symb_id, orig_lead_lag, 0, 0, expr_arg, ""); return symb_id; } @@ -996,35 +990,29 @@ SymbolTable::writeJsonOutput(ostream &output) const { case AuxVarType::endoLead: case AuxVarType::exoLead: + case AuxVarType::expectation: + case AuxVarType::pacExpectation: + case AuxVarType::pacTargetNonstationary: break; case AuxVarType::endoLag: case AuxVarType::exoLag: + case AuxVarType::diffLag: + case AuxVarType::diffLead: + case AuxVarType::diffForward: output << R"(, "orig_index": )" << getTypeSpecificID(aux_vars[i].get_orig_symb_id())+1 << R"(, "orig_lead_lag": )" << aux_vars[i].get_orig_lead_lag(); break; case AuxVarType::unaryOp: + output << R"(, "unary_op": ")" << aux_vars[i].get_unary_op() << R"(")"; + // NB: Fallback! + case AuxVarType::diff: if (aux_vars[i].get_orig_symb_id() >= 0) output << R"(, "orig_index": )" << getTypeSpecificID(aux_vars[i].get_orig_symb_id())+1 - << R"(, "orig_lead_lag": )" << aux_vars[i].get_orig_lead_lag() - << R"(, "unary_op": ")" << aux_vars[i].get_unary_op() << R"(")"; + << R"(, "orig_lead_lag": )" << aux_vars[i].get_orig_lead_lag(); break; case AuxVarType::multiplier: output << R"(, "eq_nbr": )" << aux_vars[i].get_equation_number_for_multiplier() + 1; break; - case AuxVarType::diffForward: - output << R"(, orig_index": )" << getTypeSpecificID(aux_vars[i].get_orig_symb_id())+1; - break; - case AuxVarType::expectation: - case AuxVarType::pacExpectation: - case AuxVarType::pacTargetNonstationary: - break; - case AuxVarType::diff: - case AuxVarType::diffLag: - case AuxVarType::diffLead: - if (aux_vars[i].get_orig_symb_id() >= 0) - output << R"(, "orig_index": )" << getTypeSpecificID(aux_vars[i].get_orig_symb_id())+1 - << R"(, "orig_lead_lag": )" << aux_vars[i].get_orig_lead_lag(); - break; } if (expr_t orig_expr = aux_vars[i].get_expr_node(); diff --git a/src/SymbolTable.hh b/src/SymbolTable.hh index a0f4f325..373fe32b 100644 --- a/src/SymbolTable.hh +++ b/src/SymbolTable.hh @@ -63,13 +63,18 @@ class AuxVarInfo private: int symb_id; //!< Symbol ID of the auxiliary variable AuxVarType type; //!< Its type - int orig_symb_id; /* Symbol ID of the endo of the original model represented - by this aux var. Used by endoLag, endoLead, exoLag, - exoLead, diffForward, varModel, diff, diffLag, diffLead - and unaryOp */ - int orig_lead_lag; /* Lead/lag of the endo of the original model represented - by this aux var. Used by endoLag, endoLead, exoLag, - exoLead, varModel, unaryOp, diff, diffLag, diffLead */ + int orig_symb_id; /* Symbol ID of the (only) endo that appears on the RHS of + the definition of this auxvar. + Used by endoLag, exoLag, diffForward, diff, diffLag, + diffLead and unaryOp. + For diff and unaryOp, if the argument expression is more complex + than than a simple variable, this value is equal to -1. */ + int orig_lead_lag; /* Lead/lag of the (only) endo as it appears on the RHS of the definition + of this auxvar. Only used if orig_symb_id is used. + (in particular, for diff and unaryOp, unused if orig_symb_id == -1). + For diff and diffForward, since the definition of the + auxvar is a time difference, the value corresponds to the + time index of the first term of that difference. */ int equation_number_for_multiplier; //!< Stores the original constraint equation number associated with this aux var. Only used for avMultiplier. int information_set; //! Argument of expectation operator. Only used for avExpectation. expr_t expr_node; //! Auxiliary variable definition @@ -308,7 +313,7 @@ public: \param[in] orig_symb_id The symb_id of the forward variable \return the symbol ID of the new symbol */ - int addDiffForwardAuxiliaryVar(int orig_symb_id, expr_t arg) noexcept(false); + int addDiffForwardAuxiliaryVar(int orig_symb_id, int orig_lead_lag, expr_t arg) noexcept(false); //! Searches auxiliary variables which are substitutes for a given symbol_id and lead/lag /*! The search is only performed among auxiliary variables of endo/exo lag. -- GitLab