diff --git a/docs/make.jl b/docs/make.jl index 94883011b950609b4eb1fa4f82bf85a853ea3213..4d7b3bc4442da7dd8297f88b99b72781db2d26a2 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,3 +1,3 @@ using Documenter, FastLapackInterface -makedocs(sitename="MyDocumentation") +makedocs(sitename = "MyDocumentation") diff --git a/src/LinSolveAlgo.jl b/src/LinSolveAlgo.jl index 0088d6d8b3b426fd3c516586bb1c7c3c8ad6d15b..c626dc65a4f09086dee2d9ef44842113711cfbfa 100644 --- a/src/LinSolveAlgo.jl +++ b/src/LinSolveAlgo.jl @@ -7,44 +7,58 @@ const libblastrampoline = "libblastrampoline" using LinearAlgebra import LinearAlgebra.BlasInt import LinearAlgebra.BLAS.@blasfunc -import LinearAlgebra.LAPACK: chklapackerror +import LinearAlgebra.LAPACK: chklapackerror export LinSolveWs, linsolve_core!, linsolve_core_no_lu!, lu! -struct LinSolveWs{T <: Number, U <: Integer} +struct LinSolveWs{T<:Number,U<:Integer} lu::Vector{T} ipiv::Vector{BlasInt} - function LinSolveWs{T, U}(n::U) where {T <: Number, U <: Integer} - lu = zeros(T, n*n) + function LinSolveWs{T,U}(n::U) where {T<:Number,U<:Integer} + lu = zeros(T, n * n) ipiv = zeros(BlasInt, n) new(lu, ipiv) end end # Float64 is the default -LinSolveWs(n) = LinSolveWs{Float64, Int64}(n) +LinSolveWs(n) = LinSolveWs{Float64,Int64}(n) strides(a::Adjoint) = strides(a.parent) -for (getrf, getrs, elty) in - ((:dgetrf_, :dgetrs_, :Float64), - (:sgetrf_, :sgetrs_, :Float32), - (:zgetrf_, :zgetrs_, :ComplexF64), - (:cgetrf_, :cgetrs_, :ComplexF32)) +for (getrf, getrs, elty) in ( + (:dgetrf_, :dgetrs_, :Float64), + (:sgetrf_, :sgetrs_, :Float32), + (:zgetrf_, :zgetrs_, :ComplexF64), + (:cgetrf_, :cgetrs_, :ComplexF32), +) @eval begin - function lu!(a::StridedMatrix{$elty}, - ws::LinSolveWs) + function lu!(a::StridedMatrix{$elty}, ws::LinSolveWs) copyto!(ws.lu, a) - mm,nn = size(a) + mm, nn = size(a) m = Ref{BlasInt}(mm) n = Ref{BlasInt}(nn) # ws.lu isn't a view and has continuous storage - lda = Ref{BlasInt}(max(1,mm)) + lda = Ref{BlasInt}(max(1, mm)) info = Ref{BlasInt}(0) - ccall((@blasfunc($getrf), libblastrampoline), Cvoid, - (Ref{BlasInt},Ref{BlasInt},Ptr{$elty},Ref{BlasInt}, - Ptr{BlasInt},Ref{BlasInt}), - m, n, ws.lu, lda, ws.ipiv, info) + ccall( + (@blasfunc($getrf), libblastrampoline), + Cvoid, + ( + Ref{BlasInt}, + Ref{BlasInt}, + Ptr{$elty}, + Ref{BlasInt}, + Ptr{BlasInt}, + Ref{BlasInt}, + ), + m, + n, + ws.lu, + lda, + ws.ipiv, + info, + ) if info[] != 0 chklapackerror(info[]) end @@ -52,41 +66,57 @@ for (getrf, getrs, elty) in t1 = StridedMatrix{$elty} - t2 = Transpose{$elty, <: StridedMatrix} - t3 = Adjoint{$elty, <: StridedMatrix} + t2 = Transpose{$elty,<:StridedMatrix} + t3 = Adjoint{$elty,<:StridedMatrix} end - - for (elty2, transchar) in - ((t1, 'N'), - (t2, 'T'), - (t3, 'C')) - + + for (elty2, transchar) in ((t1, 'N'), (t2, 'T'), (t3, 'C')) + @eval begin - function linsolve_core_no_lu!(a::$elty2, - b::StridedVecOrMat{$elty}, - ws::LinSolveWs) - mm,nn = size(a) + function linsolve_core_no_lu!( + a::$elty2, + b::StridedVecOrMat{$elty}, + ws::LinSolveWs, + ) + mm, nn = size(a) m = Ref{BlasInt}(mm) n = Ref{BlasInt}(nn) - nhrs = Ref{BlasInt}(size(b,2)) + nhrs = Ref{BlasInt}(size(b, 2)) # ws.lu isn't a view and has continuous storage - lda = Ref{BlasInt}(max(1,mm)) - ldb = Ref{BlasInt}(max(1,stride(b,2))) + lda = Ref{BlasInt}(max(1, mm)) + ldb = Ref{BlasInt}(max(1, stride(b, 2))) info = Ref{BlasInt}(0) - ccall((@blasfunc($getrs), libblastrampoline), Cvoid, - (Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{$elty}, - Ref{BlasInt}, Ptr{BlasInt}, Ptr{$elty}, Ref{BlasInt}, - Ref{BlasInt}), - $transchar, n, nhrs, ws.lu, lda, ws.ipiv, b, ldb, info) + ccall( + (@blasfunc($getrs), libblastrampoline), + Cvoid, + ( + Ref{UInt8}, + Ref{BlasInt}, + Ref{BlasInt}, + Ptr{$elty}, + Ref{BlasInt}, + Ptr{BlasInt}, + Ptr{$elty}, + Ref{BlasInt}, + Ref{BlasInt}, + ), + $transchar, + n, + nhrs, + ws.lu, + lda, + ws.ipiv, + b, + ldb, + info, + ) if info[] != 0 chklapackerror(info[]) end end - function linsolve_core!(a::$elty2, - b::StridedVecOrMat{$elty}, - ws::LinSolveWs) + function linsolve_core!(a::$elty2, b::StridedVecOrMat{$elty}, ws::LinSolveWs) lu!(a, ws) linsolve_core_no_lu!(a, b, ws) diff --git a/src/QrAlgo.jl b/src/QrAlgo.jl index a62522c7867db8d418efc473bada61ca358c33a5..42f03999cd009c56399a3c287f5ef4f0bd116222 100644 --- a/src/QrAlgo.jl +++ b/src/QrAlgo.jl @@ -5,40 +5,59 @@ const libblastrampoline = "libblastrampoline" using LinearAlgebra import LinearAlgebra: BlasInt import LinearAlgebra.BLAS: @blasfunc -import LinearAlgebra.LAPACK: chklapackerror +import LinearAlgebra.LAPACK: chklapackerror export QrWs, QrpWs, geqrf_core!, geqp3!, ormqr_core! abstract type QR end -struct QrWs{T <: Number} <: QR +struct QrWs{T<:Number} <: QR tau::Vector{T} work::Vector{T} lwork::Ref{BlasInt} info::Ref{BlasInt} end -for (geqrf, ormqr, elty) in - ((:dgeqrf_, :dormqr_, :Float64), - (:sgeqrf_, :sormqr_, :Float32), - (:zgeqrf_, :zormqr_, :ComplexF64), - (:cgeqrf_, :cormqr_, :ComplexF32)) +for (geqrf, ormqr, elty) in ( + (:dgeqrf_, :dormqr_, :Float64), + (:sgeqrf_, :sormqr_, :Float32), + (:zgeqrf_, :zormqr_, :ComplexF64), + (:cgeqrf_, :cormqr_, :ComplexF32), +) @eval begin - function QrWs(A::StridedMatrix{T}) where T <: $elty + function QrWs(A::StridedMatrix{T}) where {T<:$elty} nn, mm = size(A) m = Ref{BlasInt}(mm) n = Ref{BlasInt}(nn) - RldA = Ref{BlasInt}(max(1,stride(A,2))) - tau = Vector{T}(undef, min(nn,mm)) + RldA = Ref{BlasInt}(max(1, stride(A, 2))) + tau = Vector{T}(undef, min(nn, mm)) work = Vector{T}(undef, 1) lwork = Ref{BlasInt}(-1) info = Ref{BlasInt}(0) - ccall((@blasfunc($geqrf), libblastrampoline), Nothing, - (Ref{BlasInt}, Ref{BlasInt}, Ptr{T}, Ref{BlasInt}, - Ptr{T}, Ptr{T}, Ref{BlasInt}, Ref{BlasInt}), - m, n, A, RldA, tau, work, lwork, info) + ccall( + (@blasfunc($geqrf), libblastrampoline), + Nothing, + ( + Ref{BlasInt}, + Ref{BlasInt}, + Ptr{T}, + Ref{BlasInt}, + Ptr{T}, + Ptr{T}, + Ref{BlasInt}, + Ref{BlasInt}, + ), + m, + n, + A, + RldA, + tau, + work, + lwork, + info, + ) chklapackerror(info[]) lwork = Ref{BlasInt}(real(work[1])) work = Array{T}(undef, lwork[]) @@ -46,69 +65,144 @@ for (geqrf, ormqr, elty) in end function geqrf_core!(A::StridedMatrix{$elty}, ws::QrWs) - mm,nn = size(A) + mm, nn = size(A) m = Ref{BlasInt}(mm) n = Ref{BlasInt}(nn) - RldA = Ref{BlasInt}(max(1,stride(A,2))) - ccall((@blasfunc($geqrf), libblastrampoline), Nothing, - (Ref{BlasInt},Ref{BlasInt},Ptr{$elty},Ref{BlasInt}, - Ptr{$elty},Ptr{$elty},Ref{BlasInt},Ref{BlasInt}), - m,n,A,RldA,ws.tau,ws.work,ws.lwork,ws.info) + RldA = Ref{BlasInt}(max(1, stride(A, 2))) + ccall( + (@blasfunc($geqrf), libblastrampoline), + Nothing, + ( + Ref{BlasInt}, + Ref{BlasInt}, + Ptr{$elty}, + Ref{BlasInt}, + Ptr{$elty}, + Ptr{$elty}, + Ref{BlasInt}, + Ref{BlasInt}, + ), + m, + n, + A, + RldA, + ws.tau, + ws.work, + ws.lwork, + ws.info, + ) chklapackerror(ws.info[]) end t1 = StridedMatrix{$elty} - t2 = Transpose{$elty, <: StridedMatrix} - t3 = Adjoint{$elty, <: StridedMatrix} + t2 = Transpose{$elty,<:StridedMatrix} + t3 = Adjoint{$elty,<:StridedMatrix} end end -for (geqrf, ormqr, elty) in - ((:dgeqrf_, :dormqr_, :Float64), - (:sgeqrf_, :sormqr_, :Float32)) +for (geqrf, ormqr, elty) in ((:dgeqrf_, :dormqr_, :Float64), (:sgeqrf_, :sormqr_, :Float32)) @eval begin - function ormqr_core!(side::Char, A::StridedMatrix{$elty}, - C::StridedMatrix{$elty}, ws::QR) - mm,nn = size(C) + function ormqr_core!( + side::Char, + A::StridedMatrix{$elty}, + C::StridedMatrix{$elty}, + ws::QR, + ) + mm, nn = size(C) m = Ref{BlasInt}(mm) n = Ref{BlasInt}(nn) k = Ref{BlasInt}(length(ws.tau)) - RldA = Ref{BlasInt}(max(1,stride(A,2))) - RldC = Ref{BlasInt}(max(1,stride(C,2))) - ccall((@blasfunc($ormqr), libblastrampoline), Nothing, - (Ref{UInt8},Ref{UInt8},Ref{BlasInt},Ref{BlasInt},Ref{BlasInt},Ptr{$elty},Ref{BlasInt}, - Ptr{$elty},Ptr{$elty},Ref{BlasInt},Ptr{$elty},Ref{BlasInt},Ref{BlasInt}), - side, 'N', m, n, k, A, RldA, ws.tau, C, RldC, ws.work, ws.lwork, ws.info) + RldA = Ref{BlasInt}(max(1, stride(A, 2))) + RldC = Ref{BlasInt}(max(1, stride(C, 2))) + ccall( + (@blasfunc($ormqr), libblastrampoline), + Nothing, + ( + Ref{UInt8}, + Ref{UInt8}, + Ref{BlasInt}, + Ref{BlasInt}, + Ref{BlasInt}, + Ptr{$elty}, + Ref{BlasInt}, + Ptr{$elty}, + Ptr{$elty}, + Ref{BlasInt}, + Ptr{$elty}, + Ref{BlasInt}, + Ref{BlasInt}, + ), + side, + 'N', + m, + n, + k, + A, + RldA, + ws.tau, + C, + RldC, + ws.work, + ws.lwork, + ws.info, + ) chklapackerror(ws.info[]) end end @eval begin - t1 = Transpose{$elty, <: StridedMatrix{$elty}} - t2 = Adjoint{$elty, <: StridedMatrix{$elty}} + t1 = Transpose{$elty,<:StridedMatrix{$elty}} + t2 = Adjoint{$elty,<:StridedMatrix{$elty}} end - + for elty2 in (t1, t2) - + @eval begin - function ormqr_core!(side::Char, A::$elty2, - C::StridedMatrix{$elty}, ws::QR) - mm,nn = size(C) + function ormqr_core!(side::Char, A::$elty2, C::StridedMatrix{$elty}, ws::QR) + mm, nn = size(C) m = Ref{BlasInt}(mm) n = Ref{BlasInt}(nn) k = Ref{BlasInt}(length(ws.tau)) - RldA = Ref{BlasInt}(max(1,stride(A.parent,2))) - RldC = Ref{BlasInt}(max(1,stride(C,2))) - ccall((@blasfunc($ormqr), libblastrampoline), Nothing, - (Ref{UInt8},Ref{UInt8},Ref{BlasInt},Ref{BlasInt},Ref{BlasInt},Ptr{$elty},Ref{BlasInt}, - Ptr{$elty},Ptr{$elty},Ref{BlasInt},Ptr{$elty},Ref{BlasInt},Ref{BlasInt}), - side, 'T', m, n, k, A.parent, RldA, ws.tau, C, RldC, ws.work, ws.lwork, ws.info) + RldA = Ref{BlasInt}(max(1, stride(A.parent, 2))) + RldC = Ref{BlasInt}(max(1, stride(C, 2))) + ccall( + (@blasfunc($ormqr), libblastrampoline), + Nothing, + ( + Ref{UInt8}, + Ref{UInt8}, + Ref{BlasInt}, + Ref{BlasInt}, + Ref{BlasInt}, + Ptr{$elty}, + Ref{BlasInt}, + Ptr{$elty}, + Ptr{$elty}, + Ref{BlasInt}, + Ptr{$elty}, + Ref{BlasInt}, + Ref{BlasInt}, + ), + side, + 'T', + m, + n, + k, + A.parent, + RldA, + ws.tau, + C, + RldC, + ws.work, + ws.lwork, + ws.info, + ) chklapackerror(ws.info[]) end end - end -end + end +end #= for (geqrf, ormqr, elty) in @@ -162,7 +256,7 @@ for (geqrf, ormqr, elty) in end =# -struct QrpWs{T <: Number} <: QR +struct QrpWs{T<:Number} <: QR tau::Vector{T} jpvt::Vector{BlasInt} work::Vector{T} @@ -170,26 +264,47 @@ struct QrpWs{T <: Number} <: QR info::Ref{BlasInt} end -for (geqp3, elty) in - ((:dgeqp3_, :Float64), - (:sgeqp3_, :Float32), - (:zgeqp3_, :ComplexF64), - (:cgeqp3_, :ComplexF32)) +for (geqp3, elty) in ( + (:dgeqp3_, :Float64), + (:sgeqp3_, :Float32), + (:zgeqp3_, :ComplexF64), + (:cgeqp3_, :ComplexF32), +) @eval begin function QrpWs(A::StridedMatrix{$elty}) m, n = size(A) - RldA = BlasInt(max(1,stride(A,2))) + RldA = BlasInt(max(1, stride(A, 2))) jpvt = zeros(BlasInt, n) tau = Vector{$elty}(undef, min(m, n)) work = Vector{$elty}(undef, 1) lwork = BlasInt(-1) info = Ref{BlasInt}() - ccall((@blasfunc($geqp3), libblastrampoline), Nothing, - (Ref{BlasInt}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}, - Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ref{BlasInt}, Ref{BlasInt}), - m, n, A, RldA, jpvt, tau, work, lwork, info) + ccall( + (@blasfunc($geqp3), libblastrampoline), + Nothing, + ( + Ref{BlasInt}, + Ref{BlasInt}, + Ptr{$elty}, + Ref{BlasInt}, + Ptr{BlasInt}, + Ptr{$elty}, + Ptr{$elty}, + Ref{BlasInt}, + Ref{BlasInt}, + ), + m, + n, + A, + RldA, + jpvt, + tau, + work, + lwork, + info, + ) chklapackerror(info[]) lwork = BlasInt(real(work[1])) work = resize!(work, lwork) @@ -198,17 +313,37 @@ for (geqp3, elty) in function geqp3!(A::StridedMatrix{$elty}, ws::QrpWs) m, n = size(A) - RldA = BlasInt(max(1,stride(A,2))) - ccall((@blasfunc($geqp3), libblastrampoline), Nothing, - (Ref{BlasInt}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}, - Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ref{BlasInt}, Ref{BlasInt}), - m, n, A, RldA, ws.jpvt, ws.tau, ws.work, ws.lwork, ws.info) + RldA = BlasInt(max(1, stride(A, 2))) + ccall( + (@blasfunc($geqp3), libblastrampoline), + Nothing, + ( + Ref{BlasInt}, + Ref{BlasInt}, + Ptr{$elty}, + Ref{BlasInt}, + Ptr{BlasInt}, + Ptr{$elty}, + Ptr{$elty}, + Ref{BlasInt}, + Ref{BlasInt}, + ), + m, + n, + A, + RldA, + ws.jpvt, + ws.tau, + ws.work, + ws.lwork, + ws.info, + ) chklapackerror(ws.info[]) end t1 = StridedMatrix{$elty} - t2 = Transpose{$elty, <: StridedMatrix} - t3 = Adjoint{$elty, <: StridedMatrix} + t2 = Transpose{$elty,<:StridedMatrix} + t3 = Adjoint{$elty,<:StridedMatrix} end end diff --git a/src/SchurAlgo.jl b/src/SchurAlgo.jl index e18354b57876d289f9f985a2163cd79ec52e00ab..bfac7a9e30ecdba7c6859f28555d948956e89560 100644 --- a/src/SchurAlgo.jl +++ b/src/SchurAlgo.jl @@ -7,14 +7,14 @@ include("exceptions.jl") const libblastrampoline = "libblastrampoline" import LinearAlgebra: USE_BLAS64, LAPACKException -import LinearAlgebra: BlasInt, BlasFloat, checksquare, chkstride1 +import LinearAlgebra: BlasInt, BlasFloat, checksquare, chkstride1 import LinearAlgebra.BLAS: @blasfunc, libblas import LinearAlgebra.LAPACK: chklapackerror import Base: has_offset_axes export DgeesWs, dgees!, DggesWs, dgges! -const criterium = 1+1e-6 +const criterium = 1 + 1e-6 #= function mycompare(wr_, wi_)::Cint @@ -24,11 +24,14 @@ function mycompare(wr_, wi_)::Cint end =# -function mycompare(alphar_::Ptr{T}, alphai_::Ptr{T}, beta_::Ptr{T})::Cint where T +function mycompare(alphar_::Ptr{T}, alphai_::Ptr{T}, beta_::Ptr{T})::Cint where {T} alphar = unsafe_load(alphar_) alphai = unsafe_load(alphai_) beta = unsafe_load(beta_) - return convert(Cint, ((alphar*alphar + alphai*alphai) < criterium*beta*beta) ? 1 : 0) + return convert( + Cint, + ((alphar * alphar + alphai * alphai) < criterium * beta * beta) ? 1 : 0, + ) end struct DgeesWs @@ -44,27 +47,60 @@ struct DgeesWs eigen_values::Vector{Complex{Float64}} info::Ref{BlasInt} - function DgeesWs(jobvs::Ref{UInt8}, A::StridedMatrix{Float64}, sdim::Ref{BlasInt}, - wr::Vector{Float64}, wi::Vector{Float64}, ldvs::Ref{BlasInt}, vs::Matrix{Float64}, - work::Vector{Float64}, lwork::Ref{BlasInt}, bwork::Vector{Int64}, - eigen_values::Vector{Complex{Float64}}, info::Ref{BlasInt}) - n = Ref{BlasInt}(size(A,1)) - RldA = Ref{BlasInt}(max(1,stride(A,2))) + function DgeesWs( + jobvs::Ref{UInt8}, + A::StridedMatrix{Float64}, + sdim::Ref{BlasInt}, + wr::Vector{Float64}, + wi::Vector{Float64}, + ldvs::Ref{BlasInt}, + vs::Matrix{Float64}, + work::Vector{Float64}, + lwork::Ref{BlasInt}, + bwork::Vector{Int64}, + eigen_values::Vector{Complex{Float64}}, + info::Ref{BlasInt}, + ) + n = Ref{BlasInt}(size(A, 1)) + RldA = Ref{BlasInt}(max(1, stride(A, 2))) Rsort = Ref{UInt8}('N') -# mycompare_c = @cfunction(mycompare, Cint, (Ptr{Cdouble}, Ptr{Cdouble})) - ccall((@blasfunc(dgees_), libblastrampoline), Nothing, - (Ref{UInt8}, Ref{UInt8}, Ptr{Nothing}, - Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt}, - Ptr{BlasInt}, Ptr{Float64}, - Ptr{Float64}, Ptr{Float64}, Ref{BlasInt}, - Ptr{Float64}, Ref{BlasInt}, Ptr{BlasInt}, - Ptr{BlasInt}), - jobvs, 'N', C_NULL, - n, A, RldA, - sdim, wr, wi, - vs, ldvs, - work, lwork, bwork, - info) + # mycompare_c = @cfunction(mycompare, Cint, (Ptr{Cdouble}, Ptr{Cdouble})) + ccall( + (@blasfunc(dgees_), libblastrampoline), + Nothing, + ( + Ref{UInt8}, + Ref{UInt8}, + Ptr{Nothing}, + Ref{BlasInt}, + Ptr{Float64}, + Ref{BlasInt}, + Ptr{BlasInt}, + Ptr{Float64}, + Ptr{Float64}, + Ptr{Float64}, + Ref{BlasInt}, + Ptr{Float64}, + Ref{BlasInt}, + Ptr{BlasInt}, + Ptr{BlasInt}, + ), + jobvs, + 'N', + C_NULL, + n, + A, + RldA, + sdim, + wr, + wi, + vs, + ldvs, + work, + lwork, + bwork, + info, + ) chklapackerror(info[]) lwork = Ref{BlasInt}(real(work[1])) work = Vector{Float64}(undef, lwork[]) @@ -93,62 +129,108 @@ function DgeesWs(A::StridedMatrix{Float64}) end function DgeesWs(n::Int64) - A = zeros(n,n) + A = zeros(n, n) DgeesWs(A) end -function dgees!(ws::DgeesWs,A::StridedMatrix{Float64}) - n = Ref{BlasInt}(size(A,1)) - RldA = Ref{BlasInt}(max(1,stride(A,2))) +function dgees!(ws::DgeesWs, A::StridedMatrix{Float64}) + n = Ref{BlasInt}(size(A, 1)) + RldA = Ref{BlasInt}(max(1, stride(A, 2))) myfunc::Function = make_select_function(>=, 1.0) - ccall((@blasfunc(dgees_), libblastrampoline), Cvoid, - (Ref{UInt8}, Ref{UInt8}, Ptr{Cvoid}, - Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt}, - Ptr{BlasInt}, Ptr{Float64}, Ptr{Float64}, - Ptr{Float64}, Ref{BlasInt}, - Ptr{Float64}, Ref{BlasInt}, Ptr{BlasInt}, - Ptr{BlasInt}), - ws.jobvs, 'N', C_NULL, - n, A, RldA, - ws.sdim, ws.wr, ws.wi, - ws.vs, ws.ldvs, - ws.work, ws.lwork, ws.bwork, - ws.info) + ccall( + (@blasfunc(dgees_), libblastrampoline), + Cvoid, + ( + Ref{UInt8}, + Ref{UInt8}, + Ptr{Cvoid}, + Ref{BlasInt}, + Ptr{Float64}, + Ref{BlasInt}, + Ptr{BlasInt}, + Ptr{Float64}, + Ptr{Float64}, + Ptr{Float64}, + Ref{BlasInt}, + Ptr{Float64}, + Ref{BlasInt}, + Ptr{BlasInt}, + Ptr{BlasInt}, + ), + ws.jobvs, + 'N', + C_NULL, + n, + A, + RldA, + ws.sdim, + ws.wr, + ws.wi, + ws.vs, + ws.ldvs, + ws.work, + ws.lwork, + ws.bwork, + ws.info, + ) copyto!(ws.eigen_values, complex.(ws.wr, ws.wi)) chklapackerror(ws.info[]) end - + function make_select_function(op, crit)::Function - mycompare = function(wr_, wi_) + mycompare = function (wr_, wi_) wr = unsafe_load(wr_) wi = unsafe_load(wi_) - return convert(Cint, op(wr*wr + wi*wi, crit) ? 1 : 0) + return convert(Cint, op(wr * wr + wi * wi, crit) ? 1 : 0) end return mycompare end function dgees!(ws::DgeesWs, A::StridedMatrix{Float64}, op, crit) - n = Ref{BlasInt}(size(A,1)) - RldA = Ref{BlasInt}(max(1,stride(A,2))) + n = Ref{BlasInt}(size(A, 1)) + RldA = Ref{BlasInt}(max(1, stride(A, 2))) myfunc::Function = make_select_function(op, crit) mycompare_c = @cfunction($myfunc, Cint, (Ptr{Cdouble}, Ptr{Cdouble})) - ccall((@blasfunc(dgees_), libblastrampoline), Cvoid, - (Ref{UInt8}, Ref{UInt8}, Ptr{Cvoid}, - Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt}, - Ptr{BlasInt}, Ptr{Float64}, Ptr{Float64}, - Ptr{Float64}, Ref{BlasInt}, - Ptr{Float64}, Ref{BlasInt}, Ptr{BlasInt}, - Ptr{BlasInt}), - ws.jobvs, 'S', mycompare_c, - n, A, RldA, - ws.sdim, ws.wr, ws.wi, - ws.vs, ws.ldvs, - ws.work, ws.lwork, ws.bwork, - ws.info) + ccall( + (@blasfunc(dgees_), libblastrampoline), + Cvoid, + ( + Ref{UInt8}, + Ref{UInt8}, + Ptr{Cvoid}, + Ref{BlasInt}, + Ptr{Float64}, + Ref{BlasInt}, + Ptr{BlasInt}, + Ptr{Float64}, + Ptr{Float64}, + Ptr{Float64}, + Ref{BlasInt}, + Ptr{Float64}, + Ref{BlasInt}, + Ptr{BlasInt}, + Ptr{BlasInt}, + ), + ws.jobvs, + 'S', + mycompare_c, + n, + A, + RldA, + ws.sdim, + ws.wr, + ws.wi, + ws.vs, + ws.ldvs, + ws.work, + ws.lwork, + ws.bwork, + ws.info, + ) copyto!(ws.eigen_values, complex.(ws.wr, ws.wi)) chklapackerror(ws.info[]) end - + mutable struct DggesWs alphar::Vector{Float64} alphai::Vector{Float64} @@ -158,13 +240,19 @@ mutable struct DggesWs bwork::Vector{Int64} sdim::BlasInt - function DggesWs(jobvsl::Ref{UInt8}, jobvsr::Ref{UInt8}, sort::Ref{UInt8}, A::StridedMatrix{Float64}, B::StridedMatrix{Float64}) + function DggesWs( + jobvsl::Ref{UInt8}, + jobvsr::Ref{UInt8}, + sort::Ref{UInt8}, + A::StridedMatrix{Float64}, + B::StridedMatrix{Float64}, + ) chkstride1(A, B) n, m = checksquare(A, B) if n != m throw(DimensionMismatch("Dimensions of A, ($n,$n), and B, ($m,$m), must match")) end - n = BlasInt(size(A,1)) + n = BlasInt(size(A, 1)) alphar = Vector{Float64}(undef, n) alphai = Vector{Float64}(undef, n) beta = Vector{Float64}(undef, n) @@ -176,24 +264,60 @@ mutable struct DggesWs work = Vector{Float64}(undef, 1) sdim = BlasInt(0) info = BlasInt(0) - mycompare_g_c = @cfunction(mycompare, Cint, (Ptr{Cdouble}, Ptr{Cdouble}, Ptr{Cdouble})) - ccall((@blasfunc(dgges_), libblastrampoline), Nothing, - (Ref{UInt8}, Ref{UInt8}, Ref{UInt8}, Ptr{Nothing}, - Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt}, Ptr{Float64}, - Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64}, Ptr{Float64}, - Ptr{Float64}, Ptr{Float64}, Ref{BlasInt}, Ptr{Float64}, - Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt}, Ptr{Int64}, - Ref{BlasInt}), - jobvsl, jobvsr, sort, mycompare_g_c, - n, A, max(1,stride(A, 2)), B, - max(1,stride(B, 2)), sdim, alphar, alphai, - beta, C_NULL, ldvsl, C_NULL, - ldvsr, work, lwork, bwork, - info) + mycompare_g_c = + @cfunction(mycompare, Cint, (Ptr{Cdouble}, Ptr{Cdouble}, Ptr{Cdouble})) + ccall( + (@blasfunc(dgges_), libblastrampoline), + Nothing, + ( + Ref{UInt8}, + Ref{UInt8}, + Ref{UInt8}, + Ptr{Nothing}, + Ref{BlasInt}, + Ptr{Float64}, + Ref{BlasInt}, + Ptr{Float64}, + Ref{BlasInt}, + Ref{BlasInt}, + Ptr{Float64}, + Ptr{Float64}, + Ptr{Float64}, + Ptr{Float64}, + Ref{BlasInt}, + Ptr{Float64}, + Ref{BlasInt}, + Ptr{Float64}, + Ref{BlasInt}, + Ptr{Int64}, + Ref{BlasInt}, + ), + jobvsl, + jobvsr, + sort, + mycompare_g_c, + n, + A, + max(1, stride(A, 2)), + B, + max(1, stride(B, 2)), + sdim, + alphar, + alphai, + beta, + C_NULL, + ldvsl, + C_NULL, + ldvsr, + work, + lwork, + bwork, + info, + ) chklapackerror(info) lwork = BlasInt(real(work[1])) work = Vector{Float64}(undef, lwork) - new(alphar,alphai,beta,lwork,work,bwork,sdim) + new(alphar, alphai, beta, lwork, work, bwork, sdim) end end @@ -201,35 +325,77 @@ function DggesWs(A::StridedMatrix{Float64}, B::StridedMatrix{Float64}) DggesWs(Ref{UInt8}('N'), Ref{UInt8}('N'), Ref{UInt8}('N'), A, B) end -function dgges!(jobvsl::Char, jobvsr::Char, A::StridedMatrix{Float64}, B::StridedMatrix{Float64}, - vsl::Matrix{Float64}, vsr::Matrix{Float64}, eigval::Array{ComplexF64,1}, - ws::DggesWs) - n = size(A,1) +function dgges!( + jobvsl::Char, + jobvsr::Char, + A::StridedMatrix{Float64}, + B::StridedMatrix{Float64}, + vsl::Matrix{Float64}, + vsr::Matrix{Float64}, + eigval::Array{ComplexF64,1}, + ws::DggesWs, +) + n = size(A, 1) ldvsl = jobvsl == 'V' ? n : 1 ldvsr = jobvsr == 'V' ? n : 1 sort = 'S' sdim = Ref{BlasInt}(0) info = Ref{BlasInt}(0) mycompare_g_c = @cfunction(mycompare, Cint, (Ptr{Cdouble}, Ptr{Cdouble}, Ptr{Cdouble})) - ccall((@blasfunc(dgges_), libblastrampoline), Nothing, - (Ref{UInt8}, Ref{UInt8}, Ref{UInt8}, Ptr{Nothing}, - Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt}, Ptr{Float64}, - Ref{BlasInt}, Ptr{BlasInt}, Ptr{Float64}, Ptr{Float64}, - Ptr{Float64}, Ptr{Float64}, Ref{BlasInt}, Ptr{Float64}, - Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt}, Ptr{Int64}, - Ref{Int64}), - jobvsl, jobvsr, sort, mycompare_g_c, - n, A, max(1,stride(A, 2)), B, - max(1,stride(B, 2)), sdim, ws.alphar, ws.alphai, - ws.beta, vsl, ldvsl, vsr, - ldvsr, ws.work, ws.lwork, ws.bwork, - info) + ccall( + (@blasfunc(dgges_), libblastrampoline), + Nothing, + ( + Ref{UInt8}, + Ref{UInt8}, + Ref{UInt8}, + Ptr{Nothing}, + Ref{BlasInt}, + Ptr{Float64}, + Ref{BlasInt}, + Ptr{Float64}, + Ref{BlasInt}, + Ptr{BlasInt}, + Ptr{Float64}, + Ptr{Float64}, + Ptr{Float64}, + Ptr{Float64}, + Ref{BlasInt}, + Ptr{Float64}, + Ref{BlasInt}, + Ptr{Float64}, + Ref{BlasInt}, + Ptr{Int64}, + Ref{Int64}, + ), + jobvsl, + jobvsr, + sort, + mycompare_g_c, + n, + A, + max(1, stride(A, 2)), + B, + max(1, stride(B, 2)), + sdim, + ws.alphar, + ws.alphai, + ws.beta, + vsl, + ldvsl, + vsr, + ldvsr, + ws.work, + ws.lwork, + ws.bwork, + info, + ) ws.sdim = sdim[] if info[] > 0 throw(DggesException(info[])) end - for i in 1:n - eigval[i] = complex(ws.alphar[i],ws.alphai[i])/ws.beta[i] + for i = 1:n + eigval[i] = complex(ws.alphar[i], ws.alphai[i]) / ws.beta[i] end end diff --git a/test/LinSolveAlgo_test.jl b/test/LinSolveAlgo_test.jl index 3a8d9f834b645dd3f89d2f0fef932552d7a87e1c..41af70b09c7c6355dc178b6052456630200d78f1 100644 --- a/test/LinSolveAlgo_test.jl +++ b/test/LinSolveAlgo_test.jl @@ -15,12 +15,12 @@ for elty in (Float32, Float64, ComplexF32, ComplexF64) A = copy(A0) # Full matrix - linws = LinSolveAlgo.LinSolveWs{elty, Int64}(n) + linws = LinSolveAlgo.LinSolveWs{elty,Int64}(n) LinSolveAlgo.lu!(A, linws) @test A == A0 F = lu(A) - @test UpperTriangular(reshape(linws.lu, n, n)) ≈ F.U + @test UpperTriangular(reshape(linws.lu, n, n)) ≈ F.U LinSolveAlgo.lu!(A', linws) @test UpperTriangular(reshape(linws.lu, n, n)) ≈ F.U @@ -28,107 +28,107 @@ for elty in (Float32, Float64, ComplexF32, ComplexF64) B = copy(B0) LinSolveAlgo.linsolve_core!(A, B, linws) @test A == A0 - @test B ≈ A\B0 + @test B ≈ A \ B0 copy!(B, B1) LinSolveAlgo.linsolve_core_no_lu!(A, B, linws) @test A == A0 - @test B ≈ A\B1 + @test B ≈ A \ B1 copy!(B, B0) LinSolveAlgo.linsolve_core!(A', B, linws) - @test B ≈ A'\B0 + @test B ≈ A' \ B0 copy!(B, B1) LinSolveAlgo.linsolve_core_no_lu!(A', B, linws) @test A == A0 - @test B ≈ A'\B1 + @test B ≈ A' \ B1 # view of a matrix in upper left corner - linws1 = LinSolveAlgo.LinSolveWs{elty, Int64}(n-1) + linws1 = LinSolveAlgo.LinSolveWs{elty,Int64}(n - 1) C = view(A, 1:n-1, 1:n-1) D = view(B, 1:n-1, 1:m-1) D0 = copy(D) D1 = copy(D) LinSolveAlgo.linsolve_core!(C, D, linws1) @test C == view(A, 1:n-1, 1:n-1) - @test D ≈ C\D0 + @test D ≈ C \ D0 D = copy(D1) LinSolveAlgo.linsolve_core_no_lu!(C, D, linws1) @test C == view(A, 1:n-1, 1:n-1) - @test D ≈ C\D1 + @test D ≈ C \ D1 D = copy(D0) LinSolveAlgo.linsolve_core!(C', D, linws1) @test C == view(A, 1:n-1, 1:n-1) - @test D ≈ C'\D0 + @test D ≈ C' \ D0 D = copy(D1) LinSolveAlgo.linsolve_core_no_lu!(C', D, linws1) @test C == view(A, 1:n-1, 1:n-1) - @test D ≈ C'\D1 + @test D ≈ C' \ D1 # view of a matrix in lower left corner - linws1 = LinSolveAlgo.LinSolveWs{elty, Int64}(n-1) + linws1 = LinSolveAlgo.LinSolveWs{elty,Int64}(n - 1) C = view(A, 2:n, 1:n-1) C1 = copy(C) LinSolveAlgo.lu!(C1, linws1) F = LinearAlgebra.lu!(C) - @test triu(reshape(linws1.lu, n-1, n-1)) ≈ F.U - @test tril(reshape(linws1.lu, n-1, n-1), -1) ≈ tril(F.L, -1) - + @test triu(reshape(linws1.lu, n - 1, n - 1)) ≈ F.U + @test tril(reshape(linws1.lu, n - 1, n - 1), -1) ≈ tril(F.L, -1) + D0 = view(B0, 2:n, 1:m-1) D1 = view(B1, 2:n, 1:m-1) D = copy(D0) LinSolveAlgo.linsolve_core!(C, D, linws1) @test C == view(A, 2:n, 1:n-1) - @test D ≈ C\D0 + @test D ≈ C \ D0 D = copy(D1) LinSolveAlgo.linsolve_core_no_lu!(C, D, linws1) @test C == view(A, 2:n, 1:n-1) - @test D ≈ C\D1 + @test D ≈ C \ D1 D = copy(D0) LinSolveAlgo.linsolve_core!(C', D, linws1) @test C == view(A, 2:n, 1:n-1) - @test D ≈ C'\D0 + @test D ≈ C' \ D0 D = copy(D1) LinSolveAlgo.linsolve_core_no_lu!(C', D, linws1) @test C == view(A, 2:n, 1:n-1) - @test D ≈ C'\D1 + @test D ≈ C' \ D1 # using too big a work space - linws1 = LinSolveAlgo.LinSolveWs{elty, Int64}(n) + linws1 = LinSolveAlgo.LinSolveWs{elty,Int64}(n) C = view(A, 2:n, 1:n-1) C1 = copy(C) LinSolveAlgo.lu!(C1, linws1) F = LinearAlgebra.lu!(C) - @test triu(reshape(linws1.lu[1:(n-1)^2], n-1, n-1)) ≈ F.U - @test tril(reshape(linws1.lu[1:(n-1)^2], n-1, n-1), -1) ≈ tril(F.L, -1) - + @test triu(reshape(linws1.lu[1:(n-1)^2], n - 1, n - 1)) ≈ F.U + @test tril(reshape(linws1.lu[1:(n-1)^2], n - 1, n - 1), -1) ≈ tril(F.L, -1) + D0 = view(B0, 2:n, 1:m-1) D1 = view(B1, 2:n, 1:m-1) D = copy(D0) LinSolveAlgo.linsolve_core!(C, D, linws1) @test C == view(A, 2:n, 1:n-1) - @test D ≈ C\D0 + @test D ≈ C \ D0 D = copy(D1) LinSolveAlgo.linsolve_core_no_lu!(C, D, linws1) @test C == view(A, 2:n, 1:n-1) - @test D ≈ C\D1 + @test D ≈ C \ D1 D = copy(D0) LinSolveAlgo.linsolve_core!(C', D, linws1) @test C == view(A, 2:n, 1:n-1) - @test D ≈ C'\D0 + @test D ≈ C' \ D0 D = copy(D1) LinSolveAlgo.linsolve_core_no_lu!(C', D, linws1) @test C == view(A, 2:n, 1:n-1) - @test D ≈ C'\D1 + @test D ≈ C' \ D1 end diff --git a/test/QrAlgo_test.jl b/test/QrAlgo_test.jl index 84756b01227650ff427c3473bb6f57f656cc18e2..64fce2e04d196354a8f6a9f5751484ba2ef3ad18 100644 --- a/test/QrAlgo_test.jl +++ b/test/QrAlgo_test.jl @@ -1,17 +1,17 @@ - + n = 10 #for elty in (Float32, Float64, ComplexF32, ComplexF64) # A0 = randn(elty, n, n) - A0 = randn(n, n) +A0 = randn(n, n) - A = copy(A0) - ws = QrAlgo.QrpWs(A) +A = copy(A0) +ws = QrAlgo.QrpWs(A) - QrAlgo.geqp3!(A, ws) +QrAlgo.geqp3!(A, ws) - target = qr(A0, ColumnNorm()) +target = qr(A0, ColumnNorm()) #display(triu(A)) #display(triu(target.R)) diff --git a/test/SchurAlgo_test.jl b/test/SchurAlgo_test.jl index 8b22b06da3be852144025200c5a6bb1cb17ed882..201d353befb12f0a7caba4560840c26a1e1a60ec 100644 --- a/test/SchurAlgo_test.jl +++ b/test/SchurAlgo_test.jl @@ -1,4 +1,4 @@ -A = diagm([1, -0.5 , 1]) +A = diagm([1, -0.5, 1]) ws = DgeesWs(3)