diff --git a/src/DynamicModel.cc b/src/DynamicModel.cc index 62230e78b8a86a0d0d5e0f105f2cd2a5965ba615..c874ef2c2f1125765e35896798351557225a8c61 100644 --- a/src/DynamicModel.cc +++ b/src/DynamicModel.cc @@ -2320,6 +2320,7 @@ DynamicModel::computeChainRuleJacobian() } // Compute the block derivatives + map<expr_t, set<int>> non_null_chain_rule_derivatives; map<pair<expr_t, int>, expr_t> chain_rule_deriv_cache; for (const auto &[indices, derivType] : determineBlockDerivativesType(blk)) { @@ -2337,10 +2338,10 @@ DynamicModel::computeChainRuleJacobian() d = Zero; break; case BlockDerivativeType::chainRule: - d = equations[eq_orig]->getChainRuleDerivative(deriv_id, recursive_vars, chain_rule_deriv_cache); + d = equations[eq_orig]->getChainRuleDerivative(deriv_id, recursive_vars, non_null_chain_rule_derivatives, chain_rule_deriv_cache); break; case BlockDerivativeType::normalizedChainRule: - d = equation_type_and_normalized_equation[eq_orig].second->getChainRuleDerivative(deriv_id, recursive_vars, chain_rule_deriv_cache); + d = equation_type_and_normalized_equation[eq_orig].second->getChainRuleDerivative(deriv_id, recursive_vars, non_null_chain_rule_derivatives, chain_rule_deriv_cache); break; } diff --git a/src/ExprNode.cc b/src/ExprNode.cc index 30591012329ca6b459cf5fe8f0adde1688c528a4..6f72885f28ae1c991ef7e7222b0b4c72e5866354 100644 --- a/src/ExprNode.cc +++ b/src/ExprNode.cc @@ -56,14 +56,24 @@ ExprNode::getDerivative(int deriv_id) expr_t ExprNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, + map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) { + if (!non_null_chain_rule_derivatives.contains(this)) + prepareForChainRuleDerivation(recursive_variables, non_null_chain_rule_derivatives); + + // Return zero if derivative is necessarily null (using symbolic a priori) + if (!non_null_chain_rule_derivatives.at(this).contains(deriv_id)) + return datatree.Zero; + + // If derivative is in the cache, return that value pair key {this, deriv_id}; if (auto it = cache.find(key); it != cache.end()) return it->second; - auto r = computeChainRuleDerivative(deriv_id, recursive_variables, cache); + auto r = computeChainRuleDerivative(deriv_id, recursive_variables, + non_null_chain_rule_derivatives, cache); auto [ignore, success] = cache.emplace(key, r); assert(success); // The element should not already exist @@ -477,6 +487,13 @@ NumConstNode::prepareForDerivation() // All derivatives are null, so non_null_derivatives is left empty } +void +NumConstNode::prepareForChainRuleDerivation([[maybe_unused]] const map<int, BinaryOpNode *> &recursive_variables, + map<expr_t, set<int>> &non_null_chain_rule_derivatives) const +{ + non_null_chain_rule_derivatives.try_emplace(const_cast<NumConstNode *>(this)); +} + expr_t NumConstNode::computeDerivative([[maybe_unused]] int deriv_id) { @@ -565,6 +582,7 @@ NumConstNode::normalizeEquationHelper([[maybe_unused]] const set<expr_t> &contai expr_t NumConstNode::computeChainRuleDerivative([[maybe_unused]] int deriv_id, [[maybe_unused]] const map<int, BinaryOpNode *> &recursive_variables, + [[maybe_unused]] map<expr_t, set<int>> &non_null_chain_rule_derivatives, [[maybe_unused]] map<pair<expr_t, int>, expr_t> &cache) { return datatree.Zero; @@ -897,6 +915,56 @@ VariableNode::prepareForDerivation() } } +void +VariableNode::prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, + map<expr_t, set<int>> &non_null_chain_rule_derivatives) const +{ + if (non_null_chain_rule_derivatives.contains(const_cast<VariableNode *>(this))) + return; + + switch (get_type()) + { + case SymbolType::endogenous: + { + set<int> &nnd { non_null_chain_rule_derivatives[const_cast<VariableNode *>(this)] }; + int my_deriv_id {datatree.getDerivID(symb_id, lag)}; + if (auto it = recursive_variables.find(my_deriv_id); + it != recursive_variables.end()) + { + it->second->arg2->prepareForChainRuleDerivation(recursive_variables, non_null_chain_rule_derivatives); + nnd = non_null_chain_rule_derivatives.at(it->second->arg2); + } + nnd.insert(my_deriv_id); + } + break; + case SymbolType::exogenous: + case SymbolType::exogenousDet: + case SymbolType::parameter: + case SymbolType::trend: + case SymbolType::logTrend: + case SymbolType::modFileLocalVariable: + case SymbolType::statementDeclaredVariable: + case SymbolType::unusedEndogenous: + // Those variables are never derived using chain rule + non_null_chain_rule_derivatives.try_emplace(const_cast<VariableNode *>(this)); + break; + case SymbolType::modelLocalVariable: + { + expr_t def { datatree.getLocalVariable(symb_id) }; + // Non null derivatives are those of the value of the model local variable + def->prepareForChainRuleDerivation(recursive_variables, non_null_chain_rule_derivatives); + non_null_chain_rule_derivatives.emplace(const_cast<VariableNode *>(this), + non_null_chain_rule_derivatives.at(def)); + } + break; + case SymbolType::externalFunction: + case SymbolType::epilogue: + case SymbolType::excludedVariable: + cerr << "VariableNode::prepareForChainRuleDerivation: impossible case" << endl; + exit(EXIT_FAILURE); + } +} + expr_t VariableNode::computeDerivative(int deriv_id) { @@ -1422,6 +1490,7 @@ VariableNode::normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs expr_t VariableNode::computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, + map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) { switch (get_type()) @@ -1442,12 +1511,12 @@ VariableNode::computeChainRuleDerivative(int deriv_id, // If there is in the equation a recursive variable we could use a chaine rule derivation else if (auto it = recursive_variables.find(my_deriv_id); it != recursive_variables.end()) - return it->second->arg2->getChainRuleDerivative(deriv_id, recursive_variables, cache); + return it->second->arg2->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache); else return datatree.Zero; case SymbolType::modelLocalVariable: - return datatree.getLocalVariable(symb_id)->getChainRuleDerivative(deriv_id, recursive_variables, cache); + return datatree.getLocalVariable(symb_id)->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache); case SymbolType::modFileLocalVariable: cerr << "modFileLocalVariable is not derivable" << endl; exit(EXIT_FAILURE); @@ -2151,6 +2220,28 @@ UnaryOpNode::prepareForDerivation() } } +void +UnaryOpNode::prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, + map<expr_t, set<int>> &non_null_chain_rule_derivatives) const +{ + if (non_null_chain_rule_derivatives.contains(const_cast<UnaryOpNode *>(this))) + return; + + /* Non-null derivatives are those of the argument (except for STEADY_STATE in + a dynamic context, in which case the potentially non-null derivatives are + all the parameters) */ + set<int> &nnd { non_null_chain_rule_derivatives[const_cast<UnaryOpNode *>(this)] }; + if ((op_code == UnaryOpcode::steadyState || op_code == UnaryOpcode::steadyStateParamDeriv + || op_code == UnaryOpcode::steadyStateParam2ndDeriv) + && datatree.isDynamic()) + datatree.addAllParamDerivId(nnd); + else + { + arg->prepareForChainRuleDerivation(recursive_variables, non_null_chain_rule_derivatives); + nnd = non_null_chain_rule_derivatives.at(arg); + } +} + expr_t UnaryOpNode::composeDerivatives(expr_t darg, int deriv_id) { @@ -3271,9 +3362,10 @@ UnaryOpNode::normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) expr_t UnaryOpNode::computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, + map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) { - expr_t darg = arg->getChainRuleDerivative(deriv_id, recursive_variables, cache); + expr_t darg = arg->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache); return composeDerivatives(darg, deriv_id); } @@ -3986,6 +4078,24 @@ BinaryOpNode::prepareForDerivation() inserter(non_null_derivatives, non_null_derivatives.begin())); } +void +BinaryOpNode::prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, + map<expr_t, set<int>> &non_null_chain_rule_derivatives) const +{ + if (non_null_chain_rule_derivatives.contains(const_cast<BinaryOpNode *>(this))) + return; + + arg1->prepareForChainRuleDerivation(recursive_variables, non_null_chain_rule_derivatives); + arg2->prepareForChainRuleDerivation(recursive_variables, non_null_chain_rule_derivatives); + + set<int> &nnd { non_null_chain_rule_derivatives[const_cast<BinaryOpNode *>(this)] }; + set_union(non_null_chain_rule_derivatives.at(arg1).begin(), + non_null_chain_rule_derivatives.at(arg1).end(), + non_null_chain_rule_derivatives.at(arg2).begin(), + non_null_chain_rule_derivatives.at(arg2).end(), + inserter(nnd, nnd.begin())); +} + expr_t BinaryOpNode::getNonZeroPartofEquation() const { @@ -5038,10 +5148,11 @@ BinaryOpNode::normalizeEquation(int symb_id, int lag) const expr_t BinaryOpNode::computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, + map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) { - expr_t darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables, cache); - expr_t darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables, cache); + expr_t darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache); + expr_t darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache); return composeDerivatives(darg1, darg2); } @@ -5888,6 +5999,30 @@ TrinaryOpNode::prepareForDerivation() inserter(non_null_derivatives, non_null_derivatives.begin())); } +void +TrinaryOpNode::prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, + map<expr_t, set<int>> &non_null_chain_rule_derivatives) const +{ + if (non_null_chain_rule_derivatives.contains(const_cast<TrinaryOpNode *>(this))) + return; + + arg1->prepareForChainRuleDerivation(recursive_variables, non_null_chain_rule_derivatives); + arg2->prepareForChainRuleDerivation(recursive_variables, non_null_chain_rule_derivatives); + arg3->prepareForChainRuleDerivation(recursive_variables, non_null_chain_rule_derivatives); + + set<int> &nnd { non_null_chain_rule_derivatives[const_cast<TrinaryOpNode *>(this)] }; + set<int> nnd_tmp; + set_union(non_null_chain_rule_derivatives.at(arg1).begin(), + non_null_chain_rule_derivatives.at(arg1).end(), + non_null_chain_rule_derivatives.at(arg2).begin(), + non_null_chain_rule_derivatives.at(arg2).end(), + inserter(nnd_tmp, nnd_tmp.begin())); + set_union(nnd_tmp.begin(), nnd_tmp.end(), + non_null_chain_rule_derivatives.at(arg3).begin(), + non_null_chain_rule_derivatives.at(arg3).end(), + inserter(nnd, nnd.begin())); +} + expr_t TrinaryOpNode::composeDerivatives(expr_t darg1, expr_t darg2, expr_t darg3) { @@ -6351,11 +6486,12 @@ TrinaryOpNode::normalizeEquationHelper([[maybe_unused]] const set<expr_t> &conta expr_t TrinaryOpNode::computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, + map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) { - expr_t darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables, cache); - expr_t darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables, cache); - expr_t darg3 = arg3->getChainRuleDerivative(deriv_id, recursive_variables, cache); + expr_t darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache); + expr_t darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache); + expr_t darg3 = arg3->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache); return composeDerivatives(darg1, darg2, darg3); } @@ -6772,6 +6908,30 @@ AbstractExternalFunctionNode::prepareForDerivation() preparedForDerivation = true; } +void +AbstractExternalFunctionNode::prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, + map<expr_t, set<int>> &non_null_chain_rule_derivatives) const +{ + if (non_null_chain_rule_derivatives.contains(const_cast<AbstractExternalFunctionNode *>(this))) + return; + + for (auto argument : arguments) + argument->prepareForChainRuleDerivation(recursive_variables, non_null_chain_rule_derivatives); + + non_null_chain_rule_derivatives.emplace(const_cast<AbstractExternalFunctionNode *>(this), + non_null_chain_rule_derivatives.at(arguments.at(0))); + set<int> &nnd { non_null_chain_rule_derivatives.at(const_cast<AbstractExternalFunctionNode *>(this)) }; + for (int i {1}; i < static_cast<int>(arguments.size()); i++) + { + set<int> nnd_tmp; + set_union(nnd.begin(), nnd.end(), + non_null_chain_rule_derivatives.at(arguments.at(i)).begin(), + non_null_chain_rule_derivatives.at(arguments.at(i)).end(), + inserter(nnd_tmp, nnd_tmp.begin())); + nnd = move(nnd_tmp); + } +} + expr_t AbstractExternalFunctionNode::computeDerivative(int deriv_id) { @@ -6785,12 +6945,13 @@ AbstractExternalFunctionNode::computeDerivative(int deriv_id) expr_t AbstractExternalFunctionNode::computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, + map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) { assert(datatree.external_functions_table.getNargs(symb_id) > 0); vector<expr_t> dargs; for (auto argument : arguments) - dargs.push_back(argument->getChainRuleDerivative(deriv_id, recursive_variables, cache)); + dargs.push_back(argument->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache)); return composeDerivatives(dargs); } @@ -8364,6 +8525,14 @@ SubModelNode::prepareForDerivation() exit(EXIT_FAILURE); } +void +SubModelNode::prepareForChainRuleDerivation([[maybe_unused]] const map<int, BinaryOpNode *> &recursive_variables, + [[maybe_unused]] map<expr_t, set<int>> &non_null_chain_rule_derivatives) const +{ + cerr << "SubModelNode::prepareForChainRuleDerivation not implemented." << endl; + exit(EXIT_FAILURE); +} + expr_t SubModelNode::computeDerivative([[maybe_unused]] int deriv_id) { @@ -8374,6 +8543,7 @@ SubModelNode::computeDerivative([[maybe_unused]] int deriv_id) expr_t SubModelNode::computeChainRuleDerivative([[maybe_unused]] int deriv_id, [[maybe_unused]] const map<int, BinaryOpNode *> &recursive_variables, + [[maybe_unused]] map<expr_t, set<int>> &non_null_chain_rule_derivatives, [[maybe_unused]] map<pair<expr_t, int>, expr_t> &cache) { cerr << "SubModelNode::computeChainRuleDerivative not implemented." << endl; diff --git a/src/ExprNode.hh b/src/ExprNode.hh index 5ecab5f445a20ff1cca73bf8828f017bc4bf367c..3d8543711936cf068e2e6a66de12bc1c1d1e18e2 100644 --- a/src/ExprNode.hh +++ b/src/ExprNode.hh @@ -249,7 +249,7 @@ private: /* Internal helper for getChainRuleDerivative(), that does the computation but assumes that the caching of this is handled elsewhere */ - virtual expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) = 0; + virtual expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) = 0; protected: //! Reference to the enclosing DataTree @@ -278,6 +278,14 @@ protected: //! Initializes data member non_null_derivatives virtual void prepareForDerivation() = 0; + /* Computes the derivatives which are potentially non-null, using symbolic a + priori, similarly to prepareForDerivation(), but in a chain rule + derivation context. See getChainRuleDerivation() for the meaning of + “recursive_variables”. Note that all non-endogenous variables are + automatically considered to have a zero derivative (since they’re never + used in a chain rule context) */ + virtual void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const = 0; + //! Cost of computing current node /*! Nodes included in temporary_terms are considered having a null cost */ virtual int cost(int cost, bool is_matlab) const; @@ -335,11 +343,17 @@ public: — “recursive_variables” contains the derivation ID for which chain rules must be applied. Keys are derivation IDs, values are equations of the form x=f(y) where x is the key variable and x doesn't appear in y + — “non_null_chain_rule_derivatives” is used to store the indices of + variables that are potentially non-null (using symbolic a priori), + similarly to ExprNode::non_null_derivatives. — “cache” is used to store already-computed derivatives (in a map - <expression, deriv_id> → derivative); this cache is specific to a given - value of “recursive_variables”, and thus should not be reused accross - calls that use different values of “recursive_variables”. */ - expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache); + <expression, deriv_id> → derivative) + NB: always returns zero when “deriv_id” corresponds to a non-endogenous + variable (since such variables are never used in a chain rule context). + NB 2: “non_null_chain_rule_derivatives” and “cache” are specific to a given + value of “recursive_variables”, and thus should not be reused accross + calls that use different values of “recursive_variables”. */ + expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache); //! Returns precedence of node /*! Equals 100 for constants, variables, unary ops, and temporary terms */ @@ -843,9 +857,10 @@ public: const int id; private: expr_t computeDerivative(int deriv_id) override; - expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override; + expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override; protected: void prepareForDerivation() override; + void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override; void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> ¶m_id, double &constant, bool at_denominator) const override; void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override; public: @@ -915,9 +930,10 @@ public: const int lag; private: expr_t computeDerivative(int deriv_id) override; - expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override; + expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override; protected: void prepareForDerivation() override; + void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override; void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> ¶m_id, double &constant, bool at_denominator) const override; void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override; public: @@ -985,6 +1001,7 @@ class UnaryOpNode : public ExprNode { protected: void prepareForDerivation() override; + void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override; void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> ¶m_id, double &constant, bool at_denominator) const override; void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override; public: @@ -998,7 +1015,7 @@ public: const vector<int> adl_lags; private: expr_t computeDerivative(int deriv_id) override; - expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override; + expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override; int cost(int cost, bool is_matlab) const override; int cost(const vector<vector<temporary_terms_t>> &blocks_temporary_terms, bool is_matlab) const override; int cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const override; @@ -1090,6 +1107,7 @@ class BinaryOpNode : public ExprNode { protected: void prepareForDerivation() override; + void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override; void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> ¶m_id, double &constant, bool at_denominator) const override; void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override; public: @@ -1099,7 +1117,7 @@ public: const string adlparam; private: expr_t computeDerivative(int deriv_id) override; - expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override; + expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override; int cost(int cost, bool is_matlab) const override; int cost(const vector<vector<temporary_terms_t>> &blocks_temporary_terms, bool is_matlab) const override; int cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const override; @@ -1240,10 +1258,11 @@ public: const TrinaryOpcode op_code; protected: void prepareForDerivation() override; + void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override; void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override; private: expr_t computeDerivative(int deriv_id) override; - expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override; + expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override; int cost(int cost, bool is_matlab) const override; int cost(const vector<vector<temporary_terms_t>> &blocks_temporary_terms, bool is_matlab) const override; int cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const override; @@ -1338,7 +1357,7 @@ public: const vector<expr_t> arguments; private: expr_t computeDerivative(int deriv_id) override; - expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override; + expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override; virtual expr_t composeDerivatives(const vector<expr_t> &dargs) = 0; // Computes the maximum of f applied to all arguments (result will always be non-negative) int maxHelper(const function<int (expr_t)> &f) const; @@ -1348,6 +1367,7 @@ protected: { }; void prepareForDerivation() override; + void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override; //! Returns true if the given external function has been written as a temporary term bool alreadyWrittenAsTefTerm(int the_symb_id, const deriv_node_temp_terms_t &tef_terms) const; //! Returns the index in the tef_terms map of this external function @@ -1622,9 +1642,10 @@ public: expr_t substituteLogTransform(int orig_symb_id, int aux_symb_id) const override; protected: void prepareForDerivation() override; + void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override; void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override; private: - expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override; + expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override; }; class VarExpectationNode : public SubModelNode diff --git a/src/StaticModel.cc b/src/StaticModel.cc index f3cfc2da0e6aaeaa75147b6815df1d5bd4ce1368..92e5f2606942840a5e6b95820580625c1e2b132c 100644 --- a/src/StaticModel.cc +++ b/src/StaticModel.cc @@ -666,6 +666,7 @@ StaticModel::computeChainRuleJacobian() && simulation_type != BlockSimulationType::solveTwoBoundariesComplete); int size = blocks[blk].size; + map<expr_t, set<int>> non_null_chain_rule_derivatives; map<pair<expr_t, int>, expr_t> chain_rule_deriv_cache; for (int eq = nb_recursives; eq < size; eq++) { @@ -673,7 +674,7 @@ StaticModel::computeChainRuleJacobian() for (int var = nb_recursives; var < size; var++) { int var_orig = getBlockVariableID(blk, var); - expr_t d1 = equations[eq_orig]->getChainRuleDerivative(getDerivID(symbol_table.getID(SymbolType::endogenous, var_orig), 0), recursive_vars, chain_rule_deriv_cache); + expr_t d1 = equations[eq_orig]->getChainRuleDerivative(getDerivID(symbol_table.getID(SymbolType::endogenous, var_orig), 0), recursive_vars, non_null_chain_rule_derivatives, chain_rule_deriv_cache); if (d1 != Zero) blocks_derivatives[blk][{ eq, var, 0 }] = d1; }