From 84c2dc5f3621d47407397d519b623eb9d1fc1eb3 Mon Sep 17 00:00:00 2001 From: Houtan Bastani <houtan@dynare.org> Date: Thu, 7 Jun 2018 12:53:00 +0200 Subject: [PATCH] transform_unary_ops now introduces aux variables/equations for all unary ops specified by UnaryOpNode::createAuxVarForUnaryOpNode() In the absence of this option, if a var_model statement(s) is present, then aux vars/eqs are created for the same types of unary operators but only for equations specified in the var_model statement In the absence of both this option and var_model statements, no unary op auxiliary variables are created diffs continue to be substituted everywhere; for the moment auxiliary variables are created for diffs of expressions. A forthcoming change will allow auxiliary variables created for diffs of expressions to be linked with their lagged expressions as is currently the case for diffs of variables --- src/DynamicModel.cc | 47 +++++++++++++++++------------ src/DynamicModel.hh | 8 ++++- src/ExprNode.cc | 72 +++++++++++++++++++++------------------------ src/ExprNode.hh | 2 +- src/ModFile.cc | 4 ++- src/SymbolTable.cc | 26 +++++++++++++--- src/SymbolTable.hh | 2 +- 7 files changed, 96 insertions(+), 65 deletions(-) diff --git a/src/DynamicModel.cc b/src/DynamicModel.cc index aa703ff5..cb62cea2 100644 --- a/src/DynamicModel.cc +++ b/src/DynamicModel.cc @@ -25,6 +25,7 @@ #include <cerrno> #include <algorithm> #include <iterator> +#include <numeric> #include "DynamicModel.hh" // For mkdir() and chdir() @@ -5403,26 +5404,40 @@ DynamicModel::findPacExpectationEquationNumbers(vector<int> &eqnumbers) const } } +void +DynamicModel::substituteUnaryOps(StaticModel &static_model) +{ + vector<int> eqnumbers(equations.size()); + iota(eqnumbers.begin(), eqnumbers.end(), 0); + substituteUnaryOps(static_model, eqnumbers); +} + void DynamicModel::substituteUnaryOps(StaticModel &static_model, set<string> &var_model_eqtags) +{ + vector<int> eqnumbers; + getEquationNumbersFromTags(eqnumbers, var_model_eqtags); + findPacExpectationEquationNumbers(eqnumbers); + substituteUnaryOps(static_model, eqnumbers); +} + +void +DynamicModel::substituteUnaryOps(StaticModel &static_model, vector<int> &eqnumbers) { diff_table_t nodes; - vector<int> eqnumber; - getEquationNumbersFromTags(eqnumber, var_model_eqtags); - findPacExpectationEquationNumbers(eqnumber); // Find matching unary ops that may be outside of diffs (i.e., those with different lags) set<int> used_local_vars; - for (int eqnn : eqnumber) - equations[eqnn]->collectVariables(eModelLocalVariable, used_local_vars); + for (int eqnumber : eqnumbers) + equations[eqnumber]->collectVariables(eModelLocalVariable, used_local_vars); // Only substitute unary ops in model local variables that appear in VAR equations for (auto & it : local_variables_table) if (used_local_vars.find(it.first) != used_local_vars.end()) it.second->findUnaryOpNodesForAuxVarCreation(static_model, nodes); - for (int eqnn : eqnumber) - equations[eqnn]->findUnaryOpNodesForAuxVarCreation(static_model, nodes); + for (int eqnumber : eqnumbers) + equations[eqnumber]->findUnaryOpNodesForAuxVarCreation(static_model, nodes); // Substitute in model local variables ExprNode::subst_table_t subst_table; @@ -5434,7 +5449,7 @@ DynamicModel::substituteUnaryOps(StaticModel &static_model, set<string> &var_mod for (auto & equation : equations) { auto *substeq = dynamic_cast<BinaryOpNode *>(equation-> - substituteUnaryOpNodes(static_model, nodes, subst_table, neweqs)); + substituteUnaryOpNodes(static_model, nodes, subst_table, neweqs)); assert(substeq != nullptr); equation = substeq; } @@ -5450,15 +5465,11 @@ DynamicModel::substituteUnaryOps(StaticModel &static_model, set<string> &var_mod } void -DynamicModel::substituteDiff(StaticModel &static_model, ExprNode::subst_table_t &diff_subst_table, set<string> &var_model_eqtags) +DynamicModel::substituteDiff(StaticModel &static_model, ExprNode::subst_table_t &diff_subst_table) { - vector<int> eqnumbers; - getEquationNumbersFromTags(eqnumbers, var_model_eqtags); - findPacExpectationEquationNumbers(eqnumbers); - set<int> used_local_vars; - for (int eqnumber : eqnumbers) - equations[eqnumber]->collectVariables(eModelLocalVariable, used_local_vars); + for (const auto & equation : equations) + equation->collectVariables(eModelLocalVariable, used_local_vars); // Only substitute diffs in model local variables that appear in VAR equations diff_table_t diff_table; @@ -5466,8 +5477,8 @@ DynamicModel::substituteDiff(StaticModel &static_model, ExprNode::subst_table_t if (used_local_vars.find(it.first) != used_local_vars.end()) it.second->findDiffNodes(static_model, diff_table); - for (int eqnumber : eqnumbers) - equations[eqnumber]->findDiffNodes(static_model, diff_table); + for (const auto & equation : equations) + equation->findDiffNodes(static_model, diff_table); // Substitute in model local variables vector<BinaryOpNode *> neweqs; @@ -5478,7 +5489,7 @@ DynamicModel::substituteDiff(StaticModel &static_model, ExprNode::subst_table_t for (auto & equation : equations) { auto *substeq = dynamic_cast<BinaryOpNode *>(equation-> - substituteDiff(static_model, diff_table, diff_subst_table, neweqs)); + substituteDiff(static_model, diff_table, diff_subst_table, neweqs)); assert(substeq != nullptr); equation = substeq; } diff --git a/src/DynamicModel.hh b/src/DynamicModel.hh index 578a9ad5..8eb69ecc 100644 --- a/src/DynamicModel.hh +++ b/src/DynamicModel.hh @@ -423,11 +423,17 @@ public: //! Substitutes adl operator void substituteAdl(); + //! Creates aux vars for all unary operators + void substituteUnaryOps(StaticModel &static_model); + //! Creates aux vars for certain unary operators: originally implemented for support of VARs void substituteUnaryOps(StaticModel &static_model, set<string> &eq_tags); + //! Creates aux vars for certain unary operators: originally implemented for support of VARs + void substituteUnaryOps(StaticModel &static_model, vector<int> &eqnumbers); + //! Substitutes diff operator - void substituteDiff(StaticModel &static_model, ExprNode::subst_table_t &diff_subst_table, set<string> &var_model_eqtags); + void substituteDiff(StaticModel &static_model, ExprNode::subst_table_t &diff_subst_table); //! Table to undiff LHS variables for pac vector z void getUndiffLHSForPac(vector<int> &lhs, vector<expr_t> &lhs_expr_t, vector<bool> &diff, vector<int> &orig_diff_var, diff --git a/src/ExprNode.cc b/src/ExprNode.cc index 8619492f..422e6c1e 100644 --- a/src/ExprNode.cc +++ b/src/ExprNode.cc @@ -3045,7 +3045,7 @@ UnaryOpNode::countDiffs() const } bool -UnaryOpNode::createAuxVarForUnaryOpNodeInDiffOp() const +UnaryOpNode::createAuxVarForUnaryOpNode() const { switch (op_code) { @@ -3077,14 +3077,14 @@ UnaryOpNode::createAuxVarForUnaryOpNodeInDiffOp() const void UnaryOpNode::findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const { - if (!this->createAuxVarForUnaryOpNodeInDiffOp()) - { - arg->findUnaryOpNodesForAuxVarCreation(static_datatree, nodes); - return; - } + arg->findUnaryOpNodesForAuxVarCreation(static_datatree, nodes); + + if (!this->createAuxVarForUnaryOpNode()) + return; expr_t sthis = this->toStatic(static_datatree); int arg_max_lag = -arg->maxLag(); + // TODO: implement recursive expression comparison, ensuring that the difference in the lags is constant across nodes auto it = nodes.find(sthis); if (it != nodes.end()) { @@ -3101,13 +3101,14 @@ UnaryOpNode::findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_t void UnaryOpNode::findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const { + arg->findDiffNodes(static_datatree, diff_table); + if (op_code != oDiff) return; - arg->findDiffNodes(static_datatree, diff_table); - expr_t sthis = this->toStatic(static_datatree); int arg_max_lag = -arg->maxLag(); + // TODO: implement recursive expression comparison, ensuring that the difference in the lags is constant across nodes auto it = diff_table.find(sthis); if (it != diff_table.end()) { @@ -3125,11 +3126,9 @@ expr_t UnaryOpNode::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const { + expr_t argsubst = arg->substituteDiff(static_datatree, diff_table, subst_table, neweqs); if (op_code != oDiff) - { - expr_t argsubst = arg->substituteDiff(static_datatree, diff_table, subst_table, neweqs); - return buildSimilarUnaryOpNode(argsubst, datatree); - } + return buildSimilarUnaryOpNode(argsubst, datatree); subst_table_t::const_iterator sit = subst_table.find(this); if (sit != subst_table.end()) @@ -3137,13 +3136,19 @@ UnaryOpNode::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, expr_t sthis = dynamic_cast<UnaryOpNode *>(this->toStatic(static_datatree)); auto it = diff_table.find(sthis); + int symb_id; if (it == diff_table.end() || it->second[-arg->maxLag()] != this) { // diff does not appear in VAR equations - // so simply substitute diff(x) with x-x(-1) - expr_t argsubst = arg->substituteDiff(static_datatree, diff_table, subst_table, neweqs); - return dynamic_cast<BinaryOpNode *>(datatree.AddMinus(argsubst, - argsubst->decreaseLeadsLags(1))); + // so simply create aux var and return + // Once the comparison of expression nodes works, come back and remove this part, folding into the next loop. + symb_id = datatree.symbol_table.addDiffAuxiliaryVar(argsubst->idx, argsubst); + VariableNode *aux_var = datatree.AddVariable(symb_id, 0); + neweqs.push_back(dynamic_cast<BinaryOpNode *>(datatree.AddEqual(aux_var, + datatree.AddMinus(argsubst, + argsubst->decreaseLeadsLags(1))))); + subst_table[this] = dynamic_cast<VariableNode *>(aux_var); + return const_cast<VariableNode *>(subst_table[this]); } int last_arg_max_lag = 0; @@ -3153,19 +3158,13 @@ UnaryOpNode::substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, { expr_t argsubst = dynamic_cast<UnaryOpNode *>(rit->second)-> get_arg()->substituteDiff(static_datatree, diff_table, subst_table, neweqs); - int symb_id; auto *vn = dynamic_cast<VariableNode *>(argsubst); if (rit == it->second.rbegin()) { if (vn != nullptr) symb_id = datatree.symbol_table.addDiffAuxiliaryVar(argsubst->idx, argsubst, vn->get_symb_id(), vn->get_lag()); else - { - // We know that the supported unary ops have already been substituted - cerr << "ERROR: You can only use the `diff` operator on variables and certain unary ops." << endl - << " Try passing the `transform_unary_ops` option on the dynare command line." << endl; - exit(EXIT_FAILURE); - } + symb_id = datatree.symbol_table.addDiffAuxiliaryVar(argsubst->idx, argsubst); // make originating aux var & equation last_arg_max_lag = rit->first; @@ -3210,35 +3209,30 @@ UnaryOpNode::substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nod auto *sthis = dynamic_cast<UnaryOpNode *>(this->toStatic(static_datatree)); auto it = nodes.find(sthis); + expr_t argsubst = arg->substituteUnaryOpNodes(static_datatree, nodes, subst_table, neweqs); if (it == nodes.end()) - { - expr_t argsubst = arg->substituteUnaryOpNodes(static_datatree, nodes, subst_table, neweqs); - return buildSimilarUnaryOpNode(argsubst, datatree); - } + return buildSimilarUnaryOpNode(argsubst, datatree); + int base_aux_lag; VariableNode *aux_var = nullptr; - for (auto rit = it->second.rbegin(); - rit != it->second.rend(); rit++) + for (auto rit = it->second.rbegin(); rit != it->second.rend(); rit++) if (rit == it->second.rbegin()) { - auto *vn = dynamic_cast<VariableNode *>(const_cast<UnaryOpNode *>(this)->get_arg()); + int symb_id; + auto *vn = dynamic_cast<VariableNode *>(argsubst); if (vn == nullptr) - { - cerr << "ERROR: You can only use a unary op on a variable node or another unary op node within a VAR." << endl; - exit(EXIT_FAILURE); - } - int symb_id = datatree.symbol_table.addUnaryOpAuxiliaryVar(this->idx, const_cast<UnaryOpNode *>(this), + symb_id = datatree.symbol_table.addUnaryOpAuxiliaryVar(this->idx, dynamic_cast<UnaryOpNode *>(rit->second)); + else + symb_id = datatree.symbol_table.addUnaryOpAuxiliaryVar(this->idx, dynamic_cast<UnaryOpNode *>(rit->second), vn->get_symb_id(), vn->get_lag()); aux_var = datatree.AddVariable(symb_id, 0); neweqs.push_back(dynamic_cast<BinaryOpNode *>(datatree.AddEqual(aux_var, dynamic_cast<UnaryOpNode *>(rit->second)))); subst_table[rit->second] = dynamic_cast<VariableNode *>(aux_var); + base_aux_lag = rit->first; } else - { - auto *vn = dynamic_cast<VariableNode *>(dynamic_cast<UnaryOpNode *>(rit->second)->get_arg()); - subst_table[rit->second] = dynamic_cast<VariableNode *>(aux_var->decreaseLeadsLags(-vn->get_lag())); - } + subst_table[rit->second] = dynamic_cast<VariableNode *>(aux_var->decreaseLeadsLags(base_aux_lag - rit->first)); sit = subst_table.find(this); return const_cast<VariableNode *>(sit->second); diff --git a/src/ExprNode.hh b/src/ExprNode.hh index 7b2b6280..13071eb0 100644 --- a/src/ExprNode.hh +++ b/src/ExprNode.hh @@ -811,7 +811,7 @@ public: expr_t substituteExpectation(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool partial_information_model) const override; expr_t substituteAdl() const override; void findDiffNodes(DataTree &static_datatree, diff_table_t &diff_table) const override; - bool createAuxVarForUnaryOpNodeInDiffOp() const; + bool createAuxVarForUnaryOpNode() const; void findUnaryOpNodesForAuxVarCreation(DataTree &static_datatree, diff_table_t &nodes) const override; expr_t substituteDiff(DataTree &static_datatree, diff_table_t &diff_table, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override; expr_t substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const override; diff --git a/src/ModFile.cc b/src/ModFile.cc index 3332cf47..1e237dc3 100644 --- a/src/ModFile.cc +++ b/src/ModFile.cc @@ -381,12 +381,14 @@ ModFile::transformPass(bool nostrict, bool stochastic, bool compute_xrefs, const } if (transform_unary_ops) + dynamic_model.substituteUnaryOps(diff_static_model); + else // substitute only those unary ops that appear in VAR equations dynamic_model.substituteUnaryOps(diff_static_model, eqtags); // Create auxiliary variable and equations for Diff operators that appear in VAR equations ExprNode::subst_table_t diff_subst_table; - dynamic_model.substituteDiff(diff_static_model, diff_subst_table, eqtags); + dynamic_model.substituteDiff(diff_static_model, diff_subst_table); // Var Model map<string, tuple<vector<int>, vector<expr_t>, vector<bool>, vector<int>, int, vector<bool>, vector<int>>> diff --git a/src/SymbolTable.cc b/src/SymbolTable.cc index f9f21917..3d56c1c2 100644 --- a/src/SymbolTable.cc +++ b/src/SymbolTable.cc @@ -354,10 +354,14 @@ SymbolTable::writeOutput(ostream &output) const noexcept(false) case avEndoLag: case avExoLag: case avVarModel: - case avUnaryOp: output << "M_.aux_vars(" << i+1 << ").orig_index = " << getTypeSpecificID(aux_vars[i].get_orig_symb_id())+1 << ";" << endl << "M_.aux_vars(" << i+1 << ").orig_lead_lag = " << aux_vars[i].get_orig_lead_lag() << ";" << endl; break; + case avUnaryOp: + if (aux_vars[i].get_orig_symb_id() >= 0) + output << "M_.aux_vars(" << i+1 << ").orig_index = " << getTypeSpecificID(aux_vars[i].get_orig_symb_id())+1 << ";" << endl + << "M_.aux_vars(" << i+1 << ").orig_lead_lag = " << aux_vars[i].get_orig_lead_lag() << ";" << endl; + break; case avMultiplier: output << "M_.aux_vars(" << i+1 << ").eq_nbr = " << aux_vars[i].get_equation_number_for_multiplier() + 1 << ";" << endl; break; @@ -479,10 +483,14 @@ SymbolTable::writeCOutput(ostream &output) const noexcept(false) case avEndoLag: case avExoLag: case avVarModel: - case avUnaryOp: output << "av[" << i << "].orig_index = " << getTypeSpecificID(aux_vars[i].get_orig_symb_id()) << ";" << endl << "av[" << i << "].orig_lead_lag = " << aux_vars[i].get_orig_lead_lag() << ";" << endl; break; + case avUnaryOp: + if (aux_vars[i].get_orig_symb_id() >= 0) + output << "av[" << i << "].orig_index = " << getTypeSpecificID(aux_vars[i].get_orig_symb_id()) << ";" << endl + << "av[" << i << "].orig_lead_lag = " << aux_vars[i].get_orig_lead_lag() << ";" << endl; + break; case avDiff: case avDiffLag: if (aux_vars[i].get_orig_symb_id() >= 0) @@ -579,10 +587,14 @@ SymbolTable::writeCCOutput(ostream &output) const noexcept(false) case avEndoLag: case avExoLag: case avVarModel: - case avUnaryOp: output << "av" << i << ".orig_index = " << getTypeSpecificID(aux_vars[i].get_orig_symb_id()) << ";" << endl << "av" << i << ".orig_lead_lag = " << aux_vars[i].get_orig_lead_lag() << ";" << endl; break; + case avUnaryOp: + if (aux_vars[i].get_orig_symb_id() >= 0) + output << "av" << i << ".orig_index = " << getTypeSpecificID(aux_vars[i].get_orig_symb_id()) << ";" << endl + << "av" << i << ".orig_lead_lag = " << aux_vars[i].get_orig_lead_lag() << ";" << endl; + break; case avDiff: case avDiffLag: if (aux_vars[i].get_orig_symb_id() >= 0) @@ -1098,10 +1110,16 @@ SymbolTable::writeJuliaOutput(ostream &output) const noexcept(false) case avEndoLag: case avExoLag: case avVarModel: - case avUnaryOp: output << getTypeSpecificID(aux_var.get_orig_symb_id()) + 1 << ", " << aux_var.get_orig_lead_lag() << ", typemin(Int), string()"; break; + case avUnaryOp: + if (aux_var.get_orig_symb_id() >= 0) + output << getTypeSpecificID(aux_var.get_orig_symb_id()) + 1 << ", " << aux_var.get_orig_lead_lag(); + else + output << "typemin(Int), typemin(Int)"; + output << ", typemin(Int), string()"; + break; case avDiff: case avDiffLag: if (aux_var.get_orig_symb_id() >= 0) diff --git a/src/SymbolTable.hh b/src/SymbolTable.hh index 06fe9502..0ccdc1c1 100644 --- a/src/SymbolTable.hh +++ b/src/SymbolTable.hh @@ -295,7 +295,7 @@ public: //! Takes care of timing between diff statements int addDiffLagAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_id, int orig_lag) noexcept(false); //! An Auxiliary variable for a unary op - int addUnaryOpAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_id, int orig_lag) noexcept(false); + int addUnaryOpAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_id = -1, int orig_lag = 0) noexcept(false); //! Returns the number of auxiliary variables int AuxVarsSize() const -- GitLab