Commit 1b952a12 authored by Houtan Bastani's avatar Houtan Bastani
Browse files

fix bug in var max lag and simplify code

parent e532ed9b
......@@ -3355,23 +3355,19 @@ DynamicModel::checkVarMinLag(vector<int> &eqnumber) const
int
DynamicModel::getVarMaxLag(StaticModel &static_model, vector<int> &eqnumber) const
{
vector<expr_t> lhs;
set<expr_t> lhs;
for (vector<int>::const_iterator it = eqnumber.begin();
it != eqnumber.end(); it++)
equations[*it]->get_arg1()->collectVARLHSVariable(lhs);
if (eqnumber.size() != lhs.size())
{
set<expr_t> lhs_set;
equations[*it]->get_arg1()->collectVARLHSVariable(lhs_set);
if (lhs_set.size() != 1)
{
cerr << "ERROR: in Equation "
<< ". A VAR may only have one endogenous variable on the LHS. " << endl;
exit(EXIT_FAILURE);
}
lhs.push_back(*(lhs_set.begin()));
cerr << "The LHS variables of the VAR are not unique" << endl;
exit(EXIT_FAILURE);
}
set<expr_t> lhs_static;
for(vector<expr_t>::const_iterator it = lhs.begin();
for(set<expr_t>::const_iterator it = lhs.begin();
it != lhs.end(); it++)
lhs_static.insert((*it)->toStatic(static_model));
......@@ -3390,7 +3386,7 @@ DynamicModel::getVarLhsDiffAndInfo(vector<int> &eqnumber, vector<bool> &diff,
for (vector<int>::const_iterator it = eqnumber.begin();
it != eqnumber.end(); it++)
{
diff.push_back(equations[*it]->get_arg1()->isDiffPresent());
equations[*it]->get_arg1()->countDiffs() > 0 ? diff.push_back(true) : diff.push_back(false);
if (diff.back())
{
set<pair<int, int> > diff_set;
......
......@@ -294,12 +294,6 @@ ExprNode::isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_ar
return false;
}
bool
ExprNode::isDiffPresent() const
{
return false;
}
void
ExprNode::getEndosAndMaxLags(map<string, int> &model_endos_and_lags) const
{
......@@ -313,10 +307,10 @@ NumConstNode::NumConstNode(DataTree &datatree_arg, int id_arg) :
datatree.num_const_node_map[id] = this;
}
bool
NumConstNode::isDiffPresent() const
int
NumConstNode::countDiffs() const
{
return false;
return 0;
}
void
......@@ -389,6 +383,8 @@ NumConstNode::compile(ostream &CompileCode, unsigned int &instruction_number,
void
NumConstNode::collectVARLHSVariable(set<expr_t> &result) const
{
cerr << "ERROR: you can only have variables or unary ops on LHS of VAR" << endl;
exit(EXIT_FAILURE);
}
void
......@@ -1144,7 +1140,7 @@ VariableNode::collectVARLHSVariable(set<expr_t> &result) const
result.insert(const_cast<VariableNode *>(this));
else
{
cerr << "ERROR: A VAR must have one endogenous variable on the LHS." << endl;
cerr << "ERROR: you can only have endogenous variables or unary ops on LHS of VAR" << endl;
exit(EXIT_FAILURE);
}
}
......@@ -1730,10 +1726,10 @@ VariableNode::detrend(int symb_id, bool log_trend, expr_t trend) const
}
}
bool
VariableNode::isDiffPresent() const
int
VariableNode::countDiffs() const
{
return false;
return 0;
}
expr_t
......@@ -2965,19 +2961,20 @@ UnaryOpNode::VarMaxLag(DataTree &static_datatree, set<expr_t> &static_lhs, int &
arg->VarMaxLag(static_datatree, static_lhs, max_lag);
else
{
for (set<expr_t>::const_iterator it = static_lhs.begin();
it != static_lhs.end(); it++)
if (*it == this->toStatic(static_datatree))
{
int max_lag_tmp = arg->maxLag();
if (max_lag_tmp > max_lag)
max_lag = max_lag_tmp;
return;
}
int max_lag_tmp = 0;
arg->VarMaxLag(static_datatree, static_lhs, max_lag_tmp);
if (max_lag_tmp + 1 > max_lag)
max_lag = max_lag_tmp + 1;
set<expr_t>::const_iterator it = static_lhs.find(this->toStatic(static_datatree));
if (it != static_lhs.end())
{
int max_lag_tmp = arg->maxLag() - arg->countDiffs();
if (max_lag_tmp > max_lag)
max_lag = max_lag_tmp;
}
else
{
int max_lag_tmp = 0;
arg->VarMaxLag(static_datatree, static_lhs, max_lag_tmp);
if (max_lag_tmp + 1 > max_lag)
max_lag = max_lag_tmp + 1;
}
}
}
......@@ -3027,12 +3024,12 @@ UnaryOpNode::substituteAdl() const
return retval;
}
bool
UnaryOpNode::isDiffPresent() const
int
UnaryOpNode::countDiffs() const
{
if (op_code == oDiff)
return true;
return arg->isDiffPresent();
return arg->countDiffs() + 1;
return arg->countDiffs();
}
bool
......@@ -4414,8 +4411,8 @@ BinaryOpNode::VarMaxLag(DataTree &static_datatree, set<expr_t> &static_lhs, int
void
BinaryOpNode::collectVARLHSVariable(set<expr_t> &result) const
{
arg1->collectVARLHSVariable(result);
arg2->collectVARLHSVariable(result);
cerr << "ERROR: you can only have variables or unary ops on LHS of VAR" << endl;
exit(EXIT_FAILURE);
}
void
......@@ -5011,10 +5008,10 @@ BinaryOpNode::substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &no
return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
}
bool
BinaryOpNode::isDiffPresent() const
int
BinaryOpNode::countDiffs() const
{
return arg1->isDiffPresent() || arg2->isDiffPresent();
return arg1->countDiffs() + arg2->countDiffs();
}
expr_t
......@@ -5657,9 +5654,8 @@ TrinaryOpNode::compileExternalFunctionOutput(ostream &CompileCode, unsigned int
void
TrinaryOpNode::collectVARLHSVariable(set<expr_t> &result) const
{
arg1->collectVARLHSVariable(result);
arg2->collectVARLHSVariable(result);
arg3->collectVARLHSVariable(result);
cerr << "ERROR: you can only have variables or unary ops on LHS of VAR" << endl;
exit(EXIT_FAILURE);
}
void
......@@ -5923,10 +5919,10 @@ TrinaryOpNode::substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &n
return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree);
}
bool
TrinaryOpNode::isDiffPresent() const
int
TrinaryOpNode::countDiffs() const
{
return arg1->isDiffPresent() || arg2->isDiffPresent() || arg3->isDiffPresent();
return arg1->countDiffs() + arg2->countDiffs() + arg3->countDiffs();
}
expr_t
......@@ -6127,9 +6123,8 @@ AbstractExternalFunctionNode::compileExternalFunctionArguments(ostream &CompileC
void
AbstractExternalFunctionNode::collectVARLHSVariable(set<expr_t> &result) const
{
for (vector<expr_t>::const_iterator it = arguments.begin();
it != arguments.end(); it++)
(*it)->collectVARLHSVariable(result);
cerr << "ERROR: you can only have variables or unary ops on LHS of VAR" << endl;
exit(EXIT_FAILURE);
}
void
......@@ -6362,13 +6357,13 @@ AbstractExternalFunctionNode::substituteUnaryOpNodes(DataTree &static_datatree,
return buildSimilarExternalFunctionNode(arguments_subst, datatree);
}
bool
AbstractExternalFunctionNode::isDiffPresent() const
int
AbstractExternalFunctionNode::countDiffs() const
{
bool result = false;
int ndiffs = 0;
for (vector<expr_t>::const_iterator it = arguments.begin(); it != arguments.end(); it++)
result = result || (*it)->isDiffPresent();
return result;
ndiffs += (*it)->countDiffs();
return ndiffs;
}
expr_t
......@@ -7829,10 +7824,10 @@ VarExpectationNode::eval(const eval_context_t &eval_context) const throw (EvalEx
return it->second;
}
bool
VarExpectationNode::isDiffPresent() const
int
VarExpectationNode::countDiffs() const
{
return false;
return 0;
}
void
......@@ -7843,6 +7838,8 @@ VarExpectationNode::computeXrefs(EquationInfo &ei) const
void
VarExpectationNode::collectVARLHSVariable(set<expr_t> &result) const
{
cerr << "ERROR: you can only have variables or unary ops on LHS of VAR" << endl;
exit(EXIT_FAILURE);
}
void
......@@ -8281,6 +8278,8 @@ PacExpectationNode::computeXrefs(EquationInfo &ei) const
void
PacExpectationNode::collectVARLHSVariable(set<expr_t> &result) const
{
cerr << "ERROR: you can only have variables or unary ops on LHS of VAR" << endl;
exit(EXIT_FAILURE);
}
void
......@@ -8306,10 +8305,10 @@ PacExpectationNode::compile(ostream &CompileCode, unsigned int &instruction_numb
exit(EXIT_FAILURE);
}
bool
PacExpectationNode::isDiffPresent() const
int
PacExpectationNode::countDiffs() const
{
return false;
return 0;
}
pair<int, expr_t >
......
......@@ -465,8 +465,8 @@ class ExprNode
//! Returns true if the expression contains one or several exogenous variable
virtual bool containsExogenous() const = 0;
//! Returns true if the expression contains a diff operator
virtual bool isDiffPresent(void) const = 0;
//! Returns the number of diffs present
virtual int countDiffs() const = 0;
//! Return true if the nodeID is a variable withe a type equal to type_arg, a specific variable id aqual to varfiable_id and a lag equal to lag_arg and false otherwise
/*!
......@@ -595,7 +595,7 @@ public:
virtual bool isNumConstNodeEqualTo(double value) const;
virtual bool containsEndogenous(void) const;
virtual bool containsExogenous() const;
virtual bool isDiffPresent(void) const;
virtual int countDiffs() const;
virtual bool isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const;
virtual expr_t replaceTrendVar() const;
virtual expr_t detrend(int symb_id, bool log_trend, expr_t trend) const;
......@@ -685,7 +685,7 @@ public:
virtual bool isNumConstNodeEqualTo(double value) const;
virtual bool containsEndogenous(void) const;
virtual bool containsExogenous() const;
virtual bool isDiffPresent(void) const;
virtual int countDiffs() const;
virtual bool isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const;
virtual expr_t replaceTrendVar() const;
virtual expr_t detrend(int symb_id, bool log_trend, expr_t trend) const;
......@@ -799,7 +799,7 @@ public:
virtual bool isNumConstNodeEqualTo(double value) const;
virtual bool containsEndogenous(void) const;
virtual bool containsExogenous() const;
virtual bool isDiffPresent(void) const;
virtual int countDiffs() const;
virtual bool isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const;
virtual expr_t replaceTrendVar() const;
virtual expr_t detrend(int symb_id, bool log_trend, expr_t trend) const;
......@@ -928,7 +928,7 @@ public:
virtual bool isNumConstNodeEqualTo(double value) const;
virtual bool containsEndogenous(void) const;
virtual bool containsExogenous() const;
virtual bool isDiffPresent(void) const;
virtual int countDiffs() const;
virtual bool isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const;
virtual expr_t replaceTrendVar() const;
virtual expr_t detrend(int symb_id, bool log_trend, expr_t trend) const;
......@@ -1033,7 +1033,7 @@ public:
virtual bool isNumConstNodeEqualTo(double value) const;
virtual bool containsEndogenous(void) const;
virtual bool containsExogenous() const;
virtual bool isDiffPresent(void) const;
virtual int countDiffs() const;
virtual bool isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const;
virtual expr_t replaceTrendVar() const;
virtual expr_t detrend(int symb_id, bool log_trend, expr_t trend) const;
......@@ -1139,7 +1139,7 @@ public:
virtual bool isNumConstNodeEqualTo(double value) const;
virtual bool containsEndogenous(void) const;
virtual bool containsExogenous() const;
virtual bool isDiffPresent(void) const;
virtual int countDiffs() const;
virtual bool isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const;
virtual void writePrhs(ostream &output, ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms, deriv_node_temp_terms_t &tef_terms, const string &ending) const;
virtual expr_t replaceTrendVar() const;
......@@ -1338,7 +1338,7 @@ public:
virtual void collectDynamicVariables(SymbolType type_arg, set<pair<int, int> > &result) const;
virtual bool containsEndogenous(void) const;
virtual bool containsExogenous() const;
virtual bool isDiffPresent(void) const;
virtual int countDiffs() const;
virtual bool isNumConstNodeEqualTo(double value) const;
virtual expr_t differentiateForwardVars(const vector<string> &subset, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const;
virtual expr_t decreaseLeadsLagsPredeterminedVariables() const;
......@@ -1423,7 +1423,7 @@ public:
virtual void collectDynamicVariables(SymbolType type_arg, set<pair<int, int> > &result) const;
virtual bool containsEndogenous(void) const;
virtual bool containsExogenous() const;
virtual bool isDiffPresent(void) const;
virtual int countDiffs() const;
virtual bool isNumConstNodeEqualTo(double value) const;
virtual expr_t differentiateForwardVars(const vector<string> &subset, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const;
virtual expr_t decreaseLeadsLagsPredeterminedVariables() const;
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment