From 9bc6c57a5c386d7f2fffcfd30ed793bf5e1de02a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?S=C3=A9bastien=20Villemot?= <sebastien@dynare.org>
Date: Wed, 20 Mar 2019 17:54:23 +0100
Subject: [PATCH] Factorization: create new utility
 SymbolTable::getUltimateOrigSymbID()

---
 src/ExprNode.cc    | 43 ++++++++-----------------------------------
 src/SymbolTable.cc | 15 +++++++++++++++
 src/SymbolTable.hh |  5 +++++
 3 files changed, 28 insertions(+), 35 deletions(-)

diff --git a/src/ExprNode.cc b/src/ExprNode.cc
index c4ff0ce8..7bbcc901 100644
--- a/src/ExprNode.cc
+++ b/src/ExprNode.cc
@@ -5152,16 +5152,7 @@ BinaryOpNode::getPacTargetSymbIdHelper(int lhs_symb_id, int undiff_lhs_symb_id,
   bool found_lagged_lhs = false;
   for (auto & it : endogs)
     {
-      int id = it.first;
-      while (datatree.symbol_table.isAuxiliaryVariable(id))
-        try
-          {
-            id = datatree.symbol_table.getOrigSymbIdForAuxVar(id);
-          }
-        catch (...)
-          {
-            break;
-          }
+      int id = datatree.symbol_table.getUltimateOrigSymbID(it.first);
       if (id == lhs_symb_id || id == undiff_lhs_symb_id)
         found_lagged_lhs = true;
       if (id != lhs_symb_id && id != undiff_lhs_symb_id)
@@ -5567,20 +5558,11 @@ BinaryOpNode::getPacEC(BinaryOpNode *bopn, int lhs_symb_id, int lhs_orig_symb_id
               exit(EXIT_FAILURE);
             }
           int id = vn->symb_id;
-          int orig_id = id;
+          int orig_id = datatree.symbol_table.getUltimateOrigSymbID(id);
           bool istarget = true;
-          while (datatree.symbol_table.isAuxiliaryVariable(id))
-            try
-              {
-                id = datatree.symbol_table.getOrigSymbIdForAuxVar(id);
-              }
-            catch (...)
-              {
-                break;
-              }
-          if (id == lhs_symb_id || id == lhs_orig_symb_id)
+          if (orig_id == lhs_symb_id || orig_id == lhs_orig_symb_id)
             istarget = false;
-          ordered_symb_ids.emplace_back(orig_id, istarget, scale);
+          ordered_symb_ids.emplace_back(id, istarget, scale);
         }
       ec_params_and_vars = make_pair(optim_param_symb_id, ordered_symb_ids);
     }
@@ -5633,17 +5615,8 @@ BinaryOpNode::getPacAREC(int lhs_symb_id, int lhs_orig_symb_id,
           exit(EXIT_FAILURE);
         }
 
-      int vidorig = vid;
-      while (datatree.symbol_table.isAuxiliaryVariable(vid))
-        try
-          {
-            vid = datatree.symbol_table.getOrigSymbIdForAuxVar(vid);
-          }
-        catch (...)
-          {
-            break;
-          }
-      if (vid == lhs_symb_id || vid == lhs_orig_symb_id)
+      int vidorig = datatree.symbol_table.getUltimateOrigSymbID(vid);
+      if (vidorig == lhs_symb_id || vidorig == lhs_orig_symb_id)
         {
           // This is an autoregressive term
           if (constant != 1 || pid == -1)
@@ -5651,11 +5624,11 @@ BinaryOpNode::getPacAREC(int lhs_symb_id, int lhs_orig_symb_id,
               cerr << "BinaryOpNode::getPacAREC: autoregressive terms must be of the form 'parameter*lagged_variable" << endl;
               exit(EXIT_FAILURE);
             }
-          ar_params_and_vars.insert({pid, { vidorig, lag }});
+          ar_params_and_vars.insert({pid, { vid, lag }});
         }
       else
         // This is a residual additive term
-        additive_vars_params_and_constants.push_back({ vidorig, lag, pid, constant});
+        additive_vars_params_and_constants.push_back({ vid, lag, pid, constant});
     }
 }
 
diff --git a/src/SymbolTable.cc b/src/SymbolTable.cc
index c7cdb933..912cbc70 100644
--- a/src/SymbolTable.cc
+++ b/src/SymbolTable.cc
@@ -1277,3 +1277,18 @@ SymbolTable::writeJsonVarVector(ostream &output, const vector<int> &varvec) cons
     }
   output << "]" << endl;
 }
+
+int
+SymbolTable::getUltimateOrigSymbID(int symb_id) const
+{
+  while (isAuxiliaryVariable(symb_id))
+    try
+      {
+        symb_id = getOrigSymbIdForAuxVar(symb_id);
+      }
+    catch (UnknownSymbolIDException &)
+      {
+        break;
+      }
+  return symb_id;
+}
diff --git a/src/SymbolTable.hh b/src/SymbolTable.hh
index 441b1c12..7c68534c 100644
--- a/src/SymbolTable.hh
+++ b/src/SymbolTable.hh
@@ -412,6 +412,11 @@ public:
   bool isDiffAuxiliaryVariable(int symb_id) const;
   //! Get list of endogenous variables without aux vars
   set <int> getOrigEndogenous() const;
+  //! Returns the original symbol corresponding to this variable
+  /* If symb_id is not an auxiliary var, returns symb_id. Otherwise,
+     repeatedly call getOrigSymbIDForAuxVar() until an original
+     (non-auxiliary) variable is found. */
+  int getUltimateOrigSymbID(int symb_id) const;
 };
 
 inline void
-- 
GitLab