Commit f5df7e75 authored by Houtan Bastani's avatar Houtan Bastani

when an equation is of the form `X` = `constant`, replace all occurrences of...

when an equation is of the form `X` = `constant`, replace all occurrences of `X` in other equations with `constant`
parent c5fc2e38
......@@ -719,6 +719,18 @@ NumConstNode::fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, c
{
}
void
NumConstNode::findConstantEquations(map<expr_t, expr_t> &table) const
{
return;
}
expr_t
NumConstNode::replaceVarsInEquation(map<expr_t, expr_t> &table) const
{
return const_cast<NumConstNode *>(this);
}
VariableNode::VariableNode(DataTree &datatree_arg, int idx_arg, int symb_id_arg, int lag_arg) :
ExprNode{datatree_arg, idx_arg},
symb_id{symb_id_arg},
......@@ -2023,6 +2035,21 @@ VariableNode::fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, c
{
}
void
VariableNode::findConstantEquations(map<expr_t, expr_t> &table) const
{
return;
}
expr_t
VariableNode::replaceVarsInEquation(map<expr_t, expr_t> &table) const
{
for (auto & it : table)
if (dynamic_cast<VariableNode *>(it.first)->symb_id == symb_id)
return dynamic_cast<NumConstNode *>(it.second);
return const_cast<VariableNode *>(this);
}
UnaryOpNode::UnaryOpNode(DataTree &datatree_arg, int idx_arg, UnaryOpcode op_code_arg, const expr_t arg_arg, int expectation_information_set_arg, int param1_symb_id_arg, int param2_symb_id_arg, string adl_param_name_arg, vector<int> adl_lags_arg) :
ExprNode{datatree_arg, idx_arg},
arg{arg_arg},
......@@ -3847,6 +3874,19 @@ UnaryOpNode::fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, co
arg->fillErrorCorrectionRow(eqn, nontrend_lhs, trend_lhs, EC);
}
void
UnaryOpNode::findConstantEquations(map<expr_t, expr_t> &table) const
{
arg->findConstantEquations(table);
}
expr_t
UnaryOpNode::replaceVarsInEquation(map<expr_t, expr_t> &table) const
{
expr_t argsubst = arg->replaceVarsInEquation(table);
return buildSimilarUnaryOpNode(argsubst, datatree);
}
BinaryOpNode::BinaryOpNode(DataTree &datatree_arg, int idx_arg, const expr_t arg1_arg,
BinaryOpcode op_code_arg, const expr_t arg2_arg, int powerDerivOrder_arg) :
ExprNode{datatree_arg, idx_arg},
......@@ -5813,6 +5853,36 @@ BinaryOpNode::fillErrorCorrectionRowHelper(expr_t arg1, expr_t arg2,
EC[make_tuple(eqn, -max_lag, colidx)] = arg1;
}
void
BinaryOpNode::findConstantEquations(map<expr_t, expr_t> &table) const
{
if (op_code == BinaryOpcode::equal)
if (dynamic_cast<VariableNode *>(arg1) != nullptr && dynamic_cast<NumConstNode *>(arg2) != nullptr)
  • Note that you can simplify this as:

    if (!dynamic_cast<VariableNode *>(arg1) && !dynamic_cast<NumConstNode *>(arg2))
  • Sorry, as:

        if (dynamic_cast<VariableNode *>(arg1) && dynamic_cast<NumConstNode *>(arg2))

    (without the negations)

  • @sebastien I have a preference for the explicit comparison to nullptr because it makes the code more readable for me. If you have a preference otherwise, then I'll make the change.

  • I personally find the compact version more readable. But ok for keeping it as it is, anyways we have both styles currently in use in the preprocessor.

  • Note that for smart pointers (std::unique_ptr and std::shared_ptr) we don't really have the choice (the only option is the compact syntax). So in my opinion it's also an argument in favor of the compact syntax.

  • Well, that's not entirely true. You can do:

    std::unique_ptr p;
    if (p.get() == nullptr)
      ...

    but that's quite heavy (you have to call get()), and clearly not the canonical way.

Please register or sign in to reply
table[arg1] = arg2;
else if (dynamic_cast<VariableNode *>(arg2) != nullptr && dynamic_cast<NumConstNode *>(arg1) != nullptr)
table[arg2] = arg1;
else
{
arg1->findConstantEquations(table);
arg2->findConstantEquations(table);
}
}
expr_t
BinaryOpNode::replaceVarsInEquation(map<expr_t, expr_t> &table) const
{
if (op_code == BinaryOpcode::equal)
for (auto & it : table)
if ((dynamic_cast<VariableNode *>(it.first) == arg1
&& dynamic_cast<NumConstNode *>(it.second) == arg2)
|| (dynamic_cast<VariableNode *>(it.first) == arg2
&& dynamic_cast<NumConstNode *>(it.second) == arg1))
Please register or sign in to reply
return const_cast<BinaryOpNode *>(this);
expr_t arg1subst = arg1->replaceVarsInEquation(table);
expr_t arg2subst = arg2->replaceVarsInEquation(table);
return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
}
void
BinaryOpNode::fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs, map<tuple<int, int, int>, expr_t> &EC) const
{
......@@ -6817,6 +6887,23 @@ TrinaryOpNode::fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs,
arg3->fillErrorCorrectionRow(eqn, nontrend_lhs, trend_lhs, EC);
}
void
TrinaryOpNode::findConstantEquations(map<expr_t, expr_t> &table) const
{
arg1->findConstantEquations(table);
arg2->findConstantEquations(table);
arg3->findConstantEquations(table);
}
expr_t
TrinaryOpNode::replaceVarsInEquation(map<expr_t, expr_t> &table) const
{
expr_t arg1subst = arg1->replaceVarsInEquation(table);
expr_t arg2subst = arg2->replaceVarsInEquation(table);
expr_t arg3subst = arg3->replaceVarsInEquation(table);
return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree);
}
AbstractExternalFunctionNode::AbstractExternalFunctionNode(DataTree &datatree_arg,
int idx_arg,
int symb_id_arg,
......@@ -7462,6 +7549,22 @@ AbstractExternalFunctionNode::fillErrorCorrectionRow(int eqn, const vector<int>
exit(EXIT_FAILURE);
}
void
AbstractExternalFunctionNode::findConstantEquations(map<expr_t, expr_t> &table) const
{
for (auto argument : arguments)
argument->findConstantEquations(table);
}
expr_t
AbstractExternalFunctionNode::replaceVarsInEquation(map<expr_t, expr_t> &table) const
{
vector<expr_t> arguments_subst;
for (auto argument : arguments)
arguments_subst.push_back(argument->replaceVarsInEquation(table));
return buildSimilarExternalFunctionNode(arguments_subst, datatree);
}
ExternalFunctionNode::ExternalFunctionNode(DataTree &datatree_arg,
int idx_arg,
int symb_id_arg,
......@@ -8982,6 +9085,18 @@ VarExpectationNode::fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_
exit(EXIT_FAILURE);
}
void
VarExpectationNode::findConstantEquations(map<expr_t, expr_t> &table) const
{
return;
}
expr_t
VarExpectationNode::replaceVarsInEquation(map<expr_t, expr_t> &table) const
{
return const_cast<VarExpectationNode *>(this);
}
void
VarExpectationNode::writeJsonAST(ostream &output) const
{
......@@ -9496,6 +9611,18 @@ PacExpectationNode::fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_
exit(EXIT_FAILURE);
}
void
PacExpectationNode::findConstantEquations(map<expr_t, expr_t> &table) const
{
return;
}
expr_t
PacExpectationNode::replaceVarsInEquation(map<expr_t, expr_t> &table) const
{
return const_cast<PacExpectationNode *>(this);
}
void
PacExpectationNode::writeJsonAST(ostream &output) const
{
......
......@@ -615,6 +615,12 @@ class ExprNode
virtual void fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs,
map<tuple<int, int, int>, expr_t> &EC) const = 0;
//! Finds equations where a variable is equal to a constant
virtual void findConstantEquations(map<expr_t, expr_t> &table) const = 0;
//! Replaces variables found in findConstantEquations() with their constant values
virtual expr_t replaceVarsInEquation(map<expr_t, expr_t> &table) const = 0;
  • Why not using map<VariableNode*, NumConstNode*> as the type for table in the above two methods? That would simplify the code, and show more clearly what all this is about.

  • @sebastien I meant to make this change part way through but forgot by the time I pushed it and have been trying to figure something else out since. I'll make this change now however since I'm blocked on the other thing....

Please register or sign in to reply
//! Returns true if PacExpectationNode encountered
virtual bool containsPacExpectation(const string &pac_model_name = "") const = 0;
......@@ -730,6 +736,8 @@ public:
void fillPacExpectationVarInfo(string &model_name_arg, vector<int> &lhs_arg, int max_lag_arg, int pac_max_lag_arg, vector<bool> &nonstationary_arg, int growth_symb_id_arg, int growth_lag, int equation_number_arg) override;
void fillAutoregressiveRow(int eqn, const vector<int> &lhs, map<tuple<int, int, int>, expr_t> &AR) const override;
void fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs, map<tuple<int, int, int>, expr_t> &EC) const override;
void findConstantEquations(map<expr_t, expr_t> &table) const override;
expr_t replaceVarsInEquation(map<expr_t, expr_t> &table) const override;
bool containsPacExpectation(const string &pac_model_name = "") const override;
void getPacOptimizingPart(int lhs_orig_symb_id, pair<int, pair<vector<int>, vector<bool>>> &ec_params_and_vars,
set<pair<int, pair<int, int>>> &params_and_vars) const override;
......@@ -820,6 +828,8 @@ public:
void fillPacExpectationVarInfo(string &model_name_arg, vector<int> &lhs_arg, int max_lag_arg, int pac_max_lag_arg, vector<bool> &nonstationary_arg, int growth_symb_id_arg, int growth_lag, int equation_number_arg) override;
void fillAutoregressiveRow(int eqn, const vector<int> &lhs, map<tuple<int, int, int>, expr_t> &AR) const override;
void fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs, map<tuple<int, int, int>, expr_t> &EC) const override;
void findConstantEquations(map<expr_t, expr_t> &table) const override;
expr_t replaceVarsInEquation(map<expr_t, expr_t> &table) const override;
bool containsPacExpectation(const string &pac_model_name = "") const override;
void getPacOptimizingPart(int lhs_orig_symb_id, pair<int, pair<vector<int>, vector<bool>>> &ec_params_and_vars,
set<pair<int, pair<int, int>>> &params_and_vars) const override;
......@@ -938,6 +948,8 @@ public:
void fillPacExpectationVarInfo(string &model_name_arg, vector<int> &lhs_arg, int max_lag_arg, int pac_max_lag_arg, vector<bool> &nonstationary_arg, int growth_symb_id_arg, int growth_lag, int equation_number_arg) override;
void fillAutoregressiveRow(int eqn, const vector<int> &lhs, map<tuple<int, int, int>, expr_t> &AR) const override;
void fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs, map<tuple<int, int, int>, expr_t> &EC) const override;
void findConstantEquations(map<expr_t, expr_t> &table) const override;
expr_t replaceVarsInEquation(map<expr_t, expr_t> &table) const override;
bool containsPacExpectation(const string &pac_model_name = "") const override;
void getPacOptimizingPart(int lhs_orig_symb_id, pair<int, pair<vector<int>, vector<bool>>> &ec_params_and_vars,
set<pair<int, pair<int, int>>> &params_and_vars) const override;
......@@ -1074,6 +1086,8 @@ public:
int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs,
map<tuple<int, int, int>, expr_t> &AR) const;
void fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs, map<tuple<int, int, int>, expr_t> &EC) const override;
void findConstantEquations(map<expr_t, expr_t> &table) const override;
expr_t replaceVarsInEquation(map<expr_t, expr_t> &table) const override;
bool containsPacExpectation(const string &pac_model_name = "") const override;
void getPacOptimizingPart(int lhs_orig_symb_id, pair<int, pair<vector<int>, vector<bool>>> &ec_params_and_vars,
set<pair<int, pair<int, int>>> &params_and_vars) const override;
......@@ -1189,6 +1203,8 @@ public:
void fillPacExpectationVarInfo(string &model_name_arg, vector<int> &lhs_arg, int max_lag_arg, int pac_max_lag_arg, vector<bool> &nonstationary_arg, int growth_symb_id_arg, int growth_lag, int equation_number_arg) override;
void fillAutoregressiveRow(int eqn, const vector<int> &lhs, map<tuple<int, int, int>, expr_t> &AR) const override;
void fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs, map<tuple<int, int, int>, expr_t> &EC) const override;
void findConstantEquations(map<expr_t, expr_t> &table) const override;
expr_t replaceVarsInEquation(map<expr_t, expr_t> &table) const override;
bool containsPacExpectation(const string &pac_model_name = "") const override;
void getPacOptimizingPart(int lhs_orig_symb_id, pair<int, pair<vector<int>, vector<bool>>> &ec_params_and_vars,
set<pair<int, pair<int, int>>> &params_and_vars) const override;
......@@ -1316,6 +1332,8 @@ public:
void fillPacExpectationVarInfo(string &model_name_arg, vector<int> &lhs_arg, int max_lag_arg, int pac_max_lag_arg, vector<bool> &nonstationary_arg, int growth_symb_id_arg, int growth_lag, int equation_number_arg) override;
void fillAutoregressiveRow(int eqn, const vector<int> &lhs, map<tuple<int, int, int>, expr_t> &AR) const override;
void fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs, map<tuple<int, int, int>, expr_t> &EC) const override;
void findConstantEquations(map<expr_t, expr_t> &table) const override;
expr_t replaceVarsInEquation(map<expr_t, expr_t> &table) const override;
bool containsPacExpectation(const string &pac_model_name = "") const override;
void getPacOptimizingPart(int lhs_orig_symb_id, pair<int, pair<vector<int>, vector<bool>>> &ec_params_and_vars,
set<pair<int, pair<int, int>>> &params_and_vars) const override;
......@@ -1531,6 +1549,8 @@ public:
void fillPacExpectationVarInfo(string &model_name_arg, vector<int> &lhs_arg, int max_lag_arg, int pac_max_lag_arg, vector<bool> &nonstationary_arg, int growth_symb_id_arg, int growth_lag, int equation_number_arg) override;
void fillAutoregressiveRow(int eqn, const vector<int> &lhs, map<tuple<int, int, int>, expr_t> &AR) const override;
void fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs, map<tuple<int, int, int>, expr_t> &EC) const override;
void findConstantEquations(map<expr_t, expr_t> &table) const override;
expr_t replaceVarsInEquation(map<expr_t, expr_t> &table) const override;
bool containsPacExpectation(const string &pac_model_name = "") const override;
void getPacOptimizingPart(int lhs_orig_symb_id, pair<int, pair<vector<int>, vector<bool>>> &ec_params_and_vars,
set<pair<int, pair<int, int>>> &params_and_vars) const override;
......@@ -1632,6 +1652,8 @@ public:
void fillPacExpectationVarInfo(string &model_name_arg, vector<int> &lhs_arg, int max_lag_arg, int pac_max_lag_arg, vector<bool> &nonstationary_arg, int growth_symb_id_arg, int growth_lag, int equation_number_arg) override;
void fillAutoregressiveRow(int eqn, const vector<int> &lhs, map<tuple<int, int, int>, expr_t> &AR) const override;
void fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs, map<tuple<int, int, int>, expr_t> &EC) const override;
void findConstantEquations(map<expr_t, expr_t> &table) const override;
expr_t replaceVarsInEquation(map<expr_t, expr_t> &table) const override;
bool containsPacExpectation(const string &pac_model_name = "") const override;
void getPacOptimizingPart(int lhs_orig_symb_id, pair<int, pair<vector<int>, vector<bool>>> &ec_params_and_vars,
set<pair<int, pair<int, int>>> &params_and_vars) const override;
......
......@@ -370,6 +370,8 @@ ModFile::transformPass(bool nostrict, bool stochastic, bool compute_xrefs, const
dynamic_model.setLeadsLagsOrig();
original_model = dynamic_model;
dynamic_model.simplifyEquations();
if (nostrict)
{
set<int> unusedEndogs = dynamic_model.findUnusedEndogenous();
......
/*
* Copyright (C) 2003-2018 Dynare Team
* Copyright (C) 2003-2019 Dynare Team
*
* This file is part of Dynare.
*
......@@ -1921,6 +1921,29 @@ ModelTree::addEquation(expr_t eq, int lineno)
equations_lineno.push_back(lineno);
}
void
ModelTree::simplifyEquations()
{
size_t last_subst_table_size = 0;
map<expr_t, expr_t> subst_table;
findConstantEquations(subst_table);
while (subst_table.size() != last_subst_table_size)
{
last_subst_table_size = subst_table.size();
for (auto & equation : equations)
equation = dynamic_cast<BinaryOpNode *>(equation->replaceVarsInEquation(subst_table));
subst_table.clear();
findConstantEquations(subst_table);
}
}
void
ModelTree::findConstantEquations(map<expr_t, expr_t> &subst_table) const
{
for (auto & equation : equations)
equation->findConstantEquations(subst_table);
}
void
ModelTree::addEquation(expr_t eq, int lineno, const vector<pair<string, string>> &eq_tags)
{
......
/*
* Copyright (C) 2003-2018 Dynare Team
* Copyright (C) 2003-2019 Dynare Team
*
* This file is part of Dynare.
*
......@@ -348,6 +348,12 @@ public:
void set_cutoff_to_zero();
//! Helper for writing the Jacobian elements in MATLAB and C
/*! Writes either (i+1,j+1) or [i+j*no_eq] */
//! Simplify model equations: if a variable is equal to a constant, replace that variable elsewhere in the model
void simplifyEquations();
//! Find equations where variable is equal to a constant
void findConstantEquations(map<expr_t, expr_t> &subst_table) const;
void jacobianHelper(ostream &output, int eq_nb, int col_nb, ExprNodeOutputType output_type) const;
//! Helper for writing the sparse Hessian or third derivatives in MATLAB and C
/*! If order=2, writes either v2(i+1,j+1) or v2[i+j*NNZDerivatives[2]]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment