Commit 6d573547 authored by Houtan Bastani's avatar Houtan Bastani

preprocessor: add var_model info to restrictions class

parent 1d079d4e
......@@ -260,6 +260,7 @@ VarEstimationStatement::writeOutput(ostream &output, const string &basename, boo
}
VarRestrictionsStatement::VarRestrictionsStatement(const string &var_model_name_arg,
const map<string, vector<string> > &var_map_arg,
const map<int, map<int, SymbolList> > &exclusion_restrictions_arg,
const equation_restrictions_t &equation_restrictions_arg,
const crossequation_restrictions_t &crossequation_restrictions_arg,
......@@ -267,6 +268,7 @@ VarRestrictionsStatement::VarRestrictionsStatement(const string &var_model_name_
const map<pair<int, int>, pair<int, int> > &covariance_pair_restriction_arg,
const SymbolTable &symbol_table_arg ) :
var_model_name(var_model_name_arg),
var_map(var_map_arg),
exclusion_restrictions(exclusion_restrictions_arg),
equation_restrictions(equation_restrictions_arg),
crossequation_restrictions(crossequation_restrictions_arg),
......@@ -276,9 +278,40 @@ VarRestrictionsStatement::VarRestrictionsStatement(const string &var_model_name_
{
}
int
VarRestrictionsStatement::findIdxInVector(const vector<string> &vecvars, const string &var) const
{
int idx = 0;
bool setflag = false;
for (vector<string>::const_iterator itvs = vecvars.begin();
itvs != vecvars.end(); itvs++, idx++)
if (*itvs == var)
{
setflag = true;
break;
}
if (!setflag)
{
cerr << "ERROR: you are imposing an exclusion restriction on an equation or variable "
<< var << " that is not contained in VAR " << var_model_name;
exit(EXIT_FAILURE);
}
return idx;
}
void
VarRestrictionsStatement::writeOutput(ostream &output, const string &basename, bool minimal_workspace) const
{
map<string, vector<string> >::const_iterator itvs = var_map.find(var_model_name);
if (itvs == var_map.end())
{
cerr << "ERROR: you are imposing restrictions on a VAR named " << var_model_name
<< " but this VAR has not been declared via thevar_model statement." << endl;
exit(EXIT_FAILURE);
}
vector<string> vars = itvs->second;
string Mstr ("M_.var." + var_model_name + ".restrictions.");
int nrestrictions = 0;
......@@ -295,9 +328,14 @@ VarRestrictionsStatement::writeOutput(ostream &output, const string &basename, b
{
if (it1 != it->second.begin())
output << " ";
output << "{'" << symbol_table.getName(it1->first) << "', ";
it1->second.write(output);
output << "};";
output << "struct('eq', " << findIdxInVector(vars, symbol_table.getName(it1->first)) + 1
<< ", 'vars', [";
vector<string> excvars = it1->second.getSymbols();
for (vector<string>::const_iterator itvs1 = excvars.begin();
itvs1 != excvars.end(); itvs1++)
output << findIdxInVector(vars, *itvs1) + 1 << " ";
output << "])";
nrestrictions += it1->second.getSize();
}
output << "];" << endl;
......@@ -344,10 +382,10 @@ VarRestrictionsStatement::writeOutput(ostream &output, const string &basename, b
<< it->second << ";" << endl;
var_restriction_eq_crosseq_t ls = it->first.first;
output << Mstr << "crossequation_restriction{" << idx << "}.lseq = '"
<< symbol_table.getName(ls.first.first) << "';" << endl
<< Mstr << "crossequation_restriction{" << idx << "}.lsvar = '"
<< symbol_table.getName(ls.first.second.first) << "';" << endl
output << Mstr << "crossequation_restriction{" << idx << "}.lseq = "
<< findIdxInVector(vars, symbol_table.getName(ls.first.first)) + 1 << ";" << endl
<< Mstr << "crossequation_restriction{" << idx << "}.lsvar = "
<< findIdxInVector(vars, symbol_table.getName(ls.first.second.first)) + 1 << ";" << endl
<< Mstr << "crossequation_restriction{" << idx << "}.lslag = "
<< ls.first.second.second << ";" << endl
<< Mstr << "crossequation_restriction{" << idx << "}.lscoeff = ";
......@@ -357,10 +395,10 @@ VarRestrictionsStatement::writeOutput(ostream &output, const string &basename, b
var_restriction_eq_crosseq_t rs = it->first.second;
if (rs.first.first >= 0)
{
output << Mstr << "crossequation_restriction{" << idx << "}.rseq = '"
<< symbol_table.getName(rs.first.first) << "';" << endl
<< Mstr << "crossequation_restriction{" << idx << "}.rsvar = '"
<< symbol_table.getName(rs.first.second.first) << "';" << endl
output << Mstr << "crossequation_restriction{" << idx << "}.rseq = "
<< findIdxInVector(vars, symbol_table.getName(rs.first.first)) + 1 << ";" << endl
<< Mstr << "crossequation_restriction{" << idx << "}.rsvar = "
<< findIdxInVector(vars, symbol_table.getName(rs.first.second.first)) + 1 << ";" << endl
<< Mstr << "crossequation_restriction{" << idx << "}.rslag = "
<< rs.first.second.second << ";" << endl
<< Mstr << "crossequation_restriction{" << idx << "}.rscoeff = ";
......
......@@ -131,6 +131,7 @@ class VarRestrictionsStatement : public Statement
private:
typedef pair<pair<int, pair<int, int> >, expr_t> var_restriction_eq_crosseq_t;
const string &var_model_name;
const map<string, vector<string> > &var_map;
const map<int, map<int, SymbolList> > exclusion_restrictions;
typedef map<int, pair<pair<var_restriction_eq_crosseq_t, var_restriction_eq_crosseq_t>, double> > equation_restrictions_t;
const equation_restrictions_t equation_restrictions;
......@@ -139,8 +140,10 @@ private:
const map<pair<int, int>, double> covariance_number_restriction;
const map<pair<int, int>, pair<int, int> > covariance_pair_restriction;
const SymbolTable &symbol_table;
int findIdxInVector(const vector<string> &vecvars, const string &var) const;
public:
VarRestrictionsStatement(const string &var_model_name_arg,
const map<string, vector<string> > &var_map_arg,
const map<int, map<int, SymbolList> > &exclusion_restrictions_arg,
const equation_restrictions_t &equation_restrictions_arg,
const crossequation_restrictions_t &crossequation_restrictions_arg,
......
......@@ -491,6 +491,7 @@ void
ParsingDriver::end_VAR_restrictions(string *var_model_name)
{
mod_file->addStatement(new VarRestrictionsStatement(*var_model_name,
var_map,
exclusion_restrictions,
equation_restrictions,
crossequation_restrictions,
......@@ -1452,6 +1453,7 @@ ParsingDriver::var_model()
error("You must pass the model_name option to the var_model statement.");
const string *name = new string(it->second);
mod_file->addStatement(new VarModelStatement(symbol_list, options_list, *name));
var_map[it->second] = symbol_list.getSymbols();
symbol_list.clear();
options_list.clear();
}
......
......@@ -227,6 +227,9 @@ private:
//! Temporary storage for equation tags
vector<pair<string, string> > eq_tags;
//! Map Var name to variables
map<string, vector<string> > var_map;
//! The mod file representation constructed by this ParsingDriver
ModFile *mod_file;
......
......@@ -53,6 +53,12 @@ SymbolList::getSize() const
return symbols.size();
}
vector<string>
SymbolList::getSymbols() const
{
return symbols;
}
void
SymbolList::clear()
{
......
......@@ -49,6 +49,8 @@ public:
int empty() const { return symbols.empty(); };
//! Return the number of Symbols contained in the list
int getSize() const;
//! Return the list of symbols
vector<string> getSymbols() const;
};
#endif
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment