Commit 6d573547 authored by Houtan Bastani's avatar Houtan Bastani

preprocessor: add var_model info to restrictions class

parent 1d079d4e
...@@ -260,6 +260,7 @@ VarEstimationStatement::writeOutput(ostream &output, const string &basename, boo ...@@ -260,6 +260,7 @@ VarEstimationStatement::writeOutput(ostream &output, const string &basename, boo
} }
VarRestrictionsStatement::VarRestrictionsStatement(const string &var_model_name_arg, VarRestrictionsStatement::VarRestrictionsStatement(const string &var_model_name_arg,
const map<string, vector<string> > &var_map_arg,
const map<int, map<int, SymbolList> > &exclusion_restrictions_arg, const map<int, map<int, SymbolList> > &exclusion_restrictions_arg,
const equation_restrictions_t &equation_restrictions_arg, const equation_restrictions_t &equation_restrictions_arg,
const crossequation_restrictions_t &crossequation_restrictions_arg, const crossequation_restrictions_t &crossequation_restrictions_arg,
...@@ -267,6 +268,7 @@ VarRestrictionsStatement::VarRestrictionsStatement(const string &var_model_name_ ...@@ -267,6 +268,7 @@ VarRestrictionsStatement::VarRestrictionsStatement(const string &var_model_name_
const map<pair<int, int>, pair<int, int> > &covariance_pair_restriction_arg, const map<pair<int, int>, pair<int, int> > &covariance_pair_restriction_arg,
const SymbolTable &symbol_table_arg ) : const SymbolTable &symbol_table_arg ) :
var_model_name(var_model_name_arg), var_model_name(var_model_name_arg),
var_map(var_map_arg),
exclusion_restrictions(exclusion_restrictions_arg), exclusion_restrictions(exclusion_restrictions_arg),
equation_restrictions(equation_restrictions_arg), equation_restrictions(equation_restrictions_arg),
crossequation_restrictions(crossequation_restrictions_arg), crossequation_restrictions(crossequation_restrictions_arg),
...@@ -276,9 +278,40 @@ VarRestrictionsStatement::VarRestrictionsStatement(const string &var_model_name_ ...@@ -276,9 +278,40 @@ VarRestrictionsStatement::VarRestrictionsStatement(const string &var_model_name_
{ {
} }
int
VarRestrictionsStatement::findIdxInVector(const vector<string> &vecvars, const string &var) const
{
int idx = 0;
bool setflag = false;
for (vector<string>::const_iterator itvs = vecvars.begin();
itvs != vecvars.end(); itvs++, idx++)
if (*itvs == var)
{
setflag = true;
break;
}
if (!setflag)
{
cerr << "ERROR: you are imposing an exclusion restriction on an equation or variable "
<< var << " that is not contained in VAR " << var_model_name;
exit(EXIT_FAILURE);
}
return idx;
}
void void
VarRestrictionsStatement::writeOutput(ostream &output, const string &basename, bool minimal_workspace) const VarRestrictionsStatement::writeOutput(ostream &output, const string &basename, bool minimal_workspace) const
{ {
map<string, vector<string> >::const_iterator itvs = var_map.find(var_model_name);
if (itvs == var_map.end())
{
cerr << "ERROR: you are imposing restrictions on a VAR named " << var_model_name
<< " but this VAR has not been declared via thevar_model statement." << endl;
exit(EXIT_FAILURE);
}
vector<string> vars = itvs->second;
string Mstr ("M_.var." + var_model_name + ".restrictions."); string Mstr ("M_.var." + var_model_name + ".restrictions.");
int nrestrictions = 0; int nrestrictions = 0;
...@@ -295,9 +328,14 @@ VarRestrictionsStatement::writeOutput(ostream &output, const string &basename, b ...@@ -295,9 +328,14 @@ VarRestrictionsStatement::writeOutput(ostream &output, const string &basename, b
{ {
if (it1 != it->second.begin()) if (it1 != it->second.begin())
output << " "; output << " ";
output << "{'" << symbol_table.getName(it1->first) << "', ";
it1->second.write(output); output << "struct('eq', " << findIdxInVector(vars, symbol_table.getName(it1->first)) + 1
output << "};"; << ", 'vars', [";
vector<string> excvars = it1->second.getSymbols();
for (vector<string>::const_iterator itvs1 = excvars.begin();
itvs1 != excvars.end(); itvs1++)
output << findIdxInVector(vars, *itvs1) + 1 << " ";
output << "])";
nrestrictions += it1->second.getSize(); nrestrictions += it1->second.getSize();
} }
output << "];" << endl; output << "];" << endl;
...@@ -344,10 +382,10 @@ VarRestrictionsStatement::writeOutput(ostream &output, const string &basename, b ...@@ -344,10 +382,10 @@ VarRestrictionsStatement::writeOutput(ostream &output, const string &basename, b
<< it->second << ";" << endl; << it->second << ";" << endl;
var_restriction_eq_crosseq_t ls = it->first.first; var_restriction_eq_crosseq_t ls = it->first.first;
output << Mstr << "crossequation_restriction{" << idx << "}.lseq = '" output << Mstr << "crossequation_restriction{" << idx << "}.lseq = "
<< symbol_table.getName(ls.first.first) << "';" << endl << findIdxInVector(vars, symbol_table.getName(ls.first.first)) + 1 << ";" << endl
<< Mstr << "crossequation_restriction{" << idx << "}.lsvar = '" << Mstr << "crossequation_restriction{" << idx << "}.lsvar = "
<< symbol_table.getName(ls.first.second.first) << "';" << endl << findIdxInVector(vars, symbol_table.getName(ls.first.second.first)) + 1 << ";" << endl
<< Mstr << "crossequation_restriction{" << idx << "}.lslag = " << Mstr << "crossequation_restriction{" << idx << "}.lslag = "
<< ls.first.second.second << ";" << endl << ls.first.second.second << ";" << endl
<< Mstr << "crossequation_restriction{" << idx << "}.lscoeff = "; << Mstr << "crossequation_restriction{" << idx << "}.lscoeff = ";
...@@ -357,10 +395,10 @@ VarRestrictionsStatement::writeOutput(ostream &output, const string &basename, b ...@@ -357,10 +395,10 @@ VarRestrictionsStatement::writeOutput(ostream &output, const string &basename, b
var_restriction_eq_crosseq_t rs = it->first.second; var_restriction_eq_crosseq_t rs = it->first.second;
if (rs.first.first >= 0) if (rs.first.first >= 0)
{ {
output << Mstr << "crossequation_restriction{" << idx << "}.rseq = '" output << Mstr << "crossequation_restriction{" << idx << "}.rseq = "
<< symbol_table.getName(rs.first.first) << "';" << endl << findIdxInVector(vars, symbol_table.getName(rs.first.first)) + 1 << ";" << endl
<< Mstr << "crossequation_restriction{" << idx << "}.rsvar = '" << Mstr << "crossequation_restriction{" << idx << "}.rsvar = "
<< symbol_table.getName(rs.first.second.first) << "';" << endl << findIdxInVector(vars, symbol_table.getName(rs.first.second.first)) + 1 << ";" << endl
<< Mstr << "crossequation_restriction{" << idx << "}.rslag = " << Mstr << "crossequation_restriction{" << idx << "}.rslag = "
<< rs.first.second.second << ";" << endl << rs.first.second.second << ";" << endl
<< Mstr << "crossequation_restriction{" << idx << "}.rscoeff = "; << Mstr << "crossequation_restriction{" << idx << "}.rscoeff = ";
......
...@@ -131,6 +131,7 @@ class VarRestrictionsStatement : public Statement ...@@ -131,6 +131,7 @@ class VarRestrictionsStatement : public Statement
private: private:
typedef pair<pair<int, pair<int, int> >, expr_t> var_restriction_eq_crosseq_t; typedef pair<pair<int, pair<int, int> >, expr_t> var_restriction_eq_crosseq_t;
const string &var_model_name; const string &var_model_name;
const map<string, vector<string> > &var_map;
const map<int, map<int, SymbolList> > exclusion_restrictions; const map<int, map<int, SymbolList> > exclusion_restrictions;
typedef map<int, pair<pair<var_restriction_eq_crosseq_t, var_restriction_eq_crosseq_t>, double> > equation_restrictions_t; typedef map<int, pair<pair<var_restriction_eq_crosseq_t, var_restriction_eq_crosseq_t>, double> > equation_restrictions_t;
const equation_restrictions_t equation_restrictions; const equation_restrictions_t equation_restrictions;
...@@ -139,8 +140,10 @@ private: ...@@ -139,8 +140,10 @@ private:
const map<pair<int, int>, double> covariance_number_restriction; const map<pair<int, int>, double> covariance_number_restriction;
const map<pair<int, int>, pair<int, int> > covariance_pair_restriction; const map<pair<int, int>, pair<int, int> > covariance_pair_restriction;
const SymbolTable &symbol_table; const SymbolTable &symbol_table;
int findIdxInVector(const vector<string> &vecvars, const string &var) const;
public: public:
VarRestrictionsStatement(const string &var_model_name_arg, VarRestrictionsStatement(const string &var_model_name_arg,
const map<string, vector<string> > &var_map_arg,
const map<int, map<int, SymbolList> > &exclusion_restrictions_arg, const map<int, map<int, SymbolList> > &exclusion_restrictions_arg,
const equation_restrictions_t &equation_restrictions_arg, const equation_restrictions_t &equation_restrictions_arg,
const crossequation_restrictions_t &crossequation_restrictions_arg, const crossequation_restrictions_t &crossequation_restrictions_arg,
......
...@@ -491,6 +491,7 @@ void ...@@ -491,6 +491,7 @@ void
ParsingDriver::end_VAR_restrictions(string *var_model_name) ParsingDriver::end_VAR_restrictions(string *var_model_name)
{ {
mod_file->addStatement(new VarRestrictionsStatement(*var_model_name, mod_file->addStatement(new VarRestrictionsStatement(*var_model_name,
var_map,
exclusion_restrictions, exclusion_restrictions,
equation_restrictions, equation_restrictions,
crossequation_restrictions, crossequation_restrictions,
...@@ -1452,6 +1453,7 @@ ParsingDriver::var_model() ...@@ -1452,6 +1453,7 @@ ParsingDriver::var_model()
error("You must pass the model_name option to the var_model statement."); error("You must pass the model_name option to the var_model statement.");
const string *name = new string(it->second); const string *name = new string(it->second);
mod_file->addStatement(new VarModelStatement(symbol_list, options_list, *name)); mod_file->addStatement(new VarModelStatement(symbol_list, options_list, *name));
var_map[it->second] = symbol_list.getSymbols();
symbol_list.clear(); symbol_list.clear();
options_list.clear(); options_list.clear();
} }
......
...@@ -227,6 +227,9 @@ private: ...@@ -227,6 +227,9 @@ private:
//! Temporary storage for equation tags //! Temporary storage for equation tags
vector<pair<string, string> > eq_tags; vector<pair<string, string> > eq_tags;
//! Map Var name to variables
map<string, vector<string> > var_map;
//! The mod file representation constructed by this ParsingDriver //! The mod file representation constructed by this ParsingDriver
ModFile *mod_file; ModFile *mod_file;
......
...@@ -53,6 +53,12 @@ SymbolList::getSize() const ...@@ -53,6 +53,12 @@ SymbolList::getSize() const
return symbols.size(); return symbols.size();
} }
vector<string>
SymbolList::getSymbols() const
{
return symbols;
}
void void
SymbolList::clear() SymbolList::clear()
{ {
......
...@@ -49,6 +49,8 @@ public: ...@@ -49,6 +49,8 @@ public:
int empty() const { return symbols.empty(); }; int empty() const { return symbols.empty(); };
//! Return the number of Symbols contained in the list //! Return the number of Symbols contained in the list
int getSize() const; int getSize() const;
//! Return the list of symbols
vector<string> getSymbols() const;
}; };
#endif #endif
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment