|
1 | 1 | using BlockArrays: blocksizes |
2 | 2 | using DiagonalArrays: diagonal |
3 | 3 | using LinearAlgebra: LinearAlgebra, Diagonal |
4 | | -using MatrixAlgebraKit: |
5 | | - MatrixAlgebraKit, |
6 | | - TruncationStrategy, |
7 | | - check_input, |
8 | | - default_eig_algorithm, |
9 | | - default_eigh_algorithm, |
10 | | - diagview, |
11 | | - eig_full!, |
12 | | - eig_trunc!, |
13 | | - eig_vals!, |
14 | | - eigh_full!, |
15 | | - eigh_trunc!, |
16 | | - eigh_vals!, |
17 | | - findtruncated |
| 4 | +using MatrixAlgebraKit: MatrixAlgebraKit, diagview |
| 5 | +using MatrixAlgebraKit: default_eig_algorithm, eig_full!, eig_vals! |
| 6 | +using MatrixAlgebraKit: default_eigh_algorithm, eigh_full!, eigh_vals! |
18 | 7 |
|
19 | 8 | for f in [:default_eig_algorithm, :default_eigh_algorithm] |
20 | 9 | @eval begin |
21 | 10 | function MatrixAlgebraKit.$f(::Type{<:AbstractBlockSparseMatrix}; kwargs...) |
22 | | - return BlockPermutedDiagonalAlgorithm() do block |
| 11 | + return BlockDiagonalAlgorithm() do block |
23 | 12 | return $f(block; kwargs...) |
24 | 13 | end |
25 | 14 | end |
26 | 15 | end |
27 | 16 | end |
28 | 17 |
|
| 18 | +function output_type(::typeof(eig_full!), A::Type{<:AbstractMatrix{T}}) where {T} |
| 19 | + DV = Base.promote_op(eig_full!, A) |
| 20 | + return if isconcretetype(DV) |
| 21 | + DV |
| 22 | + else |
| 23 | + Tuple{AbstractMatrix{complex(T)},AbstractMatrix{complex(T)}} |
| 24 | + end |
| 25 | +end |
| 26 | +function output_type(::typeof(eigh_full!), A::Type{<:AbstractMatrix{T}}) where {T} |
| 27 | + DV = Base.promote_op(eigh_full!, A) |
| 28 | + return isconcretetype(DV) ? DV : Tuple{AbstractMatrix{real(T)},AbstractMatrix{T}} |
| 29 | +end |
| 30 | + |
29 | 31 | function MatrixAlgebraKit.check_input( |
30 | | - ::typeof(eig_full!), A::AbstractBlockSparseMatrix, (D, V) |
| 32 | + ::typeof(eig_full!), A::AbstractBlockSparseMatrix, (D, V), ::BlockDiagonalAlgorithm |
31 | 33 | ) |
32 | 34 | @assert isa(D, AbstractBlockSparseMatrix) && isa(V, AbstractBlockSparseMatrix) |
33 | 35 | @assert eltype(V) === eltype(D) === complex(eltype(A)) |
34 | 36 | @assert axes(A, 1) == axes(A, 2) |
35 | 37 | @assert axes(A) == axes(D) == axes(V) |
| 38 | + @assert isblockdiagonal(A) |
36 | 39 | return nothing |
37 | 40 | end |
38 | 41 | function MatrixAlgebraKit.check_input( |
39 | | - ::typeof(eigh_full!), A::AbstractBlockSparseMatrix, (D, V) |
| 42 | + ::typeof(eigh_full!), A::AbstractBlockSparseMatrix, (D, V), ::BlockDiagonalAlgorithm |
40 | 43 | ) |
41 | 44 | @assert isa(D, AbstractBlockSparseMatrix) && isa(V, AbstractBlockSparseMatrix) |
42 | 45 | @assert eltype(V) === eltype(A) |
43 | 46 | @assert eltype(D) === real(eltype(A)) |
44 | 47 | @assert axes(A, 1) == axes(A, 2) |
45 | 48 | @assert axes(A) == axes(D) == axes(V) |
| 49 | + @assert isblockdiagonal(A) |
46 | 50 | return nothing |
47 | 51 | end |
48 | 52 |
|
49 | | -function output_type(f::typeof(eig_full!), A::Type{<:AbstractMatrix{T}}) where {T} |
50 | | - DV = Base.promote_op(f, A) |
51 | | - !isconcretetype(DV) && return Tuple{AbstractMatrix{complex(T)},AbstractMatrix{complex(T)}} |
52 | | - return DV |
53 | | -end |
54 | | -function output_type(f::typeof(eigh_full!), A::Type{<:AbstractMatrix{T}}) where {T} |
55 | | - DV = Base.promote_op(f, A) |
56 | | - !isconcretetype(DV) && return Tuple{AbstractMatrix{real(T)},AbstractMatrix{T}} |
57 | | - return DV |
58 | | -end |
59 | | - |
60 | 53 | for f in [:eig_full!, :eigh_full!] |
61 | 54 | @eval begin |
62 | 55 | function MatrixAlgebraKit.initialize_output( |
63 | | - ::typeof($f), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm |
| 56 | + ::typeof($f), A::AbstractBlockSparseMatrix, alg::BlockDiagonalAlgorithm |
64 | 57 | ) |
65 | 58 | Td, Tv = fieldtypes(output_type($f, blocktype(A))) |
66 | 59 | D = similar(A, BlockType(Td)) |
67 | 60 | V = similar(A, BlockType(Tv)) |
68 | 61 | return (D, V) |
69 | 62 | end |
70 | 63 | function MatrixAlgebraKit.$f( |
71 | | - A::AbstractBlockSparseMatrix, (D, V), alg::BlockPermutedDiagonalAlgorithm |
| 64 | + A::AbstractBlockSparseMatrix, (D, V), alg::BlockDiagonalAlgorithm |
72 | 65 | ) |
73 | | - check_input($f, A, (D, V)) |
74 | | - for I in eachstoredblockdiagindex(A) |
75 | | - block = @view!(A[I]) |
76 | | - block_alg = block_algorithm(alg, block) |
77 | | - D[I], V[I] = $f(block, block_alg) |
78 | | - end |
79 | | - for I in eachunstoredblockdiagindex(A) |
80 | | - # TODO: Support setting `LinearAlgebra.I` directly, and/or |
81 | | - # using `FillArrays.Eye`. |
82 | | - V[I] = LinearAlgebra.I(size(@view(V[I]), 1)) |
| 66 | + MatrixAlgebraKit.check_input($f, A, (D, V), alg) |
| 67 | + |
| 68 | + # do decomposition on each block |
| 69 | + for bI in blockdiagindices(A) |
| 70 | + if isstored(A, bI) |
| 71 | + block = @view!(A[bI]) |
| 72 | + block_alg = block_algorithm(alg, block) |
| 73 | + bD, bV = $f(block, block_alg) |
| 74 | + D[bI] = bD |
| 75 | + V[bI] = bV |
| 76 | + else |
| 77 | + # TODO: this should be `V[bI] = LinearAlgebra.I` |
| 78 | + copyto!(@view!(V[bI]), LinearAlgebra.I) |
| 79 | + end |
83 | 80 | end |
84 | 81 | return (D, V) |
85 | 82 | end |
|
100 | 97 | for f in [:eig_vals!, :eigh_vals!] |
101 | 98 | @eval begin |
102 | 99 | function MatrixAlgebraKit.initialize_output( |
103 | | - ::typeof($f), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm |
| 100 | + ::typeof($f), A::AbstractBlockSparseMatrix, alg::BlockDiagonalAlgorithm |
104 | 101 | ) |
105 | 102 | T = output_type($f, blocktype(A)) |
106 | 103 | return similar(A, BlockType(T), axes(A, 1)) |
107 | 104 | end |
| 105 | + function MatrixAlgebraKit.check_input( |
| 106 | + ::typeof($f), A::AbstractBlockSparseMatrix, D, ::BlockDiagonalAlgorithm |
| 107 | + ) |
| 108 | + @assert isa(D, AbstractBlockSparseVector) |
| 109 | + @assert eltype(D) === $(f == :eig_vals! ? complex : real)(eltype(A)) |
| 110 | + @assert axes(A, 1) == axes(A, 2) |
| 111 | + @assert (axes(A, 1),) == axes(D) |
| 112 | + @assert isblockdiagonal(A) |
| 113 | + return nothing |
| 114 | + end |
| 115 | + |
108 | 116 | function MatrixAlgebraKit.$f( |
109 | | - A::AbstractBlockSparseMatrix, D, alg::BlockPermutedDiagonalAlgorithm |
| 117 | + A::AbstractBlockSparseMatrix, D, alg::BlockDiagonalAlgorithm |
110 | 118 | ) |
| 119 | + MatrixAlgebraKit.check_input($f, A, D, alg) |
111 | 120 | for I in eachblockstoredindex(A) |
112 | 121 | block = @view!(A[I]) |
113 | | - D[I] = $f(block, block_algorithm(alg, block)) |
| 122 | + D[Tuple(I)[1]] = $f(block, block_algorithm(alg, block)) |
114 | 123 | end |
115 | 124 | return D |
116 | 125 | end |
|
0 commit comments