From 4897ab7d6955dd209aa3f417374f7287fee9ae9b Mon Sep 17 00:00:00 2001
From: Houtan Bastani <houtan@dynare.org>
Date: Wed, 28 Feb 2018 11:31:08 +0100
Subject: [PATCH] clean up adl code

---
 src/DataTree.cc      |  8 +----
 src/DataTree.hh      | 13 ++++----
 src/ExprNode.cc      | 72 ++++++++++++++++++--------------------------
 src/ExprNode.hh      |  3 +-
 src/ParsingDriver.cc | 19 +++---------
 5 files changed, 43 insertions(+), 72 deletions(-)

diff --git a/src/DataTree.cc b/src/DataTree.cc
index 49313722..217a3137 100644
--- a/src/DataTree.cc
+++ b/src/DataTree.cc
@@ -263,16 +263,10 @@ DataTree::AddDiff(expr_t iArg1)
   return AddUnaryOp(oDiff, iArg1);
 }
 
-expr_t
-DataTree::AddAdl(expr_t iArg1, const string &name, int lag)
-{
-  return AddUnaryOp(oAdl, iArg1, 0, 0, 0, string(name), lag);
-}
-
 expr_t
 DataTree::AddAdl(expr_t iArg1, const string &name, const vector<int> &lags)
 {
-  return AddUnaryOp(oAdl, iArg1, 0, 0, 0, string(name), -1, lags);
+  return AddUnaryOp(oAdl, iArg1, 0, 0, 0, string(name), lags);
 }
 
 expr_t
diff --git a/src/DataTree.hh b/src/DataTree.hh
index e55de76c..d302063e 100644
--- a/src/DataTree.hh
+++ b/src/DataTree.hh
@@ -65,7 +65,7 @@ protected:
   variable_node_map_t variable_node_map;
   //! Pair( Pair(arg1, UnaryOpCode), Pair( Expectation Info Set, Pair(param1_symb_id, param2_symb_id)) ))
 
-  typedef map<pair<pair<expr_t, UnaryOpcode>, pair<pair<int, pair<int, int> >, pair<string, pair<int, vector<int> > > > >, UnaryOpNode *> unary_op_node_map_t;
+  typedef map<pair<pair<expr_t, UnaryOpcode>, pair<pair<int, pair<int, int> >, pair<string, vector<int> > > >, UnaryOpNode *> unary_op_node_map_t;
   unary_op_node_map_t unary_op_node_map;
   //! Pair( Pair( Pair(arg1, arg2), order of Power Derivative), opCode)
   typedef map<pair<pair<pair<expr_t, expr_t>, int>, BinaryOpcode>, BinaryOpNode *> binary_op_node_map_t;
@@ -111,7 +111,7 @@ private:
   int node_counter;
 
   inline expr_t AddPossiblyNegativeConstant(double val);
-  inline expr_t AddUnaryOp(UnaryOpcode op_code, expr_t arg, int arg_exp_info_set = 0, int param1_symb_id = 0, int param2_symb_id = 0, const string &adl_param_name = "", int adl_param_lag = -1, const vector<int> &adl_lags = vector<int>());
+  inline expr_t AddUnaryOp(UnaryOpcode op_code, expr_t arg, int arg_exp_info_set = 0, int param1_symb_id = 0, int param2_symb_id = 0, const string &adl_param_name = "", const vector<int> &adl_lags = vector<int>());
   inline expr_t AddBinaryOp(expr_t arg1, BinaryOpcode op_code, expr_t arg2, int powerDerivOrder = 0);
   inline expr_t AddTrinaryOp(expr_t arg1, TrinaryOpcode op_code, expr_t arg2, expr_t arg3);
 
@@ -172,8 +172,7 @@ public:
   expr_t AddExpectation(int iArg1, expr_t iArg2);
   //! Adds "diff(arg)" to model tree
   expr_t AddDiff(expr_t iArg1);
-  //! Adds "adl(arg1, arg2)" to model tree
-  expr_t AddAdl(expr_t iArg1, const string &name, int lag);
+  //! Adds "adl(arg1, name, lag/lags)" to model tree
   expr_t AddAdl(expr_t iArg1, const string &name, const vector<int> &lags);
   //! Adds "exp(arg)" to model tree
   expr_t AddExp(expr_t iArg1);
@@ -327,10 +326,10 @@ DataTree::AddPossiblyNegativeConstant(double v)
 }
 
 inline expr_t
-DataTree::AddUnaryOp(UnaryOpcode op_code, expr_t arg, int arg_exp_info_set, int param1_symb_id, int param2_symb_id, const string &adl_param_name, int adl_param_lag, const vector<int> &adl_lags)
+DataTree::AddUnaryOp(UnaryOpcode op_code, expr_t arg, int arg_exp_info_set, int param1_symb_id, int param2_symb_id, const string &adl_param_name, const vector<int> &adl_lags)
 {
   // If the node already exists in tree, share it
-  unary_op_node_map_t::iterator it = unary_op_node_map.find(make_pair(make_pair(arg, op_code), make_pair(make_pair(arg_exp_info_set, make_pair(param1_symb_id, param2_symb_id)), make_pair(adl_param_name, make_pair(adl_param_lag, adl_lags)))));
+  unary_op_node_map_t::iterator it = unary_op_node_map.find(make_pair(make_pair(arg, op_code), make_pair(make_pair(arg_exp_info_set, make_pair(param1_symb_id, param2_symb_id)), make_pair(adl_param_name, adl_lags))));
   if (it != unary_op_node_map.end())
     return it->second;
 
@@ -349,7 +348,7 @@ DataTree::AddUnaryOp(UnaryOpcode op_code, expr_t arg, int arg_exp_info_set, int
         {
         }
     }
-  return new UnaryOpNode(*this, op_code, arg, arg_exp_info_set, param1_symb_id, param2_symb_id, adl_param_name, adl_param_lag, adl_lags);
+  return new UnaryOpNode(*this, op_code, arg, arg_exp_info_set, param1_symb_id, param2_symb_id, adl_param_name, adl_lags);
 }
 
 inline expr_t
diff --git a/src/ExprNode.cc b/src/ExprNode.cc
index 5ed78458..43095bcf 100644
--- a/src/ExprNode.cc
+++ b/src/ExprNode.cc
@@ -1676,7 +1676,7 @@ VariableNode::getEndosAndMaxLags(map<string, int> &model_endos_and_lags) const
       model_endos_and_lags[varname] = lag;
 }
 
-UnaryOpNode::UnaryOpNode(DataTree &datatree_arg, UnaryOpcode op_code_arg, const expr_t arg_arg, int expectation_information_set_arg, int param1_symb_id_arg, int param2_symb_id_arg, const string &adl_param_name_arg, int adl_param_lag_arg, vector<int> adl_lags_arg) :
+UnaryOpNode::UnaryOpNode(DataTree &datatree_arg, UnaryOpcode op_code_arg, const expr_t arg_arg, int expectation_information_set_arg, int param1_symb_id_arg, int param2_symb_id_arg, const string &adl_param_name_arg, vector<int> adl_lags_arg) :
   ExprNode(datatree_arg),
   arg(arg_arg),
   expectation_information_set(expectation_information_set_arg),
@@ -1684,14 +1684,13 @@ UnaryOpNode::UnaryOpNode(DataTree &datatree_arg, UnaryOpcode op_code_arg, const
   param2_symb_id(param2_symb_id_arg),
   op_code(op_code_arg),
   adl_param_name(adl_param_name_arg),
-  adl_param_lag(adl_param_lag_arg),
   adl_lags(adl_lags_arg)
 {
   // Add myself to the unary op map
   datatree.unary_op_node_map[make_pair(make_pair(arg, op_code),
                                        make_pair(make_pair(expectation_information_set,
                                                            make_pair(param1_symb_id, param2_symb_id)),
-                                                 make_pair(adl_param_name, make_pair(adl_param_lag, adl_lags))))] = this;
+                                                 make_pair(adl_param_name, adl_lags)))] = this;
 }
 
 void
@@ -2126,10 +2125,16 @@ UnaryOpNode::writeJsonOutput(ostream &output,
       output << "diff";
       break;
     case oAdl:
-      output << "adl";
-      output << "(";
+      output << "adl(";
       arg->writeJsonOutput(output, temporary_terms, tef_terms);
-      output << "," << adl_param_name << "," << adl_param_lag << ")";
+      output << ", '" << adl_param_name << "', [";
+      for (vector<int>::const_iterator it = adl_lags.begin(); it != adl_lags.end(); it++)
+        {
+          if (it != adl_lags.begin())
+            output << ", ";
+          output << *it;
+        }
+      output << "])";
       return;
     case oSteadyState:
       output << "(";
@@ -2727,7 +2732,7 @@ UnaryOpNode::buildSimilarUnaryOpNode(expr_t alt_arg, DataTree &alt_datatree) con
     case oDiff:
       return alt_datatree.AddDiff(alt_arg);
     case oAdl:
-      return alt_datatree.AddAdl(alt_arg, adl_param_name, adl_param_lag);
+      return alt_datatree.AddAdl(alt_arg, adl_param_name, adl_lags);
     }
   // Suppress GCC warning
   exit(EXIT_FAILURE);
@@ -2795,41 +2800,24 @@ UnaryOpNode::substituteAdl() const
   expr_t arg1subst = arg->substituteAdl();
   expr_t retval = NULL;
   ostringstream inttostr;
-  if (adl_param_lag >= 0)
-    {
-      int i = 1;
-      inttostr << i;
-      retval = datatree.AddTimes(datatree.AddVariable(datatree.symbol_table.getID(adl_param_name + "_lag_" + inttostr.str()), 0),
-                                 arg1subst->decreaseLeadsLags(i));
-      i++;
-      for (; i <= adl_param_lag; i++)
-        {
-          inttostr.clear();
-          inttostr.str("");
-          inttostr << i;
-          retval = datatree.AddPlus(retval,
-                                    datatree.AddTimes(datatree.AddVariable(datatree.symbol_table.getID(adl_param_name + "_lag_" + inttostr.str()), 0),
-                                                      arg1subst->decreaseLeadsLags(i)));
-        }
-    }
-  else
-    for (vector<int>::const_iterator it = adl_lags.begin(); it != adl_lags.end(); it++)
-      if (it == adl_lags.begin())
-        {
-          inttostr << *it;
-          retval = datatree.AddTimes(datatree.AddVariable(datatree.symbol_table.getID(adl_param_name + "_lag_" + inttostr.str()), 0),
-                                     arg1subst->decreaseLeadsLags(*it));
-        }
-      else
-        {
-          inttostr.clear();
-          inttostr.str("");
-          inttostr << *it;
-          retval = datatree.AddPlus(retval,
-                                    datatree.AddTimes(datatree.AddVariable(datatree.symbol_table.getID(adl_param_name + "_lag_"
-                                                                                                       + inttostr.str()), 0),
-                                                      arg1subst->decreaseLeadsLags(*it)));
-        }
+
+  for (vector<int>::const_iterator it = adl_lags.begin(); it != adl_lags.end(); it++)
+    if (it == adl_lags.begin())
+      {
+        inttostr << *it;
+        retval = datatree.AddTimes(datatree.AddVariable(datatree.symbol_table.getID(adl_param_name + "_lag_" + inttostr.str()), 0),
+                                   arg1subst->decreaseLeadsLags(*it));
+      }
+    else
+      {
+        inttostr.clear();
+        inttostr.str("");
+        inttostr << *it;
+        retval = datatree.AddPlus(retval,
+                                  datatree.AddTimes(datatree.AddVariable(datatree.symbol_table.getID(adl_param_name + "_lag_"
+                                                                                                     + inttostr.str()), 0),
+                                                    arg1subst->decreaseLeadsLags(*it)));
+      }
   return retval;
 }
 
diff --git a/src/ExprNode.hh b/src/ExprNode.hh
index cc8409ad..7fe371cc 100644
--- a/src/ExprNode.hh
+++ b/src/ExprNode.hh
@@ -667,7 +667,6 @@ private:
   const int param1_symb_id, param2_symb_id;
   const UnaryOpcode op_code;
   const string adl_param_name;
-  const int adl_param_lag;
   const vector<int> adl_lags;
   virtual expr_t computeDerivative(int deriv_id);
   virtual int cost(int cost, bool is_matlab) const;
@@ -676,7 +675,7 @@ private:
   //! Returns the derivative of this node if darg is the derivative of the argument
   expr_t composeDerivatives(expr_t darg, int deriv_id);
 public:
-  UnaryOpNode(DataTree &datatree_arg, UnaryOpcode op_code_arg, const expr_t arg_arg, int expectation_information_set_arg, int param1_symb_id_arg, int param2_symb_id_arg, const string &adl_param_name_arg, int adl_param_lag_arg, vector<int> adl_lags_arg);
+  UnaryOpNode(DataTree &datatree_arg, UnaryOpcode op_code_arg, const expr_t arg_arg, int expectation_information_set_arg, int param1_symb_id_arg, int param2_symb_id_arg, const string &adl_param_name_arg, vector<int> adl_lags_arg);
   virtual void prepareForDerivation();
   virtual void computeTemporaryTerms(map<expr_t, pair<int, NodeTreeReference> > &reference_count,
                                      map<NodeTreeReference, temporary_terms_t> &temp_terms_map,
diff --git a/src/ParsingDriver.cc b/src/ParsingDriver.cc
index 4e2ede10..e949294e 100644
--- a/src/ParsingDriver.cc
+++ b/src/ParsingDriver.cc
@@ -2757,22 +2757,13 @@ ParsingDriver::add_diff(expr_t arg1)
 expr_t
 ParsingDriver::add_adl(expr_t arg1, string *name, string *lag)
 {
-  expr_t id = data_tree->AddAdl(arg1, *name, atoi(lag->c_str()));
+  vector<int> *lags = new vector<int>();
+  for (int i = 1; i <= atoi(lag->c_str()); i++)
+    lags->push_back(i);
 
-  // Declare parameters here so that parameters can be initialized after the model block
-  int i = 0;
-  ostringstream inttostr;
-  for (; i < atoi(lag->c_str()); i++)
-    {
-      inttostr.clear();
-      inttostr.str("");
-      inttostr << i + 1;
-      declare_parameter(new string(*name + "_lag_" + inttostr.str()));
-    }
-
-  delete name;
   delete lag;
-  return id;
+
+  return add_adl(arg1, name, lags);
 }
 
 expr_t
-- 
GitLab