Commit 84c2dc5f authored by Houtan Bastani's avatar Houtan Bastani
Browse files

transform_unary_ops now introduces aux variables/equations for all unary ops...

transform_unary_ops now introduces aux variables/equations for all unary ops specified by UnaryOpNode::createAuxVarForUnaryOpNode()

In the absence of this option, if a var_model statement(s) is present, then aux vars/eqs are created for the same types of unary operators but only for equations specified in the var_model statement

In the absence of both this option and var_model statements, no unary op auxiliary variables are created

diffs continue to be substituted everywhere; for the moment auxiliary variables are created for diffs of expressions. A forthcoming change will allow auxiliary variables created for diffs of expressions to be linked with their lagged expressions as is currently the case for diffs of variables
parent c51487b9
......@@ -25,6 +25,7 @@
#include <cerrno>
#include <algorithm>
#include <iterator>
#include <numeric>
#include "DynamicModel.hh"
// For mkdir() and chdir()
......@@ -5403,26 +5404,40 @@ DynamicModel::findPacExpectationEquationNumbers(vector<int> &eqnumbers) const
}
}
void
DynamicModel::substituteUnaryOps(StaticModel &static_model)
{
vector<int> eqnumbers(equations.size());
iota(eqnumbers.begin(), eqnumbers.end(), 0);
substituteUnaryOps(static_model, eqnumbers);
}
void
DynamicModel::substituteUnaryOps(StaticModel &static_model, set<string> &var_model_eqtags)
{
vector<int> eqnumbers;
getEquationNumbersFromTags(eqnumbers, var_model_eqtags);
findPacExpectationEquationNumbers(eqnumbers);
substituteUnaryOps(static_model, eqnumbers);
}
void
DynamicModel::substituteUnaryOps(StaticModel &static_model, vector<int> &eqnumbers)
{
diff_table_t nodes;
vector<int> eqnumber;
getEquationNumbersFromTags(eqnumber, var_model_eqtags);
findPacExpectationEquationNumbers(eqnumber);
// Find matching unary ops that may be outside of diffs (i.e., those with different lags)
set<int> used_local_vars;
for (int eqnn : eqnumber)
equations[eqnn]->collectVariables(eModelLocalVariable, used_local_vars);
for (int eqnumber : eqnumbers)
equations[eqnumber]->collectVariables(eModelLocalVariable, used_local_vars);
// Only substitute unary ops in model local variables that appear in VAR equations
for (auto & it : local_variables_table)
if (used_local_vars.find(it.first) != used_local_vars.end())
it.second->findUnaryOpNodesForAuxVarCreation(static_model, nodes);
for (int eqnn : eqnumber)
equations[eqnn]->findUnaryOpNodesForAuxVarCreation(static_model, nodes);
for (int eqnumber : eqnumbers)
equations[eqnumber]->findUnaryOpNodesForAuxVarCreation(static_model, nodes);
// Substitute in model local variables
ExprNode::subst_table_t subst_table;
......@@ -5434,7 +5449,7 @@ DynamicModel::substituteUnaryOps(StaticModel &static_model, set<string> &var_mod
for (auto & equation : equations)
{
auto *substeq = dynamic_cast<BinaryOpNode *>(equation->
substituteUnaryOpNodes(static_model, nodes, subst_table, neweqs));
substituteUnaryOpNodes(static_model, nodes, subst_table, neweqs));
assert(substeq != nullptr);
equation = substeq;
}
......@@ -5450,15 +5465,11 @@ DynamicModel::substituteUnaryOps(StaticModel &static_model, set<string> &var_mod
}
void
DynamicModel::substituteDiff(StaticModel &static_model, ExprNode::subst_table_t &diff_subst_table, set<string> &var_model_eqtags)
DynamicModel::substituteDiff(StaticModel &static_model, ExprNode::subst_table_t &diff_subst_table)
{
vector<int> eqnumbers;
getEquationNumbersFromTags(eqnumbers, var_model_eqtags);
findPacExpectationEquationNumbers(eqnumbers);
set<int> used_local_vars;
for (int eqnumber : eqnumbers)
equations[eqnumber]->collectVariables(eModelLocalVariable, used_local_vars);
for (const auto & equation : equations)
equation->collectVariables(eModelLocalVariable, used_local_vars);
// Only substitute diffs in model local variables that appear in VAR equations
diff_table_t diff_table;
......@@ -5466,8 +5477,8 @@ DynamicModel::substituteDiff(StaticModel &static_model, ExprNode::subst_table_t
if (used_local_vars.find(it.first) != used_local_vars.end())
it.second->findDiffNodes(static_model, diff_table);
for (int eqnumber : eqnumbers)
equations[eqnumber]->findDiffNodes(static_model, diff_table);
for (const auto & equation : equations)
equation->findDiffNodes(static_model, diff_table);
// Substitute in model local variables
vector<BinaryOpNode *> neweqs;
......@@ -5478,7 +5489,7 @@ DynamicModel::substituteDiff(StaticModel &static_model, ExprNode::subst_table_t
for (auto & equation : equations)
{
auto *substeq = dynamic_cast<BinaryOpNode *>(equation->
substituteDiff(static_model, diff_table, diff_subst_table, neweqs));
substituteDiff(static_model, diff_table, diff_subst_table, neweqs));
assert(substeq != nullptr);
equation = substeq;
}
......
......@@ -423,11 +423,17 @@ public:
//! Substitutes adl operator
void substituteAdl();
//! Creates aux vars for all unary operators
void substituteUnaryOps(StaticModel &static_model);
//! Creates aux vars for certain unary operators: originally implemented for support of VARs
void substituteUnaryOps(StaticModel &static_model, set<string> &eq_tags);
//! Creates aux vars for certain unary operators: originally implemented for support of VARs
void substituteUnaryOps(StaticModel &static_model, vector<int> &eqnumbers);
//! Substitutes diff operator
void substituteDiff(StaticModel &static_model, ExprNode::subst_table_t &diff_subst_table, set<string> &var_model_eqtags);
void substituteDiff(StaticModel &static_model, ExprNode::subst_table_t &diff_subst_table);
//! Table to undiff LHS variables for pac vector z
void getUndiffLHSForPac(vector<int> &lhs, vector<expr_t> &lhs_expr_t, vector<bool> &diff, vector<int> &orig_diff_var,
......
......@@ -3045,7 +3045,7 @@ UnaryOpNode::countDiffs() const
}
bool
UnaryOpNode::createAuxVarForUnaryOpNodeInDiffOp() const
UnaryOpNode::createAuxVarForUnaryOpNode() const
{
switch (op_code)
{
......@@ -3077,14 +3077,14 @@ UnaryOpNode::createAuxVarForUnaryOpNodeInDiffOp() const
void
UnaryOpNode::findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const
{
if (!this->createAuxVarForUnaryOpNodeInDiffOp())
{
arg->findUnaryOpNodesForAuxVarCreation(static_datatree, nodes);
return;
}
arg->findUnaryOpNodesForAuxVarCreation(static_datatree, nodes);
if (!this->createAuxVarForUnaryOpNode())
return;
expr_t sthis = this->toStatic(static_datatree);
int arg_max_lag = -arg->maxLag();
// TODO: implement recursive expression comparison, ensuring that the difference in the lags is constant across nodes
auto it = nodes.find(sthis);
if (it != nodes.end())
{
......@@ -3101,13 +3101,14 @@ UnaryOpNode::findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_t
void
UnaryOpNode::findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const
{
arg->findDiffNodes(static_datatree, diff_table);
if (op_code != oDiff)
return;
arg->findDiffNodes(static_datatree, diff_table);
expr_t sthis = this->toStatic(static_datatree);
int arg_max_lag = -arg->maxLag();
// TODO: implement recursive expression comparison, ensuring that the difference in the lags is constant across nodes
auto it = diff_table.find(sthis);
if (it != diff_table.end())
{
......@@ -3125,11 +3126,9 @@ expr_t
UnaryOpNode::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table,
vector<BinaryOpNode *> &neweqs) const
{
expr_t argsubst = arg->substituteDiff(static_datatree, diff_table, subst_table, neweqs);
if (op_code != oDiff)
{
expr_t argsubst = arg->substituteDiff(static_datatree, diff_table, subst_table, neweqs);
return buildSimilarUnaryOpNode(argsubst, datatree);
}
return buildSimilarUnaryOpNode(argsubst, datatree);
subst_table_t::const_iterator sit = subst_table.find(this);
if (sit != subst_table.end())
......@@ -3137,13 +3136,19 @@ UnaryOpNode::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table,
expr_t sthis = dynamic_cast<UnaryOpNode *>(this->toStatic(static_datatree));
auto it = diff_table.find(sthis);
int symb_id;
if (it == diff_table.end() || it->second[-arg->maxLag()] != this)
{
// diff does not appear in VAR equations
// so simply substitute diff(x) with x-x(-1)
expr_t argsubst = arg->substituteDiff(static_datatree, diff_table, subst_table, neweqs);
return dynamic_cast<BinaryOpNode *>(datatree.AddMinus(argsubst,
argsubst->decreaseLeadsLags(1)));
// so simply create aux var and return
// Once the comparison of expression nodes works, come back and remove this part, folding into the next loop.
symb_id = datatree.symbol_table.addDiffAuxiliaryVar(argsubst->idx, argsubst);
VariableNode *aux_var = datatree.AddVariable(symb_id, 0);
neweqs.push_back(dynamic_cast<BinaryOpNode *>(datatree.AddEqual(aux_var,
datatree.AddMinus(argsubst,
argsubst->decreaseLeadsLags(1)))));
subst_table[this] = dynamic_cast<VariableNode *>(aux_var);
return const_cast<VariableNode *>(subst_table[this]);
}
int last_arg_max_lag = 0;
......@@ -3153,19 +3158,13 @@ UnaryOpNode::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table,
{
expr_t argsubst = dynamic_cast<UnaryOpNode *>(rit->second)->
get_arg()->substituteDiff(static_datatree, diff_table, subst_table, neweqs);
int symb_id;
auto *vn = dynamic_cast<VariableNode *>(argsubst);
if (rit == it->second.rbegin())
{
if (vn != nullptr)
symb_id = datatree.symbol_table.addDiffAuxiliaryVar(argsubst->idx, argsubst, vn->get_symb_id(), vn->get_lag());
else
{
// We know that the supported unary ops have already been substituted
cerr << "ERROR: You can only use the `diff` operator on variables and certain unary ops." << endl
<< " Try passing the `transform_unary_ops` option on the dynare command line." << endl;
exit(EXIT_FAILURE);
}
symb_id = datatree.symbol_table.addDiffAuxiliaryVar(argsubst->idx, argsubst);
// make originating aux var & equation
last_arg_max_lag = rit->first;
......@@ -3210,35 +3209,30 @@ UnaryOpNode::substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nod
auto *sthis = dynamic_cast<UnaryOpNode *>(this->toStatic(static_datatree));
auto it = nodes.find(sthis);
expr_t argsubst = arg->substituteUnaryOpNodes(static_datatree, nodes, subst_table, neweqs);
if (it == nodes.end())
{
expr_t argsubst = arg->substituteUnaryOpNodes(static_datatree, nodes, subst_table, neweqs);
return buildSimilarUnaryOpNode(argsubst, datatree);
}
return buildSimilarUnaryOpNode(argsubst, datatree);
int base_aux_lag;
VariableNode *aux_var = nullptr;
for (auto rit = it->second.rbegin();
rit != it->second.rend(); rit++)
for (auto rit = it->second.rbegin(); rit != it->second.rend(); rit++)
if (rit == it->second.rbegin())
{
auto *vn = dynamic_cast<VariableNode *>(const_cast<UnaryOpNode *>(this)->get_arg());
int symb_id;
auto *vn = dynamic_cast<VariableNode *>(argsubst);
if (vn == nullptr)
{
cerr << "ERROR: You can only use a unary op on a variable node or another unary op node within a VAR." << endl;
exit(EXIT_FAILURE);
}
int symb_id = datatree.symbol_table.addUnaryOpAuxiliaryVar(this->idx, const_cast<UnaryOpNode *>(this),
symb_id = datatree.symbol_table.addUnaryOpAuxiliaryVar(this->idx, dynamic_cast<UnaryOpNode *>(rit->second));
else
symb_id = datatree.symbol_table.addUnaryOpAuxiliaryVar(this->idx, dynamic_cast<UnaryOpNode *>(rit->second),
vn->get_symb_id(), vn->get_lag());
aux_var = datatree.AddVariable(symb_id, 0);
neweqs.push_back(dynamic_cast<BinaryOpNode *>(datatree.AddEqual(aux_var,
dynamic_cast<UnaryOpNode *>(rit->second))));
subst_table[rit->second] = dynamic_cast<VariableNode *>(aux_var);
base_aux_lag = rit->first;
}
else
{
auto *vn = dynamic_cast<VariableNode *>(dynamic_cast<UnaryOpNode *>(rit->second)->get_arg());
subst_table[rit->second] = dynamic_cast<VariableNode *>(aux_var->decreaseLeadsLags(-vn->get_lag()));
}
subst_table[rit->second] = dynamic_cast<VariableNode *>(aux_var->decreaseLeadsLags(base_aux_lag - rit->first));
sit = subst_table.find(this);
return const_cast<VariableNode *>(sit->second);
......
......@@ -811,7 +811,7 @@ public:
expr_t substituteExpectation(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool partial_information_model) const override;
expr_t substituteAdl() const override;
void findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const override;
bool createAuxVarForUnaryOpNodeInDiffOp() const;
bool createAuxVarForUnaryOpNode() const;
void findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const override;
expr_t substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
expr_t substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override;
......
......@@ -381,12 +381,14 @@ ModFile::transformPass(bool nostrict, bool stochastic, bool compute_xrefs, const
}
if (transform_unary_ops)
dynamic_model.substituteUnaryOps(diff_static_model);
else
// substitute only those unary ops that appear in VAR equations
dynamic_model.substituteUnaryOps(diff_static_model, eqtags);
// Create auxiliary variable and equations for Diff operators that appear in VAR equations
ExprNode::subst_table_t diff_subst_table;
dynamic_model.substituteDiff(diff_static_model, diff_subst_table, eqtags);
dynamic_model.substituteDiff(diff_static_model, diff_subst_table);
// Var Model
map<string, tuple<vector<int>, vector<expr_t>, vector<bool>, vector<int>, int, vector<bool>, vector<int>>>
......
......@@ -354,10 +354,14 @@ SymbolTable::writeOutput(ostream &output) const noexcept(false)
case avEndoLag:
case avExoLag:
case avVarModel:
case avUnaryOp:
output << "M_.aux_vars(" << i+1 << ").orig_index = " << getTypeSpecificID(aux_vars[i].get_orig_symb_id())+1 << ";" << endl
<< "M_.aux_vars(" << i+1 << ").orig_lead_lag = " << aux_vars[i].get_orig_lead_lag() << ";" << endl;
break;
case avUnaryOp:
if (aux_vars[i].get_orig_symb_id() >= 0)
output << "M_.aux_vars(" << i+1 << ").orig_index = " << getTypeSpecificID(aux_vars[i].get_orig_symb_id())+1 << ";" << endl
<< "M_.aux_vars(" << i+1 << ").orig_lead_lag = " << aux_vars[i].get_orig_lead_lag() << ";" << endl;
break;
case avMultiplier:
output << "M_.aux_vars(" << i+1 << ").eq_nbr = " << aux_vars[i].get_equation_number_for_multiplier() + 1 << ";" << endl;
break;
......@@ -479,10 +483,14 @@ SymbolTable::writeCOutput(ostream &output) const noexcept(false)
case avEndoLag:
case avExoLag:
case avVarModel:
case avUnaryOp:
output << "av[" << i << "].orig_index = " << getTypeSpecificID(aux_vars[i].get_orig_symb_id()) << ";" << endl
<< "av[" << i << "].orig_lead_lag = " << aux_vars[i].get_orig_lead_lag() << ";" << endl;
break;
case avUnaryOp:
if (aux_vars[i].get_orig_symb_id() >= 0)
output << "av[" << i << "].orig_index = " << getTypeSpecificID(aux_vars[i].get_orig_symb_id()) << ";" << endl
<< "av[" << i << "].orig_lead_lag = " << aux_vars[i].get_orig_lead_lag() << ";" << endl;
break;
case avDiff:
case avDiffLag:
if (aux_vars[i].get_orig_symb_id() >= 0)
......@@ -579,10 +587,14 @@ SymbolTable::writeCCOutput(ostream &output) const noexcept(false)
case avEndoLag:
case avExoLag:
case avVarModel:
case avUnaryOp:
output << "av" << i << ".orig_index = " << getTypeSpecificID(aux_vars[i].get_orig_symb_id()) << ";" << endl
<< "av" << i << ".orig_lead_lag = " << aux_vars[i].get_orig_lead_lag() << ";" << endl;
break;
case avUnaryOp:
if (aux_vars[i].get_orig_symb_id() >= 0)
output << "av" << i << ".orig_index = " << getTypeSpecificID(aux_vars[i].get_orig_symb_id()) << ";" << endl
<< "av" << i << ".orig_lead_lag = " << aux_vars[i].get_orig_lead_lag() << ";" << endl;
break;
case avDiff:
case avDiffLag:
if (aux_vars[i].get_orig_symb_id() >= 0)
......@@ -1098,10 +1110,16 @@ SymbolTable::writeJuliaOutput(ostream &output) const noexcept(false)
case avEndoLag:
case avExoLag:
case avVarModel:
case avUnaryOp:
output << getTypeSpecificID(aux_var.get_orig_symb_id()) + 1 << ", "
<< aux_var.get_orig_lead_lag() << ", typemin(Int), string()";
break;
case avUnaryOp:
if (aux_var.get_orig_symb_id() >= 0)
output << getTypeSpecificID(aux_var.get_orig_symb_id()) + 1 << ", " << aux_var.get_orig_lead_lag();
else
output << "typemin(Int), typemin(Int)";
output << ", typemin(Int), string()";
break;
case avDiff:
case avDiffLag:
if (aux_var.get_orig_symb_id() >= 0)
......
......@@ -295,7 +295,7 @@ public:
//! Takes care of timing between diff statements
int addDiffLagAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_id, int orig_lag) noexcept(false);
//! An Auxiliary variable for a unary op
int addUnaryOpAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_id, int orig_lag) noexcept(false);
int addUnaryOpAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_id = -1, int orig_lag = 0) noexcept(false);
//! Returns the number of auxiliary variables
int
AuxVarsSize() const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment