From 544251290435e878201bf23d3a3cad6fbabb286d Mon Sep 17 00:00:00 2001
From: Houtan Bastani <houtan@dynare.org>
Date: Fri, 2 Aug 2019 15:03:05 -0400
Subject: [PATCH] macro processor: introduce double casts

---
 src/macro/Expressions.cc                 | 37 ++++++++++++++++++++++++
 src/macro/Expressions.hh                 |  6 ++++
 src/macro/ForwardDeclarationsAndEnums.hh |  1 +
 src/macro/Parser.yy                      |  6 ++--
 src/macro/Tokenizer.ll                   |  1 +
 5 files changed, 49 insertions(+), 2 deletions(-)

diff --git a/src/macro/Expressions.cc b/src/macro/Expressions.cc
index ce8e91b8..5e123148 100644
--- a/src/macro/Expressions.cc
+++ b/src/macro/Expressions.cc
@@ -310,6 +310,19 @@ String::cast_int() const
     }
 }
 
+DoublePtr
+String::cast_double() const
+{
+  try
+    {
+      return make_shared<Double>(stod(value), env);
+    }
+  catch (...)
+    {
+      throw StackTrace(value + " cannot be converted to a double");
+    }
+}
+
 BaseTypePtr
 Array::plus(const BaseTypePtr &btp) const
 {
@@ -513,6 +526,14 @@ Array::cast_int() const
   return arr.at(0)->eval()->cast_int();
 }
 
+DoublePtr
+Array::cast_double() const
+{
+  if (arr.size() != 1)
+    throw StackTrace("Array must be of size 1 to be cast to a double");
+  return arr.at(0)->eval()->cast_double();
+}
+
 BoolPtr
 Tuple::is_equal(const BaseTypePtr &btp) const
 {
@@ -555,6 +576,14 @@ Tuple::cast_int() const
   return tup.at(0)->eval()->cast_int();
 }
 
+DoublePtr
+Tuple::cast_double() const
+{
+  if (tup.size() != 1)
+    throw StackTrace("Tuple must be of size 1 to be cast to a double");
+  return tup.at(0)->eval()->cast_double();
+}
+
 BaseTypePtr
 Array::eval()
 {
@@ -740,6 +769,8 @@ UnaryOp::eval()
         {
         case codes::UnaryOp::cast_int:
           return argbt->cast_int();
+        case codes::UnaryOp::cast_double:
+          return argbt->cast_double();
         case codes::UnaryOp::logical_not:
           return argbt->logical_not();
         case codes::UnaryOp::unary_minus:
@@ -1062,6 +1093,8 @@ UnaryOp::to_string() const noexcept
     {
     case codes::UnaryOp::cast_int:
       return "(int)" + retval;
+    case codes::UnaryOp::cast_double:
+      return "(double)" + retval;
     case codes::UnaryOp::logical_not:
       return "!" + retval;
     case codes::UnaryOp::unary_minus:
@@ -1265,6 +1298,9 @@ UnaryOp::print(ostream &output, bool matlab_output) const noexcept
     case codes::UnaryOp::cast_int:
       output << "(int)";
       break;
+    case codes::UnaryOp::cast_double:
+      output << "(double)";
+      break;
     case codes::UnaryOp::logical_not:
       output << "!";
       break;
@@ -1351,6 +1387,7 @@ UnaryOp::print(ostream &output, bool matlab_output) const noexcept
   arg->print(output, matlab_output);
 
   if (op_code != codes::UnaryOp::cast_int
+      && op_code != codes::UnaryOp::cast_double
       && op_code != codes::UnaryOp::logical_not
       && op_code != codes::UnaryOp::unary_plus
       && op_code != codes::UnaryOp::unary_minus)
diff --git a/src/macro/Expressions.hh b/src/macro/Expressions.hh
index d75bf2cd..2e0b44d6 100644
--- a/src/macro/Expressions.hh
+++ b/src/macro/Expressions.hh
@@ -181,6 +181,7 @@ namespace macro
     virtual DoublePtr normcdf() const { throw StackTrace("Operator `normcdf` does not exist for this type"); }
     virtual DoublePtr normcdf(const BaseTypePtr &btp1, const BaseTypePtr &btp2) const { throw StackTrace("Operator `normcdf` does not exist for this type"); }
     virtual DoublePtr cast_int() const { throw StackTrace("This type cannot be cast to an integer"); }
+    virtual DoublePtr cast_double() const { throw StackTrace("This type cannot be cast to a double"); }
   };
 
 
@@ -204,6 +205,7 @@ namespace macro
     BoolPtr logical_or(const BaseTypePtr &btp) const override;
     BoolPtr logical_not() const override;
     inline DoublePtr cast_int() const override { return value ? make_shared<Double>(1, env) : make_shared<Double>(0, env); }
+    inline DoublePtr cast_double() const override { return cast_int(); }
   };
 
 
@@ -289,6 +291,7 @@ namespace macro
     }
     DoublePtr normcdf(const BaseTypePtr &btp1, const BaseTypePtr &btp2) const override;
     inline DoublePtr cast_int() const override { return make_shared<Double>(static_cast<int>(value), env); }
+    inline DoublePtr cast_double() const override { return make_shared<Double>(value, env); }
   };
 
   class String final : public BaseType
@@ -314,6 +317,7 @@ namespace macro
     BoolPtr is_equal(const BaseTypePtr &btp) const override;
     inline DoublePtr length() const override { return make_shared<Double>(value.size(), env); }
     inline DoublePtr cast_int() const override;
+    inline DoublePtr cast_double() const override;
   };
 
 
@@ -340,6 +344,7 @@ namespace macro
     BoolPtr contains(const BaseTypePtr &btp) const override;
     inline DoublePtr length() const override { return make_shared<Double>(tup.size(), env); }
     DoublePtr cast_int() const override;
+    DoublePtr cast_double() const override;
   };
 
 
@@ -382,6 +387,7 @@ namespace macro
     inline DoublePtr length() const override { return make_shared<Double>(arr.size(), env); }
     DoublePtr sum() const override;
     DoublePtr cast_int() const override;
+    DoublePtr cast_double() const override;
   };
 
 
diff --git a/src/macro/ForwardDeclarationsAndEnums.hh b/src/macro/ForwardDeclarationsAndEnums.hh
index 3da2c3c1..98f10f0c 100644
--- a/src/macro/ForwardDeclarationsAndEnums.hh
+++ b/src/macro/ForwardDeclarationsAndEnums.hh
@@ -68,6 +68,7 @@ namespace macro
     enum class UnaryOp
       {
        cast_int,
+       cast_double,
        logical_not,
        unary_minus,
        unary_plus,
diff --git a/src/macro/Parser.yy b/src/macro/Parser.yy
index d2764da2..1c850945 100644
--- a/src/macro/Parser.yy
+++ b/src/macro/Parser.yy
@@ -65,7 +65,7 @@ using namespace macro;
 %token SQRT CBRT SIGN MAX MIN FLOOR CEIL TRUNC SUM MOD
 %token ERF ERFC GAMMA LGAMMA ROUND NORMPDF NORMCDF LENGTH
 
-%token INT
+%token INT DOUBLE
 
 %left OR
 %left AND
@@ -78,7 +78,7 @@ using namespace macro;
 %left PLUS MINUS
 %left TIMES DIVIDE
 %precedence UMINUS UPLUS NOT
-%precedence CAST_INT
+%precedence CAST_INT CAST_DOUBLE
 %nonassoc POWER
 
 %token <string> NAME TEXT QUOTED_STRING NUMBER EOL
@@ -324,6 +324,8 @@ expr : LPAREN expr RPAREN
        { $$ = make_shared<Comprehension>($2, $4, $6, $8, driver.env, @$); }
      | LPAREN INT RPAREN expr %prec CAST_INT
        { $$ = make_shared<UnaryOp>(codes::UnaryOp::cast_int, $4, driver.env, @$); }
+     | LPAREN DOUBLE RPAREN expr %prec CAST_DOUBLE
+       { $$ = make_shared<UnaryOp>(codes::UnaryOp::cast_double, $4, driver.env, @$); }
      | NOT expr
        { $$ = make_shared<UnaryOp>(codes::UnaryOp::logical_not, $2, driver.env, @$); }
      | MINUS expr %prec UMINUS
diff --git a/src/macro/Tokenizer.ll b/src/macro/Tokenizer.ll
index d09a57b3..7f499531 100644
--- a/src/macro/Tokenizer.ll
+++ b/src/macro/Tokenizer.ll
@@ -143,6 +143,7 @@ CONT \\\\{SPC}*
 <expr,eval>normcdf         { return token::NORMCDF; }
 
 <expr,eval>int             { return token::INT; }
+<expr,eval>double          { return token::DOUBLE; }
 
 <expr,eval>((([0-9]*\.[0-9]+)|([0-9]+\.))([ed][-+]?[0-9]+)?)|([0-9]+([ed][-+]?[0-9]+)?)|nan|inf {
   yylval->build<string>(yytext);
-- 
GitLab