CyclicReduction.jl 3.99 KB
 Michel Juillard committed Apr 05, 2020 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 ``````using FastLapackInterface.LinSolveAlgo: LinSolveWs, linsolve_core! using LinearAlgebra using LinearAlgebra.BLAS: gemm! export CyclicReductionWs, cyclic_reduction!, cyclic_reduction_check mutable struct CyclicReductionWs linsolve_ws::LinSolveWs ahat1::Matrix{Float64} a1copy::Matrix{Float64} m::Matrix{Float64,} m00::SubArray{Float64} m02::SubArray{Float64} m20::SubArray{Float64} m22::SubArray{Float64} m1::Matrix{Float64} m1_a0::SubArray{Float64} m1_a2::SubArray{Float64} m2::Matrix{Float64} m2_a0::SubArray{Float64} m2_a2::SubArray{Float64} info::Int64 function CyclicReductionWs(n) linsolve_ws = LinSolveWs(n) ahat1 = Matrix{Float64}(undef, n,n) a1copy = Matrix{Float64}(undef, n,n) m = Matrix{Float64}(undef, 2*n,2*n) m00 = view(m,1:n,1:n) m02 = view(m,1:n,n .+ (1:n)) m20 = view(m,n .+ (1:n), 1:n) m22 = view(m,n .+ (1:n),n .+(1:n)) m1 = Matrix{Float64}(undef, n, 2*n) m1_a0 = view(m1,1:n,1:n) m1_a2 = view(m1,1:n,n .+ (1:n)) m2 = Matrix{Float64}(undef, 2*n, n) m2_a0 = view(m2, 1:n, 1:n) m2_a2 = view(m2, n .+ (1:n), 1:n) info = 0 new(linsolve_ws,ahat1,a1copy,m, m00, m02, m20, m22, m1, m1_a0, m1_a2, m2, m2_a0, m2_a2, info) end end """ cyclic_reduction!(x::Array{Float64},a0::Array{Float64},a1::Array{Float64},a2::Array{Float64},ws::CyclicReductionWs, cvg_tol::Float64, max_it::Int64) Solve the quadratic matrix equation a0 + a1*x + a2*x*x = 0, using the cyclic reduction method from Bini et al. (???). The solution is returned in matrix x. In case of nonconvergency, x is set to NaN and an error code is returned in ws.info * info = 0: return OK * info = 1: no stable solution (????) * info = 2: multiple stable solutions (????) # Example ```meta DocTestSetup = quote using CyclicReduction n = 3 ws = CyclicReductionWs(n) a0 = [0.5 0 0; 0 0.5 0; 0 0 0]; a1 = eye(n) a2 = [0 0 0; 0 0 0; 0 0 0.8] x = zeros(n,n) end ``` ```jldoctest julia> display(names(CyclicReduction)) ``` ```jldoctest julia> cyclic_reduction!(x,a0,a1,a2,ws,1e-8,50) ``` """ function cyclic_reduction!(x::Array{Float64},a0::Array{Float64},a1::Array{Float64},a2::Array{Float64},ws::CyclicReductionWs, cvg_tol::Float64, max_it::Int64) copyto!(x,a0) copyto!(ws.ahat1,1,a1,1,length(a1)) @inbounds copyto!(ws.m1_a0, a0) @inbounds copyto!(ws.m1_a2, a2) @inbounds copyto!(ws.m2_a0, a0) @inbounds copyto!(ws.m2_a2, a2) it = 0 @inbounds while it < max_it # ws.m = [a0; a2]*(a1\[a0 a2]) copyto!(ws.a1copy,a1) linsolve_core!(ws.a1copy, ws.m1, ws.linsolve_ws) gemm!('N','N',-1.0,ws.m2,ws.m1,0.0,ws.m) @simd for i in eachindex(a1) a1[i] += ws.m02[i] + ws.m20[i] end copyto!(ws.m1_a0, ws.m00) copyto!(ws.m1_a2, ws.m22) copyto!(ws.m2_a0, ws.m00) copyto!(ws.m2_a2, ws.m22) if any(isinf.(ws.m)) if norm(ws.m1_a0) < cvg_tol ws.info = 2 else ws.info = 1 end fill!(x,NaN) return end ws.ahat1 += ws.m20 crit = norm(ws.m1_a0,1) if crit < cvg_tol # keep iterating until condition on a2 is met if norm(ws.m1_a2,1) < cvg_tol break end end it += 1 end if it == max_it if norm(ws.m1_a0) < cvg_tol ws.info = 2 else ws.info = 1 end fill!(x,NaN) return else linsolve_core!(ws.ahat1, x, ws.linsolve_ws) @inbounds lmul!(-1.0,x) ws.info = 0 end end function cyclic_reduction_check(x::Array{Float64,2},a0::Array{Float64,2}, a1::Array{Float64,2}, a2::Array{Float64,2},cvg_tol::Float64) res = a0 + a1*x + a2*x*x if (sum(sum(abs.(res))) > cvg_tol) print("the norm of the residuals, ", res, ", compared to the tolerance criterion ",cvg_tol) end nothing end ``````