From 23b0c12d8e351b8d2f8323c5bd9aca240af778bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Villemot?= <sebastien@dynare.org> Date: Tue, 8 Nov 2022 12:28:48 +0100 Subject: [PATCH] Performance improvement of chain rule derivation, using caching Useful for mfs > 0 on large models. --- src/DynamicModel.cc | 5 ++-- src/ExprNode.cc | 68 ++++++++++++++++++++++++++++++++------------- src/ExprNode.hh | 34 ++++++++++++++--------- src/StaticModel.cc | 4 +-- 4 files changed, 74 insertions(+), 37 deletions(-) diff --git a/src/DynamicModel.cc b/src/DynamicModel.cc index 7292fc73..7fcd8314 100644 --- a/src/DynamicModel.cc +++ b/src/DynamicModel.cc @@ -3124,6 +3124,7 @@ DynamicModel::computeChainRuleJacobian() } // Compute the block derivatives + map<pair<expr_t, int>, expr_t> chain_rule_deriv_cache; for (const auto &[indices, derivType] : determineBlockDerivativesType(blk)) { auto [lag, eq, var] = indices; @@ -3140,10 +3141,10 @@ DynamicModel::computeChainRuleJacobian() d = Zero; break; case BlockDerivativeType::chainRule: - d = equations[eq_orig]->getChainRuleDerivative(deriv_id, recursive_vars); + d = equations[eq_orig]->getChainRuleDerivative(deriv_id, recursive_vars, chain_rule_deriv_cache); break; case BlockDerivativeType::normalizedChainRule: - d = equation_type_and_normalized_equation[eq_orig].second->getChainRuleDerivative(deriv_id, recursive_vars); + d = equation_type_and_normalized_equation[eq_orig].second->getChainRuleDerivative(deriv_id, recursive_vars, chain_rule_deriv_cache); break; } diff --git a/src/ExprNode.cc b/src/ExprNode.cc index 47a2107f..3beed93a 100644 --- a/src/ExprNode.cc +++ b/src/ExprNode.cc @@ -53,6 +53,22 @@ ExprNode::getDerivative(int deriv_id) } } +expr_t +ExprNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, + map<pair<expr_t, int>, expr_t> &cache) +{ + pair key {this, deriv_id}; + if (auto it = cache.find(key); + it != cache.end()) + return it->second; + + auto r = computeChainRuleDerivative(deriv_id, recursive_variables, cache); + + auto [ignore, success] = cache.emplace(key, r); + assert(success); // The element should not already exist + return r; +} + int ExprNode::precedence([[maybe_unused]] ExprNodeOutputType output_type, [[maybe_unused]] const temporary_terms_t &temporary_terms) const @@ -546,8 +562,9 @@ NumConstNode::normalizeEquationHelper([[maybe_unused]] const set<expr_t> &contai } expr_t -NumConstNode::getChainRuleDerivative([[maybe_unused]] int deriv_id, - [[maybe_unused]] const map<int, BinaryOpNode *> &recursive_variables) +NumConstNode::computeChainRuleDerivative([[maybe_unused]] int deriv_id, + [[maybe_unused]] const map<int, BinaryOpNode *> &recursive_variables, + [[maybe_unused]] map<pair<expr_t, int>, expr_t> &cache) { return datatree.Zero; } @@ -1402,7 +1419,9 @@ VariableNode::normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs } expr_t -VariableNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) +VariableNode::computeChainRuleDerivative(int deriv_id, + const map<int, BinaryOpNode *> &recursive_variables, + map<pair<expr_t, int>, expr_t> &cache) { switch (get_type()) { @@ -1421,12 +1440,12 @@ VariableNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode * // If there is in the equation a recursive variable we could use a chaine rule derivation else if (auto it = recursive_variables.find(datatree.getDerivID(symb_id, lag)); it != recursive_variables.end()) - return it->second->arg2->getChainRuleDerivative(deriv_id, recursive_variables); + return it->second->arg2->getChainRuleDerivative(deriv_id, recursive_variables, cache); else return datatree.Zero; case SymbolType::modelLocalVariable: - return datatree.getLocalVariable(symb_id)->getChainRuleDerivative(deriv_id, recursive_variables); + return datatree.getLocalVariable(symb_id)->getChainRuleDerivative(deriv_id, recursive_variables, cache); case SymbolType::modFileLocalVariable: cerr << "modFileLocalVariable is not derivable" << endl; exit(EXIT_FAILURE); @@ -1439,7 +1458,7 @@ VariableNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode * case SymbolType::externalFunction: case SymbolType::epilogue: case SymbolType::excludedVariable: - cerr << "VariableNode::getChainRuleDerivative: Impossible case" << endl; + cerr << "VariableNode::computeChainRuleDerivative: Impossible case" << endl; exit(EXIT_FAILURE); } // Suppress GCC warning @@ -3224,9 +3243,11 @@ UnaryOpNode::normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) } expr_t -UnaryOpNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) +UnaryOpNode::computeChainRuleDerivative(int deriv_id, + const map<int, BinaryOpNode *> &recursive_variables, + map<pair<expr_t, int>, expr_t> &cache) { - expr_t darg = arg->getChainRuleDerivative(deriv_id, recursive_variables); + expr_t darg = arg->getChainRuleDerivative(deriv_id, recursive_variables, cache); return composeDerivatives(darg, deriv_id); } @@ -4978,10 +4999,12 @@ BinaryOpNode::normalizeEquation(int symb_id, int lag) const } expr_t -BinaryOpNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) +BinaryOpNode::computeChainRuleDerivative(int deriv_id, + const map<int, BinaryOpNode *> &recursive_variables, + map<pair<expr_t, int>, expr_t> &cache) { - expr_t darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables); - expr_t darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables); + expr_t darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables, cache); + expr_t darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables, cache); return composeDerivatives(darg1, darg2); } @@ -6289,11 +6312,13 @@ TrinaryOpNode::normalizeEquationHelper([[maybe_unused]] const set<expr_t> &conta } expr_t -TrinaryOpNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) +TrinaryOpNode::computeChainRuleDerivative(int deriv_id, + const map<int, BinaryOpNode *> &recursive_variables, + map<pair<expr_t, int>, expr_t> &cache) { - expr_t darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables); - expr_t darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables); - expr_t darg3 = arg3->getChainRuleDerivative(deriv_id, recursive_variables); + expr_t darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables, cache); + expr_t darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables, cache); + expr_t darg3 = arg3->getChainRuleDerivative(deriv_id, recursive_variables, cache); return composeDerivatives(darg1, darg2, darg3); } @@ -6717,12 +6742,14 @@ AbstractExternalFunctionNode::computeDerivative(int deriv_id) } expr_t -AbstractExternalFunctionNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) +AbstractExternalFunctionNode::computeChainRuleDerivative(int deriv_id, + const map<int, BinaryOpNode *> &recursive_variables, + map<pair<expr_t, int>, expr_t> &cache) { assert(datatree.external_functions_table.getNargs(symb_id) > 0); vector<expr_t> dargs; for (auto argument : arguments) - dargs.push_back(argument->getChainRuleDerivative(deriv_id, recursive_variables)); + dargs.push_back(argument->getChainRuleDerivative(deriv_id, recursive_variables, cache)); return composeDerivatives(dargs); } @@ -8334,10 +8361,11 @@ SubModelNode::computeDerivative([[maybe_unused]] int deriv_id) } expr_t -SubModelNode::getChainRuleDerivative([[maybe_unused]] int deriv_id, - [[maybe_unused]] const map<int, BinaryOpNode *> &recursive_variables) +SubModelNode::computeChainRuleDerivative([[maybe_unused]] int deriv_id, + [[maybe_unused]] const map<int, BinaryOpNode *> &recursive_variables, + [[maybe_unused]] map<pair<expr_t, int>, expr_t> &cache) { - cerr << "SubModelNode::getChainRuleDerivative not implemented." << endl; + cerr << "SubModelNode::computeChainRuleDerivative not implemented." << endl; exit(EXIT_FAILURE); } diff --git a/src/ExprNode.hh b/src/ExprNode.hh index f7cbc96c..e5469ad7 100644 --- a/src/ExprNode.hh +++ b/src/ExprNode.hh @@ -247,6 +247,10 @@ private: /*! You shoud use getDerivative() to get the benefit of symbolic a priori and of caching */ virtual expr_t computeDerivative(int deriv_id) = 0; + /* Internal helper for getChainRuleDerivative(), that does the computation + but assumes that the caching of this is handled elsewhere */ + virtual expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) = 0; + protected: //! Reference to the enclosing DataTree DataTree &datatree; @@ -327,12 +331,15 @@ public: For an equal node, returns the derivative of lhs minus rhs */ expr_t getDerivative(int deriv_id); - //! Computes derivatives by applying the chain rule for some variables - /*! - \param deriv_id The derivation ID with respect to which we are derivating - \param recursive_variables Contains the derivation ID for which chain rules must be applied. Keys are derivation IDs, values are equations of the form x=f(y) where x is the key variable and x doesn't appear in y - */ - virtual expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) = 0; + /* Computes derivatives by applying the chain rule for some variables. + — “recursive_variables” contains the derivation ID for which chain rules + must be applied. Keys are derivation IDs, values are equations of the + form x=f(y) where x is the key variable and x doesn't appear in y + — “cache” is used to store already-computed derivatives (in a map + <expression, deriv_id> → derivative); this cache is specific to a given + value of “recursive_variables”, and thus should not be reused accross + calls that use different values of “recursive_variables”. */ + expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache); //! Returns precedence of node /*! Equals 100 for constants, variables, unary ops, and temporary terms */ @@ -836,6 +843,7 @@ public: const int id; private: expr_t computeDerivative(int deriv_id) override; + expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override; protected: void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> ¶m_id, double &constant, bool at_denominator) const override; void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override; @@ -853,7 +861,6 @@ public: expr_t toStatic(DataTree &static_datatree) const override; void computeXrefs(EquationInfo &ei) const override; BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override; - expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override; int maxEndoLead() const override; int maxExoLead() const override; int maxEndoLag() const override; @@ -908,6 +915,7 @@ public: const int lag; private: expr_t computeDerivative(int deriv_id) override; + expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override; protected: void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> ¶m_id, double &constant, bool at_denominator) const override; void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override; @@ -926,7 +934,6 @@ public: void computeXrefs(EquationInfo &ei) const override; SymbolType get_type() const; BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override; - expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override; int maxEndoLead() const override; int maxExoLead() const override; int maxEndoLag() const override; @@ -990,6 +997,7 @@ public: const vector<int> adl_lags; private: expr_t computeDerivative(int deriv_id) override; + expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override; int cost(int cost, bool is_matlab) const override; int cost(const vector<vector<temporary_terms_t>> &blocks_temporary_terms, bool is_matlab) const override; int cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const override; @@ -1029,7 +1037,6 @@ public: expr_t toStatic(DataTree &static_datatree) const override; void computeXrefs(EquationInfo &ei) const override; BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override; - expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override; int maxEndoLead() const override; int maxExoLead() const override; int maxEndoLag() const override; @@ -1091,6 +1098,7 @@ public: const string adlparam; private: expr_t computeDerivative(int deriv_id) override; + expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override; int cost(int cost, bool is_matlab) const override; int cost(const vector<vector<temporary_terms_t>> &blocks_temporary_terms, bool is_matlab) const override; int cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const override; @@ -1137,7 +1145,6 @@ public: //! Try to normalize an equation with respect to a given dynamic variable. /*! Should only be called on Equal nodes. The variable must appear in the equation. */ BinaryOpNode *normalizeEquation(int symb_id, int lag) const; - expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override; int maxEndoLead() const override; int maxExoLead() const override; int maxEndoLag() const override; @@ -1235,6 +1242,7 @@ protected: void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override; private: expr_t computeDerivative(int deriv_id) override; + expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override; int cost(int cost, bool is_matlab) const override; int cost(const vector<vector<temporary_terms_t>> &blocks_temporary_terms, bool is_matlab) const override; int cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const override; @@ -1276,7 +1284,6 @@ public: expr_t toStatic(DataTree &static_datatree) const override; void computeXrefs(EquationInfo &ei) const override; BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override; - expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override; int maxEndoLead() const override; int maxExoLead() const override; int maxEndoLag() const override; @@ -1331,6 +1338,7 @@ public: const vector<expr_t> arguments; private: expr_t computeDerivative(int deriv_id) override; + expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override; virtual expr_t composeDerivatives(const vector<expr_t> &dargs) = 0; protected: //! Thrown when trying to access an unknown entry in external_function_node_map @@ -1389,7 +1397,6 @@ public: expr_t toStatic(DataTree &static_datatree) const override = 0; void computeXrefs(EquationInfo &ei) const override = 0; BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override; - expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override; int maxEndoLead() const override; int maxExoLead() const override; int maxEndoLag() const override; @@ -1568,7 +1575,6 @@ public: expr_t toStatic(DataTree &static_datatree) const override; void prepareForDerivation() override; expr_t computeDerivative(int deriv_id) override; - expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override; int maxEndoLead() const override; int maxExoLead() const override; int maxEndoLag() const override; @@ -1615,6 +1621,8 @@ public: expr_t substituteLogTransform(int orig_symb_id, int aux_symb_id) const override; protected: void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override; +private: + expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override; }; class VarExpectationNode : public SubModelNode diff --git a/src/StaticModel.cc b/src/StaticModel.cc index 78f94ca7..e3adc0b3 100644 --- a/src/StaticModel.cc +++ b/src/StaticModel.cc @@ -952,14 +952,14 @@ StaticModel::computeChainRuleJacobian() && simulation_type != BlockSimulationType::solveTwoBoundariesComplete); int size = blocks[blk].size; - + map<pair<expr_t, int>, expr_t> chain_rule_deriv_cache; for (int eq = nb_recursives; eq < size; eq++) { int eq_orig = getBlockEquationID(blk, eq); for (int var = nb_recursives; var < size; var++) { int var_orig = getBlockVariableID(blk, var); - expr_t d1 = equations[eq_orig]->getChainRuleDerivative(getDerivID(symbol_table.getID(SymbolType::endogenous, var_orig), 0), recursive_vars); + expr_t d1 = equations[eq_orig]->getChainRuleDerivative(getDerivID(symbol_table.getID(SymbolType::endogenous, var_orig), 0), recursive_vars, chain_rule_deriv_cache); if (d1 != Zero) blocks_derivatives[blk][{ eq, var, 0 }] = d1; } -- GitLab