From cec47cc78c9c1aa94c2416f1a86af6f0dc00d207 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Villemot?= <sebastien@dynare.org> Date: Fri, 20 May 2022 12:35:38 +0200 Subject: [PATCH] Implement bytecode compilation of 2nd deriv of external functions --- src/ExprNode.cc | 80 ++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 76 insertions(+), 4 deletions(-) diff --git a/src/ExprNode.cc b/src/ExprNode.cc index 365bebff..5136d14b 100644 --- a/src/ExprNode.cc +++ b/src/ExprNode.cc @@ -8135,8 +8135,35 @@ SecondDerivExternalFunctionNode::compile(ostream &CompileCode, unsigned int &ins const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic, const deriv_node_temp_terms_t &tef_terms) const { - cerr << "SecondDerivExternalFunctionNode::compile: not implemented." << endl; - exit(EXIT_FAILURE); + if (auto this2 = const_cast<SecondDerivExternalFunctionNode *>(this); + temporary_terms.contains(this2)) + { + if (dynamic) + { + FLDT_ fldt(temporary_terms_idxs.at(this2)); + fldt.write(CompileCode, instruction_number); + } + else + { + FLDST_ fldst(temporary_terms_idxs.at(this2)); + fldst.write(CompileCode, instruction_number); + } + return; + } + + int second_deriv_symb_id = datatree.external_functions_table.getSecondDerivSymbID(symb_id); + assert(second_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided); + + if (!lhs_rhs) + { + FLDTEFDD_ fldtefdd(getIndxInTefTerms(symb_id, tef_terms), inputIndex1, inputIndex2); + fldtefdd.write(CompileCode, instruction_number); + } + else + { + FSTPTEFDD_ fstptefdd(getIndxInTefTerms(symb_id, tef_terms), inputIndex1, inputIndex2); + fstptefdd.write(CompileCode, instruction_number); + } } void @@ -8145,8 +8172,53 @@ SecondDerivExternalFunctionNode::compileExternalFunctionOutput(ostream &CompileC const temporary_terms_idxs_t &temporary_terms_idxs, bool dynamic, bool steady_dynamic, deriv_node_temp_terms_t &tef_terms) const { - cerr << "SecondDerivExternalFunctionNode::compileExternalFunctionOutput: not implemented." << endl; - exit(EXIT_FAILURE); + int second_deriv_symb_id = datatree.external_functions_table.getSecondDerivSymbID(symb_id); + assert(second_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided); + + /* For a node with derivs provided by the user function, call the method + on the non-derived node */ + if (second_deriv_symb_id == symb_id) + { + expr_t parent = datatree.AddExternalFunction(symb_id, arguments); + parent->compileExternalFunctionOutput(CompileCode, instruction_number, lhs_rhs, + temporary_terms, temporary_terms_idxs, + dynamic, steady_dynamic, tef_terms); + return; + } + + if (alreadyWrittenAsTefTerm(second_deriv_symb_id, tef_terms)) + return; + + unsigned int nb_add_input_arguments = compileExternalFunctionArguments(CompileCode, instruction_number, lhs_rhs, temporary_terms, + temporary_terms_idxs, dynamic, steady_dynamic, tef_terms); + if (second_deriv_symb_id == ExternalFunctionsTable::IDNotSet) + { + unsigned int nb_input_arguments = 0; + unsigned int nb_output_arguments = 1; + int indx = getIndxInTefTerms(symb_id, tef_terms); + FCALL_ fcall(nb_output_arguments, nb_input_arguments, "hess_element", indx); + fcall.set_arg_func_name(datatree.symbol_table.getName(symb_id)); + fcall.set_row(inputIndex1); + fcall.set_col(inputIndex2); + fcall.set_nb_add_input_arguments(nb_add_input_arguments); + fcall.set_function_type(ExternalFunctionType::numericalSecondDerivative); + fcall.write(CompileCode, instruction_number); + FSTPTEFDD_ fstptefdd(indx, inputIndex1, inputIndex2); + fstptefdd.write(CompileCode, instruction_number); + } + else + { + tef_terms[{ second_deriv_symb_id, arguments }] = static_cast<int>(tef_terms.size()); + int indx = getIndxInTefTerms(symb_id, tef_terms); + + unsigned int nb_output_arguments = 1; + + FCALL_ fcall(nb_output_arguments, nb_add_input_arguments, datatree.symbol_table.getName(second_deriv_symb_id), indx); + fcall.set_function_type(ExternalFunctionType::secondDerivative); + fcall.write(CompileCode, instruction_number); + FSTPTEFDD_ fstptefdd(indx, inputIndex1, inputIndex2); + fstptefdd.write(CompileCode, instruction_number); + } } function<bool (expr_t)> -- GitLab