From e5d093c6ad8d57de55e2f7c74f3a8f1163f45b05 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?S=C3=A9bastien=20Villemot?= <sebastien.villemot@ens.fr>
Date: Fri, 11 Jun 2010 19:11:27 +0200
Subject: [PATCH] Estimation DLL: refactor detrending stuff to avoid allocating
 the matrix for detrended data at every iteration

---
 mex/sources/estimation/DetrendData.cc           |  6 +++---
 mex/sources/estimation/DetrendData.hh           |  2 +-
 .../estimation/InitializeKalmanFilter.cc        | 10 ++++++----
 .../estimation/InitializeKalmanFilter.hh        |  4 ++--
 mex/sources/estimation/KalmanFilter.cc          | 17 ++++++++---------
 mex/sources/estimation/KalmanFilter.hh          |  5 +++--
 mex/sources/estimation/LogLikelihoodMain.cc     |  9 ++++++---
 mex/sources/estimation/LogLikelihoodMain.hh     |  1 +
 .../estimation/LogLikelihoodSubSample.cc        |  4 ++--
 .../estimation/LogLikelihoodSubSample.hh        |  2 +-
 10 files changed, 33 insertions(+), 27 deletions(-)

diff --git a/mex/sources/estimation/DetrendData.cc b/mex/sources/estimation/DetrendData.cc
index 6e045f5f6..928e78959 100644
--- a/mex/sources/estimation/DetrendData.cc
+++ b/mex/sources/estimation/DetrendData.cc
@@ -31,9 +31,9 @@ DetrendData::DetrendData(const bool INlogLinear) //, Vector& INtrendCoeff)
 };
 
 void
-DetrendData::detrend(const VectorView &SteadyState, const MatrixConstView &dataView, Matrix &Y)
+DetrendData::detrend(const VectorView &SteadyState, const MatrixConstView &dataView,
+                     MatrixView &detrendedDataView)
 {
-// dummy
-Y=dataView;
+  detrendedDataView = dataView;
 };
 
diff --git a/mex/sources/estimation/DetrendData.hh b/mex/sources/estimation/DetrendData.hh
index 674e2ec81..5f26d9f69 100644
--- a/mex/sources/estimation/DetrendData.hh
+++ b/mex/sources/estimation/DetrendData.hh
@@ -33,7 +33,7 @@ class DetrendData
 public:
   virtual ~DetrendData(){};
   DetrendData(const bool logLinear); // add later Vector& trendCoeff);
-  void detrend(const VectorView &SteadyState, const MatrixConstView &dataView, Matrix &Y);
+  void detrend(const VectorView &SteadyState, const MatrixConstView &dataView, MatrixView &detrendedDataView);
 
 private:
   const bool logLinear;
diff --git a/mex/sources/estimation/InitializeKalmanFilter.cc b/mex/sources/estimation/InitializeKalmanFilter.cc
index 0df72265a..9c9176964 100644
--- a/mex/sources/estimation/InitializeKalmanFilter.cc
+++ b/mex/sources/estimation/InitializeKalmanFilter.cc
@@ -59,10 +59,11 @@ InitializeKalmanFilter::InitializeKalmanFilter(const std::string &dynamicDllFile
 void
 InitializeKalmanFilter::initialize(VectorView &steadyState, const Vector &deepParams, Matrix &R,
                                    const Matrix &Q, Matrix &RQRt, Matrix &T, 
-                                   double &penalty, const MatrixConstView &dataView, Matrix &Y, int &info)
+                                   double &penalty, const MatrixConstView &dataView,
+                                   MatrixView &detrendedDataView, int &info)
 {
   modelSolution.compute(steadyState, deepParams, g_x, g_u);
-  detrendData.detrend(steadyState, dataView, Y);
+  detrendData.detrend(steadyState, dataView, detrendedDataView);
 
   setT(T, info);
   setRQR(R, Q, RQRt, info);
@@ -72,9 +73,10 @@ InitializeKalmanFilter::initialize(VectorView &steadyState, const Vector &deepPa
 void
 InitializeKalmanFilter::initialize(VectorView &steadyState, const Vector &deepParams, Matrix &R,
                                    const Matrix &Q, Matrix &RQRt, Matrix &T, Matrix &Pstar, Matrix &Pinf,
-                                   double &penalty, const MatrixConstView &dataView, Matrix &Y, int &info)
+                                   double &penalty, const MatrixConstView &dataView,
+                                   MatrixView &detrendedDataView, int &info)
 {
-  initialize(steadyState, deepParams, R, Q, RQRt, T, penalty, dataView, Y, info);
+  initialize(steadyState, deepParams, R, Q, RQRt, T, penalty, dataView, detrendedDataView, info);
   setPstar(Pstar, Pinf, T, RQRt, info);
 }
 
diff --git a/mex/sources/estimation/InitializeKalmanFilter.hh b/mex/sources/estimation/InitializeKalmanFilter.hh
index f5fbb753a..196b3e478 100644
--- a/mex/sources/estimation/InitializeKalmanFilter.hh
+++ b/mex/sources/estimation/InitializeKalmanFilter.hh
@@ -50,10 +50,10 @@ public:
   virtual ~InitializeKalmanFilter();
   // initialise all KF matrices
   void initialize(VectorView &steadyState, const Vector &deepParams, Matrix &R, const Matrix &Q, Matrix &RQRt,
-                  Matrix &T, Matrix &Pstar, Matrix &Pinf, double &penalty, const MatrixConstView &dataView, Matrix &Y, int &info);
+                  Matrix &T, Matrix &Pstar, Matrix &Pinf, double &penalty, const MatrixConstView &dataView, MatrixView &detrendedDataView, int &info);
   // initialise parameter dependent KF matrices only but not Ps
   void initialize(VectorView &steadyState, const Vector &deepParams, Matrix &R, const Matrix &Q, Matrix &RQRt,
-                  Matrix &T, double &penalty, const MatrixConstView &dataView, Matrix &Y, int &info);
+                  Matrix &T, double &penalty, const MatrixConstView &dataView, MatrixView &detrendedDataView, int &info);
 
 private:
   const double lyapunov_tol;
diff --git a/mex/sources/estimation/KalmanFilter.cc b/mex/sources/estimation/KalmanFilter.cc
index 70b27303b..3a5d9216b 100644
--- a/mex/sources/estimation/KalmanFilter.cc
+++ b/mex/sources/estimation/KalmanFilter.cc
@@ -73,18 +73,17 @@ KalmanFilter::compute_zeta_varobs_back_mixed(const std::vector<size_t> &zeta_bac
 double
 KalmanFilter::compute(const MatrixConstView &dataView, VectorView &steadyState,
                       const Matrix &Q, const Matrix &H, const Vector &deepParams,
-                      VectorView &vll, size_t start, size_t period, double &penalty, int &info)
+                      VectorView &vll, MatrixView &detrendedDataView,
+                      size_t start, size_t period, double &penalty, int &info)
 {
-  Matrix Y(dataView.getRows(), dataView.getCols());    // data
-
   if(period==0) // initialise all KF matrices
     initKalmanFilter.initialize(steadyState, deepParams, R, Q, RQRt, T, Pstar, Pinf,
-                              penalty, dataView, Y, info);
+                              penalty, dataView, detrendedDataView, info);
   else  // initialise parameter dependent KF matrices only but not Ps
     initKalmanFilter.initialize(steadyState, deepParams, R, Q, RQRt, T, 
-                              penalty, dataView, Y, info);
+                              penalty, dataView, detrendedDataView, info);
 
-  return filter(Y, H, vll, start, info);
+  return filter(detrendedDataView, H, vll, start, info);
 
 };
 
@@ -92,13 +91,13 @@ KalmanFilter::compute(const MatrixConstView &dataView, VectorView &steadyState,
  * 30:*
  */
 double
-KalmanFilter::filter(const Matrix &dataView,  const Matrix &H, VectorView &vll, size_t start, int &info)
+KalmanFilter::filter(const MatrixView &detrendedDataView,  const Matrix &H, VectorView &vll, size_t start, int &info)
 {
   double loglik=0.0, ll, logFdet, Fdet;
   size_t p = Finv.getRows();
 
   bool nonstationary = true;
-  for (size_t t = 0; t < dataView.getCols(); ++t)
+  for (size_t t = 0; t < detrendedDataView.getCols(); ++t)
     {
       if (nonstationary)
         {
@@ -138,7 +137,7 @@ KalmanFilter::filter(const Matrix &dataView,  const Matrix &H, VectorView &vll,
         }
 
       // err= Yt - Za
-      MatrixConstView yt(dataView, 0, t, dataView.getRows(), 1); // current observation vector
+      MatrixConstView yt(detrendedDataView, 0, t, detrendedDataView.getRows(), 1); // current observation vector
       vt = yt;
       blas::gemm("N", "N", -1.0, Z, a_init, 1.0, vt);
       // at+1= T(at+ KFinv *err)
diff --git a/mex/sources/estimation/KalmanFilter.hh b/mex/sources/estimation/KalmanFilter.hh
index 8c3769ef7..c9b9c548e 100644
--- a/mex/sources/estimation/KalmanFilter.hh
+++ b/mex/sources/estimation/KalmanFilter.hh
@@ -55,7 +55,8 @@ public:
 
   double compute(const MatrixConstView &dataView, VectorView &steadyState,
                  const Matrix &Q, const Matrix &H, const Vector &deepParams,
-                 VectorView &vll, size_t start, size_t period, double &penalty, int &info);
+                 VectorView &vll, MatrixView &detrendedDataView, size_t start, size_t period,
+                 double &penalty, int &info);
 
 private:
   const std::vector<size_t> zeta_varobs_back_mixed;
@@ -77,7 +78,7 @@ private:
   InitializeKalmanFilter initKalmanFilter; //Initialise KF matrices
 
   // Method
-  double filter(const Matrix &data,  const Matrix &H, VectorView &vll, size_t start, int &info);
+  double filter(const MatrixView &detrendedDataView,  const Matrix &H, VectorView &vll, size_t start, int &info);
 
 };
 
diff --git a/mex/sources/estimation/LogLikelihoodMain.cc b/mex/sources/estimation/LogLikelihoodMain.cc
index aa31d63d7..cf480a761 100644
--- a/mex/sources/estimation/LogLikelihoodMain.cc
+++ b/mex/sources/estimation/LogLikelihoodMain.cc
@@ -34,8 +34,8 @@ LogLikelihoodMain::LogLikelihoodMain( //const Matrix &data_arg, Vector &deepPara
   : estSubsamples(estiParDesc.estSubsamples),
   logLikelihoodSubSample(dynamicDllFile, estiParDesc, n_endo, n_exo, zeta_fwrd_arg, zeta_back_arg, zeta_mixed_arg, zeta_static_arg, qz_criterium,
                          varobs, riccati_tol, lyapunov_tol, info_arg),
-    vll(estiParDesc.getNumberOfPeriods()) // time dimension size of data
-
+    vll(estiParDesc.getNumberOfPeriods()), // time dimension size of data
+    detrendedData(varobs.size(), estiParDesc.getNumberOfPeriods())
 {
 
 }
@@ -55,9 +55,12 @@ LogLikelihoodMain::compute(Matrix &steadyState, const Vector &estParams, Vector
 
       MatrixConstView dataView(data, 0, estSubsamples[i].startPeriod,
                                data.getRows(), estSubsamples[i].endPeriod-estSubsamples[i].startPeriod+1);
+      MatrixView detrendedDataView(detrendedData, 0, estSubsamples[i].startPeriod,
+                                   data.getRows(), estSubsamples[i].endPeriod-estSubsamples[i].startPeriod+1);
+
       VectorView vllView(vll, estSubsamples[i].startPeriod, estSubsamples[i].endPeriod-estSubsamples[i].startPeriod+1);
       logLikelihood += logLikelihoodSubSample.compute(vSteadyState, dataView, estParams, deepParams,
-                                                      Q, H, vllView, info, start, i);
+                                                      Q, H, vllView, detrendedDataView, info, start, i);
     }
   return logLikelihood;
 };
diff --git a/mex/sources/estimation/LogLikelihoodMain.hh b/mex/sources/estimation/LogLikelihoodMain.hh
index 70fea18bb..0b7af4556 100644
--- a/mex/sources/estimation/LogLikelihoodMain.hh
+++ b/mex/sources/estimation/LogLikelihoodMain.hh
@@ -32,6 +32,7 @@ private:
   std::vector<EstimationSubsample> &estSubsamples; // reference to member of EstimatedParametersDescription
   LogLikelihoodSubSample logLikelihoodSubSample;
   Vector vll;  // vector of all KF step likelihoods
+  Matrix detrendedData;
 
 public:
   virtual ~LogLikelihoodMain();
diff --git a/mex/sources/estimation/LogLikelihoodSubSample.cc b/mex/sources/estimation/LogLikelihoodSubSample.cc
index e1331a96f..125fb3e07 100644
--- a/mex/sources/estimation/LogLikelihoodSubSample.cc
+++ b/mex/sources/estimation/LogLikelihoodSubSample.cc
@@ -44,12 +44,12 @@ LogLikelihoodSubSample::LogLikelihoodSubSample(const std::string &dynamicDllFile
 
 double
 LogLikelihoodSubSample::compute(VectorView &steadyState, const MatrixConstView &dataView, const Vector &estParams, Vector &deepParams,
-                                Matrix &Q, Matrix &H, VectorView &vll, int &info, size_t start, size_t period)
+                                Matrix &Q, Matrix &H, VectorView &vll, MatrixView &detrendedDataView, int &info, size_t start, size_t period)
 {
 
   updateParams(estParams, deepParams, Q, H, period);
   if (info == 0)
-    logLikelihood = kalmanFilter.compute(dataView, steadyState,  Q, H, deepParams, vll, start, period, penalty,  info);
+    logLikelihood = kalmanFilter.compute(dataView, steadyState,  Q, H, deepParams, vll, detrendedDataView, start, period, penalty,  info);
   //  else
   //    logLikelihood+=penalty;
 
diff --git a/mex/sources/estimation/LogLikelihoodSubSample.hh b/mex/sources/estimation/LogLikelihoodSubSample.hh
index 5567f2ac2..ec339b2ca 100644
--- a/mex/sources/estimation/LogLikelihoodSubSample.hh
+++ b/mex/sources/estimation/LogLikelihoodSubSample.hh
@@ -39,7 +39,7 @@ public:
                          const std::vector<size_t> &varobs_arg, double riccati_tol_in, double lyapunov_tol, int &info);
 
   double compute(VectorView &steadyState, const MatrixConstView &dataView, const Vector &estParams, Vector &deepParams,
-                 Matrix &Q, Matrix &H, VectorView &vll, int &info,  size_t start, size_t period);
+                 Matrix &Q, Matrix &H, VectorView &vll, MatrixView &detrendedDataView, int &info,  size_t start, size_t period);
   virtual ~LogLikelihoodSubSample();
 
 private:
-- 
GitLab