Commit bb624ec6 authored by Houtan Bastani's avatar Houtan Bastani

epilogue: allow for simulations in epilogue block

parent 38152c34
Pipeline #503 passed with stage
in 1 minute and 27 seconds
......@@ -880,8 +880,8 @@ epilogue_equation_list : epilogue_equation_list epilogue_equation
| epilogue_equation
;
epilogue_equation : NAME EQUAL expression ';'
{ driver.add_epilogue_equal($1, $3); }
epilogue_equation : NAME { driver.add_epilogue_variable($1); } EQUAL expression ';'
{ driver.add_epilogue_equal($1, $4); }
;
model_options : BLOCK { driver.block(); }
......
......@@ -989,8 +989,10 @@ VariableNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
break;
case ExprNodeOutputType::epilogueFile:
output << "dseries__." << datatree.symbol_table.getName(symb_id);
if (lag != 0)
output << LEFT_ARRAY_SUBSCRIPT(output_type) << lag << RIGHT_ARRAY_SUBSCRIPT(output_type);
output << LEFT_ARRAY_SUBSCRIPT(output_type) << "t";
if (lag != 0)
output << lag;
output << RIGHT_ARRAY_SUBSCRIPT(output_type);
break;
default:
cerr << "VariableNode::writeOutput: should not reach this point" << endl;
......@@ -1047,8 +1049,10 @@ VariableNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
break;
case ExprNodeOutputType::epilogueFile:
output << "dseries__." << datatree.symbol_table.getName(symb_id);
output << LEFT_ARRAY_SUBSCRIPT(output_type) << "t";
if (lag != 0)
output << LEFT_ARRAY_SUBSCRIPT(output_type) << lag << RIGHT_ARRAY_SUBSCRIPT(output_type);
output << lag;
output << RIGHT_ARRAY_SUBSCRIPT(output_type);
break;
default:
cerr << "VariableNode::writeOutput: should not reach this point" << endl;
......@@ -1105,8 +1109,10 @@ VariableNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
break;
case ExprNodeOutputType::epilogueFile:
output << "dseries__." << datatree.symbol_table.getName(symb_id);
output << LEFT_ARRAY_SUBSCRIPT(output_type) << "t";
if (lag != 0)
output << LEFT_ARRAY_SUBSCRIPT(output_type) << lag << RIGHT_ARRAY_SUBSCRIPT(output_type);
output << lag;
output << RIGHT_ARRAY_SUBSCRIPT(output_type);
break;
default:
cerr << "VariableNode::writeOutput: should not reach this point" << endl;
......@@ -1117,8 +1123,10 @@ VariableNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
if (output_type == ExprNodeOutputType::epilogueFile)
{
output << "dseries__." << datatree.symbol_table.getName(symb_id);
output << LEFT_ARRAY_SUBSCRIPT(output_type) << "t";
if (lag != 0)
output << LEFT_ARRAY_SUBSCRIPT(output_type) << lag << RIGHT_ARRAY_SUBSCRIPT(output_type);
output << lag;
output << RIGHT_ARRAY_SUBSCRIPT(output_type);
break;
}
else
......@@ -1523,6 +1531,7 @@ VariableNode::maxLagWithDiffsExpanded() const
case SymbolType::endogenous:
case SymbolType::exogenous:
case SymbolType::exogenousDet:
case SymbolType::epilogue:
return -lag;
case SymbolType::modelLocalVariable:
return datatree.getLocalVariable(symb_id)->maxLagWithDiffsExpanded();
......
......@@ -355,7 +355,8 @@ Epilogue::writeEpilogueFile(const string &basename) const
ExprNodeOutputType output_type = ExprNodeOutputType::epilogueFile;
output << "function dseries__ = epilogue(params, dseries__)" << endl
<< "% function dseries__ = epilogue(params, dseries__)" << endl
<< "% Epilogue file generated by Dynare preprocessor" << endl;
<< "% Epilogue file generated by Dynare preprocessor" << endl << endl
<< "simul_end_date = lastdate(dseries__);" << endl;
deriv_node_temp_terms_t tef_terms;
temporary_terms_t temporary_terms;
......@@ -366,9 +367,24 @@ Epilogue::writeEpilogueFile(const string &basename) const
output << endl;
for (const auto & it : def_table)
{
output << "dseries__." << symbol_table.getName(it.first) << " = ";
int max_lag = it.second->maxLagWithDiffsExpanded();
set<int> used_symbols;
it.second->collectVariables(SymbolType::endogenous, used_symbols);
it.second->collectVariables(SymbolType::exogenous, used_symbols);
it.second->collectVariables(SymbolType::epilogue, used_symbols);
output << "simul_begin_date = dseries__." << symbol_table.getName(*(used_symbols.begin())) << ".firstobservedperiod;" << endl;
for (auto it1 = next(used_symbols.begin()); it1 != used_symbols.end(); it1++)
output << "if simul_begin_date < dseries__." << symbol_table.getName(*it1) << ".firstobservedperiod" << endl
<< " simul_begin_date = dseries__." << symbol_table.getName(*it1) << ".firstobservedperiod;" << endl
<< "end" << endl;
output << "simul_begin_date = simul_begin_date + " << max_lag << " + 1;" << endl
<< "if ~dseries__.exist('" << symbol_table.getName(it.first) << "')" << endl
<< " dseries__ = [dseries__ dseries(NaN(dseries__.nobs,1), dseries__.firstdate, '" << symbol_table.getName(it.first)<< "')];" << endl
<< "end" << endl
<< "from simul_begin_date to simul_end_date do "
<< "dseries__." << symbol_table.getName(it.first) << "(t) = ";
it.second->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
output << ";" << endl;
output << ";" << endl << endl;
}
output << "end" << endl;
output.close();
......
......@@ -830,9 +830,14 @@ ParsingDriver::end_epilogue()
}
void
ParsingDriver::add_epilogue_equal(const string &name, expr_t expr)
ParsingDriver::add_epilogue_variable(const string &name)
{
declare_symbol(name, SymbolType::epilogue, "", {});
}
void
ParsingDriver::add_epilogue_equal(const string &name, expr_t expr)
{
mod_file->epilogue.addDefinition(mod_file->symbol_table.getID(name), expr);
}
......
......@@ -421,8 +421,10 @@ public:
void end_homotopy();
//! Begin epilogue block
void begin_epilogue();
//! Endepilogue block
//! End epilogue block
void end_epilogue();
//! Add epilogue variable
void add_epilogue_variable(const string &varname);
//! Add equation in epilogue block
void add_epilogue_equal(const string &varname, expr_t expr);
//! Begin a model block
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment