Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found
Select Git revision
  • 4.6
  • 5.x
  • 6.x
  • aux_vars_fix
  • julia
  • julia-6.3.0
  • julia-6.4.0
  • julia-meson
  • llvm-15
  • master
  • python-codegen
  • rework_pac
  • uop
  • created_preprocessor_repo
  • julia-6.2.0
15 results

Target

Select target project
  • normann/preprocessor
  • Dynare/preprocessor
  • FerhatMihoubi/preprocessor
  • MichelJuillard/preprocessor
  • sebastien/preprocessor
  • lnsongxf/preprocessor
  • albop/preprocessor
  • DoraK/preprocessor
  • amg/preprocessor
  • wmutschl/preprocessor
  • JohannesPfeifer/preprocessor
11 results
Select Git revision
Show changes
Showing with 2728 additions and 1694 deletions
/*
* Copyright © 2018-2019 Dynare Team
* Copyright © 2018-2023 Dynare Team
*
* This file is part of Dynare.
*
......@@ -14,20 +14,26 @@
* GNU General Public License for more details.SS
*
* You should have received a copy of the GNU General Public License
* along with Dynare. If not, see <http://www.gnu.org/licenses/>.
* along with Dynare. If not, see <https://www.gnu.org/licenses/>.
*/
#ifndef _SUBMODEL_HH
#define _SUBMODEL_HH
#ifndef SUB_MODEL_HH
#define SUB_MODEL_HH
#include <set>
#include <iostream>
#include <map>
#include <optional>
#include <set>
#include <vector>
#include <iostream>
#include "ExprNode.hh"
#include "Statement.hh"
#include "SymbolList.hh"
#include "SymbolTable.hh"
// DynamicModel.hh can’t be included here, otherwise it would be a circular dependency
class DynamicModel;
using namespace std;
//! A table with all Trend Component Models in the .mod file
......@@ -40,13 +46,17 @@ private:
SymbolTable& symbol_table;
set<string> names;
map<string, vector<string>> eqtags, target_eqtags;
map<string, vector<int>> eqnums, target_eqnums, nontarget_eqnums, max_lags, lhs, target_lhs, nontarget_lhs, orig_diff_var;
map<string, vector<int>> eqnums, target_eqnums, nontarget_eqnums, max_lags, lhs, target_lhs,
nontarget_lhs;
map<string, vector<optional<int>>> orig_diff_var;
map<string, vector<set<pair<int, int>>>> rhs;
map<string, vector<bool>> diff;
map<string, vector<expr_t>> lhs_expr_t;
map<string, vector<int>> target_vars;
map<string, map<tuple<int, int, int>, expr_t>> AR; // AR: name -> (eqn, lag, lhs_symb_id) -> expr_t
map<string, map<tuple<int, int, int>, expr_t>> A0, A0star; // EC: name -> (eqn, lag, col) -> expr_t
map<string, vector<optional<int>>> target_vars;
map<string, map<tuple<int, int, int>, expr_t>> AR; // name -> (eqn, lag, lhs_symb_id) -> expr_t
/* Note that A0 in the trend-component model context is not the same thing as
in the structural VAR context. */
map<string, map<tuple<int, int>, expr_t>> A0, A0star; // name -> (eqn, col) -> expr_t
public:
explicit TrendComponentModelTable(SymbolTable& symbol_table_arg);
......@@ -54,38 +64,36 @@ public:
void addTrendComponentModel(string name_arg, vector<string> eqtags_arg,
vector<string> target_eqtags_arg);
inline bool isExistingTrendComponentModelName(const string &name_arg) const;
inline bool empty() const;
map<string, vector<string>> getEqTags() const;
vector<string> getEqTags(const string &name_arg) const;
map<string, vector<string>> getTargetEqTags() const;
map<string, vector<int>> getEqNums() const;
map<string, vector<int>> getTargetEqNums() const;
vector<int> getTargetEqNums(const string &name_arg) const;
vector<int> getEqNums(const string &name_arg) const;
vector<int> getMaxLags(const string &name_arg) const;
int getMaxLag(const string &name_arg) const;
vector<int> getLhs(const string &name_arg) const;
vector<expr_t> getLhsExprT(const string &name_arg) const;
vector<bool> getDiff(const string &name_arg) const;
vector<int> getOrigDiffVar(const string &name_arg) const;
map<string, vector<int>> getNonTargetEqNums() const;
vector<int> getNonTargetEqNums(const string &name_arg) const;
vector<int> getNonTargetLhs(const string &name_arg) const;
vector<int> getTargetLhs(const string &name_arg) const;
[[nodiscard]] inline bool isExistingTrendComponentModelName(const string& name_arg) const;
[[nodiscard]] inline bool empty() const;
[[nodiscard]] const map<string, vector<string>>& getEqTags() const;
[[nodiscard]] const vector<string>& getEqTags(const string& name_arg) const;
[[nodiscard]] const map<string, vector<string>>& getTargetEqTags() const;
[[nodiscard]] const map<string, vector<int>>& getEqNums() const;
[[nodiscard]] const map<string, vector<int>>& getTargetEqNums() const;
[[nodiscard]] const vector<int>& getTargetEqNums(const string& name_arg) const;
[[nodiscard]] const vector<int>& getEqNums(const string& name_arg) const;
[[nodiscard]] const vector<int>& getMaxLags(const string& name_arg) const;
[[nodiscard]] int getMaxLag(const string& name_arg) const;
[[nodiscard]] const vector<int>& getLhs(const string& name_arg) const;
[[nodiscard]] const vector<expr_t>& getLhsExprT(const string& name_arg) const;
[[nodiscard]] const vector<bool>& getDiff(const string& name_arg) const;
[[nodiscard]] const map<string, vector<int>>& getNonTargetEqNums() const;
[[nodiscard]] const vector<int>& getNonTargetEqNums(const string& name_arg) const;
[[nodiscard]] const vector<int>& getNonTargetLhs(const string& name_arg) const;
[[nodiscard]] const vector<int>& getTargetLhs(const string& name_arg) const;
void setVals(map<string, vector<int>> eqnums_arg, map<string, vector<int>> target_eqnums_arg,
map<string, vector<int>> lhs_arg,
map<string, vector<expr_t>> lhs_expr_t_arg);
map<string, vector<int>> lhs_arg, map<string, vector<expr_t>> lhs_expr_t_arg);
void setRhs(map<string, vector<set<pair<int, int>>>> rhs_arg);
void setMaxLags(map<string, vector<int>> max_lags_arg);
void setDiff(map<string, vector<bool>> diff_arg);
void setOrigDiffVar(map<string, vector<int>> orig_diff_var_arg);
void setTargetVar(map<string, vector<int>> target_vars_arg);
void setOrigDiffVar(map<string, vector<optional<int>>> orig_diff_var_arg);
void setTargetVar(map<string, vector<optional<int>>> target_vars_arg);
void setAR(map<string, map<tuple<int, int, int>, expr_t>> AR_arg);
void setA0(map<string, map<tuple<int, int, int>, expr_t>> A0_arg,
map<string, map<tuple<int, int, int>, expr_t>> A0star_arg);
void setA0(map<string, map<tuple<int, int>, expr_t>> A0_arg,
map<string, map<tuple<int, int>, expr_t>> A0star_arg);
//! Write output of this class
void writeOutput(const string& basename, ostream& output) const;
......@@ -101,7 +109,7 @@ private:
inline bool
TrendComponentModelTable::isExistingTrendComponentModelName(const string& name_arg) const
{
return names.find(name_arg) != names.end();
return names.contains(name_arg);
}
inline bool
......@@ -110,41 +118,48 @@ TrendComponentModelTable::empty() const
return names.empty();
}
class VarModelTable
{
private:
SymbolTable& symbol_table;
set<string> names;
map<string, pair<SymbolList, int>> symbol_list_and_order;
map<string, bool> structural; // Whether VARs are structural or reduced-form
map<string, vector<string>> eqtags;
map<string, vector<int>> eqnums, max_lags, lhs, lhs_orig_symb_ids, orig_diff_var;
map<string, vector<set<pair<int, int>>>> rhs;
map<string, vector<int>> eqnums, max_lags, lhs, lhs_orig_symb_ids;
map<string, vector<optional<int>>> orig_diff_var;
map<string, vector<set<pair<int, int>>>>
rhs; // name -> for each equation: set of pairs (var, lag)
map<string, vector<bool>> diff;
map<string, vector<expr_t>> lhs_expr_t;
map<string, map<tuple<int, int, int>, expr_t>> AR; // AR: name -> (eqn, lag, lhs_symb_id) -> param_expr_t
map<string, map<tuple<int, int, int>, expr_t>>
AR; // name -> (eqn, lag, lhs_symb_id) -> param_expr_t
/* The A0 matrix is mainly for structural VARs. For reduced-form VARs, it
will be equal to the identity matrix. Also note that A0 in the structural
VAR context is not the same thing as in the trend-component model
context. */
map<string, map<tuple<int, int>, expr_t>> A0; // name -> (eqn, lhs_symb_id) -> param_expr_t
map<string, map<int, expr_t>> constants; // name -> eqn -> constant
public:
explicit VarModelTable(SymbolTable& symbol_table_arg);
//! Add a VAR model
void addVarModel(string name, vector<string> eqtags,
pair<SymbolList, int> symbol_list_and_order_arg);
inline bool isExistingVarModelName(const string &name_arg) const;
inline bool empty() const;
map<string, vector<string>> getEqTags() const;
vector<string> getEqTags(const string &name_arg) const;
map<string, vector<int>> getEqNums() const;
vector<bool> getDiff(const string &name_arg) const;
vector<int> getEqNums(const string &name_arg) const;
vector<int> getMaxLags(const string &name_arg) const;
int getMaxLag(const string &name_arg) const;
vector<int> getLhs(const string &name_arg) const;
vector<int> getLhsOrigIds(const string &name_arg) const;
map<string, pair<SymbolList, int>> getSymbolListAndOrder() const;
vector<set<pair<int, int>>> getRhs(const string &name_arg) const;
vector<expr_t> getLhsExprT(const string &name_arg) const;
void addVarModel(string name, bool structural_arg, vector<string> eqtags);
[[nodiscard]] inline bool isExistingVarModelName(const string& name_arg) const;
[[nodiscard]] inline bool empty() const;
[[nodiscard]] const map<string, bool>& getStructural() const;
[[nodiscard]] const map<string, vector<string>>& getEqTags() const;
[[nodiscard]] const vector<string>& getEqTags(const string& name_arg) const;
[[nodiscard]] const map<string, vector<int>>& getEqNums() const;
[[nodiscard]] const vector<bool>& getDiff(const string& name_arg) const;
[[nodiscard]] const vector<int>& getEqNums(const string& name_arg) const;
[[nodiscard]] const vector<int>& getMaxLags(const string& name_arg) const;
[[nodiscard]] int getMaxLag(const string& name_arg) const;
[[nodiscard]] const vector<int>& getLhs(const string& name_arg) const;
[[nodiscard]] const vector<int>& getLhsOrigIds(const string& name_arg) const;
[[nodiscard]] const vector<set<pair<int, int>>>& getRhs(const string& name_arg) const;
[[nodiscard]] const vector<expr_t>& getLhsExprT(const string& name_arg) const;
void setEqNums(map<string, vector<int>> eqnums_arg);
void setLhs(map<string, vector<int>> lhs_arg);
......@@ -152,8 +167,10 @@ public:
void setLhsExprT(map<string, vector<expr_t>> lhs_expr_t_arg);
void setDiff(map<string, vector<bool>> diff_arg);
void setMaxLags(map<string, vector<int>> max_lags_arg);
void setOrigDiffVar(map<string, vector<int>> orig_diff_var_arg);
void setOrigDiffVar(map<string, vector<optional<int>>> orig_diff_var_arg);
void setAR(map<string, map<tuple<int, int, int>, expr_t>> AR_arg);
void setA0(map<string, map<tuple<int, int>, expr_t>> A0_arg);
void setConstants(map<string, map<int, expr_t>> constants_arg);
//! Write output of this class
void writeOutput(const string& basename, ostream& output) const;
......@@ -168,7 +185,7 @@ private:
inline bool
VarModelTable::isExistingVarModelName(const string& name_arg) const
{
return names.find(name_arg) != names.end();
return names.contains(name_arg);
}
inline bool
......@@ -177,4 +194,161 @@ VarModelTable::empty() const
return names.empty();
}
class VarExpectationModelTable
{
private:
SymbolTable& symbol_table;
set<string> names;
map<string, expr_t> expression;
map<string, string> aux_model_name;
map<string, string> horizon;
map<string, expr_t> discount;
map<string, int> time_shift;
// For each model, list of generated auxiliary param ids, in variable-major order
map<string, vector<int>> aux_param_symb_ids;
// Decomposition of the expression
map<string, vector<tuple<int, optional<int>, double>>> vars_params_constants;
public:
explicit VarExpectationModelTable(SymbolTable& symbol_table_arg);
void addVarExpectationModel(string name_arg, expr_t expression_arg, string aux_model_name_arg,
string horizon_arg, expr_t discount_arg, int time_shift_arg);
[[nodiscard]] bool isExistingVarExpectationModelName(const string& name_arg) const;
[[nodiscard]] bool empty() const;
void substituteUnaryOpsInExpression(const lag_equivalence_table_t& nodes,
ExprNode::subst_table_t& subst_table,
vector<BinaryOpNode*>& neweqs);
// Called by DynamicModel::substituteDiff()
void substituteDiffNodesInExpression(const lag_equivalence_table_t& diff_nodes,
ExprNode::subst_table_t& diff_subst_table,
vector<BinaryOpNode*>& neweqs);
void transformPass(ExprNode::subst_table_t& diff_subst_table, DynamicModel& dynamic_model,
const VarModelTable& var_model_table,
const TrendComponentModelTable& trend_component_model_table);
void writeOutput(ostream& output) const;
void writeJsonOutput(ostream& output) const;
};
class PacModelTable
{
private:
SymbolTable& symbol_table;
set<string> names;
map<string, string> aux_model_name;
map<string, string> discount;
/* The growth expressions belong to the main dynamic_model from the ModFile
instance. The growth expression is necessarily nullptr for a model with a
pac_target_info block. */
map<string, expr_t> growth, original_growth;
/* Information about the structure of growth expressions (which must be a
linear combination of variables, possibly with additional constants).
Each tuple represents a term: (endo_id, lag, param_id, constant) */
using growth_info_t = vector<tuple<optional<int>, int, optional<int>, double>>;
map<string, growth_info_t> growth_info;
// The “auxname” option of pac_model (empty if not passed)
map<string, string> auxname;
// The “kind” option of pac_model (“undefined” if not passed)
map<string, PacTargetKind> kind;
/* Stores the name of the PAC equation associated to the model.
pac_model_name → eq_name */
map<string, string> eq_name;
/* Stores symb_ids for auxiliary endogenous created for the expression
substituted to the pac_expectation operator:
- in the backward case, this auxiliary contains exactly the
pac_expectation value
- in the MCE case, this auxiliary represents Z₁ (i.e. without the growth
correction term)
Note that this structure is not used in the presence of the
pac_target_info block.
pac_model_name → symb_id */
map<string, int> aux_var_symb_ids;
/* Stores symb_ids for auxiliary parameters created for the expression
substituted to the pac_expectation operator (excluding the growth
neutrality correction):
- in the backward case, contains the “h” parameters
- in the MCE case, contains the “α” parameters
Note that this structure is not used in the presence of the
pac_target_info block.
pac_model_name → symb_ids */
map<string, vector<int>> aux_param_symb_ids;
/* Stores indices for growth neutrality parameters
pac_model_name → growth_neutrality_param_index.
This map is not used for PAC models with a pac_target_info block. */
map<string, int> growth_neutrality_params;
// Stores LHS vars (only for backward PAC models)
map<string, vector<int>> lhs;
// Stores auxiliary model type (only for backward PAC models)
map<string, string> aux_model_type;
public:
/* Stores info about PAC equations
pac_model_name →
(lhs, optim_share_index, ar_params_and_vars, ec_params_and_vars,
non_optim_vars_params_and_constants, additive_vars_params_and_constants,
optim_additive_vars_params_and_constants)
*/
using equation_info_t
= map<string,
tuple<pair<int, int>, optional<int>, vector<tuple<optional<int>, optional<int>, int>>,
pair<int, vector<tuple<int, bool, int>>>,
vector<tuple<int, int, optional<int>, double>>,
vector<tuple<int, int, optional<int>, double>>,
vector<tuple<int, int, optional<int>, double>>>>;
private:
equation_info_t equation_info;
public:
/* (component variable/expr, growth, auxname, kind, coeff. in the linear
combination, growth_param ID (unused if growth is nullptr), vector of h parameters,
original_growth, growth_info) */
using target_component_t = tuple<expr_t, expr_t, string, PacTargetKind, expr_t, int, vector<int>,
expr_t, growth_info_t>;
private:
// pac_model_name → (target variable/expr, auxname_target_nonstationary, target components)
map<string, tuple<expr_t, string, vector<target_component_t>>> target_info;
[[nodiscard]] int pacEquationMaxLag(const string& name_arg) const;
// Return a text representation of a kind (but fails on “unspecified” kind value)
static string kindToString(PacTargetKind kind);
public:
explicit PacModelTable(SymbolTable& symbol_table_arg);
void addPacModel(string name_arg, string aux_model_name_arg, string discount_arg,
expr_t growth_arg, string auxname_arg, PacTargetKind kind_arg);
[[nodiscard]] bool isExistingPacModelName(const string& name_arg) const;
[[nodiscard]] bool empty() const;
void checkPass(ModFileStructure& mod_file_struct);
// Called by DynamicModel::substituteUnaryOps()
void substituteUnaryOpsInGrowth(const lag_equivalence_table_t& nodes,
ExprNode::subst_table_t& subst_table,
vector<BinaryOpNode*>& neweqs);
void findDiffNodesInGrowth(lag_equivalence_table_t& diff_nodes) const;
// Called by DynamicModel::substituteDiff()
void substituteDiffNodesInGrowth(const lag_equivalence_table_t& diff_nodes,
ExprNode::subst_table_t& diff_subst_table,
vector<BinaryOpNode*>& neweqs);
// Must be called after substituteDiffNodesInGrowth() and substituteUnaryOpsInGrowth()
void transformPass(const lag_equivalence_table_t& unary_ops_nodes,
ExprNode::subst_table_t& unary_ops_subst_table,
const lag_equivalence_table_t& diff_nodes,
ExprNode::subst_table_t& diff_subst_table, DynamicModel& dynamic_model,
const VarModelTable& var_model_table,
const TrendComponentModelTable& trend_component_model_table);
void writeOutput(ostream& output) const;
void writeJsonOutput(ostream& output) const;
void setTargetExpr(const string& name_arg, expr_t target);
void setTargetAuxnameNonstationary(const string& name_arg, string auxname);
/* Only the first four elements of the tuple are expected to be set by the
caller. The other ones will be filled by this class. */
void addTargetComponent(const string& name_arg, target_component_t component);
void writeTargetCoefficientsFile(const string& basename) const;
};
#endif
/*
* Copyright © 2003-2019 Dynare Team
* Copyright © 2003-2024 Dynare Team
*
* This file is part of Dynare.
*
......@@ -14,27 +14,120 @@
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with Dynare. If not, see <http://www.gnu.org/licenses/>.
* along with Dynare. If not, see <https://www.gnu.org/licenses/>.
*/
#include <regex>
#include "SymbolList.hh"
SymbolList::SymbolList(vector<string> symbols_arg) : symbols {move(symbols_arg)}
{
}
void
SymbolList::addSymbol(const string &symbol)
SymbolList::checkPass(WarningConsolidation& warnings, const vector<SymbolType>& types,
const SymbolTable& symbol_table) const noexcept(false)
{
if (types.empty())
return;
smatch m;
string regex_str = "AUX_EXPECT_|MULT_";
for (auto type : types)
if (type == SymbolType::endogenous)
{
regex_str += "|AUX_ENDO_|LOG_";
break;
}
regex re("^(" + regex_str + ")");
for (const auto& symbol : symbols)
{
if (!symbol_table.exists(symbol))
{
if (regex_search(symbol, m, re))
{
warnings
<< "WARNING: symbol_list variable " << symbol << " has not yet been declared. "
<< "This is being ignored because the variable name corresponds to a possible "
<< "auxiliary variable name." << endl;
return;
}
else
throw SymbolListException {"Variable " + symbol + " was not declared."};
}
if (ranges::none_of(types,
[&](SymbolType type) { return symbol_table.getType(symbol) == type; }))
{
symbols.push_back(symbol);
string valid_types;
for (auto type : types)
switch (type)
{
case SymbolType::endogenous:
valid_types += "endogenous, ";
break;
case SymbolType::exogenous:
valid_types += "exogenous, ";
break;
case SymbolType::epilogue:
valid_types += "epilogue, ";
break;
case SymbolType::parameter:
valid_types += "parameter, ";
break;
case SymbolType::exogenousDet:
valid_types += "exogenousDet, ";
break;
case SymbolType::trend:
valid_types += "trend, ";
break;
case SymbolType::logTrend:
valid_types += "logTrend, ";
break;
case SymbolType::modFileLocalVariable:
valid_types += "modFileLocalVariable, ";
break;
case SymbolType::modelLocalVariable:
valid_types += "modelLocalVariable, ";
break;
case SymbolType::externalFunction:
valid_types += "externalFunction, ";
break;
case SymbolType::statementDeclaredVariable:
valid_types += "statementDeclaredVariable, ";
break;
case SymbolType::unusedEndogenous:
valid_types += "unusedEndogenous, ";
break;
case SymbolType::excludedVariable:
valid_types += "excludedVariable, ";
break;
case SymbolType::heterogeneousEndogenous:
valid_types += "heterogeneousEndogenous, ";
break;
case SymbolType::heterogeneousExogenous:
valid_types += "heterogeneousExogenous, ";
break;
case SymbolType::heterogeneousParameter:
valid_types += "heterogeneousParameter, ";
break;
}
valid_types = valid_types.erase(valid_types.size() - 2, 2);
throw SymbolListException {"Variable " + symbol + " is not one of {" + valid_types + "}"};
}
}
}
void
SymbolList::writeOutput(const string& varname, ostream& output) const
{
output << varname << " = {";
for (auto it = symbols.begin();
it != symbols.end(); ++it)
for (bool printed_something {false}; const auto& name : symbols)
{
if (it != symbols.begin())
if (exchange(printed_something, true))
output << ";";
output << "'" << *it << "'";
output << "'" << name << "'";
}
output << "};" << endl;
}
......@@ -43,28 +136,15 @@ void
SymbolList::writeJsonOutput(ostream& output) const
{
output << R"("symbol_list": [)";
for (auto it = symbols.begin();
it != symbols.end(); ++it)
for (bool printed_something {false}; const auto& name : symbols)
{
if (it != symbols.begin())
if (exchange(printed_something, true))
output << ",";
output << R"(")" << *it << R"(")";
output << R"(")" << name << R"(")";
}
output << "]";
}
void
SymbolList::clear()
{
symbols.clear();
}
int
SymbolList::getSize() const
{
return symbols.size();
}
vector<string>
SymbolList::getSymbols() const
{
......@@ -75,11 +155,12 @@ void
SymbolList::removeDuplicates(const string& dynare_command, WarningConsolidation& warnings)
{
vector<string> unique_symbols;
for (auto & it : symbols)
for (const auto& it : symbols)
if (find(unique_symbols.begin(), unique_symbols.end(), it) == unique_symbols.end())
unique_symbols.push_back(it);
else
warnings << "WARNING: In " << dynare_command << ": " << it
<< " found more than once in symbol list. Removing all but first occurence." << endl;
<< " found more than once in symbol list. Removing all but first occurrence."
<< endl;
symbols = unique_symbols;
}
/*
* Copyright © 2003-2019 Dynare Team
* Copyright © 2003-2023 Dynare Team
*
* This file is part of Dynare.
*
......@@ -14,16 +14,18 @@
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with Dynare. If not, see <http://www.gnu.org/licenses/>.
* along with Dynare. If not, see <https://www.gnu.org/licenses/>.
*/
#ifndef _SYMBOL_LIST_HH
#define _SYMBOL_LIST_HH
#ifndef SYMBOL_LIST_HH
#define SYMBOL_LIST_HH
#include <algorithm>
#include <ostream>
#include <string>
#include <vector>
#include <ostream>
#include "SymbolTable.hh"
#include "WarningConsolidation.hh"
using namespace std;
......@@ -33,13 +35,22 @@ using namespace std;
class SymbolList
{
private:
//! Internal container for symbol list
vector<string> symbols;
public:
//! Adds a symbol to the list
void addSymbol(const string &symbol);
//! Removed duplicate symbols
SymbolList() = default;
// This constructor is deliberately not marked explicit, to allow implicit conversion
SymbolList(vector<string> symbols_arg);
struct SymbolListException
{
const string message;
};
//! Remove duplicate symbols
void removeDuplicates(const string& dynare_command, WarningConsolidation& warnings);
//! Check symbols to ensure variables have been declared and are endogenous
void checkPass(WarningConsolidation& warnings, const vector<SymbolType>& types,
const SymbolTable& symbol_table) const noexcept(false);
//! Output content in Matlab format
/*! Creates a string array for Matlab, stored in variable "varname" */
void writeOutput(const string& varname, ostream& output) const;
......@@ -47,24 +58,14 @@ public:
void write(ostream& output) const;
//! Write JSON output
void writeJsonOutput(ostream& output) const;
//! Clears all content
void clear();
//! Get a copy of the string vector
vector<string>
get_symbols() const
{
return symbols;
};
//! Is Empty
int
[[nodiscard]] bool
empty() const
{
return symbols.empty();
};
//! Return the number of Symbols contained in the list
int getSize() const;
}
//! Return the list of symbols
vector<string> getSymbols() const;
[[nodiscard]] vector<string> getSymbols() const;
};
#endif
/*
* Copyright © 2003-2019 Dynare Team
* Copyright © 2003-2024 Dynare Team
*
* This file is part of Dynare.
*
......@@ -14,13 +14,13 @@
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with Dynare. If not, see <http://www.gnu.org/licenses/>.
* along with Dynare. If not, see <https://www.gnu.org/licenses/>.
*/
#include <algorithm>
#include <sstream>
#include <iostream>
#include <cassert>
#include <iostream>
#include <sstream>
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wold-style-cast"
#include <boost/algorithm/string/replace.hpp>
......@@ -29,25 +29,10 @@
#include "SymbolTable.hh"
AuxVarInfo::AuxVarInfo(int symb_id_arg, AuxVarType type_arg, int orig_symb_id_arg, int orig_lead_lag_arg,
int equation_number_for_multiplier_arg, int information_set_arg,
expr_t expr_node_arg, string unary_op_arg) :
symb_id{symb_id_arg},
type{type_arg},
orig_symb_id{orig_symb_id_arg},
orig_lead_lag{orig_lead_lag_arg},
equation_number_for_multiplier{equation_number_for_multiplier_arg},
information_set{information_set_arg},
expr_node{expr_node_arg},
unary_op{move(unary_op_arg)}
{
}
SymbolTable::SymbolTable()
= default;
int
SymbolTable::addSymbol(const string &name, SymbolType type, const string &tex_name, const vector<pair<string, string>> &partition_value) noexcept(false)
SymbolTable::addSymbol(const string& name, SymbolType type, const string& tex_name,
const vector<pair<string, string>>& partition_value,
const optional<int>& heterogeneity_dimension) noexcept(false)
{
if (frozen)
throw FrozenException();
......@@ -55,9 +40,9 @@ SymbolTable::addSymbol(const string &name, SymbolType type, const string &tex_na
if (exists(name))
{
if (type_table[getID(name)] == type)
throw AlreadyDeclaredException(name, true);
throw AlreadyDeclaredException {name, true};
else
throw AlreadyDeclaredException(name, false);
throw AlreadyDeclaredException {name, false};
}
string final_tex_name = tex_name;
......@@ -74,7 +59,7 @@ SymbolTable::addSymbol(const string &name, SymbolType type, const string &tex_na
string final_long_name = name;
bool non_long_name_partition_exists = false;
for (auto it : partition_value)
for (const auto& it : partition_value)
if (it.first == "long_name")
final_long_name = it.second;
else
......@@ -90,17 +75,24 @@ SymbolTable::addSymbol(const string &name, SymbolType type, const string &tex_na
if (non_long_name_partition_exists)
{
map<string, string> pmv;
for (auto it : partition_value)
for (const auto& it : partition_value)
pmv[it.first] = it.second;
partition_value_map[id] = pmv;
}
assert(!isHeterogeneous(type)
|| (heterogeneity_dimension.has_value() && *heterogeneity_dimension >= 0
&& *heterogeneity_dimension < heterogeneity_table.size()));
if (isHeterogeneous(type))
heterogeneity_dimensions.emplace(id, *heterogeneity_dimension);
return id;
}
int
SymbolTable::addSymbol(const string& name, SymbolType type) noexcept(false)
{
return addSymbol(name, type, "", {});
return addSymbol(name, type, "", {}, {});
}
void
......@@ -111,6 +103,10 @@ SymbolTable::freeze() noexcept(false)
frozen = true;
het_endo_ids.resize(heterogeneity_table.size());
het_exo_ids.resize(heterogeneity_table.size());
het_param_ids.resize(heterogeneity_table.size());
for (int i = 0; i < static_cast<int>(symbol_table.size()); i++)
{
int tsi;
......@@ -132,11 +128,22 @@ SymbolTable::freeze() noexcept(false)
tsi = param_ids.size();
param_ids.push_back(i);
break;
default:
tsi = -1;
case SymbolType::heterogeneousEndogenous:
tsi = het_endo_ids.at(heterogeneity_dimensions.at(i)).size();
het_endo_ids.at(heterogeneity_dimensions.at(i)).push_back(i);
break;
case SymbolType::heterogeneousExogenous:
tsi = het_exo_ids.at(heterogeneity_dimensions.at(i)).size();
het_exo_ids.at(heterogeneity_dimensions.at(i)).push_back(i);
break;
case SymbolType::heterogeneousParameter:
tsi = het_param_ids.at(heterogeneity_dimensions.at(i)).size();
het_param_ids.at(heterogeneity_dimensions.at(i)).push_back(i);
break;
default:
continue;
}
type_specific_ids.push_back(tsi);
type_specific_ids[i] = tsi;
}
}
......@@ -148,12 +155,18 @@ SymbolTable::unfreeze()
exo_ids.clear();
exo_det_ids.clear();
param_ids.clear();
het_endo_ids.clear();
het_exo_ids.clear();
het_param_ids.clear();
type_specific_ids.clear();
}
void
SymbolTable::changeType(int id, SymbolType newtype) noexcept(false)
{
// FIXME: implement switch to heterogeneous variable; dimension will have to be provided
assert(!isHeterogeneous(newtype));
if (frozen)
throw FrozenException();
......@@ -163,7 +176,8 @@ SymbolTable::changeType(int id, SymbolType newtype) noexcept(false)
}
int
SymbolTable::getID(SymbolType type, int tsid) const noexcept(false)
SymbolTable::getID(SymbolType type, int tsid, const optional<int>& heterogeneity_dimension) const
noexcept(false)
{
if (!frozen)
throw NotYetFrozenException();
......@@ -172,26 +186,44 @@ SymbolTable::getID(SymbolType type, int tsid) const noexcept(false)
{
case SymbolType::endogenous:
if (tsid < 0 || tsid >= static_cast<int>(endo_ids.size()))
throw UnknownTypeSpecificIDException(tsid, type);
throw UnknownTypeSpecificIDException {tsid, type, {}};
else
return endo_ids[tsid];
case SymbolType::exogenous:
if (tsid < 0 || tsid >= static_cast<int>(exo_ids.size()))
throw UnknownTypeSpecificIDException(tsid, type);
throw UnknownTypeSpecificIDException {tsid, type, {}};
else
return exo_ids[tsid];
case SymbolType::exogenousDet:
if (tsid < 0 || tsid >= static_cast<int>(exo_det_ids.size()))
throw UnknownTypeSpecificIDException(tsid, type);
throw UnknownTypeSpecificIDException {tsid, type, {}};
else
return exo_det_ids[tsid];
case SymbolType::parameter:
if (tsid < 0 || tsid >= static_cast<int>(param_ids.size()))
throw UnknownTypeSpecificIDException(tsid, type);
throw UnknownTypeSpecificIDException {tsid, type, {}};
else
return param_ids[tsid];
case SymbolType::heterogeneousEndogenous:
assert(heterogeneity_dimension.has_value());
if (tsid < 0 || tsid >= static_cast<int>(het_endo_ids.at(*heterogeneity_dimension).size()))
throw UnknownTypeSpecificIDException {tsid, type, *heterogeneity_dimension};
else
return het_endo_ids.at(*heterogeneity_dimension).at(tsid);
case SymbolType::heterogeneousExogenous:
assert(heterogeneity_dimension.has_value());
if (tsid < 0 || tsid >= static_cast<int>(het_exo_ids.at(*heterogeneity_dimension).size()))
throw UnknownTypeSpecificIDException {tsid, type, *heterogeneity_dimension};
else
return het_exo_ids.at(*heterogeneity_dimension).at(tsid);
case SymbolType::heterogeneousParameter:
assert(heterogeneity_dimension.has_value());
if (tsid < 0 || tsid >= static_cast<int>(het_param_ids.at(*heterogeneity_dimension).size()))
throw UnknownTypeSpecificIDException {tsid, type, *heterogeneity_dimension};
else
return het_param_ids.at(*heterogeneity_dimension).at(tsid);
default:
throw UnknownTypeSpecificIDException(tsid, type);
throw UnknownTypeSpecificIDException {tsid, type, {}};
}
}
......@@ -201,13 +233,8 @@ SymbolTable::getPartitionsForType(SymbolType st) const noexcept(false)
map<string, map<int, string>> partitions;
for (const auto& it : partition_value_map)
if (getType(it.first) == st)
for (auto it1 = it.second.begin();
it1 != it.second.end(); it1++)
{
if (partitions.find(it1->first) == partitions.end())
partitions[it1->first] = map<int, string> ();
partitions[it1->first][it.first] = it1->second;
}
for (const auto& it1 : it.second)
partitions[it1.first][it.first] = it1.second;
return partitions;
}
......@@ -224,27 +251,25 @@ SymbolTable::writeOutput(ostream &output) const noexcept(false)
output << "M_.exo_names_long = cell(" << exo_nbr() << ",1);" << endl;
for (int id = 0; id < exo_nbr(); id++)
output << "M_.exo_names(" << id + 1 << ") = {'" << getName(exo_ids[id]) << "'};" << endl
<< "M_.exo_names_tex(" << id+1 << ") = {'" << getTeXName(exo_ids[id]) << "'};" << endl
<< "M_.exo_names_long(" << id+1 << ") = {'" << getLongName(exo_ids[id]) << "'};" << endl;
map<string, map<int, string>> partitions = getPartitionsForType(SymbolType::exogenous);
for (map<string, map<int, string>>::const_iterator it = partitions.begin();
it != partitions.end(); it++)
if (it->first != "long_name")
{
map<int, string>::const_iterator it1;
output << "M_.exo_partitions." << it->first << " = { ";
<< "M_.exo_names_tex(" << id + 1 << ") = {'" << getTeXName(exo_ids[id]) << "'};"
<< endl
<< "M_.exo_names_long(" << id + 1 << ") = {'" << getLongName(exo_ids[id]) << "'};"
<< endl;
for (auto& partition : getPartitionsForType(SymbolType::exogenous))
if (partition.first != "long_name")
{
output << "M_.exo_partitions." << partition.first << " = { ";
for (int id = 0; id < exo_nbr(); id++)
{
output << "'";
it1 = it->second.find(exo_ids[id]);
if (it1 != it->second.end())
if (auto it1 = partition.second.find(exo_ids[id]); it1 != partition.second.end())
output << it1->second;
output << "' ";
}
output << "};" << endl;
if (it->first == "status")
if (partition.first == "status")
output << "M_ = set_observed_exogenous_variables(M_);" << endl;
if (it->first == "used")
if (partition.first == "used")
output << "M_ = set_exogenous_variables_for_simulation(M_);" << endl;
}
}
......@@ -261,22 +286,22 @@ SymbolTable::writeOutput(ostream &output) const noexcept(false)
output << "M_.exo_det_names_tex = cell(" << exo_det_nbr() << ",1);" << endl;
output << "M_.exo_det_names_long = cell(" << exo_det_nbr() << ",1);" << endl;
for (int id = 0; id < exo_det_nbr(); id++)
output << "M_.exo_det_names(" << id+1 << ") = {'" << getName(exo_det_ids[id]) << "'};" << endl
<< "M_.exo_det_names_tex(" << id+1 << ") = {'" << getTeXName(exo_det_ids[id]) << "'};" << endl
<< "M_.exo_det_names_long(" << id+1 << ") = {'" << getLongName(exo_det_ids[id]) << "'};" << endl;
output << "M_.exo_det_names(" << id + 1 << ") = {'" << getName(exo_det_ids[id]) << "'};"
<< endl
<< "M_.exo_det_names_tex(" << id + 1 << ") = {'" << getTeXName(exo_det_ids[id])
<< "'};" << endl
<< "M_.exo_det_names_long(" << id + 1 << ") = {'" << getLongName(exo_det_ids[id])
<< "'};" << endl;
output << "M_.exo_det_partitions = struct();" << endl;
map<string, map<int, string>> partitions = getPartitionsForType(SymbolType::exogenousDet);
for (map<string, map<int, string>>::const_iterator it = partitions.begin();
it != partitions.end(); it++)
if (it->first != "long_name")
for (auto& partition : getPartitionsForType(SymbolType::exogenousDet))
if (partition.first != "long_name")
{
map<int, string>::const_iterator it1;
output << "M_.exo_det_partitions." << it->first << " = { ";
output << "M_.exo_det_partitions." << partition.first << " = { ";
for (int id = 0; id < exo_det_nbr(); id++)
{
output << "'";
it1 = it->second.find(exo_det_ids[id]);
if (it1 != it->second.end())
if (auto it1 = partition.second.find(exo_det_ids[id]);
it1 != partition.second.end())
output << it1->second;
output << "' ";
}
......@@ -291,21 +316,19 @@ SymbolTable::writeOutput(ostream &output) const noexcept(false)
output << "M_.endo_names_long = cell(" << endo_nbr() << ",1);" << endl;
for (int id = 0; id < endo_nbr(); id++)
output << "M_.endo_names(" << id + 1 << ") = {'" << getName(endo_ids[id]) << "'};" << endl
<< "M_.endo_names_tex(" << id+1 << ") = {'" << getTeXName(endo_ids[id]) << "'};" << endl
<< "M_.endo_names_long(" << id+1 << ") = {'" << getLongName(endo_ids[id]) << "'};" << endl;
<< "M_.endo_names_tex(" << id + 1 << ") = {'" << getTeXName(endo_ids[id]) << "'};"
<< endl
<< "M_.endo_names_long(" << id + 1 << ") = {'" << getLongName(endo_ids[id]) << "'};"
<< endl;
output << "M_.endo_partitions = struct();" << endl;
map<string, map<int, string>> partitions = getPartitionsForType(SymbolType::endogenous);
for (map<string, map<int, string>>::const_iterator it = partitions.begin();
it != partitions.end(); it++)
if (it->first != "long_name")
for (auto& partition : getPartitionsForType(SymbolType::endogenous))
if (partition.first != "long_name")
{
map<int, string>::const_iterator it1;
output << "M_.endo_partitions." << it->first << " = { ";
output << "M_.endo_partitions." << partition.first << " = { ";
for (int id = 0; id < endo_nbr(); id++)
{
output << "'";
it1 = it->second.find(endo_ids[id]);
if (it1 != it->second.end())
if (auto it1 = partition.second.find(endo_ids[id]); it1 != partition.second.end())
output << it1->second;
output << "' ";
}
......@@ -320,25 +343,24 @@ SymbolTable::writeOutput(ostream &output) const noexcept(false)
output << "M_.param_names_long = cell(" << param_nbr() << ",1);" << endl;
for (int id = 0; id < param_nbr(); id++)
{
output << "M_.param_names(" << id+1 << ") = {'" << getName(param_ids[id]) << "'};" << endl
<< "M_.param_names_tex(" << id+1 << ") = {'" << getTeXName(param_ids[id]) << "'};" << endl
<< "M_.param_names_long(" << id+1 << ") = {'" << getLongName(param_ids[id]) << "'};" << endl;
output << "M_.param_names(" << id + 1 << ") = {'" << getName(param_ids[id]) << "'};"
<< endl
<< "M_.param_names_tex(" << id + 1 << ") = {'" << getTeXName(param_ids[id])
<< "'};" << endl
<< "M_.param_names_long(" << id + 1 << ") = {'" << getLongName(param_ids[id])
<< "'};" << endl;
if (getName(param_ids[id]) == "dsge_prior_weight")
output << "options_.dsge_var = 1;" << endl;
}
output << "M_.param_partitions = struct();" << endl;
map<string, map<int, string>> partitions = getPartitionsForType(SymbolType::parameter);
for (map<string, map<int, string>>::const_iterator it = partitions.begin();
it != partitions.end(); it++)
if (it->first != "long_name")
for (auto& partition : getPartitionsForType(SymbolType::parameter))
if (partition.first != "long_name")
{
map<int, string>::const_iterator it1;
output << "M_.param_partitions." << it->first << " = { ";
output << "M_.param_partitions." << partition.first << " = { ";
for (int id = 0; id < param_nbr(); id++)
{
output << "'";
it1 = it->second.find(param_ids[id]);
if (it1 != it->second.end())
if (auto it1 = partition.second.find(param_ids[id]); it1 != partition.second.end())
output << it1->second;
output << "' ";
}
......@@ -364,46 +386,56 @@ SymbolTable::writeOutput(ostream &output) const noexcept(false)
else
for (int i = 0; i < static_cast<int>(aux_vars.size()); i++)
{
output << "M_.aux_vars(" << i+1 << ").endo_index = " << getTypeSpecificID(aux_vars[i].get_symb_id())+1 << ";" << endl
<< "M_.aux_vars(" << i+1 << ").type = " << aux_vars[i].get_type_id() << ";" << endl;
switch (aux_vars[i].get_type())
output << "M_.aux_vars(" << i + 1
<< ").endo_index = " << getTypeSpecificID(aux_vars[i].symb_id) + 1 << ";" << endl
<< "M_.aux_vars(" << i + 1 << ").type = " << aux_vars[i].get_type_id() << ";"
<< endl;
switch (aux_vars[i].type)
{
case AuxVarType::endoLead:
case AuxVarType::exoLead:
case AuxVarType::expectation:
case AuxVarType::pacExpectation:
case AuxVarType::pacTargetNonstationary:
case AuxVarType::aggregationOp:
break;
case AuxVarType::endoLag:
case AuxVarType::exoLag:
case AuxVarType::varModel:
output << "M_.aux_vars(" << i+1 << ").orig_index = " << getTypeSpecificID(aux_vars[i].get_orig_symb_id())+1 << ";" << endl
<< "M_.aux_vars(" << i+1 << ").orig_lead_lag = " << aux_vars[i].get_orig_lead_lag() << ";" << endl;
case AuxVarType::logTransform:
case AuxVarType::diffLag:
case AuxVarType::diffLead:
case AuxVarType::diffForward:
output << "M_.aux_vars(" << i + 1
<< ").orig_index = " << getTypeSpecificID(aux_vars[i].orig_symb_id.value()) + 1
<< ";" << endl
<< "M_.aux_vars(" << i + 1
<< ").orig_lead_lag = " << aux_vars[i].orig_lead_lag.value() << ";" << endl;
break;
case AuxVarType::unaryOp:
if (aux_vars[i].get_orig_symb_id() >= 0)
output << "M_.aux_vars(" << i+1 << ").orig_index = " << getTypeSpecificID(aux_vars[i].get_orig_symb_id())+1 << ";" << endl
<< "M_.aux_vars(" << i+1 << ").orig_lead_lag = " << aux_vars[i].get_orig_lead_lag() << ";" << endl;
output << "M_.aux_vars(" << i+1 << ").unary_op = '" << aux_vars[i].get_unary_op() << "';" << endl;
output << "M_.aux_vars(" << i + 1 << ").unary_op = '" << aux_vars[i].unary_op << "';"
<< endl;
[[fallthrough]];
case AuxVarType::diff:
if (aux_vars[i].orig_symb_id)
output << "M_.aux_vars(" << i + 1
<< ").orig_index = " << getTypeSpecificID(*aux_vars[i].orig_symb_id) + 1 << ";"
<< endl
<< "M_.aux_vars(" << i + 1
<< ").orig_lead_lag = " << aux_vars[i].orig_lead_lag.value() << ";" << endl;
break;
case AuxVarType::multiplier:
output << "M_.aux_vars(" << i+1 << ").eq_nbr = " << aux_vars[i].get_equation_number_for_multiplier() + 1 << ";" << endl;
break;
case AuxVarType::diffForward:
output << "M_.aux_vars(" << i+1 << ").orig_index = " << getTypeSpecificID(aux_vars[i].get_orig_symb_id())+1 << ";" << endl;
break;
case AuxVarType::expectation:
output << "M_.aux_vars(" << i+1 << R"().orig_expr = '\mathbb{E}_{t)"
<< (aux_vars[i].get_information_set() < 0 ? "" : "+")
<< aux_vars[i].get_information_set() << "}(";
aux_vars[i].get_expr_node()->writeOutput(output, ExprNodeOutputType::latexDynamicModel);
output << ")';" << endl;
break;
case AuxVarType::diff:
case AuxVarType::diffLag:
case AuxVarType::diffLead:
if (aux_vars[i].get_orig_symb_id() >= 0)
output << "M_.aux_vars(" << i+1 << ").orig_index = " << getTypeSpecificID(aux_vars[i].get_orig_symb_id())+1 << ";" << endl
<< "M_.aux_vars(" << i+1 << ").orig_lead_lag = " << aux_vars[i].get_orig_lead_lag() << ";" << endl;
output << "M_.aux_vars(" << i + 1
<< ").eq_nbr = " << aux_vars[i].equation_number_for_multiplier + 1 << ";"
<< endl;
break;
}
if (expr_t orig_expr = aux_vars[i].expr_node; orig_expr)
{
output << "M_.aux_vars(" << i + 1 << ").orig_expr = '";
orig_expr->writeJsonOutput(output, {}, {});
output << "';" << endl;
}
}
if (predeterminedNbr() > 0)
......@@ -416,11 +448,9 @@ SymbolTable::writeOutput(ostream &output) const noexcept(false)
if (observedVariablesNbr() > 0)
{
int ic = 1;
output << "options_.varobs = cell(" << observedVariablesNbr() << ", 1);" << endl;
for (auto it = varobs.begin();
it != varobs.end(); it++, ic++)
output << "options_.varobs(" << ic << ") = {'" << getName(*it) << "'};" << endl;
for (int ic {1}; int it : varobs)
output << "options_.varobs(" << ic++ << ") = {'" << getName(it) << "'};" << endl;
output << "options_.varobs_id = [ ";
for (int varob : varobs)
......@@ -430,66 +460,92 @@ SymbolTable::writeOutput(ostream &output) const noexcept(false)
if (observedExogenousVariablesNbr() > 0)
{
int ic = 1;
output << "options_.varexobs = cell(1);" << endl;
for (auto it = varexobs.begin();
it != varexobs.end(); it++, ic++)
output << "options_.varexobs(" << ic << ") = {'" << getName(*it) << "'};" << endl;
for (int ic {1}; int it : varexobs)
output << "options_.varexobs(" << ic++ << ") = {'" << getName(it) << "'};" << endl;
output << "options_.varexobs_id = [ ";
for (int varexob : varexobs)
output << getTypeSpecificID(varexob) + 1 << " ";
output << " ];" << endl;
}
// Heterogeneous symbols
// FIXME: the following helper could be used to simplify non-heterogenous variables
auto print_symb_names = [this, &output](const string& field, const auto& symb_ids) {
auto helper = [this, &output, &symb_ids](auto nameMethod) {
for (bool first_printed {false}; int symb_id : symb_ids)
{
if (exchange(first_printed, true))
output << "; ";
output << "'" << (this->*nameMethod)(symb_id) << "'";
}
};
output << field << " = {";
helper(&SymbolTable::getName);
output << "};" << endl << field << "_tex = {";
helper(&SymbolTable::getTeXName);
output << "};" << endl << field << "_long = {";
helper(&SymbolTable::getLongName);
output << "};" << endl;
};
for (int het_dim {0}; het_dim < heterogeneity_table.size(); het_dim++)
{
const string basefield {"M_.heterogeneity(" + to_string(het_dim + 1) + ")."};
output << basefield << "endo_nbr = " << het_endo_nbr(het_dim) << ";" << endl;
print_symb_names(basefield + "endo_names", het_endo_ids.at(het_dim));
output << basefield << "exo_nbr = " << het_exo_nbr(het_dim) << ";" << endl;
print_symb_names(basefield + "exo_names", het_exo_ids.at(het_dim));
output << basefield << "param_nbr = " << het_param_nbr(het_dim) << ";" << endl;
print_symb_names(basefield + "param_names", het_param_ids.at(het_dim));
}
}
int
SymbolTable::addLeadAuxiliaryVarInternal(bool endo, int index, expr_t expr_arg) noexcept(false)
{
ostringstream varname;
if (endo)
varname << "AUX_ENDO_LEAD_";
else
varname << "AUX_EXO_LEAD_";
varname << index;
string varname {(endo ? "AUX_ENDO_LEAD_" : "AUX_EXO_LEAD_") + to_string(index)};
int symb_id;
try
{
symb_id = addSymbol(varname.str(), SymbolType::endogenous);
symb_id = addSymbol(varname, SymbolType::endogenous);
}
catch (AlreadyDeclaredException& e)
{
cerr << "ERROR: you should rename your variable called " << varname.str() << ", this name is internally used by Dynare" << endl;
cerr << "ERROR: you should rename your variable called " << varname
<< ", this name is internally used by Dynare" << endl;
exit(EXIT_FAILURE);
}
aux_vars.emplace_back(symb_id, (endo ? AuxVarType::endoLead : AuxVarType::exoLead), 0, 0, 0, 0, expr_arg, "");
aux_vars.emplace_back(symb_id, (endo ? AuxVarType::endoLead : AuxVarType::exoLead), 0, 0, 0, 0,
expr_arg, "");
return symb_id;
}
int
SymbolTable::addLagAuxiliaryVarInternal(bool endo, int orig_symb_id, int orig_lead_lag, expr_t expr_arg) noexcept(false)
SymbolTable::addLagAuxiliaryVarInternal(bool endo, int orig_symb_id, int orig_lead_lag,
expr_t expr_arg) noexcept(false)
{
ostringstream varname;
if (endo)
varname << "AUX_ENDO_LAG_";
else
varname << "AUX_EXO_LAG_";
varname << orig_symb_id << "_" << -orig_lead_lag;
string varname {(endo ? "AUX_ENDO_LAG_" : "AUX_EXO_LAG_") + to_string(orig_symb_id) + "_"
+ to_string(-orig_lead_lag)};
int symb_id;
try
{
symb_id = addSymbol(varname.str(), SymbolType::endogenous);
symb_id = addSymbol(varname, SymbolType::endogenous);
}
catch (AlreadyDeclaredException& e)
{
cerr << "ERROR: you should rename your variable called " << varname.str() << ", this name is internally used by Dynare" << endl;
cerr << "ERROR: you should rename your variable called " << varname
<< ", this name is internally used by Dynare" << endl;
exit(EXIT_FAILURE);
}
aux_vars.emplace_back(symb_id, (endo ? AuxVarType::endoLag : AuxVarType::exoLag), orig_symb_id, orig_lead_lag, 0, 0, expr_arg, "");
aux_vars.emplace_back(symb_id, (endo ? AuxVarType::endoLag : AuxVarType::exoLag), orig_symb_id,
orig_lead_lag, 0, 0, expr_arg, "");
return symb_id;
}
......@@ -501,7 +557,8 @@ SymbolTable::addEndoLeadAuxiliaryVar(int index, expr_t expr_arg) noexcept(false)
}
int
SymbolTable::addEndoLagAuxiliaryVar(int orig_symb_id, int orig_lead_lag, expr_t expr_arg) noexcept(false)
SymbolTable::addEndoLagAuxiliaryVar(int orig_symb_id, int orig_lead_lag,
expr_t expr_arg) noexcept(false)
{
return addLagAuxiliaryVarInternal(true, orig_symb_id, orig_lead_lag, expr_arg);
}
......@@ -513,27 +570,27 @@ SymbolTable::addExoLeadAuxiliaryVar(int index, expr_t expr_arg) noexcept(false)
}
int
SymbolTable::addExoLagAuxiliaryVar(int orig_symb_id, int orig_lead_lag, expr_t expr_arg) noexcept(false)
SymbolTable::addExoLagAuxiliaryVar(int orig_symb_id, int orig_lead_lag,
expr_t expr_arg) noexcept(false)
{
return addLagAuxiliaryVarInternal(false, orig_symb_id, orig_lead_lag, expr_arg);
}
int
SymbolTable::addExpectationAuxiliaryVar(int information_set, int index, expr_t expr_arg) noexcept(false)
SymbolTable::addExpectationAuxiliaryVar(int information_set, int index,
expr_t expr_arg) noexcept(false)
{
ostringstream varname;
string varname {"AUX_EXPECT_"s + (information_set < 0 ? "LAG" : "LEAD") + "_"
+ to_string(abs(information_set)) + "_" + to_string(index)};
int symb_id;
varname << "AUX_EXPECT_" << (information_set < 0 ? "LAG" : "LEAD") << "_"
<< abs(information_set) << "_" << index;
try
{
symb_id = addSymbol(varname.str(), SymbolType::endogenous);
symb_id = addSymbol(varname, SymbolType::endogenous);
}
catch (AlreadyDeclaredException& e)
{
cerr << "ERROR: you should rename your variable called " << varname.str() << ", this name is internally used by Dynare" << endl;
cerr << "ERROR: you should rename your variable called " << varname
<< ", this name is internally used by Dynare" << endl;
exit(EXIT_FAILURE);
}
......@@ -543,20 +600,43 @@ SymbolTable::addExpectationAuxiliaryVar(int information_set, int index, expr_t e
}
int
SymbolTable::addDiffLagAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_id, int orig_lag) noexcept(false)
SymbolTable::addLogTransformAuxiliaryVar(int orig_symb_id, int orig_lead_lag,
expr_t expr_arg) noexcept(false)
{
ostringstream varname;
string varname = "LOG_" + getName(orig_symb_id);
int symb_id;
try
{
symb_id = addSymbol(varname, SymbolType::endogenous);
}
catch (AlreadyDeclaredException& e)
{
cerr << "ERROR: you should rename your variable called " << varname
<< ", it conflicts with the auxiliary variable created for representing the log of "
<< getName(orig_symb_id) << endl;
exit(EXIT_FAILURE);
}
varname << "AUX_DIFF_LAG_" << index;
aux_vars.emplace_back(symb_id, AuxVarType::logTransform, orig_symb_id, orig_lead_lag, 0, 0,
expr_arg, "");
return symb_id;
}
int
SymbolTable::addDiffLagAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_id,
int orig_lag) noexcept(false)
{
string varname {"AUX_DIFF_LAG_" + to_string(index)};
int symb_id;
try
{
symb_id = addSymbol(varname.str(), SymbolType::endogenous);
symb_id = addSymbol(varname, SymbolType::endogenous);
}
catch (AlreadyDeclaredException& e)
{
cerr << "ERROR: you should rename your variable called " << varname.str() << ", this name is internally used by Dynare" << endl;
cerr << "ERROR: you should rename your variable called " << varname
<< ", this name is internally used by Dynare" << endl;
exit(EXIT_FAILURE);
}
......@@ -566,20 +646,19 @@ SymbolTable::addDiffLagAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_id
}
int
SymbolTable::addDiffLeadAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_id, int orig_lead) noexcept(false)
SymbolTable::addDiffLeadAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_id,
int orig_lead) noexcept(false)
{
ostringstream varname;
string varname {"AUX_DIFF_LEAD_" + to_string(index)};
int symb_id;
varname << "AUX_DIFF_LEAD_" << index;
try
{
symb_id = addSymbol(varname.str(), SymbolType::endogenous);
symb_id = addSymbol(varname, SymbolType::endogenous);
}
catch (AlreadyDeclaredException& e)
{
cerr << "ERROR: you should rename your variable called " << varname.str() << ", this name is internally used by Dynare" << endl;
cerr << "ERROR: you should rename your variable called " << varname
<< ", this name is internally used by Dynare" << endl;
exit(EXIT_FAILURE);
}
......@@ -589,20 +668,19 @@ SymbolTable::addDiffLeadAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_i
}
int
SymbolTable::addDiffAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_id, int orig_lag) noexcept(false)
SymbolTable::addDiffAuxiliaryVar(int index, expr_t expr_arg, const optional<int>& orig_symb_id,
const optional<int>& orig_lag) noexcept(false)
{
ostringstream varname;
string varname {"AUX_DIFF_" + to_string(index)};
int symb_id;
varname << "AUX_DIFF_" << index;
try
{
symb_id = addSymbol(varname.str(), SymbolType::endogenous);
symb_id = addSymbol(varname, SymbolType::endogenous);
}
catch (AlreadyDeclaredException& e)
{
cerr << "ERROR: you should rename your variable called " << varname.str() << ", this name is internally used by Dynare" << endl;
cerr << "ERROR: you should rename your variable called " << varname
<< ", this name is internally used by Dynare" << endl;
exit(EXIT_FAILURE);
}
......@@ -612,159 +690,175 @@ SymbolTable::addDiffAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_id, i
}
int
SymbolTable::addDiffAuxiliaryVar(int index, expr_t expr_arg) noexcept(false)
SymbolTable::addUnaryOpAuxiliaryVar(int index, expr_t expr_arg, string unary_op,
const optional<int>& orig_symb_id,
const optional<int>& orig_lag) noexcept(false)
{
return addDiffAuxiliaryVar(index, expr_arg, -1, 0);
string varname {"AUX_UOP_" + to_string(index)};
int symb_id;
try
{
symb_id = addSymbol(varname, SymbolType::endogenous);
}
catch (AlreadyDeclaredException& e)
{
cerr << "ERROR: you should rename your variable called " << varname
<< ", this name is internally used by Dynare" << endl;
exit(EXIT_FAILURE);
}
aux_vars.emplace_back(symb_id, AuxVarType::unaryOp, orig_symb_id, orig_lag, 0, 0, expr_arg,
move(unary_op));
return symb_id;
}
int
SymbolTable::addUnaryOpAuxiliaryVar(int index, expr_t expr_arg, string unary_op, int orig_symb_id, int orig_lag) noexcept(false)
SymbolTable::addMultiplierAuxiliaryVar(int index) noexcept(false)
{
ostringstream varname;
string varname {"MULT_" + to_string(index + 1)};
int symb_id;
varname << "AUX_UOP_" << index;
try
{
symb_id = addSymbol(varname.str(), SymbolType::endogenous);
symb_id = addSymbol(varname, SymbolType::endogenous);
}
catch (AlreadyDeclaredException& e)
{
cerr << "ERROR: you should rename your variable called " << varname.str() << ", this name is internally used by Dynare" << endl;
cerr << "ERROR: you should rename your variable called " << varname
<< ", this name is internally used by Dynare" << endl;
exit(EXIT_FAILURE);
}
aux_vars.emplace_back(symb_id, AuxVarType::unaryOp, orig_symb_id, orig_lag, 0, 0, expr_arg, unary_op);
aux_vars.emplace_back(symb_id, AuxVarType::multiplier, 0, 0, index, 0, nullptr, "");
return symb_id;
}
int
SymbolTable::addVarModelEndoLagAuxiliaryVar(int orig_symb_id, int orig_lead_lag, expr_t expr_arg) noexcept(false)
SymbolTable::addDiffForwardAuxiliaryVar(int orig_symb_id, int orig_lead_lag,
expr_t expr_arg) noexcept(false)
{
string varname {"AUX_DIFF_FWRD_" + to_string(orig_symb_id + 1)};
int symb_id;
ostringstream varname;
varname << "AUX_VARMODEL_" << orig_symb_id << "_" << -orig_lead_lag;
try
{
symb_id = addSymbol(varname.str(), SymbolType::endogenous);
symb_id = addSymbol(varname, SymbolType::endogenous);
}
catch (AlreadyDeclaredException& e)
{
cerr << "ERROR: you should rename your variable called " << varname.str() << ", this name is internally used by Dynare" << endl;
cerr << "ERROR: you should rename your variable called " << varname
<< ", this name is internally used by Dynare" << endl;
exit(EXIT_FAILURE);
}
aux_vars.emplace_back(symb_id, AuxVarType::varModel, orig_symb_id, orig_lead_lag, 0, 0, expr_arg, "");
aux_vars.emplace_back(symb_id, AuxVarType::diffForward, orig_symb_id, orig_lead_lag, 0, 0,
expr_arg, "");
return symb_id;
}
int
SymbolTable::addMultiplierAuxiliaryVar(int index) noexcept(false)
SymbolTable::addPacExpectationAuxiliaryVar(const string& name, expr_t expr_arg)
{
ostringstream varname;
int symb_id;
varname << "MULT_" << index+1;
try
{
symb_id = addSymbol(varname.str(), SymbolType::endogenous);
symb_id = addSymbol(name, SymbolType::endogenous);
}
catch (AlreadyDeclaredException& e)
{
cerr << "ERROR: you should rename your variable called " << varname.str() << ", this name is internally used by Dynare" << endl;
cerr << "ERROR: the variable/parameter '" << name
<< "' conflicts with a variable that will be generated for a 'pac_expectation' "
"expression. Please rename it."
<< endl;
exit(EXIT_FAILURE);
}
aux_vars.emplace_back(symb_id, AuxVarType::multiplier, 0, 0, index, 0, nullptr, "");
aux_vars.emplace_back(symb_id, AuxVarType::pacExpectation, 0, 0, 0, 0, expr_arg, "");
return symb_id;
}
int
SymbolTable::addDiffForwardAuxiliaryVar(int orig_symb_id, expr_t expr_arg) noexcept(false)
SymbolTable::addPacTargetNonstationaryAuxiliaryVar(const string& name, expr_t expr_arg)
{
ostringstream varname;
int symb_id;
varname << "AUX_DIFF_FWRD_" << orig_symb_id+1;
try
{
symb_id = addSymbol(varname.str(), SymbolType::endogenous);
symb_id = addSymbol(name, SymbolType::endogenous);
}
catch (AlreadyDeclaredException& e)
{
cerr << "ERROR: you should rename your variable called " << varname.str() << ", this name is internally used by Dynare" << endl;
cerr << "ERROR: the variable/parameter '" << name
<< "' conflicts with a variable that will be generated for a 'pac_target_nonstationary' "
"expression. Please rename it."
<< endl;
exit(EXIT_FAILURE);
}
aux_vars.emplace_back(symb_id, AuxVarType::diffForward, orig_symb_id, 0, 0, 0, expr_arg, "");
aux_vars.emplace_back(symb_id, AuxVarType::pacTargetNonstationary, 0, 0, 0, 0, expr_arg, "");
return symb_id;
}
int
SymbolTable::searchAuxiliaryVars(int orig_symb_id, int orig_lead_lag) const noexcept(false)
SymbolTable::addAggregationOpAuxiliaryVar(const string& name, expr_t expr_arg)
{
for (const auto & aux_var : aux_vars)
if ((aux_var.get_type() == AuxVarType::endoLag || aux_var.get_type() == AuxVarType::exoLag)
&& aux_var.get_orig_symb_id() == orig_symb_id && aux_var.get_orig_lead_lag() == orig_lead_lag)
return aux_var.get_symb_id();
throw SearchFailedException(orig_symb_id, orig_lead_lag);
int symb_id {[&] {
try
{
return addSymbol(name, SymbolType::endogenous);
}
int
SymbolTable::getOrigSymbIdForAuxVar(int aux_var_symb_id) const noexcept(false)
catch (AlreadyDeclaredException&)
{
for (const auto & aux_var : aux_vars)
if ((aux_var.get_type() == AuxVarType::endoLag
|| aux_var.get_type() == AuxVarType::exoLag
|| aux_var.get_type() == AuxVarType::diff
|| aux_var.get_type() == AuxVarType::diffLag
|| aux_var.get_type() == AuxVarType::diffLead)
&& aux_var.get_symb_id() == aux_var_symb_id)
return aux_var.get_orig_symb_id();
throw UnknownSymbolIDException(aux_var_symb_id);
cerr << "ERROR: the variable/parameter '" << name
<< "' conflicts with a variable that will be generated for an aggregation operator. "
"Please rename it."
<< endl;
exit(EXIT_FAILURE);
}
}()};
aux_vars.emplace_back(symb_id, AuxVarType::aggregationOp, 0, 0, 0, 0, expr_arg, "");
return symb_id;
}
int
SymbolTable::getOrigLeadLagForDiffAuxVar(int diff_aux_var_symb_id) const noexcept(false)
SymbolTable::searchAuxiliaryVars(int orig_symb_id, int orig_lead_lag) const noexcept(false)
{
int lag = 0;
for (const auto& aux_var : aux_vars)
if ((aux_var.get_type() == AuxVarType::diffLag || aux_var.get_type() == AuxVarType::diffLead)
&& aux_var.get_symb_id() == diff_aux_var_symb_id)
lag += 1 + getOrigLeadLagForDiffAuxVar(aux_var.get_orig_symb_id());
return lag;
if ((aux_var.type == AuxVarType::endoLag || aux_var.type == AuxVarType::exoLag)
&& aux_var.orig_symb_id == orig_symb_id && aux_var.orig_lead_lag == orig_lead_lag)
return aux_var.symb_id;
throw SearchFailedException {orig_symb_id, orig_lead_lag};
}
int
SymbolTable::getOrigSymbIdForDiffAuxVar(int diff_aux_var_symb_id) const noexcept(false)
SymbolTable::getOrigSymbIdForAuxVar(int aux_var_symb_id_arg) const noexcept(false)
{
int orig_symb_id = -1;
for (const auto& aux_var : aux_vars)
if (aux_var.get_symb_id() == diff_aux_var_symb_id)
if (aux_var.get_type() == AuxVarType::diff)
orig_symb_id = diff_aux_var_symb_id;
else if (aux_var.get_type() == AuxVarType::diffLag || aux_var.get_type() == AuxVarType::diffLead)
orig_symb_id = getOrigSymbIdForDiffAuxVar(aux_var.get_orig_symb_id());
return orig_symb_id;
if ((aux_var.type == AuxVarType::endoLag || aux_var.type == AuxVarType::exoLag
|| aux_var.type == AuxVarType::diff || aux_var.type == AuxVarType::diffLag
|| aux_var.type == AuxVarType::diffLead || aux_var.type == AuxVarType::diffForward
|| aux_var.type == AuxVarType::unaryOp)
&& aux_var.symb_id == aux_var_symb_id_arg)
{
if (optional<int> r = aux_var.orig_symb_id; r)
return *r;
else
throw UnknownSymbolIDException {
aux_var_symb_id_arg}; // Some diff and unaryOp auxvars have orig_symb_id unset
}
throw UnknownSymbolIDException {aux_var_symb_id_arg};
}
expr_t
SymbolTable::getAuxiliaryVarsExprNode(int symb_id) const noexcept(false)
// throw exception if it is a Lagrange multiplier
pair<int, int>
SymbolTable::unrollDiffLeadLagChain(int symb_id, int lag) const noexcept(false)
{
for (const auto& aux_var : aux_vars)
if (aux_var.get_symb_id() == symb_id)
if (aux_var.symb_id == symb_id)
if (aux_var.type == AuxVarType::diffLag || aux_var.type == AuxVarType::diffLead)
{
expr_t expr_node = aux_var.get_expr_node();
if (expr_node != nullptr)
return expr_node;
else
throw SearchFailedException(symb_id);
auto [orig_symb_id, orig_lag] = unrollDiffLeadLagChain(aux_var.orig_symb_id.value(), lag);
return {orig_symb_id, orig_lag + aux_var.orig_lead_lag.value()};
}
throw SearchFailedException(symb_id);
return {symb_id, lag};
}
void
......@@ -780,17 +874,30 @@ SymbolTable::markPredetermined(int symb_id) noexcept(false)
predetermined_variables.insert(symb_id);
}
void
SymbolTable::markWithLogTransform(int symb_id) noexcept(false)
{
validateSymbID(symb_id);
if (frozen)
throw FrozenException();
assert(getType(symb_id) == SymbolType::endogenous);
with_log_transform.insert(symb_id);
}
bool
SymbolTable::isPredetermined(int symb_id) const noexcept(false)
{
validateSymbID(symb_id);
return (predetermined_variables.find(symb_id) != predetermined_variables.end());
return predetermined_variables.contains(symb_id);
}
int
SymbolTable::predeterminedNbr() const
{
return (predetermined_variables.size());
return predetermined_variables.size();
}
void
......@@ -810,13 +917,13 @@ SymbolTable::observedVariablesNbr() const
bool
SymbolTable::isObservedVariable(int symb_id) const
{
return (find(varobs.begin(), varobs.end(), symb_id) != varobs.end());
return ranges::find(varobs, symb_id) != varobs.end();
}
int
SymbolTable::getObservedVariableIndex(int symb_id) const
{
auto it = find(varobs.begin(), varobs.end(), symb_id);
auto it = ranges::find(varobs, symb_id);
assert(it != varobs.end());
return static_cast<int>(it - varobs.begin());
}
......@@ -825,7 +932,7 @@ void
SymbolTable::addObservedExogenousVariable(int symb_id) noexcept(false)
{
validateSymbID(symb_id);
assert(getType(symb_id) != SymbolType::endogenous);
assert(getType(symb_id) == SymbolType::exogenous);
varexobs.push_back(symb_id);
}
......@@ -838,13 +945,13 @@ SymbolTable::observedExogenousVariablesNbr() const
bool
SymbolTable::isObservedExogenousVariable(int symb_id) const
{
return (find(varexobs.begin(), varexobs.end(), symb_id) != varexobs.end());
return ranges::find(varexobs, symb_id) != varexobs.end();
}
int
SymbolTable::getObservedExogenousVariableIndex(int symb_id) const
{
auto it = find(varexobs.begin(), varexobs.end(), symb_id);
auto it = ranges::find(varexobs, symb_id);
assert(it != varexobs.end());
return static_cast<int>(it - varexobs.begin());
}
......@@ -893,31 +1000,17 @@ SymbolTable::getEndogenous() const
bool
SymbolTable::isAuxiliaryVariable(int symb_id) const
{
for (const auto & aux_var : aux_vars)
if (aux_var.get_symb_id() == symb_id)
return true;
return false;
}
bool
SymbolTable::isAuxiliaryVariableButNotMultiplier(int symb_id) const
{
for (const auto & aux_var : aux_vars)
if (aux_var.get_symb_id() == symb_id && aux_var.get_type() != AuxVarType::multiplier)
return true;
return false;
return ranges::any_of(aux_vars, [=](const auto& av) { return av.symb_id == symb_id; });
}
bool
SymbolTable::isDiffAuxiliaryVariable(int symb_id) const
{
for (const auto & aux_var : aux_vars)
if (aux_var.get_symb_id() == symb_id &&
(aux_var.get_type() == AuxVarType::diff
|| aux_var.get_type() == AuxVarType::diffLag
|| aux_var.get_type() == AuxVarType::diffLead))
return true;
return false;
return ranges::any_of(aux_vars, [=](const auto& av) {
return av.symb_id == symb_id
&& (av.type == AuxVarType::diff || av.type == AuxVarType::diffLag
|| av.type == AuxVarType::diffLead);
});
}
set<int>
......@@ -931,151 +1024,135 @@ SymbolTable::getOrigEndogenous() const
}
void
SymbolTable::writeJuliaOutput(ostream &output) const noexcept(false)
SymbolTable::writeJsonOutput(ostream& output) const
{
if (!frozen)
throw NotYetFrozenException();
output << R"("endogenous": )";
writeJsonVarVector(output, endo_ids);
output << "# Endogenous Variables" << endl
<< "model_.endo = [" << endl;
if (endo_nbr() > 0)
for (int id = 0; id < endo_nbr(); id++)
output << R"( DynareModel.Endo(")"
<< getName(endo_ids[id]) << R"(", raw")"
<< getTeXName(endo_ids[id]) << R"(", ")"
<< getLongName(endo_ids[id]) << R"("))" << endl;
output << " ]" << endl;
output << "model_.endo_nbr = " << endo_nbr() << ";" << endl;
output << R"(, "exogenous":)";
writeJsonVarVector(output, exo_ids);
output << "# Exogenous Variables" << endl
<< "model_.exo = [" << endl;
if (exo_nbr() > 0)
for (int id = 0; id < exo_nbr(); id++)
output << R"( DynareModel.Exo(")"
<< getName(exo_ids[id]) << R"(", raw")"
<< getTeXName(exo_ids[id]) << R"(", ")"
<< getLongName(exo_ids[id]) << R"("))" << endl;
output << R"(, "exogenous_deterministic": )";
writeJsonVarVector(output, exo_det_ids);
output << R"(, "parameters": )";
writeJsonVarVector(output, param_ids);
if (observedVariablesNbr() > 0)
{
output << R"(, "varobs": [)";
for (size_t i = 0; i < varobs.size(); i++)
{
if (i != 0)
output << ", ";
output << R"(")" << getName(varobs[i]) << R"(")";
}
output << "]" << endl;
output << "model_.exo_nbr = " << exo_nbr() << ";" << endl;
if (exo_det_nbr() > 0)
output << R"(, "varobs_ids": [)";
for (size_t i = 0; i < varobs.size(); i++)
{
output << "# Exogenous Deterministic Variables" << endl
<< "model_.exo_det = [" << endl;
if (exo_det_nbr() > 0)
for (int id = 0; id < exo_det_nbr(); id++)
output << R"( DynareModel.ExoDet(")"
<< getName(exo_det_ids[id]) << R"(", raw")"
<< getTeXName(exo_det_ids[id]) << R"(", ")"
<< getLongName(exo_det_ids[id]) << R"("))" << endl;
if (i != 0)
output << ", ";
output << getTypeSpecificID(varobs[i]) + 1;
}
output << "]" << endl;
output << "model_.exo_det_nbr = " << exo_det_nbr() << ";" << endl;
}
output << "# Parameters" << endl
<< "model_.param = [" << endl;
if (param_nbr() > 0)
for (int id = 0; id < param_nbr(); id++)
output << R"( DynareModel.Param(")"
<< getName(param_ids[id]) << R"(", raw")"
<< getTeXName(param_ids[id]) << R"(", ")"
<< getLongName(param_ids[id]) << R"("))" << endl;
if (observedExogenousVariablesNbr() > 0)
{
output << R"(, "varexobs": [)";
for (size_t i = 0; i < varexobs.size(); i++)
{
if (i != 0)
output << ", ";
output << R"(")" << getName(varexobs[i]) << R"(")";
}
output << "]" << endl;
output << "model_.param_nbr = " << param_nbr() << ";" << endl;
output << "model_.orig_endo_nbr = " << orig_endo_nbr() << endl;
if (aux_vars.size() > 0)
output << R"(, "varexobs_ids": [)";
for (size_t i = 0; i < varexobs.size(); i++)
{
output << "# Auxiliary Variables" << endl
<< "model_.aux_vars = [" << endl;
for (const auto & aux_var : aux_vars)
if (i != 0)
output << ", ";
output << getTypeSpecificID(varexobs[i]) + 1;
}
output << "]" << endl;
}
// Write the auxiliary variable table
output << R"(, "orig_endo_nbr": )" << orig_endo_nbr() << endl;
if (aux_vars.size() == 0)
output << R"(, "aux_vars": [])";
else
{
output << " DynareModel.AuxVars("
<< getTypeSpecificID(aux_var.get_symb_id()) + 1 << ", "
<< aux_var.get_type_id() << ", ";
switch (aux_var.get_type())
output << R"(, "aux_vars": [)" << endl;
for (int i = 0; i < static_cast<int>(aux_vars.size()); i++)
{
if (i != 0)
output << ", ";
output << R"({"endo_index": )" << getTypeSpecificID(aux_vars[i].symb_id) + 1
<< R"(, "type": )" << aux_vars[i].get_type_id();
switch (aux_vars[i].type)
{
case AuxVarType::endoLead:
case AuxVarType::exoLead:
case AuxVarType::expectation:
case AuxVarType::pacExpectation:
case AuxVarType::pacTargetNonstationary:
case AuxVarType::aggregationOp:
break;
case AuxVarType::endoLag:
case AuxVarType::exoLag:
case AuxVarType::varModel:
output << getTypeSpecificID(aux_var.get_orig_symb_id()) + 1 << ", "
<< aux_var.get_orig_lead_lag() << ", typemin(Int), string(), string()";
case AuxVarType::logTransform:
case AuxVarType::diffLag:
case AuxVarType::diffLead:
case AuxVarType::diffForward:
output << R"(, "orig_index": )"
<< getTypeSpecificID(aux_vars[i].orig_symb_id.value()) + 1
<< R"(, "orig_lead_lag": )" << aux_vars[i].orig_lead_lag.value();
break;
case AuxVarType::unaryOp:
if (aux_var.get_orig_symb_id() >= 0)
output << getTypeSpecificID(aux_var.get_orig_symb_id()) + 1 << ", " << aux_var.get_orig_lead_lag();
else
output << "typemin(Int), typemin(Int)";
output << ", typemin(Int), string(), "
<< R"(")" << aux_var.get_unary_op() << R"(")" << endl;
break;
output << R"(, "unary_op": ")" << aux_vars[i].unary_op << R"(")";
[[fallthrough]];
case AuxVarType::diff:
case AuxVarType::diffLag:
case AuxVarType::diffLead:
if (aux_var.get_orig_symb_id() >= 0)
output << getTypeSpecificID(aux_var.get_orig_symb_id()) + 1 << ", "
<< aux_var.get_orig_lead_lag() << ", typemin(Int), string(), string()";
if (aux_vars[i].orig_symb_id)
output << R"(, "orig_index": )" << getTypeSpecificID(*aux_vars[i].orig_symb_id) + 1
<< R"(, "orig_lead_lag": )" << aux_vars[i].orig_lead_lag.value();
break;
case AuxVarType::multiplier:
output << "typemin(Int), typemin(Int), " << aux_var.get_equation_number_for_multiplier() + 1
<< ", string(), string()";
output << R"(, "eq_nbr": )" << aux_vars[i].equation_number_for_multiplier + 1;
break;
case AuxVarType::diffForward:
output << getTypeSpecificID(aux_var.get_orig_symb_id())+1 << ", typemin(Int), typemin(Int), string(), string()";
break;
case AuxVarType::expectation:
output << R"(typemin(Int), typemin(Int), typemin(Int), "\mathbb{E}_{t)"
<< (aux_var.get_information_set() < 0 ? "" : "+")
<< aux_var.get_information_set() << "}(";
aux_var.get_expr_node()->writeOutput(output, ExprNodeOutputType::latexDynamicModel);
output << R"lit()")lit";
break;
default:
output << " typemin(Int), typemin(Int), typemin(Int), string(), string()";
}
output << ")" << endl;
}
output << "]" << endl;
}
if (predeterminedNbr() > 0)
if (expr_t orig_expr = aux_vars[i].expr_node; orig_expr)
{
output << "# Predetermined Variables" << endl
<< "model_.pred_vars = [ " << endl;
for (int predetermined_variable : predetermined_variables)
output << " DynareModel.PredVars("
<< getTypeSpecificID(predetermined_variable)+1 << ")" << endl;
output << " ]" << endl;
output << R"(, "orig_expr": ")";
orig_expr->writeJsonOutput(output, {}, {});
output << R"(")";
}
if (observedVariablesNbr() > 0)
{
output << "# Observed Variables" << endl
<< "options_.obs_vars = [" << endl;
for (int varob : varobs)
output << " DynareModel.ObsVars("
<< getTypeSpecificID(varob)+1 << ")" << endl;
output << " ]" << endl;
output << '}' << endl;
}
output << "]" << endl;
}
void
SymbolTable::writeJsonOutput(ostream &output) const
if (!heterogeneity_table.empty())
{
output << R"("endogenous": )";
writeJsonVarVector(output, endo_ids);
output << R"(, "heterogeneous_symbols": [)";
for (int i {0}; i < heterogeneity_table.size(); i++)
{
if (i != 0)
output << ", ";
output << R"({ "dimension": ")" << heterogeneity_table.getName(i)
<< R"(", "endogenous": )";
writeJsonVarVector(output, het_endo_ids.at(i));
output << R"(, "exogenous": )";
writeJsonVarVector(output, exo_ids);
output << R"(, "exogenous_deterministic": )";
writeJsonVarVector(output, exo_det_ids);
writeJsonVarVector(output, het_exo_ids.at(i));
output << R"(, "parameters": )";
writeJsonVarVector(output, param_ids);
writeJsonVarVector(output, het_param_ids.at(i));
output << "}";
}
output << "]" << endl;
}
}
void
......@@ -1088,9 +1165,10 @@ SymbolTable::writeJsonVarVector(ostream &output, const vector<int> &varvec) cons
output << ", ";
output << "{"
<< R"("name":")" << getName(varvec[i]) << R"(", )"
<< R"("texName":")" << boost::replace_all_copy(getTeXName(varvec[i]), R"(\)", R"(\\)") << R"(", )"
<< R"("longName":")" << boost::replace_all_copy(getLongName(varvec[i]), R"(\)", R"(\\)") << R"("})"
<< endl;
<< R"("texName":")" << boost::replace_all_copy(getTeXName(varvec[i]), R"(\)", R"(\\)")
<< R"(", )"
<< R"("longName":")"
<< boost::replace_all_copy(getLongName(varvec[i]), R"(\)", R"(\\)") << R"("})" << endl;
}
output << "]" << endl;
}
......@@ -1109,3 +1187,39 @@ SymbolTable::getUltimateOrigSymbID(int symb_id) const
}
return symb_id;
}
optional<int>
SymbolTable::getEquationNumberForMultiplier(int symb_id) const
{
for (const auto& aux_var : aux_vars)
if (aux_var.symb_id == symb_id && aux_var.type == AuxVarType::multiplier)
return aux_var.equation_number_for_multiplier;
return nullopt;
}
const set<int>&
SymbolTable::getVariablesWithLogTransform() const
{
return with_log_transform;
}
set<int>
SymbolTable::getLagrangeMultipliers() const
{
set<int> r;
for (const auto& aux_var : aux_vars)
if (aux_var.type == AuxVarType::multiplier)
r.insert(aux_var.symb_id);
return r;
}
int
SymbolTable::getHeterogeneityDimension(int symb_id) const
{
validateSymbID(symb_id);
auto it = heterogeneity_dimensions.find(symb_id);
if (it != heterogeneity_dimensions.end())
return it->second;
else
throw NonHeteregeneousSymbolException {symb_id};
}
/*
* Copyright © 2003-2019 Dynare Team
* Copyright © 2003-2024 Dynare Team
*
* This file is part of Dynare.
*
......@@ -14,25 +14,25 @@
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with Dynare. If not, see <http://www.gnu.org/licenses/>.
* along with Dynare. If not, see <https://www.gnu.org/licenses/>.
*/
#ifndef _SYMBOLTABLE_HH
#define _SYMBOLTABLE_HH
using namespace std;
#ifndef SYMBOL_TABLE_HH
#define SYMBOL_TABLE_HH
#include <map>
#include <optional>
#include <ostream>
#include <set>
#include <string>
#include <utility>
#include <vector>
#include <set>
#include <ostream>
#include "CodeInterpreter.hh"
#include "CommonEnums.hh"
#include "ExprNode.hh"
#include "HeterogeneityTable.hh"
using expr_t = class ExprNode *;
using namespace std;
//! Types of auxiliary variables
enum class AuxVarType
......@@ -42,79 +42,63 @@ enum class AuxVarType
exoLead = 2, //!< Substitute for exo leads >= 1
exoLag = 3, //!< Substitute for exo lags >= 1
expectation = 4, //!< Substitute for Expectation Operator
diffForward = 5, //!< Substitute for the differentiate of a forward variable
diffForward = 5, /* Substitute for the differentiate of a forward variable,
for the differentiate_forward_vars option.
N.B.: nothing to do with the diff() operator! */
multiplier = 6, //!< Multipliers for FOC of Ramsey Problem
varModel = 7, //!< Variable for var_model with order > abs(min_lag()) present in model
logTransform = 7, //!< Log-transformation of a variable declared with “var(log)”
diff = 8, //!< Variable for Diff operator
diffLag = 9, //!< Variable for timing between Diff operators
unaryOp = 10, //!< Variable for allowing the undiff operator to work when diff was taken of unary op, eg diff(log(x))
diffLead = 11 //!< Variable for timing between Diff operators
diffLag = 9, //!< Variable for timing between Diff operators (lag)
unaryOp = 10, //!< Variable for allowing the undiff operator to work when diff was taken of unary
//!< op, eg diff(log(x))
diffLead = 11, //!< Variable for timing between Diff operators (lead)
pacExpectation = 12, //!< Variable created for the substitution of the pac_expectation operator
pacTargetNonstationary
= 13, //!< Variable created for the substitution of the pac_target_nonstationary operator
aggregationOp
= 14 // Substitute for an aggregation operator in a heterogeneous setup, such as SUM()
};
//! Information on some auxiliary variables
class AuxVarInfo
{
private:
int symb_id; //!< Symbol ID of the auxiliary variable
AuxVarType type; //!< Its type
int orig_symb_id; //!< Symbol ID of the endo of the original model represented by this aux var. Only used for avEndoLag and avExoLag.
int orig_lead_lag; //!< Lead/lag of the endo of the original model represented by this aux var. Only used for avEndoLag and avExoLag.
int equation_number_for_multiplier; //!< Stores the original constraint equation number associated with this aux var. Only used for avMultiplier.
int information_set; //! Argument of expectation operator. Only used for avExpectation.
expr_t expr_node; //! Auxiliary variable definition
string unary_op; //! Used with AuxUnaryOp
public:
AuxVarInfo(int symb_id_arg, AuxVarType type_arg, int orig_symb_id, int orig_lead_lag, int equation_number_for_multiplier_arg, int information_set_arg, expr_t expr_node_arg, string unary_op_arg);
int
get_symb_id() const
{
return symb_id;
};
AuxVarType
get_type() const
{
return type;
};
int
struct AuxVarInfo
{
const int symb_id; // Symbol ID of the auxiliary variable
const AuxVarType type; // Its type
const optional<int> orig_symb_id; /* Symbol ID of the (only) endo that appears on the RHS of
the definition of this auxvar.
Used by endoLag, exoLag, diffForward, logTransform, diff,
diffLag, diffLead and unaryOp.
For diff and unaryOp, if the argument expression is more
complex than than a simple variable, this value is unset
(hence the need for std::optional). */
const optional<int> orig_lead_lag; /* Lead/lag of the (only) endo as it appears on the RHS of the
definition of this auxvar. Only set if orig_symb_id is set
(in particular, for diff and unaryOp, unset
if orig_symb_id is unset).
For diff and diffForward, since the definition of the
auxvar is a time difference, the value corresponds to the
time index of the first term of that difference. */
const int equation_number_for_multiplier; /* Stores the original constraint equation number
associated with this aux var. Only used for
avMultiplier. */
const int information_set; // Argument of expectation operator. Only used for avExpectation.
const expr_t expr_node; // Auxiliary variable definition
const string unary_op; // Used with AuxUnaryOp
[[nodiscard]] int
get_type_id() const
{
return static_cast<int>(type);
}
int
get_orig_symb_id() const
{
return orig_symb_id;
};
int
get_orig_lead_lag() const
{
return orig_lead_lag;
};
int
get_equation_number_for_multiplier() const
{
return equation_number_for_multiplier;
};
int
get_information_set() const
{
return information_set;
};
expr_t
get_expr_node() const
{
return expr_node;
};
string
get_unary_op() const
{
return unary_op;
};
};
//! Stores the symbol table
/*!
A symbol is given by its name, and is internally represented by a unique integer.
A symbol is given by its name, and is internally represented by a unique
integer, called a symbol ID.
There is a guarantee that symbol IDs are increasing, i.e. if symbol A is
added after symbol B, then the ID of A is greater than the ID of B.
When method freeze() is called, computes a distinct sequence of IDs for some types
(endogenous, exogenous, parameters), which are used by the Matlab/Octave functions.
......@@ -125,6 +109,8 @@ public:
class SymbolTable
{
private:
HeterogeneityTable& heterogeneity_table;
//! Has method freeze() been called?
bool frozen {false};
......@@ -142,9 +128,11 @@ private:
map<int, map<string, string>> partition_value_map;
//! Maps IDs to types
vector<SymbolType> type_table;
// Maps IDs of heterogenous symbols to heterogeneity dimension IDs
map<int, int> heterogeneity_dimensions;
//! Maps symbol IDs to type specific IDs
vector<int> type_specific_ids;
map<int, int> type_specific_ids;
//! Maps type specific IDs of endogenous to symbol IDs
vector<int> endo_ids;
......@@ -154,6 +142,16 @@ private:
vector<int> exo_det_ids;
//! Maps type specific IDs of parameters to symbol IDs
vector<int> param_ids;
/* Maps type specific IDs of heterogeneous endogenous to symbol IDs (outer vector is for
heterogeneity dimensions) */
vector<vector<int>> het_endo_ids;
/* Maps type specific IDs of heterogeneous exogenous to symbol IDs (outer vector is for
heterogeneity dimensions) */
vector<vector<int>> het_exo_ids;
/* Maps type specific IDs of heterogeneous parameters to symbol IDs (outer vector is for
heterogeneity dimensions) */
vector<vector<int>> het_param_ids;
//! Information about auxiliary variables
vector<AuxVarInfo> aux_vars;
......@@ -166,49 +164,42 @@ private:
//! Stores the list of observed exogenous variables
vector<int> varexobs;
//! Stores the endogenous variables declared with “var(log)”
set<int> with_log_transform;
public:
SymbolTable();
//! Thrown when trying to access an unknown symbol (by name)
class UnknownSymbolNameException
struct UnknownSymbolNameException
{
public:
//! Symbol name
string name;
explicit UnknownSymbolNameException(string name_arg) : name{move(name_arg)}
{
}
const string name;
};
//! Thrown when trying to access an unknown symbol (by id)
class UnknownSymbolIDException
struct UnknownSymbolIDException
{
public:
//! Symbol ID
int id;
explicit UnknownSymbolIDException(int id_arg) : id{id_arg}
{
}
const int id;
};
//! Thrown when trying to access an unknown type specific ID
class UnknownTypeSpecificIDException
struct UnknownTypeSpecificIDException
{
public:
int tsid;
SymbolType type;
UnknownTypeSpecificIDException(int tsid_arg, SymbolType type_arg) : tsid{tsid_arg}, type{type_arg}
const int tsid;
const SymbolType type;
const optional<int> heterogeneity_dimension;
};
/* Thrown when requesting the type specific ID of a symbol which doesn’t
have one */
struct NoTypeSpecificIDException
{
}
const int symb_id;
};
//! Thrown when trying to declare a symbol twice
class AlreadyDeclaredException
struct AlreadyDeclaredException
{
public:
//! Symbol name
string name;
const string name;
//! Was the previous declaration done with the same symbol type ?
bool same_type;
AlreadyDeclaredException(string name_arg, bool same_type_arg) : name{move(name_arg)}, same_type{same_type_arg}
{
}
const bool same_type;
};
//! Thrown when table is frozen and trying to modify it
class FrozenException
......@@ -222,30 +213,44 @@ public:
class SearchFailedException
{
public:
int orig_symb_id, orig_lead_lag, symb_id;
SearchFailedException(int orig_symb_id_arg, int orig_lead_lag_arg) : orig_symb_id{orig_symb_id_arg},
orig_lead_lag{orig_lead_lag_arg}
int orig_symb_id, orig_lead_lag;
SearchFailedException(int orig_symb_id_arg, int orig_lead_lag_arg) :
orig_symb_id {orig_symb_id_arg}, orig_lead_lag {orig_lead_lag_arg}
{
}
explicit SearchFailedException(int symb_id_arg) : symb_id{symb_id_arg}
};
// Thrown by getHeterogeneityDimension() on non-heterogeneous symbols
struct NonHeteregeneousSymbolException
{
}
const int id;
};
private:
//! Factorized code for adding aux lag variables
int addLagAuxiliaryVarInternal(bool endo, int orig_symb_id, int orig_lead_lag, expr_t arg) noexcept(false);
int addLagAuxiliaryVarInternal(bool endo, int orig_symb_id, int orig_lead_lag,
expr_t arg) noexcept(false);
//! Factorized code for adding aux lead variables
int addLeadAuxiliaryVarInternal(bool endo, int index, expr_t arg) noexcept(false);
//! Factorized code for Json writing
void writeJsonVarVector(ostream& output, const vector<int>& varvec) const;
//! Factorized code for asserting that 0 <= symb_id <= symbol_table.size()
inline void validateSymbID(int symb_id) const noexcept(false);
public:
SymbolTable(HeterogeneityTable& heterogeneity_table_arg) :
heterogeneity_table {heterogeneity_table_arg}
{
}
//! Add a symbol
/*! Returns the symbol ID */
int addSymbol(const string &name, SymbolType type, const string &tex_name, const vector<pair<string, string>> &partition_value) noexcept(false);
//! Add a symbol without its TeX name (will be equal to its name)
/* Returns the symbol ID.
heterogeneity_dimension must be defined if this is a heterogeneous symbol (otherwise it is
ignored) */
int addSymbol(const string& name, SymbolType type, const string& tex_name,
const vector<pair<string, string>>& partition_value,
const optional<int>& heterogeneity_dimension) noexcept(false);
//! Add a (non-heterogenous) symbol without its TeX name (will be equal to its name)
/*! Returns the symbol ID */
int addSymbol(const string& name, SymbolType type) noexcept(false);
//! Adds an auxiliary variable for endogenous with lead >= 2
......@@ -255,9 +260,9 @@ public:
int addEndoLeadAuxiliaryVar(int index, expr_t arg) noexcept(false);
//! Adds an auxiliary variable for endogenous with lag >= 2
/*!
\param[in] orig_symb_id symbol ID of the endogenous declared by the user that this new variable will represent
\param[in] orig_lead_lag lag value such that this new variable will be equivalent to orig_symb_id(orig_lead_lag)
\return the symbol ID of the new symbol */
\param[in] orig_symb_id symbol ID of the endogenous declared by the user that this new variable
will represent \param[in] orig_lead_lag lag value such that this new variable will be equivalent
to orig_symb_id(orig_lead_lag) \return the symbol ID of the new symbol */
int addEndoLagAuxiliaryVar(int orig_symb_id, int orig_lead_lag, expr_t arg) noexcept(false);
//! Adds an auxiliary variable for endogenous with lead >= 1
/*!
......@@ -266,9 +271,9 @@ public:
int addExoLeadAuxiliaryVar(int index, expr_t arg) noexcept(false);
//! Adds an auxiliary variable for exogenous with lag >= 1
/*!
\param[in] orig_symb_id symbol ID of the exogenous declared by the user that this new variable will represent
\param[in] orig_lead_lag lag value such that this new variable will be equivalent to orig_symb_id(orig_lead_lag)
\return the symbol ID of the new symbol */
\param[in] orig_symb_id symbol ID of the exogenous declared by the user that this new variable
will represent \param[in] orig_lead_lag lag value such that this new variable will be equivalent
to orig_symb_id(orig_lead_lag) \return the symbol ID of the new symbol */
int addExoLagAuxiliaryVar(int orig_symb_id, int orig_lead_lag, expr_t arg) noexcept(false);
//! Adds an auxiliary variable for the expectation operator
/*!
......@@ -283,65 +288,95 @@ public:
\return the symbol ID of the new symbol
*/
int addMultiplierAuxiliaryVar(int index) noexcept(false);
/* Adds an auxiliary variable associated to an endogenous declared with
“var(log)”.
– orig_symb_id is the symbol ID of the original variable
– orig_lead_lag is typically 0
– expr_arg is typically log(orig_symb_id)
*/
int addLogTransformAuxiliaryVar(int orig_symb_id, int orig_lead_lag,
expr_t expr_arg) noexcept(false);
//! Adds an auxiliary variable for the (time) differentiate of a forward var
/*!
\param[in] orig_symb_id The symb_id of the forward variable
\return the symbol ID of the new symbol
*/
int addDiffForwardAuxiliaryVar(int orig_symb_id, expr_t arg) noexcept(false);
int addDiffForwardAuxiliaryVar(int orig_symb_id, int orig_lead_lag, expr_t arg) noexcept(false);
//! Searches auxiliary variables which are substitutes for a given symbol_id and lead/lag
/*!
The search is only performed among auxiliary variables of endo/exo lag.
\return the symbol ID of the auxiliary variable
Throws an exception if match not found.
*/
int searchAuxiliaryVars(int orig_symb_id, int orig_lead_lag) const noexcept(false);
//! Serches aux_vars for the aux var represented by aux_var_symb_id and returns its associated orig_symb_id
int getOrigSymbIdForAuxVar(int aux_var_symb_id) const noexcept(false);
//! Searches for diff aux var and finds the original lag associated with this variable
int getOrigLeadLagForDiffAuxVar(int diff_aux_var_symb_id) const noexcept(false);
//! Searches for diff aux var and finds the symb id associated with this variable
int getOrigSymbIdForDiffAuxVar(int diff_aux_var_symb_id) const noexcept(false);
//! Adds an auxiliary variable when var_model is used with an order that is greater in absolute value
//! than the largest lag present in the model.
int addVarModelEndoLagAuxiliaryVar(int orig_symb_id, int orig_lead_lag, expr_t expr_arg) noexcept(false);
[[nodiscard]] int searchAuxiliaryVars(int orig_symb_id, int orig_lead_lag) const noexcept(false);
/* Searches aux_vars for the aux var represented by aux_var_symb_id and
returns its associated orig_symb_id.
Throws an UnknownSymbolIDException if there is no orig_symb_id associated to
this auxvar (either because it’s of the wrong type, or because there is
no such orig var for this specific auxvar, in case of complex expressions
in diff or unaryOp). */
[[nodiscard]] int getOrigSymbIdForAuxVar(int aux_var_symb_id_arg) const noexcept(false);
/* Unrolls a chain of diffLag or diffLead aux vars until it founds a (regular) diff aux
var. In other words:
- if the arg is a (regu) diff aux var, returns the arg
- if the arg is a diffLag/diffLead, get its orig symb ID, and call the
method recursively
- if the arg is something else, throw an UnknownSymbolIDException
exception
The 2nd input/output arguments are used to track leads/lags. The 2nd
output argument is equal to the 2nd input argument, shifted by as many
lead/lags were encountered in the chain (a diffLag decreases it, a
diffLead increases it). */
[[nodiscard]] pair<int, int> unrollDiffLeadLagChain(int symb_id, int lag) const noexcept(false);
//! Adds an auxiliary variable when the diff operator is encountered
int addDiffAuxiliaryVar(int index, expr_t expr_arg) noexcept(false);
int addDiffAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_id, int orig_lag) noexcept(false);
int addDiffAuxiliaryVar(int index, expr_t expr_arg, const optional<int>& orig_symb_id = nullopt,
const optional<int>& orig_lag = nullopt) noexcept(false);
//! Takes care of timing between diff statements
int addDiffLagAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_id, int orig_lag) noexcept(false);
int addDiffLagAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_id,
int orig_lag) noexcept(false);
//! Takes care of timing between diff statements
int addDiffLeadAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_id, int orig_lead) noexcept(false);
int addDiffLeadAuxiliaryVar(int index, expr_t expr_arg, int orig_symb_id,
int orig_lead) noexcept(false);
//! An Auxiliary variable for a unary op
int addUnaryOpAuxiliaryVar(int index, expr_t expr_arg, string unary_op, int orig_symb_id = -1, int orig_lag = 0) noexcept(false);
int addUnaryOpAuxiliaryVar(int index, expr_t expr_arg, string unary_op,
const optional<int>& orig_symb_id = nullopt,
const optional<int>& orig_lag = nullopt) noexcept(false);
//! An auxiliary variable for a pac_expectation operator
int addPacExpectationAuxiliaryVar(const string& name, expr_t expr_arg);
//! An auxiliary variable for a pac_target_nonstationary operator
int addPacTargetNonstationaryAuxiliaryVar(const string& name, expr_t expr_arg);
// An auxiliary variable for an aggregation operator (e.g. SUM(yh) where yh is heterogeneous)
int addAggregationOpAuxiliaryVar(const string& name, expr_t expr_arg);
//! Returns the number of auxiliary variables
int
[[nodiscard]] int
AuxVarsSize() const
{
return aux_vars.size();
};
//! Retruns expr_node for an auxiliary variable
expr_t getAuxiliaryVarsExprNode(int symb_id) const noexcept(false);
}
//! Tests if symbol already exists
inline bool exists(const string &name) const;
[[nodiscard]] inline bool exists(const string& name) const;
//! Get symbol name (by ID)
inline string getName(int id) const noexcept(false);
[[nodiscard]] inline string getName(int id) const noexcept(false);
//! Get TeX name
inline string getTeXName(int id) const noexcept(false);
[[nodiscard]] inline string getTeXName(int id) const noexcept(false);
//! Get long name
inline string getLongName(int id) const noexcept(false);
//! Returns true if the partition name is the first encountered for the type of variable represented by id
bool isFirstOfPartitionForType(int id) const noexcept(false);
[[nodiscard]] inline string getLongName(int id) const noexcept(false);
//! Returns true if the partition name is the first encountered for the type of variable
//! represented by id
[[nodiscard]] bool isFirstOfPartitionForType(int id) const noexcept(false);
//! Returns a list of partitions and symbols that belong to that partition
map<string, map<int, string>> getPartitionsForType(SymbolType st) const noexcept(false);
[[nodiscard]] map<string, map<int, string>> getPartitionsForType(SymbolType st) const
noexcept(false);
//! Get type (by ID)
inline SymbolType getType(int id) const noexcept(false);
[[nodiscard]] inline SymbolType getType(int id) const noexcept(false);
//! Get type (by name)
inline SymbolType getType(const string &name) const noexcept(false);
[[nodiscard]] inline SymbolType getType(const string& name) const noexcept(false);
//! Get ID (by name)
inline int getID(const string &name) const noexcept(false);
[[nodiscard]] inline int getID(const string& name) const noexcept(false);
//! Get ID (by type specific ID)
int getID(SymbolType type, int tsid) const noexcept(false);
[[nodiscard]] int getID(SymbolType type, int tsid,
const optional<int>& heterogeneity_dimension = nullopt) const
noexcept(false);
//! Freeze symbol table
void freeze() noexcept(false);
//! unreeze symbol table
......@@ -350,83 +385,102 @@ public:
//! Change the type of a symbol
void changeType(int id, SymbolType newtype) noexcept(false);
//! Get type specific ID (by symbol ID)
inline int getTypeSpecificID(int id) const noexcept(false);
[[nodiscard]] inline int getTypeSpecificID(int id) const noexcept(false);
//! Get type specific ID (by symbol name)
inline int getTypeSpecificID(const string &name) const noexcept(false);
[[nodiscard]] inline int getTypeSpecificID(const string& name) const noexcept(false);
//! Get number of endogenous variables
inline int endo_nbr() const noexcept(false);
[[nodiscard]] inline int endo_nbr() const noexcept(false);
//! Get number of exogenous variables
inline int exo_nbr() const noexcept(false);
[[nodiscard]] inline int exo_nbr() const noexcept(false);
//! Get number of exogenous deterministic variables
inline int exo_det_nbr() const noexcept(false);
[[nodiscard]] inline int exo_det_nbr() const noexcept(false);
//! Get number of parameters
inline int param_nbr() const noexcept(false);
[[nodiscard]] inline int param_nbr() const noexcept(false);
//! Get number of heterogeneous endogenous variables along a given dimension
[[nodiscard]] inline int het_endo_nbr(int het_dim) const noexcept(false);
//! Get number of heterogeneous exogenous variables along a given dimension
[[nodiscard]] inline int het_exo_nbr(int het_dim) const noexcept(false);
//! Get number of heterogeneous parameters along a given dimension
[[nodiscard]] inline int het_param_nbr(int het_dim) const noexcept(false);
//! Returns the greatest symbol ID (the smallest is zero)
inline int maxID();
[[nodiscard]] inline int maxID() const;
//! Get number of user-declared endogenous variables (without the auxiliary variables)
inline int orig_endo_nbr() const noexcept(false);
[[nodiscard]] inline int orig_endo_nbr() const noexcept(false);
//! Write output of this class
void writeOutput(ostream& output) const noexcept(false);
//! Write JSON Output
void writeJsonOutput(ostream& output) const;
//! Write Julia output of this class
void writeJuliaOutput(ostream &output) const noexcept(false);
//! Mark a symbol as predetermined variable
void markPredetermined(int symb_id) noexcept(false);
//! Mark an endogenous as having been declared with “var(log)”
void markWithLogTransform(int symb_id) noexcept(false);
//! Test if a given symbol is a predetermined variable
bool isPredetermined(int symb_id) const noexcept(false);
[[nodiscard]] bool isPredetermined(int symb_id) const noexcept(false);
//! Return the number of predetermined variables
int predeterminedNbr() const;
[[nodiscard]] int predeterminedNbr() const;
//! Add an observed variable
void addObservedVariable(int symb_id) noexcept(false);
//! Return the number of observed variables
int observedVariablesNbr() const;
[[nodiscard]] int observedVariablesNbr() const;
//! Is a given symbol in the set of observed variables
bool isObservedVariable(int symb_id) const;
[[nodiscard]] bool isObservedVariable(int symb_id) const;
//! Return the index of a given observed variable in the vector of all observed variables
int getObservedVariableIndex(int symb_id) const;
[[nodiscard]] int getObservedVariableIndex(int symb_id) const;
//! Add an observed exogenous variable
void addObservedExogenousVariable(int symb_id) noexcept(false);
//! Return the number of observed exogenous variables
int observedExogenousVariablesNbr() const;
[[nodiscard]] int observedExogenousVariablesNbr() const;
//! Is a given symbol in the set of observed exogenous variables
bool isObservedExogenousVariable(int symb_id) const;
//! Return the index of a given observed exogenous variable in the vector of all observed variables
int getObservedExogenousVariableIndex(int symb_id) const;
vector <int> getTrendVarIds() const;
[[nodiscard]] bool isObservedExogenousVariable(int symb_id) const;
//! Return the index of a given observed exogenous variable in the vector of all observed
//! variables
[[nodiscard]] int getObservedExogenousVariableIndex(int symb_id) const;
[[nodiscard]] vector<int> getTrendVarIds() const;
//! Get list of exogenous variables
set <int> getExogenous() const;
[[nodiscard]] set<int> getExogenous() const;
//! Get list of exogenous variables
set <int> getObservedExogenous() const;
[[nodiscard]] set<int> getObservedExogenous() const;
//! Get list of endogenous variables
set <int> getEndogenous() const;
[[nodiscard]] set<int> getEndogenous() const;
//! Is a given symbol an auxiliary variable
bool isAuxiliaryVariable(int symb_id) const;
//! Is a given symbol an auxiliary variable but not a Lagrange multiplier
bool isAuxiliaryVariableButNotMultiplier(int symb_id) const;
[[nodiscard]] bool isAuxiliaryVariable(int symb_id) const;
//! Is a given symbol a diff, diff lead, or diff lag auxiliary variable
bool isDiffAuxiliaryVariable(int symb_id) const;
[[nodiscard]] bool isDiffAuxiliaryVariable(int symb_id) const;
//! Get list of endogenous variables without aux vars
set <int> getOrigEndogenous() const;
[[nodiscard]] set<int> getOrigEndogenous() const;
//! Returns the original symbol corresponding to this variable
/* If symb_id is not an auxiliary var, returns symb_id. Otherwise,
repeatedly call getOrigSymbIDForAuxVar() until an original
(non-auxiliary) variable is found. */
int getUltimateOrigSymbID(int symb_id) const;
/* If symb_id has no original variable, returns symb_id. Otherwise,
repeatedly call getOrigSymbIDForAuxVar() until an original variable is
found. Note that the result may be an auxiliary variable if the latter has
no original variable (e.g. aux var for lead, Lagrange Multiplier or diff
associated to a complex expression). */
[[nodiscard]] int getUltimateOrigSymbID(int symb_id) const;
//! If this is a Lagrange multiplier, return its associated equation number; otherwise return
//! nullopt
[[nodiscard]] optional<int> getEquationNumberForMultiplier(int symb_id) const;
/* Return all the information about a given auxiliary variable. Throws
UnknownSymbolIDException if it is not an aux var */
[[nodiscard]] const AuxVarInfo& getAuxVarInfo(int symb_id) const;
// Returns the set of all endogenous declared with “var(log)”
[[nodiscard]] const set<int>& getVariablesWithLogTransform() const;
// Returns all Lagrange multipliers
[[nodiscard]] set<int> getLagrangeMultipliers() const;
/* Get heterogeneity dimension of a given symbol. Throws NonHeterogeneousSymbolException
if there is no such dimension. */
[[nodiscard]] int getHeterogeneityDimension(int symb_id) const;
};
inline void
SymbolTable::validateSymbID(int symb_id) const noexcept(false)
{
if (symb_id < 0 || symb_id > static_cast<int>(symbol_table.size()))
throw UnknownSymbolIDException(symb_id);
throw UnknownSymbolIDException {symb_id};
}
inline bool
SymbolTable::exists(const string& name) const
{
auto iter = symbol_table.find(name);
return (iter != symbol_table.end());
return symbol_table.contains(name);
}
inline string
......@@ -466,11 +520,10 @@ SymbolTable::getType(const string &name) const noexcept(false)
inline int
SymbolTable::getID(const string& name) const noexcept(false)
{
auto iter = symbol_table.find(name);
if (iter != symbol_table.end())
if (auto iter = symbol_table.find(name); iter != symbol_table.end())
return iter->second;
else
throw UnknownSymbolNameException(name);
throw UnknownSymbolNameException {name};
}
inline int
......@@ -481,7 +534,10 @@ SymbolTable::getTypeSpecificID(int id) const noexcept(false)
validateSymbID(id);
return type_specific_ids[id];
if (auto it = type_specific_ids.find(id); it != type_specific_ids.end())
return it->second;
else
throw NoTypeSpecificIDException {id};
}
inline int
......@@ -527,7 +583,34 @@ SymbolTable::param_nbr() const noexcept(false)
}
inline int
SymbolTable::maxID()
SymbolTable::het_endo_nbr(int het_dim) const noexcept(false)
{
if (!frozen)
throw NotYetFrozenException();
return het_endo_ids.at(het_dim).size();
}
inline int
SymbolTable::het_exo_nbr(int het_dim) const noexcept(false)
{
if (!frozen)
throw NotYetFrozenException();
return het_exo_ids.at(het_dim).size();
}
inline int
SymbolTable::het_param_nbr(int het_dim) const noexcept(false)
{
if (!frozen)
throw NotYetFrozenException();
return het_param_ids.at(het_dim).size();
}
inline int
SymbolTable::maxID() const
{
return symbol_table.size() - 1;
}
......@@ -535,7 +618,16 @@ SymbolTable::maxID()
inline int
SymbolTable::orig_endo_nbr() const noexcept(false)
{
return (endo_nbr() - aux_vars.size());
return endo_nbr() - aux_vars.size();
}
inline const AuxVarInfo&
SymbolTable::getAuxVarInfo(int symb_id) const
{
for (const auto& aux_var : aux_vars)
if (aux_var.symb_id == symb_id)
return aux_var;
throw UnknownSymbolIDException {symb_id};
}
#endif
/*
* Copyright © 2009-2020 Dynare Team
*
* This file is part of Dynare.
*
* Dynare is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* Dynare is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with Dynare. If not, see <https://www.gnu.org/licenses/>.
*/
#include <algorithm>
#include <iostream>
#include "VariableDependencyGraph.hh"
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wold-style-cast"
#pragma GCC diagnostic ignored "-Wsign-compare"
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#include <boost/graph/strong_components.hpp>
#include <boost/graph/topological_sort.hpp>
#pragma GCC diagnostic pop
#include <ranges>
using namespace boost;
VariableDependencyGraph::VariableDependencyGraph(int n) : base(n)
{
/* It is necessary to manually initialize the vertex_index property since
this graph uses listS and not vecS as underlying vertex container */
auto v_index = get(vertex_index, *this);
for (int i = 0; i < n; i++)
put(v_index, vertex(i, *this), i);
}
void
VariableDependencyGraph::suppress(vertex_descriptor vertex_to_eliminate)
{
clear_vertex(vertex_to_eliminate, *this);
remove_vertex(vertex_to_eliminate, *this);
}
void
VariableDependencyGraph::suppress(int vertex_num)
{
suppress(vertex(vertex_num, *this));
}
void
VariableDependencyGraph::eliminate(vertex_descriptor vertex_to_eliminate)
{
if (in_degree(vertex_to_eliminate, *this) > 0 && out_degree(vertex_to_eliminate, *this) > 0)
for (auto [it_in, in_end] = in_edges(vertex_to_eliminate, *this); it_in != in_end; ++it_in)
for (auto [it_out, out_end] = out_edges(vertex_to_eliminate, *this); it_out != out_end;
++it_out)
if (auto [ed, exist] = edge(source(*it_in, *this), target(*it_out, *this), *this); !exist)
add_edge(source(*it_in, *this), target(*it_out, *this), *this);
suppress(vertex_to_eliminate);
}
bool
VariableDependencyGraph::hasCycleDFS(vertex_descriptor u, color_t& color,
vector<int>& circuit_stack) const
{
auto v_index = get(vertex_index, *this);
color[u] = gray_color;
for (auto [vi, vi_end] = out_edges(u, *this); vi != vi_end; ++vi)
if (color[target(*vi, *this)] == white_color
&& hasCycleDFS(target(*vi, *this), color, circuit_stack))
{
// cycle detected, return immediately
circuit_stack.push_back(v_index[target(*vi, *this)]);
return true;
}
else if (color[target(*vi, *this)] == gray_color)
{
// *vi is an ancestor!
circuit_stack.push_back(v_index[target(*vi, *this)]);
return true;
}
color[u] = black_color;
return false;
}
bool
VariableDependencyGraph::hasCycle() const
{
// Initialize color map to white
color_t color;
vector<int> circuit_stack;
for (auto [vi, vi_end] = vertices(*this); vi != vi_end; ++vi)
color[*vi] = white_color;
// Perform depth-first search
for (auto [vi, vi_end] = vertices(*this); vi != vi_end; ++vi)
if (color[*vi] == white_color && hasCycleDFS(*vi, color, circuit_stack))
return true;
return false;
}
void
VariableDependencyGraph::print() const
{
auto v_index = get(vertex_index, *this);
cout << "Graph\n"
<< "-----\n";
for (auto [it, it_end] = vertices(*this); it != it_end; ++it)
{
cout << "vertex[" << v_index[*it] + 1 << "] <-";
for (auto [it_in, in_end] = in_edges(*it, *this); it_in != in_end; ++it_in)
cout << v_index[source(*it_in, *this)] + 1 << " ";
cout << "\n ->";
for (auto [it_out, out_end] = out_edges(*it, *this); it_out != out_end; ++it_out)
cout << v_index[target(*it_out, *this)] + 1 << " ";
cout << "\n";
}
}
VariableDependencyGraph
VariableDependencyGraph::extractSubgraph(const vector<int>& select_index) const
{
int n = select_index.size();
VariableDependencyGraph G(n);
auto v_index = get(vertex_index, *this);
auto v_index1_G = get(vertex_index1, G); // Maps new vertices to original indices
map<int, int> reverse_index; // Maps orig indices to new ones
for (int i = 0; i < n; i++)
{
reverse_index[select_index[i]] = i;
v_index1_G[vertex(i, G)] = select_index[i];
}
for (int i = 0; i < n; i++)
{
auto vi = vertex(select_index[i], *this);
for (auto [it_out, out_end] = out_edges(vi, *this); it_out != out_end; ++it_out)
if (auto it = reverse_index.find(v_index[target(*it_out, *this)]);
it != reverse_index.end())
add_edge(vertex(i, G), vertex(it->second, G), G);
}
return G;
}
bool
VariableDependencyGraph::vertexBelongsToAClique(vertex_descriptor vertex) const
{
vector<vertex_descriptor> liste;
bool agree = true;
auto [it_in, in_end] = in_edges(vertex, *this);
auto [it_out, out_end] = out_edges(vertex, *this);
while (it_in != in_end && it_out != out_end && agree)
{
agree = (source(*it_in, *this) == target(*it_out, *this)
&& source(*it_in, *this) != target(*it_in, *this)); // not a loop
liste.push_back(source(*it_in, *this));
++it_in;
++it_out;
}
if (agree)
{
if (it_in != in_end || it_out != out_end)
agree = false;
int i = 1;
while (i < static_cast<int>(liste.size()) && agree)
{
int j = i + 1;
while (j < static_cast<int>(liste.size()) && agree)
{
auto [ed1, exist1] = edge(liste[i], liste[j], *this);
auto [ed2, exist2] = edge(liste[j], liste[i], *this);
agree = exist1 && exist2;
j++;
}
i++;
}
}
return agree;
}
bool
VariableDependencyGraph::eliminationOfVerticesWithOneOrLessIndegreeOrOutdegree()
{
bool something_has_been_done = false;
bool not_a_loop;
int i;
vertex_iterator it, ita, it_end;
for (tie(it, it_end) = vertices(*this), i = 0; it != it_end; ++it, i++)
{
int in_degree_n = in_degree(*it, *this);
int out_degree_n = out_degree(*it, *this);
if (in_degree_n <= 1 || out_degree_n <= 1)
{
not_a_loop = true;
if (in_degree_n >= 1
&& out_degree_n >= 1) // Do not eliminate a vertex if it loops on itself!
for (auto [it_in, in_end] = in_edges(*it, *this); it_in != in_end; ++it_in)
if (source(*it_in, *this) == target(*it_in, *this))
not_a_loop = false;
if (not_a_loop)
{
eliminate(*it);
something_has_been_done = true;
if (i > 0)
it = ita;
else
{
tie(it, it_end) = vertices(*this);
i--;
}
}
}
ita = it;
}
return something_has_been_done;
}
bool
VariableDependencyGraph::eliminationOfVerticesBelongingToAClique()
{
vertex_iterator it, ita, it_end;
bool something_has_been_done = false;
int i;
for (tie(it, it_end) = vertices(*this), i = 0; it != it_end; ++it, i++)
{
if (vertexBelongsToAClique(*it))
{
eliminate(*it);
something_has_been_done = true;
if (i > 0)
it = ita;
else
{
tie(it, it_end) = vertices(*this);
i--;
}
}
ita = it;
}
return something_has_been_done;
}
bool
VariableDependencyGraph::suppressionOfVerticesWithLoop(set<int>& feed_back_vertices)
{
bool something_has_been_done = false;
vertex_iterator ita;
int i = 0;
for (auto [it, it_end] = vertices(*this); it != it_end; ++it, i++)
{
auto [ed, exist] = edge(*it, *it, *this);
if (exist)
{
auto v_index = get(vertex_index, *this);
feed_back_vertices.insert(v_index[*it]);
suppress(*it);
something_has_been_done = true;
if (i > 0)
it = ita;
else
{
tie(it, it_end) = vertices(*this);
i--;
}
}
ita = it;
}
return something_has_been_done;
}
set<int>
VariableDependencyGraph::minimalSetOfFeedbackVertices() const
{
set<int> feed_back_vertices;
VariableDependencyGraph G(*this);
while (num_vertices(G) > 0)
{
bool something_has_been_done = true;
while (something_has_been_done && num_vertices(G) > 0)
{
something_has_been_done = G.eliminationOfVerticesWithOneOrLessIndegreeOrOutdegree();
something_has_been_done
= G.eliminationOfVerticesBelongingToAClique() || something_has_been_done;
something_has_been_done
= G.suppressionOfVerticesWithLoop(feed_back_vertices) || something_has_been_done;
}
if (!G.hasCycle())
return feed_back_vertices;
if (num_vertices(G) > 0)
{
/* If nothing has been done in the five previous rule then cut the
vertex with the maximum in_degree+out_degree */
int max_degree = 0, num = 0;
vertex_iterator max_degree_index;
for (auto [it, it_end] = vertices(G); it != it_end; ++it, num++)
if (static_cast<int>(in_degree(*it, G) + out_degree(*it, G)) > max_degree)
{
max_degree = in_degree(*it, G) + out_degree(*it, G);
max_degree_index = it;
}
auto v_index = get(vertex_index, G);
feed_back_vertices.insert(v_index[*max_degree_index]);
G.suppress(*max_degree_index);
}
}
return feed_back_vertices;
}
vector<int>
VariableDependencyGraph::reorderRecursiveVariables(const set<int>& feedback_vertices) const
{
vector<int> reordered_vertices;
VariableDependencyGraph G(*this);
auto v_index = get(vertex_index, G);
// Suppress feedback vertices, in decreasing order
for (int feedback_vertex : ranges::reverse_view(feedback_vertices))
G.suppress(feedback_vertex);
bool something_has_been_done = true;
while (something_has_been_done)
{
something_has_been_done = false;
vertex_iterator it, it_end, ita;
int i;
for (tie(it, it_end) = vertices(G), i = 0; it != it_end; ++it, i++)
{
if (in_degree(*it, G) == 0)
{
reordered_vertices.push_back(v_index[*it]);
G.suppress(*it);
something_has_been_done = true;
if (i > 0)
it = ita;
else
{
tie(it, it_end) = vertices(G);
i--;
}
}
ita = it;
}
}
if (num_vertices(G))
cout << "Error in the computation of feedback vertex set\n";
return reordered_vertices;
}
pair<int, vector<int>>
VariableDependencyGraph::sortedStronglyConnectedComponents() const
{
vector<int> vertex2scc(num_vertices(*this));
auto v_index = get(vertex_index, *this);
// Compute SCCs and create mapping from vertices to unordered SCCs
int num_scc = strong_components(static_cast<base>(*this),
make_iterator_property_map(vertex2scc.begin(), v_index));
// Create directed acyclic graph (DAG) associated to the SCCs
adjacency_list<vecS, vecS, directedS> dag(num_scc);
for (int i = 0; i < static_cast<int>(num_vertices(*this)); i++)
{
auto vi = vertex(i, *this);
for (auto [it_out, out_end] = out_edges(vi, *this); it_out != out_end; ++it_out)
if (int t_b = vertex2scc[v_index[target(*it_out, *this)]],
s_b = vertex2scc[v_index[source(*it_out, *this)]];
s_b != t_b)
add_edge(s_b, t_b, dag);
}
/* Compute topological sort of DAG (ordered list of unordered SCC)
Note: the order is reversed. */
vector<int> reverseOrdered2unordered;
topological_sort(dag, back_inserter(reverseOrdered2unordered));
// Construct mapping from unordered SCC to ordered SCC
vector<int> unordered2ordered(num_scc);
for (int j = 0; j < num_scc; j++)
unordered2ordered[reverseOrdered2unordered[num_scc - j - 1]] = j;
// Update the mapping of vertices to (now sorted) SCCs
for (int i = 0; i < static_cast<int>(num_vertices(*this)); i++)
vertex2scc[i] = unordered2ordered[vertex2scc[i]];
return {num_scc, vertex2scc};
}
/*
* Copyright © 2009-2023 Dynare Team
*
* This file is part of Dynare.
*
* Dynare is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* Dynare is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with Dynare. If not, see <https://www.gnu.org/licenses/>.
*/
#ifndef VARIABLE_DEPENDENCY_GRAPH_HH
#define VARIABLE_DEPENDENCY_GRAPH_HH
#include <map>
#include <vector>
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wold-style-cast"
#include <boost/graph/adjacency_list.hpp>
#pragma GCC diagnostic pop
using namespace std;
using VertexProperty_t = boost::property<
boost::vertex_index_t, int,
boost::property<
boost::vertex_index1_t, int,
boost::property<boost::vertex_degree_t, int,
boost::property<boost::vertex_in_degree_t, int,
boost::property<boost::vertex_out_degree_t, int>>>>>;
/* Class used to store a graph representing dependencies between variables.
Used in the block decomposition. */
class VariableDependencyGraph :
public boost::adjacency_list<boost::listS, boost::listS, boost::bidirectionalS,
VertexProperty_t>
{
public:
using color_t = map<boost::graph_traits<VariableDependencyGraph>::vertex_descriptor,
boost::default_color_type>;
using base
= boost::adjacency_list<boost::listS, boost::listS, boost::bidirectionalS, VertexProperty_t>;
VariableDependencyGraph(int n);
//! Extracts a subgraph
/*!
\param[in] select_index The vertex indices to select
\return The subgraph
The property vertex_index1 of the subgraph contains indices of the original
graph.
*/
[[nodiscard]] VariableDependencyGraph extractSubgraph(const vector<int>& select_index) const;
//! Return the feedback set
[[nodiscard]] set<int> minimalSetOfFeedbackVertices() const;
//! Reorder the recursive variables
/*! They appear first in a quasi triangular form and they are followed by the feedback variables
*/
[[nodiscard]] vector<int> reorderRecursiveVariables(const set<int>& feedback_vertices) const;
/* Computes the strongly connected components (SCCs) of the graph, and sort them
topologically.
Returns the number of SCCs, and a mapping of vertex indices to sorted SCC
indices. */
[[nodiscard]] pair<int, vector<int>> sortedStronglyConnectedComponents() const;
// Print on stdout a description of the graph
void print() const;
private:
// Remove a vertex (including all edges to and from it); takes a vertex descriptor
void suppress(vertex_descriptor vertex_to_eliminate);
// Remove a vertex (including all edges to and from it); takes a vertex index
void suppress(int vertex_num);
/* Remove a vertex, but keeping the paths that go through it (i.e. by adding
edges that directly connect vertices that would otherwise be connected
through the vertex to be removed) */
void eliminate(vertex_descriptor vertex_to_eliminate);
// Internal helper for hasCycle()
bool hasCycleDFS(vertex_descriptor u, color_t& color, vector<int>& circuit_stack) const;
// Determine whether the graph has a cycle
[[nodiscard]] bool hasCycle() const;
bool vertexBelongsToAClique(vertex_descriptor vertex) const;
bool eliminationOfVerticesWithOneOrLessIndegreeOrOutdegree();
bool eliminationOfVerticesBelongingToAClique();
// The suppressed vertices are stored in feedback set
bool suppressionOfVerticesWithLoop(set<int>& feed_back_vertices);
};
#endif
/*
* Copyright © 2012-2017 Dynare Team
* Copyright © 2012-2024 Dynare Team
*
* This file is part of Dynare.
*
......@@ -14,46 +14,37 @@
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with Dynare. If not, see <http://www.gnu.org/licenses/>.
* along with Dynare. If not, see <https://www.gnu.org/licenses/>.
*/
#include "WarningConsolidation.hh"
#include <ostream>
WarningConsolidation
&
operator<<(WarningConsolidation &wcc, const string &warning)
ostream&
operator<<(ostream& stream, const Dynare::location& l)
{
if (wcc.no_warn)
return wcc;
stream << *l.begin.filename << ": line " << l.begin.line;
if (l.begin.line == l.end.line)
if (l.begin.column == l.end.column - 1)
stream << ", col " << l.begin.column;
else
stream << ", cols " << l.begin.column << "-" << l.end.column - 1;
else
stream << ", col " << l.begin.column << " -"
<< " line " << l.end.line << ", col " << l.end.column - 1;
cerr << warning;
wcc.addWarning(warning);
return wcc;
};
return stream;
}
WarningConsolidation &
operator<<(WarningConsolidation &wcc, const Dynare::location &loc)
void
WarningConsolidation::incrementWarnings(const string& msg)
{
if (wcc.no_warn)
return wcc;
stringstream ostr;
Dynare::position last = loc.end - 1;
ostr << loc.begin;
if (last.filename
&& (!loc.begin.filename
|| *loc.begin.filename != *last.filename))
ostr << '-' << last;
else if (loc.begin.line != last.line)
ostr << '-' << last.line << '.' << last.column;
else if (loc.begin.column != last.column)
ostr << '-' << last.column;
cerr << ostr.str();
wcc.addWarning(ostr.str());
return wcc;
};
size_t p {0};
while ((p = msg.find('\n', p)) != string::npos)
{
p++;
num_warnings++;
}
}
WarningConsolidation&
operator<<(WarningConsolidation& wcc, ostream& (*pf)(ostream&))
......@@ -61,49 +52,12 @@ operator<<(WarningConsolidation &wcc, ostream & (*pf)(ostream &))
if (wcc.no_warn)
return wcc;
cerr << pf;
wcc.addWarning(pf);
return wcc;
}
void
WarningConsolidation::writeOutput(ostream &output) const
{
if (warnings.str().empty())
return;
ostringstream ostr;
ostr << pf;
output << "disp([char(10) 'Dynare Preprocessor Warning(s) Encountered:']);" << endl;
cerr << ostr.str();
bool writedisp = true;
string warningsstr = warnings.str();
for (size_t i = 0; i < warningsstr.length(); i++)
{
if (writedisp)
{
output << "disp(' ";
writedisp = false;
}
wcc.incrementWarnings(ostr.str());
if (warningsstr[i] != '\n')
output << warningsstr[i];
else
{
output << "');" << endl;
if (i+1 < warningsstr.length())
writedisp = true;
}
}
}
int
WarningConsolidation::countWarnings() const
{
size_t p = 0;
int n = 0;
while ((p = warnings.str().find('\n', p)) != string::npos)
{
p++;
n++;
}
return n;
return wcc;
}
/*
* Copyright © 2012-2017 Dynare Team
* Copyright © 2012-2024 Dynare Team
*
* This file is part of Dynare.
*
......@@ -14,52 +14,71 @@
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with Dynare. If not, see <http://www.gnu.org/licenses/>.
* along with Dynare. If not, see <https://www.gnu.org/licenses/>.
*/
#ifndef _WARNINGCONSOLIDATION_HH
#define _WARNINGCONSOLIDATION_HH
#ifndef WARNING_CONSOLIDATION_HH
#define WARNING_CONSOLIDATION_HH
#include "DynareBisonLocation.hh"
#include <iostream>
#include <ostream>
#include <sstream>
#include <string>
#include "location.hh"
using namespace std;
//! Stores Warnings issued by the Preprocessor
/* Provide our implementation of operator<< with locations in DynareBison.hh. Note that the
following is a template specialization of the version provided in DynareBisonLocation.hh.
Ideally it should go into DynareBisonLocation.hh, but there does not seem to be a way to achieve
that. */
ostream& operator<<(ostream& stream, const Dynare::location& l);
class WarningConsolidation
{
private:
stringstream warnings;
bool no_warn;
const bool no_warn;
int num_warnings {0};
// Increases the warning counter by as many newlines as there are in the message
void incrementWarnings(const string& msg);
public:
explicit WarningConsolidation(bool no_warn_arg) : no_warn {no_warn_arg}
{
};
}
// Generic function to print something to the warning stream
friend WarningConsolidation& operator<<(WarningConsolidation& wcc, auto&& warning);
//! Add A Warning to the StringStream
friend WarningConsolidation &operator<<(WarningConsolidation &wcc, const string &warning);
friend WarningConsolidation &operator<<(WarningConsolidation &wcc, const Dynare::location &loc);
/* Print std::endl to the warning stream. Unfortunately, since std::endl is a template of
functions, it cannot be bound to the universal reference of the generic function, hence the
need for this specialization. */
friend WarningConsolidation& operator<<(WarningConsolidation& wcc, ostream& (*pf)(ostream&));
inline void
addWarning(const string &w)
int
numWarnings() const
{
warnings << w;
return num_warnings;
}
};
inline void
addWarning(ostream & (*pf)(ostream &))
WarningConsolidation&
operator<<(WarningConsolidation& wcc, auto&& warning)
{
warnings << pf;
};
if (wcc.no_warn)
return wcc;
//! Write Warnings to m file
void writeOutput(ostream &output) const;
//! Count warnings
/*! This is done in a very lousy way, by counting newlines in the
stringstream... */
int countWarnings() const;
};
ostringstream ostr;
ostr << warning;
cerr << ostr.str();
wcc.incrementWarnings(ostr.str());
return wcc;
}
#endif
# GDB pretty-printer for ExprNode class hierarchy
# Copyright © 2022-2023 Dynare Team
#
# This file is part of Dynare.
#
# Dynare is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Dynare is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Dynare. If not, see <https://www.gnu.org/licenses/>.
class ExprNodePrinter:
'''Pretty-prints an ExprNode value'''
def __init__(self, val):
self.val = val
def to_string(self):
# Call the toString() method on the pointer.
# We must use the raw pretty-printer for the pointer itself, otherwise
# we enter an infinite loop.
# We retrieve a C string, because with C++ strings the gdb pretty-printer
# insists on keeping the quotes around it.
r = gdb.parse_and_eval("((ExprNode *) " + self.val.format_string(raw = True) + ")->toString().c_str()")
typestr = "(" + str(self.val.type) + ") ";
# Add dynamic type information between brackets, if different from static type
if str(self.val.type) != str(self.val.dynamic_type):
typestr += "[" + str(self.val.dynamic_type) + "] "
return typestr + r.string()
class ExprNodePrinterControl(gdb.printing.PrettyPrinter):
'''Determines whether a value can be pretty printed with ExprNodePrinter. To be directly registered within the GDB API.'''
def __init__(self):
# The name below will appear in “info pretty-printer”, and can be used with “enable/disable pretty-printer”
super().__init__('ExprNode')
def __call__(self, val):
# Check if the value is a subtype of ExprNode *.
# Doing a dynamic_cast on a non-pointer type triggers an exception, so we first check
# whether it’s a pointer (after resolving for typedefs, such as “expr_t”).
if val.type.strip_typedefs().code == gdb.TYPE_CODE_PTR and val.dynamic_cast(gdb.lookup_type('ExprNode').pointer()) != 0:
return ExprNodePrinter(val)
# Register the pretty printer
gdb.pretty_printers.append(ExprNodePrinterControl())
/*
* Copyright © 2019 Dynare Team
* Copyright © 2019-2023 Dynare Team
*
* This file is part of Dynare.
*
......@@ -14,22 +14,23 @@
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with Dynare. If not, see <http://www.gnu.org/licenses/>.
* along with Dynare. If not, see <https://www.gnu.org/licenses/>.
*/
#include "Directives.hh"
#include "Driver.hh"
#include <fstream>
#include <utility>
using namespace macro;
void
Eval::interpret(ostream &output, bool no_line_macro, vector<filesystem::path> &paths)
Eval::interpret(ostream& output, Environment& env, [[maybe_unused]] vector<filesystem::path>& paths)
{
try
{
output << expr->eval()->to_string();
output << expr->eval(env)->to_string();
}
catch (StackTrace& ex)
{
......@@ -43,12 +44,12 @@ Eval::interpret(ostream &output, bool no_line_macro, vector<filesystem::path> &p
}
void
Include::interpret(ostream &output, bool no_line_macro, vector<filesystem::path> &paths)
Include::interpret(ostream& output, Environment& env, vector<filesystem::path>& paths)
{
using namespace filesystem;
try
{
StringPtr msp = dynamic_pointer_cast<String>(expr->eval());
StringPtr msp = dynamic_pointer_cast<String>(expr->eval(env));
if (!msp)
throw StackTrace("File name does not evaluate to a string");
path filename = msp->to_string();
......@@ -67,15 +68,18 @@ Include::interpret(ostream &output, bool no_line_macro, vector<filesystem::path>
errmsg << " * " << current_path().string() << endl;
for (const auto& dir : paths)
errmsg << " * " << absolute(dir).string() << endl;
error(StackTrace("@#includepath", "Could not open " + filename.string() +
". The following directories were searched:\n" + errmsg.str(), location));
error(StackTrace("@#include",
"Could not open " + filename.string()
+ ". The following directories were searched:\n" + errmsg.str(),
location));
}
}
Driver m(env, no_line_macro);
// Calling `string()` method on filename and filename.stem() because of bug in
// MinGW 8.3.0 that ignores implicit conversion to string from filename::path.
// Test if bug exists when version of MinGW is upgraded on Debian runners
m.parse(filename.string(), filename.stem().string(), incfile, output, false, vector<pair<string, string>>{}, paths);
Driver m;
/* Calling string() method on filename: not necessary on GNU/Linux and macOS because there is
an implicit conversion from from filesystem:path to string (i.e. basic_string<char>), but
needed on Windows because the implicit conversion is only to wstring (i.e.
basic_string<wchar_t>). */
m.parse(filename.string(), incfile, false, {}, env, paths, output);
}
catch (StackTrace& ex)
{
......@@ -86,15 +90,17 @@ Include::interpret(ostream &output, bool no_line_macro, vector<filesystem::path>
{
error(StackTrace("@#include", e.what(), location));
}
printEndLineInfo(output);
}
void
IncludePath::interpret(ostream &output, bool no_line_macro, vector<filesystem::path> &paths)
IncludePath::interpret([[maybe_unused]] ostream& output, Environment& env,
vector<filesystem::path>& paths)
{
using namespace filesystem;
try
{
StringPtr msp = dynamic_pointer_cast<String>(expr->eval());
StringPtr msp = dynamic_pointer_cast<String>(expr->eval(env));
if (!msp)
throw StackTrace("File name does not evaluate to a string");
path ip = static_cast<string>(*msp);
......@@ -113,10 +119,12 @@ IncludePath::interpret(ostream &output, bool no_line_macro, vector<filesystem::p
{
error(StackTrace("@#includepath", e.what(), location));
}
printEndLineInfo(output);
}
void
Define::interpret(ostream &output, bool no_line_macro, vector<filesystem::path> &paths)
Define::interpret([[maybe_unused]] ostream& output, Environment& env,
[[maybe_unused]] vector<filesystem::path>& paths)
{
try
{
......@@ -136,14 +144,15 @@ Define::interpret(ostream &output, bool no_line_macro, vector<filesystem::path>
{
error(StackTrace("@#define", e.what(), location));
}
printEndLineInfo(output);
}
void
Echo::interpret(ostream &output, bool no_line_macro, vector<filesystem::path> &paths)
Echo::interpret(ostream& output, Environment& env, [[maybe_unused]] vector<filesystem::path>& paths)
{
try
{
cout << "@#echo (" << getLocation() << "): " << expr->eval()->to_string() << endl;
cout << "@#echo (" << getLocation() << "): " << expr->eval(env)->to_string() << endl;
}
catch (StackTrace& ex)
{
......@@ -154,15 +163,16 @@ Echo::interpret(ostream &output, bool no_line_macro, vector<filesystem::path> &p
{
error(StackTrace("@#echo", e.what(), location));
}
printEndLineInfo(output, no_line_macro);
printEndLineInfo(output);
}
void
Error::interpret(ostream &output, bool no_line_macro, vector<filesystem::path> &paths)
Error::interpret([[maybe_unused]] ostream& output, Environment& env,
[[maybe_unused]] vector<filesystem::path>& paths)
{
try
{
throw StackTrace(expr->eval()->to_string());
throw StackTrace(expr->eval(env)->to_string());
}
catch (StackTrace& ex)
{
......@@ -176,22 +186,20 @@ Error::interpret(ostream &output, bool no_line_macro, vector<filesystem::path> &
}
void
EchoMacroVars::interpret(ostream &output, bool no_line_macro, vector<filesystem::path> &paths)
EchoMacroVars::interpret(ostream& output, Environment& env,
[[maybe_unused]] vector<filesystem::path>& paths)
{
if (save)
env.print(output, vars, location.begin.line, true);
else
env.print(cout, vars);
printEndLineInfo(output, no_line_macro);
env.print(save ? output : cout, vars, location.begin.line, save);
printEndLineInfo(output);
}
void
For::interpret(ostream &output, bool no_line_macro, vector<filesystem::path> &paths)
For::interpret(ostream& output, Environment& env, vector<filesystem::path>& paths)
{
ArrayPtr ap;
try
{
ap = dynamic_pointer_cast<Array>(index_vals->eval());
ap = dynamic_pointer_cast<Array>(index_vals->eval(env));
if (!ap)
throw StackTrace("The index must loop through an array");
}
......@@ -219,8 +227,11 @@ For::interpret(ostream &output, bool no_line_macro, vector<filesystem::path> &pa
{
TuplePtr mtp = dynamic_pointer_cast<Tuple>(btp);
if (index_vec.size() != mtp->size())
error(StackTrace("@#for", "Encountered tuple of size " + to_string(mtp->size())
+ " but only have " + to_string(index_vec.size()) + " index variables", location));
error(StackTrace("@#for",
"Encountered tuple of size " + to_string(mtp->size())
+ " but only have " + to_string(index_vec.size())
+ " index variables",
location));
else
for (size_t j = 0; j < index_vec.size(); j++)
env.define(index_vec.at(j), mtp->at(j));
......@@ -232,33 +243,49 @@ For::interpret(ostream &output, bool no_line_macro, vector<filesystem::path> &pa
{
if (printLine)
{
statement->printLineInfo(output, no_line_macro);
statement->printLineInfo(output);
printLine = false;
}
statement->interpret(output, no_line_macro, paths);
statement->interpret(output, env, paths);
}
}
printEndLineInfo(output, no_line_macro);
printEndLineInfo(output);
}
void
If::interpret(ostream &output, bool no_line_macro, vector<filesystem::path> &paths)
If::interpret(ostream& output, Environment& env, vector<filesystem::path>& paths)
{
for (const auto & [expr, body] : expr_and_body)
for (bool first_clause {true}; const auto& [expr, body] : expr_and_body)
try
{
auto tmp = expr->eval();
if ((ifdef || ifndef) && exchange(first_clause, false))
{
VariablePtr vp = dynamic_pointer_cast<Variable>(expr);
if (!vp)
error(StackTrace(ifdef ? "@#ifdef" : "@#ifndef",
"The condition must be a variable name", location));
if ((ifdef && env.isVariableDefined(vp->getName()))
|| (ifndef && !env.isVariableDefined(vp->getName())))
{
interpretBody(body, output, env, paths);
break;
}
}
else
{
auto tmp = expr->eval(env);
RealPtr dp = dynamic_pointer_cast<Real>(tmp);
BoolPtr bp = dynamic_pointer_cast<Bool>(tmp);
if (!bp && !dp)
error(StackTrace("@#if",
"The condition must evaluate to a boolean or a double", location));
error(StackTrace("@#if", "The condition must evaluate to a boolean or a double",
location));
if ((bp && *bp) || (dp && *dp))
{
interpretBody(body, output, no_line_macro, paths);
interpretBody(body, output, env, paths);
break;
}
}
}
catch (StackTrace& ex)
{
ex.push("@#if", location);
......@@ -268,48 +295,17 @@ If::interpret(ostream &output, bool no_line_macro, vector<filesystem::path> &pat
{
error(StackTrace("@#if", e.what(), location));
}
printEndLineInfo(output, no_line_macro);
printEndLineInfo(output);
}
void
If::interpretBody(const vector<DirectivePtr> &body, ostream &output, bool no_line_macro, vector<filesystem::path> &paths)
If::interpretBody(const vector<DirectivePtr>& body, ostream& output, Environment& env,
vector<filesystem::path>& paths)
{
bool printLine = !no_line_macro;
for (const auto & statement : body)
{
if (printLine)
for (bool printLine {true}; const auto& statement : body)
{
statement->printLineInfo(output, no_line_macro);
printLine = false;
}
statement->interpret(output, no_line_macro, paths);
}
}
void
Ifdef::interpret(ostream &output, bool no_line_macro, vector<filesystem::path> &paths)
{
for (const auto & [expr, body] : expr_and_body)
if (VariablePtr vp = dynamic_pointer_cast<Variable>(expr);
dynamic_pointer_cast<BaseType>(expr)
|| (vp && env.isVariableDefined(vp->getName())))
{
interpretBody(body, output, no_line_macro, paths);
break;
}
printEndLineInfo(output, no_line_macro);
}
void
Ifndef::interpret(ostream &output, bool no_line_macro, vector<filesystem::path> &paths)
{
for (const auto & [expr, body] : expr_and_body)
if (VariablePtr vp = dynamic_pointer_cast<Variable>(expr);
!(dynamic_pointer_cast<BaseType>(expr)
|| (vp && env.isVariableDefined(vp->getName()))))
{
interpretBody(body, output, no_line_macro, paths);
break;
if (exchange(printLine, false))
statement->printLineInfo(output);
statement->interpret(output, env, paths);
}
printEndLineInfo(output, no_line_macro);
}
/*
* Copyright (C) 2019 Dynare Team
* Copyright © 2019-2023 Dynare Team
*
* This file is part of Dynare.
*
......@@ -14,11 +14,11 @@
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with Dynare. If not, see <http://www.gnu.org/licenses/>.
* along with Dynare. If not, see <https://www.gnu.org/licenses/>.
*/
#ifndef _DIRECTIVES_HH
#define _DIRECTIVES_HH
#ifndef DIRECTIVES_HH
#define DIRECTIVES_HH
#include "Expressions.hh"
......@@ -30,12 +30,13 @@ namespace macro
{
// A Parent class just for clarity
public:
Directive(Environment &env_arg, Tokenizer::location location_arg) : Node(env_arg, move(location_arg)) { }
explicit Directive(Tokenizer::location location_arg) : Node(move(location_arg))
{
}
// Directives can be interpreted
virtual void interpret(ostream &output, bool no_line_macro, vector<filesystem::path> &paths) = 0;
virtual void interpret(ostream& output, Environment& env, vector<filesystem::path>& paths) = 0;
};
class TextNode : public Directive
{
// Class for text not interpreted by macroprocessor
......@@ -43,13 +44,20 @@ namespace macro
// Treated as such as the output is only to be interpreted
private:
const string text;
public:
TextNode(string text_arg, Environment &env_arg, Tokenizer::location location_arg) :
Directive(env_arg, move(location_arg)), text{move(text_arg)} { }
inline void interpret(ostream &output, bool no_line_macro, vector<filesystem::path> &paths) override { output << text; }
TextNode(string text_arg, Tokenizer::location location_arg) :
Directive(move(location_arg)), text {move(text_arg)}
{
}
void
interpret(ostream& output, [[maybe_unused]] Environment& env,
[[maybe_unused]] vector<filesystem::path>& paths) override
{
output << text;
}
};
class Eval : public Directive
{
// Class for @{} statements
......@@ -57,111 +65,123 @@ namespace macro
// Treated as such as the output is only to be interpreted
private:
const ExpressionPtr expr;
public:
Eval(ExpressionPtr expr_arg, Environment &env_arg, Tokenizer::location location_arg) :
Directive(env_arg, move(location_arg)), expr{move(expr_arg)} { }
void interpret(ostream &output, bool no_line_macro, vector<filesystem::path> &paths) override;
Eval(ExpressionPtr expr_arg, Tokenizer::location location_arg) :
Directive(move(location_arg)), expr {move(expr_arg)}
{
}
void interpret(ostream& output, Environment& env, vector<filesystem::path>& paths) override;
};
class Include : public Directive
{
private:
const ExpressionPtr expr;
public:
Include(ExpressionPtr expr_arg, Environment &env_arg, Tokenizer::location location_arg) :
Directive(env_arg, move(location_arg)), expr{move(expr_arg)} { }
void interpret(ostream &output, bool no_line_macro, vector<filesystem::path> &paths) override;
Include(ExpressionPtr expr_arg, Tokenizer::location location_arg) :
Directive(move(location_arg)), expr {move(expr_arg)}
{
}
void interpret(ostream& output, Environment& env, vector<filesystem::path>& paths) override;
};
class IncludePath : public Directive
{
private:
const ExpressionPtr expr;
public:
IncludePath(ExpressionPtr expr_arg, Environment &env_arg, Tokenizer::location location_arg) :
Directive(env_arg, move(location_arg)), expr{move(expr_arg)} { }
void interpret(ostream &output, bool no_line_macro, vector<filesystem::path> &paths) override;
IncludePath(ExpressionPtr expr_arg, Tokenizer::location location_arg) :
Directive(move(location_arg)), expr {move(expr_arg)}
{
}
void interpret(ostream& output, Environment& env, vector<filesystem::path>& paths) override;
};
class Define : public Directive
{
private:
const VariablePtr var;
const FunctionPtr func;
const ExpressionPtr value;
public:
Define(VariablePtr var_arg,
ExpressionPtr value_arg,
Environment &env_arg, Tokenizer::location location_arg) :
Directive(env_arg, move(location_arg)), var{move(var_arg)}, value{move(value_arg)} { }
Define(FunctionPtr func_arg,
ExpressionPtr value_arg,
Environment &env_arg, Tokenizer::location location_arg) :
Directive(env_arg, move(location_arg)), func{move(func_arg)}, value{move(value_arg)} { }
void interpret(ostream &output, bool no_line_macro, vector<filesystem::path> &paths) override;
Define(VariablePtr var_arg, ExpressionPtr value_arg, Tokenizer::location location_arg) :
Directive(move(location_arg)), var {move(var_arg)}, value {move(value_arg)}
{
}
Define(FunctionPtr func_arg, ExpressionPtr value_arg, Tokenizer::location location_arg) :
Directive(move(location_arg)), func {move(func_arg)}, value {move(value_arg)}
{
}
void interpret(ostream& output, Environment& env, vector<filesystem::path>& paths) override;
};
class Echo : public Directive
{
private:
const ExpressionPtr expr;
public:
Echo(ExpressionPtr expr_arg,
Environment &env_arg, Tokenizer::location location_arg) :
Directive(env_arg, move(location_arg)), expr{move(expr_arg)} { }
void interpret(ostream &output, bool no_line_macro, vector<filesystem::path> &paths) override;
Echo(ExpressionPtr expr_arg, Tokenizer::location location_arg) :
Directive(move(location_arg)), expr {move(expr_arg)}
{
}
void interpret(ostream& output, Environment& env, vector<filesystem::path>& paths) override;
};
class Error : public Directive
{
private:
const ExpressionPtr expr;
public:
Error(ExpressionPtr expr_arg,
Environment &env_arg, Tokenizer::location location_arg) :
Directive(env_arg, move(location_arg)), expr{move(expr_arg)} { }
void interpret(ostream &output, bool no_line_macro, vector<filesystem::path> &paths) override;
Error(ExpressionPtr expr_arg, Tokenizer::location location_arg) :
Directive(move(location_arg)), expr {move(expr_arg)}
{
}
void interpret(ostream& output, Environment& env, vector<filesystem::path>& paths) override;
};
class EchoMacroVars : public Directive
{
private:
const bool save;
const vector<string> vars;
public:
EchoMacroVars(bool save_arg,
Environment &env_arg, Tokenizer::location location_arg) :
Directive(env_arg, move(location_arg)), save{save_arg} { }
EchoMacroVars(bool save_arg, vector<string> vars_arg,
Environment &env_arg, Tokenizer::location location_arg) :
Directive(env_arg, move(location_arg)), save{save_arg}, vars{move(vars_arg)} { }
void interpret(ostream &output, bool no_line_macro, vector<filesystem::path> &paths) override;
EchoMacroVars(bool save_arg, Tokenizer::location location_arg) :
Directive(move(location_arg)), save {save_arg}
{
}
EchoMacroVars(bool save_arg, vector<string> vars_arg, Tokenizer::location location_arg) :
Directive(move(location_arg)), save {save_arg}, vars {move(vars_arg)}
{
}
void interpret(ostream& output, Environment& env, vector<filesystem::path>& paths) override;
};
class For : public Directive
{
private:
const vector<VariablePtr> index_vec;
const ExpressionPtr index_vals;
const vector<DirectivePtr> statements;
public:
For(vector<VariablePtr> index_vec_arg,
ExpressionPtr index_vals_arg,
vector<DirectivePtr> statements_arg,
Environment &env_arg, Tokenizer::location location_arg) :
Directive(env_arg, move(location_arg)), index_vec{move(index_vec_arg)},
index_vals{move(index_vals_arg)}, statements{move(statements_arg)} { }
void interpret(ostream &output, bool no_line_macro, vector<filesystem::path> &paths) override;
For(vector<VariablePtr> index_vec_arg, ExpressionPtr index_vals_arg,
vector<DirectivePtr> statements_arg, Tokenizer::location location_arg) :
Directive(move(location_arg)),
index_vec {move(index_vec_arg)},
index_vals {move(index_vals_arg)},
statements {move(statements_arg)}
{
}
void interpret(ostream& output, Environment& env, vector<filesystem::path>& paths) override;
};
class If : public Directive
{
protected:
......@@ -174,33 +194,42 @@ namespace macro
* If there is an `else` statement it is the last element in the vector. Its condition is true.
*/
const vector<pair<ExpressionPtr, vector<DirectivePtr>>> expr_and_body;
const bool ifdef, ifndef;
public:
If(vector<pair<ExpressionPtr, vector<DirectivePtr>>> expr_and_body_arg,
Environment &env_arg, Tokenizer::location location_arg) :
Directive(env_arg, move(location_arg)), expr_and_body{move(expr_and_body_arg)} { }
void interpret(ostream &output, bool no_line_macro, vector<filesystem::path> &paths) override;
Tokenizer::location location_arg, bool ifdef_arg = false, bool ifndef_arg = false) :
Directive(move(location_arg)),
expr_and_body {move(expr_and_body_arg)},
ifdef {ifdef_arg},
ifndef {ifndef_arg}
{
}
void interpret(ostream& output, Environment& env, vector<filesystem::path>& paths) override;
protected:
void interpretBody(const vector<DirectivePtr> &body, ostream &output, bool no_line_macro, vector<filesystem::path> &paths);
void interpretBody(const vector<DirectivePtr>& body, ostream& output, Environment& env,
vector<filesystem::path>& paths);
};
class Ifdef : public If
{
public:
Ifdef(vector<pair<ExpressionPtr, vector<DirectivePtr>>> expr_and_body_arg,
Environment &env_arg, Tokenizer::location location_arg) :
If(move(expr_and_body_arg), env_arg, move(location_arg)) { }
void interpret(ostream &output, bool no_line_macro, vector<filesystem::path> &paths) override;
Tokenizer::location location_arg) :
If(move(expr_and_body_arg), move(location_arg), true, false)
{
}
};
class Ifndef : public If
{
public:
Ifndef(vector<pair<ExpressionPtr, vector<DirectivePtr>>> expr_and_body_arg,
Environment &env_arg, Tokenizer::location location_arg) :
If(move(expr_and_body_arg), env_arg, move(location_arg)) { }
void interpret(ostream &output, bool no_line_macro, vector<filesystem::path> &paths) override;
Tokenizer::location location_arg) :
If(move(expr_and_body_arg), move(location_arg), false, true)
{
}
};
}
#endif
/*
* Copyright © 2019 Dynare Team
* Copyright © 2019-2023 Dynare Team
*
* This file is part of Dynare.
*
......@@ -14,47 +14,33 @@
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with Dynare. If not, see <http://www.gnu.org/licenses/>.
* along with Dynare. If not, see <https://www.gnu.org/licenses/>.
*/
#include "Driver.hh"
#include <regex>
#include <utility>
using namespace macro;
void
Driver::parse(const string &file_arg, const string &basename_arg, istream &modfile,
ostream &output, bool debug, const vector<pair<string, string>> &defines,
vector<filesystem::path> &paths)
Driver::parse(const string& file_arg, const istream& modfile, bool debug,
const vector<pair<string, string>>& defines, Environment& env,
vector<filesystem::path>& paths, ostream& output)
{
file = file_arg;
basename = basename_arg;
if (!defines.empty())
{
stringstream command_line_defines_with_endl;
for (const auto& [var, val] : defines)
try
{
stoi(val);
command_line_defines_with_endl << "@#define " << var << " = " << val << endl;
}
catch (const invalid_argument &)
{
if (!val.empty() && val.at(0) == '[' && val.at(val.length()-1) == ']')
// If the input is an array. Issue #1578
command_line_defines_with_endl << "@#define " << var << " = " << val << endl;
else
command_line_defines_with_endl << "@#define " << var << " = \"" << val << "\"" << endl;
}
Driver m(env, true);
Driver m;
istream is(command_line_defines_with_endl.rdbuf());
m.parse("command_line_defines", "command_line_defines", is, output, debug, vector<pair<string, string>>{}, paths);
m.parse("command_line_defines", is, debug, {}, env, paths, output);
}
// Handle empty files
if (modfile.rdbuf()->in_avail() == 0)
return;
stringstream file_with_endl;
file_with_endl << modfile.rdbuf() << endl;
......@@ -68,15 +54,11 @@ Driver::parse(const string &file_arg, const string &basename_arg, istream &modfi
parser.parse();
// Interpret parsed statements
bool printLine = true;
for (const auto & statement : statements)
for (bool printLine {true}; const auto& statement : statements)
{
if (printLine)
{
statement->printLineInfo(output, no_line_macro);
printLine = false;
}
statement->interpret(output, no_line_macro, paths);
if (exchange(printLine, false))
statement->printLineInfo(output);
statement->interpret(output, env, paths);
}
}
......
/*
* Copyright © 2019 Dynare Team
* Copyright © 2019-2023 Dynare Team
*
* This file is part of Dynare.
*
......@@ -14,19 +14,19 @@
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with Dynare. If not, see <http://www.gnu.org/licenses/>.
* along with Dynare. If not, see <https://www.gnu.org/licenses/>.
*/
#ifndef _MACRO_DRIVER_HH
#define _MACRO_DRIVER_HH
#ifndef MACRO_DRIVER_HH
#define MACRO_DRIVER_HH
#ifdef _PARSING_DRIVER_HH
#ifdef PARSING_DRIVER_HH
# error Impossible to include both ../ParsingDriver.hh and Driver.hh
#endif
#include "Parser.hh"
#include "Environment.hh"
#include "Expressions.hh"
#include "Parser.hh"
#include <stack>
......@@ -46,71 +46,83 @@ namespace macro
class TokenizerFlex : public TokenizerFlexLexer
{
public:
TokenizerFlex(istream *in) : TokenizerFlexLexer{in} { }
TokenizerFlex(istream* in) : TokenizerFlexLexer {in}
{
}
TokenizerFlex(const TokenizerFlex&) = delete;
TokenizerFlex(TokenizerFlex &&) = delete;
TokenizerFlex& operator=(const TokenizerFlex&) = delete;
TokenizerFlex & operator=(TokenizerFlex &&) = delete;
//! The main lexing function
Tokenizer::parser::token_type lex(Tokenizer::parser::semantic_type* yylval,
Tokenizer::parser::location_type* yylloc,
macro::Driver& driver);
static void location_increment(Tokenizer::parser::location_type* yylloc, const char* yytext);
};
//! Implements the macro expansion using a Flex scanner and a Bison parser
class Driver
{
public:
Environment &env;
private:
bool no_line_macro;
vector<DirectivePtr> statements;
stack<vector<DirectivePtr>> directive_stack;
public:
Driver(Environment &env_arg, bool no_line_macro_arg) :
env{env_arg}, no_line_macro(no_line_macro_arg) { }
Driver() = default;
Driver(const Driver&) = delete;
Driver(Driver &&) = delete;
Driver& operator=(const Driver&) = delete;
Driver & operator=(Driver &&) = delete;
//! Exception thrown when value of an unknown variable is requested
class UnknownVariable
struct UnknownVariable
{
public:
const string name;
explicit UnknownVariable(string name_arg) : name{move(name_arg)}
{
}
};
//! Starts parsing a file, returns output in out
void parse(const string &file_arg, const string &basename_arg, istream &modfile,
ostream &output, bool debug, const vector<pair<string, string>> &defines,
vector<filesystem::path> &paths_arg);
//! Starts parsing a file, modifies `env`, `paths` and `output`
//! as they are modified by various macro directives
void parse(const string& file, const istream& modfile, bool debug,
const vector<pair<string, string>>& defines, Environment& env,
vector<filesystem::path>& paths, ostream& output);
//! Name of main file being parsed
//! Name of main file being parsed (for error reporting purposes)
string file;
//! Basename of main file being parsed
string basename;
//! Reference to the lexer
unique_ptr<TokenizerFlex> lexer;
//! Error handler
void error(const Tokenizer::parser::location_type& location, const string& message) const;
inline bool inContext() const { return !directive_stack.empty(); }
[[nodiscard]] bool
inContext() const
{
return !directive_stack.empty();
}
inline void pushContext() { directive_stack.emplace(vector<DirectivePtr>()); }
void
pushContext()
{
directive_stack.emplace();
}
inline void pushContextTop(DirectivePtr statement) { directive_stack.top().emplace_back(move(statement)); }
void
pushContextTop(DirectivePtr statement)
{
directive_stack.top().emplace_back(move(statement));
}
inline void pushStatements(DirectivePtr statement) { statements.emplace_back(move(statement)); }
void
pushStatements(DirectivePtr statement)
{
statements.emplace_back(move(statement));
}
inline vector<DirectivePtr> popContext() { auto top = move(directive_stack.top()); directive_stack.pop(); return top; }
vector<DirectivePtr>
popContext()
{
auto top = move(directive_stack.top());
directive_stack.pop();
return top;
}
};
}
#endif
/*
* Copyright © 2019 Dynare Team
* Copyright © 2019-2024 Dynare Team
*
* This file is part of Dynare.
*
......@@ -14,28 +14,32 @@
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with Dynare. If not, see <http://www.gnu.org/licenses/>.
* along with Dynare. If not, see <https://www.gnu.org/licenses/>.
*/
#include <algorithm>
#include <cassert>
#include <ranges>
#include "Environment.hh"
#include "Expressions.hh"
using namespace macro;
void
Environment::define(VariablePtr var, ExpressionPtr value)
Environment::define(const VariablePtr& var, const ExpressionPtr& value)
{
string name = var->getName();
if (functions.find(name) != functions.end())
if (functions.contains(name))
throw StackTrace("Variable " + name + " was previously defined as a function");
variables[move(name)] = value->eval();
variables[move(name)] = value->eval(*this);
}
void
Environment::define(FunctionPtr func, ExpressionPtr value)
{
string name = func->getName();
if (variables.find(name) != variables.end())
if (variables.contains(name))
throw StackTrace("Variable " + name + " was previously defined as a variable");
functions[name] = {move(func), move(value)};
}
......@@ -43,8 +47,7 @@ Environment::define(FunctionPtr func, ExpressionPtr value)
ExpressionPtr
Environment::getVariable(const string& name) const
{
auto it = variables.find(name);
if (it != variables.end())
if (auto it = variables.find(name); it != variables.end())
return it->second;
if (!parent)
......@@ -53,11 +56,10 @@ Environment::getVariable(const string &name) const
return getGlobalEnv()->getVariable(name);
}
tuple<FunctionPtr, ExpressionPtr>
pair<FunctionPtr, ExpressionPtr>
Environment::getFunction(const string& name) const
{
auto it = functions.find(name);
if (it != functions.end())
if (auto it = functions.find(name); it != functions.end())
return it->second;
if (!parent)
......@@ -67,9 +69,9 @@ Environment::getFunction(const string &name) const
}
codes::BaseType
Environment::getType(const string &name)
Environment::getType(const string& name) const
{
return getVariable(name)->eval()->getType();
return getVariable(name)->eval(const_cast<Environment&>(*this))->getType();
}
bool
......@@ -104,26 +106,42 @@ void
Environment::print(ostream& output, const vector<string>& vars, int line, bool save) const
{
if (!save && !variables.empty())
output << "Macro Variables:" << endl;
output << "Macro Variables (at line " << line << "):" << endl;
// For sorting the symbols in a case-insensitive way, see #128
auto case_insensitive_string_less = [](const string& a, const string& b) {
return ranges::lexicographical_compare(
a, b, [](char c1, char c2) { return tolower(c1) < tolower(c2); });
};
if (vars.empty())
for (const auto & it : variables)
printVariable(output, it.first, line, save);
{
vector<string> variables_sorted;
ranges::copy(views::keys(variables), back_inserter(variables_sorted));
ranges::sort(variables_sorted, case_insensitive_string_less);
for (const auto& it : variables_sorted)
printVariable(output, it, line, save);
}
else
for (const auto& it : vars)
if (isVariableDefined(it))
printVariable(output, it, line, save);
if (!save && !functions.empty())
output << "Macro Functions:" << endl;
output << "Macro Functions (at line " << line << "):" << endl;
if (vars.empty())
for (const auto & it : functions)
printFunction(output, it.second, line, save);
{
vector<string> functions_sorted;
ranges::copy(views::keys(functions), back_inserter(functions_sorted));
ranges::sort(functions_sorted, case_insensitive_string_less);
for (const auto& it : functions_sorted)
printFunction(output, it, line, save);
}
else
for (const auto& it : vars)
if (isFunctionDefined(it))
printFunction(output, functions.find(it)->second, line, save);
printFunction(output, it, line, save);
if (parent)
parent->print(output, vars, line, save);
......@@ -132,27 +150,27 @@ Environment::print(ostream &output, const vector<string> &vars, int line, bool s
void
Environment::printVariable(ostream& output, const string& name, int line, bool save) const
{
output << (save ? "options_.macrovars_line_" + to_string(line) + "." : " " )
<< name << " = ";
getVariable(name)->eval()->print(output, save);
output << (save ? "options_.macrovars_line_" + to_string(line) + "." : " ") << name << " = ";
getVariable(name)->eval(const_cast<Environment&>(*this))->print(output, save);
if (save)
output << ";";
output << endl;
}
void
Environment::printFunction(ostream &output, const tuple<FunctionPtr, ExpressionPtr> & function, int line, bool save) const
Environment::printFunction(ostream& output, const string& name, int line, bool save) const
{
auto [func_signature, func_body] = getFunction(name);
output << (save ? "options_.macrovars_line_" + to_string(line) + ".function." : " ");
if (save)
{
get<0>(function)->printName(output);
func_signature->printName(output);
output << " = '";
}
get<0>(function)->print(output);
func_signature->print(output);
output << " = ";
get<1>(function)->print(output);
func_body->print(output);
if (save)
output << "';";
......
/*
* Copyright © 2019 Dynare Team
* Copyright © 2019-2024 Dynare Team
*
* This file is part of Dynare.
*
......@@ -14,15 +14,16 @@
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with Dynare. If not, see <http://www.gnu.org/licenses/>.
* along with Dynare. If not, see <https://www.gnu.org/licenses/>.
*/
#ifndef _ENVIRONMENT_HH
#define _ENVIRONMENT_HH
#ifndef ENVIRONMENT_HH
#define ENVIRONMENT_HH
#include "ForwardDeclarationsAndEnums.hh"
#include <map>
#include <optional>
#include <vector>
namespace macro
......@@ -30,25 +31,43 @@ namespace macro
class Environment
{
private:
const Environment *parent;
const Environment* parent {nullptr};
map<string, ExpressionPtr> variables;
map<string, tuple<FunctionPtr, ExpressionPtr>> functions;
map<string, pair<FunctionPtr, ExpressionPtr>> functions;
public:
Environment() : parent{nullptr} { }
Environment(const Environment *parent_arg) : parent{parent_arg} { }
void define(VariablePtr var, ExpressionPtr value);
Environment() = default;
Environment(const Environment* parent_arg) : parent {parent_arg}
{
}
void define(const VariablePtr& var, const ExpressionPtr& value);
void define(FunctionPtr func, ExpressionPtr value);
ExpressionPtr getVariable(const string &name) const;
tuple<FunctionPtr, ExpressionPtr> getFunction(const string &name) const;
codes::BaseType getType(const string &name);
bool isVariableDefined(const string &name) const noexcept;
bool isFunctionDefined(const string &name) const noexcept;
inline bool isSymbolDefined(const string &name) const noexcept { return isVariableDefined(name) || isFunctionDefined(name); }
void print(ostream &output, const vector<string> &vars, int line = -1, bool save = false) const;
/* The following two functions are not marked [[nodiscard]], because they are used without output
to check whether they return an exception or not. */
ExpressionPtr getVariable(const string& name) const; // NOLINT(modernize-use-nodiscard)
pair<FunctionPtr, ExpressionPtr> // NOLINT(modernize-use-nodiscard)
getFunction(const string& name) const;
[[nodiscard]] codes::BaseType getType(const string& name) const;
[[nodiscard]] bool isVariableDefined(const string& name) const noexcept;
[[nodiscard]] bool isFunctionDefined(const string& name) const noexcept;
[[nodiscard]] bool
isSymbolDefined(const string& name) const noexcept
{
return isVariableDefined(name) || isFunctionDefined(name);
}
void print(ostream& output, const vector<string>& vars, int line, bool save) const;
void printVariable(ostream& output, const string& name, int line, bool save) const;
void printFunction(ostream &output, const tuple<FunctionPtr, ExpressionPtr> & function, int line, bool save) const;
inline size_t size() const noexcept { return variables.size() + functions.size(); }
inline const Environment *getGlobalEnv() const noexcept { return parent == nullptr ? this : parent->getGlobalEnv(); }
void printFunction(ostream& output, const string& name, int line, bool save) const;
[[nodiscard]] size_t
size() const noexcept
{
return variables.size() + functions.size();
}
[[nodiscard]] const Environment*
getGlobalEnv() const noexcept
{
return parent == nullptr ? this : parent->getGlobalEnv();
}
};
}
#endif
/*
* Copyright © 2019 Dynare Team
* Copyright © 2019-2024 Dynare Team
*
* This file is part of Dynare.
*
......@@ -14,9 +14,12 @@
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with Dynare. If not, see <http://www.gnu.org/licenses/>.
* along with Dynare. If not, see <https://www.gnu.org/licenses/>.
*/
#include <numbers>
#include <utility>
#include "Expressions.hh"
using namespace macro;
......@@ -25,8 +28,8 @@ BoolPtr
BaseType::is_different(const BaseTypePtr& btp) const
{
if (*(this->is_equal(btp)))
return make_shared<Bool>(false, env);
return make_shared<Bool>(true, env);
return make_shared<Bool>(false);
return make_shared<Bool>(true);
}
BoolPtr
......@@ -34,30 +37,38 @@ Bool::is_equal(const BaseTypePtr &btp) const
{
auto btp2 = dynamic_pointer_cast<Bool>(btp);
if (!btp2)
return make_shared<Bool>(false, env);
return make_shared<Bool>(value == btp2->value, env);
return make_shared<Bool>(false);
return make_shared<Bool>(value == btp2->value);
}
BoolPtr
Bool::logical_and(const BaseTypePtr &btp) const
Bool::logical_and(const ExpressionPtr& ep, Environment& env) const
{
if (!value)
return make_shared<Bool>(false);
auto btp = ep->eval(env);
if (auto btp2 = dynamic_pointer_cast<Bool>(btp); btp2)
return make_shared<Bool>(value && *btp2, env);
return make_shared<Bool>(*btp2);
if (auto btp2 = dynamic_pointer_cast<Real>(btp); btp2)
return make_shared<Bool>(value && *btp2, env);
return make_shared<Bool>(*btp2);
throw StackTrace("Type mismatch for operands of && operator");
}
BoolPtr
Bool::logical_or(const BaseTypePtr &btp) const
Bool::logical_or(const ExpressionPtr& ep, Environment& env) const
{
if (value)
return make_shared<Bool>(true);
auto btp = ep->eval(env);
if (auto btp2 = dynamic_pointer_cast<Bool>(btp); btp2)
return make_shared<Bool>(value || *btp2, env);
return make_shared<Bool>(*btp2);
if (auto btp2 = dynamic_pointer_cast<Real>(btp); btp2)
return make_shared<Bool>(value || *btp2, env);
return make_shared<Bool>(*btp2);
throw StackTrace("Type mismatch for operands of || operator");
}
......@@ -65,7 +76,7 @@ Bool::logical_or(const BaseTypePtr &btp) const
BoolPtr
Bool::logical_not() const
{
return make_shared<Bool>(!value, env);
return make_shared<Bool>(!value);
}
BaseTypePtr
......@@ -74,7 +85,7 @@ Real::plus(const BaseTypePtr &btp) const
auto btp2 = dynamic_pointer_cast<Real>(btp);
if (!btp2)
throw StackTrace("Type mismatch for operands of + operator");
return make_shared<Real>(value + btp2->value, env);
return make_shared<Real>(value + btp2->value);
}
BaseTypePtr
......@@ -83,7 +94,7 @@ Real::minus(const BaseTypePtr &btp) const
auto btp2 = dynamic_pointer_cast<Real>(btp);
if (!btp2)
throw StackTrace("Type mismatch for operands of - operator");
return make_shared<Real>(value - btp2->value, env);
return make_shared<Real>(value - btp2->value);
}
BaseTypePtr
......@@ -92,7 +103,7 @@ Real::times(const BaseTypePtr &btp) const
auto btp2 = dynamic_pointer_cast<Real>(btp);
if (!btp2)
throw StackTrace("Type mismatch for operands of * operator");
return make_shared<Real>(value * btp2->value, env);
return make_shared<Real>(value * btp2->value);
}
BaseTypePtr
......@@ -101,7 +112,7 @@ Real::divide(const BaseTypePtr &btp) const
auto btp2 = dynamic_pointer_cast<Real>(btp);
if (!btp2)
throw StackTrace("Type mismatch for operands of / operator");
return make_shared<Real>(value / btp2->value, env);
return make_shared<Real>(value / btp2->value);
}
BaseTypePtr
......@@ -110,7 +121,7 @@ Real::power(const BaseTypePtr &btp) const
auto btp2 = dynamic_pointer_cast<Real>(btp);
if (!btp2)
throw StackTrace("Type mismatch for operands of ^ operator");
return make_shared<Real>(pow(value, btp2->value), env);
return make_shared<Real>(pow(value, btp2->value));
}
BoolPtr
......@@ -119,7 +130,7 @@ Real::is_less(const BaseTypePtr &btp) const
auto btp2 = dynamic_pointer_cast<Real>(btp);
if (!btp2)
throw StackTrace("Type mismatch for operands of < operator");
return make_shared<Bool>(isless(value, btp2->value), env);
return make_shared<Bool>(isless(value, btp2->value));
}
BoolPtr
......@@ -128,7 +139,7 @@ Real::is_greater(const BaseTypePtr &btp) const
auto btp2 = dynamic_pointer_cast<Real>(btp);
if (!btp2)
throw StackTrace("Type mismatch for operands of > operator");
return make_shared<Bool>(isgreater(value, btp2->value), env);
return make_shared<Bool>(isgreater(value, btp2->value));
}
BoolPtr
......@@ -137,7 +148,7 @@ Real::is_less_equal(const BaseTypePtr &btp) const
auto btp2 = dynamic_pointer_cast<Real>(btp);
if (!btp2)
throw StackTrace("Type mismatch for operands of <= operator");
return make_shared<Bool>(islessequal(value, btp2->value), env);
return make_shared<Bool>(islessequal(value, btp2->value));
}
BoolPtr
......@@ -146,7 +157,7 @@ Real::is_greater_equal(const BaseTypePtr &btp) const
auto btp2 = dynamic_pointer_cast<Real>(btp);
if (!btp2)
throw StackTrace("Type mismatch for operands of >= operator");
return make_shared<Bool>(isgreaterequal(value, btp2->value), env);
return make_shared<Bool>(isgreaterequal(value, btp2->value));
}
BoolPtr
......@@ -154,30 +165,38 @@ Real::is_equal(const BaseTypePtr &btp) const
{
auto btp2 = dynamic_pointer_cast<Real>(btp);
if (!btp2)
return make_shared<Bool>(false, env);
return make_shared<Bool>(value == btp2->value, env);
return make_shared<Bool>(false);
return make_shared<Bool>(value == btp2->value);
}
BoolPtr
Real::logical_and(const BaseTypePtr &btp) const
Real::logical_and(const ExpressionPtr& ep, Environment& env) const
{
if (!value)
return make_shared<Bool>(false);
auto btp = ep->eval(env);
if (auto btp2 = dynamic_pointer_cast<Real>(btp); btp2)
return make_shared<Bool>(value && *btp2, env);
return make_shared<Bool>(*btp2);
if (auto btp2 = dynamic_pointer_cast<Bool>(btp); btp2)
return make_shared<Bool>(value && *btp2, env);
return make_shared<Bool>(*btp2);
throw StackTrace("Type mismatch for operands of && operator");
}
BoolPtr
Real::logical_or(const BaseTypePtr &btp) const
Real::logical_or(const ExpressionPtr& ep, Environment& env) const
{
if (value)
return make_shared<Bool>(true);
auto btp = ep->eval(env);
if (auto btp2 = dynamic_pointer_cast<Real>(btp); btp2)
return make_shared<Bool>(value || *btp2, env);
return make_shared<Bool>(*btp2);
if (auto btp2 = dynamic_pointer_cast<Bool>(btp); btp2)
return make_shared<Bool>(value || *btp2, env);
return make_shared<Bool>(*btp2);
throw StackTrace("Type mismatch for operands of || operator");
}
......@@ -185,7 +204,7 @@ Real::logical_or(const BaseTypePtr &btp) const
BoolPtr
Real::logical_not() const
{
return make_shared<Bool>(!value, env);
return make_shared<Bool>(!value);
}
RealPtr
......@@ -194,7 +213,7 @@ Real::max(const BaseTypePtr &btp) const
auto btp2 = dynamic_pointer_cast<Real>(btp);
if (!btp2)
throw StackTrace("Type mismatch for operands of `max` operator");
return make_shared<Real>(std::max(value, btp2->value), env);
return make_shared<Real>(std::max(value, btp2->value));
}
RealPtr
......@@ -203,7 +222,7 @@ Real::min(const BaseTypePtr &btp) const
auto btp2 = dynamic_pointer_cast<Real>(btp);
if (!btp2)
throw StackTrace("Type mismatch for operands of `min` operator");
return make_shared<Real>(std::min(value, btp2->value), env);
return make_shared<Real>(std::min(value, btp2->value));
}
RealPtr
......@@ -212,7 +231,7 @@ Real::mod(const BaseTypePtr &btp) const
auto btp2 = dynamic_pointer_cast<Real>(btp);
if (!btp2)
throw StackTrace("Type mismatch for operands of `mod` operator");
return make_shared<Real>(std::fmod(value, btp2->value), env);
return make_shared<Real>(std::fmod(value, btp2->value));
}
RealPtr
......@@ -222,7 +241,9 @@ Real::normpdf(const BaseTypePtr &btp1, const BaseTypePtr &btp2) const
auto btp22 = dynamic_pointer_cast<Real>(btp2);
if (!btp12 || !btp22)
throw StackTrace("Type mismatch for operands of `normpdf` operator");
return make_shared<Real>((1/(btp22->value*std::sqrt(2*M_PI)*std::exp(pow((value-btp12->value)/btp22->value, 2)/2))), env);
return make_shared<Real>(1
/ (btp22->value * std::sqrt(2 * numbers::pi)
* std::exp(pow((value - btp12->value) / btp22->value, 2) / 2)));
}
RealPtr
......@@ -232,7 +253,8 @@ Real::normcdf(const BaseTypePtr &btp1, const BaseTypePtr &btp2) const
auto btp22 = dynamic_pointer_cast<Real>(btp2);
if (!btp12 || !btp22)
throw StackTrace("Type mismatch for operands of `normpdf` operator");
return make_shared<Real>((0.5*(1+std::erf((value-btp12->value)/btp22->value/M_SQRT2))), env);
return make_shared<Real>(
0.5 * (1 + std::erf((value - btp12->value) / btp22->value / numbers::sqrt2)));
}
BaseTypePtr
......@@ -241,7 +263,7 @@ String::plus(const BaseTypePtr &btp) const
auto btp2 = dynamic_pointer_cast<String>(btp);
if (!btp2)
throw StackTrace("Type mismatch for operands of + operator");
return make_shared<String>(value + btp2->value, env);
return make_shared<String>(value + btp2->value);
}
BoolPtr
......@@ -250,7 +272,7 @@ String::is_less(const BaseTypePtr &btp) const
auto btp2 = dynamic_pointer_cast<String>(btp);
if (!btp2)
throw StackTrace("Type mismatch for operands of < operator");
return make_shared<Bool>(value < btp2->value, env);
return make_shared<Bool>(value < btp2->value);
}
BoolPtr
......@@ -259,7 +281,7 @@ String::is_greater(const BaseTypePtr &btp) const
auto btp2 = dynamic_pointer_cast<String>(btp);
if (!btp2)
throw StackTrace("Type mismatch for operands of > operator");
return make_shared<Bool>(value > btp2->value, env);
return make_shared<Bool>(value > btp2->value);
}
BoolPtr
......@@ -268,7 +290,7 @@ String::is_less_equal(const BaseTypePtr &btp) const
auto btp2 = dynamic_pointer_cast<String>(btp);
if (!btp2)
throw StackTrace("Type mismatch for operands of <= operator");
return make_shared<Bool>(value <= btp2->value, env);
return make_shared<Bool>(value <= btp2->value);
}
BoolPtr
......@@ -277,7 +299,7 @@ String::is_greater_equal(const BaseTypePtr &btp) const
auto btp2 = dynamic_pointer_cast<String>(btp);
if (!btp2)
throw StackTrace("Type mismatch for operands of >= operator");
return make_shared<Bool>(value >= btp2->value, env);
return make_shared<Bool>(value >= btp2->value);
}
BoolPtr
......@@ -285,20 +307,20 @@ String::is_equal(const BaseTypePtr &btp) const
{
auto btp2 = dynamic_pointer_cast<String>(btp);
if (!btp2)
return make_shared<Bool>(false, env);
return make_shared<Bool>(value == btp2->value, env);
return make_shared<Bool>(false);
return make_shared<Bool>(value == btp2->value);
}
BoolPtr
String::cast_bool() const
String::cast_bool([[maybe_unused]] Environment& env) const
{
auto f = [](const char& a, const char& b) { return (tolower(a) == tolower(b)); };
auto f = [](const char& a, const char& b) { return tolower(a) == tolower(b); };
if (string tf = "true"; equal(value.begin(), value.end(), tf.begin(), tf.end(), f))
return make_shared<Bool>(true, env);
if (ranges::equal(value, "true"s, f))
return make_shared<Bool>(true);
if (string tf = "false"; equal(value.begin(), value.end(), tf.begin(), tf.end(), f))
return make_shared<Bool>(false, env);
if (ranges::equal(value, "false"s, f))
return make_shared<Bool>(false);
try
{
......@@ -306,7 +328,7 @@ String::cast_bool() const
double value_d = stod(value, &pos);
if (pos != value.length())
throw StackTrace("Entire string not converted");
return make_shared<Bool>(static_cast<bool>(value_d), env);
return make_shared<Bool>(static_cast<bool>(value_d));
}
catch (...)
{
......@@ -315,7 +337,7 @@ String::cast_bool() const
}
RealPtr
String::cast_real() const
String::cast_real([[maybe_unused]] Environment& env) const
{
try
{
......@@ -323,7 +345,7 @@ String::cast_real() const
double value_d = stod(value, &pos);
if (pos != value.length())
throw StackTrace("Entire string not converted");
return make_shared<Real>(value_d, env);
return make_shared<Real>(value_d);
}
catch (...)
{
......@@ -340,7 +362,7 @@ Array::plus(const BaseTypePtr &btp) const
vector<ExpressionPtr> arr_copy {arr};
arr_copy.insert(arr_copy.end(), btp2->arr.begin(), btp2->arr.end());
return make_shared<Array>(arr_copy, env);
return make_shared<Array>(arr_copy);
}
BaseTypePtr
......@@ -363,7 +385,7 @@ Array::minus(const BaseTypePtr &btp) const
if (it2 == btp2->arr.cend())
arr_copy.emplace_back(itbtp);
}
return make_shared<Array>(arr_copy, env);
return make_shared<Array>(arr_copy);
}
BaseTypePtr
......@@ -393,10 +415,10 @@ Array::times(const BaseTypePtr &btp) const
else
throw StackTrace("Array::times: unsupported type on rhs");
values.emplace_back(make_shared<Tuple>(new_tuple, env));
values.emplace_back(make_shared<Tuple>(new_tuple));
}
return make_shared<Array>(values, env);
return make_shared<Array>(values);
}
BaseTypePtr
......@@ -406,11 +428,11 @@ Array::power(const BaseTypePtr &btp) const
if (!btp2 || !*(btp2->isinteger()))
throw StackTrace("The second argument of the power operator (^) must be an integer");
auto retval = make_shared<Array>(arr, env);
auto retval = make_shared<Array>(arr);
for (int i = 1; i < *btp2; i++)
{
auto btpv = retval->times(make_shared<Array>(arr, env));
retval = make_shared<Array>(dynamic_pointer_cast<Array>(btpv)->getValue(), env);
auto btpv = retval->times(make_shared<Array>(arr));
retval = make_shared<Array>(dynamic_pointer_cast<Array>(btpv)->getValue());
}
return retval;
}
......@@ -420,19 +442,19 @@ Array::is_equal(const BaseTypePtr &btp) const
{
auto btp2 = dynamic_pointer_cast<Array>(btp);
if (!btp2)
return make_shared<Bool>(false, env);
return make_shared<Bool>(false);
if (arr.size() != btp2->arr.size())
return make_shared<Bool>(false, env);
return make_shared<Bool>(false);
for (size_t i = 0; i < arr.size(); i++)
{
auto bt = dynamic_pointer_cast<BaseType>(arr[i]);
auto bt2 = dynamic_pointer_cast<BaseType>(btp2->arr[i]);
if (!*(bt->is_equal(bt2)))
return make_shared<Bool>(false, env);
return make_shared<Bool>(false);
}
return make_shared<Bool>(true, env);
return make_shared<Bool>(true);
}
ArrayPtr
......@@ -463,7 +485,7 @@ Array::set_union(const BaseTypePtr &btp) const
if (!found)
new_values.push_back(it);
}
return make_shared<Array>(new_values, env);
return make_shared<Array>(new_values);
}
ArrayPtr
......@@ -491,7 +513,7 @@ Array::set_intersection(const BaseTypePtr &btp) const
}
}
}
return make_shared<Array>(new_values, env);
return make_shared<Array>(new_values);
}
BoolPtr
......@@ -503,9 +525,9 @@ Array::contains(const BaseTypePtr &btp) const
if (!v2)
throw StackTrace("Type mismatch for operands of in operator");
if (*(v2->is_equal(btp)))
return make_shared<Bool>(true, env);
return make_shared<Bool>(true);
}
return make_shared<Bool>(false, env);
return make_shared<Bool>(false);
}
RealPtr
......@@ -519,23 +541,23 @@ Array::sum() const
throw StackTrace("Type mismatch for operands of in operator");
retval += *v2;
}
return make_shared<Real>(retval, env);
return make_shared<Real>(retval);
}
BoolPtr
Array::cast_bool() const
Array::cast_bool(Environment& env) const
{
if (arr.size() != 1)
throw StackTrace("Array must be of size 1 to be cast to a boolean");
return arr.at(0)->eval()->cast_bool();
return arr.at(0)->eval(env)->cast_bool(env);
}
RealPtr
Array::cast_real() const
Array::cast_real(Environment& env) const
{
if (arr.size() != 1)
throw StackTrace("Array must be of size 1 to be cast to a real");
return arr.at(0)->eval()->cast_real();
return arr.at(0)->eval(env)->cast_real(env);
}
BoolPtr
......@@ -543,19 +565,19 @@ Tuple::is_equal(const BaseTypePtr &btp) const
{
auto btp2 = dynamic_pointer_cast<Tuple>(btp);
if (!btp2)
return make_shared<Bool>(false, env);
return make_shared<Bool>(false);
if (tup.size() != btp2->tup.size())
return make_shared<Bool>(false, env);
return make_shared<Bool>(false);
for (size_t i = 0; i < tup.size(); i++)
{
auto bt = dynamic_pointer_cast<BaseType>(tup[i]);
auto bt2 = dynamic_pointer_cast<BaseType>(btp2->tup[i]);
if (!*(bt->is_equal(bt2)))
return make_shared<Bool>(false, env);
return make_shared<Bool>(false);
}
return make_shared<Bool>(true, env);
return make_shared<Bool>(true);
}
BoolPtr
......@@ -567,84 +589,91 @@ Tuple::contains(const BaseTypePtr &btp) const
if (!v2)
throw StackTrace("Type mismatch for operands of in operator");
if (*(v2->is_equal(btp)))
return make_shared<Bool>(true, env);
return make_shared<Bool>(true);
}
return make_shared<Bool>(false, env);
return make_shared<Bool>(false);
}
BoolPtr
Tuple::cast_bool() const
Tuple::cast_bool(Environment& env) const
{
if (tup.size() != 1)
throw StackTrace("Tuple must be of size 1 to be cast to a boolean");
return tup.at(0)->eval()->cast_bool();
return tup.at(0)->eval(env)->cast_bool(env);
}
RealPtr
Tuple::cast_real() const
Tuple::cast_real(Environment& env) const
{
if (tup.size() != 1)
throw StackTrace("Tuple must be of size 1 to be cast to a real");
return tup.at(0)->eval()->cast_real();
return tup.at(0)->eval(env)->cast_real(env);
}
BaseTypePtr
Range::eval()
Range::eval(Environment& env) const
{
RealPtr incdbl = make_shared<Real>(1, env);
RealPtr incdbl = make_shared<Real>(1);
if (inc)
incdbl = dynamic_pointer_cast<Real>(inc->eval());
RealPtr startdbl = dynamic_pointer_cast<Real>(start->eval());
RealPtr enddbl = dynamic_pointer_cast<Real>(end->eval());
incdbl = dynamic_pointer_cast<Real>(inc->eval(env));
RealPtr startdbl = dynamic_pointer_cast<Real>(start->eval(env));
RealPtr enddbl = dynamic_pointer_cast<Real>(end->eval(env));
if (!startdbl || !enddbl || !incdbl)
throw StackTrace("To create an array from a range using the colon operator, "
"the arguments must evaluate to reals");
vector<ExpressionPtr> arr;
// We do want a float counter, because that’s the macro-language semantics
// NOLINTBEGIN(clang-analyzer-security.FloatLoopCounter)
if (*incdbl > 0 && *startdbl <= *enddbl)
for (double i = *startdbl; i <= *enddbl; i += *incdbl)
arr.emplace_back(make_shared<Real>(i, env));
arr.emplace_back(make_shared<Real>(i));
else if (*startdbl >= *enddbl && *incdbl < 0)
for (double i = *startdbl; i >= *enddbl; i += *incdbl)
arr.emplace_back(make_shared<Real>(i, env));
arr.emplace_back(make_shared<Real>(i));
// NOLINTEND(clang-analyzer-security.FloatLoopCounter)
return make_shared<Array>(arr, env, location);
return make_shared<Array>(arr, location);
}
BaseTypePtr
Array::eval()
Array::eval(Environment& env) const
{
vector<ExpressionPtr> retval;
retval.reserve(arr.size());
for (const auto& it : arr)
retval.emplace_back(it->eval());
return make_shared<Array>(retval, env);
retval.emplace_back(it->eval(env));
return make_shared<Array>(retval);
}
BaseTypePtr
Tuple::eval()
Tuple::eval(Environment& env) const
{
vector<ExpressionPtr> retval;
retval.reserve(tup.size());
for (const auto& it : tup)
retval.emplace_back(it->eval());
return make_shared<Tuple>(retval, env);
retval.emplace_back(it->eval(env));
return make_shared<Tuple>(retval);
}
BaseTypePtr
Variable::eval()
Variable::eval(Environment& env) const
{
if (indices && !indices->empty())
{
ArrayPtr map = dynamic_pointer_cast<Array>(indices->eval());
vector<ExpressionPtr> index = map->getValue();
ArrayPtr map = dynamic_pointer_cast<Array>(indices->eval(env));
vector<int> ind;
for (const auto & it : index)
for (const auto& it : map->getValue())
// Necessary to handle indexes like: y[1:2,2]
// In general this evaluates to [[1:2],2] but when subscripting we want to expand it to [1,2,2]
// In general this evaluates to [[1:2],2] but when subscripting we want to expand it to
// [1,2,2]
if (auto db = dynamic_pointer_cast<Real>(it); db)
{
if (!*(db->isinteger()))
throw StackTrace("variable", "When indexing a variable you must pass "
"an int or an int array", location);
throw StackTrace("variable",
"When indexing a variable you must pass "
"an int or an int array",
location);
ind.emplace_back(*db);
}
else if (dynamic_pointer_cast<Array>(it))
......@@ -652,16 +681,22 @@ Variable::eval()
if (db = dynamic_pointer_cast<Real>(it1); db)
{
if (!*(db->isinteger()))
throw StackTrace("variable", "When indexing a variable you must pass "
"an int or an int array", location);
throw StackTrace("variable",
"When indexing a variable you must pass "
"an int or an int array",
location);
ind.emplace_back(*db);
}
else
throw StackTrace("variable", "You cannot index a variable with a "
"nested array", location);
throw StackTrace("variable",
"You cannot index a variable with a "
"nested array",
location);
else
throw StackTrace("variable", "You can only index a variable with an int or "
"an int array", location);
throw StackTrace("variable",
"You can only index a variable with an int or "
"an int array",
location);
switch (env.getType(name))
{
......@@ -675,8 +710,7 @@ Variable::eval()
throw StackTrace("variable", "Internal Error: Range: should not arrive here", location);
case codes::BaseType::String:
{
string orig_string =
dynamic_pointer_cast<String>(env.getVariable(name))->to_string();
string orig_string = dynamic_pointer_cast<String>(env.getVariable(name))->to_string();
string retvals;
for (auto it : ind)
try
......@@ -687,7 +721,7 @@ Variable::eval()
{
throw StackTrace("variable", "Index out of range", location);
}
return make_shared<String>(retvals, env);
return make_shared<String>(retvals);
}
case codes::BaseType::Array:
{
......@@ -696,7 +730,7 @@ Variable::eval()
for (auto it : ind)
try
{
retval.emplace_back(ap->at(it - 1)->eval());
retval.emplace_back(ap->at(it - 1)->eval(env));
}
catch (const out_of_range& ex)
{
......@@ -706,15 +740,15 @@ Variable::eval()
if (retval.size() == 1)
return retval.at(0);
vector<ExpressionPtr> retvala(retval.begin(), retval.end());
return make_shared<Array>(retvala, env);
return make_shared<Array>(retvala);
}
}
}
return env.getVariable(name)->eval();
return env.getVariable(name)->eval(env);
}
BaseTypePtr
Function::eval()
Function::eval(Environment& env) const
{
FunctionPtr func;
ExpressionPtr body;
......@@ -731,17 +765,19 @@ Function::eval()
}
if (func->args.size() != args.size())
throw StackTrace("Function", "The number of arguments used to call " + name +
" does not match the number used in its definition", location);
throw StackTrace("Function",
"The number of arguments used to call " + name
+ " does not match the number used in its definition",
location);
try
{
for (size_t i = 0; i < func->args.size(); i++)
{
VariablePtr mvp = dynamic_pointer_cast<Variable>(func->args.at(i));
env.define(mvp, args.at(i)->eval());
env.define(mvp, args.at(i)->eval(env));
}
auto retval = body->eval();
auto retval = body->eval(env);
env = env_orig;
return retval;
}
......@@ -753,91 +789,90 @@ Function::eval()
}
BaseTypePtr
UnaryOp::eval()
UnaryOp::eval(Environment& env) const
{
try
{
auto argbt = arg->eval();
switch (op_code)
{
case codes::UnaryOp::cast_bool:
return argbt->cast_bool();
return arg->eval(env)->cast_bool(env);
case codes::UnaryOp::cast_real:
return argbt->cast_real();
return arg->eval(env)->cast_real(env);
case codes::UnaryOp::cast_string:
return argbt->cast_string();
return arg->eval(env)->cast_string();
case codes::UnaryOp::cast_tuple:
return argbt->cast_tuple();
return arg->eval(env)->cast_tuple();
case codes::UnaryOp::cast_array:
return argbt->cast_array();
return arg->eval(env)->cast_array();
case codes::UnaryOp::logical_not:
return argbt->logical_not();
return arg->eval(env)->logical_not();
case codes::UnaryOp::unary_minus:
return argbt->unary_minus();
return arg->eval(env)->unary_minus();
case codes::UnaryOp::unary_plus:
return argbt->unary_plus();
return arg->eval(env)->unary_plus();
case codes::UnaryOp::length:
return argbt->length();
return arg->eval(env)->length();
case codes::UnaryOp::isempty:
return argbt->isempty();
return arg->eval(env)->isempty();
case codes::UnaryOp::isboolean:
return argbt->isboolean();
return arg->eval(env)->isboolean();
case codes::UnaryOp::isreal:
return argbt->isreal();
return arg->eval(env)->isreal();
case codes::UnaryOp::isstring:
return argbt->isstring();
return arg->eval(env)->isstring();
case codes::UnaryOp::istuple:
return argbt->istuple();
return arg->eval(env)->istuple();
case codes::UnaryOp::isarray:
return argbt->isarray();
return arg->eval(env)->isarray();
case codes::UnaryOp::exp:
return argbt->exp();
return arg->eval(env)->exp();
case codes::UnaryOp::ln:
return argbt->ln();
return arg->eval(env)->ln();
case codes::UnaryOp::log10:
return argbt->log10();
return arg->eval(env)->log10();
case codes::UnaryOp::sin:
return argbt->sin();
return arg->eval(env)->sin();
case codes::UnaryOp::cos:
return argbt->cos();
return arg->eval(env)->cos();
case codes::UnaryOp::tan:
return argbt->tan();
return arg->eval(env)->tan();
case codes::UnaryOp::asin:
return argbt->asin();
return arg->eval(env)->asin();
case codes::UnaryOp::acos:
return argbt->acos();
return arg->eval(env)->acos();
case codes::UnaryOp::atan:
return argbt->atan();
return arg->eval(env)->atan();
case codes::UnaryOp::sqrt:
return argbt->sqrt();
return arg->eval(env)->sqrt();
case codes::UnaryOp::cbrt:
return argbt->cbrt();
return arg->eval(env)->cbrt();
case codes::UnaryOp::sign:
return argbt->sign();
return arg->eval(env)->sign();
case codes::UnaryOp::floor:
return argbt->floor();
return arg->eval(env)->floor();
case codes::UnaryOp::ceil:
return argbt->ceil();
return arg->eval(env)->ceil();
case codes::UnaryOp::trunc:
return argbt->trunc();
return arg->eval(env)->trunc();
case codes::UnaryOp::sum:
return argbt->sum();
return arg->eval(env)->sum();
case codes::UnaryOp::erf:
return argbt->erf();
return arg->eval(env)->erf();
case codes::UnaryOp::erfc:
return argbt->erfc();
return arg->eval(env)->erfc();
case codes::UnaryOp::gamma:
return argbt->gamma();
return arg->eval(env)->gamma();
case codes::UnaryOp::lgamma:
return argbt->lgamma();
return arg->eval(env)->lgamma();
case codes::UnaryOp::round:
return argbt->round();
return arg->eval(env)->round();
case codes::UnaryOp::normpdf:
return argbt->normpdf();
return arg->eval(env)->normpdf();
case codes::UnaryOp::normcdf:
return argbt->normcdf();
return arg->eval(env)->normcdf();
case codes::UnaryOp::defined:
return argbt->defined();
return arg->eval(env)->defined(env);
}
}
catch (StackTrace& ex)
......@@ -854,52 +889,50 @@ UnaryOp::eval()
}
BaseTypePtr
BinaryOp::eval()
BinaryOp::eval(Environment& env) const
{
try
{
auto arg1bt = arg1->eval();
auto arg2bt = arg2->eval();
switch (op_code)
{
case codes::BinaryOp::plus:
return arg1bt->plus(arg2bt);
return arg1->eval(env)->plus(arg2->eval(env));
case codes::BinaryOp::minus:
return arg1bt->minus(arg2bt);
return arg1->eval(env)->minus(arg2->eval(env));
case codes::BinaryOp::times:
return arg1bt->times(arg2bt);
return arg1->eval(env)->times(arg2->eval(env));
case codes::BinaryOp::divide:
return arg1bt->divide(arg2bt);
return arg1->eval(env)->divide(arg2->eval(env));
case codes::BinaryOp::power:
return arg1bt->power( arg2bt);
return arg1->eval(env)->power(arg2->eval(env));
case codes::BinaryOp::equal_equal:
return arg1bt->is_equal(arg2bt);
return arg1->eval(env)->is_equal(arg2->eval(env));
case codes::BinaryOp::not_equal:
return arg1bt->is_different(arg2bt);
return arg1->eval(env)->is_different(arg2->eval(env));
case codes::BinaryOp::less:
return arg1bt->is_less(arg2bt);
return arg1->eval(env)->is_less(arg2->eval(env));
case codes::BinaryOp::greater:
return arg1bt->is_greater(arg2bt);
return arg1->eval(env)->is_greater(arg2->eval(env));
case codes::BinaryOp::less_equal:
return arg1bt->is_less_equal(arg2bt);
return arg1->eval(env)->is_less_equal(arg2->eval(env));
case codes::BinaryOp::greater_equal:
return arg1bt->is_greater_equal(arg2bt);
return arg1->eval(env)->is_greater_equal(arg2->eval(env));
case codes::BinaryOp::logical_and:
return arg1bt->logical_and(arg2bt);
return arg1->eval(env)->logical_and(arg2, env);
case codes::BinaryOp::logical_or:
return arg1bt->logical_or(arg2bt);
return arg1->eval(env)->logical_or(arg2, env);
case codes::BinaryOp::in:
return arg2bt->contains(arg1bt);
return arg2->eval(env)->contains(arg1->eval(env));
case codes::BinaryOp::set_union:
return arg1bt->set_union(arg2bt);
return arg1->eval(env)->set_union(arg2->eval(env));
case codes::BinaryOp::set_intersection:
return arg1bt->set_intersection(arg2bt);
return arg1->eval(env)->set_intersection(arg2->eval(env));
case codes::BinaryOp::max:
return arg1bt->max(arg2bt);
return arg1->eval(env)->max(arg2->eval(env));
case codes::BinaryOp::min:
return arg1bt->min(arg2bt);
return arg1->eval(env)->min(arg2->eval(env));
case codes::BinaryOp::mod:
return arg1bt->mod(arg2bt);
return arg1->eval(env)->mod(arg2->eval(env));
}
}
catch (StackTrace& ex)
......@@ -916,19 +949,16 @@ BinaryOp::eval()
}
BaseTypePtr
TrinaryOp::eval()
TrinaryOp::eval(Environment& env) const
{
try
{
auto arg1bt = arg1->eval();
auto arg2bt = arg2->eval();
auto arg3bt = arg3->eval();
switch (op_code)
{
case codes::TrinaryOp::normpdf:
return arg1bt->normpdf(arg2bt, arg3bt);
return arg1->eval(env)->normpdf(arg2->eval(env), arg3->eval(env));
case codes::TrinaryOp::normcdf:
return arg1bt->normcdf(arg2bt, arg3bt);
return arg1->eval(env)->normcdf(arg2->eval(env), arg3->eval(env));
}
}
catch (StackTrace& ex)
......@@ -945,21 +975,23 @@ TrinaryOp::eval()
}
BaseTypePtr
Comprehension::eval()
Comprehension::eval(Environment& env) const
{
ArrayPtr input_set;
VariablePtr vp;
TuplePtr mt;
try
{
input_set = dynamic_pointer_cast<Array>(c_set->eval());
input_set = dynamic_pointer_cast<Array>(c_set->eval(env));
if (!input_set)
throw StackTrace("Comprehension", "The input set must evaluate to an array", location);
vp = dynamic_pointer_cast<Variable>(c_vars);
mt = dynamic_pointer_cast<Tuple>(c_vars);
if ((!vp && !mt) || (vp && mt))
throw StackTrace("Comprehension", "the loop variables must be either "
"a tuple or a variable", location);
throw StackTrace("Comprehension",
"the loop variables must be either "
"a tuple or a variable",
location);
}
catch (StackTrace& ex)
{
......@@ -973,40 +1005,45 @@ Comprehension::eval()
auto btp = dynamic_pointer_cast<BaseType>(input_set->at(i));
if (vp)
env.define(vp, btp);
else
if (btp->getType() == codes::BaseType::Tuple)
else if (btp->getType() == codes::BaseType::Tuple)
{
auto mt2 = dynamic_pointer_cast<Tuple>(btp);
if (mt->size() != mt2->size())
throw StackTrace("Comprehension", "The number of elements in the input "
throw StackTrace("Comprehension",
"The number of elements in the input "
" set tuple are not the same as the number of elements in "
"the output expression tuple", location);
"the output expression tuple",
location);
for (size_t j = 0; j < mt->size(); j++)
{
auto vp2 = dynamic_pointer_cast<Variable>(mt->at(j));
if (!vp2)
throw StackTrace("Comprehension", "Output expression tuple must be "
"comprised of variable names", location);
throw StackTrace("Comprehension",
"Output expression tuple must be "
"comprised of variable names",
location);
env.define(vp2, mt2->at(j));
}
}
else
throw StackTrace("Comprehension", "assigning to tuple in output expression "
"but input expression does not contain tuples", location);
throw StackTrace("Comprehension",
"assigning to tuple in output expression "
"but input expression does not contain tuples",
location);
if (!c_when)
if (!c_expr)
throw StackTrace("Comprehension", "Internal Error: Impossible case", location);
else
values.emplace_back(c_expr->clone()->eval());
values.emplace_back(c_expr->eval(env));
else
{
RealPtr dp;
BoolPtr bp;
try
{
auto tmp = c_when->eval();
auto tmp = c_when->eval(env);
dp = dynamic_pointer_cast<Real>(tmp);
bp = dynamic_pointer_cast<Bool>(tmp);
if (!bp && !dp)
......@@ -1018,51 +1055,15 @@ Comprehension::eval()
throw;
}
if ((bp && *bp) || (dp && *dp))
{
if (c_expr)
values.emplace_back(c_expr->clone()->eval());
values.emplace_back(c_expr->eval(env));
else
values.emplace_back(btp);
}
}
return make_shared<Array>(values, env);
}
ExpressionPtr
Tuple::clone() const noexcept
{
vector<ExpressionPtr> tup_copy;
for (const auto & it : tup)
tup_copy.emplace_back(it->clone());
return make_shared<Tuple>(tup_copy, env, location);
}
ExpressionPtr
Array::clone() const noexcept
{
vector<ExpressionPtr> arr_copy;
for (const auto & it : arr)
arr_copy.emplace_back(it->clone());
return make_shared<Array>(arr_copy, env, location);
}
ExpressionPtr
Function::clone() const noexcept
{
vector<ExpressionPtr> args_copy;
for (const auto & it : args)
args_copy.emplace_back(it->clone());
return make_shared<Function>(name, args_copy, env, location);
}
ExpressionPtr
Comprehension::clone() const noexcept
{
if (c_expr && c_when)
return make_shared<Comprehension>(c_expr->clone(), c_vars->clone(), c_set->clone(), c_when->clone(), env, location);
else if (c_expr)
return make_shared<Comprehension>(c_expr->clone(), c_vars->clone(), c_set->clone(), env, location);
else
return make_shared<Comprehension>(true, c_vars->clone(), c_set->clone(), c_when->clone(), env, location);
return make_shared<Array>(values);
}
string
......@@ -1251,9 +1252,11 @@ TrinaryOp::to_string() const noexcept
switch (op_code)
{
case codes::TrinaryOp::normpdf:
return "normpdf(" + arg1->to_string() + ", " + arg2->to_string() + ", " + arg3->to_string() + ")";
return "normpdf(" + arg1->to_string() + ", " + arg2->to_string() + ", " + arg3->to_string()
+ ")";
case codes::TrinaryOp::normcdf:
return "normcdf(" + arg1->to_string() + ", " + arg2->to_string() + ", " + arg3->to_string() + ")";
return "normcdf(" + arg1->to_string() + ", " + arg2->to_string() + ", " + arg3->to_string()
+ ")";
}
// Suppress GCC warning
exit(EXIT_FAILURE);
......@@ -1274,20 +1277,18 @@ Comprehension::to_string() const noexcept
void
String::print(ostream& output, bool matlab_output) const noexcept
{
output << (matlab_output ? "'" : R"(")")
<< value
<< (matlab_output ? "'" : R"(")");
output << (matlab_output ? "'" : R"(")") << value << (matlab_output ? "'" : R"(")");
}
void
Array::print(ostream& output, bool matlab_output) const noexcept
{
output << (matlab_output ? "{" : "[");
for (auto it = arr.begin(); it != arr.end(); it++)
for (bool printed_something {false}; const auto& e : arr)
{
if (it != arr.begin())
if (exchange(printed_something, true))
output << ", ";
(*it)->print(output, matlab_output);
e->print(output, matlab_output);
}
output << (matlab_output ? "}" : "]");
}
......@@ -1296,11 +1297,11 @@ void
Tuple::print(ostream& output, bool matlab_output) const noexcept
{
output << (matlab_output ? "{" : "(");
for (auto it = tup.begin(); it != tup.end(); it++)
for (bool printed_something {false}; const auto& e : tup)
{
if (it != tup.begin())
if (exchange(printed_something, true))
output << ", ";
(*it)->print(output, matlab_output);
e->print(output, matlab_output);
}
output << (matlab_output ? "}" : ")");
}
......@@ -1309,11 +1310,11 @@ void
Function::printArgs(ostream& output) const noexcept
{
output << "(";
for (auto it = args.begin(); it != args.end(); it++)
for (bool printed_something {false}; const auto& e : args)
{
if (it != args.begin())
if (exchange(printed_something, true))
output << ", ";
(*it)->print(output);
e->print(output);
}
output << ")";
}
......@@ -1444,24 +1445,18 @@ UnaryOp::print(ostream &output, bool matlab_output) const noexcept
arg->print(output, matlab_output);
if (op_code != codes::UnaryOp::cast_bool
&& op_code != codes::UnaryOp::cast_real
&& op_code != codes::UnaryOp::cast_string
&& op_code != codes::UnaryOp::cast_tuple
&& op_code != codes::UnaryOp::cast_array
&& op_code != codes::UnaryOp::logical_not
&& op_code != codes::UnaryOp::unary_plus
&& op_code != codes::UnaryOp::unary_minus)
if (op_code != codes::UnaryOp::cast_bool && op_code != codes::UnaryOp::cast_real
&& op_code != codes::UnaryOp::cast_string && op_code != codes::UnaryOp::cast_tuple
&& op_code != codes::UnaryOp::cast_array && op_code != codes::UnaryOp::logical_not
&& op_code != codes::UnaryOp::unary_plus && op_code != codes::UnaryOp::unary_minus)
output << ")";
}
void
BinaryOp::print(ostream& output, bool matlab_output) const noexcept
{
if (op_code == codes::BinaryOp::set_union
|| op_code == codes::BinaryOp::set_intersection
|| op_code == codes::BinaryOp::max
|| op_code == codes::BinaryOp::min
if (op_code == codes::BinaryOp::set_union || op_code == codes::BinaryOp::set_intersection
|| op_code == codes::BinaryOp::max || op_code == codes::BinaryOp::min
|| op_code == codes::BinaryOp::mod)
{
switch (op_code)
......@@ -1588,4 +1583,3 @@ Comprehension::print(ostream &output, bool matlab_output) const noexcept
}
output << "]";
}