diff --git a/src/DynamicModel.cc b/src/DynamicModel.cc
index 858c708788bc6eeee213b14c686ae6d1f5767662..4d687bf66d20163e05b96d7130637fb8e10b958a 100644
--- a/src/DynamicModel.cc
+++ b/src/DynamicModel.cc
@@ -4520,7 +4520,7 @@ DynamicModel::computeChainRuleJacobian()
       int nb_recursives = blocks[blk].getRecursiveSize();
 
       // Create a map from recursive vars to their defining (normalized) equation
-      map<int, expr_t> recursive_vars;
+      map<int, BinaryOpNode *> recursive_vars;
       for (int i = 0; i < nb_recursives; i++)
         {
           int deriv_id = getDerivID(symbol_table.getID(SymbolType::endogenous, getBlockVariableID(blk, i)), 0);
diff --git a/src/ExprNode.cc b/src/ExprNode.cc
index dff1abb6872940be30866c21074dd7b520efdfc4..6de8f8447deb4a37ff626bdb50d381a69da8a857 100644
--- a/src/ExprNode.cc
+++ b/src/ExprNode.cc
@@ -483,7 +483,7 @@ NumConstNode::normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs
 }
 
 expr_t
-NumConstNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables)
+NumConstNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables)
 {
   return datatree.Zero;
 }
@@ -1307,7 +1307,7 @@ VariableNode::normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs
 }
 
 expr_t
-VariableNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables)
+VariableNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables)
 {
   switch (get_type())
     {
@@ -1322,11 +1322,7 @@ VariableNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recur
       // If there is in the equation a recursive variable we could use a chaine rule derivation
       else if (auto it = recursive_variables.find(datatree.getDerivID(symb_id, lag));
                it != recursive_variables.end())
-        {
-          map<int, expr_t> recursive_vars2(recursive_variables);
-          recursive_vars2.erase(it->first);
-          return datatree.AddUMinus(it->second->getChainRuleDerivative(deriv_id, recursive_vars2));
-        }
+        return it->second->arg2->getChainRuleDerivative(deriv_id, recursive_variables);
       else
         return datatree.Zero;
 
@@ -3088,7 +3084,7 @@ UnaryOpNode::normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs)
 }
 
 expr_t
-UnaryOpNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables)
+UnaryOpNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables)
 {
   expr_t darg = arg->getChainRuleDerivative(deriv_id, recursive_variables);
   return composeDerivatives(darg, deriv_id);
@@ -4820,7 +4816,7 @@ BinaryOpNode::normalizeEquation(int symb_id, int lag) const
 }
 
 expr_t
-BinaryOpNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables)
+BinaryOpNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables)
 {
   expr_t darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables);
   expr_t darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables);
@@ -6089,7 +6085,7 @@ TrinaryOpNode::normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rh
 }
 
 expr_t
-TrinaryOpNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables)
+TrinaryOpNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables)
 {
   expr_t darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables);
   expr_t darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables);
@@ -6508,7 +6504,7 @@ AbstractExternalFunctionNode::computeDerivative(int deriv_id)
 }
 
 expr_t
-AbstractExternalFunctionNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables)
+AbstractExternalFunctionNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables)
 {
   assert(datatree.external_functions_table.getNargs(symb_id) > 0);
   vector<expr_t> dargs;
@@ -8185,7 +8181,7 @@ VarExpectationNode::computeDerivative(int deriv_id)
 }
 
 expr_t
-VarExpectationNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables)
+VarExpectationNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables)
 {
   cerr << "VarExpectationNode::getChainRuleDerivative not implemented." << endl;
   exit(EXIT_FAILURE);
@@ -8582,7 +8578,7 @@ PacExpectationNode::computeDerivative(int deriv_id)
 }
 
 expr_t
-PacExpectationNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables)
+PacExpectationNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables)
 {
   cerr << "PacExpectationNode::getChainRuleDerivative: shouldn't arrive here." << endl;
   exit(EXIT_FAILURE);
diff --git a/src/ExprNode.hh b/src/ExprNode.hh
index e7b1f910ec061928477d3f812f7f7de60d80a853..4e3b2cd4d86a3c4f8ffb04fa11e4465ddb6f8b66 100644
--- a/src/ExprNode.hh
+++ b/src/ExprNode.hh
@@ -268,7 +268,7 @@ public:
     \param deriv_id The derivation ID with respect to which we are derivating
     \param recursive_variables Contains the derivation ID for which chain rules must be applied. Keys are derivation IDs, values are equations of the form x=f(y) where x is the key variable and x doesn't appear in y
   */
-  virtual expr_t getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables) = 0;
+  virtual expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) = 0;
 
   //! Returns precedence of node
   /*! Equals 100 for constants, variables, unary ops, and temporary terms */
@@ -747,7 +747,7 @@ public:
   void computeXrefs(EquationInfo &ei) const override;
   void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
   BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override;
-  expr_t getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables) override;
+  expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override;
   int maxEndoLead() const override;
   int maxExoLead() const override;
   int maxEndoLag() const override;
@@ -820,7 +820,7 @@ public:
   SymbolType get_type() const;
   void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
   BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override;
-  expr_t getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables) override;
+  expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override;
   int maxEndoLead() const override;
   int maxExoLead() const override;
   int maxEndoLag() const override;
@@ -921,7 +921,7 @@ public:
   void computeXrefs(EquationInfo &ei) const override;
   void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
   BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override;
-  expr_t getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables) override;
+  expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override;
   int maxEndoLead() const override;
   int maxExoLead() const override;
   int maxEndoLag() const override;
@@ -1028,7 +1028,7 @@ public:
   //! Try to normalize an equation with respect to a given dynamic variable.
   /*! Should only be called on Equal nodes. The variable must appear in the equation. */
   BinaryOpNode *normalizeEquation(int symb_id, int lag) const;
-  expr_t getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables) override;
+  expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override;
   int maxEndoLead() const override;
   int maxExoLead() const override;
   int maxEndoLag() const override;
@@ -1160,7 +1160,7 @@ public:
   void computeXrefs(EquationInfo &ei) const override;
   void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
   BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override;
-  expr_t getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables) override;
+  expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override;
   int maxEndoLead() const override;
   int maxExoLead() const override;
   int maxEndoLag() const override;
@@ -1272,7 +1272,7 @@ public:
   void computeXrefs(EquationInfo &ei) const override = 0;
   void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
   BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override;
-  expr_t getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables) override;
+  expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override;
   int maxEndoLead() const override;
   int maxExoLead() const override;
   int maxEndoLag() const override;
@@ -1458,7 +1458,7 @@ public:
   expr_t decreaseLeadsLags(int n) const override;
   void prepareForDerivation() override;
   expr_t computeDerivative(int deriv_id) override;
-  expr_t getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables) override;
+  expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override;
   bool containsExternalFunction() const override;
   double eval(const eval_context_t &eval_context) const noexcept(false) override;
   void computeXrefs(EquationInfo &ei) const override;
@@ -1531,7 +1531,7 @@ public:
   expr_t decreaseLeadsLags(int n) const override;
   void prepareForDerivation() override;
   expr_t computeDerivative(int deriv_id) override;
-  expr_t getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables) override;
+  expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override;
   bool containsExternalFunction() const override;
   double eval(const eval_context_t &eval_context) const noexcept(false) override;
   void computeXrefs(EquationInfo &ei) const override;
diff --git a/src/StaticModel.cc b/src/StaticModel.cc
index 68e6e4182bcff63aefbc75a454326d55ccd20996..cd506fed7dc140a3cd926945974db27057f3d6ef 100644
--- a/src/StaticModel.cc
+++ b/src/StaticModel.cc
@@ -2065,7 +2065,7 @@ StaticModel::computeChainRuleJacobian()
     {
       int nb_recursives = blocks[blk].getRecursiveSize();
 
-      map<int, expr_t> recursive_vars;
+      map<int, BinaryOpNode *> recursive_vars;
       for (int i = 0; i < nb_recursives; i++)
         {
           int deriv_id = getDerivID(symbol_table.getID(SymbolType::endogenous, getBlockVariableID(blk, i)), 0);