diff --git a/src/macro/Expressions.cc b/src/macro/Expressions.cc index d1a551872f98942854e8a40e54baeb09aac0c0ae..a740b6f967289a30214a55f61e41e2f9745113f5 100644 --- a/src/macro/Expressions.cc +++ b/src/macro/Expressions.cc @@ -297,6 +297,19 @@ String::is_equal(const BaseTypePtr &btp) const return make_shared<Bool>(value == btp2->value, env); } +BoolPtr +String::cast_bool() const +{ + try + { + return make_shared<Bool>(static_cast<bool>(stoi(value)), env); + } + catch (...) + { + throw StackTrace(value + " cannot be converted to a boolean"); + } +} + DoublePtr String::cast_int() const { @@ -518,6 +531,14 @@ Array::sum() const return make_shared<Double>(retval, env); } +BoolPtr +Array::cast_bool() const +{ + if (arr.size() != 1) + throw StackTrace("Array must be of size 1 to be cast to a boolean"); + return arr.at(0)->eval()->cast_bool(); +} + DoublePtr Array::cast_int() const { @@ -568,6 +589,14 @@ Tuple::contains(const BaseTypePtr &btp) const return make_shared<Bool>(false, env); } +BoolPtr +Tuple::cast_bool() const +{ + if (tup.size() != 1) + throw StackTrace("Tuple must be of size 1 to be cast to a boolean"); + return tup.at(0)->eval()->cast_bool(); +} + DoublePtr Tuple::cast_int() const { @@ -767,6 +796,8 @@ UnaryOp::eval() auto argbt = arg->eval(); switch (op_code) { + case codes::UnaryOp::cast_bool: + return argbt->cast_bool(); case codes::UnaryOp::cast_int: return argbt->cast_int(); case codes::UnaryOp::cast_double: @@ -1097,6 +1128,8 @@ UnaryOp::to_string() const noexcept string retval = arg->to_string(); switch (op_code) { + case codes::UnaryOp::cast_bool: + return "(bool)" + retval; case codes::UnaryOp::cast_int: return "(int)" + retval; case codes::UnaryOp::cast_double: @@ -1307,6 +1340,9 @@ UnaryOp::print(ostream &output, bool matlab_output) const noexcept { switch (op_code) { + case codes::UnaryOp::cast_bool: + output << "(bool)"; + break; case codes::UnaryOp::cast_int: output << "(int)"; break; @@ -1407,7 +1443,8 @@ UnaryOp::print(ostream &output, bool matlab_output) const noexcept arg->print(output, matlab_output); - if (op_code != codes::UnaryOp::cast_int + if (op_code != codes::UnaryOp::cast_bool + && op_code != codes::UnaryOp::cast_int && op_code != codes::UnaryOp::cast_double && op_code != codes::UnaryOp::cast_string && op_code != codes::UnaryOp::cast_tuple diff --git a/src/macro/Expressions.hh b/src/macro/Expressions.hh index d195ac155deb6869aaf18a248f4f3c127bb14803..7d7143ff5dc59ebc1c2cef7182c984c96a8c09a2 100644 --- a/src/macro/Expressions.hh +++ b/src/macro/Expressions.hh @@ -180,6 +180,7 @@ namespace macro virtual DoublePtr normpdf(const BaseTypePtr &btp1, const BaseTypePtr &btp2) const { throw StackTrace("Operator `normpdf` does not exist for this type"); } virtual DoublePtr normcdf() const { throw StackTrace("Operator `normcdf` does not exist for this type"); } virtual DoublePtr normcdf(const BaseTypePtr &btp1, const BaseTypePtr &btp2) const { throw StackTrace("Operator `normcdf` does not exist for this type"); } + virtual BoolPtr cast_bool() const { throw StackTrace("This type cannot be cast to a boolean"); } virtual DoublePtr cast_int() const { throw StackTrace("This type cannot be cast to an integer"); } virtual DoublePtr cast_double() const { throw StackTrace("This type cannot be cast to a double"); } virtual StringPtr cast_string() const { throw StackTrace("This type cannot be cast to a string"); } @@ -207,6 +208,7 @@ namespace macro BoolPtr logical_and(const BaseTypePtr &btp) const override; BoolPtr logical_or(const BaseTypePtr &btp) const override; BoolPtr logical_not() const override; + inline BoolPtr cast_bool() const override { return make_shared<Bool>(value, env); } inline DoublePtr cast_int() const override { return value ? make_shared<Double>(1, env) : make_shared<Double>(0, env); } inline DoublePtr cast_double() const override { return cast_int(); } inline StringPtr cast_string() const override { return make_shared<String>(this->to_string(), env); } @@ -302,6 +304,7 @@ namespace macro return normcdf(make_shared<Double>(0, env), make_shared<Double>(1, env)); } DoublePtr normcdf(const BaseTypePtr &btp1, const BaseTypePtr &btp2) const override; + inline BoolPtr cast_bool() const override { return make_shared<Bool>(static_cast<bool>(value), env); } inline DoublePtr cast_int() const override { return make_shared<Double>(static_cast<int>(value), env); } inline DoublePtr cast_double() const override { return make_shared<Double>(value, env); } inline StringPtr cast_string() const override { return make_shared<String>(this->to_string(), env); } @@ -337,8 +340,9 @@ namespace macro BoolPtr is_greater_equal(const BaseTypePtr &btp) const override; BoolPtr is_equal(const BaseTypePtr &btp) const override; inline DoublePtr length() const override { return make_shared<Double>(value.size(), env); } - inline DoublePtr cast_int() const override; - inline DoublePtr cast_double() const override; + BoolPtr cast_bool() const override; + DoublePtr cast_int() const override; + DoublePtr cast_double() const override; inline StringPtr cast_string() const override { return make_shared<String>(value, env); } inline TuplePtr cast_tuple() const override { @@ -373,6 +377,7 @@ namespace macro BoolPtr is_equal(const BaseTypePtr &btp) const override; BoolPtr contains(const BaseTypePtr &btp) const override; inline DoublePtr length() const override { return make_shared<Double>(tup.size(), env); } + BoolPtr cast_bool() const override; DoublePtr cast_int() const override; DoublePtr cast_double() const override; inline StringPtr cast_string() const override { return make_shared<String>(this->to_string(), env); } @@ -419,6 +424,7 @@ namespace macro BoolPtr contains(const BaseTypePtr &btp) const override; inline DoublePtr length() const override { return make_shared<Double>(arr.size(), env); } DoublePtr sum() const override; + BoolPtr cast_bool() const override; DoublePtr cast_int() const override; DoublePtr cast_double() const override; inline StringPtr cast_string() const override { return make_shared<String>(this->to_string(), env); } diff --git a/src/macro/ForwardDeclarationsAndEnums.hh b/src/macro/ForwardDeclarationsAndEnums.hh index 78070e8d2411c1598557e6b8923065daeb9ca849..09eb1065457abe7e862fce3fbc4649ac1d072092 100644 --- a/src/macro/ForwardDeclarationsAndEnums.hh +++ b/src/macro/ForwardDeclarationsAndEnums.hh @@ -67,6 +67,7 @@ namespace macro enum class UnaryOp { + cast_bool, cast_int, cast_double, cast_string, diff --git a/src/macro/Parser.yy b/src/macro/Parser.yy index 7940dcfa21c9fbd13523f4d257f7a82613649d57..36ff49037a221d228601968e4069cbb324ed8d29 100644 --- a/src/macro/Parser.yy +++ b/src/macro/Parser.yy @@ -65,7 +65,7 @@ using namespace macro; %token SQRT CBRT SIGN MAX MIN FLOOR CEIL TRUNC SUM MOD %token ERF ERFC GAMMA LGAMMA ROUND NORMPDF NORMCDF LENGTH -%token INT DOUBLE STRING TUPLE ARRAY +%token BOOL INT DOUBLE STRING TUPLE ARRAY %left OR %left AND @@ -78,7 +78,7 @@ using namespace macro; %left PLUS MINUS %left TIMES DIVIDE %precedence UMINUS UPLUS NOT -%precedence CAST_INT CAST_DOUBLE CAST_STRING CAST_TUPLE CAST_ARRAY +%precedence CAST_BOOL CAST_INT CAST_DOUBLE CAST_STRING CAST_TUPLE CAST_ARRAY %nonassoc POWER %token <string> NAME TEXT QUOTED_STRING NUMBER EOL @@ -322,6 +322,8 @@ expr : LPAREN expr RPAREN { $$ = make_shared<Comprehension>($2, $4, $6, driver.env, @$); } | LBRACKET expr FOR expr IN expr WHEN expr RBRACKET { $$ = make_shared<Comprehension>($2, $4, $6, $8, driver.env, @$); } + | LPAREN BOOL RPAREN expr %prec CAST_BOOL + { $$ = make_shared<UnaryOp>(codes::UnaryOp::cast_bool, $4, driver.env, @$); } | LPAREN INT RPAREN expr %prec CAST_INT { $$ = make_shared<UnaryOp>(codes::UnaryOp::cast_int, $4, driver.env, @$); } | LPAREN DOUBLE RPAREN expr %prec CAST_DOUBLE diff --git a/src/macro/Tokenizer.ll b/src/macro/Tokenizer.ll index dfb15c0f66446ba4ff28a8a51281857a1cc02cfe..a993912804d0abd1af66254d0cc65d3c217cafd3 100644 --- a/src/macro/Tokenizer.ll +++ b/src/macro/Tokenizer.ll @@ -142,6 +142,7 @@ CONT \\\\{SPC}* <expr,eval>normpdf { return token::NORMPDF; } <expr,eval>normcdf { return token::NORMCDF; } +<expr,eval>bool { return token::BOOL; } <expr,eval>int { return token::INT; } <expr,eval>double { return token::DOUBLE; } <expr,eval>string { return token::STRING; }