From 5f4bed9253404494d219d9147a7f6226ab52e99f Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?S=C3=A9bastien=20Villemot?= <sebastien@dynare.org>
Date: Fri, 24 Apr 2020 17:09:28 +0200
Subject: [PATCH] VariableNode::getChainRuleDerivative: do not cache values in
 ExprNode::derivatives

This field is used for standard derivatives. Using it also for chain rule
derivatives can only lead to wrong results.

(manually cherry picked from commit 45b260cf20496e219609d17a46d1b94de1d41be0)
---
 src/ExprNode.cc | 30 +++++++++---------------------
 1 file changed, 9 insertions(+), 21 deletions(-)

diff --git a/src/ExprNode.cc b/src/ExprNode.cc
index fcc64733..89e0117f 100644
--- a/src/ExprNode.cc
+++ b/src/ExprNode.cc
@@ -1405,29 +1405,17 @@ VariableNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recur
     case SymbolType::logTrend:
       if (deriv_id == datatree.getDerivID(symb_id, lag))
         return datatree.One;
-      else
+      // If there is in the equation a recursive variable we could use a chaine rule derivation
+      else if (auto it = recursive_variables.find(datatree.getDerivID(symb_id, lag));
+               it != recursive_variables.end())
         {
-          //if there is in the equation a recursive variable we could use a chaine rule derivation
-          if (auto it = recursive_variables.find(datatree.getDerivID(symb_id, lag));
-              it != recursive_variables.end())
-            {
-              if (auto it2 = derivatives.find(deriv_id);
-                  it2 != derivatives.end())
-                return it2->second;
-              else
-                {
-                  map<int, expr_t> recursive_vars2(recursive_variables);
-                  recursive_vars2.erase(it->first);
-                  //expr_t c = datatree.AddNonNegativeConstant("1");
-                  expr_t d = datatree.AddUMinus(it->second->getChainRuleDerivative(deriv_id, recursive_vars2));
-                  //d = datatree.AddTimes(c, d);
-                  derivatives[deriv_id] = d;
-                  return d;
-                }
-            }
-          else
-            return datatree.Zero;
+          map<int, expr_t> recursive_vars2(recursive_variables);
+          recursive_vars2.erase(it->first);
+          return datatree.AddUMinus(it->second->getChainRuleDerivative(deriv_id, recursive_vars2));
         }
+      else
+        return datatree.Zero;
+
     case SymbolType::modelLocalVariable:
       return datatree.getLocalVariable(symb_id)->getChainRuleDerivative(deriv_id, recursive_variables);
     case SymbolType::modFileLocalVariable:
-- 
GitLab