Skip to content
Snippets Groups Projects
Commit 24d5d420 authored by Michel Juillard's avatar Michel Juillard
Browse files

initial files

parents
Branches
No related tags found
No related merge requests found
using Documenter, FastLapackInterface
makedocs(sitename="MyDocumentation")
module FastLapackInterface
greet() = print("Hello World!")
end # module
module LinSolveAlgo
import Base.strides
using LinearAlgebra
import LinearAlgebra.BlasInt
import LinearAlgebra.BLAS.@blasfunc
import LinearAlgebra.LAPACK: liblapack, chklapackerror
export LinSolveWs, linsolve_core!, linsolve_core_no_lu!, lu!
struct LinSolveWs{T <: Number, U <: Integer}
lu::Matrix{T}
ipiv::Vector{BlasInt}
function LinSolveWs{T, U}(n::U) where {T <: Number, U <: Integer}
lu = zeros(T, n, n)
ipiv = zeros(BlasInt, n)
new(lu, ipiv)
end
end
# Float64 is the default
LinSolveWs(n) = LinSolveWs{Float64, Int64}(n)
strides(a::Adjoint) = strides(a.parent)
for (getrf, getrs, elty) in
((:dgetrf_, :dgetrs_, :Float64),
(:sgetrf_, :sgetrs_, :Float32),
(:zgetrf_, :zgetrs_, :ComplexF64),
(:cgetrf_, :cgetrs_, :ComplexF32))
@eval begin
function lu!(a::StridedMatrix{$elty},
ws::LinSolveWs)
copy!(ws.lu, a)
mm,nn = size(a)
m = Ref{BlasInt}(mm)
n = Ref{BlasInt}(nn)
lda = Ref{BlasInt}(max(1,stride(a,2)))
info = Ref{BlasInt}(0)
ccall((@blasfunc($getrf), liblapack), Cvoid,
(Ref{BlasInt},Ref{BlasInt},Ptr{$elty},Ref{BlasInt},
Ptr{BlasInt},Ref{BlasInt}),
m, n, ws.lu, lda, ws.ipiv, info)
if info[] != 0
chklapackerror(info[])
end
end
# if a is an adjoint matrix, we compute lu decomposition
# for its parent as the transposition is done at the
# solution stage
lu!(a::Adjoint,
ws::LinSolveWs) = lu!(a.parent, ws)
t1 = StridedMatrix{$elty}
t2 = Transpose{$elty, <: StridedMatrix}
t3 = Adjoint{$elty, <: StridedMatrix}
end
for (elty2, transchar) in
((t1, 'N'),
(t2, 'T'),
(t3, 'C'))
@eval begin
function linsolve_core_no_lu!(a::$elty2,
b::StridedVecOrMat{$elty},
ws::LinSolveWs)
mm,nn = size(a)
m = Ref{BlasInt}(mm)
n = Ref{BlasInt}(nn)
nhrs = Ref{BlasInt}(size(b,2))
lda = Ref{BlasInt}(max(1,stride(a,2)))
ldb = Ref{BlasInt}(max(1,stride(b,2)))
info = Ref{BlasInt}(0)
ccall((@blasfunc($getrs), liblapack), Cvoid,
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{$elty},
Ref{BlasInt}, Ptr{BlasInt}, Ptr{$elty}, Ref{BlasInt},
Ref{BlasInt}),
$transchar, n, nhrs, ws.lu, lda, ws.ipiv, b, ldb, info)
if info[] != 0
chklapackerror(info[])
end
end
function linsolve_core!(a::$elty2,
b::StridedVecOrMat{$elty},
ws::LinSolveWs)
mm,nn = size(a)
m = Ref{BlasInt}(mm)
n = Ref{BlasInt}(nn)
nhrs = Ref{BlasInt}(size(b,2))
lda = Ref{BlasInt}(max(1,stride(a,2)))
ldb = Ref{BlasInt}(max(1,stride(b,2)))
info = Ref{BlasInt}(0)
lu!(a, ws)
ccall((@blasfunc($getrs), liblapack), Cvoid,
(Ref{UInt8},Ref{BlasInt},Ref{BlasInt},Ptr{$elty},Ref{BlasInt},
Ptr{BlasInt},Ptr{$elty},Ref{BlasInt},Ref{BlasInt}),
$transchar, n, nhrs, ws.lu, lda, ws.ipiv, b, ldb, info)
if info[] != 0
chklapackerror(info[])
end
end
end
end
end
end
using LinearAlgebra
import LinearAlgebra: BlasInt
import LinearAlgebra.BLAS: @blasfunc
import LinearAlgebra.LAPACK: liblapack, chklapackerror
export QrWs, dgeqrf_core!, dormrqf_core!
struct QrWs{T <: Number}
tau::Vector{T}
work::Vector{T}
lwork::Ref{BlasInt}
info::Ref{BlasInt}
end
for (geqrf, ormqr, elty) in
((:dgeqrf_, :dormqr_, :Float64),
(:sgeqrf_, :sormqr_, :Float32),
(:zgeqrf_, :zormqr_, :ComplexF64),
(:cgeqrf_, :cormqr_, :ComplexF32))
@eval begin
function QrWs(A::StridedMatrix{T}) where T <: $elty
nn, mm = size(A)
m = Ref{BlasInt}(mm)
n = Ref{BlasInt}(nn)
RldA = Ref{BlasInt}(max(1,stride(A,2)))
tau = Vector{T}(undef, min(nn,mm))
work = Vector{T}(undef, 1)
lwork = Ref{BlasInt}(-1)
info = Ref{BlasInt}(0)
ccall((@blasfunc($geqrf), liblapack), Nothing,
(Ref{BlasInt}, Ref{BlasInt}, Ptr{T}, Ref{BlasInt},
Ptr{T}, Ptr{T}, Ref{BlasInt}, Ref{BlasInt}),
m, n, A, RldA, tau, work, lwork, info)
chklapackerror(info[])
lwork = Ref{BlasInt}(real(work[1]))
work = Array{T}(undef, lwork[])
QrWs(tau, work, lwork, info)
end
function dgeqrf_core!(A::StridedMatrix{$elty}, ws::QrWs)
mm,nn = size(A)
m = Ref{BlasInt}(mm)
n = Ref{BlasInt}(nn)
RldA = Ref{BlasInt}(max(1,stride(A,2)))
ccall((@blasfunc($geqrf), liblapack), Nothing,
(Ref{BlasInt},Ref{BlasInt},Ptr{$elty},Ref{BlasInt},
Ptr{$elty},Ptr{$elty},Ref{BlasInt},Ref{BlasInt}),
m,n,A,RldA,ws.tau,ws.work,ws.lwork,ws.info)
chklapackerror(ws.info[])
end
t1 = StridedMatrix{$elty}
t2 = Transpose{$elty, <: StridedMatrix}
t3 = Adjoint{$elty, <: StridedMatrix}
end
for (elty2, transchar) in
((t1, 'N'),
(t2, 'T'),
(t3, 'C'))
@eval begin
function ormrqf_core!(side::Ref{UInt8}, A::$elty2,
C::StridedMatrix{$elty}, ws::QrWs)
mm,nn = size(C)
m = Ref{BlasInt}(mm)
n = Ref{BlasInt}(nn)
k = Ref{BlasInt}(length(ws.tau))
RldA = Ref{BlasInt}(max(1,stride(A,2)))
RldC = Ref{BlasInt}(max(1,stride(C,2)))
ccall((@blasfunc($ormqr), liblapack), Nothing,
(Ref{UInt8},Ref{UInt8},Ref{BlasInt},Ref{BlasInt},Ref{BlasInt},Ptr{$elty},Ref{BlasInt},
Ptr{$elty},Ptr{$elty},Ref{BlasInt},Ptr{$elty},Ref{BlasInt},Ref{BlasInt}),
side,$transchar,m,n,k,A,RldA,ws.tau,C,RldC,ws.work,ws.lwork,ws.info)
chklapackerror(ws.info[])
end
end
end
end
module SchurAlgo
# general Schur decomposition with reordering
# adataped from ./base/linalg/lapack.jl
include("exceptions.jl")
import LinearAlgebra: USE_BLAS64, LAPACKException
import LinearAlgebra: BlasInt, BlasFloat, checksquare, chkstride1
import LinearAlgebra.BLAS: @blasfunc, libblas
import LinearAlgebra.LAPACK: liblapack, chklapackerror
import Base: has_offset_axes
export DgeesWS, dgees!, DggesWS, dgges!
const criterium = 1+1e-6
function mycompare(wr_, wi_)::Cint
wr = unsafe_load(wr_)
wi = unsafe_load(wi_)
return convert(Cint, ((wr*wr + wi*wi) < criterium) ? 1 : 0)
end
function mycompare(alphar_::Ptr{T}, alphai_::Ptr{T}, beta_::Ptr{T})::Cint where T
alphar = unsafe_load(alphar_)
alphai = unsafe_load(alphai_)
beta = unsafe_load(beta_)
return convert(Cint, ((alphar*alphar + alphai*alphai) < criterium*beta*beta) ? 1 : 0)
end
mutable struct DgeesWS
jobvs::Ref{UInt8}
sdim::Ref{BlasInt}
wr::Vector{Float64}
wi::Vector{Float64}
ldvs::Ref{BlasInt}
vs::Matrix{Float64}
work::Vector{Float64}
lwork::Ref{BlasInt}
bwork::Vector{Int64}
eigen_values::Vector{Complex{Float64}}
info::Ref{BlasInt}
function DgeesWS(jobvs::Ref{UInt8}, A::StridedMatrix{Float64}, sdim::Ref{BlasInt},
wr::Vector{Float64}, wi::Vector{Float64}, ldvs::Ref{BlasInt}, vs::Matrix{Float64},
work::Vector{Float64}, lwork::Ref{BlasInt}, bwork::Vector{Int64},
eigen_values::Vector{Complex{Float64}}, info::Ref{BlasInt})
n = Ref{BlasInt}(size(A,1))
RldA = Ref{BlasInt}(max(1,stride(A,2)))
Rsort = Ref{UInt8}('N')
mycompare_c = @cfunction(mycompare, Cint, (Ptr{Cdouble}, Ptr{Cdouble}))
ccall((@blasfunc(dgees_), liblapack), Nothing,
(Ref{UInt8}, Ref{UInt8}, Ptr{Nothing},
Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt},
Ptr{BlasInt}, Ptr{Float64},
Ptr{Float64}, Ptr{Float64}, Ref{BlasInt},
Ptr{Float64}, Ref{BlasInt}, Ptr{BlasInt},
Ptr{BlasInt}),
jobvs, Rsort, mycompare_c,
n, A, RldA,
sdim, wr, wi,
vs, ldvs,
work, lwork, bwork,
info)
chklapackerror(info[])
lwork = Ref{BlasInt}(real(work[1]))
work = Vector{Float64}(undef, lwork[])
new(jobvs, sdim, wr, wi, ldvs, vs, work, lwork, bwork, eigen_values, info)
end
end
function DgeesWS(A::StridedMatrix{Float64})
chkstride1(A)
n, = checksquare(A)
jobvs = Ref{UInt8}('V')
sdim = Ref{BlasInt}(0)
wr = Vector{Float64}(undef, n)
wi = Vector{Float64}(undef, n)
ldvs = Ref{BlasInt}(jobvs[] == UInt32('V') ? n : 1)
vs = Matrix{Float64}(undef, ldvs[], n)
work = Vector{Float64}(undef, 1)
lwork = Ref{BlasInt}(-1)
bwork = Vector{Int64}(undef, n)
eigen_values = Vector{Complex{Float64}}(undef, n)
info = Ref{BlasInt}(0)
DgeesWS(jobvs, A, sdim, wr, wi, ldvs, vs, work, lwork, bwork, eigen_values, info)
end
function DgeesWS(n::Int64)
A = zeros(n,n)
DgeesWS(A)
end
function dgees!(ws::DgeesWS,A::StridedMatrix{Float64})
n = Ref{BlasInt}(size(A,1))
RldA = Ref{BlasInt}(max(1,stride(A,2)))
mycompare_c = @cfunction(mycompare, Cint, (Ptr{Cdouble}, Ptr{Cdouble}))
ccall((@blasfunc(dgees_), liblapack), Cvoid,
(Ref{UInt8}, Ref{UInt8}, Ptr{Cvoid},
Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt},
Ptr{BlasInt}, Ptr{Float64}, Ptr{Float64},
Ptr{Float64}, Ref{BlasInt},
Ptr{Float64}, Ref{BlasInt}, Ptr{BlasInt},
Ptr{BlasInt}),
ws.jobvs, 'S', mycompare_c,
n, A, RldA,
ws.sdim, ws.wr, ws.wi,
ws.vs, ws.ldvs,
ws.work, ws.lwork, ws.bwork,
ws.info)
copyto!(ws.eigen_values, complex.(ws.wr, ws.wi))
chklapackerror(ws.info[])
end
mutable struct DggesWS
alphar::Vector{Float64}
alphai::Vector{Float64}
beta::Vector{Float64}
lwork::BlasInt
work::Vector{Float64}
bwork::Vector{Int64}
sdim::BlasInt
function DggesWS(jobvsl::Ref{UInt8}, jobvsr::Ref{UInt8}, sort::Ref{UInt8}, A::StridedMatrix{Float64}, B::StridedMatrix{Float64})
chkstride1(A, B)
n, m = checksquare(A, B)
if n != m
throw(DimensionMismatch("Dimensions of A, ($n,$n), and B, ($m,$m), must match"))
end
n = BlasInt(size(A,1))
alphar = Vector{Float64}(undef, n)
alphai = Vector{Float64}(undef, n)
beta = Vector{Float64}(undef, n)
bwork = Vector{Int64}(undef, n)
ldvsl = BlasInt(1)
ldvsr = BlasInt(1)
sdim = BlasInt(0)
lwork = BlasInt(-1)
work = Vector{Float64}(undef, 1)
sdim = BlasInt(0)
info = BlasInt(0)
mycompare_g_c = @cfunction(mycompare, Cint, (Ptr{Cdouble}, Ptr{Cdouble}, Ptr{Cdouble}))
ccall((@blasfunc(dgges_), liblapack), Nothing,
(Ref{UInt8}, Ref{UInt8}, Ref{UInt8}, Ptr{Nothing},
Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt}, Ptr{Float64},
Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64}, Ptr{Float64},
Ptr{Float64}, Ptr{Float64}, Ref{BlasInt}, Ptr{Float64},
Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt}, Ptr{Int64},
Ref{BlasInt}),
jobvsl, jobvsr, sort, mycompare_g_c,
n, A, max(1,stride(A, 2)), B,
max(1,stride(B, 2)), sdim, alphar, alphai,
beta, C_NULL, ldvsl, C_NULL,
ldvsr, work, lwork, bwork,
info)
chklapackerror(info)
lwork = BlasInt(real(work[1]))
work = Vector{Float64}(undef, lwork)
new(alphar,alphai,beta,lwork,work,bwork,sdim)
end
end
function DggesWS(A::StridedMatrix{Float64}, B::StridedMatrix{Float64})
DggesWS(Ref{UInt8}('N'), Ref{UInt8}('N'), Ref{UInt8}('N'), A, B)
end
function dgges!(jobvsl::Char, jobvsr::Char, A::StridedMatrix{Float64}, B::StridedMatrix{Float64},
vsl::Matrix{Float64}, vsr::Matrix{Float64}, eigval::Array{ComplexF64,1},
ws::DggesWS)
n = size(A,1)
ldvsl = jobvsl == 'V' ? n : 1
ldvsr = jobvsr == 'V' ? n : 1
sort = 'S'
sdim = Ref{BlasInt}(0)
info = Ref{BlasInt}(0)
mycompare_g_c = @cfunction(mycompare, Cint, (Ptr{Cdouble}, Ptr{Cdouble}, Ptr{Cdouble}))
ccall((@blasfunc(dgges_), liblapack), Nothing,
(Ref{UInt8}, Ref{UInt8}, Ref{UInt8}, Ptr{Nothing},
Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt}, Ptr{Float64},
Ref{BlasInt}, Ptr{BlasInt}, Ptr{Float64}, Ptr{Float64},
Ptr{Float64}, Ptr{Float64}, Ref{BlasInt}, Ptr{Float64},
Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt}, Ptr{Int64},
Ref{Int64}),
jobvsl, jobvsr, sort, mycompare_g_c,
n, A, max(1,stride(A, 2)), B,
max(1,stride(B, 2)), sdim, ws.alphar, ws.alphai,
ws.beta, vsl, ldvsl, vsr,
ldvsr, ws.work, ws.lwork, ws.bwork,
info)
ws.sdim = sdim[]
if info[] > 0
throw(DggesException(info[]))
end
for i in 1:n
eigval[i] = complex(ws.alphar[i],ws.alphai[i])/ws.beta[i]
end
end
end
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment