From 2deb4b42fbb2c3b93d8babc112f4a43376ffcc70 Mon Sep 17 00:00:00 2001
From: houtanb <houtan@dynare.org>
Date: Fri, 23 Jun 2017 18:01:07 +0200
Subject: [PATCH] preprocessor: fix bug in adl implementation

---
 DataTree.cc    |  5 ++---
 DataTree.hh    | 18 ++++++++----------
 ExprNode.cc    | 35 +++++++++++++++++++++++------------
 ExprNode.hh    |  3 +++
 SymbolTable.cc |  2 +-
 SymbolTable.hh |  2 +-
 6 files changed, 38 insertions(+), 27 deletions(-)

diff --git a/DataTree.cc b/DataTree.cc
index 2ce90dcc..39278b19 100644
--- a/DataTree.cc
+++ b/DataTree.cc
@@ -264,10 +264,9 @@ DataTree::AddDiff(expr_t iArg1)
 }
 
 expr_t
-DataTree::AddAdl(expr_t iArg1, string &name, expr_t iArg2)
+DataTree::AddAdl(expr_t iArg1, const string &name, expr_t iArg2)
 {
-  expr_t adlnode = AddBinaryOp(iArg1, oAdl, iArg2);
-  adl_map[adlnode] = new string(name);
+  expr_t adlnode = AddBinaryOp(iArg1, oAdl, iArg2, 0, string(name));
   return adlnode;
 }
 
diff --git a/DataTree.hh b/DataTree.hh
index 18e73417..59d05a1a 100644
--- a/DataTree.hh
+++ b/DataTree.hh
@@ -57,10 +57,6 @@ protected:
   //! A reference to the external functions table
   ExternalFunctionsTable &external_functions_table;
 
-  //! A reference to the adl table
-  typedef map<expr_t, string *> adl_map_t;
-  adl_map_t adl_map;
-
   typedef map<int, NumConstNode *> num_const_node_map_t;
   num_const_node_map_t num_const_node_map;
   //! Pair (symbol_id, lag) used as key
@@ -70,7 +66,7 @@ protected:
   typedef map<pair<pair<expr_t, UnaryOpcode>, pair<int, pair<int, int> > >, UnaryOpNode *> unary_op_node_map_t;
   unary_op_node_map_t unary_op_node_map;
   //! Pair( Pair( Pair(arg1, arg2), order of Power Derivative), opCode)
-  typedef map<pair<pair<pair<expr_t, expr_t>, int>, BinaryOpcode>, BinaryOpNode *> binary_op_node_map_t;
+  typedef map<pair<pair<pair<expr_t, expr_t>, pair<int, string> >, BinaryOpcode>, BinaryOpNode *> binary_op_node_map_t;
   binary_op_node_map_t binary_op_node_map;
   typedef map<pair<pair<pair<expr_t, expr_t>, expr_t>, TrinaryOpcode>, TrinaryOpNode *> trinary_op_node_map_t;
   trinary_op_node_map_t trinary_op_node_map;
@@ -108,7 +104,7 @@ private:
 
   inline expr_t AddPossiblyNegativeConstant(double val);
   inline expr_t AddUnaryOp(UnaryOpcode op_code, expr_t arg, int arg_exp_info_set = 0, int param1_symb_id = 0, int param2_symb_id = 0);
-  inline expr_t AddBinaryOp(expr_t arg1, BinaryOpcode op_code, expr_t arg2, int powerDerivOrder = 0);
+  inline expr_t AddBinaryOp(expr_t arg1, BinaryOpcode op_code, expr_t arg2, int powerDerivOrder = 0, const string &adlparam = "");
   inline expr_t AddTrinaryOp(expr_t arg1, TrinaryOpcode op_code, expr_t arg2, expr_t arg3);
 
 public:
@@ -169,7 +165,7 @@ public:
   //! Adds "diff(arg)" to model tree
   expr_t AddDiff(expr_t iArg1);
   //! Adds "adl(arg1, arg2)" to model tree
-  expr_t AddAdl(expr_t iArg1, string &name, expr_t iArg2);
+  expr_t AddAdl(expr_t iArg1, const string &name, expr_t iArg2);
   //! Adds "exp(arg)" to model tree
   expr_t AddExp(expr_t iArg1);
   //! Adds "log(arg)" to model tree
@@ -346,9 +342,11 @@ DataTree::AddUnaryOp(UnaryOpcode op_code, expr_t arg, int arg_exp_info_set, int
 }
 
 inline expr_t
-DataTree::AddBinaryOp(expr_t arg1, BinaryOpcode op_code, expr_t arg2, int powerDerivOrder)
+DataTree::AddBinaryOp(expr_t arg1, BinaryOpcode op_code, expr_t arg2, int powerDerivOrder, const string &adlparam)
 {
-  binary_op_node_map_t::iterator it = binary_op_node_map.find(make_pair(make_pair(make_pair(arg1, arg2), powerDerivOrder), op_code));
+  binary_op_node_map_t::iterator it = binary_op_node_map.find(make_pair(make_pair(make_pair(arg1, arg2),
+                                                                                  make_pair(powerDerivOrder, adlparam)),
+                                                                        op_code));
   if (it != binary_op_node_map.end())
     return it->second;
 
@@ -363,7 +361,7 @@ DataTree::AddBinaryOp(expr_t arg1, BinaryOpcode op_code, expr_t arg2, int powerD
   catch (ExprNode::EvalException &e)
     {
     }
-  return new BinaryOpNode(*this, arg1, op_code, arg2, powerDerivOrder);
+  return new BinaryOpNode(*this, arg1, op_code, arg2, powerDerivOrder, adlparam);
 }
 
 inline expr_t
diff --git a/ExprNode.cc b/ExprNode.cc
index c2d2a518..1705c2f9 100644
--- a/ExprNode.cc
+++ b/ExprNode.cc
@@ -2887,9 +2887,10 @@ BinaryOpNode::BinaryOpNode(DataTree &datatree_arg, const expr_t arg1_arg,
   arg1(arg1_arg),
   arg2(arg2_arg),
   op_code(op_code_arg),
-  powerDerivOrder(0)
+  powerDerivOrder(0),
+  adlparam("")
 {
-  datatree.binary_op_node_map[make_pair(make_pair(make_pair(arg1, arg2), powerDerivOrder), op_code)] = this;
+  datatree.binary_op_node_map[make_pair(make_pair(make_pair(arg1, arg2), make_pair(powerDerivOrder, adlparam)), op_code)] = this;
 }
 
 BinaryOpNode::BinaryOpNode(DataTree &datatree_arg, const expr_t arg1_arg,
@@ -2898,10 +2899,25 @@ BinaryOpNode::BinaryOpNode(DataTree &datatree_arg, const expr_t arg1_arg,
   arg1(arg1_arg),
   arg2(arg2_arg),
   op_code(op_code_arg),
-  powerDerivOrder(powerDerivOrder_arg)
+  powerDerivOrder(powerDerivOrder_arg),
+  adlparam("")
 {
   assert(powerDerivOrder >= 0);
-  datatree.binary_op_node_map[make_pair(make_pair(make_pair(arg1, arg2), powerDerivOrder), op_code)] = this;
+  datatree.binary_op_node_map[make_pair(make_pair(make_pair(arg1, arg2), make_pair(powerDerivOrder, adlparam)), op_code)] = this;
+}
+
+BinaryOpNode::BinaryOpNode(DataTree &datatree_arg, const expr_t arg1_arg,
+                           BinaryOpcode op_code_arg, const expr_t arg2_arg,
+                           int powerDerivOrder_arg, string adlparam_arg) :
+  ExprNode(datatree_arg),
+  arg1(arg1_arg),
+  arg2(arg2_arg),
+  op_code(op_code_arg),
+  powerDerivOrder(powerDerivOrder_arg),
+  adlparam(adlparam_arg)
+{
+  assert(powerDerivOrder >= 0);
+  datatree.binary_op_node_map[make_pair(make_pair(make_pair(arg1, arg2), make_pair(powerDerivOrder, adlparam)), op_code)] = this;
 }
 
 void
@@ -4101,9 +4117,7 @@ BinaryOpNode::buildSimilarBinaryOpNode(expr_t alt_arg1, expr_t alt_arg2, DataTre
     case oPowerDeriv:
       return alt_datatree.AddPowerDeriv(alt_arg1, alt_arg2, powerDerivOrder);
     case oAdl:
-      DataTree::adl_map_t::const_iterator it = datatree.adl_map.find(const_cast<BinaryOpNode *>(this));
-      assert (it != datatree.adl_map.end());
-      return alt_datatree.AddAdl(alt_arg1, *(it->second), alt_arg2);
+      return alt_datatree.AddAdl(alt_arg1, adlparam, alt_arg2);
     }
   // Suppress GCC warning
   exit(EXIT_FAILURE);
@@ -4301,16 +4315,13 @@ BinaryOpNode::substituteAdlAndDiff() const
     }
 
   expr_t arg1subst = arg1->substituteAdlAndDiff();
-  DataTree::adl_map_t::const_iterator it = datatree.adl_map.find(const_cast<BinaryOpNode *>(this));
-  assert (it != datatree.adl_map.end());
-
   int i = 1;
-  expr_t retval = datatree.AddTimes(datatree.AddVariable(datatree.symbol_table.addAdlParameter(*(it->second), i), 0),
+  expr_t retval = datatree.AddTimes(datatree.AddVariable(datatree.symbol_table.addAdlParameter(adlparam, i), 0),
                                     arg1subst->decreaseLeadsLags(i));
   i++;
   for (; i <= (int) arg2->eval(eval_context_t()); i++)
     retval = datatree.AddPlus(retval,
-                              datatree.AddTimes(datatree.AddVariable(datatree.symbol_table.addAdlParameter(*(it->second), i), 0),
+                              datatree.AddTimes(datatree.AddVariable(datatree.symbol_table.addAdlParameter(adlparam, i), 0),
                                                 arg1subst->decreaseLeadsLags(i)));
   return retval;
 }
diff --git a/ExprNode.hh b/ExprNode.hh
index 4845e4a6..b54587ee 100644
--- a/ExprNode.hh
+++ b/ExprNode.hh
@@ -730,11 +730,14 @@ private:
   //! Returns the derivative of this node if darg1 and darg2 are the derivatives of the arguments
   expr_t composeDerivatives(expr_t darg1, expr_t darg2);
   const int powerDerivOrder;
+  const string adlparam;
 public:
   BinaryOpNode(DataTree &datatree_arg, const expr_t arg1_arg,
                BinaryOpcode op_code_arg, const expr_t arg2_arg);
   BinaryOpNode(DataTree &datatree_arg, const expr_t arg1_arg,
                BinaryOpcode op_code_arg, const expr_t arg2_arg, int powerDerivOrder);
+  BinaryOpNode(DataTree &datatree_arg, const expr_t arg1_arg,
+               BinaryOpcode op_code_arg, const expr_t arg2_arg, int powerDerivOrder_arg, string adlparam_arg);
   virtual void prepareForDerivation();
   virtual int precedenceJson(const temporary_terms_t &temporary_terms) const;
   virtual int precedence(ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms) const;
diff --git a/SymbolTable.cc b/SymbolTable.cc
index 27450363..d92661e6 100644
--- a/SymbolTable.cc
+++ b/SymbolTable.cc
@@ -650,7 +650,7 @@ SymbolTable::addLagAuxiliaryVarInternal(bool endo, int orig_symb_id, int orig_le
 }
 
 int
-SymbolTable::addAdlParameter(string &basename, int lag) throw (FrozenException)
+SymbolTable::addAdlParameter(const string &basename, int lag) throw (FrozenException)
 {
   ostringstream varname;
   varname << basename << "_lag_" << lag;
diff --git a/SymbolTable.hh b/SymbolTable.hh
index a804c908..bca56d5f 100644
--- a/SymbolTable.hh
+++ b/SymbolTable.hh
@@ -287,7 +287,7 @@ public:
   /*
   // Adds a parameter for the transformation of the adl operator
   */
-  int addAdlParameter(string &basename, int lag) throw (FrozenException);
+  int addAdlParameter(const string &basename, int lag) throw (FrozenException);
   //! Returns the number of auxiliary variables
   int
   AuxVarsSize() const
-- 
GitLab