diff --git a/src/DynamicModel.cc b/src/DynamicModel.cc index 4dfefec017835e9a976b61ff45066dfc55a8c99b..e65a6bb22e37a17c874e381d9e1a0c4a71fc117f 100644 --- a/src/DynamicModel.cc +++ b/src/DynamicModel.cc @@ -3983,7 +3983,7 @@ DynamicModel::computePacModelConsistentExpectationSubstitution(const string &nam { expr_t this_diff_node = AddDiff(create_target_lag(i)); int symb_id = symbol_table.addDiffLeadAuxiliaryVar(this_diff_node->idx, this_diff_node, - last_aux_var->symb_id, last_aux_var->lag); + last_aux_var->symb_id, 1); VariableNode *current_aux_var = AddVariable(symb_id); auto neweq = AddEqual(current_aux_var, AddVariable(last_aux_var->symb_id, 1)); addEquation(neweq, -1); diff --git a/src/ExprNode.cc b/src/ExprNode.cc index c88c36be39b3b16772a7074038c305c43a9aa8f3..579fc5a3a18df1e23aa353a79fc636f4a8e72285 100644 --- a/src/ExprNode.cc +++ b/src/ExprNode.cc @@ -312,18 +312,12 @@ ExprNode::fillErrorCorrectionRow(int eqn, continue; } - // Helper function - auto one_step_orig = [this](int symb_id) { - return datatree.symbol_table.isAuxiliaryVariable(symb_id) ? - datatree.symbol_table.getOrigSymbIdForDiffAuxVar(symb_id) : symb_id; - }; - /* Verify that all variables belong to the error-correction term. FIXME: same remark as above about skipping terms. */ bool not_ec = false; for (const auto &t : m.second) { - int vid = one_step_orig(get<0>(t)); + auto [vid, vlag] = datatree.symbol_table.unrollDiffLeadLagChain(get<0>(t), get<1>(t)); not_ec = not_ec || (find(target_lhs.begin(), target_lhs.end(), vid) == target_lhs.end() && find(nontarget_lhs.begin(), nontarget_lhs.end(), vid) == nontarget_lhs.end()); } @@ -333,8 +327,7 @@ ExprNode::fillErrorCorrectionRow(int eqn, // Now fill the matrices for (auto [var_id, lag, param_id, constant] : m.second) { - int orig_vid = one_step_orig(var_id); - int orig_lag = datatree.symbol_table.isAuxiliaryVariable(var_id) ? -datatree.symbol_table.getOrigLeadLagForDiffAuxVar(var_id) : lag; + auto [orig_vid, orig_lag] = datatree.symbol_table.unrollDiffLeadLagChain(var_id, lag); if (find(target_lhs.begin(), target_lhs.end(), orig_vid) == target_lhs.end()) { // This an LHS variable, so fill A0 @@ -3418,10 +3411,10 @@ UnaryOpNode::substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t { if (i == last_index) symb_id = datatree.symbol_table.addDiffLagAuxiliaryVar(argsubst->idx, rit->second, - last_aux_var->symb_id, last_aux_var->lag - 1); + last_aux_var->symb_id, -1); else symb_id = datatree.symbol_table.addDiffLagAuxiliaryVar(new_aux_var->idx, rit->second, - last_aux_var->symb_id, last_aux_var->lag - 1); + last_aux_var->symb_id, -1); new_aux_var = datatree.AddVariable(symb_id, 0); neweqs.push_back(datatree.AddEqual(new_aux_var, @@ -5332,8 +5325,8 @@ BinaryOpNode::getPacAREC(int lhs_symb_id, int lhs_orig_symb_id, exit(EXIT_FAILURE); } - if (int vidorig = datatree.symbol_table.getUltimateOrigSymbID(vid); - vidorig == lhs_symb_id || vidorig == lhs_orig_symb_id) + if (auto [vidorig, vlagorig] = datatree.symbol_table.unrollDiffLeadLagChain(vid, vlag); + vidorig == lhs_symb_id) { // This is an autoregressive term if (constant != 1 || pid == -1 || !datatree.symbol_table.isDiffAuxiliaryVariable(vid)) @@ -5341,10 +5334,9 @@ BinaryOpNode::getPacAREC(int lhs_symb_id, int lhs_orig_symb_id, cerr << "BinaryOpNode::getPacAREC: autoregressive terms must be of the form 'parameter*diff_lagged_variable" << endl; exit(EXIT_FAILURE); } - int ar_lag = datatree.symbol_table.getOrigLeadLagForDiffAuxVar(vid); - if (static_cast<int>(ar_params_and_vars.size()) < ar_lag) - ar_params_and_vars.resize(ar_lag, { -1, -1, 0 }); - ar_params_and_vars[ar_lag-1] = { pid, vid, vlag }; + if (static_cast<int>(ar_params_and_vars.size()) < -vlagorig) + ar_params_and_vars.resize(-vlagorig, { -1, -1, 0 }); + ar_params_and_vars[-vlagorig-1] = { pid, vid, vlag }; } else // This is a residual additive term @@ -5533,11 +5525,7 @@ BinaryOpNode::fillAutoregressiveRow(int eqn, const vector<int> &lhs, map<tuple<i continue; } - if (datatree.symbol_table.isDiffAuxiliaryVariable(vid)) - { - lag = -datatree.symbol_table.getOrigLeadLagForDiffAuxVar(vid); - vid = datatree.symbol_table.getOrigSymbIdForDiffAuxVar(vid); - } + tie(vid, lag) = datatree.symbol_table.unrollDiffLeadLagChain(vid, lag); if (find(lhs.begin(), lhs.end(), vid) == lhs.end()) continue; diff --git a/src/SymbolTable.cc b/src/SymbolTable.cc index e9899fcaadd3e5567b441525a94937e793559f7c..87b7daafe3a5dda4a9b80a51cd5f7902ed7201bd 100644 --- a/src/SymbolTable.cc +++ b/src/SymbolTable.cc @@ -733,27 +733,17 @@ SymbolTable::getOrigSymbIdForAuxVar(int aux_var_symb_id) const noexcept(false) throw UnknownSymbolIDException(aux_var_symb_id); } -int -SymbolTable::getOrigLeadLagForDiffAuxVar(int diff_aux_var_symb_id) const noexcept(false) +pair<int, int> +SymbolTable::unrollDiffLeadLagChain(int symb_id, int lag) const noexcept(false) { for (const auto &aux_var : aux_vars) - if ((aux_var.get_type() == AuxVarType::diffLag || aux_var.get_type() == AuxVarType::diffLead) - && aux_var.get_symb_id() == diff_aux_var_symb_id) - return (aux_var.get_type() == AuxVarType::diffLag ? 1 : -1) + getOrigLeadLagForDiffAuxVar(aux_var.get_orig_symb_id()); - return 0; -} - -int -SymbolTable::getOrigSymbIdForDiffAuxVar(int diff_aux_var_symb_id) const noexcept(false) -{ - int orig_symb_id = -1; - for (const auto &aux_var : aux_vars) - if (aux_var.get_symb_id() == diff_aux_var_symb_id) - if (aux_var.get_type() == AuxVarType::diff) - orig_symb_id = diff_aux_var_symb_id; - else if (aux_var.get_type() == AuxVarType::diffLag || aux_var.get_type() == AuxVarType::diffLead) - orig_symb_id = getOrigSymbIdForDiffAuxVar(aux_var.get_orig_symb_id()); - return orig_symb_id; + if (aux_var.get_symb_id() == symb_id) + if (aux_var.get_type() == AuxVarType::diffLag || aux_var.get_type() == AuxVarType::diffLead) + { + auto [orig_symb_id, orig_lag] = unrollDiffLeadLagChain(aux_var.get_orig_symb_id(), lag); + return { orig_symb_id, orig_lag + aux_var.get_orig_lead_lag() }; + } + return { symb_id, lag }; } expr_t diff --git a/src/SymbolTable.hh b/src/SymbolTable.hh index adab581a4ef2d76516e4745683fee35093e93667..a0f4f3252f3cd5112e4b6dc8de6513e7abba7943 100644 --- a/src/SymbolTable.hh +++ b/src/SymbolTable.hh @@ -325,10 +325,18 @@ public: N.B.: some code might rely on the fact that, in particular, it does not work on unaryOp type (to be verified) */ int getOrigSymbIdForAuxVar(int aux_var_symb_id) const noexcept(false); - //! Searches for diff aux var and finds the original lag associated with this variable - int getOrigLeadLagForDiffAuxVar(int diff_aux_var_symb_id) const noexcept(false); - //! Searches for diff aux var and finds the symb id associated with this variable - int getOrigSymbIdForDiffAuxVar(int diff_aux_var_symb_id) const noexcept(false); + /* Unrolls a chain of diffLag or diffLead aux vars until it founds a (regular) diff aux + var. In other words: + - if the arg is a (regu) diff aux var, returns the arg + - if the arg is a diffLag/diffLead, get its orig symb ID, and call the + method recursively + - if the arg is something else, throw an UnknownSymbolIDException + exception + The 2nd input/output arguments are used to track leads/lags. The 2nd + output argument is equal to the 2nd input argument, shifted by as many + lead/lags were encountered in the chain (a diffLag decreases it, a + diffLead increases it). */ + pair<int, int> unrollDiffLeadLagChain(int symb_id, int lag) const noexcept(false); //! Adds an auxiliary variable when the diff operator is encountered int addDiffAuxiliaryVar(int index, expr_t expr_arg) noexcept(false); int addDiffAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_id, int orig_lag) noexcept(false);