diff --git a/src/ComputingTasks.cc b/src/ComputingTasks.cc index b9157bf1d3ed4af46d45447a73bc9d926352a5b1..f3d34ae9b1e3bf455423829a3564cc08f63a8334 100644 --- a/src/ComputingTasks.cc +++ b/src/ComputingTasks.cc @@ -296,7 +296,7 @@ PacModelStatement::overwriteGrowth(expr_t new_growth) try { - growth_info = growth->matchLinearCombinationOfVariables(); + growth_info = growth->matchLinearCombinationOfVariables(false); } catch (ExprNode::MatchFailureException &e) { @@ -325,67 +325,56 @@ PacModelStatement::writeOutput(ostream &output, const string &basename, bool min if (growth) { - size_t nlc = growth_info.size(); - output << "M_.pac." << name << ".growth_index = repmat(-1, " << nlc << ", 1);" << endl - << "M_.pac." << name << ".growth_lag = zeros(" << nlc << ", 1);" << endl - << "M_.pac." << name << ".growth_param_id = repmat(-1, " << nlc << ", 1);" << endl - << "M_.pac." << name << ".growth_constant = zeros(" << nlc << ", 1);" << endl - << "M_.pac." << name << ".growth_type = repmat({''}, " << nlc << ", 1);" << endl - << "M_.pac." << name << ".growth_part_str = repmat({''}, " << nlc << ", 1);" << endl; + output << "M_.pac." << name << ".growth_str = '"; + original_growth->writeJsonOutput(output, {}, {}, true); + output << "';" << endl; int i = 0; for (auto [growth_symb_id, growth_lag, param_id, constant] : growth_info) { - i++; - string growth_type; - switch (symbol_table.getType(growth_symb_id)) - { - case SymbolType::endogenous: - growth_type = "endogenous"; - break; - case SymbolType::exogenous: - growth_type = "exogenous"; - break; - case SymbolType::parameter: - growth_type = "parameter"; - break; - default: - { - } - } - - try - { - // case when this is not the highest lag of the growth variable - int aux_symb_id = symbol_table.searchAuxiliaryVars(growth_symb_id, growth_lag); - output << "M_.pac." << name << ".growth_index(" << i << ") = " << symbol_table.getTypeSpecificID(aux_symb_id) + 1 << ";" << endl - << "M_.pac." << name << ".growth_lag(" << i << ") = 0;" << endl; - } - catch (...) + string structname = "M_.pac." + name + ".growth_linear_comb(" + to_string(++i) + ")."; + if (growth_symb_id >= 0) { + string var_field = "endo_id"; + if (symbol_table.getType(growth_symb_id) == SymbolType::exogenous) + { + var_field = "exo_id"; + output << structname << "endo_id = 0;" << endl; + } + else + output << structname << "exo_id = 0;" << endl; try { - // case when this is the highest lag of the growth variable - int tmp_growth_lag = growth_lag + 1; - int aux_symb_id = symbol_table.searchAuxiliaryVars(growth_symb_id, tmp_growth_lag); - output << "M_.pac." << name << ".growth_index(" << i << ") = " << symbol_table.getTypeSpecificID(aux_symb_id) + 1 << ";" << endl - << "M_.pac." << name << ".growth_lag(" << i << ") = -1;" << endl; + // case when this is not the highest lag of the growth variable + int aux_symb_id = symbol_table.searchAuxiliaryVars(growth_symb_id, growth_lag); + output << structname << var_field << " = " << symbol_table.getTypeSpecificID(aux_symb_id) + 1 << ";" << endl + << structname << "lag = 0;" << endl; } catch (...) { - // case when there is no aux var for the variable - output << "M_.pac." << name << ".growth_index(" << i << ") = " << symbol_table.getTypeSpecificID(growth_symb_id) + 1 << ";" << endl - << "M_.pac." << name << ".growth_lag(" << i << ") = " << growth_lag << ";" << endl; + try + { + // case when this is the highest lag of the growth variable + int tmp_growth_lag = growth_lag + 1; + int aux_symb_id = symbol_table.searchAuxiliaryVars(growth_symb_id, tmp_growth_lag); + output << structname << var_field << " = " << symbol_table.getTypeSpecificID(aux_symb_id) + 1 << ";" << endl + << structname << "lag = -1;" << endl; + } + catch (...) + { + // case when there is no aux var for the variable + output << structname << var_field << " = "<< symbol_table.getTypeSpecificID(growth_symb_id) + 1 << ";" << endl + << structname << "lag = " << growth_lag << ";" << endl; + } } } - - output << "M_.pac." << name << ".growth_param_id(" << i << ") = " - << (param_id == -1 ? -1 : symbol_table.getTypeSpecificID(param_id)) + 1 << ";" << endl - << "M_.pac." << name << ".growth_constant(" << i << ") = " << constant << ";" << endl - << "M_.pac." << name << ".growth_type{" << i << "} = '" << growth_type << "';" << endl; + else + output << structname << "endo_id = 0;" << endl + << structname << "exo_id = 0;" << endl + << structname << "lag = 0;" << endl; + output << structname << "param_id = " + << (param_id == -1 ? 0 : symbol_table.getTypeSpecificID(param_id) + 1) << ";" << endl + << structname << "constant = " << constant << ";" << endl; } - output << "M_.pac." << name << ".growth_str = '"; - original_growth->writeJsonOutput(output, {}, {}, true); - output << "';" << endl; } } diff --git a/src/ExprNode.cc b/src/ExprNode.cc index 7e0a814f7930a1f90c24edaa166018b13dd36e37..8969d448e6dac67966d23b58ad1cec3558eea131 100644 --- a/src/ExprNode.cc +++ b/src/ExprNode.cc @@ -9537,12 +9537,12 @@ BinaryOpNode::decomposeAdditiveTerms(vector<pair<expr_t, int>> &terms, int curre } tuple<int, int, int, double> -ExprNode::matchVariableTimesConstantTimesParam() const +ExprNode::matchVariableTimesConstantTimesParam(bool variable_obligatory) const { int variable_id = -1, lag = 0, param_id = -1; double constant = 1.0; matchVTCTPHelper(variable_id, lag, param_id, constant, false); - if (variable_id == -1) + if (variable_obligatory && variable_id == -1) throw MatchFailureException{"No variable in this expression"}; return {variable_id, lag, param_id, constant}; } @@ -9615,7 +9615,7 @@ BinaryOpNode::matchVTCTPHelper(int &var_id, int &lag, int ¶m_id, double &con } vector<tuple<int, int, int, double>> -ExprNode::matchLinearCombinationOfVariables() const +ExprNode::matchLinearCombinationOfVariables(bool variable_obligatory_in_each_term) const { vector<pair<expr_t, int>> terms; decomposeAdditiveTerms(terms); @@ -9626,7 +9626,7 @@ ExprNode::matchLinearCombinationOfVariables() const { expr_t term = it.first; int sign = it.second; - auto m = term->matchVariableTimesConstantTimesParam(); + auto m = term->matchVariableTimesConstantTimesParam(variable_obligatory_in_each_term); get<3>(m) *= sign; result.push_back(m); } diff --git a/src/ExprNode.hh b/src/ExprNode.hh index 4ef3e0bd18fc5f0e8433e3cabdf24d01317da22e..c7326d02a8bfe76ecfc8a618171084cc34ffa975 100644 --- a/src/ExprNode.hh +++ b/src/ExprNode.hh @@ -604,8 +604,11 @@ class ExprNode /*! Returns a list of (variable_id, lag, param_id, constant) corresponding to the terms in the expression. When there is no parameter in a term, param_id == -1. - Can throw a MatchFailureException. */ - vector<tuple<int, int, int, double>> matchLinearCombinationOfVariables() const; + Can throw a MatchFailureException. + if `variable_obligatory_in_each_term` is true, then every part of the linear combination must contain a variable; + otherwise, if `variable_obligatory_in_each_term`, then any linear combination of constant/variable/param is matched + */ + vector<tuple<int, int, int, double>> matchLinearCombinationOfVariables(bool variable_obligatory_in_each_term = true) const; pair<int, vector<tuple<int, int, int, double>>> matchParamTimesLinearCombinationOfVariables() const; @@ -645,8 +648,11 @@ class ExprNode denominator (i.e. after a divide sign). The parameter is optional (in which case param_id == -1). If the expression is not of the expected form, throws a - MatchFailureException */ - tuple<int, int, int, double> matchVariableTimesConstantTimesParam() const; + MatchFailureException + if `variable_obligatory` is true, then the linear combination must contain a variable; + otherwise, if `variable_obligatory`, then an expression is matched that has any mix of constant/variable/param + */ + tuple<int, int, int, double> matchVariableTimesConstantTimesParam(bool variable_obligatory = true) const; //! Exception thrown by matchVariableTimesConstantTimesParam when matching fails class MatchFailureException