From a76d4495d3126a0570d3b6f3b76f56cc76bdd792 Mon Sep 17 00:00:00 2001
From: Houtan Bastani <houtan@dynare.org>
Date: Mon, 3 Jul 2017 16:38:44 +0200
Subject: [PATCH] preprocessor: move adl to unaryopnode

---
 CodeInterpreter.hh |   6 +--
 DataTree.cc        |   5 +-
 DataTree.hh        |  25 +++++----
 ExprNode.cc        | 130 +++++++++++++++++++++------------------------
 ExprNode.hh        |   6 +--
 ParsingDriver.cc   |   3 +-
 6 files changed, 82 insertions(+), 93 deletions(-)

diff --git a/CodeInterpreter.hh b/CodeInterpreter.hh
index 8dfd7e40..de7cf9dc 100644
--- a/CodeInterpreter.hh
+++ b/CodeInterpreter.hh
@@ -201,7 +201,8 @@ enum UnaryOpcode
     oSteadyStateParam2ndDeriv, // for the 2nd derivative of the STEADY_STATE operator w.r.t. to a parameter
     oExpectation,
     oErf,
-    oDiff
+    oDiff,
+    oAdl
   };
 
 enum BinaryOpcode
@@ -220,8 +221,7 @@ enum BinaryOpcode
     oLessEqual,
     oGreaterEqual,
     oEqualEqual,
-    oDifferent,
-    oAdl
+    oDifferent
   };
 
 enum TrinaryOpcode
diff --git a/DataTree.cc b/DataTree.cc
index 39278b19..e4220bcd 100644
--- a/DataTree.cc
+++ b/DataTree.cc
@@ -264,10 +264,9 @@ DataTree::AddDiff(expr_t iArg1)
 }
 
 expr_t
-DataTree::AddAdl(expr_t iArg1, const string &name, expr_t iArg2)
+DataTree::AddAdl(expr_t iArg1, const string &name, int lag)
 {
-  expr_t adlnode = AddBinaryOp(iArg1, oAdl, iArg2, 0, string(name));
-  return adlnode;
+  return AddUnaryOp(oAdl, iArg1, 0, 0, 0, string(name), lag);
 }
 
 expr_t
diff --git a/DataTree.hh b/DataTree.hh
index 59d05a1a..72f8086d 100644
--- a/DataTree.hh
+++ b/DataTree.hh
@@ -63,10 +63,11 @@ protected:
   typedef map<pair<int, int>, VariableNode *> variable_node_map_t;
   variable_node_map_t variable_node_map;
   //! Pair( Pair(arg1, UnaryOpCode), Pair( Expectation Info Set, Pair(param1_symb_id, param2_symb_id)) ))
-  typedef map<pair<pair<expr_t, UnaryOpcode>, pair<int, pair<int, int> > >, UnaryOpNode *> unary_op_node_map_t;
+
+  typedef map<pair<pair<expr_t, UnaryOpcode>, pair<pair<int, pair<int, int> >, pair<string, int> > >, UnaryOpNode *> unary_op_node_map_t;
   unary_op_node_map_t unary_op_node_map;
   //! Pair( Pair( Pair(arg1, arg2), order of Power Derivative), opCode)
-  typedef map<pair<pair<pair<expr_t, expr_t>, pair<int, string> >, BinaryOpcode>, BinaryOpNode *> binary_op_node_map_t;
+  typedef map<pair<pair<pair<expr_t, expr_t>, int>, BinaryOpcode>, BinaryOpNode *> binary_op_node_map_t;
   binary_op_node_map_t binary_op_node_map;
   typedef map<pair<pair<pair<expr_t, expr_t>, expr_t>, TrinaryOpcode>, TrinaryOpNode *> trinary_op_node_map_t;
   trinary_op_node_map_t trinary_op_node_map;
@@ -103,8 +104,8 @@ private:
   int node_counter;
 
   inline expr_t AddPossiblyNegativeConstant(double val);
-  inline expr_t AddUnaryOp(UnaryOpcode op_code, expr_t arg, int arg_exp_info_set = 0, int param1_symb_id = 0, int param2_symb_id = 0);
-  inline expr_t AddBinaryOp(expr_t arg1, BinaryOpcode op_code, expr_t arg2, int powerDerivOrder = 0, const string &adlparam = "");
+  inline expr_t AddUnaryOp(UnaryOpcode op_code, expr_t arg, int arg_exp_info_set = 0, int param1_symb_id = 0, int param2_symb_id = 0, const string &adl_param_name = "", int adl_param_lag = -1);
+  inline expr_t AddBinaryOp(expr_t arg1, BinaryOpcode op_code, expr_t arg2, int powerDerivOrder = 0);
   inline expr_t AddTrinaryOp(expr_t arg1, TrinaryOpcode op_code, expr_t arg2, expr_t arg3);
 
 public:
@@ -165,7 +166,7 @@ public:
   //! Adds "diff(arg)" to model tree
   expr_t AddDiff(expr_t iArg1);
   //! Adds "adl(arg1, arg2)" to model tree
-  expr_t AddAdl(expr_t iArg1, const string &name, expr_t iArg2);
+  expr_t AddAdl(expr_t iArg1, const string &name, int lag);
   //! Adds "exp(arg)" to model tree
   expr_t AddExp(expr_t iArg1);
   //! Adds "log(arg)" to model tree
@@ -316,10 +317,10 @@ DataTree::AddPossiblyNegativeConstant(double v)
 }
 
 inline expr_t
-DataTree::AddUnaryOp(UnaryOpcode op_code, expr_t arg, int arg_exp_info_set, int param1_symb_id, int param2_symb_id)
+DataTree::AddUnaryOp(UnaryOpcode op_code, expr_t arg, int arg_exp_info_set, int param1_symb_id, int param2_symb_id, const string &adl_param_name, int adl_param_lag)
 {
   // If the node already exists in tree, share it
-  unary_op_node_map_t::iterator it = unary_op_node_map.find(make_pair(make_pair(arg, op_code), make_pair(arg_exp_info_set, make_pair(param1_symb_id, param2_symb_id))));
+  unary_op_node_map_t::iterator it = unary_op_node_map.find(make_pair(make_pair(arg, op_code), make_pair(make_pair(arg_exp_info_set, make_pair(param1_symb_id, param2_symb_id)), make_pair(adl_param_name, adl_param_lag))));
   if (it != unary_op_node_map.end())
     return it->second;
 
@@ -338,15 +339,13 @@ DataTree::AddUnaryOp(UnaryOpcode op_code, expr_t arg, int arg_exp_info_set, int
         {
         }
     }
-  return new UnaryOpNode(*this, op_code, arg, arg_exp_info_set, param1_symb_id, param2_symb_id);
+  return new UnaryOpNode(*this, op_code, arg, arg_exp_info_set, param1_symb_id, param2_symb_id, adl_param_name, adl_param_lag);
 }
 
 inline expr_t
-DataTree::AddBinaryOp(expr_t arg1, BinaryOpcode op_code, expr_t arg2, int powerDerivOrder, const string &adlparam)
+DataTree::AddBinaryOp(expr_t arg1, BinaryOpcode op_code, expr_t arg2, int powerDerivOrder)
 {
-  binary_op_node_map_t::iterator it = binary_op_node_map.find(make_pair(make_pair(make_pair(arg1, arg2),
-                                                                                  make_pair(powerDerivOrder, adlparam)),
-                                                                        op_code));
+  binary_op_node_map_t::iterator it = binary_op_node_map.find(make_pair(make_pair(make_pair(arg1, arg2), powerDerivOrder), op_code));
   if (it != binary_op_node_map.end())
     return it->second;
 
@@ -361,7 +360,7 @@ DataTree::AddBinaryOp(expr_t arg1, BinaryOpcode op_code, expr_t arg2, int powerD
   catch (ExprNode::EvalException &e)
     {
     }
-  return new BinaryOpNode(*this, arg1, op_code, arg2, powerDerivOrder, adlparam);
+  return new BinaryOpNode(*this, arg1, op_code, arg2, powerDerivOrder);
 }
 
 inline expr_t
diff --git a/ExprNode.cc b/ExprNode.cc
index 855f6d57..7bc69fb3 100644
--- a/ExprNode.cc
+++ b/ExprNode.cc
@@ -1601,18 +1601,21 @@ VariableNode::getEndosAndMaxLags(map<string, int> &model_endos_and_lags) const
       model_endos_and_lags[varname] = lag;
 }
 
-UnaryOpNode::UnaryOpNode(DataTree &datatree_arg, UnaryOpcode op_code_arg, const expr_t arg_arg, int expectation_information_set_arg, int param1_symb_id_arg, int param2_symb_id_arg) :
+UnaryOpNode::UnaryOpNode(DataTree &datatree_arg, UnaryOpcode op_code_arg, const expr_t arg_arg, int expectation_information_set_arg, int param1_symb_id_arg, int param2_symb_id_arg, const string &adl_param_name_arg, int adl_param_lag_arg) :
   ExprNode(datatree_arg),
   arg(arg_arg),
   expectation_information_set(expectation_information_set_arg),
   param1_symb_id(param1_symb_id_arg),
   param2_symb_id(param2_symb_id_arg),
-  op_code(op_code_arg)
+  op_code(op_code_arg),
+  adl_param_name(adl_param_name_arg),
+  adl_param_lag(adl_param_lag_arg)
 {
   // Add myself to the unary op map
   datatree.unary_op_node_map[make_pair(make_pair(arg, op_code),
-                                       make_pair(expectation_information_set,
-                                                 make_pair(param1_symb_id, param2_symb_id)))] = this;
+                                       make_pair(make_pair(expectation_information_set,
+                                                           make_pair(param1_symb_id, param2_symb_id)),
+                                                 make_pair(adl_param_name, adl_param_lag)))] = this;
 }
 
 void
@@ -1761,6 +1764,9 @@ UnaryOpNode::composeDerivatives(expr_t darg, int deriv_id)
     case oDiff:
       cerr << "UnaryOpNode::composeDerivatives: not implemented on oDiff" << endl;
       exit(EXIT_FAILURE);
+    case oAdl:
+      cerr << "UnaryOpNode::composeDerivatives: not implemented on oAdl" << endl;
+      exit(EXIT_FAILURE);
     }
   // Suppress GCC warning
   exit(EXIT_FAILURE);
@@ -1845,6 +1851,9 @@ UnaryOpNode::cost(int cost, bool is_matlab) const
       case oDiff:
         cerr << "UnaryOpNode::cost: not implemented on oDiff" << endl;
         exit(EXIT_FAILURE);
+      case oAdl:
+        cerr << "UnaryOpNode::cost: not implemented on oAdl" << endl;
+        exit(EXIT_FAILURE);
       }
   else
     // Cost for C files
@@ -1890,6 +1899,9 @@ UnaryOpNode::cost(int cost, bool is_matlab) const
       case oDiff:
         cerr << "UnaryOpNode::cost: not implemented on oDiff" << endl;
         exit(EXIT_FAILURE);
+      case oAdl:
+        cerr << "UnaryOpNode::cost: not implemented on oAdl" << endl;
+        exit(EXIT_FAILURE);
       }
   exit(EXIT_FAILURE);
 }
@@ -2036,6 +2048,12 @@ UnaryOpNode::writeJsonOutput(ostream &output,
     case oDiff:
       output << "diff";
       break;
+    case oAdl:
+      output << "adl";
+      output << "(";
+      arg->writeJsonOutput(output, temporary_terms, tef_terms);
+      output << "," << adl_param_name << "," << adl_param_lag << ")";
+      return;
     case oSteadyState:
       output << "(";
       arg->writeJsonOutput(output, temporary_terms, tef_terms);
@@ -2259,6 +2277,9 @@ UnaryOpNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
     case oDiff:
       output << "diff";
       break;
+    case oAdl:
+      output << "adl";
+      break;
     }
 
   bool close_parenthesis = false;
@@ -2367,6 +2388,9 @@ UnaryOpNode::eval_opcode(UnaryOpcode op_code, double v) throw (EvalException, Ev
     case oDiff:
       cerr << "UnaryOpNode::eval_opcode: not implemented on oDiff" << endl;
       exit(EXIT_FAILURE);
+    case oAdl:
+      cerr << "UnaryOpNode::eval_opcode: not implemented on oAdl" << endl;
+      exit(EXIT_FAILURE);
     }
   // Suppress GCC warning
   exit(EXIT_FAILURE);
@@ -2618,6 +2642,8 @@ UnaryOpNode::buildSimilarUnaryOpNode(expr_t alt_arg, DataTree &alt_datatree) con
       return alt_datatree.AddErf(alt_arg);
     case oDiff:
       return alt_datatree.AddDiff(alt_arg);
+    case oAdl:
+      return alt_datatree.AddAdl(alt_arg, adl_param_name, adl_param_lag);
     }
   // Suppress GCC warning
   exit(EXIT_FAILURE);
@@ -2676,15 +2702,36 @@ UnaryOpNode::maxLead() const
 expr_t
 UnaryOpNode::substituteAdlAndDiff() const
 {
-  if (op_code != oDiff)
+  if (op_code != oDiff && op_code != oAdl)
     {
       expr_t argsubst = arg->substituteAdlAndDiff();
       return buildSimilarUnaryOpNode(argsubst, datatree);
     }
 
-  expr_t argsubst = arg->substituteAdlAndDiff();
-  return datatree.AddMinus(argsubst,
-                           argsubst->decreaseLeadsLags(1));
+  if (op_code == oDiff)
+    {
+      expr_t argsubst = arg->substituteAdlAndDiff();
+      return datatree.AddMinus(argsubst,
+                               argsubst->decreaseLeadsLags(1));
+    }
+
+  expr_t arg1subst = arg->substituteAdlAndDiff();
+  int i = 1;
+  ostringstream inttostr;
+  inttostr << i;
+  expr_t retval = datatree.AddTimes(datatree.AddVariable(datatree.symbol_table.getID(adl_param_name + "_lag_" + inttostr.str()), 0),
+                                    arg1subst->decreaseLeadsLags(i));
+  i++;
+  for (; i <= adl_param_lag; i++)
+    {
+      inttostr.clear();
+      inttostr.str("");
+      inttostr << i;
+      retval = datatree.AddPlus(retval,
+                                datatree.AddTimes(datatree.AddVariable(datatree.symbol_table.getID(adl_param_name + "_lag_" + inttostr.str()), 0),
+                                                  arg1subst->decreaseLeadsLags(i)));
+    }
+  return retval;
 }
 
 expr_t
@@ -2887,10 +2934,9 @@ BinaryOpNode::BinaryOpNode(DataTree &datatree_arg, const expr_t arg1_arg,
   arg1(arg1_arg),
   arg2(arg2_arg),
   op_code(op_code_arg),
-  powerDerivOrder(0),
-  adlparam("")
+  powerDerivOrder(0)
 {
-  datatree.binary_op_node_map[make_pair(make_pair(make_pair(arg1, arg2), make_pair(powerDerivOrder, adlparam)), op_code)] = this;
+  datatree.binary_op_node_map[make_pair(make_pair(make_pair(arg1, arg2), powerDerivOrder), op_code)] = this;
 }
 
 BinaryOpNode::BinaryOpNode(DataTree &datatree_arg, const expr_t arg1_arg,
@@ -2899,25 +2945,10 @@ BinaryOpNode::BinaryOpNode(DataTree &datatree_arg, const expr_t arg1_arg,
   arg1(arg1_arg),
   arg2(arg2_arg),
   op_code(op_code_arg),
-  powerDerivOrder(powerDerivOrder_arg),
-  adlparam("")
-{
-  assert(powerDerivOrder >= 0);
-  datatree.binary_op_node_map[make_pair(make_pair(make_pair(arg1, arg2), make_pair(powerDerivOrder, adlparam)), op_code)] = this;
-}
-
-BinaryOpNode::BinaryOpNode(DataTree &datatree_arg, const expr_t arg1_arg,
-                           BinaryOpcode op_code_arg, const expr_t arg2_arg,
-                           int powerDerivOrder_arg, string adlparam_arg) :
-  ExprNode(datatree_arg),
-  arg1(arg1_arg),
-  arg2(arg2_arg),
-  op_code(op_code_arg),
-  powerDerivOrder(powerDerivOrder_arg),
-  adlparam(adlparam_arg)
+  powerDerivOrder(powerDerivOrder_arg)
 {
   assert(powerDerivOrder >= 0);
-  datatree.binary_op_node_map[make_pair(make_pair(make_pair(arg1, arg2), make_pair(powerDerivOrder, adlparam)), op_code)] = this;
+  datatree.binary_op_node_map[make_pair(make_pair(make_pair(arg1, arg2), powerDerivOrder), op_code)] = this;
 }
 
 void
@@ -3051,9 +3082,6 @@ BinaryOpNode::composeDerivatives(expr_t darg1, expr_t darg2)
       return datatree.AddPlus(t14, t12);
     case oEqual:
       return datatree.AddMinus(darg1, darg2);
-    case oAdl:
-      cerr << "BinaryOpNode::composeDerivatives not implemented for oAdl";
-      exit(EXIT_FAILURE);
     }
   // Suppress GCC warning
   exit(EXIT_FAILURE);
@@ -3120,9 +3148,6 @@ BinaryOpNode::precedence(ExprNodeOutputType output_type, const temporary_terms_t
     case oMin:
     case oMax:
       return 100;
-    case oAdl:
-      cerr << "BinaryOpNode::precedence not implemented for oAdl";
-      exit(EXIT_FAILURE);
     }
   // Suppress GCC warning
   exit(EXIT_FAILURE);
@@ -3159,7 +3184,6 @@ BinaryOpNode::precedenceJson(const temporary_terms_t &temporary_terms) const
       return 5;
     case oMin:
     case oMax:
-    case oAdl:
       return 100;
     }
   // Suppress GCC warning
@@ -3220,9 +3244,6 @@ BinaryOpNode::cost(int cost, bool is_matlab) const
         return cost + (MIN_COST_MATLAB/2+1);
       case oEqual:
         return cost;
-      case oAdl:
-        cerr << "BinaryOpNode::cost not implemented for oAdl";
-        exit(EXIT_FAILURE);
       }
   else
     // Cost for C files
@@ -3250,9 +3271,6 @@ BinaryOpNode::cost(int cost, bool is_matlab) const
         return cost + (MIN_COST_C/2+1);;
       case oEqual:
         return cost;
-      case oAdl:
-        cerr << "BinaryOpNode::cost not implemented for oAdl";
-        exit(EXIT_FAILURE);
       }
   // Suppress GCC warning
   exit(EXIT_FAILURE);
@@ -3365,9 +3383,6 @@ BinaryOpNode::eval_opcode(double v1, BinaryOpcode op_code, double v2, int derivO
       return (v1 != v2);
     case oEqual:
       throw EvalException();
-    case oAdl:
-        cerr << "BinaryOpNode::eval_opcode not implemented for oAdl";
-        exit(EXIT_FAILURE);
     }
   // Suppress GCC warning
   exit(EXIT_FAILURE);
@@ -4116,8 +4131,6 @@ BinaryOpNode::buildSimilarBinaryOpNode(expr_t alt_arg1, expr_t alt_arg2, DataTre
       return alt_datatree.AddDifferent(alt_arg1, alt_arg2);
     case oPowerDeriv:
       return alt_datatree.AddPowerDeriv(alt_arg1, alt_arg2, powerDerivOrder);
-    case oAdl:
-      return alt_datatree.AddAdl(alt_arg1, adlparam, alt_arg2);
     }
   // Suppress GCC warning
   exit(EXIT_FAILURE);
@@ -4307,30 +4320,9 @@ BinaryOpNode::substituteExpectation(subst_table_t &subst_table, vector<BinaryOpN
 expr_t
 BinaryOpNode::substituteAdlAndDiff() const
 {
-  if (op_code != oAdl)
-    {
-      expr_t arg1subst = arg1->substituteAdlAndDiff();
-      expr_t arg2subst = arg2->substituteAdlAndDiff();
-      return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
-    }
-
   expr_t arg1subst = arg1->substituteAdlAndDiff();
-  int i = 1;
-  ostringstream inttostr;
-  inttostr << i;
-  int param_symb_id = datatree.symbol_table.getID(adlparam + "_lag_" + inttostr.str());
-  expr_t retval = datatree.AddTimes(datatree.AddVariable(param_symb_id, 0), arg1subst->decreaseLeadsLags(i));
-  i++;
-  for (; i <= (int) arg2->eval(eval_context_t());)
-    {
-      inttostr.clear();
-      inttostr.str("");
-      inttostr << i++;
-      retval = datatree.AddPlus(retval,
-                                datatree.AddTimes(datatree.AddVariable(datatree.symbol_table.getID(adlparam + "_lag_" + inttostr.str()), 0),
-                                                  arg1subst->decreaseLeadsLags(i)));
-    }
-  return retval;
+  expr_t arg2subst = arg2->substituteAdlAndDiff();
+  return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
 }
 
 expr_t
diff --git a/ExprNode.hh b/ExprNode.hh
index b54587ee..aafd3d23 100644
--- a/ExprNode.hh
+++ b/ExprNode.hh
@@ -633,6 +633,8 @@ private:
   //! Only used for oSteadyStateParamDeriv and oSteadyStateParam2ndDeriv
   const int param1_symb_id, param2_symb_id;
   const UnaryOpcode op_code;
+  const string adl_param_name;
+  const int adl_param_lag;
   virtual expr_t computeDerivative(int deriv_id);
   virtual int cost(int cost, bool is_matlab) const;
   virtual int cost(const temporary_terms_t &temporary_terms, bool is_matlab) const;
@@ -640,7 +642,7 @@ private:
   //! Returns the derivative of this node if darg is the derivative of the argument
   expr_t composeDerivatives(expr_t darg, int deriv_id);
 public:
-  UnaryOpNode(DataTree &datatree_arg, UnaryOpcode op_code_arg, const expr_t arg_arg, int expectation_information_set_arg, int param1_symb_id_arg, int param2_symb_id_arg);
+  UnaryOpNode(DataTree &datatree_arg, UnaryOpcode op_code_arg, const expr_t arg_arg, int expectation_information_set_arg, int param1_symb_id_arg, int param2_symb_id_arg, const string &adl_param_name_arg, int adl_param_lag_arg);
   virtual void prepareForDerivation();
   virtual void computeTemporaryTerms(map<expr_t, pair<int, NodeTreeReference> > &reference_count,
                                      map<NodeTreeReference, temporary_terms_t> &temp_terms_map,
@@ -736,8 +738,6 @@ public:
                BinaryOpcode op_code_arg, const expr_t arg2_arg);
   BinaryOpNode(DataTree &datatree_arg, const expr_t arg1_arg,
                BinaryOpcode op_code_arg, const expr_t arg2_arg, int powerDerivOrder);
-  BinaryOpNode(DataTree &datatree_arg, const expr_t arg1_arg,
-               BinaryOpcode op_code_arg, const expr_t arg2_arg, int powerDerivOrder_arg, string adlparam_arg);
   virtual void prepareForDerivation();
   virtual int precedenceJson(const temporary_terms_t &temporary_terms) const;
   virtual int precedence(ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms) const;
diff --git a/ParsingDriver.cc b/ParsingDriver.cc
index dc0943cb..d8a1b1ea 100644
--- a/ParsingDriver.cc
+++ b/ParsingDriver.cc
@@ -2594,8 +2594,7 @@ ParsingDriver::add_diff(expr_t arg1)
 expr_t
 ParsingDriver::add_adl(expr_t arg1, string *name, string *lag)
 {
-  expr_t id = data_tree->AddAdl(arg1, *name,
-                                data_tree->AddNonNegativeConstant(*lag));
+  expr_t id = data_tree->AddAdl(arg1, *name, atoi(lag->c_str()));
 
   // Declare parameters here so that parameters can be initialized after the model block
   int i = 0;
-- 
GitLab