From 30428aeb17162db8328d941491162fe3a6895041 Mon Sep 17 00:00:00 2001
From: Houtan Bastani <houtan@dynare.org>
Date: Tue, 3 Mar 2015 15:08:33 +0100
Subject: [PATCH] preprocessor: add joint prior syntax, #824

---
 matlab/global_initialization.m |   2 +
 preprocessor/ComputingTasks.cc | 109 +++++++++++++++++++++++++++++++++
 preprocessor/ComputingTasks.hh |  16 +++++
 preprocessor/DynareBison.yy    |  21 +++++++
 preprocessor/DynareFlex.ll     |  32 ++++++++++
 preprocessor/ParsingDriver.cc  |  20 ++++++
 preprocessor/ParsingDriver.hh  |   6 ++
 tests/ms-dsge/test_ms_dsge.mod |   4 +-
 8 files changed, 209 insertions(+), 1 deletion(-)

diff --git a/matlab/global_initialization.m b/matlab/global_initialization.m
index 36a0ccf9fd..bcd3cff407 100644
--- a/matlab/global_initialization.m
+++ b/matlab/global_initialization.m
@@ -362,6 +362,8 @@ estimation_info.measurement_error_corr.range_index = {};
 estimation_info.structural_innovation_corr_prior_index = {};
 estimation_info.structural_innovation_corr_options_index = {};
 estimation_info.structural_innovation_corr.range_index = {};
+estimation_info.joint_parameter_prior_index = {};
+estimation_info.joint_parameter = cell2table(cell(0,11));
 options_.initial_period = NaN; %dates(1,1);
 options_.dataset.file = [];
 options_.dataset.series = [];
diff --git a/preprocessor/ComputingTasks.cc b/preprocessor/ComputingTasks.cc
index 2c0c5d2978..2d7b1ffc3e 100644
--- a/preprocessor/ComputingTasks.cc
+++ b/preprocessor/ComputingTasks.cc
@@ -2082,6 +2082,115 @@ SubsamplesEqualStatement::writeOutput(ostream &output, const string &basename) c
          << endl;
 }
 
+JointPriorStatement::JointPriorStatement(const vector<string> joint_parameters_arg,
+                                         const PriorDistributions &prior_shape_arg,
+                                         const OptionsList &options_list_arg) :
+  joint_parameters(joint_parameters_arg),
+  prior_shape(prior_shape_arg),
+  options_list(options_list_arg)
+{
+}
+
+void
+JointPriorStatement::checkPass(ModFileStructure &mod_file_struct, WarningConsolidation &warnings)
+{
+  if (joint_parameters.size() < 2)
+    {
+      cerr << "ERROR: you must pass at least two parameters to the joint prior statement" << endl;
+      exit(EXIT_FAILURE);
+    }
+
+  if (prior_shape == eNoShape)
+    {
+      cerr << "ERROR: You must pass the shape option to the prior statement." << endl;
+      exit(EXIT_FAILURE);
+    }
+
+  if (options_list.num_options.find("mean") == options_list.num_options.end() &&
+      options_list.num_options.find("mode") == options_list.num_options.end())
+    {
+      cerr << "ERROR: You must pass at least one of mean and mode to the prior statement." << endl;
+      exit(EXIT_FAILURE);
+    }
+
+  OptionsList::num_options_t::const_iterator it_num = options_list.num_options.find("domain");
+  if (it_num != options_list.num_options.end())
+    {
+      using namespace boost;
+      vector<string> tokenizedDomain;
+      split(tokenizedDomain, it_num->second, is_any_of("[ ]"), token_compress_on);
+      if (tokenizedDomain.size() != 4)
+        {
+          cerr << "ERROR: You must pass exactly two values to the domain option." << endl;
+          exit(EXIT_FAILURE);
+        }
+    }
+}
+
+void
+JointPriorStatement::writeOutput(ostream &output, const string &basename) const
+{
+  for (vector<string>::const_iterator it = joint_parameters.begin() ; it != joint_parameters.end(); it++)
+    output << "eifind = get_new_or_existing_ei_index('joint_parameter_prior_index', '"
+           << *it << "', '');" << endl
+           << "estimation_info.joint_parameter_prior_index(eifind) = {'" << *it << "'};" << endl;
+
+  output << "key = {[";
+  for (vector<string>::const_iterator it = joint_parameters.begin() ; it != joint_parameters.end(); it++)
+    output << "get_new_or_existing_ei_index('joint_parameter_prior_index', '" << *it << "', '') ..."
+           << endl << "    ";
+  output << "]};" << endl;
+
+  string lhs_field("estimation_info.joint_parameter_tmp");
+
+  writeOutputHelper(output, "domain", lhs_field);
+  writeOutputHelper(output, "interval", lhs_field);
+  writeOutputHelper(output, "mean", lhs_field);
+  writeOutputHelper(output, "median", lhs_field);
+  writeOutputHelper(output, "mode", lhs_field);
+
+  assert(prior_shape != eNoShape);
+  output << lhs_field << ".shape = " << prior_shape << ";" << endl;
+
+  writeOutputHelper(output, "shift", lhs_field);
+  writeOutputHelper(output, "stdev", lhs_field);
+  writeOutputHelper(output, "truncate", lhs_field);
+  writeOutputHelper(output, "variance", lhs_field);
+
+  output << "estimation_info.joint_parameter_tmp = table(key, ..." << endl
+         << "    " << lhs_field << ".domain , ..." << endl
+         << "    " << lhs_field << ".interval , ..." << endl
+         << "    " << lhs_field << ".mean , ..." << endl
+         << "    " << lhs_field << ".median , ..." << endl
+         << "    " << lhs_field << ".mode , ..." << endl
+         << "    " << lhs_field << ".shape , ..." << endl
+         << "    " << lhs_field << ".shift , ..." << endl
+         << "    " << lhs_field << ".stdev , ..." << endl
+         << "    " << lhs_field << ".truncate , ..." << endl
+         << "    " << lhs_field << ".variance, ..." << endl
+         << "    'VariableNames',{'index','domain','interval','mean','median','mode','shape','shift','stdev','truncate','variance'});" << endl;
+
+  output << "if height(estimation_info.joint_parameter)" << endl
+         << "  estimation_info.joint_parameter = [estimation_info.joint_parameter; estimation_info.joint_parameter_tmp];" << endl
+         << "else" << endl
+         << "    estimation_info.joint_parameter = estimation_info.joint_parameter_tmp;" << endl
+         << "end" << endl
+         << "clear estimation_info.joint_parameter_tmp;" << endl;
+}
+
+void
+JointPriorStatement::writeOutputHelper(ostream &output, const string &field, const string &lhs_field) const
+{
+  OptionsList::num_options_t::const_iterator itn = options_list.num_options.find(field);
+  output << lhs_field << "." << field << " = {";
+  if (itn != options_list.num_options.end())
+    output << itn->second;
+  else
+    output << "{}";
+  output << "};" << endl;
+}
+
+
 BasicPriorStatement::~BasicPriorStatement()
 {
 }
diff --git a/preprocessor/ComputingTasks.hh b/preprocessor/ComputingTasks.hh
index 36732cc249..d19f3e4c27 100644
--- a/preprocessor/ComputingTasks.hh
+++ b/preprocessor/ComputingTasks.hh
@@ -679,6 +679,22 @@ public:
   virtual void writeOutput(ostream &output, const string &basename) const;
 };
 
+class JointPriorStatement : public Statement
+{
+private:
+  const vector<string> joint_parameters;
+  const PriorDistributions prior_shape;
+  const OptionsList options_list;
+public:
+  JointPriorStatement(const vector<string> joint_parameters_arg,
+                      const PriorDistributions &prior_shape_arg,
+                      const OptionsList &options_list_arg);
+  virtual void checkPass(ModFileStructure &mod_file_struct, WarningConsolidation &warnings);
+  virtual void writeOutput(ostream &output, const string &basename) const;
+  void writeOutputHelper(ostream &output, const string &field, const string &lhs_field) const;
+};
+
+
 class BasicPriorStatement : public Statement
 {
 public:
diff --git a/preprocessor/DynareBison.yy b/preprocessor/DynareBison.yy
index 3760a46ede..748cd78042 100644
--- a/preprocessor/DynareBison.yy
+++ b/preprocessor/DynareBison.yy
@@ -171,6 +171,7 @@ class ParsingDriver;
 %token PARAMETER_CONVERGENCE_CRITERION NUMBER_OF_LARGE_PERTURBATIONS NUMBER_OF_SMALL_PERTURBATIONS
 %token NUMBER_OF_POSTERIOR_DRAWS_AFTER_PERTURBATION MAX_NUMBER_OF_STAGES
 %token RANDOM_FUNCTION_CONVERGENCE_CRITERION RANDOM_PARAMETER_CONVERGENCE_CRITERION
+%token <vector_string_val> SYMBOL_VEC
 
 %type <node_val> expression expression_or_empty
 %type <node_val> equation hand_side
@@ -1422,6 +1423,8 @@ prior : symbol '.' PRIOR { driver.set_prior_variance(); driver.prior_shape = eNo
         { driver.set_prior($1, new string ("")); }
       | symbol '.' symbol '.' PRIOR { driver.set_prior_variance(); driver.prior_shape = eNoShape; } '(' prior_options_list ')' ';'
         { driver.set_prior($1, $3); }
+      | SYMBOL_VEC '.' PRIOR { driver.set_prior_variance(); driver.prior_shape = eNoShape; }  '(' joint_prior_options_list ')' ';'
+        { driver.set_joint_prior($1); }
       | STD '(' symbol ')' '.' PRIOR { driver.set_prior_variance(); driver.prior_shape = eNoShape; } '(' prior_options_list ')' ';'
         { driver.set_std_prior($3, new string ("")); }
       | STD '(' symbol ')' '.' symbol '.' PRIOR { driver.set_prior_variance(); driver.prior_shape = eNoShape; } '(' prior_options_list ')' ';'
@@ -1448,6 +1451,22 @@ prior_options : o_shift
               | o_domain
               ;
 
+joint_prior_options_list : joint_prior_options_list COMMA joint_prior_options
+                         | joint_prior_options
+                         ;
+
+joint_prior_options : o_shift
+                    | o_mean_vec
+                    | o_median
+                    | o_stdev
+                    | o_truncate
+                    | o_variance_mat
+                    | o_mode
+                    | o_interval
+                    | o_shape
+                    | o_domain
+                    ;
+
 prior_eq : prior_eq_opt EQUAL prior_eq_opt ';'
            {
              driver.copy_prior($1->at(0), $1->at(1), $1->at(2), $1->at(3),
@@ -2544,6 +2563,7 @@ o_shift : SHIFT EQUAL signed_number { driver.option_num("shift", $3); };
 o_shape : SHAPE EQUAL prior_distribution { driver.prior_shape = $3; };
 o_mode : MODE EQUAL signed_number { driver.option_num("mode", $3); };
 o_mean : MEAN EQUAL signed_number { driver.option_num("mean", $3); };
+o_mean_vec : MEAN EQUAL vec_value { driver.option_num("mean", $3); };
 o_truncate : TRUNCATE EQUAL vec_value { driver.option_num("truncate", $3); };
 o_stdev : STDEV EQUAL non_negative_number { driver.option_num("stdev", $3); };
 o_jscale : JSCALE EQUAL non_negative_number { driver.option_num("jscale", $3); };
@@ -2552,6 +2572,7 @@ o_bounds : BOUNDS EQUAL vec_value_w_inf { driver.option_num("bounds", $3); };
 o_domain : DOMAINN EQUAL vec_value { driver.option_num("domain", $3); };
 o_interval : INTERVAL EQUAL vec_value { driver.option_num("interval", $3); };
 o_variance : VARIANCE EQUAL expression { driver.set_prior_variance($3); }
+o_variance_mat : VARIANCE EQUAL vec_of_vec_value { driver.option_num("variance",$3); }
 o_prefilter : PREFILTER EQUAL INT_NUMBER { driver.option_num("prefilter", $3); };
 o_presample : PRESAMPLE EQUAL INT_NUMBER { driver.option_num("presample", $3); };
 o_lik_algo : LIK_ALGO EQUAL INT_NUMBER { driver.option_num("lik_algo", $3); };
diff --git a/preprocessor/DynareFlex.ll b/preprocessor/DynareFlex.ll
index cf65ff855c..793de8c0e6 100644
--- a/preprocessor/DynareFlex.ll
+++ b/preprocessor/DynareFlex.ll
@@ -839,6 +839,38 @@ DATE -?[0-9]+([YyAa]|[Mm]([1-9]|1[0-2])|[Qq][1-4]|[Ww]([1-9]{1}|[1-4][0-9]|5[0-2
     }
 }
 
+ /* For joint prior statement, match [symbol, symbol, ...]
+   If no match, begin native and push everything back on stack
+ */
+<INITIAL>\[([[:space:]]*[A-Za-z_][A-Za-z0-9_]*[[:space:]]*,{1}[[:space:]]*)*([[:space:]]*[A-Za-z_][A-Za-z0-9_]*[[:space:]]*){1}\] {
+  string yytextcpy = string(yytext);
+  yytextcpy.erase(remove(yytextcpy.begin(), yytextcpy.end(), '['), yytextcpy.end());
+  yytextcpy.erase(remove(yytextcpy.begin(), yytextcpy.end(), ']'), yytextcpy.end());
+  yytextcpy.erase(remove(yytextcpy.begin(), yytextcpy.end(), ' '), yytextcpy.end());
+  istringstream ss(yytextcpy);
+  string token;
+  yylval->vector_string_val = new vector<string *>;
+
+  while(getline(ss, token, ','))
+    if (driver.symbol_exists_and_is_not_modfile_local_or_external_function(token.c_str()))
+      yylval->vector_string_val->push_back(new string(token));
+    else
+      {
+        for (vector<string *>::iterator it=yylval->vector_string_val->begin();
+            it != yylval->vector_string_val->end(); it++)
+          delete *it;
+        delete yylval->vector_string_val;
+        BEGIN NATIVE;
+        yyless(0);
+        break;
+      }
+  if (yylval->vector_string_val->size() > 0)
+    {
+      BEGIN DYNARE_STATEMENT;
+      return token::SYMBOL_VEC;
+    }
+}
+
  /* Enter a native block */
 <INITIAL>. { BEGIN NATIVE; yyless(0); }
 
diff --git a/preprocessor/ParsingDriver.cc b/preprocessor/ParsingDriver.cc
index acf0c8a914..cd9b0101c3 100644
--- a/preprocessor/ParsingDriver.cc
+++ b/preprocessor/ParsingDriver.cc
@@ -1410,6 +1410,26 @@ ParsingDriver::set_prior(string *name, string *subsample_name)
   delete subsample_name;
 }
 
+void
+ParsingDriver::set_joint_prior(vector<string *>*symbol_vec)
+{
+  for (vector<string *>::const_iterator it=symbol_vec->begin(); it != symbol_vec->end(); it++)
+    add_joint_parameter(*it);
+  mod_file->addStatement(new JointPriorStatement(joint_parameters, prior_shape, options_list));
+  joint_parameters.clear();
+  options_list.clear();
+  prior_shape = eNoShape;
+  delete symbol_vec;
+}
+
+void
+ParsingDriver::add_joint_parameter(string *name)
+{
+  check_symbol_is_parameter(name);
+  joint_parameters.push_back(*name);
+  delete name;
+}
+
 void
 ParsingDriver::set_prior_variance(expr_t variance)
 {
diff --git a/preprocessor/ParsingDriver.hh b/preprocessor/ParsingDriver.hh
index ace9d1223d..d4ebc05dc4 100644
--- a/preprocessor/ParsingDriver.hh
+++ b/preprocessor/ParsingDriver.hh
@@ -186,6 +186,8 @@ private:
 
   //! Temporary storage for argument list of external function
   stack<vector<expr_t> >  stack_external_function_args;
+  //! Temporary storage for parameters in joint prior statement
+  vector<string> joint_parameters;
   //! Temporary storage for the symb_id associated with the "name" symbol of the current external_function statement
   int current_external_function_id;
   //! Temporary storage for option list provided to external_function()
@@ -411,6 +413,10 @@ public:
   void estimation_data();
   //! Sets the prior for a parameter
   void set_prior(string *arg1, string *arg2);
+  //! Sets the joint prior for a set of parameters
+  void set_joint_prior(vector<string *>*symbol_vec);
+  //! Adds a parameters to the list of joint parameters
+  void add_joint_parameter(string *name);
   //! Adds the variance option to its temporary holding place
   void set_prior_variance(expr_t variance=NULL);
   //! Copies the prior from_name to_name
diff --git a/tests/ms-dsge/test_ms_dsge.mod b/tests/ms-dsge/test_ms_dsge.mod
index e11715b3e4..51b132adf9 100644
--- a/tests/ms-dsge/test_ms_dsge.mod
+++ b/tests/ms-dsge/test_ms_dsge.mod
@@ -35,4 +35,6 @@ alpha.options(init=1);
 rho.options(init=1);
 beta.options(init=0.2);
 std(u).options(init=3);
-corr(y,c).options(init=.02);
\ No newline at end of file
+corr(y,c).options(init=.02);
+
+[alpha , beta , rho].prior(shape=beta, mean=[2 3 4], variance=[[1 2 3],[2 3 4]]);
\ No newline at end of file
-- 
GitLab