From 23b0c12d8e351b8d2f8323c5bd9aca240af778bd Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?S=C3=A9bastien=20Villemot?= <sebastien@dynare.org>
Date: Tue, 8 Nov 2022 12:28:48 +0100
Subject: [PATCH] Performance improvement of chain rule derivation, using
 caching

Useful for mfs > 0 on large models.
---
 src/DynamicModel.cc |  5 ++--
 src/ExprNode.cc     | 68 ++++++++++++++++++++++++++++++++-------------
 src/ExprNode.hh     | 34 ++++++++++++++---------
 src/StaticModel.cc  |  4 +--
 4 files changed, 74 insertions(+), 37 deletions(-)

diff --git a/src/DynamicModel.cc b/src/DynamicModel.cc
index 7292fc73..7fcd8314 100644
--- a/src/DynamicModel.cc
+++ b/src/DynamicModel.cc
@@ -3124,6 +3124,7 @@ DynamicModel::computeChainRuleJacobian()
         }
 
       // Compute the block derivatives
+      map<pair<expr_t, int>, expr_t> chain_rule_deriv_cache;
       for (const auto &[indices, derivType] : determineBlockDerivativesType(blk))
         {
           auto [lag, eq, var] = indices;
@@ -3140,10 +3141,10 @@ DynamicModel::computeChainRuleJacobian()
                 d = Zero;
               break;
             case BlockDerivativeType::chainRule:
-              d = equations[eq_orig]->getChainRuleDerivative(deriv_id, recursive_vars);
+              d = equations[eq_orig]->getChainRuleDerivative(deriv_id, recursive_vars, chain_rule_deriv_cache);
               break;
             case BlockDerivativeType::normalizedChainRule:
-              d = equation_type_and_normalized_equation[eq_orig].second->getChainRuleDerivative(deriv_id, recursive_vars);
+              d = equation_type_and_normalized_equation[eq_orig].second->getChainRuleDerivative(deriv_id, recursive_vars, chain_rule_deriv_cache);
               break;
             }
 
diff --git a/src/ExprNode.cc b/src/ExprNode.cc
index 47a2107f..3beed93a 100644
--- a/src/ExprNode.cc
+++ b/src/ExprNode.cc
@@ -53,6 +53,22 @@ ExprNode::getDerivative(int deriv_id)
     }
 }
 
+expr_t
+ExprNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables,
+                                 map<pair<expr_t, int>, expr_t> &cache)
+{
+  pair key {this, deriv_id};
+  if (auto it = cache.find(key);
+      it != cache.end())
+    return it->second;
+
+  auto r = computeChainRuleDerivative(deriv_id, recursive_variables, cache);
+
+  auto [ignore, success] = cache.emplace(key, r);
+  assert(success); // The element should not already exist
+  return r;
+}
+
 int
 ExprNode::precedence([[maybe_unused]] ExprNodeOutputType output_type,
                      [[maybe_unused]] const temporary_terms_t &temporary_terms) const
@@ -546,8 +562,9 @@ NumConstNode::normalizeEquationHelper([[maybe_unused]] const set<expr_t> &contai
 }
 
 expr_t
-NumConstNode::getChainRuleDerivative([[maybe_unused]] int deriv_id,
-                                     [[maybe_unused]] const map<int, BinaryOpNode *> &recursive_variables)
+NumConstNode::computeChainRuleDerivative([[maybe_unused]] int deriv_id,
+                                         [[maybe_unused]] const map<int, BinaryOpNode *> &recursive_variables,
+                                         [[maybe_unused]] map<pair<expr_t, int>, expr_t> &cache)
 {
   return datatree.Zero;
 }
@@ -1402,7 +1419,9 @@ VariableNode::normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs
 }
 
 expr_t
-VariableNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables)
+VariableNode::computeChainRuleDerivative(int deriv_id,
+                                         const map<int, BinaryOpNode *> &recursive_variables,
+                                         map<pair<expr_t, int>, expr_t> &cache)
 {
   switch (get_type())
     {
@@ -1421,12 +1440,12 @@ VariableNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *
       // If there is in the equation a recursive variable we could use a chaine rule derivation
       else if (auto it = recursive_variables.find(datatree.getDerivID(symb_id, lag));
                it != recursive_variables.end())
-        return it->second->arg2->getChainRuleDerivative(deriv_id, recursive_variables);
+        return it->second->arg2->getChainRuleDerivative(deriv_id, recursive_variables, cache);
       else
         return datatree.Zero;
 
     case SymbolType::modelLocalVariable:
-      return datatree.getLocalVariable(symb_id)->getChainRuleDerivative(deriv_id, recursive_variables);
+      return datatree.getLocalVariable(symb_id)->getChainRuleDerivative(deriv_id, recursive_variables, cache);
     case SymbolType::modFileLocalVariable:
       cerr << "modFileLocalVariable is not derivable" << endl;
       exit(EXIT_FAILURE);
@@ -1439,7 +1458,7 @@ VariableNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *
     case SymbolType::externalFunction:
     case SymbolType::epilogue:
     case SymbolType::excludedVariable:
-      cerr << "VariableNode::getChainRuleDerivative: Impossible case" << endl;
+      cerr << "VariableNode::computeChainRuleDerivative: Impossible case" << endl;
       exit(EXIT_FAILURE);
     }
   // Suppress GCC warning
@@ -3224,9 +3243,11 @@ UnaryOpNode::normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs)
 }
 
 expr_t
-UnaryOpNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables)
+UnaryOpNode::computeChainRuleDerivative(int deriv_id,
+                                        const map<int, BinaryOpNode *> &recursive_variables,
+                                        map<pair<expr_t, int>, expr_t> &cache)
 {
-  expr_t darg = arg->getChainRuleDerivative(deriv_id, recursive_variables);
+  expr_t darg = arg->getChainRuleDerivative(deriv_id, recursive_variables, cache);
   return composeDerivatives(darg, deriv_id);
 }
 
@@ -4978,10 +4999,12 @@ BinaryOpNode::normalizeEquation(int symb_id, int lag) const
 }
 
 expr_t
-BinaryOpNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables)
+BinaryOpNode::computeChainRuleDerivative(int deriv_id,
+                                         const map<int, BinaryOpNode *> &recursive_variables,
+                                         map<pair<expr_t, int>, expr_t> &cache)
 {
-  expr_t darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables);
-  expr_t darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables);
+  expr_t darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables, cache);
+  expr_t darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables, cache);
   return composeDerivatives(darg1, darg2);
 }
 
@@ -6289,11 +6312,13 @@ TrinaryOpNode::normalizeEquationHelper([[maybe_unused]] const set<expr_t> &conta
 }
 
 expr_t
-TrinaryOpNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables)
+TrinaryOpNode::computeChainRuleDerivative(int deriv_id,
+                                          const map<int, BinaryOpNode *> &recursive_variables,
+                                          map<pair<expr_t, int>, expr_t> &cache)
 {
-  expr_t darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables);
-  expr_t darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables);
-  expr_t darg3 = arg3->getChainRuleDerivative(deriv_id, recursive_variables);
+  expr_t darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables, cache);
+  expr_t darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables, cache);
+  expr_t darg3 = arg3->getChainRuleDerivative(deriv_id, recursive_variables, cache);
   return composeDerivatives(darg1, darg2, darg3);
 }
 
@@ -6717,12 +6742,14 @@ AbstractExternalFunctionNode::computeDerivative(int deriv_id)
 }
 
 expr_t
-AbstractExternalFunctionNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables)
+AbstractExternalFunctionNode::computeChainRuleDerivative(int deriv_id,
+                                                         const map<int, BinaryOpNode *> &recursive_variables,
+                                                         map<pair<expr_t, int>, expr_t> &cache)
 {
   assert(datatree.external_functions_table.getNargs(symb_id) > 0);
   vector<expr_t> dargs;
   for (auto argument : arguments)
-    dargs.push_back(argument->getChainRuleDerivative(deriv_id, recursive_variables));
+    dargs.push_back(argument->getChainRuleDerivative(deriv_id, recursive_variables, cache));
   return composeDerivatives(dargs);
 }
 
@@ -8334,10 +8361,11 @@ SubModelNode::computeDerivative([[maybe_unused]] int deriv_id)
 }
 
 expr_t
-SubModelNode::getChainRuleDerivative([[maybe_unused]] int deriv_id,
-                                     [[maybe_unused]] const map<int, BinaryOpNode *> &recursive_variables)
+SubModelNode::computeChainRuleDerivative([[maybe_unused]] int deriv_id,
+                                         [[maybe_unused]] const map<int, BinaryOpNode *> &recursive_variables,
+                                         [[maybe_unused]] map<pair<expr_t, int>, expr_t> &cache)
 {
-  cerr << "SubModelNode::getChainRuleDerivative not implemented." << endl;
+  cerr << "SubModelNode::computeChainRuleDerivative not implemented." << endl;
   exit(EXIT_FAILURE);
 }
 
diff --git a/src/ExprNode.hh b/src/ExprNode.hh
index f7cbc96c..e5469ad7 100644
--- a/src/ExprNode.hh
+++ b/src/ExprNode.hh
@@ -247,6 +247,10 @@ private:
   /*! You shoud use getDerivative() to get the benefit of symbolic a priori and of caching */
   virtual expr_t computeDerivative(int deriv_id) = 0;
 
+  /* Internal helper for getChainRuleDerivative(), that does the computation
+     but assumes that the caching of this is handled elsewhere */
+  virtual expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) = 0;
+
 protected:
   //! Reference to the enclosing DataTree
   DataTree &datatree;
@@ -327,12 +331,15 @@ public:
     For an equal node, returns the derivative of lhs minus rhs */
   expr_t getDerivative(int deriv_id);
 
-  //! Computes derivatives by applying the chain rule for some variables
-  /*!
-    \param deriv_id The derivation ID with respect to which we are derivating
-    \param recursive_variables Contains the derivation ID for which chain rules must be applied. Keys are derivation IDs, values are equations of the form x=f(y) where x is the key variable and x doesn't appear in y
-  */
-  virtual expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) = 0;
+  /* Computes derivatives by applying the chain rule for some variables.
+     — “recursive_variables” contains the derivation ID for which chain rules
+       must be applied. Keys are derivation IDs, values are equations of the
+       form x=f(y) where x is the key variable and x doesn't appear in y
+     — “cache” is used to store already-computed derivatives (in a map
+       <expression, deriv_id> → derivative); this cache is specific to a given
+       value of “recursive_variables”, and thus should not be reused accross
+       calls that use different values of “recursive_variables”. */
+  expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache);
 
   //! Returns precedence of node
   /*! Equals 100 for constants, variables, unary ops, and temporary terms */
@@ -836,6 +843,7 @@ public:
   const int id;
 private:
   expr_t computeDerivative(int deriv_id) override;
+  expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override;
 protected:
   void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> &param_id, double &constant, bool at_denominator) const override;
   void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
@@ -853,7 +861,6 @@ public:
   expr_t toStatic(DataTree &static_datatree) const override;
   void computeXrefs(EquationInfo &ei) const override;
   BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override;
-  expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override;
   int maxEndoLead() const override;
   int maxExoLead() const override;
   int maxEndoLag() const override;
@@ -908,6 +915,7 @@ public:
   const int lag;
 private:
   expr_t computeDerivative(int deriv_id) override;
+  expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override;
 protected:
   void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> &param_id, double &constant, bool at_denominator) const override;
   void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
@@ -926,7 +934,6 @@ public:
   void computeXrefs(EquationInfo &ei) const override;
   SymbolType get_type() const;
   BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override;
-  expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override;
   int maxEndoLead() const override;
   int maxExoLead() const override;
   int maxEndoLag() const override;
@@ -990,6 +997,7 @@ public:
   const vector<int> adl_lags;
 private:
   expr_t computeDerivative(int deriv_id) override;
+  expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override;
   int cost(int cost, bool is_matlab) const override;
   int cost(const vector<vector<temporary_terms_t>> &blocks_temporary_terms, bool is_matlab) const override;
   int cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const override;
@@ -1029,7 +1037,6 @@ public:
   expr_t toStatic(DataTree &static_datatree) const override;
   void computeXrefs(EquationInfo &ei) const override;
   BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override;
-  expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override;
   int maxEndoLead() const override;
   int maxExoLead() const override;
   int maxEndoLag() const override;
@@ -1091,6 +1098,7 @@ public:
   const string adlparam;
 private:
   expr_t computeDerivative(int deriv_id) override;
+  expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override;
   int cost(int cost, bool is_matlab) const override;
   int cost(const vector<vector<temporary_terms_t>> &blocks_temporary_terms, bool is_matlab) const override;
   int cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const override;
@@ -1137,7 +1145,6 @@ public:
   //! Try to normalize an equation with respect to a given dynamic variable.
   /*! Should only be called on Equal nodes. The variable must appear in the equation. */
   BinaryOpNode *normalizeEquation(int symb_id, int lag) const;
-  expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override;
   int maxEndoLead() const override;
   int maxExoLead() const override;
   int maxEndoLag() const override;
@@ -1235,6 +1242,7 @@ protected:
   void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
 private:
   expr_t computeDerivative(int deriv_id) override;
+  expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override;
   int cost(int cost, bool is_matlab) const override;
   int cost(const vector<vector<temporary_terms_t>> &blocks_temporary_terms, bool is_matlab) const override;
   int cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const override;
@@ -1276,7 +1284,6 @@ public:
   expr_t toStatic(DataTree &static_datatree) const override;
   void computeXrefs(EquationInfo &ei) const override;
   BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override;
-  expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override;
   int maxEndoLead() const override;
   int maxExoLead() const override;
   int maxEndoLag() const override;
@@ -1331,6 +1338,7 @@ public:
   const vector<expr_t> arguments;
 private:
   expr_t computeDerivative(int deriv_id) override;
+  expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override;
   virtual expr_t composeDerivatives(const vector<expr_t> &dargs) = 0;
 protected:
   //! Thrown when trying to access an unknown entry in external_function_node_map
@@ -1389,7 +1397,6 @@ public:
   expr_t toStatic(DataTree &static_datatree) const override = 0;
   void computeXrefs(EquationInfo &ei) const override = 0;
   BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override;
-  expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override;
   int maxEndoLead() const override;
   int maxExoLead() const override;
   int maxEndoLag() const override;
@@ -1568,7 +1575,6 @@ public:
   expr_t toStatic(DataTree &static_datatree) const override;
   void prepareForDerivation() override;
   expr_t computeDerivative(int deriv_id) override;
-  expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override;
   int maxEndoLead() const override;
   int maxExoLead() const override;
   int maxEndoLag() const override;
@@ -1615,6 +1621,8 @@ public:
   expr_t substituteLogTransform(int orig_symb_id, int aux_symb_id) const override;
 protected:
   void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
+private:
+  expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override;
 };
 
 class VarExpectationNode : public SubModelNode
diff --git a/src/StaticModel.cc b/src/StaticModel.cc
index 78f94ca7..e3adc0b3 100644
--- a/src/StaticModel.cc
+++ b/src/StaticModel.cc
@@ -952,14 +952,14 @@ StaticModel::computeChainRuleJacobian()
              && simulation_type != BlockSimulationType::solveTwoBoundariesComplete);
 
       int size = blocks[blk].size;
-
+      map<pair<expr_t, int>, expr_t> chain_rule_deriv_cache;
       for (int eq = nb_recursives; eq < size; eq++)
         {
           int eq_orig = getBlockEquationID(blk, eq);
           for (int var = nb_recursives; var < size; var++)
             {
               int var_orig = getBlockVariableID(blk, var);
-              expr_t d1 = equations[eq_orig]->getChainRuleDerivative(getDerivID(symbol_table.getID(SymbolType::endogenous, var_orig), 0), recursive_vars);
+              expr_t d1 = equations[eq_orig]->getChainRuleDerivative(getDerivID(symbol_table.getID(SymbolType::endogenous, var_orig), 0), recursive_vars, chain_rule_deriv_cache);
               if (d1 != Zero)
                 blocks_derivatives[blk][{ eq, var, 0 }] = d1;
             }
-- 
GitLab