@@ -16,6 +16,7 @@ use cubecl_core::{
1616} ;
1717use cubecl_runtime:: logging:: ServerLogger ;
1818use cubecl_runtime:: memory_management:: offset_handles;
19+ use cubecl_runtime:: stride:: contiguous_strides;
1920use cubecl_runtime:: {
2021 memory_management:: MemoryDeviceProperties , server:: ComputeServer , storage:: BindingResource ,
2122} ;
@@ -166,15 +167,14 @@ impl ComputeServer for WgpuServer {
166167 continue ;
167168 }
168169
169- // 2D pitched rows: support rank==2 with inner-most contiguous dimension.
170- // Note: contiguous path unchanged; pitched path uses per-row queue.write_buffer with small overhead.
171- let shape = desc. shape ;
172- if shape. len ( ) == 2 && desc. strides [ 1 ] == 1 {
173- let rows = shape[ 0 ] as u64 ;
174- let cols = shape[ 1 ] as u64 ;
170+ // Inner-contiguous pitched rows: rank>=2, inner-most contiguous
171+ if desc. shape . len ( ) >= 2 && desc. strides [ desc. shape . len ( ) - 1 ] == 1 {
172+ let last = desc. shape . len ( ) - 1 ;
173+ let rows = desc. shape [ ..last] . iter ( ) . product :: < usize > ( ) as u64 ;
174+ let cols = desc. shape [ last] as u64 ;
175175 let elem = desc. elem_size as u64 ;
176176 let row_bytes = cols * elem;
177- let row_pitch = desc. strides [ 0 ] as u64 * elem;
177+ let row_pitch = desc. strides [ last - 1 ] as u64 * elem;
178178
179179 let resource = self . stream . mem_manage . get_resource ( desc. binding ) ;
180180 self . stream
@@ -245,11 +245,4 @@ fn compiler(backend: wgpu::Backend) -> AutoCompiler {
245245 }
246246}
247247
248- pub ( crate ) fn contiguous_strides ( shape : & [ usize ] ) -> Vec < usize > {
249- let rank = shape. len ( ) ;
250- let mut strides = vec ! [ 1 ; rank] ;
251- for i in ( 0 ..rank - 1 ) . rev ( ) {
252- strides[ i] = strides[ i + 1 ] * shape[ i + 1 ] ;
253- }
254- strides
255- }
248+ // Note: use cubecl_runtime::stride::contiguous_strides for canonical row-major strides.
0 commit comments