Verified Commit 571b5d08 authored by Sébastien Villemot's avatar Sébastien Villemot
Browse files

Computation of temporary terms generalized to any derivation order

parent 67ac4bf8
Pipeline #400 passed with stage
in 1 minute and 29 seconds
...@@ -257,46 +257,6 @@ enum class PriorDistributions ...@@ -257,46 +257,6 @@ enum class PriorDistributions
weibull = 8 weibull = 8
}; };
enum class NodeTreeReference
{
residuals,
firstDeriv,
secondDeriv,
thirdDeriv,
residualsParamsDeriv,
jacobianParamsDeriv,
residualsParamsSecondDeriv,
jacobianParamsSecondDeriv,
hessianParamsDeriv
};
/*! Lists elements of the NodeTreeReference enum that come “before” the argument.
Used in AbstractExternalFunctionNode::computeTemporaryTerms */
inline auto
nodeTreeReferencesBefore(NodeTreeReference tr)
{
vector<NodeTreeReference> v;
// Should be same order as the one appearing in ModelTree::computeTemporaryTerms()
for (auto tr2 : { NodeTreeReference::residuals, NodeTreeReference::firstDeriv, NodeTreeReference::secondDeriv, NodeTreeReference::thirdDeriv })
if (tr == tr2)
return v;
else
v.push_back(tr2);
v.clear();
// Should be same order as the one appearing in ModelTree::computeParamsDerivativesTemporaryTerms()
for (auto tr2 : { NodeTreeReference::residualsParamsDeriv, NodeTreeReference::jacobianParamsDeriv, NodeTreeReference::residualsParamsSecondDeriv,
NodeTreeReference::jacobianParamsSecondDeriv, NodeTreeReference::hessianParamsDeriv})
if (tr == tr2)
return v;
else
v.push_back(tr2);
cerr << "nodeTreeReferencesBelow: impossible case" << endl;
exit(EXIT_FAILURE);
}
struct Block_contain_type struct Block_contain_type
{ {
int Equation, Variable, Own_Derivative; int Equation, Variable, Own_Derivative;
......
...@@ -5349,8 +5349,8 @@ DynamicModel::writeParamsDerivativesFile(const string &basename, bool julia) con ...@@ -5349,8 +5349,8 @@ DynamicModel::writeParamsDerivativesFile(const string &basename, bool julia) con
deriv_node_temp_terms_t tef_terms; deriv_node_temp_terms_t tef_terms;
writeModelLocalVariableTemporaryTerms(temp_term_union, params_derivs_temporary_terms_idxs, tt_output, output_type, tef_terms); writeModelLocalVariableTemporaryTerms(temp_term_union, params_derivs_temporary_terms_idxs, tt_output, output_type, tef_terms);
for (auto it : { make_pair(0,1), make_pair(1,1), make_pair(0,2), make_pair(1,2), make_pair(2,1) }) for (const auto &it : params_derivs_temporary_terms)
writeTemporaryTerms(params_derivs_temporary_terms.find(it)->second, temp_term_union, params_derivs_temporary_terms_idxs, tt_output, output_type, tef_terms); writeTemporaryTerms(it.second, temp_term_union, params_derivs_temporary_terms_idxs, tt_output, output_type, tef_terms);
for (const auto & residuals_params_derivative : params_derivatives.find({ 0, 1 })->second) for (const auto & residuals_params_derivative : params_derivatives.find({ 0, 1 })->second)
{ {
...@@ -6553,8 +6553,8 @@ DynamicModel::writeJsonParamsDerivativesFile(ostream &output, bool writeDetails) ...@@ -6553,8 +6553,8 @@ DynamicModel::writeJsonParamsDerivativesFile(ostream &output, bool writeDetails)
temporary_terms_t temp_term_union; temporary_terms_t temp_term_union;
string concat = "all"; string concat = "all";
for (auto it : { make_pair(0,1), make_pair(1,1), make_pair(0,2), make_pair(1,2), make_pair(2,1) }) for (const auto &it : params_derivs_temporary_terms)
writeJsonTemporaryTerms(params_derivs_temporary_terms.find(it)->second, temp_term_union, model_output, tef_terms, concat); writeJsonTemporaryTerms(it.second, temp_term_union, model_output, tef_terms, concat);
jacobian_output << "\"deriv_wrt_params\": {" jacobian_output << "\"deriv_wrt_params\": {"
<< " \"neqs\": " << equations.size() << " \"neqs\": " << equations.size()
......
...@@ -86,7 +86,7 @@ ExprNode::cost(const temporary_terms_t &temp_terms_map, bool is_matlab) const ...@@ -86,7 +86,7 @@ ExprNode::cost(const temporary_terms_t &temp_terms_map, bool is_matlab) const
} }
int int
ExprNode::cost(const map<NodeTreeReference, temporary_terms_t> &temp_terms_map, bool is_matlab) const ExprNode::cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const
{ {
// For a terminal node, the cost is null // For a terminal node, the cost is null
return 0; return 0;
...@@ -146,9 +146,10 @@ ExprNode::collectExogenous(set<pair<int, int>> &result) const ...@@ -146,9 +146,10 @@ ExprNode::collectExogenous(set<pair<int, int>> &result) const
} }
void void
ExprNode::computeTemporaryTerms(map<expr_t, pair<int, NodeTreeReference>> &reference_count, ExprNode::computeTemporaryTerms(const pair<int, int> &derivOrder,
map<NodeTreeReference, temporary_terms_t> &temp_terms_map, map<pair<int, int>, temporary_terms_t> &temp_terms_map,
bool is_matlab, NodeTreeReference tr) const map<expr_t, pair<int, pair<int, int>>> &reference_count,
bool is_matlab) const
{ {
// Nothing to do for a terminal node // Nothing to do for a terminal node
} }
...@@ -2169,7 +2170,7 @@ UnaryOpNode::computeDerivative(int deriv_id) ...@@ -2169,7 +2170,7 @@ UnaryOpNode::computeDerivative(int deriv_id)
} }
int int
UnaryOpNode::cost(const map<NodeTreeReference, temporary_terms_t> &temp_terms_map, bool is_matlab) const UnaryOpNode::cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const
{ {
// For a temporary term, the cost is null // For a temporary term, the cost is null
for (const auto & it : temp_terms_map) for (const auto & it : temp_terms_map)
...@@ -2295,17 +2296,18 @@ UnaryOpNode::cost(int cost, bool is_matlab) const ...@@ -2295,17 +2296,18 @@ UnaryOpNode::cost(int cost, bool is_matlab) const
} }
void void
UnaryOpNode::computeTemporaryTerms(map<expr_t, pair<int, NodeTreeReference>> &reference_count, UnaryOpNode::computeTemporaryTerms(const pair<int, int> &derivOrder,
map<NodeTreeReference, temporary_terms_t> &temp_terms_map, map<pair<int, int>, temporary_terms_t> &temp_terms_map,
bool is_matlab, NodeTreeReference tr) const map<expr_t, pair<int, pair<int, int>>> &reference_count,
bool is_matlab) const
{ {
expr_t this2 = const_cast<UnaryOpNode *>(this); expr_t this2 = const_cast<UnaryOpNode *>(this);
auto it = reference_count.find(this2); auto it = reference_count.find(this2);
if (it == reference_count.end()) if (it == reference_count.end())
{ {
reference_count[this2] = { 1, tr }; reference_count[this2] = { 1, derivOrder };
arg->computeTemporaryTerms(reference_count, temp_terms_map, is_matlab, tr); arg->computeTemporaryTerms(derivOrder, temp_terms_map, reference_count, is_matlab);
} }
else else
{ {
...@@ -4057,7 +4059,7 @@ BinaryOpNode::precedenceJson(const temporary_terms_t &temporary_terms) const ...@@ -4057,7 +4059,7 @@ BinaryOpNode::precedenceJson(const temporary_terms_t &temporary_terms) const
} }
int int
BinaryOpNode::cost(const map<NodeTreeReference, temporary_terms_t> &temp_terms_map, bool is_matlab) const BinaryOpNode::cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const
{ {
// For a temporary term, the cost is null // For a temporary term, the cost is null
for (const auto & it : temp_terms_map) for (const auto & it : temp_terms_map)
...@@ -4142,9 +4144,10 @@ BinaryOpNode::cost(int cost, bool is_matlab) const ...@@ -4142,9 +4144,10 @@ BinaryOpNode::cost(int cost, bool is_matlab) const
} }
void void
BinaryOpNode::computeTemporaryTerms(map<expr_t, pair<int, NodeTreeReference>> &reference_count, BinaryOpNode::computeTemporaryTerms(const pair<int, int> &derivOrder,
map<NodeTreeReference, temporary_terms_t> &temp_terms_map, map<pair<int, int>, temporary_terms_t> &temp_terms_map,
bool is_matlab, NodeTreeReference tr) const map<expr_t, pair<int, pair<int, int>>> &reference_count,
bool is_matlab) const
{ {
expr_t this2 = const_cast<BinaryOpNode *>(this); expr_t this2 = const_cast<BinaryOpNode *>(this);
auto it = reference_count.find(this2); auto it = reference_count.find(this2);
...@@ -4152,9 +4155,9 @@ BinaryOpNode::computeTemporaryTerms(map<expr_t, pair<int, NodeTreeReference>> &r ...@@ -4152,9 +4155,9 @@ BinaryOpNode::computeTemporaryTerms(map<expr_t, pair<int, NodeTreeReference>> &r
{ {
// If this node has never been encountered, set its ref count to one, // If this node has never been encountered, set its ref count to one,
// and travel through its children // and travel through its children
reference_count[this2] = { 1, tr }; reference_count[this2] = { 1, derivOrder };
arg1->computeTemporaryTerms(reference_count, temp_terms_map, is_matlab, tr); arg1->computeTemporaryTerms(derivOrder, temp_terms_map, reference_count, is_matlab);
arg2->computeTemporaryTerms(reference_count, temp_terms_map, is_matlab, tr); arg2->computeTemporaryTerms(derivOrder, temp_terms_map, reference_count, is_matlab);
} }
else else
{ {
...@@ -5964,7 +5967,7 @@ TrinaryOpNode::precedence(ExprNodeOutputType output_type, const temporary_terms_ ...@@ -5964,7 +5967,7 @@ TrinaryOpNode::precedence(ExprNodeOutputType output_type, const temporary_terms_
} }
int int
TrinaryOpNode::cost(const map<NodeTreeReference, temporary_terms_t> &temp_terms_map, bool is_matlab) const TrinaryOpNode::cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const
{ {
// For a temporary term, the cost is null // For a temporary term, the cost is null
for (const auto & it : temp_terms_map) for (const auto & it : temp_terms_map)
...@@ -6016,9 +6019,10 @@ TrinaryOpNode::cost(int cost, bool is_matlab) const ...@@ -6016,9 +6019,10 @@ TrinaryOpNode::cost(int cost, bool is_matlab) const
} }
void void
TrinaryOpNode::computeTemporaryTerms(map<expr_t, pair<int, NodeTreeReference>> &reference_count, TrinaryOpNode::computeTemporaryTerms(const pair<int, int> &derivOrder,
map<NodeTreeReference, temporary_terms_t> &temp_terms_map, map<pair<int, int>, temporary_terms_t> &temp_terms_map,
bool is_matlab, NodeTreeReference tr) const map<expr_t, pair<int, pair<int, int>>> &reference_count,
bool is_matlab) const
{ {
expr_t this2 = const_cast<TrinaryOpNode *>(this); expr_t this2 = const_cast<TrinaryOpNode *>(this);
auto it = reference_count.find(this2); auto it = reference_count.find(this2);
...@@ -6026,10 +6030,10 @@ TrinaryOpNode::computeTemporaryTerms(map<expr_t, pair<int, NodeTreeReference>> & ...@@ -6026,10 +6030,10 @@ TrinaryOpNode::computeTemporaryTerms(map<expr_t, pair<int, NodeTreeReference>> &
{ {
// If this node has never been encountered, set its ref count to one, // If this node has never been encountered, set its ref count to one,
// and travel through its children // and travel through its children
reference_count[this2] = { 1, tr }; reference_count[this2] = { 1, derivOrder };
arg1->computeTemporaryTerms(reference_count, temp_terms_map, is_matlab, tr); arg1->computeTemporaryTerms(derivOrder, temp_terms_map, reference_count, is_matlab);
arg2->computeTemporaryTerms(reference_count, temp_terms_map, is_matlab, tr); arg2->computeTemporaryTerms(derivOrder, temp_terms_map, reference_count, is_matlab);
arg3->computeTemporaryTerms(reference_count, temp_terms_map, is_matlab, tr); arg3->computeTemporaryTerms(derivOrder, temp_terms_map, reference_count, is_matlab);
} }
else else
{ {
...@@ -7118,9 +7122,10 @@ AbstractExternalFunctionNode::getIndxInTefTerms(int the_symb_id, const deriv_nod ...@@ -7118,9 +7122,10 @@ AbstractExternalFunctionNode::getIndxInTefTerms(int the_symb_id, const deriv_nod
} }
void void
AbstractExternalFunctionNode::computeTemporaryTerms(map<expr_t, pair<int, NodeTreeReference>> &reference_count, AbstractExternalFunctionNode::computeTemporaryTerms(const pair<int, int> &derivOrder,
map<NodeTreeReference, temporary_terms_t> &temp_terms_map, map<pair<int, int>, temporary_terms_t> &temp_terms_map,
bool is_matlab, NodeTreeReference tr) const map<expr_t, pair<int, pair<int, int>>> &reference_count,
bool is_matlab) const
{ {
/* All external function nodes are declared as temporary terms. /* All external function nodes are declared as temporary terms.
...@@ -7133,18 +7138,17 @@ AbstractExternalFunctionNode::computeTemporaryTerms(map<expr_t, pair<int, NodeTr ...@@ -7133,18 +7138,17 @@ AbstractExternalFunctionNode::computeTemporaryTerms(map<expr_t, pair<int, NodeTr
corresponding to the same external function call is present in that corresponding to the same external function call is present in that
previous level. */ previous level. */
for (auto tr2 : nodeTreeReferencesBefore(tr)) for (auto &tt : temp_terms_map)
{ {
auto it = find_if(temp_terms_map[tr2].cbegin(), temp_terms_map[tr2].cend(), auto it = find_if(tt.second.cbegin(), tt.second.cend(), sameTefTermPredicate());
sameTefTermPredicate()); if (it != tt.second.cend())
if (it != temp_terms_map[tr2].cend())
{ {
temp_terms_map[tr2].insert(const_cast<AbstractExternalFunctionNode *>(this)); tt.second.insert(const_cast<AbstractExternalFunctionNode *>(this));
return; return;
} }
} }
temp_terms_map[tr].insert(const_cast<AbstractExternalFunctionNode *>(this)); temp_terms_map[derivOrder].insert(const_cast<AbstractExternalFunctionNode *>(this));
} }
bool bool
...@@ -8460,9 +8464,10 @@ VarExpectationNode::VarExpectationNode(DataTree &datatree_arg, ...@@ -8460,9 +8464,10 @@ VarExpectationNode::VarExpectationNode(DataTree &datatree_arg,
} }
void void
VarExpectationNode::computeTemporaryTerms(map<expr_t, pair<int, NodeTreeReference>> &reference_count, VarExpectationNode::computeTemporaryTerms(const pair<int, int> &derivOrder,
map<NodeTreeReference, temporary_terms_t> &temp_terms_map, map<pair<int, int>, temporary_terms_t> &temp_terms_map,
bool is_matlab, NodeTreeReference tr) const map<expr_t, pair<int, pair<int, int>>> &reference_count,
bool is_matlab) const
{ {
cerr << "VarExpectationNode::computeTemporaryTerms not implemented." << endl; cerr << "VarExpectationNode::computeTemporaryTerms not implemented." << endl;
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
...@@ -8917,11 +8922,12 @@ PacExpectationNode::PacExpectationNode(DataTree &datatree_arg, ...@@ -8917,11 +8922,12 @@ PacExpectationNode::PacExpectationNode(DataTree &datatree_arg,
} }
void void
PacExpectationNode::computeTemporaryTerms(map<expr_t, pair<int, NodeTreeReference>> &reference_count, PacExpectationNode::computeTemporaryTerms(const pair<int, int> &derivOrder,
map<NodeTreeReference, temporary_terms_t> &temp_terms_map, map<pair<int, int>, temporary_terms_t> &temp_terms_map,
bool is_matlab, NodeTreeReference tr) const map<expr_t, pair<int, pair<int, int>>> &reference_count,
bool is_matlab) const
{ {
temp_terms_map[tr].insert(const_cast<PacExpectationNode *>(this)); temp_terms_map[derivOrder].insert(const_cast<PacExpectationNode *>(this));
} }
void void
......
...@@ -188,7 +188,7 @@ class ExprNode ...@@ -188,7 +188,7 @@ class ExprNode
/*! Nodes included in temporary_terms are considered having a null cost */ /*! Nodes included in temporary_terms are considered having a null cost */
virtual int cost(int cost, bool is_matlab) const; virtual int cost(int cost, bool is_matlab) const;
virtual int cost(const temporary_terms_t &temporary_terms, bool is_matlab) const; virtual int cost(const temporary_terms_t &temporary_terms, bool is_matlab) const;
virtual int cost(const map<NodeTreeReference, temporary_terms_t> &temp_terms_map, bool is_matlab) const; virtual int cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const;
//! For creating equation cross references //! For creating equation cross references
struct EquationInfo struct EquationInfo
...@@ -237,11 +237,26 @@ class ExprNode ...@@ -237,11 +237,26 @@ class ExprNode
/*! Equals 100 for constants, variables, unary ops, and temporary terms */ /*! Equals 100 for constants, variables, unary ops, and temporary terms */
virtual int precedence(ExprNodeOutputType output_t, const temporary_terms_t &temporary_terms) const; virtual int precedence(ExprNodeOutputType output_t, const temporary_terms_t &temporary_terms) const;
//! Fills temporary_terms set, using reference counts //! Compute temporary terms in this expression
/*! A node will be marked as a temporary term if it is referenced at least two times (i.e. has at least two parents), and has a computing cost (multiplied by reference count) greater to datatree.min_cost */ /*!
virtual void computeTemporaryTerms(map<expr_t, pair<int, NodeTreeReference>> &reference_count, \param[in] derivOrder the derivation order (first w.r.t. endo/exo,
map<NodeTreeReference, temporary_terms_t> &temp_terms_map, second w.r.t. params)
bool is_matlab, NodeTreeReference tr) const; \param[out] temp_terms_map the computed temporary terms, associated
with their derivation order
\param[out] reference_count a temporary structure, used to count
references to each node (integer in outer pair is the
reference count, the inner pair is the derivation order)
\param[in] is_matlab whether we are in a MATLAB context, since that
affects the cost of each operator
A node will be marked as a temporary term if it is referenced at least
two times (i.e. has at least two parents), and has a computing cost
(multiplied by reference count) greater to datatree.min_cost
*/
virtual void computeTemporaryTerms(const pair<int, int> &derivOrder,
map<pair<int, int>, temporary_terms_t> &temp_terms_map,
map<expr_t, pair<int, pair<int, int>>> &reference_count,
bool is_matlab) const;
//! Writes output of node, using a Txxx notation for nodes in temporary_terms, and specifiying the set of already written external functions //! Writes output of node, using a Txxx notation for nodes in temporary_terms, and specifiying the set of already written external functions
/*! /*!
...@@ -249,8 +264,8 @@ class ExprNode ...@@ -249,8 +264,8 @@ class ExprNode
\param[in] output_type the type of output (MATLAB, C, LaTeX...) \param[in] output_type the type of output (MATLAB, C, LaTeX...)
\param[in] temporary_terms the nodes that are marked as temporary terms \param[in] temporary_terms the nodes that are marked as temporary terms
\param[in] a map from temporary_terms to integers indexes (in the \param[in] a map from temporary_terms to integers indexes (in the
MATLAB or Julia vector of temporary terms); can be empty MATLAB, C or Julia vector of temporary terms); can be empty
when writing C or MATLAB with block decomposition) when writing MATLAB with block decomposition)
\param[in] tef_terms the set of already written external function nodes \param[in] tef_terms the set of already written external function nodes
*/ */
virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const = 0; virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const = 0;
...@@ -821,15 +836,16 @@ private: ...@@ -821,15 +836,16 @@ private:
expr_t computeDerivative(int deriv_id) override; expr_t computeDerivative(int deriv_id) override;
int cost(int cost, bool is_matlab) const override; int cost(int cost, bool is_matlab) const override;
int cost(const temporary_terms_t &temporary_terms, bool is_matlab) const override; int cost(const temporary_terms_t &temporary_terms, bool is_matlab) const override;
int cost(const map<NodeTreeReference, temporary_terms_t> &temp_terms_map, bool is_matlab) const override; int cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const override;
//! Returns the derivative of this node if darg is the derivative of the argument //! Returns the derivative of this node if darg is the derivative of the argument
expr_t composeDerivatives(expr_t darg, int deriv_id); expr_t composeDerivatives(expr_t darg, int deriv_id);
public: public:
UnaryOpNode(DataTree &datatree_arg, int idx_arg, UnaryOpcode op_code_arg, const expr_t arg_arg, int expectation_information_set_arg, int param1_symb_id_arg, int param2_symb_id_arg, string adl_param_name_arg, vector<int> adl_lags_arg); UnaryOpNode(DataTree &datatree_arg, int idx_arg, UnaryOpcode op_code_arg, const expr_t arg_arg, int expectation_information_set_arg, int param1_symb_id_arg, int param2_symb_id_arg, string adl_param_name_arg, vector<int> adl_lags_arg);
void prepareForDerivation() override; void prepareForDerivation() override;
void computeTemporaryTerms(map<expr_t, pair<int, NodeTreeReference>> &reference_count, void computeTemporaryTerms(const pair<int, int> &derivOrder,
map<NodeTreeReference, temporary_terms_t> &temp_terms_map, map<pair<int, int>, temporary_terms_t> &temp_terms_map,
bool is_matlab, NodeTreeReference tr) const override; map<expr_t, pair<int, pair<int, int>>> &reference_count,
bool is_matlab) const override;
void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const override; void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const override;
void writeJsonAST(ostream &output) const override; void writeJsonAST(ostream &output) const override;
void writeJsonOutput(ostream &output, const temporary_terms_t &temporary_terms, const deriv_node_temp_terms_t &tef_terms, const bool isdynamic) const override; void writeJsonOutput(ostream &output, const temporary_terms_t &temporary_terms, const deriv_node_temp_terms_t &tef_terms, const bool isdynamic) const override;
...@@ -933,7 +949,7 @@ private: ...@@ -933,7 +949,7 @@ private:
expr_t computeDerivative(int deriv_id) override; expr_t computeDerivative(int deriv_id) override;
int cost(int cost, bool is_matlab) const override; int cost(int cost, bool is_matlab) const override;
int cost(const temporary_terms_t &temporary_terms, bool is_matlab) const override; int cost(const temporary_terms_t &temporary_terms, bool is_matlab) const override;
int cost(const map<NodeTreeReference, temporary_terms_t> &temp_terms_map, bool is_matlab) const override; int cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const override;
//! Returns the derivative of this node if darg1 and darg2 are the derivatives of the arguments //! Returns the derivative of this node if darg1 and darg2 are the derivatives of the arguments
expr_t composeDerivatives(expr_t darg1, expr_t darg2); expr_t composeDerivatives(expr_t darg1, expr_t darg2);
public: public:
...@@ -942,9 +958,10 @@ public: ...@@ -942,9 +958,10 @@ public:
void prepareForDerivation() override; void prepareForDerivation() override;
int precedenceJson(const temporary_terms_t &temporary_terms) const override; int precedenceJson(const temporary_terms_t &temporary_terms) const override;
int precedence(ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms) const override; int precedence(ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms) const override;
void computeTemporaryTerms(map<expr_t, pair<int, NodeTreeReference>> &reference_count, void computeTemporaryTerms(const pair<int, int> &derivOrder,
map<NodeTreeReference, temporary_terms_t> &temp_terms_map, map<pair<int, int>, temporary_terms_t> &temp_terms_map,
bool is_matlab, NodeTreeReference tr) const override; map<expr_t, pair<int, pair<int, int>>> &reference_count,
bool is_matlab) const override;
void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const override; void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const override;
void writeJsonAST(ostream &output) const override; void writeJsonAST(ostream &output) const override;
void writeJsonOutput(ostream &output, const temporary_terms_t &temporary_terms, const deriv_node_temp_terms_t &tef_terms, const bool isdynamic) const override; void writeJsonOutput(ostream &output, const temporary_terms_t &temporary_terms, const deriv_node_temp_terms_t &tef_terms, const bool isdynamic) const override;
...@@ -1065,7 +1082,7 @@ private: ...@@ -1065,7 +1082,7 @@ private:
expr_t computeDerivative(int deriv_id) override; expr_t computeDerivative(int deriv_id) override;
int cost(int cost, bool is_matlab) const override; int cost(int cost, bool is_matlab) const override;
int cost(const temporary_terms_t &temporary_terms, bool is_matlab) const override; int cost(const temporary_terms_t &temporary_terms, bool is_matlab) const override;
int cost(const map<NodeTreeReference, temporary_terms_t> &temp_terms_map, bool is_matlab) const override; int cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const override;
//! Returns the derivative of this node if darg1, darg2 and darg3 are the derivatives of the arguments //! Returns the derivative of this node if darg1, darg2 and darg3 are the derivatives of the arguments
expr_t composeDerivatives(expr_t darg1, expr_t darg2, expr_t darg3); expr_t composeDerivatives(expr_t darg1, expr_t darg2, expr_t darg3);
public: public:
...@@ -1073,9 +1090,10 @@ public: ...@@ -1073,9 +1090,10 @@ public:
TrinaryOpcode op_code_arg, const expr_t arg2_arg, const expr_t arg3_arg); TrinaryOpcode op_code_arg, const expr_t arg2_arg, const expr_t arg3_arg);
void prepareForDerivation() override; void prepareForDerivation() override;
int precedence(ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms) const override; int precedence(ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms) const override;
void computeTemporaryTerms(map<expr_t, pair<int, NodeTreeReference>> &reference_count, void computeTemporaryTerms(const pair<int, int> &derivOrder,
map<NodeTreeReference, temporary_terms_t> &temp_terms_map, map<pair<int, int>, temporary_terms_t> &temp_terms_map,
bool is_matlab, NodeTreeReference tr) const override; map<expr_t, pair<int, pair<int, int>>> &reference_count,
bool is_matlab) const override;
void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const override; void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const override;
void writeJsonAST(ostream &output) const override; void writeJsonAST(ostream &output) const override;
void writeJsonOutput(ostream &output, const temporary_terms_t &temporary_terms, const deriv_node_temp_terms_t &tef_terms, const bool isdynamic) const override; void writeJsonOutput(ostream &output, const temporary_terms_t &temporary_terms, const deriv_node_temp_terms_t &tef_terms, const bool isdynamic) const override;
...@@ -1193,9 +1211,10 @@ public: ...@@ -1193,9 +1211,10 @@ public:
AbstractExternalFunctionNode(DataTree &datatree_arg, int idx_arg, int symb_id_arg, AbstractExternalFunctionNode(DataTree &datatree_arg, int idx_arg, int symb_id_arg,
vector<expr_t> arguments_arg); vector<expr_t> arguments_arg);
void prepareForDerivation() override; void prepareForDerivation() override;
void computeTemporaryTerms(map<expr_t, pair<int, NodeTreeReference>> &reference_count, void computeTemporaryTerms(const pair<int, int> &derivOrder,
map<NodeTreeReference, temporary_terms_t> &temp_terms_map, map<pair<int, int>, temporary_terms_t> &temp_terms_map,
bool is_matlab, NodeTreeReference tr) const override; map<expr_t, pair<int, pair<int, int>>> &reference_count,
bool is_matlab) const override;
void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const override = 0; void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const override = 0;
void writeJsonAST(ostream &output) const override = 0; void writeJsonAST(ostream &output) const override = 0;
void writeJsonOutput(ostream &output, const temporary_terms_t &temporary_terms, const deriv_node_temp_terms_t &tef_terms, const bool isdynamic = true) const override = 0; void writeJsonOutput(ostream &output, const temporary_terms_t &temporary_terms, const deriv_node_temp_terms_t &tef_terms, const bool isdynamic = true) const override = 0;
...@@ -1421,9 +1440,10 @@ class VarExpectationNode : public ExprNode ...@@ -1421,9 +1440,10 @@ class VarExpectationNode : public ExprNode
public: public:
const string model_name; const string model_name;
VarExpectationNode(DataTree &datatree_arg, int idx_arg, string model_name_arg); VarExpectationNode(DataTree &datatree_arg, int idx_arg, string model_name_arg);
void computeTemporaryTerms(map<expr_t, pair<int, NodeTreeReference>> &reference_count, void computeTemporaryTerms(const pair<int, int> &derivOrder,
map<NodeTreeReference, temporary_terms_t> &temp_terms_map, map<pair<int, int>, temporary_terms_t> &temp_terms_map,
bool is_matlab, NodeTreeReference tr) const override; map<expr_t, pair<int, pair<int, int>>> &reference_count,
bool is_matlab) const override;
void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const override; void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const override;
void computeTemporaryTerms(map<expr_t, int> &reference_count, void computeTemporaryTerms(map<expr_t, int> &reference_count,
temporary_terms_t &temporary_terms, temporary_terms_t &temporary_terms,
...@@ -1519,9 +1539,10 @@ private: ...@@ -1519,9 +1539,10 @@ private:
vector<tuple<int, int, int, double>> non_optim_vars_params_and_constants; vector<tuple<int, int, int, double>> non_optim_vars_params_and_constants;
public: public:
PacExpectationNode(DataTree &datatree_arg, int idx_arg, string model_name); PacExpectationNode(DataTree &datatree_arg, int idx_arg, string model_name);
void computeTemporaryTerms(map<expr_t, pair<int, NodeTreeReference>> &reference_count, void computeTemporaryTerms(const pair<int, int> &derivOrder,
map<NodeTreeReference, temporary_terms_t> &temp_terms_map, map<pair<int, int>, temporary_terms_t> &temp_terms_map,
bool is_matlab, NodeTreeReference tr) const override; map<expr_t, pair<int, pair<int, int>>> &reference_count,
bool is_matlab) const override;
void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const override; void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const override;
void computeTemporaryTerms(map<expr_t, int> &reference_count, void computeTemporaryTerms(map<expr_t, int> &reference_count,
temporary_terms_t &temporary_terms, temporary_terms_t &temporary_terms,
......
...@@ -1315,17 +1315,12 @@ ModelTree::computeDerivatives(int order, const set<int> &vars) ...@@ -1315,17 +1315,12 @@ ModelTree::computeDerivatives(int order, const set<int> &vars)
void void
ModelTree::computeTemporaryTerms(bool is_matlab, bool no_tmp_terms) ModelTree::computeTemporaryTerms(bool is_matlab, bool no_tmp_terms)
{ {
map<expr_t, pair<int, NodeTreeReference>> reference_count;
temporary_terms.clear();
temporary_terms_mlv.clear();
temporary_terms_derivatives.clear();
temporary_terms_derivatives.resize(4);
/* Collect all model local variables appearing in equations (and only those, /* Collect all model local variables appearing in equations (and only those,
because printing unused model local variables can lead to a crash, because printing unused model local variables can lead to a crash,
see Dynare/dynare#101). see Dynare/dynare#101).
Then store them in a dedicated structure (temporary_terms_mlv), that will Then store them in a dedicated structure (temporary_terms_mlv), that will
be treated as the rest of temporary terms. */ be treated as the rest of temporary terms. */
temporary_terms_mlv.clear();
set<int> used_local_vars; set<int> used_local_vars;
for (auto & equation : equations) for (auto & equation : equations)
equation->collectVariables(SymbolType::modelLocalVariable, used_local_vars); equation->collectVariables(SymbolType::modelLocalVariable, used_local_vars);
...@@ -1335,59 +1330,44 @@ ModelTree::computeTemporaryTerms(bool is_matlab, bool no_tmp_terms) ...@@ -1335,59 +1330,44 @@ ModelTree::computeTemporaryTerms(bool is_matlab, bool no_tmp_terms)
temporary_terms_mlv[v] = local_variables_table.find(used_local_var)->second; temporary_terms_mlv[v] = local_variables_table.find(used_local_var)->second;
} }
map<NodeTreeReference, temporary_terms_t> temp_terms_map; // Compute the temporary terms in equations and derivatives
temp_terms_map[NodeTreeReference::residuals] = temporary_terms_derivatives[0]; map<pair<int, int>, temporary_terms_t> temp_terms_map;
temp_terms_map[NodeTreeReference::firstDeriv] = temporary_terms_derivatives[1];
temp_terms_map[NodeTreeReference::secondDeriv] = temporary_terms_derivatives[2];
temp_terms_map[NodeTreeReference::thirdDeriv] = temporary_terms_derivatives[3];
if (!no_tmp_terms) if (!no_tmp_terms)
{ {
map<expr_t, pair<int, pair<int, int>>> reference_count;
for (auto & equation : equations) for (auto & equation : equations)
equation->computeTemporaryTerms(reference_count, equation->computeTemporaryTerms({ 0, 0 },
temp_terms_map, temp_terms_map,