From 7a438a3ce799e06b903317ba76639f04324483d6 Mon Sep 17 00:00:00 2001
From: Houtan Bastani <houtan@dynare.org>
Date: Mon, 3 Sep 2018 15:05:30 +0200
Subject: [PATCH] ensure that trend_variable found is actually a trend variable
 as declared in the trend_component_model statement

---
 src/DynamicModel.cc | 55 +++++++++++++++++++++++++++++++++++++--------
 src/DynamicModel.hh |  5 +++--
 src/ModFile.cc      |  2 +-
 src/SubModel.cc     |  6 +++++
 src/SubModel.hh     |  1 +
 5 files changed, 57 insertions(+), 12 deletions(-)

diff --git a/src/DynamicModel.cc b/src/DynamicModel.cc
index 5282909a..38c88eca 100644
--- a/src/DynamicModel.cc
+++ b/src/DynamicModel.cc
@@ -3456,31 +3456,48 @@ DynamicModel::runTrendTest(const eval_context_t &eval_context)
 }
 
 void
-DynamicModel::updateVarAndTrendModelRhs() const
+DynamicModel::updateVarAndTrendModel() const
 {
   for (int i = 0; i < 2; i++)
     {
-      map<string, vector<int>> eqnums;
+      map<string, vector<int>> eqnums, trend_eqnums;
       if (i == 0)
         eqnums = var_model_table.getEqNums();
       else if (i == 1)
-        eqnums = trend_component_model_table.getEqNums();
+        {
+          eqnums = trend_component_model_table.getEqNums();
+          trend_eqnums = trend_component_model_table.getTrendEqNums();
+        }
 
       map<string, vector<int>> trend_varr;
       map<string, vector<set<pair<int, int>>>> rhsr;
       for (const auto & it : eqnums)
         {
-          vector<int> lhs;
-          vector<int> trend_var;
+          vector<int> lhs, trend_var, trend_lhs;
           vector<set<pair<int, int>>> rhs;
-          int lhs_idx = 0;
+
           if (i == 1)
-            lhs = trend_component_model_table.getLhs(it.first);
+            {
+              lhs = trend_component_model_table.getLhs(it.first);
+              for (auto teqn : trend_eqnums.at(it.first))
+                {
+                  int eqnidx = 0;
+                  for (auto eqn : it.second)
+                    {
+                      if (eqn == teqn)
+                        trend_lhs.push_back(lhs[eqnidx]);
+                      eqnidx++;
+                    }
+                }
+            }
+
+          int lhs_idx = 0;
           for (auto eqn : it.second)
             {
               set<pair<int, int>> rhs_set;
               equations[eqn]->get_arg2()->collectDynamicVariables(SymbolType::endogenous, rhs_set);
               rhs.push_back(rhs_set);
+
               if (i == 1)
                 {
                   int lhs_symb_id = lhs[lhs_idx++];
@@ -3492,11 +3509,31 @@ DynamicModel::updateVarAndTrendModelRhs() const
                     catch (...)
                       {
                       }
-                  trend_var.push_back(equations[eqn]->get_arg2()->findTrendVariable(lhs_symb_id));
+                  int trend_var_symb_id = equations[eqn]->get_arg2()->findTrendVariable(lhs_symb_id);
+                  trend_var.push_back(trend_var_symb_id);
+                  if (trend_var_symb_id >= 0)
+                    {
+                      if (symbol_table.isAuxiliaryVariable(trend_var_symb_id))
+                        try
+                          {
+                            trend_var_symb_id = symbol_table.getOrigSymbIdForAuxVar(trend_var_symb_id);
+                          }
+                        catch (...)
+                          {
+                          }
+                      if (find(trend_lhs.begin(), trend_lhs.end(), trend_var_symb_id) == trend_lhs.end())
+                        {
+                          cerr << "ERROR: trend found in trend_component equation #" << eqn << " ("
+                               << symbol_table.getName(trend_var_symb_id) << ") does not correspond to a trend equation" << endl;
+                          exit(EXIT_FAILURE);
+                        }
+                    }
                 }
             }
+
           rhsr[it.first] = rhs;
-          trend_varr[it.first] = trend_var;
+          if (i == 1)
+            trend_varr[it.first] = trend_var;
         }
 
       if (i == 0)
diff --git a/src/DynamicModel.hh b/src/DynamicModel.hh
index 30bdc162..9224149b 100644
--- a/src/DynamicModel.hh
+++ b/src/DynamicModel.hh
@@ -313,8 +313,9 @@ public:
   void fillVarModelTableFromOrigModel(StaticModel &static_model) const;
 
   //! Update the rhs references in the var model and trend component tables
-  //! after substitution of auxiliary variables
-  void updateVarAndTrendModelRhs() const;
+  //! after substitution of auxiliary variables and find the trend variables
+  //! in the trend_component model
+  void updateVarAndTrendModel() const;
 
   //! Add aux equations (and aux variables) for variables declared in var_model
   //! at max order if they don't already exist
diff --git a/src/ModFile.cc b/src/ModFile.cc
index ee3512c0..e7183d5d 100644
--- a/src/ModFile.cc
+++ b/src/ModFile.cc
@@ -580,7 +580,7 @@ ModFile::transformPass(bool nostrict, bool stochastic, bool compute_xrefs, const
       dynamic_model.substituteEndoLagGreaterThanTwo(true);
     }
 
-  dynamic_model.updateVarAndTrendModelRhs();
+  dynamic_model.updateVarAndTrendModel();
 
   if (differentiate_forward_vars)
     dynamic_model.differentiateForwardVars(differentiate_forward_vars_subset);
diff --git a/src/SubModel.cc b/src/SubModel.cc
index 62613b39..f392c926 100644
--- a/src/SubModel.cc
+++ b/src/SubModel.cc
@@ -177,6 +177,12 @@ TrendComponentModelTable::getEqNums() const
   return eqnums;
 }
 
+map<string, vector<int>>
+TrendComponentModelTable::getTrendEqNums() const
+{
+  return trend_eqnums;
+}
+
 vector<int>
 TrendComponentModelTable::getNonTrendEqNums(const string &name_arg) const
 {
diff --git a/src/SubModel.hh b/src/SubModel.hh
index e5cba518..96432780 100644
--- a/src/SubModel.hh
+++ b/src/SubModel.hh
@@ -59,6 +59,7 @@ public:
   vector<string> getEqTags(const string &name_arg) const;
   map<string, vector<string>> getTrendEqTags() const;
   map<string, vector<int>> getEqNums() const;
+  map<string, vector<int>> getTrendEqNums() const;
   vector<int> getEqNums(const string &name_arg) const;
   vector<int> getMaxLags(const string &name_arg) const;
   int getMaxLag(const string &name_arg) const;
-- 
GitLab