diff --git a/src/ExprNode.cc b/src/ExprNode.cc
index 31c3bbefeb6715f13e9a72dd4d6c3cf668b2421f..73ef9978eccdceeefe468aa7be9bdf68c277aada 100644
--- a/src/ExprNode.cc
+++ b/src/ExprNode.cc
@@ -476,6 +476,12 @@ NumConstNode::maxLag() const
   return 0;
 }
 
+int
+NumConstNode::maxLagWithDiffsExpanded() const
+{
+  return 0;
+}
+
 expr_t
 NumConstNode::undiff() const
 {
@@ -1480,8 +1486,8 @@ VariableNode::maxLead() const
   switch (get_type())
     {
     case SymbolType::endogenous:
-      return lag;
     case SymbolType::exogenous:
+    case SymbolType::exogenousDet:
       return lag;
     case SymbolType::modelLocalVariable:
       return datatree.getLocalVariable(symb_id)->maxLead();
@@ -1515,8 +1521,8 @@ VariableNode::maxLag() const
   switch (get_type())
     {
     case SymbolType::endogenous:
-      return -lag;
     case SymbolType::exogenous:
+    case SymbolType::exogenousDet:
       return -lag;
     case SymbolType::modelLocalVariable:
       return datatree.getLocalVariable(symb_id)->maxLag();
@@ -1525,6 +1531,22 @@ VariableNode::maxLag() const
     }
 }
 
+int
+VariableNode::maxLagWithDiffsExpanded() const
+{
+  switch (get_type())
+    {
+    case SymbolType::endogenous:
+    case SymbolType::exogenous:
+    case SymbolType::exogenousDet:
+      return -lag;
+    case SymbolType::modelLocalVariable:
+      return datatree.getLocalVariable(symb_id)->maxLagWithDiffsExpanded();
+    default:
+      return 0;
+    }
+}
+
 expr_t
 VariableNode::undiff() const
 {
@@ -3211,11 +3233,17 @@ UnaryOpNode::maxLead() const
 int
 UnaryOpNode::maxLag() const
 {
-  if (op_code == UnaryOpcode::diff)
-    return arg->maxLag() + 1;
   return arg->maxLag();
 }
 
+int
+UnaryOpNode::maxLagWithDiffsExpanded() const
+{
+  if (op_code == UnaryOpcode::diff)
+    return arg->maxLagWithDiffsExpanded() + 1;
+  return arg->maxLagWithDiffsExpanded();
+}
+
 expr_t
 UnaryOpNode::undiff() const
 {
@@ -3230,7 +3258,7 @@ UnaryOpNode::VarMaxLag(DataTree &static_datatree, set<expr_t> &static_lhs) const
   auto it = static_lhs.find(this->toStatic(static_datatree));
   if (it == static_lhs.end())
     return 0;
-  return arg->maxLag() - arg->countDiffs();
+  return arg->maxLag();
 }
 
 int
@@ -3333,7 +3361,7 @@ UnaryOpNode::findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_t
     return;
 
   expr_t sthis = this->toStatic(static_datatree);
-  int arg_max_lag = -arg->maxLag();
+  int arg_max_lag = -arg->maxLagWithDiffsExpanded();
   // TODO: implement recursive expression comparison, ensuring that the difference in the lags is constant across nodes
   auto it = nodes.find(sthis);
   if (it != nodes.end())
@@ -3357,7 +3385,7 @@ UnaryOpNode::findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table)
     return;
 
   expr_t sthis = this->toStatic(static_datatree);
-  int arg_max_lag = -arg->maxLag();
+  int arg_max_lag = -arg->maxLagWithDiffsExpanded();
   // TODO: implement recursive expression comparison, ensuring that the difference in the lags is constant across nodes
   auto it = diff_table.find(sthis);
   if (it != diff_table.end())
@@ -3393,7 +3421,7 @@ UnaryOpNode::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table,
   expr_t sthis = dynamic_cast<UnaryOpNode *>(this->toStatic(static_datatree));
   auto it = diff_table.find(sthis);
   int symb_id;
-  if (it == diff_table.end() || it->second[-arg->maxLag()] != this)
+  if (it == diff_table.end() || it->second[-arg->maxLagWithDiffsExpanded()] != this)
     {
       // diff does not appear in VAR equations
       // so simply create aux var and return
@@ -5130,6 +5158,12 @@ BinaryOpNode::maxLag() const
   return max(arg1->maxLag(), arg2->maxLag());
 }
 
+int
+BinaryOpNode::maxLagWithDiffsExpanded() const
+{
+  return max(arg1->maxLagWithDiffsExpanded(), arg2->maxLagWithDiffsExpanded());
+}
+
 expr_t
 BinaryOpNode::undiff() const
 {
@@ -6411,6 +6445,13 @@ TrinaryOpNode::maxLag() const
   return max(arg1->maxLag(), max(arg2->maxLag(), arg3->maxLag()));
 }
 
+int
+TrinaryOpNode::maxLagWithDiffsExpanded() const
+{
+  return max(arg1->maxLagWithDiffsExpanded(),
+             max(arg2->maxLagWithDiffsExpanded(), arg3->maxLagWithDiffsExpanded()));
+}
+
 expr_t
 TrinaryOpNode::undiff() const
 {
@@ -6902,6 +6943,15 @@ AbstractExternalFunctionNode::maxLag() const
   return val;
 }
 
+int
+AbstractExternalFunctionNode::maxLagWithDiffsExpanded() const
+{
+  int val = 0;
+  for (auto argument : arguments)
+    val = max(val, argument->maxLagWithDiffsExpanded());
+  return val;
+}
+
 expr_t
 AbstractExternalFunctionNode::undiff() const
 {
@@ -8547,6 +8597,13 @@ VarExpectationNode::maxLag() const
   exit(EXIT_FAILURE);
 }
 
+int
+VarExpectationNode::maxLagWithDiffsExpanded() const
+{
+  cerr << "VarExpectationNode::maxLagWithDiffsExpanded not implemented." << endl;
+  exit(EXIT_FAILURE);
+}
+
 expr_t
 VarExpectationNode::undiff() const
 {
@@ -9106,6 +9163,12 @@ PacExpectationNode::maxLag() const
   return 0;
 }
 
+int
+PacExpectationNode::maxLagWithDiffsExpanded() const
+{
+  return 0;
+}
+
 expr_t
 PacExpectationNode::undiff() const
 {
diff --git a/src/ExprNode.hh b/src/ExprNode.hh
index d65b3d27bedca067e63e0f746aa6551546cfed3d..e9095fef9e54120f161b92b35849b20bc04062ae 100644
--- a/src/ExprNode.hh
+++ b/src/ExprNode.hh
@@ -404,14 +404,22 @@ class ExprNode
       /*! Always returns a non-negative value */
       virtual int maxExoLag() const = 0;
 
-      //! Returns the relative period of the most forward term in this expression
-      /*! A negative value means that the expression contains only lagged variables */
+      //! Returns the maximum lead of endo/exo/exodet in this expression
+      /*! A negative value means that the expression contains only lagged
+          variables. */
       virtual int maxLead() const = 0;
 
-      //! Returns the relative period of the most backward term in this expression
-      /*! A negative value means that the expression contains only leaded variables */
+      //! Returns the maximum lag of endo/exo/exodet in this expression
+      /*! A negative value means that the expression contains only leaded
+          variables. */
       virtual int maxLag() const = 0;
 
+      //! Returns the maximum lag of endo/exo/exodet, as if diffs were expanded
+      /*! This function behaves as maxLag(), except that it treats diff()
+          differently. For e.g., on diff(diff(x(-1))), maxLag() returns 1 while
+          maxLagWithDiffsExpanded() returns 3. */
+      virtual int maxLagWithDiffsExpanded() const = 0;
+
       //! Get Max lag of var associated with Pac model
       //! Takes account of undiffed LHS variables in calculating the max lag
       virtual int PacMaxLag(int lhs_symb_id) const = 0;
@@ -683,6 +691,7 @@ public:
   int maxExoLag() const override;
   int maxLead() const override;
   int maxLag() const override;
+  int maxLagWithDiffsExpanded() const override;
   int VarMinLag() const override;
   int VarMaxLag(DataTree &static_datatree, set<expr_t> &static_lhs) const override;
   int PacMaxLag(int lhs_symb_id) const override;
@@ -771,6 +780,7 @@ public:
   int maxExoLag() const override;
   int maxLead() const override;
   int maxLag() const override;
+  int maxLagWithDiffsExpanded() const override;
   int VarMinLag() const override;
   int VarMaxLag(DataTree &static_datatree, set<expr_t> &static_lhs) const override;
   int PacMaxLag(int lhs_symb_id) const override;
@@ -884,6 +894,7 @@ public:
   int maxExoLag() const override;
   int maxLead() const override;
   int maxLag() const override;
+  int maxLagWithDiffsExpanded() const override;
   int VarMinLag() const override;
   int VarMaxLag(DataTree &static_datatree, set<expr_t> &static_lhs) const override;
   int PacMaxLag(int lhs_symb_id) const override;
@@ -1006,6 +1017,7 @@ public:
   int maxExoLag() const override;
   int maxLead() const override;
   int maxLag() const override;
+  int maxLagWithDiffsExpanded() const override;
   int VarMinLag() const override;
   int VarMaxLag(DataTree &static_datatree, set<expr_t> &static_lhs) const override;
   int PacMaxLag(int lhs_symb_id) const override;
@@ -1132,6 +1144,7 @@ public:
   int maxExoLag() const override;
   int maxLead() const override;
   int maxLag() const override;
+  int maxLagWithDiffsExpanded() const override;
   int VarMinLag() const override;
   int VarMaxLag(DataTree &static_datatree, set<expr_t> &static_lhs) const override;
   int PacMaxLag(int lhs_symb_id) const override;
@@ -1257,6 +1270,7 @@ public:
   int maxExoLag() const override;
   int maxLead() const override;
   int maxLag() const override;
+  int maxLagWithDiffsExpanded() const override;
   int VarMinLag() const override;
   int VarMaxLag(DataTree &static_datatree, set<expr_t> &static_lhs) const override;
   int PacMaxLag(int lhs_symb_id) const override;
@@ -1459,6 +1473,7 @@ public:
   int maxExoLag() const override;
   int maxLead() const override;
   int maxLag() const override;
+  int maxLagWithDiffsExpanded() const override;
   int VarMinLag() const override;
   int VarMaxLag(DataTree &static_datatree, set<expr_t> &static_lhs) const override;
   int PacMaxLag(int lhs_symb_id) const override;
@@ -1558,6 +1573,7 @@ public:
   int maxExoLag() const override;
   int maxLead() const override;
   int maxLag() const override;
+  int maxLagWithDiffsExpanded() const override;
   int VarMinLag() const override;
   int VarMaxLag(DataTree &static_datatree, set<expr_t> &static_lhs) const override;
   int PacMaxLag(int lhs_symb_id) const override;