diff --git a/preprocessor/DynamicModel.cc b/preprocessor/DynamicModel.cc
index 2750a2c193a8e85f03f69330268b787e95f6c739..c328d0d8cfd8149cbe13457fa2b97bbe7ea07b1f 100644
--- a/preprocessor/DynamicModel.cc
+++ b/preprocessor/DynamicModel.cc
@@ -664,9 +664,7 @@ end:
                         //cout << "+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n";
                         //cout << "derivaive eq=" << eq << " var=" << var << " k0=" << k << "\n";
                         int deriv_id = getDerivID(symbol_table.getID(eEndogenous, var),0);
-                        map<int, NodeID>  recursive_variables_save(recursive_variables);
-                        NodeID ChaineRule_Derivative = tmp_n->getChaineRuleDerivative(deriv_id ,recursive_variables, var, 0);
-                        recursive_variables = recursive_variables_save;
+                        NodeID ChaineRule_Derivative = tmp_n->getChainRuleDerivative(deriv_id, recursive_variables);
                         ChaineRule_Derivative->writeOutput(output, oMatlabDynamicModelSparse, temporary_terms);
                         output << ";";
                         output << " %2 variable=" << symbol_table.getName(symbol_table.getID(eEndogenous, var))
diff --git a/preprocessor/ExprNode.cc b/preprocessor/ExprNode.cc
index 046a10d6092ed51da06ed2028f32b31c20def086..952965143d063fc9a5d5cb67f517925872b8dafc 100644
--- a/preprocessor/ExprNode.cc
+++ b/preprocessor/ExprNode.cc
@@ -60,32 +60,6 @@ ExprNode::getDerivative(int deriv_id)
     }
 }
 
-
-NodeID
-ExprNode::getChaineRuleDerivative(int deriv_id, map<int, NodeID> &recursive_variables, int var, int lag_)
-{
-  // Return zero if derivative is necessarily null (using symbolic a priori)
-  /*set<int>::const_iterator it = non_null_derivatives.find(deriv_id);
-  if (it == non_null_derivatives.end())
-    {
-      cout << "0\n";
-      return datatree.Zero;
-    }
-  */
-
-  // If derivative is stored in cache, use the cached value, otherwise compute it (and cache it)
-  /*map<int, NodeID>::const_iterator it2 = derivatives.find(deriv_id);
-  if (it2 != derivatives.end())
-    return it2->second;
-  else*/
-    {
-      NodeID d = computeChaineRuleDerivative(deriv_id, recursive_variables, var, lag_);
-      //derivatives[deriv_id] = d;
-      return d;
-    }
-}
-
-
 int
 ExprNode::precedence(ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const
   {
@@ -213,20 +187,11 @@ NumConstNode::normalizeLinearInEndoEquation(int var_endo, NodeID Derivative) con
   }
 
 NodeID
-NumConstNode::computeChaineRuleDerivative(int deriv_id, map<int, NodeID> &recursive_variables, int var, int lag_)
+NumConstNode::getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables)
 {
   return datatree.Zero;
 }
 
-
-/*
-pair<bool, NodeID>
-NumConstNode::computeDerivativeRespectToFeedbackVariable(int equ, int var, int varr, int lag_, int max_lag, vector<int> &recursive_variables, vector<int> &feeback_variables) const
-  {
-    return(make_pair(false, datatree.Zero));
-  }
-*/
-
 NodeID
 NumConstNode::toStatic(DataTree &static_datatree) const
   {
@@ -606,7 +571,7 @@ pair<InputIterator, int> find_r ( InputIterator first, InputIterator last, const
 
 
 NodeID
-VariableNode::computeChaineRuleDerivative(int deriv_id_arg, map<int, NodeID> &recursive_variables, int var, int lag_)
+VariableNode::getChainRuleDerivative(int deriv_id_arg, const map<int, NodeID> &recursive_variables)
 {
   switch (type)
     {
@@ -614,28 +579,23 @@ VariableNode::computeChaineRuleDerivative(int deriv_id_arg, map<int, NodeID> &re
     case eExogenous:
     case eExogenousDet:
     case eParameter:
-      //cout << "deriv_id=" << deriv_id << " deriv_id_arg=" << deriv_id_arg << " symb_id=" << symb_id << " type=" << type << " lag=" << lag << " var=" << var << " lag_ = " << lag_ << "\n";
       if (deriv_id == deriv_id_arg)
         return datatree.One;
       else
         {
           //if there is in the equation a recursive variable we could use a chaine rule derivation
-          if(lag == lag_)
+          map<int, NodeID>::const_iterator it = recursive_variables.find(deriv_id);
+          if (it != recursive_variables.end())
             {
-              map<int, NodeID>::const_iterator it = recursive_variables.find(deriv_id);
-              if (it != recursive_variables.end())
-                {
-                  recursive_variables.erase(it->first);
-                  return datatree.AddUMinus(it->second->getChaineRuleDerivative(deriv_id_arg, recursive_variables, var, lag_));
-                }
-              else
-                return datatree.Zero;
+              map<int, NodeID> recursive_vars2(recursive_variables);
+              recursive_vars2.erase(it->first);
+              return datatree.AddUMinus(it->second->getChainRuleDerivative(deriv_id_arg, recursive_vars2));
             }
           else
             return datatree.Zero;
         }
     case eModelLocalVariable:
-      return datatree.local_variables_table[symb_id]->getChaineRuleDerivative(deriv_id_arg, recursive_variables, var, lag_);
+      return datatree.local_variables_table[symb_id]->getChainRuleDerivative(deriv_id_arg, recursive_variables);
     case eModFileLocalVariable:
       cerr << "ModFileLocalVariable is not derivable" << endl;
       exit(EXIT_FAILURE);
@@ -669,10 +629,8 @@ UnaryOpNode::UnaryOpNode(DataTree &datatree_arg, UnaryOpcode op_code_arg, const
 }
 
 NodeID
-UnaryOpNode::computeDerivative(int deriv_id)
+UnaryOpNode::composeDerivatives(NodeID darg)
 {
-  NodeID darg = arg->getDerivative(deriv_id);
-
   NodeID t11, t12, t13;
 
   switch (op_code)
@@ -738,6 +696,13 @@ UnaryOpNode::computeDerivative(int deriv_id)
   exit(EXIT_FAILURE);
 }
 
+NodeID
+UnaryOpNode::computeDerivative(int deriv_id)
+{
+  NodeID darg = arg->getDerivative(deriv_id);
+  return composeDerivatives(darg);
+}
+
 int
 UnaryOpNode::cost(const temporary_terms_type &temporary_terms, bool is_matlab) const
   {
@@ -1117,76 +1082,12 @@ UnaryOpNode::normalizeLinearInEndoEquation(int var_endo, NodeID Derivative) cons
 
 
 NodeID
-UnaryOpNode::computeChaineRuleDerivative(int deriv_id, map<int, NodeID> &recursive_variables, int var, int lag_)
+UnaryOpNode::getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables)
 {
-  NodeID darg = arg->getChaineRuleDerivative(deriv_id, recursive_variables, var, lag_);
-
-  NodeID t11, t12, t13;
-
-  switch (op_code)
-    {
-    case oUminus:
-      return datatree.AddUMinus(darg);
-    case oExp:
-      return datatree.AddTimes(darg, this);
-    case oLog:
-      return datatree.AddDivide(darg, arg);
-    case oLog10:
-      t11 = datatree.AddExp(datatree.One);
-      t12 = datatree.AddLog10(t11);
-      t13 = datatree.AddDivide(darg, arg);
-      return datatree.AddTimes(t12, t13);
-    case oCos:
-      t11 = datatree.AddSin(arg);
-      t12 = datatree.AddUMinus(t11);
-      return datatree.AddTimes(darg, t12);
-    case oSin:
-      t11 = datatree.AddCos(arg);
-      return datatree.AddTimes(darg, t11);
-    case oTan:
-      t11 = datatree.AddTimes(this, this);
-      t12 = datatree.AddPlus(t11, datatree.One);
-      return datatree.AddTimes(darg, t12);
-    case oAcos:
-      t11 = datatree.AddSin(this);
-      t12 = datatree.AddDivide(darg, t11);
-      return datatree.AddUMinus(t12);
-    case oAsin:
-      t11 = datatree.AddCos(this);
-      return datatree.AddDivide(darg, t11);
-    case oAtan:
-      t11 = datatree.AddTimes(arg, arg);
-      t12 = datatree.AddPlus(datatree.One, t11);
-      return datatree.AddDivide(darg, t12);
-    case oCosh:
-      t11 = datatree.AddSinh(arg);
-      return datatree.AddTimes(darg, t11);
-    case oSinh:
-      t11 = datatree.AddCosh(arg);
-      return datatree.AddTimes(darg, t11);
-    case oTanh:
-      t11 = datatree.AddTimes(this, this);
-      t12 = datatree.AddMinus(datatree.One, t11);
-      return datatree.AddTimes(darg, t12);
-    case oAcosh:
-      t11 = datatree.AddSinh(this);
-      return datatree.AddDivide(darg, t11);
-    case oAsinh:
-      t11 = datatree.AddCosh(this);
-      return datatree.AddDivide(darg, t11);
-    case oAtanh:
-      t11 = datatree.AddTimes(arg, arg);
-      t12 = datatree.AddMinus(datatree.One, t11);
-      return datatree.AddTimes(darg, t12);
-    case oSqrt:
-      t11 = datatree.AddPlus(this, this);
-      return datatree.AddDivide(darg, t11);
-    }
-  // Suppress GCC warning
-  exit(EXIT_FAILURE);
+  NodeID darg = arg->getChainRuleDerivative(deriv_id, recursive_variables);
+  return composeDerivatives(darg);
 }
 
-
 NodeID
 UnaryOpNode::toStatic(DataTree &static_datatree) const
   {
@@ -1252,11 +1153,8 @@ BinaryOpNode::BinaryOpNode(DataTree &datatree_arg, const NodeID arg1_arg,
 }
 
 NodeID
-BinaryOpNode::computeDerivative(int deriv_id)
+BinaryOpNode::composeDerivatives(NodeID darg1, NodeID darg2)
 {
-  NodeID darg1 = arg1->getDerivative(deriv_id);
-  NodeID darg2 = arg2->getDerivative(deriv_id);
-
   NodeID t11, t12, t13, t14, t15;
 
   switch (op_code)
@@ -1328,6 +1226,14 @@ BinaryOpNode::computeDerivative(int deriv_id)
   exit(EXIT_FAILURE);
 }
 
+NodeID
+BinaryOpNode::computeDerivative(int deriv_id)
+{
+  NodeID darg1 = arg1->getDerivative(deriv_id);
+  NodeID darg2 = arg2->getDerivative(deriv_id);
+  return composeDerivatives(darg1, darg2);
+}
+
 int
 BinaryOpNode::precedence(ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const
   {
@@ -1880,80 +1786,11 @@ BinaryOpNode::normalizeLinearInEndoEquation(int var_endo, NodeID Derivative) con
 
 
 NodeID
-BinaryOpNode::computeChaineRuleDerivative(int deriv_id, map<int, NodeID> &recursive_variables, int var, int lag_)
+BinaryOpNode::getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables)
 {
-  NodeID darg1 = arg1->getChaineRuleDerivative(deriv_id, recursive_variables, var, lag_);
-  NodeID darg2 = arg2->getChaineRuleDerivative(deriv_id, recursive_variables, var, lag_);
-
-  NodeID t11, t12, t13, t14, t15;
-
-  switch (op_code)
-    {
-    case oPlus:
-      return datatree.AddPlus(darg1, darg2);
-    case oMinus:
-      return datatree.AddMinus(darg1, darg2);
-    case oTimes:
-      t11 = datatree.AddTimes(darg1, arg2);
-      t12 = datatree.AddTimes(darg2, arg1);
-      return datatree.AddPlus(t11, t12);
-    case oDivide:
-      if (darg2!=datatree.Zero)
-        {
-          t11 = datatree.AddTimes(darg1, arg2);
-          t12 = datatree.AddTimes(darg2, arg1);
-          t13 = datatree.AddMinus(t11, t12);
-          t14 = datatree.AddTimes(arg2, arg2);
-          return datatree.AddDivide(t13, t14);
-        }
-      else
-        return datatree.AddDivide(darg1, arg2);
-    case oLess:
-    case oGreater:
-    case oLessEqual:
-    case oGreaterEqual:
-    case oEqualEqual:
-    case oDifferent:
-      return datatree.Zero;
-    case oPower:
-      if (darg2 == datatree.Zero)
-        {
-          if (darg1 == datatree.Zero)
-            return datatree.Zero;
-          else
-            {
-              t11 = datatree.AddMinus(arg2, datatree.One);
-              t12 = datatree.AddPower(arg1, t11);
-              t13 = datatree.AddTimes(arg2, t12);
-              return datatree.AddTimes(darg1, t13);
-            }
-        }
-      else
-        {
-          t11 = datatree.AddLog(arg1);
-          t12 = datatree.AddTimes(darg2, t11);
-          t13 = datatree.AddTimes(darg1, arg2);
-          t14 = datatree.AddDivide(t13, arg1);
-          t15 = datatree.AddPlus(t12, t14);
-          return datatree.AddTimes(t15, this);
-        }
-    case oMax:
-      t11 = datatree.AddGreater(arg1,arg2);
-      t12 = datatree.AddTimes(t11,darg1);
-      t13 = datatree.AddMinus(datatree.One,t11);
-      t14 = datatree.AddTimes(t13,darg2);
-      return datatree.AddPlus(t14,t12);
-    case oMin:
-      t11 = datatree.AddGreater(arg2,arg1);
-      t12 = datatree.AddTimes(t11,darg1);
-      t13 = datatree.AddMinus(datatree.One,t11);
-      t14 = datatree.AddTimes(t13,darg2);
-      return datatree.AddPlus(t14,t12);
-    case oEqual:
-      return datatree.AddMinus(darg1, darg2);
-    }
-  // Suppress GCC warning
-  exit(EXIT_FAILURE);
+  NodeID darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables);
+  NodeID darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables);
+  return composeDerivatives(darg1, darg2);
 }
 
 NodeID
@@ -2023,11 +1860,8 @@ TrinaryOpNode::TrinaryOpNode(DataTree &datatree_arg, const NodeID arg1_arg,
 }
 
 NodeID
-TrinaryOpNode::computeDerivative(int deriv_id)
+TrinaryOpNode::composeDerivatives(NodeID darg1, NodeID darg2, NodeID darg3)
 {
-  NodeID darg1 = arg1->getDerivative(deriv_id);
-  NodeID darg2 = arg2->getDerivative(deriv_id);
-  NodeID darg3 = arg3->getDerivative(deriv_id);
 
   NodeID t11, t12, t13, t14, t15;
 
@@ -2073,6 +1907,15 @@ TrinaryOpNode::computeDerivative(int deriv_id)
   exit(EXIT_FAILURE);
 }
 
+NodeID
+TrinaryOpNode::computeDerivative(int deriv_id)
+{
+  NodeID darg1 = arg1->getDerivative(deriv_id);
+  NodeID darg2 = arg2->getDerivative(deriv_id);
+  NodeID darg3 = arg3->getDerivative(deriv_id);
+  return composeDerivatives(darg1, darg2, darg3);
+}
+
 int
 TrinaryOpNode::precedence(ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const
   {
@@ -2297,54 +2140,12 @@ TrinaryOpNode::normalizeLinearInEndoEquation(int var_endo, NodeID Derivative) co
   }
 
 NodeID
-TrinaryOpNode::computeChaineRuleDerivative(int deriv_id, map<int, NodeID> &recursive_variables, int var, int lag_)
+TrinaryOpNode::getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables)
 {
-  NodeID darg1 = arg1->getChaineRuleDerivative(deriv_id, recursive_variables, var, lag_);
-  NodeID darg2 = arg2->getChaineRuleDerivative(deriv_id, recursive_variables, var, lag_);
-  NodeID darg3 = arg3->getChaineRuleDerivative(deriv_id, recursive_variables, var, lag_);
-
-  NodeID t11, t12, t13, t14, t15;
-
-  switch (op_code)
-    {
-    case oNormcdf:
-      // normal pdf is inlined in the tree
-      NodeID y;
-      // sqrt(2*pi)
-      t14 = datatree.AddSqrt(datatree.AddTimes(datatree.Two, datatree.Pi));
-      // x - mu
-      t12 = datatree.AddMinus(arg1,arg2);
-      // y = (x-mu)/sigma
-      y = datatree.AddDivide(t12,arg3);
-      // (x-mu)^2/sigma^2
-      t12 = datatree.AddTimes(y,y);
-      // -(x-mu)^2/sigma^2
-      t13 = datatree.AddUMinus(t12);
-      // -((x-mu)^2/sigma^2)/2
-      t12 = datatree.AddDivide(t13, datatree.Two);
-      // exp(-((x-mu)^2/sigma^2)/2)
-      t13 = datatree.AddExp(t12);
-      // derivative of a standardized normal
-      // t15 = (1/sqrt(2*pi))*exp(-y^2/2)
-      t15 = datatree.AddDivide(t13,t14);
-      // derivatives thru x
-      t11 = datatree.AddDivide(darg1,arg3);
-      // derivatives thru mu
-      t12 = datatree.AddDivide(darg2,arg3);
-      // intermediary sum
-      t14 = datatree.AddMinus(t11,t12);
-      // derivatives thru sigma
-      t11 = datatree.AddDivide(y,arg3);
-      t12 = datatree.AddTimes(t11,darg3);
-      //intermediary sum
-      t11 = datatree.AddMinus(t14,t12);
-      // total derivative:
-      // (darg1/sigma - darg2/sigma - darg3*(x-mu)/sigma^2) * t15
-      // where t15 is the derivative of a standardized normal
-      return datatree.AddTimes(t11, t15);
-    }
-  // Suppress GCC warning
-  exit(EXIT_FAILURE);
+  NodeID darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables);
+  NodeID darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables);
+  NodeID darg3 = arg3->getChainRuleDerivative(deriv_id, recursive_variables);
+  return composeDerivatives(darg1, darg2, darg3);
 }
 
 NodeID
@@ -2380,9 +2181,9 @@ UnknownFunctionNode::computeDerivative(int deriv_id)
 }
 
 NodeID
-UnknownFunctionNode::computeChaineRuleDerivative(int deriv_id, map<int, NodeID> &recursive_variables, int var, int lag_)
+UnknownFunctionNode::getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables)
 {
-  cerr << "UnknownFunctionNode::computeDerivative: operation impossible!" << endl;
+  cerr << "UnknownFunctionNode::getChainRuleDerivative: operation impossible!" << endl;
   exit(EXIT_FAILURE);
 }
 
diff --git a/preprocessor/ExprNode.hh b/preprocessor/ExprNode.hh
index 70050271fbf4d4bca926beef56e0161cb0e823c6..caff75ee755819351e66eff49d1c90b5590aa35a 100644
--- a/preprocessor/ExprNode.hh
+++ b/preprocessor/ExprNode.hh
@@ -109,10 +109,6 @@ private:
   //! Computes derivative w.r. to a derivation ID (but doesn't store it in derivatives map)
   /*! You shoud use getDerivative() to get the benefit of symbolic a priori and of caching */
   virtual NodeID computeDerivative(int deriv_id) = 0;
-  //! Computes derivative w.r. to a derivation ID and use chaine rule derivatives (but doesn't store it in derivatives map)
-  /*! You shoud use getDerivative() to get the benefit of symbolic a priori and of caching */
-  virtual NodeID computeChaineRuleDerivative(int deriv_id, map<int, NodeID> &recursive_variables, int var, int lag_) = 0;
-
 
 protected:
   //! Reference to the enclosing DataTree
@@ -140,9 +136,12 @@ public:
     For an equal node, returns the derivative of lhs minus rhs */
   NodeID getDerivative(int deriv_id);
 
-  //! Returns derivative w.r. to derivation ID and use if it possible chaine rule derivatives
-  NodeID getChaineRuleDerivative(int deriv_id, map<int, NodeID> &recursive_variables, int var, int lag_);
-
+  //! Computes derivatives by applying the chain rule for some variables
+  /*!
+    \param deriv_id The derivation ID with respect to which we are derivating
+    \param recursive_variables Contains the derivation ID for which chain rules must be applied. Keys are derivation IDs, values are equations of the form x=f(y) where x is the key variable and x doesn't appear in y
+  */
+  virtual NodeID getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables) = 0;
 
   //! Returns precedence of node
   /*! Equals 100 for constants, variables, unary ops, and temporary terms */
@@ -215,7 +214,6 @@ private:
   //! Id from numerical constants table
   const int id;
   virtual NodeID computeDerivative(int deriv_id);
-  virtual NodeID computeChaineRuleDerivative(int deriv_id, map<int, NodeID> &recursive_variables, int var, int lag_);
 public:
   NumConstNode(DataTree &datatree_arg, int id_arg);
   virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const;
@@ -226,6 +224,7 @@ public:
   virtual void compile(ofstream &CompileCode, bool lhs_rhs, const temporary_terms_type &temporary_terms, map_idx_type &map_idx) const;
   virtual NodeID toStatic(DataTree &static_datatree) const;
   virtual pair<bool, NodeID> normalizeLinearInEndoEquation(int symb_id_endo, NodeID Derivative) const;
+  virtual NodeID getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables);
 };
 
 //! Symbol or variable node
@@ -239,7 +238,6 @@ private:
   //! Derivation ID
   const int deriv_id;
   virtual NodeID computeDerivative(int deriv_id_arg);
-  virtual NodeID computeChaineRuleDerivative(int deriv_id, map<int, NodeID> &recursive_variables, int var, int lag_);
 public:
   VariableNode(DataTree &datatree_arg, int symb_id_arg, int lag_arg, int deriv_id_arg);
   virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms = temporary_terms_type()) const;
@@ -258,6 +256,7 @@ public:
   virtual NodeID toStatic(DataTree &static_datatree) const;
   int get_symb_id() const { return symb_id; };
   virtual pair<bool, NodeID> normalizeLinearInEndoEquation(int symb_id_endo, NodeID Derivative) const;
+  virtual NodeID getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables);
 };
 
 //! Unary operator node
@@ -267,9 +266,9 @@ private:
   const NodeID arg;
   const UnaryOpcode op_code;
   virtual NodeID computeDerivative(int deriv_id);
-  virtual NodeID computeChaineRuleDerivative(int deriv_id, map<int, NodeID> &recursive_variables, int var, int lag_);
-
   virtual int cost(const temporary_terms_type &temporary_terms, bool is_matlab) const;
+  //! Returns the derivative of this node if darg is the derivative of the argument
+  NodeID composeDerivatives(NodeID darg);
 public:
   UnaryOpNode(DataTree &datatree_arg, UnaryOpcode op_code_arg, const NodeID arg_arg);
   virtual void computeTemporaryTerms(map<NodeID, int> &reference_count, temporary_terms_type &temporary_terms, bool is_matlab) const;
@@ -293,6 +292,7 @@ public:
   UnaryOpcode get_op_code() const { return(op_code); };
   virtual NodeID toStatic(DataTree &static_datatree) const;
   virtual pair<bool, NodeID> normalizeLinearInEndoEquation(int symb_id_endo, NodeID Derivative) const;
+  virtual NodeID getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables);
 };
 
 //! Binary operator node
@@ -302,9 +302,9 @@ private:
   const NodeID arg1, arg2;
   const BinaryOpcode op_code;
   virtual NodeID computeDerivative(int deriv_id);
-  virtual NodeID computeChaineRuleDerivative(int deriv_id, map<int, NodeID> &recursive_variables, int var, int lag_);
-
   virtual int cost(const temporary_terms_type &temporary_terms, bool is_matlab) const;
+  //! Returns the derivative of this node if darg1 and darg2 are the derivatives of the arguments
+  NodeID composeDerivatives(NodeID darg1, NodeID darg2);
 public:
   BinaryOpNode(DataTree &datatree_arg, const NodeID arg1_arg,
                BinaryOpcode op_code_arg, const NodeID arg2_arg);
@@ -332,6 +332,7 @@ public:
   BinaryOpcode get_op_code() const { return(op_code); };
   virtual NodeID toStatic(DataTree &static_datatree) const;
   pair<bool, NodeID> normalizeLinearInEndoEquation(int symb_id_endo, NodeID Derivative) const;
+  virtual NodeID getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables);
 };
 
 //! Trinary operator node
@@ -342,9 +343,9 @@ private:
   const NodeID arg1, arg2, arg3;
   const TrinaryOpcode op_code;
   virtual NodeID computeDerivative(int deriv_id);
-  virtual NodeID computeChaineRuleDerivative(int deriv_id, map<int, NodeID> &recursive_variables, int var, int lag_);
-
   virtual int cost(const temporary_terms_type &temporary_terms, bool is_matlab) const;
+  //! Returns the derivative of this node if darg1, darg2 and darg3 are the derivatives of the arguments
+  NodeID composeDerivatives(NodeID darg1, NodeID darg2, NodeID darg3);
 public:
   TrinaryOpNode(DataTree &datatree_arg, const NodeID arg1_arg,
 		TrinaryOpcode op_code_arg, const NodeID arg2_arg, const NodeID arg3_arg);
@@ -366,6 +367,7 @@ public:
   virtual void compile(ofstream &CompileCode, bool lhs_rhs, const temporary_terms_type &temporary_terms, map_idx_type &map_idx) const;
   virtual NodeID toStatic(DataTree &static_datatree) const;
   virtual pair<bool, NodeID> normalizeLinearInEndoEquation(int symb_id_endo, NodeID Derivative) const;
+  virtual NodeID getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables);
 };
 
 //! Unknown function node
@@ -375,7 +377,6 @@ private:
   const int symb_id;
   const vector<NodeID> arguments;
   virtual NodeID computeDerivative(int deriv_id);
-  virtual NodeID computeChaineRuleDerivative(int deriv_id, map<int, NodeID> &recursive_variables, int var, int lag_);
 public:
   UnknownFunctionNode(DataTree &datatree_arg, int symb_id_arg,
                       const vector<NodeID> &arguments_arg);
@@ -395,6 +396,7 @@ public:
   virtual void compile(ofstream &CompileCode, bool lhs_rhs, const temporary_terms_type &temporary_terms, map_idx_type &map_idx) const;
   virtual NodeID toStatic(DataTree &static_datatree) const;
   virtual pair<bool, NodeID> normalizeLinearInEndoEquation(int symb_id_endo, NodeID Derivative) const;
+  virtual NodeID getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables);
 };
 
 //! For one lead/lag of one block, stores mapping of information between original model and block-decomposed model