Commit 3f26933f authored by Sébastien Villemot's avatar Sébastien Villemot
Browse files

Replace ExternalFunctionSetOrNot enum by integer constants

This was not conceptually an enum, but rather a collection of unrelated
constants:
- two constants for use as placeholder for symbol IDs at different places
- one constant for the default number of arguments
parent 5e6b8f0a
......@@ -7515,7 +7515,7 @@ ExternalFunctionNode::compileExternalFunctionOutput(ostream &CompileCode, unsign
deriv_node_temp_terms_t &tef_terms) const
{
int first_deriv_symb_id = datatree.external_functions_table.getFirstDerivSymbID(symb_id);
assert(first_deriv_symb_id != eExtFunSetButNoNameProvided);
assert(first_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
for (auto argument : arguments)
argument->compileExternalFunctionOutput(CompileCode, instruction_number, lhs_rhs, temporary_terms,
......@@ -7526,7 +7526,7 @@ ExternalFunctionNode::compileExternalFunctionOutput(ostream &CompileCode, unsign
tef_terms[{ symb_id, arguments }] = (int) tef_terms.size();
int indx = getIndxInTefTerms(symb_id, tef_terms);
int second_deriv_symb_id = datatree.external_functions_table.getSecondDerivSymbID(symb_id);
assert(second_deriv_symb_id != eExtFunSetButNoNameProvided);
assert(second_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
unsigned int nb_output_arguments = 0;
if (symb_id == first_deriv_symb_id
......@@ -7618,7 +7618,7 @@ ExternalFunctionNode::writeExternalFunctionOutput(ostream &output, ExprNodeOutpu
deriv_node_temp_terms_t &tef_terms) const
{
int first_deriv_symb_id = datatree.external_functions_table.getFirstDerivSymbID(symb_id);
assert(first_deriv_symb_id != eExtFunSetButNoNameProvided);
assert(first_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
for (auto argument : arguments)
argument->writeExternalFunctionOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
......@@ -7628,7 +7628,7 @@ ExternalFunctionNode::writeExternalFunctionOutput(ostream &output, ExprNodeOutpu
tef_terms[{ symb_id, arguments }] = (int) tef_terms.size();
int indx = getIndxInTefTerms(symb_id, tef_terms);
int second_deriv_symb_id = datatree.external_functions_table.getSecondDerivSymbID(symb_id);
assert(second_deriv_symb_id != eExtFunSetButNoNameProvided);
assert(second_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
if (isCOutput(output_type))
{
......@@ -7695,7 +7695,7 @@ ExternalFunctionNode::writeJsonExternalFunctionOutput(vector<string> &efout,
const bool isdynamic) const
{
int first_deriv_symb_id = datatree.external_functions_table.getFirstDerivSymbID(symb_id);
assert(first_deriv_symb_id != eExtFunSetButNoNameProvided);
assert(first_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
for (auto argument : arguments)
argument->writeJsonExternalFunctionOutput(efout, temporary_terms, tef_terms, isdynamic);
......@@ -7705,7 +7705,7 @@ ExternalFunctionNode::writeJsonExternalFunctionOutput(vector<string> &efout,
tef_terms[{ symb_id, arguments }] = (int) tef_terms.size();
int indx = getIndxInTefTerms(symb_id, tef_terms);
int second_deriv_symb_id = datatree.external_functions_table.getSecondDerivSymbID(symb_id);
assert(second_deriv_symb_id != eExtFunSetButNoNameProvided);
assert(second_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
stringstream ef;
ef << "{\"external_function\": {"
......@@ -7826,14 +7826,14 @@ FirstDerivExternalFunctionNode::writeJsonOutput(ostream &output,
}
const int first_deriv_symb_id = datatree.external_functions_table.getFirstDerivSymbID(symb_id);
assert(first_deriv_symb_id != eExtFunSetButNoNameProvided);
assert(first_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
const int tmpIndx = inputIndex - 1;
if (first_deriv_symb_id == symb_id)
output << "TEFD_" << getIndxInTefTerms(symb_id, tef_terms)
<< "[" << tmpIndx << "]";
else if (first_deriv_symb_id == eExtFunNotSet)
else if (first_deriv_symb_id == ExternalFunctionsTable::IDNotSet)
output << "TEFD_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex;
else
output << "TEFD_def_" << getIndxInTefTerms(first_deriv_symb_id, tef_terms)
......@@ -7861,14 +7861,14 @@ FirstDerivExternalFunctionNode::writeOutput(ostream &output, ExprNodeOutputType
return;
const int first_deriv_symb_id = datatree.external_functions_table.getFirstDerivSymbID(symb_id);
assert(first_deriv_symb_id != eExtFunSetButNoNameProvided);
assert(first_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
const int tmpIndx = inputIndex - 1 + ARRAY_SUBSCRIPT_OFFSET(output_type);
if (first_deriv_symb_id == symb_id)
output << "TEFD_" << getIndxInTefTerms(symb_id, tef_terms)
<< LEFT_ARRAY_SUBSCRIPT(output_type) << tmpIndx << RIGHT_ARRAY_SUBSCRIPT(output_type);
else if (first_deriv_symb_id == eExtFunNotSet)
else if (first_deriv_symb_id == ExternalFunctionsTable::IDNotSet)
{
if (isCOutput(output_type))
output << "*";
......@@ -7903,7 +7903,7 @@ FirstDerivExternalFunctionNode::compile(ostream &CompileCode, unsigned int &inst
return;
}
int first_deriv_symb_id = datatree.external_functions_table.getFirstDerivSymbID(symb_id);
assert(first_deriv_symb_id != eExtFunSetButNoNameProvided);
assert(first_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
if (!lhs_rhs)
{
......@@ -7925,7 +7925,7 @@ FirstDerivExternalFunctionNode::writeExternalFunctionOutput(ostream &output, Exp
{
assert(output_type != ExprNodeOutputType::matlabOutsideModel);
int first_deriv_symb_id = datatree.external_functions_table.getFirstDerivSymbID(symb_id);
assert(first_deriv_symb_id != eExtFunSetButNoNameProvided);
assert(first_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
/* For a node with derivs provided by the user function, call the method
on the non-derived node */
......@@ -7941,7 +7941,7 @@ FirstDerivExternalFunctionNode::writeExternalFunctionOutput(ostream &output, Exp
return;
if (isCOutput(output_type))
if (first_deriv_symb_id == eExtFunNotSet)
if (first_deriv_symb_id == ExternalFunctionsTable::IDNotSet)
{
stringstream ending;
ending << "_tefd_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex;
......@@ -8002,7 +8002,7 @@ FirstDerivExternalFunctionNode::writeExternalFunctionOutput(ostream &output, Exp
}
else
{
if (first_deriv_symb_id == eExtFunNotSet)
if (first_deriv_symb_id == ExternalFunctionsTable::IDNotSet)
output << "TEFD_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex << " = jacob_element('"
<< datatree.symbol_table.getName(symb_id) << "'," << inputIndex << ",{";
else
......@@ -8014,7 +8014,7 @@ FirstDerivExternalFunctionNode::writeExternalFunctionOutput(ostream &output, Exp
writeExternalFunctionArguments(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
if (first_deriv_symb_id == eExtFunNotSet)
if (first_deriv_symb_id == ExternalFunctionsTable::IDNotSet)
output << "}";
output << ");" << endl;
}
......@@ -8027,7 +8027,7 @@ FirstDerivExternalFunctionNode::writeJsonExternalFunctionOutput(vector<string> &
const bool isdynamic) const
{
int first_deriv_symb_id = datatree.external_functions_table.getFirstDerivSymbID(symb_id);
assert(first_deriv_symb_id != eExtFunSetButNoNameProvided);
assert(first_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
/* For a node with derivs provided by the user function, call the method
on the non-derived node */
......@@ -8042,7 +8042,7 @@ FirstDerivExternalFunctionNode::writeJsonExternalFunctionOutput(vector<string> &
return;
stringstream ef;
if (first_deriv_symb_id == eExtFunNotSet)
if (first_deriv_symb_id == ExternalFunctionsTable::IDNotSet)
ef << "{\"first_deriv_external_function\": {"
<< "\"external_function_term\": \"TEFD_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex << "\""
<< ", \"analytic_derivative\": false"
......@@ -8069,14 +8069,14 @@ FirstDerivExternalFunctionNode::compileExternalFunctionOutput(ostream &CompileCo
deriv_node_temp_terms_t &tef_terms) const
{
int first_deriv_symb_id = datatree.external_functions_table.getFirstDerivSymbID(symb_id);
assert(first_deriv_symb_id != eExtFunSetButNoNameProvided);
assert(first_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
if (first_deriv_symb_id == symb_id || alreadyWrittenAsTefTerm(first_deriv_symb_id, tef_terms))
return;
unsigned int nb_add_input_arguments = compileExternalFunctionArguments(CompileCode, instruction_number, lhs_rhs, temporary_terms,
map_idx, dynamic, steady_dynamic, tef_terms);
if (first_deriv_symb_id == eExtFunNotSet)
if (first_deriv_symb_id == ExternalFunctionsTable::IDNotSet)
{
unsigned int nb_input_arguments = 0;
unsigned int nb_output_arguments = 1;
......@@ -8095,7 +8095,7 @@ FirstDerivExternalFunctionNode::compileExternalFunctionOutput(ostream &CompileCo
tef_terms[{ first_deriv_symb_id, arguments }] = (int) tef_terms.size();
int indx = getIndxInTefTerms(symb_id, tef_terms);
int second_deriv_symb_id = datatree.external_functions_table.getSecondDerivSymbID(symb_id);
assert(second_deriv_symb_id != eExtFunSetButNoNameProvided);
assert(second_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
unsigned int nb_output_arguments = 1;
......@@ -8215,7 +8215,7 @@ SecondDerivExternalFunctionNode::writeJsonOutput(ostream &output,
}
const int second_deriv_symb_id = datatree.external_functions_table.getSecondDerivSymbID(symb_id);
assert(second_deriv_symb_id != eExtFunSetButNoNameProvided);
assert(second_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
const int tmpIndex1 = inputIndex1 - 1;
const int tmpIndex2 = inputIndex2 - 1;
......@@ -8223,7 +8223,7 @@ SecondDerivExternalFunctionNode::writeJsonOutput(ostream &output,
if (second_deriv_symb_id == symb_id)
output << "TEFDD_" << getIndxInTefTerms(symb_id, tef_terms)
<< "[" << tmpIndex1 << "," << tmpIndex2 << "]";
else if (second_deriv_symb_id == eExtFunNotSet)
else if (second_deriv_symb_id == ExternalFunctionsTable::IDNotSet)
output << "TEFDD_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex1 << "_" << inputIndex2;
else
output << "TEFDD_def_" << getIndxInTefTerms(second_deriv_symb_id, tef_terms)
......@@ -8251,7 +8251,7 @@ SecondDerivExternalFunctionNode::writeOutput(ostream &output, ExprNodeOutputType
return;
const int second_deriv_symb_id = datatree.external_functions_table.getSecondDerivSymbID(symb_id);
assert(second_deriv_symb_id != eExtFunSetButNoNameProvided);
assert(second_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
const int tmpIndex1 = inputIndex1 - 1 + ARRAY_SUBSCRIPT_OFFSET(output_type);
const int tmpIndex2 = inputIndex2 - 1 + ARRAY_SUBSCRIPT_OFFSET(output_type);
......@@ -8265,7 +8265,7 @@ SecondDerivExternalFunctionNode::writeOutput(ostream &output, ExprNodeOutputType
else
output << "TEFDD_" << getIndxInTefTerms(symb_id, tef_terms)
<< LEFT_ARRAY_SUBSCRIPT(output_type) << tmpIndex1 << "," << tmpIndex2 << RIGHT_ARRAY_SUBSCRIPT(output_type);
else if (second_deriv_symb_id == eExtFunNotSet)
else if (second_deriv_symb_id == ExternalFunctionsTable::IDNotSet)
{
if (isCOutput(output_type))
output << "*";
......@@ -8289,7 +8289,7 @@ SecondDerivExternalFunctionNode::writeExternalFunctionOutput(ostream &output, Ex
{
assert(output_type != ExprNodeOutputType::matlabOutsideModel);
int second_deriv_symb_id = datatree.external_functions_table.getSecondDerivSymbID(symb_id);
assert(second_deriv_symb_id != eExtFunSetButNoNameProvided);
assert(second_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
/* For a node with derivs provided by the user function, call the method
on the non-derived node */
......@@ -8305,7 +8305,7 @@ SecondDerivExternalFunctionNode::writeExternalFunctionOutput(ostream &output, Ex
return;
if (isCOutput(output_type))
if (second_deriv_symb_id == eExtFunNotSet)
if (second_deriv_symb_id == ExternalFunctionsTable::IDNotSet)
{
stringstream ending;
ending << "_tefdd_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex1 << "_" << inputIndex2;
......@@ -8368,7 +8368,7 @@ SecondDerivExternalFunctionNode::writeExternalFunctionOutput(ostream &output, Ex
}
else
{
if (second_deriv_symb_id == eExtFunNotSet)
if (second_deriv_symb_id == ExternalFunctionsTable::IDNotSet)
output << "TEFDD_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex1 << "_" << inputIndex2
<< " = hess_element('" << datatree.symbol_table.getName(symb_id) << "',"
<< inputIndex1 << "," << inputIndex2 << ",{";
......@@ -8381,7 +8381,7 @@ SecondDerivExternalFunctionNode::writeExternalFunctionOutput(ostream &output, Ex
writeExternalFunctionArguments(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
if (second_deriv_symb_id == eExtFunNotSet)
if (second_deriv_symb_id == ExternalFunctionsTable::IDNotSet)
output << "}";
output << ");" << endl;
}
......@@ -8394,7 +8394,7 @@ SecondDerivExternalFunctionNode::writeJsonExternalFunctionOutput(vector<string>
const bool isdynamic) const
{
int second_deriv_symb_id = datatree.external_functions_table.getSecondDerivSymbID(symb_id);
assert(second_deriv_symb_id != eExtFunSetButNoNameProvided);
assert(second_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
/* For a node with derivs provided by the user function, call the method
on the non-derived node */
......@@ -8409,7 +8409,7 @@ SecondDerivExternalFunctionNode::writeJsonExternalFunctionOutput(vector<string>
return;
stringstream ef;
if (second_deriv_symb_id == eExtFunNotSet)
if (second_deriv_symb_id == ExternalFunctionsTable::IDNotSet)
ef << "{\"second_deriv_external_function\": {"
<< "\"external_function_term\": \"TEFDD_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex1 << "_" << inputIndex2 << "\""
<< ", \"analytic_derivative\": false"
......
......@@ -37,14 +37,14 @@ ExternalFunctionsTable::addExternalFunction(int symb_id, const external_function
// Change options to be saved so the table is consistent
external_function_options external_function_options_chng = external_function_options_arg;
if (external_function_options_arg.firstDerivSymbID == eExtFunSetButNoNameProvided)
if (external_function_options_arg.firstDerivSymbID == IDSetButNoNameProvided)
external_function_options_chng.firstDerivSymbID = symb_id;
if (external_function_options_arg.secondDerivSymbID == eExtFunSetButNoNameProvided)
if (external_function_options_arg.secondDerivSymbID == IDSetButNoNameProvided)
external_function_options_chng.secondDerivSymbID = symb_id;
if (!track_nargs)
external_function_options_chng.nargs = eExtFunNotSet;
external_function_options_chng.nargs = IDNotSet;
// Ensure 1st & 2nd deriv option consistency
if (external_function_options_chng.secondDerivSymbID == symb_id
......@@ -57,15 +57,15 @@ ExternalFunctionsTable::addExternalFunction(int symb_id, const external_function
if ((external_function_options_chng.secondDerivSymbID != symb_id
&& external_function_options_chng.firstDerivSymbID == symb_id)
&& external_function_options_chng.secondDerivSymbID != eExtFunNotSet)
&& external_function_options_chng.secondDerivSymbID != IDNotSet)
{
cerr << "ERROR: If the first derivative is provided by the top-level function, the "
<< "second derivative cannot be provided by any other external function." << endl;
exit(EXIT_FAILURE);
}
if (external_function_options_chng.secondDerivSymbID != eExtFunNotSet
&& external_function_options_chng.firstDerivSymbID == eExtFunNotSet)
if (external_function_options_chng.secondDerivSymbID != IDNotSet
&& external_function_options_chng.firstDerivSymbID == IDNotSet)
{
cerr << "ERROR: If the second derivative is provided, the first derivative must also be provided." << endl;
exit(EXIT_FAILURE);
......@@ -73,7 +73,7 @@ ExternalFunctionsTable::addExternalFunction(int symb_id, const external_function
if (external_function_options_chng.secondDerivSymbID == external_function_options_chng.firstDerivSymbID
&& external_function_options_chng.firstDerivSymbID != symb_id
&& external_function_options_chng.firstDerivSymbID != eExtFunNotSet)
&& external_function_options_chng.firstDerivSymbID != IDNotSet)
{
cerr << "ERROR: If the Jacobian and Hessian are provided by the same function, that "
<< "function must be the top-level function." << endl;
......@@ -84,7 +84,7 @@ ExternalFunctionsTable::addExternalFunction(int symb_id, const external_function
if (exists(symb_id))
{
bool ok_to_overwrite = false;
if (getNargs(symb_id) == eExtFunNotSet) // implies that the information stored about this function is not important
if (getNargs(symb_id) == IDNotSet) // implies that the information stored about this function is not important
ok_to_overwrite = true;
if (!ok_to_overwrite) // prevents multiple non-compatible calls to external_function(name=funcname)
......
......@@ -27,13 +27,6 @@ using namespace std;
#include <vector>
#include <map>
enum ExternalFunctionSetOrNot
{
eExtFunSetButNoNameProvided = -2, //! Signifies that the derivative is obtained from the top-level function
eExtFunNotSet = -1, //! Signifies that no external function exists that calculates the derivative
eExtFunSetDefaultNargs = 1 //! This is the default number of arguments when nargs is not specified
};
//! Handles external functions
class ExternalFunctionsTable
{
......@@ -59,6 +52,12 @@ public:
int nargs, firstDerivSymbID, secondDerivSymbID;
};
using external_function_table_type = map<int, external_function_options>;
//! Symbol ID used when no external function exists that calculates the derivative
const static int IDNotSet = -1;
//! Symbol ID used when the derivative is obtained from the top-level function
const static int IDSetButNoNameProvided = -2;
//! Default number of arguments when nargs is not specified
const static int defaultNargs = 1;
private:
//! Map containing options provided to external_functions()
external_function_table_type externalFunctionTable;
......
......@@ -80,10 +80,10 @@ ParsingDriver::reset_data_tree()
void
ParsingDriver::reset_current_external_function_options()
{
current_external_function_options.nargs = eExtFunSetDefaultNargs;
current_external_function_options.firstDerivSymbID = eExtFunNotSet;
current_external_function_options.secondDerivSymbID = eExtFunNotSet;
current_external_function_id = eExtFunNotSet;
current_external_function_options.nargs = ExternalFunctionsTable::defaultNargs;
current_external_function_options.firstDerivSymbID = ExternalFunctionsTable::IDNotSet;
current_external_function_options.secondDerivSymbID = ExternalFunctionsTable::IDNotSet;
current_external_function_id = ExternalFunctionsTable::IDNotSet;
}
unique_ptr<ModFile>
......@@ -2858,7 +2858,7 @@ ParsingDriver::external_function_option(const string &name_option, const string
else if (name_option == "first_deriv_provided")
{
if (opt.empty())
current_external_function_options.firstDerivSymbID = eExtFunSetButNoNameProvided;
current_external_function_options.firstDerivSymbID = ExternalFunctionsTable::IDSetButNoNameProvided;
else
{
declare_symbol(opt, SymbolType::externalFunction, "", {});
......@@ -2868,7 +2868,7 @@ ParsingDriver::external_function_option(const string &name_option, const string
else if (name_option == "second_deriv_provided")
{
if (opt.empty())
current_external_function_options.secondDerivSymbID = eExtFunSetButNoNameProvided;
current_external_function_options.secondDerivSymbID = ExternalFunctionsTable::IDSetButNoNameProvided;
else
{
declare_symbol(opt, SymbolType::externalFunction, "", {});
......@@ -2884,15 +2884,15 @@ ParsingDriver::external_function_option(const string &name_option, const string
void
ParsingDriver::external_function()
{
if (current_external_function_id == eExtFunNotSet)
if (current_external_function_id == ExternalFunctionsTable::IDNotSet)
error("The 'name' option must be passed to external_function().");
if (current_external_function_options.secondDerivSymbID >= 0
&& current_external_function_options.firstDerivSymbID == eExtFunNotSet)
&& current_external_function_options.firstDerivSymbID == ExternalFunctionsTable::IDNotSet)
error("If the second derivative is provided to the external_function command, the first derivative must also be provided.");
if (current_external_function_options.secondDerivSymbID == eExtFunSetButNoNameProvided
&& current_external_function_options.firstDerivSymbID != eExtFunSetButNoNameProvided)
if (current_external_function_options.secondDerivSymbID == ExternalFunctionsTable::IDSetButNoNameProvided
&& current_external_function_options.firstDerivSymbID != ExternalFunctionsTable::IDSetButNoNameProvided)
error("If the second derivative is provided in the top-level function, the first derivative must also be provided in that function.");
mod_file->external_functions_table.addExternalFunction(current_external_function_id, current_external_function_options, true);
......@@ -2992,7 +2992,7 @@ ParsingDriver::add_model_var_or_external_function(const string &function_name, b
error("Using a derivative of an external function (" + function_name + ") in the model block is currently not allowed.");
if (in_model_block || parsing_epilogue)
if (mod_file->external_functions_table.getNargs(symb_id) == eExtFunNotSet)
if (mod_file->external_functions_table.getNargs(symb_id) == ExternalFunctionsTable::IDNotSet)
error("Before using " + function_name
+"() in the model block, you must first declare it via the external_function() statement");
else if ((int) (stack_external_function_args.top().size()) != mod_file->external_functions_table.getNargs(symb_id))
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment