diff --git a/src/ExprNode.cc b/src/ExprNode.cc index c627c9c58d126278efae379307689a198c94ccb8..2df5e3bff725570b38d849d32e0e6e7b03844d24 100644 --- a/src/ExprNode.cc +++ b/src/ExprNode.cc @@ -3486,6 +3486,73 @@ UnaryOpNode::substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nod if (it == nodes.end()) return buildSimilarUnaryOpNode(argsubst, datatree); + string unary_op = ""; + switch (op_code) + { + case UnaryOpcode::exp: + unary_op = "exp"; + break; + case UnaryOpcode::log: + unary_op = "log"; + break; + case UnaryOpcode::log10: + unary_op = "log10"; + break; + case UnaryOpcode::cos: + unary_op = "cos"; + break; + case UnaryOpcode::sin: + unary_op = "sin"; + break; + case UnaryOpcode::tan: + unary_op = "tan"; + break; + case UnaryOpcode::acos: + unary_op = "acos"; + break; + case UnaryOpcode::asin: + unary_op = "asin"; + break; + case UnaryOpcode::atan: + unary_op = "atan"; + break; + case UnaryOpcode::cosh: + unary_op = "cosh"; + break; + case UnaryOpcode::sinh: + unary_op = "sinh"; + break; + case UnaryOpcode::tanh: + unary_op = "tanh"; + break; + case UnaryOpcode::acosh: + unary_op = "acosh"; + break; + case UnaryOpcode::asinh: + unary_op = "asinh"; + break; + case UnaryOpcode::atanh: + unary_op = "atanh"; + break; + case UnaryOpcode::sqrt: + unary_op = "sqrt"; + break; + case UnaryOpcode::abs: + unary_op = "abs"; + break; + case UnaryOpcode::sign: + unary_op = "sign"; + break; + case UnaryOpcode::erf: + unary_op = "erf"; + break; + default: + { + cerr << "UnaryOpNode::substituteUnaryOpNodes: Shouldn't arrive here" << endl; + exit(EXIT_FAILURE); + } + } + int base_aux_lag = 0; VariableNode *aux_var = nullptr; for (auto rit = it->second.rbegin(); rit != it->second.rend(); rit++) @@ -3494,10 +3561,10 @@ UnaryOpNode::substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nod int symb_id; auto *vn = dynamic_cast<VariableNode *>(argsubst); if (vn == nullptr) - symb_id = datatree.symbol_table.addUnaryOpAuxiliaryVar(this->idx, dynamic_cast<UnaryOpNode *>(rit->second)); + symb_id = datatree.symbol_table.addUnaryOpAuxiliaryVar(this->idx, dynamic_cast<UnaryOpNode *>(rit->second), unary_op); else - symb_id = datatree.symbol_table.addUnaryOpAuxiliaryVar(this->idx, dynamic_cast<UnaryOpNode *>(rit->second), - vn->get_symb_id(), vn->get_lag()); + symb_id = datatree.symbol_table.addUnaryOpAuxiliaryVar(this->idx, dynamic_cast<UnaryOpNode *>(rit->second), unary_op, + vn->get_symb_id(), vn->get_lag()); aux_var = datatree.AddVariable(symb_id, 0); neweqs.push_back(dynamic_cast<BinaryOpNode *>(datatree.AddEqual(aux_var, dynamic_cast<UnaryOpNode *>(rit->second)))); diff --git a/src/SymbolTable.cc b/src/SymbolTable.cc index 46d7ef6396f354b676778cd58ed27948db98ecc4..d5e08580e7c17cdde8fd3a401728c77dcb5afa90 100644 --- a/src/SymbolTable.cc +++ b/src/SymbolTable.cc @@ -27,14 +27,15 @@ AuxVarInfo::AuxVarInfo(int symb_id_arg, AuxVarType type_arg, int orig_symb_id_arg, int orig_lead_lag_arg, int equation_number_for_multiplier_arg, int information_set_arg, - expr_t expr_node_arg) : + expr_t expr_node_arg, string unary_op_arg) : symb_id{symb_id_arg}, type{type_arg}, orig_symb_id{orig_symb_id_arg}, orig_lead_lag{orig_lead_lag_arg}, equation_number_for_multiplier{equation_number_for_multiplier_arg}, information_set{information_set_arg}, - expr_node{expr_node_arg} + expr_node{expr_node_arg}, + unary_op{unary_op_arg} { } @@ -360,6 +361,7 @@ SymbolTable::writeOutput(ostream &output) const noexcept(false) if (aux_vars[i].get_orig_symb_id() >= 0) output << "M_.aux_vars(" << i+1 << ").orig_index = " << getTypeSpecificID(aux_vars[i].get_orig_symb_id())+1 << ";" << endl << "M_.aux_vars(" << i+1 << ").orig_lead_lag = " << aux_vars[i].get_orig_lead_lag() << ";" << endl; + output << "M_.aux_vars(" << i+1 << ").unary_op = '" << aux_vars[i].get_unary_op() << "';" << endl; break; case AuxVarType::multiplier: output << "M_.aux_vars(" << i+1 << ").eq_nbr = " << aux_vars[i].get_equation_number_for_multiplier() + 1 << ";" << endl; @@ -489,6 +491,7 @@ SymbolTable::writeCOutput(ostream &output) const noexcept(false) if (aux_vars[i].get_orig_symb_id() >= 0) output << "av[" << i << "].orig_index = " << getTypeSpecificID(aux_vars[i].get_orig_symb_id()) << ";" << endl << "av[" << i << "].orig_lead_lag = " << aux_vars[i].get_orig_lead_lag() << ";" << endl; + output << "av[" << i << "].unary_op = \"" << aux_vars[i].get_unary_op() << "\";" << endl; break; case AuxVarType::diff: case AuxVarType::diffLag: @@ -593,6 +596,7 @@ SymbolTable::writeCCOutput(ostream &output) const noexcept(false) if (aux_vars[i].get_orig_symb_id() >= 0) output << "av" << i << ".orig_index = " << getTypeSpecificID(aux_vars[i].get_orig_symb_id()) << ";" << endl << "av" << i << ".orig_lead_lag = " << aux_vars[i].get_orig_lead_lag() << ";" << endl; + output << "av" << i << ".unary_op = \"" << aux_vars[i].get_unary_op() << "\";" << endl; break; case AuxVarType::diff: case AuxVarType::diffLag: @@ -634,7 +638,7 @@ SymbolTable::addLeadAuxiliaryVarInternal(bool endo, int index, expr_t expr_arg) exit(EXIT_FAILURE); } - aux_vars.emplace_back(symb_id, (endo ? AuxVarType::endoLead : AuxVarType::exoLead), 0, 0, 0, 0, expr_arg); + aux_vars.emplace_back(symb_id, (endo ? AuxVarType::endoLead : AuxVarType::exoLead), 0, 0, 0, 0, expr_arg, ""); return symb_id; } @@ -660,7 +664,7 @@ SymbolTable::addLagAuxiliaryVarInternal(bool endo, int orig_symb_id, int orig_le exit(EXIT_FAILURE); } - aux_vars.emplace_back(symb_id, (endo ? AuxVarType::endoLag : AuxVarType::exoLag), orig_symb_id, orig_lead_lag, 0, 0, expr_arg); + aux_vars.emplace_back(symb_id, (endo ? AuxVarType::endoLag : AuxVarType::exoLag), orig_symb_id, orig_lead_lag, 0, 0, expr_arg, ""); return symb_id; } @@ -708,7 +712,7 @@ SymbolTable::addExpectationAuxiliaryVar(int information_set, int index, expr_t e exit(EXIT_FAILURE); } - aux_vars.emplace_back(symb_id, AuxVarType::expectation, 0, 0, 0, information_set, expr_arg); + aux_vars.emplace_back(symb_id, AuxVarType::expectation, 0, 0, 0, information_set, expr_arg, ""); return symb_id; } @@ -731,7 +735,7 @@ SymbolTable::addDiffLagAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_id exit(EXIT_FAILURE); } - aux_vars.emplace_back(symb_id, AuxVarType::diffLag, orig_symb_id, orig_lag, 0, 0, expr_arg); + aux_vars.emplace_back(symb_id, AuxVarType::diffLag, orig_symb_id, orig_lag, 0, 0, expr_arg, ""); return symb_id; } @@ -754,7 +758,7 @@ SymbolTable::addDiffAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_id, i exit(EXIT_FAILURE); } - aux_vars.emplace_back(symb_id, AuxVarType::diff, orig_symb_id, orig_lag, 0, 0, expr_arg); + aux_vars.emplace_back(symb_id, AuxVarType::diff, orig_symb_id, orig_lag, 0, 0, expr_arg, ""); return symb_id; } @@ -766,7 +770,7 @@ SymbolTable::addDiffAuxiliaryVar(int index, expr_t expr_arg) noexcept(false) } int -SymbolTable::addUnaryOpAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_id, int orig_lag) noexcept(false) +SymbolTable::addUnaryOpAuxiliaryVar(int index, expr_t expr_arg, string unary_op, int orig_symb_id, int orig_lag) noexcept(false) { ostringstream varname; int symb_id; @@ -782,7 +786,7 @@ SymbolTable::addUnaryOpAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_id exit(EXIT_FAILURE); } - aux_vars.emplace_back(symb_id, AuxVarType::unaryOp, orig_symb_id, orig_lag, 0, 0, expr_arg); + aux_vars.emplace_back(symb_id, AuxVarType::unaryOp, orig_symb_id, orig_lag, 0, 0, expr_arg, unary_op); return symb_id; } @@ -804,7 +808,7 @@ SymbolTable::addVarModelEndoLagAuxiliaryVar(int orig_symb_id, int orig_lead_lag, exit(EXIT_FAILURE); } - aux_vars.emplace_back(symb_id, AuxVarType::varModel, orig_symb_id, orig_lead_lag, 0, 0, expr_arg); + aux_vars.emplace_back(symb_id, AuxVarType::varModel, orig_symb_id, orig_lead_lag, 0, 0, expr_arg, ""); return symb_id; } @@ -826,7 +830,7 @@ SymbolTable::addMultiplierAuxiliaryVar(int index) noexcept(false) exit(EXIT_FAILURE); } - aux_vars.emplace_back(symb_id, AuxVarType::multiplier, 0, 0, index, 0, nullptr); + aux_vars.emplace_back(symb_id, AuxVarType::multiplier, 0, 0, index, 0, nullptr, ""); return symb_id; } @@ -847,7 +851,7 @@ SymbolTable::addDiffForwardAuxiliaryVar(int orig_symb_id, expr_t expr_arg) noexc exit(EXIT_FAILURE); } - aux_vars.emplace_back(symb_id, AuxVarType::diffForward, orig_symb_id, 0, 0, 0, expr_arg); + aux_vars.emplace_back(symb_id, AuxVarType::diffForward, orig_symb_id, 0, 0, 0, expr_arg, ""); return symb_id; } @@ -1137,27 +1141,28 @@ SymbolTable::writeJuliaOutput(ostream &output) const noexcept(false) case AuxVarType::exoLag: case AuxVarType::varModel: output << getTypeSpecificID(aux_var.get_orig_symb_id()) + 1 << ", " - << aux_var.get_orig_lead_lag() << ", typemin(Int), string()"; + << aux_var.get_orig_lead_lag() << ", typemin(Int), string(), string()"; break; case AuxVarType::unaryOp: if (aux_var.get_orig_symb_id() >= 0) output << getTypeSpecificID(aux_var.get_orig_symb_id()) + 1 << ", " << aux_var.get_orig_lead_lag(); else output << "typemin(Int), typemin(Int)"; - output << ", typemin(Int), string()"; + output << ", typemin(Int), string(), " + << "\"" << aux_var.get_unary_op() << "\"" << endl; break; case AuxVarType::diff: case AuxVarType::diffLag: if (aux_var.get_orig_symb_id() >= 0) output << getTypeSpecificID(aux_var.get_orig_symb_id()) + 1 << ", " - << aux_var.get_orig_lead_lag() << ", typemin(Int), string()"; + << aux_var.get_orig_lead_lag() << ", typemin(Int), string(), string()"; break; case AuxVarType::multiplier: output << "typemin(Int), typemin(Int), " << aux_var.get_equation_number_for_multiplier() + 1 - << ", string()"; + << ", string(), string()"; break; case AuxVarType::diffForward: - output << getTypeSpecificID(aux_var.get_orig_symb_id())+1 << ", typemin(Int), typemin(Int), string()"; + output << getTypeSpecificID(aux_var.get_orig_symb_id())+1 << ", typemin(Int), typemin(Int), string(), string()"; break; case AuxVarType::expectation: output << "typemin(Int), typemin(Int), typemin(Int), \"\\mathbb{E}_{t" @@ -1167,7 +1172,7 @@ SymbolTable::writeJuliaOutput(ostream &output) const noexcept(false) output << ")\""; break; default: - output << " typemin(Int), typemin(Int), typemin(Int), string()"; + output << " typemin(Int), typemin(Int), typemin(Int), string(), string()"; } output << ")" << endl; } diff --git a/src/SymbolTable.hh b/src/SymbolTable.hh index 19488e68453e221bb106a37e2c5e37b20257be37..49027f48a1124df16e0bd879a1188f53c7eab9db 100644 --- a/src/SymbolTable.hh +++ b/src/SymbolTable.hh @@ -61,8 +61,9 @@ private: int equation_number_for_multiplier; //!< Stores the original constraint equation number associated with this aux var. Only used for avMultiplier. int information_set; //! Argument of expectation operator. Only used for avExpectation. expr_t expr_node; //! Auxiliary variable definition + string unary_op; //! Used with AuxUnaryOp public: - AuxVarInfo(int symb_id_arg, AuxVarType type_arg, int orig_symb_id, int orig_lead_lag, int equation_number_for_multiplier_arg, int information_set_arg, expr_t expr_node_arg); + AuxVarInfo(int symb_id_arg, AuxVarType type_arg, int orig_symb_id, int orig_lead_lag, int equation_number_for_multiplier_arg, int information_set_arg, expr_t expr_node_arg, string unary_op_arg); int get_symb_id() const { @@ -103,6 +104,11 @@ public: { return expr_node; }; + string + get_unary_op() const + { + return unary_op; + }; }; //! Stores the symbol table @@ -304,7 +310,7 @@ public: //! Takes care of timing between diff statements int addDiffLagAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_id, int orig_lag) noexcept(false); //! An Auxiliary variable for a unary op - int addUnaryOpAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_id = -1, int orig_lag = 0) noexcept(false); + int addUnaryOpAuxiliaryVar(int index, expr_t expr_arg, string unary_op, int orig_symb_id = -1, int orig_lag = 0) noexcept(false); //! Returns the number of auxiliary variables int AuxVarsSize() const