Skip to content

Commit 416ed9f

Browse files
committed
wgpu: resolve merge; import stride::contiguous_strides; format
1 parent 1c82538 commit 416ed9f

2 files changed

Lines changed: 16 additions & 22 deletions

File tree

crates/cubecl-wgpu/src/compute/server.rs

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use cubecl_core::{
1616
};
1717
use cubecl_runtime::logging::ServerLogger;
1818
use cubecl_runtime::memory_management::offset_handles;
19+
use cubecl_runtime::stride::contiguous_strides;
1920
use cubecl_runtime::{
2021
memory_management::MemoryDeviceProperties, server::ComputeServer, storage::BindingResource,
2122
};
@@ -166,15 +167,14 @@ impl ComputeServer for WgpuServer {
166167
continue;
167168
}
168169

169-
// 2D pitched rows: support rank==2 with inner-most contiguous dimension.
170-
// Note: contiguous path unchanged; pitched path uses per-row queue.write_buffer with small overhead.
171-
let shape = desc.shape;
172-
if shape.len() == 2 && desc.strides[1] == 1 {
173-
let rows = shape[0] as u64;
174-
let cols = shape[1] as u64;
170+
// Inner-contiguous pitched rows: rank>=2, inner-most contiguous
171+
if desc.shape.len() >= 2 && desc.strides[desc.shape.len() - 1] == 1 {
172+
let last = desc.shape.len() - 1;
173+
let rows = desc.shape[..last].iter().product::<usize>() as u64;
174+
let cols = desc.shape[last] as u64;
175175
let elem = desc.elem_size as u64;
176176
let row_bytes = cols * elem;
177-
let row_pitch = desc.strides[0] as u64 * elem;
177+
let row_pitch = desc.strides[last - 1] as u64 * elem;
178178

179179
let resource = self.stream.mem_manage.get_resource(desc.binding);
180180
self.stream
@@ -245,11 +245,4 @@ fn compiler(backend: wgpu::Backend) -> AutoCompiler {
245245
}
246246
}
247247

248-
pub(crate) fn contiguous_strides(shape: &[usize]) -> Vec<usize> {
249-
let rank = shape.len();
250-
let mut strides = vec![1; rank];
251-
for i in (0..rank - 1).rev() {
252-
strides[i] = strides[i + 1] * shape[i + 1];
253-
}
254-
strides
255-
}
248+
// Note: use cubecl_runtime::stride::contiguous_strides for canonical row-major strides.

crates/cubecl-wgpu/src/compute/stream.rs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,8 @@ impl WgpuStream {
189189
let resource = self.mem_manage.get_resource(binding);
190190

191191
// Contiguous path: copy entire resource range
192-
let is_contiguous = super::super::compute::server::contiguous_strides(descriptor.shape)
193-
== descriptor.strides;
192+
let is_contiguous =
193+
cubecl_runtime::stride::contiguous_strides(descriptor.shape) == descriptor.strides;
194194

195195
if is_contiguous {
196196
let size = descriptor.shape.iter().product::<usize>() * descriptor.elem_size;
@@ -214,12 +214,13 @@ impl WgpuStream {
214214
continue;
215215
}
216216

217-
// 2D pitched rows: rank==2, innermost contiguous
218-
if descriptor.shape.len() == 2 && descriptor.strides[1] == 1 {
219-
let rows = descriptor.shape[0] as u64;
220-
let cols = descriptor.shape[1] as u64;
217+
// Inner-contiguous pitched rows: rank>=2, innermost contiguous
218+
if descriptor.shape.len() >= 2 && descriptor.strides[descriptor.shape.len() - 1] == 1 {
219+
let last = descriptor.shape.len() - 1;
220+
let rows = descriptor.shape[..last].iter().product::<usize>() as u64;
221+
let cols = descriptor.shape[last] as u64;
221222
let row_bytes = cols * elem;
222-
let row_pitch = descriptor.strides[0] as u64 * elem;
223+
let row_pitch = descriptor.strides[last - 1] as u64 * elem;
223224
let total = rows * row_pitch;
224225
let align = wgpu::COPY_BUFFER_ALIGNMENT;
225226
let aligned_total = total.div_ceil(align) * align;

0 commit comments

Comments
 (0)