From 67ac4bf8ea1de96a1e66c662eb09d8f0c90765c2 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?S=C3=A9bastien=20Villemot?= <sebastien@dynare.org>
Date: Thu, 29 Nov 2018 16:01:49 +0100
Subject: [PATCH] Allow diff() and log() in "expression" option of
 var_expectation_model

---
 src/ComputingTasks.cc | 57 +++++++++++++++++++++++++++++++++++++------
 src/ComputingTasks.hh | 10 +++++++-
 src/DynamicModel.cc   | 16 +++++-------
 src/DynamicModel.hh   |  8 +++---
 src/DynareBison.yy    |  7 ++++--
 src/DynareFlex.ll     |  2 +-
 src/ExprNode.hh       |  2 +-
 src/ModFile.cc        | 15 +++++++++---
 src/ParsingDriver.cc  | 13 +++-------
 9 files changed, 91 insertions(+), 39 deletions(-)

diff --git a/src/ComputingTasks.cc b/src/ComputingTasks.cc
index cb94d749..686e4194 100644
--- a/src/ComputingTasks.cc
+++ b/src/ComputingTasks.cc
@@ -4924,14 +4924,51 @@ VarExpectationModelStatement::VarExpectationModelStatement(string model_name_arg
   aux_model_name{move(aux_model_name_arg)}, horizon{move(horizon_arg)},
   discount{discount_arg}, symbol_table{symbol_table_arg}
 {
-  auto vpc = expression->matchLinearCombinationOfVariables();
-  for (const auto &it : vpc)
+}
+
+void
+VarExpectationModelStatement::substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, ExprNode::subst_table_t &subst_table)
+{
+  vector<BinaryOpNode *> neweqs;
+  expression = expression->substituteUnaryOpNodes(static_datatree, nodes, subst_table, neweqs);
+  if (neweqs.size() > 0)
+    {
+      cerr << "ERROR: the 'expression' option of var_expectation_model contains a variable with a unary operator that is not present in the VAR model" << endl;
+      exit(EXIT_FAILURE);
+    }
+}
+
+void
+VarExpectationModelStatement::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, ExprNode::subst_table_t &subst_table)
+{
+  vector<BinaryOpNode *> neweqs;
+  expression = expression->substituteDiff(static_datatree, diff_table, subst_table, neweqs);
+  if (neweqs.size() > 0)
     {
-      if (get<1>(it) != 0)
-        throw ExprNode::MatchFailureException{"lead/lags are not allowed"};
-      if (symbol_table.getType(get<0>(it)) != SymbolType::endogenous)
-        throw ExprNode::MatchFailureException{"Variable is not an endogenous"};
-      vars_params_constants.emplace_back(get<0>(it), get<2>(it), get<3>(it));
+      cerr << "ERROR: the 'expression' option of var_expectation_model contains a diff'd variable that is not present in the VAR model" << endl;
+      exit(EXIT_FAILURE);
+    }
+}
+
+void
+VarExpectationModelStatement::matchExpression()
+{
+  try
+    {
+      auto vpc = expression->matchLinearCombinationOfVariables();
+      for (const auto &it : vpc)
+        {
+          if (get<1>(it) != 0)
+            throw ExprNode::MatchFailureException{"lead/lags are not allowed"};
+          if (symbol_table.getType(get<0>(it)) != SymbolType::endogenous)
+            throw ExprNode::MatchFailureException{"Variable is not an endogenous"};
+          vars_params_constants.emplace_back(get<0>(it), get<2>(it), get<3>(it));
+        }
+    }
+  catch (ExprNode::MatchFailureException &e)
+    {
+      cerr << "ERROR: expression in var_expectation_model is not of the expected form: " << e.message << endl;
+      exit(EXIT_FAILURE);
     }
 }
 
@@ -4942,6 +4979,12 @@ VarExpectationModelStatement::writeOutput(ostream &output, const string &basenam
   output << mstruct << ".auxiliary_model_name = '" << aux_model_name << "';" << endl
          << mstruct << ".horizon = " << horizon << ';' << endl;
 
+  if (!vars_params_constants.size())
+    {
+      cerr << "ERROR: VarExpectationModelStatement::writeOutput: matchExpression() has not been called" << endl;
+      exit(EXIT_FAILURE);
+    }
+
   ostringstream vars_list, params_list, constants_list;
   for (auto it = vars_params_constants.begin(); it != vars_params_constants.end(); ++it)
     {
diff --git a/src/ComputingTasks.hh b/src/ComputingTasks.hh
index 2afcbf76..8b882ab1 100644
--- a/src/ComputingTasks.hh
+++ b/src/ComputingTasks.hh
@@ -1185,7 +1185,9 @@ class VarExpectationModelStatement : public Statement
 {
 public:
   const string model_name;
-  const expr_t expression;
+private:
+  expr_t expression;
+public:
   const string aux_model_name, horizon;
   const expr_t discount;
   const SymbolTable &symbol_table;
@@ -1196,6 +1198,12 @@ private:
 public:
   VarExpectationModelStatement(string model_name_arg, expr_t expression_arg, string aux_model_name_arg,
                                string horizon_arg, expr_t discount_arg, const SymbolTable &symbol_table_arg);
+  void substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, ExprNode::subst_table_t &subst_table);
+  void substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, ExprNode::subst_table_t &subst_table);
+  // Analyzes the linear combination contained in the 'expression' option
+  /* Must be called after substituteUnaryOpNodes() and substituteDiff() (in
+     that order) */
+  void matchExpression();
   void writeOutput(ostream &output, const string &basename, bool minimal_workspace) const override;
   void writeJsonOutput(ostream &output) const override;
 };
diff --git a/src/DynamicModel.cc b/src/DynamicModel.cc
index ad9e6909..41398450 100644
--- a/src/DynamicModel.cc
+++ b/src/DynamicModel.cc
@@ -5802,27 +5802,25 @@ DynamicModel::findPacExpectationEquationNumbers(vector<int> &eqnumbers) const
 }
 
 void
-DynamicModel::substituteUnaryOps(StaticModel &static_model, bool nopreprocessoroutput)
+DynamicModel::substituteUnaryOps(StaticModel &static_model, diff_table_t &nodes, ExprNode::subst_table_t &subst_table, bool nopreprocessoroutput)
 {
   vector<int> eqnumbers(equations.size());
   iota(eqnumbers.begin(), eqnumbers.end(), 0);
-  substituteUnaryOps(static_model, eqnumbers, nopreprocessoroutput);
+  substituteUnaryOps(static_model, nodes, subst_table, eqnumbers, nopreprocessoroutput);
 }
 
 void
-DynamicModel::substituteUnaryOps(StaticModel &static_model, set<string> &var_model_eqtags, bool nopreprocessoroutput)
+DynamicModel::substituteUnaryOps(StaticModel &static_model, diff_table_t &nodes, ExprNode::subst_table_t &subst_table, set<string> &var_model_eqtags, bool nopreprocessoroutput)
 {
   vector<int> eqnumbers;
   getEquationNumbersFromTags(eqnumbers, var_model_eqtags);
   findPacExpectationEquationNumbers(eqnumbers);
-  substituteUnaryOps(static_model, eqnumbers, nopreprocessoroutput);
+  substituteUnaryOps(static_model, nodes, subst_table, eqnumbers, nopreprocessoroutput);
 }
 
 void
-DynamicModel::substituteUnaryOps(StaticModel &static_model, vector<int> &eqnumbers, bool nopreprocessoroutput)
+DynamicModel::substituteUnaryOps(StaticModel &static_model, diff_table_t &nodes, ExprNode::subst_table_t &subst_table, vector<int> &eqnumbers, bool nopreprocessoroutput)
 {
-  diff_table_t nodes;
-
   // Find matching unary ops that may be outside of diffs (i.e., those with different lags)
   set<int> used_local_vars;
   for (int eqnumber : eqnumbers)
@@ -5837,7 +5835,6 @@ DynamicModel::substituteUnaryOps(StaticModel &static_model, vector<int> &eqnumbe
     equations[eqnumber]->findUnaryOpNodesForAuxVarCreation(static_model, nodes);
 
   // Substitute in model local variables
-  ExprNode::subst_table_t subst_table;
   vector<BinaryOpNode *> neweqs;
   for (auto & it : local_variables_table)
     it.second = it.second->substituteUnaryOpNodes(static_model, nodes, subst_table, neweqs);
@@ -5862,14 +5859,13 @@ DynamicModel::substituteUnaryOps(StaticModel &static_model, vector<int> &eqnumbe
 }
 
 void
-DynamicModel::substituteDiff(StaticModel &static_model, ExprNode::subst_table_t &diff_subst_table, bool nopreprocessoroutput)
+DynamicModel::substituteDiff(StaticModel &static_model, diff_table_t &diff_table, ExprNode::subst_table_t &diff_subst_table, bool nopreprocessoroutput)
 {
   set<int> used_local_vars;
   for (const auto & equation : equations)
     equation->collectVariables(SymbolType::modelLocalVariable, used_local_vars);
 
   // Only substitute diffs in model local variables that appear in VAR equations
-  diff_table_t diff_table;
   for (auto & it : local_variables_table)
     if (used_local_vars.find(it.first) != used_local_vars.end())
       it.second->findDiffNodes(static_model, diff_table);
diff --git a/src/DynamicModel.hh b/src/DynamicModel.hh
index cf161507..2fd42318 100644
--- a/src/DynamicModel.hh
+++ b/src/DynamicModel.hh
@@ -437,16 +437,16 @@ public:
   void substituteAdl();
 
   //! Creates aux vars for all unary operators
-  void substituteUnaryOps(StaticModel &static_model, bool nopreprocessoroutput);
+  void substituteUnaryOps(StaticModel &static_model, diff_table_t &nodes, ExprNode::subst_table_t &subst_table, bool nopreprocessoroutput);
 
   //! Creates aux vars for certain unary operators: originally implemented for support of VARs
-  void substituteUnaryOps(StaticModel &static_model, set<string> &eq_tags, bool nopreprocessoroutput);
+  void substituteUnaryOps(StaticModel &static_model, diff_table_t &nodes, ExprNode::subst_table_t &subst_table, set<string> &eq_tags, bool nopreprocessoroutput);
 
   //! Creates aux vars for certain unary operators: originally implemented for support of VARs
-  void substituteUnaryOps(StaticModel &static_model, vector<int> &eqnumbers, bool nopreprocessoroutput);
+  void substituteUnaryOps(StaticModel &static_model, diff_table_t &nodes, ExprNode::subst_table_t &subst_table, vector<int> &eqnumbers, bool nopreprocessoroutput);
 
   //! Substitutes diff operator
-  void substituteDiff(StaticModel &static_model, ExprNode::subst_table_t &diff_subst_table, bool nopreprocessoroutput);
+  void substituteDiff(StaticModel &static_model, diff_table_t &diff_table, ExprNode::subst_table_t &diff_subst_table, bool nopreprocessoroutput);
 
   //! Substitute VarExpectation operators
   void substituteVarExpectation(const map<string, expr_t> &subst_table);
diff --git a/src/DynareBison.yy b/src/DynareBison.yy
index fac0a1b1..e76c4c8b 100644
--- a/src/DynareBison.yy
+++ b/src/DynareBison.yy
@@ -421,8 +421,11 @@ var_expectation_model_options_list : var_expectation_model_option
 
 var_expectation_model_option : VARIABLE EQUAL symbol
                                { driver.option_str("variable", $3); }
-                             | EXPRESSION EQUAL expression
-                               { driver.var_expectation_model_expression = $3; }
+                             | EXPRESSION EQUAL { driver.begin_model(); } hand_side
+                               {
+                                 driver.var_expectation_model_expression = $4;
+                                 driver.reset_data_tree();
+                               }
                              | AUXILIARY_MODEL_NAME EQUAL symbol
                                { driver.option_str("auxiliary_model_name", $3); }
                              | HORIZON EQUAL INT_NUMBER
diff --git a/src/DynareFlex.ll b/src/DynareFlex.ll
index 186889af..c703738c 100644
--- a/src/DynareFlex.ll
+++ b/src/DynareFlex.ll
@@ -388,7 +388,7 @@ DATE -?[0-9]+([YyAa]|[Mm]([1-9]|1[0-2])|[Qq][1-4]|[Ww]([1-9]{1}|[1-4][0-9]|5[0-2
 <DYNARE_BLOCK>crossequations {return token::CROSSEQUATIONS;}
 <DYNARE_BLOCK>covariance {return token::COVARIANCE;}
 <DYNARE_BLOCK>adl {return token::ADL;}
-<DYNARE_BLOCK>diff {return token::DIFF;}
+<DYNARE_STATEMENT,DYNARE_BLOCK>diff {return token::DIFF;}
 <DYNARE_STATEMENT>cross_restrictions {return token::CROSS_RESTRICTIONS;}
 <DYNARE_STATEMENT>contemp_reduced_form {return token::CONTEMP_REDUCED_FORM;}
 <DYNARE_STATEMENT>real_pseudo_forecast {return token::REAL_PSEUDO_FORECAST;}
diff --git a/src/ExprNode.hh b/src/ExprNode.hh
index 3927c6d4..d3f281be 100644
--- a/src/ExprNode.hh
+++ b/src/ExprNode.hh
@@ -411,7 +411,7 @@ class ExprNode
       */
       virtual expr_t decreaseLeadsLags(int n) const = 0;
 
-      //! Type for the substitution map used in the process of creating auxiliary vars for leads >= 2
+      //! Type for the substitution map used in the process of creating auxiliary vars
       using subst_table_t = map<const ExprNode *, const VariableNode *>;
 
       //! Type for the substitution map used in the process of substituting adl expressions
diff --git a/src/ModFile.cc b/src/ModFile.cc
index 63593f8e..43d89595 100644
--- a/src/ModFile.cc
+++ b/src/ModFile.cc
@@ -391,15 +391,18 @@ ModFile::transformPass(bool nostrict, bool stochastic, bool compute_xrefs, const
     for (auto & it1 : it.second)
       eqtags.insert(it1);
 
+  diff_table_t unary_ops_nodes;
+  ExprNode::subst_table_t unary_ops_subst_table;
   if (transform_unary_ops)
-    dynamic_model.substituteUnaryOps(diff_static_model, nopreprocessoroutput);
+    dynamic_model.substituteUnaryOps(diff_static_model, unary_ops_nodes, unary_ops_subst_table, nopreprocessoroutput);
   else
     // substitute only those unary ops that appear in auxiliary model equations
-    dynamic_model.substituteUnaryOps(diff_static_model, eqtags, nopreprocessoroutput);
+    dynamic_model.substituteUnaryOps(diff_static_model, unary_ops_nodes, unary_ops_subst_table, eqtags, nopreprocessoroutput);
 
   // Create auxiliary variable and equations for Diff operators that appear in VAR equations
+  diff_table_t diff_table;
   ExprNode::subst_table_t diff_subst_table;
-  dynamic_model.substituteDiff(diff_static_model, diff_subst_table, nopreprocessoroutput);
+  dynamic_model.substituteDiff(diff_static_model, diff_table, diff_subst_table, nopreprocessoroutput);
 
   // Fill Trend Component Model Table
   dynamic_model.fillTrendComponentModelTable();
@@ -544,6 +547,12 @@ ModFile::transformPass(bool nostrict, bool stochastic, bool compute_xrefs, const
           exit(EXIT_FAILURE);
         }
 
+      /* Substitute unary and diff operators in the 'expression' option, then
+         match the linear combination in the expression option */
+      vems->substituteUnaryOpNodes(diff_static_model, unary_ops_nodes, unary_ops_subst_table);
+      vems->substituteDiff(diff_static_model, diff_table, diff_subst_table);
+      vems->matchExpression();
+
       /* Create auxiliary parameters and the expression to be substituted into
          the var_expectations statement */
       auto subst_expr = dynamic_model.Zero;
diff --git a/src/ParsingDriver.cc b/src/ParsingDriver.cc
index 482e5f90..7a1122b0 100644
--- a/src/ParsingDriver.cc
+++ b/src/ParsingDriver.cc
@@ -3383,16 +3383,9 @@ ParsingDriver::var_expectation_model()
   else
     var_expectation_model_discount = data_tree->One;
 
-  try
-    {
-      mod_file->addStatement(make_unique<VarExpectationModelStatement>(model_name, var_expectation_model_expression,
-                                                                       var_model_name, horizon,
-                                                                       var_expectation_model_discount, mod_file->symbol_table));
-    }
-  catch (ExprNode::MatchFailureException &e)
-    {
-      error("expression in var_expectation_model is not of the expected form: " + e.message);
-    }
+  mod_file->addStatement(make_unique<VarExpectationModelStatement>(model_name, var_expectation_model_expression,
+                                                                   var_model_name, horizon,
+                                                                   var_expectation_model_discount, mod_file->symbol_table));
 
   options_list.clear();
   var_expectation_model_discount = nullptr;
-- 
GitLab