diff --git a/src/ExprNode.cc b/src/ExprNode.cc index fcc64733638fd0750bf8164200355c255cd8624c..89e0117fc3178ab5fbcacbd487b10aea99a592dc 100644 --- a/src/ExprNode.cc +++ b/src/ExprNode.cc @@ -1405,29 +1405,17 @@ VariableNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recur case SymbolType::logTrend: if (deriv_id == datatree.getDerivID(symb_id, lag)) return datatree.One; - else + // If there is in the equation a recursive variable we could use a chaine rule derivation + else if (auto it = recursive_variables.find(datatree.getDerivID(symb_id, lag)); + it != recursive_variables.end()) { - //if there is in the equation a recursive variable we could use a chaine rule derivation - if (auto it = recursive_variables.find(datatree.getDerivID(symb_id, lag)); - it != recursive_variables.end()) - { - if (auto it2 = derivatives.find(deriv_id); - it2 != derivatives.end()) - return it2->second; - else - { - map<int, expr_t> recursive_vars2(recursive_variables); - recursive_vars2.erase(it->first); - //expr_t c = datatree.AddNonNegativeConstant("1"); - expr_t d = datatree.AddUMinus(it->second->getChainRuleDerivative(deriv_id, recursive_vars2)); - //d = datatree.AddTimes(c, d); - derivatives[deriv_id] = d; - return d; - } - } - else - return datatree.Zero; + map<int, expr_t> recursive_vars2(recursive_variables); + recursive_vars2.erase(it->first); + return datatree.AddUMinus(it->second->getChainRuleDerivative(deriv_id, recursive_vars2)); } + else + return datatree.Zero; + case SymbolType::modelLocalVariable: return datatree.getLocalVariable(symb_id)->getChainRuleDerivative(deriv_id, recursive_variables); case SymbolType::modFileLocalVariable: