From 45b260cf20496e219609d17a46d1b94de1d41be0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Villemot?= <sebastien@dynare.org> Date: Fri, 24 Apr 2020 17:09:28 +0200 Subject: [PATCH] VariableNode::getChainRuleDerivative: do not cache values in ExprNode::derivatives This field is used for standard derivatives. Using it also for chain rule derivatives can only lead to wrong results. --- src/ExprNode.cc | 28 +++++++++------------------- 1 file changed, 9 insertions(+), 19 deletions(-) diff --git a/src/ExprNode.cc b/src/ExprNode.cc index 6fb3c586..7de1de68 100644 --- a/src/ExprNode.cc +++ b/src/ExprNode.cc @@ -1387,27 +1387,17 @@ VariableNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recur case SymbolType::logTrend: if (deriv_id == datatree.getDerivID(symb_id, lag)) return datatree.One; - else + // If there is in the equation a recursive variable we could use a chaine rule derivation + else if (auto it = recursive_variables.find(datatree.getDerivID(symb_id, lag)); + it != recursive_variables.end()) { - // If there is in the equation a recursive variable we could use a chaine rule derivation - if (auto it = recursive_variables.find(datatree.getDerivID(symb_id, lag)); - it != recursive_variables.end()) - { - if (auto it2 = derivatives.find(deriv_id); - it2 != derivatives.end()) - return it2->second; - else - { - map<int, expr_t> recursive_vars2(recursive_variables); - recursive_vars2.erase(it->first); - expr_t d = datatree.AddUMinus(it->second->getChainRuleDerivative(deriv_id, recursive_vars2)); - derivatives[deriv_id] = d; - return d; - } - } - else - return datatree.Zero; + map<int, expr_t> recursive_vars2(recursive_variables); + recursive_vars2.erase(it->first); + return datatree.AddUMinus(it->second->getChainRuleDerivative(deriv_id, recursive_vars2)); } + else + return datatree.Zero; + case SymbolType::modelLocalVariable: return datatree.getLocalVariable(symb_id)->getChainRuleDerivative(deriv_id, recursive_variables); case SymbolType::modFileLocalVariable: -- GitLab