From 4c6032895da8ecabd345d70d28170de64e737244 Mon Sep 17 00:00:00 2001 From: Houtan Bastani <houtan@dynare.org> Date: Mon, 24 Jun 2019 15:01:22 +0200 Subject: [PATCH] macro processor: fix bug in indexing of strings/arrays --- src/macro/Expressions.cc | 60 +++++++++++++++++++++++++++++----------- src/macro/Expressions.hh | 9 ++---- src/macro/Parser.yy | 2 +- 3 files changed, 47 insertions(+), 24 deletions(-) diff --git a/src/macro/Expressions.cc b/src/macro/Expressions.cc index c0cc4c43..e7cff488 100644 --- a/src/macro/Expressions.cc +++ b/src/macro/Expressions.cc @@ -593,26 +593,55 @@ Variable::eval() { ArrayPtr map = dynamic_pointer_cast<Array>(indices->eval()); vector<ExpressionPtr> index = map->getValue(); + vector<int> ind; + double intpart; + for (auto it : index) + { + // Necessary to handle indexes like: y[1:2,2] + // In general this evaluates to [[1:2],2] but when subscripting we want to expand it to [1,2,2] + auto db = dynamic_pointer_cast<Double>(it); + if (db) + { + if (modf(*db, &intpart) != 0.0) + throw StackTrace("variable", "When indexing a variable you must pass an int or an int array", location); + ind.emplace_back(*db); + } + else if (dynamic_pointer_cast<Array>(it)) + for (auto it1 : dynamic_pointer_cast<Array>(it)->getValue()) + { + db = dynamic_pointer_cast<Double>(it1); + if (db) + { + if (modf(*db, &intpart) != 0.0) + throw StackTrace("variable", "When indexing a variable you must pass an int or an int array", location); + ind.emplace_back(*db); + } + else + throw StackTrace("variable", "You cannot index a variable with a nested array", location); + } + else + throw StackTrace("variable", "You can only index a variable with an int or an int array", location); + } + switch (env.getType(name)) { case codes::BaseType::Bool: throw StackTrace("variable", "You cannot index a boolean", location); case codes::BaseType::Double: throw StackTrace("variable", "You cannot index a double", location); + case codes::BaseType::Tuple: + throw StackTrace("variable", "You cannot index a tuple", location); case codes::BaseType::String: { string orig_string = dynamic_pointer_cast<String>(env.getVariable(name))->to_string(); string retvals; - for (auto & it : index) + for (auto it : ind) try { - DoublePtr idx = dynamic_pointer_cast<Double>(it); - if (!idx) - throw StackTrace("variable", "indexing must be done with an int array", location); - retvals += orig_string.substr(*idx - 1, 1); + retvals += orig_string.substr(it - 1, 1); } - catch (const std::out_of_range& oor) + catch (const std::out_of_range &ex) { throw StackTrace("variable", "Index out of range", location); } @@ -621,23 +650,22 @@ Variable::eval() case codes::BaseType::Array: { ArrayPtr ap = dynamic_pointer_cast<Array>(env.getVariable(name)); - vector<ExpressionPtr> retval; - for (auto & it : index) + vector<BaseTypePtr> retval; + for (auto it : ind) try { - DoublePtr idx = dynamic_pointer_cast<Double>(it); - if (!idx) - throw StackTrace("variable", "indexing must be done with int array", location); - retval.emplace_back(ap->at(*idx - 1)->eval()); + retval.emplace_back(ap->at(it - 1)->eval()); } - catch (const std::out_of_range& oor) + catch (const out_of_range &ex) { throw StackTrace("variable", "Index out of range", location); } - return make_shared<Array>(retval, env); + + if (retval.size() == 1) + return retval.at(0); + vector<ExpressionPtr> retvala(retval.begin(), retval.end()); + return make_shared<Array>(retvala, env); } - case codes::BaseType::Tuple: - throw StackTrace("variable", "You cannot index a tuple", location); } } return env.getVariable(name)->eval(); diff --git a/src/macro/Expressions.hh b/src/macro/Expressions.hh index e3fcc221..39082a12 100644 --- a/src/macro/Expressions.hh +++ b/src/macro/Expressions.hh @@ -370,17 +370,12 @@ namespace macro { private: const string name; - ArrayPtr indices; // for strings + ArrayPtr indices; // for strings/arrays public: Variable(const string name_arg, Environment &env_arg, const Tokenizer::location location_arg) : Expression(env_arg, move(location_arg)), name{move(name_arg)} { } - inline void addIndexing(ExpressionPtr indices_arg) - { - indices = dynamic_pointer_cast<Array>(indices_arg); - if (!indices) - indices = make_shared<Array>(vector<ExpressionPtr>{indices_arg}, env); - } + inline void addIndexing(const vector<ExpressionPtr> indices_arg) { indices = make_shared<Array>(indices_arg, env); } inline string to_string() const noexcept override { return name; } inline void print(ostream &output, bool matlab_output = false) const noexcept override { output << name; } BaseTypePtr eval() override; diff --git a/src/macro/Parser.yy b/src/macro/Parser.yy index 8629cf7a..cf1b4c4b 100644 --- a/src/macro/Parser.yy +++ b/src/macro/Parser.yy @@ -300,7 +300,7 @@ expr : LPAREN expr RPAREN { $$ = make_shared<Array>($1, $3, driver.env, @$); } | LBRACKET comma_expr RBRACKET { $$ = make_shared<Array>($2, driver.env, @$); } - | symbol LBRACKET expr RBRACKET + | symbol LBRACKET comma_expr RBRACKET { $1->addIndexing($3); $$ = $1; } | LPAREN tuple_comma_expr RPAREN { $$ = make_shared<Tuple>($2, driver.env, @$); } -- GitLab