From 4c6032895da8ecabd345d70d28170de64e737244 Mon Sep 17 00:00:00 2001
From: Houtan Bastani <houtan@dynare.org>
Date: Mon, 24 Jun 2019 15:01:22 +0200
Subject: [PATCH] macro processor: fix bug in indexing of strings/arrays

---
 src/macro/Expressions.cc | 60 +++++++++++++++++++++++++++++-----------
 src/macro/Expressions.hh |  9 ++----
 src/macro/Parser.yy      |  2 +-
 3 files changed, 47 insertions(+), 24 deletions(-)

diff --git a/src/macro/Expressions.cc b/src/macro/Expressions.cc
index c0cc4c43..e7cff488 100644
--- a/src/macro/Expressions.cc
+++ b/src/macro/Expressions.cc
@@ -593,26 +593,55 @@ Variable::eval()
     {
       ArrayPtr map = dynamic_pointer_cast<Array>(indices->eval());
       vector<ExpressionPtr> index = map->getValue();
+      vector<int> ind;
+      double intpart;
+      for (auto it : index)
+        {
+          // Necessary to handle indexes like: y[1:2,2]
+          // In general this evaluates to [[1:2],2] but when subscripting we want to expand it to [1,2,2]
+          auto db = dynamic_pointer_cast<Double>(it);
+          if (db)
+            {
+              if (modf(*db, &intpart) != 0.0)
+                throw StackTrace("variable", "When indexing a variable you must pass an int or an int array", location);
+              ind.emplace_back(*db);
+            }
+          else if (dynamic_pointer_cast<Array>(it))
+            for (auto it1 : dynamic_pointer_cast<Array>(it)->getValue())
+              {
+                db = dynamic_pointer_cast<Double>(it1);
+                if (db)
+                  {
+                    if (modf(*db, &intpart) != 0.0)
+                      throw StackTrace("variable", "When indexing a variable you must pass an int or an int array", location);
+                    ind.emplace_back(*db);
+                  }
+                else
+                  throw StackTrace("variable", "You cannot index a variable with a nested array", location);
+              }
+          else
+            throw StackTrace("variable", "You can only index a variable with an int or an int array", location);
+        }
+
       switch (env.getType(name))
         {
         case codes::BaseType::Bool:
           throw StackTrace("variable", "You cannot index a boolean", location);
         case codes::BaseType::Double:
           throw StackTrace("variable", "You cannot index a double", location);
+        case codes::BaseType::Tuple:
+          throw StackTrace("variable", "You cannot index a tuple", location);
         case codes::BaseType::String:
           {
             string orig_string =
               dynamic_pointer_cast<String>(env.getVariable(name))->to_string();
             string retvals;
-            for (auto & it : index)
+            for (auto it : ind)
               try
                 {
-                  DoublePtr idx = dynamic_pointer_cast<Double>(it);
-                  if (!idx)
-                    throw StackTrace("variable", "indexing must be done with an int array", location);
-                  retvals += orig_string.substr(*idx - 1, 1);
+                  retvals += orig_string.substr(it - 1, 1);
                 }
-              catch (const std::out_of_range& oor)
+              catch (const std::out_of_range &ex)
                 {
                   throw StackTrace("variable", "Index out of range", location);
                 }
@@ -621,23 +650,22 @@ Variable::eval()
         case codes::BaseType::Array:
           {
             ArrayPtr ap = dynamic_pointer_cast<Array>(env.getVariable(name));
-            vector<ExpressionPtr> retval;
-            for (auto & it : index)
+            vector<BaseTypePtr> retval;
+            for (auto it : ind)
               try
                 {
-                  DoublePtr idx = dynamic_pointer_cast<Double>(it);
-                  if (!idx)
-                    throw StackTrace("variable", "indexing must be done with int array", location);
-                  retval.emplace_back(ap->at(*idx - 1)->eval());
+                  retval.emplace_back(ap->at(it - 1)->eval());
                 }
-              catch (const std::out_of_range& oor)
+              catch (const out_of_range &ex)
                 {
                   throw StackTrace("variable", "Index out of range", location);
                 }
-            return make_shared<Array>(retval, env);
+
+            if (retval.size() == 1)
+              return retval.at(0);
+            vector<ExpressionPtr> retvala(retval.begin(), retval.end());
+            return make_shared<Array>(retvala, env);
           }
-        case codes::BaseType::Tuple:
-          throw StackTrace("variable", "You cannot index a tuple", location);
         }
     }
   return env.getVariable(name)->eval();
diff --git a/src/macro/Expressions.hh b/src/macro/Expressions.hh
index e3fcc221..39082a12 100644
--- a/src/macro/Expressions.hh
+++ b/src/macro/Expressions.hh
@@ -370,17 +370,12 @@ namespace macro
   {
   private:
     const string name;
-    ArrayPtr indices; // for strings
+    ArrayPtr indices; // for strings/arrays
   public:
     Variable(const string name_arg,
              Environment &env_arg, const Tokenizer::location location_arg) :
       Expression(env_arg, move(location_arg)), name{move(name_arg)} { }
-    inline void addIndexing(ExpressionPtr indices_arg)
-    {
-      indices = dynamic_pointer_cast<Array>(indices_arg);
-      if (!indices)
-        indices = make_shared<Array>(vector<ExpressionPtr>{indices_arg}, env);
-    }
+    inline void addIndexing(const vector<ExpressionPtr> indices_arg) { indices = make_shared<Array>(indices_arg, env); }
     inline string to_string() const noexcept override { return name; }
     inline void print(ostream &output, bool matlab_output = false) const noexcept override { output << name; }
     BaseTypePtr eval() override;
diff --git a/src/macro/Parser.yy b/src/macro/Parser.yy
index 8629cf7a..cf1b4c4b 100644
--- a/src/macro/Parser.yy
+++ b/src/macro/Parser.yy
@@ -300,7 +300,7 @@ expr : LPAREN expr RPAREN
        { $$ = make_shared<Array>($1, $3, driver.env, @$); }
      | LBRACKET comma_expr RBRACKET
        { $$ = make_shared<Array>($2, driver.env, @$); }
-     | symbol LBRACKET expr RBRACKET
+     | symbol LBRACKET comma_expr RBRACKET
        { $1->addIndexing($3); $$ = $1; }
      | LPAREN tuple_comma_expr RPAREN
        { $$ = make_shared<Tuple>($2, driver.env, @$); }
-- 
GitLab