Unverified Commit a377fc83 authored by Sébastien Villemot's avatar Sébastien Villemot
Browse files

Fix various potential bugs with model local variables

Many recursive function on ExprNodes were not correctly recursing into the
definition of model local variables.
parent e4687bb9
......@@ -98,6 +98,11 @@ DataTree::operator=(const DataTree &d)
// Constants must be initialized first because they are used in some Add* methods
initConstants();
/* Model local variables must be next, because they can be evaluated in Add*
methods when the model equations are added */
for (const auto &it : d.local_variables_table)
local_variables_table[it.first] = it.second->clone(*this);
for (const auto &it : d.node_list)
it->clone(*this);
......@@ -105,9 +110,6 @@ DataTree::operator=(const DataTree &d)
local_variables_vector = d.local_variables_vector;
for (const auto &it : d.local_variables_table)
local_variables_table[it.first] = it.second->clone(*this);
return *this;
}
......
......@@ -859,6 +859,9 @@ VariableNode::computeDerivative(int deriv_id)
bool
VariableNode::containsExternalFunction() const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->containsExternalFunction();
return false;
}
......@@ -1203,6 +1206,9 @@ VariableNode::substituteStaticAuxiliaryVariable() const
double
VariableNode::eval(const eval_context_t &eval_context) const noexcept(false)
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->eval(eval_context);
auto it = eval_context.find(symb_id);
if (it == eval_context.end())
throw EvalException();
......@@ -1311,6 +1317,8 @@ VariableNode::computeSubExprContainingVariable(int symb_id_arg, int lag_arg, set
{
if (symb_id == symb_id_arg && lag == lag_arg)
contain_var.insert(const_cast<VariableNode*>(this));
if (get_type() == SymbolType::modelLocalVariable)
datatree.getLocalVariable(symb_id)->computeSubExprContainingVariable(symb_id_arg, lag_arg, contain_var);
}
BinaryOpNode *
......@@ -1318,6 +1326,9 @@ VariableNode::normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs
{
assert(contain_var.count(const_cast<VariableNode *>(this)) > 0);
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->normalizeEquationHelper(contain_var, rhs);
// This the LHS variable: we have finished the normalization
return datatree.AddEqual(const_cast<VariableNode *>(this), rhs);
}
......@@ -1391,10 +1402,12 @@ VariableNode::computeXrefs(EquationInfo &ei) const
case SymbolType::parameter:
ei.param.emplace(symb_id, 0);
break;
case SymbolType::modFileLocalVariable:
datatree.getLocalVariable(symb_id)->computeXrefs(ei);
break;
case SymbolType::trend:
case SymbolType::logTrend:
case SymbolType::modelLocalVariable:
case SymbolType::modFileLocalVariable:
case SymbolType::statementDeclaredVariable:
case SymbolType::unusedEndogenous:
case SymbolType::externalFunction:
......@@ -1559,6 +1572,9 @@ VariableNode::VarMaxLag(const set<expr_t> &lhs_lag_equiv) const
int
VariableNode::PacMaxLag(int lhs_symb_id) const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->PacMaxLag(lhs_symb_id);
if (lhs_symb_id == symb_id)
return -lag;
return 0;
......@@ -1567,28 +1583,41 @@ VariableNode::PacMaxLag(int lhs_symb_id) const
expr_t
VariableNode::substituteAdl() const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->substituteAdl();
return const_cast<VariableNode *>(this);
}
expr_t
VariableNode::substituteVarExpectation(const map<string, expr_t> &subst_table) const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->substituteVarExpectation(subst_table);
return const_cast<VariableNode *>(this);
}
void
VariableNode::findDiffNodes(lag_equivalence_table_t &nodes) const
{
if (get_type() == SymbolType::modelLocalVariable)
datatree.getLocalVariable(symb_id)->findDiffNodes(nodes);
}
void
VariableNode::findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const
{
if (get_type() == SymbolType::modelLocalVariable)
datatree.getLocalVariable(symb_id)->findUnaryOpNodesForAuxVarCreation(nodes);
}
int
VariableNode::findTargetVariable(int lhs_symb_id) const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->findTargetVariable(lhs_symb_id);
return -1;
}
......@@ -1596,18 +1625,27 @@ expr_t
VariableNode::substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table,
vector<BinaryOpNode *> &neweqs) const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->substituteDiff(nodes, subst_table, neweqs);
return const_cast<VariableNode *>(this);
}
expr_t
VariableNode::substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->substituteUnaryOpNodes(nodes, subst_table, neweqs);
return const_cast<VariableNode *>(this);
}
expr_t
VariableNode::substitutePacExpectation(const string &name, expr_t subexpr)
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->substitutePacExpectation(name, subexpr);
return const_cast<VariableNode *>(this);
}
......@@ -1632,6 +1670,9 @@ VariableNode::decreaseLeadsLags(int n) const
expr_t
VariableNode::decreaseLeadsLagsPredeterminedVariables() const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->decreaseLeadsLagsPredeterminedVariables();
if (datatree.symbol_table.isPredetermined(symb_id))
return decreaseLeadsLags(1);
else
......@@ -1773,6 +1814,9 @@ VariableNode::substituteExoLag(subst_table_t &subst_table, vector<BinaryOpNode *
expr_t
VariableNode::substituteExpectation(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool partial_information_model) const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->substituteExpectation(subst_table, neweqs, partial_information_model);
return const_cast<VariableNode *>(this);
}
......@@ -1831,12 +1875,18 @@ VariableNode::isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int la
bool
VariableNode::containsPacExpectation(const string &pac_model_name) const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->containsPacExpectation(pac_model_name);
return false;
}
bool
VariableNode::containsEndogenous() const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->containsEndogenous();
if (get_type() == SymbolType::endogenous)
return true;
else
......@@ -1846,12 +1896,18 @@ VariableNode::containsEndogenous() const
bool
VariableNode::containsExogenous() const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->containsExogenous();
return get_type() == SymbolType::exogenous || get_type() == SymbolType::exogenousDet;
}
expr_t
VariableNode::replaceTrendVar() const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->replaceTrendVar();
if (get_type() == SymbolType::trend)
return datatree.One;
else if (get_type() == SymbolType::logTrend)
......@@ -1863,6 +1919,9 @@ VariableNode::replaceTrendVar() const
expr_t
VariableNode::detrend(int symb_id, bool log_trend, expr_t trend) const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->detrend(symb_id, log_trend, trend);
if (this->symb_id != symb_id)
return const_cast<VariableNode *>(this);
......@@ -1885,12 +1944,18 @@ VariableNode::detrend(int symb_id, bool log_trend, expr_t trend) const
int
VariableNode::countDiffs() const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->countDiffs();
return 0;
}
expr_t
VariableNode::removeTrendLeadLag(const map<int, expr_t> &trend_symbols_map) const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->removeTrendLeadLag(trend_symbols_map);
if ((get_type() != SymbolType::trend && get_type() != SymbolType::logTrend) || lag == 0)
return const_cast<VariableNode *>(this);
......@@ -1936,24 +2001,36 @@ VariableNode::removeTrendLeadLag(const map<int, expr_t> &trend_symbols_map) cons
bool
VariableNode::isInStaticForm() const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->isInStaticForm();
return lag == 0;
}
bool
VariableNode::isParamTimesEndogExpr() const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->isParamTimesEndogExpr();
return false;
}
bool
VariableNode::isVarModelReferenced(const string &model_info_name) const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->isVarModelReferenced(model_info_name);
return false;
}
void
VariableNode::getEndosAndMaxLags(map<string, int> &model_endos_and_lags) const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->getEndosAndMaxLags(model_endos_and_lags);
if (get_type() == SymbolType::endogenous)
if (string varname = datatree.symbol_table.getName(symb_id);
model_endos_and_lags.find(varname) == model_endos_and_lags.end())
......@@ -1971,6 +2048,9 @@ VariableNode::findConstantEquations(map<VariableNode *, NumConstNode *> &table)
expr_t
VariableNode::replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->replaceVarsInEquation(table);
for (auto &it : table)
if (it.first->symb_id == symb_id)
return it.second;
......@@ -1980,6 +2060,9 @@ VariableNode::replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table)
void
VariableNode::matchMatchedMoment(vector<int> &symb_ids, vector<int> &lags, vector<int> &powers) const
{
/* Used for simple expression outside model block, so no need to special-case
model local variables */
if (get_type() != SymbolType::endogenous)
throw MatchFailureException{"Variable " + datatree.symbol_table.getName(symb_id) + " is not an endogenous"};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment