From a0f74f5c16b9a1826069797a5fa9ebce2ae7777f Mon Sep 17 00:00:00 2001
From: Houtan Bastani <houtan@dynare.org>
Date: Mon, 13 May 2019 17:43:16 +0200
Subject: [PATCH] pac growth: allow linear combination

---
 src/ComputingTasks.cc | 136 ++++++++++++------------------------------
 src/ComputingTasks.hh |   4 +-
 src/DynamicModel.cc   |   9 ++-
 src/DynamicModel.hh   |   2 +-
 src/ModFile.cc        |   2 +-
 5 files changed, 46 insertions(+), 107 deletions(-)

diff --git a/src/ComputingTasks.cc b/src/ComputingTasks.cc
index 814eec45..fa84b3c7 100644
--- a/src/ComputingTasks.cc
+++ b/src/ComputingTasks.cc
@@ -275,66 +275,25 @@ PacModelStatement::PacModelStatement(string name_arg,
   original_growth{growth_arg},
   steady_state_growth_rate_number{steady_state_growth_rate_number_arg},
   steady_state_growth_rate_symb_id{steady_state_growth_rate_symb_id_arg},
-  symbol_table{symbol_table_arg},
-  growth_symb_id{-1},
-  growth_lag{0}
+  symbol_table{symbol_table_arg}
 {
 }
 
-void
-PacModelStatement::checkPass(ModFileStructure &mod_file_struct, WarningConsolidation &warnings)
-{
-  if (growth == nullptr)
-    return;
-
-  auto *vn = dynamic_cast<VariableNode *>(growth);
-  if (vn != nullptr)
-    {
-      mod_file_struct.pac_params.insert(vn->symb_id);
-      mod_file_struct.pac_params.insert(vn->lag);
-    }
-
-  auto *uon = dynamic_cast<UnaryOpNode *>(growth);
-  if (uon != nullptr)
-    if (uon->op_code == UnaryOpcode::diff)
-      {
-        auto *uonvn = dynamic_cast<VariableNode *>(uon->arg);
-        auto *uon1 = dynamic_cast<UnaryOpNode *>(uon->arg);
-        while (uonvn == nullptr && uon1 != nullptr)
-          {
-            uonvn = dynamic_cast<VariableNode *>(uon1->arg);
-            uon1 = dynamic_cast<UnaryOpNode *>(uon1->arg);
-          }
-        if (uonvn == nullptr)
-          {
-            cerr << "Pac growth parameter must be either a variable or a diff unary op of a variable" << endl;
-            exit(EXIT_FAILURE);
-          }
-         mod_file_struct.pac_params.insert(uonvn->symb_id);
-         mod_file_struct.pac_params.insert(uonvn->lag);
-      }
-
-  if (vn == nullptr && uon == nullptr)
-    {
-      cerr << "Pac growth parameter must be either a variable or a diff unary op of a variable" << endl;
-      exit(EXIT_FAILURE);
-    }
-}
-
 void
 PacModelStatement::overwriteGrowth(expr_t new_growth)
 {
   if (new_growth == nullptr || growth == nullptr)
     return;
-
   growth = new_growth;
-  auto *vn = dynamic_cast<VariableNode *>(growth);
-  if (vn == nullptr)
+  try
+    {
+      growth_info = growth->matchLinearCombinationOfVariables();
+    }
+  catch (ExprNode::MatchFailureException &e)
     {
-      cerr << "PacModelStatement::overwriteGrowth: Internal Dynare error: should not arrive here" << endl;
+      cerr << "Pac growth must be a linear combination of varibles" << endl;
+      exit(EXIT_FAILURE);
     }
-  growth_symb_id = vn->symb_id;
-  growth_lag = vn->lag;
 }
 
 void
@@ -349,8 +308,20 @@ PacModelStatement::writeOutput(ostream &output, const string &basename, bool min
     output << "M_.pac." << name << ".steady_state_growth_rate = "
            << symbol_table.getTypeSpecificID(steady_state_growth_rate_symb_id) + 1 << ";" << endl;
 
-  if (growth_symb_id >= 0)
+  size_t nlc = growth_info.size();
+  output << "M_.pac." << name << ".growth_index = repmat(-1, " << nlc << ", 1);" << endl
+         << "M_.pac." << name << ".growth_lag = zeros(" << nlc << ", 1);" << endl
+         << "M_.pac." << name << ".growth_param_id = repmat(-1, " << nlc << ", 1);" << endl
+         << "M_.pac." << name << ".growth_constant = zeros(" << nlc << ", 1);" << endl
+         << "M_.pac." << name << ".growth_type = repmat({''}, " <<  nlc << ", 1);" << endl
+         << "M_.pac." << name << ".growth_part_str = repmat({''}, " <<  nlc << ", 1);" << endl;
+  int i = 0;
+  for (auto & it : growth_info)
     {
+      i++;
+      int growth_symb_id, growth_lag, param_id = -1;
+      double constant = 0;
+      tie(growth_symb_id, growth_lag, param_id, constant) = it;
       string growth_type;
       switch (symbol_table.getType(growth_symb_id))
         {
@@ -372,12 +343,8 @@ PacModelStatement::writeOutput(ostream &output, const string &basename, bool min
         {
           // case when this is not the highest lag of the growth variable
           int aux_symb_id = symbol_table.searchAuxiliaryVars(growth_symb_id, growth_lag);
-          output << "M_.pac." << name << ".growth_index = " << symbol_table.getTypeSpecificID(aux_symb_id) + 1 << ";" << endl
-                 << "M_.pac." << name << ".growth_lag = 0;" << endl
-                 << "M_.pac." << name << ".growth_type = '" << growth_type << "';" << endl
-                 << "M_.pac." << name << ".growth_str = '";
-          original_growth->writeJsonOutput(output, {}, {}, true);
-          output << "';" << endl;
+          output << "M_.pac." << name << ".growth_index(" << i << ") = " << symbol_table.getTypeSpecificID(aux_symb_id) + 1 << ";" << endl
+                 << "M_.pac." << name << ".growth_lag(" << i << ") = 0;" << endl;
         }
       catch (...)
         {
@@ -386,25 +353,25 @@ PacModelStatement::writeOutput(ostream &output, const string &basename, bool min
               // case when this is the highest lag of the growth variable
               int tmp_growth_lag = growth_lag + 1;
               int aux_symb_id = symbol_table.searchAuxiliaryVars(growth_symb_id, tmp_growth_lag);
-              output << "M_.pac." << name << ".growth_index = " << symbol_table.getTypeSpecificID(aux_symb_id) + 1 << ";" << endl
-                     << "M_.pac." << name << ".growth_lag = -1;" << endl
-                     << "M_.pac." << name << ".growth_type = '" << growth_type << "';" << endl
-                     << "M_.pac." << name << ".growth_str = '";
-              original_growth->writeJsonOutput(output, {}, {}, true);
-              output << "';" << endl;
+              output << "M_.pac." << name << ".growth_index(" << i << ") = " << symbol_table.getTypeSpecificID(aux_symb_id) + 1 << ";" << endl
+                     << "M_.pac." << name << ".growth_lag(" << i << ") = -1;" << endl;
             }
           catch (...)
             {
               // case when there is no aux var for the variable
-              output << "M_.pac." << name << ".growth_index = " << symbol_table.getTypeSpecificID(growth_symb_id) + 1 << ";" << endl
-                     << "M_.pac." << name << ".growth_lag = " << growth_lag << ";" << endl
-                     << "M_.pac." << name << ".growth_type = '" << growth_type << "';" << endl
-                     << "M_.pac." << name << ".growth_str = '";
-              original_growth->writeJsonOutput(output, {}, {}, true);
-              output << "';" << endl;
+              output << "M_.pac." << name << ".growth_index(" << i << ") = " << symbol_table.getTypeSpecificID(growth_symb_id) + 1 << ";" << endl
+                     << "M_.pac." << name << ".growth_lag(" << i << ") = " << growth_lag << ";" << endl;
             }
         }
+
+      output << "M_.pac." << name << ".growth_param_id(" << i << ") = "
+             << (param_id == -1 ? -1 : symbol_table.getTypeSpecificID(param_id)) + 1 << ";" << endl
+             << "M_.pac." << name << ".growth_constant(" << i << ") = " << constant << ";" << endl
+             << "M_.pac." << name << ".growth_type{" << i << "} = '" << growth_type << "';" << endl;
     }
+  output << "M_.pac." << name << ".growth_str = '";
+  original_growth->writeJsonOutput(output, {}, {}, true);
+  output << "';" << endl;
 }
 
 void
@@ -413,35 +380,10 @@ PacModelStatement::writeJsonOutput(ostream &output) const
   output << R"({"statementName": "pac_model",)"
          << R"("model_name": ")" << name << R"(",)"
          << R"("auxiliary_model_name": ")" << aux_model_name << R"(",)"
-         << R"("discount_index": )" << symbol_table.getTypeSpecificID(discount) + 1;
-
-  if (growth_symb_id >= 0)
-    {
-      string growth_type;
-      switch (symbol_table.getType(growth_symb_id))
-        {
-        case SymbolType::endogenous:
-          growth_type = "endogenous";
-          break;
-        case SymbolType::exogenous:
-          growth_type = "exogenous";
-          break;
-        case SymbolType::parameter:
-          growth_type = "parameter";
-          break;
-        default:
-          {
-          }
-        }
-      output << ","
-             << R"("growth_index": )" << symbol_table.getTypeSpecificID(growth_symb_id) + 1 << ","
-             << R"("growth_lag": )" << growth_lag << ","
-             << R"("growth_type": ")" << growth_type << R"(",)" << endl
-             << R"("growth_str": ")";
-      original_growth->writeJsonOutput(output, {}, {}, true);
-      output << R"(")" << endl;
-    }
-  output << "}";
+         << R"("discount_index": )" << symbol_table.getTypeSpecificID(discount) + 1
+         << R"("growth_str": ")";
+  original_growth->writeJsonOutput(output, {}, {}, true);
+  output << R"("})" << endl;
 }
 
 VarEstimationStatement::VarEstimationStatement(OptionsList options_list_arg) :
diff --git a/src/ComputingTasks.hh b/src/ComputingTasks.hh
index 24e6de75..d6c19e4e 100644
--- a/src/ComputingTasks.hh
+++ b/src/ComputingTasks.hh
@@ -143,9 +143,8 @@ private:
   const double steady_state_growth_rate_number;
   const int steady_state_growth_rate_symb_id;
   const SymbolTable &symbol_table;
+  vector<tuple<int, int, int, double>> growth_info;
 public:
-  int growth_symb_id;
-  int growth_lag;
   PacModelStatement(string name_arg,
                     string aux_model_name_arg,
                     string discount_arg,
@@ -153,7 +152,6 @@ public:
                     double steady_state_growth_rate_number_arg,
                     int steady_state_growth_rate_symb_id_arg,
                     const SymbolTable &symbol_table_arg);
-  void checkPass(ModFileStructure &mod_file_struct, WarningConsolidation &warnings) override;
   void overwriteGrowth(expr_t new_growth);
   void writeOutput(ostream &output, const string &basename, bool minimal_workspace) const override;
   void writeJsonOutput(ostream &output) const override;
diff --git a/src/DynamicModel.cc b/src/DynamicModel.cc
index 81917191..fd2b5a8d 100644
--- a/src/DynamicModel.cc
+++ b/src/DynamicModel.cc
@@ -4729,7 +4729,7 @@ DynamicModel::fillPacModelInfo(const string &pac_model_name,
                                string aux_model_type,
                                const map<pair<string, string>, pair<string, int>> &eqtag_and_lag,
                                const vector<bool> &nonstationary,
-                               int growth_symb_id, int growth_lag)
+                               expr_t growth)
 {
   pac_eqtag_and_lag.insert(eqtag_and_lag.begin(), eqtag_and_lag.end());
 
@@ -4745,7 +4745,7 @@ DynamicModel::fillPacModelInfo(const string &pac_model_name,
         stationary_vars_present = true;
 
   int growth_param_index = -1;
-  if (growth_symb_id >= 0)
+  if (growth != nullptr)
     growth_param_index = symbol_table.addSymbol(pac_model_name +
                                                 "_pac_growth_neutrality_correction",
                                                 SymbolType::parameter);
@@ -4789,10 +4789,9 @@ DynamicModel::fillPacModelInfo(const string &pac_model_name,
                                          AddVariable(lhsit, -i)));
             }
 
-      if (growth_symb_id >= 0)
+      if (growth != nullptr)
         subExpr = AddPlus(subExpr,
-                          AddTimes(AddVariable(growth_param_index),
-                                   AddVariable(growth_symb_id, growth_lag)));
+                          AddTimes(AddVariable(growth_param_index), growth));
 
       pac_expectation_substitution[{pac_model_name, eqtag}] = subExpr;
     }
diff --git a/src/DynamicModel.hh b/src/DynamicModel.hh
index 324fa8b7..ddd991fe 100644
--- a/src/DynamicModel.hh
+++ b/src/DynamicModel.hh
@@ -352,7 +352,7 @@ public:
                         string aux_model_type,
                         const map<pair<string, string>, pair<string, int>> &eqtag_and_lag,
                         const vector<bool> &nonstationary,
-                        int growth_symb_id, int growth_lag);
+                        expr_t growth);
 
   //! Substitutes pac_expectation operator with expectation based on auxiliary model
   void substitutePacExpectation(const string & pac_model_name);
diff --git a/src/ModFile.cc b/src/ModFile.cc
index 159bf6b6..d395a38b 100644
--- a/src/ModFile.cc
+++ b/src/ModFile.cc
@@ -471,7 +471,7 @@ ModFile::transformPass(bool nostrict, bool stochastic, bool compute_xrefs, const
                                                                     eqtag_and_lag, diff_subst_table);
            else
              dynamic_model.fillPacModelInfo(pms->name, lhs, max_lag, aux_model_type,
-                                            eqtag_and_lag, nonstationary, pms->growth_symb_id, pms->growth_lag);
+                                            eqtag_and_lag, nonstationary, pms->growth);
            dynamic_model.substitutePacExpectation(pms->name);
          }
      }
-- 
GitLab