From daa8d016868f8502bd708d39fd39fba51141cdcb Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?S=C3=A9bastien=20Villemot?= <sebastien@dynare.org>
Date: Thu, 2 Apr 2020 14:36:26 +0200
Subject: [PATCH] Complete rewrite of the equation normalization symbolic
 engine

---
 src/ExprNode.cc  | 646 ++++++++++++++++-------------------------------
 src/ExprNode.hh  |  41 ++-
 src/ModelTree.cc |  20 +-
 3 files changed, 256 insertions(+), 451 deletions(-)

diff --git a/src/ExprNode.cc b/src/ExprNode.cc
index 66e111a1..218f0793 100644
--- a/src/ExprNode.cc
+++ b/src/ExprNode.cc
@@ -171,13 +171,6 @@ ExprNode::computeTemporaryTerms(map<expr_t, int> &reference_count,
   // Nothing to do for a terminal node
 }
 
-pair<int, expr_t>
-ExprNode::normalizeEquation(int var_endo, vector<tuple<int, expr_t, expr_t>> &List_of_Op_RHS) const
-{
-  /* nothing to do */
-  return { 0, nullptr };
-}
-
 void
 ExprNode::writeOutput(ostream &output) const
 {
@@ -497,11 +490,16 @@ NumConstNode::collectDynamicVariables(SymbolType type_arg, set<pair<int, int>> &
 {
 }
 
-pair<int, expr_t>
-NumConstNode::normalizeEquation(int var_endo, vector<tuple<int, expr_t, expr_t>> &List_of_Op_RHS) const
+void
+NumConstNode::computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const
+{
+}
+
+BinaryOpNode *
+NumConstNode::normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const
 {
-  /* return the numercial constant */
-  return { 0, datatree.AddNonNegativeConstant(datatree.num_constants.get(id)) };
+  cerr << "NumConstNode::normalizeEquation: this should not happen" << endl;
+  exit(EXIT_FAILURE);
 }
 
 expr_t
@@ -1360,36 +1358,20 @@ VariableNode::collectDynamicVariables(SymbolType type_arg, set<pair<int, int>> &
     datatree.getLocalVariable(symb_id)->collectDynamicVariables(type_arg, result);
 }
 
-pair<int, expr_t>
-VariableNode::normalizeEquation(int var_endo, vector<tuple<int, expr_t, expr_t>> &List_of_Op_RHS) const
-{
-  /* The equation has to be normalized with respect to the current endogenous variable ascribed to it.
-     The two input arguments are :
-     - The ID of the endogenous variable associated to the equation.
-     - The list of operators and operands needed to normalize the equation*
-
-     The pair returned by NormalizeEquation is composed of
-     - a flag indicating if the expression returned contains (flag = 1) or not (flag = 0)
-     the endogenous variable related to the equation.
-     If the expression contains more than one occurence of the associated endogenous variable,
-     the flag is equal to 2.
-     - an expression equal to the RHS if flag = 0 and equal to NULL elsewhere
-  */
-  if (get_type() == SymbolType::endogenous)
-    {
-      if (datatree.symbol_table.getTypeSpecificID(symb_id) == var_endo && lag == 0)
-        /* the endogenous variable */
-        return { 1, nullptr };
-      else
-        return { 0, datatree.AddVariable(symb_id, lag) };
-    }
-  else
-    {
-      if (get_type() == SymbolType::parameter)
-        return { 0, datatree.AddVariable(symb_id, 0) };
-      else
-        return { 0, datatree.AddVariable(symb_id, lag) };
-    }
+void
+VariableNode::computeSubExprContainingVariable(int symb_id_arg, int lag_arg, set<expr_t> &contain_var) const
+{
+  if (symb_id == symb_id_arg && lag == lag_arg)
+    contain_var.insert(const_cast<VariableNode*>(this));
+}
+
+BinaryOpNode *
+VariableNode::normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const
+{
+  assert(contain_var.count(const_cast<VariableNode *>(this)) > 0);
+
+  // This the LHS variable: we have finished the normalization
+  return datatree.AddEqual(const_cast<VariableNode *>(this), rhs);
 }
 
 expr_t
@@ -3073,144 +3055,80 @@ UnaryOpNode::collectDynamicVariables(SymbolType type_arg, set<pair<int, int>> &r
   arg->collectDynamicVariables(type_arg, result);
 }
 
-pair<int, expr_t>
-UnaryOpNode::normalizeEquation(int var_endo, vector<tuple<int, expr_t, expr_t>> &List_of_Op_RHS) const
+void
+UnaryOpNode::computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const
+{
+  arg->computeSubExprContainingVariable(symb_id, lag, contain_var);
+  if (contain_var.count(arg) > 0)
+    contain_var.insert(const_cast<UnaryOpNode *>(this));
+}
+
+BinaryOpNode *
+UnaryOpNode::normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const
 {
-  pair<int, expr_t> res = arg->normalizeEquation(var_endo, List_of_Op_RHS);
-  int is_endogenous_present = res.first;
-  expr_t New_expr_t = res.second;
+  assert(contain_var.count(const_cast<UnaryOpNode *>(this)) > 0);
 
-  if (is_endogenous_present == 2) /* The equation could not be normalized and the process is given-up*/
-    return { 2, nullptr };
-  else if (is_endogenous_present) /* The argument of the function contains the current values of
-                                     the endogenous variable associated to the equation.
-                                     In order to normalized, we have to apply the invert function to the RHS.*/
+  switch (op_code)
     {
-      switch (op_code)
-        {
-        case UnaryOpcode::uminus:
-          List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::uminus), nullptr, nullptr);
-          return { 1, nullptr };
-        case UnaryOpcode::exp:
-          List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::log), nullptr, nullptr);
-          return { 1, nullptr };
-        case UnaryOpcode::log:
-          List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::exp), nullptr, nullptr);
-          return { 1, nullptr };
-        case UnaryOpcode::log10:
-          List_of_Op_RHS.emplace_back(static_cast<int>(BinaryOpcode::power), nullptr, datatree.AddNonNegativeConstant("10"));
-          return { 1, nullptr };
-        case UnaryOpcode::cos:
-          List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::acos), nullptr, nullptr);
-          return { 1, nullptr };
-        case UnaryOpcode::sin:
-          List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::asin), nullptr, nullptr);
-          return { 1, nullptr };
-        case UnaryOpcode::tan:
-          List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::atan), nullptr, nullptr);
-          return { 1, nullptr };
-        case UnaryOpcode::acos:
-          List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::cos), nullptr, nullptr);
-          return { 1, nullptr };
-        case UnaryOpcode::asin:
-          List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::sin), nullptr, nullptr);
-          return { 1, nullptr };
-        case UnaryOpcode::atan:
-          List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::tan), nullptr, nullptr);
-          return { 1, nullptr };
-        case UnaryOpcode::cosh:
-          List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::acosh), nullptr, nullptr);
-          return { 1, nullptr };
-        case UnaryOpcode::sinh:
-          List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::asinh), nullptr, nullptr);
-          return { 1, nullptr };
-        case UnaryOpcode::tanh:
-          List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::atanh), nullptr, nullptr);
-          return { 1, nullptr };
-        case UnaryOpcode::acosh:
-          List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::cosh), nullptr, nullptr);
-          return { 1, nullptr };
-        case UnaryOpcode::asinh:
-          List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::sinh), nullptr, nullptr);
-          return { 1, nullptr };
-        case UnaryOpcode::atanh:
-          List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::tanh), nullptr, nullptr);
-          return { 1, nullptr };
-        case UnaryOpcode::sqrt:
-          List_of_Op_RHS.emplace_back(static_cast<int>(BinaryOpcode::power), nullptr, datatree.Two);
-          return { 1, nullptr };
-        case UnaryOpcode::cbrt:
-          List_of_Op_RHS.emplace_back(static_cast<int>(BinaryOpcode::power), nullptr, datatree.Three);
-          return { 1, nullptr };
-        case UnaryOpcode::abs:
-          return { 2, nullptr };
-        case UnaryOpcode::sign:
-          return { 2, nullptr };
-        case UnaryOpcode::steadyState:
-          return { 2, nullptr };
-        case UnaryOpcode::erf:
-          return { 2, nullptr };
-        default:
-          cerr << "Unary operator not handled during the normalization process" << endl;
-          return { 2, nullptr }; // Could not be normalized
-        }
-    }
-  else
-    { /* If the argument of the function do not contain the current values of the endogenous variable
-         related to the equation, the function with its argument is stored in the RHS*/
-      switch (op_code)
-        {
-        case UnaryOpcode::uminus:
-          return { 0, datatree.AddUMinus(New_expr_t) };
-        case UnaryOpcode::exp:
-          return { 0, datatree.AddExp(New_expr_t) };
-        case UnaryOpcode::log:
-          return { 0, datatree.AddLog(New_expr_t) };
-        case UnaryOpcode::log10:
-          return { 0, datatree.AddLog10(New_expr_t) };
-        case UnaryOpcode::cos:
-          return { 0, datatree.AddCos(New_expr_t) };
-        case UnaryOpcode::sin:
-          return { 0, datatree.AddSin(New_expr_t) };
-        case UnaryOpcode::tan:
-          return { 0, datatree.AddTan(New_expr_t) };
-        case UnaryOpcode::acos:
-          return { 0, datatree.AddAcos(New_expr_t) };
-        case UnaryOpcode::asin:
-          return { 0, datatree.AddAsin(New_expr_t) };
-        case UnaryOpcode::atan:
-          return { 0, datatree.AddAtan(New_expr_t) };
-        case UnaryOpcode::cosh:
-          return { 0, datatree.AddCosh(New_expr_t) };
-        case UnaryOpcode::sinh:
-          return { 0, datatree.AddSinh(New_expr_t) };
-        case UnaryOpcode::tanh:
-          return { 0, datatree.AddTanh(New_expr_t) };
-        case UnaryOpcode::acosh:
-          return { 0, datatree.AddAcosh(New_expr_t) };
-        case UnaryOpcode::asinh:
-          return { 0, datatree.AddAsinh(New_expr_t) };
-        case UnaryOpcode::atanh:
-          return { 0, datatree.AddAtanh(New_expr_t) };
-        case UnaryOpcode::sqrt:
-          return { 0, datatree.AddSqrt(New_expr_t) };
-        case UnaryOpcode::cbrt:
-          return { 0, datatree.AddCbrt(New_expr_t) };
-        case UnaryOpcode::abs:
-          return { 0, datatree.AddAbs(New_expr_t) };
-        case UnaryOpcode::sign:
-          return { 0, datatree.AddSign(New_expr_t) };
-        case UnaryOpcode::steadyState:
-          return { 0, datatree.AddSteadyState(New_expr_t) };
-        case UnaryOpcode::erf:
-          return { 0, datatree.AddErf(New_expr_t) };
-        default:
-          cerr << "Unary operator not handled during the normalization process" << endl;
-          return { 2, nullptr }; // Could not be normalized
-        }
+    case UnaryOpcode::uminus:
+      rhs = datatree.AddUMinus(rhs);
+      break;
+    case UnaryOpcode::exp:
+      rhs = datatree.AddLog(rhs);
+      break;
+    case UnaryOpcode::log:
+      rhs = datatree.AddExp(rhs);
+      break;
+    case UnaryOpcode::log10:
+      rhs = datatree.AddPower(datatree.AddNonNegativeConstant("10"), rhs);
+      break;
+    case UnaryOpcode::cos:
+      rhs = datatree.AddAcos(rhs);
+      break;
+    case UnaryOpcode::sin:
+      rhs = datatree.AddAsin(rhs);
+      break;
+    case UnaryOpcode::tan:
+      rhs = datatree.AddAtan(rhs);
+      break;
+    case UnaryOpcode::acos:
+      rhs = datatree.AddCos(rhs);
+      break;
+    case UnaryOpcode::asin:
+      rhs = datatree.AddSin(rhs);
+      break;
+    case UnaryOpcode::atan:
+      rhs = datatree.AddTan(rhs);
+      break;
+    case UnaryOpcode::cosh:
+      rhs = datatree.AddAcosh(rhs);
+      break;
+    case UnaryOpcode::sinh:
+      rhs = datatree.AddAsinh(rhs);
+      break;
+    case UnaryOpcode::tanh:
+      rhs = datatree.AddAtanh(rhs);
+      break;
+    case UnaryOpcode::acosh:
+      rhs = datatree.AddCosh(rhs);
+      break;
+    case UnaryOpcode::asinh:
+      rhs = datatree.AddSinh(rhs);
+      break;
+    case UnaryOpcode::atanh:
+      rhs = datatree.AddTanh(rhs);
+      break;
+    case UnaryOpcode::sqrt:
+      rhs = datatree.AddPower(rhs, datatree.Two);
+      break;
+    case UnaryOpcode::cbrt:
+      rhs = datatree.AddPower(rhs, datatree.Three);
+      break;
+    default:
+      throw NormalizationFailed();
     }
-  cerr << "UnaryOpNode::normalizeEquation: impossible case" << endl;
-  exit(EXIT_FAILURE);
+
+  return arg->normalizeEquationHelper(contain_var, rhs);
 }
 
 expr_t
@@ -4852,24 +4770,19 @@ BinaryOpNode::collectDynamicVariables(SymbolType type_arg, set<pair<int, int>> &
 expr_t
 BinaryOpNode::Compute_RHS(expr_t arg1, expr_t arg2, int op, int op_type) const
 {
-  temporary_terms_t temp;
   switch (op_type)
     {
     case 0: /*Unary Operator*/
       switch (static_cast<UnaryOpcode>(op))
         {
         case UnaryOpcode::uminus:
-          return (datatree.AddUMinus(arg1));
-          break;
+          return datatree.AddUMinus(arg1);
         case UnaryOpcode::exp:
-          return (datatree.AddExp(arg1));
-          break;
+          return datatree.AddExp(arg1);
         case UnaryOpcode::log:
-          return (datatree.AddLog(arg1));
-          break;
+          return datatree.AddLog(arg1);
         case UnaryOpcode::log10:
-          return (datatree.AddLog10(arg1));
-          break;
+          return datatree.AddLog10(arg1);
         default:
           cerr << "BinaryOpNode::Compute_RHS: case not handled";
           exit(EXIT_FAILURE);
@@ -4879,20 +4792,15 @@ BinaryOpNode::Compute_RHS(expr_t arg1, expr_t arg2, int op, int op_type) const
       switch (static_cast<BinaryOpcode>(op))
         {
         case BinaryOpcode::plus:
-          return (datatree.AddPlus(arg1, arg2));
-          break;
+          return datatree.AddPlus(arg1, arg2);
         case BinaryOpcode::minus:
-          return (datatree.AddMinus(arg1, arg2));
-          break;
+          return datatree.AddMinus(arg1, arg2);
         case BinaryOpcode::times:
-          return (datatree.AddTimes(arg1, arg2));
-          break;
+          return datatree.AddTimes(arg1, arg2);
         case BinaryOpcode::divide:
-          return (datatree.AddDivide(arg1, arg2));
-          break;
+          return datatree.AddDivide(arg1, arg2);
         case BinaryOpcode::power:
-          return (datatree.AddPower(arg1, arg2));
-          break;
+          return datatree.AddPower(arg1, arg2);
         default:
           cerr << "BinaryOpNode::Compute_RHS: case not handled";
           exit(EXIT_FAILURE);
@@ -4902,224 +4810,90 @@ BinaryOpNode::Compute_RHS(expr_t arg1, expr_t arg2, int op, int op_type) const
   return nullptr;
 }
 
-pair<int, expr_t>
-BinaryOpNode::normalizeEquation(int var_endo, vector<tuple<int, expr_t, expr_t>> &List_of_Op_RHS) const
-{
-  /* Checks if the current value of the endogenous variable related to the equation
-     is present in the arguments of the binary operator. */
-  vector<tuple<int, expr_t, expr_t>> List_of_Op_RHS1, List_of_Op_RHS2;
-  pair<int, expr_t> res = arg1->normalizeEquation(var_endo, List_of_Op_RHS1);
-  int is_endogenous_present_1 = res.first;
-  expr_t expr_t_1 = res.second;
-
-  res = arg2->normalizeEquation(var_endo, List_of_Op_RHS2);
-  int is_endogenous_present_2 = res.first;
-  expr_t expr_t_2 = res.second;
-
-  /* If the two expressions contains the current value of the endogenous variable associated to the equation
-     the equation could not be normalized and the process is given-up.*/
-  if (is_endogenous_present_1 == 2 || is_endogenous_present_2 == 2)
-    return { 2, nullptr };
-  else if (is_endogenous_present_1 && is_endogenous_present_2)
-    return { 2, nullptr };
-  else if (is_endogenous_present_1) /*If the current values of the endogenous variable associated to the equation
-                                      is present only in the first operand of the expression, we try to normalize the equation*/
-    {
-      if (op_code == BinaryOpcode::equal) /* The end of the normalization process :
-                                             All the operations needed to normalize the equation are applied. */
-        while (!List_of_Op_RHS1.empty())
-          {
-            tuple<int, expr_t, expr_t> it = List_of_Op_RHS1.back();
-            List_of_Op_RHS1.pop_back();
-            if (get<1>(it) && !get<2>(it)) /*Binary operator*/
-              expr_t_2 = Compute_RHS(expr_t_2, static_cast<BinaryOpNode *>(get<1>(it)), get<0>(it), 1);
-            else if (get<2>(it) && !get<1>(it)) /*Binary operator*/
-              expr_t_2 = Compute_RHS(get<2>(it), expr_t_2, get<0>(it), 1);
-            else if (get<2>(it) && get<1>(it)) /*Binary operator*/
-              expr_t_2 = Compute_RHS(get<1>(it), get<2>(it), get<0>(it), 1);
-            else /*Unary operator*/
-              expr_t_2 = Compute_RHS(static_cast<UnaryOpNode *>(expr_t_2), static_cast<UnaryOpNode *>(get<1>(it)), get<0>(it), 0);
-          }
-      else
-        List_of_Op_RHS = List_of_Op_RHS1;
-    }
-  else if (is_endogenous_present_2)
-    {
-      if (op_code == BinaryOpcode::equal)
-        while (!List_of_Op_RHS2.empty())
-          {
-            tuple<int, expr_t, expr_t> it = List_of_Op_RHS2.back();
-            List_of_Op_RHS2.pop_back();
-            if (get<1>(it) && !get<2>(it)) /*Binary operator*/
-              expr_t_1 = Compute_RHS(static_cast<BinaryOpNode *>(expr_t_1), static_cast<BinaryOpNode *>(get<1>(it)), get<0>(it), 1);
-            else if (get<2>(it) && !get<1>(it)) /*Binary operator*/
-              expr_t_1 = Compute_RHS(static_cast<BinaryOpNode *>(get<2>(it)), static_cast<BinaryOpNode *>(expr_t_1), get<0>(it), 1);
-            else if (get<2>(it) && get<1>(it)) /*Binary operator*/
-              expr_t_1 = Compute_RHS(get<1>(it), get<2>(it), get<0>(it), 1);
-            else
-              expr_t_1 = Compute_RHS(static_cast<UnaryOpNode *>(expr_t_1), static_cast<UnaryOpNode *>(get<1>(it)), get<0>(it), 0);
-          }
-      else
-        List_of_Op_RHS = List_of_Op_RHS2;
-    }
+void
+BinaryOpNode::computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const
+{
+  arg1->computeSubExprContainingVariable(symb_id, lag, contain_var);
+  arg2->computeSubExprContainingVariable(symb_id, lag, contain_var);
+  if (contain_var.count(arg1) > 0 || contain_var.count(arg2) > 0)
+    contain_var.insert(const_cast<BinaryOpNode *>(this));
+}
+
+BinaryOpNode *
+BinaryOpNode::normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const
+{
+  assert(contain_var.count(const_cast<BinaryOpNode *>(this)) > 0);
+
+  bool arg1_contains_var = contain_var.count(arg1) > 0;
+  bool arg2_contains_var = contain_var.count(arg2) > 0;
+  assert(arg1_contains_var || arg2_contains_var);
+
+  if (arg1_contains_var && arg2_contains_var)
+    throw NormalizationFailed();
+
   switch (op_code)
     {
     case BinaryOpcode::plus:
-      if (!is_endogenous_present_1 && !is_endogenous_present_2)
-        return { 0, datatree.AddPlus(expr_t_1, expr_t_2) };
-      else if (is_endogenous_present_1 && is_endogenous_present_2)
-        return { 2, nullptr };
-      else if (!is_endogenous_present_1 && is_endogenous_present_2)
-        {
-          List_of_Op_RHS.emplace_back(static_cast<int>(BinaryOpcode::minus), expr_t_1, nullptr);
-          return { 1, expr_t_1 };
-        }
-      else if (is_endogenous_present_1 && !is_endogenous_present_2)
-        {
-          List_of_Op_RHS.emplace_back(static_cast<int>(BinaryOpcode::minus), expr_t_2, nullptr);
-          return { 1, expr_t_2 };
-        }
+      if (arg1_contains_var)
+        rhs = datatree.AddMinus(rhs, arg2);
+      else
+        rhs = datatree.AddMinus(rhs, arg1);
       break;
     case BinaryOpcode::minus:
-      if (!is_endogenous_present_1 && !is_endogenous_present_2)
-        return { 0, datatree.AddMinus(expr_t_1, expr_t_2) };
-      else if (is_endogenous_present_1 && is_endogenous_present_2)
-        return { 2, nullptr };
-      else if (!is_endogenous_present_1 && is_endogenous_present_2)
-        {
-          List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::uminus), nullptr, nullptr);
-          List_of_Op_RHS.emplace_back(static_cast<int>(BinaryOpcode::minus), expr_t_1, nullptr);
-          return { 1, expr_t_1 };
-        }
-      else if (is_endogenous_present_1 && !is_endogenous_present_2)
-        {
-          List_of_Op_RHS.emplace_back(static_cast<int>(BinaryOpcode::plus), expr_t_2, nullptr);
-          return { 1, datatree.AddUMinus(expr_t_2) };
-        }
+      if (arg1_contains_var)
+        rhs = datatree.AddPlus(rhs, arg2);
+      else
+        rhs = datatree.AddMinus(arg1, rhs);
       break;
     case BinaryOpcode::times:
-      if (!is_endogenous_present_1 && !is_endogenous_present_2)
-        return { 0, datatree.AddTimes(expr_t_1, expr_t_2) };
-      else if (!is_endogenous_present_1 && is_endogenous_present_2)
-        {
-          List_of_Op_RHS.emplace_back(static_cast<int>(BinaryOpcode::divide), expr_t_1, nullptr);
-          return { 1, expr_t_1 };
-        }
-      else if (is_endogenous_present_1 && !is_endogenous_present_2)
-        {
-          List_of_Op_RHS.emplace_back(static_cast<int>(BinaryOpcode::divide), expr_t_2, nullptr);
-          return { 1, expr_t_2 };
-        }
+      if (arg1_contains_var)
+        rhs = datatree.AddDivide(rhs, arg2);
       else
-        return { 2, nullptr };
+        rhs = datatree.AddDivide(rhs, arg1);
       break;
     case BinaryOpcode::divide:
-      if (!is_endogenous_present_1 && !is_endogenous_present_2)
-        return { 0, datatree.AddDivide(expr_t_1, expr_t_2) };
-      else if (!is_endogenous_present_1 && is_endogenous_present_2)
-        {
-          List_of_Op_RHS.emplace_back(static_cast<int>(BinaryOpcode::divide), nullptr, expr_t_1);
-          return { 1, expr_t_1 };
-        }
-      else if (is_endogenous_present_1 && !is_endogenous_present_2)
-        {
-          List_of_Op_RHS.emplace_back(static_cast<int>(BinaryOpcode::times), expr_t_2, nullptr);
-          return { 1, expr_t_2 };
-        }
+      if (arg1_contains_var)
+        rhs = datatree.AddTimes(rhs, arg2);
       else
-        return { 2, nullptr };
+        rhs = datatree.AddDivide(arg1, rhs);
       break;
     case BinaryOpcode::power:
-      if (!is_endogenous_present_1 && !is_endogenous_present_2)
-        return { 0, datatree.AddPower(expr_t_1, expr_t_2) };
-      else if (is_endogenous_present_1 && !is_endogenous_present_2)
-        {
-          List_of_Op_RHS.emplace_back(static_cast<int>(BinaryOpcode::power), datatree.AddDivide(datatree.One, expr_t_2), nullptr);
-          return { 1, nullptr };
-        }
-      else if (!is_endogenous_present_1 && is_endogenous_present_2)
-        {
-          /* we have to nomalize a^f(X) = RHS */
-          /* First computes the ln(RHS)*/
-          List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::log), nullptr, nullptr);
-          /* Second  computes f(X) = ln(RHS) / ln(a)*/
-          List_of_Op_RHS.emplace_back(static_cast<int>(BinaryOpcode::divide), nullptr, datatree.AddLog(expr_t_1));
-          return { 1, nullptr };
-        }
-      break;
-    case BinaryOpcode::equal:
-      if (!is_endogenous_present_1 && !is_endogenous_present_2)
-        {
-          return { 0, datatree.AddEqual(datatree.AddVariable(datatree.symbol_table.getID(SymbolType::endogenous, var_endo), 0), datatree.AddMinus(expr_t_2, expr_t_1)) };
-        }
-      else if (is_endogenous_present_1 && is_endogenous_present_2)
-        {
-          return { 0, datatree.AddEqual(datatree.AddVariable(datatree.symbol_table.getID(SymbolType::endogenous, var_endo), 0), datatree.Zero) };
-        }
-      else if (!is_endogenous_present_1 && is_endogenous_present_2)
-        {
-          return { 0, datatree.AddEqual(datatree.AddVariable(datatree.symbol_table.getID(SymbolType::endogenous, var_endo), 0), /*datatree.AddUMinus(expr_t_1)*/ expr_t_1) };
-        }
-      else if (is_endogenous_present_1 && !is_endogenous_present_2)
-        {
-          return { 0, datatree.AddEqual(datatree.AddVariable(datatree.symbol_table.getID(SymbolType::endogenous, var_endo), 0), expr_t_2) };
-        }
-      break;
-    case BinaryOpcode::max:
-      if (!is_endogenous_present_1 && !is_endogenous_present_2)
-        return { 0, datatree.AddMax(expr_t_1, expr_t_2) };
-      else
-        return { 2, nullptr };
-      break;
-    case BinaryOpcode::min:
-      if (!is_endogenous_present_1 && !is_endogenous_present_2)
-        return { 0, datatree.AddMin(expr_t_1, expr_t_2) };
-      else
-        return { 2, nullptr };
-      break;
-    case BinaryOpcode::less:
-      if (!is_endogenous_present_1 && !is_endogenous_present_2)
-        return { 0, datatree.AddLess(expr_t_1, expr_t_2) };
-      else
-        return { 2, nullptr };
-      break;
-    case BinaryOpcode::greater:
-      if (!is_endogenous_present_1 && !is_endogenous_present_2)
-        return { 0, datatree.AddGreater(expr_t_1, expr_t_2) };
+      if (arg1_contains_var)
+        rhs = datatree.AddPower(rhs, datatree.AddDivide(datatree.One, arg2));
       else
-        return { 2, nullptr };
-      break;
-    case BinaryOpcode::lessEqual:
-      if (!is_endogenous_present_1 && !is_endogenous_present_2)
-        return { 0, datatree.AddLessEqual(expr_t_1, expr_t_2) };
-      else
-        return { 2, nullptr };
-      break;
-    case BinaryOpcode::greaterEqual:
-      if (!is_endogenous_present_1 && !is_endogenous_present_2)
-        return { 0, datatree.AddGreaterEqual(expr_t_1, expr_t_2) };
-      else
-        return { 2, nullptr };
-      break;
-    case BinaryOpcode::equalEqual:
-      if (!is_endogenous_present_1 && !is_endogenous_present_2)
-        return { 0, datatree.AddEqualEqual(expr_t_1, expr_t_2) };
-      else
-        return { 2, nullptr };
-      break;
-    case BinaryOpcode::different:
-      if (!is_endogenous_present_1 && !is_endogenous_present_2)
-        return { 0, datatree.AddDifferent(expr_t_1, expr_t_2) };
-      else
-        return { 2, nullptr };
+        // a^f(X)=rhs is normalized in f(X)=ln(rhs)/ln(a)
+        rhs = datatree.AddDivide(datatree.AddLog(rhs), datatree.AddLog(arg1));
       break;
+    case BinaryOpcode::equal:
+      cerr << "BinaryOpCode::normalizeEquationHelper: this case should not happen" << endl;
+      exit(EXIT_FAILURE);
     default:
-      cerr << "Binary operator not handled during the normalization process" << endl;
-      return { 2, nullptr }; // Could not be normalized
+      throw NormalizationFailed();
     }
-  // Suppress GCC warning
-  cerr << "BinaryOpNode::normalizeEquation: impossible case" << endl;
-  exit(EXIT_FAILURE);
+
+  if (arg1_contains_var)
+    return arg1->normalizeEquationHelper(contain_var, rhs);
+  else
+    return arg2->normalizeEquationHelper(contain_var, rhs);
+}
+
+BinaryOpNode *
+BinaryOpNode::normalizeEquation(int symb_id, int lag) const
+{
+  assert(op_code == BinaryOpcode::equal);
+
+  set<expr_t> contain_var;
+  computeSubExprContainingVariable(symb_id, lag, contain_var);
+
+  bool arg1_contains_var = contain_var.count(arg1) > 0;
+  bool arg2_contains_var = contain_var.count(arg2) > 0;
+  assert(arg1_contains_var || arg2_contains_var);
+
+  if (arg1_contains_var && arg2_contains_var)
+    throw NormalizationFailed();
+
+  return arg1_contains_var ? arg1->normalizeEquationHelper(contain_var, arg2)
+    : arg2->normalizeEquationHelper(contain_var, arg1);
 }
 
 expr_t
@@ -6477,22 +6251,20 @@ TrinaryOpNode::collectDynamicVariables(SymbolType type_arg, set<pair<int, int>>
   arg3->collectDynamicVariables(type_arg, result);
 }
 
-pair<int, expr_t>
-TrinaryOpNode::normalizeEquation(int var_endo, vector<tuple<int, expr_t, expr_t>> &List_of_Op_RHS) const
-{
-  pair<int, expr_t> res = arg1->normalizeEquation(var_endo, List_of_Op_RHS);
-  bool is_endogenous_present_1 = res.first;
-  expr_t expr_t_1 = res.second;
-  res = arg2->normalizeEquation(var_endo, List_of_Op_RHS);
-  bool is_endogenous_present_2 = res.first;
-  expr_t expr_t_2 = res.second;
-  res = arg3->normalizeEquation(var_endo, List_of_Op_RHS);
-  bool is_endogenous_present_3 = res.first;
-  expr_t expr_t_3 = res.second;
-  if (!is_endogenous_present_1 && !is_endogenous_present_2 && !is_endogenous_present_3)
-    return { 0, datatree.AddNormcdf(expr_t_1, expr_t_2, expr_t_3) };
-  else
-    return { 2, nullptr };
+void
+TrinaryOpNode::computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const
+{
+  arg1->computeSubExprContainingVariable(symb_id, lag, contain_var);
+  arg2->computeSubExprContainingVariable(symb_id, lag, contain_var);
+  arg3->computeSubExprContainingVariable(symb_id, lag, contain_var);
+  if (contain_var.count(arg1) > 0 || contain_var.count(arg2) > 0 || contain_var.count(arg3) > 0)
+    contain_var.insert(const_cast<TrinaryOpNode *>(this));
+}
+
+BinaryOpNode *
+TrinaryOpNode::normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const
+{
+  throw NormalizationFailed();
 }
 
 expr_t
@@ -7405,22 +7177,24 @@ AbstractExternalFunctionNode::getEndosAndMaxLags(map<string, int> &model_endos_a
     argument->getEndosAndMaxLags(model_endos_and_lags);
 }
 
-pair<int, expr_t>
-AbstractExternalFunctionNode::normalizeEquation(int var_endo, vector<tuple<int, expr_t, expr_t>> &List_of_Op_RHS) const
+
+void
+AbstractExternalFunctionNode::computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const
 {
-  vector<pair<bool, expr_t>> V_arguments;
-  vector<expr_t> V_expr_t;
-  bool present = false;
-  for (auto argument : arguments)
+  bool var_present = false;
+  for (auto arg : arguments)
     {
-      V_arguments.emplace_back(argument->normalizeEquation(var_endo, List_of_Op_RHS));
-      present = present || V_arguments[V_arguments.size()-1].first;
-      V_expr_t.push_back(V_arguments[V_arguments.size()-1].second);
+      arg->computeSubExprContainingVariable(symb_id, lag, contain_var);
+      var_present = var_present || contain_var.count(arg) > 0;
     }
-  if (!present)
-    return { 0, datatree.AddExternalFunction(symb_id, V_expr_t) };
-  else
-    return { 2, nullptr };
+  if (var_present)
+    contain_var.insert(const_cast<AbstractExternalFunctionNode *>(this));
+}
+
+BinaryOpNode *
+AbstractExternalFunctionNode::normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const
+{
+  throw NormalizationFailed();
 }
 
 void
@@ -8806,11 +8580,15 @@ VarExpectationNode::compile(ostream &CompileCode, unsigned int &instruction_numb
   exit(EXIT_FAILURE);
 }
 
-pair<int, expr_t>
-VarExpectationNode::normalizeEquation(int var_endo, vector<tuple<int, expr_t, expr_t>> &List_of_Op_RHS) const
+void
+VarExpectationNode::computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const
 {
-  cerr << "VarExpectationNode::normalizeEquation not implemented." << endl;
-  exit(EXIT_FAILURE);
+}
+
+BinaryOpNode *
+VarExpectationNode::normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const
+{
+  throw NormalizationFailed();
 }
 
 expr_t
@@ -9238,11 +9016,15 @@ PacExpectationNode::countDiffs() const
   return 0;
 }
 
-pair<int, expr_t>
-PacExpectationNode::normalizeEquation(int var_endo, vector<tuple<int, expr_t, expr_t>> &List_of_Op_RHS) const
+void
+PacExpectationNode::computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const
 {
-  cerr << "PacExpectationNode::normalizeEquation not implemented." << endl;
-  exit(EXIT_FAILURE);
+}
+
+BinaryOpNode *
+PacExpectationNode::normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const
+{
+  throw NormalizationFailed();
 }
 
 expr_t
diff --git a/src/ExprNode.hh b/src/ExprNode.hh
index 210c0360..9f8bcd15 100644
--- a/src/ExprNode.hh
+++ b/src/ExprNode.hh
@@ -431,8 +431,18 @@ public:
   */
   //  virtual void computeXrefs(set<int> &param, set<int> &endo, set<int> &exo, set<int> &exo_det) const = 0;
   virtual void computeXrefs(EquationInfo &ei) const = 0;
-  //! Try to normalize an equation linear in its endogenous variable
-  virtual pair<int, expr_t> normalizeEquation(int symb_id_endo, vector<tuple<int, expr_t, expr_t>> &List_of_Op_RHS) const = 0;
+
+  // Computes the set of all sub-expressions that contain the variable (symb_id, lag)
+  virtual void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const = 0;
+
+  //! Helper for normalization of equations
+  /*! Normalize the equation this = rhs.
+      Must be called on a node containing the desired LHS variable.
+      Returns an equal node of the form: LHS variable = new RHS.
+      Must be given the set of all subexpressions that contain the desired LHS variable.
+      Throws a NormallizationFailed() exception if normalization is not possible. */
+  virtual BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const = 0;
+  class NormalizationFailed {};
 
   //! Returns the maximum lead of endogenous in this expression
   /*! Always returns a non-negative value */
@@ -744,7 +754,8 @@ public:
   void compile(ostream &CompileCode, unsigned int &instruction_number, bool lhs_rhs, const temporary_terms_t &temporary_terms, const map_idx_t &map_idx, bool dynamic, bool steady_dynamic, const deriv_node_temp_terms_t &tef_terms) const override;
   expr_t toStatic(DataTree &static_datatree) const override;
   void computeXrefs(EquationInfo &ei) const override;
-  pair<int, expr_t> normalizeEquation(int symb_id_endo, vector<tuple<int, expr_t, expr_t>> &List_of_Op_RHS) const override;
+  void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
+  BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override;
   expr_t getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables) override;
   int maxEndoLead() const override;
   int maxExoLead() const override;
@@ -827,7 +838,8 @@ public:
   expr_t toStatic(DataTree &static_datatree) const override;
   void computeXrefs(EquationInfo &ei) const override;
   SymbolType get_type() const;
-  pair<int, expr_t> normalizeEquation(int symb_id_endo, vector<tuple<int, expr_t, expr_t>> &List_of_Op_RHS) const override;
+  void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
+  BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override;
   expr_t getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables) override;
   int maxEndoLead() const override;
   int maxExoLead() const override;
@@ -935,7 +947,8 @@ public:
   void compile(ostream &CompileCode, unsigned int &instruction_number, bool lhs_rhs, const temporary_terms_t &temporary_terms, const map_idx_t &map_idx, bool dynamic, bool steady_dynamic, const deriv_node_temp_terms_t &tef_terms) const override;
   expr_t toStatic(DataTree &static_datatree) const override;
   void computeXrefs(EquationInfo &ei) const override;
-  pair<int, expr_t> normalizeEquation(int symb_id_endo, vector<tuple<int, expr_t, expr_t>> &List_of_Op_RHS) const override;
+  void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
+  BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override;
   expr_t getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables) override;
   int maxEndoLead() const override;
   int maxExoLead() const override;
@@ -1047,7 +1060,11 @@ public:
   expr_t Compute_RHS(expr_t arg1, expr_t arg2, int op, int op_type) const;
   expr_t toStatic(DataTree &static_datatree) const override;
   void computeXrefs(EquationInfo &ei) const override;
-  pair<int, expr_t> normalizeEquation(int symb_id_endo, vector<tuple<int, expr_t, expr_t>> &List_of_Op_RHS) const override;
+  void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
+  BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override;
+  //! Try to normalize an equation with respect to a given dynamic variable.
+  /*! Should only be called on Equal nodes. The variable must appear in the equation. */
+  BinaryOpNode *normalizeEquation(int symb_id, int lag) const;
   expr_t getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables) override;
   int maxEndoLead() const override;
   int maxExoLead() const override;
@@ -1178,7 +1195,8 @@ public:
   void compile(ostream &CompileCode, unsigned int &instruction_number, bool lhs_rhs, const temporary_terms_t &temporary_terms, const map_idx_t &map_idx, bool dynamic, bool steady_dynamic, const deriv_node_temp_terms_t &tef_terms) const override;
   expr_t toStatic(DataTree &static_datatree) const override;
   void computeXrefs(EquationInfo &ei) const override;
-  pair<int, expr_t> normalizeEquation(int symb_id_endo, vector<tuple<int, expr_t, expr_t>> &List_of_Op_RHS) const override;
+  void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
+  BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override;
   expr_t getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables) override;
   int maxEndoLead() const override;
   int maxExoLead() const override;
@@ -1298,7 +1316,8 @@ public:
   void compile(ostream &CompileCode, unsigned int &instruction_number, bool lhs_rhs, const temporary_terms_t &temporary_terms, const map_idx_t &map_idx, bool dynamic, bool steady_dynamic, const deriv_node_temp_terms_t &tef_terms) const override = 0;
   expr_t toStatic(DataTree &static_datatree) const override = 0;
   void computeXrefs(EquationInfo &ei) const override = 0;
-  pair<int, expr_t> normalizeEquation(int symb_id_endo, vector<tuple<int, expr_t, expr_t>> &List_of_Op_RHS) const override;
+  void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
+  BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override;
   expr_t getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables) override;
   int maxEndoLead() const override;
   int maxExoLead() const override;
@@ -1529,7 +1548,8 @@ public:
   expr_t substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
   expr_t substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
   expr_t substitutePacExpectation(const string &name, expr_t subexpr) override;
-  pair<int, expr_t> normalizeEquation(int symb_id_endo, vector<tuple<int, expr_t, expr_t>> &List_of_Op_RHS) const override;
+  void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
+  BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override;
   void compile(ostream &CompileCode, unsigned int &instruction_number,
                bool lhs_rhs, const temporary_terms_t &temporary_terms,
                const map_idx_t &map_idx, bool dynamic, bool steady_dynamic,
@@ -1610,7 +1630,8 @@ public:
   expr_t substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
   expr_t substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
   expr_t substitutePacExpectation(const string &name, expr_t subexpr) override;
-  pair<int, expr_t> normalizeEquation(int symb_id_endo, vector<tuple<int, expr_t, expr_t>> &List_of_Op_RHS) const override;
+  void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
+  BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override;
   void compile(ostream &CompileCode, unsigned int &instruction_number,
                bool lhs_rhs, const temporary_terms_t &temporary_terms,
                const map_idx_t &map_idx, bool dynamic, bool steady_dynamic,
diff --git a/src/ModelTree.cc b/src/ModelTree.cc
index a25b1539..769665f9 100644
--- a/src/ModelTree.cc
+++ b/src/ModelTree.cc
@@ -661,7 +661,7 @@ ModelTree::equationTypeDetermination(const map<tuple<int, int, int>, expr_t> &fi
       int var = variable_reordered[i];
       expr_t lhs = equations[eq]->arg1;
       EquationType Equation_Simulation_Type = EquationType::solve;
-      pair<int, expr_t> res;
+      BinaryOpNode *normalized_eq = nullptr;
       if (auto it = first_order_endo_derivatives.find({ eq, var, 0 });
           it != first_order_endo_derivatives.end())
         {
@@ -676,16 +676,18 @@ ModelTree::equationTypeDetermination(const map<tuple<int, int, int>, expr_t> &fi
               derivative->collectEndogenous(result);
               bool variable_not_in_derivative = result.find({ var, 0 }) == result.end();
 
-              vector<tuple<int, expr_t, expr_t>> List_of_Op_RHS;
-              res = equations[eq]->normalizeEquation(var, List_of_Op_RHS);
-
-              if (mfs == 2 && variable_not_in_derivative && res.second)
-                Equation_Simulation_Type = EquationType::evaluate_s;
-              else if (mfs == 3 && res.second) // The equation could be solved analytically
-                Equation_Simulation_Type = EquationType::evaluate_s;
+              try
+                {
+                  normalized_eq = equations[eq]->normalizeEquation(symbol_table.getID(SymbolType::endogenous, var), 0);
+                  if ((mfs == 2 && variable_not_in_derivative) || mfs == 3)
+                    Equation_Simulation_Type = EquationType::evaluate_s;
+                }
+              catch (ExprNode::NormalizationFailed &e)
+                {
+                }
             }
         }
-      equation_type_and_normalized_equation[eq] = { Equation_Simulation_Type, dynamic_cast<BinaryOpNode *>(res.second) };
+      equation_type_and_normalized_equation[eq] = { Equation_Simulation_Type, normalized_eq };
     }
 }
 
-- 
GitLab