From 3a820fffa2f7291be5732c33bde552526eab1602 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?S=C3=A9bastien=20Villemot?= <sebastien@dynare.org>
Date: Tue, 3 May 2022 16:52:04 +0200
Subject: [PATCH] =?UTF-8?q?New=20+=3D=20and=20*=3D=20syntaxes=20in=20?=
 =?UTF-8?q?=E2=80=9Cendval(learnt=5Fin=3D=E2=80=A6)=E2=80=9D=20blocks?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Incidentally, forbid exogenous deterministic variables in “endval” blocks.
---
 src/DynareBison.yy             | 28 +++++++++++++++++-------
 src/DynareFlex.ll              |  2 ++
 src/NumericalInitialization.cc | 34 ++++++++++++++++++++++-------
 src/NumericalInitialization.hh | 15 +++++++++++--
 src/ParsingDriver.cc           | 40 +++++++++++++++++++++++++++++-----
 src/ParsingDriver.hh           | 11 +++++++---
 6 files changed, 103 insertions(+), 27 deletions(-)

diff --git a/src/DynareBison.yy b/src/DynareBison.yy
index cbc421cd..42c53d92 100644
--- a/src/DynareBison.yy
+++ b/src/DynareBison.yy
@@ -186,7 +186,7 @@ class ParsingDriver;
 %token NO_IDENTIFICATION_MINIMAL NO_IDENTIFICATION_SPECTRUM NORMALIZE_JACOBIANS GRID_NBR
 %token TOL_RANK TOL_DERIV TOL_SV CHECKS_VIA_SUBSETS MAX_DIM_SUBSETS_GROUPS ZERO_MOMENTS_TOLERANCE
 %token MAX_NROWS SQUEEZE_SHOCK_DECOMPOSITION WITH_EPILOGUE MODEL_REMOVE MODEL_REPLACE MODEL_OPTIONS
-%token VAR_REMOVE ESTIMATED_PARAMS_REMOVE STATIC INCIDENCE RESID NON_ZERO LEARNT_IN
+%token VAR_REMOVE ESTIMATED_PARAMS_REMOVE STATIC INCIDENCE RESID NON_ZERO LEARNT_IN PLUS_EQUAL TIMES_EQUAL
 
 %token <vector<string>> SYMBOL_VEC
 
@@ -729,6 +729,12 @@ initval : INITVAL ';' initval_list END ';'
           { driver.end_initval(true); }
         ;
 
+initval_list : initval_list initval_elem
+             | initval_elem
+             ;
+
+initval_elem : symbol EQUAL expression ';' { driver.init_val($1, $3); };
+
 histval_file : HISTVAL_FILE '(' h_options_list ')' ';'
               { driver.histval_file();};
 
@@ -753,19 +759,25 @@ h_options: o_filename
           | o_series
           ;
 
-endval : ENDVAL ';' initval_list END ';'
+endval : ENDVAL ';' endval_list END ';'
          { driver.end_endval(false); }
-       | ENDVAL '(' ALL_VALUES_REQUIRED ')' ';' initval_list END ';'
+       | ENDVAL '(' ALL_VALUES_REQUIRED ')' ';' endval_list END ';'
          { driver.end_endval(true); }
-       | ENDVAL '(' LEARNT_IN EQUAL INT_NUMBER ')' ';' initval_list END ';'
+       | ENDVAL '(' LEARNT_IN EQUAL INT_NUMBER ')' ';' endval_list END ';'
          { driver.end_endval_learnt_in($5); }
        ;
 
-initval_list : initval_list initval_elem
-             | initval_elem
-             ;
+endval_list : endval_list endval_elem
+            | endval_elem
+            ;
 
-initval_elem : symbol EQUAL expression ';' { driver.init_val($1, $3); };
+endval_elem : symbol EQUAL expression ';'
+              { driver.end_val(EndValLearntInStatement::LearntEndValType::level, $1, $3); };
+            | symbol PLUS_EQUAL expression ';'
+              { driver.end_val(EndValLearntInStatement::LearntEndValType::add, $1, $3); };
+            | symbol TIMES_EQUAL expression ';'
+              { driver.end_val(EndValLearntInStatement::LearntEndValType::multiply, $1, $3); };
+            ;
 
 histval : HISTVAL ';' histval_list END ';'
           { driver.end_histval(false); };
diff --git a/src/DynareFlex.ll b/src/DynareFlex.ll
index 8f0e1b1a..23de13ca 100644
--- a/src/DynareFlex.ll
+++ b/src/DynareFlex.ll
@@ -951,6 +951,8 @@ DATE -?[0-9]+([ya]|m([1-9]|1[0-2])|q[1-4])
 <DYNARE_STATEMENT,DYNARE_BLOCK><= {return token::LESS_EQUAL;}
 <DYNARE_STATEMENT,DYNARE_BLOCK>== {return token::EQUAL_EQUAL;}
 <DYNARE_STATEMENT,DYNARE_BLOCK>!= {return token::EXCLAMATION_EQUAL;}
+<DYNARE_BLOCK>\+= {return token::PLUS_EQUAL;}
+<DYNARE_BLOCK>\*= {return token::TIMES_EQUAL;}
 <DYNARE_STATEMENT,DYNARE_BLOCK>\^ {return token::POWER;}
 <DYNARE_STATEMENT,DYNARE_BLOCK>exp {return token::EXP;}
 <DYNARE_STATEMENT,DYNARE_BLOCK>log {return token::LOG;}
diff --git a/src/NumericalInitialization.cc b/src/NumericalInitialization.cc
index afe6667a..201b9168 100644
--- a/src/NumericalInitialization.cc
+++ b/src/NumericalInitialization.cc
@@ -300,10 +300,10 @@ EndValStatement::writeJsonOutput(ostream &output) const
 }
 
 EndValLearntInStatement::EndValLearntInStatement(int learnt_in_period_arg,
-                                                 const InitOrEndValStatement::init_values_t &init_values_arg,
+                                                 const learnt_end_values_t &learnt_end_values_arg,
                                                  const SymbolTable &symbol_table_arg) :
   learnt_in_period{learnt_in_period_arg},
-  init_values{move(init_values_arg)},
+  learnt_end_values{move(learnt_end_values_arg)},
   symbol_table{symbol_table_arg}
 {
 }
@@ -314,14 +314,30 @@ EndValLearntInStatement::checkPass(ModFileStructure &mod_file_struct, WarningCon
   mod_file_struct.endval_learnt_in_present = true;
 }
 
+string
+EndValLearntInStatement::typeToString(LearntEndValType type)
+{
+  switch (type)
+    {
+    case LearntEndValType::level:
+      return "level";
+    case LearntEndValType::add:
+      return "add";
+    case LearntEndValType::multiply:
+      return "multiply";
+    }
+  exit(EXIT_FAILURE); // Silence GCC warning
+}
+
 void
 EndValLearntInStatement::writeOutput(ostream &output, const string &basename, bool minimal_workspace) const
 {
   output << "M_.learnt_endval = [ M_.learnt_endval;" << endl;
-  for (auto [symb_id, value] : init_values)
+  for (auto [type, symb_id, value] : learnt_end_values)
     {
       output << "struct('learnt_in'," << learnt_in_period
              << ",'exo_id'," << symbol_table.getTypeSpecificID(symb_id)+1
+             << ",'type','" << typeToString(type) << "'"
              << ",'value',";
       value->writeOutput(output);
       output << ");" << endl;
@@ -334,13 +350,15 @@ EndValLearntInStatement::writeJsonOutput(ostream &output) const
 {
   output << R"({"statementName": "endval", "learnt_in": )"
          << learnt_in_period <<  R"(, "vals": [)";
-  for (auto it = init_values.begin();
-       it != init_values.end(); ++it)
+  for (auto it = learnt_end_values.begin();
+       it != learnt_end_values.end(); ++it)
     {
-      auto [symb_id, value] = *it;
-      if (it != init_values.begin())
+      auto [type, symb_id, value] = *it;
+      if (it != learnt_end_values.begin())
         output << ", ";
-      output << R"({"name": ")" << symbol_table.getName(symb_id) << R"(", )" << R"("value": ")";
+      output << R"({"name": ")" << symbol_table.getName(symb_id) << R"(", )"
+             << R"("type": ")" << typeToString(type) << R"(", )"
+             << R"("value": ")";
       value->writeJsonOutput(output, {}, {});
       output << R"("})";
     }
diff --git a/src/NumericalInitialization.hh b/src/NumericalInitialization.hh
index 5c5fd611..0779c950 100644
--- a/src/NumericalInitialization.hh
+++ b/src/NumericalInitialization.hh
@@ -100,10 +100,21 @@ class EndValLearntInStatement : public Statement
 {
 public:
   const int learnt_in_period;
-  const InitOrEndValStatement::init_values_t init_values;
+  enum class LearntEndValType
+    {
+      level,
+      add,
+      multiply
+    };
+  // The tuple is (type, symb_id, value)
+  using learnt_end_values_t = vector<tuple<LearntEndValType, int, expr_t>>;
+  const learnt_end_values_t learnt_end_values;
+private:
   const SymbolTable &symbol_table;
+  static string typeToString(LearntEndValType type);
+public:
   EndValLearntInStatement(int learnt_in_period_arg,
-                          const InitOrEndValStatement::init_values_t &init_values_arg,
+                          const learnt_end_values_t &learnt_end_values_arg,
                           const SymbolTable &symbol_table_arg);
   void checkPass(ModFileStructure &mod_file_struct, WarningConsolidation &warnings) override;
   void writeOutput(ostream &output, const string &basename, bool minimal_workspace) const override;
diff --git a/src/ParsingDriver.cc b/src/ParsingDriver.cc
index f6506cec..aa1696a3 100644
--- a/src/ParsingDriver.cc
+++ b/src/ParsingDriver.cc
@@ -544,7 +544,7 @@ ParsingDriver::init_val(const string &name, expr_t rhs)
   if (nostrict)
     if (!mod_file->symbol_table.exists(name))
       {
-        warning("discarding '" + name + "' as it was not recognized in the initval or endval statement");
+        warning("discarding '" + name + "' as it was not recognized in the initval statement");
         return;
       }
 
@@ -560,6 +560,21 @@ ParsingDriver::initval_file()
   options_list.clear(); 
 }
 
+void
+ParsingDriver::end_val(EndValLearntInStatement::LearntEndValType type, const string &name, expr_t rhs)
+{
+  if (nostrict)
+    if (!mod_file->symbol_table.exists(name))
+      {
+        warning("discarding '" + name + "' as it was not recognized in the endval statement");
+        return;
+      }
+
+  check_symbol_is_endogenous_or_exogenous(name, false);
+  int symb_id = mod_file->symbol_table.getID(name);
+  end_values.emplace_back(type, symb_id, rhs);
+}
+
 void
 ParsingDriver::hist_val(const string &name, const string &lag, expr_t rhs)
 {
@@ -743,8 +758,21 @@ ParsingDriver::end_initval(bool all_values_required)
 void
 ParsingDriver::end_endval(bool all_values_required)
 {
-  mod_file->addStatement(make_unique<EndValStatement>(init_values, mod_file->symbol_table, all_values_required));
-  init_values.clear();
+  InitOrEndValStatement::init_values_t end_values_new;
+  for (auto [type, symb_id, value] : end_values)
+    switch (type)
+      {
+      case EndValLearntInStatement::LearntEndValType::level:
+        end_values_new.emplace_back(symb_id, value);
+        break;
+      case EndValLearntInStatement::LearntEndValType::add:
+        error("endval: '" + mod_file->symbol_table.getName(symb_id) + " += ...' line not allowed unless 'learnt_in' option with value >1 is passed");
+      case EndValLearntInStatement::LearntEndValType::multiply:
+        error("endval: '" + mod_file->symbol_table.getName(symb_id) + " *= ...' line not allowed unless 'learnt_in' option with value >1 is passed");
+      }
+
+  mod_file->addStatement(make_unique<EndValStatement>(end_values_new, mod_file->symbol_table, all_values_required));
+  end_values.clear();
 }
 
 void
@@ -758,11 +786,11 @@ ParsingDriver::end_endval_learnt_in(const string &learnt_in_period)
       end_endval(false);
       return;
     }
-  for (auto [symb_id, value] : init_values)
+  for (auto [type, symb_id, value] : end_values)
     if (mod_file->symbol_table.getType(symb_id) != SymbolType::exogenous)
       error("endval(learnt_in=...): " + mod_file->symbol_table.getName(symb_id) + " is not an exogenous variable");
-  mod_file->addStatement(make_unique<EndValLearntInStatement>(learnt_in_period_int, init_values, mod_file->symbol_table));
-  init_values.clear();
+  mod_file->addStatement(make_unique<EndValLearntInStatement>(learnt_in_period_int, end_values, mod_file->symbol_table));
+  end_values.clear();
 }
 
 void
diff --git a/src/ParsingDriver.hh b/src/ParsingDriver.hh
index e53b5307..82923136 100644
--- a/src/ParsingDriver.hh
+++ b/src/ParsingDriver.hh
@@ -169,8 +169,11 @@ private:
   SigmaeStatement::row_t sigmae_row;
   //! Temporary storage for Sigma_e matrix
   SigmaeStatement::matrix_t sigmae_matrix;
-  //! Temporary storage for initval/endval blocks
+  //! Temporary storage for initval blocks
   InitOrEndValStatement::init_values_t init_values;
+  /* Temporary storage for endval blocks. Uses a type that encompasses both
+     regular “endval” blocks and “endval(learnt_in=…)” blocks. */
+  EndValLearntInStatement::learnt_end_values_t end_values;
   //! Temporary storage for histval blocks
   HistValStatement::hist_values_t hist_values;
   //! Temporary storage for homotopy_setup blocks
@@ -415,9 +418,11 @@ public:
   void dsample(const string &arg1, const string &arg2);
   //! Writes parameter intitialisation expression
   void init_param(const string &name, expr_t rhs);
-  //! Writes an initval block
+  //! Add a line inside an initval block
   void init_val(const string &name, expr_t rhs);
-  //! Writes an histval block
+  //! Add a line inside an endval block
+  void end_val(EndValLearntInStatement::LearntEndValType type, const string &name, expr_t rhs);
+  //! Add a line inside a histval block
   void hist_val(const string &name, const string &lag, expr_t rhs);
   //! Adds an entry in a homotopy_setup block
   /*! Second argument "val1" can be NULL if no initial value provided */
-- 
GitLab