From aa120abd028c4ffd911cbbc81ffec1421a45be2d Mon Sep 17 00:00:00 2001 From: Houtan Bastani <houtan@dynare.org> Date: Wed, 12 Sep 2018 11:56:04 +0200 Subject: [PATCH] simplify code --- src/DynamicModel.cc | 6 +--- src/SubModel.cc | 85 +++++++++++++++++++++------------------------ src/SubModel.hh | 13 ++++--- 3 files changed, 47 insertions(+), 57 deletions(-) diff --git a/src/DynamicModel.cc b/src/DynamicModel.cc index 755f4345..0ed546ae 100644 --- a/src/DynamicModel.cc +++ b/src/DynamicModel.cc @@ -3841,12 +3841,8 @@ DynamicModel::fillTrendComponentModelTable() const rhsr[it.first] = rhs; nonstationaryr[it.first] = nonstationary; } - trend_component_model_table.setEqNums(eqnums); - trend_component_model_table.setTrendEqNums(trend_eqnums); - trend_component_model_table.setLhs(lhsr); trend_component_model_table.setRhs(rhsr); - trend_component_model_table.setLhsExprT(lhs_expr_tr); - trend_component_model_table.setNonstationary(nonstationaryr); + trend_component_model_table.setVals(eqnums, trend_eqnums, lhsr, lhs_expr_tr, nonstationaryr); // Fill AR Matrix map<string, map<tuple<int, int, int>, expr_t>> ARr, ECr; diff --git a/src/SubModel.cc b/src/SubModel.cc index a4da2b3e..ed6f43ff 100644 --- a/src/SubModel.cc +++ b/src/SubModel.cc @@ -42,24 +42,15 @@ TrendComponentModelTable::addTrendComponentModel(string name_arg, } void -TrendComponentModelTable::setEqNums(map<string, vector<int>> eqnums_arg) +TrendComponentModelTable::setVals(map<string, vector<int>> eqnums_arg, map<string, vector<int>> trend_eqnums_arg, + map<string, vector<int>> lhs_arg, + map<string, vector<expr_t>> lhs_expr_t_arg, map<string, vector<bool>> nonstationary_arg) { eqnums = move(eqnums_arg); - setNonTrendEqnums(); -} - -void -TrendComponentModelTable::setTrendEqNums(map<string, vector<int>> trend_eqnums_arg) -{ trend_eqnums = move(trend_eqnums_arg); - setNonTrendEqnums(); -} - -void -TrendComponentModelTable::setNonTrendEqnums() -{ - if (!nontrend_eqnums.empty() || eqnums.empty() || trend_eqnums.empty()) - return; + lhs = move(lhs_arg); + lhs_expr_t = move(lhs_expr_t_arg); + nonstationary = move(nonstationary_arg); for (const auto &it : eqnums) { @@ -69,24 +60,20 @@ TrendComponentModelTable::setNonTrendEqnums() nontrend_vec.push_back(eq); nontrend_eqnums[it.first] = nontrend_vec; } -} -void -TrendComponentModelTable::setNonstationary(map<string, vector<bool>> nonstationary_arg) -{ - nonstationary = move(nonstationary_arg); -} - -void -TrendComponentModelTable::setTrendVar(map<string, vector<int>> trend_vars_arg) -{ - trend_vars = move(trend_vars_arg); -} + for (const auto &name : names) + { + vector<int> nontrend_lhs_vec, trend_lhs_vec; + vector<int> lhsv = getLhs(name); + vector<int> eqnumsv = getEqNums(name); + for (int nontrend_it : getNonTrendEqNums(name)) + nontrend_lhs_vec.push_back(lhsv.at(distance(eqnumsv.begin(), find(eqnumsv.begin(), eqnumsv.end(), nontrend_it)))); + nontrend_lhs[name] = nontrend_lhs_vec; -void -TrendComponentModelTable::setLhs(map<string, vector<int>> lhs_arg) -{ - lhs = move(lhs_arg); + for (int trend_it : getTrendEqNums(name)) + trend_lhs_vec.push_back(lhsv.at(distance(eqnumsv.begin(), find(eqnumsv.begin(), eqnumsv.end(), trend_it)))); + trend_lhs[name] = trend_lhs_vec; + } } void @@ -96,9 +83,9 @@ TrendComponentModelTable::setRhs(map<string, vector<set<pair<int, int>>>> rhs_ar } void -TrendComponentModelTable::setLhsExprT(map<string, vector<expr_t>> lhs_expr_t_arg) +TrendComponentModelTable::setTrendVar(map<string, vector<int>> trend_vars_arg) { - lhs_expr_t = move(lhs_expr_t_arg); + trend_vars = move(trend_vars_arg); } void @@ -169,6 +156,20 @@ TrendComponentModelTable::getNonstationary(const string &name_arg) const return nonstationary.find(name_arg)->second; } +vector<int> +TrendComponentModelTable::getNontrendLhs(const string &name_arg) const +{ + checkModelName(name_arg); + return nontrend_lhs.find(name_arg)->second; +} + +vector<int> +TrendComponentModelTable::getTrendLhs(const string &name_arg) const +{ + checkModelName(name_arg); + return trend_lhs.find(name_arg)->second; +} + vector<int> TrendComponentModelTable::getLhs(const string &name_arg) const { @@ -335,35 +336,29 @@ TrendComponentModelTable::writeOutput(const string &basename, ostream &output) c output << (it >= 0 ? symbol_table.getTypeSpecificID(it) + 1 : -1) << " "; output << "];" << endl; - vector<int> nontrend_lhs, trend_lhs; - vector<int> lhsv = getLhs(name); - vector<int> eqnumsv = getEqNums(name); - for (int nontrend_it : getNonTrendEqNums(name)) - nontrend_lhs.push_back(lhsv.at(distance(eqnumsv.begin(), find(eqnumsv.begin(), eqnumsv.end(), nontrend_it)))); - - for (int trend_it : getTrendEqNums(name)) - trend_lhs.push_back(lhsv.at(distance(eqnumsv.begin(), find(eqnumsv.begin(), eqnumsv.end(), trend_it)))); + vector<int> trend_lhs_vec = getTrendLhs(name); + vector<int> nontrend_lhs_vec = getNontrendLhs(name); ar_ec_output << "if strcmp(model_name, '" << name << "')" << endl << " % AR" << endl - << " ar = zeros(" << nontrend_lhs.size() << ", " << nontrend_lhs.size() << ", " << getMaxLag(name) << ");" << endl; + << " ar = zeros(" << nontrend_lhs_vec.size() << ", " << nontrend_lhs_vec.size() << ", " << getMaxLag(name) << ");" << endl; for (const auto & it : AR.at(name)) { int eqn, lag, lhs_symb_id; tie (eqn, lag, lhs_symb_id) = it.first; - int colidx = (int) distance(nontrend_lhs.begin(), find(nontrend_lhs.begin(), nontrend_lhs.end(), lhs_symb_id)); + int colidx = (int) distance(nontrend_lhs_vec.begin(), find(nontrend_lhs_vec.begin(), nontrend_lhs_vec.end(), lhs_symb_id)); ar_ec_output << " ar(" << eqn + 1 << ", " << colidx + 1 << ", " << lag << ") = "; it.second->writeOutput(ar_ec_output, ExprNodeOutputType::matlabDynamicModel); ar_ec_output << ";" << endl; } ar_ec_output << endl << " % EC" << endl - << " ec = zeros(" << nontrend_lhs.size() << ", " << nontrend_lhs.size() << ", 1);" << endl; + << " ec = zeros(" << nontrend_lhs_vec.size() << ", " << nontrend_lhs_vec.size() << ", 1);" << endl; for (const auto & it : EC.at(name)) { int eqn, lag, lhs_symb_id; tie (eqn, lag, lhs_symb_id) = it.first; - int colidx = (int) distance(trend_lhs.begin(), find(trend_lhs.begin(), trend_lhs.end(), lhs_symb_id)); + int colidx = (int) distance(trend_lhs_vec.begin(), find(trend_lhs_vec.begin(), trend_lhs_vec.end(), lhs_symb_id)); ar_ec_output << " ec(" << eqn + 1 << ", " << colidx + 1 << ", 1) = "; it.second->writeOutput(ar_ec_output, ExprNodeOutputType::matlabDynamicModel); ar_ec_output << ";" << endl; diff --git a/src/SubModel.hh b/src/SubModel.hh index c3a0bc57..5d3c0a5b 100644 --- a/src/SubModel.hh +++ b/src/SubModel.hh @@ -40,7 +40,7 @@ private: SymbolTable &symbol_table; set<string> names; map<string, vector<string>> eqtags, trend_eqtags; - map<string, vector<int>> eqnums, trend_eqnums, nontrend_eqnums, max_lags, lhs, orig_diff_var; + map<string, vector<int>> eqnums, trend_eqnums, nontrend_eqnums, max_lags, lhs, trend_lhs, nontrend_lhs, orig_diff_var; map<string, vector<set<pair<int, int>>>> rhs; map<string, vector<bool>> diff, nonstationary; map<string, vector<expr_t>> lhs_expr_t; @@ -72,20 +72,19 @@ public: map<string, vector<int>> getNonTrendEqNums() const; vector<int> getNonTrendEqNums(const string &name_arg) const; vector<bool> getNonstationary(const string &name_arg) const; + vector<int> getNontrendLhs(const string &name_arg) const; + vector<int> getTrendLhs(const string &name_arg) const; - void setEqNums(map<string, vector<int>> eqnums_arg); - void setTrendEqNums(map<string, vector<int>> trend_eqnums_arg); - void setLhs(map<string, vector<int>> lhs_arg); + void setVals(map<string, vector<int>> eqnums_arg, map<string, vector<int>> trend_eqnums_arg, + map<string, vector<int>> lhs_arg, + map<string, vector<expr_t>> lhs_expr_t_arg, map<string, vector<bool>> nonstationary_arg); void setRhs(map<string, vector<set<pair<int, int>>>> rhs_arg); - void setLhsExprT(map<string, vector<expr_t>> lhs_expr_t_arg); void setMaxLags(map<string, vector<int>> max_lags_arg); void setDiff(map<string, vector<bool>> diff_arg); void setOrigDiffVar(map<string, vector<int>> orig_diff_var_arg); - void setNonstationary(map<string, vector<bool>> nonstationary_arg); void setTrendVar(map<string, vector<int>> trend_vars_arg); void setAR(map<string, map<tuple<int, int, int>, expr_t>> AR_arg); void setEC(map<string, map<tuple<int, int, int>, expr_t>> EC_arg); - void setNonTrendEqNums(map<string, vector<int>> trend_eqnums_arg); //! Write output of this class void writeOutput(const string &basename, ostream &output) const; -- GitLab