From 84c2dc5f3621d47407397d519b623eb9d1fc1eb3 Mon Sep 17 00:00:00 2001
From: Houtan Bastani <houtan@dynare.org>
Date: Thu, 7 Jun 2018 12:53:00 +0200
Subject: [PATCH] transform_unary_ops now introduces aux variables/equations
 for all unary ops specified by UnaryOpNode::createAuxVarForUnaryOpNode()

In the absence of this option, if a var_model statement(s) is present, then aux vars/eqs are created for the same types of unary operators but only for equations specified in the var_model statement

In the absence of both this option and var_model statements, no unary op auxiliary variables are created

diffs continue to be substituted everywhere; for the moment auxiliary variables are created for diffs of expressions. A forthcoming change will allow auxiliary variables created for diffs of expressions to be linked with their lagged expressions as is currently the case for diffs of variables
---
 src/DynamicModel.cc | 47 +++++++++++++++++------------
 src/DynamicModel.hh |  8 ++++-
 src/ExprNode.cc     | 72 +++++++++++++++++++++------------------------
 src/ExprNode.hh     |  2 +-
 src/ModFile.cc      |  4 ++-
 src/SymbolTable.cc  | 26 +++++++++++++---
 src/SymbolTable.hh  |  2 +-
 7 files changed, 96 insertions(+), 65 deletions(-)

diff --git a/src/DynamicModel.cc b/src/DynamicModel.cc
index aa703ff5..cb62cea2 100644
--- a/src/DynamicModel.cc
+++ b/src/DynamicModel.cc
@@ -25,6 +25,7 @@
 #include <cerrno>
 #include <algorithm>
 #include <iterator>
+#include <numeric>
 #include "DynamicModel.hh"
 
 // For mkdir() and chdir()
@@ -5403,26 +5404,40 @@ DynamicModel::findPacExpectationEquationNumbers(vector<int> &eqnumbers) const
     }
 }
 
+void
+DynamicModel::substituteUnaryOps(StaticModel &static_model)
+{
+  vector<int> eqnumbers(equations.size());
+  iota(eqnumbers.begin(), eqnumbers.end(), 0);
+  substituteUnaryOps(static_model, eqnumbers);
+}
+
 void
 DynamicModel::substituteUnaryOps(StaticModel &static_model, set<string> &var_model_eqtags)
+{
+  vector<int> eqnumbers;
+  getEquationNumbersFromTags(eqnumbers, var_model_eqtags);
+  findPacExpectationEquationNumbers(eqnumbers);
+  substituteUnaryOps(static_model, eqnumbers);
+}
+
+void
+DynamicModel::substituteUnaryOps(StaticModel &static_model, vector<int> &eqnumbers)
 {
   diff_table_t nodes;
-  vector<int> eqnumber;
-  getEquationNumbersFromTags(eqnumber, var_model_eqtags);
-  findPacExpectationEquationNumbers(eqnumber);
 
   // Find matching unary ops that may be outside of diffs (i.e., those with different lags)
   set<int> used_local_vars;
-  for (int eqnn : eqnumber)
-    equations[eqnn]->collectVariables(eModelLocalVariable, used_local_vars);
+  for (int eqnumber : eqnumbers)
+    equations[eqnumber]->collectVariables(eModelLocalVariable, used_local_vars);
 
   // Only substitute unary ops in model local variables that appear in VAR equations
   for (auto & it : local_variables_table)
     if (used_local_vars.find(it.first) != used_local_vars.end())
       it.second->findUnaryOpNodesForAuxVarCreation(static_model, nodes);
 
-  for (int eqnn : eqnumber)
-    equations[eqnn]->findUnaryOpNodesForAuxVarCreation(static_model, nodes);
+  for (int eqnumber : eqnumbers)
+    equations[eqnumber]->findUnaryOpNodesForAuxVarCreation(static_model, nodes);
 
   // Substitute in model local variables
   ExprNode::subst_table_t subst_table;
@@ -5434,7 +5449,7 @@ DynamicModel::substituteUnaryOps(StaticModel &static_model, set<string> &var_mod
   for (auto & equation : equations)
     {
       auto *substeq = dynamic_cast<BinaryOpNode *>(equation->
-                                                           substituteUnaryOpNodes(static_model, nodes, subst_table, neweqs));
+                                                   substituteUnaryOpNodes(static_model, nodes, subst_table, neweqs));
       assert(substeq != nullptr);
       equation = substeq;
     }
@@ -5450,15 +5465,11 @@ DynamicModel::substituteUnaryOps(StaticModel &static_model, set<string> &var_mod
 }
 
 void
-DynamicModel::substituteDiff(StaticModel &static_model, ExprNode::subst_table_t &diff_subst_table, set<string> &var_model_eqtags)
+DynamicModel::substituteDiff(StaticModel &static_model, ExprNode::subst_table_t &diff_subst_table)
 {
-  vector<int> eqnumbers;
-  getEquationNumbersFromTags(eqnumbers, var_model_eqtags);
-  findPacExpectationEquationNumbers(eqnumbers);
-
   set<int> used_local_vars;
-  for (int eqnumber : eqnumbers)
-    equations[eqnumber]->collectVariables(eModelLocalVariable, used_local_vars);
+  for (const auto & equation : equations)
+    equation->collectVariables(eModelLocalVariable, used_local_vars);
 
   // Only substitute diffs in model local variables that appear in VAR equations
   diff_table_t diff_table;
@@ -5466,8 +5477,8 @@ DynamicModel::substituteDiff(StaticModel &static_model, ExprNode::subst_table_t
     if (used_local_vars.find(it.first) != used_local_vars.end())
       it.second->findDiffNodes(static_model, diff_table);
 
-  for (int eqnumber : eqnumbers)
-    equations[eqnumber]->findDiffNodes(static_model, diff_table);
+  for (const auto & equation : equations)
+    equation->findDiffNodes(static_model, diff_table);
 
   // Substitute in model local variables
   vector<BinaryOpNode *> neweqs;
@@ -5478,7 +5489,7 @@ DynamicModel::substituteDiff(StaticModel &static_model, ExprNode::subst_table_t
   for (auto & equation : equations)
     {
       auto *substeq = dynamic_cast<BinaryOpNode *>(equation->
-                                                           substituteDiff(static_model, diff_table, diff_subst_table, neweqs));
+                                                   substituteDiff(static_model, diff_table, diff_subst_table, neweqs));
       assert(substeq != nullptr);
       equation = substeq;
     }
diff --git a/src/DynamicModel.hh b/src/DynamicModel.hh
index 578a9ad5..8eb69ecc 100644
--- a/src/DynamicModel.hh
+++ b/src/DynamicModel.hh
@@ -423,11 +423,17 @@ public:
   //! Substitutes adl operator
   void substituteAdl();
 
+  //! Creates aux vars for all unary operators
+  void substituteUnaryOps(StaticModel &static_model);
+
   //! Creates aux vars for certain unary operators: originally implemented for support of VARs
   void substituteUnaryOps(StaticModel &static_model, set<string> &eq_tags);
 
+  //! Creates aux vars for certain unary operators: originally implemented for support of VARs
+  void substituteUnaryOps(StaticModel &static_model, vector<int> &eqnumbers);
+
   //! Substitutes diff operator
-  void substituteDiff(StaticModel &static_model, ExprNode::subst_table_t &diff_subst_table, set<string> &var_model_eqtags);
+  void substituteDiff(StaticModel &static_model, ExprNode::subst_table_t &diff_subst_table);
 
   //! Table to undiff LHS variables for pac vector z
   void getUndiffLHSForPac(vector<int> &lhs, vector<expr_t> &lhs_expr_t, vector<bool> &diff, vector<int> &orig_diff_var,
diff --git a/src/ExprNode.cc b/src/ExprNode.cc
index 8619492f..422e6c1e 100644
--- a/src/ExprNode.cc
+++ b/src/ExprNode.cc
@@ -3045,7 +3045,7 @@ UnaryOpNode::countDiffs() const
 }
 
 bool
-UnaryOpNode::createAuxVarForUnaryOpNodeInDiffOp() const
+UnaryOpNode::createAuxVarForUnaryOpNode() const
 {
   switch (op_code)
     {
@@ -3077,14 +3077,14 @@ UnaryOpNode::createAuxVarForUnaryOpNodeInDiffOp() const
 void
 UnaryOpNode::findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const
 {
-  if (!this->createAuxVarForUnaryOpNodeInDiffOp())
-    {
-      arg->findUnaryOpNodesForAuxVarCreation(static_datatree, nodes);
-      return;
-    }
+  arg->findUnaryOpNodesForAuxVarCreation(static_datatree, nodes);
+
+  if (!this->createAuxVarForUnaryOpNode())
+    return;
 
   expr_t sthis = this->toStatic(static_datatree);
   int arg_max_lag = -arg->maxLag();
+  // TODO: implement recursive expression comparison, ensuring that the difference in the lags is constant across nodes
   auto it = nodes.find(sthis);
   if (it != nodes.end())
     {
@@ -3101,13 +3101,14 @@ UnaryOpNode::findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_t
 void
 UnaryOpNode::findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const
 {
+  arg->findDiffNodes(static_datatree, diff_table);
+
   if (op_code != oDiff)
     return;
 
-  arg->findDiffNodes(static_datatree, diff_table);
-
   expr_t sthis = this->toStatic(static_datatree);
   int arg_max_lag = -arg->maxLag();
+  // TODO: implement recursive expression comparison, ensuring that the difference in the lags is constant across nodes
   auto it = diff_table.find(sthis);
   if (it != diff_table.end())
     {
@@ -3125,11 +3126,9 @@ expr_t
 UnaryOpNode::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table,
                             vector<BinaryOpNode *> &neweqs) const
 {
+  expr_t argsubst = arg->substituteDiff(static_datatree, diff_table, subst_table, neweqs);
   if (op_code != oDiff)
-    {
-      expr_t argsubst = arg->substituteDiff(static_datatree, diff_table, subst_table, neweqs);
-      return buildSimilarUnaryOpNode(argsubst, datatree);
-    }
+    return buildSimilarUnaryOpNode(argsubst, datatree);
 
   subst_table_t::const_iterator sit = subst_table.find(this);
   if (sit != subst_table.end())
@@ -3137,13 +3136,19 @@ UnaryOpNode::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table,
 
   expr_t sthis = dynamic_cast<UnaryOpNode *>(this->toStatic(static_datatree));
   auto it = diff_table.find(sthis);
+  int symb_id;
   if (it == diff_table.end() || it->second[-arg->maxLag()] != this)
     {
       // diff does not appear in VAR equations
-      // so simply substitute diff(x) with x-x(-1)
-      expr_t argsubst = arg->substituteDiff(static_datatree, diff_table, subst_table, neweqs);
-      return dynamic_cast<BinaryOpNode *>(datatree.AddMinus(argsubst,
-                                                            argsubst->decreaseLeadsLags(1)));
+      // so simply create aux var and return
+      // Once the comparison of expression nodes works, come back and remove this part, folding into the next loop.
+      symb_id = datatree.symbol_table.addDiffAuxiliaryVar(argsubst->idx, argsubst);
+      VariableNode *aux_var = datatree.AddVariable(symb_id, 0);
+      neweqs.push_back(dynamic_cast<BinaryOpNode *>(datatree.AddEqual(aux_var,
+                                                                      datatree.AddMinus(argsubst,
+                                                                                        argsubst->decreaseLeadsLags(1)))));
+      subst_table[this] = dynamic_cast<VariableNode *>(aux_var);
+      return const_cast<VariableNode *>(subst_table[this]);
     }
 
   int last_arg_max_lag = 0;
@@ -3153,19 +3158,13 @@ UnaryOpNode::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table,
     {
       expr_t argsubst = dynamic_cast<UnaryOpNode *>(rit->second)->
           get_arg()->substituteDiff(static_datatree, diff_table, subst_table, neweqs);
-      int symb_id;
       auto *vn = dynamic_cast<VariableNode *>(argsubst);
       if (rit == it->second.rbegin())
         {
           if (vn != nullptr)
             symb_id = datatree.symbol_table.addDiffAuxiliaryVar(argsubst->idx, argsubst, vn->get_symb_id(), vn->get_lag());
           else
-            {
-              // We know that the supported unary ops have already been substituted
-              cerr << "ERROR: You can only use the `diff` operator on variables and certain unary ops." << endl
-                   << "       Try passing the `transform_unary_ops` option on the dynare command line." << endl;
-              exit(EXIT_FAILURE);
-            }
+            symb_id = datatree.symbol_table.addDiffAuxiliaryVar(argsubst->idx, argsubst);
 
           // make originating aux var & equation
           last_arg_max_lag = rit->first;
@@ -3210,35 +3209,30 @@ UnaryOpNode::substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nod
 
   auto *sthis = dynamic_cast<UnaryOpNode *>(this->toStatic(static_datatree));
   auto it = nodes.find(sthis);
+  expr_t argsubst = arg->substituteUnaryOpNodes(static_datatree, nodes, subst_table, neweqs);
   if (it == nodes.end())
-    {
-      expr_t argsubst = arg->substituteUnaryOpNodes(static_datatree, nodes, subst_table, neweqs);
-      return buildSimilarUnaryOpNode(argsubst, datatree);
-    }
+    return buildSimilarUnaryOpNode(argsubst, datatree);
 
+  int base_aux_lag;
   VariableNode *aux_var = nullptr;
-  for (auto rit = it->second.rbegin();
-       rit != it->second.rend(); rit++)
+  for (auto rit = it->second.rbegin(); rit != it->second.rend(); rit++)
     if (rit == it->second.rbegin())
       {
-        auto *vn = dynamic_cast<VariableNode *>(const_cast<UnaryOpNode *>(this)->get_arg());
+        int symb_id;
+        auto *vn = dynamic_cast<VariableNode *>(argsubst);
         if (vn == nullptr)
-          {
-            cerr << "ERROR: You can only use a unary op on a variable node or another unary op node within a VAR." << endl;
-            exit(EXIT_FAILURE);
-          }
-        int symb_id = datatree.symbol_table.addUnaryOpAuxiliaryVar(this->idx, const_cast<UnaryOpNode *>(this),
+            symb_id = datatree.symbol_table.addUnaryOpAuxiliaryVar(this->idx, dynamic_cast<UnaryOpNode *>(rit->second));
+        else
+            symb_id = datatree.symbol_table.addUnaryOpAuxiliaryVar(this->idx, dynamic_cast<UnaryOpNode *>(rit->second),
                                                                    vn->get_symb_id(), vn->get_lag());
         aux_var = datatree.AddVariable(symb_id, 0);
         neweqs.push_back(dynamic_cast<BinaryOpNode *>(datatree.AddEqual(aux_var,
                                                                         dynamic_cast<UnaryOpNode *>(rit->second))));
         subst_table[rit->second] = dynamic_cast<VariableNode *>(aux_var);
+        base_aux_lag = rit->first;
       }
     else
-      {
-        auto *vn = dynamic_cast<VariableNode *>(dynamic_cast<UnaryOpNode *>(rit->second)->get_arg());
-        subst_table[rit->second] = dynamic_cast<VariableNode *>(aux_var->decreaseLeadsLags(-vn->get_lag()));
-      }
+      subst_table[rit->second] = dynamic_cast<VariableNode *>(aux_var->decreaseLeadsLags(base_aux_lag - rit->first));
 
   sit = subst_table.find(this);
   return const_cast<VariableNode *>(sit->second);
diff --git a/src/ExprNode.hh b/src/ExprNode.hh
index 7b2b6280..13071eb0 100644
--- a/src/ExprNode.hh
+++ b/src/ExprNode.hh
@@ -811,7 +811,7 @@ public:
   expr_t substituteExpectation(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool partial_information_model) const override;
   expr_t substituteAdl() const override;
   void findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const override;
-  bool createAuxVarForUnaryOpNodeInDiffOp() const;
+  bool createAuxVarForUnaryOpNode() const;
   void findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const override;
   expr_t substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
   expr_t substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
diff --git a/src/ModFile.cc b/src/ModFile.cc
index 3332cf47..1e237dc3 100644
--- a/src/ModFile.cc
+++ b/src/ModFile.cc
@@ -381,12 +381,14 @@ ModFile::transformPass(bool nostrict, bool stochastic, bool compute_xrefs, const
     }
 
   if (transform_unary_ops)
+    dynamic_model.substituteUnaryOps(diff_static_model);
+  else
     // substitute only those unary ops that appear in VAR equations
     dynamic_model.substituteUnaryOps(diff_static_model, eqtags);
 
   // Create auxiliary variable and equations for Diff operators that appear in VAR equations
   ExprNode::subst_table_t diff_subst_table;
-  dynamic_model.substituteDiff(diff_static_model, diff_subst_table, eqtags);
+  dynamic_model.substituteDiff(diff_static_model, diff_subst_table);
 
   // Var Model
   map<string, tuple<vector<int>, vector<expr_t>, vector<bool>, vector<int>, int, vector<bool>, vector<int>>>
diff --git a/src/SymbolTable.cc b/src/SymbolTable.cc
index f9f21917..3d56c1c2 100644
--- a/src/SymbolTable.cc
+++ b/src/SymbolTable.cc
@@ -354,10 +354,14 @@ SymbolTable::writeOutput(ostream &output) const noexcept(false)
           case avEndoLag:
           case avExoLag:
           case avVarModel:
-          case avUnaryOp:
             output << "M_.aux_vars(" << i+1 << ").orig_index = " << getTypeSpecificID(aux_vars[i].get_orig_symb_id())+1 << ";" << endl
                    << "M_.aux_vars(" << i+1 << ").orig_lead_lag = " << aux_vars[i].get_orig_lead_lag() << ";" << endl;
             break;
+          case avUnaryOp:
+            if (aux_vars[i].get_orig_symb_id() >= 0)
+              output << "M_.aux_vars(" << i+1 << ").orig_index = " << getTypeSpecificID(aux_vars[i].get_orig_symb_id())+1 << ";" << endl
+                     << "M_.aux_vars(" << i+1 << ").orig_lead_lag = " << aux_vars[i].get_orig_lead_lag() << ";" << endl;
+            break;
           case avMultiplier:
             output << "M_.aux_vars(" << i+1 << ").eq_nbr = " << aux_vars[i].get_equation_number_for_multiplier() + 1 << ";" << endl;
             break;
@@ -479,10 +483,14 @@ SymbolTable::writeCOutput(ostream &output) const noexcept(false)
             case avEndoLag:
             case avExoLag:
             case avVarModel:
-            case avUnaryOp:
               output << "av[" << i << "].orig_index = " << getTypeSpecificID(aux_vars[i].get_orig_symb_id()) << ";" << endl
                      << "av[" << i << "].orig_lead_lag = " << aux_vars[i].get_orig_lead_lag() << ";" << endl;
               break;
+            case avUnaryOp:
+              if (aux_vars[i].get_orig_symb_id() >= 0)
+                output << "av[" << i << "].orig_index = " << getTypeSpecificID(aux_vars[i].get_orig_symb_id()) << ";" << endl
+                       << "av[" << i << "].orig_lead_lag = " << aux_vars[i].get_orig_lead_lag() << ";" << endl;
+              break;
             case avDiff:
             case avDiffLag:
               if (aux_vars[i].get_orig_symb_id() >= 0)
@@ -579,10 +587,14 @@ SymbolTable::writeCCOutput(ostream &output) const noexcept(false)
         case avEndoLag:
         case avExoLag:
         case avVarModel:
-        case avUnaryOp:
           output << "av" << i << ".orig_index = " << getTypeSpecificID(aux_vars[i].get_orig_symb_id()) << ";" << endl
                  << "av" << i << ".orig_lead_lag = " << aux_vars[i].get_orig_lead_lag() << ";" << endl;
           break;
+        case avUnaryOp:
+          if (aux_vars[i].get_orig_symb_id() >= 0)
+            output << "av" << i << ".orig_index = " << getTypeSpecificID(aux_vars[i].get_orig_symb_id()) << ";" << endl
+                   << "av" << i << ".orig_lead_lag = " << aux_vars[i].get_orig_lead_lag() << ";" << endl;
+          break;
         case avDiff:
         case avDiffLag:
           if (aux_vars[i].get_orig_symb_id() >= 0)
@@ -1098,10 +1110,16 @@ SymbolTable::writeJuliaOutput(ostream &output) const noexcept(false)
             case avEndoLag:
             case avExoLag:
             case avVarModel:
-            case avUnaryOp:
               output << getTypeSpecificID(aux_var.get_orig_symb_id()) + 1 << ", "
                      << aux_var.get_orig_lead_lag() << ", typemin(Int), string()";
               break;
+            case avUnaryOp:
+              if (aux_var.get_orig_symb_id() >= 0)
+                output << getTypeSpecificID(aux_var.get_orig_symb_id()) + 1 << ", " << aux_var.get_orig_lead_lag();
+              else
+                output << "typemin(Int), typemin(Int)";
+              output << ", typemin(Int), string()";
+              break;
             case avDiff:
             case avDiffLag:
               if (aux_var.get_orig_symb_id() >= 0)
diff --git a/src/SymbolTable.hh b/src/SymbolTable.hh
index 06fe9502..0ccdc1c1 100644
--- a/src/SymbolTable.hh
+++ b/src/SymbolTable.hh
@@ -295,7 +295,7 @@ public:
   //! Takes care of timing between diff statements
   int addDiffLagAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_id, int orig_lag) noexcept(false);
   //! An Auxiliary variable for a unary op
-  int addUnaryOpAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_id, int orig_lag) noexcept(false);
+  int addUnaryOpAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_id = -1, int orig_lag = 0) noexcept(false);
   //! Returns the number of auxiliary variables
   int
   AuxVarsSize() const
-- 
GitLab