Commit 056d5487 authored by george's avatar george
Browse files

- Extending basic C++ KF with performance boosting algorithm for checking if P...

- Extending basic C++ KF with performance boosting algorithm for checking if P is steady within tolerance and performing reduced operations if so;
- added new functions to GeneralMatrix::isDiff* to check if matrices are different within tolerance level
- adjusting test and dll drivers accordingly


git-svn-id: https://www.dynare.org/svn/dynare/trunk@2843 ac1d8469-bf42-47a9-8791-bf33cf982152
parent b7f96a47
......@@ -474,14 +474,15 @@ void SmootherResults::exportV(GeneralMatrix&out)const
BasicKalmanTask::BasicKalmanTask(const GeneralMatrix&d,const GeneralMatrix&ZZ,
const GeneralMatrix&HH,const GeneralMatrix&TT,
const GeneralMatrix&RR,const GeneralMatrix&QQ,
const StateInit&init_state)
const StateInit&init_state, const double rTol)
: // ssf(Z,H,T,R,Q),
data(d), Zt(*(new ConstGeneralMatrix(ZZ))),
Ht(*(new ConstGeneralMatrix(HH))),
Tt(*(new ConstGeneralMatrix(TT))),
Rt(*(new ConstGeneralMatrix(RR))),
Qt(*(new ConstGeneralMatrix(QQ))),
init(init_state)
init(init_state),
riccatiTol(rTol)
{
TS_RAISE_IF(d.numRows()!=Zt.numRows(),
"Data not compatible with BasicKalmanTask constructor");
......@@ -492,9 +493,9 @@ init(init_state)
BasicKalmanTask::BasicKalmanTask(const GeneralMatrix&d,const ConstGeneralMatrix&ZZ,
const ConstGeneralMatrix&HH,const ConstGeneralMatrix&TT,
const ConstGeneralMatrix&RR,const ConstGeneralMatrix&QQ,
const StateInit&init_state)
const StateInit&init_state, const double rTol)
: // ssf(Z,H,T,R,Q),
data(d), Zt(ZZ), Ht(HH), Tt(TT), Rt(RR), Qt(QQ),init(init_state)
data(d), Zt(ZZ), Ht(HH), Tt(TT), Rt(RR), Qt(QQ),init(init_state), riccatiTol(rTol)
{
TS_RAISE_IF(d.numRows()!=Zt.numRows(),
"Data not compatible with BasicKalmanTask constructor");
......@@ -686,6 +687,7 @@ BasicKalmanTask::filterNonDiffuse(const Vector&a,const GeneralMatrix&P,
PLUFact Ftinv(Ht.numRows(), Ht.numCols());
GeneralMatrix Lt(Tt);
GeneralMatrix PtLttrans(m,m);
GeneralMatrix PtOld(m,m);
GeneralMatrix Mt(m,n);
GeneralMatrix Kt(m,n);
GeneralMatrix KtTransTmp(n,m); // perm space for temp Kt trans
......@@ -703,8 +705,8 @@ BasicKalmanTask::filterNonDiffuse(const Vector&a,const GeneralMatrix&P,
int p;
int t= 1;
double vFinvv,ll;
int nonsteady=1;
for(;t<=data.numCols()&&nonsteady;++t)
bool nonSteady=true;
for(;t<=data.numCols()&&nonSteady;++t)
{
// ConstVector yt(data,t-1);
......@@ -787,6 +789,7 @@ BasicKalmanTask::filterNonDiffuse(const Vector&a,const GeneralMatrix&P,
/*****************
This calculates $$P_{t+1} = T_tP_tL_t^T + R_tQ_tR_t^T.$$
*****************/
PtOld=Pt;
// GeneralMatrix PtLttrans(Pt,Lt,"trans");
// DGEMM: C := alpha*op( A )*op( B ) + beta*C,
BLAS_dgemm("N", "T", &m, &m, &m, &alpha, Pt.base(), &m,
......@@ -815,13 +818,58 @@ BasicKalmanTask::filterNonDiffuse(const Vector&a,const GeneralMatrix&P,
BLAS_dgemm("N", "N", &m, &m, &rcols, &alpha, Rt.base(), &m,
QtRttrans.base(), &rcols, &alpha, Pt.base(), &m);
}
if (PtOld.isDiffSym(Pt, riccatiTol)==false)
nonSteady=false;
}
}
// for(;t<=data.numCols();t++)
// {
// ConstVector yt(data,t-1);
// }
// Steady
double detF=p*log(2*M_PI)+Ftinv.getLogDeterminant();
#ifdef DEBUG
if (nonSteady==false)
mexPrintf("Basickalman_filter Steady at t=%d / %d \n", t,data.numCols());
#endif
for(;t<=data.numCols();++t)
{
/*****************
This calculates $$v_t = y_t - Z_t*a_t.$$
*****************/
memcpy(vt.base(), &(data.get(0,t-1)), n*sizeof(double));
// Zt.multsVec(vt,at);
BLAS_dgemv("N", &n, &m, &neg_alpha, Zt.base(), &n, at.base(),
&inc, &alpha, vt.base(), &inc);
/*****************
Here we calc likelihood and store results.
*****************/
// double ll= calcStepLogLik(Ftinv,vt);
Finvv=vt;
Ftinv.multInvLeft(Finvv);
vFinvv= vt.dot(Finvv);
ll=-0.5*(detF+vFinvv);
// fres.set(t,Ftinv,vt,Lt,at,Pt,ll);
(*vll)[t-1]=ll;
if (t>start) loglik+=ll;
if(t<data.numCols())
{
/*****************
This calculates $$a_{t+1} = T_ta_t + K_tv_t.$$
*****************/
atsave=at;
Tt.multVec(0.0,at,1.0,atsave);
Kt.multVec(1.0,at,1.0,ConstVector(vt));
}
}
return loglik;
}
......
......@@ -164,11 +164,12 @@ class BasicKalmanTask{
const ConstGeneralMatrix &Rt;
const ConstGeneralMatrix &Qt;
const StateInit&init;
const double riccatiTol;
public:
BasicKalmanTask(const GeneralMatrix&d,const GeneralMatrix&ZZ,
const GeneralMatrix&HH,const GeneralMatrix&TT,
const GeneralMatrix&RR,const GeneralMatrix&QQ,
const StateInit&init_state);
const StateInit&init_state, const double riccatiTol);
// BasicKalmanTask(const GeneralMatrix&d,const TMatrix&Z,
// const TMatrix&H,const TMatrix&T,
// const TMatrix&R,const TMatrix&Q,
......@@ -176,7 +177,7 @@ class BasicKalmanTask{
BasicKalmanTask(const GeneralMatrix&d,const ConstGeneralMatrix&ZZ,
const ConstGeneralMatrix&HH,const ConstGeneralMatrix&TT,
const ConstGeneralMatrix&RR,const ConstGeneralMatrix&QQ,
const StateInit&init_state);
const StateInit&init_state, const double riccatiTol);
virtual ~BasicKalmanTask();
// double filter(int&per,int&d)const;
// double filter(int&per,int&d, int start, std::vector<double>* vll)const;
......
......@@ -79,7 +79,8 @@ extern "C" {
mexErrMsgTxt("Must have 1, 2, 3 or 4 output parameters.\n");
//int start = 1; // default start of likelihood calculation
// test for univariate case
bool uni = false;
bool uni = false;
double riccatiTol=0.000001;
const mxArray* const last = prhs[nrhs-1];
if (mxIsChar(last)
&& ((*mxGetChars(last)) == 'u' || (*mxGetChars(last)) == 'U'))
......@@ -151,14 +152,14 @@ extern "C" {
else // basic Kalman
{
init = new StateInit(P, a.getData());
BasicKalmanTask bkt(Y, Z, H, T, R, Q, *init);
BasicKalmanTask bkt(Y, Z, H, T, R, Q, *init, riccatiTol);
#ifdef TIMING_LOOP
for (int tt=0;tt<1000;++tt)
{
#endif
loglik = bkt.filter( per, d, (start-1), vll);
#ifdef DEBUG
mexPrintf("Basickalman_filter: loglik=%d \n", loglik);
mexPrintf("Basickalman_filter: loglik=%f \n", loglik);
#endif
#ifdef TIMING_LOOP
}
......
......@@ -201,7 +201,7 @@ int main(int argc, char* argv[])
}
}
***********/
double riccatiTol=0.000001;
int start = 1;
GeneralMatrix Z(Zmat, 4, 8);
GeneralMatrix a(8, 1);
......@@ -263,7 +263,7 @@ int main(int argc, char* argv[])
else // basic Kalman
{
init = new StateInit(P, a.getData());
BasicKalmanTask bkt(Y, Z, H, T, R, Q, *init);
BasicKalmanTask bkt(Y, Z, H, T, R, Q, *init, riccatiTol);
#ifdef TIMING_LOOP
for (int tt=0;tt<10000;++tt)
{
......
......@@ -294,6 +294,30 @@ void GeneralMatrix::add(double a, const ConstGeneralMatrix& m, const char* dum)
get(i,j) += a*m.get(j,i);
}
bool GeneralMatrix::isDiff(const GeneralMatrix& m, const double tol=0.0)const
{
if (m.numRows() != rows || m.numCols() != cols)
throw SYLV_MES_EXCEPTION("Matrix has different size in GeneralMatrix::isDiff.");
for (int i = 0; i < rows; i++)
for (int j = 0; j < cols; j++)
if (fabs(get(i,j) - m.get(i,j))>tol)
return true;
return false;
}
bool GeneralMatrix::isDiffSym(const GeneralMatrix& m, const double tol=0.0)const
{
if (m.numRows() != rows || m.numCols() != cols || m.numRows() != cols || m.numCols() != rows)
throw SYLV_MES_EXCEPTION("Matrix has different size or not square in GeneralMatrix::isDiffSym.");
for (int i = 0; i < cols; i++)
for (int j = 0; i+j < cols ; j++) // traverse the upper triangle only
if (fabs(get(j,j+i) - m.get(j,j+i))>tol) // along diagonals where higher changes occur
return true;
return false;
}
/* x = scalar(a)*x + scalar(b)*this*d */
void GeneralMatrix::multVec(double a, Vector& x, double b, const ConstVector& d) const
{
......
......@@ -62,7 +62,6 @@ class ConstGeneralMatrix {
bool isFinite() const;
/** Returns true of the matrix is exactly zero. */
bool isZero() const;
virtual void print() const;
protected:
void multInvLeft(const char* trans, int mrows, int mcols, int mld, double* d) const;
......@@ -247,6 +246,12 @@ public:
void add(double a, const GeneralMatrix& m, const char* dum)
{add(a, ConstGeneralMatrix(m), dum);}
/* Returns true if this and m matrices are different for more than tolerance tol */
bool isDiff(const GeneralMatrix& m, const double tol)const;
bool isDiffSym(const GeneralMatrix& m, const double tol)const;
bool isDiffUpprTriang(const GeneralMatrix& m, const double tol=0.0)const
{return isDiffSym(m, tol);}
bool isFinite() const
{return (ConstGeneralMatrix(*this)).isFinite();}
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment