From 520876560d5bb4c886a11de681557dd05cbc200e Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?S=C3=A9bastien=20Villemot?= <sebastien@dynare.org>
Date: Tue, 19 Dec 2023 16:09:41 +0100
Subject: [PATCH] New matched_irfs and matched_irfs_weights blocks

Closes: #124
---
 src/ComputingTasks.cc | 100 ++++++++++++++++++++++++++++++++++++
 src/ComputingTasks.hh |  29 +++++++++++
 src/DynareBison.yy    | 114 +++++++++++++++++++++++++++++++++++++++++-
 src/DynareFlex.ll     |   5 +-
 src/ModFile.cc        |   8 +--
 src/ParsingDriver.cc  |  13 +++++
 src/ParsingDriver.hh  |   7 +++
 7 files changed, 270 insertions(+), 6 deletions(-)

diff --git a/src/ComputingTasks.cc b/src/ComputingTasks.cc
index fa65f44c..90964908 100644
--- a/src/ComputingTasks.cc
+++ b/src/ComputingTasks.cc
@@ -5483,3 +5483,103 @@ ResidStatement::writeJsonOutput(ostream& output) const
     }
   output << "}";
 }
+
+MatchedIrfsStatement::MatchedIrfsStatement(matched_irfs_t values_weights_arg, bool overwrite_arg) :
+    values_weights {move(values_weights_arg)}, overwrite {overwrite_arg}
+{
+}
+
+void
+MatchedIrfsStatement::writeOutput(ostream& output, [[maybe_unused]] const string& basename,
+                                  [[maybe_unused]] bool minimal_workspace) const
+{
+  if (overwrite)
+    output << "M_.matched_irfs = {};" << endl;
+
+  for (const auto& [key, vec] : values_weights)
+    {
+      const auto& [endo, exo] = key;
+      output << "M_.matched_irfs = [M_.matched_irfs; {'" << endo << "', '" << exo << "', {";
+      for (const auto& [p1, p2, value, weight] : vec)
+        {
+          output << p1 << ":" << p2 << ", ";
+          value->writeOutput(output);
+          output << ", ";
+          weight->writeOutput(output);
+          output << "; ";
+        }
+      output << "}}];" << endl;
+    }
+}
+
+void
+MatchedIrfsStatement::writeJsonOutput(ostream& output) const
+{
+  output << R"({"statementName": "matched_irfs")"
+         << R"(, "overwrite": )" << boolalpha << overwrite << R"(, "contents": [)";
+  for (bool printed_something {false}; const auto& [key, vec] : values_weights)
+    {
+      if (exchange(printed_something, true))
+        output << ", ";
+      const auto& [endo, exo] = key;
+      output << R"({"var": ")" << endo << R"(", "varexo": ")" << exo
+             << R"(", "periods_values_weights": [)";
+      for (bool printed_something2 {false}; const auto& [p1, p2, value, weight] : vec)
+        {
+          if (exchange(printed_something2, true))
+            output << ", ";
+          output << R"({"period1": )" << p1 << ", "
+                 << R"("period2": })" << p2 << ", "
+                 << R"("value": ")";
+          value->writeJsonOutput(output, {}, {});
+          output << R"(", "weight": ")";
+          weight->writeJsonOutput(output, {}, {});
+          output << R"("})";
+        }
+      output << "]}";
+    }
+  output << "]}";
+}
+
+MatchedIrfsWeightsStatement::MatchedIrfsWeightsStatement(matched_irfs_weights_t weights_arg,
+                                                         bool overwrite_arg) :
+    weights {move(weights_arg)}, overwrite {overwrite_arg}
+{
+}
+
+void
+MatchedIrfsWeightsStatement::writeOutput(ostream& output, [[maybe_unused]] const string& basename,
+                                         [[maybe_unused]] bool minimal_workspace) const
+{
+  if (overwrite)
+    output << "M_.matched_irfs_weights = {};" << endl;
+
+  for (const auto& [key, val] : weights)
+    {
+      const auto& [endo1, periods1, exo1, endo2, periods2, exo2] = key;
+      output << "M_.matched_irfs_weights = [M_.matched_irfs_weights; {'" << endo1 << "', "
+             << periods1 << ", '" << exo1 << "', '" << endo2 << "', " << periods2 << ", '" << exo2
+             << "', ";
+      val->writeOutput(output);
+      output << "}];" << endl;
+    }
+}
+
+void
+MatchedIrfsWeightsStatement::writeJsonOutput(ostream& output) const
+{
+  output << R"({"statementName": "matched_irfs_weights")"
+         << R"(, "overwrite": )" << boolalpha << overwrite << R"(, "contents": [)";
+  for (bool printed_something {false}; const auto& [key, val] : weights)
+    {
+      const auto& [endo1, periods1, exo1, endo2, periods2, exo2] = key;
+      if (exchange(printed_something, true))
+        output << ", ";
+      output << R"({"endo1": ")" << endo1 << R"(", "periods1": ")" << periods1 << R"(", "exo1": )"
+             << exo1 << R"(", "endo2": ")" << endo2 << R"(", "periods2": ")" << periods2
+             << R"(", "exo2": )" << exo2 << R"(", "weight": ")";
+      val->writeJsonOutput(output, {}, {});
+      output << R"("})";
+    }
+  output << "]}";
+}
diff --git a/src/ComputingTasks.hh b/src/ComputingTasks.hh
index 24aadf9a..8909a8c7 100644
--- a/src/ComputingTasks.hh
+++ b/src/ComputingTasks.hh
@@ -1332,4 +1332,33 @@ public:
   void writeJsonOutput(ostream& output) const override;
 };
 
+class MatchedIrfsStatement : public Statement
+{
+public:
+  // (endo name, exo name) → vector of (period start, period end, value, weight)
+  using matched_irfs_t = map<pair<string, string>, vector<tuple<int, int, expr_t, expr_t>>>;
+  MatchedIrfsStatement(matched_irfs_t values_weights_arg, bool overwrite_arg);
+  void writeOutput(ostream& output, const string& basename, bool minimal_workspace) const override;
+  void writeJsonOutput(ostream& output) const override;
+
+private:
+  const matched_irfs_t values_weights;
+  const bool overwrite;
+};
+
+class MatchedIrfsWeightsStatement : public Statement
+{
+public:
+  /* (endo1 name, period index or range for endo1, exo1 name, endo2 name, period index or range for
+     endo2, exo2 name) → weight */
+  using matched_irfs_weights_t = map<tuple<string, string, string, string, string, string>, expr_t>;
+  MatchedIrfsWeightsStatement(matched_irfs_weights_t weights_arg, bool overwrite_arg);
+  void writeOutput(ostream& output, const string& basename, bool minimal_workspace) const override;
+  void writeJsonOutput(ostream& output) const override;
+
+private:
+  const matched_irfs_weights_t weights;
+  const bool overwrite;
+};
+
 #endif
diff --git a/src/DynareBison.yy b/src/DynareBison.yy
index 8e0072c4..2e2087da 100644
--- a/src/DynareBison.yy
+++ b/src/DynareBison.yy
@@ -215,7 +215,7 @@ str_tolower(string s)
 %token ENDVAL_STEADY STEADY_SOLVE_ALGO STEADY_MAXIT STEADY_TOLF STEADY_TOLX STEADY_MARKOWITZ
 %token HOMOTOPY_MAX_COMPLETION_SHARE HOMOTOPY_MIN_STEP_SIZE HOMOTOPY_INITIAL_STEP_SIZE HOMOTOPY_STEP_SIZE_INCREASE_SUCCESS_COUNT
 %token HOMOTOPY_LINEARIZATION_FALLBACK HOMOTOPY_MARGINAL_LINEARIZATION_FALLBACK FROM_INITVAL_TO_ENDVAL
-%token STATIC_MFS RELATIVE_TO_INITVAL
+%token STATIC_MFS RELATIVE_TO_INITVAL MATCHED_IRFS MATCHED_IRFS_WEIGHTS WEIGHTS
 
 %token <vector<string>> SYMBOL_VEC
 
@@ -236,7 +236,7 @@ str_tolower(string s)
 %type <vector<int>> vec_int_elem vec_int_1 vec_int vec_int_number
 %type <PriorDistributions> prior_pdf prior_distribution
 %type <pair<expr_t,expr_t>> calibration_range
-%type <pair<string,string>> partition_elem subsamples_eq_opt integer_range_w_inf tag_pair
+%type <pair<string,string>> partition_elem subsamples_eq_opt integer_range_w_inf tag_pair matched_irfs_elem_var_varexo
 %type <vector<pair<string,string>>> partition partition_1 symbol_list_with_tex
 %type <vector<map<string, string>>> tag_pair_list_for_selection
 %type <map<string, string>> tag_pair_list
@@ -251,6 +251,12 @@ str_tolower(string s)
 %type <vector<tuple<string, string, vector<pair<string, string>>>>> symbol_list_with_tex_and_partition
 %type <map<string, variant<bool, string>>> mshocks_options_list
 %type <pair<string, variant<bool, string>>> mshocks_option
+%type <pair<vector<expr_t>, vector<expr_t>>> matched_irfs_elem_values_weights
+%type <pair<pair<string, string>, vector<tuple<int, int, expr_t, expr_t>>>> matched_irfs_elem
+%type <map<pair<string, string>, vector<tuple<int, int, expr_t, expr_t>>>> matched_irfs_list
+%type <tuple<string, string, string>> matched_irfs_weights_elem_var_varexo
+%type <pair<tuple<string, string, string, string, string, string>, expr_t>> matched_irfs_weights_elem
+%type <map<tuple<string, string, string, string, string, string>, expr_t>> matched_irfs_weights_list
 %%
 
 %start statement_list;
@@ -386,6 +392,8 @@ statement : parameters
           | var_remove
           | pac_target_info
           | resid
+          | matched_irfs
+          | matched_irfs_weights
           ;
 
 dsample : DSAMPLE INT_NUMBER ';'
@@ -3548,6 +3556,108 @@ init2shocks_element : symbol symbol ';' { driver.add_init2shocks($1, $2); }
                     | symbol COMMA symbol ';' { driver.add_init2shocks($1, $3); }
                     ;
 
+matched_irfs : MATCHED_IRFS ';' matched_irfs_list END ';'
+               { driver.matched_irfs($3, false); }
+             | MATCHED_IRFS '(' OVERWRITE ')' ';' matched_irfs_list END ';'
+               { driver.matched_irfs($6, true); }
+             ;
+
+matched_irfs_list : matched_irfs_elem
+                    { $$ = {$1}; }
+                  | matched_irfs_list matched_irfs_elem
+                    {
+                      $$ = $1;
+                      auto [it, success] = $$.insert($2);
+                      if (!success)
+                        driver.error("matched_irfs: the pair endogenous " + $2.first.first + " with exogenous " + $2.first.second + " appears two times");
+                    }
+                  ;
+
+matched_irfs_elem : matched_irfs_elem_var_varexo
+                    PERIODS period_list ';'
+                    matched_irfs_elem_values_weights
+                    {
+                      if ($3.size() != $5.first.size())
+                        driver.error("matched_irfs: the 'periods' and 'values' keywords are not followed by the same number of elements");
+                      if ($3.size() != $5.second.size())
+                        driver.error("matched_irfs: the 'periods' and 'values' keywords are not followed by the same number of elements");
+                      vector<tuple<int, int, expr_t, expr_t>> v;
+                      v.reserve($3.size());
+                      for (size_t i {0}; i < $3.size(); i++)
+                        v.emplace_back($3[i].first, $3[i].second, $5.first[i], $5.second[i]);
+                      $$ = {$1, v};
+                    }
+                  ;
+
+matched_irfs_elem_var_varexo : VAR symbol ';' VAREXO symbol ';'
+                               {
+                                 driver.check_symbol_is_endogenous($2);
+                                 driver.check_symbol_is_exogenous($5, false);
+                                 $$ = {$2, $5};
+                               }
+                             | VAREXO symbol ';' VAR symbol ';'
+                               {
+                                 driver.check_symbol_is_endogenous($5);
+                                 driver.check_symbol_is_exogenous($2, false);
+                                 $$ = {$5, $2};
+                               }
+                             ;
+
+matched_irfs_elem_values_weights : VALUES value_list ';'
+                                   {
+                                     $$ = {$2, vector($2.size(),
+                                                      driver.add_non_negative_constant("1"))};
+                                   }
+                                 | VALUES value_list ';' WEIGHTS value_list ';'
+                                   { $$ = {$2, $5}; }
+                                 | WEIGHTS value_list ';' VALUES value_list ';'
+                                   { $$ = {$5, $2}; }
+                                 ;
+
+matched_irfs_weights : MATCHED_IRFS_WEIGHTS ';' matched_irfs_weights_list END ';'
+                       { driver.matched_irfs_weights($3, false); }
+                     | MATCHED_IRFS_WEIGHTS '(' OVERWRITE ')' ';' matched_irfs_weights_list END ';'
+                       { driver.matched_irfs_weights($6, true); }
+                     ;
+
+matched_irfs_weights_list : matched_irfs_weights_elem
+                            { $$ = {$1}; }
+                          | matched_irfs_weights_list matched_irfs_weights_elem
+                            {
+                              $$ = $1;
+                              auto [it, success] = $$.insert($2);
+                              if (!success)
+                                driver.error("matched_irfs: the tuple (" + get<0>($2.first)
+                                             + "(" + get<1>($2.first) + ")," + get<2>($2.first)
+                                             + "," + get<3>($2.first) + "(" + get<4>($2.first) + "),"
+                                             + get<5>($2.first) + ") appears two times");
+                            }
+                          ;
+
+matched_irfs_weights_elem : matched_irfs_weights_elem_var_varexo COMMA
+                            matched_irfs_weights_elem_var_varexo COMMA
+                            expression ';'
+                            {
+                              $$ = {{get<0>($1), get<1>($1), get<2>($1),
+                                     get<0>($3), get<1>($3), get<2>($3)},
+                                    $5};
+                            }
+                          ;
+
+matched_irfs_weights_elem_var_varexo : symbol '(' INT_NUMBER ')' COMMA symbol
+                                       {
+                                         driver.check_symbol_is_endogenous($1);
+                                         driver.check_symbol_is_exogenous($6, false);
+                                         $$ = {$1, $3, $6};
+                                       }
+                                     | symbol '(' integer_range ')' COMMA symbol
+                                       {
+                                         driver.check_symbol_is_endogenous($1);
+                                         driver.check_symbol_is_exogenous($6, false);
+                                         $$ = {$1, $3, $6};
+                                       }
+                                     ;
+
 o_solve_algo : SOLVE_ALGO EQUAL INT_NUMBER { driver.option_num("solve_algo", $3); };
 o_stack_solve_algo : STACK_SOLVE_ALGO EQUAL INT_NUMBER { driver.option_num("stack_solve_algo", $3); };
 o_robust_lin_solve : ROBUST_LIN_SOLVE { driver.option_num("simul.robust_lin_solve", "true"); };
diff --git a/src/DynareFlex.ll b/src/DynareFlex.ll
index f38277a0..5fceacf4 100644
--- a/src/DynareFlex.ll
+++ b/src/DynareFlex.ll
@@ -233,6 +233,8 @@ DATE -?[0-9]+([ya]|m([1-9]|1[0-2])|q[1-4])
 <INITIAL>occbin_constraints {BEGIN DYNARE_BLOCK; return token::OCCBIN_CONSTRAINTS;}
 <INITIAL>model_replace {BEGIN DYNARE_BLOCK; return token::MODEL_REPLACE;}
 <INITIAL>pac_target_info {BEGIN DYNARE_BLOCK; return token::PAC_TARGET_INFO;}
+<INITIAL>matched_irfs {BEGIN DYNARE_BLOCK; return token::MATCHED_IRFS;}
+<INITIAL>matched_irfs_weights {BEGIN DYNARE_BLOCK; return token::MATCHED_IRFS_WEIGHTS;}
 
  /* For the semicolon after an "end" keyword */
 <INITIAL>; {return Dynare::parser::token_type (yytext[0]);}
@@ -787,6 +789,7 @@ DATE -?[0-9]+([ya]|m([1-9]|1[0-2])|q[1-4])
 
  /* Inside a Dynare block */
 <DYNARE_BLOCK>var {return token::VAR;}
+<DYNARE_BLOCK>varexo {return token::VAREXO;}
 <DYNARE_BLOCK>stderr {return token::STDERR;}
 <DYNARE_BLOCK>values {return token::VALUES;}
 <DYNARE_BLOCK>corr {return token::CORR;}
@@ -856,7 +859,7 @@ DATE -?[0-9]+([ya]|m([1-9]|1[0-2])|q[1-4])
   yylval->build<string>(yytext);
   return token::DD;
 }
-
+<DYNARE_BLOCK>weights {return token::WEIGHTS;}
 
  /* Inside Dynare statement */
 <DYNARE_STATEMENT>solve_algo {return token::SOLVE_ALGO;}
diff --git a/src/ModFile.cc b/src/ModFile.cc
index 1df12611..3201be71 100644
--- a/src/ModFile.cc
+++ b/src/ModFile.cc
@@ -953,14 +953,16 @@ ModFile::writeMOutput(const string& basename, bool clear_all, bool clear_global,
   // May be later modified by a shocks block
   mOutputFile << "M_.sigma_e_is_diagonal = true;" << endl;
 
-  // Initialize M_.det_shocks, M_.surprise_shocks, M_.learnt_shocks, M_.learnt_endval and
-  // M_.heteroskedastic_shocks
+  /* Initialize the structures created for several blocks, as part of the implementation of the
+     “overwrite” option */
   mOutputFile << "M_.det_shocks = [];" << endl
               << "M_.surprise_shocks = [];" << endl
               << "M_.learnt_shocks = [];" << endl
               << "M_.learnt_endval = [];" << endl
               << "M_.heteroskedastic_shocks.Qvalue_orig = [];" << endl
-              << "M_.heteroskedastic_shocks.Qscale_orig = [];" << endl;
+              << "M_.heteroskedastic_shocks.Qscale_orig = [];" << endl
+              << "M_.matched_irfs = {};" << endl
+              << "M_.matched_irfs_weights = {};" << endl;
 
   // NB: options_.{ramsey,discretionary}_policy should rather be fields of M_
   mOutputFile << boolalpha << "options_.linear = " << linear << ";" << endl
diff --git a/src/ParsingDriver.cc b/src/ParsingDriver.cc
index 67cf5557..94e27545 100644
--- a/src/ParsingDriver.cc
+++ b/src/ParsingDriver.cc
@@ -3875,3 +3875,16 @@ ParsingDriver::resid()
   mod_file->addStatement(make_unique<ResidStatement>(move(options_list)));
   options_list.clear();
 }
+
+void
+ParsingDriver::matched_irfs(MatchedIrfsStatement::matched_irfs_t values_weights, bool overwrite)
+{
+  mod_file->addStatement(make_unique<MatchedIrfsStatement>(move(values_weights), overwrite));
+}
+
+void
+ParsingDriver::matched_irfs_weights(MatchedIrfsWeightsStatement::matched_irfs_weights_t weights,
+                                    bool overwrite)
+{
+  mod_file->addStatement(make_unique<MatchedIrfsWeightsStatement>(move(weights), overwrite));
+}
diff --git a/src/ParsingDriver.hh b/src/ParsingDriver.hh
index be31e57d..724aabd8 100644
--- a/src/ParsingDriver.hh
+++ b/src/ParsingDriver.hh
@@ -98,6 +98,7 @@ private:
   //! message if it isn't
   void check_symbol_is_endogenous_or_exogenous(const string& name, bool allow_exo_det);
 
+public:
   //! Checks that a given symbol exists and is a endogenous, and stops with an error message if it
   //! isn't
   void check_symbol_is_endogenous(const string& name);
@@ -106,6 +107,7 @@ private:
   //! isn't
   void check_symbol_is_exogenous(const string& name, bool allow_exo_det);
 
+private:
   //! Checks for symbol existence in model block. If it doesn't exist, an error message is stored to
   //! be printed at the end of the model block
   void check_symbol_existence_in_model_block(const string& name);
@@ -954,6 +956,11 @@ public:
   void set_pac_target_info_component_kind(PacTargetKind kind);
   // Add a resid statement
   void resid();
+  // Add a matched_irfs block
+  void matched_irfs(MatchedIrfsStatement::matched_irfs_t values_weights, bool overwrite);
+  // Add a matched_irfs_weights block
+  void matched_irfs_weights(MatchedIrfsWeightsStatement::matched_irfs_weights_t weights,
+                            bool overwrite);
   // Returns true iff the string is a legal symbol identifier (see NAME token in lexer)
   static bool isSymbolIdentifier(const string& str);
   // Given an Occbin regime name, returns the corresponding auxiliary parameter
-- 
GitLab