From 7acf278370f2166c36ba832f97d09c6813a24e73 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?S=C3=A9bastien=20Villemot?= <sebastien@dynare.org>
Date: Thu, 2 Mar 2023 17:49:16 +0100
Subject: [PATCH] Performance improvement of chain rule derivation
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Commit 23b0c12d8e351b8d2f8323c5bd9aca240af778bd introduced caching in chain
rule derivation (used by block decomposition), which increased speed for mfs >
0, but actually decreased it for mfs=0.

This patch introduces the pre-computation of derivatives which are known to be
zero using symbolic a priori (similarly to what is done in the non-chain rule
context). The algorithms are now identical between the two contexts (both
symbolic a priori + caching), the difference being that in the chain rule
context, the symbolic a priori and the cache are not stored within the ExprNode
class, since they depend on the list of recursive variables.

This patch brings a significant performant improvement for all values of the
“mfs” option (the improvement is greater for small values of “mfs”).
---
 src/DynamicModel.cc |   5 +-
 src/ExprNode.cc     | 190 +++++++++++++++++++++++++++++++++++++++++---
 src/ExprNode.hh     |  45 ++++++++---
 src/StaticModel.cc  |   3 +-
 4 files changed, 218 insertions(+), 25 deletions(-)

diff --git a/src/DynamicModel.cc b/src/DynamicModel.cc
index 62230e78..c874ef2c 100644
--- a/src/DynamicModel.cc
+++ b/src/DynamicModel.cc
@@ -2320,6 +2320,7 @@ DynamicModel::computeChainRuleJacobian()
         }
 
       // Compute the block derivatives
+      map<expr_t, set<int>> non_null_chain_rule_derivatives;
       map<pair<expr_t, int>, expr_t> chain_rule_deriv_cache;
       for (const auto &[indices, derivType] : determineBlockDerivativesType(blk))
         {
@@ -2337,10 +2338,10 @@ DynamicModel::computeChainRuleJacobian()
                 d = Zero;
               break;
             case BlockDerivativeType::chainRule:
-              d = equations[eq_orig]->getChainRuleDerivative(deriv_id, recursive_vars, chain_rule_deriv_cache);
+              d = equations[eq_orig]->getChainRuleDerivative(deriv_id, recursive_vars, non_null_chain_rule_derivatives, chain_rule_deriv_cache);
               break;
             case BlockDerivativeType::normalizedChainRule:
-              d = equation_type_and_normalized_equation[eq_orig].second->getChainRuleDerivative(deriv_id, recursive_vars, chain_rule_deriv_cache);
+              d = equation_type_and_normalized_equation[eq_orig].second->getChainRuleDerivative(deriv_id, recursive_vars, non_null_chain_rule_derivatives, chain_rule_deriv_cache);
               break;
             }
 
diff --git a/src/ExprNode.cc b/src/ExprNode.cc
index 30591012..6f72885f 100644
--- a/src/ExprNode.cc
+++ b/src/ExprNode.cc
@@ -56,14 +56,24 @@ ExprNode::getDerivative(int deriv_id)
 
 expr_t
 ExprNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables,
+                                 map<expr_t, set<int>> &non_null_chain_rule_derivatives,
                                  map<pair<expr_t, int>, expr_t> &cache)
 {
+  if (!non_null_chain_rule_derivatives.contains(this))
+    prepareForChainRuleDerivation(recursive_variables, non_null_chain_rule_derivatives);
+
+  // Return zero if derivative is necessarily null (using symbolic a priori)
+  if (!non_null_chain_rule_derivatives.at(this).contains(deriv_id))
+    return datatree.Zero;
+
+  // If derivative is in the cache, return that value
   pair key {this, deriv_id};
   if (auto it = cache.find(key);
       it != cache.end())
     return it->second;
 
-  auto r = computeChainRuleDerivative(deriv_id, recursive_variables, cache);
+  auto r = computeChainRuleDerivative(deriv_id, recursive_variables,
+                                      non_null_chain_rule_derivatives, cache);
 
   auto [ignore, success] = cache.emplace(key, r);
   assert(success); // The element should not already exist
@@ -477,6 +487,13 @@ NumConstNode::prepareForDerivation()
   // All derivatives are null, so non_null_derivatives is left empty
 }
 
+void
+NumConstNode::prepareForChainRuleDerivation([[maybe_unused]] const map<int, BinaryOpNode *> &recursive_variables,
+                                            map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
+{
+  non_null_chain_rule_derivatives.try_emplace(const_cast<NumConstNode *>(this));
+}
+
 expr_t
 NumConstNode::computeDerivative([[maybe_unused]] int deriv_id)
 {
@@ -565,6 +582,7 @@ NumConstNode::normalizeEquationHelper([[maybe_unused]] const set<expr_t> &contai
 expr_t
 NumConstNode::computeChainRuleDerivative([[maybe_unused]] int deriv_id,
                                          [[maybe_unused]] const map<int, BinaryOpNode *> &recursive_variables,
+                                         [[maybe_unused]] map<expr_t, set<int>> &non_null_chain_rule_derivatives,
                                          [[maybe_unused]] map<pair<expr_t, int>, expr_t> &cache)
 {
   return datatree.Zero;
@@ -897,6 +915,56 @@ VariableNode::prepareForDerivation()
     }
 }
 
+void
+VariableNode::prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables,
+                                            map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
+{
+  if (non_null_chain_rule_derivatives.contains(const_cast<VariableNode *>(this)))
+    return;
+
+  switch (get_type())
+    {
+    case SymbolType::endogenous:
+      {
+        set<int> &nnd { non_null_chain_rule_derivatives[const_cast<VariableNode *>(this)] };
+        int my_deriv_id {datatree.getDerivID(symb_id, lag)};
+        if (auto it = recursive_variables.find(my_deriv_id);
+            it != recursive_variables.end())
+          {
+            it->second->arg2->prepareForChainRuleDerivation(recursive_variables, non_null_chain_rule_derivatives);
+            nnd = non_null_chain_rule_derivatives.at(it->second->arg2);
+          }
+        nnd.insert(my_deriv_id);
+      }
+      break;
+    case SymbolType::exogenous:
+    case SymbolType::exogenousDet:
+    case SymbolType::parameter:
+    case SymbolType::trend:
+    case SymbolType::logTrend:
+    case SymbolType::modFileLocalVariable:
+    case SymbolType::statementDeclaredVariable:
+    case SymbolType::unusedEndogenous:
+      // Those variables are never derived using chain rule
+      non_null_chain_rule_derivatives.try_emplace(const_cast<VariableNode *>(this));
+      break;
+    case SymbolType::modelLocalVariable:
+      {
+        expr_t def { datatree.getLocalVariable(symb_id) };
+        // Non null derivatives are those of the value of the model local variable
+        def->prepareForChainRuleDerivation(recursive_variables, non_null_chain_rule_derivatives);
+        non_null_chain_rule_derivatives.emplace(const_cast<VariableNode *>(this),
+                                                non_null_chain_rule_derivatives.at(def));
+      }
+      break;
+    case SymbolType::externalFunction:
+    case SymbolType::epilogue:
+    case SymbolType::excludedVariable:
+      cerr << "VariableNode::prepareForChainRuleDerivation: impossible case" << endl;
+      exit(EXIT_FAILURE);
+    }
+}
+
 expr_t
 VariableNode::computeDerivative(int deriv_id)
 {
@@ -1422,6 +1490,7 @@ VariableNode::normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs
 expr_t
 VariableNode::computeChainRuleDerivative(int deriv_id,
                                          const map<int, BinaryOpNode *> &recursive_variables,
+                                         map<expr_t, set<int>> &non_null_chain_rule_derivatives,
                                          map<pair<expr_t, int>, expr_t> &cache)
 {
   switch (get_type())
@@ -1442,12 +1511,12 @@ VariableNode::computeChainRuleDerivative(int deriv_id,
       // If there is in the equation a recursive variable we could use a chaine rule derivation
       else if (auto it = recursive_variables.find(my_deriv_id);
                it != recursive_variables.end())
-        return it->second->arg2->getChainRuleDerivative(deriv_id, recursive_variables, cache);
+        return it->second->arg2->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache);
       else
         return datatree.Zero;
 
     case SymbolType::modelLocalVariable:
-      return datatree.getLocalVariable(symb_id)->getChainRuleDerivative(deriv_id, recursive_variables, cache);
+      return datatree.getLocalVariable(symb_id)->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache);
     case SymbolType::modFileLocalVariable:
       cerr << "modFileLocalVariable is not derivable" << endl;
       exit(EXIT_FAILURE);
@@ -2151,6 +2220,28 @@ UnaryOpNode::prepareForDerivation()
     }
 }
 
+void
+UnaryOpNode::prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables,
+                                           map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
+{
+  if (non_null_chain_rule_derivatives.contains(const_cast<UnaryOpNode *>(this)))
+    return;
+
+  /* Non-null derivatives are those of the argument (except for STEADY_STATE in
+     a dynamic context, in which case the potentially non-null derivatives are
+     all the parameters) */
+  set<int> &nnd { non_null_chain_rule_derivatives[const_cast<UnaryOpNode *>(this)] };
+  if ((op_code == UnaryOpcode::steadyState || op_code == UnaryOpcode::steadyStateParamDeriv
+       || op_code == UnaryOpcode::steadyStateParam2ndDeriv)
+      && datatree.isDynamic())
+    datatree.addAllParamDerivId(nnd);
+  else
+    {
+      arg->prepareForChainRuleDerivation(recursive_variables, non_null_chain_rule_derivatives);
+      nnd = non_null_chain_rule_derivatives.at(arg);
+    }
+}
+
 expr_t
 UnaryOpNode::composeDerivatives(expr_t darg, int deriv_id)
 {
@@ -3271,9 +3362,10 @@ UnaryOpNode::normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs)
 expr_t
 UnaryOpNode::computeChainRuleDerivative(int deriv_id,
                                         const map<int, BinaryOpNode *> &recursive_variables,
+                                        map<expr_t, set<int>> &non_null_chain_rule_derivatives,
                                         map<pair<expr_t, int>, expr_t> &cache)
 {
-  expr_t darg = arg->getChainRuleDerivative(deriv_id, recursive_variables, cache);
+  expr_t darg = arg->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache);
   return composeDerivatives(darg, deriv_id);
 }
 
@@ -3986,6 +4078,24 @@ BinaryOpNode::prepareForDerivation()
             inserter(non_null_derivatives, non_null_derivatives.begin()));
 }
 
+void
+BinaryOpNode::prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables,
+                                            map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
+{
+  if (non_null_chain_rule_derivatives.contains(const_cast<BinaryOpNode *>(this)))
+    return;
+
+  arg1->prepareForChainRuleDerivation(recursive_variables, non_null_chain_rule_derivatives);
+  arg2->prepareForChainRuleDerivation(recursive_variables, non_null_chain_rule_derivatives);
+
+  set<int> &nnd { non_null_chain_rule_derivatives[const_cast<BinaryOpNode *>(this)] };
+  set_union(non_null_chain_rule_derivatives.at(arg1).begin(),
+            non_null_chain_rule_derivatives.at(arg1).end(),
+            non_null_chain_rule_derivatives.at(arg2).begin(),
+            non_null_chain_rule_derivatives.at(arg2).end(),
+            inserter(nnd, nnd.begin()));
+}
+
 expr_t
 BinaryOpNode::getNonZeroPartofEquation() const
 {
@@ -5038,10 +5148,11 @@ BinaryOpNode::normalizeEquation(int symb_id, int lag) const
 expr_t
 BinaryOpNode::computeChainRuleDerivative(int deriv_id,
                                          const map<int, BinaryOpNode *> &recursive_variables,
+                                         map<expr_t, set<int>> &non_null_chain_rule_derivatives,
                                          map<pair<expr_t, int>, expr_t> &cache)
 {
-  expr_t darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables, cache);
-  expr_t darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables, cache);
+  expr_t darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache);
+  expr_t darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache);
   return composeDerivatives(darg1, darg2);
 }
 
@@ -5888,6 +5999,30 @@ TrinaryOpNode::prepareForDerivation()
             inserter(non_null_derivatives, non_null_derivatives.begin()));
 }
 
+void
+TrinaryOpNode::prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables,
+                                             map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
+{
+  if (non_null_chain_rule_derivatives.contains(const_cast<TrinaryOpNode *>(this)))
+    return;
+
+  arg1->prepareForChainRuleDerivation(recursive_variables, non_null_chain_rule_derivatives);
+  arg2->prepareForChainRuleDerivation(recursive_variables, non_null_chain_rule_derivatives);
+  arg3->prepareForChainRuleDerivation(recursive_variables, non_null_chain_rule_derivatives);
+
+  set<int> &nnd { non_null_chain_rule_derivatives[const_cast<TrinaryOpNode *>(this)] };
+  set<int> nnd_tmp;
+  set_union(non_null_chain_rule_derivatives.at(arg1).begin(),
+            non_null_chain_rule_derivatives.at(arg1).end(),
+            non_null_chain_rule_derivatives.at(arg2).begin(),
+            non_null_chain_rule_derivatives.at(arg2).end(),
+            inserter(nnd_tmp, nnd_tmp.begin()));
+  set_union(nnd_tmp.begin(), nnd_tmp.end(),
+            non_null_chain_rule_derivatives.at(arg3).begin(),
+            non_null_chain_rule_derivatives.at(arg3).end(),
+            inserter(nnd, nnd.begin()));
+}
+
 expr_t
 TrinaryOpNode::composeDerivatives(expr_t darg1, expr_t darg2, expr_t darg3)
 {
@@ -6351,11 +6486,12 @@ TrinaryOpNode::normalizeEquationHelper([[maybe_unused]] const set<expr_t> &conta
 expr_t
 TrinaryOpNode::computeChainRuleDerivative(int deriv_id,
                                           const map<int, BinaryOpNode *> &recursive_variables,
+                                          map<expr_t, set<int>> &non_null_chain_rule_derivatives,
                                           map<pair<expr_t, int>, expr_t> &cache)
 {
-  expr_t darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables, cache);
-  expr_t darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables, cache);
-  expr_t darg3 = arg3->getChainRuleDerivative(deriv_id, recursive_variables, cache);
+  expr_t darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache);
+  expr_t darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache);
+  expr_t darg3 = arg3->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache);
   return composeDerivatives(darg1, darg2, darg3);
 }
 
@@ -6772,6 +6908,30 @@ AbstractExternalFunctionNode::prepareForDerivation()
   preparedForDerivation = true;
 }
 
+void
+AbstractExternalFunctionNode::prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables,
+                                                            map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
+{
+  if (non_null_chain_rule_derivatives.contains(const_cast<AbstractExternalFunctionNode *>(this)))
+    return;
+
+  for (auto argument : arguments)
+    argument->prepareForChainRuleDerivation(recursive_variables, non_null_chain_rule_derivatives);
+
+  non_null_chain_rule_derivatives.emplace(const_cast<AbstractExternalFunctionNode *>(this),
+                                          non_null_chain_rule_derivatives.at(arguments.at(0)));
+  set<int> &nnd { non_null_chain_rule_derivatives.at(const_cast<AbstractExternalFunctionNode *>(this)) };
+  for (int i {1}; i < static_cast<int>(arguments.size()); i++)
+    {
+      set<int> nnd_tmp;
+      set_union(nnd.begin(), nnd.end(),
+                non_null_chain_rule_derivatives.at(arguments.at(i)).begin(),
+                non_null_chain_rule_derivatives.at(arguments.at(i)).end(),
+                inserter(nnd_tmp, nnd_tmp.begin()));
+      nnd = move(nnd_tmp);
+    }
+}
+
 expr_t
 AbstractExternalFunctionNode::computeDerivative(int deriv_id)
 {
@@ -6785,12 +6945,13 @@ AbstractExternalFunctionNode::computeDerivative(int deriv_id)
 expr_t
 AbstractExternalFunctionNode::computeChainRuleDerivative(int deriv_id,
                                                          const map<int, BinaryOpNode *> &recursive_variables,
+                                                         map<expr_t, set<int>> &non_null_chain_rule_derivatives,
                                                          map<pair<expr_t, int>, expr_t> &cache)
 {
   assert(datatree.external_functions_table.getNargs(symb_id) > 0);
   vector<expr_t> dargs;
   for (auto argument : arguments)
-    dargs.push_back(argument->getChainRuleDerivative(deriv_id, recursive_variables, cache));
+    dargs.push_back(argument->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache));
   return composeDerivatives(dargs);
 }
 
@@ -8364,6 +8525,14 @@ SubModelNode::prepareForDerivation()
   exit(EXIT_FAILURE);
 }
 
+void
+SubModelNode::prepareForChainRuleDerivation([[maybe_unused]] const map<int, BinaryOpNode *> &recursive_variables,
+                                            [[maybe_unused]] map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
+{
+  cerr << "SubModelNode::prepareForChainRuleDerivation not implemented." << endl;
+  exit(EXIT_FAILURE);
+}
+
 expr_t
 SubModelNode::computeDerivative([[maybe_unused]] int deriv_id)
 {
@@ -8374,6 +8543,7 @@ SubModelNode::computeDerivative([[maybe_unused]] int deriv_id)
 expr_t
 SubModelNode::computeChainRuleDerivative([[maybe_unused]] int deriv_id,
                                          [[maybe_unused]] const map<int, BinaryOpNode *> &recursive_variables,
+                                         [[maybe_unused]] map<expr_t, set<int>> &non_null_chain_rule_derivatives,
                                          [[maybe_unused]] map<pair<expr_t, int>, expr_t> &cache)
 {
   cerr << "SubModelNode::computeChainRuleDerivative not implemented." << endl;
diff --git a/src/ExprNode.hh b/src/ExprNode.hh
index 5ecab5f4..3d854371 100644
--- a/src/ExprNode.hh
+++ b/src/ExprNode.hh
@@ -249,7 +249,7 @@ private:
 
   /* Internal helper for getChainRuleDerivative(), that does the computation
      but assumes that the caching of this is handled elsewhere */
-  virtual expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) = 0;
+  virtual expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) = 0;
 
 protected:
   //! Reference to the enclosing DataTree
@@ -278,6 +278,14 @@ protected:
   //! Initializes data member non_null_derivatives
   virtual void prepareForDerivation() = 0;
 
+  /* Computes the derivatives which are potentially non-null, using symbolic a
+     priori, similarly to prepareForDerivation(), but in a chain rule
+     derivation context. See getChainRuleDerivation() for the meaning of
+     “recursive_variables”. Note that all non-endogenous variables are
+     automatically considered to have a zero derivative (since they’re never
+     used in a chain rule context) */
+  virtual void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const = 0;
+
   //! Cost of computing current node
   /*! Nodes included in temporary_terms are considered having a null cost */
   virtual int cost(int cost, bool is_matlab) const;
@@ -335,11 +343,17 @@ public:
      — “recursive_variables” contains the derivation ID for which chain rules
        must be applied. Keys are derivation IDs, values are equations of the
        form x=f(y) where x is the key variable and x doesn't appear in y
+     — “non_null_chain_rule_derivatives” is used to store the indices of
+       variables that are potentially non-null (using symbolic a priori),
+       similarly to ExprNode::non_null_derivatives.
      — “cache” is used to store already-computed derivatives (in a map
-       <expression, deriv_id> → derivative); this cache is specific to a given
-       value of “recursive_variables”, and thus should not be reused accross
-       calls that use different values of “recursive_variables”. */
-  expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache);
+       <expression, deriv_id> → derivative)
+     NB: always returns zero when “deriv_id” corresponds to a non-endogenous
+     variable (since such variables are never used in a chain rule context).
+     NB 2: “non_null_chain_rule_derivatives” and “cache” are specific to a given
+     value of “recursive_variables”, and thus should not be reused accross
+     calls that use different values of “recursive_variables”. */
+  expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache);
 
   //! Returns precedence of node
   /*! Equals 100 for constants, variables, unary ops, and temporary terms */
@@ -843,9 +857,10 @@ public:
   const int id;
 private:
   expr_t computeDerivative(int deriv_id) override;
-  expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override;
+  expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override;
 protected:
   void prepareForDerivation() override;
+  void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
   void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> &param_id, double &constant, bool at_denominator) const override;
   void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
 public:
@@ -915,9 +930,10 @@ public:
   const int lag;
 private:
   expr_t computeDerivative(int deriv_id) override;
-  expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override;
+  expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override;
 protected:
   void prepareForDerivation() override;
+  void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
   void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> &param_id, double &constant, bool at_denominator) const override;
   void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
 public:
@@ -985,6 +1001,7 @@ class UnaryOpNode : public ExprNode
 {
 protected:
   void prepareForDerivation() override;
+  void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
   void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> &param_id, double &constant, bool at_denominator) const override;
   void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
 public:
@@ -998,7 +1015,7 @@ public:
   const vector<int> adl_lags;
 private:
   expr_t computeDerivative(int deriv_id) override;
-  expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override;
+  expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override;
   int cost(int cost, bool is_matlab) const override;
   int cost(const vector<vector<temporary_terms_t>> &blocks_temporary_terms, bool is_matlab) const override;
   int cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const override;
@@ -1090,6 +1107,7 @@ class BinaryOpNode : public ExprNode
 {
 protected:
   void prepareForDerivation() override;
+  void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
   void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> &param_id, double &constant, bool at_denominator) const override;
   void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
 public:
@@ -1099,7 +1117,7 @@ public:
   const string adlparam;
 private:
   expr_t computeDerivative(int deriv_id) override;
-  expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override;
+  expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override;
   int cost(int cost, bool is_matlab) const override;
   int cost(const vector<vector<temporary_terms_t>> &blocks_temporary_terms, bool is_matlab) const override;
   int cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const override;
@@ -1240,10 +1258,11 @@ public:
   const TrinaryOpcode op_code;
 protected:
   void prepareForDerivation() override;
+  void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
   void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
 private:
   expr_t computeDerivative(int deriv_id) override;
-  expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override;
+  expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override;
   int cost(int cost, bool is_matlab) const override;
   int cost(const vector<vector<temporary_terms_t>> &blocks_temporary_terms, bool is_matlab) const override;
   int cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const override;
@@ -1338,7 +1357,7 @@ public:
   const vector<expr_t> arguments;
 private:
   expr_t computeDerivative(int deriv_id) override;
-  expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override;
+  expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override;
   virtual expr_t composeDerivatives(const vector<expr_t> &dargs) = 0;
   // Computes the maximum of f applied to all arguments (result will always be non-negative)
   int maxHelper(const function<int (expr_t)> &f) const;
@@ -1348,6 +1367,7 @@ protected:
   {
   };
   void prepareForDerivation() override;
+  void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
   //! Returns true if the given external function has been written as a temporary term
   bool alreadyWrittenAsTefTerm(int the_symb_id, const deriv_node_temp_terms_t &tef_terms) const;
   //! Returns the index in the tef_terms map of this external function
@@ -1622,9 +1642,10 @@ public:
   expr_t substituteLogTransform(int orig_symb_id, int aux_symb_id) const override;
 protected:
   void prepareForDerivation() override;
+  void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
   void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
 private:
-  expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override;
+  expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override;
 };
 
 class VarExpectationNode : public SubModelNode
diff --git a/src/StaticModel.cc b/src/StaticModel.cc
index f3cfc2da..92e5f260 100644
--- a/src/StaticModel.cc
+++ b/src/StaticModel.cc
@@ -666,6 +666,7 @@ StaticModel::computeChainRuleJacobian()
              && simulation_type != BlockSimulationType::solveTwoBoundariesComplete);
 
       int size = blocks[blk].size;
+      map<expr_t, set<int>> non_null_chain_rule_derivatives;
       map<pair<expr_t, int>, expr_t> chain_rule_deriv_cache;
       for (int eq = nb_recursives; eq < size; eq++)
         {
@@ -673,7 +674,7 @@ StaticModel::computeChainRuleJacobian()
           for (int var = nb_recursives; var < size; var++)
             {
               int var_orig = getBlockVariableID(blk, var);
-              expr_t d1 = equations[eq_orig]->getChainRuleDerivative(getDerivID(symbol_table.getID(SymbolType::endogenous, var_orig), 0), recursive_vars, chain_rule_deriv_cache);
+              expr_t d1 = equations[eq_orig]->getChainRuleDerivative(getDerivID(symbol_table.getID(SymbolType::endogenous, var_orig), 0), recursive_vars, non_null_chain_rule_derivatives, chain_rule_deriv_cache);
               if (d1 != Zero)
                 blocks_derivatives[blk][{ eq, var, 0 }] = d1;
             }
-- 
GitLab