diff --git a/src/SchurAlgo.jl b/src/SchurAlgo.jl index 84118cd39e2b7aec2d791e9de7937b1f45b4c52d..a5155b8a30b4c3f31ce0f85a2056b5c7f8f5d5b7 100644 --- a/src/SchurAlgo.jl +++ b/src/SchurAlgo.jl @@ -14,11 +14,13 @@ export DgeesWs, dgees!, DggesWs, dgges! const criterium = 1+1e-6 +#= function mycompare(wr_, wi_)::Cint wr = unsafe_load(wr_) wi = unsafe_load(wi_) return convert(Cint, ((wr*wr + wi*wi) < criterium) ? 1 : 0) end +=# function mycompare(alphar_::Ptr{T}, alphai_::Ptr{T}, beta_::Ptr{T})::Cint where T alphar = unsafe_load(alphar_) @@ -47,7 +49,7 @@ mutable struct DgeesWs n = Ref{BlasInt}(size(A,1)) RldA = Ref{BlasInt}(max(1,stride(A,2))) Rsort = Ref{UInt8}('N') - mycompare_c = @cfunction(mycompare, Cint, (Ptr{Cdouble}, Ptr{Cdouble})) +# mycompare_c = @cfunction(mycompare, Cint, (Ptr{Cdouble}, Ptr{Cdouble})) ccall((@blasfunc(dgees_), liblapack), Nothing, (Ref{UInt8}, Ref{UInt8}, Ptr{Nothing}, Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt}, @@ -55,7 +57,7 @@ mutable struct DgeesWs Ptr{Float64}, Ptr{Float64}, Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}), - jobvs, Rsort, mycompare_c, + jobvs, 'N', C_NULL, n, A, RldA, sdim, wr, wi, vs, ldvs, @@ -96,7 +98,38 @@ end function dgees!(ws::DgeesWs,A::StridedMatrix{Float64}) n = Ref{BlasInt}(size(A,1)) RldA = Ref{BlasInt}(max(1,stride(A,2))) - mycompare_c = @cfunction(mycompare, Cint, (Ptr{Cdouble}, Ptr{Cdouble})) + myfunc::Function = make_select_function(>=, 1.0) + ccall((@blasfunc(dgees_), liblapack), Cvoid, + (Ref{UInt8}, Ref{UInt8}, Ptr{Cvoid}, + Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt}, + Ptr{BlasInt}, Ptr{Float64}, Ptr{Float64}, + Ptr{Float64}, Ref{BlasInt}, + Ptr{Float64}, Ref{BlasInt}, Ptr{BlasInt}, + Ptr{BlasInt}), + ws.jobvs, 'N', C_NULL, + n, A, RldA, + ws.sdim, ws.wr, ws.wi, + ws.vs, ws.ldvs, + ws.work, ws.lwork, ws.bwork, + ws.info) + copyto!(ws.eigen_values, complex.(ws.wr, ws.wi)) + chklapackerror(ws.info[]) +end + +function make_select_function(op, crit)::Function + mycompare = function(wr_, wi_) + wr = unsafe_load(wr_) + wi = unsafe_load(wi_) + return convert(Cint, op(wr*wr + wi*wi, crit) ? 1 : 0) + end + return mycompare +end + +function dgees!(ws::DgeesWs, A::StridedMatrix{Float64}, op, crit) + n = Ref{BlasInt}(size(A,1)) + RldA = Ref{BlasInt}(max(1,stride(A,2))) + myfunc::Function = make_select_function(op, crit) + mycompare_c = @cfunction($myfunc, Cint, (Ptr{Cdouble}, Ptr{Cdouble})) ccall((@blasfunc(dgees_), liblapack), Cvoid, (Ref{UInt8}, Ref{UInt8}, Ptr{Cvoid}, Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt},