Commit 67ac4bf8 authored by Sébastien Villemot's avatar Sébastien Villemot

Allow diff() and log() in "expression" option of var_expectation_model

parent e9341c71
Pipeline #399 passed with stage
in 1 minute and 19 seconds
......@@ -4924,14 +4924,51 @@ VarExpectationModelStatement::VarExpectationModelStatement(string model_name_arg
aux_model_name{move(aux_model_name_arg)}, horizon{move(horizon_arg)},
discount{discount_arg}, symbol_table{symbol_table_arg}
{
auto vpc = expression->matchLinearCombinationOfVariables();
for (const auto &it : vpc)
}
void
VarExpectationModelStatement::substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, ExprNode::subst_table_t &subst_table)
{
vector<BinaryOpNode *> neweqs;
expression = expression->substituteUnaryOpNodes(static_datatree, nodes, subst_table, neweqs);
if (neweqs.size() > 0)
{
cerr << "ERROR: the 'expression' option of var_expectation_model contains a variable with a unary operator that is not present in the VAR model" << endl;
exit(EXIT_FAILURE);
}
}
void
VarExpectationModelStatement::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, ExprNode::subst_table_t &subst_table)
{
vector<BinaryOpNode *> neweqs;
expression = expression->substituteDiff(static_datatree, diff_table, subst_table, neweqs);
if (neweqs.size() > 0)
{
if (get<1>(it) != 0)
throw ExprNode::MatchFailureException{"lead/lags are not allowed"};
if (symbol_table.getType(get<0>(it)) != SymbolType::endogenous)
throw ExprNode::MatchFailureException{"Variable is not an endogenous"};
vars_params_constants.emplace_back(get<0>(it), get<2>(it), get<3>(it));
cerr << "ERROR: the 'expression' option of var_expectation_model contains a diff'd variable that is not present in the VAR model" << endl;
exit(EXIT_FAILURE);
}
}
void
VarExpectationModelStatement::matchExpression()
{
try
{
auto vpc = expression->matchLinearCombinationOfVariables();
for (const auto &it : vpc)
{
if (get<1>(it) != 0)
throw ExprNode::MatchFailureException{"lead/lags are not allowed"};
if (symbol_table.getType(get<0>(it)) != SymbolType::endogenous)
throw ExprNode::MatchFailureException{"Variable is not an endogenous"};
vars_params_constants.emplace_back(get<0>(it), get<2>(it), get<3>(it));
}
}
catch (ExprNode::MatchFailureException &e)
{
cerr << "ERROR: expression in var_expectation_model is not of the expected form: " << e.message << endl;
exit(EXIT_FAILURE);
}
}
......@@ -4942,6 +4979,12 @@ VarExpectationModelStatement::writeOutput(ostream &output, const string &basenam
output << mstruct << ".auxiliary_model_name = '" << aux_model_name << "';" << endl
<< mstruct << ".horizon = " << horizon << ';' << endl;
if (!vars_params_constants.size())
{
cerr << "ERROR: VarExpectationModelStatement::writeOutput: matchExpression() has not been called" << endl;
exit(EXIT_FAILURE);
}
ostringstream vars_list, params_list, constants_list;
for (auto it = vars_params_constants.begin(); it != vars_params_constants.end(); ++it)
{
......
......@@ -1185,7 +1185,9 @@ class VarExpectationModelStatement : public Statement
{
public:
const string model_name;
const expr_t expression;
private:
expr_t expression;
public:
const string aux_model_name, horizon;
const expr_t discount;
const SymbolTable &symbol_table;
......@@ -1196,6 +1198,12 @@ private:
public:
VarExpectationModelStatement(string model_name_arg, expr_t expression_arg, string aux_model_name_arg,
string horizon_arg, expr_t discount_arg, const SymbolTable &symbol_table_arg);
void substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, ExprNode::subst_table_t &subst_table);
void substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, ExprNode::subst_table_t &subst_table);
// Analyzes the linear combination contained in the 'expression' option
/* Must be called after substituteUnaryOpNodes() and substituteDiff() (in
that order) */
void matchExpression();
void writeOutput(ostream &output, const string &basename, bool minimal_workspace) const override;
void writeJsonOutput(ostream &output) const override;
};
......
......@@ -5802,27 +5802,25 @@ DynamicModel::findPacExpectationEquationNumbers(vector<int> &eqnumbers) const
}
void
DynamicModel::substituteUnaryOps(StaticModel &static_model, bool nopreprocessoroutput)
DynamicModel::substituteUnaryOps(StaticModel &static_model, diff_table_t &nodes, ExprNode::subst_table_t &subst_table, bool nopreprocessoroutput)
{
vector<int> eqnumbers(equations.size());
iota(eqnumbers.begin(), eqnumbers.end(), 0);
substituteUnaryOps(static_model, eqnumbers, nopreprocessoroutput);
substituteUnaryOps(static_model, nodes, subst_table, eqnumbers, nopreprocessoroutput);
}
void
DynamicModel::substituteUnaryOps(StaticModel &static_model, set<string> &var_model_eqtags, bool nopreprocessoroutput)
DynamicModel::substituteUnaryOps(StaticModel &static_model, diff_table_t &nodes, ExprNode::subst_table_t &subst_table, set<string> &var_model_eqtags, bool nopreprocessoroutput)
{
vector<int> eqnumbers;
getEquationNumbersFromTags(eqnumbers, var_model_eqtags);
findPacExpectationEquationNumbers(eqnumbers);
substituteUnaryOps(static_model, eqnumbers, nopreprocessoroutput);
substituteUnaryOps(static_model, nodes, subst_table, eqnumbers, nopreprocessoroutput);
}
void
DynamicModel::substituteUnaryOps(StaticModel &static_model, vector<int> &eqnumbers, bool nopreprocessoroutput)
DynamicModel::substituteUnaryOps(StaticModel &static_model, diff_table_t &nodes, ExprNode::subst_table_t &subst_table, vector<int> &eqnumbers, bool nopreprocessoroutput)
{
diff_table_t nodes;
// Find matching unary ops that may be outside of diffs (i.e., those with different lags)
set<int> used_local_vars;
for (int eqnumber : eqnumbers)
......@@ -5837,7 +5835,6 @@ DynamicModel::substituteUnaryOps(StaticModel &static_model, vector<int> &eqnumbe
equations[eqnumber]->findUnaryOpNodesForAuxVarCreation(static_model, nodes);
// Substitute in model local variables
ExprNode::subst_table_t subst_table;
vector<BinaryOpNode *> neweqs;
for (auto & it : local_variables_table)
it.second = it.second->substituteUnaryOpNodes(static_model, nodes, subst_table, neweqs);
......@@ -5862,14 +5859,13 @@ DynamicModel::substituteUnaryOps(StaticModel &static_model, vector<int> &eqnumbe
}
void
DynamicModel::substituteDiff(StaticModel &static_model, ExprNode::subst_table_t &diff_subst_table, bool nopreprocessoroutput)
DynamicModel::substituteDiff(StaticModel &static_model, diff_table_t &diff_table, ExprNode::subst_table_t &diff_subst_table, bool nopreprocessoroutput)
{
set<int> used_local_vars;
for (const auto & equation : equations)
equation->collectVariables(SymbolType::modelLocalVariable, used_local_vars);
// Only substitute diffs in model local variables that appear in VAR equations
diff_table_t diff_table;
for (auto & it : local_variables_table)
if (used_local_vars.find(it.first) != used_local_vars.end())
it.second->findDiffNodes(static_model, diff_table);
......
......@@ -437,16 +437,16 @@ public:
void substituteAdl();
//! Creates aux vars for all unary operators
void substituteUnaryOps(StaticModel &static_model, bool nopreprocessoroutput);
void substituteUnaryOps(StaticModel &static_model, diff_table_t &nodes, ExprNode::subst_table_t &subst_table, bool nopreprocessoroutput);
//! Creates aux vars for certain unary operators: originally implemented for support of VARs
void substituteUnaryOps(StaticModel &static_model, set<string> &eq_tags, bool nopreprocessoroutput);
void substituteUnaryOps(StaticModel &static_model, diff_table_t &nodes, ExprNode::subst_table_t &subst_table, set<string> &eq_tags, bool nopreprocessoroutput);
//! Creates aux vars for certain unary operators: originally implemented for support of VARs
void substituteUnaryOps(StaticModel &static_model, vector<int> &eqnumbers, bool nopreprocessoroutput);
void substituteUnaryOps(StaticModel &static_model, diff_table_t &nodes, ExprNode::subst_table_t &subst_table, vector<int> &eqnumbers, bool nopreprocessoroutput);
//! Substitutes diff operator
void substituteDiff(StaticModel &static_model, ExprNode::subst_table_t &diff_subst_table, bool nopreprocessoroutput);
void substituteDiff(StaticModel &static_model, diff_table_t &diff_table, ExprNode::subst_table_t &diff_subst_table, bool nopreprocessoroutput);
//! Substitute VarExpectation operators
void substituteVarExpectation(const map<string, expr_t> &subst_table);
......
......@@ -421,8 +421,11 @@ var_expectation_model_options_list : var_expectation_model_option
var_expectation_model_option : VARIABLE EQUAL symbol
{ driver.option_str("variable", $3); }
| EXPRESSION EQUAL expression
{ driver.var_expectation_model_expression = $3; }
| EXPRESSION EQUAL { driver.begin_model(); } hand_side
{
driver.var_expectation_model_expression = $4;
driver.reset_data_tree();
}
| AUXILIARY_MODEL_NAME EQUAL symbol
{ driver.option_str("auxiliary_model_name", $3); }
| HORIZON EQUAL INT_NUMBER
......
......@@ -388,7 +388,7 @@ DATE -?[0-9]+([YyAa]|[Mm]([1-9]|1[0-2])|[Qq][1-4]|[Ww]([1-9]{1}|[1-4][0-9]|5[0-2
<DYNARE_BLOCK>crossequations {return token::CROSSEQUATIONS;}
<DYNARE_BLOCK>covariance {return token::COVARIANCE;}
<DYNARE_BLOCK>adl {return token::ADL;}
<DYNARE_BLOCK>diff {return token::DIFF;}
<DYNARE_STATEMENT,DYNARE_BLOCK>diff {return token::DIFF;}
<DYNARE_STATEMENT>cross_restrictions {return token::CROSS_RESTRICTIONS;}
<DYNARE_STATEMENT>contemp_reduced_form {return token::CONTEMP_REDUCED_FORM;}
<DYNARE_STATEMENT>real_pseudo_forecast {return token::REAL_PSEUDO_FORECAST;}
......
......@@ -411,7 +411,7 @@ class ExprNode
*/
virtual expr_t decreaseLeadsLags(int n) const = 0;
//! Type for the substitution map used in the process of creating auxiliary vars for leads >= 2
//! Type for the substitution map used in the process of creating auxiliary vars
using subst_table_t = map<const ExprNode *, const VariableNode *>;
//! Type for the substitution map used in the process of substituting adl expressions
......
......@@ -391,15 +391,18 @@ ModFile::transformPass(bool nostrict, bool stochastic, bool compute_xrefs, const
for (auto & it1 : it.second)
eqtags.insert(it1);
diff_table_t unary_ops_nodes;
ExprNode::subst_table_t unary_ops_subst_table;
if (transform_unary_ops)
dynamic_model.substituteUnaryOps(diff_static_model, nopreprocessoroutput);
dynamic_model.substituteUnaryOps(diff_static_model, unary_ops_nodes, unary_ops_subst_table, nopreprocessoroutput);
else
// substitute only those unary ops that appear in auxiliary model equations
dynamic_model.substituteUnaryOps(diff_static_model, eqtags, nopreprocessoroutput);
dynamic_model.substituteUnaryOps(diff_static_model, unary_ops_nodes, unary_ops_subst_table, eqtags, nopreprocessoroutput);
// Create auxiliary variable and equations for Diff operators that appear in VAR equations
diff_table_t diff_table;
ExprNode::subst_table_t diff_subst_table;
dynamic_model.substituteDiff(diff_static_model, diff_subst_table, nopreprocessoroutput);
dynamic_model.substituteDiff(diff_static_model, diff_table, diff_subst_table, nopreprocessoroutput);
// Fill Trend Component Model Table
dynamic_model.fillTrendComponentModelTable();
......@@ -544,6 +547,12 @@ ModFile::transformPass(bool nostrict, bool stochastic, bool compute_xrefs, const
exit(EXIT_FAILURE);
}
/* Substitute unary and diff operators in the 'expression' option, then
match the linear combination in the expression option */
vems->substituteUnaryOpNodes(diff_static_model, unary_ops_nodes, unary_ops_subst_table);
vems->substituteDiff(diff_static_model, diff_table, diff_subst_table);
vems->matchExpression();
/* Create auxiliary parameters and the expression to be substituted into
the var_expectations statement */
auto subst_expr = dynamic_model.Zero;
......
......@@ -3383,16 +3383,9 @@ ParsingDriver::var_expectation_model()
else
var_expectation_model_discount = data_tree->One;
try
{
mod_file->addStatement(make_unique<VarExpectationModelStatement>(model_name, var_expectation_model_expression,
var_model_name, horizon,
var_expectation_model_discount, mod_file->symbol_table));
}
catch (ExprNode::MatchFailureException &e)
{
error("expression in var_expectation_model is not of the expected form: " + e.message);
}
mod_file->addStatement(make_unique<VarExpectationModelStatement>(model_name, var_expectation_model_expression,
var_model_name, horizon,
var_expectation_model_discount, mod_file->symbol_table));
options_list.clear();
var_expectation_model_discount = nullptr;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment