Skip to content

Commit f880a87

Browse files
Generalize batched_vec to support N-D batches (#652)
* Initial plan * Implement generalized batched_vec for N-D batches Co-authored-by: CarloLucibello <[email protected]> * Update src/batched/batchedmul.jl * Update test/batchedmul.jl --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: CarloLucibello <[email protected]> Co-authored-by: Carlo Lucibello <[email protected]>
1 parent b265a0a commit f880a87

File tree

2 files changed

+57
-3
lines changed

2 files changed

+57
-3
lines changed

src/batched/batchedmul.jl

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,13 +162,18 @@ _semi_batched_mul(A::Transpose{<:Number,<:AbstractMatrix}, B::AbstractArray{<:An
162162
batched_mul(batched_transpose(reshape(parent(A), size(parent(A))..., 1)), B)
163163

164164
"""
165-
batched_vec(A::Array{T,3}, B::Matrix)
166-
batched_vec(A::Array{T,3}, b::Vector)
165+
batched_vec(A::AbstractArray{T,3}, B::AbstractMatrix)
166+
batched_vec(A::AbstractArray{T,3}, b::AbstractVector)
167+
batched_vec(A::AbstractArray, B::AbstractArray)
167168
168-
Batched matrix-vector multiplication:
169+
Batched matrix-vector multiplication. For the 3D case:
169170
the result has `C[:,:,k] == A[:,:,k] * B[:,k]` for all `k`,
170171
or else `C[:,:,k] == A[:,:,k] * b` for `b::Vector`.
171172
173+
For the general N-D case where `ndims(A) == ndims(B) + 1`:
174+
the result has `C[:,k...] == A[:,:,k...] * B[:,k...]` for all batch indices `k...`.
175+
The batch dimensions must match: `size(A)[3:end] == size(B)[2:end]`.
176+
172177
With the same argument types, `batched_mul(A, B)` would regard `B` as
173178
a fixed matrix, not a batch of vectors. Both reshape and then
174179
call `batched_mul(::Array{T,3}, ::Array{T,3})`.
@@ -181,8 +186,27 @@ julia> batched_vec(A,B) |> size
181186
182187
julia> batched_vec(A,b) |> size
183188
(16, 32)
189+
190+
julia> A4d, B3d = randn(16,8,10,32), randn(8,10,32); # 4D and 3D arrays
191+
192+
julia> batched_vec(A4d, B3d) |> size
193+
(16, 10, 32)
184194
```
185195
"""
196+
function batched_vec(A::AbstractArray, B::AbstractArray)
197+
ndims(A) == ndims(B) + 1 || throw(DimensionMismatch(
198+
"batched_vec requires ndims(A) == ndims(B) + 1, got ndims(A)=$(ndims(A)) and ndims(B)=$(ndims(B))"))
199+
size(A)[3:end] == size(B)[2:end] || throw(DimensionMismatch(
200+
"batch dimensions must match: size(A)[3:end]=$(size(A)[3:end]) != size(B)[2:end]=$(size(B)[2:end])"))
201+
202+
# Reshape B to add a singleton dimension for matrix multiplication
203+
B_reshaped = reshape(B, size(B, 1), 1, size(B)[2:end]...)
204+
# Perform batched multiplication
205+
C = batched_mul(A, B_reshaped)
206+
# Remove the singleton dimension
207+
return dropdims(C, dims=2)
208+
end
209+
186210
batched_vec(A::AbstractArray{T,3} where T, B::AbstractMatrix) =
187211
reshape(batched_mul(A, reshape(B, size(B,1), 1, size(B,2))), size(A,1), size(A,3))
188212

test/batchedmul.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,3 +303,33 @@ FiniteDifferences.to_vec(x::BatchedTranspose) = FiniteDifferences.to_vec(collect
303303

304304
gradtest(batched_vec, randn(rng, M, P, B), randn(rng, P))
305305
end
306+
307+
@testset "batched_vec: N-D batches" begin
308+
# Test 4D case: A is 4D, B is 3D
309+
A4d = randn(4, 5, 3, 2) # (matrix_rows, matrix_cols, batch_dim1, batch_dim2)
310+
B3d = randn(5, 3, 2) # (vector_length, batch_dim1, batch_dim2)
311+
312+
C = batched_vec(A4d, B3d)
313+
@test size(C) == (4, 3, 2)
314+
315+
# Manual verification
316+
for i in 1:3, j in 1:2
317+
@test C[:, i, j] A4d[:, :, i, j] * B3d[:, i, j]
318+
end
319+
320+
# Test 5D case: A is 5D, B is 4D
321+
A5d = randn(3, 4, 2, 3, 2) # (matrix_rows, matrix_cols, batch1, batch2, batch3)
322+
B4d = randn(4, 2, 3, 2) # (vector_length, batch1, batch2, batch3)
323+
324+
C5 = batched_vec(A5d, B4d)
325+
@test size(C5) == (3, 2, 3, 2)
326+
327+
# Manual verification for a few cases
328+
@test C5[:, 1, 1, 1] A5d[:, :, 1, 1, 1] * B4d[:, 1, 1, 1]
329+
@test C5[:, 2, 3, 2] A5d[:, :, 2, 3, 2] * B4d[:, 2, 3, 2]
330+
331+
# Test dimension mismatch errors
332+
@test_throws DimensionMismatch batched_vec(randn(3, 4, 2), randn(4, 3)) # ndims mismatch
333+
@test_throws DimensionMismatch batched_vec(randn(3, 4, 2, 3), randn(4, 2, 2)) # batch size mismatch
334+
335+
end

0 commit comments

Comments
 (0)