From 9b98da424a6d6e35b667f3aa99637b6ae232eee4 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?S=C3=A9bastien=20Villemot?= <sebastien@dynare.org>
Date: Tue, 19 Mar 2019 18:31:27 +0100
Subject: [PATCH] Simplification in BinaryOpNode::getPacAREC()

---
 src/ExprNode.cc | 84 +++++++++++++++++++------------------------------
 1 file changed, 33 insertions(+), 51 deletions(-)

diff --git a/src/ExprNode.cc b/src/ExprNode.cc
index f7e146ca..c4ff0ce8 100644
--- a/src/ExprNode.cc
+++ b/src/ExprNode.cc
@@ -5617,63 +5617,45 @@ BinaryOpNode::getPacAREC(int lhs_symb_id, int lhs_orig_symb_id,
 
   for (const auto & it : terms)
     {
-      auto bopn = dynamic_cast<BinaryOpNode *>(it.first);
-      auto pen = dynamic_cast<PacExpectationNode *>(it.first);
-      if (pen)
+      if (dynamic_cast<PacExpectationNode *>(it.first))
         continue;
-      if (bopn != nullptr)
+
+      int pid, vid, lag;
+      double constant;
+      try
+        {
+          tie(vid, lag, pid, constant) = it.first->matchVariableTimesConstantTimesParam();
+          constant *= it.second;
+        }
+      catch (MatchFailureException &e)
         {
-          auto vn1 = dynamic_cast<VariableNode *>(bopn->arg1);
-          auto vn2 = dynamic_cast<VariableNode *>(bopn->arg2);
-          if (vn1 && vn2)
+          cerr << "Unsupported expression in PAC equation" << endl;
+          exit(EXIT_FAILURE);
+        }
+
+      int vidorig = vid;
+      while (datatree.symbol_table.isAuxiliaryVariable(vid))
+        try
+          {
+            vid = datatree.symbol_table.getOrigSymbIdForAuxVar(vid);
+          }
+        catch (...)
+          {
+            break;
+          }
+      if (vid == lhs_symb_id || vid == lhs_orig_symb_id)
+        {
+          // This is an autoregressive term
+          if (constant != 1 || pid == -1)
             {
-              int pid, vid, lag;
-              pid = vid = lag = -1;
-              if (datatree.symbol_table.getType(vn1->symb_id) == SymbolType::parameter
-                  && (datatree.symbol_table.getType(vn2->symb_id) == SymbolType::endogenous
-                      || datatree.symbol_table.getType(vn2->symb_id) == SymbolType::exogenous))
-                {
-                  pid = vn1->symb_id;
-                  vid = vn2->symb_id;
-                  lag = vn2->lag;
-                }
-              else if (datatree.symbol_table.getType(vn2->symb_id) == SymbolType::parameter
-                       && (datatree.symbol_table.getType(vn1->symb_id) == SymbolType::endogenous
-                           || datatree.symbol_table.getType(vn1->symb_id) == SymbolType::exogenous))
-                {
-                  pid = vn2->symb_id;
-                  vid = vn1->symb_id;
-                  lag = vn1->lag;
-                }
-              if (pid >= 0 && vid >= 0)
-                {
-                  int vidorig = vid;
-                  while (datatree.symbol_table.isAuxiliaryVariable(vid))
-                    try
-                      {
-                        vid = datatree.symbol_table.getOrigSymbIdForAuxVar(vid);
-                      }
-                    catch (...)
-                      {
-                        break;
-                      }
-                  if (vid == lhs_symb_id || vid == lhs_orig_symb_id)
-                    ar_params_and_vars.insert({pid, {vidorig, lag}});
-                  else
-                    {
-                      auto m = it.first->matchVariableTimesConstantTimesParam();
-                      get<3>(m) *= it.second;
-                      additive_vars_params_and_constants.push_back(m);
-                    }
-                }
+              cerr << "BinaryOpNode::getPacAREC: autoregressive terms must be of the form 'parameter*lagged_variable" << endl;
+              exit(EXIT_FAILURE);
             }
+          ar_params_and_vars.insert({pid, { vidorig, lag }});
         }
       else
-        {
-          auto m = it.first->matchVariableTimesConstantTimesParam();
-          get<3>(m) *= it.second;
-          additive_vars_params_and_constants.push_back(m);
-        }
+        // This is a residual additive term
+        additive_vars_params_and_constants.push_back({ vidorig, lag, pid, constant});
     }
 }
 
-- 
GitLab