Commit 46709ada authored by Houtan Bastani's avatar Houtan Bastani

output AR matrix in file for trend component models

parent 13a0115b
......@@ -3709,19 +3709,22 @@ DynamicModel::fillVarModelTableFromOrigModel(StaticModel &static_model) const
// Fill AR Matrix
map<string, map<tuple<int, int, int>, expr_t>> ARr;
fillAutoregressiveMatrix(ARr);
fillAutoregressiveMatrix(ARr, false);
var_model_table.setAR(ARr);
}
void
DynamicModel::fillAutoregressiveMatrix(map<string, map<tuple<int, int, int>, expr_t>> &ARr) const
DynamicModel::fillAutoregressiveMatrix(map<string, map<tuple<int, int, int>, expr_t>> &ARr, bool is_trend_component_model) const
{
for (const auto & it : var_model_table.getEqNums())
auto eqnums = is_trend_component_model ? trend_component_model_table.getEqNums() : var_model_table.getEqNums();
for (const auto & it : eqnums)
{
int i = 0;
map<tuple<int, int, int>, expr_t> AR;
vector<int> lhs = is_trend_component_model ?
trend_component_model_table.getLhs(it.first) : var_model_table.getLhs(it.first);
for (auto eqn : it.second)
equations[eqn]->get_arg2()->fillAutoregressiveRow(i++, var_model_table.getLhs(it.first), AR);
equations[eqn]->get_arg2()->fillAutoregressiveRow(i++, lhs, AR);
ARr[it.first] = AR;
}
}
......@@ -3844,6 +3847,11 @@ DynamicModel::fillTrendComponentModelTable() const
trend_component_model_table.setRhs(rhsr);
trend_component_model_table.setLhsExprT(lhs_expr_tr);
trend_component_model_table.setNonstationary(nonstationaryr);
// Fill AR Matrix
map<string, map<tuple<int, int, int>, expr_t>> ARr;
fillAutoregressiveMatrix(ARr, true);
trend_component_model_table.setAR(ARr);
}
void
......
......@@ -305,7 +305,7 @@ public:
void setNonZeroHessianEquations(map<int, string> &eqs);
//! Fill Autoregressive Matrix for var_model
void fillAutoregressiveMatrix(map<string, map<tuple<int, int, int>, expr_t>> &ARr) const;
void fillAutoregressiveMatrix(map<string, map<tuple<int, int, int>, expr_t>> &ARr, bool is_trend_component_model) const;
//! Fill the Trend Component Model Table
void fillTrendComponentModelTable() const;
......
......@@ -5440,21 +5440,30 @@ BinaryOpNode::fillAutoregressiveRowHelper(expr_t arg1, expr_t arg2,
const vector<int> &lhs,
map<tuple<int, int, int>, expr_t> &AR) const
{
if (op_code != BinaryOpcode::times)
return;
set<pair<int, int>> endogs, tmp;
arg2->collectDynamicVariables(SymbolType::endogenous, endogs);
if (endogs.size() != 1)
return;
int lhs_symb_id = endogs.begin()->first;
if (find(lhs.begin(), lhs.end(), lhs_symb_id) == lhs.end())
return;
int lag = endogs.begin()->second;
if (datatree.symbol_table.isAuxiliaryVariable(lhs_symb_id))
{
int orig_lhs_symb_id = datatree.symbol_table.getOrigSymbIdForDiffAuxVar(lhs_symb_id);
if (find(lhs.begin(), lhs.end(), orig_lhs_symb_id) == lhs.end())
return;
lag = -(datatree.symbol_table.getOrigLeadLagForDiffAuxVar(lhs_symb_id) - 1);
lhs_symb_id = orig_lhs_symb_id;
}
arg1->collectDynamicVariables(SymbolType::endogenous, tmp);
arg1->collectDynamicVariables(SymbolType::exogenous, tmp);
if (tmp.size() != 0)
return;
int lag = endogs.begin()->second;
if (AR.find(make_tuple(eqn, -lag, lhs_symb_id)) != AR.end())
{
cerr << "BinaryOpNode::fillAutoregressiveRowHelper: Error filling AR matrix: lag/symb_id encountered more than once in equtaion" << endl;
......@@ -5466,11 +5475,8 @@ BinaryOpNode::fillAutoregressiveRowHelper(expr_t arg1, expr_t arg2,
void
BinaryOpNode::fillAutoregressiveRow(int eqn, const vector<int> &lhs, map<tuple<int, int, int>, expr_t> &AR) const
{
if (op_code == BinaryOpcode::times)
{
fillAutoregressiveRowHelper(arg1, arg2, eqn, lhs, AR);
fillAutoregressiveRowHelper(arg2, arg1, eqn, lhs, AR);
}
fillAutoregressiveRowHelper(arg1, arg2, eqn, lhs, AR);
fillAutoregressiveRowHelper(arg2, arg1, eqn, lhs, AR);
arg1->fillAutoregressiveRow(eqn, lhs, AR);
arg2->fillAutoregressiveRow(eqn, lhs, AR);
}
......
......@@ -878,7 +878,7 @@ ModFile::writeOutputFiles(const string &basename, bool clear_all, bool clear_glo
symbol_table.writeOutput(mOutputFile);
var_model_table.writeOutput(basename, mOutputFile);
trend_component_model_table.writeOutput(mOutputFile);
trend_component_model_table.writeOutput(basename, mOutputFile);
// Initialize M_.Sigma_e, M_.Correlation_matrix, M_.H, and M_.Correlation_matrix_ME
mOutputFile << "M_.Sigma_e = zeros(" << symbol_table.exo_nbr() << ", "
......
......@@ -45,18 +45,18 @@ void
TrendComponentModelTable::setEqNums(map<string, vector<int>> eqnums_arg)
{
eqnums = move(eqnums_arg);
setUndiffEqnums();
setNonTrendEqnums();
}
void
TrendComponentModelTable::setTrendEqNums(map<string, vector<int>> trend_eqnums_arg)
{
trend_eqnums = move(trend_eqnums_arg);
setUndiffEqnums();
setNonTrendEqnums();
}
void
TrendComponentModelTable::setUndiffEqnums()
TrendComponentModelTable::setNonTrendEqnums()
{
if (!nontrend_eqnums.empty() || eqnums.empty() || trend_eqnums.empty())
return;
......@@ -65,8 +65,7 @@ TrendComponentModelTable::setUndiffEqnums()
{
vector<int> nontrend_vec;
for (auto eq : it.second)
if (find(trend_eqnums[it.first].begin(), trend_eqnums[it.first].end(), eq)
== trend_eqnums[it.first].end())
if (find(trend_eqnums[it.first].begin(), trend_eqnums[it.first].end(), eq) == trend_eqnums[it.first].end())
nontrend_vec.push_back(eq);
nontrend_eqnums[it.first] = nontrend_vec;
}
......@@ -120,6 +119,12 @@ TrendComponentModelTable::setOrigDiffVar(map<string, vector<int>> orig_diff_var_
orig_diff_var = move(orig_diff_var_arg);
}
void
TrendComponentModelTable::setAR(map<string, map<tuple<int, int, int>, expr_t>> AR_arg)
{
AR = move(AR_arg);
}
map<string, vector<string>>
TrendComponentModelTable::getEqTags() const
{
......@@ -228,8 +233,20 @@ TrendComponentModelTable::getOrigDiffVar(const string &name_arg) const
}
void
TrendComponentModelTable::writeOutput(ostream &output) const
TrendComponentModelTable::writeOutput(const string &basename, ostream &output) const
{
string filename = "+" + basename + "/trend_component_ar.m";
ofstream ar_output;
ar_output.open(filename, ios::out | ios::binary);
if (!ar_output.is_open())
{
cerr << "Error: Can't open file " << filename << " for writing" << endl;
exit(EXIT_FAILURE);
}
ar_output << "function ar = trend_component_ar(model_name, params)" << endl
<< "%function ar = trend_component_ar(model_name, params)" << endl
<< "% File automatically generated by the Dynare preprocessor" << endl << endl;
for (const auto &name : names)
{
output << "M_.trend_component." << name << ".model_name = '" << name << "';" << endl
......@@ -291,7 +308,30 @@ TrendComponentModelTable::writeOutput(ostream &output) const
for (auto it : trend_vars.at(name))
output << (it >= 0 ? symbol_table.getTypeSpecificID(it) + 1 : -1) << " ";
output << "];" << endl;
vector<int> nontrend_lhs;
vector<int> lhsv = getLhs(name);
vector<int> eqnumsv = getEqNums(name);
for (int nontrend_it : getNonTrendEqNums(name))
nontrend_lhs.push_back(lhsv.at(distance(eqnumsv.begin(), find(eqnumsv.begin(), eqnumsv.end(), nontrend_it))));
ar_output << "if strcmp(model_name, '" << name << "')" << endl
<< " ar = zeros(" << nontrend_lhs.size() << ", " << nontrend_lhs.size() << ", " << getMaxLag(name) << ");" << endl;
for (const auto & it : AR.at(name))
{
int eqn, lag, lhs_symb_id;
tie (eqn, lag, lhs_symb_id) = it.first;
int colidx = (int) distance(nontrend_lhs.begin(), find(nontrend_lhs.begin(), nontrend_lhs.end(), lhs_symb_id));
ar_output << " ar(" << eqn + 1 << ", " << colidx + 1 << ", " << lag << ") = ";
it.second->writeOutput(ar_output, ExprNodeOutputType::matlabDynamicModel);
ar_output << ";" << endl;
}
ar_output << " return" << endl
<< "end" << endl << endl;
}
ar_output << "error([model_name ' is not a valid trend_component_model name'])" << endl
<< "end" << endl;
ar_output.close();
}
void
......
......@@ -45,6 +45,7 @@ private:
map<string, vector<bool>> diff, nonstationary;
map<string, vector<expr_t>> lhs_expr_t;
map<string, vector<int>> trend_vars;
map<string, map<tuple<int, int, int>, expr_t>> AR; // AR: name -> (eqn, lag, lhs_symb_id) -> param_expr_t
public:
TrendComponentModelTable(SymbolTable &symbol_table_arg);
......@@ -80,16 +81,18 @@ public:
void setOrigDiffVar(map<string, vector<int>> orig_diff_var_arg);
void setNonstationary(map<string, vector<bool>> nonstationary_arg);
void setTrendVar(map<string, vector<int>> trend_vars_arg);
void setAR(map<string, map<tuple<int, int, int>, expr_t>> AR_arg);
void setNonTrendEqNums(map<string, vector<int>> trend_eqnums_arg);
//! Write output of this class
void writeOutput(ostream &output) const;
void writeOutput(const string &basename, ostream &output) const;
//! Write JSON Output
void writeJsonOutput(ostream &output) const;
private:
void checkModelName(const string &name_arg) const;
void setUndiffEqnums();
void setNonTrendEqnums();
};
inline bool
......@@ -120,7 +123,7 @@ private:
public:
VarModelTable(SymbolTable &symbol_table_arg);
//! Add a trend component model
//! Add a VAR model
void addVarModel(string name, vector<string> eqtags,
pair<SymbolList, int> symbol_list_and_order_arg);
......
......@@ -874,6 +874,31 @@ SymbolTable::getOrigSymbIdForAuxVar(int aux_var_symb_id) const noexcept(false)
throw UnknownSymbolIDException(aux_var_symb_id);
}
int
SymbolTable::getOrigLeadLagForDiffAuxVar(int diff_aux_var_symb_id) const noexcept(false)
{
int lag = 0;
for (const auto & aux_var : aux_vars)
if ((aux_var.get_type() == AuxVarType::diff
|| aux_var.get_type() == AuxVarType::diffLag)
&& aux_var.get_symb_id() == diff_aux_var_symb_id)
lag += 1 + getOrigLeadLagForDiffAuxVar(aux_var.get_orig_symb_id());
return lag;
}
int
SymbolTable::getOrigSymbIdForDiffAuxVar(int diff_aux_var_symb_id) const noexcept(false)
{
int orig_symb_id = -1;
for (const auto & aux_var : aux_vars)
if (aux_var.get_symb_id() == diff_aux_var_symb_id)
if (aux_var.get_type() == AuxVarType::diff)
orig_symb_id = diff_aux_var_symb_id;
else if (aux_var.get_type() == AuxVarType::diffLag)
orig_symb_id = getOrigSymbIdForDiffAuxVar(aux_var.get_orig_symb_id());
return orig_symb_id;
}
expr_t
SymbolTable::getAuxiliaryVarsExprNode(int symb_id) const noexcept(false)
// throw exception if it is a Lagrange multiplier
......
......@@ -291,6 +291,10 @@ public:
int searchAuxiliaryVars(int orig_symb_id, int orig_lead_lag) const noexcept(false);
//! Serches aux_vars for the aux var represented by aux_var_symb_id and returns its associated orig_symb_id
int getOrigSymbIdForAuxVar(int aux_var_symb_id) const noexcept(false);
//! Searches for diff aux var and finds the original lag associated with this variable
int getOrigLeadLagForDiffAuxVar(int diff_aux_var_symb_id) const noexcept(false);
//! Searches for diff aux var and finds the symb id associated with this variable
int getOrigSymbIdForDiffAuxVar(int diff_aux_var_symb_id) const noexcept(false);
//! Adds an auxiliary variable when var_model is used with an order that is greater in absolute value
//! than the largest lag present in the model.
int addVarModelEndoLagAuxiliaryVar(int orig_symb_id, int orig_lead_lag, expr_t expr_arg) noexcept(false);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment