From 571b5d081657271b7f5e02f93778bddf9020787b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?S=C3=A9bastien=20Villemot?= <sebastien@dynare.org>
Date: Fri, 30 Nov 2018 12:22:13 +0100
Subject: [PATCH] Computation of temporary terms generalized to any derivation
 order

---
 src/CodeInterpreter.hh |  40 ------------
 src/DynamicModel.cc    |   8 +--
 src/ExprNode.cc        |  88 ++++++++++++++------------
 src/ExprNode.hh        |  79 ++++++++++++++---------
 src/ModelTree.cc       | 140 ++++++++++++-----------------------------
 src/ModelTree.hh       |   2 +
 src/StaticModel.cc     |   8 +--
 7 files changed, 146 insertions(+), 219 deletions(-)

diff --git a/src/CodeInterpreter.hh b/src/CodeInterpreter.hh
index c9881c9e..d9c21d5e 100644
--- a/src/CodeInterpreter.hh
+++ b/src/CodeInterpreter.hh
@@ -257,46 +257,6 @@ enum class PriorDistributions
     weibull = 8
   };
 
-enum class NodeTreeReference
-  {
-    residuals,
-    firstDeriv,
-    secondDeriv,
-    thirdDeriv,
-    residualsParamsDeriv,
-    jacobianParamsDeriv,
-    residualsParamsSecondDeriv,
-    jacobianParamsSecondDeriv,
-    hessianParamsDeriv
-  };
-
-/*! Lists elements of the NodeTreeReference enum that come “before” the argument.
-    Used in AbstractExternalFunctionNode::computeTemporaryTerms */
-inline auto
-nodeTreeReferencesBefore(NodeTreeReference tr)
-{
-  vector<NodeTreeReference> v;
-
-  // Should be same order as the one appearing in ModelTree::computeTemporaryTerms()
-  for (auto tr2 : { NodeTreeReference::residuals, NodeTreeReference::firstDeriv, NodeTreeReference::secondDeriv, NodeTreeReference::thirdDeriv })
-    if (tr == tr2)
-      return v;
-    else
-      v.push_back(tr2);
-  v.clear();
-
-  // Should be same order as the one appearing in ModelTree::computeParamsDerivativesTemporaryTerms()
-  for (auto tr2 : { NodeTreeReference::residualsParamsDeriv, NodeTreeReference::jacobianParamsDeriv, NodeTreeReference::residualsParamsSecondDeriv,
-        NodeTreeReference::jacobianParamsSecondDeriv, NodeTreeReference::hessianParamsDeriv})
-    if (tr == tr2)
-      return v;
-    else
-      v.push_back(tr2);
-
-  cerr << "nodeTreeReferencesBelow: impossible case" << endl;
-  exit(EXIT_FAILURE);
-}
-
 struct Block_contain_type
 {
   int Equation, Variable, Own_Derivative;
diff --git a/src/DynamicModel.cc b/src/DynamicModel.cc
index 41398450..0316c975 100644
--- a/src/DynamicModel.cc
+++ b/src/DynamicModel.cc
@@ -5349,8 +5349,8 @@ DynamicModel::writeParamsDerivativesFile(const string &basename, bool julia) con
   deriv_node_temp_terms_t tef_terms;
 
   writeModelLocalVariableTemporaryTerms(temp_term_union, params_derivs_temporary_terms_idxs, tt_output, output_type, tef_terms);
-  for (auto it : { make_pair(0,1), make_pair(1,1), make_pair(0,2), make_pair(1,2), make_pair(2,1) })
-    writeTemporaryTerms(params_derivs_temporary_terms.find(it)->second, temp_term_union, params_derivs_temporary_terms_idxs, tt_output, output_type, tef_terms);
+  for (const auto &it : params_derivs_temporary_terms)
+    writeTemporaryTerms(it.second, temp_term_union, params_derivs_temporary_terms_idxs, tt_output, output_type, tef_terms);
 
   for (const auto & residuals_params_derivative : params_derivatives.find({ 0, 1 })->second)
     {
@@ -6553,8 +6553,8 @@ DynamicModel::writeJsonParamsDerivativesFile(ostream &output, bool writeDetails)
 
   temporary_terms_t temp_term_union;
   string concat = "all";
-  for (auto it : { make_pair(0,1), make_pair(1,1), make_pair(0,2), make_pair(1,2), make_pair(2,1) })
-    writeJsonTemporaryTerms(params_derivs_temporary_terms.find(it)->second, temp_term_union, model_output, tef_terms, concat);
+  for (const auto &it : params_derivs_temporary_terms)
+    writeJsonTemporaryTerms(it.second, temp_term_union, model_output, tef_terms, concat);
 
   jacobian_output << "\"deriv_wrt_params\": {"
                   << "  \"neqs\": " << equations.size()
diff --git a/src/ExprNode.cc b/src/ExprNode.cc
index a415b7f9..a2821e90 100644
--- a/src/ExprNode.cc
+++ b/src/ExprNode.cc
@@ -86,7 +86,7 @@ ExprNode::cost(const temporary_terms_t &temp_terms_map, bool is_matlab) const
 }
 
 int
-ExprNode::cost(const map<NodeTreeReference, temporary_terms_t> &temp_terms_map, bool is_matlab) const
+ExprNode::cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const
 {
   // For a terminal node, the cost is null
   return 0;
@@ -146,9 +146,10 @@ ExprNode::collectExogenous(set<pair<int, int>> &result) const
 }
 
 void
-ExprNode::computeTemporaryTerms(map<expr_t, pair<int, NodeTreeReference>> &reference_count,
-                                map<NodeTreeReference, temporary_terms_t> &temp_terms_map,
-                                bool is_matlab, NodeTreeReference tr) const
+ExprNode::computeTemporaryTerms(const pair<int, int> &derivOrder,
+                                map<pair<int, int>, temporary_terms_t> &temp_terms_map,
+                                map<expr_t, pair<int, pair<int, int>>> &reference_count,
+                                bool is_matlab) const
 {
   // Nothing to do for a terminal node
 }
@@ -2169,7 +2170,7 @@ UnaryOpNode::computeDerivative(int deriv_id)
 }
 
 int
-UnaryOpNode::cost(const map<NodeTreeReference, temporary_terms_t> &temp_terms_map, bool is_matlab) const
+UnaryOpNode::cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const
 {
   // For a temporary term, the cost is null
   for (const auto & it : temp_terms_map)
@@ -2295,17 +2296,18 @@ UnaryOpNode::cost(int cost, bool is_matlab) const
 }
 
 void
-UnaryOpNode::computeTemporaryTerms(map<expr_t, pair<int, NodeTreeReference>> &reference_count,
-                                   map<NodeTreeReference, temporary_terms_t> &temp_terms_map,
-                                   bool is_matlab, NodeTreeReference tr) const
+UnaryOpNode::computeTemporaryTerms(const pair<int, int> &derivOrder,
+                                   map<pair<int, int>, temporary_terms_t> &temp_terms_map,
+                                   map<expr_t, pair<int, pair<int, int>>> &reference_count,
+                                   bool is_matlab) const
 {
   expr_t this2 = const_cast<UnaryOpNode *>(this);
 
   auto it = reference_count.find(this2);
   if (it == reference_count.end())
     {
-      reference_count[this2] = { 1, tr };
-      arg->computeTemporaryTerms(reference_count, temp_terms_map, is_matlab, tr);
+      reference_count[this2] = { 1, derivOrder };
+      arg->computeTemporaryTerms(derivOrder, temp_terms_map, reference_count, is_matlab);
     }
   else
     {
@@ -4057,7 +4059,7 @@ BinaryOpNode::precedenceJson(const temporary_terms_t &temporary_terms) const
 }
 
 int
-BinaryOpNode::cost(const map<NodeTreeReference, temporary_terms_t> &temp_terms_map, bool is_matlab) const
+BinaryOpNode::cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const
 {
   // For a temporary term, the cost is null
   for (const auto & it : temp_terms_map)
@@ -4142,9 +4144,10 @@ BinaryOpNode::cost(int cost, bool is_matlab) const
 }
 
 void
-BinaryOpNode::computeTemporaryTerms(map<expr_t, pair<int, NodeTreeReference>> &reference_count,
-                                    map<NodeTreeReference, temporary_terms_t> &temp_terms_map,
-                                    bool is_matlab, NodeTreeReference tr) const
+BinaryOpNode::computeTemporaryTerms(const pair<int, int> &derivOrder,
+                                    map<pair<int, int>, temporary_terms_t> &temp_terms_map,
+                                    map<expr_t, pair<int, pair<int, int>>> &reference_count,
+                                    bool is_matlab) const
 {
   expr_t this2 = const_cast<BinaryOpNode *>(this);
   auto it = reference_count.find(this2);
@@ -4152,9 +4155,9 @@ BinaryOpNode::computeTemporaryTerms(map<expr_t, pair<int, NodeTreeReference>> &r
     {
       // If this node has never been encountered, set its ref count to one,
       //  and travel through its children
-      reference_count[this2] = { 1, tr };
-      arg1->computeTemporaryTerms(reference_count, temp_terms_map, is_matlab, tr);
-      arg2->computeTemporaryTerms(reference_count, temp_terms_map, is_matlab, tr);
+      reference_count[this2] = { 1, derivOrder };
+      arg1->computeTemporaryTerms(derivOrder, temp_terms_map, reference_count, is_matlab);
+      arg2->computeTemporaryTerms(derivOrder, temp_terms_map, reference_count, is_matlab);
     }
   else
     {
@@ -5964,7 +5967,7 @@ TrinaryOpNode::precedence(ExprNodeOutputType output_type, const temporary_terms_
 }
 
 int
-TrinaryOpNode::cost(const map<NodeTreeReference, temporary_terms_t> &temp_terms_map, bool is_matlab) const
+TrinaryOpNode::cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const
 {
   // For a temporary term, the cost is null
   for (const auto & it : temp_terms_map)
@@ -6016,9 +6019,10 @@ TrinaryOpNode::cost(int cost, bool is_matlab) const
 }
 
 void
-TrinaryOpNode::computeTemporaryTerms(map<expr_t, pair<int, NodeTreeReference>> &reference_count,
-                                     map<NodeTreeReference, temporary_terms_t> &temp_terms_map,
-                                     bool is_matlab, NodeTreeReference tr) const
+TrinaryOpNode::computeTemporaryTerms(const pair<int, int> &derivOrder,
+                                     map<pair<int, int>, temporary_terms_t> &temp_terms_map,
+                                     map<expr_t, pair<int, pair<int, int>>> &reference_count,
+                                     bool is_matlab) const
 {
   expr_t this2 = const_cast<TrinaryOpNode *>(this);
   auto it = reference_count.find(this2);
@@ -6026,10 +6030,10 @@ TrinaryOpNode::computeTemporaryTerms(map<expr_t, pair<int, NodeTreeReference>> &
     {
       // If this node has never been encountered, set its ref count to one,
       //  and travel through its children
-      reference_count[this2] = { 1, tr };
-      arg1->computeTemporaryTerms(reference_count, temp_terms_map, is_matlab, tr);
-      arg2->computeTemporaryTerms(reference_count, temp_terms_map, is_matlab, tr);
-      arg3->computeTemporaryTerms(reference_count, temp_terms_map, is_matlab, tr);
+      reference_count[this2] = { 1, derivOrder };
+      arg1->computeTemporaryTerms(derivOrder, temp_terms_map, reference_count, is_matlab);
+      arg2->computeTemporaryTerms(derivOrder, temp_terms_map, reference_count, is_matlab);
+      arg3->computeTemporaryTerms(derivOrder, temp_terms_map, reference_count, is_matlab);
     }
   else
     {
@@ -7118,9 +7122,10 @@ AbstractExternalFunctionNode::getIndxInTefTerms(int the_symb_id, const deriv_nod
 }
 
 void
-AbstractExternalFunctionNode::computeTemporaryTerms(map<expr_t, pair<int, NodeTreeReference>> &reference_count,
-                                            map<NodeTreeReference, temporary_terms_t> &temp_terms_map,
-                                            bool is_matlab, NodeTreeReference tr) const
+AbstractExternalFunctionNode::computeTemporaryTerms(const pair<int, int> &derivOrder,
+                                                    map<pair<int, int>, temporary_terms_t> &temp_terms_map,
+                                                    map<expr_t, pair<int, pair<int, int>>> &reference_count,
+                                                    bool is_matlab) const
 {
   /* All external function nodes are declared as temporary terms.
 
@@ -7133,18 +7138,17 @@ AbstractExternalFunctionNode::computeTemporaryTerms(map<expr_t, pair<int, NodeTr
      corresponding to the same external function call is present in that
      previous level. */
 
-  for (auto tr2 : nodeTreeReferencesBefore(tr))
+  for (auto &tt : temp_terms_map)
     {
-      auto it = find_if(temp_terms_map[tr2].cbegin(), temp_terms_map[tr2].cend(),
-                        sameTefTermPredicate());
-      if (it != temp_terms_map[tr2].cend())
+      auto it = find_if(tt.second.cbegin(), tt.second.cend(), sameTefTermPredicate());
+      if (it != tt.second.cend())
         {
-          temp_terms_map[tr2].insert(const_cast<AbstractExternalFunctionNode *>(this));
+          tt.second.insert(const_cast<AbstractExternalFunctionNode *>(this));
           return;
         }
     }
 
-  temp_terms_map[tr].insert(const_cast<AbstractExternalFunctionNode *>(this));
+  temp_terms_map[derivOrder].insert(const_cast<AbstractExternalFunctionNode *>(this));
 }
 
 bool
@@ -8460,9 +8464,10 @@ VarExpectationNode::VarExpectationNode(DataTree &datatree_arg,
 }
 
 void
-VarExpectationNode::computeTemporaryTerms(map<expr_t, pair<int, NodeTreeReference>> &reference_count,
-                                          map<NodeTreeReference, temporary_terms_t> &temp_terms_map,
-                                          bool is_matlab, NodeTreeReference tr) const
+VarExpectationNode::computeTemporaryTerms(const pair<int, int> &derivOrder,
+                                          map<pair<int, int>, temporary_terms_t> &temp_terms_map,
+                                          map<expr_t, pair<int, pair<int, int>>> &reference_count,
+                                          bool is_matlab) const
 {
   cerr << "VarExpectationNode::computeTemporaryTerms not implemented." << endl;
   exit(EXIT_FAILURE);
@@ -8917,11 +8922,12 @@ PacExpectationNode::PacExpectationNode(DataTree &datatree_arg,
 }
 
 void
-PacExpectationNode::computeTemporaryTerms(map<expr_t, pair<int, NodeTreeReference>> &reference_count,
-                                          map<NodeTreeReference, temporary_terms_t> &temp_terms_map,
-                                          bool is_matlab, NodeTreeReference tr) const
+PacExpectationNode::computeTemporaryTerms(const pair<int, int> &derivOrder,
+                                          map<pair<int, int>, temporary_terms_t> &temp_terms_map,
+                                          map<expr_t, pair<int, pair<int, int>>> &reference_count,
+                                          bool is_matlab) const
 {
-  temp_terms_map[tr].insert(const_cast<PacExpectationNode *>(this));
+  temp_terms_map[derivOrder].insert(const_cast<PacExpectationNode *>(this));
 }
 
 void
diff --git a/src/ExprNode.hh b/src/ExprNode.hh
index d3f281be..5f9770db 100644
--- a/src/ExprNode.hh
+++ b/src/ExprNode.hh
@@ -188,7 +188,7 @@ class ExprNode
       /*! Nodes included in temporary_terms are considered having a null cost */
       virtual int cost(int cost, bool is_matlab) const;
       virtual int cost(const temporary_terms_t &temporary_terms, bool is_matlab) const;
-      virtual int cost(const map<NodeTreeReference, temporary_terms_t> &temp_terms_map, bool is_matlab) const;
+      virtual int cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const;
 
       //! For creating equation cross references
       struct EquationInfo
@@ -237,11 +237,26 @@ class ExprNode
       /*! Equals 100 for constants, variables, unary ops, and temporary terms */
       virtual int precedence(ExprNodeOutputType output_t, const temporary_terms_t &temporary_terms) const;
 
-      //! Fills temporary_terms set, using reference counts
-      /*! A node will be marked as a temporary term if it is referenced at least two times (i.e. has at least two parents), and has a computing cost (multiplied by reference count) greater to datatree.min_cost */
-      virtual void computeTemporaryTerms(map<expr_t, pair<int, NodeTreeReference>> &reference_count,
-                                         map<NodeTreeReference, temporary_terms_t> &temp_terms_map,
-                                         bool is_matlab, NodeTreeReference tr) const;
+      //! Compute temporary terms in this expression
+      /*!
+        \param[in] derivOrder the derivation order (first w.r.t. endo/exo,
+                   second w.r.t. params)
+        \param[out] temp_terms_map the computed temporary terms, associated
+                    with their derivation order
+        \param[out] reference_count a temporary structure, used to count
+                    references to each node (integer in outer pair is the
+                    reference count, the inner pair is the derivation order)
+        \param[in] is_matlab whether we are in a MATLAB context, since that
+                    affects the cost of each operator
+
+        A node will be marked as a temporary term if it is referenced at least
+        two times (i.e. has at least two parents), and has a computing cost
+        (multiplied by reference count) greater to datatree.min_cost
+      */
+      virtual void computeTemporaryTerms(const pair<int, int> &derivOrder,
+                                         map<pair<int, int>, temporary_terms_t> &temp_terms_map,
+                                         map<expr_t, pair<int, pair<int, int>>> &reference_count,
+                                         bool is_matlab) const;
 
       //! Writes output of node, using a Txxx notation for nodes in temporary_terms, and specifiying the set of already written external functions
       /*!
@@ -249,8 +264,8 @@ class ExprNode
         \param[in] output_type the type of output (MATLAB, C, LaTeX...)
         \param[in] temporary_terms the nodes that are marked as temporary terms
         \param[in] a map from temporary_terms to integers indexes (in the
-                   MATLAB or Julia vector of temporary terms); can be empty
-                   when writing C or MATLAB with block decomposition)
+                   MATLAB, C or Julia vector of temporary terms); can be empty
+                   when writing MATLAB with block decomposition)
         \param[in] tef_terms the set of already written external function nodes
       */
       virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const = 0;
@@ -821,15 +836,16 @@ private:
   expr_t computeDerivative(int deriv_id) override;
   int cost(int cost, bool is_matlab) const override;
   int cost(const temporary_terms_t &temporary_terms, bool is_matlab) const override;
-  int cost(const map<NodeTreeReference, temporary_terms_t> &temp_terms_map, bool is_matlab) const override;
+  int cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const override;
   //! Returns the derivative of this node if darg is the derivative of the argument
   expr_t composeDerivatives(expr_t darg, int deriv_id);
 public:
   UnaryOpNode(DataTree &datatree_arg, int idx_arg, UnaryOpcode op_code_arg, const expr_t arg_arg, int expectation_information_set_arg, int param1_symb_id_arg, int param2_symb_id_arg, string adl_param_name_arg, vector<int> adl_lags_arg);
   void prepareForDerivation() override;
-  void computeTemporaryTerms(map<expr_t, pair<int, NodeTreeReference>> &reference_count,
-                                     map<NodeTreeReference, temporary_terms_t> &temp_terms_map,
-                                     bool is_matlab, NodeTreeReference tr) const override;
+  void computeTemporaryTerms(const pair<int, int> &derivOrder,
+                             map<pair<int, int>, temporary_terms_t> &temp_terms_map,
+                             map<expr_t, pair<int, pair<int, int>>> &reference_count,
+                             bool is_matlab) const override;
   void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const override;
   void writeJsonAST(ostream &output) const override;
   void writeJsonOutput(ostream &output, const temporary_terms_t &temporary_terms, const deriv_node_temp_terms_t &tef_terms, const bool isdynamic) const override;
@@ -933,7 +949,7 @@ private:
   expr_t computeDerivative(int deriv_id) override;
   int cost(int cost, bool is_matlab) const override;
   int cost(const temporary_terms_t &temporary_terms, bool is_matlab) const override;
-  int cost(const map<NodeTreeReference, temporary_terms_t> &temp_terms_map, bool is_matlab) const override;
+  int cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const override;
   //! Returns the derivative of this node if darg1 and darg2 are the derivatives of the arguments
   expr_t composeDerivatives(expr_t darg1, expr_t darg2);
 public:
@@ -942,9 +958,10 @@ public:
   void prepareForDerivation() override;
   int precedenceJson(const temporary_terms_t &temporary_terms) const override;
   int precedence(ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms) const override;
-  void computeTemporaryTerms(map<expr_t, pair<int, NodeTreeReference>> &reference_count,
-                                     map<NodeTreeReference, temporary_terms_t> &temp_terms_map,
-                                     bool is_matlab, NodeTreeReference tr) const override;
+  void computeTemporaryTerms(const pair<int, int> &derivOrder,
+                             map<pair<int, int>, temporary_terms_t> &temp_terms_map,
+                             map<expr_t, pair<int, pair<int, int>>> &reference_count,
+                             bool is_matlab) const override;
   void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const override;
   void writeJsonAST(ostream &output) const override;
   void writeJsonOutput(ostream &output, const temporary_terms_t &temporary_terms, const deriv_node_temp_terms_t &tef_terms, const bool isdynamic) const override;
@@ -1065,7 +1082,7 @@ private:
   expr_t computeDerivative(int deriv_id) override;
   int cost(int cost, bool is_matlab) const override;
   int cost(const temporary_terms_t &temporary_terms, bool is_matlab) const override;
-  int cost(const map<NodeTreeReference, temporary_terms_t> &temp_terms_map, bool is_matlab) const override;
+  int cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const override;
   //! Returns the derivative of this node if darg1, darg2 and darg3 are the derivatives of the arguments
   expr_t composeDerivatives(expr_t darg1, expr_t darg2, expr_t darg3);
 public:
@@ -1073,9 +1090,10 @@ public:
                 TrinaryOpcode op_code_arg, const expr_t arg2_arg, const expr_t arg3_arg);
   void prepareForDerivation() override;
   int precedence(ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms) const override;
-  void computeTemporaryTerms(map<expr_t, pair<int, NodeTreeReference>> &reference_count,
-                                     map<NodeTreeReference, temporary_terms_t> &temp_terms_map,
-                                     bool is_matlab, NodeTreeReference tr) const override;
+  void computeTemporaryTerms(const pair<int, int> &derivOrder,
+                             map<pair<int, int>, temporary_terms_t> &temp_terms_map,
+                             map<expr_t, pair<int, pair<int, int>>> &reference_count,
+                             bool is_matlab) const override;
   void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const override;
   void writeJsonAST(ostream &output) const override;
   void writeJsonOutput(ostream &output, const temporary_terms_t &temporary_terms, const deriv_node_temp_terms_t &tef_terms, const bool isdynamic) const override;
@@ -1193,9 +1211,10 @@ public:
   AbstractExternalFunctionNode(DataTree &datatree_arg, int idx_arg, int symb_id_arg,
                                vector<expr_t> arguments_arg);
   void prepareForDerivation() override;
-  void computeTemporaryTerms(map<expr_t, pair<int, NodeTreeReference>> &reference_count,
-                                     map<NodeTreeReference, temporary_terms_t> &temp_terms_map,
-                                     bool is_matlab, NodeTreeReference tr) const override;
+  void computeTemporaryTerms(const pair<int, int> &derivOrder,
+                             map<pair<int, int>, temporary_terms_t> &temp_terms_map,
+                             map<expr_t, pair<int, pair<int, int>>> &reference_count,
+                             bool is_matlab) const override;
   void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const override = 0;
   void writeJsonAST(ostream &output) const override = 0;
   void writeJsonOutput(ostream &output, const temporary_terms_t &temporary_terms, const deriv_node_temp_terms_t &tef_terms, const bool isdynamic = true) const override = 0;
@@ -1421,9 +1440,10 @@ class VarExpectationNode : public ExprNode
 public:
   const string model_name;
   VarExpectationNode(DataTree &datatree_arg, int idx_arg, string model_name_arg);
-  void computeTemporaryTerms(map<expr_t, pair<int, NodeTreeReference>> &reference_count,
-                                     map<NodeTreeReference, temporary_terms_t> &temp_terms_map,
-                                     bool is_matlab, NodeTreeReference tr) const override;
+  void computeTemporaryTerms(const pair<int, int> &derivOrder,
+                             map<pair<int, int>, temporary_terms_t> &temp_terms_map,
+                             map<expr_t, pair<int, pair<int, int>>> &reference_count,
+                             bool is_matlab) const override;
   void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const override;
   void computeTemporaryTerms(map<expr_t, int> &reference_count,
                                      temporary_terms_t &temporary_terms,
@@ -1519,9 +1539,10 @@ private:
   vector<tuple<int, int, int, double>> non_optim_vars_params_and_constants;
 public:
   PacExpectationNode(DataTree &datatree_arg, int idx_arg, string model_name);
-  void computeTemporaryTerms(map<expr_t, pair<int, NodeTreeReference>> &reference_count,
-                                     map<NodeTreeReference, temporary_terms_t> &temp_terms_map,
-                                     bool is_matlab, NodeTreeReference tr) const override;
+  void computeTemporaryTerms(const pair<int, int> &derivOrder,
+                             map<pair<int, int>, temporary_terms_t> &temp_terms_map,
+                             map<expr_t, pair<int, pair<int, int>>> &reference_count,
+                             bool is_matlab) const override;
   void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const override;
   void computeTemporaryTerms(map<expr_t, int> &reference_count,
                                      temporary_terms_t &temporary_terms,
diff --git a/src/ModelTree.cc b/src/ModelTree.cc
index 02905b6f..6c5bd13a 100644
--- a/src/ModelTree.cc
+++ b/src/ModelTree.cc
@@ -1315,17 +1315,12 @@ ModelTree::computeDerivatives(int order, const set<int> &vars)
 void
 ModelTree::computeTemporaryTerms(bool is_matlab, bool no_tmp_terms)
 {
-  map<expr_t, pair<int, NodeTreeReference>> reference_count;
-  temporary_terms.clear();
-  temporary_terms_mlv.clear();
-  temporary_terms_derivatives.clear();
-  temporary_terms_derivatives.resize(4);
-
   /* Collect all model local variables appearing in equations (and only those,
      because printing unused model local variables can lead to a crash,
      see Dynare/dynare#101).
      Then store them in a dedicated structure (temporary_terms_mlv), that will
      be treated as the rest of temporary terms. */
+  temporary_terms_mlv.clear();
   set<int> used_local_vars;
   for (auto & equation : equations)
     equation->collectVariables(SymbolType::modelLocalVariable, used_local_vars);
@@ -1335,59 +1330,44 @@ ModelTree::computeTemporaryTerms(bool is_matlab, bool no_tmp_terms)
       temporary_terms_mlv[v] = local_variables_table.find(used_local_var)->second;
     }
 
-  map<NodeTreeReference, temporary_terms_t> temp_terms_map;
-  temp_terms_map[NodeTreeReference::residuals] = temporary_terms_derivatives[0];
-  temp_terms_map[NodeTreeReference::firstDeriv] = temporary_terms_derivatives[1];
-  temp_terms_map[NodeTreeReference::secondDeriv] = temporary_terms_derivatives[2];
-  temp_terms_map[NodeTreeReference::thirdDeriv] = temporary_terms_derivatives[3];
-
+  // Compute the temporary terms in equations and derivatives
+  map<pair<int, int>, temporary_terms_t> temp_terms_map;
   if (!no_tmp_terms)
     {
+      map<expr_t, pair<int, pair<int, int>>> reference_count;
+
       for (auto & equation : equations)
-        equation->computeTemporaryTerms(reference_count,
+        equation->computeTemporaryTerms({ 0, 0 },
                                         temp_terms_map,
-                                        is_matlab, NodeTreeReference::residuals);
-
-      for (auto & first_derivative : derivatives[1])
-        first_derivative.second->computeTemporaryTerms(reference_count,
-                                                       temp_terms_map,
-                                                       is_matlab, NodeTreeReference::firstDeriv);
-
-      for (auto & second_derivative : derivatives[2])
-        second_derivative.second->computeTemporaryTerms(reference_count,
-                                                        temp_terms_map,
-                                                        is_matlab, NodeTreeReference::secondDeriv);
-
-      for (auto & third_derivative : derivatives[3])
-        third_derivative.second->computeTemporaryTerms(reference_count,
-                                                       temp_terms_map,
-                                                       is_matlab, NodeTreeReference::thirdDeriv);
+                                        reference_count,
+                                        is_matlab);
+
+      for (int order = 1; order < (int) derivatives.size(); order++)
+        for (const auto &it : derivatives[order])
+          it.second->computeTemporaryTerms({ 0, order },
+                                           temp_terms_map,
+                                           reference_count,
+                                           is_matlab);
     }
 
-  for (map<NodeTreeReference, temporary_terms_t>::const_iterator it = temp_terms_map.begin();
-       it != temp_terms_map.end(); it++)
-      temporary_terms.insert(it->second.begin(), it->second.end());
+  // Fill the (now obsolete) temporary_terms structure
+  temporary_terms.clear();
+  for (const auto &it : temp_terms_map)
+    temporary_terms.insert(it.second.begin(), it.second.end());
 
-  temporary_terms_derivatives[0] = temp_terms_map[NodeTreeReference::residuals];
-  temporary_terms_derivatives[1] = temp_terms_map[NodeTreeReference::firstDeriv];
-  temporary_terms_derivatives[2] = temp_terms_map[NodeTreeReference::secondDeriv];
-  temporary_terms_derivatives[3] = temp_terms_map[NodeTreeReference::thirdDeriv];
+  // Fill the new structure
+  temporary_terms_derivatives.clear();
+  temporary_terms_derivatives.resize(derivatives.size());
+  for (int order = 0; order < (int) derivatives.size(); order++)
+    temporary_terms_derivatives[order] = move(temp_terms_map[{ 0, order }]);
 
+  // Compute indices in MATLAB/Julia vector
   int idx = 0;
   for (auto &it : temporary_terms_mlv)
     temporary_terms_idxs[it.first] = idx++;
-
-  for (auto it : temporary_terms_derivatives[0])
-    temporary_terms_idxs[it] = idx++;
-
-  for (auto it : temporary_terms_derivatives[1])
-    temporary_terms_idxs[it] = idx++;
-
-  for (auto it : temporary_terms_derivatives[2])
-    temporary_terms_idxs[it] = idx++;
-
-  for (auto it : temporary_terms_derivatives[3])
-    temporary_terms_idxs[it] = idx++;
+  for (int order = 0; order < (int) derivatives.size(); order++)
+    for (const auto &it : temporary_terms_derivatives[order])
+      temporary_terms_idxs[it] = idx++;
 }
 
 void
@@ -2078,66 +2058,24 @@ ModelTree::computeParamsDerivatives(int paramsDerivsOrder)
 void
 ModelTree::computeParamsDerivativesTemporaryTerms()
 {
-  map<expr_t, pair<int, NodeTreeReference >> reference_count;
-  map<NodeTreeReference, temporary_terms_t> temp_terms_map;
-  temp_terms_map[NodeTreeReference::residualsParamsDeriv] = params_derivs_temporary_terms[{ 0, 1 }];
-  temp_terms_map[NodeTreeReference::jacobianParamsDeriv] = params_derivs_temporary_terms[{ 1, 1 }];
-  temp_terms_map[NodeTreeReference::residualsParamsSecondDeriv] = params_derivs_temporary_terms[{ 0, 2 }];
-  temp_terms_map[NodeTreeReference::jacobianParamsSecondDeriv] = params_derivs_temporary_terms[{ 1, 2 }];
-  temp_terms_map[NodeTreeReference::hessianParamsDeriv] = params_derivs_temporary_terms[{ 2, 1}];
+  map<expr_t, pair<int, pair<int, int>>> reference_count;
 
   /* The temp terms should be constructed in the same order as the for loops in
      {Static,Dynamic}Model::write{Json,}ParamsDerivativesFile() */
-
-  for (const auto &residuals_params_derivative : params_derivatives[{ 0, 1 }])
-    residuals_params_derivative.second->computeTemporaryTerms(reference_count,
-                                      temp_terms_map,
-                                      true, NodeTreeReference::residualsParamsDeriv);
-
-  for (const auto &jacobian_params_derivative : params_derivatives[{ 1, 1 }])
-    jacobian_params_derivative.second->computeTemporaryTerms(reference_count,
-                                      temp_terms_map,
-                                      true, NodeTreeReference::jacobianParamsDeriv);
-
-  for (const auto &it : params_derivatives[{ 0, 2 }])
-    it.second->computeTemporaryTerms(reference_count,
-                                     temp_terms_map,
-                                     true, NodeTreeReference::residualsParamsSecondDeriv);
-
-  for (const auto &it : params_derivatives[{ 1, 2 }])
-    it.second->computeTemporaryTerms(reference_count,
-                                     temp_terms_map,
-                                     true, NodeTreeReference::jacobianParamsSecondDeriv);
-
-  for (const auto &it : params_derivatives[{ 2, 1 }])
-    it.second->computeTemporaryTerms(reference_count,
-                                     temp_terms_map,
-                                     true, NodeTreeReference::hessianParamsDeriv);
-
-  params_derivs_temporary_terms[{ 0, 1 }] = temp_terms_map[NodeTreeReference::residualsParamsDeriv];
-  params_derivs_temporary_terms[{ 1, 1 }] = temp_terms_map[NodeTreeReference::jacobianParamsDeriv];
-  params_derivs_temporary_terms[{ 0, 2 }] = temp_terms_map[NodeTreeReference::residualsParamsSecondDeriv];
-  params_derivs_temporary_terms[{ 1, 2 }] = temp_terms_map[NodeTreeReference::jacobianParamsSecondDeriv];
-  params_derivs_temporary_terms[{ 2, 1 }] = temp_terms_map[NodeTreeReference::hessianParamsDeriv];
+  params_derivs_temporary_terms.clear();
+  for (const auto &it : params_derivatives)
+    for (const auto &it2 : it.second)
+      it2.second->computeTemporaryTerms(it.first,
+                                        params_derivs_temporary_terms,
+                                        reference_count,
+                                        true);
 
   int idx = 0;
   for (auto &it : temporary_terms_mlv)
     params_derivs_temporary_terms_idxs[it.first] = idx++;
-
-  for (auto tt : params_derivs_temporary_terms[{ 0, 1 }])
-    params_derivs_temporary_terms_idxs[tt] = idx++;
-
-  for (auto tt : params_derivs_temporary_terms[{ 1, 1 }])
-    params_derivs_temporary_terms_idxs[tt] = idx++;
-
-  for (auto tt : params_derivs_temporary_terms[{ 0, 2 }])
-    params_derivs_temporary_terms_idxs[tt] = idx++;
-
-  for (auto tt : params_derivs_temporary_terms[{ 1, 2 }])
-    params_derivs_temporary_terms_idxs[tt] = idx++;
-
-  for (auto tt : params_derivs_temporary_terms[{ 2, 1 }])
-    params_derivs_temporary_terms_idxs[tt] = idx++;
+  for (const auto &it : params_derivs_temporary_terms)
+    for (const auto &tt : it.second)
+      params_derivs_temporary_terms_idxs[tt] = idx++;
 }
 
 bool
diff --git a/src/ModelTree.hh b/src/ModelTree.hh
index 882a06d6..2eb3c2e0 100644
--- a/src/ModelTree.hh
+++ b/src/ModelTree.hh
@@ -120,12 +120,14 @@ protected:
   /*! Index 0 is temp. terms of residuals, index 1 for first derivatives, ... */
   vector<temporary_terms_t> temporary_terms_derivatives;
 
+  //! Stores, for each temporary term, its index in the MATLAB/Julia vector
   temporary_terms_idxs_t temporary_terms_idxs;
 
   //! Temporary terms for parameter derivatives, under a disaggregated form
   /*! The pair of integers is to be interpreted as in param_derivatives */
   map<pair<int,int>, temporary_terms_t> params_derivs_temporary_terms;
 
+  //! Stores, for each temporary term in param. derivs, its index in the MATLAB/Julia vector
   temporary_terms_idxs_t params_derivs_temporary_terms_idxs;
 
   //! Trend variables and their growth factors
diff --git a/src/StaticModel.cc b/src/StaticModel.cc
index 4081f33b..d364a5ff 100644
--- a/src/StaticModel.cc
+++ b/src/StaticModel.cc
@@ -2615,8 +2615,8 @@ StaticModel::writeParamsDerivativesFile(const string &basename, bool julia) cons
   deriv_node_temp_terms_t tef_terms;
 
   writeModelLocalVariableTemporaryTerms(temp_term_union, params_derivs_temporary_terms_idxs, tt_output, output_type, tef_terms);
-  for (auto it : { make_pair(0,1), make_pair(1,1), make_pair(0,2), make_pair(1,2), make_pair(2,1) })
-    writeTemporaryTerms(params_derivs_temporary_terms.find(it)->second, temp_term_union, params_derivs_temporary_terms_idxs, tt_output, output_type, tef_terms);
+  for (const auto &it : params_derivs_temporary_terms)
+    writeTemporaryTerms(it.second, temp_term_union, params_derivs_temporary_terms_idxs, tt_output, output_type, tef_terms);
 
   for (const auto & residuals_params_derivative : params_derivatives.find({ 0, 1 })->second)
     {
@@ -3082,8 +3082,8 @@ StaticModel::writeJsonParamsDerivativesFile(ostream &output, bool writeDetails)
 
   temporary_terms_t temp_term_union;
   string concat = "all";
-  for (auto it : { make_pair(0,1), make_pair(1,1), make_pair(0,2), make_pair(1,2), make_pair(2,1) })
-    writeJsonTemporaryTerms(params_derivs_temporary_terms.find(it)->second, temp_term_union, model_output, tef_terms, concat);
+  for (const auto &it : params_derivs_temporary_terms)
+    writeJsonTemporaryTerms(it.second, temp_term_union, model_output, tef_terms, concat);
 
   jacobian_output << "\"deriv_wrt_params\": {"
                   << "  \"neqs\": " << equations.size()
-- 
GitLab