From e22d9049ee79e0dfbdd7552aba277eaa5fd15f20 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?S=C3=A9bastien=20Villemot?= <sebastien@dynare.org>
Date: Wed, 5 Apr 2023 14:16:22 +0200
Subject: [PATCH] Optimization: use std::unordered_map instead of std::map for
 caching chain rule derivation

Improves performance on very very large models (tens of thousands of equations).
---
 src/DynamicModel.cc |  4 ++--
 src/ExprNode.cc     | 55 +++++++++++++++++++++++----------------------
 src/ExprNode.hh     | 40 ++++++++++++++++++---------------
 src/StaticModel.cc  |  9 ++++----
 4 files changed, 57 insertions(+), 51 deletions(-)

diff --git a/src/DynamicModel.cc b/src/DynamicModel.cc
index fe4d4089..4480b4f5 100644
--- a/src/DynamicModel.cc
+++ b/src/DynamicModel.cc
@@ -2320,8 +2320,8 @@ DynamicModel::computeChainRuleJacobian()
         }
 
       // Compute the block derivatives
-      map<expr_t, set<int>> non_null_chain_rule_derivatives;
-      map<pair<expr_t, int>, expr_t> chain_rule_deriv_cache;
+      unordered_map<expr_t, set<int>> non_null_chain_rule_derivatives;
+      unordered_map<expr_t, map<int, expr_t>> chain_rule_deriv_cache;
       for (const auto &[indices, derivType] : determineBlockDerivativesType(blk))
         {
           auto [lag, eq, var] = indices;
diff --git a/src/ExprNode.cc b/src/ExprNode.cc
index faeb7d70..8a203031 100644
--- a/src/ExprNode.cc
+++ b/src/ExprNode.cc
@@ -56,8 +56,8 @@ ExprNode::getDerivative(int deriv_id)
 
 expr_t
 ExprNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables,
-                                 map<expr_t, set<int>> &non_null_chain_rule_derivatives,
-                                 map<pair<expr_t, int>, expr_t> &cache)
+                                 unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives,
+                                 unordered_map<expr_t, map<int, expr_t>> &cache)
 {
   if (!non_null_chain_rule_derivatives.contains(this))
     prepareForChainRuleDerivation(recursive_variables, non_null_chain_rule_derivatives);
@@ -67,15 +67,16 @@ ExprNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &r
     return datatree.Zero;
 
   // If derivative is in the cache, return that value
-  pair key {this, deriv_id};
-  if (auto it = cache.find(key);
+  if (auto it = cache.find(this);
       it != cache.end())
-    return it->second;
+    if (auto it2 = it->second.find(deriv_id);
+        it2 != it->second.end())
+      return it2->second;
 
   auto r = computeChainRuleDerivative(deriv_id, recursive_variables,
                                       non_null_chain_rule_derivatives, cache);
 
-  auto [ignore, success] = cache.emplace(key, r);
+  auto [ignore, success] = cache[this].emplace(deriv_id, r);
   assert(success); // The element should not already exist
   return r;
 }
@@ -489,7 +490,7 @@ NumConstNode::prepareForDerivation()
 
 void
 NumConstNode::prepareForChainRuleDerivation([[maybe_unused]] const map<int, BinaryOpNode *> &recursive_variables,
-                                            map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
+                                            unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
 {
   non_null_chain_rule_derivatives.try_emplace(const_cast<NumConstNode *>(this));
 }
@@ -582,8 +583,8 @@ NumConstNode::normalizeEquationHelper([[maybe_unused]] const set<expr_t> &contai
 expr_t
 NumConstNode::computeChainRuleDerivative([[maybe_unused]] int deriv_id,
                                          [[maybe_unused]] const map<int, BinaryOpNode *> &recursive_variables,
-                                         [[maybe_unused]] map<expr_t, set<int>> &non_null_chain_rule_derivatives,
-                                         [[maybe_unused]] map<pair<expr_t, int>, expr_t> &cache)
+                                         [[maybe_unused]] unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives,
+                                         [[maybe_unused]] unordered_map<expr_t, map<int, expr_t>> &cache)
 {
   return datatree.Zero;
 }
@@ -911,7 +912,7 @@ VariableNode::prepareForDerivation()
 
 void
 VariableNode::prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables,
-                                            map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
+                                            unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
 {
   if (non_null_chain_rule_derivatives.contains(const_cast<VariableNode *>(this)))
     return;
@@ -1470,8 +1471,8 @@ VariableNode::normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs
 expr_t
 VariableNode::computeChainRuleDerivative(int deriv_id,
                                          const map<int, BinaryOpNode *> &recursive_variables,
-                                         map<expr_t, set<int>> &non_null_chain_rule_derivatives,
-                                         map<pair<expr_t, int>, expr_t> &cache)
+                                         unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives,
+                                         unordered_map<expr_t, map<int, expr_t>> &cache)
 {
   switch (get_type())
     {
@@ -2202,7 +2203,7 @@ UnaryOpNode::prepareForDerivation()
 
 void
 UnaryOpNode::prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables,
-                                           map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
+                                           unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
 {
   if (non_null_chain_rule_derivatives.contains(const_cast<UnaryOpNode *>(this)))
     return;
@@ -3342,8 +3343,8 @@ UnaryOpNode::normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs)
 expr_t
 UnaryOpNode::computeChainRuleDerivative(int deriv_id,
                                         const map<int, BinaryOpNode *> &recursive_variables,
-                                        map<expr_t, set<int>> &non_null_chain_rule_derivatives,
-                                        map<pair<expr_t, int>, expr_t> &cache)
+                                        unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives,
+                                        unordered_map<expr_t, map<int, expr_t>> &cache)
 {
   expr_t darg = arg->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache);
   return composeDerivatives(darg, deriv_id);
@@ -4022,7 +4023,7 @@ BinaryOpNode::prepareForDerivation()
 
 void
 BinaryOpNode::prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables,
-                                            map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
+                                            unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
 {
   if (non_null_chain_rule_derivatives.contains(const_cast<BinaryOpNode *>(this)))
     return;
@@ -5090,8 +5091,8 @@ BinaryOpNode::normalizeEquation(int symb_id, int lag) const
 expr_t
 BinaryOpNode::computeChainRuleDerivative(int deriv_id,
                                          const map<int, BinaryOpNode *> &recursive_variables,
-                                         map<expr_t, set<int>> &non_null_chain_rule_derivatives,
-                                         map<pair<expr_t, int>, expr_t> &cache)
+                                         unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives,
+                                         unordered_map<expr_t, map<int, expr_t>> &cache)
 {
   expr_t darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache);
   expr_t darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache);
@@ -5890,7 +5891,7 @@ TrinaryOpNode::prepareForDerivation()
 
 void
 TrinaryOpNode::prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables,
-                                             map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
+                                             unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
 {
   if (non_null_chain_rule_derivatives.contains(const_cast<TrinaryOpNode *>(this)))
     return;
@@ -6375,8 +6376,8 @@ TrinaryOpNode::normalizeEquationHelper([[maybe_unused]] const set<expr_t> &conta
 expr_t
 TrinaryOpNode::computeChainRuleDerivative(int deriv_id,
                                           const map<int, BinaryOpNode *> &recursive_variables,
-                                          map<expr_t, set<int>> &non_null_chain_rule_derivatives,
-                                          map<pair<expr_t, int>, expr_t> &cache)
+                                          unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives,
+                                          unordered_map<expr_t, map<int, expr_t>> &cache)
 {
   expr_t darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache);
   expr_t darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables, non_null_chain_rule_derivatives, cache);
@@ -6723,7 +6724,7 @@ AbstractExternalFunctionNode::prepareForDerivation()
 
 void
 AbstractExternalFunctionNode::prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables,
-                                                            map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
+                                                            unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
 {
   if (non_null_chain_rule_derivatives.contains(const_cast<AbstractExternalFunctionNode *>(this)))
     return;
@@ -6758,8 +6759,8 @@ AbstractExternalFunctionNode::computeDerivative(int deriv_id)
 expr_t
 AbstractExternalFunctionNode::computeChainRuleDerivative(int deriv_id,
                                                          const map<int, BinaryOpNode *> &recursive_variables,
-                                                         map<expr_t, set<int>> &non_null_chain_rule_derivatives,
-                                                         map<pair<expr_t, int>, expr_t> &cache)
+                                                         unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives,
+                                                         unordered_map<expr_t, map<int, expr_t>> &cache)
 {
   assert(datatree.external_functions_table.getNargs(symb_id) > 0);
   vector<expr_t> dargs;
@@ -8238,7 +8239,7 @@ SubModelNode::prepareForDerivation()
 
 void
 SubModelNode::prepareForChainRuleDerivation([[maybe_unused]] const map<int, BinaryOpNode *> &recursive_variables,
-                                            [[maybe_unused]] map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
+                                            [[maybe_unused]] unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives) const
 {
   cerr << "SubModelNode::prepareForChainRuleDerivation not implemented." << endl;
   exit(EXIT_FAILURE);
@@ -8254,8 +8255,8 @@ SubModelNode::computeDerivative([[maybe_unused]] int deriv_id)
 expr_t
 SubModelNode::computeChainRuleDerivative([[maybe_unused]] int deriv_id,
                                          [[maybe_unused]] const map<int, BinaryOpNode *> &recursive_variables,
-                                         [[maybe_unused]] map<expr_t, set<int>> &non_null_chain_rule_derivatives,
-                                         [[maybe_unused]] map<pair<expr_t, int>, expr_t> &cache)
+                                         [[maybe_unused]] unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives,
+                                         [[maybe_unused]] unordered_map<expr_t, map<int, expr_t>> &cache)
 {
   cerr << "SubModelNode::computeChainRuleDerivative not implemented." << endl;
   exit(EXIT_FAILURE);
diff --git a/src/ExprNode.hh b/src/ExprNode.hh
index 875e27d9..31025900 100644
--- a/src/ExprNode.hh
+++ b/src/ExprNode.hh
@@ -27,6 +27,7 @@
 #include <functional>
 #include <optional>
 #include <utility>
+#include <unordered_map>
 
 using namespace std;
 
@@ -250,7 +251,7 @@ private:
 
   /* Internal helper for getChainRuleDerivative(), that does the computation
      but assumes that the caching of this is handled elsewhere */
-  virtual expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) = 0;
+  virtual expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives, unordered_map<expr_t, map<int, expr_t>> &cache) = 0;
 
 protected:
   //! Reference to the enclosing DataTree
@@ -285,7 +286,7 @@ protected:
      “recursive_variables”. Note that all non-endogenous variables are
      automatically considered to have a zero derivative (since they’re never
      used in a chain rule context) */
-  virtual void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const = 0;
+  virtual void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives) const = 0;
 
   //! Cost of computing current node
   /*! Nodes included in temporary_terms are considered having a null cost */
@@ -353,8 +354,11 @@ public:
      variable (since such variables are never used in a chain rule context).
      NB 2: “non_null_chain_rule_derivatives” and “cache” are specific to a given
      value of “recursive_variables”, and thus should not be reused accross
-     calls that use different values of “recursive_variables”. */
-  expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache);
+     calls that use different values of “recursive_variables”.
+     NB 3: the use of std::unordered_map instead of std::map for caching
+     purposes improves performance on very very large models (tens of thousands
+     of equations) */
+  expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives, unordered_map<expr_t, map<int, expr_t>> &cache);
 
   //! Returns precedence of node
   /*! Equals 100 for constants, variables, unary ops, and temporary terms */
@@ -855,10 +859,10 @@ public:
   const int id;
 private:
   expr_t computeDerivative(int deriv_id) override;
-  expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override;
+  expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives, unordered_map<expr_t, map<int, expr_t>> &cache) override;
 protected:
   void prepareForDerivation() override;
-  void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
+  void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
   void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> &param_id, double &constant, bool at_denominator) const override;
   void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
 public:
@@ -927,10 +931,10 @@ public:
   const int lag;
 private:
   expr_t computeDerivative(int deriv_id) override;
-  expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override;
+  expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives, unordered_map<expr_t, map<int, expr_t>> &cache) override;
 protected:
   void prepareForDerivation() override;
-  void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
+  void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
   void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> &param_id, double &constant, bool at_denominator) const override;
   void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
 public:
@@ -996,7 +1000,7 @@ class UnaryOpNode : public ExprNode
 {
 protected:
   void prepareForDerivation() override;
-  void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
+  void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
   void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> &param_id, double &constant, bool at_denominator) const override;
   void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
   // Returns the node obtained by applying a transformation recursively on the argument (in same datatree)
@@ -1018,7 +1022,7 @@ public:
   const vector<int> adl_lags;
 private:
   expr_t computeDerivative(int deriv_id) override;
-  expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override;
+  expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives, unordered_map<expr_t, map<int, expr_t>> &cache) override;
   int cost(int cost, bool is_matlab) const override;
   int cost(const vector<vector<temporary_terms_t>> &blocks_temporary_terms, bool is_matlab) const override;
   int cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const override;
@@ -1108,7 +1112,7 @@ class BinaryOpNode : public ExprNode
 {
 protected:
   void prepareForDerivation() override;
-  void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
+  void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
   void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> &param_id, double &constant, bool at_denominator) const override;
   void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
 public:
@@ -1118,7 +1122,7 @@ public:
   const string adlparam;
 private:
   expr_t computeDerivative(int deriv_id) override;
-  expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override;
+  expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives, unordered_map<expr_t, map<int, expr_t>> &cache) override;
   int cost(int cost, bool is_matlab) const override;
   int cost(const vector<vector<temporary_terms_t>> &blocks_temporary_terms, bool is_matlab) const override;
   int cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const override;
@@ -1264,11 +1268,11 @@ public:
   const TrinaryOpcode op_code;
 protected:
   void prepareForDerivation() override;
-  void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
+  void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
   void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
 private:
   expr_t computeDerivative(int deriv_id) override;
-  expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override;
+  expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives, unordered_map<expr_t, map<int, expr_t>> &cache) override;
   int cost(int cost, bool is_matlab) const override;
   int cost(const vector<vector<temporary_terms_t>> &blocks_temporary_terms, bool is_matlab) const override;
   int cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const override;
@@ -1371,7 +1375,7 @@ public:
   const vector<expr_t> arguments;
 private:
   expr_t computeDerivative(int deriv_id) override;
-  expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override;
+  expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives, unordered_map<expr_t, map<int, expr_t>> &cache) override;
   virtual expr_t composeDerivatives(const vector<expr_t> &dargs) = 0;
   // Computes the maximum of f applied to all arguments (result will always be non-negative)
   int maxHelper(const function<int (expr_t)> &f) const;
@@ -1391,7 +1395,7 @@ protected:
   {
   };
   void prepareForDerivation() override;
-  void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
+  void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
   //! Returns true if the given external function has been written as a temporary term
   bool alreadyWrittenAsTefTerm(int the_symb_id, const deriv_node_temp_terms_t &tef_terms) const;
   //! Returns the index in the tef_terms map of this external function
@@ -1657,10 +1661,10 @@ public:
   expr_t substituteLogTransform(int orig_symb_id, int aux_symb_id) const override;
 protected:
   void prepareForDerivation() override;
-  void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
+  void prepareForChainRuleDerivation(const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives) const override;
   void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
 private:
-  expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<expr_t, set<int>> &non_null_chain_rule_derivatives, map<pair<expr_t, int>, expr_t> &cache) override;
+  expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, unordered_map<expr_t, set<int>> &non_null_chain_rule_derivatives, unordered_map<expr_t, map<int, expr_t>> &cache) override;
 };
 
 class VarExpectationNode : public SubModelNode
diff --git a/src/StaticModel.cc b/src/StaticModel.cc
index d3af5132..94fad6e2 100644
--- a/src/StaticModel.cc
+++ b/src/StaticModel.cc
@@ -24,6 +24,7 @@
 #include <algorithm>
 #include <sstream>
 #include <numeric>
+#include <unordered_map>
 
 #include "StaticModel.hh"
 #include "DynamicModel.hh"
@@ -650,8 +651,8 @@ StaticModel::computeChainRuleJacobian()
              && simulation_type != BlockSimulationType::solveTwoBoundariesComplete);
 
       int size = blocks[blk].size;
-      map<expr_t, set<int>> non_null_chain_rule_derivatives;
-      map<pair<expr_t, int>, expr_t> chain_rule_deriv_cache;
+      unordered_map<expr_t, set<int>> non_null_chain_rule_derivatives;
+      unordered_map<expr_t, map<int, expr_t>> chain_rule_deriv_cache;
       for (int eq = nb_recursives; eq < size; eq++)
         {
           int eq_orig = getBlockEquationID(blk, eq);
@@ -822,8 +823,8 @@ StaticModel::computeRamseyMultipliersDerivatives(int ramsey_orig_endo_nbr, bool
     }
 
   // Compute the chain rule derivatives w.r.t. multipliers
-  map<expr_t, set<int>> non_null_chain_rule_derivatives;
-  map<pair<expr_t, int>, expr_t> cache;
+  unordered_map<expr_t, set<int>> non_null_chain_rule_derivatives;
+  unordered_map<expr_t, map<int, expr_t>> cache;
   for (int eq {0}; eq < ramsey_orig_endo_nbr; eq++)
     for (int mult {0}; mult < static_cast<int>(mult_deriv_ids.size()); mult++)
       if (expr_t d { equations[eq]->getChainRuleDerivative(mult_deriv_ids[mult], recursive_variables,
-- 
GitLab