From 2cd3aa95cc954eb8c208963b397313d297f9f515 Mon Sep 17 00:00:00 2001 From: Houtan Bastani <houtan@dynare.org> Date: Tue, 5 Jun 2018 16:38:37 +0200 Subject: [PATCH] When `transform_unary_ops` is passed, only substitute unary operators that appear in VAR equations --- src/DynamicModel.cc | 24 ++++++++++++++++++++---- src/DynamicModel.hh | 2 +- src/ExprNode.cc | 36 ++++++++++++++++++++---------------- src/ModFile.cc | 20 +++++++++++++++++--- 4 files changed, 58 insertions(+), 24 deletions(-) diff --git a/src/DynamicModel.cc b/src/DynamicModel.cc index 117c48a9..535e7a08 100644 --- a/src/DynamicModel.cc +++ b/src/DynamicModel.cc @@ -5378,15 +5378,31 @@ DynamicModel::substituteAdl() } void -DynamicModel::substituteUnaryOps(StaticModel &static_model) +DynamicModel::substituteUnaryOps(StaticModel &static_model, set<string> &var_model_eqtags) { diff_table_t nodes; + vector<int> eqnumber; + for (auto & eqtag : var_model_eqtags) + for (const auto & equation_tag : equation_tags) + if (equation_tag.second.first == "name" + && equation_tag.second.second == eqtag) + { + eqnumber.push_back(equation_tag.first); + break; + } + // Find matching unary ops that may be outside of diffs (i.e., those with different lags) + set<int> used_local_vars; + for (int eqnn : eqnumber) + equations[eqnn]->collectVariables(eModelLocalVariable, used_local_vars); + + // Only substitute unary ops in model local variables that appear in VAR equations for (auto & it : local_variables_table) - it.second->findUnaryOpNodesForAuxVarCreation(static_model, nodes); + if (used_local_vars.find(it.first) != used_local_vars.end()) + it.second->findUnaryOpNodesForAuxVarCreation(static_model, nodes); - for (auto & equation : equations) - equation->findUnaryOpNodesForAuxVarCreation(static_model, nodes); + for (int eqnn : eqnumber) + equations[eqnn]->findUnaryOpNodesForAuxVarCreation(static_model, nodes); // Substitute in model local variables ExprNode::subst_table_t subst_table; diff --git a/src/DynamicModel.hh b/src/DynamicModel.hh index 6ae70ab4..36141a2e 100644 --- a/src/DynamicModel.hh +++ b/src/DynamicModel.hh @@ -420,7 +420,7 @@ public: void substituteAdl(); //! Creates aux vars for certain unary operators: originally implemented for support of VARs - void substituteUnaryOps(StaticModel &static_model); + void substituteUnaryOps(StaticModel &static_model, set<string> &eq_tags); //! Substitutes diff operator void substituteDiff(StaticModel &static_model, ExprNode::subst_table_t &diff_subst_table); diff --git a/src/ExprNode.cc b/src/ExprNode.cc index 191c3a59..9ab518f1 100644 --- a/src/ExprNode.cc +++ b/src/ExprNode.cc @@ -3204,22 +3204,26 @@ UnaryOpNode::substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nod VariableNode *aux_var = nullptr; for (auto rit = it->second.rbegin(); rit != it->second.rend(); rit++) - { - if (rit == it->second.rbegin()) - { - auto *vn = dynamic_cast<VariableNode *>(const_cast<UnaryOpNode *>(this)->get_arg()); - int symb_id = datatree.symbol_table.addUnaryOpAuxiliaryVar(this->idx, const_cast<UnaryOpNode *>(this), - vn->get_symb_id(), vn->get_lag()); - aux_var = datatree.AddVariable(symb_id, 0); - neweqs.push_back(dynamic_cast<BinaryOpNode *>(datatree.AddEqual(aux_var, dynamic_cast<UnaryOpNode *>(rit->second)))); - subst_table[rit->second] = dynamic_cast<VariableNode *>(aux_var); - } - else - { - auto *vn = dynamic_cast<VariableNode *>(dynamic_cast<UnaryOpNode *>(rit->second)->get_arg()); - subst_table[rit->second] = dynamic_cast<VariableNode *>(aux_var->decreaseLeadsLags(-vn->get_lag())); - } - } + if (rit == it->second.rbegin()) + { + auto *vn = dynamic_cast<VariableNode *>(const_cast<UnaryOpNode *>(this)->get_arg()); + if (vn == nullptr) + { + cerr << "ERROR: You can only use a unary op on a variable node or another unary op node within a VAR." << endl; + exit(EXIT_FAILURE); + } + int symb_id = datatree.symbol_table.addUnaryOpAuxiliaryVar(this->idx, const_cast<UnaryOpNode *>(this), + vn->get_symb_id(), vn->get_lag()); + aux_var = datatree.AddVariable(symb_id, 0); + neweqs.push_back(dynamic_cast<BinaryOpNode *>(datatree.AddEqual(aux_var, + dynamic_cast<UnaryOpNode *>(rit->second)))); + subst_table[rit->second] = dynamic_cast<VariableNode *>(aux_var); + } + else + { + auto *vn = dynamic_cast<VariableNode *>(dynamic_cast<UnaryOpNode *>(rit->second)->get_arg()); + subst_table[rit->second] = dynamic_cast<VariableNode *>(aux_var->decreaseLeadsLags(-vn->get_lag())); + } sit = subst_table.find(this); return const_cast<VariableNode *>(sit->second); diff --git a/src/ModFile.cc b/src/ModFile.cc index 2f85f50d..323b80b1 100644 --- a/src/ModFile.cc +++ b/src/ModFile.cc @@ -364,16 +364,30 @@ ModFile::transformPass(bool nostrict, bool stochastic, bool compute_xrefs, const } } + string var_model_name; + set<string> eqtags; + map<string, vector<string>> var_model_eq_tags; + map<string, pair<SymbolList, int>> var_model_info_var_expectation; + for (auto it = statements.begin(); it != statements.end(); it++) + { + auto *vms = dynamic_cast<VarModelStatement *>(*it); + if (vms != nullptr) + { + vms->getVarModelInfo(var_model_name, var_model_info_var_expectation, var_model_eq_tags); + for (auto & eqtag : var_model_eq_tags[var_model_name]) + eqtags.insert(eqtag); + } + } + if (transform_unary_ops) - dynamic_model.substituteUnaryOps(diff_static_model); + // substitute only those unary ops that appear in VAR equations + dynamic_model.substituteUnaryOps(diff_static_model, eqtags); // Create auxiliary variable and equations for Diff operator ExprNode::subst_table_t diff_subst_table; dynamic_model.substituteDiff(diff_static_model, diff_subst_table); // Var Model - map<string, pair<SymbolList, int>> var_model_info_var_expectation; - map<string, vector<string>> var_model_eq_tags; map<string, tuple<vector<int>, vector<expr_t>, vector<bool>, vector<int>, int, vector<bool>, vector<int>>> var_model_info_pac_expectation; for (auto it = statements.begin(); it != statements.end(); it++) -- GitLab