Verified Commit f4837e6c authored by Houtan Bastani's avatar Houtan Bastani
Browse files

fix AR and EC matrices when `equation = C` results in a simplified equation

parent 6e680c01
Pipeline #814 passed with stage
in 1 minute and 22 seconds
......@@ -3897,7 +3897,7 @@ DynamicModel::fillAutoregressiveMatrix(map<string, map<tuple<int, int, int>, exp
int i = 0;
map<tuple<int, int, int>, expr_t> AR;
vector<int> lhs = is_trend_component_model ?
trend_component_model_table.getLhs(it.first) : var_model_table.getLhs(it.first);
trend_component_model_table.getNonTargetLhs(it.first) : var_model_table.getLhs(it.first);
for (auto eqn : it.second)
equations[eqn]->arg2->fillAutoregressiveRow(i++, lhs, AR);
ARr[it.first] = AR;
......@@ -4014,22 +4014,22 @@ DynamicModel::fillErrorComponentMatrix(map<string, map<tuple<int, int, int>, exp
{
int i = 0;
map<tuple<int, int, int>, expr_t> EC;
vector<int> trend_lhs = trend_component_model_table.getTargetLhs(it.first);
vector<int> nontrend_eqnums = trend_component_model_table.getNonTargetEqNums(it.first);
vector<int> undiff_nontrend_lhs = getUndiffLHSForPac(it.first, diff_subst_table);
vector<int> parsed_undiff_nontrend_lhs;
vector<int> target_lhs = trend_component_model_table.getTargetLhs(it.first);
vector<int> nontarget_eqnums = trend_component_model_table.getNonTargetEqNums(it.first);
vector<int> undiff_nontarget_lhs = getUndiffLHSForPac(it.first, diff_subst_table);
vector<int> parsed_undiff_nontarget_lhs;
for (auto eqn : it.second)
{
if (find(nontrend_eqnums.begin(), nontrend_eqnums.end(), eqn) != nontrend_eqnums.end())
parsed_undiff_nontrend_lhs.push_back(undiff_nontrend_lhs.at(i));
if (find(nontarget_eqnums.begin(), nontarget_eqnums.end(), eqn) != nontarget_eqnums.end())
parsed_undiff_nontarget_lhs.push_back(undiff_nontarget_lhs.at(i));
i++;
}
i = 0;
for (auto eqn : it.second)
if (find(nontrend_eqnums.begin(), nontrend_eqnums.end(), eqn) != nontrend_eqnums.end())
equations[eqn]->arg2->fillErrorCorrectionRow(i++, parsed_undiff_nontrend_lhs, trend_lhs, EC);
if (find(nontarget_eqnums.begin(), nontarget_eqnums.end(), eqn) != nontarget_eqnums.end())
equations[eqn]->arg2->fillErrorCorrectionRow(i++, parsed_undiff_nontarget_lhs, target_lhs, EC);
ECr[it.first] = EC;
}
}
......
......@@ -5715,15 +5715,30 @@ BinaryOpNode::fillAutoregressiveRowHelper(expr_t arg1, expr_t arg2,
return;
set<pair<int, int>> endogs, tmp;
arg2->collectDynamicVariables(SymbolType::endogenous, endogs);
if (endogs.size() != 1)
return;
arg1->collectDynamicVariables(SymbolType::endogenous, tmp);
arg1->collectDynamicVariables(SymbolType::exogenous, tmp);
if (tmp.size() != 0)
return;
arg1->collectDynamicVariables(SymbolType::parameter, tmp);
if (tmp.size() != 1)
return;
auto *vn = dynamic_cast<VariableNode *>(arg2);
if (vn == nullptr)
return;
arg2->collectDynamicVariables(SymbolType::exogenous, endogs);
if (endogs.size() != 0)
{
cerr << "BinaryOpNode::fillAutoregressiveRowHelper: do not currently support param*exog;" << endl;
exit(EXIT_FAILURE);
}
arg2->collectDynamicVariables(SymbolType::endogenous, endogs);
if (endogs.size() != 1)
return;
int lhs_symb_id = endogs.begin()->first;
int lag = endogs.begin()->second;
if (datatree.symbol_table.isAuxiliaryVariable(lhs_symb_id))
......@@ -5734,7 +5749,9 @@ BinaryOpNode::fillAutoregressiveRowHelper(expr_t arg1, expr_t arg2,
lag = -1 * datatree.symbol_table.getOrigLeadLagForDiffAuxVar(lhs_symb_id);
lhs_symb_id = orig_lhs_symb_id;
}
else
if (find(lhs.begin(), lhs.end(), lhs_symb_id) == lhs.end())
return;
if (AR.find(make_tuple(eqn, -lag, lhs_symb_id)) != AR.end())
{
......@@ -5756,8 +5773,8 @@ BinaryOpNode::fillAutoregressiveRow(int eqn, const vector<int> &lhs, map<tuple<i
void
BinaryOpNode::fillErrorCorrectionRowHelper(expr_t arg1, expr_t arg2,
int eqn,
const vector<int> &nontrend_lhs,
const vector<int> &trend_lhs,
const vector<int> &nontarget_lhs,
const vector<int> &target_lhs,
map<tuple<int, int, int>, expr_t> &EC) const
{
if (op_code != BinaryOpcode::times)
......@@ -5769,9 +5786,12 @@ BinaryOpNode::fillErrorCorrectionRowHelper(expr_t arg1, expr_t arg2,
if (tmp.size() != 0)
return;
auto *multiplicandr = dynamic_cast<BinaryOpNode *>(arg2);
if (multiplicandr == nullptr
|| multiplicandr->op_code != BinaryOpcode::minus)
arg1->collectDynamicVariables(SymbolType::parameter, tmp);
if (tmp.size() != 1)
return;
auto *bopn = dynamic_cast<BinaryOpNode *>(arg2);
if (bopn == nullptr || bopn->op_code != BinaryOpcode::minus)
return;
arg2->collectDynamicVariables(SymbolType::endogenous, endogs);
......@@ -5781,12 +5801,21 @@ BinaryOpNode::fillErrorCorrectionRowHelper(expr_t arg1, expr_t arg2,
arg2->collectDynamicVariables(SymbolType::exogenous, endogs);
arg2->collectDynamicVariables(SymbolType::parameter, endogs);
if (endogs.size() != 2)
return;
{
cerr << "ERROR in model; expecting param*endog or param*(endog-endog)" << endl;
exit(EXIT_FAILURE);
}
int endog1, lag1, endog2, lag2;
tie(endog1, lag1) = *endogs.begin();
tie(endog2, lag2) = *next(endogs.begin(), 1);
int orig_endog1 = endog1;
auto *vn1 = dynamic_cast<VariableNode *>(bopn->arg1);
auto *vn2 = dynamic_cast<VariableNode *>(bopn->arg2);
if (vn1 == nullptr || vn2 == nullptr)
{
cerr << "ERROR in model; expecting param*endog or param*(endog-endog)" << endl;
exit(EXIT_FAILURE);
}
int endog1 = vn1->symb_id;
int endog2 = vn2->symb_id;
int orig_endog2 = endog2;
bool isauxvar1 = datatree.symbol_table.isAuxiliaryVariable(endog1);
......@@ -5799,24 +5828,16 @@ BinaryOpNode::fillErrorCorrectionRowHelper(expr_t arg1, expr_t arg2,
int max_lag = 0;
int colidx = -1;
if (find(nontrend_lhs.begin(), nontrend_lhs.end(), endog1) != nontrend_lhs.end())
if (find(nontarget_lhs.begin(), nontarget_lhs.end(), endog1) != nontarget_lhs.end()
&& find(target_lhs.begin(), target_lhs.end(), endog2) != target_lhs.end())
{
colidx = (int) distance(nontrend_lhs.begin(), find(nontrend_lhs.begin(), nontrend_lhs.end(), endog1));
int tmp_lag = lag2;
colidx = (int) distance(target_lhs.begin(), find(target_lhs.begin(), target_lhs.end(), endog2));
int tmp_lag = vn2->lag;
if (isauxvar2)
tmp_lag = -1 * datatree.symbol_table.getOrigLeadLagForDiffAuxVar(orig_endog2);
if (tmp_lag < max_lag)
max_lag = tmp_lag;
}
else if (find(nontrend_lhs.begin(), nontrend_lhs.end(), endog2) != nontrend_lhs.end())
{
colidx = (int) distance(nontrend_lhs.begin(), find(nontrend_lhs.begin(), nontrend_lhs.end(), endog2));
int tmp_lag = lag1;
if (isauxvar1)
tmp_lag = -1 * datatree.symbol_table.getOrigLeadLagForDiffAuxVar(orig_endog1);
if (tmp_lag < max_lag)
max_lag = tmp_lag;
}
else
return;
......
/*
* Copyright (C) 2018 Dynare Team
* Copyright (C) 2018-2019 Dynare Team
*
* This file is part of Dynare.
*
......@@ -148,7 +148,7 @@ TrendComponentModelTable::checkModelName(const string &name_arg) const
}
vector<int>
TrendComponentModelTable::getNontrendLhs(const string &name_arg) const
TrendComponentModelTable::getNonTargetLhs(const string &name_arg) const
{
checkModelName(name_arg);
return nontarget_lhs.find(name_arg)->second;
......@@ -331,7 +331,7 @@ TrendComponentModelTable::writeOutput(const string &basename, ostream &output) c
output << "];" << endl;
vector<int> target_lhs_vec = getTargetLhs(name);
vector<int> nontarget_lhs_vec = getNontrendLhs(name);
vector<int> nontarget_lhs_vec = getNonTargetLhs(name);
ar_ec_output << "if strcmp(model_name, '" << name << "')" << endl
<< " % AR" << endl
......@@ -347,7 +347,7 @@ TrendComponentModelTable::writeOutput(const string &basename, ostream &output) c
}
ar_ec_output << endl
<< " % EC" << endl
<< " ec = zeros(" << nontarget_lhs_vec.size() << ", " << nontarget_lhs_vec.size() << ", 1);" << endl;
<< " ec = zeros(" << nontarget_lhs_vec.size() << ", " << target_lhs_vec.size() << ", 1);" << endl;
for (const auto & it : EC.at(name))
{
int eqn, lag, colidx;
......
/*
* Copyright (C) 2018 Dynare Team
* Copyright (C) 2018-2019 Dynare Team
*
* This file is part of Dynare.
*
......@@ -72,7 +72,7 @@ public:
vector<int> getOrigDiffVar(const string &name_arg) const;
map<string, vector<int>> getNonTargetEqNums() const;
vector<int> getNonTargetEqNums(const string &name_arg) const;
vector<int> getNontrendLhs(const string &name_arg) const;
vector<int> getNonTargetLhs(const string &name_arg) const;
vector<int> getTargetLhs(const string &name_arg) const;
void setVals(map<string, vector<int>> eqnums_arg, map<string, vector<int>> target_eqnums_arg,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment