diff --git a/src/ExprNode.cc b/src/ExprNode.cc index f2d585b84110eac05609fce902889a3b60873273..efc38c0de8b3f8145e55dbb7caab816129aedb36 100644 --- a/src/ExprNode.cc +++ b/src/ExprNode.cc @@ -3107,9 +3107,20 @@ UnaryOpNode::collectDynamicVariables(SymbolType type_arg, set<pair<int, int>> &r void UnaryOpNode::computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const { - arg->computeSubExprContainingVariable(symb_id, lag, contain_var); - if (contain_var.contains(arg)) - contain_var.insert(const_cast<UnaryOpNode *>(this)); + if (op_code == UnaryOpcode::diff) + { + expr_t lagged_arg {arg->decreaseLeadsLags(1)}; + expr_t substitute {datatree.AddMinus(arg, lagged_arg)}; + substitute->computeSubExprContainingVariable(symb_id, lag, contain_var); + if (contain_var.contains(arg) || contain_var.contains(lagged_arg)) + contain_var.insert(const_cast<UnaryOpNode *>(this)); + } + else + { + arg->computeSubExprContainingVariable(symb_id, lag, contain_var); + if (contain_var.contains(arg)) + contain_var.insert(const_cast<UnaryOpNode *>(this)); + } } BinaryOpNode * @@ -3173,6 +3184,13 @@ UnaryOpNode::normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) case UnaryOpcode::cbrt: rhs = datatree.AddPower(rhs, datatree.Three); break; + case UnaryOpcode::diff: + /* Recursively call the function on arg-arg(-1). + This is necessary to deal with the 3 different possible cases: + — var in arg but not arg(-1); + — var in arg(-1) but not arg; + — var in both arg and arg(-1). */ + return datatree.AddMinus(arg, arg->decreaseLeadsLags(1))->normalizeEquationHelper(contain_var, rhs); default: throw NormalizationFailed(); } diff --git a/src/ExprNode.hh b/src/ExprNode.hh index 1fdee2c66e56a0efce11ad3e2f6a9a74d3ce2e43..58b1024ed9d011c659bbfd090a493e00aa4a85af 100644 --- a/src/ExprNode.hh +++ b/src/ExprNode.hh @@ -281,6 +281,14 @@ protected: lag_equivalence_table_t for an explanation of these concepts. */ pair<expr_t, int> getLagEquivalenceClass() const; + /* Computes the set of all sub-expressions that contain the variable + (symb_id, lag). + Note that if a diff operator is encountered: + - diff(expr) will be added to the output set if either expr or expr(-1) + contains the variable; + - the method will be called recursively on expr-expr(-1) */ + virtual void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const = 0; + public: ExprNode(DataTree &datatree_arg, int idx_arg); virtual ~ExprNode() = default; @@ -457,9 +465,6 @@ public: // virtual void computeXrefs(set<int> ¶m, set<int> &endo, set<int> &exo, set<int> &exo_det) const = 0; virtual void computeXrefs(EquationInfo &ei) const = 0; - // Computes the set of all sub-expressions that contain the variable (symb_id, lag) - virtual void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const = 0; - //! Helper for normalization of equations /*! Normalize the equation this = rhs. Must be called on a node containing the desired LHS variable. @@ -810,6 +815,7 @@ private: expr_t computeDerivative(int deriv_id) override; protected: void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> ¶m_id, double &constant, bool at_denominator) const override; + void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override; public: NumConstNode(DataTree &datatree_arg, int idx_arg, int id_arg); void prepareForDerivation() override; @@ -823,7 +829,6 @@ public: void writeBytecodeOutput(BytecodeWriter &code_file, ExprNodeBytecodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const override; expr_t toStatic(DataTree &static_datatree) const override; void computeXrefs(EquationInfo &ei) const override; - void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override; BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override; expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override; int maxEndoLead() const override; @@ -882,6 +887,7 @@ private: expr_t computeDerivative(int deriv_id) override; protected: void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> ¶m_id, double &constant, bool at_denominator) const override; + void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override; public: VariableNode(DataTree &datatree_arg, int idx_arg, int symb_id_arg, int lag_arg); void prepareForDerivation() override; @@ -896,7 +902,6 @@ public: expr_t toStatic(DataTree &static_datatree) const override; void computeXrefs(EquationInfo &ei) const override; SymbolType get_type() const; - void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override; BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override; expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override; int maxEndoLead() const override; @@ -950,6 +955,7 @@ class UnaryOpNode : public ExprNode { protected: void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> ¶m_id, double &constant, bool at_denominator) const override; + void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override; public: const expr_t arg; //! Stores the information set. Only used for expectation operator @@ -999,7 +1005,6 @@ public: void writeBytecodeOutput(BytecodeWriter &code_file, ExprNodeBytecodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const override; expr_t toStatic(DataTree &static_datatree) const override; void computeXrefs(EquationInfo &ei) const override; - void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override; BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override; expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override; int maxEndoLead() const override; @@ -1055,6 +1060,7 @@ class BinaryOpNode : public ExprNode { protected: void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> ¶m_id, double &constant, bool at_denominator) const override; + void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override; public: const expr_t arg1, arg2; const BinaryOpcode op_code; @@ -1104,7 +1110,6 @@ public: expr_t Compute_RHS(expr_t arg1, expr_t arg2, int op, int op_type) const; expr_t toStatic(DataTree &static_datatree) const override; void computeXrefs(EquationInfo &ei) const override; - void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override; BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override; //! Try to normalize an equation with respect to a given dynamic variable. /*! Should only be called on Equal nodes. The variable must appear in the equation. */ @@ -1203,6 +1208,8 @@ class TrinaryOpNode : public ExprNode public: const expr_t arg1, arg2, arg3; const TrinaryOpcode op_code; +protected: + void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override; private: expr_t computeDerivative(int deriv_id) override; int cost(int cost, bool is_matlab) const override; @@ -1245,7 +1252,6 @@ public: void writeBytecodeOutput(BytecodeWriter &code_file, ExprNodeBytecodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const override; expr_t toStatic(DataTree &static_datatree) const override; void computeXrefs(EquationInfo &ei) const override; - void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override; BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override; expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override; int maxEndoLead() const override; @@ -1325,6 +1331,7 @@ protected: function which is computed by the same external function call (i.e. it has the same so-called "Tef" index) */ virtual function<bool (expr_t)> sameTefTermPredicate() const = 0; + void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override; public: AbstractExternalFunctionNode(DataTree &datatree_arg, int idx_arg, int symb_id_arg, vector<expr_t> arguments_arg); @@ -1358,7 +1365,6 @@ public: void writeBytecodeOutput(BytecodeWriter &code_file, ExprNodeBytecodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const override = 0; expr_t toStatic(DataTree &static_datatree) const override = 0; void computeXrefs(EquationInfo &ei) const override = 0; - void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override; BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override; expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override; int maxEndoLead() const override; @@ -1565,7 +1571,6 @@ public: optional<int> findTargetVariable(int lhs_symb_id) const override; expr_t substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override; expr_t substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override; - void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override; BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override; void writeBytecodeOutput(BytecodeWriter &code_file, ExprNodeBytecodeOutputType output_type, const temporary_terms_t &temporary_terms, @@ -1585,6 +1590,8 @@ public: expr_t detrend(int symb_id, bool log_trend, expr_t trend) const override; expr_t removeTrendLeadLag(const map<int, expr_t> &trend_symbols_map) const override; expr_t substituteLogTransform(int orig_symb_id, int aux_symb_id) const override; +protected: + void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override; }; class VarExpectationNode : public SubModelNode