CyclicReduction.jl 4.24 KB
Newer Older
Michel Juillard's avatar
Michel Juillard committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
using FastLapackInterface.LinSolveAlgo: LinSolveWs, linsolve_core!

using LinearAlgebra
using LinearAlgebra.BLAS: gemm!
export CyclicReductionWs, cyclic_reduction!, cyclic_reduction_check

mutable struct CyclicReductionWs
    linsolve_ws::LinSolveWs
    ahat1::Matrix{Float64}
    a1copy::Matrix{Float64}
    m::Matrix{Float64,}
    m00::SubArray{Float64}
    m02::SubArray{Float64}
    m20::SubArray{Float64}
    m22::SubArray{Float64}
    m1::Matrix{Float64}
    m1_a0::SubArray{Float64}
    m1_a2::SubArray{Float64}
    m2::Matrix{Float64}
    m2_a0::SubArray{Float64}
    m2_a2::SubArray{Float64}
    info::Int64

    function CyclicReductionWs(n)
        linsolve_ws = LinSolveWs(n)
        ahat1 = Matrix{Float64}(undef, n,n)
	a1copy = Matrix{Float64}(undef, n,n)
        m = Matrix{Float64}(undef, 2*n,2*n)
        m00 = view(m,1:n,1:n)
        m02 = view(m,1:n,n .+ (1:n))
        m20 = view(m,n .+ (1:n), 1:n)
        m22 = view(m,n .+ (1:n),n .+(1:n))
        m1 = Matrix{Float64}(undef, n, 2*n)
        m1_a0 = view(m1,1:n,1:n)
        m1_a2 = view(m1,1:n,n .+ (1:n))
        m2 = Matrix{Float64}(undef, 2*n, n)
        m2_a0 = view(m2, 1:n, 1:n)
        m2_a2 = view(m2, n .+ (1:n), 1:n)
        info = 0
40
41
        new(linsolve_ws, ahat1, a1copy, m, m00, m02, m20, m22,
            m1, m1_a0, m1_a2, m2, m2_a0, m2_a2, info) 
Michel Juillard's avatar
Michel Juillard committed
42
43
44
45
46
47
48
49
    end
end

"""
    cyclic_reduction!(x::Array{Float64},a0::Array{Float64},a1::Array{Float64},a2::Array{Float64},ws::CyclicReductionWs, cvg_tol::Float64, max_it::Int64)

Solve the quadratic matrix equation a0 + a1*x + a2*x*x = 0, using the cyclic reduction method from Bini et al. (???).

50
51
The solution is returned in matrix x. In case of nonconvergency, x is set to NaN and 
UndeterminateSystemExcpetion or UnstableSystemException is thrown
Michel Juillard's avatar
Michel Juillard committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73

# Example
```meta
DocTestSetup = quote
     using CyclicReduction
     n = 3
     ws = CyclicReductionWs(n)
     a0 = [0.5 0 0; 0 0.5 0; 0 0 0];
     a1 = eye(n)
     a2 = [0 0 0; 0 0 0; 0 0 0.8]
     x = zeros(n,n)
end
```

```jldoctest
julia> display(names(CyclicReduction))
```

```jldoctest
julia> cyclic_reduction!(x,a0,a1,a2,ws,1e-8,50)
```
"""
Michel Juillard's avatar
Michel Juillard committed
74
75
76
77
78
79
80
function cyclic_reduction!(x::AbstractMatrix{Float64},
                           a0::AbstractMatrix{Float64},
                           a1::AbstractMatrix{Float64},
                           a2::AbstractMatrix{Float64},
                           ws::CyclicReductionWs,
                           cvg_tol::Float64,
                           max_it::Int64)
Michel Juillard's avatar
Michel Juillard committed
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    copyto!(x,a0)
    copyto!(ws.ahat1,1,a1,1,length(a1))
    @inbounds copyto!(ws.m1_a0, a0)
    @inbounds copyto!(ws.m1_a2, a2)
    @inbounds copyto!(ws.m2_a0, a0)
    @inbounds copyto!(ws.m2_a2, a2)
    it = 0
    @inbounds while it < max_it
        #        ws.m = [a0; a2]*(a1\[a0 a2])
	copyto!(ws.a1copy,a1)
        linsolve_core!(ws.a1copy, ws.m1, ws.linsolve_ws)
        gemm!('N','N',-1.0,ws.m2,ws.m1,0.0,ws.m)
        @simd for i in eachindex(a1)
            a1[i] += ws.m02[i] + ws.m20[i]
        end
        copyto!(ws.m1_a0, ws.m00)
        copyto!(ws.m1_a2, ws.m22)
        copyto!(ws.m2_a0, ws.m00)
        copyto!(ws.m2_a2, ws.m22)
100
101
102
103
        if any(isinf.(ws.m)) || any(isnan.(ws.m))
            fill!(x,NaN)
            if norm(ws.m1_a0) < Inf
                throw(UndeterminateSystemException())
Michel Juillard's avatar
Michel Juillard committed
104
            else
105
                throw(UnstableSystemException())
Michel Juillard's avatar
Michel Juillard committed
106
107
108
109
110
111
112
113
114
115
116
117
118
            end
        end
        ws.ahat1 += ws.m20
        crit = norm(ws.m1_a0,1)
        if crit < cvg_tol
	    # keep iterating until condition on a2 is met
            if norm(ws.m1_a2,1) < cvg_tol
                break
            end
        end
        it += 1
    end
    if it == max_it
119
        println("max_it")
Michel Juillard's avatar
Michel Juillard committed
120
        if norm(ws.m1_a0) < cvg_tol
121
            throw(UnstableSystemException())
Michel Juillard's avatar
Michel Juillard committed
122
        else
123
            throw(UndeterminateSystemException())
Michel Juillard's avatar
Michel Juillard committed
124
125
126
127
128
129
130
        end
        fill!(x,NaN)
        return
    else
        linsolve_core!(ws.ahat1, x, ws.linsolve_ws)
        @inbounds lmul!(-1.0,x)
        ws.info = 0
131
    end
Michel Juillard's avatar
Michel Juillard committed
132
133
134
135
136
137
138
139
140
141
end

function cyclic_reduction_check(x::Array{Float64,2},a0::Array{Float64,2}, a1::Array{Float64,2}, a2::Array{Float64,2},cvg_tol::Float64)
    res = a0 + a1*x + a2*x*x
    if (sum(sum(abs.(res))) > cvg_tol)
        print("the norm of the residuals, ", res, ", compared to the tolerance criterion ",cvg_tol)
    end
    nothing
end