Commit aa120abd authored by Houtan Bastani's avatar Houtan Bastani

simplify code

parent a95cf811
......@@ -3841,12 +3841,8 @@ DynamicModel::fillTrendComponentModelTable() const
rhsr[it.first] = rhs;
nonstationaryr[it.first] = nonstationary;
}
trend_component_model_table.setEqNums(eqnums);
trend_component_model_table.setTrendEqNums(trend_eqnums);
trend_component_model_table.setLhs(lhsr);
trend_component_model_table.setRhs(rhsr);
trend_component_model_table.setLhsExprT(lhs_expr_tr);
trend_component_model_table.setNonstationary(nonstationaryr);
trend_component_model_table.setVals(eqnums, trend_eqnums, lhsr, lhs_expr_tr, nonstationaryr);
// Fill AR Matrix
map<string, map<tuple<int, int, int>, expr_t>> ARr, ECr;
......
......@@ -42,24 +42,15 @@ TrendComponentModelTable::addTrendComponentModel(string name_arg,
}
void
TrendComponentModelTable::setEqNums(map<string, vector<int>> eqnums_arg)
TrendComponentModelTable::setVals(map<string, vector<int>> eqnums_arg, map<string, vector<int>> trend_eqnums_arg,
map<string, vector<int>> lhs_arg,
map<string, vector<expr_t>> lhs_expr_t_arg, map<string, vector<bool>> nonstationary_arg)
{
eqnums = move(eqnums_arg);
setNonTrendEqnums();
}
void
TrendComponentModelTable::setTrendEqNums(map<string, vector<int>> trend_eqnums_arg)
{
trend_eqnums = move(trend_eqnums_arg);
setNonTrendEqnums();
}
void
TrendComponentModelTable::setNonTrendEqnums()
{
if (!nontrend_eqnums.empty() || eqnums.empty() || trend_eqnums.empty())
return;
lhs = move(lhs_arg);
lhs_expr_t = move(lhs_expr_t_arg);
nonstationary = move(nonstationary_arg);
for (const auto &it : eqnums)
{
......@@ -69,24 +60,20 @@ TrendComponentModelTable::setNonTrendEqnums()
nontrend_vec.push_back(eq);
nontrend_eqnums[it.first] = nontrend_vec;
}
}
void
TrendComponentModelTable::setNonstationary(map<string, vector<bool>> nonstationary_arg)
{
nonstationary = move(nonstationary_arg);
}
void
TrendComponentModelTable::setTrendVar(map<string, vector<int>> trend_vars_arg)
{
trend_vars = move(trend_vars_arg);
}
for (const auto &name : names)
{
vector<int> nontrend_lhs_vec, trend_lhs_vec;
vector<int> lhsv = getLhs(name);
vector<int> eqnumsv = getEqNums(name);
for (int nontrend_it : getNonTrendEqNums(name))
nontrend_lhs_vec.push_back(lhsv.at(distance(eqnumsv.begin(), find(eqnumsv.begin(), eqnumsv.end(), nontrend_it))));
nontrend_lhs[name] = nontrend_lhs_vec;
void
TrendComponentModelTable::setLhs(map<string, vector<int>> lhs_arg)
{
lhs = move(lhs_arg);
for (int trend_it : getTrendEqNums(name))
trend_lhs_vec.push_back(lhsv.at(distance(eqnumsv.begin(), find(eqnumsv.begin(), eqnumsv.end(), trend_it))));
trend_lhs[name] = trend_lhs_vec;
}
}
void
......@@ -96,9 +83,9 @@ TrendComponentModelTable::setRhs(map<string, vector<set<pair<int, int>>>> rhs_ar
}
void
TrendComponentModelTable::setLhsExprT(map<string, vector<expr_t>> lhs_expr_t_arg)
TrendComponentModelTable::setTrendVar(map<string, vector<int>> trend_vars_arg)
{
lhs_expr_t = move(lhs_expr_t_arg);
trend_vars = move(trend_vars_arg);
}
void
......@@ -169,6 +156,20 @@ TrendComponentModelTable::getNonstationary(const string &name_arg) const
return nonstationary.find(name_arg)->second;
}
vector<int>
TrendComponentModelTable::getNontrendLhs(const string &name_arg) const
{
checkModelName(name_arg);
return nontrend_lhs.find(name_arg)->second;
}
vector<int>
TrendComponentModelTable::getTrendLhs(const string &name_arg) const
{
checkModelName(name_arg);
return trend_lhs.find(name_arg)->second;
}
vector<int>
TrendComponentModelTable::getLhs(const string &name_arg) const
{
......@@ -335,35 +336,29 @@ TrendComponentModelTable::writeOutput(const string &basename, ostream &output) c
output << (it >= 0 ? symbol_table.getTypeSpecificID(it) + 1 : -1) << " ";
output << "];" << endl;
vector<int> nontrend_lhs, trend_lhs;
vector<int> lhsv = getLhs(name);
vector<int> eqnumsv = getEqNums(name);
for (int nontrend_it : getNonTrendEqNums(name))
nontrend_lhs.push_back(lhsv.at(distance(eqnumsv.begin(), find(eqnumsv.begin(), eqnumsv.end(), nontrend_it))));
for (int trend_it : getTrendEqNums(name))
trend_lhs.push_back(lhsv.at(distance(eqnumsv.begin(), find(eqnumsv.begin(), eqnumsv.end(), trend_it))));
vector<int> trend_lhs_vec = getTrendLhs(name);
vector<int> nontrend_lhs_vec = getNontrendLhs(name);
ar_ec_output << "if strcmp(model_name, '" << name << "')" << endl
<< " % AR" << endl
<< " ar = zeros(" << nontrend_lhs.size() << ", " << nontrend_lhs.size() << ", " << getMaxLag(name) << ");" << endl;
<< " ar = zeros(" << nontrend_lhs_vec.size() << ", " << nontrend_lhs_vec.size() << ", " << getMaxLag(name) << ");" << endl;
for (const auto & it : AR.at(name))
{
int eqn, lag, lhs_symb_id;
tie (eqn, lag, lhs_symb_id) = it.first;
int colidx = (int) distance(nontrend_lhs.begin(), find(nontrend_lhs.begin(), nontrend_lhs.end(), lhs_symb_id));
int colidx = (int) distance(nontrend_lhs_vec.begin(), find(nontrend_lhs_vec.begin(), nontrend_lhs_vec.end(), lhs_symb_id));
ar_ec_output << " ar(" << eqn + 1 << ", " << colidx + 1 << ", " << lag << ") = ";
it.second->writeOutput(ar_ec_output, ExprNodeOutputType::matlabDynamicModel);
ar_ec_output << ";" << endl;
}
ar_ec_output << endl
<< " % EC" << endl
<< " ec = zeros(" << nontrend_lhs.size() << ", " << nontrend_lhs.size() << ", 1);" << endl;
<< " ec = zeros(" << nontrend_lhs_vec.size() << ", " << nontrend_lhs_vec.size() << ", 1);" << endl;
for (const auto & it : EC.at(name))
{
int eqn, lag, lhs_symb_id;
tie (eqn, lag, lhs_symb_id) = it.first;
int colidx = (int) distance(trend_lhs.begin(), find(trend_lhs.begin(), trend_lhs.end(), lhs_symb_id));
int colidx = (int) distance(trend_lhs_vec.begin(), find(trend_lhs_vec.begin(), trend_lhs_vec.end(), lhs_symb_id));
ar_ec_output << " ec(" << eqn + 1 << ", " << colidx + 1 << ", 1) = ";
it.second->writeOutput(ar_ec_output, ExprNodeOutputType::matlabDynamicModel);
ar_ec_output << ";" << endl;
......
......@@ -40,7 +40,7 @@ private:
SymbolTable &symbol_table;
set<string> names;
map<string, vector<string>> eqtags, trend_eqtags;
map<string, vector<int>> eqnums, trend_eqnums, nontrend_eqnums, max_lags, lhs, orig_diff_var;
map<string, vector<int>> eqnums, trend_eqnums, nontrend_eqnums, max_lags, lhs, trend_lhs, nontrend_lhs, orig_diff_var;
map<string, vector<set<pair<int, int>>>> rhs;
map<string, vector<bool>> diff, nonstationary;
map<string, vector<expr_t>> lhs_expr_t;
......@@ -72,20 +72,19 @@ public:
map<string, vector<int>> getNonTrendEqNums() const;
vector<int> getNonTrendEqNums(const string &name_arg) const;
vector<bool> getNonstationary(const string &name_arg) const;
vector<int> getNontrendLhs(const string &name_arg) const;
vector<int> getTrendLhs(const string &name_arg) const;
void setEqNums(map<string, vector<int>> eqnums_arg);
void setTrendEqNums(map<string, vector<int>> trend_eqnums_arg);
void setLhs(map<string, vector<int>> lhs_arg);
void setVals(map<string, vector<int>> eqnums_arg, map<string, vector<int>> trend_eqnums_arg,
map<string, vector<int>> lhs_arg,
map<string, vector<expr_t>> lhs_expr_t_arg, map<string, vector<bool>> nonstationary_arg);
void setRhs(map<string, vector<set<pair<int, int>>>> rhs_arg);
void setLhsExprT(map<string, vector<expr_t>> lhs_expr_t_arg);
void setMaxLags(map<string, vector<int>> max_lags_arg);
void setDiff(map<string, vector<bool>> diff_arg);
void setOrigDiffVar(map<string, vector<int>> orig_diff_var_arg);
void setNonstationary(map<string, vector<bool>> nonstationary_arg);
void setTrendVar(map<string, vector<int>> trend_vars_arg);
void setAR(map<string, map<tuple<int, int, int>, expr_t>> AR_arg);
void setEC(map<string, map<tuple<int, int, int>, expr_t>> EC_arg);
void setNonTrendEqNums(map<string, vector<int>> trend_eqnums_arg);
//! Write output of this class
void writeOutput(const string &basename, ostream &output) const;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment