Skip to content
Snippets Groups Projects
Commit cffaeffc authored by MichelJuillard's avatar MichelJuillard
Browse files

improve QR algotithm

parent 30e19594
No related branches found
No related tags found
No related merge requests found
...@@ -5,9 +5,11 @@ import LinearAlgebra: BlasInt ...@@ -5,9 +5,11 @@ import LinearAlgebra: BlasInt
import LinearAlgebra.BLAS: @blasfunc import LinearAlgebra.BLAS: @blasfunc
import LinearAlgebra.LAPACK: liblapack, chklapackerror import LinearAlgebra.LAPACK: liblapack, chklapackerror
export QrWs, QrpWs, geqrf_core!, geqp3_core!, ormqr_core! export QrWs, QrpWs, geqrf_core!, geqp3!, ormqr_core!
struct QrWs{T <: Number} abstract type QR end
struct QrWs{T <: Number} <: QR
tau::Vector{T} tau::Vector{T}
work::Vector{T} work::Vector{T}
lwork::Ref{BlasInt} lwork::Ref{BlasInt}
...@@ -65,7 +67,7 @@ for (geqrf, ormqr, elty) in ...@@ -65,7 +67,7 @@ for (geqrf, ormqr, elty) in
@eval begin @eval begin
function ormqr_core!(side::Char, A::StridedMatrix{$elty}, function ormqr_core!(side::Char, A::StridedMatrix{$elty},
C::StridedMatrix{$elty}, ws::QrWs) C::StridedMatrix{$elty}, ws::QR)
mm,nn = size(C) mm,nn = size(C)
m = Ref{BlasInt}(mm) m = Ref{BlasInt}(mm)
n = Ref{BlasInt}(nn) n = Ref{BlasInt}(nn)
...@@ -89,7 +91,7 @@ for (geqrf, ormqr, elty) in ...@@ -89,7 +91,7 @@ for (geqrf, ormqr, elty) in
@eval begin @eval begin
function ormqr_core!(side::Char, A::$elty2, function ormqr_core!(side::Char, A::$elty2,
C::StridedMatrix{$elty}, ws::QrWs) C::StridedMatrix{$elty}, ws::QR)
mm,nn = size(C) mm,nn = size(C)
m = Ref{BlasInt}(mm) m = Ref{BlasInt}(mm)
n = Ref{BlasInt}(nn) n = Ref{BlasInt}(nn)
...@@ -112,7 +114,7 @@ for (geqrf, ormqr, elty) in ...@@ -112,7 +114,7 @@ for (geqrf, ormqr, elty) in
@eval begin @eval begin
function ormqr_core!(side::Char, A::StridedMatrix{$elty}, function ormqr_core!(side::Char, A::StridedMatrix{$elty},
C::StridedMatrix{$elty}, ws::QrWs) C::StridedMatrix{$elty}, ws::QR)
mm,nn = size(C) mm,nn = size(C)
m = Ref{BlasInt}(mm) m = Ref{BlasInt}(mm)
n = Ref{BlasInt}(nn) n = Ref{BlasInt}(nn)
...@@ -138,7 +140,7 @@ for (geqrf, ormqr, elty) in ...@@ -138,7 +140,7 @@ for (geqrf, ormqr, elty) in
@eval begin @eval begin
function ormqr_core!(side::Char, A::$elty2, function ormqr_core!(side::Char, A::$elty2,
C::StridedMatrix{$elty}, ws::QrWs) C::StridedMatrix{$elty}, ws::QR)
mm,nn = size(C) mm,nn = size(C)
m = Ref{BlasInt}(mm) m = Ref{BlasInt}(mm)
n = Ref{BlasInt}(nn) n = Ref{BlasInt}(nn)
...@@ -155,11 +157,11 @@ for (geqrf, ormqr, elty) in ...@@ -155,11 +157,11 @@ for (geqrf, ormqr, elty) in
end end
end end
struct QrpWs{T <: Number} struct QrpWs{T <: Number} <: QR
tau::Vector{T} tau::Vector{T}
jpvt::Vector{BlasInt} jpvt::Vector{BlasInt}
work::Vector{T} work::Vector{T}
lwork::Ref{BlasInt} lwork::BlasInt
info::Ref{BlasInt} info::Ref{BlasInt}
end end
...@@ -171,37 +173,37 @@ for (geqp3, elty) in ...@@ -171,37 +173,37 @@ for (geqp3, elty) in
@eval begin @eval begin
function QrpWs(A::StridedMatrix{T}) where T <: $elty function QrpWs(A::StridedMatrix{$elty})
nn, mm = size(A) m, n = size(A)
m = Ref{BlasInt}(mm) RldA = BlasInt(max(1,stride(A,2)))
n = Ref{BlasInt}(nn) jpvt = zeros(BlasInt, n)
RldA = Ref{BlasInt}(max(1,stride(A,2))) tau = Vector{$elty}(undef, min(m, n))
jpvt = Vector{BlasInt}(undef, mm) work = Vector{$elty}(undef, 1)
tau = Vector{T}(undef, min(nn,mm)) lwork = BlasInt(-1)
work = Vector{T}(undef, 1) info = Ref{BlasInt}()
lwork = Ref{BlasInt}(-1)
info = Ref{BlasInt}(0)
ccall((@blasfunc($geqp3), liblapack), Nothing, ccall((@blasfunc($geqp3), liblapack), Nothing,
(Ref{BlasInt}, Ref{BlasInt}, Ptr{T}, Ref{BlasInt}, (Ref{BlasInt}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt},
Ptr{BlasInt}, Ptr{T}, Ptr{T}, Ref{BlasInt}, Ref{BlasInt}), Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ref{BlasInt}, Ref{BlasInt}),
m, n, A, RldA, jpvt, tau, work, lwork, info) m, n, A, RldA, jpvt, tau, work, lwork, info)
chklapackerror(info[]) chklapackerror(info[])
lwork = Ref{BlasInt}(real(work[1])) lwork = BlasInt(real(work[1]))
work = Array{T}(undef, lwork[]) @show lwork
@show m
@show n
@show RldA
@show info[]
work = resize!(work, lwork)
QrpWs(tau, jpvt, work, lwork, info) QrpWs(tau, jpvt, work, lwork, info)
end end
function geqp3_core!(A::StridedMatrix{$elty}, ws::QrpWs) function geqp3!(A::StridedMatrix{$elty}, ws::QrpWs)
mm,nn = size(A) m, n = size(A)
m = Ref{BlasInt}(mm) RldA = BlasInt(max(1,stride(A,2)))
n = Ref{BlasInt}(nn)
RldA = Ref{BlasInt}(max(1,stride(A,2)))
ccall((@blasfunc($geqp3), liblapack), Nothing, ccall((@blasfunc($geqp3), liblapack), Nothing,
(Ref{BlasInt}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}, (Ref{BlasInt}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt},
Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ref{BlasInt}, Ref{BlasInt}), Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ref{BlasInt}, Ref{BlasInt}),
m, n, A, RldA, ws.jpvt, ws.tau, ws.work, ws.lwork, ws.info) m, n, A, RldA, ws.jpvt, ws.tau, ws.work, ws.lwork, ws.info)
chklapackerror(ws.info[]) chklapackerror(ws.info[])
println(ws.jpvt)
end end
t1 = StridedMatrix{$elty} t1 = StridedMatrix{$elty}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment