diff --git a/src/DynamicModel.cc b/src/DynamicModel.cc index 858c708788bc6eeee213b14c686ae6d1f5767662..4d687bf66d20163e05b96d7130637fb8e10b958a 100644 --- a/src/DynamicModel.cc +++ b/src/DynamicModel.cc @@ -4520,7 +4520,7 @@ DynamicModel::computeChainRuleJacobian() int nb_recursives = blocks[blk].getRecursiveSize(); // Create a map from recursive vars to their defining (normalized) equation - map<int, expr_t> recursive_vars; + map<int, BinaryOpNode *> recursive_vars; for (int i = 0; i < nb_recursives; i++) { int deriv_id = getDerivID(symbol_table.getID(SymbolType::endogenous, getBlockVariableID(blk, i)), 0); diff --git a/src/ExprNode.cc b/src/ExprNode.cc index dff1abb6872940be30866c21074dd7b520efdfc4..6de8f8447deb4a37ff626bdb50d381a69da8a857 100644 --- a/src/ExprNode.cc +++ b/src/ExprNode.cc @@ -483,7 +483,7 @@ NumConstNode::normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs } expr_t -NumConstNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables) +NumConstNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) { return datatree.Zero; } @@ -1307,7 +1307,7 @@ VariableNode::normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs } expr_t -VariableNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables) +VariableNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) { switch (get_type()) { @@ -1322,11 +1322,7 @@ VariableNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recur // If there is in the equation a recursive variable we could use a chaine rule derivation else if (auto it = recursive_variables.find(datatree.getDerivID(symb_id, lag)); it != recursive_variables.end()) - { - map<int, expr_t> recursive_vars2(recursive_variables); - recursive_vars2.erase(it->first); - return datatree.AddUMinus(it->second->getChainRuleDerivative(deriv_id, recursive_vars2)); - } + return it->second->arg2->getChainRuleDerivative(deriv_id, recursive_variables); else return datatree.Zero; @@ -3088,7 +3084,7 @@ UnaryOpNode::normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) } expr_t -UnaryOpNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables) +UnaryOpNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) { expr_t darg = arg->getChainRuleDerivative(deriv_id, recursive_variables); return composeDerivatives(darg, deriv_id); @@ -4820,7 +4816,7 @@ BinaryOpNode::normalizeEquation(int symb_id, int lag) const } expr_t -BinaryOpNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables) +BinaryOpNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) { expr_t darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables); expr_t darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables); @@ -6089,7 +6085,7 @@ TrinaryOpNode::normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rh } expr_t -TrinaryOpNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables) +TrinaryOpNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) { expr_t darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables); expr_t darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables); @@ -6508,7 +6504,7 @@ AbstractExternalFunctionNode::computeDerivative(int deriv_id) } expr_t -AbstractExternalFunctionNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables) +AbstractExternalFunctionNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) { assert(datatree.external_functions_table.getNargs(symb_id) > 0); vector<expr_t> dargs; @@ -8185,7 +8181,7 @@ VarExpectationNode::computeDerivative(int deriv_id) } expr_t -VarExpectationNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables) +VarExpectationNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) { cerr << "VarExpectationNode::getChainRuleDerivative not implemented." << endl; exit(EXIT_FAILURE); @@ -8582,7 +8578,7 @@ PacExpectationNode::computeDerivative(int deriv_id) } expr_t -PacExpectationNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables) +PacExpectationNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) { cerr << "PacExpectationNode::getChainRuleDerivative: shouldn't arrive here." << endl; exit(EXIT_FAILURE); diff --git a/src/ExprNode.hh b/src/ExprNode.hh index e7b1f910ec061928477d3f812f7f7de60d80a853..4e3b2cd4d86a3c4f8ffb04fa11e4465ddb6f8b66 100644 --- a/src/ExprNode.hh +++ b/src/ExprNode.hh @@ -268,7 +268,7 @@ public: \param deriv_id The derivation ID with respect to which we are derivating \param recursive_variables Contains the derivation ID for which chain rules must be applied. Keys are derivation IDs, values are equations of the form x=f(y) where x is the key variable and x doesn't appear in y */ - virtual expr_t getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables) = 0; + virtual expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) = 0; //! Returns precedence of node /*! Equals 100 for constants, variables, unary ops, and temporary terms */ @@ -747,7 +747,7 @@ public: void computeXrefs(EquationInfo &ei) const override; void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override; BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override; - expr_t getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables) override; + expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override; int maxEndoLead() const override; int maxExoLead() const override; int maxEndoLag() const override; @@ -820,7 +820,7 @@ public: SymbolType get_type() const; void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override; BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override; - expr_t getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables) override; + expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override; int maxEndoLead() const override; int maxExoLead() const override; int maxEndoLag() const override; @@ -921,7 +921,7 @@ public: void computeXrefs(EquationInfo &ei) const override; void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override; BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override; - expr_t getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables) override; + expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override; int maxEndoLead() const override; int maxExoLead() const override; int maxEndoLag() const override; @@ -1028,7 +1028,7 @@ public: //! Try to normalize an equation with respect to a given dynamic variable. /*! Should only be called on Equal nodes. The variable must appear in the equation. */ BinaryOpNode *normalizeEquation(int symb_id, int lag) const; - expr_t getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables) override; + expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override; int maxEndoLead() const override; int maxExoLead() const override; int maxEndoLag() const override; @@ -1160,7 +1160,7 @@ public: void computeXrefs(EquationInfo &ei) const override; void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override; BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override; - expr_t getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables) override; + expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override; int maxEndoLead() const override; int maxExoLead() const override; int maxEndoLag() const override; @@ -1272,7 +1272,7 @@ public: void computeXrefs(EquationInfo &ei) const override = 0; void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override; BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override; - expr_t getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables) override; + expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override; int maxEndoLead() const override; int maxExoLead() const override; int maxEndoLag() const override; @@ -1458,7 +1458,7 @@ public: expr_t decreaseLeadsLags(int n) const override; void prepareForDerivation() override; expr_t computeDerivative(int deriv_id) override; - expr_t getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables) override; + expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override; bool containsExternalFunction() const override; double eval(const eval_context_t &eval_context) const noexcept(false) override; void computeXrefs(EquationInfo &ei) const override; @@ -1531,7 +1531,7 @@ public: expr_t decreaseLeadsLags(int n) const override; void prepareForDerivation() override; expr_t computeDerivative(int deriv_id) override; - expr_t getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables) override; + expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override; bool containsExternalFunction() const override; double eval(const eval_context_t &eval_context) const noexcept(false) override; void computeXrefs(EquationInfo &ei) const override; diff --git a/src/StaticModel.cc b/src/StaticModel.cc index 68e6e4182bcff63aefbc75a454326d55ccd20996..cd506fed7dc140a3cd926945974db27057f3d6ef 100644 --- a/src/StaticModel.cc +++ b/src/StaticModel.cc @@ -2065,7 +2065,7 @@ StaticModel::computeChainRuleJacobian() { int nb_recursives = blocks[blk].getRecursiveSize(); - map<int, expr_t> recursive_vars; + map<int, BinaryOpNode *> recursive_vars; for (int i = 0; i < nb_recursives; i++) { int deriv_id = getDerivID(symbol_table.getID(SymbolType::endogenous, getBlockVariableID(blk, i)), 0);