From 56581b1dd46315cefb62c906165942883f1e1570 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?S=C3=A9bastien=20Villemot?= <sebastien@dynare.org>
Date: Mon, 19 Oct 2020 18:26:36 +0200
Subject: [PATCH] PAC model: make detection of non-optimizing part more robust

Introduce a new method for decomposing a product of factors, so that we can
identify expressions of the form (1-optim_share)*A*B.

Also enforce that the optim_share parameter be in a factor of the form
1-optim_share (previously it would accept any expression containing the
parameter).

Note that this fix does not yet allow to actually write non-optimizing parts of
the form (1-optim_share)*A*B, since at a later point the preprocessor imposes
that this part be a linear combination of variables (but in the future we could
think of expanding the A*B product into a linear combination if, for example, A
is a paramater or a constant and B is a linear combination).

Closes: #50
---
 src/ExprNode.cc | 73 +++++++++++++++++++++++++++++++++++--------------
 src/ExprNode.hh | 11 ++++++--
 2 files changed, 61 insertions(+), 23 deletions(-)

diff --git a/src/ExprNode.cc b/src/ExprNode.cc
index 4bb67ba7..a9df2293 100644
--- a/src/ExprNode.cc
+++ b/src/ExprNode.cc
@@ -5346,31 +5346,41 @@ BinaryOpNode::isParamTimesEndogExpr() const
   return false;
 }
 
-bool
-BinaryOpNode::getPacNonOptimizingPartHelper(BinaryOpNode *bopn, int optim_share) const
-{
-  set<int> params;
-  bopn->collectVariables(SymbolType::parameter, params);
-  if (params.size() == 1 && *params.begin() == optim_share)
-    return true;
-  return false;
-}
-
 expr_t
-BinaryOpNode::getPacNonOptimizingPart(BinaryOpNode *bopn, int optim_share) const
+BinaryOpNode::getPacNonOptimizingPart(int optim_share_symb_id) const
 {
-  auto a1 = dynamic_cast<BinaryOpNode *>(bopn->arg1);
-  auto a2 = dynamic_cast<BinaryOpNode *>(bopn->arg2);
-  if (!a1 && !a2)
-    return nullptr;
+  vector<pair<expr_t, int>> factors;
+  decomposeMultiplicativeFactors(factors);
 
-  if (a1 && getPacNonOptimizingPartHelper(a1, optim_share))
-    return bopn->arg2;
+  // Search for a factor of the form 1-optim_share
+  expr_t one_minus_optim_share = nullptr;
+  for (auto [factor, exponent] : factors)
+    {
+      auto bopn = dynamic_cast<BinaryOpNode *>(factor);
+      if (exponent != 1 || !bopn || bopn->op_code != BinaryOpcode::minus)
+        continue;
+      auto arg1 = dynamic_cast<NumConstNode *>(bopn->arg1);
+      auto arg2 = dynamic_cast<VariableNode *>(bopn->arg2);
+      if (arg1 && arg2 && arg1->eval({}) == 1 && arg2->symb_id == optim_share_symb_id)
+        {
+          one_minus_optim_share = factor;
+          break;
+        }
+    }
 
-  if (a2 && getPacNonOptimizingPartHelper(a2, optim_share))
-    return bopn->arg1;
+  if (!one_minus_optim_share)
+    return nullptr;
 
-  return nullptr;
+  // Construct the product formed by the other factors and return it
+  expr_t non_optim_part = datatree.One;
+  for (auto [factor, exponent] : factors)
+    if (factor != one_minus_optim_share)
+      if (exponent == 1)
+        non_optim_part = datatree.AddTimes(non_optim_part, factor);
+      else
+        non_optim_part = datatree.AddDivide(non_optim_part, factor);
+
+  return non_optim_part;
 }
 
 pair<int, expr_t>
@@ -5433,7 +5443,7 @@ BinaryOpNode::getPacOptimizingShareAndExprNodes(int lhs_symb_id, int lhs_orig_sy
   for (auto it = terms.begin(); it != terms.end(); ++it)
     if (auto bopn = dynamic_cast<BinaryOpNode *>(it->first); bopn)
       {
-        non_optim_part = getPacNonOptimizingPart(bopn, optim_share);
+        non_optim_part = bopn->getPacNonOptimizingPart(optim_share);
         if (non_optim_part)
           {
             terms.erase(it);
@@ -8797,6 +8807,27 @@ BinaryOpNode::decomposeAdditiveTerms(vector<pair<expr_t, int>> &terms, int curre
     ExprNode::decomposeAdditiveTerms(terms, current_sign);
 }
 
+void
+ExprNode::decomposeMultiplicativeFactors(vector<pair<expr_t, int>> &factors, int current_exponent) const
+{
+  factors.emplace_back(const_cast<ExprNode *>(this), current_exponent);
+}
+
+void
+BinaryOpNode::decomposeMultiplicativeFactors(vector<pair<expr_t, int>> &factors, int current_exponent) const
+{
+  if (op_code == BinaryOpcode::times || op_code == BinaryOpcode::divide)
+    {
+      arg1->decomposeMultiplicativeFactors(factors, current_exponent);
+      if (op_code == BinaryOpcode::times)
+        arg2->decomposeMultiplicativeFactors(factors, current_exponent);
+      else
+        arg2->decomposeMultiplicativeFactors(factors, -current_exponent);
+    }
+  else
+    ExprNode::decomposeMultiplicativeFactors(factors, current_exponent);
+}
+
 tuple<int, int, int, double>
 ExprNode::matchVariableTimesConstantTimesParam(bool variable_obligatory) const
 {
diff --git a/src/ExprNode.hh b/src/ExprNode.hh
index 32d7d510..9f8bdd2f 100644
--- a/src/ExprNode.hh
+++ b/src/ExprNode.hh
@@ -672,6 +672,13 @@ public:
     If current_sign == -1, then all signs are inverted */
   virtual void decomposeAdditiveTerms(vector<pair<expr_t, int>> &terms, int current_sign = 1) const;
 
+  //! Decompose an expression into its multiplicative factors
+  /*! Returns a list of factors, with their exponents (either 1 or -1, depending
+    on whether the factors appear at the numerator or the denominator).
+    The current_exponent argument should normally be left to 1.
+    If current_exponent == -1, then all exponents are inverted */
+  virtual void decomposeMultiplicativeFactors(vector<pair<expr_t, int>> &factors, int current_exponent = 1) const;
+
   // Matches an expression of the form variable*constant*parameter
   /* Returns a tuple (variable_id, lag, param_id, constant).
      The variable must be an exogenous or an endogenous.
@@ -1093,8 +1100,7 @@ public:
   //! and the expr node associated with the non-optimizing part
   tuple<int, expr_t, expr_t, expr_t> getPacOptimizingShareAndExprNodes(int lhs_symb_id, int lhs_orig_symb_id) const;
   pair<int, expr_t> getPacOptimizingShareAndExprNodesHelper(int lhs_symb_id, int lhs_orig_symb_id) const;
-  expr_t getPacNonOptimizingPart(BinaryOpNode *bopn, int optim_share) const;
-  bool getPacNonOptimizingPartHelper(BinaryOpNode *bopn, int optim_share) const;
+  expr_t getPacNonOptimizingPart(int optim_share_symb_id) const;
   bool isParamTimesEndogExpr() const override;
   bool isVarModelReferenced(const string &model_info_name) const override;
   void getEndosAndMaxLags(map<string, int> &model_endos_and_lags) const override;
@@ -1103,6 +1109,7 @@ public:
   //! Substitute auxiliary variables by their expression in static model auxiliary variable definition
   expr_t substituteStaticAuxiliaryDefinition() const;
   void decomposeAdditiveTerms(vector<pair<expr_t, int>> &terms, int current_sign) const override;
+  void decomposeMultiplicativeFactors(vector<pair<expr_t, int>> &factors, int current_exponent = 1) const override;
   void matchMatchedMoment(vector<int> &symb_ids, vector<int> &lags, vector<int> &powers) const override;
 };
 
-- 
GitLab