Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disambiguate structured and abstract matrix multiplication #52464

Merged
merged 9 commits into from
Jan 6, 2024
11 changes: 11 additions & 0 deletions stdlib/LinearAlgebra/src/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,17 @@ _makevector(x::AbstractVector) = Vector(x)
_pushzero(A) = (B = similar(A, length(A)+1); @inbounds B[begin:end-1] .= A; @inbounds B[end] = zero(eltype(B)); B)
_droplast!(A) = deleteat!(A, lastindex(A))

# destination type for matmul
matprod_dest(A::StructuredMatrix, B::StructuredMatrix, TS) = similar(B, TS, size(B))
matprod_dest(A, B::StructuredMatrix, TS) = similar(A, TS, size(A))
matprod_dest(A::StructuredMatrix, B, TS) = similar(B, TS, size(B))
matprod_dest(A::StructuredMatrix, B::Diagonal, TS) = similar(A, TS)
matprod_dest(A::Diagonal, B::StructuredMatrix, TS) = similar(B, TS)
matprod_dest(A::Diagonal, B::Diagonal, TS) = similar(B, TS)
matprod_dest(A::HermOrSym, B::Diagonal, TS) = similar(A, TS, size(A))
matprod_dest(A::Diagonal, B::HermOrSym, TS) = similar(B, TS, size(B))

# TODO: remove once not used anymore in SparseArrays.jl
# some trait like this would be cool
# onedefined(::Type{T}) where {T} = hasmethod(one, (T,))
# but we are actually asking for oneunit(T), that is, however, defined for generic T as
Expand Down
26 changes: 14 additions & 12 deletions stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -802,34 +802,35 @@ ldiv!(c::AbstractVecOrMat, A::AdjOrTrans{<:Any,<:Bidiagonal}, b::AbstractVecOrMa
(t = wrapperop(A); _rdiv!(t(c), t(b), t(A)); return c)

### Generic promotion methods and fallbacks
\(A::Bidiagonal, B::AbstractVecOrMat) = ldiv!(_initarray(\, eltype(A), eltype(B), B), A, B)
\(A::Bidiagonal, B::AbstractVecOrMat) =
ldiv!(matprod_dest(A, B, promote_op(\, eltype(A), eltype(B))), A, B)
\(xA::AdjOrTrans{<:Any,<:Bidiagonal}, B::AbstractVecOrMat) = copy(xA) \ B

### Triangular specializations
for tri in (:UpperTriangular, :UnitUpperTriangular)
@eval function \(B::Bidiagonal, U::$tri)
A = ldiv!(_initarray(\, eltype(B), eltype(U), U), B, U)
A = ldiv!(matprod_dest(B, U, promote_op(\, eltype(B), eltype(U))), B, U)
return B.uplo == 'U' ? UpperTriangular(A) : A
end
@eval function \(U::$tri, B::Bidiagonal)
A = ldiv!(_initarray(\, eltype(U), eltype(B), U), U, B)
A = ldiv!(matprod_dest(U, B, promote_op(\, eltype(U), eltype(B))), U, B)
return B.uplo == 'U' ? UpperTriangular(A) : A
end
end
for tri in (:LowerTriangular, :UnitLowerTriangular)
@eval function \(B::Bidiagonal, L::$tri)
A = ldiv!(_initarray(\, eltype(B), eltype(L), L), B, L)
A = ldiv!(matprod_dest(B, L, promote_op(\, eltype(B), eltype(L))), B, L)
return B.uplo == 'L' ? LowerTriangular(A) : A
end
@eval function \(L::$tri, B::Bidiagonal)
A = ldiv!(_initarray(\, eltype(L), eltype(B), L), L, B)
A = ldiv!(matprod_dest(L, B, promote_op(\, eltype(L), eltype(B))), L, B)
return B.uplo == 'L' ? LowerTriangular(A) : A
end
end

### Diagonal specialization
function \(B::Bidiagonal, D::Diagonal)
A = ldiv!(_initarray(\, eltype(B), eltype(D), D), B, D)
A = ldiv!(similar(D, promote_op(\, eltype(B), eltype(D)), size(D)), B, D)
return B.uplo == 'U' ? UpperTriangular(A) : LowerTriangular(A)
end

Expand Down Expand Up @@ -879,33 +880,34 @@ rdiv!(A::AbstractMatrix, B::AdjOrTrans{<:Any,<:Bidiagonal}) = @inline _rdiv!(A,
_rdiv!(C::AbstractMatrix, A::AbstractMatrix, B::AdjOrTrans{<:Any,<:Bidiagonal}) =
(t = wrapperop(B); ldiv!(t(C), t(B), t(A)); return C)

/(A::AbstractMatrix, B::Bidiagonal) = _rdiv!(_initarray(/, eltype(A), eltype(B), A), A, B)
/(A::AbstractMatrix, B::Bidiagonal) =
_rdiv!(similar(A, promote_op(/, eltype(A), eltype(B)), size(A)), A, B)

### Triangular specializations
for tri in (:UpperTriangular, :UnitUpperTriangular)
@eval function /(U::$tri, B::Bidiagonal)
A = _rdiv!(_initarray(/, eltype(U), eltype(B), U), U, B)
A = _rdiv!(matprod_dest(U, B, promote_op(/, eltype(U), eltype(B))), U, B)
return B.uplo == 'U' ? UpperTriangular(A) : A
end
@eval function /(B::Bidiagonal, U::$tri)
A = _rdiv!(_initarray(/, eltype(B), eltype(U), U), B, U)
A = _rdiv!(matprod_dest(B, U, promote_op(/, eltype(B), eltype(U))), B, U)
return B.uplo == 'U' ? UpperTriangular(A) : A
end
end
for tri in (:LowerTriangular, :UnitLowerTriangular)
@eval function /(L::$tri, B::Bidiagonal)
A = _rdiv!(_initarray(/, eltype(L), eltype(B), L), L, B)
A = _rdiv!(matprod_dest(L, B, promote_op(/, eltype(L), eltype(B))), L, B)
return B.uplo == 'L' ? LowerTriangular(A) : A
end
@eval function /(B::Bidiagonal, L::$tri)
A = _rdiv!(_initarray(/, eltype(B), eltype(L), L), B, L)
A = _rdiv!(matprod_dest(B, L, promote_op(/, eltype(B), eltype(L))), B, L)
return B.uplo == 'L' ? LowerTriangular(A) : A
end
end

### Diagonal specialization
function /(D::Diagonal, B::Bidiagonal)
A = _rdiv!(_initarray(/, eltype(D), eltype(B), D), D, B)
A = _rdiv!(similar(D, promote_op(/, eltype(D), eltype(B)), size(D)), D, B)
return B.uplo == 'U' ? UpperTriangular(A) : LowerTriangular(A)
end

Expand Down
26 changes: 7 additions & 19 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -310,15 +310,6 @@ function (*)(D::Diagonal, V::AbstractVector)
return D.diag .* V
end

(*)(A::AbstractMatrix, D::Diagonal) =
mul!(similar(A, promote_op(*, eltype(A), eltype(D.diag))), A, D)
(*)(A::HermOrSym, D::Diagonal) =
mul!(similar(A, promote_op(*, eltype(A), eltype(D.diag)), size(A)), A, D)
(*)(D::Diagonal, A::AbstractMatrix) =
mul!(similar(A, promote_op(*, eltype(D.diag), eltype(A))), D, A)
(*)(D::Diagonal, A::HermOrSym) =
mul!(similar(A, promote_op(*, eltype(A), eltype(D.diag)), size(A)), D, A)

rmul!(A::AbstractMatrix, D::Diagonal) = @inline mul!(A, A, D)
lmul!(D::Diagonal, B::AbstractVecOrMat) = @inline mul!(B, D, B)

Expand Down Expand Up @@ -431,8 +422,8 @@ function (*)(Da::Diagonal, Db::Diagonal, Dc::Diagonal)
return Diagonal(Da.diag .* Db.diag .* Dc.diag)
end

/(A::AbstractVecOrMat, D::Diagonal) = _rdiv!(similar(A, _init_eltype(/, eltype(A), eltype(D))), A, D)
/(A::HermOrSym, D::Diagonal) = _rdiv!(similar(A, _init_eltype(/, eltype(A), eltype(D)), size(A)), A, D)
/(A::AbstractVecOrMat, D::Diagonal) = _rdiv!(matprod_dest(A, D, promote_op(/, eltype(A), eltype(D))), A, D)
/(A::HermOrSym, D::Diagonal) = _rdiv!(matprod_dest(A, D, promote_op(/, eltype(A), eltype(D))), A, D)

rdiv!(A::AbstractVecOrMat, D::Diagonal) = @inline _rdiv!(A, A, D)
# avoid copy when possible via internal 3-arg backend
Expand All @@ -458,8 +449,8 @@ function \(D::Diagonal, B::AbstractVector)
isnothing(j) || throw(SingularException(j))
return D.diag .\ B
end
\(D::Diagonal, B::AbstractMatrix) = ldiv!(similar(B, _init_eltype(\, eltype(D), eltype(B))), D, B)
\(D::Diagonal, B::HermOrSym) = ldiv!(similar(B, _init_eltype(\, eltype(D), eltype(B)), size(B)), D, B)
\(D::Diagonal, B::AbstractMatrix) = ldiv!(matprod_dest(D, B, promote_op(\, eltype(D), eltype(B))), D, B)
\(D::Diagonal, B::HermOrSym) = ldiv!(matprod_dest(D, B, promote_op(\, eltype(D), eltype(B))), D, B)

ldiv!(D::Diagonal, B::AbstractVecOrMat) = @inline ldiv!(B, D, B)
function ldiv!(B::AbstractVecOrMat, D::Diagonal, A::AbstractVecOrMat)
Expand All @@ -479,8 +470,8 @@ function ldiv!(B::AbstractVecOrMat, D::Diagonal, A::AbstractVecOrMat)
end

# Optimizations for \, / between Diagonals
\(D::Diagonal, B::Diagonal) = ldiv!(similar(B, promote_op(\, eltype(D), eltype(B))), D, B)
/(A::Diagonal, D::Diagonal) = _rdiv!(similar(A, promote_op(/, eltype(A), eltype(D))), A, D)
\(D::Diagonal, B::Diagonal) = ldiv!(matprod_dest(D, B, promote_op(\, eltype(D), eltype(B))), D, B)
/(A::Diagonal, D::Diagonal) = _rdiv!(matprod_dest(A, D, promote_op(/, eltype(A), eltype(D))), A, D)
function _rdiv!(Dc::Diagonal, Db::Diagonal, Da::Diagonal)
n, k = length(Db.diag), length(Da.diag)
n == k || throw(DimensionMismatch("left hand side has $n columns but D is $k by $k"))
Expand Down Expand Up @@ -543,7 +534,7 @@ function (/)(S::SymTridiagonal, D::Diagonal)
dl = similar(S.ev, T, max(length(S.dv)-1, 0))
_rdiv!(Tridiagonal(dl, d, du), S, D)
end
(/)(T::Tridiagonal, D::Diagonal) = _rdiv!(similar(T, promote_op(/, eltype(T), eltype(D))), T, D)
(/)(T::Tridiagonal, D::Diagonal) = _rdiv!(matprod_dest(T, D, promote_op(/, eltype(T), eltype(D))), T, D)
function _rdiv!(T::Tridiagonal, S::Union{SymTridiagonal,Tridiagonal}, D::Diagonal)
n = size(S, 2)
dd = D.diag
Expand Down Expand Up @@ -876,9 +867,6 @@ function svd(D::Diagonal{T}) where {T<:Number}
return SVD(U, S, Vt)
end

# disambiguation methods: * and / of Diagonal and Adj/Trans AbsVec
*(u::AdjointAbsVec, D::Diagonal) = (D'u')'
*(u::TransposeAbsVec, D::Diagonal) = transpose(transpose(D) * transpose(u))
*(x::AdjointAbsVec, D::Diagonal, y::AbstractVector) = _mapreduce_prod(*, x, D, y)
*(x::TransposeAbsVec, D::Diagonal, y::AbstractVector) = _mapreduce_prod(*, x, D, y)
/(u::AdjointAbsVec, D::Diagonal) = (D' \ u')'
Expand Down
12 changes: 6 additions & 6 deletions stdlib/LinearAlgebra/src/hessenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,29 +133,29 @@ for T = (:Number, :UniformScaling, :Diagonal)
end

function *(H::UpperHessenberg, U::UpperOrUnitUpperTriangular)
HH = mul!(_initarray(*, eltype(H), eltype(U), H), H, U)
HH = mul!(matprod_dest(H, U, promote_op(matprod, eltype(H), eltype(U))), H, U)
UpperHessenberg(HH)
end
function *(U::UpperOrUnitUpperTriangular, H::UpperHessenberg)
HH = mul!(_initarray(*, eltype(U), eltype(H), H), U, H)
HH = mul!(matprod_dest(U, H, promote_op(matprod, eltype(U), eltype(H))), U, H)
UpperHessenberg(HH)
end

function /(H::UpperHessenberg, U::UpperTriangular)
HH = _rdiv!(_initarray(/, eltype(H), eltype(U), H), H, U)
HH = _rdiv!(matprod_dest(H, U, promote_op(/, eltype(H), eltype(U))), H, U)
UpperHessenberg(HH)
end
function /(H::UpperHessenberg, U::UnitUpperTriangular)
HH = _rdiv!(_initarray(/, eltype(H), eltype(U), H), H, U)
HH = _rdiv!(matprod_dest(H, U, promote_op(/, eltype(H), eltype(U))), H, U)
UpperHessenberg(HH)
end

function \(U::UpperTriangular, H::UpperHessenberg)
HH = ldiv!(_initarray(\, eltype(U), eltype(H), H), U, H)
HH = ldiv!(matprod_dest(U, H, promote_op(\, eltype(U), eltype(H))), U, H)
UpperHessenberg(HH)
end
function \(U::UnitUpperTriangular, H::UpperHessenberg)
HH = ldiv!(_initarray(\, eltype(U), eltype(H), H), U, H)
HH = ldiv!(matprod_dest(U, H, promote_op(\, eltype(U), eltype(H))), U, H)
UpperHessenberg(HH)
end

Expand Down
5 changes: 4 additions & 1 deletion stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,11 @@ julia> [1 1; 0 1] * [1 0; 1 1]
"""
function (*)(A::AbstractMatrix, B::AbstractMatrix)
TS = promote_op(matprod, eltype(A), eltype(B))
mul!(similar(B, TS, (size(A, 1), size(B, 2))), A, B)
mul!(matprod_dest(A, B, TS), A, B)
end

matprod_dest(A, B, TS) = similar(B, TS, (size(A, 1), size(B, 2)))

# optimization for dispatching to BLAS, e.g. *(::Matrix{Float32}, ::Matrix{Float64})
# but avoiding the case *(::Matrix{<:BlasComplex}, ::Matrix{<:BlasReal})
# which is better handled by reinterpreting rather than promotion
Expand Down
29 changes: 4 additions & 25 deletions stdlib/LinearAlgebra/src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1545,22 +1545,11 @@ rmul!(A::LowerTriangular, B::UnitLowerTriangular) = LowerTriangular(rmul!(tril!(
## necessary in the general triangular solve problem.

_inner_type_promotion(op, ::Type{TA}, ::Type{TB}) where {TA<:Integer,TB<:Integer} =
_init_eltype(*, TA, TB)
promote_op(matprod, TA, TB)
_inner_type_promotion(op, ::Type{TA}, ::Type{TB}) where {TA,TB} =
_init_eltype(op, TA, TB)
promote_op(op, TA, TB)
## The general promotion methods
function *(A::AbstractTriangular, B::AbstractTriangular)
TAB = _init_eltype(*, eltype(A), eltype(B))
mul!(similar(B, TAB, size(B)), A, B)
end

for mat in (:AbstractVector, :AbstractMatrix)
### Multiplication with triangle to the left and hence rhs cannot be transposed.
@eval function *(A::AbstractTriangular, B::$mat)
require_one_based_indexing(B)
TAB = _init_eltype(*, eltype(A), eltype(B))
mul!(similar(B, TAB, size(B)), A, B)
end
### Left division with triangle to the left hence rhs cannot be transposed. No quotients.
@eval function \(A::Union{UnitUpperTriangular,UnitLowerTriangular}, B::$mat)
require_one_based_indexing(B)
Expand All @@ -1570,7 +1559,7 @@ for mat in (:AbstractVector, :AbstractMatrix)
### Left division with triangle to the left hence rhs cannot be transposed. Quotients.
@eval function \(A::Union{UpperTriangular,LowerTriangular}, B::$mat)
require_one_based_indexing(B)
TAB = _init_eltype(\, eltype(A), eltype(B))
TAB = promote_op(\, eltype(A), eltype(B))
ldiv!(similar(B, TAB, size(B)), A, B)
end
### Right division with triangle to the right hence lhs cannot be transposed. No quotients.
Expand All @@ -1582,20 +1571,10 @@ for mat in (:AbstractVector, :AbstractMatrix)
### Right division with triangle to the right hence lhs cannot be transposed. Quotients.
@eval function /(A::$mat, B::Union{UpperTriangular,LowerTriangular})
require_one_based_indexing(A)
TAB = _init_eltype(/, eltype(A), eltype(B))
TAB = promote_op(/, eltype(A), eltype(B))
_rdiv!(similar(A, TAB, size(A)), A, B)
end
end
### Multiplication with triangle to the right and hence lhs cannot be transposed.
# Only for AbstractMatrix, hence outside the above loop.
function *(A::AbstractMatrix, B::AbstractTriangular)
require_one_based_indexing(A)
TAB = _init_eltype(*, eltype(A), eltype(B))
mul!(similar(A, TAB, size(A)), A, B)
end
# ambiguity resolution with definitions in matmul.jl
*(v::AdjointAbsVec, A::AbstractTriangular) = adjoint(adjoint(A) * v.parent)
*(v::TransposeAbsVec, A::AbstractTriangular) = transpose(transpose(A) * v.parent)

## Some Triangular-Triangular cases. We might want to write tailored methods
## for these cases, but I'm not sure it is worth it.
Expand Down
8 changes: 8 additions & 0 deletions stdlib/LinearAlgebra/test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1242,6 +1242,14 @@ end
end
end

@testset "avoid matmul ambiguities with ::MyMatrix * ::AbstractMatrix" begin
A = [i+j for i in 1:2, j in 1:2]
S = SizedArrays.SizedArray{(2,2)}(A)
D = Diagonal([1:2;])
@test S * D == A * D
@test D * S == D * A
end

@testset "copy" begin
@test copy(Diagonal(1:5)) === Diagonal(1:5)
end
Expand Down
8 changes: 8 additions & 0 deletions stdlib/LinearAlgebra/test/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,14 @@ end
end
end

@testset "avoid matmul ambiguities with ::MyMatrix * ::AbstractMatrix" begin
A = [i+j for i in 1:2, j in 1:2]
S = SizedArrays.SizedArray{(2,2)}(A)
U = UpperTriangular(ones(2,2))
@test S * U == A * U
@test U * S == U * A
end

@testset "custom axes" begin
SZA = SizedArrays.SizedArray{(2,2)}([1 2; 3 4])
for T in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular)
Expand Down
5 changes: 5 additions & 0 deletions test/testhelpers/SizedArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,9 @@ function *(S1::SizedArrayLike, S2::SizedArrayLike)
SZ = ndims(data) == 1 ? (size(S1, 1), ) : (size(S1, 1), size(S2, 2))
SizedArray{SZ}(data)
end

# deliberately wide method definition to ensure that this doesn't lead to ambiguities with
# structured matrices
*(S1::SizedArrayLike, M::AbstractMatrix) = _data(S1) * M

end