@@ -49,13 +49,32 @@ function gather_kernel!(dst, src, idx::CUDA.CuDeviceArray{<:CartesianIndex}, max
4949 return nothing
5050end
5151
52+ function checkbounds_src (src, dims:: Int , :: Type{<:Any} )
53+ return i -> checkbounds (Bool, src, ntuple (x -> Colon (), dims)... , i... )
54+ end
55+
56+ function checkbounds_src (src, dims:: Int , :: Type{<:CartesianIndex} )
57+ return i -> checkbounds (Bool, src, ntuple (x -> Colon (), dims)... , i)
58+ end
59+
5260function NNlib. gather! (dst:: AnyCuArray , src:: AnyCuArray , idx:: AnyCuArray )
61+ # check dims
5362 dims = gather_check_dims (src, dst, idx)
5463 dims_size = size (src)[1 : dims]
5564 max_dims_idx = prod (dims_size)
5665 max_idx = max_dims_idx * length (idx)
57- args = dst, src, idx, max_idx, max_dims_idx, dims_size
5866
67+ # check bounds
68+ chk = checkbounds_src (src, dims, eltype (idx))
69+ in_bnd = map (chk, collect (idx))
70+ if ! all (in_bnd)
71+ j = findfirst (i -> ! i, in_bnd)
72+ k = CUDA. @allowscalar idx[j]
73+ throw (BoundsError (src, k))
74+ end
75+
76+ # cuda kernel
77+ args = dst, src, idx, max_idx, max_dims_idx, dims_size
5978 kernel = @cuda launch= false gather_kernel! (args... )
6079 config = launch_configuration (kernel. fun; max_threads= 256 )
6180 threads = min (max_idx, config. threads)
0 commit comments