From ad22e628749560d9c59b5f7002700a505102a234 Mon Sep 17 00:00:00 2001
From: Johannes Pfeifer <jpfeifer@gmx.de>
Date: Mon, 18 Jan 2021 16:42:20 +0100
Subject: [PATCH] filter_initial_state: Port lost preprocessor commits from
 https://git.dynare.org/JohannesPfeifer/dynare/-/commits/filter_initial_state

---
 src/ComputingTasks.cc | 65 +++++++++++++++++++++++++++++++++++++++++++
 src/ComputingTasks.hh | 14 ++++++++++
 src/DynareBison.yy    | 11 +++++++-
 src/DynareFlex.ll     |  1 +
 src/ParsingDriver.cc  | 32 +++++++++++++++++++++
 src/ParsingDriver.hh  |  6 ++++
 6 files changed, 128 insertions(+), 1 deletion(-)

diff --git a/src/ComputingTasks.cc b/src/ComputingTasks.cc
index aebdf321..f0f010d5 100644
--- a/src/ComputingTasks.cc
+++ b/src/ComputingTasks.cc
@@ -1700,6 +1700,71 @@ ObservationTrendsStatement::writeJsonOutput(ostream &output) const
          << "}";
 }
 
+FilterInitialStateStatement::FilterInitialStateStatement(filter_initial_state_elements_t filter_initial_state_elements_arg,
+                                                         const SymbolTable &symbol_table_arg) :
+  filter_initial_state_elements{move(filter_initial_state_elements_arg)},
+  symbol_table{symbol_table_arg}
+{
+}
+
+void
+FilterInitialStateStatement::writeOutput(ostream &output, const string &basename, bool minimal_workspace) const
+{
+  output << "M_.filter_initial_state = cell(M_.endo_nbr, 2);" << endl;
+  for (const auto &[key, val] : filter_initial_state_elements)
+    {
+      auto [symb_id, lag] = key;
+      SymbolType type = symbol_table.getType(symb_id);
+
+      if ((type == SymbolType::endogenous && lag < 0) || type == SymbolType::exogenous)
+        {
+          try
+            {
+              // This function call must remain the 1st statement in this block
+              symb_id = symbol_table.searchAuxiliaryVars(symb_id, lag);
+            }
+          catch (SymbolTable::SearchFailedException &e)
+            {
+              if (type == SymbolType::endogenous)
+                {
+                  cerr << "filter_initial_state: internal error, please contact the developers";
+                  exit(EXIT_FAILURE);
+                }
+              // We don't fail for exogenous, because they are not replaced by
+              // auxiliary variables in deterministic mode.
+            }
+        }
+
+      output << "M_.filter_initial_state("
+             << symbol_table.getTypeSpecificID(symb_id) + 1
+             << ",:) = {'" << symbol_table.getName(symb_id) << "', '";
+      val->writeOutput(output);
+      output << ";'};" << endl;
+    }
+}
+
+void
+FilterInitialStateStatement::writeJsonOutput(ostream &output) const
+{
+  output << R"({"statementName": "filter_initial_state", )"
+         << R"("states": [)";
+
+  for (auto it = filter_initial_state_elements.begin();
+       it != filter_initial_state_elements.end(); ++it)
+    {
+      if (it != filter_initial_state_elements.begin())
+        output << ", ";
+      auto &[key, val] = *it;
+      auto &[symb_id, lag] = key;
+      output << R"({ "var": ")" << symbol_table.getName(symb_id)
+             << R"(", "lag": )" << lag
+             << R"(, "value": ")";
+      val->writeJsonOutput(output, {}, {});
+      output << R"(" })";
+    }
+  output << "] }";
+}
+
 OsrParamsStatement::OsrParamsStatement(SymbolList symbol_list_arg, const SymbolTable &symbol_table_arg) :
   symbol_list{move(symbol_list_arg)},
   symbol_table{symbol_table_arg}
diff --git a/src/ComputingTasks.hh b/src/ComputingTasks.hh
index aac12ade..69ab9d70 100644
--- a/src/ComputingTasks.hh
+++ b/src/ComputingTasks.hh
@@ -301,6 +301,20 @@ public:
   void writeJsonOutput(ostream &output) const override;
 };
 
+class FilterInitialStateStatement : public Statement
+{
+public:
+  using filter_initial_state_elements_t = map<pair<int, int>, expr_t>;
+private:
+  const filter_initial_state_elements_t filter_initial_state_elements;
+  const SymbolTable &symbol_table;
+public:
+  FilterInitialStateStatement(filter_initial_state_elements_t filter_initial_state_elements_arg,
+                              const SymbolTable &symbol_table_arg);
+  void writeOutput(ostream &output, const string &basename, bool minimal_workspace) const override;
+  void writeJsonOutput(ostream &output) const override;
+};
+
 class OsrParamsStatement : public Statement
 {
 private:
diff --git a/src/DynareBison.yy b/src/DynareBison.yy
index 44d14904..bbcaa96e 100644
--- a/src/DynareBison.yy
+++ b/src/DynareBison.yy
@@ -100,7 +100,7 @@ class ParsingDriver;
 %token MODE_CHECK MODE_CHECK_NEIGHBOURHOOD_SIZE MODE_CHECK_SYMMETRIC_PLOTS MODE_CHECK_NUMBER_OF_POINTS MODE_COMPUTE MODE_FILE MODEL MODEL_COMPARISON MODEL_INFO MSHOCKS ABS SIGN
 %token MODEL_DIAGNOSTICS MODIFIEDHARMONICMEAN MOMENTS_VARENDO CONTEMPORANEOUS_CORRELATION DIFFUSE_FILTER SUB_DRAWS TAPER_STEPS GEWEKE_INTERVAL RAFTERY_LEWIS_QRS RAFTERY_LEWIS_DIAGNOSTICS MCMC_JUMPING_COVARIANCE MOMENT_CALIBRATION
 %token NUMBER_OF_PARTICLES RESAMPLING SYSTEMATIC GENERIC RESAMPLING_THRESHOLD RESAMPLING_METHOD KITAGAWA STRATIFIED SMOOTH
-%token CPF_WEIGHTS AMISANOTRISTANI MURRAYJONESPARSLOW WRITE_EQUATION_TAGS
+%token CPF_WEIGHTS AMISANOTRISTANI MURRAYJONESPARSLOW WRITE_EQUATION_TAGS FILTER_INITIAL_STATE
 %token NONLINEAR_FILTER_INITIALIZATION FILTER_ALGORITHM PROPOSAL_APPROXIMATION CUBATURE UNSCENTED MONTECARLO DISTRIBUTION_APPROXIMATION
 %token <string> NAME
 %token USE_PENALIZED_OBJECTIVE_FOR_HESSIAN INIT_STATE FAST_REALTIME RESCALE_PREDICTION_ERROR_COVARIANCE GENERATE_IRFS
@@ -244,6 +244,7 @@ statement : parameters
           | options_eq
           | varobs
           | observation_trends
+          | filter_initial_state
           | varexobs
           | unit_root_vars
           | dsample
@@ -2072,6 +2073,14 @@ trend_list : trend_list trend_element
 
 trend_element :  symbol '(' expression ')' ';' { driver.set_trend_element($1, $3); };
 
+filter_initial_state : FILTER_INITIAL_STATE ';' filter_initial_state_list END ';' { driver.set_filter_initial_state(); };
+
+filter_initial_state_list : filter_initial_state_list filter_initial_state_element
+                          | filter_initial_state_element
+                          ;
+
+filter_initial_state_element : symbol '(' signed_integer ')' EQUAL expression ';' { driver.set_filter_initial_state_element($1, $3, $6); };
+
 unit_root_vars : UNIT_ROOT_VARS symbol_list ';' { driver.set_unit_root_vars(); };
 
 optim_weights : OPTIM_WEIGHTS ';' optim_weights_list END ';' { driver.optim_weights(); };
diff --git a/src/DynareFlex.ll b/src/DynareFlex.ll
index 390e87bf..fabfb87b 100644
--- a/src/DynareFlex.ll
+++ b/src/DynareFlex.ll
@@ -200,6 +200,7 @@ DATE -?[0-9]+([ya]|m([1-9]|1[0-2])|q[1-4])
 <INITIAL>initval {BEGIN DYNARE_BLOCK; return token::INITVAL;}
 <INITIAL>endval {BEGIN DYNARE_BLOCK; return token::ENDVAL;}
 <INITIAL>histval {BEGIN DYNARE_BLOCK; return token::HISTVAL;}
+<INITIAL>filter_initial_state {BEGIN DYNARE_BLOCK; return token::FILTER_INITIAL_STATE;}
 <INITIAL>shocks {BEGIN DYNARE_BLOCK; return token::SHOCKS;}
 <INITIAL>shock_groups {BEGIN DYNARE_BLOCK; return token::SHOCK_GROUPS;}
 <INITIAL>init2shocks {BEGIN DYNARE_BLOCK; return token::INIT2SHOCKS;}
diff --git a/src/ParsingDriver.cc b/src/ParsingDriver.cc
index e9518117..9b0695c7 100644
--- a/src/ParsingDriver.cc
+++ b/src/ParsingDriver.cc
@@ -1935,6 +1935,13 @@ ParsingDriver::set_trends()
   trend_elements.clear();
 }
 
+void
+ParsingDriver::set_filter_initial_state()
+{
+  mod_file->addStatement(make_unique<FilterInitialStateStatement>(filter_initial_state_elements, mod_file->symbol_table));
+  filter_initial_state_elements.clear();
+}
+
 void
 ParsingDriver::set_trend_element(string arg1, expr_t arg2)
 {
@@ -1944,6 +1951,31 @@ ParsingDriver::set_trend_element(string arg1, expr_t arg2)
   trend_elements[move(arg1)] = arg2;
 }
 
+void
+ParsingDriver::set_filter_initial_state_element(const string &name, const string &lag, expr_t rhs)
+{
+  check_symbol_existence(name);
+  int symb_id = mod_file->symbol_table.getID(name);
+  SymbolType type = mod_file->symbol_table.getType(symb_id);
+  int ilag = stoi(lag);
+
+  if (type != SymbolType::endogenous
+      && type != SymbolType::exogenous
+      && type != SymbolType::exogenousDet)
+    error("filter_initial_state: " + name + " should be an endogenous or exogenous variable");
+
+  if ((type == SymbolType::exogenous || type == SymbolType::exogenousDet) && ilag == 0)
+    error("filter_initial_state: exogenous variable " + name + " must be provided with a lag");
+
+  if (filter_initial_state_elements.find({ symb_id, ilag }) != filter_initial_state_elements.end())
+    error("filter_initial_state: (" + name + ", " + lag + ") declared twice");
+
+  if (mod_file->dynamic_model.minLagForSymbol(symb_id) > ilag - 1)
+    error("filter_initial_state: variable " + name + " does not appear in the model with the lag " + to_string(ilag-1) + " (see the reference manual for the timing convention in 'filter_initial_state')");
+
+  filter_initial_state_elements[{ symb_id, ilag }] = rhs;
+}
+
 void
 ParsingDriver::set_optim_weights(string name, expr_t value)
 {
diff --git a/src/ParsingDriver.hh b/src/ParsingDriver.hh
index 04b791b4..ed4ef240 100644
--- a/src/ParsingDriver.hh
+++ b/src/ParsingDriver.hh
@@ -141,6 +141,8 @@ private:
   OptionsList options_list;
   //! Temporary storage for trend elements
   ObservationTrendsStatement::trend_elements_t trend_elements;
+  //! Temporary storage for filter_initial_state elements
+  FilterInitialStateStatement::filter_initial_state_elements_t filter_initial_state_elements;
   //! Temporary storage for filename list of ModelComparison (contains weights)
   ModelComparisonStatement::filename_list_t filename_list;
   //! Temporary storage for list of EstimationParams (from estimated_params* statements)
@@ -610,6 +612,10 @@ public:
   void forecast();
   void set_trends();
   void set_trend_element(string arg1, expr_t arg2);
+  //! filter_initial_state block
+  void set_filter_initial_state();
+  //! element for filter_initial_state block
+  void set_filter_initial_state_element(const string &name, const string &lag, expr_t rhs);
   void set_unit_root_vars();
   void optim_weights();
   void set_optim_weights(string name, expr_t value);
-- 
GitLab