change map type for readability

parent 1e071ca4
......@@ -714,13 +714,13 @@ NumConstNode::fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, c
}
void
NumConstNode::findConstantEquations(map<expr_t, expr_t> &table) const
NumConstNode::findConstantEquations(map<VariableNode *, NumConstNode *> &table) const
{
return;
}
expr_t
NumConstNode::replaceVarsInEquation(map<expr_t, expr_t> &table) const
NumConstNode::replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const
{
return const_cast<NumConstNode *>(this);
}
......@@ -2024,17 +2024,17 @@ VariableNode::fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, c
}
void
VariableNode::findConstantEquations(map<expr_t, expr_t> &table) const
VariableNode::findConstantEquations(map<VariableNode *, NumConstNode *> &table) const
{
return;
}
expr_t
VariableNode::replaceVarsInEquation(map<expr_t, expr_t> &table) const
VariableNode::replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const
{
for (auto & it : table)
if (dynamic_cast<VariableNode *>(it.first)->symb_id == symb_id)
return dynamic_cast<NumConstNode *>(it.second);
if (it.first->symb_id == symb_id)
return it.second;
return const_cast<VariableNode *>(this);
}
......@@ -3857,13 +3857,13 @@ UnaryOpNode::fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, co
}
void
UnaryOpNode::findConstantEquations(map<expr_t, expr_t> &table) const
UnaryOpNode::findConstantEquations(map<VariableNode *, NumConstNode *> &table) const
{
arg->findConstantEquations(table);
}
expr_t
UnaryOpNode::replaceVarsInEquation(map<expr_t, expr_t> &table) const
UnaryOpNode::replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const
{
expr_t argsubst = arg->replaceVarsInEquation(table);
return buildSimilarUnaryOpNode(argsubst, datatree);
......@@ -5830,13 +5830,13 @@ BinaryOpNode::fillErrorCorrectionRowHelper(expr_t arg1, expr_t arg2,
}
void
BinaryOpNode::findConstantEquations(map<expr_t, expr_t> &table) const
BinaryOpNode::findConstantEquations(map<VariableNode *, NumConstNode *> &table) const
{
if (op_code == BinaryOpcode::equal)
if (dynamic_cast<VariableNode *>(arg1) != nullptr && dynamic_cast<NumConstNode *>(arg2) != nullptr)
table[arg1] = arg2;
table[dynamic_cast<VariableNode *>(arg1)] = dynamic_cast<NumConstNode *>(arg2);
else if (dynamic_cast<VariableNode *>(arg2) != nullptr && dynamic_cast<NumConstNode *>(arg1) != nullptr)
table[arg2] = arg1;
table[dynamic_cast<VariableNode *>(arg2)] = dynamic_cast<NumConstNode *>(arg1);
else
{
arg1->findConstantEquations(table);
......@@ -5845,7 +5845,7 @@ BinaryOpNode::findConstantEquations(map<expr_t, expr_t> &table) const
}
expr_t
BinaryOpNode::replaceVarsInEquation(map<expr_t, expr_t> &table) const
BinaryOpNode::replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const
{
if (op_code == BinaryOpcode::equal)
for (auto & it : table)
......@@ -6858,7 +6858,7 @@ TrinaryOpNode::fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs,
}
void
TrinaryOpNode::findConstantEquations(map<expr_t, expr_t> &table) const
TrinaryOpNode::findConstantEquations(map<VariableNode *, NumConstNode *> &table) const
{
arg1->findConstantEquations(table);
arg2->findConstantEquations(table);
......@@ -6866,7 +6866,7 @@ TrinaryOpNode::findConstantEquations(map<expr_t, expr_t> &table) const
}
expr_t
TrinaryOpNode::replaceVarsInEquation(map<expr_t, expr_t> &table) const
TrinaryOpNode::replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const
{
expr_t arg1subst = arg1->replaceVarsInEquation(table);
expr_t arg2subst = arg2->replaceVarsInEquation(table);
......@@ -7511,14 +7511,14 @@ AbstractExternalFunctionNode::fillErrorCorrectionRow(int eqn, const vector<int>
}
void
AbstractExternalFunctionNode::findConstantEquations(map<expr_t, expr_t> &table) const
AbstractExternalFunctionNode::findConstantEquations(map<VariableNode *, NumConstNode *> &table) const
{
for (auto argument : arguments)
argument->findConstantEquations(table);
}
expr_t
AbstractExternalFunctionNode::replaceVarsInEquation(map<expr_t, expr_t> &table) const
AbstractExternalFunctionNode::replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const
{
vector<expr_t> arguments_subst;
for (auto argument : arguments)
......@@ -9040,13 +9040,13 @@ VarExpectationNode::fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_
}
void
VarExpectationNode::findConstantEquations(map<expr_t, expr_t> &table) const
VarExpectationNode::findConstantEquations(map<VariableNode *, NumConstNode *> &table) const
{
return;
}
expr_t
VarExpectationNode::replaceVarsInEquation(map<expr_t, expr_t> &table) const
VarExpectationNode::replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const
{
return const_cast<VarExpectationNode *>(this);
}
......@@ -9560,13 +9560,13 @@ PacExpectationNode::fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_
}
void
PacExpectationNode::findConstantEquations(map<expr_t, expr_t> &table) const
PacExpectationNode::findConstantEquations(map<VariableNode *, NumConstNode *> &table) const
{
return;
}
expr_t
PacExpectationNode::replaceVarsInEquation(map<expr_t, expr_t> &table) const
PacExpectationNode::replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const
{
return const_cast<PacExpectationNode *>(this);
}
......
......@@ -34,6 +34,7 @@ using namespace std;
#include "SymbolList.hh"
class DataTree;
class NumConstNode;
class VariableNode;
class UnaryOpNode;
class BinaryOpNode;
......@@ -613,10 +614,10 @@ class ExprNode
map<tuple<int, int, int>, expr_t> &EC) const = 0;
//! Finds equations where a variable is equal to a constant
virtual void findConstantEquations(map<expr_t, expr_t> &table) const = 0;
virtual void findConstantEquations(map<VariableNode *, NumConstNode *> &table) const = 0;
//! Replaces variables found in findConstantEquations() with their constant values
virtual expr_t replaceVarsInEquation(map<expr_t, expr_t> &table) const = 0;
virtual expr_t replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const = 0;
//! Returns true if PacExpectationNode encountered
virtual bool containsPacExpectation(const string &pac_model_name = "") const = 0;
......@@ -732,8 +733,8 @@ public:
void fillPacExpectationVarInfo(string &model_name_arg, vector<int> &lhs_arg, int max_lag_arg, int pac_max_lag_arg, vector<bool> &nonstationary_arg, int growth_symb_id_arg, int growth_lag, int equation_number_arg) override;
void fillAutoregressiveRow(int eqn, const vector<int> &lhs, map<tuple<int, int, int>, expr_t> &AR) const override;
void fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs, map<tuple<int, int, int>, expr_t> &EC) const override;
void findConstantEquations(map<expr_t, expr_t> &table) const override;
expr_t replaceVarsInEquation(map<expr_t, expr_t> &table) const override;
void findConstantEquations(map<VariableNode *, NumConstNode *> &table) const override;
expr_t replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const override;
bool containsPacExpectation(const string &pac_model_name = "") const override;
void getPacOptimizingPart(int lhs_orig_symb_id, pair<int, pair<vector<int>, vector<bool>>> &ec_params_and_vars,
set<pair<int, pair<int, int>>> &params_and_vars) const override;
......@@ -823,8 +824,8 @@ public:
void fillPacExpectationVarInfo(string &model_name_arg, vector<int> &lhs_arg, int max_lag_arg, int pac_max_lag_arg, vector<bool> &nonstationary_arg, int growth_symb_id_arg, int growth_lag, int equation_number_arg) override;
void fillAutoregressiveRow(int eqn, const vector<int> &lhs, map<tuple<int, int, int>, expr_t> &AR) const override;
void fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs, map<tuple<int, int, int>, expr_t> &EC) const override;
void findConstantEquations(map<expr_t, expr_t> &table) const override;
expr_t replaceVarsInEquation(map<expr_t, expr_t> &table) const override;
void findConstantEquations(map<VariableNode *, NumConstNode *> &table) const override;
expr_t replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const override;
bool containsPacExpectation(const string &pac_model_name = "") const override;
void getPacOptimizingPart(int lhs_orig_symb_id, pair<int, pair<vector<int>, vector<bool>>> &ec_params_and_vars,
set<pair<int, pair<int, int>>> &params_and_vars) const override;
......@@ -942,8 +943,8 @@ public:
void fillPacExpectationVarInfo(string &model_name_arg, vector<int> &lhs_arg, int max_lag_arg, int pac_max_lag_arg, vector<bool> &nonstationary_arg, int growth_symb_id_arg, int growth_lag, int equation_number_arg) override;
void fillAutoregressiveRow(int eqn, const vector<int> &lhs, map<tuple<int, int, int>, expr_t> &AR) const override;
void fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs, map<tuple<int, int, int>, expr_t> &EC) const override;
void findConstantEquations(map<expr_t, expr_t> &table) const override;
expr_t replaceVarsInEquation(map<expr_t, expr_t> &table) const override;
void findConstantEquations(map<VariableNode *, NumConstNode *> &table) const override;
expr_t replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const override;
bool containsPacExpectation(const string &pac_model_name = "") const override;
void getPacOptimizingPart(int lhs_orig_symb_id, pair<int, pair<vector<int>, vector<bool>>> &ec_params_and_vars,
set<pair<int, pair<int, int>>> &params_and_vars) const override;
......@@ -1079,8 +1080,8 @@ public:
int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs,
map<tuple<int, int, int>, expr_t> &AR) const;
void fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs, map<tuple<int, int, int>, expr_t> &EC) const override;
void findConstantEquations(map<expr_t, expr_t> &table) const override;
expr_t replaceVarsInEquation(map<expr_t, expr_t> &table) const override;
void findConstantEquations(map<VariableNode *, NumConstNode *> &table) const override;
expr_t replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const override;
bool containsPacExpectation(const string &pac_model_name = "") const override;
void getPacOptimizingPart(int lhs_orig_symb_id, pair<int, pair<vector<int>, vector<bool>>> &ec_params_and_vars,
set<pair<int, pair<int, int>>> &params_and_vars) const override;
......@@ -1195,8 +1196,8 @@ public:
void fillPacExpectationVarInfo(string &model_name_arg, vector<int> &lhs_arg, int max_lag_arg, int pac_max_lag_arg, vector<bool> &nonstationary_arg, int growth_symb_id_arg, int growth_lag, int equation_number_arg) override;
void fillAutoregressiveRow(int eqn, const vector<int> &lhs, map<tuple<int, int, int>, expr_t> &AR) const override;
void fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs, map<tuple<int, int, int>, expr_t> &EC) const override;
void findConstantEquations(map<expr_t, expr_t> &table) const override;
expr_t replaceVarsInEquation(map<expr_t, expr_t> &table) const override;
void findConstantEquations(map<VariableNode *, NumConstNode *> &table) const override;
expr_t replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const override;
bool containsPacExpectation(const string &pac_model_name = "") const override;
void getPacOptimizingPart(int lhs_orig_symb_id, pair<int, pair<vector<int>, vector<bool>>> &ec_params_and_vars,
set<pair<int, pair<int, int>>> &params_and_vars) const override;
......@@ -1323,8 +1324,8 @@ public:
void fillPacExpectationVarInfo(string &model_name_arg, vector<int> &lhs_arg, int max_lag_arg, int pac_max_lag_arg, vector<bool> &nonstationary_arg, int growth_symb_id_arg, int growth_lag, int equation_number_arg) override;
void fillAutoregressiveRow(int eqn, const vector<int> &lhs, map<tuple<int, int, int>, expr_t> &AR) const override;
void fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs, map<tuple<int, int, int>, expr_t> &EC) const override;
void findConstantEquations(map<expr_t, expr_t> &table) const override;
expr_t replaceVarsInEquation(map<expr_t, expr_t> &table) const override;
void findConstantEquations(map<VariableNode *, NumConstNode *> &table) const override;
expr_t replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const override;
bool containsPacExpectation(const string &pac_model_name = "") const override;
void getPacOptimizingPart(int lhs_orig_symb_id, pair<int, pair<vector<int>, vector<bool>>> &ec_params_and_vars,
set<pair<int, pair<int, int>>> &params_and_vars) const override;
......@@ -1539,8 +1540,8 @@ public:
void fillPacExpectationVarInfo(string &model_name_arg, vector<int> &lhs_arg, int max_lag_arg, int pac_max_lag_arg, vector<bool> &nonstationary_arg, int growth_symb_id_arg, int growth_lag, int equation_number_arg) override;
void fillAutoregressiveRow(int eqn, const vector<int> &lhs, map<tuple<int, int, int>, expr_t> &AR) const override;
void fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs, map<tuple<int, int, int>, expr_t> &EC) const override;
void findConstantEquations(map<expr_t, expr_t> &table) const override;
expr_t replaceVarsInEquation(map<expr_t, expr_t> &table) const override;
void findConstantEquations(map<VariableNode *, NumConstNode *> &table) const override;
expr_t replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const override;
bool containsPacExpectation(const string &pac_model_name = "") const override;
void getPacOptimizingPart(int lhs_orig_symb_id, pair<int, pair<vector<int>, vector<bool>>> &ec_params_and_vars,
set<pair<int, pair<int, int>>> &params_and_vars) const override;
......@@ -1641,8 +1642,8 @@ public:
void fillPacExpectationVarInfo(string &model_name_arg, vector<int> &lhs_arg, int max_lag_arg, int pac_max_lag_arg, vector<bool> &nonstationary_arg, int growth_symb_id_arg, int growth_lag, int equation_number_arg) override;
void fillAutoregressiveRow(int eqn, const vector<int> &lhs, map<tuple<int, int, int>, expr_t> &AR) const override;
void fillErrorCorrectionRow(int eqn, const vector<int> &nontrend_lhs, const vector<int> &trend_lhs, map<tuple<int, int, int>, expr_t> &EC) const override;
void findConstantEquations(map<expr_t, expr_t> &table) const override;
expr_t replaceVarsInEquation(map<expr_t, expr_t> &table) const override;
void findConstantEquations(map<VariableNode *, NumConstNode *> &table) const override;
expr_t replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const override;
bool containsPacExpectation(const string &pac_model_name = "") const override;
void getPacOptimizingPart(int lhs_orig_symb_id, pair<int, pair<vector<int>, vector<bool>>> &ec_params_and_vars,
set<pair<int, pair<int, int>>> &params_and_vars) const override;
......
......@@ -1925,7 +1925,7 @@ void
ModelTree::simplifyEquations()
{
size_t last_subst_table_size = 0;
map<expr_t, expr_t> subst_table;
map<VariableNode *, NumConstNode *> subst_table;
findConstantEquations(subst_table);
while (subst_table.size() != last_subst_table_size)
{
......@@ -1938,7 +1938,7 @@ ModelTree::simplifyEquations()
}
void
ModelTree::findConstantEquations(map<expr_t, expr_t> &subst_table) const
ModelTree::findConstantEquations(map<VariableNode *, NumConstNode *> &subst_table) const
{
for (auto & equation : equations)
equation->findConstantEquations(subst_table);
......
......@@ -352,7 +352,7 @@ public:
//! Simplify model equations: if a variable is equal to a constant, replace that variable elsewhere in the model
void simplifyEquations();
//! Find equations where variable is equal to a constant
void findConstantEquations(map<expr_t, expr_t> &subst_table) const;
void findConstantEquations(map<VariableNode *, NumConstNode *> &subst_table) const;
void jacobianHelper(ostream &output, int eq_nb, int col_nb, ExprNodeOutputType output_type) const;
//! Helper for writing the sparse Hessian or third derivatives in MATLAB and C
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment