From 50d5b916e24438cb63a7638b8fd031796f8aa63e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Villemot?= <sebastien@dynare.org> Date: Wed, 20 Jul 2022 14:32:57 +0200 Subject: [PATCH] Remove constructor and accessors for AuxVarInfo Rather make all data members public and const, and use aggregate-initialization. --- src/DynamicModel.cc | 10 ++-- src/ExprNode.cc | 14 +++--- src/SymbolTable.cc | 112 +++++++++++++++++++------------------------- src/SymbolTable.hh | 92 ++++++++++-------------------------- 4 files changed, 87 insertions(+), 141 deletions(-) diff --git a/src/DynamicModel.cc b/src/DynamicModel.cc index 5a0d951f..8174bc13 100644 --- a/src/DynamicModel.cc +++ b/src/DynamicModel.cc @@ -2713,23 +2713,23 @@ DynamicModel::getVARDerivIDs(int lhs_symb_id, int lead_lag) const continue; } - if (avi->get_type() == AuxVarType::endoLag && avi->get_orig_symb_id().value() == lhs_symb_id - && avi->get_orig_lead_lag().value() + lead_lag2 == lead_lag) + if (avi->type == AuxVarType::endoLag && avi->orig_symb_id.value() == lhs_symb_id + && avi->orig_lead_lag.value() + lead_lag2 == lead_lag) deriv_ids.push_back(deriv_id2); // Handle diff lag auxvar, possibly nested several times int diff_lag_depth = 0; - while (avi->get_type() == AuxVarType::diffLag) + while (avi->type == AuxVarType::diffLag) { diff_lag_depth++; - if (avi->get_orig_symb_id() == lhs_symb_id && lead_lag2 - diff_lag_depth == lead_lag) + if (avi->orig_symb_id == lhs_symb_id && lead_lag2 - diff_lag_depth == lead_lag) { deriv_ids.push_back(deriv_id2); break; } try { - avi = &symbol_table.getAuxVarInfo(avi->get_orig_symb_id().value()); + avi = &symbol_table.getAuxVarInfo(avi->orig_symb_id.value()); } catch (SymbolTable::UnknownSymbolIDException) { diff --git a/src/ExprNode.cc b/src/ExprNode.cc index 1bd72faf..3acfc967 100644 --- a/src/ExprNode.cc +++ b/src/ExprNode.cc @@ -9101,14 +9101,14 @@ ExprNode::matchParamTimesTargetMinusVariable(int symb_id) const return false; if (datatree.symbol_table.isAuxiliaryVariable(target->symb_id)) { - auto avi = datatree.symbol_table.getAuxVarInfo(target->symb_id); - if (avi.get_type() == AuxVarType::pacTargetNonstationary && target->lag == -1) + auto &avi = datatree.symbol_table.getAuxVarInfo(target->symb_id); + if (avi.type == AuxVarType::pacTargetNonstationary && target->lag == -1) return true; - return (avi.get_type() == AuxVarType::unaryOp - && avi.get_unary_op() == "log" - && avi.get_orig_symb_id() - && !datatree.symbol_table.isAuxiliaryVariable(*avi.get_orig_symb_id()) - && target->lag + avi.get_orig_lead_lag().value() == -1); + return (avi.type == AuxVarType::unaryOp + && avi.unary_op == "log" + && avi.orig_symb_id + && !datatree.symbol_table.isAuxiliaryVariable(*avi.orig_symb_id) + && target->lag + avi.orig_lead_lag.value() == -1); } else return target->lag == -1; diff --git a/src/SymbolTable.cc b/src/SymbolTable.cc index 949a7fee..055fedd5 100644 --- a/src/SymbolTable.cc +++ b/src/SymbolTable.cc @@ -29,20 +29,6 @@ #include "SymbolTable.hh" -AuxVarInfo::AuxVarInfo(int symb_id_arg, AuxVarType type_arg, optional<int> orig_symb_id_arg, - optional<int> orig_lead_lag_arg, int equation_number_for_multiplier_arg, - int information_set_arg, expr_t expr_node_arg, string unary_op_arg) : - symb_id{symb_id_arg}, - type{type_arg}, - orig_symb_id{orig_symb_id_arg}, - orig_lead_lag{orig_lead_lag_arg}, - equation_number_for_multiplier{equation_number_for_multiplier_arg}, - information_set{information_set_arg}, - expr_node{expr_node_arg}, - unary_op{move(unary_op_arg)} -{ -} - int SymbolTable::addSymbol(const string &name, SymbolType type, const string &tex_name, const vector<pair<string, string>> &partition_value) noexcept(false) { @@ -343,9 +329,9 @@ SymbolTable::writeOutput(ostream &output) const noexcept(false) else for (int i = 0; i < static_cast<int>(aux_vars.size()); i++) { - output << "M_.aux_vars(" << i+1 << ").endo_index = " << getTypeSpecificID(aux_vars[i].get_symb_id())+1 << ";" << endl + output << "M_.aux_vars(" << i+1 << ").endo_index = " << getTypeSpecificID(aux_vars[i].symb_id)+1 << ";" << endl << "M_.aux_vars(" << i+1 << ").type = " << aux_vars[i].get_type_id() << ";" << endl; - switch (aux_vars[i].get_type()) + switch (aux_vars[i].type) { case AuxVarType::endoLead: case AuxVarType::exoLead: @@ -359,23 +345,23 @@ SymbolTable::writeOutput(ostream &output) const noexcept(false) case AuxVarType::diffLag: case AuxVarType::diffLead: case AuxVarType::diffForward: - output << "M_.aux_vars(" << i+1 << ").orig_index = " << getTypeSpecificID(aux_vars[i].get_orig_symb_id().value())+1 << ";" << endl - << "M_.aux_vars(" << i+1 << ").orig_lead_lag = " << aux_vars[i].get_orig_lead_lag().value() << ";" << endl; + output << "M_.aux_vars(" << i+1 << ").orig_index = " << getTypeSpecificID(aux_vars[i].orig_symb_id.value())+1 << ";" << endl + << "M_.aux_vars(" << i+1 << ").orig_lead_lag = " << aux_vars[i].orig_lead_lag.value() << ";" << endl; break; case AuxVarType::unaryOp: - output << "M_.aux_vars(" << i+1 << ").unary_op = '" << aux_vars[i].get_unary_op() << "';" << endl; + output << "M_.aux_vars(" << i+1 << ").unary_op = '" << aux_vars[i].unary_op << "';" << endl; [[fallthrough]]; case AuxVarType::diff: - if (aux_vars[i].get_orig_symb_id()) - output << "M_.aux_vars(" << i+1 << ").orig_index = " << getTypeSpecificID(*aux_vars[i].get_orig_symb_id())+1 << ";" << endl - << "M_.aux_vars(" << i+1 << ").orig_lead_lag = " << aux_vars[i].get_orig_lead_lag().value() << ";" << endl; + if (aux_vars[i].orig_symb_id) + output << "M_.aux_vars(" << i+1 << ").orig_index = " << getTypeSpecificID(*aux_vars[i].orig_symb_id)+1 << ";" << endl + << "M_.aux_vars(" << i+1 << ").orig_lead_lag = " << aux_vars[i].orig_lead_lag.value() << ";" << endl; break; case AuxVarType::multiplier: - output << "M_.aux_vars(" << i+1 << ").eq_nbr = " << aux_vars[i].get_equation_number_for_multiplier() + 1 << ";" << endl; + output << "M_.aux_vars(" << i+1 << ").eq_nbr = " << aux_vars[i].equation_number_for_multiplier + 1 << ";" << endl; break; } - if (expr_t orig_expr = aux_vars[i].get_expr_node(); + if (expr_t orig_expr = aux_vars[i].expr_node; orig_expr) { output << "M_.aux_vars(" << i+1 << ").orig_expr = '"; @@ -682,40 +668,40 @@ int SymbolTable::searchAuxiliaryVars(int orig_symb_id, int orig_lead_lag) const noexcept(false) { for (const auto &aux_var : aux_vars) - if ((aux_var.get_type() == AuxVarType::endoLag || aux_var.get_type() == AuxVarType::exoLag) - && aux_var.get_orig_symb_id() == orig_symb_id && aux_var.get_orig_lead_lag() == orig_lead_lag) - return aux_var.get_symb_id(); + if ((aux_var.type == AuxVarType::endoLag || aux_var.type == AuxVarType::exoLag) + && aux_var.orig_symb_id == orig_symb_id && aux_var.orig_lead_lag == orig_lead_lag) + return aux_var.symb_id; throw SearchFailedException(orig_symb_id, orig_lead_lag); } int -SymbolTable::getOrigSymbIdForAuxVar(int aux_var_symb_id) const noexcept(false) +SymbolTable::getOrigSymbIdForAuxVar(int aux_var_symb_id_arg) const noexcept(false) { for (const auto &aux_var : aux_vars) - if ((aux_var.get_type() == AuxVarType::endoLag - || aux_var.get_type() == AuxVarType::exoLag - || aux_var.get_type() == AuxVarType::diff - || aux_var.get_type() == AuxVarType::diffLag - || aux_var.get_type() == AuxVarType::diffLead - || aux_var.get_type() == AuxVarType::diffForward - || aux_var.get_type() == AuxVarType::unaryOp) - && aux_var.get_symb_id() == aux_var_symb_id) - if (optional<int> r = aux_var.get_orig_symb_id(); r) + if ((aux_var.type == AuxVarType::endoLag + || aux_var.type == AuxVarType::exoLag + || aux_var.type == AuxVarType::diff + || aux_var.type == AuxVarType::diffLag + || aux_var.type == AuxVarType::diffLead + || aux_var.type == AuxVarType::diffForward + || aux_var.type == AuxVarType::unaryOp) + && aux_var.symb_id == aux_var_symb_id_arg) + if (optional<int> r = aux_var.orig_symb_id; r) return *r; else - throw UnknownSymbolIDException(aux_var_symb_id); // Some diff and unaryOp auxvars have orig_symb_id unset - throw UnknownSymbolIDException(aux_var_symb_id); + throw UnknownSymbolIDException(aux_var_symb_id_arg); // Some diff and unaryOp auxvars have orig_symb_id unset + throw UnknownSymbolIDException(aux_var_symb_id_arg); } pair<int, int> SymbolTable::unrollDiffLeadLagChain(int symb_id, int lag) const noexcept(false) { for (const auto &aux_var : aux_vars) - if (aux_var.get_symb_id() == symb_id) - if (aux_var.get_type() == AuxVarType::diffLag || aux_var.get_type() == AuxVarType::diffLead) + if (aux_var.symb_id == symb_id) + if (aux_var.type == AuxVarType::diffLag || aux_var.type == AuxVarType::diffLead) { - auto [orig_symb_id, orig_lag] = unrollDiffLeadLagChain(aux_var.get_orig_symb_id().value(), lag); - return { orig_symb_id, orig_lag + aux_var.get_orig_lead_lag().value() }; + auto [orig_symb_id, orig_lag] = unrollDiffLeadLagChain(aux_var.orig_symb_id.value(), lag); + return { orig_symb_id, orig_lag + aux_var.orig_lead_lag.value() }; } return { symb_id, lag }; } @@ -725,8 +711,8 @@ SymbolTable::getAuxiliaryVarsExprNode(int symb_id) const noexcept(false) // throw exception if it is a Lagrange multiplier { for (const auto &aux_var : aux_vars) - if (aux_var.get_symb_id() == symb_id) - if (expr_t expr_node = aux_var.get_expr_node(); + if (aux_var.symb_id == symb_id) + if (expr_t expr_node = aux_var.expr_node; expr_node) return expr_node; else @@ -874,7 +860,7 @@ bool SymbolTable::isAuxiliaryVariable(int symb_id) const { for (const auto &aux_var : aux_vars) - if (aux_var.get_symb_id() == symb_id) + if (aux_var.symb_id == symb_id) return true; return false; } @@ -883,7 +869,7 @@ bool SymbolTable::isAuxiliaryVariableButNotMultiplier(int symb_id) const { for (const auto &aux_var : aux_vars) - if (aux_var.get_symb_id() == symb_id && aux_var.get_type() != AuxVarType::multiplier) + if (aux_var.symb_id == symb_id && aux_var.type != AuxVarType::multiplier) return true; return false; } @@ -892,10 +878,10 @@ bool SymbolTable::isDiffAuxiliaryVariable(int symb_id) const { for (const auto &aux_var : aux_vars) - if (aux_var.get_symb_id() == symb_id - && (aux_var.get_type() == AuxVarType::diff - || aux_var.get_type() == AuxVarType::diffLag - || aux_var.get_type() == AuxVarType::diffLead)) + if (aux_var.symb_id == symb_id + && (aux_var.type == AuxVarType::diff + || aux_var.type == AuxVarType::diffLag + || aux_var.type == AuxVarType::diffLead)) return true; return false; } @@ -977,9 +963,9 @@ SymbolTable::writeJsonOutput(ostream &output) const { if (i != 0) output << ", "; - output << R"({"endo_index": )" << getTypeSpecificID(aux_vars[i].get_symb_id())+1 + output << R"({"endo_index": )" << getTypeSpecificID(aux_vars[i].symb_id)+1 << R"(, "type": )" << aux_vars[i].get_type_id(); - switch (aux_vars[i].get_type()) + switch (aux_vars[i].type) { case AuxVarType::endoLead: case AuxVarType::exoLead: @@ -993,23 +979,23 @@ SymbolTable::writeJsonOutput(ostream &output) const case AuxVarType::diffLag: case AuxVarType::diffLead: case AuxVarType::diffForward: - output << R"(, "orig_index": )" << getTypeSpecificID(aux_vars[i].get_orig_symb_id().value())+1 - << R"(, "orig_lead_lag": )" << aux_vars[i].get_orig_lead_lag().value(); + output << R"(, "orig_index": )" << getTypeSpecificID(aux_vars[i].orig_symb_id.value())+1 + << R"(, "orig_lead_lag": )" << aux_vars[i].orig_lead_lag.value(); break; case AuxVarType::unaryOp: - output << R"(, "unary_op": ")" << aux_vars[i].get_unary_op() << R"(")"; + output << R"(, "unary_op": ")" << aux_vars[i].unary_op << R"(")"; [[fallthrough]]; case AuxVarType::diff: - if (aux_vars[i].get_orig_symb_id()) - output << R"(, "orig_index": )" << getTypeSpecificID(*aux_vars[i].get_orig_symb_id())+1 - << R"(, "orig_lead_lag": )" << aux_vars[i].get_orig_lead_lag().value(); + if (aux_vars[i].orig_symb_id) + output << R"(, "orig_index": )" << getTypeSpecificID(*aux_vars[i].orig_symb_id)+1 + << R"(, "orig_lead_lag": )" << aux_vars[i].orig_lead_lag.value(); break; case AuxVarType::multiplier: - output << R"(, "eq_nbr": )" << aux_vars[i].get_equation_number_for_multiplier() + 1; + output << R"(, "eq_nbr": )" << aux_vars[i].equation_number_for_multiplier + 1; break; } - if (expr_t orig_expr = aux_vars[i].get_expr_node(); + if (expr_t orig_expr = aux_vars[i].expr_node; orig_expr) { output << R"(, "orig_expr": ")"; @@ -1058,8 +1044,8 @@ optional<int> SymbolTable::getEquationNumberForMultiplier(int symb_id) const { for (const auto &aux_var : aux_vars) - if (aux_var.get_symb_id() == symb_id && aux_var.get_type() == AuxVarType::multiplier) - return aux_var.get_equation_number_for_multiplier(); + if (aux_var.symb_id == symb_id && aux_var.type == AuxVarType::multiplier) + return aux_var.equation_number_for_multiplier; return nullopt; } diff --git a/src/SymbolTable.hh b/src/SymbolTable.hh index fb321d2a..818334fc 100644 --- a/src/SymbolTable.hh +++ b/src/SymbolTable.hh @@ -57,76 +57,36 @@ enum class AuxVarType }; //! Information on some auxiliary variables -class AuxVarInfo +struct AuxVarInfo { -private: - int symb_id; //!< Symbol ID of the auxiliary variable - AuxVarType type; //!< Its type - optional<int> orig_symb_id; /* Symbol ID of the (only) endo that appears on the RHS of - the definition of this auxvar. - Used by endoLag, exoLag, diffForward, logTransform, diff, diffLag, - diffLead and unaryOp. - For diff and unaryOp, if the argument expression is more complex - than than a simple variable, this value is unset - (hence the need for std::optional). */ - optional<int> orig_lead_lag; /* Lead/lag of the (only) endo as it appears on the RHS of the definition - of this auxvar. Only set if orig_symb_id is set - (in particular, for diff and unaryOp, unset - if orig_symb_id is unset). - For diff and diffForward, since the definition of the - auxvar is a time difference, the value corresponds to the - time index of the first term of that difference. */ - int equation_number_for_multiplier; //!< Stores the original constraint equation number associated with this aux var. Only used for avMultiplier. - int information_set; //! Argument of expectation operator. Only used for avExpectation. - expr_t expr_node; //! Auxiliary variable definition - string unary_op; //! Used with AuxUnaryOp -public: - AuxVarInfo(int symb_id_arg, AuxVarType type_arg, optional<int> orig_symb_id_arg, optional<int> orig_lead_lag_arg, int equation_number_for_multiplier_arg, int information_set_arg, expr_t expr_node_arg, string unary_op_arg); - int - get_symb_id() const - { - return symb_id; - }; - AuxVarType - get_type() const - { - return type; - }; + const int symb_id; // Symbol ID of the auxiliary variable + const AuxVarType type; // Its type + const optional<int> orig_symb_id; /* Symbol ID of the (only) endo that appears on the RHS of + the definition of this auxvar. + Used by endoLag, exoLag, diffForward, logTransform, diff, + diffLag, diffLead and unaryOp. + For diff and unaryOp, if the argument expression is more complex + than than a simple variable, this value is unset + (hence the need for std::optional). */ + const optional<int> orig_lead_lag; /* Lead/lag of the (only) endo as it appears on the RHS of the + definition of this auxvar. Only set if orig_symb_id is set + (in particular, for diff and unaryOp, unset + if orig_symb_id is unset). + For diff and diffForward, since the definition of the + auxvar is a time difference, the value corresponds to the + time index of the first term of that difference. */ + const int equation_number_for_multiplier; /* Stores the original constraint equation number + associated with this aux var. Only used for + avMultiplier. */ + const int information_set; // Argument of expectation operator. Only used for avExpectation. + const expr_t expr_node; // Auxiliary variable definition + const string unary_op; // Used with AuxUnaryOp + int get_type_id() const { return static_cast<int>(type); } - optional<int> - get_orig_symb_id() const - { - return orig_symb_id; - }; - optional<int> - get_orig_lead_lag() const - { - return orig_lead_lag; - }; - int - get_equation_number_for_multiplier() const - { - return equation_number_for_multiplier; - }; - int - get_information_set() const - { - return information_set; - }; - expr_t - get_expr_node() const - { - return expr_node; - }; - const string & - get_unary_op() const - { - return unary_op; - }; }; //! Stores the symbol table @@ -318,7 +278,7 @@ public: this auxvar (either because it’s of the wrong type, or because there is no such orig var for this specific auxvar, in case of complex expressions in diff or unaryOp). */ - int getOrigSymbIdForAuxVar(int aux_var_symb_id) const noexcept(false); + int getOrigSymbIdForAuxVar(int aux_var_symb_id_arg) const noexcept(false); /* Unrolls a chain of diffLag or diffLead aux vars until it founds a (regular) diff aux var. In other words: - if the arg is a (regu) diff aux var, returns the arg @@ -583,7 +543,7 @@ inline const AuxVarInfo & SymbolTable::getAuxVarInfo(int symb_id) const { for (const auto &aux_var : aux_vars) - if (aux_var.get_symb_id() == symb_id) + if (aux_var.symb_id == symb_id) return aux_var; throw UnknownSymbolIDException(symb_id); } -- GitLab