Commit 21528300 authored by Sébastien Villemot's avatar Sébastien Villemot
Browse files

No longer store symbol type in VariableNode

This facilitates switching variable types on the fly. In particular, this
allows removing the hack in DynamicModel::updateAfterVariableChange() that way
basically recreating all the nodes after the type change.
parent c47b6e6e
......@@ -4827,21 +4827,6 @@ DynamicModel::writeAuxVarRecursiveDefinitions(ostream &output, ExprNodeOutputTyp
}
}
void
DynamicModel::updateAfterVariableChange(DynamicModel &dm)
{
variable_node_map.clear();
unary_op_node_map.clear();
binary_op_node_map.clear();
trinary_op_node_map.clear();
external_function_node_map.clear();
first_deriv_external_function_node_map.clear();
second_deriv_external_function_node_map.clear();
cloneDynamic(dm);
dm.replaceMyEquations(*this);
}
void
DynamicModel::cloneDynamic(DynamicModel &dynamic_model) const
{
......
......@@ -378,9 +378,6 @@ public:
/*! It assumes that the dynamic model given in argument has just been allocated */
void cloneDynamic(DynamicModel &dynamic_model) const;
//! update equations after variable type change in model block
void updateAfterVariableChange(DynamicModel &dynamic_model);
//! Replaces model equations with derivatives of Lagrangian w.r.t. endogenous
void computeRamseyPolicyFOCs(const StaticModel &static_model, const bool nopreprocessoroutput);
//! Replaces the model equations in dynamic_model with those in this model
......
......@@ -714,12 +714,11 @@ NumConstNode::fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, c
VariableNode::VariableNode(DataTree &datatree_arg, int idx_arg, int symb_id_arg, int lag_arg) :
ExprNode{datatree_arg, idx_arg},
symb_id{symb_id_arg},
type{datatree.symbol_table.getType(symb_id_arg)},
lag{lag_arg}
{
// It makes sense to allow a lead/lag on parameters: during steady state calibration, endogenous and parameters can be swapped
assert(type != SymbolType::externalFunction
&& (lag == 0 || (type != SymbolType::modelLocalVariable && type != SymbolType::modFileLocalVariable)));
assert(get_type() != SymbolType::externalFunction
&& (lag == 0 || (get_type() != SymbolType::modelLocalVariable && get_type() != SymbolType::modFileLocalVariable)));
}
void
......@@ -731,7 +730,7 @@ VariableNode::prepareForDerivation()
preparedForDerivation = true;
// Fill in non_null_derivatives
switch (type)
switch (get_type())
{
case SymbolType::endogenous:
case SymbolType::exogenous:
......@@ -765,7 +764,7 @@ VariableNode::prepareForDerivation()
expr_t
VariableNode::computeDerivative(int deriv_id)
{
switch (type)
switch (get_type())
{
case SymbolType::endogenous:
case SymbolType::exogenous:
......@@ -806,7 +805,7 @@ VariableNode::collectTemporary_terms(const temporary_terms_t &temporary_terms, t
auto it = temporary_terms.find(const_cast<VariableNode *>(this));
if (it != temporary_terms.end())
temporary_terms_inuse.insert(idx);
if (type == SymbolType::modelLocalVariable)
if (get_type() == SymbolType::modelLocalVariable)
datatree.getLocalVariable(symb_id)->collectTemporary_terms(temporary_terms, temporary_terms_inuse, Curr_Block);
}
......@@ -821,7 +820,7 @@ VariableNode::writeJsonAST(ostream &output) const
{
output << "{\"node_type\" : \"VariableNode\", "
<< "\"name\" : \"" << datatree.symbol_table.getName(symb_id) << "\", \"type\" : \"";
switch (type)
switch (get_type())
{
case SymbolType::endogenous:
output << "endogenous";
......@@ -896,6 +895,7 @@ VariableNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
const temporary_terms_idxs_t &temporary_terms_idxs,
const deriv_node_temp_terms_t &tef_terms) const
{
auto type = get_type();
if (checkIfTemporaryTermThenWrite(output, output_type, temporary_terms, temporary_terms_idxs))
return;
......@@ -1150,7 +1150,7 @@ VariableNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
expr_t
VariableNode::substituteStaticAuxiliaryVariable() const
{
if (type == SymbolType::endogenous)
if (get_type() == SymbolType::endogenous)
{
try
{
......@@ -1179,6 +1179,7 @@ VariableNode::compile(ostream &CompileCode, unsigned int &instruction_number,
const map_idx_t &map_idx, bool dynamic, bool steady_dynamic,
const deriv_node_temp_terms_t &tef_terms) const
{
auto type = get_type();
if (type == SymbolType::modelLocalVariable || type == SymbolType::modFileLocalVariable)
datatree.getLocalVariable(symb_id)->compile(CompileCode, instruction_number, lhs_rhs, temporary_terms, map_idx, dynamic, steady_dynamic, tef_terms);
else
......@@ -1255,14 +1256,14 @@ VariableNode::computeTemporaryTerms(map<expr_t, int> &reference_count,
vector<vector<temporary_terms_t>> &v_temporary_terms,
int equation) const
{
if (type == SymbolType::modelLocalVariable)
if (get_type() == SymbolType::modelLocalVariable)
datatree.getLocalVariable(symb_id)->computeTemporaryTerms(reference_count, temporary_terms, first_occurence, Curr_block, v_temporary_terms, equation);
}
void
VariableNode::collectVARLHSVariable(set<expr_t> &result) const
{
if (type == SymbolType::endogenous && lag == 0)
if (get_type() == SymbolType::endogenous && lag == 0)
result.insert(const_cast<VariableNode *>(this));
else
{
......@@ -1274,9 +1275,9 @@ VariableNode::collectVARLHSVariable(set<expr_t> &result) const
void
VariableNode::collectDynamicVariables(SymbolType type_arg, set<pair<int, int>> &result) const
{
if (type == type_arg)
if (get_type() == type_arg)
result.emplace(symb_id, lag);
if (type == SymbolType::modelLocalVariable)
if (get_type() == SymbolType::modelLocalVariable)
datatree.getLocalVariable(symb_id)->collectDynamicVariables(type_arg, result);
}
......@@ -1295,7 +1296,7 @@ VariableNode::normalizeEquation(int var_endo, vector<pair<int, pair<expr_t, expr
the flag is equal to 2.
- an expression equal to the RHS if flag = 0 and equal to NULL elsewhere
*/
if (type == SymbolType::endogenous)
if (get_type() == SymbolType::endogenous)
{
if (datatree.symbol_table.getTypeSpecificID(symb_id) == var_endo && lag == 0)
/* the endogenous variable */
......@@ -1305,7 +1306,7 @@ VariableNode::normalizeEquation(int var_endo, vector<pair<int, pair<expr_t, expr
}
else
{
if (type == SymbolType::parameter)
if (get_type() == SymbolType::parameter)
return { 0, datatree.AddVariable(symb_id, 0) };
else
return { 0, datatree.AddVariable(symb_id, lag) };
......@@ -1315,7 +1316,7 @@ VariableNode::normalizeEquation(int var_endo, vector<pair<int, pair<expr_t, expr
expr_t
VariableNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables)
{
switch (type)
switch (get_type())
{
case SymbolType::endogenous:
case SymbolType::exogenous:
......@@ -1380,7 +1381,7 @@ VariableNode::toStatic(DataTree &static_datatree) const
void
VariableNode::computeXrefs(EquationInfo &ei) const
{
switch (type)
switch (get_type())
{
case SymbolType::endogenous:
ei.endo.emplace(symb_id, lag);
......@@ -1409,6 +1410,12 @@ VariableNode::computeXrefs(EquationInfo &ei) const
}
}
SymbolType
VariableNode::get_type() const
{
return datatree.symbol_table.getType(symb_id);
}
expr_t
VariableNode::cloneDynamic(DataTree &dynamic_datatree) const
{
......@@ -1418,7 +1425,7 @@ VariableNode::cloneDynamic(DataTree &dynamic_datatree) const
int
VariableNode::maxEndoLead() const
{
switch (type)
switch (get_type())
{
case SymbolType::endogenous:
return max(lag, 0);
......@@ -1432,7 +1439,7 @@ VariableNode::maxEndoLead() const
int
VariableNode::maxExoLead() const
{
switch (type)
switch (get_type())
{
case SymbolType::exogenous:
return max(lag, 0);
......@@ -1446,7 +1453,7 @@ VariableNode::maxExoLead() const
int
VariableNode::maxEndoLag() const
{
switch (type)
switch (get_type())
{
case SymbolType::endogenous:
return max(-lag, 0);
......@@ -1460,7 +1467,7 @@ VariableNode::maxEndoLag() const
int
VariableNode::maxExoLag() const
{
switch (type)
switch (get_type())
{
case SymbolType::exogenous:
return max(-lag, 0);
......@@ -1474,7 +1481,7 @@ VariableNode::maxExoLag() const
int
VariableNode::maxLead() const
{
switch (type)
switch (get_type())
{
case SymbolType::endogenous:
return lag;
......@@ -1490,7 +1497,7 @@ VariableNode::maxLead() const
int
VariableNode::VarMinLag() const
{
switch (type)
switch (get_type())
{
case SymbolType::endogenous:
return -lag;
......@@ -1509,7 +1516,7 @@ VariableNode::VarMinLag() const
int
VariableNode::maxLag() const
{
switch (type)
switch (get_type())
{
case SymbolType::endogenous:
return -lag;
......@@ -1595,7 +1602,7 @@ VariableNode::substitutePacExpectation(map<const PacExpectationNode *, const Bin
expr_t
VariableNode::decreaseLeadsLags(int n) const
{
switch (type)
switch (get_type())
{
case SymbolType::endogenous:
case SymbolType::exogenous:
......@@ -1623,7 +1630,7 @@ expr_t
VariableNode::substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool deterministic_model) const
{
expr_t value;
switch (type)
switch (get_type())
{
case SymbolType::endogenous:
if (lag <= 1)
......@@ -1648,7 +1655,7 @@ VariableNode::substituteEndoLagGreaterThanTwo(subst_table_t &subst_table, vector
expr_t value;
subst_table_t::const_iterator it;
int cur_lag;
switch (type)
switch (get_type())
{
case SymbolType::endogenous:
if (lag >= -1)
......@@ -1696,7 +1703,7 @@ expr_t
VariableNode::substituteExoLead(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool deterministic_model) const
{
expr_t value;
switch (type)
switch (get_type())
{
case SymbolType::exogenous:
if (lag <= 0)
......@@ -1721,7 +1728,7 @@ VariableNode::substituteExoLag(subst_table_t &subst_table, vector<BinaryOpNode *
expr_t value;
subst_table_t::const_iterator it;
int cur_lag;
switch (type)
switch (get_type())
{
case SymbolType::exogenous:
if (lag >= 0)
......@@ -1775,7 +1782,7 @@ expr_t
VariableNode::differentiateForwardVars(const vector<string> &subset, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
{
expr_t value;
switch (type)
switch (get_type())
{
case SymbolType::endogenous:
assert(lag <= 1);
......@@ -1820,7 +1827,7 @@ VariableNode::isNumConstNodeEqualTo(double value) const
bool
VariableNode::isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const
{
if (type == type_arg && datatree.symbol_table.getTypeSpecificID(symb_id) == variable_id && lag == lag_arg)
if (get_type() == type_arg && datatree.symbol_table.getTypeSpecificID(symb_id) == variable_id && lag == lag_arg)
return true;
else
return false;
......@@ -1835,7 +1842,7 @@ VariableNode::containsPacExpectation(const string &pac_model_name) const
bool
VariableNode::containsEndogenous() const
{
if (type == SymbolType::endogenous)
if (get_type() == SymbolType::endogenous)
return true;
else
return false;
......@@ -1844,7 +1851,7 @@ VariableNode::containsEndogenous() const
bool
VariableNode::containsExogenous() const
{
return (type == SymbolType::exogenous || type == SymbolType::exogenousDet);
return (get_type() == SymbolType::exogenous || get_type() == SymbolType::exogenousDet);
}
expr_t
......@@ -1947,8 +1954,8 @@ void
VariableNode::getPacNonOptimizingPart(set<pair<int, pair<pair<int, int>, double>>>
&params_vars_and_scaling_factor) const
{
if (type != SymbolType::endogenous
&& type != SymbolType::exogenous)
if (get_type() != SymbolType::endogenous
&& get_type() != SymbolType::exogenous)
{
cerr << "ERROR VariableNode::getPacNonOptimizingPart: Error in parsing PAC equation"
<< endl;
......@@ -1993,7 +2000,7 @@ void
VariableNode::getEndosAndMaxLags(map<string, int> &model_endos_and_lags) const
{
string varname = datatree.symbol_table.getName(symb_id);
if (type == SymbolType::endogenous)
if (get_type() == SymbolType::endogenous)
if (model_endos_and_lags.find(varname) == model_endos_and_lags.end())
model_endos_and_lags[varname] = min(model_endos_and_lags[varname], lag);
else
......
......@@ -693,7 +693,6 @@ class VariableNode : public ExprNode
private:
//! Id from the symbol table
const int symb_id;
const SymbolType type;
//! A positive value is a lead, a negative is a lag
const int lag;
expr_t computeDerivative(int deriv_id) override;
......@@ -717,11 +716,7 @@ public:
void compile(ostream &CompileCode, unsigned int &instruction_number, bool lhs_rhs, const temporary_terms_t &temporary_terms, const map_idx_t &map_idx, bool dynamic, bool steady_dynamic, const deriv_node_temp_terms_t &tef_terms) const override;
expr_t toStatic(DataTree &static_datatree) const override;
void computeXrefs(EquationInfo &ei) const override;
SymbolType
get_type() const
{
return type;
};
SymbolType get_type() const;
int
get_symb_id() const
{
......
......@@ -369,14 +369,6 @@ ParsingDriver::declare_or_change_type(SymbolType new_type, const string &name)
symb_id = mod_file->symbol_table.getID(name);
mod_file->symbol_table.changeType(symb_id, new_type);
// change in equations in ModelTree
auto dm = make_unique<DynamicModel>(mod_file->symbol_table,
mod_file->num_constants,
mod_file->external_functions_table,
mod_file->trend_component_model_table,
mod_file->var_model_table);
mod_file->dynamic_model.updateAfterVariableChange(*dm);
// remove error messages
undeclared_model_vars.erase(name);
for (auto it = undeclared_model_variable_errors.begin();
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment