Commit eb74d483 authored by Houtan Bastani's avatar Houtan Bastani

trend_component_model: replace `trends` option with `targets`

parent dfb0629c
Pipeline #11 passed with stage
in 1 minute and 33 seconds
......@@ -3466,7 +3466,7 @@ DynamicModel::updateVarAndTrendModel() const
else if (i == 1)
{
eqnums = trend_component_model_table.getEqNums();
trend_eqnums = trend_component_model_table.getTrendEqNums();
trend_eqnums = trend_component_model_table.getTargetEqNums();
}
map<string, vector<int>> trend_varr;
......@@ -3509,7 +3509,7 @@ DynamicModel::updateVarAndTrendModel() const
catch (...)
{
}
int trend_var_symb_id = equations[eqn]->get_arg2()->findTrendVariable(lhs_symb_id);
int trend_var_symb_id = equations[eqn]->get_arg2()->findTargetVariable(lhs_symb_id);
if (trend_var_symb_id >= 0)
{
if (symbol_table.isAuxiliaryVariable(trend_var_symb_id))
......@@ -3541,7 +3541,7 @@ DynamicModel::updateVarAndTrendModel() const
else if (i == 1)
{
trend_component_model_table.setRhs(rhsr);
trend_component_model_table.setTrendVar(trend_varr);
trend_component_model_table.setTargetVar(trend_varr);
}
}
}
......@@ -3717,7 +3717,7 @@ void
DynamicModel::fillAutoregressiveMatrix(map<string, map<tuple<int, int, int>, expr_t>> &ARr, bool is_trend_component_model) const
{
auto eqnums = is_trend_component_model ?
trend_component_model_table.getNonTrendEqNums() : var_model_table.getEqNums();
trend_component_model_table.getNonTargetEqNums() : var_model_table.getEqNums();
for (const auto & it : eqnums)
{
int i = 0;
......@@ -3738,7 +3738,7 @@ DynamicModel::fillTrendComponentModelTable() const
map<string, vector<bool>> nonstationaryr;
map<string, vector<set<pair<int, int>>>> rhsr;
map<string, vector<string>> eqtags = trend_component_model_table.getEqTags();
map<string, vector<string>> trend_eqtags = trend_component_model_table.getTrendEqTags();
map<string, vector<string>> trend_eqtags = trend_component_model_table.getTargetEqTags();
for (const auto & it : trend_eqtags)
{
vector<int> trend_eqnumber;
......@@ -3854,8 +3854,8 @@ DynamicModel::fillErrorComponentMatrix(map<string, map<tuple<int, int, int>, exp
{
int i = 0;
map<tuple<int, int, int>, expr_t> EC;
vector<int> trend_lhs = trend_component_model_table.getTrendLhs(it.first);
vector<int> nontrend_eqnums = trend_component_model_table.getNonTrendEqNums(it.first);
vector<int> trend_lhs = trend_component_model_table.getTargetLhs(it.first);
vector<int> nontrend_eqnums = trend_component_model_table.getNonTargetEqNums(it.first);
vector<int> undiff_nontrend_lhs = getUndiffLHSForPac(it.first, diff_subst_table);
vector<int> parsed_undiff_nontrend_lhs;
......@@ -4018,7 +4018,7 @@ DynamicModel::getUndiffLHSForPac(const string &aux_model_name,
vector<bool> diff = trend_component_model_table.getDiff(aux_model_name);
vector<int> orig_diff_var = trend_component_model_table.getOrigDiffVar(aux_model_name);
vector<int> eqnumber = trend_component_model_table.getEqNums(aux_model_name);
vector<int> nontrend_eqnums = trend_component_model_table.getNonTrendEqNums(aux_model_name);
vector<int> nontrend_eqnums = trend_component_model_table.getNonTargetEqNums(aux_model_name);
for (auto eqn : nontrend_eqnums)
{
......
......@@ -111,7 +111,7 @@ class ParsingDriver;
%token PRINT PRIOR_MC PRIOR_TRUNC PRIOR_MODE PRIOR_MEAN POSTERIOR_MODE POSTERIOR_MEAN POSTERIOR_MEDIAN MLE_MODE PRUNING
%token <string> QUOTED_STRING
%token QZ_CRITERIUM QZ_ZERO_THRESHOLD DSGE_VAR DSGE_VARLAG DSGE_PRIOR_WEIGHT TRUNCATE PIPE_E PIPE_X PIPE_P
%token RELATIVE_IRF REPLIC SIMUL_REPLIC RPLOT SAVE_PARAMS_AND_STEADY_STATE PARAMETER_UNCERTAINTY TRENDS
%token RELATIVE_IRF REPLIC SIMUL_REPLIC RPLOT SAVE_PARAMS_AND_STEADY_STATE PARAMETER_UNCERTAINTY TARGETS
%token SHOCKS SHOCK_DECOMPOSITION SHOCK_GROUPS USE_SHOCK_GROUPS SIGMA_E SIMUL SIMUL_ALGO SIMUL_SEED ENDOGENOUS_TERMINAL_PERIOD
%token SMOOTHER SMOOTHER2HISTVAL SQUARE_ROOT_SOLVER STACK_SOLVE_ALGO STEADY_STATE_MODEL SOLVE_ALGO SOLVER_PERIODS ROBUST_LIN_SOLVE
%token STDERR STEADY STOCH_SIMUL SYLVESTER SYLVESTER_FIXED_POINT_TOL REGIMES REGIME REALTIME_SHOCK_DECOMPOSITION
......@@ -385,7 +385,7 @@ trend_component_model_options_list : trend_component_model_options_list COMMA tr
;
trend_component_model_options : o_trend_component_model_name
| o_trend_component_model_trends
| o_trend_component_model_targets
| o_trend_component_model_eq_tags
;
......@@ -3209,7 +3209,7 @@ o_nobs : NOBS EQUAL vec_int
{ driver.option_vec_int("nobs", $3); }
;
o_trend_component_model_name : MODEL_NAME EQUAL symbol { driver.option_str("trend_component.name", $3); };
o_trend_component_model_trends : TRENDS EQUAL vec_str { driver.option_vec_str("trend_component.trends", $3); }
o_trend_component_model_targets : TARGETS EQUAL vec_str { driver.option_vec_str("trend_component.targets", $3); }
o_trend_component_model_eq_tags : EQTAGS EQUAL vec_str { driver.option_vec_str("trend_component.eqtags", $3); }
o_conditional_variance_decomposition : CONDITIONAL_VARIANCE_DECOMPOSITION EQUAL vec_int
{ driver.option_vec_int("conditional_variance_decomposition", $3); }
......
......@@ -490,7 +490,7 @@ DATE -?[0-9]+([YyAa]|[Mm]([1-9]|1[0-2])|[Qq][1-4]|[Ww]([1-9]{1}|[1-4][0-9]|5[0-2
}
<DYNARE_STATEMENT>write_equation_tags {return token::WRITE_EQUATION_TAGS;}
<DYNARE_STATEMENT>eqtags {return token::EQTAGS;}
<DYNARE_STATEMENT>trends {return token::TRENDS;}
<DYNARE_STATEMENT>targets {return token::TARGETS;}
<DYNARE_STATEMENT>indxap {return token::INDXAP;}
<DYNARE_STATEMENT>apband {return token::APBAND;}
<DYNARE_STATEMENT>indximf {return token::INDXIMF;}
......
......@@ -557,7 +557,7 @@ NumConstNode::findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_
}
int
NumConstNode::findTrendVariable(int lhs_symb_id) const
NumConstNode::findTargetVariable(int lhs_symb_id) const
{
return -1;
}
......@@ -1506,7 +1506,7 @@ VariableNode::findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_
}
int
VariableNode::findTrendVariable(int lhs_symb_id) const
VariableNode::findTargetVariable(int lhs_symb_id) const
{
return -1;
}
......@@ -3219,9 +3219,9 @@ UnaryOpNode::findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table)
}
int
UnaryOpNode::findTrendVariable(int lhs_symb_id) const
UnaryOpNode::findTargetVariable(int lhs_symb_id) const
{
return arg->findTrendVariable(lhs_symb_id);
return arg->findTargetVariable(lhs_symb_id);
}
expr_t
......@@ -5146,7 +5146,7 @@ BinaryOpNode::isInStaticForm() const
}
bool
BinaryOpNode::findTrendVariableHelper1(int lhs_symb_id, int rhs_symb_id) const
BinaryOpNode::findTargetVariableHelper1(int lhs_symb_id, int rhs_symb_id) const
{
if (lhs_symb_id == rhs_symb_id)
return true;
......@@ -5164,8 +5164,8 @@ BinaryOpNode::findTrendVariableHelper1(int lhs_symb_id, int rhs_symb_id) const
}
int
BinaryOpNode::findTrendVariableHelper(const expr_t arg1, const expr_t arg2,
int lhs_symb_id) const
BinaryOpNode::findTargetVariableHelper(const expr_t arg1, const expr_t arg2,
int lhs_symb_id) const
{
set<int> params;
arg1->collectVariables(SymbolType::parameter, params);
......@@ -5182,9 +5182,9 @@ BinaryOpNode::findTrendVariableHelper(const expr_t arg1, const expr_t arg2,
auto *test_arg1 = dynamic_cast<VariableNode *>(testarg2->get_arg1());
auto *test_arg2 = dynamic_cast<VariableNode *>(testarg2->get_arg2());
if (test_arg1 != nullptr && test_arg2 != nullptr )
if (findTrendVariableHelper1(lhs_symb_id, endogs.begin()->first))
if (findTargetVariableHelper1(lhs_symb_id, endogs.begin()->first))
return endogs.rbegin()->first;
else if (findTrendVariableHelper1(lhs_symb_id, endogs.rbegin()->first))
else if (findTargetVariableHelper1(lhs_symb_id, endogs.rbegin()->first))
return endogs.begin()->first;
}
}
......@@ -5192,15 +5192,15 @@ BinaryOpNode::findTrendVariableHelper(const expr_t arg1, const expr_t arg2,
}
int
BinaryOpNode::findTrendVariable(int lhs_symb_id) const
BinaryOpNode::findTargetVariable(int lhs_symb_id) const
{
int retval = findTrendVariableHelper(arg1, arg2, lhs_symb_id);
int retval = findTargetVariableHelper(arg1, arg2, lhs_symb_id);
if (retval < 0)
retval = findTrendVariableHelper(arg2, arg1, lhs_symb_id);
retval = findTargetVariableHelper(arg2, arg1, lhs_symb_id);
if (retval < 0)
retval = arg1->findTrendVariable(lhs_symb_id);
retval = arg1->findTargetVariable(lhs_symb_id);
if (retval < 0)
retval = arg2->findTrendVariable(lhs_symb_id);
retval = arg2->findTargetVariable(lhs_symb_id);
return retval;
}
......@@ -6341,13 +6341,13 @@ TrinaryOpNode::findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff
}
int
TrinaryOpNode::findTrendVariable(int lhs_symb_id) const
TrinaryOpNode::findTargetVariable(int lhs_symb_id) const
{
int retval = arg1->findTrendVariable(lhs_symb_id);
int retval = arg1->findTargetVariable(lhs_symb_id);
if (retval < 0)
retval = arg2->findTrendVariable(lhs_symb_id);
retval = arg2->findTargetVariable(lhs_symb_id);
if (retval < 0)
retval = arg3->findTrendVariable(lhs_symb_id);
retval = arg3->findTargetVariable(lhs_symb_id);
return retval;
}
......@@ -6831,11 +6831,11 @@ AbstractExternalFunctionNode::findUnaryOpNodesForAuxVarCreation(DataTree &static
}
int
AbstractExternalFunctionNode::findTrendVariable(int lhs_symb_id) const
AbstractExternalFunctionNode::findTargetVariable(int lhs_symb_id) const
{
for (auto argument : arguments)
{
int retval = argument->findTrendVariable(lhs_symb_id);
int retval = argument->findTargetVariable(lhs_symb_id);
if (retval >= 0)
return retval;
}
......@@ -8488,7 +8488,7 @@ VarExpectationNode::findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree,
}
int
VarExpectationNode::findTrendVariable(int lhs_symb_id) const
VarExpectationNode::findTargetVariable(int lhs_symb_id) const
{
return -1;
}
......@@ -9033,7 +9033,7 @@ PacExpectationNode::findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree,
}
int
PacExpectationNode::findTrendVariable(int lhs_symb_id) const
PacExpectationNode::findTargetVariable(int lhs_symb_id) const
{
return -1;
}
......
......@@ -523,7 +523,7 @@ class ExprNode
//! Substitute diff operator
virtual void findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const = 0;
virtual void findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const = 0;
virtual int findTrendVariable(int lhs_symb_id) const = 0;
virtual int findTargetVariable(int lhs_symb_id) const = 0;
virtual expr_t substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const = 0;
virtual expr_t substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const = 0;
......@@ -643,7 +643,7 @@ public:
expr_t substituteVarExpectation(const map<string, expr_t> &subst_table) const override;
void findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const override;
void findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const override;
int findTrendVariable(int lhs_symb_id) const override;
int findTargetVariable(int lhs_symb_id) const override;
expr_t substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substitutePacExpectation(map<const PacExpectationNode *, const BinaryOpNode *> &subst_table) override;
......@@ -744,7 +744,7 @@ public:
expr_t substituteVarExpectation(const map<string, expr_t> &subst_table) const override;
void findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const override;
void findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const override;
int findTrendVariable(int lhs_symb_id) const override;
int findTargetVariable(int lhs_symb_id) const override;
expr_t substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substitutePacExpectation(map<const PacExpectationNode *, const BinaryOpNode *> &subst_table) override;
......@@ -869,7 +869,7 @@ public:
void findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const override;
bool createAuxVarForUnaryOpNode() const;
void findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const override;
int findTrendVariable(int lhs_symb_id) const override;
int findTargetVariable(int lhs_symb_id) const override;
expr_t substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substitutePacExpectation(map<const PacExpectationNode *, const BinaryOpNode *> &subst_table) override;
......@@ -1014,9 +1014,9 @@ public:
expr_t substituteVarExpectation(const map<string, expr_t> &subst_table) const override;
void findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const override;
void findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const override;
bool findTrendVariableHelper1(int lhs_symb_id, int rhs_symb_id) const;
int findTrendVariableHelper(const expr_t arg1, const expr_t arg2, int lhs_symb_id) const;
int findTrendVariable(int lhs_symb_id) const override;
bool findTargetVariableHelper1(int lhs_symb_id, int rhs_symb_id) const;
int findTargetVariableHelper(const expr_t arg1, const expr_t arg2, int lhs_symb_id) const;
int findTargetVariable(int lhs_symb_id) const override;
expr_t substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substitutePacExpectation(map<const PacExpectationNode *, const BinaryOpNode *> &subst_table) override;
......@@ -1136,7 +1136,7 @@ public:
expr_t substituteVarExpectation(const map<string, expr_t> &subst_table) const override;
void findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const override;
void findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const override;
int findTrendVariable(int lhs_symb_id) const override;
int findTargetVariable(int lhs_symb_id) const override;
expr_t substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substitutePacExpectation(map<const PacExpectationNode *, const BinaryOpNode *> &subst_table) override;
......@@ -1257,7 +1257,7 @@ public:
expr_t substituteVarExpectation(const map<string, expr_t> &subst_table) const override;
void findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const override;
void findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const override;
int findTrendVariable(int lhs_symb_id) const override;
int findTargetVariable(int lhs_symb_id) const override;
expr_t substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substitutePacExpectation(map<const PacExpectationNode *, const BinaryOpNode *> &subst_table) override;
......@@ -1462,7 +1462,7 @@ public:
expr_t substituteVarExpectation(const map<string, expr_t> &subst_table) const override;
void findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const override;
void findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const override;
int findTrendVariable(int lhs_symb_id) const override;
int findTargetVariable(int lhs_symb_id) const override;
expr_t substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substitutePacExpectation(map<const PacExpectationNode *, const BinaryOpNode *> &subst_table) override;
......@@ -1560,7 +1560,7 @@ public:
expr_t substituteVarExpectation(const map<string, expr_t> &subst_table) const override;
void findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const override;
void findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const override;
int findTrendVariable(int lhs_symb_id) const override;
int findTargetVariable(int lhs_symb_id) const override;
expr_t substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substitutePacExpectation(map<const PacExpectationNode *, const BinaryOpNode *> &subst_table) override;
......
......@@ -1489,12 +1489,12 @@ ParsingDriver::trend_component_model()
error("You must pass the eqtags option to the trend_component_model statement.");
auto eqtags = itvs->second;
const auto itvs1 = options_list.vector_str_options.find("trend_component.trends");
const auto itvs1 = options_list.vector_str_options.find("trend_component.targets");
if (itvs1 == options_list.vector_str_options.end())
error("You must pass the trends option to the trend_component_model statement.");
auto trends = itvs1->second;
error("You must pass the targets option to the trend_component_model statement.");
auto targets = itvs1->second;
mod_file->trend_component_model_table.addTrendComponentModel(name, eqtags, trends);
mod_file->trend_component_model_table.addTrendComponentModel(name, eqtags, targets);
options_list.clear();
}
......
......@@ -29,7 +29,7 @@ TrendComponentModelTable::TrendComponentModelTable(SymbolTable &symbol_table_arg
void
TrendComponentModelTable::addTrendComponentModel(string name_arg,
vector<string> eqtags_arg,
vector<string> trend_eqtags_arg)
vector<string> target_eqtags_arg)
{
if (isExistingTrendComponentModelName(name_arg))
{
......@@ -37,17 +37,17 @@ TrendComponentModelTable::addTrendComponentModel(string name_arg,
exit(EXIT_FAILURE);
}
eqtags[name_arg] = move(eqtags_arg);
trend_eqtags[name_arg] = move(trend_eqtags_arg);
target_eqtags[name_arg] = move(target_eqtags_arg);
names.insert(move(name_arg));
}
void
TrendComponentModelTable::setVals(map<string, vector<int>> eqnums_arg, map<string, vector<int>> trend_eqnums_arg,
TrendComponentModelTable::setVals(map<string, vector<int>> eqnums_arg, map<string, vector<int>> target_eqnums_arg,
map<string, vector<int>> lhs_arg,
map<string, vector<expr_t>> lhs_expr_t_arg, map<string, vector<bool>> nonstationary_arg)
{
eqnums = move(eqnums_arg);
trend_eqnums = move(trend_eqnums_arg);
target_eqnums = move(target_eqnums_arg);
lhs = move(lhs_arg);
lhs_expr_t = move(lhs_expr_t_arg);
nonstationary = move(nonstationary_arg);
......@@ -56,23 +56,23 @@ TrendComponentModelTable::setVals(map<string, vector<int>> eqnums_arg, map<strin
{
vector<int> nontrend_vec;
for (auto eq : it.second)
if (find(trend_eqnums[it.first].begin(), trend_eqnums[it.first].end(), eq) == trend_eqnums[it.first].end())
if (find(target_eqnums[it.first].begin(), target_eqnums[it.first].end(), eq) == target_eqnums[it.first].end())
nontrend_vec.push_back(eq);
nontrend_eqnums[it.first] = nontrend_vec;
nontarget_eqnums[it.first] = nontrend_vec;
}
for (const auto &name : names)
{
vector<int> nontrend_lhs_vec, trend_lhs_vec;
vector<int> nontarget_lhs_vec, target_lhs_vec;
vector<int> lhsv = getLhs(name);
vector<int> eqnumsv = getEqNums(name);
for (int nontrend_it : getNonTrendEqNums(name))
nontrend_lhs_vec.push_back(lhsv.at(distance(eqnumsv.begin(), find(eqnumsv.begin(), eqnumsv.end(), nontrend_it))));
nontrend_lhs[name] = nontrend_lhs_vec;
for (int nontrend_it : getNonTargetEqNums(name))
nontarget_lhs_vec.push_back(lhsv.at(distance(eqnumsv.begin(), find(eqnumsv.begin(), eqnumsv.end(), nontrend_it))));
nontarget_lhs[name] = nontarget_lhs_vec;
for (int trend_it : getTrendEqNums(name))
trend_lhs_vec.push_back(lhsv.at(distance(eqnumsv.begin(), find(eqnumsv.begin(), eqnumsv.end(), trend_it))));
trend_lhs[name] = trend_lhs_vec;
for (int trend_it : getTargetEqNums(name))
target_lhs_vec.push_back(lhsv.at(distance(eqnumsv.begin(), find(eqnumsv.begin(), eqnumsv.end(), trend_it))));
target_lhs[name] = target_lhs_vec;
}
}
......@@ -83,9 +83,9 @@ TrendComponentModelTable::setRhs(map<string, vector<set<pair<int, int>>>> rhs_ar
}
void
TrendComponentModelTable::setTrendVar(map<string, vector<int>> trend_vars_arg)
TrendComponentModelTable::setTargetVar(map<string, vector<int>> target_vars_arg)
{
trend_vars = move(trend_vars_arg);
target_vars = move(target_vars_arg);
}
void
......@@ -160,14 +160,14 @@ vector<int>
TrendComponentModelTable::getNontrendLhs(const string &name_arg) const
{
checkModelName(name_arg);
return nontrend_lhs.find(name_arg)->second;
return nontarget_lhs.find(name_arg)->second;
}
vector<int>
TrendComponentModelTable::getTrendLhs(const string &name_arg) const
TrendComponentModelTable::getTargetLhs(const string &name_arg) const
{
checkModelName(name_arg);
return trend_lhs.find(name_arg)->second;
return target_lhs.find(name_arg)->second;
}
vector<int>
......@@ -185,9 +185,9 @@ TrendComponentModelTable::getLhsExprT(const string &name_arg) const
}
map<string, vector<string>>
TrendComponentModelTable::getTrendEqTags() const
TrendComponentModelTable::getTargetEqTags() const
{
return trend_eqtags;
return target_eqtags;
}
map<string, vector<int>>
......@@ -197,29 +197,29 @@ TrendComponentModelTable::getEqNums() const
}
map<string, vector<int>>
TrendComponentModelTable::getTrendEqNums() const
TrendComponentModelTable::getTargetEqNums() const
{
return trend_eqnums;
return target_eqnums;
}
vector<int>
TrendComponentModelTable::getTrendEqNums(const string &name_arg) const
TrendComponentModelTable::getTargetEqNums(const string &name_arg) const
{
checkModelName(name_arg);
return trend_eqnums.find(name_arg)->second;
return target_eqnums.find(name_arg)->second;
}
map<string, vector<int>>
TrendComponentModelTable::getNonTrendEqNums() const
TrendComponentModelTable::getNonTargetEqNums() const
{
return nontrend_eqnums;
return nontarget_eqnums;
}
vector<int>
TrendComponentModelTable::getNonTrendEqNums(const string &name_arg) const
TrendComponentModelTable::getNonTargetEqNums(const string &name_arg) const
{
checkModelName(name_arg);
return nontrend_eqnums.find(name_arg)->second;
return nontarget_eqnums.find(name_arg)->second;
}
vector<int>
......@@ -285,14 +285,14 @@ TrendComponentModelTable::writeOutput(const string &basename, ostream &output) c
for (auto it : eqnums.at(name))
output << it + 1 << " ";
output << "];" << endl
<< "M_.trend_component." << name << ".trend_eqn = [";
for (auto it : trend_eqnums.at(name))
<< "M_.trend_component." << name << ".target_eqn = [";
for (auto it : target_eqnums.at(name))
output << it + 1 << " ";
output << "];" << endl
<< "M_.trend_component." << name << ".trends = [";
<< "M_.trend_component." << name << ".targets = [";
for (auto it : eqnums.at(name))
if (find(trend_eqnums.at(name).begin(), trend_eqnums.at(name).end(), it)
== trend_eqnums.at(name).end())
if (find(target_eqnums.at(name).begin(), target_eqnums.at(name).end(), it)
== target_eqnums.at(name).end())
output << "false ";
else
output << "true ";
......@@ -331,29 +331,29 @@ TrendComponentModelTable::writeOutput(const string &basename, ostream &output) c
i++;
}
output << "M_.trend_component." << name << ".trend_vars = [";
for (auto it : trend_vars.at(name))
output << "M_.trend_component." << name << ".target_vars = [";
for (auto it : target_vars.at(name))
output << (it >= 0 ? symbol_table.getTypeSpecificID(it) + 1 : -1) << " ";
output << "];" << endl;
vector<int> trend_lhs_vec = getTrendLhs(name);
vector<int> nontrend_lhs_vec = getNontrendLhs(name);
vector<int> target_lhs_vec = getTargetLhs(name);
vector<int> nontarget_lhs_vec = getNontrendLhs(name);
ar_ec_output << "if strcmp(model_name, '" << name << "')" << endl
<< " % AR" << endl
<< " ar = zeros(" << nontrend_lhs_vec.size() << ", " << nontrend_lhs_vec.size() << ", " << getMaxLag(name) << ");" << endl;
<< " ar = zeros(" << nontarget_lhs_vec.size() << ", " << nontarget_lhs_vec.size() << ", " << getMaxLag(name) << ");" << endl;
for (const auto & it : AR.at(name))
{
int eqn, lag, lhs_symb_id;
tie (eqn, lag, lhs_symb_id) = it.first;
int colidx = (int) distance(nontrend_lhs_vec.begin(), find(nontrend_lhs_vec.begin(), nontrend_lhs_vec.end(), lhs_symb_id));
int colidx = (int) distance(nontarget_lhs_vec.begin(), find(nontarget_lhs_vec.begin(), nontarget_lhs_vec.end(), lhs_symb_id));
ar_ec_output << " ar(" << eqn + 1 << ", " << colidx + 1 << ", " << lag << ") = ";
it.second->writeOutput(ar_ec_output, ExprNodeOutputType::matlabDynamicModel);
ar_ec_output << ";" << endl;
}
ar_ec_output << endl
<< " % EC" << endl
<< " ec = zeros(" << nontrend_lhs_vec.size() << ", " << nontrend_lhs_vec.size() << ", 1);" << endl;
<< " ec = zeros(" << nontarget_lhs_vec.size() << ", " << nontarget_lhs_vec.size() << ", 1);" << endl;
for (const auto & it : EC.at(name))
{
int eqn, lag, colidx;
......@@ -386,11 +386,11 @@ TrendComponentModelTable::writeJsonOutput(ostream &output) const
if (&it != &eqtags.at(name).back())
output << ", ";
}
output << "], \"trend_eqtags\": [";
for (const auto &it : trend_eqtags.at(name))
output << "], \"target_eqtags\": [";
for (const auto &it : target_eqtags.at(name))
{
output << "\"" << it << "\"";
if (&it != &trend_eqtags.at(name).back())
if (&it != &target_eqtags.at(name).back())
output << ", ";
}
output << "]}";
......
......@@ -11,7 +11,7 @@
* Dynare is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
* GNU General Public License for more details.SS
*
* You should have received a copy of the GNU General Public License
* along with Dynare. If not, see <http://www.gnu.org/licenses/>.
......@@ -39,12 +39,12 @@ class TrendComponentModelTable
private:
SymbolTable &symbol_table;
set<string> names;
map<string, vector<string>> eqtags, trend_eqtags;
map<string, vector<int>> eqnums, trend_eqnums, nontrend_eqnums, max_lags, lhs, trend_lhs, nontrend_lhs, orig_diff_var;
map<string, vector<string>> eqtags, target_eqtags;
map<string, vector<int>> eqnums, target_eqnums, nontarget_eqnums, max_lags, lhs, target_lhs, nontarget_lhs, orig_diff_var;
map<string, vector<set<pair<int, int>>>> rhs;
map<string, vector<bool>> diff, nonstationary;
map<string, vector<expr_t>> lhs_expr_t;
map<string, vector<int>> trend_vars;
map<string, vector<int>> target_vars;
map<string, map<tuple<int, int, int>, expr_t>> AR; // AR: name -> (eqn, lag, lhs_symb_id) -> expr_t
map<string, map<tuple<int, int, int>, expr_t>> EC; // EC: name -> (eqn, lag, col) -> expr_t
public:
......@@ -52,17 +52,17 @@ public:
//! Add a trend component model
void addTrendComponentModel(string name_arg, vector<string> eqtags_arg,
vector<string> trend_eqtags_arg);
vector<string> target_eqtags_arg);
inline bool isExistingTrendComponentModelName(const string &name_arg) const;
inline bool empty() const;
map<string, vector<string>> getEqTags() const;
vector<string> getEqTags(const string &name_arg) const;
map<string, vector<string>> getTrendEqTags() const;
map<string, vector<string>> getTargetEqTags() const;
map<string, vector<int>> getEqNums() const;
map<string, vector<int>> getTrendEqNums() const;
vector<int> getTrendEqNums(const string &name_arg) const;
map<string, vector<int>> getTargetEqNums() const;
vector<int> getTargetEqNums(const string &name_arg) const;
vector<int> getEqNums(const string &name_arg) const;
vector<int> getMaxLags(const string &name_arg) const;
int getMaxLag(const string &name_arg) const;
......@@ -70,20 +70,20 @@ public:
vector<expr_t> getLhsExprT(const string &name_arg) const;
vector<bool> getDiff(const string &name_arg) const;
vector<int> getOrigDiffVar(const string &name_arg) const;
map<string, vector<int>> getNonTrendEqNums() const;
vector<int> getNonTrendEqNums(const string &name_arg) const;
map<string, vector<int>> getNonTargetEqNums() const;
vector<int> getNonTargetEqNums(const string &name_arg) const;
vector<bool> getNonstationary(const string &name_arg) const;
vector<int> getNontrendLhs(const string &name_arg) const;
vector<int> getTrendLhs(const string &name_arg) const;
vector<int> getTargetLhs(const string &name_arg) const;
void setVals(map<string, vector<int>> eqnums_arg, map<string, vector<int>> trend_eqnums_arg,
void setVals(map<string, vector<int>> eqnums_arg, map<string, vector<int>> target_eqnums_arg,
map<string, vector<int>> lhs_arg,
map<string, vector<expr_t>> lhs_expr_t_arg, map<string, vector<bool>> nonstationary_arg);
void setRhs(map<string, vector<set<pair<int, int>>>> rhs_arg);
void setMaxLags(map<string, vector<int>> max_lags_arg);
void setDiff(map<string, vector<bool>> diff_arg);
void setOrigDiffVar(map<string, vector<int>> orig_diff_var_arg);
void setTrendVar(map<string, vector<int>> trend_vars_arg);
void setTargetVar(map<string, vector<int>> target_vars_arg);
void setAR(map<string, map<tuple<int, int, int>, expr_t>> AR_arg);
void setEC(map<string, map<tuple<int, int, int>, expr_t>> EC_arg);
......@@ -95,7 +95,7 @@ public:
private:
void checkModelName(const string &name_arg) const;
void setNonTrendEqnums();
void setNonTargetEqnums();