From f72b919ba9a88c154a07fc57c002bce2f502918d Mon Sep 17 00:00:00 2001 From: Michel Juillard <michel@debian.home> Date: Thu, 19 Mar 2020 10:01:19 +0100 Subject: [PATCH] change function names fix handling of transpose operator --- src/QrAlgo.jl | 81 ++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 77 insertions(+), 4 deletions(-) diff --git a/src/QrAlgo.jl b/src/QrAlgo.jl index d24d7f8..84a7b1e 100644 --- a/src/QrAlgo.jl +++ b/src/QrAlgo.jl @@ -1,9 +1,11 @@ +module QrAlgo + using LinearAlgebra import LinearAlgebra: BlasInt import LinearAlgebra.BLAS: @blasfunc import LinearAlgebra.LAPACK: liblapack, chklapackerror -export QrWs, dgeqrf_core!, dormrqf_core! +export QrWs, geqrf_core!, ormrqf_core! struct QrWs{T <: Number} tau::Vector{T} @@ -39,7 +41,7 @@ for (geqrf, ormqr, elty) in QrWs(tau, work, lwork, info) end - function dgeqrf_core!(A::StridedMatrix{$elty}, ws::QrWs) + function geqrf_core!(A::StridedMatrix{$elty}, ws::QrWs) mm,nn = size(A) m = Ref{BlasInt}(mm) n = Ref{BlasInt}(nn) @@ -55,14 +57,29 @@ for (geqrf, ormqr, elty) in t2 = Transpose{$elty, <: StridedMatrix} t3 = Adjoint{$elty, <: StridedMatrix} end +end +for (geqrf, ormqr, elty) in + ((:dgeqrf_, :dormqr_, :Float64), + (:sgeqrf_, :sormqr_, :Float32)) + + @eval begin + t1 = StridedMatrix{$elty} + t2 = Transpose{$elty, <: StridedMatrix} + t3 = Adjoint{$elty, <: StridedMatrix} + end + for (elty2, transchar) in ((t1, 'N'), (t2, 'T'), - (t3, 'C')) + (t3, 'T')) + + if (transchar == 'C' && ((elty == Float32) || (elty == Float64))) + transchar = 'T' + end @eval begin - function ormrqf_core!(side::Ref{UInt8}, A::$elty2, + function ormrqf_core!(side::Ref{UInt8}, A::StridedMatrix{$elty}, C::StridedMatrix{$elty}, ws::QrWs) mm,nn = size(C) m = Ref{BlasInt}(mm) @@ -78,6 +95,62 @@ for (geqrf, ormqr, elty) in end end end + + @eval begin + ormrqf_core!(side::Ref{UInt8}, A::Transpose{$elty}, + C::StridedMatrix{$elty}, ws::QrWs) = + ormrqf_core!(side, A.parent, C, ws) + ormrqf_core!(side::Ref{UInt8}, A::Adjoint{$elty}, + C::StridedMatrix{$elty}, ws::QrWs) = + ormrqf_core!(side, A.parent, C, ws) + end end +for (geqrf, ormqr, elty) in + ((:zgeqrf_, :zormqr_, :ComplexF64), + (:cgeqrf_, :cormqr_, :ComplexF32)) + + @eval begin + t1 = StridedMatrix{$elty} + t2 = Transpose{$elty, <: StridedMatrix} + t3 = Adjoint{$elty, <: StridedMatrix} + end + + for (elty2, transchar) in + ((t1, 'N'), + (t2, 'T'), + (t3, 'C')) + + if (transchar == 'C' && ((elty == Float32) || (elty == Float64))) + transchar = 'T' + end + + @eval begin + function ormrqf_core!(side::Ref{UInt8}, A::StridedMatrix{$elty}, + C::StridedMatrix{$elty}, ws::QrWs) + mm,nn = size(C) + m = Ref{BlasInt}(mm) + n = Ref{BlasInt}(nn) + k = Ref{BlasInt}(length(ws.tau)) + RldA = Ref{BlasInt}(max(1,stride(A,2))) + RldC = Ref{BlasInt}(max(1,stride(C,2))) + ccall((@blasfunc($ormqr), liblapack), Nothing, + (Ref{UInt8},Ref{UInt8},Ref{BlasInt},Ref{BlasInt},Ref{BlasInt},Ptr{$elty},Ref{BlasInt}, + Ptr{$elty},Ptr{$elty},Ref{BlasInt},Ptr{$elty},Ref{BlasInt},Ref{BlasInt}), + side,$transchar,m,n,k,A,RldA,ws.tau,C,RldC,ws.work,ws.lwork,ws.info) + chklapackerror(ws.info[]) + end + end + end + @eval begin + ormrqf_core!(side::Ref{UInt8}, A::Transpose{$elty}, + C::StridedMatrix{$elty}, ws::QrWs) = + ormrqf_core!(side, A.parent, C, ws) + ormrqf_core!(side::Ref{UInt8}, A::Adjoint{$elty}, + C::StridedMatrix{$elty}, ws::QrWs) = + ormrqf_core!(side, A.parent, C, ws) + end +end + +end -- GitLab