diff --git a/src/QrAlgo.jl b/src/QrAlgo.jl index 71129ba8fdaf130a7bb58a115e3e798df7c00bf1..0d37edaac0d75306aea176bde53a77446687de6f 100644 --- a/src/QrAlgo.jl +++ b/src/QrAlgo.jl @@ -5,9 +5,11 @@ import LinearAlgebra: BlasInt import LinearAlgebra.BLAS: @blasfunc import LinearAlgebra.LAPACK: liblapack, chklapackerror -export QrWs, QrpWs, geqrf_core!, geqp3_core!, ormqr_core! +export QrWs, QrpWs, geqrf_core!, geqp3!, ormqr_core! -struct QrWs{T <: Number} +abstract type QR end + +struct QrWs{T <: Number} <: QR tau::Vector{T} work::Vector{T} lwork::Ref{BlasInt} @@ -65,7 +67,7 @@ for (geqrf, ormqr, elty) in @eval begin function ormqr_core!(side::Char, A::StridedMatrix{$elty}, - C::StridedMatrix{$elty}, ws::QrWs) + C::StridedMatrix{$elty}, ws::QR) mm,nn = size(C) m = Ref{BlasInt}(mm) n = Ref{BlasInt}(nn) @@ -89,7 +91,7 @@ for (geqrf, ormqr, elty) in @eval begin function ormqr_core!(side::Char, A::$elty2, - C::StridedMatrix{$elty}, ws::QrWs) + C::StridedMatrix{$elty}, ws::QR) mm,nn = size(C) m = Ref{BlasInt}(mm) n = Ref{BlasInt}(nn) @@ -112,7 +114,7 @@ for (geqrf, ormqr, elty) in @eval begin function ormqr_core!(side::Char, A::StridedMatrix{$elty}, - C::StridedMatrix{$elty}, ws::QrWs) + C::StridedMatrix{$elty}, ws::QR) mm,nn = size(C) m = Ref{BlasInt}(mm) n = Ref{BlasInt}(nn) @@ -138,7 +140,7 @@ for (geqrf, ormqr, elty) in @eval begin function ormqr_core!(side::Char, A::$elty2, - C::StridedMatrix{$elty}, ws::QrWs) + C::StridedMatrix{$elty}, ws::QR) mm,nn = size(C) m = Ref{BlasInt}(mm) n = Ref{BlasInt}(nn) @@ -155,11 +157,11 @@ for (geqrf, ormqr, elty) in end end -struct QrpWs{T <: Number} +struct QrpWs{T <: Number} <: QR tau::Vector{T} jpvt::Vector{BlasInt} work::Vector{T} - lwork::Ref{BlasInt} + lwork::BlasInt info::Ref{BlasInt} end @@ -171,37 +173,37 @@ for (geqp3, elty) in @eval begin - function QrpWs(A::StridedMatrix{T}) where T <: $elty - nn, mm = size(A) - m = Ref{BlasInt}(mm) - n = Ref{BlasInt}(nn) - RldA = Ref{BlasInt}(max(1,stride(A,2))) - jpvt = Vector{BlasInt}(undef, mm) - tau = Vector{T}(undef, min(nn,mm)) - work = Vector{T}(undef, 1) - lwork = Ref{BlasInt}(-1) - info = Ref{BlasInt}(0) + function QrpWs(A::StridedMatrix{$elty}) + m, n = size(A) + RldA = BlasInt(max(1,stride(A,2))) + jpvt = zeros(BlasInt, n) + tau = Vector{$elty}(undef, min(m, n)) + work = Vector{$elty}(undef, 1) + lwork = BlasInt(-1) + info = Ref{BlasInt}() ccall((@blasfunc($geqp3), liblapack), Nothing, - (Ref{BlasInt}, Ref{BlasInt}, Ptr{T}, Ref{BlasInt}, - Ptr{BlasInt}, Ptr{T}, Ptr{T}, Ref{BlasInt}, Ref{BlasInt}), + (Ref{BlasInt}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}, + Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ref{BlasInt}, Ref{BlasInt}), m, n, A, RldA, jpvt, tau, work, lwork, info) chklapackerror(info[]) - lwork = Ref{BlasInt}(real(work[1])) - work = Array{T}(undef, lwork[]) + lwork = BlasInt(real(work[1])) + @show lwork + @show m + @show n + @show RldA + @show info[] + work = resize!(work, lwork) QrpWs(tau, jpvt, work, lwork, info) end - function geqp3_core!(A::StridedMatrix{$elty}, ws::QrpWs) - mm,nn = size(A) - m = Ref{BlasInt}(mm) - n = Ref{BlasInt}(nn) - RldA = Ref{BlasInt}(max(1,stride(A,2))) + function geqp3!(A::StridedMatrix{$elty}, ws::QrpWs) + m, n = size(A) + RldA = BlasInt(max(1,stride(A,2))) ccall((@blasfunc($geqp3), liblapack), Nothing, - (Ref{BlasInt},Ref{BlasInt},Ptr{$elty},Ref{BlasInt}, - Ptr{BlasInt}, Ptr{$elty},Ptr{$elty},Ref{BlasInt},Ref{BlasInt}), - m,n,A,RldA,ws.jpvt, ws.tau,ws.work,ws.lwork,ws.info) + (Ref{BlasInt}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}, + Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ref{BlasInt}, Ref{BlasInt}), + m, n, A, RldA, ws.jpvt, ws.tau, ws.work, ws.lwork, ws.info) chklapackerror(ws.info[]) - println(ws.jpvt) end t1 = StridedMatrix{$elty}