Commit fde836d7 authored by Houtan Bastani's avatar Houtan Bastani

trend_component_model: find trend_vars associated with equation

parent 52da3ecf
......@@ -3466,23 +3466,46 @@ DynamicModel::updateVarAndTrendModelRhs() const
else if (i == 1)
eqnums = trend_component_model_table.getEqNums();
map<string, vector<int>> trend_varr;
map<string, vector<set<pair<int, int>>>> rhsr;
for (const auto & it : eqnums)
{
vector<int> lhs;
vector<int> trend_var;
vector<set<pair<int, int>>> rhs;
int lhs_idx = 0;
if (i == 1)
lhs = trend_component_model_table.getLhs(it.first);
for (auto eqn : it.second)
{
set<pair<int, int>> rhs_set;
equations[eqn]->get_arg2()->collectDynamicVariables(SymbolType::endogenous, rhs_set);
rhs.push_back(rhs_set);
if (i == 1)
{
int lhs_symb_id = lhs[lhs_idx++];
if (symbol_table.isAuxiliaryVariable(lhs_symb_id))
try
{
lhs_symb_id = symbol_table.getOrigSymbIdForAuxVar(lhs_symb_id);
}
catch (...)
{
}
trend_var.push_back(equations[eqn]->get_arg2()->findTrendVariable(lhs_symb_id));
}
}
rhsr[it.first] = rhs;
trend_varr[it.first] = trend_var;
}
if (i == 0)
var_model_table.setRhs(rhsr);
else if (i == 1)
trend_component_model_table.setRhs(rhsr);
{
trend_component_model_table.setRhs(rhsr);
trend_component_model_table.setTrendVar(trend_varr);
}
}
}
......
......@@ -563,6 +563,12 @@ NumConstNode::findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_
{
}
int
NumConstNode::findTrendVariable(int lhs_symb_id) const
{
return -1;
}
expr_t
NumConstNode::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
{
......@@ -1458,6 +1464,12 @@ VariableNode::findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_
{
}
int
VariableNode::findTrendVariable(int lhs_symb_id) const
{
return -1;
}
expr_t
VariableNode::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table,
vector<BinaryOpNode *> &neweqs) const
......@@ -3157,6 +3169,12 @@ UnaryOpNode::findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table)
diff_table[sthis][arg_max_lag] = const_cast<UnaryOpNode *>(this);
}
int
UnaryOpNode::findTrendVariable(int lhs_symb_id) const
{
return arg->findTrendVariable(lhs_symb_id);
}
expr_t
UnaryOpNode::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table,
vector<BinaryOpNode *> &neweqs) const
......@@ -5078,6 +5096,65 @@ BinaryOpNode::isInStaticForm() const
return arg1->isInStaticForm() && arg2->isInStaticForm();
}
bool
BinaryOpNode::findTrendVariableHelper1(int lhs_symb_id, int rhs_symb_id) const
{
if (lhs_symb_id == rhs_symb_id)
return true;
try
{
if (datatree.symbol_table.isAuxiliaryVariable(rhs_symb_id)
&& lhs_symb_id == datatree.symbol_table.getOrigSymbIdForAuxVar(rhs_symb_id))
return true;
}
catch (...)
{
}
return false;
}
int
BinaryOpNode::findTrendVariableHelper(const expr_t arg1, const expr_t arg2,
int lhs_symb_id) const
{
set<int> params;
arg1->collectVariables(SymbolType::parameter, params);
if (params.size() != 1)
return -1;
set<pair<int, int>> endogs;
arg2->collectDynamicVariables(SymbolType::endogenous, endogs);
if (endogs.size() == 2)
{
auto *testarg2 = dynamic_cast<BinaryOpNode *>(arg2);
if (testarg2 != nullptr && testarg2->get_op_code() == BinaryOpcode::minus)
{
auto *test_arg1 = dynamic_cast<VariableNode *>(testarg2->get_arg1());
auto *test_arg2 = dynamic_cast<VariableNode *>(testarg2->get_arg2());
if (test_arg1 != nullptr && test_arg2 != nullptr )
if (findTrendVariableHelper1(lhs_symb_id, endogs.begin()->first))
return endogs.rbegin()->first;
else if (findTrendVariableHelper1(lhs_symb_id, endogs.rbegin()->first))
return endogs.begin()->first;
}
}
return -1;
}
int
BinaryOpNode::findTrendVariable(int lhs_symb_id) const
{
int retval = findTrendVariableHelper(arg1, arg2, lhs_symb_id);
if (retval < 0)
retval = findTrendVariableHelper(arg2, arg1, lhs_symb_id);
if (retval < 0)
retval = arg1->findTrendVariable(lhs_symb_id);
if (retval < 0)
retval = arg2->findTrendVariable(lhs_symb_id);
return retval;
}
void
BinaryOpNode::getPacOptimizingPartHelper(const expr_t arg1, const expr_t arg2,
pair<int, vector<int>> &ec_params_and_vars,
......@@ -6073,6 +6150,17 @@ TrinaryOpNode::findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff
arg3->findUnaryOpNodesForAuxVarCreation(static_datatree, nodes);
}
int
TrinaryOpNode::findTrendVariable(int lhs_symb_id) const
{
int retval = arg1->findTrendVariable(lhs_symb_id);
if (retval < 0)
retval = arg2->findTrendVariable(lhs_symb_id);
if (retval < 0)
retval = arg3->findTrendVariable(lhs_symb_id);
return retval;
}
expr_t
TrinaryOpNode::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table,
vector<BinaryOpNode *> &neweqs) const
......@@ -6535,6 +6623,18 @@ AbstractExternalFunctionNode::findUnaryOpNodesForAuxVarCreation(DataTree &static
argument->findUnaryOpNodesForAuxVarCreation(static_datatree, nodes);
}
int
AbstractExternalFunctionNode::findTrendVariable(int lhs_symb_id) const
{
for (auto argument : arguments)
{
int retval = argument->findTrendVariable(lhs_symb_id);
if (retval >= 0)
return retval;
}
return -1;
}
expr_t
AbstractExternalFunctionNode::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table,
vector<BinaryOpNode *> &neweqs) const
......@@ -8168,6 +8268,12 @@ VarExpectationNode::findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree,
{
}
int
VarExpectationNode::findTrendVariable(int lhs_symb_id) const
{
return -1;
}
expr_t
VarExpectationNode::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table,
vector<BinaryOpNode *> &neweqs) const
......@@ -8689,6 +8795,12 @@ PacExpectationNode::findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree,
{
}
int
PacExpectationNode::findTrendVariable(int lhs_symb_id) const
{
return -1;
}
expr_t
PacExpectationNode::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table,
vector<BinaryOpNode *> &neweqs) const
......
......@@ -505,6 +505,7 @@ class ExprNode
//! Substitute diff operator
virtual void findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const = 0;
virtual void findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const = 0;
virtual int findTrendVariable(int lhs_symb_id) const = 0;
virtual expr_t substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const = 0;
virtual expr_t substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const = 0;
......@@ -617,6 +618,7 @@ public:
expr_t substituteVarExpectation(const map<string, expr_t> &subst_table) const override;
void findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const override;
void findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const override;
int findTrendVariable(int lhs_symb_id) const override;
expr_t substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substitutePacExpectation(map<const PacExpectationNode *, const BinaryOpNode *> &subst_table) override;
......@@ -715,6 +717,7 @@ public:
expr_t substituteVarExpectation(const map<string, expr_t> &subst_table) const override;
void findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const override;
void findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const override;
int findTrendVariable(int lhs_symb_id) const override;
expr_t substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substitutePacExpectation(map<const PacExpectationNode *, const BinaryOpNode *> &subst_table) override;
......@@ -837,6 +840,7 @@ public:
void findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const override;
bool createAuxVarForUnaryOpNode() const;
void findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const override;
int findTrendVariable(int lhs_symb_id) const override;
expr_t substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substitutePacExpectation(map<const PacExpectationNode *, const BinaryOpNode *> &subst_table) override;
......@@ -980,6 +984,9 @@ public:
expr_t substituteVarExpectation(const map<string, expr_t> &subst_table) const override;
void findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const override;
void findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const override;
bool findTrendVariableHelper1(int lhs_symb_id, int rhs_symb_id) const;
int findTrendVariableHelper(const expr_t arg1, const expr_t arg2, int lhs_symb_id) const;
int findTrendVariable(int lhs_symb_id) const override;
expr_t substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substitutePacExpectation(map<const PacExpectationNode *, const BinaryOpNode *> &subst_table) override;
......@@ -1092,6 +1099,7 @@ public:
expr_t substituteVarExpectation(const map<string, expr_t> &subst_table) const override;
void findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const override;
void findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const override;
int findTrendVariable(int lhs_symb_id) const override;
expr_t substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substitutePacExpectation(map<const PacExpectationNode *, const BinaryOpNode *> &subst_table) override;
......@@ -1210,6 +1218,7 @@ public:
expr_t substituteVarExpectation(const map<string, expr_t> &subst_table) const override;
void findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const override;
void findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const override;
int findTrendVariable(int lhs_symb_id) const override;
expr_t substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substitutePacExpectation(map<const PacExpectationNode *, const BinaryOpNode *> &subst_table) override;
......@@ -1412,6 +1421,7 @@ public:
expr_t substituteVarExpectation(const map<string, expr_t> &subst_table) const override;
void findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const override;
void findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const override;
int findTrendVariable(int lhs_symb_id) const override;
expr_t substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substitutePacExpectation(map<const PacExpectationNode *, const BinaryOpNode *> &subst_table) override;
......@@ -1507,6 +1517,7 @@ public:
expr_t substituteVarExpectation(const map<string, expr_t> &subst_table) const override;
void findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const override;
void findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const override;
int findTrendVariable(int lhs_symb_id) const override;
expr_t substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substitutePacExpectation(map<const PacExpectationNode *, const BinaryOpNode *> &subst_table) override;
......
......@@ -78,6 +78,12 @@ TrendComponentModelTable::setNonstationary(map<string, vector<bool>> nonstationa
nonstationary = move(nonstationary_arg);
}
void
TrendComponentModelTable::setTrendVar(map<string, vector<int>> trend_vars_arg)
{
trend_vars = move(trend_vars_arg);
}
void
TrendComponentModelTable::setLhs(map<string, vector<int>> lhs_arg)
{
......@@ -275,6 +281,10 @@ TrendComponentModelTable::writeOutput(ostream &output) const
i++;
}
output << "M_.trend_component." << name << ".trend_vars = [";
for (auto it : trend_vars.at(name))
output << (it >= 0 ? symbol_table.getTypeSpecificID(it) + 1 : -1) << " ";
output << "];" << endl;
}
}
......
......@@ -44,6 +44,7 @@ private:
map<string, vector<set<pair<int, int>>>> rhs;
map<string, vector<bool>> diff, nonstationary;
map<string, vector<expr_t>> lhs_expr_t;
map<string, vector<int>> trend_vars;
public:
TrendComponentModelTable(SymbolTable &symbol_table_arg);
......@@ -77,6 +78,7 @@ public:
void setDiff(map<string, vector<bool>> diff_arg);
void setOrigDiffVar(map<string, vector<int>> orig_diff_var_arg);
void setNonstationary(map<string, vector<bool>> nonstationary_arg);
void setTrendVar(map<string, vector<int>> trend_vars_arg);
//! Write output of this class
void writeOutput(ostream &output) const;
......
......@@ -865,7 +865,10 @@ int
SymbolTable::getOrigSymbIdForAuxVar(int aux_var_symb_id) const noexcept(false)
{
for (const auto & aux_var : aux_vars)
if ((aux_var.get_type() == AuxVarType::endoLag || aux_var.get_type() == AuxVarType::exoLag || aux_var.get_type() == AuxVarType::diff)
if ((aux_var.get_type() == AuxVarType::endoLag
|| aux_var.get_type() == AuxVarType::exoLag
|| aux_var.get_type() == AuxVarType::diff
|| aux_var.get_type() == AuxVarType::diffLag)
&& aux_var.get_symb_id() == aux_var_symb_id)
return aux_var.get_orig_symb_id();
throw UnknownSymbolIDException(aux_var_symb_id);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment