diff --git a/src/macro/Expressions.cc b/src/macro/Expressions.cc index 7e2de2a82e2bc475bdb022d95de9abd1f5207cbe..2cde1d0ea0af81f134901a256a822c2fc82fdc01 100644 --- a/src/macro/Expressions.cc +++ b/src/macro/Expressions.cc @@ -876,7 +876,7 @@ TrinaryOp::eval() } BaseTypePtr -Comprehension::eval() +ListComprehension::eval() { ArrayPtr ap; VariablePtr vp; @@ -885,15 +885,15 @@ Comprehension::eval() { ap = dynamic_pointer_cast<Array>(c_set->eval()); if (!ap) - throw StackTrace("Comprehension", "The input set must evaluate to an array", location); + throw StackTrace("ListComprehension", "The input set must evaluate to an array", location); vp = dynamic_pointer_cast<Variable>(c_vars); mt = dynamic_pointer_cast<Tuple>(c_vars); if ((!vp && !mt) || (vp && mt)) - throw StackTrace("Comprehension", "the output expression must be either a tuple or a variable", location); + throw StackTrace("ListComprehension", "the output expression must be either a tuple or a variable", location); } catch (StackTrace &ex) { - ex.push("Comprehension: ", location); + ex.push("ListComprehension: ", location); throw; } @@ -908,19 +908,19 @@ Comprehension::eval() { auto mt2 = dynamic_pointer_cast<Tuple>(btp); if (mt->size() != mt2->size()) - throw StackTrace("Comprehension", "The number of elements in the input set tuple are not " + throw StackTrace("ListComprehension", "The number of elements in the input set tuple are not " "the same as the number of elements in the output expression tuple", location); for (size_t j = 0; j < mt->size(); j++) { auto vp2 = dynamic_pointer_cast<Variable>(mt->at(j)); if (!vp2) - throw StackTrace("Comprehension", "Output expression tuple must be comprised of variable names", location); + throw StackTrace("ListComprehension", "Output expression tuple must be comprised of variable names", location); env.define(vp2, mt2->at(j)); } } else - throw StackTrace("Comprehension", "assigning to tuple in output expression " + throw StackTrace("ListComprehension", "assigning to tuple in output expression " "but input expression does not contain tuples", location); DoublePtr dp; @@ -934,7 +934,7 @@ Comprehension::eval() } catch (StackTrace &ex) { - ex.push("Comprehension", location); + ex.push("ListComprehension", location); throw; } if ((bp && *bp) || (dp && *dp)) @@ -943,13 +943,123 @@ Comprehension::eval() return make_shared<Array>(values, env); } +BaseTypePtr +ArrayComprehension::eval() +{ + ArrayPtr input_set; + VariablePtr vp; + TuplePtr mt; + try + { + input_set = dynamic_pointer_cast<Array>(c_set->eval()); + if (!input_set) + throw StackTrace("ArrayComprehension", "The input set must evaluate to an array", location); + vp = dynamic_pointer_cast<Variable>(c_vars); + mt = dynamic_pointer_cast<Tuple>(c_vars); + if ((!vp && !mt) || (vp && mt)) + throw StackTrace("ArrayComprehension", "the loop variables must be either " + "a tuple or a variable", location); + } + catch (StackTrace &ex) + { + ex.push("ArrayComprehension: ", location); + throw; + } + + vector<ExpressionPtr> values; + for (size_t i = 0; i < input_set->size(); i++) + { + auto btp = dynamic_pointer_cast<BaseType>(input_set->at(i)); + if (vp) + env.define(vp, btp); + else + if (btp->getType() == codes::BaseType::Tuple) + { + auto mt2 = dynamic_pointer_cast<Tuple>(btp); + if (mt->size() != mt2->size()) + throw StackTrace("ArrayComprehension", "The number of elements in the input " + " set tuple are not the same as the number of elements in " + "the output expression tuple", location); + + for (size_t j = 0; j < mt->size(); j++) + { + auto vp2 = dynamic_pointer_cast<Variable>(mt->at(j)); + if (!vp2) + throw StackTrace("ArrayComprehension", "Output expression tuple must be " + "comprised of variable names", location); + env.define(vp2, mt2->at(j)); + } + } + else + throw StackTrace("ArrayComprehension", "assigning to tuple in output expression " + "but input expression does not contain tuples", location); + + if (!c_when) + values.emplace_back(c_expr->clone()->eval()); + else + { + DoublePtr dp; + BoolPtr bp; + try + { + dp = dynamic_pointer_cast<Double>(c_when->eval()); + bp = dynamic_pointer_cast<Bool>(c_when->eval()); + if (!bp && !dp) + throw StackTrace("The condition must evaluate to a boolean or a double"); + } + catch (StackTrace &ex) + { + ex.push("ArrayComprehension", location); + throw; + } + if ((bp && *bp) || (dp && *dp)) + values.emplace_back(c_expr->clone()->eval()); + } + } + return make_shared<Array>(values, env); +} + +ExpressionPtr +Tuple::clone() const noexcept +{ + vector<ExpressionPtr> tup_copy; + for (auto & it : tup) + tup_copy.emplace_back(it->clone()); + return make_shared<Tuple>(tup_copy, env, location); +} + +ExpressionPtr +Array::clone() const noexcept +{ + if (range1 && range2) + return make_shared<Array>(range1, range2, env, location); + vector<ExpressionPtr> arr_copy; + for (auto & it : arr) + arr_copy.emplace_back(it->clone()); + return make_shared<Array>(arr_copy, env, location); +} + +ExpressionPtr +Function::clone() const noexcept +{ + vector<ExpressionPtr> args_copy; + for (auto & it : args) + args_copy.emplace_back(it->clone()); + return make_shared<Function>(name, args_copy, env, location); +} + string Array::to_string() const noexcept { - string retval = "["; - for (const auto & it : arr) - retval += dynamic_pointer_cast<BaseType>(it)->to_string() + ", "; - return retval.substr(0, retval.size()-2) + "]"; + if (!arr.empty()) + { + string retval = "["; + for (const auto & it : arr) + retval += dynamic_pointer_cast<BaseType>(it)->to_string() + ", "; + return retval.substr(0, retval.size()-2) + "]"; + } + else + return "[" + range1->to_string() + ":" + range2->to_string() + "]"; } string @@ -1112,6 +1222,15 @@ TrinaryOp::to_string() const noexcept exit(EXIT_FAILURE); } +string +ArrayComprehension::to_string() const noexcept +{ + string retval = "[" + c_expr->to_string() + " for " + c_vars->to_string() + " in " + c_set->to_string(); + if (c_when) + retval += " when " + c_when->to_string(); + return retval + "]"; +} + void String::print(ostream &output, bool matlab_output) const noexcept { @@ -1368,7 +1487,7 @@ TrinaryOp::print(ostream &output, bool matlab_output) const noexcept } void -Comprehension::print(ostream &output, bool matlab_output) const noexcept +ListComprehension::print(ostream &output, bool matlab_output) const noexcept { output << "["; c_vars->print(output, matlab_output); @@ -1379,3 +1498,20 @@ Comprehension::print(ostream &output, bool matlab_output) const noexcept output << "]"; } +void +ArrayComprehension::print(ostream &output, bool matlab_output) const noexcept +{ + output << "["; + c_expr->print(output, matlab_output); + output << " for "; + c_vars->print(output, matlab_output); + output << " in "; + c_set->print(output, matlab_output); + if (c_when) + { + output << " when "; + c_when->print(output, matlab_output); + } + output << "]"; +} + diff --git a/src/macro/Expressions.hh b/src/macro/Expressions.hh index 0600e9d511dc89233a39281dc8fac27fd5e9079a..3e4a344f541ffea2e3b2db30174d6d710a69bb45 100644 --- a/src/macro/Expressions.hh +++ b/src/macro/Expressions.hh @@ -116,6 +116,7 @@ namespace macro virtual string to_string() const noexcept = 0; virtual void print(ostream &output, bool matlab_output = false) const noexcept = 0; virtual BaseTypePtr eval() = 0; + virtual ExpressionPtr clone() const noexcept = 0; }; @@ -194,6 +195,7 @@ namespace macro inline codes::BaseType getType() const noexcept override { return codes::BaseType::Bool; } inline string to_string() const noexcept override { return value ? "true" : "false"; } inline void print(ostream &output, bool matlab_output = false) const noexcept override { output << to_string(); } + inline ExpressionPtr clone() const noexcept override { return make_shared<Bool>(value, env, location); } public: operator bool() const { return value; } BoolPtr is_equal(const BaseTypePtr &btp) const override; @@ -215,8 +217,8 @@ namespace macro BaseType(env_arg, move(location_arg)), value{strtod(value_arg.c_str(), nullptr)} { } Double(double value_arg, - Environment &env_arg) : - BaseType(env_arg), + Environment &env_arg, Tokenizer::location location_arg = Tokenizer::location()) : + BaseType(env_arg, move(location_arg)), value{value_arg} { } inline codes::BaseType getType() const noexcept override { return codes::BaseType::Double; } inline string to_string() const noexcept override @@ -226,6 +228,7 @@ namespace macro return strs.str(); } inline void print(ostream &output, bool matlab_output = false) const noexcept override { output << to_string(); } + inline ExpressionPtr clone() const noexcept override { return make_shared<Double>(value, env, location); } public: operator double() const { return value; } BaseTypePtr plus(const BaseTypePtr &bt) const override; @@ -288,6 +291,7 @@ namespace macro inline codes::BaseType getType() const noexcept override { return codes::BaseType::String; } inline string to_string() const noexcept override { return value; } void print(ostream &output, bool matlab_output = false) const noexcept override; + inline ExpressionPtr clone() const noexcept override { return make_shared<String>(value, env, location); } public: operator string() const { return value; } BaseTypePtr plus(const BaseTypePtr &bt) const override; @@ -313,6 +317,7 @@ namespace macro string to_string() const noexcept override; void print(ostream &output, bool matlab_output = false) const noexcept override; BaseTypePtr eval() override; + ExpressionPtr clone() const noexcept override; public: inline size_t size() const { return tup.size(); } inline bool empty() const { return tup.empty(); } @@ -346,6 +351,7 @@ namespace macro string to_string() const noexcept override; void print(ostream &output, bool matlab_output = false) const noexcept override; BaseTypePtr eval() override; + ExpressionPtr clone() const noexcept override; public: inline size_t size() const { return arr.size(); } inline vector<ExpressionPtr> getValue() const { return arr; } @@ -373,10 +379,14 @@ namespace macro Variable(const string name_arg, Environment &env_arg, const Tokenizer::location location_arg) : Expression(env_arg, move(location_arg)), name{move(name_arg)} { } + Variable(const string name_arg, const ArrayPtr indices_arg, + Environment &env_arg, const Tokenizer::location location_arg) : + Expression(env_arg, move(location_arg)), name{move(name_arg)}, indices{move(indices_arg)} { } inline void addIndexing(const vector<ExpressionPtr> indices_arg) { indices = make_shared<Array>(indices_arg, env); } inline string to_string() const noexcept override { return name; } inline void print(ostream &output, bool matlab_output = false) const noexcept override { output << name; } BaseTypePtr eval() override; + inline ExpressionPtr clone() const noexcept override { return indices ? make_shared<Variable>(name, indices, env, location) : make_shared<Variable>(name, env, location); } public: inline string getName() const noexcept { return name; } inline codes::BaseType getType() const { return env.getType(name); } @@ -396,6 +406,7 @@ namespace macro string to_string() const noexcept override; inline void print(ostream &output, bool matlab_output = false) const noexcept override { printName(output); printArgs(output); } BaseTypePtr eval() override; + ExpressionPtr clone() const noexcept override; public: inline void printName(ostream &output) const noexcept { output << name; } void printArgs(ostream &output) const noexcept; @@ -417,6 +428,7 @@ namespace macro string to_string() const noexcept override; void print(ostream &output, bool matlab_output = false) const noexcept override; BaseTypePtr eval() override; + inline ExpressionPtr clone() const noexcept override { return make_shared<UnaryOp>(op_code, arg->clone(), env, location); } }; @@ -435,6 +447,7 @@ namespace macro string to_string() const noexcept override; void print(ostream &output, bool matlab_output = false) const noexcept override; BaseTypePtr eval() override; + inline ExpressionPtr clone() const noexcept override { return make_shared<BinaryOp>(op_code, arg1->clone(), arg2->clone(), env, location); } }; @@ -452,23 +465,51 @@ namespace macro string to_string() const noexcept override; void print(ostream &output, bool matlab_output = false) const noexcept override; BaseTypePtr eval() override; + inline ExpressionPtr clone() const noexcept override { return make_shared<TrinaryOp>(op_code, arg1->clone(), arg2->clone(), arg3->clone(), env, location); } }; - class Comprehension final : public Expression + class ListComprehension final : public Expression { private: const ExpressionPtr c_vars, c_set, c_when; public: - Comprehension(const ExpressionPtr c_vars_arg, - const ExpressionPtr c_set_arg, - const ExpressionPtr c_when_arg, - Environment &env_arg, const Tokenizer::location location_arg) : + ListComprehension(const ExpressionPtr c_vars_arg, + const ExpressionPtr c_set_arg, + const ExpressionPtr c_when_arg, + Environment &env_arg, const Tokenizer::location location_arg) : Expression(env_arg, move(location_arg)), c_vars{move(c_vars_arg)}, c_set{move(c_set_arg)}, c_when{move(c_when_arg)} { } inline string to_string() const noexcept override { return "[" + c_vars->to_string() + " in " + c_set->to_string() + " when " + c_when->to_string() + "]"; } void print(ostream &output, bool matlab_output = false) const noexcept override; BaseTypePtr eval() override; + inline ExpressionPtr clone() const noexcept override { return make_shared<ListComprehension>(c_vars->clone(), c_set->clone(), c_when->clone(), env, location); } + }; + + + class ArrayComprehension final : public Expression + { + private: + const ExpressionPtr c_expr, c_vars, c_set, c_when; + public: + ArrayComprehension(const ExpressionPtr c_expr_arg, + const ExpressionPtr c_vars_arg, + const ExpressionPtr c_set_arg, + const ExpressionPtr c_when_arg, + Environment &env_arg, const Tokenizer::location location_arg) : + Expression(env_arg, move(location_arg)), + c_expr{move(c_expr_arg)}, c_vars{move(c_vars_arg)}, + c_set{move(c_set_arg)}, c_when{move(c_when_arg)} { } + ArrayComprehension(const ExpressionPtr c_expr_arg, + const ExpressionPtr c_vars_arg, + const ExpressionPtr c_set_arg, + Environment &env_arg, const Tokenizer::location location_arg) : + Expression(env_arg, move(location_arg)), + c_expr{move(c_expr_arg)}, c_vars{move(c_vars_arg)}, c_set{move(c_set_arg)} { } + string to_string() const noexcept override; + void print(ostream &output, bool matlab_output = false) const noexcept override; + BaseTypePtr eval() override; + inline ExpressionPtr clone() const noexcept override { return make_shared<ArrayComprehension>(c_expr->clone(), c_vars->clone(), c_set->clone(), c_when->clone(), env, location); } }; } #endif diff --git a/src/macro/Parser.yy b/src/macro/Parser.yy index 85972412fd4e8e055cf7d7248e9dc89363c76368..739ad3d15cc229fc0d73ad338dd57e8c29e296b5 100644 --- a/src/macro/Parser.yy +++ b/src/macro/Parser.yy @@ -304,7 +304,11 @@ expr : LPAREN expr RPAREN | LPAREN tuple_comma_expr RPAREN { $$ = make_shared<Tuple>($2, driver.env, @$); } | LBRACKET expr IN expr WHEN expr RBRACKET - { $$ = make_shared<Comprehension>($2, $4, $6, driver.env, @$); } + { $$ = make_shared<ListComprehension>($2, $4, $6, driver.env, @$); } + | LBRACKET expr FOR expr IN expr RBRACKET + { $$ = make_shared<ArrayComprehension>($2, $4, $6, driver.env, @$); } + | LBRACKET expr FOR expr IN expr WHEN expr RBRACKET + { $$ = make_shared<ArrayComprehension>($2, $4, $6, $8, driver.env, @$); } | NOT expr { $$ = make_shared<UnaryOp>(codes::UnaryOp::logical_not, $2, driver.env, @$); } | MINUS expr %prec UMINUS diff --git a/src/macro/Tokenizer.ll b/src/macro/Tokenizer.ll index a1268a4a56faeadde1ce7b3caf052264275943e4..a76d4314b22e3aa19c8427f8cb301a2268c96eb8 100644 --- a/src/macro/Tokenizer.ll +++ b/src/macro/Tokenizer.ll @@ -111,6 +111,7 @@ CONT \\\\{SPC}* <expr,eval>\[ { return token::LBRACKET; } <expr,eval>\] { return token::RBRACKET; } <expr,eval>in { return token::IN; } +<expr,eval>for { return token::FOR; } <expr,eval>when { return token::WHEN; } <expr,eval>save { return token::SAVE; }