Fix various potential bugs with model local variables

Many recursive function on ExprNodes were not correctly recursing into the
definition of model local variables.

(manually cherry picked from commit a377fc83)
parent b1779a95
Pipeline #4226 passed with stages
in 4 minutes and 43 seconds
......@@ -98,6 +98,11 @@ DataTree::operator=(const DataTree &d)
// Constants must be initialized first because they are used in some Add* methods
initConstants();
/* Model local variables must be next, because they can be evaluated in Add*
methods when the model equations are added */
for (const auto &it : d.local_variables_table)
local_variables_table[it.first] = it.second->clone(*this);
for (const auto &it : d.node_list)
it->clone(*this);
......@@ -105,9 +110,6 @@ DataTree::operator=(const DataTree &d)
local_variables_vector = d.local_variables_vector;
for (const auto &it : d.local_variables_table)
local_variables_table[it.first] = it.second->clone(*this);
return *this;
}
......
......@@ -887,6 +887,9 @@ VariableNode::collectTemporary_terms(const temporary_terms_t &temporary_terms, t
bool
VariableNode::containsExternalFunction() const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->containsExternalFunction();
return false;
}
......@@ -1245,6 +1248,9 @@ VariableNode::substituteStaticAuxiliaryVariable() const
double
VariableNode::eval(const eval_context_t &eval_context) const noexcept(false)
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->eval(eval_context);
auto it = eval_context.find(symb_id);
if (it == eval_context.end())
throw EvalException();
......@@ -1461,10 +1467,12 @@ VariableNode::computeXrefs(EquationInfo &ei) const
case SymbolType::parameter:
ei.param.emplace(symb_id, 0);
break;
case SymbolType::modFileLocalVariable:
datatree.getLocalVariable(symb_id)->computeXrefs(ei);
break;
case SymbolType::trend:
case SymbolType::logTrend:
case SymbolType::modelLocalVariable:
case SymbolType::modFileLocalVariable:
case SymbolType::statementDeclaredVariable:
case SymbolType::unusedEndogenous:
case SymbolType::externalFunction:
......@@ -1629,6 +1637,9 @@ VariableNode::VarMaxLag(const set<expr_t> &lhs_lag_equiv) const
int
VariableNode::PacMaxLag(int lhs_symb_id) const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->PacMaxLag(lhs_symb_id);
if (lhs_symb_id == symb_id)
return -lag;
return 0;
......@@ -1643,28 +1654,41 @@ VariableNode::getPacTargetSymbId(int lhs_symb_id, int undiff_lhs_symb_id) const
expr_t
VariableNode::substituteAdl() const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->substituteAdl();
return const_cast<VariableNode *>(this);
}
expr_t
VariableNode::substituteVarExpectation(const map<string, expr_t> &subst_table) const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->substituteVarExpectation(subst_table);
return const_cast<VariableNode *>(this);
}
void
VariableNode::findDiffNodes(lag_equivalence_table_t &nodes) const
{
if (get_type() == SymbolType::modelLocalVariable)
datatree.getLocalVariable(symb_id)->findDiffNodes(nodes);
}
void
VariableNode::findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const
{
if (get_type() == SymbolType::modelLocalVariable)
datatree.getLocalVariable(symb_id)->findUnaryOpNodesForAuxVarCreation(nodes);
}
int
VariableNode::findTargetVariable(int lhs_symb_id) const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->findTargetVariable(lhs_symb_id);
return -1;
}
......@@ -1672,18 +1696,27 @@ expr_t
VariableNode::substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table,
vector<BinaryOpNode *> &neweqs) const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->substituteDiff(nodes, subst_table, neweqs);
return const_cast<VariableNode *>(this);
}
expr_t
VariableNode::substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->substituteUnaryOpNodes(nodes, subst_table, neweqs);
return const_cast<VariableNode *>(this);
}
expr_t
VariableNode::substitutePacExpectation(const string &name, expr_t subexpr)
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->substitutePacExpectation(name, subexpr);
return const_cast<VariableNode *>(this);
}
......@@ -1708,6 +1741,9 @@ VariableNode::decreaseLeadsLags(int n) const
expr_t
VariableNode::decreaseLeadsLagsPredeterminedVariables() const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->decreaseLeadsLagsPredeterminedVariables();
if (datatree.symbol_table.isPredetermined(symb_id))
return decreaseLeadsLags(1);
else
......@@ -1849,6 +1885,9 @@ VariableNode::substituteExoLag(subst_table_t &subst_table, vector<BinaryOpNode *
expr_t
VariableNode::substituteExpectation(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool partial_information_model) const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->substituteExpectation(subst_table, neweqs, partial_information_model);
return const_cast<VariableNode *>(this);
}
......@@ -1907,12 +1946,18 @@ VariableNode::isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int la
bool
VariableNode::containsPacExpectation(const string &pac_model_name) const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->containsPacExpectation(pac_model_name);
return false;
}
bool
VariableNode::containsEndogenous() const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->containsEndogenous();
if (get_type() == SymbolType::endogenous)
return true;
else
......@@ -1922,12 +1967,18 @@ VariableNode::containsEndogenous() const
bool
VariableNode::containsExogenous() const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->containsExogenous();
return get_type() == SymbolType::exogenous || get_type() == SymbolType::exogenousDet;
}
expr_t
VariableNode::replaceTrendVar() const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->replaceTrendVar();
if (get_type() == SymbolType::trend)
return datatree.One;
else if (get_type() == SymbolType::logTrend)
......@@ -1939,6 +1990,9 @@ VariableNode::replaceTrendVar() const
expr_t
VariableNode::detrend(int symb_id, bool log_trend, expr_t trend) const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->detrend(symb_id, log_trend, trend);
if (this->symb_id != symb_id)
return const_cast<VariableNode *>(this);
......@@ -1961,12 +2015,18 @@ VariableNode::detrend(int symb_id, bool log_trend, expr_t trend) const
int
VariableNode::countDiffs() const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->countDiffs();
return 0;
}
expr_t
VariableNode::removeTrendLeadLag(const map<int, expr_t> &trend_symbols_map) const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->removeTrendLeadLag(trend_symbols_map);
if ((get_type() != SymbolType::trend && get_type() != SymbolType::logTrend) || lag == 0)
return const_cast<VariableNode *>(this);
......@@ -2012,24 +2072,36 @@ VariableNode::removeTrendLeadLag(const map<int, expr_t> &trend_symbols_map) cons
bool
VariableNode::isInStaticForm() const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->isInStaticForm();
return lag == 0;
}
bool
VariableNode::isParamTimesEndogExpr() const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->isParamTimesEndogExpr();
return false;
}
bool
VariableNode::isVarModelReferenced(const string &model_info_name) const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->isVarModelReferenced(model_info_name);
return false;
}
void
VariableNode::getEndosAndMaxLags(map<string, int> &model_endos_and_lags) const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->getEndosAndMaxLags(model_endos_and_lags);
if (get_type() == SymbolType::endogenous)
if (string varname = datatree.symbol_table.getName(symb_id);
model_endos_and_lags.find(varname) == model_endos_and_lags.end())
......@@ -2047,6 +2119,9 @@ VariableNode::findConstantEquations(map<VariableNode *, NumConstNode *> &table)
expr_t
VariableNode::replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const
{
if (get_type() == SymbolType::modelLocalVariable)
return datatree.getLocalVariable(symb_id)->replaceVarsInEquation(table);
for (auto &it : table)
if (it.first->symb_id == symb_id)
return it.second;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment