diff --git a/ComputingTasks.cc b/ComputingTasks.cc index e792360a2ebd4c135eb2f712a9f61ae22e58ca97..6ab9e27c82edf0e8f7d7cf40b18c96d57c76ba9e 100644 --- a/ComputingTasks.cc +++ b/ComputingTasks.cc @@ -260,6 +260,7 @@ VarEstimationStatement::writeOutput(ostream &output, const string &basename, boo } VarRestrictionsStatement::VarRestrictionsStatement(const string &var_model_name_arg, + const map<string, vector<string> > &var_map_arg, const map<int, map<int, SymbolList> > &exclusion_restrictions_arg, const equation_restrictions_t &equation_restrictions_arg, const crossequation_restrictions_t &crossequation_restrictions_arg, @@ -267,6 +268,7 @@ VarRestrictionsStatement::VarRestrictionsStatement(const string &var_model_name_ const map<pair<int, int>, pair<int, int> > &covariance_pair_restriction_arg, const SymbolTable &symbol_table_arg ) : var_model_name(var_model_name_arg), + var_map(var_map_arg), exclusion_restrictions(exclusion_restrictions_arg), equation_restrictions(equation_restrictions_arg), crossequation_restrictions(crossequation_restrictions_arg), @@ -276,9 +278,40 @@ VarRestrictionsStatement::VarRestrictionsStatement(const string &var_model_name_ { } +int +VarRestrictionsStatement::findIdxInVector(const vector<string> &vecvars, const string &var) const +{ + int idx = 0; + bool setflag = false; + for (vector<string>::const_iterator itvs = vecvars.begin(); + itvs != vecvars.end(); itvs++, idx++) + if (*itvs == var) + { + setflag = true; + break; + } + + if (!setflag) + { + cerr << "ERROR: you are imposing an exclusion restriction on an equation or variable " + << var << " that is not contained in VAR " << var_model_name; + exit(EXIT_FAILURE); + } + return idx; +} + void VarRestrictionsStatement::writeOutput(ostream &output, const string &basename, bool minimal_workspace) const { + map<string, vector<string> >::const_iterator itvs = var_map.find(var_model_name); + if (itvs == var_map.end()) + { + cerr << "ERROR: you are imposing restrictions on a VAR named " << var_model_name + << " but this VAR has not been declared via thevar_model statement." << endl; + exit(EXIT_FAILURE); + } + vector<string> vars = itvs->second; + string Mstr ("M_.var." + var_model_name + ".restrictions."); int nrestrictions = 0; @@ -295,9 +328,14 @@ VarRestrictionsStatement::writeOutput(ostream &output, const string &basename, b { if (it1 != it->second.begin()) output << " "; - output << "{'" << symbol_table.getName(it1->first) << "', "; - it1->second.write(output); - output << "};"; + + output << "struct('eq', " << findIdxInVector(vars, symbol_table.getName(it1->first)) + 1 + << ", 'vars', ["; + vector<string> excvars = it1->second.getSymbols(); + for (vector<string>::const_iterator itvs1 = excvars.begin(); + itvs1 != excvars.end(); itvs1++) + output << findIdxInVector(vars, *itvs1) + 1 << " "; + output << "])"; nrestrictions += it1->second.getSize(); } output << "];" << endl; @@ -344,10 +382,10 @@ VarRestrictionsStatement::writeOutput(ostream &output, const string &basename, b << it->second << ";" << endl; var_restriction_eq_crosseq_t ls = it->first.first; - output << Mstr << "crossequation_restriction{" << idx << "}.lseq = '" - << symbol_table.getName(ls.first.first) << "';" << endl - << Mstr << "crossequation_restriction{" << idx << "}.lsvar = '" - << symbol_table.getName(ls.first.second.first) << "';" << endl + output << Mstr << "crossequation_restriction{" << idx << "}.lseq = " + << findIdxInVector(vars, symbol_table.getName(ls.first.first)) + 1 << ";" << endl + << Mstr << "crossequation_restriction{" << idx << "}.lsvar = " + << findIdxInVector(vars, symbol_table.getName(ls.first.second.first)) + 1 << ";" << endl << Mstr << "crossequation_restriction{" << idx << "}.lslag = " << ls.first.second.second << ";" << endl << Mstr << "crossequation_restriction{" << idx << "}.lscoeff = "; @@ -357,10 +395,10 @@ VarRestrictionsStatement::writeOutput(ostream &output, const string &basename, b var_restriction_eq_crosseq_t rs = it->first.second; if (rs.first.first >= 0) { - output << Mstr << "crossequation_restriction{" << idx << "}.rseq = '" - << symbol_table.getName(rs.first.first) << "';" << endl - << Mstr << "crossequation_restriction{" << idx << "}.rsvar = '" - << symbol_table.getName(rs.first.second.first) << "';" << endl + output << Mstr << "crossequation_restriction{" << idx << "}.rseq = " + << findIdxInVector(vars, symbol_table.getName(rs.first.first)) + 1 << ";" << endl + << Mstr << "crossequation_restriction{" << idx << "}.rsvar = " + << findIdxInVector(vars, symbol_table.getName(rs.first.second.first)) + 1 << ";" << endl << Mstr << "crossequation_restriction{" << idx << "}.rslag = " << rs.first.second.second << ";" << endl << Mstr << "crossequation_restriction{" << idx << "}.rscoeff = "; diff --git a/ComputingTasks.hh b/ComputingTasks.hh index 94a0ecfb36229ba5f077a96cc67f4459a8abe2cb..3c675718f9ab0565ed570e83db065d18bf716e57 100644 --- a/ComputingTasks.hh +++ b/ComputingTasks.hh @@ -131,6 +131,7 @@ class VarRestrictionsStatement : public Statement private: typedef pair<pair<int, pair<int, int> >, expr_t> var_restriction_eq_crosseq_t; const string &var_model_name; + const map<string, vector<string> > &var_map; const map<int, map<int, SymbolList> > exclusion_restrictions; typedef map<int, pair<pair<var_restriction_eq_crosseq_t, var_restriction_eq_crosseq_t>, double> > equation_restrictions_t; const equation_restrictions_t equation_restrictions; @@ -139,8 +140,10 @@ private: const map<pair<int, int>, double> covariance_number_restriction; const map<pair<int, int>, pair<int, int> > covariance_pair_restriction; const SymbolTable &symbol_table; + int findIdxInVector(const vector<string> &vecvars, const string &var) const; public: VarRestrictionsStatement(const string &var_model_name_arg, + const map<string, vector<string> > &var_map_arg, const map<int, map<int, SymbolList> > &exclusion_restrictions_arg, const equation_restrictions_t &equation_restrictions_arg, const crossequation_restrictions_t &crossequation_restrictions_arg, diff --git a/ParsingDriver.cc b/ParsingDriver.cc index 9fd948870cdad5e970c594b1456e7167a5f5e303..a7e81352c01403388682d7d43987089997242949 100644 --- a/ParsingDriver.cc +++ b/ParsingDriver.cc @@ -491,6 +491,7 @@ void ParsingDriver::end_VAR_restrictions(string *var_model_name) { mod_file->addStatement(new VarRestrictionsStatement(*var_model_name, + var_map, exclusion_restrictions, equation_restrictions, crossequation_restrictions, @@ -1452,6 +1453,7 @@ ParsingDriver::var_model() error("You must pass the model_name option to the var_model statement."); const string *name = new string(it->second); mod_file->addStatement(new VarModelStatement(symbol_list, options_list, *name)); + var_map[it->second] = symbol_list.getSymbols(); symbol_list.clear(); options_list.clear(); } diff --git a/ParsingDriver.hh b/ParsingDriver.hh index 0198f52dbabbd658636f0581b039127cd4677f7d..a54407c2a73db58d537c99fea2b603194092591d 100644 --- a/ParsingDriver.hh +++ b/ParsingDriver.hh @@ -227,6 +227,9 @@ private: //! Temporary storage for equation tags vector<pair<string, string> > eq_tags; + //! Map Var name to variables + map<string, vector<string> > var_map; + //! The mod file representation constructed by this ParsingDriver ModFile *mod_file; diff --git a/SymbolList.cc b/SymbolList.cc index 692517083f8f9008d3943c748d16ce77825d8467..8a8f8c62b4bbd1801bd6e0962971d235d663f10f 100644 --- a/SymbolList.cc +++ b/SymbolList.cc @@ -53,6 +53,12 @@ SymbolList::getSize() const return symbols.size(); } +vector<string> +SymbolList::getSymbols() const +{ + return symbols; +} + void SymbolList::clear() { diff --git a/SymbolList.hh b/SymbolList.hh index 7ec5a13b8ff19784b3ec6b617730d370eefe7714..988c984eb82b4def0f1cf99ac470ed50bec9feee 100644 --- a/SymbolList.hh +++ b/SymbolList.hh @@ -49,6 +49,8 @@ public: int empty() const { return symbols.empty(); }; //! Return the number of Symbols contained in the list int getSize() const; + //! Return the list of symbols + vector<string> getSymbols() const; }; #endif