From a31ef6069cb1aeab408f1080fec3ad3eaaa8088f Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?S=C3=A9bastien=20Villemot?= <sebastien@dynare.org>
Date: Tue, 18 Oct 2022 17:24:51 +0200
Subject: [PATCH] Correctly handle diff operator in equation renormalization
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Should have no impact though, since diff nodes are already substituted out at
that point. But it’s better to implement it properly, in case we change the
substitution rules later.

By the way, make the computeSubExprContainingVariable method protected.
---
 src/ExprNode.cc | 24 +++++++++++++++++++++---
 src/ExprNode.hh | 27 +++++++++++++++++----------
 2 files changed, 38 insertions(+), 13 deletions(-)

diff --git a/src/ExprNode.cc b/src/ExprNode.cc
index f2d585b8..efc38c0d 100644
--- a/src/ExprNode.cc
+++ b/src/ExprNode.cc
@@ -3107,9 +3107,20 @@ UnaryOpNode::collectDynamicVariables(SymbolType type_arg, set<pair<int, int>> &r
 void
 UnaryOpNode::computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const
 {
-  arg->computeSubExprContainingVariable(symb_id, lag, contain_var);
-  if (contain_var.contains(arg))
-    contain_var.insert(const_cast<UnaryOpNode *>(this));
+  if (op_code == UnaryOpcode::diff)
+    {
+      expr_t lagged_arg {arg->decreaseLeadsLags(1)};
+      expr_t substitute {datatree.AddMinus(arg, lagged_arg)};
+      substitute->computeSubExprContainingVariable(symb_id, lag, contain_var);
+      if (contain_var.contains(arg) || contain_var.contains(lagged_arg))
+        contain_var.insert(const_cast<UnaryOpNode *>(this));
+    }
+  else
+    {
+      arg->computeSubExprContainingVariable(symb_id, lag, contain_var);
+      if (contain_var.contains(arg))
+        contain_var.insert(const_cast<UnaryOpNode *>(this));
+    }
 }
 
 BinaryOpNode *
@@ -3173,6 +3184,13 @@ UnaryOpNode::normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs)
     case UnaryOpcode::cbrt:
       rhs = datatree.AddPower(rhs, datatree.Three);
       break;
+    case UnaryOpcode::diff:
+      /* Recursively call the function on arg-arg(-1).
+         This is necessary to deal with the 3 different possible cases:
+         β€” var in arg but not arg(-1);
+         β€” var in arg(-1) but not arg;
+         β€” var in both arg and arg(-1). */
+      return datatree.AddMinus(arg, arg->decreaseLeadsLags(1))->normalizeEquationHelper(contain_var, rhs);
     default:
       throw NormalizationFailed();
     }
diff --git a/src/ExprNode.hh b/src/ExprNode.hh
index 1fdee2c6..58b1024e 100644
--- a/src/ExprNode.hh
+++ b/src/ExprNode.hh
@@ -281,6 +281,14 @@ protected:
      lag_equivalence_table_t for an explanation of these concepts. */
   pair<expr_t, int> getLagEquivalenceClass() const;
 
+  /* Computes the set of all sub-expressions that contain the variable
+     (symb_id, lag).
+     Note that if a diff operator is encountered:
+     - diff(expr) will be added to the output set if either expr or expr(-1)
+       contains the variable;
+     - the method will be called recursively on expr-expr(-1)  */
+  virtual void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const = 0;
+
 public:
   ExprNode(DataTree &datatree_arg, int idx_arg);
   virtual ~ExprNode() = default;
@@ -457,9 +465,6 @@ public:
   //  virtual void computeXrefs(set<int> &param, set<int> &endo, set<int> &exo, set<int> &exo_det) const = 0;
   virtual void computeXrefs(EquationInfo &ei) const = 0;
 
-  // Computes the set of all sub-expressions that contain the variable (symb_id, lag)
-  virtual void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const = 0;
-
   //! Helper for normalization of equations
   /*! Normalize the equation this = rhs.
       Must be called on a node containing the desired LHS variable.
@@ -810,6 +815,7 @@ private:
   expr_t computeDerivative(int deriv_id) override;
 protected:
   void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> &param_id, double &constant, bool at_denominator) const override;
+  void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
 public:
   NumConstNode(DataTree &datatree_arg, int idx_arg, int id_arg);
   void prepareForDerivation() override;
@@ -823,7 +829,6 @@ public:
   void writeBytecodeOutput(BytecodeWriter &code_file, ExprNodeBytecodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const override;
   expr_t toStatic(DataTree &static_datatree) const override;
   void computeXrefs(EquationInfo &ei) const override;
-  void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
   BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override;
   expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override;
   int maxEndoLead() const override;
@@ -882,6 +887,7 @@ private:
   expr_t computeDerivative(int deriv_id) override;
 protected:
   void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> &param_id, double &constant, bool at_denominator) const override;
+  void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
 public:
   VariableNode(DataTree &datatree_arg, int idx_arg, int symb_id_arg, int lag_arg);
   void prepareForDerivation() override;
@@ -896,7 +902,6 @@ public:
   expr_t toStatic(DataTree &static_datatree) const override;
   void computeXrefs(EquationInfo &ei) const override;
   SymbolType get_type() const;
-  void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
   BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override;
   expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override;
   int maxEndoLead() const override;
@@ -950,6 +955,7 @@ class UnaryOpNode : public ExprNode
 {
 protected:
   void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> &param_id, double &constant, bool at_denominator) const override;
+  void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
 public:
   const expr_t arg;
   //! Stores the information set. Only used for expectation operator
@@ -999,7 +1005,6 @@ public:
   void writeBytecodeOutput(BytecodeWriter &code_file, ExprNodeBytecodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const override;
   expr_t toStatic(DataTree &static_datatree) const override;
   void computeXrefs(EquationInfo &ei) const override;
-  void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
   BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override;
   expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override;
   int maxEndoLead() const override;
@@ -1055,6 +1060,7 @@ class BinaryOpNode : public ExprNode
 {
 protected:
   void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> &param_id, double &constant, bool at_denominator) const override;
+  void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
 public:
   const expr_t arg1, arg2;
   const BinaryOpcode op_code;
@@ -1104,7 +1110,6 @@ public:
   expr_t Compute_RHS(expr_t arg1, expr_t arg2, int op, int op_type) const;
   expr_t toStatic(DataTree &static_datatree) const override;
   void computeXrefs(EquationInfo &ei) const override;
-  void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
   BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override;
   //! Try to normalize an equation with respect to a given dynamic variable.
   /*! Should only be called on Equal nodes. The variable must appear in the equation. */
@@ -1203,6 +1208,8 @@ class TrinaryOpNode : public ExprNode
 public:
   const expr_t arg1, arg2, arg3;
   const TrinaryOpcode op_code;
+protected:
+  void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
 private:
   expr_t computeDerivative(int deriv_id) override;
   int cost(int cost, bool is_matlab) const override;
@@ -1245,7 +1252,6 @@ public:
   void writeBytecodeOutput(BytecodeWriter &code_file, ExprNodeBytecodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const override;
   expr_t toStatic(DataTree &static_datatree) const override;
   void computeXrefs(EquationInfo &ei) const override;
-  void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
   BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override;
   expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override;
   int maxEndoLead() const override;
@@ -1325,6 +1331,7 @@ protected:
     function which is computed by the same external function call (i.e. it has
     the same so-called "Tef" index) */
   virtual function<bool (expr_t)> sameTefTermPredicate() const = 0;
+  void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
 public:
   AbstractExternalFunctionNode(DataTree &datatree_arg, int idx_arg, int symb_id_arg,
                                vector<expr_t> arguments_arg);
@@ -1358,7 +1365,6 @@ public:
   void writeBytecodeOutput(BytecodeWriter &code_file, ExprNodeBytecodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const override = 0;
   expr_t toStatic(DataTree &static_datatree) const override = 0;
   void computeXrefs(EquationInfo &ei) const override = 0;
-  void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
   BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override;
   expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override;
   int maxEndoLead() const override;
@@ -1565,7 +1571,6 @@ public:
   optional<int> findTargetVariable(int lhs_symb_id) const override;
   expr_t substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
   expr_t substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
-  void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
   BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override;
   void writeBytecodeOutput(BytecodeWriter &code_file, ExprNodeBytecodeOutputType output_type,
                            const temporary_terms_t &temporary_terms,
@@ -1585,6 +1590,8 @@ public:
   expr_t detrend(int symb_id, bool log_trend, expr_t trend) const override;
   expr_t removeTrendLeadLag(const map<int, expr_t> &trend_symbols_map) const override;
   expr_t substituteLogTransform(int orig_symb_id, int aux_symb_id) const override;
+protected:
+  void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
 };
 
 class VarExpectationNode : public SubModelNode
-- 
GitLab