diff --git a/src/triangular.jl b/src/triangular.jl index b0af1c34..1d8570bf 100644 --- a/src/triangular.jl +++ b/src/triangular.jl @@ -238,10 +238,22 @@ Base.isassigned(A::UpperOrLowerTriangular, i::Int, j::Int) = Base.isstored(A::UpperOrLowerTriangular, i::Int, j::Int) = _shouldforwardindex(A, i, j) ? Base.isstored(A.data, i, j) : false -@propagate_inbounds getindex(A::Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}}, i::Int, j::Int) where {T} = - _shouldforwardindex(A, i, j) ? A.data[i,j] : ifelse(i == j, oneunit(T), zero(T)) -@propagate_inbounds getindex(A::Union{LowerTriangular, UpperTriangular}, i::Int, j::Int) = - _shouldforwardindex(A, i, j) ? A.data[i,j] : diagzero(A,i,j) +@propagate_inbounds function getindex(A::Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}}, i::Int, j::Int) where {T} + if _shouldforwardindex(A, i, j) + A.data[i,j] + else + @boundscheck checkbounds(A, i, j) + ifelse(i == j, oneunit(T), zero(T)) + end +end +@propagate_inbounds function getindex(A::Union{LowerTriangular, UpperTriangular}, i::Int, j::Int) + if _shouldforwardindex(A, i, j) + A.data[i,j] + else + @boundscheck checkbounds(A, i, j) + @inbounds diagzero(A,i,j) + end +end _shouldforwardindex(U::UpperTriangular, b::BandIndex) = b.band >= 0 _shouldforwardindex(U::LowerTriangular, b::BandIndex) = b.band <= 0 @@ -250,10 +262,20 @@ _shouldforwardindex(U::UnitLowerTriangular, b::BandIndex) = b.band < 0 # these specialized getindex methods enable constant-propagation of the band Base.@constprop :aggressive @propagate_inbounds function getindex(A::Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}}, b::BandIndex) where {T} - _shouldforwardindex(A, b) ? A.data[b] : ifelse(b.band == 0, oneunit(T), zero(T)) + if _shouldforwardindex(A, b) + A.data[b] + else + @boundscheck checkbounds(A, b) + ifelse(b.band == 0, oneunit(T), zero(T)) + end end Base.@constprop :aggressive @propagate_inbounds function getindex(A::Union{LowerTriangular, UpperTriangular}, b::BandIndex) - _shouldforwardindex(A, b) ? A.data[b] : diagzero(A.data, b) + if _shouldforwardindex(A, b) + A.data[b] + else + @boundscheck checkbounds(A, b) + @inbounds diagzero(A, b) + end end _zero_triangular_half_str(::Type{<:UpperOrUnitUpperTriangular}) = "lower" @@ -265,14 +287,20 @@ _zero_triangular_half_str(::Type{<:LowerOrUnitLowerTriangular}) = "upper" throw(ArgumentError( lazy"cannot set index in the $Ts triangular part ($i, $j) of an $Tn matrix to a nonzero value ($x)")) end -@noinline function throw_nononeerror(T, @nospecialize(x), i, j) +@noinline function throw_nonuniterror(T, @nospecialize(x), i, j) + check_compatible_type(T, x) Tn = nameof(T) throw(ArgumentError( lazy"cannot set index on the diagonal ($i, $j) of an $Tn matrix to a non-unit value ($x)")) end +function check_compatible_type(T, @nospecialize(x)) + ET = eltype(T) + convert(ET, x) # check that the types are compatible with setindex! +end @propagate_inbounds function setindex!(A::UpperTriangular, x, i::Integer, j::Integer) if i > j + @boundscheck checkbounds(A, i, j) iszero(x) || throw_nonzeroerror(typeof(A), x, i, j) else A.data[i,j] = x @@ -282,9 +310,11 @@ end @propagate_inbounds function setindex!(A::UnitUpperTriangular, x, i::Integer, j::Integer) if i > j + @boundscheck checkbounds(A, i, j) iszero(x) || throw_nonzeroerror(typeof(A), x, i, j) elseif i == j - x == oneunit(x) || throw_nononeerror(typeof(A), x, i, j) + @boundscheck checkbounds(A, i, j) + x == oneunit(eltype(A)) || throw_nonuniterror(typeof(A), x, i, j) else A.data[i,j] = x end @@ -293,6 +323,7 @@ end @propagate_inbounds function setindex!(A::LowerTriangular, x, i::Integer, j::Integer) if i < j + @boundscheck checkbounds(A, i, j) iszero(x) || throw_nonzeroerror(typeof(A), x, i, j) else A.data[i,j] = x @@ -302,9 +333,11 @@ end @propagate_inbounds function setindex!(A::UnitLowerTriangular, x, i::Integer, j::Integer) if i < j + @boundscheck checkbounds(A, i, j) iszero(x) || throw_nonzeroerror(typeof(A), x, i, j) elseif i == j - x == oneunit(x) || throw_nononeerror(typeof(A), x, i, j) + @boundscheck checkbounds(A, i, j) + x == oneunit(eltype(A)) || throw_nonuniterror(typeof(A), x, i, j) else A.data[i,j] = x end @@ -560,7 +593,7 @@ for (T, UT) in ((:UpperTriangular, :UnitUpperTriangular), (:LowerTriangular, :Un @eval @inline function _copy!(A::$UT, B::$T) for dind in diagind(A, IndexStyle(A)) if A[dind] != B[dind] - throw_nononeerror(typeof(A), B[dind], Tuple(dind)...) + throw_nonuniterror(typeof(A), B[dind], Tuple(dind)...) end end _copy!($T(parent(A)), B) @@ -740,7 +773,7 @@ function _triscale!(A::UpperOrUnitUpperTriangular, B::UnitUpperTriangular, c::Nu checksize1(A, B) _iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta) for j in axes(B.data,2) - @inbounds _modify!(_add, c, A, (j,j)) + @inbounds _modify!(_add, B[BandIndex(0,j)] * c, A, (j,j)) for i in firstindex(B.data,1):(j - 1) @inbounds _modify!(_add, B.data[i,j] * c, A.data, (i,j)) end @@ -751,7 +784,7 @@ function _triscale!(A::UpperOrUnitUpperTriangular, c::Number, B::UnitUpperTriang checksize1(A, B) _iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta) for j in axes(B.data,2) - @inbounds _modify!(_add, c, A, (j,j)) + @inbounds _modify!(_add, c * B[BandIndex(0,j)], A, (j,j)) for i in firstindex(B.data,1):(j - 1) @inbounds _modify!(_add, c * B.data[i,j], A.data, (i,j)) end @@ -782,7 +815,7 @@ function _triscale!(A::LowerOrUnitLowerTriangular, B::UnitLowerTriangular, c::Nu checksize1(A, B) _iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta) for j in axes(B.data,2) - @inbounds _modify!(_add, c, A, (j,j)) + @inbounds _modify!(_add, B[BandIndex(0,j)] * c, A, (j,j)) for i in (j + 1):lastindex(B.data,1) @inbounds _modify!(_add, B.data[i,j] * c, A.data, (i,j)) end @@ -793,7 +826,7 @@ function _triscale!(A::LowerOrUnitLowerTriangular, c::Number, B::UnitLowerTriang checksize1(A, B) _iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta) for j in axes(B.data,2) - @inbounds _modify!(_add, c, A, (j,j)) + @inbounds _modify!(_add, c * B[BandIndex(0,j)], A, (j,j)) for i in (j + 1):lastindex(B.data,1) @inbounds _modify!(_add, c * B.data[i,j], A.data, (i,j)) end diff --git a/test/triangular.jl b/test/triangular.jl index c5dca32d..4dd36ded 100644 --- a/test/triangular.jl +++ b/test/triangular.jl @@ -942,6 +942,68 @@ end @test 2\U == 2\M @test U*2 == M*2 @test 2*U == 2*M + + U2 = copy(U) + @test rmul!(U, 1) == U2 + @test lmul!(1, U) == U2 +end + +@testset "indexing checks" begin + P = [1 2; 3 4] + @testset "getindex" begin + U = UnitUpperTriangular(P) + @test_throws BoundsError U[0,0] + @test_throws BoundsError U[1,0] + @test_throws BoundsError U[BandIndex(0,0)] + @test_throws BoundsError U[BandIndex(-1,0)] + + U = UpperTriangular(P) + @test_throws BoundsError U[1,0] + @test_throws BoundsError U[BandIndex(-1,0)] + + L = UnitLowerTriangular(P) + @test_throws BoundsError L[0,0] + @test_throws BoundsError L[0,1] + @test_throws BoundsError U[BandIndex(0,0)] + @test_throws BoundsError U[BandIndex(1,0)] + + L = LowerTriangular(P) + @test_throws BoundsError L[0,1] + @test_throws BoundsError L[BandIndex(1,0)] + end + @testset "setindex!" begin + A = SizedArrays.SizedArray{(2,2)}(P) + M = fill(A, 2, 2) + U = UnitUpperTriangular(M) + @test_throws "Cannot `convert` an object of type $Int" U[1,1] = 1 + L = UnitLowerTriangular(M) + @test_throws "Cannot `convert` an object of type $Int" L[1,1] = 1 + + U = UnitUpperTriangular(P) + @test_throws BoundsError U[0,0] = 1 + @test_throws BoundsError U[1,0] = 0 + + U = UpperTriangular(P) + @test_throws BoundsError U[1,0] = 0 + + L = UnitLowerTriangular(P) + @test_throws BoundsError L[0,0] = 1 + @test_throws BoundsError L[0,1] = 0 + + L = LowerTriangular(P) + @test_throws BoundsError L[0,1] = 0 + end +end + +@testset "unit triangular l/rdiv!" begin + A = rand(3,3) + @testset for (UT,T) in ((UnitUpperTriangular, UpperTriangular), + (UnitLowerTriangular, LowerTriangular)) + UnitTri = UT(A) + Tri = T(LinearAlgebra.full(UnitTri)) + @test 2 \ UnitTri ≈ 2 \ Tri + @test UnitTri / 2 ≈ Tri / 2 + end end end # module TestTriangular