From 9b98da424a6d6e35b667f3aa99637b6ae232eee4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Villemot?= <sebastien@dynare.org> Date: Tue, 19 Mar 2019 18:31:27 +0100 Subject: [PATCH] Simplification in BinaryOpNode::getPacAREC() --- src/ExprNode.cc | 84 +++++++++++++++++++------------------------------ 1 file changed, 33 insertions(+), 51 deletions(-) diff --git a/src/ExprNode.cc b/src/ExprNode.cc index f7e146ca..c4ff0ce8 100644 --- a/src/ExprNode.cc +++ b/src/ExprNode.cc @@ -5617,63 +5617,45 @@ BinaryOpNode::getPacAREC(int lhs_symb_id, int lhs_orig_symb_id, for (const auto & it : terms) { - auto bopn = dynamic_cast<BinaryOpNode *>(it.first); - auto pen = dynamic_cast<PacExpectationNode *>(it.first); - if (pen) + if (dynamic_cast<PacExpectationNode *>(it.first)) continue; - if (bopn != nullptr) + + int pid, vid, lag; + double constant; + try + { + tie(vid, lag, pid, constant) = it.first->matchVariableTimesConstantTimesParam(); + constant *= it.second; + } + catch (MatchFailureException &e) { - auto vn1 = dynamic_cast<VariableNode *>(bopn->arg1); - auto vn2 = dynamic_cast<VariableNode *>(bopn->arg2); - if (vn1 && vn2) + cerr << "Unsupported expression in PAC equation" << endl; + exit(EXIT_FAILURE); + } + + int vidorig = vid; + while (datatree.symbol_table.isAuxiliaryVariable(vid)) + try + { + vid = datatree.symbol_table.getOrigSymbIdForAuxVar(vid); + } + catch (...) + { + break; + } + if (vid == lhs_symb_id || vid == lhs_orig_symb_id) + { + // This is an autoregressive term + if (constant != 1 || pid == -1) { - int pid, vid, lag; - pid = vid = lag = -1; - if (datatree.symbol_table.getType(vn1->symb_id) == SymbolType::parameter - && (datatree.symbol_table.getType(vn2->symb_id) == SymbolType::endogenous - || datatree.symbol_table.getType(vn2->symb_id) == SymbolType::exogenous)) - { - pid = vn1->symb_id; - vid = vn2->symb_id; - lag = vn2->lag; - } - else if (datatree.symbol_table.getType(vn2->symb_id) == SymbolType::parameter - && (datatree.symbol_table.getType(vn1->symb_id) == SymbolType::endogenous - || datatree.symbol_table.getType(vn1->symb_id) == SymbolType::exogenous)) - { - pid = vn2->symb_id; - vid = vn1->symb_id; - lag = vn1->lag; - } - if (pid >= 0 && vid >= 0) - { - int vidorig = vid; - while (datatree.symbol_table.isAuxiliaryVariable(vid)) - try - { - vid = datatree.symbol_table.getOrigSymbIdForAuxVar(vid); - } - catch (...) - { - break; - } - if (vid == lhs_symb_id || vid == lhs_orig_symb_id) - ar_params_and_vars.insert({pid, {vidorig, lag}}); - else - { - auto m = it.first->matchVariableTimesConstantTimesParam(); - get<3>(m) *= it.second; - additive_vars_params_and_constants.push_back(m); - } - } + cerr << "BinaryOpNode::getPacAREC: autoregressive terms must be of the form 'parameter*lagged_variable" << endl; + exit(EXIT_FAILURE); } + ar_params_and_vars.insert({pid, { vidorig, lag }}); } else - { - auto m = it.first->matchVariableTimesConstantTimesParam(); - get<3>(m) *= it.second; - additive_vars_params_and_constants.push_back(m); - } + // This is a residual additive term + additive_vars_params_and_constants.push_back({ vidorig, lag, pid, constant}); } } -- GitLab