From 952e899f3acea73468e4ceab3fcad86f68718be8 Mon Sep 17 00:00:00 2001
From: Houtan Bastani <houtan@dynare.org>
Date: Tue, 10 Dec 2019 16:25:28 +0100
Subject: [PATCH] fix bug in macro processor ensuring short-circuit
 functionality of `||` and `&&` statements

closes dynare#1676
---
 src/macro/Expressions.cc | 79 ++++++++++++++++++++++++----------------
 src/macro/Expressions.hh | 12 +++---
 2 files changed, 53 insertions(+), 38 deletions(-)

diff --git a/src/macro/Expressions.cc b/src/macro/Expressions.cc
index 5fd216e4..16a4553a 100644
--- a/src/macro/Expressions.cc
+++ b/src/macro/Expressions.cc
@@ -39,25 +39,33 @@ Bool::is_equal(const BaseTypePtr &btp) const
 }
 
 BoolPtr
-Bool::logical_and(const BaseTypePtr &btp) const
+Bool::logical_and(const ExpressionPtr &ep) const
 {
+  if (!value)
+    return make_shared<Bool>(false, env);
+
+  auto btp = ep->eval();
   if (auto btp2 = dynamic_pointer_cast<Bool>(btp); btp2)
-    return make_shared<Bool>(value && *btp2, env);
+    return make_shared<Bool>(*btp2, env);
 
   if (auto btp2 = dynamic_pointer_cast<Real>(btp); btp2)
-    return make_shared<Bool>(value && *btp2, env);
+    return make_shared<Bool>(*btp2, env);
 
   throw StackTrace("Type mismatch for operands of && operator");
 }
 
 BoolPtr
-Bool::logical_or(const BaseTypePtr &btp) const
+Bool::logical_or(const ExpressionPtr &ep) const
 {
+  if (value)
+    return make_shared<Bool>(true, env);
+
+  auto btp = ep->eval();
   if (auto btp2 = dynamic_pointer_cast<Bool>(btp); btp2)
-    return make_shared<Bool>(value || *btp2, env);
+    return make_shared<Bool>(*btp2, env);
 
   if (auto btp2 = dynamic_pointer_cast<Real>(btp); btp2)
-    return make_shared<Bool>(value || *btp2, env);
+    return make_shared<Bool>(*btp2, env);
 
   throw StackTrace("Type mismatch for operands of || operator");
 }
@@ -159,25 +167,33 @@ Real::is_equal(const BaseTypePtr &btp) const
 }
 
 BoolPtr
-Real::logical_and(const BaseTypePtr &btp) const
+Real::logical_and(const ExpressionPtr &ep) const
 {
+  if (!value)
+    return make_shared<Bool>(false, env);
+
+  auto btp = ep->eval();
   if (auto btp2 = dynamic_pointer_cast<Real>(btp); btp2)
-    return make_shared<Bool>(value && *btp2, env);
+    return make_shared<Bool>(*btp2, env);
 
   if (auto btp2 = dynamic_pointer_cast<Bool>(btp); btp2)
-    return make_shared<Bool>(value && *btp2, env);
+    return make_shared<Bool>(*btp2, env);
 
   throw StackTrace("Type mismatch for operands of && operator");
 }
 
 BoolPtr
-Real::logical_or(const BaseTypePtr &btp) const
+Real::logical_or(const ExpressionPtr &ep) const
 {
+  if (value)
+    return make_shared<Bool>(true, env);
+
+  auto btp = ep->eval();
   if (auto btp2 = dynamic_pointer_cast<Real>(btp); btp2)
-    return make_shared<Bool>(value || *btp2, env);
+    return make_shared<Bool>(*btp2, env);
 
   if (auto btp2 = dynamic_pointer_cast<Bool>(btp); btp2)
-    return make_shared<Bool>(value || *btp2, env);
+    return make_shared<Bool>(*btp2, env);
 
   throw StackTrace("Type mismatch for operands of || operator");
 }
@@ -859,47 +875,46 @@ BinaryOp::eval()
   try
     {
       auto arg1bt = arg1->eval();
-      auto arg2bt = arg2->eval();
       switch(op_code)
         {
         case codes::BinaryOp::plus:
-          return arg1bt->plus(arg2bt);
+          return arg1bt->plus(arg2->eval());
         case codes::BinaryOp::minus:
-          return arg1bt->minus(arg2bt);
+          return arg1bt->minus(arg2->eval());
         case codes::BinaryOp::times:
-          return arg1bt->times(arg2bt);
+          return arg1bt->times(arg2->eval());
         case codes::BinaryOp::divide:
-          return arg1bt->divide(arg2bt);
+          return arg1bt->divide(arg2->eval());
         case codes::BinaryOp::power:
-          return arg1bt->power( arg2bt);
+          return arg1bt->power( arg2->eval());
         case codes::BinaryOp::equal_equal:
-          return arg1bt->is_equal(arg2bt);
+          return arg1bt->is_equal(arg2->eval());
         case codes::BinaryOp::not_equal:
-          return arg1bt->is_different(arg2bt);
+          return arg1bt->is_different(arg2->eval());
         case codes::BinaryOp::less:
-          return arg1bt->is_less(arg2bt);
+          return arg1bt->is_less(arg2->eval());
         case codes::BinaryOp::greater:
-          return arg1bt->is_greater(arg2bt);
+          return arg1bt->is_greater(arg2->eval());
         case codes::BinaryOp::less_equal:
-          return arg1bt->is_less_equal(arg2bt);
+          return arg1bt->is_less_equal(arg2->eval());
         case codes::BinaryOp::greater_equal:
-          return arg1bt->is_greater_equal(arg2bt);
+          return arg1bt->is_greater_equal(arg2->eval());
         case codes::BinaryOp::logical_and:
-          return arg1bt->logical_and(arg2bt);
+          return arg1bt->logical_and(arg2);
         case codes::BinaryOp::logical_or:
-          return arg1bt->logical_or(arg2bt);
+          return arg1bt->logical_or(arg2);
         case codes::BinaryOp::in:
-          return arg2bt->contains(arg1bt);
+          return arg2->eval()->contains(arg1bt);
         case codes::BinaryOp::set_union:
-          return arg1bt->set_union(arg2bt);
+          return arg1bt->set_union(arg2->eval());
         case codes::BinaryOp::set_intersection:
-          return arg1bt->set_intersection(arg2bt);
+          return arg1bt->set_intersection(arg2->eval());
         case codes::BinaryOp::max:
-          return arg1bt->max(arg2bt);
+          return arg1bt->max(arg2->eval());
         case codes::BinaryOp::min:
-          return arg1bt->min(arg2bt);
+          return arg1bt->min(arg2->eval());
         case codes::BinaryOp::mod:
-          return arg1bt->mod(arg2bt);
+          return arg1bt->mod(arg2->eval());
         }
     }
   catch (StackTrace &ex)
diff --git a/src/macro/Expressions.hh b/src/macro/Expressions.hh
index fdc3843a..668051e4 100644
--- a/src/macro/Expressions.hh
+++ b/src/macro/Expressions.hh
@@ -145,8 +145,8 @@ namespace macro
     virtual BoolPtr is_greater_equal(const BaseTypePtr &btp) const { throw StackTrace("Operator >= does not exist for this type"); }
     virtual BoolPtr is_equal(const BaseTypePtr &btp) const = 0;
     virtual BoolPtr is_different(const BaseTypePtr &btp) const final;
-    virtual BoolPtr logical_and(const BaseTypePtr &btp) const { throw StackTrace("Operator && does not exist for this type"); }
-    virtual BoolPtr logical_or(const BaseTypePtr &btp) const { throw StackTrace("Operator || does not exist for this type"); }
+    virtual BoolPtr logical_and(const ExpressionPtr &ep) const { throw StackTrace("Operator && does not exist for this type"); }
+    virtual BoolPtr logical_or(const ExpressionPtr &ep) const { throw StackTrace("Operator || does not exist for this type"); }
     virtual BoolPtr logical_not() const { throw StackTrace("Operator ! does not exist for this type"); }
     virtual ArrayPtr set_union(const BaseTypePtr &btp) const { throw StackTrace("Operator | does not exist for this type"); }
     virtual ArrayPtr set_intersection(const BaseTypePtr &btp) const { throw StackTrace("Operator & does not exist for this type"); }
@@ -216,8 +216,8 @@ namespace macro
   public:
     operator bool() const { return value; }
     BoolPtr is_equal(const BaseTypePtr &btp) const override;
-    BoolPtr logical_and(const BaseTypePtr &btp) const override;
-    BoolPtr logical_or(const BaseTypePtr &btp) const override;
+    BoolPtr logical_and(const ExpressionPtr &ep) const override;
+    BoolPtr logical_or(const ExpressionPtr &ep) const override;
     BoolPtr logical_not() const override;
     inline BoolPtr isboolean() const noexcept override { return make_shared<Bool>(true, env, location); }
     inline BoolPtr cast_bool() const override { return make_shared<Bool>(value, env); }
@@ -278,8 +278,8 @@ namespace macro
       double intpart;
       return make_shared<Bool>(modf(value, &intpart) == 0.0, env, location);
     }
-    BoolPtr logical_and(const BaseTypePtr &btp) const override;
-    BoolPtr logical_or(const BaseTypePtr &btp) const override;
+    BoolPtr logical_and(const ExpressionPtr &ep) const override;
+    BoolPtr logical_or(const ExpressionPtr &ep) const override;
     BoolPtr logical_not() const override;
     RealPtr max(const BaseTypePtr &btp) const override;
     RealPtr min(const BaseTypePtr &btp) const override;
-- 
GitLab