From 5f4bed9253404494d219d9147a7f6226ab52e99f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Villemot?= <sebastien@dynare.org> Date: Fri, 24 Apr 2020 17:09:28 +0200 Subject: [PATCH] VariableNode::getChainRuleDerivative: do not cache values in ExprNode::derivatives This field is used for standard derivatives. Using it also for chain rule derivatives can only lead to wrong results. (manually cherry picked from commit 45b260cf20496e219609d17a46d1b94de1d41be0) --- src/ExprNode.cc | 30 +++++++++--------------------- 1 file changed, 9 insertions(+), 21 deletions(-) diff --git a/src/ExprNode.cc b/src/ExprNode.cc index fcc64733..89e0117f 100644 --- a/src/ExprNode.cc +++ b/src/ExprNode.cc @@ -1405,29 +1405,17 @@ VariableNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recur case SymbolType::logTrend: if (deriv_id == datatree.getDerivID(symb_id, lag)) return datatree.One; - else + // If there is in the equation a recursive variable we could use a chaine rule derivation + else if (auto it = recursive_variables.find(datatree.getDerivID(symb_id, lag)); + it != recursive_variables.end()) { - //if there is in the equation a recursive variable we could use a chaine rule derivation - if (auto it = recursive_variables.find(datatree.getDerivID(symb_id, lag)); - it != recursive_variables.end()) - { - if (auto it2 = derivatives.find(deriv_id); - it2 != derivatives.end()) - return it2->second; - else - { - map<int, expr_t> recursive_vars2(recursive_variables); - recursive_vars2.erase(it->first); - //expr_t c = datatree.AddNonNegativeConstant("1"); - expr_t d = datatree.AddUMinus(it->second->getChainRuleDerivative(deriv_id, recursive_vars2)); - //d = datatree.AddTimes(c, d); - derivatives[deriv_id] = d; - return d; - } - } - else - return datatree.Zero; + map<int, expr_t> recursive_vars2(recursive_variables); + recursive_vars2.erase(it->first); + return datatree.AddUMinus(it->second->getChainRuleDerivative(deriv_id, recursive_vars2)); } + else + return datatree.Zero; + case SymbolType::modelLocalVariable: return datatree.getLocalVariable(symb_id)->getChainRuleDerivative(deriv_id, recursive_variables); case SymbolType::modFileLocalVariable: -- GitLab