From 53987fc039d2342be7a21c63d3cbd5166799e659 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?S=C3=A9bastien=20Villemot?= <sebastien@dynare.org>
Date: Fri, 3 Nov 2023 17:56:06 +0100
Subject: [PATCH] Block: more symmetry between bytecode and non-bytecode write
 helpers

---
 src/ModelTree.hh | 52 +++++++++++++++++++-----------------------------
 1 file changed, 21 insertions(+), 31 deletions(-)

diff --git a/src/ModelTree.hh b/src/ModelTree.hh
index d7cdea22..6169ba74 100644
--- a/src/ModelTree.hh
+++ b/src/ModelTree.hh
@@ -1193,25 +1193,23 @@ ModelTree::writePerBlockHelper(int blk, ostream &output, temporary_terms_t &temp
               lhs = e->arg1;
               rhs = e->arg2;
             }
-          else if (equ_type != EquationType::evaluate)
-            {
-              cerr << "Type mismatch for equation " << getBlockEquationID(blk, eq)+1  << endl;
-              exit(EXIT_FAILURE);
-            }
+          else
+            assert(equ_type == EquationType::evaluate);
           output << "  ";
           lhs->writeOutput(output, output_type, temporary_terms, blocks_temporary_terms_idxs);
           output << '=';
           rhs->writeOutput(output, output_type, temporary_terms, blocks_temporary_terms_idxs);
           output << ';' << endl;
           break;
-        case BlockSimulationType::solveBackwardSimple:
-        case BlockSimulationType::solveForwardSimple:
         case BlockSimulationType::solveBackwardComplete:
         case BlockSimulationType::solveForwardComplete:
         case BlockSimulationType::solveTwoBoundariesComplete:
         case BlockSimulationType::solveTwoBoundariesSimple:
           if (eq < block_recursive_size)
             goto evaluation;
+          [[fallthrough]];
+        case BlockSimulationType::solveBackwardSimple:
+        case BlockSimulationType::solveForwardSimple:
           output << "  residual" << LEFT_ARRAY_SUBSCRIPT(output_type)
                  << eq-block_recursive_size+ARRAY_SUBSCRIPT_OFFSET(output_type)
                  << RIGHT_ARRAY_SUBSCRIPT(output_type) << "=(";
@@ -1220,9 +1218,6 @@ ModelTree::writePerBlockHelper(int blk, ostream &output, temporary_terms_t &temp
           rhs->writeOutput(output, output_type, temporary_terms, blocks_temporary_terms_idxs);
           output << ");" << endl;
           break;
-        default:
-          cerr << "Incorrect type for block " << blk+1 << endl;
-          exit(EXIT_FAILURE);
         }
     }
 
@@ -1681,29 +1676,25 @@ ModelTree::writeBlockBytecodeHelper(BytecodeWriter &code_file, int block, tempor
     {
       write_eq_tt(i);
 
+      EquationType equ_type { getBlockEquationType(block, i) };
+      BinaryOpNode *e { getBlockEquationExpr(block, i) };
+      expr_t lhs { e->arg1 }, rhs { e->arg2 };
       switch (simulation_type)
         {
-        evaluation:
         case BlockSimulationType::evaluateBackward:
         case BlockSimulationType::evaluateForward:
-          code_file << FNUMEXPR_{ExpressionType::ModelEquation, getBlockEquationID(block, i)};
-          if (EquationType equ_type {getBlockEquationType(block, i)};
-              equ_type == EquationType::evaluate)
-            {
-              BinaryOpNode *eq_node {getBlockEquationExpr(block, i)};
-              expr_t lhs {eq_node->arg1};
-              expr_t rhs {eq_node->arg2};
-              rhs->writeBytecodeOutput(code_file, output_type, temporary_terms_union, blocks_temporary_terms_idxs, tef_terms);
-              lhs->writeBytecodeOutput(code_file, assignment_lhs_output_type, temporary_terms_union, blocks_temporary_terms_idxs, tef_terms);
-            }
-          else if (equ_type == EquationType::evaluateRenormalized)
+          evaluation:
+          if (equ_type == EquationType::evaluateRenormalized)
             {
-              BinaryOpNode *eq_node {getBlockEquationRenormalizedExpr(block, i)};
-              expr_t lhs {eq_node->arg1};
-              expr_t rhs {eq_node->arg2};
-              rhs->writeBytecodeOutput(code_file, output_type, temporary_terms_union, blocks_temporary_terms_idxs, tef_terms);
-              lhs->writeBytecodeOutput(code_file, assignment_lhs_output_type, temporary_terms_union, blocks_temporary_terms_idxs, tef_terms);
+              e = getBlockEquationRenormalizedExpr(block, i);
+              lhs = e->arg1;
+              rhs = e->arg2;
             }
+          else
+            assert(equ_type == EquationType::evaluate);
+          code_file << FNUMEXPR_{ExpressionType::ModelEquation, getBlockEquationID(block, i)};
+          rhs->writeBytecodeOutput(code_file, output_type, temporary_terms_union, blocks_temporary_terms_idxs, tef_terms);
+          lhs->writeBytecodeOutput(code_file, assignment_lhs_output_type, temporary_terms_union, blocks_temporary_terms_idxs, tef_terms);
           break;
         case BlockSimulationType::solveBackwardComplete:
         case BlockSimulationType::solveForwardComplete:
@@ -1712,14 +1703,13 @@ ModelTree::writeBlockBytecodeHelper(BytecodeWriter &code_file, int block, tempor
           if (i < block_recursive)
             goto evaluation;
           [[fallthrough]];
-        default:
+        case BlockSimulationType::solveBackwardSimple:
+        case BlockSimulationType::solveForwardSimple:
           code_file << FNUMEXPR_{ExpressionType::ModelEquation, getBlockEquationID(block, i)};
-          BinaryOpNode *eq_node {getBlockEquationExpr(block, i)};
-          expr_t lhs {eq_node->arg1};
-          expr_t rhs {eq_node->arg2};
           lhs->writeBytecodeOutput(code_file, output_type, temporary_terms_union, blocks_temporary_terms_idxs, tef_terms);
           rhs->writeBytecodeOutput(code_file, output_type, temporary_terms_union, blocks_temporary_terms_idxs, tef_terms);
           code_file << FBINARY_{BinaryOpcode::minus} << FSTPR_{i - block_recursive};
+          break;
         }
     }
 
-- 
GitLab