From 49277dbbf2c7c92af04a95ae35feda4bafdc84d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Villemot?= <sebastien@dynare.org> Date: Wed, 8 Jun 2022 12:45:33 +0200 Subject: [PATCH] StaticModel::getDerivID() now throws an exception when arg is not endo or parameter Previously it would return -1, which is bad practice. --- src/ExprNode.cc | 25 ++++++++++++++++++------- src/StaticModel.cc | 4 +++- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/src/ExprNode.cc b/src/ExprNode.cc index 09c4b12a..5f837160 100644 --- a/src/ExprNode.cc +++ b/src/ExprNode.cc @@ -767,13 +767,16 @@ VariableNode::prepareForDerivation() // Fill in non_null_derivatives switch (get_type()) { - case SymbolType::endogenous: case SymbolType::exogenous: case SymbolType::exogenousDet: - case SymbolType::parameter: case SymbolType::trend: case SymbolType::logTrend: - // For a variable or a parameter, the only non-null derivative is with respect to itself + // In static models, exogenous and trends do not have deriv IDs + if (dynamic_cast<StaticModel *>(&datatree)) + break; + [[fallthrough]]; + case SymbolType::endogenous: + case SymbolType::parameter: non_null_derivatives.insert(datatree.getDerivID(symb_id, lag)); break; case SymbolType::modelLocalVariable: @@ -803,12 +806,16 @@ VariableNode::computeDerivative(int deriv_id) { switch (get_type()) { - case SymbolType::endogenous: case SymbolType::exogenous: case SymbolType::exogenousDet: - case SymbolType::parameter: case SymbolType::trend: case SymbolType::logTrend: + // In static models, exogenous and trends do not have deriv IDs + if (dynamic_cast<StaticModel *>(&datatree)) + return datatree.Zero; + [[fallthrough]]; + case SymbolType::endogenous: + case SymbolType::parameter: if (deriv_id == datatree.getDerivID(symb_id, lag)) return datatree.One; else @@ -1338,12 +1345,16 @@ VariableNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode * { switch (get_type()) { - case SymbolType::endogenous: case SymbolType::exogenous: case SymbolType::exogenousDet: - case SymbolType::parameter: case SymbolType::trend: case SymbolType::logTrend: + // In static models, exogenous and trends do not have deriv IDs + if (dynamic_cast<StaticModel *>(&datatree)) + return datatree.Zero; + [[fallthrough]]; + case SymbolType::endogenous: + case SymbolType::parameter: if (deriv_id == datatree.getDerivID(symb_id, lag)) return datatree.One; // If there is in the equation a recursive variable we could use a chaine rule derivation diff --git a/src/StaticModel.cc b/src/StaticModel.cc index 68352c32..534fd3dc 100644 --- a/src/StaticModel.cc +++ b/src/StaticModel.cc @@ -2051,7 +2051,9 @@ StaticModel::getDerivID(int symb_id, int lag) const noexcept(false) else if (symbol_table.getType(symb_id) == SymbolType::parameter) return symbol_table.getTypeSpecificID(symb_id) + symbol_table.endo_nbr(); else - return -1; + /* See the special treatment in VariableNode::prepareForDerivation(), + VariableNode::computeDerivative() and VariableNode::getChainRuleDerivative() */ + throw UnknownDerivIDException{}; } void -- GitLab