diff --git a/ExprNode.cc b/ExprNode.cc index 9a4ddd7168207eef8d296cd064de553f98740ee8..6176389e8321ea9871abf5677fbdc0bd69fc127b 100644 --- a/ExprNode.cc +++ b/ExprNode.cc @@ -4600,31 +4600,63 @@ BinaryOpNode::setVarExpectationIndex(map<string, pair<SymbolList, int> > &var_mo } void -BinaryOpNode::walkPacParameters(bool &pac_encountered, pair<int, int> &lhs, set<pair<int, pair<int, int> > > ¶ms_and_vals) const -{ - if (op_code == oTimes) - { - set<int> params; - set<pair<int, int> > endogs; - arg1->collectVariables(eParameter, params); - arg2->collectDynamicVariables(eEndogenous, endogs); - if (params.size() == 1 && endogs.size() == 1) - { - params_and_vals.insert(make_pair(*(params.begin()), *(endogs.begin()))); - return; - } - else +BinaryOpNode::walkPacParametersHelper(const expr_t arg1, const expr_t arg2, + pair<int, int> &lhs, + set<pair<int, pair<int, int> > > ¶ms_and_vals) const +{ + set<int> params; + set<pair<int, int> > endogs; + arg1->collectVariables(eParameter, params); + arg2->collectDynamicVariables(eEndogenous, endogs); + if (params.size() == 1) + if (endogs.size() == 1) + params_and_vals.insert(make_pair(*(params.begin()), *(endogs.begin()))); + else + if (endogs.size() == 2) { - params.clear(); - endogs.clear(); - arg1->collectDynamicVariables(eEndogenous, endogs); - arg2->collectVariables(eParameter, params); - if (params.size() == 1 && endogs.size() == 1) + BinaryOpNode *testarg2 = dynamic_cast<BinaryOpNode *>(arg2); + VariableNode *test_arg1 = dynamic_cast<VariableNode *>(testarg2->get_arg1()); + VariableNode *test_arg2 = dynamic_cast<VariableNode *>(testarg2->get_arg2()); + if (testarg2 != NULL && testarg2->get_op_code() == oMinus + && test_arg1 != NULL &&test_arg2 != NULL + && lhs.first != -1) { - params_and_vals.insert(make_pair(*(params.begin()), *(endogs.begin()))); - return; + int find_symb_id = -1; + try + { + // lhs is an aux var (diff) + find_symb_id = datatree.symbol_table.getOrigSymbIdForAuxVar(lhs.first); + } + catch (...) + { + //lhs is not an aux var + find_symb_id = lhs.first; + } + endogs.clear(); + + if (test_arg1->get_symb_id() == find_symb_id) + { + test_arg1->collectDynamicVariables(eEndogenous, endogs); + params_and_vals.insert(make_pair(*(params.begin()), *(endogs.begin()))); + } + else if (test_arg2->get_symb_id() == find_symb_id) + { + test_arg2->collectDynamicVariables(eEndogenous, endogs); + params_and_vals.insert(make_pair(*(params.begin()), *(endogs.begin()))); + } } } +} + +void +BinaryOpNode::walkPacParameters(bool &pac_encountered, pair<int, int> &lhs, set<pair<int, pair<int, int> > > ¶ms_and_vals) const +{ + if (op_code == oTimes) + { + int orig_params_and_vals_size = params_and_vals.size(); + walkPacParametersHelper(arg1, arg2, lhs, params_and_vals); + if (params_and_vals.size() == orig_params_and_vals_size) + walkPacParametersHelper(arg2, arg1, lhs, params_and_vals); } else if (op_code == oEqual) { @@ -7738,7 +7770,7 @@ PacExpectationNode::addParamInfoToPac(pair<int, int> &lhs_arg, set<pair<int, pai exit(EXIT_FAILURE); } - if (params_and_vals_arg.size() != 2) + if (params_and_vals_arg.size() != 3) { cerr << "Pac Expectation: error in obtaining RHS parameters." << endl; exit(EXIT_FAILURE); diff --git a/ExprNode.hh b/ExprNode.hh index f674887ebd4fd9693a198000dbd8e48cfaf2d6bb..cc8409adc536417a1f3947332c5af25dddcb4313 100644 --- a/ExprNode.hh +++ b/ExprNode.hh @@ -834,6 +834,9 @@ public: { return powerDerivOrder; } + void walkPacParametersHelper(const expr_t arg1, const expr_t arg2, + pair<int, int> &lhs, + set<pair<int, pair<int, int> > > ¶ms_and_vals) const; virtual expr_t toStatic(DataTree &static_datatree) const; virtual void computeXrefs(EquationInfo &ei) const; virtual pair<int, expr_t> normalizeEquation(int symb_id_endo, vector<pair<int, pair<expr_t, expr_t> > > &List_of_Op_RHS) const; diff --git a/SymbolTable.cc b/SymbolTable.cc index 993fd7b73de965efa1c81e6f3ab3972c3cceabd9..ad61e3749c2b089a88d80af2b86d7bfd2cd7365d 100644 --- a/SymbolTable.cc +++ b/SymbolTable.cc @@ -809,6 +809,16 @@ SymbolTable::searchAuxiliaryVars(int orig_symb_id, int orig_lead_lag) const thro throw SearchFailedException(orig_symb_id, orig_lead_lag); } +int +SymbolTable::getOrigSymbIdForAuxVar(int aux_var_symb_id) const throw (UnknownSymbolIDException) +{ + for (size_t i = 0; i < aux_vars.size(); i++) + if ((aux_vars[i].get_type() == avEndoLag || aux_vars[i].get_type() == avExoLag || aux_vars[i].get_type() == avDiff) + && aux_vars[i].get_symb_id() == aux_var_symb_id) + return aux_vars[i].get_orig_symb_id(); + throw UnknownSymbolIDException(aux_var_symb_id); +} + expr_t SymbolTable::getAuxiliaryVarsExprNode(int symb_id) const throw (SearchFailedException) // throw exception if it is a Lagrange multiplier diff --git a/SymbolTable.hh b/SymbolTable.hh index b754bd96a64454732342608074acb522813fcb3f..62804a59b644dc148c5a31749b44e1a7897f47e4 100644 --- a/SymbolTable.hh +++ b/SymbolTable.hh @@ -281,6 +281,8 @@ public: Throws an exception if match not found. */ int searchAuxiliaryVars(int orig_symb_id, int orig_lead_lag) const throw (SearchFailedException); + //! Serches aux_vars for the aux var represented by aux_var_symb_id and returns its associated orig_symb_id + int getOrigSymbIdForAuxVar(int aux_var_symb_id) const throw (UnknownSymbolIDException); //! Adds an auxiliary variable when var_model is used with an order that is greater in absolute value //! than the largest lag present in the model. int addVarModelEndoLagAuxiliaryVar(int orig_symb_id, int orig_lead_lag, expr_t expr_arg) throw (AlreadyDeclaredException, FrozenException);