Commit 730015bc authored by Frédéric Karamé's avatar Frédéric Karamé
Browse files

First changes to julia version 1+

parent 968c6594
module ExtendedMul module ExtendedMul
import Base.LinAlg: A_mul_B!, At_mul_B!, A_mul_Bt!, At_mul_Bt import LinearAlgebra: A_mul_B!, At_mul_B!, A_mul_Bt!, At_mul_Bt
import Base.BLAS: gemm! import Base.BLAS: gemm!
import Base.LinAlg: BlasInt, BlasFloat import LinearAlgebra: BlasInt, BlasFloat
import Base.LinAlg.BLAS: @blasfunc, libblas import LinearAlgebra.BLAS: @blasfunc, libblas
export A_mul_B!, At_mul_B!, A_mul_Bt!, At_mul_B! export A_mul_B!, At_mul_B!, A_mul_Bt!, At_mul_B!
function A_mul_B!(c::VecOrMat{Float64}, offset_c::Int64, a::VecOrMat{Float64}, function A_mul_B!(c::VecOrMat{Float64}, offset_c::Int64, a::VecOrMat{Float64},
offset_a::Int64, ma::Int64, na::Int64, b::VecOrMat{Float64}, offset_a::Int64, ma::Int64, na::Int64, b::VecOrMat{Float64},
offset_b::Int64, nb::Int64) offset_b::Int64, nb::Int64)
gemm!('N', 'N', 1.0, Ref(a, offset_a), ma, na, Ref(b, offset_b), gemm!('N', 'N', 1.0, Ref(a, offset_a), ma, na, Ref(b, offset_b),
nb, 0.0, Ref(c, offset_c)) nb, 0.0, Ref(c, offset_c))
end end
function A_mul_B!(c::Array{Float64,1}, offset_c::Int64, a::SubArray{Float64,2,Array{Float64,2},Tuple{Base.Slice{Base.OneTo{Int64}},UnitRange{Int64}},true}, offset_a::Int64, ma::Int64, na::Int64, b::Array{Float64,1}, offset_b::Int64, nb::Int64) function A_mul_B!(c::Array{Float64,1}, offset_c::Int64, a::SubArray{Float64,2,Array{Float64,2},Tuple{Base.Slice{Base.OneTo{Int64}},UnitRange{Int64}},true},
offset_a::Int64, ma::Int64, na::Int64, b::Array{Float64,1}, offset_b::Int64, nb::Int64)
if offset_a != 1 if offset_a != 1
throw(DimensionMismatch("offset_a must be 1")) throw(DimensionMismatch("offset_a must be 1"))
end end
...@@ -25,93 +25,94 @@ function A_mul_B!(c::Array{Float64,1}, offset_c::Int64, a::SubArray{Float64,2,Ar ...@@ -25,93 +25,94 @@ function A_mul_B!(c::Array{Float64,1}, offset_c::Int64, a::SubArray{Float64,2,Ar
ref_b = Ref(b, offset_b) ref_b = Ref(b, offset_b)
ref_c = Ref(c, offset_c) ref_c = Ref(c, offset_c)
lda = max(1,size(a.parent,1)) lda = max(1,size(a.parent,1))
ccall((@blasfunc(dgemm_), libblas), Void, ccall((@blasfunc(dgemm_), libblas), Cvoid,
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, (Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},
Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, Ref{BlasInt}, Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, Ref{BlasInt},
Ptr{Float64}, Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, Ptr{Float64}, Ref{BlasInt}, Ref{Float64}, Ptr{Float64},
Ref{BlasInt}), Ref{BlasInt}),
'N', 'N', ma, nb, 'N', 'N', ma, nb,
na, 1.0, ref_a, lda, na, 1.0, ref_a, lda,
ref_b, na, 0.0, ref_c, ref_b, na, 0.0, ref_c,
ma) ma)
end end
function A_mul_B!(c::Array{Float64,1}, offset_c::Int64, a::Array{Float64,1}, offset_a::Int64, ma::Int64, na::Int64, b::SubArray{Float64,2,Array{Float64,2},Tuple{Base.Slice{Base.OneTo{Int64}},UnitRange{Int64}},true}, offset_b::Int64, nb::Int64) function A_mul_B!(c::Array{Float64,1}, offset_c::Int64, a::Array{Float64,1}, offset_a::Int64, ma::Int64, na::Int64,
b::SubArray{Float64,2,Array{Float64,2},Tuple{Base.Slice{Base.OneTo{Int64}},UnitRange{Int64}},true},
offset_b::Int64, nb::Int64)
if offset_b != 1 if offset_b != 1
throw(DimensionMismatch("offset_a must be 1")) throw(DimensionMismatch("offset_a must be 1"))
end end
ccall((@blasfunc(dgemm_), libblas), Void, ccall((@blasfunc(dgemm_), libblas), Cvoid,
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, (Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},
Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, Ref{BlasInt}, Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, Ref{BlasInt},
Ptr{Float64}, Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, Ptr{Float64}, Ref{BlasInt}, Ref{Float64}, Ptr{Float64},
Ref{BlasInt}), Ref{BlasInt}),
'N', 'N', ma, nb, 'N', 'N', ma, nb,
na, 1.0, Ref(a, offset_a), max(1,ma), na, 1.0, Ref(a, offset_a), max(1,ma),
Ref(b, offset_b), max(1,size(b.parent,1)), 0.0, Ref(c, offset_c), Ref(b, offset_b), max(1,size(b.parent,1)), 0.0, Ref(c, offset_c),
max(1,ma)) max(1,ma))
end end
function At_mul_B!(c::VecOrMat{Float64}, offset_c::Int64, a::VecOrMat{Float64}, function At_mul_B!(c::VecOrMat{Float64}, offset_c::Int64, a::VecOrMat{Float64},
offset_a::Int64, ma::Int64, na::Int64, b::VecOrMat{Float64}, offset_a::Int64, ma::Int64, na::Int64, b::VecOrMat{Float64},
offset_b::Int64, nb::Int64) offset_b::Int64, nb::Int64)
ccall((@blasfunc(dgemm_), libblas), Void, ccall((@blasfunc(dgemm_), libblas), Cvoid,
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, (Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},
Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, Ref{BlasInt}, Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, Ref{BlasInt},
Ptr{Float64}, Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, Ptr{Float64}, Ref{BlasInt}, Ref{Float64}, Ptr{Float64},
Ref{BlasInt}), Ref{BlasInt}),
'T', 'N', ma, nb, 'T', 'N', ma, nb,
na, 1.0, Ref(a, offset_a), max(1,na), na, 1.0, Ref(a, offset_a), max(1,na),
Ref(b, offset_b), max(1,na), 0.0, Ref(c, offset_c), Ref(b, offset_b), max(1,na), 0.0, Ref(c, offset_c),
max(1,ma)) max(1,ma))
end end
function A_mul_Bt!(c::VecOrMat{Float64}, offset_c::Int64, a::VecOrMat{Float64}, function A_mul_Bt!(c::VecOrMat{Float64}, offset_c::Int64, a::VecOrMat{Float64},
offset_a::Int64, ma::Int64, na::Int64, b::VecOrMat{Float64}, offset_a::Int64, ma::Int64, na::Int64, b::VecOrMat{Float64},
offset_b::Int64, nb::Int64) offset_b::Int64, nb::Int64)
ccall((@blasfunc(dgemm_), libblas), Void, ccall((@blasfunc(dgemm_), libblas), Cvoid,
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, (Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},
Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, Ref{BlasInt}, Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, Ref{BlasInt},
Ptr{Float64}, Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, Ptr{Float64}, Ref{BlasInt}, Ref{Float64}, Ptr{Float64},
Ref{BlasInt}), Ref{BlasInt}),
'N', 'T', ma, nb, 'N', 'T', ma, nb,
na, 1.0, Ref(a, offset_a), max(1,ma), na, 1.0, Ref(a, offset_a), max(1,ma),
Ref(b, offset_b), max(1,nb), 0.0, Ref(c, offset_c), Ref(b, offset_b), max(1,nb), 0.0, Ref(c, offset_c),
max(1,ma)) max(1,ma))
end end
function At_mul_Bt!(c::VecOrMat{Float64}, offset_c::Int64, a::VecOrMat{Float64}, function At_mul_Bt!(c::VecOrMat{Float64}, offset_c::Int64, a::VecOrMat{Float64},
offset_a::Int64, ma::Int64, na::Int64, b::VecOrMat{Float64}, offset_a::Int64, ma::Int64, na::Int64, b::VecOrMat{Float64},
offset_b::Int64, nb::Int64) offset_b::Int64, nb::Int64)
ccall((@blasfunc(dgemm_), libblas), Cvoid,
ccall((@blasfunc(dgemm_), libblas), Void,
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, (Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},
Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, Ref{BlasInt}, Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, Ref{BlasInt},
Ptr{Float64}, Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, Ptr{Float64}, Ref{BlasInt}, Ref{Float64}, Ptr{Float64},
Ref{BlasInt}), Ref{BlasInt}),
'T', 'T', ma, nb, 'T', 'T', ma, nb,
na, 1.0, Ref(a, offset_a), max(1,na), na, 1.0, Ref(a, offset_a), max(1,na),
Ref(b, offset_b), max(1,nb), 0.0, Ref(c, offset_c), Ref(b, offset_b), max(1,nb), 0.0, Ref(c, offset_c),
max(1,ma)) max(1,ma))
end end
function gemm!(ta::Char, tb::Char, alpha::Float64, a::Union{Ref{Float64},VecOrMat{Float64}}, function gemm!(ta::Char, tb::Char, alpha::Float64, a::Union{Ref{Float64},VecOrMat{Float64}},
ma::Int64, na::Int64, b::Union{Ref{Float64},VecOrMat{Float64}}, nb::Int64, ma::Int64, na::Int64, b::Union{Ref{Float64},VecOrMat{Float64}}, nb::Int64,
beta::Float64, c::Union{Ref{Float64},VecOrMat{Float64}}) beta::Float64, c::Union{Ref{Float64},VecOrMat{Float64}})
ccall((@blasfunc(dgemm_), libblas), Void, ccall((@blasfunc(dgemm_), libblas), Cvoid,
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, (Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},
Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, Ref{BlasInt}, Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, Ref{BlasInt},
Ptr{Float64}, Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, Ptr{Float64}, Ref{BlasInt}, Ref{Float64}, Ptr{Float64},
Ref{BlasInt}), Ref{BlasInt}),
ta, tb, ma, nb, ta, tb, ma, nb,
na, alpha, a, max(1,ma), na, alpha, a, max(1,ma),
b, max(1,na), beta, c, b, max(1,na), beta, c,
max(1,ma)) max(1,ma))
end end
function gemm_t!(ta::Char, tb::Char, alpha::Float64, a::Union{Ref{Float64},VecOrMat{Float64}}, function gemm_t!(ta::Char, tb::Char, alpha::Float64, a::Union{Ref{Float64},VecOrMat{Float64}},
ma::Int64, na::Int64, b::Union{Ref{Float64},VecOrMat{Float64}}, nb::Int64, ma::Int64, na::Int64, b::Union{Ref{Float64},VecOrMat{Float64}}, nb::Int64,
beta::Float64, c::Union{Ref{Float64},VecOrMat{Float64}}) beta::Float64, c::Union{Ref{Float64},VecOrMat{Float64}})
end end
end end
...@@ -4,7 +4,7 @@ using Base.Test ...@@ -4,7 +4,7 @@ using Base.Test
using ExtendedMul using ExtendedMul
import Base.convert import Base.convert
export a_mul_kron_b!, a_mul_b_kron_c!, kron_at_kron_b_mul_c!, a_mul_b_kron_c_d!, at_mul_b_kron_c!, a_mul_b_kron_ct! export a_mul_kron_b!, a_mul_b_kron_c!, kron_at_kron_b_mul_c!, a_mul_b_kron_c_d!, at_mul_b_kron_c!, a_mul_b_kron_ct!
import Base.LinAlg: A_mul_B!, At_mul_B!, A_mul_Bt!, At_mul_Bt import LinearAlgebra: A_mul_B!, At_mul_B!, A_mul_Bt!, At_mul_Bt
import Base.BLAS: gemm! import Base.BLAS: gemm!
""" """
...@@ -44,7 +44,7 @@ function a_mul_kron_b!(c::AbstractMatrix, a::AbstractMatrix, b::AbstractMatrix, ...@@ -44,7 +44,7 @@ function a_mul_kron_b!(c::AbstractMatrix, a::AbstractMatrix, b::AbstractMatrix,
end end
copy!(c,v2) copy!(c,v2)
end end
function a_mul_kron_b!(c::AbstractMatrix, a::AbstractMatrix, b::AbstractMatrix, order::Int64) function a_mul_kron_b!(c::AbstractMatrix, a::AbstractMatrix, b::AbstractMatrix, order::Int64)
ma, na = size(a) ma, na = size(a)
mb, nb = size(b) mb, nb = size(b)
...@@ -55,7 +55,7 @@ function a_mul_kron_b!(c::AbstractMatrix, a::AbstractMatrix, b::AbstractMatrix, ...@@ -55,7 +55,7 @@ function a_mul_kron_b!(c::AbstractMatrix, a::AbstractMatrix, b::AbstractMatrix,
mc == ma || throw(DimensionMismatch("The number of rows of c, $mc, doesn't match the number of rows of a, $ma")) mc == ma || throw(DimensionMismatch("The number of rows of c, $mc, doesn't match the number of rows of a, $ma"))
nc == nborder || throw(DimensionMismatch("The number of columns of c, $nc, doesn't match the number of columns of b, $nb, times order = $order")) nc == nborder || throw(DimensionMismatch("The number of columns of c, $nc, doesn't match the number of columns of b, $nb, times order = $order"))
mb == nb || throw(DimensionMismatch("B must be a square matrix")) mb == nb || throw(DimensionMismatch("B must be a square matrix"))
avec = vec(a) avec = vec(a)
cvec = vec(c) cvec = vec(c)
for q=0:order-1 for q=0:order-1
...@@ -67,7 +67,7 @@ function a_mul_kron_b!(c::AbstractMatrix, a::AbstractMatrix, b::AbstractMatrix, ...@@ -67,7 +67,7 @@ function a_mul_kron_b!(c::AbstractMatrix, a::AbstractMatrix, b::AbstractMatrix,
end end
end end
end end
""" """
a_mul_b_kron_c!(d::AbstractVecOrMat, a::AbstractVecOrMat, b::AbstractMatrix, c::AbstractMatrix, order::Int64) a_mul_b_kron_c!(d::AbstractVecOrMat, a::AbstractVecOrMat, b::AbstractMatrix, c::AbstractMatrix, order::Int64)
...@@ -97,7 +97,7 @@ end ...@@ -97,7 +97,7 @@ end
""" """
function kron_at_kron_b_mul_c!(d::AbstractVector, offset_c::Int64, a::AbstractMatrix, order::Int64, b::AbstractMatrix, c::AbstractVector, offset_cc::Int64, work1::AbstractVector, work2::AbstractVector) function kron_at_kron_b_mul_c!(d::AbstractVector, offset_c::Int64, a::AbstractMatrix, order::Int64, b::AbstractMatrix, c::AbstractVector, offset_cc::Int64, work1::AbstractVector, work2::AbstractVector)
computes d = (a^T ⊗ a^T ⊗ ... ⊗ a^T ⊗ b)c using work vectors work1 and work2 computes d = (a^T ⊗ a^T ⊗ ... ⊗ a^T ⊗ b)c using work vectors work1 and work2
""" """
function kron_at_kron_b_mul_c!(d::AbstractVector, offset_d::Int64, a::AbstractMatrix, order::Int64, b::AbstractMatrix, c::AbstractVector, offset_c::Int64, work1::AbstractVector, work2::AbstractVector, offset_w::Int64) function kron_at_kron_b_mul_c!(d::AbstractVector, offset_d::Int64, a::AbstractMatrix, order::Int64, b::AbstractMatrix, c::AbstractVector, offset_c::Int64, work1::AbstractVector, work2::AbstractVector, offset_w::Int64)
mb,nb = size(b) mb,nb = size(b)
if order == 0 if order == 0
...@@ -124,7 +124,7 @@ end ...@@ -124,7 +124,7 @@ end
""" """
function kron_at_kron_b_mul_c!(d::AbstractVector, a::AbstractMatrix, order::Int64, b::AbstractMatrix, c::AbstractVector, work1::AbstractVector, work2::AbstractVector) function kron_at_kron_b_mul_c!(d::AbstractVector, a::AbstractMatrix, order::Int64, b::AbstractMatrix, c::AbstractVector, work1::AbstractVector, work2::AbstractVector)
computes d = (a^T ⊗ a^T ⊗ ... ⊗ a^T ⊗ b)c using work vectors work1 and work2 computes d = (a^T ⊗ a^T ⊗ ... ⊗ a^T ⊗ b)c using work vectors work1 and work2
""" """
function kron_at_kron_b_mul_c!(d::AbstractVector, a::AbstractMatrix, order::Int64, b::AbstractMatrix, c::AbstractVector, work1::AbstractVector, work2::AbstractVector) function kron_at_kron_b_mul_c!(d::AbstractVector, a::AbstractMatrix, order::Int64, b::AbstractMatrix, c::AbstractVector, work1::AbstractVector, work2::AbstractVector)
mb,nb = size(b) mb,nb = size(b)
if order == 0 if order == 0
...@@ -212,7 +212,7 @@ function at_mul_b_kron_c!(d::AbstractMatrix, a::AbstractMatrix, b::AbstractMatri ...@@ -212,7 +212,7 @@ function at_mul_b_kron_c!(d::AbstractMatrix, a::AbstractMatrix, b::AbstractMatri
At_mul_B!(vec(d), 1, a, 1, na, ma, work1, 1, nc^order) At_mul_B!(vec(d), 1, a, 1, na, ma, work1, 1, nc^order)
end end
end end
function a_mul_b_kron_ct!(d::AbstractMatrix, a::AbstractMatrix, b::AbstractMatrix, c::AbstractMatrix, order::Int64, work1::AbstractVector, work2::AbstractVector) function a_mul_b_kron_ct!(d::AbstractMatrix, a::AbstractMatrix, b::AbstractMatrix, c::AbstractMatrix, order::Int64, work1::AbstractVector, work2::AbstractVector)
ma, na = size(a) ma, na = size(a)
mb, nb = size(b) mb, nb = size(b)
...@@ -225,7 +225,7 @@ function a_mul_b_kron_ct!(d::AbstractMatrix, a::AbstractMatrix, b::AbstractMatri ...@@ -225,7 +225,7 @@ function a_mul_b_kron_ct!(d::AbstractMatrix, a::AbstractMatrix, b::AbstractMatri
At_mul_B!(vec(d), 1, a, 1, ma, na, work1, 1, nc^order) At_mul_B!(vec(d), 1, a, 1, ma, na, work1, 1, nc^order)
end end
end end
function a_mul_kron_b!(c::AbstractMatrix, a::AbstractMatrix, b::AbstractVector,work::AbstractVector) function a_mul_kron_b!(c::AbstractMatrix, a::AbstractMatrix, b::AbstractVector,work::AbstractVector)
ma, na = size(a) ma, na = size(a)
order = length(b) order = length(b)
...@@ -241,7 +241,7 @@ function a_mul_kron_b!(c::AbstractMatrix, a::AbstractMatrix, b::AbstractVector,w ...@@ -241,7 +241,7 @@ function a_mul_kron_b!(c::AbstractMatrix, a::AbstractMatrix, b::AbstractVector,w
mc == ma || throw(DimensionMismatch("The number of rows of c, $mc, doesn't match the number of rows of a, $ma")) mc == ma || throw(DimensionMismatch("The number of rows of c, $mc, doesn't match the number of rows of a, $ma"))
nc == nborder || throw(DimensionMismatch("The number of columns of c, $nc, doesn't match the number of columns of matrices in b, $nborder")) nc == nborder || throw(DimensionMismatch("The number of columns of c, $nc, doesn't match the number of columns of matrices in b, $nborder"))
mborder <= nborder || throw(DimensionMismatch("the product of the number of rows of the b matrices needs to be smaller or equal to the product of the number of columns. Otherwise, you need to call a_mul_kron_b!(c::AbstractMatrix, a::AbstractMatrix, b::Vector{AbstractMatrix},work::AbstractVector)")) mborder <= nborder || throw(DimensionMismatch("the product of the number of rows of the b matrices needs to be smaller or equal to the product of the number of columns. Otherwise, you need to call a_mul_kron_b!(c::AbstractMatrix, a::AbstractMatrix, b::Vector{AbstractMatrix},work::AbstractVector)"))
copy!(work,a) copy!(work,a)
mb, nb = size(b[1]) mb, nb = size(b[1])
vwork = view(work,1:ma*na) vwork = view(work,1:ma*na)
...@@ -262,7 +262,7 @@ function a_mul_kron_b!(c::AbstractMatrix, a::AbstractMatrix, b::AbstractVector,w ...@@ -262,7 +262,7 @@ function a_mul_kron_b!(c::AbstractMatrix, a::AbstractMatrix, b::AbstractVector,w
end end
end end
end end
""" """
a_mul_b_kron_c_d!(d::AbstractVecOrMat, a::AbstractVecOrMat, b::AbstractMatrix, c::AbstractMatrix, order::Int64) a_mul_b_kron_c_d!(d::AbstractVecOrMat, a::AbstractVecOrMat, b::AbstractMatrix, c::AbstractMatrix, order::Int64)
...@@ -340,7 +340,7 @@ function kron_mul_elem_t!(c::Vector, offset_c::Int64, a::AbstractMatrix, b::Vect ...@@ -340,7 +340,7 @@ function kron_mul_elem_t!(c::Vector, offset_c::Int64, a::AbstractMatrix, b::Vect
m, n = size(a) m, n = size(a)
length(b) >= m*p*q || throw(DimensionMismatch("The dimension of vector b, $(length(b)) doesn't correspond to order, ($p, $q) and the dimensions of the matrix, $(size(a))")) length(b) >= m*p*q || throw(DimensionMismatch("The dimension of vector b, $(length(b)) doesn't correspond to order, ($p, $q) and the dimensions of the matrix, $(size(a))"))
length(c) >= n*p*q || throw(DimensionMismatch("The dimension of the vector c, $(length(c)) doesn't correspond to order, ($p, $q) and the dimensions of the matrix, $(size(a))")) length(c) >= n*p*q || throw(DimensionMismatch("The dimension of the vector c, $(length(c)) doesn't correspond to order, ($p, $q) and the dimensions of the matrix, $(size(a))"))
begin begin
if p == 1 && q == 1 if p == 1 && q == 1
# a'*b # a'*b
......
module LinSolveAlgo module LinSolveAlgo
import Base.LinAlg.BlasInt import LinearAlgebra.BlasInt
import Base.LinAlg.BLAS.@blasfunc import LinearAlgebra.BLAS.@blasfunc
import Base.LinAlg.BLAS.libblas import LinearAlgebra.BLAS.libblas
import Base.LinAlg.LAPACK: liblapack, chklapackerror import LinearAlgebra.LAPACK: liblapack, chklapackerror
export LinSolveWS, linsolve_core!, linsolve_core_no_lu!, lu! export LinSolveWS, linsolve_core!, linsolve_core_no_lu!, lu!
struct LinSolveWS struct LinSolveWS
lu::Matrix{Float64} lu::Matrix{Float64}
ipiv::Vector{BlasInt} ipiv::Vector{BlasInt}
function LinSolveWS(n::Int64)
function LinSolveWS(n) lu = zeros(Float64,n,n)
lu = Matrix{Float64}(n,n) ipiv = zeros(BlasInt,n)
ipiv = Vector{BlasInt}(n)
new(lu,ipiv) new(lu,ipiv)
end end
end end
...@@ -27,7 +26,7 @@ function linsolve_core_no_lu!(ws::LinSolveWS,trans::Ref{UInt8},a::StridedMatrix{ ...@@ -27,7 +26,7 @@ function linsolve_core_no_lu!(ws::LinSolveWS,trans::Ref{UInt8},a::StridedMatrix{
ldb = Ref{BlasInt}(max(1,stride(b,2))) ldb = Ref{BlasInt}(max(1,stride(b,2)))
info = Ref{BlasInt}(0) info = Ref{BlasInt}(0)
ccall((@blasfunc(dgetrs_), liblapack), Void, ccall((@blasfunc(dgetrs_), liblapack), Cvoid,
(Ref{UInt8},Ref{BlasInt},Ref{BlasInt},Ptr{Float64},Ref{BlasInt}, (Ref{UInt8},Ref{BlasInt},Ref{BlasInt},Ptr{Float64},Ref{BlasInt},
Ptr{BlasInt},Ptr{Float64},Ref{BlasInt},Ref{BlasInt}), Ptr{BlasInt},Ptr{Float64},Ref{BlasInt},Ref{BlasInt}),
trans,n,nhrs,ws.lu,lda,ws.ipiv,b,ldb,info) trans,n,nhrs,ws.lu,lda,ws.ipiv,b,ldb,info)
...@@ -47,7 +46,7 @@ function linsolve_core!(ws::LinSolveWS,trans::Ref{UInt8},a::StridedMatrix{Float6 ...@@ -47,7 +46,7 @@ function linsolve_core!(ws::LinSolveWS,trans::Ref{UInt8},a::StridedMatrix{Float6
info = Ref{BlasInt}(0) info = Ref{BlasInt}(0)
lu!(ws.lu,a,ws.ipiv) lu!(ws.lu,a,ws.ipiv)
ccall((@blasfunc(dgetrs_), liblapack), Void, ccall((@blasfunc(dgetrs_), liblapack), Cvoid,
(Ref{UInt8},Ref{BlasInt},Ref{BlasInt},Ptr{Float64},Ref{BlasInt}, (Ref{UInt8},Ref{BlasInt},Ref{BlasInt},Ptr{Float64},Ref{BlasInt},
Ptr{BlasInt},Ptr{Float64},Ref{BlasInt},Ref{BlasInt}), Ptr{BlasInt},Ptr{Float64},Ref{BlasInt},Ref{BlasInt}),
trans,n,nhrs,ws.lu,lda,ws.ipiv,b,ldb,info) trans,n,nhrs,ws.lu,lda,ws.ipiv,b,ldb,info)
...@@ -68,7 +67,7 @@ function linsolve_core!(ws::LinSolveWS,trans::Ref{UInt8},a::StridedMatrix{Float6 ...@@ -68,7 +67,7 @@ function linsolve_core!(ws::LinSolveWS,trans::Ref{UInt8},a::StridedMatrix{Float6
lu!(ws.lu,a,ws.ipiv) lu!(ws.lu,a,ws.ipiv)
nhrs = Ref{BlasInt}(size(b,2)) nhrs = Ref{BlasInt}(size(b,2))
ccall((@blasfunc(dgetrs_), liblapack), Void, ccall((@blasfunc(dgetrs_), liblapack), Cvoid,
(Ref{UInt8},Ref{BlasInt},Ref{BlasInt},Ptr{Float64},Ref{BlasInt}, (Ref{UInt8},Ref{BlasInt},Ref{BlasInt},Ptr{Float64},Ref{BlasInt},
Ptr{BlasInt},Ptr{Float64},Ref{BlasInt},Ref{BlasInt}), Ptr{BlasInt},Ptr{Float64},Ref{BlasInt},Ref{BlasInt}),
trans,n,nhrs,ws.lu,lda,ws.ipiv,b,ldb,info) trans,n,nhrs,ws.lu,lda,ws.ipiv,b,ldb,info)
...@@ -77,7 +76,7 @@ function linsolve_core!(ws::LinSolveWS,trans::Ref{UInt8},a::StridedMatrix{Float6 ...@@ -77,7 +76,7 @@ function linsolve_core!(ws::LinSolveWS,trans::Ref{UInt8},a::StridedMatrix{Float6
chklapackerror(info[]) chklapackerror(info[])
end end
nhrs = Ref{BlasInt}(size(c,2)) nhrs = Ref{BlasInt}(size(c,2))
ccall((@blasfunc(dgetrs_), liblapack), Void, ccall((@blasfunc(dgetrs_), liblapack), Cvoid,
(Ref{UInt8},Ref{BlasInt},Ref{BlasInt},Ptr{Float64},Ref{BlasInt}, (Ref{UInt8},Ref{BlasInt},Ref{BlasInt},Ptr{Float64},Ref{BlasInt},
Ptr{BlasInt},Ptr{Float64},Ref{BlasInt},Ref{BlasInt}), Ptr{BlasInt},Ptr{Float64},Ref{BlasInt},Ref{BlasInt}),
trans,n,nhrs,ws.lu,lda,ws.ipiv,c,ldc,info) trans,n,nhrs,ws.lu,lda,ws.ipiv,c,ldc,info)
...@@ -99,7 +98,7 @@ function linsolve_core!(ws::LinSolveWS,trans::Ref{UInt8},a::StridedMatrix{Float6 ...@@ -99,7 +98,7 @@ function linsolve_core!(ws::LinSolveWS,trans::Ref{UInt8},a::StridedMatrix{Float6
info = Ref{BlasInt}(0) info = Ref{BlasInt}(0)
lu!(ws.lu,a,ws.ipiv) lu!(ws.lu,a,ws.ipiv)
ccall((@blasfunc(dgetrs_), liblapack), Void, ccall((@blasfunc(dgetrs_), liblapack), Cvoid,
(Ref{UInt8},Ref{BlasInt},Ref{BlasInt},Ptr{Float64},Ref{BlasInt}, (Ref{UInt8},Ref{BlasInt},Ref{BlasInt},Ptr{Float64},Ref{BlasInt},
Ptr{BlasInt},Ptr{Float64},Ref{BlasInt},Ref{BlasInt}), Ptr{BlasInt},Ptr{Float64},Ref{BlasInt},Ref{BlasInt}),
trans,n,nhrs,ws.lu,lda,ws.ipiv,b,ldb,info) trans,n,nhrs,ws.lu,lda,ws.ipiv,b,ldb,info)
...@@ -107,7 +106,7 @@ function linsolve_core!(ws::LinSolveWS,trans::Ref{UInt8},a::StridedMatrix{Float6 ...@@ -107,7 +106,7 @@ function linsolve_core!(ws::LinSolveWS,trans::Ref{UInt8},a::StridedMatrix{Float6
println("dgetrs ",info[]) println("dgetrs ",info[])
chklapackerror(info[]) chklapackerror(info[])
end end
ccall((@blasfunc(dgetrs_), liblapack), Void, ccall((@blasfunc(dgetrs_), liblapack), Cvoid,
(Ref{UInt8},Ref{BlasInt},Ref{BlasInt},Ptr{Float64},Ref{BlasInt}, (Ref{UInt8},Ref{BlasInt},Ref{BlasInt},Ptr{Float64},Ref{BlasInt},
Ptr{BlasInt},Ptr{Float64},Ref{BlasInt},Ref{BlasInt}), Ptr{BlasInt},Ptr{Float64},Ref{BlasInt},Ref{BlasInt}),
trans,n,nhrs,ws.lu,lda,ws.ipiv,c,ldc,info) trans,n,nhrs,ws.lu,lda,ws.ipiv,c,ldc,info)
...@@ -115,7 +114,7 @@ function linsolve_core!(ws::LinSolveWS,trans::Ref{UInt8},a::StridedMatrix{Float6 ...@@ -115,7 +114,7 @@ function linsolve_core!(ws::LinSolveWS,trans::Ref{UInt8},a::StridedMatrix{Float6
println("dgetrs ",info[]) println("dgetrs ",info[])
chklapackerror(info[]) chklapackerror(info[])
end end
ccall((@blasfunc(dgetrs_), liblapack), Void, ccall((@blasfunc(dgetrs_), liblapack), Cvoid,
(Ref{UInt8},Ref{BlasInt},Ref{BlasInt},Ptr{Float64},Ref{BlasInt}, (Ref{UInt8},Ref{BlasInt},Ref{BlasInt},Ptr{Float64},Ref{BlasInt},
Ptr{BlasInt},Ptr{Float64},Ref{BlasInt},Ref{BlasInt}), Ptr{BlasInt},Ptr{Float64},Ref{BlasInt},Ref{BlasInt}),
trans,n,nhrs,ws.lu,lda,ws.ipiv,d,ldd,info) trans,n,nhrs,ws.lu,lda,ws.ipiv,d,ldd,info)
...@@ -132,7 +131,7 @@ function lu!(lu,a,ipiv) ...@@ -132,7 +131,7 @@ function lu!(lu,a,ipiv)
n = Ref{BlasInt}(nn) n = Ref{BlasInt}(nn)
lda = Ref{BlasInt}(max(1,stride(a,2))) lda = Ref{BlasInt}(max(1,stride(a,2)))
info = Ref{BlasInt}(0) info = Ref{BlasInt}(0)
ccall((@blasfunc(dgetrf_), liblapack), Void, ccall((@blasfunc(dgetrf_), liblapack), Cvoid,
(Ref{BlasInt},Ref{BlasInt},Ptr{Float64},Ref{BlasInt}, (Ref{BlasInt},Ref{BlasInt},Ptr{Float64},Ref{BlasInt},
Ptr{BlasInt},Ref{BlasInt}), Ptr{BlasInt},Ref{BlasInt}),
m,n,lu,lda,ipiv,info) m,n,lu,lda,ipiv,info)
......
...@@ -7,18 +7,18 @@ import Base.getindex ...@@ -7,18 +7,18 @@ import Base.getindex
import Base.setindex! import Base.setindex!
import Base.copy import Base.copy
import Base: A_mul_B!, At_mul_B!, A_mul_Bt!, A_ldiv_B! import Base: A_mul_B!, At_mul_B!, A_mul_Bt!, A_ldiv_B!
import Base.LinAlg.BlasInt import LinearAlgebra.BlasInt
import Base.LinAlg.BLAS.@blasfunc import LinearAlgebra.BLAS.@blasfunc
import Base.LinAlg.BLAS.libblas import LinearAlgebra.BLAS.libblas
export QuasiUpperTriangular, I_plus_rA_ldiv_B!, I_plus_rA_plus_sB_ldiv_C!, A_rdiv_B!, A_rdiv_Bt! export QuasiUpperTriangular, I_plus_rA_ldiv_B!, I_plus_rA_plus_sB_ldiv_C!, A_rdiv_B!, A_rdiv_Bt!
struct QuasiUpperTriangular{T<:Real,S<:AbstractMatrix} <: AbstractMatrix{T} struct QuasiUpperTriangular{T<:Real,S<:AbstractMatrix} <: AbstractMatrix{T}
data::S data::S
end end
QuasiUpperTriangular(A::QuasiUpperTriangular) = A QuasiUpperTriangular(A::QuasiUpperTriangular) = A
function QuasiUpperTriangular(A::AbstractMatrix) function QuasiUpperTriangular(A::AbstractMatrix)
Base.LinAlg.checksquare(A) LinearAlgebra.checksquare(A)
return QuasiUpperTriangular{eltype(A), typeof(A)}(A) return QuasiUpperTriangular{eltype(A), typeof(A)}(A)
end end
...@@ -197,7 +197,7 @@ function A_mul_B!(c::AbstractVecOrMat, alpha::Float64, a::QuasiUpperTriangular, ...@@ -197,7 +197,7 @@ function A_mul_B!(c::AbstractVecOrMat, alpha::Float64, a::QuasiUpperTriangular,
Ref{UInt8}('L'), Ref{UInt8}('U'), Ref{UInt8}('N'), Ref{UInt8}('N'), Ref{BlasInt}(nr), Ref{BlasInt}(nc), Ref{UInt8}('L'), Ref{UInt8}('U'), Ref{UInt8}('N'), Ref{UInt8}('N'), Ref{BlasInt}(nr), Ref{BlasInt}(nc),
Ref{Float64}(alpha), a.data, Ref{BlasInt}(nr), c, Ref{BlasInt}(nr)) Ref{Float64}(alpha), a.data, Ref{BlasInt}(nr), c, Ref{BlasInt}(nr))
b1 = reshape(b,nr,nc) b1 = reshape(b,nr,nc)
c1 = reshape(c,nr,nc) c1 = reshape(c,nr,nc)
@inbounds for i= 2:m @inbounds for i= 2:m
x = a[i,i-1] x = a[i,i-1]
...@@ -432,7 +432,7 @@ function A_mul_B!(c::VecOrMat{Float64}, offset_c::Int64, a::QuasiUpperTriangular ...@@ -432,7 +432,7 @@ function A_mul_B!(c::VecOrMat{Float64}, offset_c::Int64, a::QuasiUpperTriangular
@inbounds for i = 2:ma @inbounds for i = 2:ma
x = a[i,i-1] x = a[i,i-1]
indb = offset_b indb = offset_b
indc = offset_c + 1 indc = offset_c + 1
@simd for j = 1:nb @simd for j = 1:nb
c[indc] += x*b[indb] c[indc] += x*b[indb]
...@@ -452,7 +452,7 @@ function A_mul_B!(c::Array{Float64,1}, offset_c::Int64, ...@@ -452,7 +452,7 @@ function A_mul_B!(c::Array{Float64,1}, offset_c::Int64,
Ref{Float64}, Ptr{Float64}, Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt}), Ref{Float64}, Ptr{Float64}, Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt}),
Ref{UInt8}('L'), Ref{UInt8}('U'),