Commit 730015bc authored by Frédéric Karamé's avatar Frédéric Karamé

First changes to julia version 1+

parent 968c6594
module ExtendedMul
import Base.LinAlg: A_mul_B!, At_mul_B!, A_mul_Bt!, At_mul_Bt
import LinearAlgebra: A_mul_B!, At_mul_B!, A_mul_Bt!, At_mul_Bt
import Base.BLAS: gemm!
import Base.LinAlg: BlasInt, BlasFloat
import Base.LinAlg.BLAS: @blasfunc, libblas
import LinearAlgebra: BlasInt, BlasFloat
import LinearAlgebra.BLAS: @blasfunc, libblas
export A_mul_B!, At_mul_B!, A_mul_Bt!, At_mul_B!
function A_mul_B!(c::VecOrMat{Float64}, offset_c::Int64, a::VecOrMat{Float64},
offset_a::Int64, ma::Int64, na::Int64, b::VecOrMat{Float64},
offset_b::Int64, nb::Int64)
gemm!('N', 'N', 1.0, Ref(a, offset_a), ma, na, Ref(b, offset_b),
nb, 0.0, Ref(c, offset_c))
end
function A_mul_B!(c::Array{Float64,1}, offset_c::Int64, a::SubArray{Float64,2,Array{Float64,2},Tuple{Base.Slice{Base.OneTo{Int64}},UnitRange{Int64}},true}, offset_a::Int64, ma::Int64, na::Int64, b::Array{Float64,1}, offset_b::Int64, nb::Int64)
function A_mul_B!(c::Array{Float64,1}, offset_c::Int64, a::SubArray{Float64,2,Array{Float64,2},Tuple{Base.Slice{Base.OneTo{Int64}},UnitRange{Int64}},true},
offset_a::Int64, ma::Int64, na::Int64, b::Array{Float64,1}, offset_b::Int64, nb::Int64)
if offset_a != 1
throw(DimensionMismatch("offset_a must be 1"))
end
......@@ -25,93 +25,94 @@ function A_mul_B!(c::Array{Float64,1}, offset_c::Int64, a::SubArray{Float64,2,Ar
ref_b = Ref(b, offset_b)
ref_c = Ref(c, offset_c)
lda = max(1,size(a.parent,1))
ccall((@blasfunc(dgemm_), libblas), Void,
ccall((@blasfunc(dgemm_), libblas), Cvoid,
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},
Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, Ref{BlasInt},
Ptr{Float64}, Ref{BlasInt}, Ref{Float64}, Ptr{Float64},
Ref{BlasInt}),
'N', 'N', ma, nb,
na, 1.0, ref_a, lda,
ref_b, na, 0.0, ref_c,
ma)
'N', 'N', ma, nb,
na, 1.0, ref_a, lda,
ref_b, na, 0.0, ref_c,
ma)
end
function A_mul_B!(c::Array{Float64,1}, offset_c::Int64, a::Array{Float64,1}, offset_a::Int64, ma::Int64, na::Int64, b::SubArray{Float64,2,Array{Float64,2},Tuple{Base.Slice{Base.OneTo{Int64}},UnitRange{Int64}},true}, offset_b::Int64, nb::Int64)
function A_mul_B!(c::Array{Float64,1}, offset_c::Int64, a::Array{Float64,1}, offset_a::Int64, ma::Int64, na::Int64,
b::SubArray{Float64,2,Array{Float64,2},Tuple{Base.Slice{Base.OneTo{Int64}},UnitRange{Int64}},true},
offset_b::Int64, nb::Int64)
if offset_b != 1
throw(DimensionMismatch("offset_a must be 1"))
end
ccall((@blasfunc(dgemm_), libblas), Void,
ccall((@blasfunc(dgemm_), libblas), Cvoid,
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},
Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, Ref{BlasInt},
Ptr{Float64}, Ref{BlasInt}, Ref{Float64}, Ptr{Float64},
Ref{BlasInt}),
'N', 'N', ma, nb,
na, 1.0, Ref(a, offset_a), max(1,ma),
Ref(b, offset_b), max(1,size(b.parent,1)), 0.0, Ref(c, offset_c),
max(1,ma))
'N', 'N', ma, nb,
na, 1.0, Ref(a, offset_a), max(1,ma),
Ref(b, offset_b), max(1,size(b.parent,1)), 0.0, Ref(c, offset_c),
max(1,ma))
end
function At_mul_B!(c::VecOrMat{Float64}, offset_c::Int64, a::VecOrMat{Float64},
offset_a::Int64, ma::Int64, na::Int64, b::VecOrMat{Float64},
offset_a::Int64, ma::Int64, na::Int64, b::VecOrMat{Float64},
offset_b::Int64, nb::Int64)
ccall((@blasfunc(dgemm_), libblas), Void,
ccall((@blasfunc(dgemm_), libblas), Cvoid,
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},
Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, Ref{BlasInt},
Ptr{Float64}, Ref{BlasInt}, Ref{Float64}, Ptr{Float64},
Ref{BlasInt}),
'T', 'N', ma, nb,
na, 1.0, Ref(a, offset_a), max(1,na),
Ref(b, offset_b), max(1,na), 0.0, Ref(c, offset_c),
max(1,ma))
'T', 'N', ma, nb,
na, 1.0, Ref(a, offset_a), max(1,na),
Ref(b, offset_b), max(1,na), 0.0, Ref(c, offset_c),
max(1,ma))
end
function A_mul_Bt!(c::VecOrMat{Float64}, offset_c::Int64, a::VecOrMat{Float64},
offset_a::Int64, ma::Int64, na::Int64, b::VecOrMat{Float64},
offset_b::Int64, nb::Int64)
ccall((@blasfunc(dgemm_), libblas), Void,
offset_a::Int64, ma::Int64, na::Int64, b::VecOrMat{Float64},
offset_b::Int64, nb::Int64)
ccall((@blasfunc(dgemm_), libblas), Cvoid,
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},
Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, Ref{BlasInt},
Ptr{Float64}, Ref{BlasInt}, Ref{Float64}, Ptr{Float64},
Ref{BlasInt}),
'N', 'T', ma, nb,
na, 1.0, Ref(a, offset_a), max(1,ma),
Ref(b, offset_b), max(1,nb), 0.0, Ref(c, offset_c),
max(1,ma))
'N', 'T', ma, nb,
na, 1.0, Ref(a, offset_a), max(1,ma),
Ref(b, offset_b), max(1,nb), 0.0, Ref(c, offset_c),
max(1,ma))
end
function At_mul_Bt!(c::VecOrMat{Float64}, offset_c::Int64, a::VecOrMat{Float64},
offset_a::Int64, ma::Int64, na::Int64, b::VecOrMat{Float64},
offset_b::Int64, nb::Int64)
ccall((@blasfunc(dgemm_), libblas), Void,
offset_a::Int64, ma::Int64, na::Int64, b::VecOrMat{Float64},
offset_b::Int64, nb::Int64)
ccall((@blasfunc(dgemm_), libblas), Cvoid,
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},
Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, Ref{BlasInt},
Ptr{Float64}, Ref{BlasInt}, Ref{Float64}, Ptr{Float64},
Ref{BlasInt}),
'T', 'T', ma, nb,
na, 1.0, Ref(a, offset_a), max(1,na),
Ref(b, offset_b), max(1,nb), 0.0, Ref(c, offset_c),
max(1,ma))
'T', 'T', ma, nb,
na, 1.0, Ref(a, offset_a), max(1,na),
Ref(b, offset_b), max(1,nb), 0.0, Ref(c, offset_c),
max(1,ma))
end
function gemm!(ta::Char, tb::Char, alpha::Float64, a::Union{Ref{Float64},VecOrMat{Float64}},
ma::Int64, na::Int64, b::Union{Ref{Float64},VecOrMat{Float64}}, nb::Int64,
beta::Float64, c::Union{Ref{Float64},VecOrMat{Float64}})
ccall((@blasfunc(dgemm_), libblas), Void,
ccall((@blasfunc(dgemm_), libblas), Cvoid,
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},
Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, Ref{BlasInt},
Ptr{Float64}, Ref{BlasInt}, Ref{Float64}, Ptr{Float64},
Ref{BlasInt}),
ta, tb, ma, nb,
na, alpha, a, max(1,ma),
b, max(1,na), beta, c,
max(1,ma))
ta, tb, ma, nb,
na, alpha, a, max(1,ma),
b, max(1,na), beta, c,
max(1,ma))
end
function gemm_t!(ta::Char, tb::Char, alpha::Float64, a::Union{Ref{Float64},VecOrMat{Float64}},
ma::Int64, na::Int64, b::Union{Ref{Float64},VecOrMat{Float64}}, nb::Int64,
beta::Float64, c::Union{Ref{Float64},VecOrMat{Float64}})
ma::Int64, na::Int64, b::Union{Ref{Float64},VecOrMat{Float64}}, nb::Int64,
beta::Float64, c::Union{Ref{Float64},VecOrMat{Float64}})
end
end
......@@ -4,7 +4,7 @@ using Base.Test
using ExtendedMul
import Base.convert
export a_mul_kron_b!, a_mul_b_kron_c!, kron_at_kron_b_mul_c!, a_mul_b_kron_c_d!, at_mul_b_kron_c!, a_mul_b_kron_ct!
import Base.LinAlg: A_mul_B!, At_mul_B!, A_mul_Bt!, At_mul_Bt
import LinearAlgebra: A_mul_B!, At_mul_B!, A_mul_Bt!, At_mul_Bt
import Base.BLAS: gemm!
"""
......@@ -44,7 +44,7 @@ function a_mul_kron_b!(c::AbstractMatrix, a::AbstractMatrix, b::AbstractMatrix,
end
copy!(c,v2)
end
function a_mul_kron_b!(c::AbstractMatrix, a::AbstractMatrix, b::AbstractMatrix, order::Int64)
ma, na = size(a)
mb, nb = size(b)
......@@ -55,7 +55,7 @@ function a_mul_kron_b!(c::AbstractMatrix, a::AbstractMatrix, b::AbstractMatrix,
mc == ma || throw(DimensionMismatch("The number of rows of c, $mc, doesn't match the number of rows of a, $ma"))
nc == nborder || throw(DimensionMismatch("The number of columns of c, $nc, doesn't match the number of columns of b, $nb, times order = $order"))
mb == nb || throw(DimensionMismatch("B must be a square matrix"))
avec = vec(a)
cvec = vec(c)
for q=0:order-1
......@@ -67,7 +67,7 @@ function a_mul_kron_b!(c::AbstractMatrix, a::AbstractMatrix, b::AbstractMatrix,
end
end
end
"""
a_mul_b_kron_c!(d::AbstractVecOrMat, a::AbstractVecOrMat, b::AbstractMatrix, c::AbstractMatrix, order::Int64)
......@@ -97,7 +97,7 @@ end
"""
function kron_at_kron_b_mul_c!(d::AbstractVector, offset_c::Int64, a::AbstractMatrix, order::Int64, b::AbstractMatrix, c::AbstractVector, offset_cc::Int64, work1::AbstractVector, work2::AbstractVector)
computes d = (a^T ⊗ a^T ⊗ ... ⊗ a^T ⊗ b)c using work vectors work1 and work2
"""
"""
function kron_at_kron_b_mul_c!(d::AbstractVector, offset_d::Int64, a::AbstractMatrix, order::Int64, b::AbstractMatrix, c::AbstractVector, offset_c::Int64, work1::AbstractVector, work2::AbstractVector, offset_w::Int64)
mb,nb = size(b)
if order == 0
......@@ -124,7 +124,7 @@ end
"""
function kron_at_kron_b_mul_c!(d::AbstractVector, a::AbstractMatrix, order::Int64, b::AbstractMatrix, c::AbstractVector, work1::AbstractVector, work2::AbstractVector)
computes d = (a^T ⊗ a^T ⊗ ... ⊗ a^T ⊗ b)c using work vectors work1 and work2
"""
"""
function kron_at_kron_b_mul_c!(d::AbstractVector, a::AbstractMatrix, order::Int64, b::AbstractMatrix, c::AbstractVector, work1::AbstractVector, work2::AbstractVector)
mb,nb = size(b)
if order == 0
......@@ -212,7 +212,7 @@ function at_mul_b_kron_c!(d::AbstractMatrix, a::AbstractMatrix, b::AbstractMatri
At_mul_B!(vec(d), 1, a, 1, na, ma, work1, 1, nc^order)
end
end
function a_mul_b_kron_ct!(d::AbstractMatrix, a::AbstractMatrix, b::AbstractMatrix, c::AbstractMatrix, order::Int64, work1::AbstractVector, work2::AbstractVector)
ma, na = size(a)
mb, nb = size(b)
......@@ -225,7 +225,7 @@ function a_mul_b_kron_ct!(d::AbstractMatrix, a::AbstractMatrix, b::AbstractMatri
At_mul_B!(vec(d), 1, a, 1, ma, na, work1, 1, nc^order)
end
end
function a_mul_kron_b!(c::AbstractMatrix, a::AbstractMatrix, b::AbstractVector,work::AbstractVector)
ma, na = size(a)
order = length(b)
......@@ -241,7 +241,7 @@ function a_mul_kron_b!(c::AbstractMatrix, a::AbstractMatrix, b::AbstractVector,w
mc == ma || throw(DimensionMismatch("The number of rows of c, $mc, doesn't match the number of rows of a, $ma"))
nc == nborder || throw(DimensionMismatch("The number of columns of c, $nc, doesn't match the number of columns of matrices in b, $nborder"))
mborder <= nborder || throw(DimensionMismatch("the product of the number of rows of the b matrices needs to be smaller or equal to the product of the number of columns. Otherwise, you need to call a_mul_kron_b!(c::AbstractMatrix, a::AbstractMatrix, b::Vector{AbstractMatrix},work::AbstractVector)"))
copy!(work,a)
mb, nb = size(b[1])
vwork = view(work,1:ma*na)
......@@ -262,7 +262,7 @@ function a_mul_kron_b!(c::AbstractMatrix, a::AbstractMatrix, b::AbstractVector,w
end
end
end
"""
a_mul_b_kron_c_d!(d::AbstractVecOrMat, a::AbstractVecOrMat, b::AbstractMatrix, c::AbstractMatrix, order::Int64)
......@@ -340,7 +340,7 @@ function kron_mul_elem_t!(c::Vector, offset_c::Int64, a::AbstractMatrix, b::Vect
m, n = size(a)
length(b) >= m*p*q || throw(DimensionMismatch("The dimension of vector b, $(length(b)) doesn't correspond to order, ($p, $q) and the dimensions of the matrix, $(size(a))"))
length(c) >= n*p*q || throw(DimensionMismatch("The dimension of the vector c, $(length(c)) doesn't correspond to order, ($p, $q) and the dimensions of the matrix, $(size(a))"))
begin
if p == 1 && q == 1
# a'*b
......
module LinSolveAlgo
import Base.LinAlg.BlasInt
import Base.LinAlg.BLAS.@blasfunc
import Base.LinAlg.BLAS.libblas
import Base.LinAlg.LAPACK: liblapack, chklapackerror
import LinearAlgebra.BlasInt
import LinearAlgebra.BLAS.@blasfunc
import LinearAlgebra.BLAS.libblas
import LinearAlgebra.LAPACK: liblapack, chklapackerror
export LinSolveWS, linsolve_core!, linsolve_core_no_lu!, lu!
struct LinSolveWS
lu::Matrix{Float64}
ipiv::Vector{BlasInt}
function LinSolveWS(n)
lu = Matrix{Float64}(n,n)
ipiv = Vector{BlasInt}(n)
function LinSolveWS(n::Int64)
lu = zeros(Float64,n,n)
ipiv = zeros(BlasInt,n)
new(lu,ipiv)
end
end
......@@ -27,7 +26,7 @@ function linsolve_core_no_lu!(ws::LinSolveWS,trans::Ref{UInt8},a::StridedMatrix{
ldb = Ref{BlasInt}(max(1,stride(b,2)))
info = Ref{BlasInt}(0)
ccall((@blasfunc(dgetrs_), liblapack), Void,
ccall((@blasfunc(dgetrs_), liblapack), Cvoid,
(Ref{UInt8},Ref{BlasInt},Ref{BlasInt},Ptr{Float64},Ref{BlasInt},
Ptr{BlasInt},Ptr{Float64},Ref{BlasInt},Ref{BlasInt}),
trans,n,nhrs,ws.lu,lda,ws.ipiv,b,ldb,info)
......@@ -47,7 +46,7 @@ function linsolve_core!(ws::LinSolveWS,trans::Ref{UInt8},a::StridedMatrix{Float6
info = Ref{BlasInt}(0)
lu!(ws.lu,a,ws.ipiv)
ccall((@blasfunc(dgetrs_), liblapack), Void,
ccall((@blasfunc(dgetrs_), liblapack), Cvoid,
(Ref{UInt8},Ref{BlasInt},Ref{BlasInt},Ptr{Float64},Ref{BlasInt},
Ptr{BlasInt},Ptr{Float64},Ref{BlasInt},Ref{BlasInt}),
trans,n,nhrs,ws.lu,lda,ws.ipiv,b,ldb,info)
......@@ -68,7 +67,7 @@ function linsolve_core!(ws::LinSolveWS,trans::Ref{UInt8},a::StridedMatrix{Float6
lu!(ws.lu,a,ws.ipiv)
nhrs = Ref{BlasInt}(size(b,2))
ccall((@blasfunc(dgetrs_), liblapack), Void,
ccall((@blasfunc(dgetrs_), liblapack), Cvoid,
(Ref{UInt8},Ref{BlasInt},Ref{BlasInt},Ptr{Float64},Ref{BlasInt},
Ptr{BlasInt},Ptr{Float64},Ref{BlasInt},Ref{BlasInt}),
trans,n,nhrs,ws.lu,lda,ws.ipiv,b,ldb,info)
......@@ -77,7 +76,7 @@ function linsolve_core!(ws::LinSolveWS,trans::Ref{UInt8},a::StridedMatrix{Float6
chklapackerror(info[])
end
nhrs = Ref{BlasInt}(size(c,2))
ccall((@blasfunc(dgetrs_), liblapack), Void,
ccall((@blasfunc(dgetrs_), liblapack), Cvoid,
(Ref{UInt8},Ref{BlasInt},Ref{BlasInt},Ptr{Float64},Ref{BlasInt},
Ptr{BlasInt},Ptr{Float64},Ref{BlasInt},Ref{BlasInt}),
trans,n,nhrs,ws.lu,lda,ws.ipiv,c,ldc,info)
......@@ -99,7 +98,7 @@ function linsolve_core!(ws::LinSolveWS,trans::Ref{UInt8},a::StridedMatrix{Float6
info = Ref{BlasInt}(0)
lu!(ws.lu,a,ws.ipiv)
ccall((@blasfunc(dgetrs_), liblapack), Void,
ccall((@blasfunc(dgetrs_), liblapack), Cvoid,
(Ref{UInt8},Ref{BlasInt},Ref{BlasInt},Ptr{Float64},Ref{BlasInt},
Ptr{BlasInt},Ptr{Float64},Ref{BlasInt},Ref{BlasInt}),
trans,n,nhrs,ws.lu,lda,ws.ipiv,b,ldb,info)
......@@ -107,7 +106,7 @@ function linsolve_core!(ws::LinSolveWS,trans::Ref{UInt8},a::StridedMatrix{Float6
println("dgetrs ",info[])
chklapackerror(info[])
end
ccall((@blasfunc(dgetrs_), liblapack), Void,
ccall((@blasfunc(dgetrs_), liblapack), Cvoid,
(Ref{UInt8},Ref{BlasInt},Ref{BlasInt},Ptr{Float64},Ref{BlasInt},
Ptr{BlasInt},Ptr{Float64},Ref{BlasInt},Ref{BlasInt}),
trans,n,nhrs,ws.lu,lda,ws.ipiv,c,ldc,info)
......@@ -115,7 +114,7 @@ function linsolve_core!(ws::LinSolveWS,trans::Ref{UInt8},a::StridedMatrix{Float6
println("dgetrs ",info[])
chklapackerror(info[])
end
ccall((@blasfunc(dgetrs_), liblapack), Void,
ccall((@blasfunc(dgetrs_), liblapack), Cvoid,
(Ref{UInt8},Ref{BlasInt},Ref{BlasInt},Ptr{Float64},Ref{BlasInt},
Ptr{BlasInt},Ptr{Float64},Ref{BlasInt},Ref{BlasInt}),
trans,n,nhrs,ws.lu,lda,ws.ipiv,d,ldd,info)
......@@ -132,7 +131,7 @@ function lu!(lu,a,ipiv)
n = Ref{BlasInt}(nn)
lda = Ref{BlasInt}(max(1,stride(a,2)))
info = Ref{BlasInt}(0)
ccall((@blasfunc(dgetrf_), liblapack), Void,
ccall((@blasfunc(dgetrf_), liblapack), Cvoid,
(Ref{BlasInt},Ref{BlasInt},Ptr{Float64},Ref{BlasInt},
Ptr{BlasInt},Ref{BlasInt}),
m,n,lu,lda,ipiv,info)
......
......@@ -7,18 +7,18 @@ import Base.getindex
import Base.setindex!
import Base.copy
import Base: A_mul_B!, At_mul_B!, A_mul_Bt!, A_ldiv_B!
import Base.LinAlg.BlasInt
import Base.LinAlg.BLAS.@blasfunc
import Base.LinAlg.BLAS.libblas
import LinearAlgebra.BlasInt
import LinearAlgebra.BLAS.@blasfunc
import LinearAlgebra.BLAS.libblas
export QuasiUpperTriangular, I_plus_rA_ldiv_B!, I_plus_rA_plus_sB_ldiv_C!, A_rdiv_B!, A_rdiv_Bt!
export QuasiUpperTriangular, I_plus_rA_ldiv_B!, I_plus_rA_plus_sB_ldiv_C!, A_rdiv_B!, A_rdiv_Bt!
struct QuasiUpperTriangular{T<:Real,S<:AbstractMatrix} <: AbstractMatrix{T}
data::S
end
QuasiUpperTriangular(A::QuasiUpperTriangular) = A
function QuasiUpperTriangular(A::AbstractMatrix)
Base.LinAlg.checksquare(A)
LinearAlgebra.checksquare(A)
return QuasiUpperTriangular{eltype(A), typeof(A)}(A)
end
......@@ -197,7 +197,7 @@ function A_mul_B!(c::AbstractVecOrMat, alpha::Float64, a::QuasiUpperTriangular,
Ref{UInt8}('L'), Ref{UInt8}('U'), Ref{UInt8}('N'), Ref{UInt8}('N'), Ref{BlasInt}(nr), Ref{BlasInt}(nc),
Ref{Float64}(alpha), a.data, Ref{BlasInt}(nr), c, Ref{BlasInt}(nr))
b1 = reshape(b,nr,nc)
b1 = reshape(b,nr,nc)
c1 = reshape(c,nr,nc)
@inbounds for i= 2:m
x = a[i,i-1]
......@@ -432,7 +432,7 @@ function A_mul_B!(c::VecOrMat{Float64}, offset_c::Int64, a::QuasiUpperTriangular
@inbounds for i = 2:ma
x = a[i,i-1]
indb = offset_b
indb = offset_b
indc = offset_c + 1
@simd for j = 1:nb
c[indc] += x*b[indb]
......@@ -452,7 +452,7 @@ function A_mul_B!(c::Array{Float64,1}, offset_c::Int64,
Ref{Float64}, Ptr{Float64}, Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt}),
Ref{UInt8}('L'), Ref{UInt8}('U'), Ref{UInt8}('N'), Ref{UInt8}('N'), Ref{BlasInt}(na), Ref{BlasInt}(nb),
Ref{Float64}(alpha), a.data, Ref{BlasInt}(ma), Ref(c, offset_c), Ref{BlasInt}(na))
@inbounds for i= 2:ma
x = a[i,i-1]
indb = offset_b
......@@ -483,7 +483,7 @@ function A_mul_B!(c::SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}}
offsetc_orig = offset_c
@inbounds for i= 2:ma
x = a[i,i-1]
indb = offset_b
indb = offset_b
indc = offset_c +1
@simd for j=1:nb
c[indc] += x*b[indb]
......@@ -506,7 +506,7 @@ function At_mul_B!(c::VecOrMat{Float64}, offset_c::Int64, a::QuasiUpperTriangula
Ref{Float64}, Ptr{Float64}, Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt}),
Ref{UInt8}('L'), Ref{UInt8}('U'), Ref{UInt8}('T'), Ref{UInt8}('N'), Ref{BlasInt}(ma), Ref{BlasInt}(nb),
Ref{Float64}(alpha), a.data, Ref{BlasInt}(ma), Ref(c, offset_c), Ref{BlasInt}(ma))
@inbounds for i = 2:ma
x = a[i,i-1]
indb = offset_b + i - 1
......@@ -579,7 +579,7 @@ function A_ldiv_B!(a::QuasiUpperTriangular, b::AbstractMatrix)
end
end
function A_rdiv_B!(a::AbstractMatrix, b::QuasiUpperTriangular)
m, n = size(a)
nb, p = size(b)
......@@ -664,7 +664,7 @@ function I_plus_rA_ldiv_B!(r::Float64,a::QuasiUpperTriangular, b::AbstractVector
j = n
@inbounds while j > 0
if j == 1 || r*a.data[j,j-1] == 0
pivot = 1.0 + r*a.data[j,j]
pivot = 1.0 + r*a.data[j,j]
pivot == zero(a.data[j,j]) && throw(SingularException(j))
b[j] = pivot \ b[j]
xj = r*b[j]
......@@ -713,7 +713,7 @@ function I_plus_rA_plus_sB_ldiv_C!(r::Float64, s::Float64,a::QuasiUpperTriangula
end
j -= 1
else
a11 = 1.0 + r*a.data[j-1,j-1] + s*b.data[j-1,j-1]
a11 = 1.0 + r*a.data[j-1,j-1] + s*b.data[j-1,j-1]
a21 = r*a.data[j,j-1] + s*b.data[j,j-1]
a12 = r*a.data[j-1,j] + s*b.data[j-1,j]
a22 = 1.0 + r*a.data[j,j] + s*b.data[j,j]
......@@ -793,7 +793,7 @@ function I_plus_rA_plus_sB_ldiv_C!(r::Float64, s::Float64,a::QuasiUpperTriangula
end
j -= 1
else
a11 = 1.0 + r*a.data[j-1,j-1] + s*b.data[j-1,j-1]
a11 = 1.0 + r*a.data[j-1,j-1] + s*b.data[j-1,j-1]
a21 = r*a.data[j,j-1] + s*b.data[j,j-1]
a12 = r*a.data[j-1,j] + s*b.data[j-1,j]
a22 = 1.0 + r*a.data[j,j] + s*b.data[j,j]
......
module QrAlgo
import Base.LinAlg.BlasInt
import Base.LinAlg.BLAS.@blasfunc
import Base.LinAlg.BLAS.libblas
import Base.LinAlg.LAPACK: liblapack, chklapackerror
import LinearAlgebra.BlasInt
import LinearAlgebra.BLAS.@blasfunc
import LinearAlgebra.BLAS.libblas
import LinearAlgebra.LAPACK: liblapack, chklapackerror
export QrWS, dgeqrf_core!, dormrqf_core!
......@@ -59,7 +59,5 @@ function dormrqf_core!(ws::QrWS,side::Ref{UInt8},trans::Ref{UInt8},A::StridedMat
side,trans,m,n,k,A,RldA,ws.tau,C,RldC,ws.work,ws.lwork,ws.info)
chklapackerror(ws.info[])
end
end
end
This diff is collapsed.
......@@ -4,7 +4,7 @@ using LinearAlgebra
import ...DynLinAlg.LinSolveAlgo: LinSolveWS, linsolve_core!
import Base.LinAlg.BLAS: scal!, gemm!
import LinearAlgebra.BLAS: scal!, gemm!
export CyclicReductionWS, cyclic_reduction!, cyclic_reduction_check
......@@ -25,23 +25,23 @@ mutable struct CyclicReductionWS
m2_a2::SubArray{Float64}
info::Int64
function CyclicReductionWS(n)
function CyclicReductionWS(n::Int64)
linsolve_ws = LinSolveWS(n)
ahat1 = Matrix{Float64}(n,n)
a1copy = Matrix{Float64}(n,n)
m = Matrix{Float64}(2*n,2*n)
ahat1 = zeros(Float64,n,n)
a1copy = zeros(Float64,n,n)
m = zeros(Float64,2*n,2*n)
m00 = view(m,1:n,1:n)
m02 = view(m,1:n,n+(1:n))
m20 = view(m,n+(1:n),1:n)
m22 = view(m,n+(1:n),n+(1:n))
m1 = Matrix{Float64}(n,2*n)
m1 = zeros(Float64,n,2*n)
m1_a0 = view(m1,1:n,1:n)
m1_a2 = view(m1,1:n,n+(1:n))
m2 = Matrix{Float64}(2*n,n)
m2 = zeros(Float64,2*n,n)
m2_a0 = view(m2,1:n,1:n)
m2_a2 = view(m2,n+(1:n),1:n)
info = 0
new(linsolve_ws,ahat1,a1copy,m, m00, m02, m20, m22, m1, m1_a0, m1_a2, m2, m2_a0, m2_a2, info)
new(linsolve_ws,ahat1,a1copy,m, m00, m02, m20, m22, m1, m1_a0, m1_a2, m2, m2_a0, m2_a2, info)
end
end
......@@ -86,7 +86,7 @@ function cyclic_reduction!(x::Array{Float64},a0::Array{Float64},a1::Array{Float6
it = 0
@inbounds while it < max_it
# ws.m = [a0; a2]*(a1\[a0 a2])
copy!(ws.a1copy,a1)
copy!(ws.a1copy,a1)
linsolve_core!(ws.linsolve_ws,Ref{UInt8}('N'),ws.a1copy,ws.m1)
gemm!('N','N',-1.0,ws.m2,ws.m1,0.0,ws.m)
@simd for i in eachindex(a1)
......@@ -97,11 +97,7 @@ function cyclic_reduction!(x::Array{Float64},a0::Array{Float64},a1::Array{Float6
copy!(ws.m2_a0, ws.m00)
copy!(ws.m2_a2, ws.m22)
if any(isinf.(ws.m))
if norm(ws.m1_a0) < cvg_tol
ws.info = 2
else
ws.info = 1
end
ws.info norm(ws.m1_a0) < cvg_tol ? 2 : 1
fill!(x,NaN)
return
end
......@@ -116,26 +112,22 @@ function cyclic_reduction!(x::Array{Float64},a0::Array{Float64},a1::Array{Float6
it += 1
end
if it == max_it
if norm(ws.m1_a0) < cvg_tol
ws.info = 2
else
ws.info = 1
end
ws.info = norm(ws.m1_a0) < cvg_tol ? 2 : 1
fill!(x,NaN)
return
else
linsolve_core!(ws.linsolve_ws,Ref{UInt8}('N'),ws.ahat1,x)
@inbounds scal!(length(x),-1.0,x,1)
ws.info = 0
end
end
end
function cyclic_reduction_check(x::Array{Float64,2},a0::Array{Float64,2}, a1::Array{Float64,2}, a2::Array{Float64,2},cvg_tol::Float64)
res = a0 + a1*x + a2*x*x
if (sum(sum(abs.(res))) > cvg_tol)
if (sum(abs.(res)) > cvg_tol)
print("the norm of the residuals, ", res, ", compared to the tolerance criterion ",cvg_tol)
end
nothing
end
end
......@@ -8,69 +8,69 @@ import ...DynLinAlg.LinSolveAlgo: LinSolveWS, linsolve_core!, linsolve_core_no_l
import ..Solvers: ResultsPerturbationWs
import ..SolveEyePlusMinusAkronB: EyePlusAtKronBWS, generalized_sylvester_solver!
import Base.LinAlg.BLAS: gemm!
import LinearAlgebra.BLAS: gemm!
export FirstOrderSolverWS, first_order_solver
type FirstOrderSolverWS
jacobian_static::Matrix{Float64}
qr_ws::QrWS
solver_ws::Union{GsSolverWS,CyclicReductionWS}
ghx::StridedMatrix{Float64}
gx::Matrix{Float64}
hx::Matrix{Float64}
temp1::Matrix{Float64}
temp2::Matrix{Float64}
temp3::Matrix{Float64}
temp4::Matrix{Float64}
temp5::Matrix{Float64}
b10::Matrix{Float64}
b11::Matrix{Float64}
linsolve_static_ws::LinSolveWS
eye_plus_at_kron_b_ws::EyePlusAtKronBWS
function FirstOrderSolverWS(algo::String, jacobian::Matrix, m::Model)
if m.n_static > 0
jacobian_static = Matrix{Float64}(m.endo_nbr,m.n_static)
qr_ws = QrWS(jacobian_static)
else
jacobian_static = Matrix{Float64}(0,0)
qr_ws = QrWS(Matrix{Float64}(0,0))
end
if algo == "GS"
d = zeros(m.n_dyn,m.n_dyn)
e = zeros(m.n_dyn,m.n_dyn)
solver_ws = GsSolverWS(d,e,m.n_bkwrd+m.n_both)
elseif algo == "CR"
n = m.endo_nbr - m.n_static
solver_ws = CyclicReductionWS(n)
end
ghx = Matrix{Float64}(m.endo_nbr,m.n_bkwrd+m.n_both)
gx = Matrix{Float64}(m.n_fwrd+m.n_both,m.n_bkwrd+m.n_both)
hx = Matrix{Float64}(m.n_bkwrd+m.n_both,m.n_bkwrd+m.n_both)
temp1 = Matrix{Float64}(m.n_static,m.n_fwrd+m.n_both)
temp2 = Matrix{Float64}(m.n_static,m.n_bkwrd+m.n_both)
temp3 = Matrix{Float64}(m.n_static,m.n_bkwrd+m.n_both)
temp4 = Matrix{Float64}(m.endo_nbr - m.n_static,m.n_bkwrd+m.n_both)
temp5 = Matrix{Float64}(m.endo_nbr,max(m.current_exogenous_nbr,m.lagged_exogenous_nbr))
b10 = Matrix{Float64}(m.n_static,m.n_static)
b11 = Matrix{Float64}(m.n_static,length(m.p_current_ns))
linsolve_static_ws = LinSolveWS(m.n_static)
if m.serially_correlated_exogenous
eye_plus_at_kron_b_ws = EyePlusAtKronBWS(ma, mb, mc, 1)
else
eye_plus_at_kron_b_ws = EyePlusAtKronBWS(1, 1, 1, 1)
end
new(jacobian_static, qr_ws, solver_ws, ghx, gx, hx, temp1,
temp2, temp3, temp4, temp5, b10, b11, linsolve_static_ws,
eye_plus_at_kron_b_ws)
end
mutable struct FirstOrderSolverWS
jacobian_static::Matrix{Float64}
qr_ws::QrWS
solver_ws::Union{GsSolverWS,CyclicReductionWS}
ghx::StridedMatrix{Float64}
gx::Matrix{Float64}
hx::Matrix{Float64}
temp1::Matrix{Float64}