diff --git a/src/DynamicModel.cc b/src/DynamicModel.cc index 7dc7fb04da53f5871cdf538c3bb9033efc6d9de2..933ae44a0b1fe7b8d42e57a6865bfec6ef501d9d 100644 --- a/src/DynamicModel.cc +++ b/src/DynamicModel.cc @@ -4111,7 +4111,22 @@ DynamicModel::addPacModelConsistentExpectationEquation(const string &name, int d // Add diff nodes and eqs for pac_target_symb_id const VariableNode *target_base_diff_node; - expr_t diff_node_to_search = AddDiff(AddVariable(pac_target_symb_id)); + auto create_target_lag = [&](int lag) + { + if (symbol_table.isAuxiliaryVariable(pac_target_symb_id)) + { + // We know it is a log, see ExprNode::matchParamTimesTargetMinusVariable() + /* We don’t use SymbolTable::getOrigSymbIdForAuxVar(), because it + does not work for unary ops, and changing this behaviour might + break stuff that relies on an exception in this case. */ + auto avi = symbol_table.getAuxVarInfo(pac_target_symb_id); + return AddLog(AddVariable(avi.get_orig_symb_id(), lag)); + } + else + return dynamic_cast<ExprNode *>(AddVariable(pac_target_symb_id, lag)); + }; + + expr_t diff_node_to_search = AddDiff(create_target_lag(0)); if (auto sit = diff_subst_table.find(diff_node_to_search); sit != diff_subst_table.end()) target_base_diff_node = sit->second; @@ -4120,8 +4135,8 @@ DynamicModel::addPacModelConsistentExpectationEquation(const string &name, int d int symb_id = symbol_table.addDiffAuxiliaryVar(diff_node_to_search->idx, diff_node_to_search); target_base_diff_node = AddVariable(symb_id); auto neweq = AddEqual(const_cast<VariableNode *>(target_base_diff_node), - AddMinus(AddVariable(pac_target_symb_id), - AddVariable(pac_target_symb_id, -1))); + AddMinus(create_target_lag(0), + create_target_lag(-1))); addEquation(neweq, -1); addAuxEquation(neweq); neqs++; @@ -4131,7 +4146,7 @@ DynamicModel::addPacModelConsistentExpectationEquation(const string &name, int d const VariableNode *last_aux_var = target_base_diff_node; for (int i = 1; i <= pac_max_lag_m - 1; i++, neqs++) { - expr_t this_diff_node = AddDiff(AddVariable(pac_target_symb_id, i)); + expr_t this_diff_node = AddDiff(create_target_lag(i)); int symb_id = symbol_table.addDiffLeadAuxiliaryVar(this_diff_node->idx, this_diff_node, last_aux_var->symb_id, last_aux_var->lag); VariableNode *current_aux_var = AddVariable(symb_id); diff --git a/src/ExprNode.cc b/src/ExprNode.cc index 4f2095e398cf6c69d23c56a4e7230f223480a5fd..6074fa2760c648350d8abd3fa56e6547b36853fc 100644 --- a/src/ExprNode.cc +++ b/src/ExprNode.cc @@ -9094,10 +9094,27 @@ ExprNode::matchParamTimesTargetMinusVariable(int symb_id) const auto lhs_level = dynamic_cast<const VariableNode *>(bminus->arg2); auto target = dynamic_cast<const VariableNode *>(bminus->arg1); - if (lhs_level && lhs_level->symb_id == symb_id && target - && (target->get_type() == SymbolType::endogenous - || target->get_type() == SymbolType::exogenous)) + + auto check_target = [&]() + { + if (target->get_type() != SymbolType::endogenous + && target->get_type() != SymbolType::exogenous) + return false; + if (datatree.symbol_table.isAuxiliaryVariable(target->symb_id)) + { + auto avi = datatree.symbol_table.getAuxVarInfo(target->symb_id); + return (avi.get_type() == AuxVarType::unaryOp + && avi.get_unary_op() == "log" + && avi.get_orig_symb_id() != -1 + && !datatree.symbol_table.isAuxiliaryVariable(avi.get_orig_symb_id()) + && target->lag + avi.get_orig_lead_lag() == -1); + } + else + return target->lag == -1; + }; + + if (lhs_level && lhs_level->symb_id == symb_id && target && check_target()) return { dynamic_cast<VariableNode *>(param)->symb_id, target->symb_id }; else - throw MatchFailureException{"Neither factor is of the form (target-variable) where target is endo or exo"}; + throw MatchFailureException{"Neither factor is of the form (target-variable) where target is endo or exo (possibly logged), and has one lag"}; } diff --git a/src/ExprNode.hh b/src/ExprNode.hh index e44092163afa7cdf413723185c24c5eb4e263b8a..ab602113b8d8f2e39ba6ff69f52a8500c7d56543 100644 --- a/src/ExprNode.hh +++ b/src/ExprNode.hh @@ -659,7 +659,8 @@ public: /* Matches an expression of the form parameter*(var1-endo2). endo2 must correspond to symb_id. var1 must be an endogenous or an - exogenous. + exogenous; it must be of the form X(-1) or log(X(-1)) or log(X)(-1) (unary ops aux var), + where X itself is *not* an aux var. Returns the symbol IDs of the parameter and of var1. Throws a MatchFailureException otherwise */ pair<int, int> matchParamTimesTargetMinusVariable(int symb_id) const; diff --git a/src/SymbolTable.hh b/src/SymbolTable.hh index 1d31964a23bf9fe5bea2232ae00e6d59f0dcedcc..8c3eef32f6ccafb0126853811e754e03fec5e5ff 100644 --- a/src/SymbolTable.hh +++ b/src/SymbolTable.hh @@ -42,13 +42,15 @@ enum class AuxVarType exoLead = 2, //!< Substitute for exo leads >= 1 exoLag = 3, //!< Substitute for exo lags >= 1 expectation = 4, //!< Substitute for Expectation Operator - diffForward = 5, //!< Substitute for the differentiate of a forward variable + diffForward = 5, /* Substitute for the differentiate of a forward variable, + for the differentiate_forward_vars option. + N.B.: nothing to do with the diff() operator! */ multiplier = 6, //!< Multipliers for FOC of Ramsey Problem varModel = 7, //!< Variable for var_model with order > abs(min_lag()) present in model diff = 8, //!< Variable for Diff operator - diffLag = 9, //!< Variable for timing between Diff operators + diffLag = 9, //!< Variable for timing between Diff operators (lag) unaryOp = 10, //!< Variable for allowing the undiff operator to work when diff was taken of unary op, eg diff(log(x)) - diffLead = 11 //!< Variable for timing between Diff operators + diffLead = 11 //!< Variable for timing between Diff operators (lead) }; //! Information on some auxiliary variables @@ -57,8 +59,13 @@ class AuxVarInfo private: int symb_id; //!< Symbol ID of the auxiliary variable AuxVarType type; //!< Its type - int orig_symb_id; //!< Symbol ID of the endo of the original model represented by this aux var. Only used for avEndoLag and avExoLag. - int orig_lead_lag; //!< Lead/lag of the endo of the original model represented by this aux var. Only used for avEndoLag and avExoLag. + int orig_symb_id; /* Symbol ID of the endo of the original model represented + by this aux var. Used by endoLag, endoLead, exoLag, + exoLead, diffForward, varModel, diff, diffLag, diffLead + and unaryOp */ + int orig_lead_lag; /* Lead/lag of the endo of the original model represented + by this aux var. Used by endoLag, endoLead, exoLag, + exoLead, varModel, unaryOp, diff, diffLag, diffLead */ int equation_number_for_multiplier; //!< Stores the original constraint equation number associated with this aux var. Only used for avMultiplier. int information_set; //! Argument of expectation operator. Only used for avExpectation. expr_t expr_node; //! Auxiliary variable definition @@ -105,7 +112,7 @@ public: { return expr_node; }; - string + const string & get_unary_op() const { return unary_op; @@ -295,7 +302,12 @@ public: Throws an exception if match not found. */ int searchAuxiliaryVars(int orig_symb_id, int orig_lead_lag) const noexcept(false); - //! Serches aux_vars for the aux var represented by aux_var_symb_id and returns its associated orig_symb_id + /* Searches aux_vars for the aux var represented by aux_var_symb_id and + returns its associated orig_symb_id. + Works only for endoLag, exoLag, diff, diffLag, diffLead. + Throws an UnknownSymbolIDException otherwise. + N.B.: some code might rely on the fact that, in particular, it does not work on unaryOp + type (to be verified) */ int getOrigSymbIdForAuxVar(int aux_var_symb_id) const noexcept(false); //! Searches for diff aux var and finds the original lag associated with this variable int getOrigLeadLagForDiffAuxVar(int diff_aux_var_symb_id) const noexcept(false); @@ -414,6 +426,9 @@ public: int getUltimateOrigSymbID(int symb_id) const; //! If this is a Lagrange multiplier, return its associated equation number; otherwise return -1 int getEquationNumberForMultiplier(int symb_id) const; + /* Return all the information about a given auxiliary variable. Throws + UnknownSymbolIDException if it is not an aux var */ + const AuxVarInfo &getAuxVarInfo(int symb_id) const; }; inline void @@ -538,4 +553,13 @@ SymbolTable::orig_endo_nbr() const noexcept(false) return endo_nbr() - aux_vars.size(); } +inline const AuxVarInfo & +SymbolTable::getAuxVarInfo(int symb_id) const +{ + for (const auto &aux_var : aux_vars) + if (aux_var.get_symb_id() == symb_id) + return aux_var; + throw UnknownSymbolIDException(symb_id); +} + #endif