Skip to content

Cleanup and generalize functions of Hermitian matrices #1340

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 48 additions & 110 deletions src/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,11 +224,15 @@ const RealHermSymSymTri{T<:Real} = Union{RealHermSym{T}, SymTridiagonal{T}}
const RealHermSymComplexHerm{T<:Real,S} = Union{Hermitian{T,S}, Symmetric{T,S}, Hermitian{Complex{T},S}}
const RealHermSymComplexSym{T<:Real,S} = Union{Hermitian{T,S}, Symmetric{T,S}, Symmetric{Complex{T},S}}
const RealHermSymSymTriComplexHerm{T<:Real} = Union{RealHermSymComplexSym{T}, SymTridiagonal{T}}
const SelfAdjoint = Union{Symmetric{<:Real}, Hermitian{<:Number}}
const SelfAdjoint = Union{SymTridiagonal{<:Real}, Symmetric{<:Real}, Hermitian}

wrappertype(::Union{Symmetric, SymTridiagonal}) = Symmetric
wrappertype(::Hermitian) = Hermitian

nonhermitianwrappertype(::SymSymTri{<:Real}) = Symmetric
nonhermitianwrappertype(::Hermitian{<:Real}) = Symmetric
nonhermitianwrappertype(::Hermitian) = identity

size(A::HermOrSym) = size(A.data)
axes(A::HermOrSym) = axes(A.data)
@inline function Base.isassigned(A::HermOrSym, i::Int, j::Int)
Expand Down Expand Up @@ -834,119 +838,74 @@ end
^(A::Symmetric{<:Complex}, p::Integer) = sympow(A, p)
^(A::SymTridiagonal{<:Real}, p::Integer) = sympow(A, p)
^(A::SymTridiagonal{<:Complex}, p::Integer) = sympow(A, p)
function sympow(A::SymSymTri, p::Integer)
if p < 0
return Symmetric(Base.power_by_squaring(inv(A), -p))
else
return Symmetric(Base.power_by_squaring(A, p))
end
end
for hermtype in (:Symmetric, :SymTridiagonal)
@eval begin
function ^(A::$hermtype{<:Real}, p::Real)
isinteger(p) && return integerpow(A, p)
F = eigen(A)
if all(λ -> λ ≥ 0, F.values)
return Symmetric((F.vectors * Diagonal((F.values).^p)) * F.vectors')
else
return Symmetric((F.vectors * Diagonal(complex.(F.values).^p)) * F.vectors')
end
end
function ^(A::$hermtype{<:Complex}, p::Real)
isinteger(p) && return integerpow(A, p)
return Symmetric(schurpow(A, p))
end
end
end
function ^(A::Hermitian, p::Integer)
^(A::Hermitian, p::Integer) = sympow(A, p)
function sympow(A, p::Integer)
if p < 0
retmat = Base.power_by_squaring(inv(A), -p)
else
retmat = Base.power_by_squaring(A, p)
end
return Hermitian(retmat)
return wrappertype(A)(retmat)
end
function ^(A::Hermitian{T}, p::Real) where T
function ^(A::SelfAdjoint, p::Real)
isinteger(p) && return integerpow(A, p)
F = eigen(A)
if all(λ -> λ ≥ 0, F.values)
retmat = (F.vectors * Diagonal((F.values).^p)) * F.vectors'
return Hermitian(retmat)
return wrappertype(A)(retmat)
else
retmat = (F.vectors * Diagonal((complex.(F.values).^p))) * F.vectors'
if T <: Real
return Symmetric(retmat)
else
return retmat
end
retmat = (F.vectors * Diagonal(complex.(F.values).^p)) * F.vectors'
return nonhermitianwrappertype(A)(retmat)
end
end
function ^(A::SymSymTri{<:Complex}, p::Real)
isinteger(p) && return integerpow(A, p)
return Symmetric(schurpow(A, p))
end

for func in (:exp, :cos, :sin, :tan, :cosh, :sinh, :tanh, :atan, :asinh, :atanh, :cbrt)
@eval begin
function ($func)(A::RealHermSymSymTri)
F = eigen(A)
return wrappertype(A)((F.vectors * Diagonal(($func).(F.values))) * F.vectors')
end
function ($func)(A::Hermitian{<:Complex})
function ($func)(A::SelfAdjoint)
F = eigen(A)
retmat = (F.vectors * Diagonal(($func).(F.values))) * F.vectors'
return Hermitian(retmat)
return wrappertype(A)(retmat)
end
end
end

function cis(A::RealHermSymSymTri)
function cis(A::SelfAdjoint)
F = eigen(A)
return Symmetric(F.vectors .* cis.(F.values') * F.vectors')
retmat = F.vectors .* cis.(F.values') * F.vectors'
return nonhermitianwrappertype(A)(retmat)
end
function cis(A::Hermitian{<:Complex})
F = eigen(A)
return F.vectors .* cis.(F.values') * F.vectors'
end


for func in (:acos, :asin)
@eval begin
function ($func)(A::RealHermSymSymTri)
F = eigen(A)
if all(λ -> -1 ≤ λ ≤ 1, F.values)
return wrappertype(A)((F.vectors * Diagonal(($func).(F.values))) * F.vectors')
else
return Symmetric((F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors')
end
end
function ($func)(A::Hermitian{<:Complex})
function ($func)(A::SelfAdjoint)
F = eigen(A)
if all(λ -> -1 ≤ λ ≤ 1, F.values)
retmat = (F.vectors * Diagonal(($func).(F.values))) * F.vectors'
return Hermitian(retmat)
return wrappertype(A)(retmat)
else
return (F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors'
retmat = (F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors'
return nonhermitianwrappertype(A)(retmat)
end
end
end
end

function acosh(A::RealHermSymSymTri)
F = eigen(A)
if all(λ -> λ ≥ 1, F.values)
return wrappertype(A)((F.vectors * Diagonal(acosh.(F.values))) * F.vectors')
else
return Symmetric((F.vectors * Diagonal(acosh.(complex.(F.values)))) * F.vectors')
end
end
function acosh(A::Hermitian{<:Complex})
function acosh(A::SelfAdjoint)
F = eigen(A)
if all(λ -> λ ≥ 1, F.values)
retmat = (F.vectors * Diagonal(acosh.(F.values))) * F.vectors'
return Hermitian(retmat)
return wrappertype(A)(retmat)
else
return (F.vectors * Diagonal(acosh.(complex.(F.values)))) * F.vectors'
retmat = (F.vectors * Diagonal(acosh.(complex.(F.values)))) * F.vectors'
return nonhermitianwrappertype(A)(retmat)
end
end

function sincos(A::RealHermSymSymTri)
function sincos(A::SelfAdjoint)
n = checksquare(A)
F = eigen(A)
T = float(eltype(F.values))
Expand All @@ -956,49 +915,28 @@ function sincos(A::RealHermSymSymTri)
end
return wrappertype(A)((F.vectors * S) * F.vectors'), wrappertype(A)((F.vectors * C) * F.vectors')
end
function sincos(A::Hermitian{<:Complex})
n = checksquare(A)

function log(A::SelfAdjoint)
F = eigen(A)
T = float(eltype(F.values))
S, C = Diagonal(similar(A, T, (n,))), Diagonal(similar(A, T, (n,)))
for i in eachindex(S.diag, C.diag, F.values)
S.diag[i], C.diag[i] = sincos(F.values[i])
end
retmatS, retmatC = (F.vectors * S) * F.vectors', (F.vectors * C) * F.vectors'
for i in diagind(retmatS, IndexStyle(retmatS))
retmatS[i] = real(retmatS[i])
retmatC[i] = real(retmatC[i])
if all(λ -> λ ≥ 0, F.values)
retmat = (F.vectors * Diagonal(log.(F.values))) * F.vectors'
return wrappertype(A)(retmat)
else
retmat = (F.vectors * Diagonal(log.(complex.(F.values)))) * F.vectors'
return nonhermitianwrappertype(A)(retmat)
end
return Hermitian(retmatS), Hermitian(retmatC)
end


for func in (:log, :sqrt)
# sqrt has rtol arg to handle matrices that are semidefinite up to roundoff errors
rtolarg = func === :sqrt ? Any[Expr(:kw, :(rtol::Real), :(eps(real(float(one(T))))*size(A,1)))] : Any[]
rtolval = func === :sqrt ? :(-maximum(abs, F.values) * rtol) : 0
@eval begin
function ($func)(A::RealHermSymSymTri{T}; $(rtolarg...)) where {T<:Real}
F = eigen(A)
λ₀ = $rtolval # treat λ ≥ λ₀ as "zero" eigenvalues up to roundoff
if all(λ -> λ ≥ λ₀, F.values)
return wrappertype(A)((F.vectors * Diagonal(($func).(max.(0, F.values)))) * F.vectors')
else
return Symmetric((F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors')
end
end
function ($func)(A::Hermitian{T}; $(rtolarg...)) where {T<:Complex}
n = checksquare(A)
F = eigen(A)
λ₀ = $rtolval # treat λ ≥ λ₀ as "zero" eigenvalues up to roundoff
if all(λ -> λ ≥ λ₀, F.values)
retmat = (F.vectors * Diagonal(($func).(max.(0, F.values)))) * F.vectors'
return Hermitian(retmat)
else
retmat = (F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors'
return retmat
end
end
# sqrt has rtol kwarg to handle matrices that are semidefinite up to roundoff errors
function sqrt(A::SelfAdjoint; rtol = eps(real(float(eltype(A)))) * size(A, 1))
F = eigen(A)
λ₀ = -maximum(abs, F.values) * rtol # treat λ ≥ λ₀ as "zero" eigenvalues up to roundoff
if all(λ -> λ ≥ λ₀, F.values)
retmat = (F.vectors * Diagonal(sqrt.(max.(0, F.values)))) * F.vectors'
return wrappertype(A)(retmat)
else
retmat = (F.vectors * Diagonal(sqrt.(complex.(F.values)))) * F.vectors'
return nonhermitianwrappertype(A)(retmat)
end
end

Expand Down
10 changes: 10 additions & 0 deletions test/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1199,4 +1199,14 @@ end
end
end

@testset "asin/acos/acosh for matrix outside the real domain" begin
M = [0 2;2 0] #eigenvalues are ±2
for T ∈ (Float32, Float64, ComplexF32, ComplexF64)
M2 = Hermitian(T.(M))
@test sin(asin(M2)) ≈ M2
@test cos(acos(M2)) ≈ M2
@test cosh(acosh(M2)) ≈ M2
end
end

end # module TestSymmetric