diff --git a/src/CodeInterpreter.hh b/src/CodeInterpreter.hh index de7cf9dcfd1d20c82aad4dfbf72e99ad8df917ce..46bcc1a27381043a1ab784fa16bebc5f61ca73c5 100644 --- a/src/CodeInterpreter.hh +++ b/src/CodeInterpreter.hh @@ -268,6 +268,33 @@ enum NodeTreeReference eHessianParamsDeriv = 8 }; +/*! Lists elements of the NodeTreeReference enum that come “before” the argument. + Used in AbstractExternalFunctionNode::computeTemporaryTerms */ +inline auto +nodeTreeReferencesBefore(NodeTreeReference tr) +{ + vector<NodeTreeReference> v; + + // Should be same order as the one appearing in ModelTree::computeTemporaryTerms() + for (auto tr2 : { eResiduals, eFirstDeriv, eSecondDeriv, eThirdDeriv }) + if (tr == tr2) + return v; + else + v.push_back(tr2); + v.clear(); + + // Should be same order as the one appearing in ModelTree::computeParamsDerivativesTemporaryTerms() + for (auto tr2 : { eResidualsParamsDeriv, eJacobianParamsDeriv, eResidualsParamsSecondDeriv, + eJacobianParamsSecondDeriv, eHessianParamsDeriv}) + if (tr == tr2) + return v; + else + v.push_back(tr2); + + cerr << "nodeTreeReferencesBelow: impossible case" << endl; + exit(EXIT_FAILURE); +} + struct Block_contain_type { int Equation, Variable, Own_Derivative; diff --git a/src/ExprNode.cc b/src/ExprNode.cc index 4f9ebdf7f401a3b965092553d7c8f34c3bfa13bb..0dc1a6f2b14455483d6d134f26775c9e3eb139f6 100644 --- a/src/ExprNode.cc +++ b/src/ExprNode.cc @@ -6242,6 +6242,36 @@ AbstractExternalFunctionNode::getIndxInTefTerms(int the_symb_id, const deriv_nod throw UnknownFunctionNameAndArgs(); } +void +AbstractExternalFunctionNode::computeTemporaryTerms(map<expr_t, pair<int, NodeTreeReference> > &reference_count, + map<NodeTreeReference, temporary_terms_t> &temp_terms_map, + bool is_matlab, NodeTreeReference tr) const +{ + /* All external function nodes are declared as temporary terms. + + Given that temporary terms are separated in several functions (residuals, + jacobian, …), we must make sure that all temporary terms derived from a + given external function call are assigned just after that call. + + As a consequence, we need to “promote” some terms to a previous level (in + the sense that residuals come before jacobian), if a temporary term + corresponding to the same external function call is present in that + previous level. */ + + for (auto tr2 : nodeTreeReferencesBefore(tr)) + { + auto it = find_if(temp_terms_map[tr2].cbegin(), temp_terms_map[tr2].cend(), + sameTefTermPredicate()); + if (it != temp_terms_map[tr2].cend()) + { + temp_terms_map[tr2].insert(const_cast<AbstractExternalFunctionNode *>(this)); + return; + } + } + + temp_terms_map[tr].insert(const_cast<AbstractExternalFunctionNode *>(this)); +} + bool AbstractExternalFunctionNode::isNumConstNodeEqualTo(double value) const { @@ -6459,14 +6489,6 @@ ExternalFunctionNode::composeDerivatives(const vector<expr_t> &dargs) return theDeriv; } -void -ExternalFunctionNode::computeTemporaryTerms(map<expr_t, pair<int, NodeTreeReference> > &reference_count, - map<NodeTreeReference, temporary_terms_t> &temp_terms_map, - bool is_matlab, NodeTreeReference tr) const -{ - temp_terms_map[tr].insert(const_cast<ExternalFunctionNode *>(this)); -} - void ExternalFunctionNode::computeTemporaryTerms(map<expr_t, int> &reference_count, temporary_terms_t &temporary_terms, @@ -6762,6 +6784,15 @@ ExternalFunctionNode::buildSimilarExternalFunctionNode(vector<expr_t> &alt_args, return alt_datatree.AddExternalFunction(symb_id, alt_args); } +function<bool (expr_t)> +ExternalFunctionNode::sameTefTermPredicate() const +{ + return [this](expr_t e) { + auto e2 = dynamic_cast<ExternalFunctionNode *>(e); + return (e2 != nullptr && e2->symb_id == symb_id); + }; +} + FirstDerivExternalFunctionNode::FirstDerivExternalFunctionNode(DataTree &datatree_arg, int top_level_symb_id_arg, const vector<expr_t> &arguments_arg, @@ -6773,14 +6804,6 @@ FirstDerivExternalFunctionNode::FirstDerivExternalFunctionNode(DataTree &datatre datatree.first_deriv_external_function_node_map[make_pair(make_pair(arguments, inputIndex), symb_id)] = this; } -void -FirstDerivExternalFunctionNode::computeTemporaryTerms(map<expr_t, pair<int, NodeTreeReference> > &reference_count, - map<NodeTreeReference, temporary_terms_t> &temp_terms_map, - bool is_matlab, NodeTreeReference tr) const -{ - temp_terms_map[tr].insert(const_cast<FirstDerivExternalFunctionNode *>(this)); -} - void FirstDerivExternalFunctionNode::computeTemporaryTerms(map<expr_t, int> &reference_count, temporary_terms_t &temporary_terms, @@ -7142,6 +7165,22 @@ FirstDerivExternalFunctionNode::computeXrefs(EquationInfo &ei) const (*it)->computeXrefs(ei); } +function<bool (expr_t)> +FirstDerivExternalFunctionNode::sameTefTermPredicate() const +{ + int first_deriv_symb_id = datatree.external_functions_table.getFirstDerivSymbID(symb_id); + if (first_deriv_symb_id == symb_id) + return [this](expr_t e) { + auto e2 = dynamic_cast<ExternalFunctionNode *>(e); + return (e2 != nullptr && e2->symb_id == symb_id); + }; + else + return [this](expr_t e) { + auto e2 = dynamic_cast<FirstDerivExternalFunctionNode *>(e); + return (e2 != nullptr && e2->symb_id == symb_id); + }; +} + SecondDerivExternalFunctionNode::SecondDerivExternalFunctionNode(DataTree &datatree_arg, int top_level_symb_id_arg, const vector<expr_t> &arguments_arg, @@ -7155,14 +7194,6 @@ SecondDerivExternalFunctionNode::SecondDerivExternalFunctionNode(DataTree &datat datatree.second_deriv_external_function_node_map[make_pair(make_pair(arguments, make_pair(inputIndex1, inputIndex2)), symb_id)] = this; } -void -SecondDerivExternalFunctionNode::computeTemporaryTerms(map<expr_t, pair<int, NodeTreeReference> > &reference_count, - map<NodeTreeReference, temporary_terms_t> &temp_terms_map, - bool is_matlab, NodeTreeReference tr) const -{ - temp_terms_map[tr].insert(const_cast<SecondDerivExternalFunctionNode *>(this)); -} - void SecondDerivExternalFunctionNode::computeTemporaryTerms(map<expr_t, int> &reference_count, temporary_terms_t &temporary_terms, @@ -7473,6 +7504,22 @@ SecondDerivExternalFunctionNode::compileExternalFunctionOutput(ostream &CompileC exit(EXIT_FAILURE); } +function<bool (expr_t)> +SecondDerivExternalFunctionNode::sameTefTermPredicate() const +{ + int second_deriv_symb_id = datatree.external_functions_table.getSecondDerivSymbID(symb_id); + if (second_deriv_symb_id == symb_id) + return [this](expr_t e) { + auto e2 = dynamic_cast<ExternalFunctionNode *>(e); + return (e2 != nullptr && e2->symb_id == symb_id); + }; + else + return [this](expr_t e) { + auto e2 = dynamic_cast<SecondDerivExternalFunctionNode *>(e); + return (e2 != nullptr && e2->symb_id == symb_id); + }; +} + VarExpectationNode::VarExpectationNode(DataTree &datatree_arg, int symb_id_arg, int forecast_horizon_arg, diff --git a/src/ExprNode.hh b/src/ExprNode.hh index c04c2676fe7a6d702bd4c8e354ba242236674184..6714c48c318035d68de0ff846295ab5244c729ee 100644 --- a/src/ExprNode.hh +++ b/src/ExprNode.hh @@ -24,6 +24,7 @@ #include <map> #include <vector> #include <ostream> +#include <functional> using namespace std; @@ -1076,13 +1077,17 @@ protected: //! Helper function to write output arguments of any given external function void writeExternalFunctionArguments(ostream &output, ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const; void writeJsonExternalFunctionArguments(ostream &output, const temporary_terms_t &temporary_terms, const deriv_node_temp_terms_t &tef_terms, const bool isdynamic) const; + /*! Returns a predicate that tests whether an other ExprNode is an external + function which is computed by the same external function call (i.e. it has + the same so-called "Tef" index) */ + virtual function<bool (expr_t)> sameTefTermPredicate() const = 0; public: AbstractExternalFunctionNode(DataTree &datatree_arg, int symb_id_arg, const vector<expr_t> &arguments_arg); virtual void prepareForDerivation(); virtual void computeTemporaryTerms(map<expr_t, pair<int, NodeTreeReference> > &reference_count, map<NodeTreeReference, temporary_terms_t> &temp_terms_map, - bool is_matlab, NodeTreeReference tr) const = 0; + bool is_matlab, NodeTreeReference tr) const; virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const = 0; virtual void writeJsonOutput(ostream &output, const temporary_terms_t &temporary_terms, const deriv_node_temp_terms_t &tef_terms, const bool isdynamic = true) const = 0; virtual bool containsExternalFunction() const; @@ -1164,14 +1169,15 @@ public: class ExternalFunctionNode : public AbstractExternalFunctionNode { + friend class FirstDerivExternalFunctionNode; + friend class SecondDerivExternalFunctionNode; private: virtual expr_t composeDerivatives(const vector<expr_t> &dargs); +protected: + function<bool (expr_t)> sameTefTermPredicate() const override; public: ExternalFunctionNode(DataTree &datatree_arg, int symb_id_arg, const vector<expr_t> &arguments_arg); - virtual void computeTemporaryTerms(map<expr_t, pair<int, NodeTreeReference> > &reference_count, - map<NodeTreeReference, temporary_terms_t> &temp_terms_map, - bool is_matlab, NodeTreeReference tr) const; virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const; virtual void writeJsonOutput(ostream &output, const temporary_terms_t &temporary_terms, const deriv_node_temp_terms_t &tef_terms, const bool isdynamic) const; virtual void writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type, @@ -1204,14 +1210,13 @@ class FirstDerivExternalFunctionNode : public AbstractExternalFunctionNode private: const int inputIndex; virtual expr_t composeDerivatives(const vector<expr_t> &dargs); +protected: + function<bool (expr_t)> sameTefTermPredicate() const override; public: FirstDerivExternalFunctionNode(DataTree &datatree_arg, int top_level_symb_id_arg, const vector<expr_t> &arguments_arg, int inputIndex_arg); - virtual void computeTemporaryTerms(map<expr_t, pair<int, NodeTreeReference> > &reference_count, - map<NodeTreeReference, temporary_terms_t> &temp_terms_map, - bool is_matlab, NodeTreeReference tr) const; virtual void computeTemporaryTerms(map<expr_t, int> &reference_count, temporary_terms_t &temporary_terms, map<expr_t, pair<int, int> > &first_occurence, @@ -1248,15 +1253,14 @@ private: const int inputIndex1; const int inputIndex2; virtual expr_t composeDerivatives(const vector<expr_t> &dargs); +protected: + function<bool (expr_t)> sameTefTermPredicate() const override; public: SecondDerivExternalFunctionNode(DataTree &datatree_arg, int top_level_symb_id_arg, const vector<expr_t> &arguments_arg, int inputIndex1_arg, int inputIndex2_arg); - virtual void computeTemporaryTerms(map<expr_t, pair<int, NodeTreeReference> > &reference_count, - map<NodeTreeReference, temporary_terms_t> &temp_terms_map, - bool is_matlab, NodeTreeReference tr) const; virtual void computeTemporaryTerms(map<expr_t, int> &reference_count, temporary_terms_t &temporary_terms, map<expr_t, pair<int, int> > &first_occurence,