diff --git a/src/ExprNode.cc b/src/ExprNode.cc
index f2d585b84110eac05609fce902889a3b60873273..efc38c0de8b3f8145e55dbb7caab816129aedb36 100644
--- a/src/ExprNode.cc
+++ b/src/ExprNode.cc
@@ -3107,9 +3107,20 @@ UnaryOpNode::collectDynamicVariables(SymbolType type_arg, set<pair<int, int>> &r
 void
 UnaryOpNode::computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const
 {
-  arg->computeSubExprContainingVariable(symb_id, lag, contain_var);
-  if (contain_var.contains(arg))
-    contain_var.insert(const_cast<UnaryOpNode *>(this));
+  if (op_code == UnaryOpcode::diff)
+    {
+      expr_t lagged_arg {arg->decreaseLeadsLags(1)};
+      expr_t substitute {datatree.AddMinus(arg, lagged_arg)};
+      substitute->computeSubExprContainingVariable(symb_id, lag, contain_var);
+      if (contain_var.contains(arg) || contain_var.contains(lagged_arg))
+        contain_var.insert(const_cast<UnaryOpNode *>(this));
+    }
+  else
+    {
+      arg->computeSubExprContainingVariable(symb_id, lag, contain_var);
+      if (contain_var.contains(arg))
+        contain_var.insert(const_cast<UnaryOpNode *>(this));
+    }
 }
 
 BinaryOpNode *
@@ -3173,6 +3184,13 @@ UnaryOpNode::normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs)
     case UnaryOpcode::cbrt:
       rhs = datatree.AddPower(rhs, datatree.Three);
       break;
+    case UnaryOpcode::diff:
+      /* Recursively call the function on arg-arg(-1).
+         This is necessary to deal with the 3 different possible cases:
+         — var in arg but not arg(-1);
+         — var in arg(-1) but not arg;
+         — var in both arg and arg(-1). */
+      return datatree.AddMinus(arg, arg->decreaseLeadsLags(1))->normalizeEquationHelper(contain_var, rhs);
     default:
       throw NormalizationFailed();
     }
diff --git a/src/ExprNode.hh b/src/ExprNode.hh
index 1fdee2c66e56a0efce11ad3e2f6a9a74d3ce2e43..58b1024ed9d011c659bbfd090a493e00aa4a85af 100644
--- a/src/ExprNode.hh
+++ b/src/ExprNode.hh
@@ -281,6 +281,14 @@ protected:
      lag_equivalence_table_t for an explanation of these concepts. */
   pair<expr_t, int> getLagEquivalenceClass() const;
 
+  /* Computes the set of all sub-expressions that contain the variable
+     (symb_id, lag).
+     Note that if a diff operator is encountered:
+     - diff(expr) will be added to the output set if either expr or expr(-1)
+       contains the variable;
+     - the method will be called recursively on expr-expr(-1)  */
+  virtual void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const = 0;
+
 public:
   ExprNode(DataTree &datatree_arg, int idx_arg);
   virtual ~ExprNode() = default;
@@ -457,9 +465,6 @@ public:
   //  virtual void computeXrefs(set<int> &param, set<int> &endo, set<int> &exo, set<int> &exo_det) const = 0;
   virtual void computeXrefs(EquationInfo &ei) const = 0;
 
-  // Computes the set of all sub-expressions that contain the variable (symb_id, lag)
-  virtual void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const = 0;
-
   //! Helper for normalization of equations
   /*! Normalize the equation this = rhs.
       Must be called on a node containing the desired LHS variable.
@@ -810,6 +815,7 @@ private:
   expr_t computeDerivative(int deriv_id) override;
 protected:
   void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> &param_id, double &constant, bool at_denominator) const override;
+  void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
 public:
   NumConstNode(DataTree &datatree_arg, int idx_arg, int id_arg);
   void prepareForDerivation() override;
@@ -823,7 +829,6 @@ public:
   void writeBytecodeOutput(BytecodeWriter &code_file, ExprNodeBytecodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const override;
   expr_t toStatic(DataTree &static_datatree) const override;
   void computeXrefs(EquationInfo &ei) const override;
-  void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
   BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override;
   expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override;
   int maxEndoLead() const override;
@@ -882,6 +887,7 @@ private:
   expr_t computeDerivative(int deriv_id) override;
 protected:
   void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> &param_id, double &constant, bool at_denominator) const override;
+  void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
 public:
   VariableNode(DataTree &datatree_arg, int idx_arg, int symb_id_arg, int lag_arg);
   void prepareForDerivation() override;
@@ -896,7 +902,6 @@ public:
   expr_t toStatic(DataTree &static_datatree) const override;
   void computeXrefs(EquationInfo &ei) const override;
   SymbolType get_type() const;
-  void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
   BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override;
   expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override;
   int maxEndoLead() const override;
@@ -950,6 +955,7 @@ class UnaryOpNode : public ExprNode
 {
 protected:
   void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> &param_id, double &constant, bool at_denominator) const override;
+  void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
 public:
   const expr_t arg;
   //! Stores the information set. Only used for expectation operator
@@ -999,7 +1005,6 @@ public:
   void writeBytecodeOutput(BytecodeWriter &code_file, ExprNodeBytecodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const override;
   expr_t toStatic(DataTree &static_datatree) const override;
   void computeXrefs(EquationInfo &ei) const override;
-  void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
   BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override;
   expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override;
   int maxEndoLead() const override;
@@ -1055,6 +1060,7 @@ class BinaryOpNode : public ExprNode
 {
 protected:
   void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> &param_id, double &constant, bool at_denominator) const override;
+  void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
 public:
   const expr_t arg1, arg2;
   const BinaryOpcode op_code;
@@ -1104,7 +1110,6 @@ public:
   expr_t Compute_RHS(expr_t arg1, expr_t arg2, int op, int op_type) const;
   expr_t toStatic(DataTree &static_datatree) const override;
   void computeXrefs(EquationInfo &ei) const override;
-  void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
   BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override;
   //! Try to normalize an equation with respect to a given dynamic variable.
   /*! Should only be called on Equal nodes. The variable must appear in the equation. */
@@ -1203,6 +1208,8 @@ class TrinaryOpNode : public ExprNode
 public:
   const expr_t arg1, arg2, arg3;
   const TrinaryOpcode op_code;
+protected:
+  void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
 private:
   expr_t computeDerivative(int deriv_id) override;
   int cost(int cost, bool is_matlab) const override;
@@ -1245,7 +1252,6 @@ public:
   void writeBytecodeOutput(BytecodeWriter &code_file, ExprNodeBytecodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const override;
   expr_t toStatic(DataTree &static_datatree) const override;
   void computeXrefs(EquationInfo &ei) const override;
-  void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
   BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override;
   expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override;
   int maxEndoLead() const override;
@@ -1325,6 +1331,7 @@ protected:
     function which is computed by the same external function call (i.e. it has
     the same so-called "Tef" index) */
   virtual function<bool (expr_t)> sameTefTermPredicate() const = 0;
+  void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
 public:
   AbstractExternalFunctionNode(DataTree &datatree_arg, int idx_arg, int symb_id_arg,
                                vector<expr_t> arguments_arg);
@@ -1358,7 +1365,6 @@ public:
   void writeBytecodeOutput(BytecodeWriter &code_file, ExprNodeBytecodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const override = 0;
   expr_t toStatic(DataTree &static_datatree) const override = 0;
   void computeXrefs(EquationInfo &ei) const override = 0;
-  void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
   BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override;
   expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override;
   int maxEndoLead() const override;
@@ -1565,7 +1571,6 @@ public:
   optional<int> findTargetVariable(int lhs_symb_id) const override;
   expr_t substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
   expr_t substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
-  void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
   BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override;
   void writeBytecodeOutput(BytecodeWriter &code_file, ExprNodeBytecodeOutputType output_type,
                            const temporary_terms_t &temporary_terms,
@@ -1585,6 +1590,8 @@ public:
   expr_t detrend(int symb_id, bool log_trend, expr_t trend) const override;
   expr_t removeTrendLeadLag(const map<int, expr_t> &trend_symbols_map) const override;
   expr_t substituteLogTransform(int orig_symb_id, int aux_symb_id) const override;
+protected:
+  void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
 };
 
 class VarExpectationNode : public SubModelNode