From 56581b1dd46315cefb62c906165942883f1e1570 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Villemot?= <sebastien@dynare.org> Date: Mon, 19 Oct 2020 18:26:36 +0200 Subject: [PATCH] PAC model: make detection of non-optimizing part more robust Introduce a new method for decomposing a product of factors, so that we can identify expressions of the form (1-optim_share)*A*B. Also enforce that the optim_share parameter be in a factor of the form 1-optim_share (previously it would accept any expression containing the parameter). Note that this fix does not yet allow to actually write non-optimizing parts of the form (1-optim_share)*A*B, since at a later point the preprocessor imposes that this part be a linear combination of variables (but in the future we could think of expanding the A*B product into a linear combination if, for example, A is a paramater or a constant and B is a linear combination). Closes: #50 --- src/ExprNode.cc | 73 +++++++++++++++++++++++++++++++++++-------------- src/ExprNode.hh | 11 ++++++-- 2 files changed, 61 insertions(+), 23 deletions(-) diff --git a/src/ExprNode.cc b/src/ExprNode.cc index 4bb67ba7..a9df2293 100644 --- a/src/ExprNode.cc +++ b/src/ExprNode.cc @@ -5346,31 +5346,41 @@ BinaryOpNode::isParamTimesEndogExpr() const return false; } -bool -BinaryOpNode::getPacNonOptimizingPartHelper(BinaryOpNode *bopn, int optim_share) const -{ - set<int> params; - bopn->collectVariables(SymbolType::parameter, params); - if (params.size() == 1 && *params.begin() == optim_share) - return true; - return false; -} - expr_t -BinaryOpNode::getPacNonOptimizingPart(BinaryOpNode *bopn, int optim_share) const +BinaryOpNode::getPacNonOptimizingPart(int optim_share_symb_id) const { - auto a1 = dynamic_cast<BinaryOpNode *>(bopn->arg1); - auto a2 = dynamic_cast<BinaryOpNode *>(bopn->arg2); - if (!a1 && !a2) - return nullptr; + vector<pair<expr_t, int>> factors; + decomposeMultiplicativeFactors(factors); - if (a1 && getPacNonOptimizingPartHelper(a1, optim_share)) - return bopn->arg2; + // Search for a factor of the form 1-optim_share + expr_t one_minus_optim_share = nullptr; + for (auto [factor, exponent] : factors) + { + auto bopn = dynamic_cast<BinaryOpNode *>(factor); + if (exponent != 1 || !bopn || bopn->op_code != BinaryOpcode::minus) + continue; + auto arg1 = dynamic_cast<NumConstNode *>(bopn->arg1); + auto arg2 = dynamic_cast<VariableNode *>(bopn->arg2); + if (arg1 && arg2 && arg1->eval({}) == 1 && arg2->symb_id == optim_share_symb_id) + { + one_minus_optim_share = factor; + break; + } + } - if (a2 && getPacNonOptimizingPartHelper(a2, optim_share)) - return bopn->arg1; + if (!one_minus_optim_share) + return nullptr; - return nullptr; + // Construct the product formed by the other factors and return it + expr_t non_optim_part = datatree.One; + for (auto [factor, exponent] : factors) + if (factor != one_minus_optim_share) + if (exponent == 1) + non_optim_part = datatree.AddTimes(non_optim_part, factor); + else + non_optim_part = datatree.AddDivide(non_optim_part, factor); + + return non_optim_part; } pair<int, expr_t> @@ -5433,7 +5443,7 @@ BinaryOpNode::getPacOptimizingShareAndExprNodes(int lhs_symb_id, int lhs_orig_sy for (auto it = terms.begin(); it != terms.end(); ++it) if (auto bopn = dynamic_cast<BinaryOpNode *>(it->first); bopn) { - non_optim_part = getPacNonOptimizingPart(bopn, optim_share); + non_optim_part = bopn->getPacNonOptimizingPart(optim_share); if (non_optim_part) { terms.erase(it); @@ -8797,6 +8807,27 @@ BinaryOpNode::decomposeAdditiveTerms(vector<pair<expr_t, int>> &terms, int curre ExprNode::decomposeAdditiveTerms(terms, current_sign); } +void +ExprNode::decomposeMultiplicativeFactors(vector<pair<expr_t, int>> &factors, int current_exponent) const +{ + factors.emplace_back(const_cast<ExprNode *>(this), current_exponent); +} + +void +BinaryOpNode::decomposeMultiplicativeFactors(vector<pair<expr_t, int>> &factors, int current_exponent) const +{ + if (op_code == BinaryOpcode::times || op_code == BinaryOpcode::divide) + { + arg1->decomposeMultiplicativeFactors(factors, current_exponent); + if (op_code == BinaryOpcode::times) + arg2->decomposeMultiplicativeFactors(factors, current_exponent); + else + arg2->decomposeMultiplicativeFactors(factors, -current_exponent); + } + else + ExprNode::decomposeMultiplicativeFactors(factors, current_exponent); +} + tuple<int, int, int, double> ExprNode::matchVariableTimesConstantTimesParam(bool variable_obligatory) const { diff --git a/src/ExprNode.hh b/src/ExprNode.hh index 32d7d510..9f8bdd2f 100644 --- a/src/ExprNode.hh +++ b/src/ExprNode.hh @@ -672,6 +672,13 @@ public: If current_sign == -1, then all signs are inverted */ virtual void decomposeAdditiveTerms(vector<pair<expr_t, int>> &terms, int current_sign = 1) const; + //! Decompose an expression into its multiplicative factors + /*! Returns a list of factors, with their exponents (either 1 or -1, depending + on whether the factors appear at the numerator or the denominator). + The current_exponent argument should normally be left to 1. + If current_exponent == -1, then all exponents are inverted */ + virtual void decomposeMultiplicativeFactors(vector<pair<expr_t, int>> &factors, int current_exponent = 1) const; + // Matches an expression of the form variable*constant*parameter /* Returns a tuple (variable_id, lag, param_id, constant). The variable must be an exogenous or an endogenous. @@ -1093,8 +1100,7 @@ public: //! and the expr node associated with the non-optimizing part tuple<int, expr_t, expr_t, expr_t> getPacOptimizingShareAndExprNodes(int lhs_symb_id, int lhs_orig_symb_id) const; pair<int, expr_t> getPacOptimizingShareAndExprNodesHelper(int lhs_symb_id, int lhs_orig_symb_id) const; - expr_t getPacNonOptimizingPart(BinaryOpNode *bopn, int optim_share) const; - bool getPacNonOptimizingPartHelper(BinaryOpNode *bopn, int optim_share) const; + expr_t getPacNonOptimizingPart(int optim_share_symb_id) const; bool isParamTimesEndogExpr() const override; bool isVarModelReferenced(const string &model_info_name) const override; void getEndosAndMaxLags(map<string, int> &model_endos_and_lags) const override; @@ -1103,6 +1109,7 @@ public: //! Substitute auxiliary variables by their expression in static model auxiliary variable definition expr_t substituteStaticAuxiliaryDefinition() const; void decomposeAdditiveTerms(vector<pair<expr_t, int>> &terms, int current_sign) const override; + void decomposeMultiplicativeFactors(vector<pair<expr_t, int>> &factors, int current_exponent = 1) const override; void matchMatchedMoment(vector<int> &symb_ids, vector<int> &lags, vector<int> &powers) const override; }; -- GitLab