From 53907ffbd3b50ffc00491562341fd17bbf543ff3 Mon Sep 17 00:00:00 2001
From: Johannes Pfeifer <jpfeifer@gmx.de>
Date: Wed, 11 Sep 2024 19:08:51 +0200
Subject: [PATCH] SMC: Add support for generate_trace_plots.m

Closes #1932
---
 matlab/estimation/GetAllPosteriorDraws.m |  6 +++-
 matlab/estimation/generate_trace_plots.m | 24 +++++++++------
 matlab/estimation/trace_plot.m           | 39 ++++++++++++++++--------
 tests/estimation/hssmc/fs2000.mod        |  2 ++
 4 files changed, 48 insertions(+), 23 deletions(-)

diff --git a/matlab/estimation/GetAllPosteriorDraws.m b/matlab/estimation/GetAllPosteriorDraws.m
index fd6cdce1cd..9a4b67b3e4 100644
--- a/matlab/estimation/GetAllPosteriorDraws.m
+++ b/matlab/estimation/GetAllPosteriorDraws.m
@@ -41,7 +41,11 @@ if ishssmc(options_)
     % Load draws from the posterior distribution
     pfiles = dir(sprintf('%s/hssmc/particles-*.mat', dname));
     posterior = load(sprintf('%s/hssmc/particles-%u-%u.mat', dname, length(pfiles), length(pfiles)));
-    draws = transpose(posterior.particles(column,:));
+    if column==0
+        draws = posterior.tlogpostkernel;
+    else
+        draws = transpose(posterior.particles(column,:));
+    end
 elseif isdime(options_)
     posterior = load(sprintf('%s%s%s%schains.mat', dname, filesep(), 'dime', filesep()));
     tune = posterior.tune;
diff --git a/matlab/estimation/generate_trace_plots.m b/matlab/estimation/generate_trace_plots.m
index b6e6c8b62f..8476ed36fe 100644
--- a/matlab/estimation/generate_trace_plots.m
+++ b/matlab/estimation/generate_trace_plots.m
@@ -30,15 +30,21 @@ function generate_trace_plots(chain_number)
 
 global M_ options_ estim_params_
 
-if issmc(options_)
-    error('generate_trace_plots:: SMC methods do not support trace plots')
-end
-
-% Get informations about the posterior draws:
-MetropolisFolder = CheckPath('metropolis', M_.dname);
-record=load_last_mh_history_file(MetropolisFolder, M_.fname);
-if max(chain_number)>record.Nblck
-    error('generate_trace_plots:: chain number is bigger than existing number of chains')
+if ~issmc(options_)
+    % Get informations about the posterior draws:
+    MetropolisFolder = CheckPath('metropolis', M_.dname);
+    record=load_last_mh_history_file(MetropolisFolder, M_.fname);
+    if max(chain_number)>record.Nblck
+        error('generate_trace_plots:: chain number is bigger than existing number of chains')
+    end
+else
+    if ishssmc(options_)
+       if max(chain_number)>1
+           error('generate_trace_plots:: HSSMC only has one chain')           
+       end
+    elseif isdime(options_)
+        error('generate_trace_plots:: DIME does not support generate_trace_plots')
+    end
 end
 
 trace_plot(options_, M_, estim_params_, 'PosteriorDensity', chain_number)
diff --git a/matlab/estimation/trace_plot.m b/matlab/estimation/trace_plot.m
index 9de1860b90..a3bb401abb 100644
--- a/matlab/estimation/trace_plot.m
+++ b/matlab/estimation/trace_plot.m
@@ -52,22 +52,35 @@ if isempty(column)
     return
 end
 
-% Get informations about the posterior draws:
-MetropolisFolder = CheckPath('metropolis',M_.dname);
-record=load_last_mh_history_file(MetropolisFolder, M_.fname);
-
-FirstMhFile = 1;
-FirstLine = 1;
-TotalNumberOfMhFiles = sum(record.MhDraws(:,2));
-TotalNumberOfMhDraws = sum(record.MhDraws(:,1));
-[mh_nblck] = size(record.LastParameters,2);
-clear record;
-
-n_nblocks_to_plot=length(blck);
+if ~issmc(options_)
+    % Get informations about the posterior draws:
+    MetropolisFolder = CheckPath('metropolis',M_.dname);
+    record=load_last_mh_history_file(MetropolisFolder, M_.fname);
+
+    FirstMhFile = 1;
+    FirstLine = 1;
+    TotalNumberOfMhFiles = sum(record.MhDraws(:,2));
+    TotalNumberOfMhDraws = sum(record.MhDraws(:,1));
+    [mh_nblck] = size(record.LastParameters,2);
+    clear record;
+
+    n_nblocks_to_plot=length(blck);
+else
+    if ishssmc(options_)
+        n_nblocks_to_plot=1;
+    elseif isdime(options_)
+        error('trace_plot:: DIME does not support the trace_plot command')
+    end
+end
 
 if n_nblocks_to_plot==1
 % Get all the posterior draws:
-    PosteriorDraws = GetAllPosteriorDraws(options_, M_.dname,M_.fname,column, FirstMhFile, FirstLine, TotalNumberOfMhFiles, TotalNumberOfMhDraws, mh_nblck, blck);
+    if ishssmc(options_)
+        PosteriorDraws = GetAllPosteriorDraws(options_, M_.dname,[],column);
+        TotalNumberOfMhDraws=length(PosteriorDraws);
+    else
+        PosteriorDraws = GetAllPosteriorDraws(options_, M_.dname,M_.fname,column, FirstMhFile, FirstLine, TotalNumberOfMhFiles, TotalNumberOfMhDraws, mh_nblck, blck);
+    end
 else
     PosteriorDraws=NaN(TotalNumberOfMhDraws,n_nblocks_to_plot);
     save_string='';
diff --git a/tests/estimation/hssmc/fs2000.mod b/tests/estimation/hssmc/fs2000.mod
index 08a527b481..4302ce3682 100644
--- a/tests/estimation/hssmc/fs2000.mod
+++ b/tests/estimation/hssmc/fs2000.mod
@@ -91,3 +91,5 @@ estimation(order=1, datafile='../fsdat_simul.m', nobs=192, loglinear,
                                       'target', .25),
 bayesian_irf, smoother, moments_varendo,consider_all_endogenous
 );
+
+generate_trace_plots(1);
\ No newline at end of file
-- 
GitLab