From a88ac75488d86a51f8febbead4692769722881c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Villemot?= <sebastien@dynare.org> Date: Wed, 13 Nov 2024 15:01:27 +0100 Subject: [PATCH] =?UTF-8?q?Allow=20dates=20in=20=E2=80=9Clearnt=5Fin?= =?UTF-8?q?=E2=80=9D=20option?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit NB: the “overwrite†option only works for the same type of “learnt_in†value, i.e. an integer-valued “learnt_in†cannot replace a date-valued “learnt_in†and vice versa. --- src/DynareBison.yy | 23 +++++--- src/NumericalInitialization.cc | 26 ++++++--- src/NumericalInitialization.hh | 8 +-- src/ParsingDriver.cc | 97 +++++++++++++++++++--------------- src/ParsingDriver.hh | 7 +-- src/Shocks.cc | 42 +++++++++++---- src/Shocks.hh | 4 +- 7 files changed, 132 insertions(+), 75 deletions(-) diff --git a/src/DynareBison.yy b/src/DynareBison.yy index 0048713f..ee619c17 100644 --- a/src/DynareBison.yy +++ b/src/DynareBison.yy @@ -254,8 +254,9 @@ str_tolower(string s) %type <pair<string, expr_t>> occbin_constraints_regime_option %type <PacTargetKind> pac_target_kind %type <vector<tuple<string, string, vector<pair<string, string>>>>> symbol_list_with_tex_and_partition -%type <map<string, variant<bool, string>>> mshocks_options_list -%type <pair<string, variant<bool, string>>> mshocks_option +%type <variant<int, string>> integer_or_date +%type <map<string, variant<bool, variant<int, string>>>> mshocks_options_list +%type <pair<string, variant<bool, variant<int, string>>>> mshocks_option %type <pair<vector<expr_t>, vector<expr_t>>> matched_irfs_elem_values_weights %type <pair<pair<string, string>, vector<tuple<int, int, expr_t, expr_t>>>> matched_irfs_elem %type <map<pair<string, string>, vector<tuple<int, int, expr_t, expr_t>>>> matched_irfs_list @@ -807,11 +808,17 @@ h_options: o_filename | o_series ; +integer_or_date : INT_NUMBER + { $$.emplace<int>(stoi($1)); } + | date_expr + { $$.emplace<string>($1); } + ; + endval : ENDVAL ';' endval_list END ';' { driver.end_endval(false); } | ENDVAL '(' ALL_VALUES_REQUIRED ')' ';' endval_list END ';' { driver.end_endval(true); } - | ENDVAL '(' LEARNT_IN EQUAL INT_NUMBER ')' ';' endval_list END ';' + | ENDVAL '(' LEARNT_IN EQUAL integer_or_date ')' ';' endval_list END ';' { driver.end_endval_learnt_in($5); } ; @@ -1220,9 +1227,9 @@ shocks : SHOCKS ';' shock_list END ';' { driver.end_shocks(false); } | SHOCKS '(' SURPRISE ')' ';' det_shock_list END ';' { driver.end_shocks_surprise(false); } | SHOCKS '(' SURPRISE COMMA OVERWRITE ')' ';' det_shock_list END ';' { driver.end_shocks_surprise(true); } | SHOCKS '(' OVERWRITE COMMA SURPRISE ')' ';' det_shock_list END ';' { driver.end_shocks_surprise(true); } - | SHOCKS '(' LEARNT_IN EQUAL INT_NUMBER ')' ';' det_shock_list END ';' { driver.end_shocks_learnt_in($5, false); } - | SHOCKS '(' LEARNT_IN EQUAL INT_NUMBER COMMA OVERWRITE ')' ';' det_shock_list END ';' { driver.end_shocks_learnt_in($5, true); } - | SHOCKS '(' OVERWRITE COMMA LEARNT_IN EQUAL INT_NUMBER ')' ';' det_shock_list END ';' { driver.end_shocks_learnt_in($7, true); } + | SHOCKS '(' LEARNT_IN EQUAL integer_or_date ')' ';' det_shock_list END ';' { driver.end_shocks_learnt_in($5, false); } + | SHOCKS '(' LEARNT_IN EQUAL integer_or_date COMMA OVERWRITE ')' ';' det_shock_list END ';' { driver.end_shocks_learnt_in($5, true); } + | SHOCKS '(' OVERWRITE COMMA LEARNT_IN EQUAL integer_or_date ')' ';' det_shock_list END ';' { driver.end_shocks_learnt_in($7, true); } | SHOCKS '(' HETEROGENEITY EQUAL symbol ')' ';' stoch_shock_list END ';' { driver.end_heterogeneous_shocks($5, false); } | SHOCKS '(' HETEROGENEITY EQUAL symbol COMMA OVERWRITE ')' ';' stoch_shock_list END ';' @@ -1372,7 +1379,7 @@ mshocks : MSHOCKS ';' mshock_list END ';' alternative in the variant, so that default initialization of the variant by the [] operator will give false */ if ($3.contains("learnt_in")) - driver.end_mshocks_learnt_in(get<string>($3.at("learnt_in")), + driver.end_mshocks_learnt_in(get<variant<int, string>>($3.at("learnt_in")), get<bool>($3["overwrite"]), get<bool>($3["relative_to_initval"])); else @@ -1393,7 +1400,7 @@ mshocks_options_list : mshocks_option mshocks_option : OVERWRITE { $$ = {"overwrite", true}; } - | LEARNT_IN EQUAL INT_NUMBER + | LEARNT_IN EQUAL integer_or_date { $$ = {"learnt_in", $3}; } | RELATIVE_TO_INITVAL { $$ = {"relative_to_initval", true}; } diff --git a/src/NumericalInitialization.cc b/src/NumericalInitialization.cc index 74a86199..0e2734e6 100644 --- a/src/NumericalInitialization.cc +++ b/src/NumericalInitialization.cc @@ -1,5 +1,5 @@ /* - * Copyright © 2003-2023 Dynare Team + * Copyright © 2003-2024 Dynare Team * * This file is part of Dynare. * @@ -303,10 +303,10 @@ EndValStatement::writeJsonOutput(ostream& output) const output << "]}"; } -EndValLearntInStatement::EndValLearntInStatement(int learnt_in_period_arg, +EndValLearntInStatement::EndValLearntInStatement(variant<int, string> learnt_in_period_arg, learnt_end_values_t learnt_end_values_arg, const SymbolTable& symbol_table_arg) : - learnt_in_period {learnt_in_period_arg}, + learnt_in_period {move(learnt_in_period_arg)}, learnt_end_values {move(learnt_end_values_arg)}, symbol_table {symbol_table_arg} { @@ -343,9 +343,10 @@ EndValLearntInStatement::writeOutput(ostream& output, [[maybe_unused]] const str { if (symbol_table.getType(symb_id) == SymbolType::unusedEndogenous) // See #82 continue; - output << "struct('learnt_in'," << learnt_in_period << ",'exo_id'," - << symbol_table.getTypeSpecificID(symb_id) + 1 << ",'type','" << typeToString(type) - << "'" + output << "struct('learnt_in',"; + visit([&](const auto& p) { output << p; }, learnt_in_period); + output << ",'exo_id'," << symbol_table.getTypeSpecificID(symb_id) + 1 << ",'type','" + << typeToString(type) << "'" << ",'value',"; value->writeOutput(output); output << ");" << endl; @@ -356,7 +357,18 @@ EndValLearntInStatement::writeOutput(ostream& output, [[maybe_unused]] const str void EndValLearntInStatement::writeJsonOutput(ostream& output) const { - output << R"({"statementName": "endval", "learnt_in": )" << learnt_in_period << R"(, "vals": [)"; + output << R"({"statementName": "endval", "learnt_in": )"; + visit( + [&]<class T>(const T& p) { + if constexpr (is_same_v<T, int>) + output << p; + else if constexpr (is_same_v<T, string>) + output << '"' << p << '"'; + else + static_assert(always_false_v<T>, "Non-exhaustive visitor!"); + }, + learnt_in_period); + output << R"(, "vals": [)"; for (bool printed_something {false}; auto& [type, symb_id, value] : learnt_end_values) { if (symbol_table.getType(symb_id) == SymbolType::unusedEndogenous) // See #82 diff --git a/src/NumericalInitialization.hh b/src/NumericalInitialization.hh index dcb6dec1..69a4710a 100644 --- a/src/NumericalInitialization.hh +++ b/src/NumericalInitialization.hh @@ -1,5 +1,5 @@ /* - * Copyright © 2003-2023 Dynare Team + * Copyright © 2003-2024 Dynare Team * * This file is part of Dynare. * @@ -23,6 +23,7 @@ #include <filesystem> #include <map> #include <string> +#include <variant> #include <vector> #include "ExprNode.hh" @@ -101,7 +102,7 @@ public: class EndValLearntInStatement : public Statement { public: - const int learnt_in_period; + const variant<int, string> learnt_in_period; enum class LearntEndValType { level, @@ -117,7 +118,8 @@ private: static string typeToString(LearntEndValType type); public: - EndValLearntInStatement(int learnt_in_period_arg, learnt_end_values_t learnt_end_values_arg, + EndValLearntInStatement(variant<int, string> learnt_in_period_arg, + learnt_end_values_t learnt_end_values_arg, const SymbolTable& symbol_table_arg); void checkPass(ModFileStructure& mod_file_struct, WarningConsolidation& warnings) override; void writeOutput(ostream& output, const string& basename, bool minimal_workspace) const override; diff --git a/src/ParsingDriver.cc b/src/ParsingDriver.cc index 8852d319..84441cf9 100644 --- a/src/ParsingDriver.cc +++ b/src/ParsingDriver.cc @@ -786,22 +786,26 @@ ParsingDriver::end_endval(bool all_values_required) } void -ParsingDriver::end_endval_learnt_in(const string& learnt_in_period) +ParsingDriver::end_endval_learnt_in(variant<int, string> learnt_in_period) { - int learnt_in_period_int = stoi(learnt_in_period); - if (learnt_in_period_int < 1) - error("endval: value '" + learnt_in_period + "' is not allowed for 'learnt_in' option"); - if (learnt_in_period_int == 1) + if (holds_alternative<int>(learnt_in_period)) { - end_endval(false); - return; + int learnt_in_period_int = get<int>(learnt_in_period); + if (learnt_in_period_int < 1) + error("endval: value '" + to_string(learnt_in_period_int) + + "' is not allowed for 'learnt_in' option"); + if (learnt_in_period_int == 1) + { + end_endval(false); + return; + } } for (auto [type, symb_id, value] : end_values) if (mod_file->symbol_table.getType(symb_id) != SymbolType::exogenous) error("endval(learnt_in=...): " + mod_file->symbol_table.getName(symb_id) + " is not an exogenous variable"); mod_file->addStatement(make_unique<EndValLearntInStatement>( - learnt_in_period_int, move(end_values), mod_file->symbol_table)); + move(learnt_in_period), move(end_values), mod_file->symbol_table)); end_values.clear(); } @@ -963,25 +967,30 @@ ParsingDriver::end_shocks_surprise(bool overwrite) } void -ParsingDriver::end_shocks_learnt_in(const string& learnt_in_period, bool overwrite) +ParsingDriver::end_shocks_learnt_in(variant<int, string> learnt_in_period, bool overwrite) { - int learnt_in_period_int = stoi(learnt_in_period); - if (learnt_in_period_int < 1) - error("shocks: value '" + learnt_in_period + "' is not allowed for 'learnt_in' option"); - if (learnt_in_period_int == 1) + if (holds_alternative<int>(learnt_in_period)) { - end_shocks(overwrite); - return; + int learnt_in_period_int = get<int>(learnt_in_period); + if (learnt_in_period_int < 1) + error("shocks: value '" + to_string(learnt_in_period_int) + + "' is not allowed for 'learnt_in' option"); + if (learnt_in_period_int == 1) + { + end_shocks(overwrite); + return; + } + for (auto& storage : {det_shocks, learnt_shocks_add, learnt_shocks_multiply}) + for (auto& [symb_id, vals] : storage) + for (const auto& [period_range, expr] : vals) + if (holds_alternative<pair<int, int>>(period_range)) + if (int period1 = get<pair<int, int>>(period_range).first; + period1 < learnt_in_period_int) + error("shocks: for variable " + mod_file->symbol_table.getName(symb_id) + + ", shock period (" + to_string(period1) + + ") is earlier than the period in which the shock is learnt (" + + to_string(learnt_in_period_int) + ")"); } - for (auto& storage : {det_shocks, learnt_shocks_add, learnt_shocks_multiply}) - for (auto& [symb_id, vals] : storage) - for (const auto& [period_range, expr] : vals) - if (holds_alternative<pair<int, int>>(period_range)) - if (int period1 = get<pair<int, int>>(period_range).first; period1 < learnt_in_period_int) - error("shocks: for variable " + mod_file->symbol_table.getName(symb_id) - + ", shock period (" + to_string(period1) - + ") is earlier than the period in which the shock is learnt (" + learnt_in_period - + ")"); // Aggregate the three types of shocks ShocksLearntInStatement::learnt_shocks_t learnt_shocks; @@ -1014,34 +1023,38 @@ ParsingDriver::end_shocks_learnt_in(const string& learnt_in_period, bool overwri } mod_file->addStatement(make_unique<ShocksLearntInStatement>( - learnt_in_period_int, overwrite, move(learnt_shocks), mod_file->symbol_table)); + move(learnt_in_period), overwrite, move(learnt_shocks), mod_file->symbol_table)); det_shocks.clear(); learnt_shocks_add.clear(); learnt_shocks_multiply.clear(); } void -ParsingDriver::end_mshocks_learnt_in(const string& learnt_in_period, bool overwrite, +ParsingDriver::end_mshocks_learnt_in(variant<int, string> learnt_in_period, bool overwrite, bool relative_to_initval) { - int learnt_in_period_int = stoi(learnt_in_period); - if (learnt_in_period_int < 1) - error("mshocks: value '" + learnt_in_period + "' is not allowed for 'learnt_in' option"); - if (learnt_in_period_int == 1) + if (holds_alternative<int>(learnt_in_period)) { - end_mshocks(overwrite, relative_to_initval); - return; + int learnt_in_period_int = get<int>(learnt_in_period); + if (learnt_in_period_int < 1) + error("mshocks: value '" + to_string(learnt_in_period_int) + + "' is not allowed for 'learnt_in' option"); + if (learnt_in_period_int == 1) + { + end_mshocks(overwrite, relative_to_initval); + return; + } + for (auto& [symb_id, vals] : det_shocks) + for (const auto& [period_range, expr] : vals) + if (holds_alternative<pair<int, int>>(period_range)) + if (int period1 = get<pair<int, int>>(period_range).first; + period1 < learnt_in_period_int) + error("mshocks: for variable " + mod_file->symbol_table.getName(symb_id) + + ", shock period (" + to_string(period1) + + ") is earlier than the period in which the shock is learnt (" + + to_string(learnt_in_period_int) + ")"); } - for (auto& [symb_id, vals] : det_shocks) - for (const auto& [period_range, expr] : vals) - if (holds_alternative<pair<int, int>>(period_range)) - if (int period1 = get<pair<int, int>>(period_range).first; period1 < learnt_in_period_int) - error("mshocks: for variable " + mod_file->symbol_table.getName(symb_id) - + ", shock period (" + to_string(period1) - + ") is earlier than the period in which the shock is learnt (" + learnt_in_period - + ")"); - ShocksLearntInStatement::learnt_shocks_t learnt_shocks; const auto type {relative_to_initval ? ShocksLearntInStatement::LearntShockType::multiplyInitialSteadyState @@ -1057,7 +1070,7 @@ ParsingDriver::end_mshocks_learnt_in(const string& learnt_in_period, bool overwr } mod_file->addStatement(make_unique<ShocksLearntInStatement>( - learnt_in_period_int, overwrite, move(learnt_shocks), mod_file->symbol_table)); + move(learnt_in_period), overwrite, move(learnt_shocks), mod_file->symbol_table)); det_shocks.clear(); if (!learnt_shocks_add.empty()) error("mshocks: 'add' keyword not allowed"); diff --git a/src/ParsingDriver.hh b/src/ParsingDriver.hh index fffa1a39..df251e79 100644 --- a/src/ParsingDriver.hh +++ b/src/ParsingDriver.hh @@ -29,6 +29,7 @@ #include <stack> #include <string> #include <string_view> +#include <variant> #include <vector> #include "ModFile.hh" @@ -437,7 +438,7 @@ public: //! Writes end of an endval block void end_endval(bool all_values_required); //! Writes end of an endval(learnt_in=…) block - void end_endval_learnt_in(const string& learnt_in_period); + void end_endval_learnt_in(variant<int, string> learnt_in_period); //! Writes end of an histval block void end_histval(bool all_values_required); //! Writes end of an homotopy_setup block @@ -464,11 +465,11 @@ public: //! Writes a shocks(surprise) statement void end_shocks_surprise(bool overwrite); //! Writes a shocks(learnt_in=…) block - void end_shocks_learnt_in(const string& learnt_in_period, bool overwrite); + void end_shocks_learnt_in(variant<int, string> learnt_in_period, bool overwrite); // For a shocks(heterogeneity=…) block void end_heterogeneous_shocks(const string& heterogeneity_dimension, bool overwrite); //! Writes a mshocks(learnt_in=…) block - void end_mshocks_learnt_in(const string& learnt_in_period, bool overwrite, + void end_mshocks_learnt_in(variant<int, string> learnt_in_period, bool overwrite, bool relative_to_initval); //! Writes a heteroskedastic_shocks statement void end_heteroskedastic_shocks(bool overwrite); diff --git a/src/Shocks.cc b/src/Shocks.cc index 644b2b06..6a80a84c 100644 --- a/src/Shocks.cc +++ b/src/Shocks.cc @@ -523,10 +523,11 @@ ShocksSurpriseStatement::writeJsonOutput(ostream& output) const output << "]}"; } -ShocksLearntInStatement::ShocksLearntInStatement(int learnt_in_period_arg, bool overwrite_arg, +ShocksLearntInStatement::ShocksLearntInStatement(variant<int, string> learnt_in_period_arg, + bool overwrite_arg, learnt_shocks_t learnt_shocks_arg, const SymbolTable& symbol_table_arg) : - learnt_in_period {learnt_in_period_arg}, + learnt_in_period {move(learnt_in_period_arg)}, overwrite {overwrite_arg}, learnt_shocks {move(learnt_shocks_arg)}, symbol_table {symbol_table_arg} @@ -563,18 +564,29 @@ void ShocksLearntInStatement::writeOutput(ostream& output, [[maybe_unused]] const string& basename, [[maybe_unused]] bool minimal_workspace) const { + auto print_matlab_learnt_in = [&](const auto& p) { output << p; }; if (overwrite) - output << "if ~isempty(M_.learnt_shocks)" << endl - << " M_.learnt_shocks = M_.learnt_shocks([M_.learnt_shocks.learnt_in] ~= " - << learnt_in_period << ");" << endl - << "end" << endl; + { + output << "if ~isempty(M_.learnt_shocks)" << endl + << " M_.learnt_shocks = M_.learnt_shocks(cellfun(@(x) ~isa(x, '"; + if (holds_alternative<int>(learnt_in_period)) + output << "numeric"; + else + output << "dates"; + output << "') || x ~= "; + /* NB: date expression not parenthesized since it can only contain a + operator, which has + higher precedence than ~= and || */ + visit(print_matlab_learnt_in, learnt_in_period); + output << ", {M_.learnt_shocks.learnt_in}));" << endl << "end" << endl; + } output << "M_.learnt_shocks = [ M_.learnt_shocks;" << endl; for (const auto& [id, shock_vec] : learnt_shocks) for (const auto& [type, period_range, value] : shock_vec) { - output << "struct('learnt_in'," << learnt_in_period << ",'exo_id'," - << symbol_table.getTypeSpecificID(id) + 1 << ",'periods',"; + output << "struct('learnt_in',"; + visit(print_matlab_learnt_in, learnt_in_period); + output << ",'exo_id'," << symbol_table.getTypeSpecificID(id) + 1 << ",'periods',"; visit(bind(print_matlab_period_range, ref(output), placeholders::_1), period_range); output << ",'type','" << typeToString(type) << "'" << ",'value',"; @@ -588,8 +600,18 @@ void ShocksLearntInStatement::writeJsonOutput(ostream& output) const { output << R"({"statementName": "shocks")" - << R"(, "learnt_in": )" << learnt_in_period << R"(, "overwrite": )" << boolalpha - << overwrite << R"(, "learnt_shocks": [)"; + << R"(, "learnt_in": )"; + visit( + [&]<class T>(const T& p) { + if constexpr (is_same_v<T, int>) + output << p; + else if constexpr (is_same_v<T, string>) + output << '"' << p << '"'; + else + static_assert(always_false_v<T>, "Non-exhaustive visitor!"); + }, + learnt_in_period); + output << R"(, "overwrite": )" << boolalpha << overwrite << R"(, "learnt_shocks": [)"; for (bool printed_something {false}; const auto& [id, shock_vec] : learnt_shocks) { if (exchange(printed_something, true)) diff --git a/src/Shocks.hh b/src/Shocks.hh index 2f8d335e..96138111 100644 --- a/src/Shocks.hh +++ b/src/Shocks.hh @@ -126,7 +126,7 @@ public: class ShocksLearntInStatement : public Statement { public: - const int learnt_in_period; + const variant<int, string> learnt_in_period; //! Does this “shocks(learnt_in=…)†or “mshocks(learnt_in=…)†block replace the previous ones? const bool overwrite; enum class LearntShockType @@ -153,7 +153,7 @@ private: static string typeToString(LearntShockType type); public: - ShocksLearntInStatement(int learnt_in_period_arg, bool overwrite_arg, + ShocksLearntInStatement(variant<int, string> learnt_in_period_arg, bool overwrite_arg, learnt_shocks_t learnt_shocks_arg, const SymbolTable& symbol_table_arg); void checkPass(ModFileStructure& mod_file_struct, WarningConsolidation& warnings) override; void writeOutput(ostream& output, const string& basename, bool minimal_workspace) const override; -- GitLab