Commit 730015bc authored by Frédéric Karamé's avatar Frédéric Karamé
Browse files

First changes to julia version 1+

parent 968c6594
module ExtendedMul
import Base.LinAlg: A_mul_B!, At_mul_B!, A_mul_Bt!, At_mul_Bt
import LinearAlgebra: A_mul_B!, At_mul_B!, A_mul_Bt!, At_mul_Bt
import Base.BLAS: gemm!
import Base.LinAlg: BlasInt, BlasFloat
import Base.LinAlg.BLAS: @blasfunc, libblas
import LinearAlgebra: BlasInt, BlasFloat
import LinearAlgebra.BLAS: @blasfunc, libblas
export A_mul_B!, At_mul_B!, A_mul_Bt!, At_mul_B!
function A_mul_B!(c::VecOrMat{Float64}, offset_c::Int64, a::VecOrMat{Float64},
offset_a::Int64, ma::Int64, na::Int64, b::VecOrMat{Float64},
offset_b::Int64, nb::Int64)
gemm!('N', 'N', 1.0, Ref(a, offset_a), ma, na, Ref(b, offset_b),
nb, 0.0, Ref(c, offset_c))
end
function A_mul_B!(c::Array{Float64,1}, offset_c::Int64, a::SubArray{Float64,2,Array{Float64,2},Tuple{Base.Slice{Base.OneTo{Int64}},UnitRange{Int64}},true}, offset_a::Int64, ma::Int64, na::Int64, b::Array{Float64,1}, offset_b::Int64, nb::Int64)
function A_mul_B!(c::Array{Float64,1}, offset_c::Int64, a::SubArray{Float64,2,Array{Float64,2},Tuple{Base.Slice{Base.OneTo{Int64}},UnitRange{Int64}},true},
offset_a::Int64, ma::Int64, na::Int64, b::Array{Float64,1}, offset_b::Int64, nb::Int64)
if offset_a != 1
throw(DimensionMismatch("offset_a must be 1"))
end
......@@ -25,7 +25,7 @@ function A_mul_B!(c::Array{Float64,1}, offset_c::Int64, a::SubArray{Float64,2,Ar
ref_b = Ref(b, offset_b)
ref_c = Ref(c, offset_c)
lda = max(1,size(a.parent,1))
ccall((@blasfunc(dgemm_), libblas), Void,
ccall((@blasfunc(dgemm_), libblas), Cvoid,
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},
Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, Ref{BlasInt},
Ptr{Float64}, Ref{BlasInt}, Ref{Float64}, Ptr{Float64},
......@@ -36,11 +36,13 @@ function A_mul_B!(c::Array{Float64,1}, offset_c::Int64, a::SubArray{Float64,2,Ar
ma)
end
function A_mul_B!(c::Array{Float64,1}, offset_c::Int64, a::Array{Float64,1}, offset_a::Int64, ma::Int64, na::Int64, b::SubArray{Float64,2,Array{Float64,2},Tuple{Base.Slice{Base.OneTo{Int64}},UnitRange{Int64}},true}, offset_b::Int64, nb::Int64)
function A_mul_B!(c::Array{Float64,1}, offset_c::Int64, a::Array{Float64,1}, offset_a::Int64, ma::Int64, na::Int64,
b::SubArray{Float64,2,Array{Float64,2},Tuple{Base.Slice{Base.OneTo{Int64}},UnitRange{Int64}},true},
offset_b::Int64, nb::Int64)
if offset_b != 1
throw(DimensionMismatch("offset_a must be 1"))
end
ccall((@blasfunc(dgemm_), libblas), Void,
ccall((@blasfunc(dgemm_), libblas), Cvoid,
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},
Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, Ref{BlasInt},
Ptr{Float64}, Ref{BlasInt}, Ref{Float64}, Ptr{Float64},
......@@ -54,7 +56,7 @@ end
function At_mul_B!(c::VecOrMat{Float64}, offset_c::Int64, a::VecOrMat{Float64},
offset_a::Int64, ma::Int64, na::Int64, b::VecOrMat{Float64},
offset_b::Int64, nb::Int64)
ccall((@blasfunc(dgemm_), libblas), Void,
ccall((@blasfunc(dgemm_), libblas), Cvoid,
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},
Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, Ref{BlasInt},
Ptr{Float64}, Ref{BlasInt}, Ref{Float64}, Ptr{Float64},
......@@ -68,7 +70,7 @@ end
function A_mul_Bt!(c::VecOrMat{Float64}, offset_c::Int64, a::VecOrMat{Float64},
offset_a::Int64, ma::Int64, na::Int64, b::VecOrMat{Float64},
offset_b::Int64, nb::Int64)
ccall((@blasfunc(dgemm_), libblas), Void,
ccall((@blasfunc(dgemm_), libblas), Cvoid,
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},
Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, Ref{BlasInt},
Ptr{Float64}, Ref{BlasInt}, Ref{Float64}, Ptr{Float64},
......@@ -82,8 +84,7 @@ end
function At_mul_Bt!(c::VecOrMat{Float64}, offset_c::Int64, a::VecOrMat{Float64},
offset_a::Int64, ma::Int64, na::Int64, b::VecOrMat{Float64},
offset_b::Int64, nb::Int64)
ccall((@blasfunc(dgemm_), libblas), Void,
ccall((@blasfunc(dgemm_), libblas), Cvoid,
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},
Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, Ref{BlasInt},
Ptr{Float64}, Ref{BlasInt}, Ref{Float64}, Ptr{Float64},
......@@ -97,7 +98,7 @@ end
function gemm!(ta::Char, tb::Char, alpha::Float64, a::Union{Ref{Float64},VecOrMat{Float64}},
ma::Int64, na::Int64, b::Union{Ref{Float64},VecOrMat{Float64}}, nb::Int64,
beta::Float64, c::Union{Ref{Float64},VecOrMat{Float64}})
ccall((@blasfunc(dgemm_), libblas), Void,
ccall((@blasfunc(dgemm_), libblas), Cvoid,
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt},
Ref{BlasInt}, Ref{Float64}, Ptr{Float64}, Ref{BlasInt},
Ptr{Float64}, Ref{BlasInt}, Ref{Float64}, Ptr{Float64},
......
......@@ -4,7 +4,7 @@ using Base.Test
using ExtendedMul
import Base.convert
export a_mul_kron_b!, a_mul_b_kron_c!, kron_at_kron_b_mul_c!, a_mul_b_kron_c_d!, at_mul_b_kron_c!, a_mul_b_kron_ct!
import Base.LinAlg: A_mul_B!, At_mul_B!, A_mul_Bt!, At_mul_Bt
import LinearAlgebra: A_mul_B!, At_mul_B!, A_mul_Bt!, At_mul_Bt
import Base.BLAS: gemm!
"""
......
module LinSolveAlgo
import Base.LinAlg.BlasInt
import Base.LinAlg.BLAS.@blasfunc
import Base.LinAlg.BLAS.libblas
import Base.LinAlg.LAPACK: liblapack, chklapackerror
import LinearAlgebra.BlasInt
import LinearAlgebra.BLAS.@blasfunc
import LinearAlgebra.BLAS.libblas
import LinearAlgebra.LAPACK: liblapack, chklapackerror
export LinSolveWS, linsolve_core!, linsolve_core_no_lu!, lu!
struct LinSolveWS
lu::Matrix{Float64}
ipiv::Vector{BlasInt}
function LinSolveWS(n)
lu = Matrix{Float64}(n,n)
ipiv = Vector{BlasInt}(n)
function LinSolveWS(n::Int64)
lu = zeros(Float64,n,n)
ipiv = zeros(BlasInt,n)
new(lu,ipiv)
end
end
......@@ -27,7 +26,7 @@ function linsolve_core_no_lu!(ws::LinSolveWS,trans::Ref{UInt8},a::StridedMatrix{
ldb = Ref{BlasInt}(max(1,stride(b,2)))
info = Ref{BlasInt}(0)
ccall((@blasfunc(dgetrs_), liblapack), Void,
ccall((@blasfunc(dgetrs_), liblapack), Cvoid,
(Ref{UInt8},Ref{BlasInt},Ref{BlasInt},Ptr{Float64},Ref{BlasInt},
Ptr{BlasInt},Ptr{Float64},Ref{BlasInt},Ref{BlasInt}),
trans,n,nhrs,ws.lu,lda,ws.ipiv,b,ldb,info)
......@@ -47,7 +46,7 @@ function linsolve_core!(ws::LinSolveWS,trans::Ref{UInt8},a::StridedMatrix{Float6
info = Ref{BlasInt}(0)
lu!(ws.lu,a,ws.ipiv)
ccall((@blasfunc(dgetrs_), liblapack), Void,
ccall((@blasfunc(dgetrs_), liblapack), Cvoid,
(Ref{UInt8},Ref{BlasInt},Ref{BlasInt},Ptr{Float64},Ref{BlasInt},
Ptr{BlasInt},Ptr{Float64},Ref{BlasInt},Ref{BlasInt}),
trans,n,nhrs,ws.lu,lda,ws.ipiv,b,ldb,info)
......@@ -68,7 +67,7 @@ function linsolve_core!(ws::LinSolveWS,trans::Ref{UInt8},a::StridedMatrix{Float6
lu!(ws.lu,a,ws.ipiv)
nhrs = Ref{BlasInt}(size(b,2))
ccall((@blasfunc(dgetrs_), liblapack), Void,
ccall((@blasfunc(dgetrs_), liblapack), Cvoid,
(Ref{UInt8},Ref{BlasInt},Ref{BlasInt},Ptr{Float64},Ref{BlasInt},
Ptr{BlasInt},Ptr{Float64},Ref{BlasInt},Ref{BlasInt}),
trans,n,nhrs,ws.lu,lda,ws.ipiv,b,ldb,info)
......@@ -77,7 +76,7 @@ function linsolve_core!(ws::LinSolveWS,trans::Ref{UInt8},a::StridedMatrix{Float6
chklapackerror(info[])
end
nhrs = Ref{BlasInt}(size(c,2))
ccall((@blasfunc(dgetrs_), liblapack), Void,
ccall((@blasfunc(dgetrs_), liblapack), Cvoid,
(Ref{UInt8},Ref{BlasInt},Ref{BlasInt},Ptr{Float64},Ref{BlasInt},
Ptr{BlasInt},Ptr{Float64},Ref{BlasInt},Ref{BlasInt}),
trans,n,nhrs,ws.lu,lda,ws.ipiv,c,ldc,info)
......@@ -99,7 +98,7 @@ function linsolve_core!(ws::LinSolveWS,trans::Ref{UInt8},a::StridedMatrix{Float6
info = Ref{BlasInt}(0)
lu!(ws.lu,a,ws.ipiv)
ccall((@blasfunc(dgetrs_), liblapack), Void,
ccall((@blasfunc(dgetrs_), liblapack), Cvoid,
(Ref{UInt8},Ref{BlasInt},Ref{BlasInt},Ptr{Float64},Ref{BlasInt},
Ptr{BlasInt},Ptr{Float64},Ref{BlasInt},Ref{BlasInt}),
trans,n,nhrs,ws.lu,lda,ws.ipiv,b,ldb,info)
......@@ -107,7 +106,7 @@ function linsolve_core!(ws::LinSolveWS,trans::Ref{UInt8},a::StridedMatrix{Float6
println("dgetrs ",info[])
chklapackerror(info[])
end
ccall((@blasfunc(dgetrs_), liblapack), Void,
ccall((@blasfunc(dgetrs_), liblapack), Cvoid,
(Ref{UInt8},Ref{BlasInt},Ref{BlasInt},Ptr{Float64},Ref{BlasInt},
Ptr{BlasInt},Ptr{Float64},Ref{BlasInt},Ref{BlasInt}),
trans,n,nhrs,ws.lu,lda,ws.ipiv,c,ldc,info)
......@@ -115,7 +114,7 @@ function linsolve_core!(ws::LinSolveWS,trans::Ref{UInt8},a::StridedMatrix{Float6
println("dgetrs ",info[])
chklapackerror(info[])
end
ccall((@blasfunc(dgetrs_), liblapack), Void,
ccall((@blasfunc(dgetrs_), liblapack), Cvoid,
(Ref{UInt8},Ref{BlasInt},Ref{BlasInt},Ptr{Float64},Ref{BlasInt},
Ptr{BlasInt},Ptr{Float64},Ref{BlasInt},Ref{BlasInt}),
trans,n,nhrs,ws.lu,lda,ws.ipiv,d,ldd,info)
......@@ -132,7 +131,7 @@ function lu!(lu,a,ipiv)
n = Ref{BlasInt}(nn)
lda = Ref{BlasInt}(max(1,stride(a,2)))
info = Ref{BlasInt}(0)
ccall((@blasfunc(dgetrf_), liblapack), Void,
ccall((@blasfunc(dgetrf_), liblapack), Cvoid,
(Ref{BlasInt},Ref{BlasInt},Ptr{Float64},Ref{BlasInt},
Ptr{BlasInt},Ref{BlasInt}),
m,n,lu,lda,ipiv,info)
......
......@@ -7,9 +7,9 @@ import Base.getindex
import Base.setindex!
import Base.copy
import Base: A_mul_B!, At_mul_B!, A_mul_Bt!, A_ldiv_B!
import Base.LinAlg.BlasInt
import Base.LinAlg.BLAS.@blasfunc
import Base.LinAlg.BLAS.libblas
import LinearAlgebra.BlasInt
import LinearAlgebra.BLAS.@blasfunc
import LinearAlgebra.BLAS.libblas
export QuasiUpperTriangular, I_plus_rA_ldiv_B!, I_plus_rA_plus_sB_ldiv_C!, A_rdiv_B!, A_rdiv_Bt!
......@@ -18,7 +18,7 @@ struct QuasiUpperTriangular{T<:Real,S<:AbstractMatrix} <: AbstractMatrix{T}
end
QuasiUpperTriangular(A::QuasiUpperTriangular) = A
function QuasiUpperTriangular(A::AbstractMatrix)
Base.LinAlg.checksquare(A)
LinearAlgebra.checksquare(A)
return QuasiUpperTriangular{eltype(A), typeof(A)}(A)
end
......
module QrAlgo
import Base.LinAlg.BlasInt
import Base.LinAlg.BLAS.@blasfunc
import Base.LinAlg.BLAS.libblas
import Base.LinAlg.LAPACK: liblapack, chklapackerror
import LinearAlgebra.BlasInt
import LinearAlgebra.BLAS.@blasfunc
import LinearAlgebra.BLAS.libblas
import LinearAlgebra.LAPACK: liblapack, chklapackerror
export QrWS, dgeqrf_core!, dormrqf_core!
......@@ -61,5 +61,3 @@ function dormrqf_core!(ws::QrWS,side::Ref{UInt8},trans::Ref{UInt8},A::StridedMat
end
end
module SchurAlgo
# general Schur decomposition with reordering
# adataped from ./base/linalg/lapack.jl
# adapted from ./base/linalg/lapack.jl
include("exceptions.jl")
import Base: USE_BLAS64, LAPACKException
import Base.LinAlg: BlasInt, BlasFloat, checksquare, chkstride1
import Base.LinAlg.BLAS: @blasfunc, libblas
import Base.LinAlg.LAPACK: liblapack, chklapackerror
import LinearAlgebra: BlasInt, BlasFloat, checksquare, chkstride1
import LinearAlgebra.BLAS: @blasfunc, libblas
import LinearAlgebra.LAPACK: liblapack, chklapackerror
export DgeesWS, dgees!, DggesWS, dgges!
......@@ -26,8 +26,8 @@ function mycompare{T}(alphar_::Ptr{T}, alphai_::Ptr{T}, beta_::Ptr{T})
return convert(Cint, ((alphar*alphar + alphai*alphai) < criterium*beta*beta) ? 1 : 0)
end
const mycompare_c = cfunction(mycompare, Cint, (Ptr{Cdouble}, Ptr{Cdouble}))
const mycompare_g_c = cfunction(mycompare, Cint, (Ptr{Cdouble}, Ptr{Cdouble}, Ptr{Cdouble}))
const mycompare_c = @cfunction(mycompare, Cint, (Ptr{Cdouble}, Ptr{Cdouble}))
const mycompare_g_c = @cfunction(mycompare, Cint, (Ptr{Cdouble}, Ptr{Cdouble}, Ptr{Cdouble}))
mutable struct DgeesWS
jobvs::Ref{UInt8}
......@@ -49,8 +49,8 @@ mutable struct DgeesWS
n = Ref{BlasInt}(size(A,1))
RldA = Ref{BlasInt}(max(1,stride(A,2)))
Rsort = Ref{UInt8}('N')
ccall((@blasfunc(dgees_), liblapack), Void,
(Ref{UInt8}, Ref{UInt8}, Ptr{Void},
ccall((@blasfunc(dgees_), liblapack), Cvoid,
(Ref{UInt8}, Ref{UInt8}, Ptr{Cvoid},
Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt},
Ptr{BlasInt}, Ptr{Float64},
Ptr{Float64}, Ptr{Float64}, Ref{BlasInt},
......@@ -67,10 +67,8 @@ mutable struct DgeesWS
work = Vector{Float64}(lwork[])
new(jobvs, sdim, wr, wi, ldvs, vs, work, lwork, bwork, eigen_values, info)
end
end
function DgeesWS(A::StridedMatrix{Float64})
chkstride1(A)
n, = checksquare(A)
......@@ -86,7 +84,6 @@ function DgeesWS(A::StridedMatrix{Float64})
eigen_values = Vector{Complex{Float64}}(n)
info = Ref{BlasInt}(0)
DgeesWS(jobvs, A, sdim, wr, wi, ldvs, vs, work, lwork, bwork, eigen_values, info)
end
function DgeesWS(n::Int64)
......@@ -98,8 +95,8 @@ function dgees!(ws::DgeesWS,A::StridedMatrix{Float64})
n = Ref{BlasInt}(size(A,1))
RldA = Ref{BlasInt}(max(1,stride(A,2)))
sort = Ref{UInt8}('S')
ccall((@blasfunc(dgees_), liblapack), Void,
(Ref{UInt8}, Ref{UInt8}, Ptr{Void},
ccall((@blasfunc(dgees_), liblapack), Cvoid,
(Ref{UInt8}, Ref{UInt8}, Ptr{Cvoid},
Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt},
Ref{BlasInt}, Ptr{Float64}, Ptr{Float64},
Ptr{Float64}, Ref{BlasInt},
......@@ -142,8 +139,8 @@ mutable struct DggesWS
work = Vector{Float64}(1)
sdim = BlasInt(0)
info = BlasInt(0)
ccall((@blasfunc(dgges_), liblapack), Void,
(Ref{UInt8}, Ref{UInt8}, Ref{UInt8}, Ptr{Void},
ccall((@blasfunc(dgges_), liblapack), Cvoid,
(Ref{UInt8}, Ref{UInt8}, Ref{UInt8}, Ptr{Cvoid},
Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt}, Ptr{Float64},
Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64}, Ptr{Float64},
Ptr{Float64}, Ptr{Float64}, Ref{BlasInt}, Ptr{Float64},
......@@ -176,8 +173,8 @@ function dgges!(jobvsl::Char, jobvsr::Char, A::StridedMatrix{Float64}, B::Stride
sort = 'S'
sdim = Ref{BlasInt}(0)
info = Ref{BlasInt}(0)
ccall((@blasfunc(dgges_), liblapack), Void,
(Ref{UInt8}, Ref{UInt8}, Ref{UInt8}, Ptr{Void},
ccall((@blasfunc(dgges_), liblapack), Cvoid,
(Ref{UInt8}, Ref{UInt8}, Ref{UInt8}, Ptr{Cvoid},
Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt}, Ptr{Float64},
Ref{BlasInt}, Ptr{BlasInt}, Ptr{Float64}, Ptr{Float64},
Ptr{Float64}, Ptr{Float64}, Ref{BlasInt}, Ptr{Float64},
......
......@@ -4,7 +4,7 @@ using LinearAlgebra
import ...DynLinAlg.LinSolveAlgo: LinSolveWS, linsolve_core!
import Base.LinAlg.BLAS: scal!, gemm!
import LinearAlgebra.BLAS: scal!, gemm!
export CyclicReductionWS, cyclic_reduction!, cyclic_reduction_check
......@@ -25,19 +25,19 @@ mutable struct CyclicReductionWS
m2_a2::SubArray{Float64}
info::Int64
function CyclicReductionWS(n)
function CyclicReductionWS(n::Int64)
linsolve_ws = LinSolveWS(n)
ahat1 = Matrix{Float64}(n,n)
a1copy = Matrix{Float64}(n,n)
m = Matrix{Float64}(2*n,2*n)
ahat1 = zeros(Float64,n,n)
a1copy = zeros(Float64,n,n)
m = zeros(Float64,2*n,2*n)
m00 = view(m,1:n,1:n)
m02 = view(m,1:n,n+(1:n))
m20 = view(m,n+(1:n),1:n)
m22 = view(m,n+(1:n),n+(1:n))
m1 = Matrix{Float64}(n,2*n)
m1 = zeros(Float64,n,2*n)
m1_a0 = view(m1,1:n,1:n)
m1_a2 = view(m1,1:n,n+(1:n))
m2 = Matrix{Float64}(2*n,n)
m2 = zeros(Float64,2*n,n)
m2_a0 = view(m2,1:n,1:n)
m2_a2 = view(m2,n+(1:n),1:n)
info = 0
......@@ -97,11 +97,7 @@ function cyclic_reduction!(x::Array{Float64},a0::Array{Float64},a1::Array{Float6
copy!(ws.m2_a0, ws.m00)
copy!(ws.m2_a2, ws.m22)
if any(isinf.(ws.m))
if norm(ws.m1_a0) < cvg_tol
ws.info = 2
else
ws.info = 1
end
ws.info norm(ws.m1_a0) < cvg_tol ? 2 : 1
fill!(x,NaN)
return
end
......@@ -116,11 +112,7 @@ function cyclic_reduction!(x::Array{Float64},a0::Array{Float64},a1::Array{Float6
it += 1
end
if it == max_it
if norm(ws.m1_a0) < cvg_tol
ws.info = 2
else
ws.info = 1
end
ws.info = norm(ws.m1_a0) < cvg_tol ? 2 : 1
fill!(x,NaN)
return
else
......@@ -132,7 +124,7 @@ end
function cyclic_reduction_check(x::Array{Float64,2},a0::Array{Float64,2}, a1::Array{Float64,2}, a2::Array{Float64,2},cvg_tol::Float64)
res = a0 + a1*x + a2*x*x
if (sum(sum(abs.(res))) > cvg_tol)
if (sum(abs.(res)) > cvg_tol)
print("the norm of the residuals, ", res, ", compared to the tolerance criterion ",cvg_tol)
end
nothing
......
......@@ -8,11 +8,11 @@ import ...DynLinAlg.LinSolveAlgo: LinSolveWS, linsolve_core!, linsolve_core_no_l
import ..Solvers: ResultsPerturbationWs
import ..SolveEyePlusMinusAkronB: EyePlusAtKronBWS, generalized_sylvester_solver!
import Base.LinAlg.BLAS: gemm!
import LinearAlgebra.BLAS: gemm!
export FirstOrderSolverWS, first_order_solver
type FirstOrderSolverWS
mutable struct FirstOrderSolverWS
jacobian_static::Matrix{Float64}
qr_ws::QrWS
solver_ws::Union{GsSolverWS,CyclicReductionWS}
......@@ -31,11 +31,11 @@ type FirstOrderSolverWS
function FirstOrderSolverWS(algo::String, jacobian::Matrix, m::Model)
if m.n_static > 0
jacobian_static = Matrix{Float64}(m.endo_nbr,m.n_static)
jacobian_static = zeros(Float64,m.endo_nbr,m.n_static)
qr_ws = QrWS(jacobian_static)
else
jacobian_static = Matrix{Float64}(0,0)
qr_ws = QrWS(Matrix{Float64}(0,0))
jacobian_static = zeros(Float64,0,0)
qr_ws = QrWS(zeros(Float64,0,0))
end
if algo == "GS"
d = zeros(m.n_dyn,m.n_dyn)
......@@ -45,24 +45,21 @@ type FirstOrderSolverWS
n = m.endo_nbr - m.n_static
solver_ws = CyclicReductionWS(n)
end
ghx = Matrix{Float64}(m.endo_nbr,m.n_bkwrd+m.n_both)
gx = Matrix{Float64}(m.n_fwrd+m.n_both,m.n_bkwrd+m.n_both)
hx = Matrix{Float64}(m.n_bkwrd+m.n_both,m.n_bkwrd+m.n_both)
temp1 = Matrix{Float64}(m.n_static,m.n_fwrd+m.n_both)
temp2 = Matrix{Float64}(m.n_static,m.n_bkwrd+m.n_both)
temp3 = Matrix{Float64}(m.n_static,m.n_bkwrd+m.n_both)
temp4 = Matrix{Float64}(m.endo_nbr - m.n_static,m.n_bkwrd+m.n_both)
temp5 = Matrix{Float64}(m.endo_nbr,max(m.current_exogenous_nbr,m.lagged_exogenous_nbr))
b10 = Matrix{Float64}(m.n_static,m.n_static)
b11 = Matrix{Float64}(m.n_static,length(m.p_current_ns))
linsolve_static_ws = LinSolveWS(m.n_static)
if m.serially_correlated_exogenous
eye_plus_at_kron_b_ws = EyePlusAtKronBWS(ma, mb, mc, 1)
else
eye_plus_at_kron_b_ws = EyePlusAtKronBWS(1, 1, 1, 1)
end
new(jacobian_static, qr_ws, solver_ws, ghx, gx, hx, temp1,
temp2, temp3, temp4, temp5, b10, b11, linsolve_static_ws,
eye_plus_at_kron_b_ws = m.serially_correlated_exogenous ? EyePlusAtKronBWS(ma, mb, mc, 1) : EyePlusAtKronBWS(1, 1, 1, 1)
new(jacobian_static,
qr_ws,
solver_ws,
zeros(Float64,m.endo_nbr,m.n_bkwrd+m.n_both), #ghx
zeros(Float64,m.n_fwrd+m.n_both,m.n_bkwrd+m.n_both),#gx
zeros(Float64,m.n_bkwrd+m.n_both,m.n_bkwrd+m.n_both),#hx
zeros(Float64,m.n_static,m.n_fwrd+m.n_both),#temp1
zeros(Float64,m.n_static,m.n_bkwrd+m.n_both), #temp2
zeros(Float64,m.n_static,m.n_bkwrd+m.n_both), temp3
zeros(Float64,m.endo_nbr - m.n_static,m.n_bkwrd+m.n_both),#temp4,
zeros(Float64,m.endo_nbr,max(m.current_exogenous_nbr,m.lagged_exogenous_nbr)),#temp5,
zeros(Float64,m.n_static,m.n_static),#b10,
zeros(Float64,m.n_static,length(m.p_current_ns)),#b11,
LinSolveWS(m.n_static),#linsolve_static_ws,
eye_plus_at_kron_b_ws)
end
end
......@@ -70,7 +67,10 @@ end
function remove_static!(ws::FirstOrderSolverWS,jacobian::Matrix,p_static::Vector)
ws.jacobian_static[:,:] = view(jacobian,:,p_static)
dgeqrf_core!(ws.qr_ws,ws.jacobian_static)
dormrqf_core!(ws.qr_ws,Ref{UInt8}('L'),Ref{UInt8}('T'),ws.jacobian_static,
dormrqf_core!(ws.qr_ws,
Ref{UInt8}('L'),
Ref{UInt8}('T'),
ws.jacobian_static,
jacobian)
end
......@@ -83,25 +83,21 @@ function add_static!(results::ResultsPerturbationWs,ws::FirstOrderSolverWS,jacob
ws.b10 = view(jacobian,i_static, model.p_static)
ws.b11 = view(jacobian,i_static, model.p_current_ns)
ws.temp3 -= view(jacobian,i_static,model.p_bkwrd_b)
for i=1:(model.n_bkwrd + model.n_both)
for j=1:length(model.i_dyn)
@inbounds for i=1:(model.n_bkwrd + model.n_both), j=1:length(model.i_dyn)
ws.temp4[j,i] = results.g[1][model.i_dyn[j],i]
end
end
gemm!('N','N',-1.0,ws.b11,ws.temp4,1.0,ws.temp3)
linsolve_core!(ws.linsolve_static_ws,Ref{UInt8}('N'),ws.b10,ws.temp3)
for i = 1:model.n_states
for j=1:model.n_static
@inbounds for i = 1:model.n_states, j=1:model.n_static
results.g[1][model.i_static[j],i] = ws.temp3[j,i]
end
end
end
using Base.Test
function make_f1g1plusf2!(results::ResultsPerturbationWs,model,jacobian)
nstate = model.n_bkwrd + model.n_both
so = nstate*model.endo_nbr + 1
for i=1:model.n_current
@inbounds for i=1:model.n_current
copy!(results.f1g1plusf2,(model.i_current[i]-1)*model.endo_nbr+1,jacobian,so,model.endo_nbr)
so += model.endo_nbr
end
......@@ -111,10 +107,10 @@ function make_f1g1plusf2!(results::ResultsPerturbationWs,model,jacobian)
z = view(results.f1g1plusf2, :, model.i_bkwrd_b[i])
for j=1:nstate
x = 0.0
for k=1:(model.n_fwrd + model.n_both)
@inbounds for k=1:(model.n_fwrd + model.n_both)
x += jacobian[j, offset + k]*y[model.i_fwrd_b[k]]
end
z[j] += x
@inbounds z[j] += x
end
end
lu!(results.f1g1plusf2_linsolve_ws.lu, results.f1g1plusf2, results.f1g1plusf2_linsolve_ws.ipiv)
......@@ -123,11 +119,9 @@ end
function solve_for_derivatives_with_respect_to_shocks(results::ResultsPerturbationWs, jacobian::AbstractMatrix, ws::FirstOrderSolverWS, model::Model)
if model.lagged_exogenous_nbr > 0
f6 = view(jacobian,:,model.i_lagged_exogenous)
for i = 1:model.current_exogenous_nbr
for j = 1:model.endo_nbr
@inbounds for i = 1:model.current_exogenous_nbr, j = 1:model.endo_nbr
results.g1_3[i,j] = -f6[i,j]
end
end
linsolve_core_no_lu!(ws, Ref{UInt8}('N'), results.f1g1plusf2, results.g1_3)
end
if model.current_exogenous_nbr > 0
......@@ -156,7 +150,7 @@ function first_order_solver(results::ResultsPerturbationWs,ws::FirstOrderSolverW
if ws.solver_ws.info[1] > 0
error("CR didn't converge")
end
for i = 1:length(model.i_bkwrd_ns)
@inbounds for i = 1:length(model.i_bkwrd_ns)
for j = 1:length(model.i_dyn)
results.g[1][model.i_dyn[j],i] = x[j,model.i_bkwrd_ns[i]]
end
......@@ -164,12 +158,11 @@ function first_order_solver(results::ResultsPerturbationWs,ws::FirstOrderSolverW
results.gs[1][j,i] = results.g[1][model.hx_rows[j],i]
end
end
elseif algo == "GS"
d, e = get_de(jacobian[model.n_static+1:end,:],model)
gs_solver!(ws.solver_ws,d,e,model.n_bkwrd+model.n_both,options.generalized_schur.criterium)
results.gs[1] = ws.solver_ws.g2
for i = 1:model.n_bkwrd+model.n_both
@inbounds for i = 1:model.n_bkwrd+model.n_both
for j = 1:model.n_bkwrd
results.g[1][model.i_bkwrd[j],i] = ws.solver_ws.g1[j,i]
end
......
......@@ -5,7 +5,7 @@ module GeneralizedSchurDecompositionSolver
import ...DynLinAlg.SchurAlgo: DggesWS, dgges!
import ...DynLinAlg.LinSolveAlgo: LinSolveWS, linsolve_core!
import Base.LinAlg.BLAS: scal!, gemm!
import LinearAlgebra.BLAS: scal!, gemm!
export GsSolverWS, gs_solver!
......@@ -25,25 +25,33 @@ mutable struct GsSolverWS
g2::Matrix{Float64}
eigval::Vector{Complex64}
function GsSolverWS(d,e,n1)
function GsSolverWS(d,e,n1) # type?
dgges_ws = DggesWS(Ref{UInt8}('N'), Ref{UInt8}('N'), Ref{UInt8}('N'), e, d)
n = size(d,1)
n2 = n - n1
linsolve_ws = LinSolveWS(n2)
D11 = view(d,1:n1,1:n1)
E11 = view(e,1:n1,1:n1)
vsr = Matrix{Float64}(n,n)
vsr = zeros(Float64,n,n)
Z11 = view(vsr,1:n1,1:n1)