From adab6c7f93d5ec88ac72c187bb8dd991c442124f Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?S=C3=A9bastien=20Villemot?= <sebastien@dynare.org>
Date: Fri, 28 Jan 2022 17:24:48 +0100
Subject: [PATCH] Comment improvement + cosmetics

---
 src/ExprNode.cc | 45 ++++++++++++++++++++++++---------------------
 src/ExprNode.hh |  9 +++++++--
 2 files changed, 31 insertions(+), 23 deletions(-)

diff --git a/src/ExprNode.cc b/src/ExprNode.cc
index 48654d7e..1c11d7ac 100644
--- a/src/ExprNode.cc
+++ b/src/ExprNode.cc
@@ -296,14 +296,15 @@ ExprNode::fillErrorCorrectionRow(int eqn,
   vector<pair<expr_t, int>> terms;
   decomposeAdditiveTerms(terms, 1);
 
-  for (const auto &it : terms)
+  for (const auto &[term, sign] : terms)
     {
-      pair<int, vector<tuple<int, int, int, double>>> m;
+      int speed_of_adjustment_param;
+      vector<tuple<int, int, int, double>> error_linear_combination;
       try
         {
-          m = it.first->matchParamTimesLinearCombinationOfVariables();
-          for (auto &t : m.second)
-            get<3>(t) *= it.second; // Update sign of constants
+          tie(speed_of_adjustment_param, error_linear_combination) = term->matchParamTimesLinearCombinationOfVariables();
+          for (auto &[var_id, lag, param_id, constant] : error_linear_combination)
+            constant *= sign; // Update sign of constants
         }
       catch (MatchFailureException &e)
         {
@@ -315,17 +316,17 @@ ExprNode::fillErrorCorrectionRow(int eqn,
       /* Verify that all variables belong to the error-correction term.
          FIXME: same remark as above about skipping terms. */
       bool not_ec = false;
-      for (const auto &t : m.second)
+      for (const auto &[var_id, lag, param_id, constant] : error_linear_combination)
         {
-          auto [vid, vlag] = datatree.symbol_table.unrollDiffLeadLagChain(get<0>(t), get<1>(t));
-          not_ec = not_ec || (find(target_lhs.begin(), target_lhs.end(), vid) == target_lhs.end()
-                              && find(nontarget_lhs.begin(), nontarget_lhs.end(), vid) == nontarget_lhs.end());
+          auto [orig_var_id, orig_lag] = datatree.symbol_table.unrollDiffLeadLagChain(var_id, lag);
+          not_ec = not_ec || (find(target_lhs.begin(), target_lhs.end(), orig_var_id) == target_lhs.end()
+                              && find(nontarget_lhs.begin(), nontarget_lhs.end(), orig_var_id) == nontarget_lhs.end());
         }
       if (not_ec)
         continue;
 
       // Now fill the matrices
-      for (auto [var_id, lag, param_id, constant] : m.second)
+      for (auto [var_id, lag, param_id, constant] : error_linear_combination)
         {
           auto [orig_vid, orig_lag] = datatree.symbol_table.unrollDiffLeadLagChain(var_id, lag);
           if (find(target_lhs.begin(), target_lhs.end(), orig_vid) == target_lhs.end())
@@ -353,13 +354,14 @@ ExprNode::fillErrorCorrectionRow(int eqn,
                        << "symb_id encountered more than once in equation" << endl;
                   exit(EXIT_FAILURE);
                 }
-              A0[{eqn, colidx}] = datatree.AddVariable(m.first);
+              A0[{eqn, colidx}] = datatree.AddVariable(speed_of_adjustment_param);
             }
           else
             {
               // This is a target, so fill A0star
               int colidx = static_cast<int>(distance(target_lhs.begin(), find(target_lhs.begin(), target_lhs.end(), orig_vid)));
-              expr_t e = datatree.AddTimes(datatree.AddVariable(m.first), datatree.AddPossiblyNegativeConstant(-constant));
+              expr_t e = datatree.AddTimes(datatree.AddVariable(speed_of_adjustment_param),
+                                           datatree.AddPossiblyNegativeConstant(-constant));
               if (param_id != -1)
                 e = datatree.AddTimes(e, datatree.AddVariable(param_id));
               if (auto coor = make_pair(eqn, colidx); A0star.find(coor) == A0star.end())
@@ -5293,21 +5295,23 @@ BinaryOpNode::getPacAREC(int lhs_symb_id, int lhs_orig_symb_id,
       exit(EXIT_FAILURE);
     }
 
-  for (const auto &it : terms)
+  for (const auto &[term, sign] : terms)
     {
-      if (dynamic_cast<PacExpectationNode *>(it.first))
+      if (dynamic_cast<PacExpectationNode *>(term))
         continue;
 
-      pair<int, vector<tuple<int, int, int, double>>> m;
+      int pid;
+      vector<tuple<int, int, int, double>> linear_combination;
       try
         {
-          m = {-1, {it.first->matchVariableTimesConstantTimesParam()}};
+          pid = -1;
+          linear_combination = { term->matchVariableTimesConstantTimesParam() };
         }
       catch (MatchFailureException &e)
         {
           try
             {
-              m = it.first->matchParamTimesLinearCombinationOfVariables();
+              tie(pid, linear_combination) = term->matchParamTimesLinearCombinationOfVariables();
             }
           catch (MatchFailureException &e)
             {
@@ -5316,11 +5320,10 @@ BinaryOpNode::getPacAREC(int lhs_symb_id, int lhs_orig_symb_id,
             }
         }
 
-      for (auto &t : m.second)
-        get<3>(t) *= it.second; // Update sign of constants
+      for (auto &[vid, vlag, pidtmp, constant] : linear_combination)
+        constant *= sign; // Update sign of constants
 
-      int pid = get<0>(m);
-      for (auto [vid, vlag, pidtmp, constant] : m.second)
+      for (auto [vid, vlag, pidtmp, constant] : linear_combination)
         {
           if (pid == -1)
             pid = pidtmp;
diff --git a/src/ExprNode.hh b/src/ExprNode.hh
index 6b9c0051..b484318b 100644
--- a/src/ExprNode.hh
+++ b/src/ExprNode.hh
@@ -644,6 +644,13 @@ public:
   */
   vector<tuple<int, int, int, double>> matchLinearCombinationOfVariables(bool variable_obligatory_in_each_term = true) const;
 
+  /* Matches a parameter, times a linear combination of variables (endo or
+     exo), where scalars can be constant*parameters.
+     The first output argument is the symbol ID of the parameter.
+     The second output argument is the linear combination, in the same format
+     as the output of matchLinearCombinationOfVariables(). */
+  pair<int, vector<tuple<int, int, int, double>>> matchParamTimesLinearCombinationOfVariables() const;
+
   /* Matches a linear combination of endogenous, where scalars can be any
      constant expression (i.e. containing no endogenous, no exogenous and no
      exogenous deterministic). The linear combination can contain constant
@@ -653,8 +660,6 @@ public:
      – the sum of all constant (intercept) terms */
   pair<vector<pair<int, expr_t>>, expr_t> matchLinearCombinationOfEndogenousWithConstant() const;
 
-  pair<int, vector<tuple<int, int, int, double>>> matchParamTimesLinearCombinationOfVariables() const;
-
   /* Matches an expression of the form parameter*(var1-endo2).
      endo2 must correspond to symb_id. var1 must be an endogenous or an
      exogenous; it must be of the form X(-1) or log(X(-1)) or log(X)(-1) (unary ops aux var),
-- 
GitLab