diff --git a/src/ModelTree.cc b/src/ModelTree.cc index 266e965a7e05fbfee3a6eff1c8a7a18979fe1103..a25b15398bb76104ae3e95f4919e25c0d31c3bf4 100644 --- a/src/ModelTree.cc +++ b/src/ModelTree.cc @@ -653,42 +653,36 @@ ModelTree::computePrologueAndEpilogue(const jacob_map_t &static_jacobian) void ModelTree::equationTypeDetermination(const map<tuple<int, int, int>, expr_t> &first_order_endo_derivatives, int mfs) { - expr_t lhs; - BinaryOpNode *eq_node; - EquationType Equation_Simulation_Type; equation_type_and_normalized_equation.clear(); equation_type_and_normalized_equation.resize(equations.size()); for (int i = 0; i < static_cast<int>(equations.size()); i++) { int eq = equation_reordered[i]; int var = variable_reordered[i]; - eq_node = equations[eq]; - lhs = eq_node->arg1; - Equation_Simulation_Type = EquationType::solve; - pair<bool, expr_t> res; - if (auto derivative = first_order_endo_derivatives.find({ eq, var, 0 }); - derivative != first_order_endo_derivatives.end()) + expr_t lhs = equations[eq]->arg1; + EquationType Equation_Simulation_Type = EquationType::solve; + pair<int, expr_t> res; + if (auto it = first_order_endo_derivatives.find({ eq, var, 0 }); + it != first_order_endo_derivatives.end()) { - set<pair<int, int>> result; - derivative->second->collectEndogenous(result); - auto d_endo_variable = result.find({ var, 0 }); - //Determine whether the equation could be evaluated rather than to be solved - if (lhs->isVariableNodeEqualTo(SymbolType::endogenous, variable_reordered[i], 0) && derivative->second->isNumConstNodeEqualTo(1)) + expr_t derivative = it->second; + // Determine whether the equation can be evaluated rather than solved + if (lhs->isVariableNodeEqualTo(SymbolType::endogenous, variable_reordered[i], 0) + && derivative->isNumConstNodeEqualTo(1)) Equation_Simulation_Type = EquationType::evaluate; else { + set<pair<int, int>> result; + derivative->collectEndogenous(result); + bool variable_not_in_derivative = result.find({ var, 0 }) == result.end(); + vector<tuple<int, expr_t, expr_t>> List_of_Op_RHS; res = equations[eq]->normalizeEquation(var, List_of_Op_RHS); - if (mfs == 2) - { - if (d_endo_variable == result.end() && res.second) - Equation_Simulation_Type = EquationType::evaluate_s; - } - else if (mfs == 3) - { - if (res.second) // The equation could be solved analytically - Equation_Simulation_Type = EquationType::evaluate_s; - } + + if (mfs == 2 && variable_not_in_derivative && res.second) + Equation_Simulation_Type = EquationType::evaluate_s; + else if (mfs == 3 && res.second) // The equation could be solved analytically + Equation_Simulation_Type = EquationType::evaluate_s; } } equation_type_and_normalized_equation[eq] = { Equation_Simulation_Type, dynamic_cast<BinaryOpNode *>(res.second) };