Unverified Commit 5d461c0b authored by Sébastien Villemot's avatar Sébastien Villemot
Browse files

Fix various potential bugs with model local variables

Many recursive function on ExprNodes were not correctly recursing into the
definition of model local variables.

(manually cherry picked from commit a377fc83)
parent b1779a95
Pipeline #4226 passed with stages
in 4 minutes and 43 seconds
...@@ -98,6 +98,11 @@ DataTree::operator=(const DataTree &d) ...@@ -98,6 +98,11 @@ DataTree::operator=(const DataTree &d)
// Constants must be initialized first because they are used in some Add* methods // Constants must be initialized first because they are used in some Add* methods
initConstants(); initConstants();
/* Model local variables must be next, because they can be evaluated in Add*
methods when the model equations are added */
for (const auto &it : d.local_variables_table)
local_variables_table[it.first] = it.second->clone(*this);
for (const auto &it : d.node_list) for (const auto &it : d.node_list)
it->clone(*this); it->clone(*this);
...@@ -105,9 +110,6 @@ DataTree::operator=(const DataTree &d) ...@@ -105,9 +110,6 @@ DataTree::operator=(const DataTree &d)
local_variables_vector = d.local_variables_vector; local_variables_vector = d.local_variables_vector;
for (const auto &it : d.local_variables_table)
local_variables_table[it.first] = it.second->clone(*this);
return *this; return *this;
} }
......
...@@ -887,6 +887,9 @@ VariableNode::collectTemporary_terms(const temporary_terms_t &temporary_terms, t ...@@ -887,6 +887,9 @@ VariableNode::collectTemporary_terms(const temporary_terms_t &temporary_terms, t
bool bool
VariableNode::containsExternalFunction() const VariableNode::containsExternalFunction() const
{ {
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->containsExternalFunction();
return false; return false;
} }
...@@ -1245,6 +1248,9 @@ VariableNode::substituteStaticAuxiliaryVariable() const ...@@ -1245,6 +1248,9 @@ VariableNode::substituteStaticAuxiliaryVariable() const
double double
VariableNode::eval(const eval_context_t &eval_context) const noexcept(false) VariableNode::eval(const eval_context_t &eval_context) const noexcept(false)
{ {
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->eval(eval_context);
auto it = eval_context.find(symb_id); auto it = eval_context.find(symb_id);
if (it == eval_context.end()) if (it == eval_context.end())
throw EvalException(); throw EvalException();
...@@ -1461,10 +1467,12 @@ VariableNode::computeXrefs(EquationInfo &ei) const ...@@ -1461,10 +1467,12 @@ VariableNode::computeXrefs(EquationInfo &ei) const
case SymbolType::parameter: case SymbolType::parameter:
ei.param.emplace(symb_id, 0); ei.param.emplace(symb_id, 0);
break; break;
case SymbolType::modFileLocalVariable:
datatree.getLocalVariable(symb_id)->computeXrefs(ei);
break;
case SymbolType::trend: case SymbolType::trend:
case SymbolType::logTrend: case SymbolType::logTrend:
case SymbolType::modelLocalVariable: case SymbolType::modelLocalVariable:
case SymbolType::modFileLocalVariable:
case SymbolType::statementDeclaredVariable: case SymbolType::statementDeclaredVariable:
case SymbolType::unusedEndogenous: case SymbolType::unusedEndogenous:
case SymbolType::externalFunction: case SymbolType::externalFunction:
...@@ -1629,6 +1637,9 @@ VariableNode::VarMaxLag(const set<expr_t> &lhs_lag_equiv) const ...@@ -1629,6 +1637,9 @@ VariableNode::VarMaxLag(const set<expr_t> &lhs_lag_equiv) const
int int
VariableNode::PacMaxLag(int lhs_symb_id) const VariableNode::PacMaxLag(int lhs_symb_id) const
{ {
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->PacMaxLag(lhs_symb_id);
if (lhs_symb_id == symb_id) if (lhs_symb_id == symb_id)
return -lag; return -lag;
return 0; return 0;
...@@ -1643,28 +1654,41 @@ VariableNode::getPacTargetSymbId(int lhs_symb_id, int undiff_lhs_symb_id) const ...@@ -1643,28 +1654,41 @@ VariableNode::getPacTargetSymbId(int lhs_symb_id, int undiff_lhs_symb_id) const
expr_t expr_t
VariableNode::substituteAdl() const VariableNode::substituteAdl() const
{ {
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->substituteAdl();
return const_cast<VariableNode *>(this); return const_cast<VariableNode *>(this);
} }
expr_t expr_t
VariableNode::substituteVarExpectation(const map<string, expr_t> &subst_table) const VariableNode::substituteVarExpectation(const map<string, expr_t> &subst_table) const
{ {
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->substituteVarExpectation(subst_table);
return const_cast<VariableNode *>(this); return const_cast<VariableNode *>(this);
} }
void void
VariableNode::findDiffNodes(lag_equivalence_table_t &nodes) const VariableNode::findDiffNodes(lag_equivalence_table_t &nodes) const
{ {
if (get_type() == SymbolType::modelLocalVariable)
datatree.getLocalVariable(symb_id)->findDiffNodes(nodes);
} }
void void
VariableNode::findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const VariableNode::findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const
{ {
if (get_type() == SymbolType::modelLocalVariable)
datatree.getLocalVariable(symb_id)->findUnaryOpNodesForAuxVarCreation(nodes);
} }
int int
VariableNode::findTargetVariable(int lhs_symb_id) const VariableNode::findTargetVariable(int lhs_symb_id) const
{ {
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->findTargetVariable(lhs_symb_id);
return -1; return -1;
} }
...@@ -1672,18 +1696,27 @@ expr_t ...@@ -1672,18 +1696,27 @@ expr_t
VariableNode::substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, VariableNode::substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table,
vector<BinaryOpNode *> &neweqs) const vector<BinaryOpNode *> &neweqs) const
{ {
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->substituteDiff(nodes, subst_table, neweqs);
return const_cast<VariableNode *>(this); return const_cast<VariableNode *>(this);
} }
expr_t expr_t
VariableNode::substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const VariableNode::substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
{ {
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->substituteUnaryOpNodes(nodes, subst_table, neweqs);
return const_cast<VariableNode *>(this); return const_cast<VariableNode *>(this);
} }
expr_t expr_t
VariableNode::substitutePacExpectation(const string &name, expr_t subexpr) VariableNode::substitutePacExpectation(const string &name, expr_t subexpr)
{ {
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->substitutePacExpectation(name, subexpr);
return const_cast<VariableNode *>(this); return const_cast<VariableNode *>(this);
} }
...@@ -1708,6 +1741,9 @@ VariableNode::decreaseLeadsLags(int n) const ...@@ -1708,6 +1741,9 @@ VariableNode::decreaseLeadsLags(int n) const
expr_t expr_t
VariableNode::decreaseLeadsLagsPredeterminedVariables() const VariableNode::decreaseLeadsLagsPredeterminedVariables() const
{ {
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->decreaseLeadsLagsPredeterminedVariables();
if (datatree.symbol_table.isPredetermined(symb_id)) if (datatree.symbol_table.isPredetermined(symb_id))
return decreaseLeadsLags(1); return decreaseLeadsLags(1);
else else
...@@ -1849,6 +1885,9 @@ VariableNode::substituteExoLag(subst_table_t &subst_table, vector<BinaryOpNode * ...@@ -1849,6 +1885,9 @@ VariableNode::substituteExoLag(subst_table_t &subst_table, vector<BinaryOpNode *
expr_t expr_t
VariableNode::substituteExpectation(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool partial_information_model) const VariableNode::substituteExpectation(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool partial_information_model) const
{ {
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->substituteExpectation(subst_table, neweqs, partial_information_model);
return const_cast<VariableNode *>(this); return const_cast<VariableNode *>(this);
} }
...@@ -1907,12 +1946,18 @@ VariableNode::isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int la ...@@ -1907,12 +1946,18 @@ VariableNode::isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int la
bool bool
VariableNode::containsPacExpectation(const string &pac_model_name) const VariableNode::containsPacExpectation(const string &pac_model_name) const
{ {
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->containsPacExpectation(pac_model_name);
return false; return false;
} }
bool bool
VariableNode::containsEndogenous() const VariableNode::containsEndogenous() const
{ {
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->containsEndogenous();
if (get_type() == SymbolType::endogenous) if (get_type() == SymbolType::endogenous)
return true; return true;
else else
...@@ -1922,12 +1967,18 @@ VariableNode::containsEndogenous() const ...@@ -1922,12 +1967,18 @@ VariableNode::containsEndogenous() const
bool bool
VariableNode::containsExogenous() const VariableNode::containsExogenous() const
{ {
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->containsExogenous();
return get_type() == SymbolType::exogenous || get_type() == SymbolType::exogenousDet; return get_type() == SymbolType::exogenous || get_type() == SymbolType::exogenousDet;
} }
expr_t expr_t
VariableNode::replaceTrendVar() const VariableNode::replaceTrendVar() const
{ {
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->replaceTrendVar();
if (get_type() == SymbolType::trend) if (get_type() == SymbolType::trend)
return datatree.One; return datatree.One;
else if (get_type() == SymbolType::logTrend) else if (get_type() == SymbolType::logTrend)
...@@ -1939,6 +1990,9 @@ VariableNode::replaceTrendVar() const ...@@ -1939,6 +1990,9 @@ VariableNode::replaceTrendVar() const
expr_t expr_t
VariableNode::detrend(int symb_id, bool log_trend, expr_t trend) const VariableNode::detrend(int symb_id, bool log_trend, expr_t trend) const
{ {
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->detrend(symb_id, log_trend, trend);
if (this->symb_id != symb_id) if (this->symb_id != symb_id)
return const_cast<VariableNode *>(this); return const_cast<VariableNode *>(this);
...@@ -1961,12 +2015,18 @@ VariableNode::detrend(int symb_id, bool log_trend, expr_t trend) const ...@@ -1961,12 +2015,18 @@ VariableNode::detrend(int symb_id, bool log_trend, expr_t trend) const
int int
VariableNode::countDiffs() const VariableNode::countDiffs() const
{ {
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->countDiffs();
return 0; return 0;
} }
expr_t expr_t
VariableNode::removeTrendLeadLag(const map<int, expr_t> &trend_symbols_map) const VariableNode::removeTrendLeadLag(const map<int, expr_t> &trend_symbols_map) const
{ {
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->removeTrendLeadLag(trend_symbols_map);
if ((get_type() != SymbolType::trend && get_type() != SymbolType::logTrend) || lag == 0) if ((get_type() != SymbolType::trend && get_type() != SymbolType::logTrend) || lag == 0)
return const_cast<VariableNode *>(this); return const_cast<VariableNode *>(this);
...@@ -2012,24 +2072,36 @@ VariableNode::removeTrendLeadLag(const map<int, expr_t> &trend_symbols_map) cons ...@@ -2012,24 +2072,36 @@ VariableNode::removeTrendLeadLag(const map<int, expr_t> &trend_symbols_map) cons
bool bool
VariableNode::isInStaticForm() const VariableNode::isInStaticForm() const
{ {
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->isInStaticForm();
return lag == 0; return lag == 0;
} }
bool bool
VariableNode::isParamTimesEndogExpr() const VariableNode::isParamTimesEndogExpr() const
{ {
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->isParamTimesEndogExpr();
return false; return false;
} }
bool bool
VariableNode::isVarModelReferenced(const string &model_info_name) const VariableNode::isVarModelReferenced(const string &model_info_name) const
{ {
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->isVarModelReferenced(model_info_name);
return false; return false;
} }
void void
VariableNode::getEndosAndMaxLags(map<string, int> &model_endos_and_lags) const VariableNode::getEndosAndMaxLags(map<string, int> &model_endos_and_lags) const
{ {
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->getEndosAndMaxLags(model_endos_and_lags);
if (get_type() == SymbolType::endogenous) if (get_type() == SymbolType::endogenous)
if (string varname = datatree.symbol_table.getName(symb_id); if (string varname = datatree.symbol_table.getName(symb_id);
model_endos_and_lags.find(varname) == model_endos_and_lags.end()) model_endos_and_lags.find(varname) == model_endos_and_lags.end())
...@@ -2047,6 +2119,9 @@ VariableNode::findConstantEquations(map<VariableNode *, NumConstNode *> &table) ...@@ -2047,6 +2119,9 @@ VariableNode::findConstantEquations(map<VariableNode *, NumConstNode *> &table)
expr_t expr_t
VariableNode::replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const VariableNode::replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const
{ {
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->replaceVarsInEquation(table);
for (auto &it : table) for (auto &it : table)
if (it.first->symb_id == symb_id) if (it.first->symb_id == symb_id)
return it.second; return it.second;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment