From 5979885714bdd9856a2fce6362aa1baaa301818a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?S=C3=A9bastien=20Villemot?= <sebastien@dynare.org>
Date: Fri, 2 Oct 2020 18:31:55 +0200
Subject: [PATCH] Block decomposition, chain rule derivation: code improvement

---
 src/DynamicModel.cc |  2 +-
 src/ExprNode.cc     | 22 +++++++++-------------
 src/ExprNode.hh     | 18 +++++++++---------
 src/StaticModel.cc  |  2 +-
 4 files changed, 20 insertions(+), 24 deletions(-)

diff --git a/src/DynamicModel.cc b/src/DynamicModel.cc
index 858c7087..4d687bf6 100644
--- a/src/DynamicModel.cc
+++ b/src/DynamicModel.cc
@@ -4520,7 +4520,7 @@ DynamicModel::computeChainRuleJacobian()
       int nb_recursives = blocks[blk].getRecursiveSize();
 
       // Create a map from recursive vars to their defining (normalized) equation
-      map<int, expr_t> recursive_vars;
+      map<int, BinaryOpNode *> recursive_vars;
       for (int i = 0; i < nb_recursives; i++)
         {
           int deriv_id = getDerivID(symbol_table.getID(SymbolType::endogenous, getBlockVariableID(blk, i)), 0);
diff --git a/src/ExprNode.cc b/src/ExprNode.cc
index dff1abb6..6de8f844 100644
--- a/src/ExprNode.cc
+++ b/src/ExprNode.cc
@@ -483,7 +483,7 @@ NumConstNode::normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs
 }
 
 expr_t
-NumConstNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables)
+NumConstNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables)
 {
   return datatree.Zero;
 }
@@ -1307,7 +1307,7 @@ VariableNode::normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs
 }
 
 expr_t
-VariableNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables)
+VariableNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables)
 {
   switch (get_type())
     {
@@ -1322,11 +1322,7 @@ VariableNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recur
       // If there is in the equation a recursive variable we could use a chaine rule derivation
       else if (auto it = recursive_variables.find(datatree.getDerivID(symb_id, lag));
                it != recursive_variables.end())
-        {
-          map<int, expr_t> recursive_vars2(recursive_variables);
-          recursive_vars2.erase(it->first);
-          return datatree.AddUMinus(it->second->getChainRuleDerivative(deriv_id, recursive_vars2));
-        }
+        return it->second->arg2->getChainRuleDerivative(deriv_id, recursive_variables);
       else
         return datatree.Zero;
 
@@ -3088,7 +3084,7 @@ UnaryOpNode::normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs)
 }
 
 expr_t
-UnaryOpNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables)
+UnaryOpNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables)
 {
   expr_t darg = arg->getChainRuleDerivative(deriv_id, recursive_variables);
   return composeDerivatives(darg, deriv_id);
@@ -4820,7 +4816,7 @@ BinaryOpNode::normalizeEquation(int symb_id, int lag) const
 }
 
 expr_t
-BinaryOpNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables)
+BinaryOpNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables)
 {
   expr_t darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables);
   expr_t darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables);
@@ -6089,7 +6085,7 @@ TrinaryOpNode::normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rh
 }
 
 expr_t
-TrinaryOpNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables)
+TrinaryOpNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables)
 {
   expr_t darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables);
   expr_t darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables);
@@ -6508,7 +6504,7 @@ AbstractExternalFunctionNode::computeDerivative(int deriv_id)
 }
 
 expr_t
-AbstractExternalFunctionNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables)
+AbstractExternalFunctionNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables)
 {
   assert(datatree.external_functions_table.getNargs(symb_id) > 0);
   vector<expr_t> dargs;
@@ -8185,7 +8181,7 @@ VarExpectationNode::computeDerivative(int deriv_id)
 }
 
 expr_t
-VarExpectationNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables)
+VarExpectationNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables)
 {
   cerr << "VarExpectationNode::getChainRuleDerivative not implemented." << endl;
   exit(EXIT_FAILURE);
@@ -8582,7 +8578,7 @@ PacExpectationNode::computeDerivative(int deriv_id)
 }
 
 expr_t
-PacExpectationNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables)
+PacExpectationNode::getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables)
 {
   cerr << "PacExpectationNode::getChainRuleDerivative: shouldn't arrive here." << endl;
   exit(EXIT_FAILURE);
diff --git a/src/ExprNode.hh b/src/ExprNode.hh
index e7b1f910..4e3b2cd4 100644
--- a/src/ExprNode.hh
+++ b/src/ExprNode.hh
@@ -268,7 +268,7 @@ public:
     \param deriv_id The derivation ID with respect to which we are derivating
     \param recursive_variables Contains the derivation ID for which chain rules must be applied. Keys are derivation IDs, values are equations of the form x=f(y) where x is the key variable and x doesn't appear in y
   */
-  virtual expr_t getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables) = 0;
+  virtual expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) = 0;
 
   //! Returns precedence of node
   /*! Equals 100 for constants, variables, unary ops, and temporary terms */
@@ -747,7 +747,7 @@ public:
   void computeXrefs(EquationInfo &ei) const override;
   void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
   BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override;
-  expr_t getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables) override;
+  expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override;
   int maxEndoLead() const override;
   int maxExoLead() const override;
   int maxEndoLag() const override;
@@ -820,7 +820,7 @@ public:
   SymbolType get_type() const;
   void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
   BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override;
-  expr_t getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables) override;
+  expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override;
   int maxEndoLead() const override;
   int maxExoLead() const override;
   int maxEndoLag() const override;
@@ -921,7 +921,7 @@ public:
   void computeXrefs(EquationInfo &ei) const override;
   void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
   BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override;
-  expr_t getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables) override;
+  expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override;
   int maxEndoLead() const override;
   int maxExoLead() const override;
   int maxEndoLag() const override;
@@ -1028,7 +1028,7 @@ public:
   //! Try to normalize an equation with respect to a given dynamic variable.
   /*! Should only be called on Equal nodes. The variable must appear in the equation. */
   BinaryOpNode *normalizeEquation(int symb_id, int lag) const;
-  expr_t getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables) override;
+  expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override;
   int maxEndoLead() const override;
   int maxExoLead() const override;
   int maxEndoLag() const override;
@@ -1160,7 +1160,7 @@ public:
   void computeXrefs(EquationInfo &ei) const override;
   void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
   BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override;
-  expr_t getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables) override;
+  expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override;
   int maxEndoLead() const override;
   int maxExoLead() const override;
   int maxEndoLag() const override;
@@ -1272,7 +1272,7 @@ public:
   void computeXrefs(EquationInfo &ei) const override = 0;
   void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override;
   BinaryOpNode *normalizeEquationHelper(const set<expr_t> &contain_var, expr_t rhs) const override;
-  expr_t getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables) override;
+  expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override;
   int maxEndoLead() const override;
   int maxExoLead() const override;
   int maxEndoLag() const override;
@@ -1458,7 +1458,7 @@ public:
   expr_t decreaseLeadsLags(int n) const override;
   void prepareForDerivation() override;
   expr_t computeDerivative(int deriv_id) override;
-  expr_t getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables) override;
+  expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override;
   bool containsExternalFunction() const override;
   double eval(const eval_context_t &eval_context) const noexcept(false) override;
   void computeXrefs(EquationInfo &ei) const override;
@@ -1531,7 +1531,7 @@ public:
   expr_t decreaseLeadsLags(int n) const override;
   void prepareForDerivation() override;
   expr_t computeDerivative(int deriv_id) override;
-  expr_t getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables) override;
+  expr_t getChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables) override;
   bool containsExternalFunction() const override;
   double eval(const eval_context_t &eval_context) const noexcept(false) override;
   void computeXrefs(EquationInfo &ei) const override;
diff --git a/src/StaticModel.cc b/src/StaticModel.cc
index 68e6e418..cd506fed 100644
--- a/src/StaticModel.cc
+++ b/src/StaticModel.cc
@@ -2065,7 +2065,7 @@ StaticModel::computeChainRuleJacobian()
     {
       int nb_recursives = blocks[blk].getRecursiveSize();
 
-      map<int, expr_t> recursive_vars;
+      map<int, BinaryOpNode *> recursive_vars;
       for (int i = 0; i < nb_recursives; i++)
         {
           int deriv_id = getDerivID(symbol_table.getID(SymbolType::endogenous, getBlockVariableID(blk, i)), 0);
-- 
GitLab