From 2cd3aa95cc954eb8c208963b397313d297f9f515 Mon Sep 17 00:00:00 2001
From: Houtan Bastani <houtan@dynare.org>
Date: Tue, 5 Jun 2018 16:38:37 +0200
Subject: [PATCH] When `transform_unary_ops` is passed, only substitute unary
 operators that appear in VAR equations

---
 src/DynamicModel.cc | 24 ++++++++++++++++++++----
 src/DynamicModel.hh |  2 +-
 src/ExprNode.cc     | 36 ++++++++++++++++++++----------------
 src/ModFile.cc      | 20 +++++++++++++++++---
 4 files changed, 58 insertions(+), 24 deletions(-)

diff --git a/src/DynamicModel.cc b/src/DynamicModel.cc
index 117c48a9..535e7a08 100644
--- a/src/DynamicModel.cc
+++ b/src/DynamicModel.cc
@@ -5378,15 +5378,31 @@ DynamicModel::substituteAdl()
 }
 
 void
-DynamicModel::substituteUnaryOps(StaticModel &static_model)
+DynamicModel::substituteUnaryOps(StaticModel &static_model, set<string> &var_model_eqtags)
 {
   diff_table_t nodes;
+  vector<int> eqnumber;
+  for (auto & eqtag : var_model_eqtags)
+    for (const auto & equation_tag : equation_tags)
+      if (equation_tag.second.first == "name"
+          && equation_tag.second.second == eqtag)
+        {
+          eqnumber.push_back(equation_tag.first);
+          break;
+        }
+
   // Find matching unary ops that may be outside of diffs (i.e., those with different lags)
+  set<int> used_local_vars;
+  for (int eqnn : eqnumber)
+    equations[eqnn]->collectVariables(eModelLocalVariable, used_local_vars);
+
+  // Only substitute unary ops in model local variables that appear in VAR equations
   for (auto & it : local_variables_table)
-    it.second->findUnaryOpNodesForAuxVarCreation(static_model, nodes);
+    if (used_local_vars.find(it.first) != used_local_vars.end())
+      it.second->findUnaryOpNodesForAuxVarCreation(static_model, nodes);
 
-  for (auto & equation : equations)
-    equation->findUnaryOpNodesForAuxVarCreation(static_model, nodes);
+  for (int eqnn : eqnumber)
+    equations[eqnn]->findUnaryOpNodesForAuxVarCreation(static_model, nodes);
 
   // Substitute in model local variables
   ExprNode::subst_table_t subst_table;
diff --git a/src/DynamicModel.hh b/src/DynamicModel.hh
index 6ae70ab4..36141a2e 100644
--- a/src/DynamicModel.hh
+++ b/src/DynamicModel.hh
@@ -420,7 +420,7 @@ public:
   void substituteAdl();
 
   //! Creates aux vars for certain unary operators: originally implemented for support of VARs
-  void substituteUnaryOps(StaticModel &static_model);
+  void substituteUnaryOps(StaticModel &static_model, set<string> &eq_tags);
 
   //! Substitutes diff operator
   void substituteDiff(StaticModel &static_model, ExprNode::subst_table_t &diff_subst_table);
diff --git a/src/ExprNode.cc b/src/ExprNode.cc
index 191c3a59..9ab518f1 100644
--- a/src/ExprNode.cc
+++ b/src/ExprNode.cc
@@ -3204,22 +3204,26 @@ UnaryOpNode::substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nod
   VariableNode *aux_var = nullptr;
   for (auto rit = it->second.rbegin();
        rit != it->second.rend(); rit++)
-    {
-      if (rit == it->second.rbegin())
-        {
-          auto *vn = dynamic_cast<VariableNode *>(const_cast<UnaryOpNode *>(this)->get_arg());
-          int symb_id = datatree.symbol_table.addUnaryOpAuxiliaryVar(this->idx, const_cast<UnaryOpNode *>(this),
-                                                                               vn->get_symb_id(), vn->get_lag());
-          aux_var = datatree.AddVariable(symb_id, 0);
-          neweqs.push_back(dynamic_cast<BinaryOpNode *>(datatree.AddEqual(aux_var, dynamic_cast<UnaryOpNode *>(rit->second))));
-          subst_table[rit->second] = dynamic_cast<VariableNode *>(aux_var);
-        }
-      else
-        {
-          auto *vn = dynamic_cast<VariableNode *>(dynamic_cast<UnaryOpNode *>(rit->second)->get_arg());
-          subst_table[rit->second] = dynamic_cast<VariableNode *>(aux_var->decreaseLeadsLags(-vn->get_lag()));
-        }
-    }
+    if (rit == it->second.rbegin())
+      {
+        auto *vn = dynamic_cast<VariableNode *>(const_cast<UnaryOpNode *>(this)->get_arg());
+        if (vn == nullptr)
+          {
+            cerr << "ERROR: You can only use a unary op on a variable node or another unary op node within a VAR." << endl;
+            exit(EXIT_FAILURE);
+          }
+        int symb_id = datatree.symbol_table.addUnaryOpAuxiliaryVar(this->idx, const_cast<UnaryOpNode *>(this),
+                                                                   vn->get_symb_id(), vn->get_lag());
+        aux_var = datatree.AddVariable(symb_id, 0);
+        neweqs.push_back(dynamic_cast<BinaryOpNode *>(datatree.AddEqual(aux_var,
+                                                                        dynamic_cast<UnaryOpNode *>(rit->second))));
+        subst_table[rit->second] = dynamic_cast<VariableNode *>(aux_var);
+      }
+    else
+      {
+        auto *vn = dynamic_cast<VariableNode *>(dynamic_cast<UnaryOpNode *>(rit->second)->get_arg());
+        subst_table[rit->second] = dynamic_cast<VariableNode *>(aux_var->decreaseLeadsLags(-vn->get_lag()));
+      }
 
   sit = subst_table.find(this);
   return const_cast<VariableNode *>(sit->second);
diff --git a/src/ModFile.cc b/src/ModFile.cc
index 2f85f50d..323b80b1 100644
--- a/src/ModFile.cc
+++ b/src/ModFile.cc
@@ -364,16 +364,30 @@ ModFile::transformPass(bool nostrict, bool stochastic, bool compute_xrefs, const
         }
     }
 
+  string var_model_name;
+  set<string> eqtags;
+  map<string, vector<string>> var_model_eq_tags;
+  map<string, pair<SymbolList, int>> var_model_info_var_expectation;
+  for (auto it = statements.begin(); it != statements.end(); it++)
+    {
+      auto *vms = dynamic_cast<VarModelStatement *>(*it);
+      if (vms != nullptr)
+        {
+          vms->getVarModelInfo(var_model_name, var_model_info_var_expectation, var_model_eq_tags);
+          for (auto & eqtag : var_model_eq_tags[var_model_name])
+            eqtags.insert(eqtag);
+        }
+    }
+
   if (transform_unary_ops)
-    dynamic_model.substituteUnaryOps(diff_static_model);
+    // substitute only those unary ops that appear in VAR equations
+    dynamic_model.substituteUnaryOps(diff_static_model, eqtags);
 
   // Create auxiliary variable and equations for Diff operator
   ExprNode::subst_table_t diff_subst_table;
   dynamic_model.substituteDiff(diff_static_model, diff_subst_table);
 
   // Var Model
-  map<string, pair<SymbolList, int>> var_model_info_var_expectation;
-  map<string, vector<string>> var_model_eq_tags;
   map<string, tuple<vector<int>, vector<expr_t>, vector<bool>, vector<int>, int, vector<bool>, vector<int>>>
     var_model_info_pac_expectation;
   for (auto it = statements.begin(); it != statements.end(); it++)
-- 
GitLab