Verified Commit f4837e6c authored by Houtan Bastani's avatar Houtan Bastani
Browse files

fix AR and EC matrices when `equation = C` results in a simplified equation

parent 6e680c01
Pipeline #814 passed with stage
in 1 minute and 22 seconds
...@@ -3897,7 +3897,7 @@ DynamicModel::fillAutoregressiveMatrix(map<string, map<tuple<int, int, int>, exp ...@@ -3897,7 +3897,7 @@ DynamicModel::fillAutoregressiveMatrix(map<string, map<tuple<int, int, int>, exp
int i = 0; int i = 0;
map<tuple<int, int, int>, expr_t> AR; map<tuple<int, int, int>, expr_t> AR;
vector<int> lhs = is_trend_component_model ? vector<int> lhs = is_trend_component_model ?
trend_component_model_table.getLhs(it.first) : var_model_table.getLhs(it.first); trend_component_model_table.getNonTargetLhs(it.first) : var_model_table.getLhs(it.first);
for (auto eqn : it.second) for (auto eqn : it.second)
equations[eqn]->arg2->fillAutoregressiveRow(i++, lhs, AR); equations[eqn]->arg2->fillAutoregressiveRow(i++, lhs, AR);
ARr[it.first] = AR; ARr[it.first] = AR;
...@@ -4014,22 +4014,22 @@ DynamicModel::fillErrorComponentMatrix(map<string, map<tuple<int, int, int>, exp ...@@ -4014,22 +4014,22 @@ DynamicModel::fillErrorComponentMatrix(map<string, map<tuple<int, int, int>, exp
{ {
int i = 0; int i = 0;
map<tuple<int, int, int>, expr_t> EC; map<tuple<int, int, int>, expr_t> EC;
vector<int> trend_lhs = trend_component_model_table.getTargetLhs(it.first); vector<int> target_lhs = trend_component_model_table.getTargetLhs(it.first);
vector<int> nontrend_eqnums = trend_component_model_table.getNonTargetEqNums(it.first); vector<int> nontarget_eqnums = trend_component_model_table.getNonTargetEqNums(it.first);
vector<int> undiff_nontrend_lhs = getUndiffLHSForPac(it.first, diff_subst_table); vector<int> undiff_nontarget_lhs = getUndiffLHSForPac(it.first, diff_subst_table);
vector<int> parsed_undiff_nontrend_lhs; vector<int> parsed_undiff_nontarget_lhs;
for (auto eqn : it.second) for (auto eqn : it.second)
{ {
if (find(nontrend_eqnums.begin(), nontrend_eqnums.end(), eqn) != nontrend_eqnums.end()) if (find(nontarget_eqnums.begin(), nontarget_eqnums.end(), eqn) != nontarget_eqnums.end())
parsed_undiff_nontrend_lhs.push_back(undiff_nontrend_lhs.at(i)); parsed_undiff_nontarget_lhs.push_back(undiff_nontarget_lhs.at(i));
i++; i++;
} }
i = 0; i = 0;
for (auto eqn : it.second) for (auto eqn : it.second)
if (find(nontrend_eqnums.begin(), nontrend_eqnums.end(), eqn) != nontrend_eqnums.end()) if (find(nontarget_eqnums.begin(), nontarget_eqnums.end(), eqn) != nontarget_eqnums.end())
equations[eqn]->arg2->fillErrorCorrectionRow(i++, parsed_undiff_nontrend_lhs, trend_lhs, EC); equations[eqn]->arg2->fillErrorCorrectionRow(i++, parsed_undiff_nontarget_lhs, target_lhs, EC);
ECr[it.first] = EC; ECr[it.first] = EC;
} }
} }
......
...@@ -5715,15 +5715,30 @@ BinaryOpNode::fillAutoregressiveRowHelper(expr_t arg1, expr_t arg2, ...@@ -5715,15 +5715,30 @@ BinaryOpNode::fillAutoregressiveRowHelper(expr_t arg1, expr_t arg2,
return; return;
set<pair<int, int>> endogs, tmp; set<pair<int, int>> endogs, tmp;
arg2->collectDynamicVariables(SymbolType::endogenous, endogs);
if (endogs.size() != 1)
return;
arg1->collectDynamicVariables(SymbolType::endogenous, tmp); arg1->collectDynamicVariables(SymbolType::endogenous, tmp);
arg1->collectDynamicVariables(SymbolType::exogenous, tmp); arg1->collectDynamicVariables(SymbolType::exogenous, tmp);
if (tmp.size() != 0) if (tmp.size() != 0)
return; return;
arg1->collectDynamicVariables(SymbolType::parameter, tmp);
if (tmp.size() != 1)
return;
auto *vn = dynamic_cast<VariableNode *>(arg2);
if (vn == nullptr)
return;
arg2->collectDynamicVariables(SymbolType::exogenous, endogs);
if (endogs.size() != 0)
{
cerr << "BinaryOpNode::fillAutoregressiveRowHelper: do not currently support param*exog;" << endl;
exit(EXIT_FAILURE);
}
arg2->collectDynamicVariables(SymbolType::endogenous, endogs);
if (endogs.size() != 1)
return;
int lhs_symb_id = endogs.begin()->first; int lhs_symb_id = endogs.begin()->first;
int lag = endogs.begin()->second; int lag = endogs.begin()->second;
if (datatree.symbol_table.isAuxiliaryVariable(lhs_symb_id)) if (datatree.symbol_table.isAuxiliaryVariable(lhs_symb_id))
...@@ -5734,7 +5749,9 @@ BinaryOpNode::fillAutoregressiveRowHelper(expr_t arg1, expr_t arg2, ...@@ -5734,7 +5749,9 @@ BinaryOpNode::fillAutoregressiveRowHelper(expr_t arg1, expr_t arg2,
lag = -1 * datatree.symbol_table.getOrigLeadLagForDiffAuxVar(lhs_symb_id); lag = -1 * datatree.symbol_table.getOrigLeadLagForDiffAuxVar(lhs_symb_id);
lhs_symb_id = orig_lhs_symb_id; lhs_symb_id = orig_lhs_symb_id;
} }
else
if (find(lhs.begin(), lhs.end(), lhs_symb_id) == lhs.end())
return;
if (AR.find(make_tuple(eqn, -lag, lhs_symb_id)) != AR.end()) if (AR.find(make_tuple(eqn, -lag, lhs_symb_id)) != AR.end())
{ {
...@@ -5756,8 +5773,8 @@ BinaryOpNode::fillAutoregressiveRow(int eqn, const vector<int> &lhs, map<tuple<i ...@@ -5756,8 +5773,8 @@ BinaryOpNode::fillAutoregressiveRow(int eqn, const vector<int> &lhs, map<tuple<i
void void
BinaryOpNode::fillErrorCorrectionRowHelper(expr_t arg1, expr_t arg2, BinaryOpNode::fillErrorCorrectionRowHelper(expr_t arg1, expr_t arg2,
int eqn, int eqn,
const vector<int> &nontrend_lhs, const vector<int> &nontarget_lhs,
const vector<int> &trend_lhs, const vector<int> &target_lhs,
map<tuple<int, int, int>, expr_t> &EC) const map<tuple<int, int, int>, expr_t> &EC) const
{ {
if (op_code != BinaryOpcode::times) if (op_code != BinaryOpcode::times)
...@@ -5769,9 +5786,12 @@ BinaryOpNode::fillErrorCorrectionRowHelper(expr_t arg1, expr_t arg2, ...@@ -5769,9 +5786,12 @@ BinaryOpNode::fillErrorCorrectionRowHelper(expr_t arg1, expr_t arg2,
if (tmp.size() != 0) if (tmp.size() != 0)
return; return;
auto *multiplicandr = dynamic_cast<BinaryOpNode *>(arg2); arg1->collectDynamicVariables(SymbolType::parameter, tmp);
if (multiplicandr == nullptr if (tmp.size() != 1)
|| multiplicandr->op_code != BinaryOpcode::minus) return;
auto *bopn = dynamic_cast<BinaryOpNode *>(arg2);
if (bopn == nullptr || bopn->op_code != BinaryOpcode::minus)
return; return;
arg2->collectDynamicVariables(SymbolType::endogenous, endogs); arg2->collectDynamicVariables(SymbolType::endogenous, endogs);
...@@ -5781,12 +5801,21 @@ BinaryOpNode::fillErrorCorrectionRowHelper(expr_t arg1, expr_t arg2, ...@@ -5781,12 +5801,21 @@ BinaryOpNode::fillErrorCorrectionRowHelper(expr_t arg1, expr_t arg2,
arg2->collectDynamicVariables(SymbolType::exogenous, endogs); arg2->collectDynamicVariables(SymbolType::exogenous, endogs);
arg2->collectDynamicVariables(SymbolType::parameter, endogs); arg2->collectDynamicVariables(SymbolType::parameter, endogs);
if (endogs.size() != 2) if (endogs.size() != 2)
return; {
cerr << "ERROR in model; expecting param*endog or param*(endog-endog)" << endl;
exit(EXIT_FAILURE);
}
int endog1, lag1, endog2, lag2; auto *vn1 = dynamic_cast<VariableNode *>(bopn->arg1);
tie(endog1, lag1) = *endogs.begin(); auto *vn2 = dynamic_cast<VariableNode *>(bopn->arg2);
tie(endog2, lag2) = *next(endogs.begin(), 1); if (vn1 == nullptr || vn2 == nullptr)
int orig_endog1 = endog1; {
cerr << "ERROR in model; expecting param*endog or param*(endog-endog)" << endl;
exit(EXIT_FAILURE);
}
int endog1 = vn1->symb_id;
int endog2 = vn2->symb_id;
int orig_endog2 = endog2; int orig_endog2 = endog2;
bool isauxvar1 = datatree.symbol_table.isAuxiliaryVariable(endog1); bool isauxvar1 = datatree.symbol_table.isAuxiliaryVariable(endog1);
...@@ -5799,24 +5828,16 @@ BinaryOpNode::fillErrorCorrectionRowHelper(expr_t arg1, expr_t arg2, ...@@ -5799,24 +5828,16 @@ BinaryOpNode::fillErrorCorrectionRowHelper(expr_t arg1, expr_t arg2,
int max_lag = 0; int max_lag = 0;
int colidx = -1; int colidx = -1;
if (find(nontrend_lhs.begin(), nontrend_lhs.end(), endog1) != nontrend_lhs.end()) if (find(nontarget_lhs.begin(), nontarget_lhs.end(), endog1) != nontarget_lhs.end()
&& find(target_lhs.begin(), target_lhs.end(), endog2) != target_lhs.end())
{ {
colidx = (int) distance(nontrend_lhs.begin(), find(nontrend_lhs.begin(), nontrend_lhs.end(), endog1)); colidx = (int) distance(target_lhs.begin(), find(target_lhs.begin(), target_lhs.end(), endog2));
int tmp_lag = lag2; int tmp_lag = vn2->lag;
if (isauxvar2) if (isauxvar2)
tmp_lag = -1 * datatree.symbol_table.getOrigLeadLagForDiffAuxVar(orig_endog2); tmp_lag = -1 * datatree.symbol_table.getOrigLeadLagForDiffAuxVar(orig_endog2);
if (tmp_lag < max_lag) if (tmp_lag < max_lag)
max_lag = tmp_lag; max_lag = tmp_lag;
} }
else if (find(nontrend_lhs.begin(), nontrend_lhs.end(), endog2) != nontrend_lhs.end())
{
colidx = (int) distance(nontrend_lhs.begin(), find(nontrend_lhs.begin(), nontrend_lhs.end(), endog2));
int tmp_lag = lag1;
if (isauxvar1)
tmp_lag = -1 * datatree.symbol_table.getOrigLeadLagForDiffAuxVar(orig_endog1);
if (tmp_lag < max_lag)
max_lag = tmp_lag;
}
else else
return; return;
......
/* /*
* Copyright (C) 2018 Dynare Team * Copyright (C) 2018-2019 Dynare Team
* *
* This file is part of Dynare. * This file is part of Dynare.
* *
...@@ -148,7 +148,7 @@ TrendComponentModelTable::checkModelName(const string &name_arg) const ...@@ -148,7 +148,7 @@ TrendComponentModelTable::checkModelName(const string &name_arg) const
} }
vector<int> vector<int>
TrendComponentModelTable::getNontrendLhs(const string &name_arg) const TrendComponentModelTable::getNonTargetLhs(const string &name_arg) const
{ {
checkModelName(name_arg); checkModelName(name_arg);
return nontarget_lhs.find(name_arg)->second; return nontarget_lhs.find(name_arg)->second;
...@@ -331,7 +331,7 @@ TrendComponentModelTable::writeOutput(const string &basename, ostream &output) c ...@@ -331,7 +331,7 @@ TrendComponentModelTable::writeOutput(const string &basename, ostream &output) c
output << "];" << endl; output << "];" << endl;
vector<int> target_lhs_vec = getTargetLhs(name); vector<int> target_lhs_vec = getTargetLhs(name);
vector<int> nontarget_lhs_vec = getNontrendLhs(name); vector<int> nontarget_lhs_vec = getNonTargetLhs(name);
ar_ec_output << "if strcmp(model_name, '" << name << "')" << endl ar_ec_output << "if strcmp(model_name, '" << name << "')" << endl
<< " % AR" << endl << " % AR" << endl
...@@ -347,7 +347,7 @@ TrendComponentModelTable::writeOutput(const string &basename, ostream &output) c ...@@ -347,7 +347,7 @@ TrendComponentModelTable::writeOutput(const string &basename, ostream &output) c
} }
ar_ec_output << endl ar_ec_output << endl
<< " % EC" << endl << " % EC" << endl
<< " ec = zeros(" << nontarget_lhs_vec.size() << ", " << nontarget_lhs_vec.size() << ", 1);" << endl; << " ec = zeros(" << nontarget_lhs_vec.size() << ", " << target_lhs_vec.size() << ", 1);" << endl;
for (const auto & it : EC.at(name)) for (const auto & it : EC.at(name))
{ {
int eqn, lag, colidx; int eqn, lag, colidx;
......
/* /*
* Copyright (C) 2018 Dynare Team * Copyright (C) 2018-2019 Dynare Team
* *
* This file is part of Dynare. * This file is part of Dynare.
* *
...@@ -72,7 +72,7 @@ public: ...@@ -72,7 +72,7 @@ public:
vector<int> getOrigDiffVar(const string &name_arg) const; vector<int> getOrigDiffVar(const string &name_arg) const;
map<string, vector<int>> getNonTargetEqNums() const; map<string, vector<int>> getNonTargetEqNums() const;
vector<int> getNonTargetEqNums(const string &name_arg) const; vector<int> getNonTargetEqNums(const string &name_arg) const;
vector<int> getNontrendLhs(const string &name_arg) const; vector<int> getNonTargetLhs(const string &name_arg) const;
vector<int> getTargetLhs(const string &name_arg) const; vector<int> getTargetLhs(const string &name_arg) const;
void setVals(map<string, vector<int>> eqnums_arg, map<string, vector<int>> target_eqnums_arg, void setVals(map<string, vector<int>> eqnums_arg, map<string, vector<int>> target_eqnums_arg,
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment