diff --git a/src/ExprNode.cc b/src/ExprNode.cc index 09c4b12a80d444a4c55a5dab2b724ab652ccae2d..5f8371602cb2dce70ecbe95afda521546ed23c03 100644 --- a/src/ExprNode.cc +++ b/src/ExprNode.cc @@ -767,13 +767,16 @@ VariableNode::prepareForDerivation() // Fill in non_null_derivatives switch (get_type()) { - case SymbolType::endogenous: case SymbolType::exogenous: case SymbolType::exogenousDet: - case SymbolType::parameter: case SymbolType::trend: case SymbolType::logTrend: - // For a variable or a parameter, the only non-null derivative is with respect to itself + // In static models, exogenous and trends do not have deriv IDs + if (dynamic_cast<StaticModel *>(&datatree)) + break; + [[fallthrough]]; + case SymbolType::endogenous: + case SymbolType::parameter: non_null_derivatives.insert(datatree.getDerivID(symb_id, lag)); break; case SymbolType::modelLocalVariable: @@ -803,12 +806,16 @@ VariableNode::computeDerivative(int deriv_id) { switch (get_type()) { - case SymbolType::endogenous: case SymbolType::exogenous: case SymbolType::exogenousDet: - case SymbolType::parameter: case SymbolType::trend: case SymbolType::logTrend: + // In static models, exogenous and trends do not have deriv IDs + if (dynamic_cast<StaticModel *>(&datatree)) + return datatree.Zero; + [[fallthrough]]; + case SymbolType::endogenous: + case SymbolType::parameter: if (deriv_id == datatree.getDerivID(symb_id, lag)) return datatree.One; else @@ -1338,12 +1345,16 @@ VariableNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode * { switch (get_type()) { - case SymbolType::endogenous: case SymbolType::exogenous: case SymbolType::exogenousDet: - case SymbolType::parameter: case SymbolType::trend: case SymbolType::logTrend: + // In static models, exogenous and trends do not have deriv IDs + if (dynamic_cast<StaticModel *>(&datatree)) + return datatree.Zero; + [[fallthrough]]; + case SymbolType::endogenous: + case SymbolType::parameter: if (deriv_id == datatree.getDerivID(symb_id, lag)) return datatree.One; // If there is in the equation a recursive variable we could use a chaine rule derivation diff --git a/src/StaticModel.cc b/src/StaticModel.cc index 68352c32568a9418e2b31a17377cec79e36833dc..534fd3dc704559939bc2a81f0523e063bcd750ad 100644 --- a/src/StaticModel.cc +++ b/src/StaticModel.cc @@ -2051,7 +2051,9 @@ StaticModel::getDerivID(int symb_id, int lag) const noexcept(false) else if (symbol_table.getType(symb_id) == SymbolType::parameter) return symbol_table.getTypeSpecificID(symb_id) + symbol_table.endo_nbr(); else - return -1; + /* See the special treatment in VariableNode::prepareForDerivation(), + VariableNode::computeDerivative() and VariableNode::getChainRuleDerivative() */ + throw UnknownDerivIDException{}; } void