diff --git a/Cargo.toml b/Cargo.toml index 96caa18e2..7a2d16eb9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -66,7 +66,11 @@ variadics_please = "1" dashmap = "6.1.0" foldhash = { version = "0.2", default-features = false } hashbrown = "0.16" -smallvec = { version = "1", features = ["union", "const_generics"] } +smallvec = { version = "1", features = [ + "union", + "const_generics", + "const_new", +] } spin = { version = "0.10.0", features = ["mutex", "spin_mutex"] } xxhash-rust = { version = "0.8", default-features = false } diff --git a/crates/cubecl-core/Cargo.toml b/crates/cubecl-core/Cargo.toml index e7479743e..c0f06e961 100644 --- a/crates/cubecl-core/Cargo.toml +++ b/crates/cubecl-core/Cargo.toml @@ -29,8 +29,8 @@ tracing = ["dep:tracing", "cubecl-runtime/tracing", "cubecl-ir/tracing"] [dependencies] # Runtime Deps log = { workspace = true } -tracing = { workspace = true, features = ["attributes"], optional = true } test-log = { workspace = true, optional = true } +tracing = { workspace = true, features = ["attributes"], optional = true } cubecl-ir = { path = "../cubecl-ir", version = "=0.10.0-pre.1", default-features = false, features = [ "serde", @@ -63,6 +63,6 @@ tempfile = { version = "3.20", optional = true } variadics_please = { workspace = true } [dev-dependencies] -test-log = { workspace = true, features = ["trace"] } pretty_assertions = { workspace = true } +test-log = { workspace = true, features = ["trace"] } trybuild = "1" diff --git a/crates/cubecl-core/src/frontend/container/sequence/launch.rs b/crates/cubecl-core/src/frontend/container/sequence/launch.rs index d85674771..3e18617fd 100644 --- a/crates/cubecl-core/src/frontend/container/sequence/launch.rs +++ b/crates/cubecl-core/src/frontend/container/sequence/launch.rs @@ -1,6 +1,7 @@ use std::{cell::RefCell, rc::Rc}; use cubecl_runtime::runtime::Runtime; +use cubecl_zspace::SmallVec; use crate::{ compute::KernelBuilder, @@ -29,7 +30,7 @@ impl<'a, R: Runtime, T: LaunchArg> SequenceArg<'a, R, T> { } pub struct SequenceCompilationArg { - pub values: Vec, + pub values: SmallVec<[C::CompilationArg; 5]>, } impl CompilationArg for SequenceCompilationArg {} diff --git a/crates/cubecl-core/src/frontend/container/tensor/tensormap.rs b/crates/cubecl-core/src/frontend/container/tensor/tensormap.rs index 49d6fd681..81c7e767f 100644 --- a/crates/cubecl-core/src/frontend/container/tensor/tensormap.rs +++ b/crates/cubecl-core/src/frontend/container/tensor/tensormap.rs @@ -5,6 +5,7 @@ use crate::ir::ExpandElement; use crate::{prelude::*, unexpanded}; use cubecl_ir::{LineSize, StorageType, Type}; use cubecl_runtime::server::TensorMapMeta; +use cubecl_zspace::{Strides, metadata::Metadata, strides}; use paste::paste; use serde::{Deserialize, Serialize}; @@ -74,10 +75,8 @@ impl<'a, R: Runtime, K: TensorMapKind> TensorMapArg<'a, R, K> { Self { metadata: TensorMapMeta { format: K::as_format(args), - rank, - shape: handle.shape.to_vec(), - strides: handle.strides.to_vec(), - elem_stride: vec![1; rank], + metadata: Metadata::new(handle.shape, handle.strides), + elem_stride: strides![1; rank], interleave: TensorMapInterleave::None, swizzle: TensorMapSwizzle::None, prefetch: TensorMapPrefetch::None, @@ -89,7 +88,7 @@ impl<'a, R: Runtime, K: TensorMapKind> TensorMapArg<'a, R, K> { } } - pub fn with_elem_stride(mut self, elem_stride: Vec) -> Self { + pub fn with_elem_stride(mut self, elem_stride: Strides) -> Self { self.metadata.elem_stride = elem_stride; self } diff --git a/crates/cubecl-core/src/lib.rs b/crates/cubecl-core/src/lib.rs index cd3e8d650..141d8c377 100644 --- a/crates/cubecl-core/src/lib.rs +++ b/crates/cubecl-core/src/lib.rs @@ -101,12 +101,12 @@ pub enum LineSizeError { /// is divisible by the vectorization. /// The last condition ensure that the current axis is contiguous within the next stride. pub fn tensor_line_size_parallel( - supported_line_sizes: impl Iterator, + optimized_line_sizes: impl Iterator, shape: &[usize], strides: &[usize], axis: usize, ) -> LineSize { - try_tensor_line_size_parallel(supported_line_sizes, shape, strides, axis).unwrap_or(1) + try_tensor_line_size_parallel(optimized_line_sizes, shape, strides, axis).unwrap_or(1) } /// Like `try_tensor_line_size_parallel` but does not assume 1 is supported diff --git a/crates/cubecl-core/src/runtime_tests/line.rs b/crates/cubecl-core/src/runtime_tests/line.rs index 2bbca267b..3677225e1 100644 --- a/crates/cubecl-core/src/runtime_tests/line.rs +++ b/crates/cubecl-core/src/runtime_tests/line.rs @@ -13,7 +13,7 @@ pub fn kernel_line_index(output: &mut Array, #[comptime] line_size: #[allow(clippy::needless_range_loop)] pub fn test_line_index(client: ComputeClient) { - for line_size in client.io_optimized_line_sizes(&F::as_type_native().unwrap()) { + for line_size in client.io_optimized_line_sizes(size_of::()) { if line_size < 4 { continue; } @@ -50,7 +50,7 @@ pub fn kernel_line_index_assign(output: &mut Array>) { } pub fn test_line_index_assign(client: ComputeClient) { - for line_size in client.io_optimized_line_sizes(&F::as_type_native().unwrap()) { + for line_size in client.io_optimized_line_sizes(size_of::()) { let handle = client.create_from_slice(F::as_bytes(&vec![F::new(0.0); line_size])); unsafe { kernel_line_index_assign::launch_unchecked::( @@ -88,7 +88,7 @@ pub fn kernel_line_loop_unroll( } pub fn test_line_loop_unroll(client: ComputeClient) { - for line_size in client.io_optimized_line_sizes(&F::as_type_native_unchecked()) { + for line_size in client.io_optimized_line_sizes(size_of::()) { let handle = client.create_from_slice(F::as_bytes(&vec![F::new(0.0); line_size])); unsafe { kernel_line_loop_unroll::launch_unchecked::( @@ -171,7 +171,7 @@ pub fn kernel_shared_memory(output: &mut Array>) { } pub fn test_shared_memory(client: ComputeClient) { - for line_size in client.io_optimized_line_sizes(&F::as_type_native().unwrap()) { + for line_size in client.io_optimized_line_sizes(size_of::()) { let output = client.create_from_slice(F::as_bytes(&vec![F::new(0.0); line_size])); unsafe { kernel_shared_memory::launch_unchecked::( diff --git a/crates/cubecl-core/src/runtime_tests/tensor.rs b/crates/cubecl-core/src/runtime_tests/tensor.rs index 2ff195323..47ea79cc6 100644 --- a/crates/cubecl-core/src/runtime_tests/tensor.rs +++ b/crates/cubecl-core/src/runtime_tests/tensor.rs @@ -27,7 +27,7 @@ pub fn test_tensor_coordinate(client: ComputeClient) { let output_size = shape.len() * input_size; // The result is independent of the line size - for &line_size in R::supported_line_sizes() { + for line_size in client.io_optimized_line_sizes(size_of::()) { let output = client.empty(core::mem::size_of::() * output_size); unsafe { tensor_coordinate::launch( diff --git a/crates/cubecl-core/src/runtime_tests/tensormap.rs b/crates/cubecl-core/src/runtime_tests/tensormap.rs index 6230577b0..1669c7b1a 100644 --- a/crates/cubecl-core/src/runtime_tests/tensormap.rs +++ b/crates/cubecl-core/src/runtime_tests/tensormap.rs @@ -5,6 +5,7 @@ use cubecl_runtime::{ server::{Allocation, ComputeServer, CopyDescriptor}, storage::ComputeStorage, }; +use cubecl_zspace::shape; use std::fmt::Debug; #[cube(launch)] @@ -129,7 +130,7 @@ where CubeDim::new_2d(32, 16), TensorMapArg::new( TiledArgs { - tile_size: vec![16, 32], + tile_size: shape![16, 32], }, input, F::as_type_native_unchecked(), @@ -173,7 +174,7 @@ where unsafe { ArrayArg::from_raw_parts::(&handle, 32 * 16, 1) }, TensorMapArg::new( TiledArgs { - tile_size: vec![16, 32], + tile_size: shape![16, 32], }, unsafe { TensorArg::from_raw_parts::(&out.handle, &out.strides, &[64, 64], 1) }, F::as_type_native_unchecked(), @@ -316,14 +317,14 @@ where input_1, TensorMapArg::new( TiledArgs { - tile_size: vec![16, 16], + tile_size: shape![16, 16], }, output_1, F::as_type_native_unchecked(), ), TensorMapArg::new( TiledArgs { - tile_size: vec![16, 32], + tile_size: shape![16, 32], }, input_2, F::as_type_native_unchecked(), diff --git a/crates/cubecl-cpu/src/compute/server.rs b/crates/cubecl-cpu/src/compute/server.rs index d8988f989..65de49bbd 100644 --- a/crates/cubecl-cpu/src/compute/server.rs +++ b/crates/cubecl-cpu/src/compute/server.rs @@ -18,6 +18,7 @@ use cubecl_core::{ ExecutionError, IoError, LaunchError, ProfileError, ProfilingToken, ServerCommunication, ServerUtilities, }, + zspace::{Strides, strides}, }; use cubecl_runtime::{ compiler::CubeTask, @@ -335,9 +336,9 @@ impl ServerCommunication for CpuServer { const SERVER_COMM_ENABLED: bool = false; } -pub(crate) fn contiguous_strides(shape: &[usize]) -> Vec { +pub(crate) fn contiguous_strides(shape: &[usize]) -> Strides { let rank = shape.len(); - let mut strides = vec![1; rank]; + let mut strides = strides![1; rank]; for i in (0..rank - 1).rev() { strides[i] = strides[i + 1] * shape[i + 1]; } diff --git a/crates/cubecl-cpu/src/runtime.rs b/crates/cubecl-cpu/src/runtime.rs index 651117bfc..62d0a6873 100644 --- a/crates/cubecl-cpu/src/runtime.rs +++ b/crates/cubecl-cpu/src/runtime.rs @@ -58,6 +58,7 @@ impl DeviceState for CpuServer { num_streaming_multiprocessors: None, num_tensor_cores: None, min_tensor_cores_dim: None, + max_line_size: LineSize::MAX, }; const ALIGNMENT: u64 = 4; @@ -95,10 +96,6 @@ impl Runtime for CpuRuntime { "cpu" } - fn supported_line_sizes() -> &'static [LineSize] { - &[128, 64, 32, 16, 8, 4, 2, 1] - } - fn max_cube_count() -> (u32, u32, u32) { (u32::MAX, u32::MAX, u32::MAX) } diff --git a/crates/cubecl-cuda/src/compute/server.rs b/crates/cubecl-cuda/src/compute/server.rs index b1d57aaf1..6be28e5ab 100644 --- a/crates/cubecl-cuda/src/compute/server.rs +++ b/crates/cubecl-cuda/src/compute/server.rs @@ -21,6 +21,7 @@ use cubecl_core::{ ExecutionError, IoError, LaunchError, ProfileError, ProfilingToken, ServerCommunication, ServerUtilities, TensorMapBinding, TensorMapMeta, }, + zspace::{Shape, Strides, strides}, }; use cubecl_runtime::{ compiler::CubeTask, @@ -109,7 +110,7 @@ impl ComputeServer for CudaServer { let pitch = width_bytes.next_multiple_of(pitch_align); let size = height * pitch; total_size += size.next_multiple_of(self.mem_alignment); - let mut stride = vec![1; rank]; + let mut stride = strides![1; rank]; if rank > 1 { stride[rank - 2] = pitch / descriptor.elem_size; } @@ -245,9 +246,16 @@ impl ComputeServer for CudaServer { let mut map_ptr = MaybeUninit::zeroed(); - let shape: Vec<_> = map.shape.iter().rev().map(|s| *s as u64).collect(); + let shape: Vec<_> = map + .metadata + .shape() + .iter() + .rev() + .map(|s| *s as u64) + .collect(); let strides: Vec<_> = map - .strides + .metadata + .strides() .iter() .rev() .skip(1) @@ -257,12 +265,13 @@ impl ComputeServer for CudaServer { match &map.format { TensorMapFormat::Tiled(TiledArgs { tile_size }) => unsafe { - let tile_size: Vec<_> = tile_size.iter().rev().copied().collect(); + let tile_size: Vec<_> = + tile_size.iter().rev().copied().map(|s| s as u32).collect(); cuTensorMapEncodeTiled( map_ptr.as_mut_ptr(), elem_to_tensor_map_type(map.storage_ty), - map.rank as u32, + map.metadata.rank() as u32, device_ptr, shape.as_ptr(), strides.as_ptr(), @@ -296,7 +305,7 @@ impl ComputeServer for CudaServer { cuTensorMapEncodeIm2col( map_ptr.as_mut_ptr(), elem_to_tensor_map_type(map.storage_ty), - map.rank as u32, + map.metadata.rank() as u32, device_ptr, shape.as_ptr(), strides.as_ptr(), @@ -339,7 +348,7 @@ impl ComputeServer for CudaServer { cuTensorMapEncodeIm2colWide( map_ptr.as_mut_ptr(), elem_to_tensor_map_type(map.storage_ty), - map.rank as u32, + map.metadata.rank() as u32, device_ptr, shape.as_ptr(), strides.as_ptr(), @@ -532,7 +541,7 @@ impl CudaServer { stream_id_src: StreamId, stream_id_dst: StreamId, ) -> Result { - let strides = src.strides.to_vec(); + let strides = src.strides.into(); let binding = src.binding.clone(); let context_src = server_src.ctx.context; @@ -590,8 +599,8 @@ impl CudaServer { stream_id_src: StreamId, stream_id_dst: StreamId, ) -> Result { - let shape = src.shape.to_vec(); - let strides = src.strides.to_vec(); + let shape: Shape = src.shape.into(); + let strides: Strides = src.strides.into(); let elem_size = src.elem_size; let binding = src.binding.clone(); let num_bytes = shape.iter().product::() * elem_size; @@ -826,9 +835,12 @@ fn check_tma_generic( } // tensorRank invariants - launch_check!((1..=5).contains(&map.rank), "Rank must be between 1 and 5")?; launch_check!( - matches!(map.interleave, TensorMapInterleave::None) || map.rank >= 3, + (1..=5).contains(&map.metadata.rank()), + "Rank must be between 1 and 5" + )?; + launch_check!( + matches!(map.interleave, TensorMapInterleave::None) || map.metadata.rank() >= 3, "When interleave is enabled, rank must be >= 3" )?; @@ -881,7 +893,10 @@ fn check_tma_generic( } fn check_tma_tiled(map: &TensorMapMeta, tile_size: &[u32]) -> Result<(), LaunchError> { - launch_check!(tile_size.len() == map.rank, "Tile shape should match rank")?; + launch_check!( + tile_size.len() == map.metadata.rank(), + "Tile shape should match rank" + )?; launch_check!( tile_size.iter().all(|it| *it > 0 && *it <= 256), "Tile shape must be non-zero and <= 256" @@ -920,20 +935,20 @@ fn check_tma_im2col( pixels_per_column: u32, ) -> Result<(), LaunchError> { launch_check!( - lower_corner.len() == map.rank - 2, + lower_corner.len() == map.metadata.rank() - 2, "Lower corner must be rank - 2 elements" )?; launch_check!( - upper_corner.len() == map.rank - 2, + upper_corner.len() == map.metadata.rank() - 2, "Upper corner must be rank - 2 elements" )?; launch_check!( - map.rank >= 3 && map.rank <= 5, + map.metadata.rank() >= 3 && map.metadata.rank() <= 5, "im2col requires rank to be between 3 and 5" )?; - let (range_lower, range_upper) = match map.rank { + let (range_lower, range_upper) = match map.metadata.rank() { 3 => (-32768, 32767), 4 => (-128, 127), 5 => (-16, 15), @@ -944,14 +959,14 @@ fn check_tma_im2col( .iter() .all(|it| *it >= range_lower && *it <= range_upper), "Lower corner must be in range [{range_lower}, {range_upper}] for {}D im2col", - map.rank + map.metadata.rank() )?; launch_check!( upper_corner .iter() .all(|it| *it >= range_lower && *it <= range_upper), "Upper corner must be in range [{range_lower}, {range_upper}] for {}D im2col", - map.rank + map.metadata.rank() )?; launch_check!( diff --git a/crates/cubecl-cuda/src/runtime.rs b/crates/cubecl-cuda/src/runtime.rs index c4023e105..6c60a6b33 100644 --- a/crates/cubecl-cuda/src/runtime.rs +++ b/crates/cubecl-cuda/src/runtime.rs @@ -151,6 +151,7 @@ impl DeviceState for CudaServer { Some(8) }, num_cpu_cores: None, + max_line_size: LineSize::MAX, } }; @@ -320,10 +321,6 @@ impl Runtime for CudaRuntime { true } - fn supported_line_sizes() -> &'static [LineSize] { - &[16, 8, 4, 2, 1] - } - fn max_cube_count() -> (u32, u32, u32) { (i32::MAX as u32, u16::MAX as u32, u16::MAX as u32) } diff --git a/crates/cubecl-hip/src/compute/server.rs b/crates/cubecl-hip/src/compute/server.rs index 1206c1a42..276546bc2 100644 --- a/crates/cubecl-hip/src/compute/server.rs +++ b/crates/cubecl-hip/src/compute/server.rs @@ -17,6 +17,7 @@ use cubecl_core::{ Allocation, AllocationKind, Binding, Bindings, CopyDescriptor, ExecutionError, IoError, LaunchError, ProfileError, ProfilingToken, ServerCommunication, ServerUtilities, }, + zspace::{Shape, Strides, strides}, }; use cubecl_runtime::{ compiler::CubeTask, @@ -88,7 +89,7 @@ impl ComputeServer for HipServer { let size = height * pitch; total_size += size.next_multiple_of(self.mem_alignment); - let mut stride = vec![1; rank]; + let mut stride = strides![1; rank]; if rank > 1 { stride[rank - 2] = pitch / descriptor.elem_size; } @@ -343,8 +344,8 @@ impl HipServer { stream_id_src: StreamId, stream_id_dst: StreamId, ) -> Result { - let shape = src.shape.to_vec(); - let strides = src.strides.to_vec(); + let shape: Shape = src.shape.into(); + let strides: Strides = src.strides.into(); let elem_size = src.elem_size; let binding = src.binding.clone(); let num_bytes = shape.iter().product::() * elem_size; diff --git a/crates/cubecl-hip/src/runtime.rs b/crates/cubecl-hip/src/runtime.rs index 044765e15..d7a6dbe01 100644 --- a/crates/cubecl-hip/src/runtime.rs +++ b/crates/cubecl-hip/src/runtime.rs @@ -139,6 +139,7 @@ impl DeviceState for HipServer { Some(16) }, num_cpu_cores: None, + max_line_size: LineSize::MAX, }; let mut device_props = DeviceProperties::new( @@ -204,10 +205,6 @@ impl Runtime for HipRuntime { true } - fn supported_line_sizes() -> &'static [LineSize] { - &[16, 8, 4, 2, 1] - } - fn max_cube_count() -> (u32, u32, u32) { (i32::MAX as u32, u16::MAX as u32, u16::MAX as u32) } diff --git a/crates/cubecl-ir/src/properties.rs b/crates/cubecl-ir/src/properties.rs index 235bad710..7d03c9d18 100644 --- a/crates/cubecl-ir/src/properties.rs +++ b/crates/cubecl-ir/src/properties.rs @@ -1,7 +1,7 @@ use core::hash::{BuildHasher, Hash, Hasher}; use crate::{ - AddressType, SemanticType, StorageType, Type, TypeHash, + AddressType, LineSize, SemanticType, StorageType, Type, TypeHash, features::{Features, TypeUsage}, }; use cubecl_common::profile::TimingMethod; @@ -50,6 +50,8 @@ pub struct HardwareProperties { /// For a backend that only supports 16x16x16, the value would be 16. /// For a backend that also supports 32x8x16, the value would be 8. pub min_tensor_cores_dim: Option, + /// Maximum line size supported by the device + pub max_line_size: LineSize, } /// Properties of the device related to allocation. diff --git a/crates/cubecl-runtime/Cargo.toml b/crates/cubecl-runtime/Cargo.toml index f260f644f..ff7192dd5 100644 --- a/crates/cubecl-runtime/Cargo.toml +++ b/crates/cubecl-runtime/Cargo.toml @@ -47,6 +47,7 @@ cubecl-common = { path = "../cubecl-common", version = "=0.10.0-pre.1", default- cubecl-ir = { path = "../cubecl-ir", version = "=0.10.0-pre.1", default-features = false, features = [ "serde", ] } +cubecl-zspace = { path = "../cubecl-zspace", version = "=0.10.0-pre.1", default-features = false } derive-new = { workspace = true } derive_more = { workspace = true, features = ["eq"] } dirs = { workspace = true, optional = true } diff --git a/crates/cubecl-runtime/src/client.rs b/crates/cubecl-runtime/src/client.rs index 21aeebbf5..795454904 100644 --- a/crates/cubecl-runtime/src/client.rs +++ b/crates/cubecl-runtime/src/client.rs @@ -22,7 +22,7 @@ use cubecl_common::{ future::DynFut, profile::ProfileDuration, }; -use cubecl_ir::{DeviceProperties, LineSize, StorageType}; +use cubecl_ir::{DeviceProperties, LineSize}; #[allow(unused)] use cubecl_common::profile::TimingMethod; @@ -811,29 +811,11 @@ impl ComputeClient { } /// Returns all line sizes that are useful to perform optimal IO operation on the given element. - pub fn io_optimized_line_sizes( - &self, - elem: &StorageType, - ) -> impl Iterator + Clone { - let load_width = self.properties().hardware.load_width as usize; - let max = load_width / elem.size_bits(); - let supported = R::supported_line_sizes(); - supported.iter().filter(move |v| **v <= max).cloned() - } - - /// Returns all line sizes that are useful to perform optimal IO operation on the given element. - /// Ignores native support, and allows all line sizes. This means the returned size may be - /// unrolled, and may not support dynamic indexing. - pub fn io_optimized_line_sizes_unchecked( - &self, - size: usize, - ) -> impl Iterator + Clone { + pub fn io_optimized_line_sizes(&self, size: usize) -> impl Iterator + Clone { let load_width = self.properties().hardware.load_width as usize; let size_bits = size * 8; let max = load_width / size_bits; - // This makes this effectively the same as checked, if it doesn't work it's a problem with - // unroll that should be investigated instead. But separate PR. - let max = usize::min(R::max_global_line_size(), max); + let max = usize::min(self.properties().hardware.max_line_size, max); // If the max is 8, we want to test 1, 2, 4, 8 which is log2(8) + 1. let num_candidates = max.trailing_zeros() + 1; diff --git a/crates/cubecl-runtime/src/runtime.rs b/crates/cubecl-runtime/src/runtime.rs index 278f7fdda..2929c45ae 100644 --- a/crates/cubecl-runtime/src/runtime.rs +++ b/crates/cubecl-runtime/src/runtime.rs @@ -1,6 +1,6 @@ use alloc::boxed::Box; use cubecl_common::device::Device; -use cubecl_ir::{LineSize, TargetProperties}; +use cubecl_ir::TargetProperties; use crate::{ client::ComputeClient, @@ -28,14 +28,6 @@ pub trait Runtime: Sized + Send + Sync + 'static + core::fmt::Debug { false } - /// Returns the supported line sizes for the current runtime's compiler. - fn supported_line_sizes() -> &'static [LineSize]; - - /// The maximum line size that can be used for global buffer bindings. - fn max_global_line_size() -> LineSize { - u8::MAX as usize - } - /// Returns the maximum cube count on each dimension that can be launched. fn max_cube_count() -> (u32, u32, u32); diff --git a/crates/cubecl-runtime/src/server.rs b/crates/cubecl-runtime/src/server.rs index e7fd84e1f..00cb040f9 100644 --- a/crates/cubecl-runtime/src/server.rs +++ b/crates/cubecl-runtime/src/server.rs @@ -22,6 +22,7 @@ use cubecl_common::{ stream_id::StreamId, }; use cubecl_ir::{DeviceProperties, StorageType}; +use cubecl_zspace::{Strides, metadata::Metadata}; use serde::{Deserialize, Serialize}; use thiserror::Error; @@ -490,12 +491,22 @@ impl<'a> AllocationDescriptor<'a> { } /// An allocation with associated strides. Strides depend on tensor layout. -#[derive(new, Debug)] +#[derive(Debug)] pub struct Allocation { /// The handle for the memory resource pub handle: Handle, /// The strides of the tensor - pub strides: Vec, + pub strides: Strides, +} + +impl Allocation { + /// Create a new allocation + pub fn new(handle: Handle, strides: impl Into) -> Self { + Allocation { + handle, + strides: strides.into(), + } + } } /// Error returned from `create`/`read`/`write` functions. Due to async execution not all errors @@ -723,15 +734,11 @@ pub struct TensorMapBinding { pub struct TensorMapMeta { /// Tensormap format (tiled or im2col) pub format: TensorMapFormat, - /// Rank of the backing tensor - pub rank: usize, - /// Shape of the backing tensor - pub shape: Vec, - /// Strides of the backing tensor - pub strides: Vec, + /// Metadata of the backing tensor + pub metadata: Metadata, /// Element stride, usually 1 but may be 2 for complex tensors /// For im2col, this is equivalent to the kernel stride - pub elem_stride: Vec, + pub elem_stride: Strides, /// Interleave mode pub interleave: TensorMapInterleave, /// Swizzle mode diff --git a/crates/cubecl-runtime/src/tma.rs b/crates/cubecl-runtime/src/tma.rs index f52b8bf6e..9dc39c0a4 100644 --- a/crates/cubecl-runtime/src/tma.rs +++ b/crates/cubecl-runtime/src/tma.rs @@ -1,4 +1,5 @@ use alloc::vec::Vec; +use cubecl_zspace::Shape; #[cfg(any(target_os = "windows", target_os = "linux", target_os = "macos"))] use serde::{Deserialize, Serialize}; @@ -14,7 +15,7 @@ pub struct TiledArgs { /// If a dimension isn't present in the tile, it should just be set to `1`. /// /// For CUDA, this must be a power of two and `<= 256` on each dimension. - pub tile_size: Vec, + pub tile_size: Shape, } /// Args for im2col tensor maps diff --git a/crates/cubecl-runtime/tests/dummy/compute.rs b/crates/cubecl-runtime/tests/dummy/compute.rs index 850f883b7..0039e23df 100644 --- a/crates/cubecl-runtime/tests/dummy/compute.rs +++ b/crates/cubecl-runtime/tests/dummy/compute.rs @@ -2,7 +2,7 @@ use super::DummyServer; use crate::dummy::KernelTask; use cubecl_common::device::{Device, DeviceState}; use cubecl_ir::MemoryDeviceProperties; -use cubecl_ir::{LineSize, StorageType}; +use cubecl_ir::StorageType; use cubecl_runtime::{ client::ComputeClient, compiler::{CompilationError, Compiler}, @@ -109,10 +109,6 @@ impl Runtime for DummyRuntime { unimplemented!() } - fn supported_line_sizes() -> &'static [LineSize] { - unimplemented!() - } - fn max_cube_count() -> (u32, u32, u32) { unimplemented!() } diff --git a/crates/cubecl-runtime/tests/dummy/server.rs b/crates/cubecl-runtime/tests/dummy/server.rs index eb0695c92..c3a3b4f08 100644 --- a/crates/cubecl-runtime/tests/dummy/server.rs +++ b/crates/cubecl-runtime/tests/dummy/server.rs @@ -2,8 +2,8 @@ use super::DummyKernel; use crate::dummy::DummyCompiler; use cubecl_common::{bytes::Bytes, future::DynFut, profile::ProfileDuration, stream_id::StreamId}; use cubecl_ir::{ - DeviceProperties, ElemType, HardwareProperties, MemoryDeviceProperties, StorageType, UIntKind, - features::Features, + DeviceProperties, ElemType, HardwareProperties, LineSize, MemoryDeviceProperties, StorageType, + UIntKind, features::Features, }; use cubecl_runtime::{ compiler::{CompilationError, CubeTask}, @@ -19,6 +19,7 @@ use cubecl_runtime::{ storage::{BindingResource, BytesResource, BytesStorage, ComputeStorage}, timestamp_profiler::TimestampProfiler, }; +use cubecl_zspace::strides; use std::sync::Arc; /// The dummy server is used to test the cubecl-runtime infrastructure. @@ -112,7 +113,7 @@ impl ComputeServer for DummyServer { .into_iter() .map(|descriptor| { let rank = descriptor.shape.len(); - let mut strides = vec![1; rank]; + let mut strides = strides![1; rank]; for i in (0..rank - 1).rev() { strides[i] = strides[i + 1] * descriptor.shape[i + 1]; } @@ -274,6 +275,7 @@ impl DummyServer { num_tensor_cores: None, min_tensor_cores_dim: None, num_cpu_cores: None, + max_line_size: LineSize::MAX, }; let features = Features::default(); let timing_method = cubecl_common::profile::TimingMethod::System; diff --git a/crates/cubecl-std/src/tensor/contiguous/base.rs b/crates/cubecl-std/src/tensor/contiguous/base.rs index 7646178d9..368ae6c5f 100644 --- a/crates/cubecl-std/src/tensor/contiguous/base.rs +++ b/crates/cubecl-std/src/tensor/contiguous/base.rs @@ -13,6 +13,7 @@ use cubecl_core::{ self as cubecl, calculate_cube_count_elemwise, ir::{LineSize, StorageType}, tensor_line_size_parallel, + zspace::{Strides, strides}, }; pub const NUM_SM_APPROX: u32 = 50; @@ -296,13 +297,13 @@ pub fn copy_gpu_ref( let in_rank = input.strides.len(); let out_rank = output.strides.len(); let line_size_in = tensor_line_size_parallel( - client.io_optimized_line_sizes(&dtype), + client.io_optimized_line_sizes(dtype.size()), input.shape, input.strides, in_rank - 1, ); let line_size_out = tensor_line_size_parallel( - client.io_optimized_line_sizes(&dtype), + client.io_optimized_line_sizes(dtype.size()), output.shape, output.strides, out_rank - 1, @@ -339,7 +340,7 @@ pub fn copy_gpu_ref( } else { // Recompute because it needs to account for `num_elems_per_unit` client - .io_optimized_line_sizes(&dtype) + .io_optimized_line_sizes(dtype.size()) .filter(|it| num_elems_per_unit.is_multiple_of(*it)) .max() .unwrap_or(1) @@ -393,7 +394,7 @@ pub fn into_contiguous_packed_ref( let out_rank = output.strides.len(); let in_packed_dim = in_rank - packed_dim - 1; let line_size = tensor_line_size_parallel( - client.io_optimized_line_sizes(&dtype), + client.io_optimized_line_sizes(dtype.size()), output.shape, output.strides, out_rank - 1, @@ -463,7 +464,7 @@ pub fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool { return true; } - for (expected, &stride) in compact_strides(shape).into_iter().zip(strides) { + for (&expected, &stride) in compact_strides(shape).iter().zip(strides) { if expected != stride { return false; } @@ -500,9 +501,9 @@ pub fn is_contiguous_pitched(shape: &[usize], strides: &[usize]) -> bool { true } -pub fn compact_strides(shape: &[usize]) -> Vec { +pub fn compact_strides(shape: &[usize]) -> Strides { let rank = shape.len(); - let mut strides = vec![1; rank]; + let mut strides = strides![1; rank]; for i in (0..rank - 1).rev() { strides[i] = strides[i + 1] * shape[i + 1]; } diff --git a/crates/cubecl-std/src/tensor/contiguous/perpendicular.rs b/crates/cubecl-std/src/tensor/contiguous/perpendicular.rs index 20cbb17a3..6cf2fabf6 100644 --- a/crates/cubecl-std/src/tensor/contiguous/perpendicular.rs +++ b/crates/cubecl-std/src/tensor/contiguous/perpendicular.rs @@ -136,13 +136,13 @@ pub fn launch_copy_perpendicular_ref( let rank = output.shape.len(); let line_size_perpendicular = tensor_line_size_perpendicular( - client.io_optimized_line_sizes(&dtype), + client.io_optimized_line_sizes(dtype.size()), input.shape, input.strides, rank - 1, ); let line_size_parallel = tensor_line_size_parallel( - client.io_optimized_line_sizes(&dtype), + client.io_optimized_line_sizes(dtype.size()), output.shape, output.strides, rank - 1, diff --git a/crates/cubecl-std/src/tensor/handle.rs b/crates/cubecl-std/src/tensor/handle.rs index f3473df65..c4f5a26d6 100644 --- a/crates/cubecl-std/src/tensor/handle.rs +++ b/crates/cubecl-std/src/tensor/handle.rs @@ -1,9 +1,12 @@ use core::marker::PhantomData; -use cubecl_core::ir::StorageType; -use cubecl_core::tensor_line_size_parallel; -use cubecl_core::{Runtime, server}; +use cubecl_core::{Runtime, server, zspace::strides}; use cubecl_core::{calculate_cube_count_elemwise, server::Allocation}; +use cubecl_core::{ir::StorageType, zspace::metadata::Metadata}; use cubecl_core::{prelude::*, server::CopyDescriptor}; +use cubecl_core::{ + tensor_line_size_parallel, + zspace::{Shape, Strides}, +}; use cubecl_runtime::server::Handle; /// Tensor representation containing a [server handle](Handle) as well as basic tensor metadata., @@ -13,8 +16,7 @@ where { /// The buffer where the data are stored. pub handle: server::Handle, - pub shape: Vec, - pub strides: Vec, + pub metadata: Box, /// The type used as storage. pub dtype: StorageType, runtime: PhantomData, @@ -27,7 +29,9 @@ where fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.write_fmt(format_args!( "Tensor {{ shape: {:?}, strides: {:?}, dtype: {}}}", - self.shape, self.strides, self.dtype, + self.shape(), + self.strides(), + self.dtype, )) } } @@ -39,8 +43,7 @@ where fn clone(&self) -> Self { Self { handle: self.handle.clone(), - shape: self.shape.clone(), - strides: self.strides.clone(), + metadata: self.metadata.clone(), dtype: self.dtype, runtime: PhantomData, } @@ -54,20 +57,20 @@ where /// Create a new tensor. pub fn new( handle: server::Handle, - shape: Vec, - strides: Vec, + shape: impl Into, + strides: impl Into, storage: StorageType, ) -> Self { Self { handle, - shape, - strides, + metadata: Box::new(Metadata::new(shape, strides)), dtype: storage, runtime: PhantomData, } } - pub fn empty(client: &ComputeClient, shape: Vec, storage: StorageType) -> Self { + pub fn empty(client: &ComputeClient, shape: impl Into, storage: StorageType) -> Self { + let shape = shape.into(); let elem_size = storage.size(); let Allocation { handle, strides } = client.empty_tensor(&shape, elem_size); @@ -78,21 +81,20 @@ where pub fn from_ref(handle: &TensorHandleRef<'_, R>, storage: StorageType) -> Self { Self { handle: handle.handle.clone(), - shape: handle.shape.to_vec(), - strides: handle.strides.to_vec(), + metadata: Box::new(Metadata::new(handle.shape, handle.strides)), dtype: storage, runtime: PhantomData, } } /// Create a new tensor with a contiguous memory layout. - pub fn new_contiguous(shape: Vec, handle: Handle, storage: StorageType) -> Self { + pub fn new_contiguous(shape: impl Into, handle: Handle, storage: StorageType) -> Self { + let shape = shape.into(); let strides = Self::contiguous_strides(&shape); Self { handle, - shape, - strides, + metadata: Box::new(Metadata::new(shape, strides)), dtype: storage, runtime: PhantomData, } @@ -107,8 +109,8 @@ where unsafe { TensorHandleRef::from_raw_parts( &self.handle, - &self.strides, - &self.shape, + self.strides(), + self.shape(), self.dtype.size(), ) } @@ -132,8 +134,8 @@ where pub fn as_copy_descriptor<'a>(&'a self) -> CopyDescriptor<'a> { CopyDescriptor { binding: self.handle.clone().binding(), - shape: &self.shape, - strides: &self.strides, + shape: self.shape(), + strides: self.strides(), elem_size: self.dtype.size(), } } @@ -143,12 +145,20 @@ where AddressType::from_len(len as usize) } - fn contiguous_strides(shape: &[usize]) -> Vec { - let mut strides = Vec::with_capacity(shape.len()); + pub fn shape(&self) -> &Shape { + self.metadata.shape() + } + + pub fn strides(&self) -> &Strides { + self.metadata.strides() + } + + fn contiguous_strides(shape: &[usize]) -> Strides { + let mut strides = strides![1; shape.len()]; let mut current = 1; - shape.iter().enumerate().rev().for_each(|(_, val)| { - strides.push(current); + shape.iter().rev().enumerate().for_each(|(i, val)| { + strides[i] = current; current *= val; }); strides.reverse(); @@ -159,15 +169,16 @@ impl TensorHandle where R: Runtime, { - pub fn zeros(client: &ComputeClient, shape: Vec, dtype: StorageType) -> Self { + pub fn zeros(client: &ComputeClient, shape: impl Into, dtype: StorageType) -> Self { + let shape = shape.into(); let num_elements: usize = shape.iter().product(); let rank = shape.len(); let output = Self::empty(client, shape, dtype); let line_size = tensor_line_size_parallel( - R::supported_line_sizes().iter().cloned(), - &output.shape, - &output.strides, + client.io_optimized_line_sizes(dtype.size()), + output.shape(), + output.strides(), rank - 1, ); diff --git a/crates/cubecl-std/src/tensor/identity.rs b/crates/cubecl-std/src/tensor/identity.rs index 26a843e7c..783bfb4d0 100644 --- a/crates/cubecl-std/src/tensor/identity.rs +++ b/crates/cubecl-std/src/tensor/identity.rs @@ -55,7 +55,7 @@ pub fn launch_ref( ); let vectorization_factor = tensor_line_size_parallel( - R::supported_line_sizes().iter().cloned(), + client.io_optimized_line_sizes(dtype.size()), output.shape, output.strides, 1, diff --git a/crates/cubecl-std/src/tensor/layout/permuted.rs b/crates/cubecl-std/src/tensor/layout/permuted.rs index ff645e275..1c813672e 100644 --- a/crates/cubecl-std/src/tensor/layout/permuted.rs +++ b/crates/cubecl-std/src/tensor/layout/permuted.rs @@ -1,5 +1,5 @@ use cubecl::prelude::*; -use cubecl_core::{self as cubecl, ir::LineSize}; +use cubecl_core::{self as cubecl, ir::LineSize, zspace::Strides}; use crate::{ FastDivmod, FastDivmodArgs, @@ -78,7 +78,7 @@ impl<'a, R: Runtime> PermutedLayoutLaunch<'a, R> { "Shape should be equal to reference or 1 on each dimension" ); - let strides: Vec = strides + let strides: Strides = strides .iter() .zip(shape.iter().zip(reference_shape)) .map(|(stride, (s, r))| if *s == *r { *stride } else { 0 }) diff --git a/crates/cubecl-std/src/tests/tensor/identity.rs b/crates/cubecl-std/src/tests/tensor/identity.rs index 61c83ddf9..e7cac2f79 100644 --- a/crates/cubecl-std/src/tests/tensor/identity.rs +++ b/crates/cubecl-std/src/tests/tensor/identity.rs @@ -20,8 +20,8 @@ pub fn test_identity( tensor::identity::launch(&client, &identity); let actual = client.read_one_tensor(identity.handle.clone().copy_descriptor( - &identity.shape, - &identity.strides, + identity.shape(), + identity.strides(), size_of::(), )); let actual = C::from_bytes(&actual); diff --git a/crates/cubecl-wgpu/src/compute/server.rs b/crates/cubecl-wgpu/src/compute/server.rs index 87ffe32c7..933d1f428 100644 --- a/crates/cubecl-wgpu/src/compute/server.rs +++ b/crates/cubecl-wgpu/src/compute/server.rs @@ -17,6 +17,7 @@ use cubecl_core::{ IoError, LaunchError, ProfileError, ProfilingToken, ResourceLimitError, ServerCommunication, ServerUtilities, }, + zspace::{Strides, strides}, }; #[cfg(feature = "spirv")] use cubecl_core::{cache::CacheOption, compilation_cache::CompilationCache, hash::StableHash}; @@ -291,7 +292,7 @@ impl ComputeServer for WgpuServer { let mut streams = vec![stream_id]; let mut resources = Vec::with_capacity(descriptors.len()); for desc in descriptors { - if contiguous_strides(desc.shape) != desc.strides { + if &*contiguous_strides(desc.shape) != desc.strides { return Box::pin(async { Err(IoError::UnsupportedStrides { backtrace: BackTrace::capture(), @@ -306,7 +307,7 @@ impl ComputeServer for WgpuServer { Ok(val) => val, Err(err) => return Box::pin(async move { Err(err) }), }; - resources.push((resource, desc.shape.to_vec(), desc.elem_size)); + resources.push((resource, desc.shape.into(), desc.elem_size)); } self.scheduler.execute_streams(streams); @@ -320,7 +321,7 @@ impl ComputeServer for WgpuServer { stream_id: StreamId, ) -> Result<(), IoError> { for (desc, data) in descriptors { - if contiguous_strides(desc.shape) != desc.strides { + if &*contiguous_strides(desc.shape) != desc.strides { return Err(IoError::UnsupportedStrides { backtrace: BackTrace::capture(), }); @@ -437,9 +438,9 @@ fn compiler(backend: wgpu::Backend) -> AutoCompiler { } } -pub(crate) fn contiguous_strides(shape: &[usize]) -> Vec { +pub(crate) fn contiguous_strides(shape: &[usize]) -> Strides { let rank = shape.len(); - let mut strides = vec![1; rank]; + let mut strides = strides![1; rank]; for i in (0..rank - 1).rev() { strides[i] = strides[i + 1] * shape[i + 1]; } diff --git a/crates/cubecl-wgpu/src/compute/stream.rs b/crates/cubecl-wgpu/src/compute/stream.rs index 00969188d..e467d1e98 100644 --- a/crates/cubecl-wgpu/src/compute/stream.rs +++ b/crates/cubecl-wgpu/src/compute/stream.rs @@ -10,6 +10,7 @@ use cubecl_core::{ CubeCount, MemoryConfiguration, future::{self, DynFut}, server::{ExecutionError, Handle, IoError, ProfileError, ProfilingToken}, + zspace::Shape, }; use cubecl_ir::MemoryDeviceProperties; use cubecl_runtime::{logging::ServerLogger, timestamp_profiler::TimestampProfiler}; @@ -121,7 +122,7 @@ impl WgpuStream { /// A [Result] containing a vector of [Bytes] with the copied data, or an [`IoError`] if any copy fails. pub fn read_resources( &mut self, - descriptors: Vec<(WgpuResource, Vec, usize)>, + descriptors: Vec<(WgpuResource, Shape, usize)>, ) -> DynFut, IoError>> { self.compute_pass = None; let mut staging_info = Vec::with_capacity(descriptors.len()); diff --git a/crates/cubecl-wgpu/src/runtime.rs b/crates/cubecl-wgpu/src/runtime.rs index a5bd96707..04aca8434 100644 --- a/crates/cubecl-wgpu/src/runtime.rs +++ b/crates/cubecl-wgpu/src/runtime.rs @@ -4,8 +4,8 @@ use crate::{ }; use cubecl_common::device::{Device, DeviceState}; use cubecl_common::{future, profile::TimingMethod}; +use cubecl_core::server::ServerUtilities; use cubecl_core::{Runtime, ir::TargetProperties}; -use cubecl_core::{ir::LineSize, server::ServerUtilities}; use cubecl_ir::{DeviceProperties, HardwareProperties, MemoryDeviceProperties}; pub use cubecl_runtime::memory_management::MemoryConfiguration; use cubecl_runtime::{ @@ -57,21 +57,6 @@ impl Runtime for WgpuRuntime { } } - fn supported_line_sizes() -> &'static [LineSize] { - #[cfg(feature = "msl")] - { - &[8, 4, 2, 1] - } - #[cfg(not(feature = "msl"))] - { - &[4, 2, 1] - } - } - - fn max_global_line_size() -> LineSize { - 4 - } - fn max_cube_count() -> (u32, u32, u32) { let max_dim = u16::MAX as u32; (max_dim, max_dim, max_dim) @@ -82,7 +67,7 @@ impl Runtime for WgpuRuntime { return true; } - for (expected, &stride) in contiguous_strides(shape).into_iter().zip(strides) { + for (&expected, &stride) in contiguous_strides(shape).iter().zip(strides) { if expected != stride { return false; } @@ -248,6 +233,7 @@ pub(crate) fn create_server(setup: WgpuSetup, options: RuntimeOptions) -> WgpuSe num_tensor_cores: None, min_tensor_cores_dim: None, num_cpu_cores: None, // TODO: Check if device is CPU. + max_line_size: 4, }; let mut compilation_options = Default::default(); diff --git a/crates/cubecl-zspace/Cargo.toml b/crates/cubecl-zspace/Cargo.toml index da8349099..5567d8d2b 100644 --- a/crates/cubecl-zspace/Cargo.toml +++ b/crates/cubecl-zspace/Cargo.toml @@ -1,14 +1,14 @@ [package] -name = "cubecl-zspace" authors = [ "nathanielsimard ", "louisfd ", "maxtremblay ", - "crutcher " + "crutcher ", ] categories = ["science", "mathematics", "algorithms"] -keywords = [] description = "CubeCL ZSpace Library." +keywords = [] +name = "cubecl-zspace" edition.workspace = true license.workspace = true @@ -25,5 +25,7 @@ workspace = true default = [] std = [] - [dependencies] +derive-new = { workspace = true } +serde = { workspace = true } +smallvec = { workspace = true, features = ["serde"] } diff --git a/crates/cubecl-zspace/src/errors.rs b/crates/cubecl-zspace/src/errors.rs index 23aad0881..7a2250af2 100644 --- a/crates/cubecl-zspace/src/errors.rs +++ b/crates/cubecl-zspace/src/errors.rs @@ -5,6 +5,8 @@ use core::fmt::{Display, Formatter}; use core::ops::Range; use std::error::Error; +use crate::{Shape, Strides}; + /// Describes the kind of an index. #[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)] pub enum IndexKind { @@ -153,15 +155,15 @@ impl ExpressionError { /// or that the ranks match. #[derive(Debug, Clone, PartialEq)] pub struct StrideRecord { - pub shape: Vec, - pub strides: Vec, + pub shape: Shape, + pub strides: Strides, } impl StrideRecord { /// Create a new [`StrideRecord`] from a slice of usize strides. pub fn from_usize_strides(shape: &[usize], strides: &[usize]) -> StrideRecord { StrideRecord { - shape: shape.to_vec(), + shape: shape.into(), strides: strides.iter().map(|s| *s as isize).collect(), } } @@ -169,8 +171,8 @@ impl StrideRecord { /// Create a new [`StrideRecord`] from a slice of isize strides. pub fn from_isize_strides(shape: &[usize], strides: &[isize]) -> StrideRecord { StrideRecord { - shape: shape.to_vec(), - strides: strides.to_vec(), + shape: shape.into(), + strides: strides.into(), } } } diff --git a/crates/cubecl-zspace/src/lib.rs b/crates/cubecl-zspace/src/lib.rs index 389a22922..b6ff4d4e9 100644 --- a/crates/cubecl-zspace/src/lib.rs +++ b/crates/cubecl-zspace/src/lib.rs @@ -15,3 +15,16 @@ extern crate alloc; pub mod errors; pub mod indexing; pub mod striding; + +pub(crate) const INLINE_DIMS: usize = 5; + +pub mod metadata; +mod shape; +mod strides; + +/// Reexport to avoid annoying rust-analyzer bug where it imports the module instead of the macro +pub use shape::*; +pub use strides::*; + +/// Reexport for use in macros +pub use smallvec::{SmallVec, smallvec}; diff --git a/crates/cubecl-zspace/src/metadata.rs b/crates/cubecl-zspace/src/metadata.rs new file mode 100644 index 000000000..d3932e3c7 --- /dev/null +++ b/crates/cubecl-zspace/src/metadata.rs @@ -0,0 +1,96 @@ +use serde::{Deserialize, Serialize}; + +use crate::{MetadataError, shape::Shape, strides::Strides}; + +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub struct Metadata { + pub shape: Shape, + pub strides: Strides, +} + +impl Metadata { + pub fn new(shape: impl Into, strides: impl Into) -> Self { + let shape = shape.into(); + let strides = strides.into(); + debug_assert_eq!( + shape.rank(), + strides.rank(), + "Rank of shape and strides must be the same" + ); + + Self { shape, strides } + } + + pub fn shape(&self) -> &Shape { + &self.shape + } + + pub fn shape_mut(&mut self) -> &mut Shape { + &mut self.shape + } + + pub fn strides(&self) -> &Strides { + &self.strides + } + + pub fn strides_mut(&mut self) -> &mut Strides { + &mut self.strides + } + + pub fn rank(&self) -> usize { + self.num_dims() + } + + pub fn num_dims(&self) -> usize { + self.shape.num_dims() + } + + /// Returns the total number of elements of a tensor having this shape + pub fn num_elements(&self) -> usize { + self.shape.num_elements() + } + + pub fn swapped(mut self, dim0: usize, dim1: usize) -> Self { + self.swap(dim0, dim1); + self + } + + pub fn swap(&mut self, dim0: usize, dim1: usize) { + debug_assert!(dim0 < self.rank(), "dim0 is out of bounds"); + debug_assert!(dim1 < self.rank(), "dim1 is out of bounds"); + self.shape.swap(dim0, dim1); + self.strides.swap(dim0, dim1); + } + + /// Reorder the shape dimensions according to the permutation of `axes`. + pub fn permute(&mut self, axes: &[usize]) -> Result<(), MetadataError> { + self.shape.permute(axes)?; + self.strides.permute(axes)?; + + Ok(()) + } + + pub fn permuted(mut self, axes: &[usize]) -> Result { + self.permute(axes)?; + Ok(self) + } + + /// Insert a dimension of `shape` with `stride` at position `index`. + pub fn insert(&mut self, index: usize, shape: usize, stride: usize) { + self.shape.insert(index, shape); + self.strides.insert(index, stride); + } + + /// Remove and return the dimension at position `index` from the metadata. + pub fn remove(&mut self, index: usize) -> (usize, usize) { + let shape = self.shape.remove(index); + let stride = self.strides.remove(index); + (shape, stride) + } + + /// Appends a dimension of `shape` with `stride` to the back of the metadata. + pub fn push(&mut self, shape: usize, stride: usize) { + self.shape.push(shape); + self.strides.push(stride); + } +} diff --git a/crates/cubecl-zspace/src/shape.rs b/crates/cubecl-zspace/src/shape.rs new file mode 100644 index 000000000..d506cdffe --- /dev/null +++ b/crates/cubecl-zspace/src/shape.rs @@ -0,0 +1,1206 @@ +//! Tensor shape definition. + +use super::indexing::ravel_index; +use alloc::format; +use alloc::string::{String, ToString}; +use alloc::vec::Vec; +use core::fmt::{Debug, Display, Formatter}; +use core::str::FromStr; +use core::{ + ops::{Deref, DerefMut, Index, IndexMut, Range}, + slice::{Iter, IterMut, SliceIndex}, +}; +use serde::{Deserialize, Serialize}; +use smallvec::{SmallVec, smallvec}; + +pub use crate::errors::ExpressionError; +use crate::{ + INLINE_DIMS, + indexing::{AsIndex, AsSize}, +}; + +/// Shape of a tensor. +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Hash)] +pub struct Shape { + /// The dimensions of the tensor. + dims: SmallVec<[usize; INLINE_DIMS]>, +} + +#[allow(missing_docs)] +#[derive(Debug, Clone, PartialEq, Eq)] +/// Error that can occur when attempting to modify shapes. +pub enum MetadataError { + /// The operands have different ranks. + RankMismatch { left: usize, right: usize }, + /// A pair of dimensions are incompatible for broadcasting. + IncompatibleDims { + left: usize, + right: usize, + dim: usize, + }, + /// Invalid dimension specified for the rank. + OutOfBounds { dim: usize, rank: usize }, + /// A pair of shapes are incompatible for the operation. + IncompatibleShapes { left: Shape, right: Shape }, + /// Invalid shape. + Invalid { reason: String }, +} + +impl MetadataError { + fn empty() -> Self { + Self::Invalid { + reason: "Shape is empty.".into(), + } + } +} + +impl Shape { + /// Constructs a new `Shape`. + pub fn new(dims: [usize; D]) -> Self { + // For backward compat + Self { + dims: SmallVec::from_slice(&dims), + } + } + + /// Constructs a new `Shape` from raw backing storage. Mainly intended for macro use. + pub fn new_raw(dims: SmallVec<[usize; INLINE_DIMS]>) -> Self { + Self { dims } + } + + /// Returns the total number of elements of a tensor having this shape + pub fn num_elements(&self) -> usize { + self.dims.iter().product() + } + + /// Returns the number of dimensions. + /// + /// Alias for `Shape::rank()`. + pub fn num_dims(&self) -> usize { + self.dims.len() + } + + /// Returns the rank (the number of dimensions). + /// + /// Alias for `Shape::num_dims()`. + pub fn rank(&self) -> usize { + self.num_dims() + } + + // For compat with dims: [usize; D] + /// Returns the dimensions of the tensor as an array. + pub fn dims(&self) -> [usize; D] { + let mut dims = [1; D]; + dims[..D].copy_from_slice(&self.dims[..D]); + dims + } + + /// Change the shape to one dimensional with the same number of elements. + pub fn flatten(mut self) -> Self { + self.dims = SmallVec::from_slice(&[self.num_elements()]); + self + } + + /// Flatten the shape along a given range of dimensions. + /// + /// This function collapses the specified range of dimensions into a single dimension, + /// effectively flattening the tensor in that range. + /// + /// # Arguments + /// + /// - `start_dim`: The starting dimension of the range to be flattened, + /// supports negative indexing. + /// - `end_dim`: The ending dimension of the range to be flattened (inclusive), + /// supports negative indexing. + /// + /// # Returns + /// + /// A new `Shape` instance with the specified range of dimensions flattened. + /// + /// # Example + /// + /// ```rust + /// use cubecl_zspace::Shape; + /// + /// fn example() { + /// let shape = Shape::new([2, 3, 4]); + /// + /// let flattened = shape.flatten_dims(1, 2); + /// println!("{flattened}"); + /// // [2, 12] + /// } + /// ``` + pub fn flatten_dims(self, start_dim: impl AsIndex, end_dim: impl AsIndex) -> Self { + let rank = self.rank(); + let start = start_dim.expect_dim_index(rank); + let end = end_dim.expect_dim_index(rank); + + assert!( + start <= end, + "start_dim ({start}) must be <= than end_dim ({end})" + ); + + let existing = self.dims; + + let flattened_size = existing[start..=end].iter().product(); + + let new_rank = rank - (end - start); + let mut dims = smallvec![0; new_rank]; + dims[..start].copy_from_slice(&existing[..start]); + dims[start] = flattened_size; + dims[start + 1..].copy_from_slice(&existing[end + 1..]); + + Self { dims } + } + + /// Compute the ravel index for the given coordinates. + /// + /// This returns the row-major order raveling: + /// * `strides[-1] = 1` + /// * `strides[i] = strides[i+1] * dims[i+1]` + /// * `dim_strides = coords * strides` + /// * `ravel = sum(dim_strides)` + /// + /// # Arguments + /// - `indices`: the index for each dimension; must be the same length as `shape`. + /// + /// # Returns + /// - the ravel offset index. + pub fn ravel_index(&self, indices: &[I]) -> usize { + ravel_index(indices, &self.dims) + } + + /// Convert shape dimensions to full covering ranges (0..dim) for each dimension. + pub fn into_ranges(self) -> Vec> { + self.iter().map(|&d| 0..d).collect() + } + + /// Construct a vector of the dims. + pub fn to_vec(&self) -> Vec { + self.dims.to_vec() + } + + /// Returns an iterator over the shape dimensions. + pub fn iter(&self) -> Iter<'_, usize> { + self.dims.iter() + } + + /// Mutable iterator over the dimensions. + pub fn iter_mut(&mut self) -> IterMut<'_, usize> { + self.dims.iter_mut() + } + + /// Borrow the underlying dimensions slice. + pub fn as_slice(&self) -> &[usize] { + &self.dims + } + + /// Borrow the underlying dimensions slice mutably. + pub fn as_mut_slice(&mut self) -> &mut [usize] { + &mut self.dims + } + + /// Insert a dimension of `size` at position `index`. + pub fn insert(&mut self, index: usize, size: usize) { + self.dims.insert(index, size); + } + + /// Remove and return the dimension at position `index` from the shape. + pub fn remove(&mut self, index: usize) -> usize { + self.dims.remove(index) + } + + /// Appends a dimension of `size` to the back of the shape. + pub fn push(&mut self, size: usize) { + self.dims.push(size) + } + + /// Extend the shape with the content of another shape or iterator. + pub fn extend(&mut self, iter: impl IntoIterator) { + self.dims.extend(iter) + } + + /// Swap two dimensions in the shape. + pub fn swapped(mut self, dim1: usize, dim2: usize) -> Result { + if dim1 >= self.rank() { + return Err(MetadataError::OutOfBounds { + dim: dim1, + rank: self.rank(), + }); + } + if dim2 >= self.rank() { + return Err(MetadataError::OutOfBounds { + dim: dim2, + rank: self.rank(), + }); + } + self.dims.swap(dim1, dim2); + Ok(self) + } + + /// Reorder the shape dimensions according to the permutation of `axes`. + pub fn permute(&mut self, axes: &[usize]) -> Result<(), MetadataError> { + if axes.len() != self.rank() { + return Err(MetadataError::RankMismatch { + left: self.rank(), + right: axes.len(), + }); + } + debug_assert!(axes.iter().all(|i| i < &self.rank())); + + self.dims = axes.iter().map(|&i| self.dims[i]).collect(); + Ok(()) + } + + /// Reorder the shape dimensions according to the permutation of `axes`. + pub fn permuted(mut self, axes: &[usize]) -> Result { + self.permute(axes)?; + Ok(self) + } + + /// Repeated the specified `dim` a number of `times`. + pub fn repeat(mut self, dim: usize, times: usize) -> Result { + if dim >= self.rank() { + return Err(MetadataError::OutOfBounds { + dim, + rank: self.rank(), + }); + } + + self.dims[dim] *= times; + Ok(self) + } + + /// Returns a new shape where the specified `dim` is reduced to size 1. + pub fn reduce(mut self, dim: usize) -> Result { + if dim >= self.rank() { + return Err(MetadataError::OutOfBounds { + dim, + rank: self.rank(), + }); + } + + self.dims[dim] = 1; + Ok(self) + } + + /// Concatenates all shapes into a new one along the given dimension. + pub fn cat<'a, I>(shapes: I, dim: usize) -> Result + where + I: IntoIterator, + { + let mut iter = shapes.into_iter(); + + let first = iter.next().ok_or(MetadataError::empty())?; + + if dim >= first.rank() { + return Err(MetadataError::OutOfBounds { + dim, + rank: first.rank(), + }); + } + + let mut shape = first.clone(); + + for s in iter { + if s.rank() != shape.rank() { + return Err(MetadataError::RankMismatch { + left: shape.rank(), + right: s.rank(), + }); + } + + if s[..dim] != shape[..dim] || s[dim + 1..] != shape[dim + 1..] { + return Err(MetadataError::IncompatibleShapes { + left: shape.clone(), + right: s.clone(), + }); + } + + shape[dim] += s[dim]; + } + + Ok(shape) + } + + /// Compute the output shape for binary operations with broadcasting support. + /// + /// - Shapes must be of the same rank (missing dimensions are not handled automatically). + /// - Two dimensions are compatible if they are equal, or one of them is 1. + /// + /// For example, a shape `[1, 1, 2, 4]` can be broadcast into `[7, 6, 2, 4]` + /// because its axes are either equal or 1. On the other hand, a shape `[2, 2]` + /// can *not* be broadcast into `[2, 4]`. + pub fn broadcast(&self, other: &Self) -> Result { + Self::broadcast_many([self, other]) + } + + /// Compute the broadcasted output shape across multiple input shapes. + /// + /// See also [broadcast](Self::broadcast). + pub fn broadcast_many<'a, I>(shapes: I) -> Result + where + I: IntoIterator, + { + let mut iter = shapes.into_iter(); + let mut broadcasted = iter.next().ok_or(MetadataError::empty())?.clone(); + let rank = broadcasted.rank(); + + for shape in iter { + if shape.rank() != rank { + return Err(MetadataError::RankMismatch { + left: rank, + right: shape.rank(), + }); + } + + for (dim, (d_lhs, &d_rhs)) in broadcasted.iter_mut().zip(shape.iter()).enumerate() { + match (*d_lhs, d_rhs) { + (a, b) if a == b => {} // same + (1, b) => *d_lhs = b, // broadcast to rhs + (_a, 1) => {} // keep existing dimension + _ => { + return Err(MetadataError::IncompatibleDims { + left: *d_lhs, + right: d_rhs, + dim, + }); + } + } + } + } + + Ok(broadcasted) + } + + /// Expand this shape to match the target shape, following broadcasting rules. + pub fn expand(&self, target: Shape) -> Result { + let target_rank = target.rank(); + if self.rank() > target_rank { + return Err(MetadataError::RankMismatch { + left: self.rank(), + right: target_rank, + }); + } + + for (i, (dim_target, dim_self)) in target.iter().rev().zip(self.iter().rev()).enumerate() { + if dim_self != dim_target && *dim_self != 1 { + return Err(MetadataError::IncompatibleDims { + left: *dim_self, + right: *dim_target, + dim: target_rank - i - 1, + }); + } + } + + Ok(target) + } + + /// Reshape this shape to the target shape. + pub fn reshape(&self, args: A) -> Result + where + A: AsRef<[T]> + Debug, + T: AsIndex, + { + let args = args.as_ref(); + let mut infer_index = None; + let mut dims = Vec::new(); + + let mut new_size = 1; + + for (idx, &s) in args.iter().enumerate() { + let s = s.as_index(); + if s > 0 { + let s = s as usize; + new_size *= s; + dims.push(s); + } else if s == 0 { + // We need to find the index of the 0 dimensions and + // replace them with the actual dimension value. + let s = self.dims[idx]; + new_size *= s; + dims.push(s); + } else if s == -1 { + match infer_index { + None => { + infer_index = Some(idx); + // Used by / Replaced by handling later. + dims.push(1); + } + Some(_) => { + return Err(MetadataError::Invalid { + reason: "Repeated -1 in reshape".to_string(), + }); + } + } + } else { + return Err(MetadataError::Invalid { + reason: "The given shape cannot contain negative dimensions (other than -1)." + .to_string(), + }); + } + } + + let source_size = self.num_elements(); + match infer_index { + None => { + if source_size != new_size { + return Err(MetadataError::Invalid { + reason: format!( + "The given shape doesn't have the same number of elements as the current shape. Current shape: {self}, target shape: {dims:?}.", + ), + }); + } + } + Some(idx) => { + if !source_size.is_multiple_of(new_size) { + return Err(MetadataError::Invalid { + reason: format!( + "Cannot infer a valid target shape. Current shape: {self}, target dimensions: {args:?}." + ), + }); + } + dims[idx] = source_size / new_size; + } + } + + Ok(dims.into()) + } +} + +#[macro_export] +macro_rules! shape { + (@one $x:expr) => (1usize); + () => ( + $crate::Shape::new_raw($crate::SmallVec::new()) + ); + ($elem:expr; $n:expr) => ({ + $crate::Shape::new_raw($crate::smallvec!($elem; $n)) + }); + ($($x:expr),+$(,)?) => ({ + $crate::Shape::new_raw($crate::smallvec!($($x),*)) + }); +} + +/// Compute the output shape for matrix multiplication with broadcasting support. +/// +/// The last two dimensions are treated as matrices, while preceding dimensions +/// follow broadcast semantics similar to elementwise operations. +pub fn calculate_matmul_output(lhs: &Shape, rhs: &Shape) -> Result { + let rank = lhs.rank(); + if rank != rhs.rank() { + return Err(MetadataError::RankMismatch { + left: rank, + right: rhs.rank(), + }); + } + + if lhs[rank - 1] != rhs[rank - 2] { + return Err(MetadataError::IncompatibleShapes { + left: lhs.clone(), + right: rhs.clone(), + }); + } + + let mut shape = if rank > 2 { + // Broadcast leading dims + Shape::from(&lhs[..rank - 2]).broadcast(&Shape::from(&rhs[..rank - 2]))? + } else { + Shape::new([]) + }; + shape.extend([lhs[rank - 2], rhs[rank - 1]]); + + Ok(shape) +} + +impl Display for Shape { + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + self.dims.fmt(f) + } +} + +impl FromStr for Shape { + type Err = ExpressionError; + + fn from_str(source: &str) -> Result { + let mut s = source.trim(); + + const DELIMS: [(&str, &str); 2] = [("[", "]"), ("(", ")")]; + + for (open, close) in DELIMS { + if let Some(p) = s.strip_prefix(open) { + if let Some(p) = p.strip_suffix(close) { + s = p.trim(); + break; + } else { + return Err(ExpressionError::ParseError { + message: "Unbalanced delimiters".to_string(), + source: source.to_string(), + }); + } + } + } + + if s.is_empty() { + return Ok(Shape::new([])); + } + + let dims = s + .split(',') + .map(|dim_str| { + dim_str + .trim() + .parse::() + .map_err(|_| ExpressionError::ParseError { + message: "Unable to parse shape".to_string(), + source: source.to_string(), + }) + }) + .collect::, ExpressionError>>()?; + + if dims.is_empty() { + unreachable!("Split should have returned at least one element"); + } + + Ok(Shape { dims }) + } +} + +impl Index for Shape +where + Idx: SliceIndex<[usize]>, +{ + type Output = Idx::Output; + + fn index(&self, index: Idx) -> &Self::Output { + &self.dims[index] + } +} + +impl IndexMut for Shape +where + Idx: SliceIndex<[usize]>, +{ + fn index_mut(&mut self, index: Idx) -> &mut Self::Output { + &mut self.dims[index] + } +} + +// Allow `&shape` to behave like a slice `&[usize]` directly +impl Deref for Shape { + type Target = [usize]; + + fn deref(&self) -> &Self::Target { + &self.dims + } +} + +// Allow `&shape` to behave like a mut slice `&mut [usize]` directly +impl DerefMut for Shape { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.dims + } +} +// Allow `shape.reshape(other_shape)`. +// +// By implementing `AsRef<[usize]>`, `Shape` behaves like a slice of dimensions, +// similar to how `Vec` can be passed to functions expecting a slice. +impl AsRef<[usize]> for Shape { + fn as_ref(&self) -> &[usize] { + &self.dims + } +} + +impl From for Vec { + fn from(shape: Shape) -> Self { + shape.dims.to_vec() + } +} + +impl From for Shape +where + T: IntoIterator, + I: AsSize, +{ + fn from(dims: T) -> Self { + Shape { + dims: dims.into_iter().map(|d| d.as_size()).collect(), + } + } +} + +impl From<&Shape> for Shape { + fn from(value: &Shape) -> Self { + value.clone() + } +} + +impl FromIterator for Shape { + fn from_iter>(iter: T) -> Self { + Shape { + dims: iter.into_iter().map(|it| it.as_size()).collect(), + } + } +} + +#[cfg(test)] +#[allow(clippy::identity_op, reason = "useful for clarity")] +mod tests { + use super::*; + use alloc::string::ToString; + use alloc::vec; + + #[test] + fn test_shape_to_str() { + let shape = Shape::new([2, 3, 4, 5]); + assert_eq!(shape.to_string(), "[2, 3, 4, 5]"); + } + + #[test] + fn test_shape_from_str() { + assert_eq!( + "[2, 3, 4, 5]".parse::().unwrap(), + Shape::new([2, 3, 4, 5]) + ); + assert_eq!( + "(2, 3, 4, 5)".parse::().unwrap(), + Shape::new([2, 3, 4, 5]) + ); + assert_eq!( + "2, 3, 4, 5".parse::().unwrap(), + Shape::new([2, 3, 4, 5]) + ); + + assert_eq!("[2]".parse::().unwrap(), Shape::new([2])); + assert_eq!("(2)".parse::().unwrap(), Shape::new([2])); + assert_eq!("2".parse::().unwrap(), Shape::new([2])); + + assert_eq!("[]".parse::().unwrap(), Shape::new([])); + assert_eq!("".parse::().unwrap(), Shape::new([])); + + assert_eq!( + "[".parse::(), + Err(ExpressionError::ParseError { + message: "Unbalanced delimiters".to_string(), + source: "[".to_string() + }) + ); + + assert_eq!( + "[[1]".parse::(), + Err(ExpressionError::ParseError { + message: "Unable to parse shape".to_string(), + source: "[[1]".to_string() + }) + ); + assert_eq!( + "[[1]]".parse::(), + Err(ExpressionError::ParseError { + message: "Unable to parse shape".to_string(), + source: "[[1]]".to_string() + }) + ); + assert_eq!( + "[1)".parse::(), + Err(ExpressionError::ParseError { + message: "Unbalanced delimiters".to_string(), + source: "[1)".to_string() + }) + ); + + assert_eq!( + "]".parse::(), + Err(ExpressionError::ParseError { + message: "Unable to parse shape".to_string(), + source: "]".to_string() + }) + ); + + assert_eq!( + "[a]".parse::(), + Err(ExpressionError::ParseError { + message: "Unable to parse shape".to_string(), + source: "[a]".to_string() + }) + ); + } + + #[test] + fn num_dims_and_rank() { + let dims = [2, 3, 4, 5]; + let shape = Shape::new(dims); + assert_eq!(4, shape.num_dims()); + assert_eq!(4, shape.rank()); + } + + #[test] + fn num_elements() { + let dims = [2, 3, 4, 5]; + let shape = Shape::new(dims); + assert_eq!(120, shape.num_elements()); + } + + #[test] + #[allow(clippy::into_iter_on_ref)] + fn test_shape_into_iter() { + let dims = [2, 3, 4, 5]; + let shape = Shape::new(dims); + + assert_eq!(shape.into_iter().sum::(), 14); + } + + #[test] + fn test_into_ranges() { + let dims = [2, 3, 4, 5]; + let shape = Shape::new(dims); + assert_eq!(shape.into_ranges(), vec![0..2, 0..3, 0..4, 0..5]); + } + + #[test] + fn test_to_vec() { + let dims = [2, 3, 4, 5]; + let shape = Shape::new(dims); + assert_eq!(shape.to_vec(), vec![2, 3, 4, 5]); + } + + #[test] + fn test_shape_index() { + let shape = Shape::new([2, 3, 4, 5]); + + assert_eq!(shape[0], 2); + assert_eq!(shape[1], 3); + assert_eq!(shape[2], 4); + assert_eq!(shape[3], 5); + + // Works with ranges + assert_eq!(shape[1..3], *&[3, 4]); + assert_eq!(shape[1..=2], *&[3, 4]); + assert_eq!(shape[..], *&[2, 3, 4, 5]); + } + + #[test] + fn test_shape_slice_methods() { + let shape = Shape::new([2, 3, 4, 5]); + + let dim = shape.first(); + assert_eq!(dim, Some(&2)); + let dim = shape.last(); + assert_eq!(dim, Some(&5)); + + assert!(!shape.is_empty()); + let shape = Shape::new([]); + assert!(shape.is_empty()); + } + + #[test] + fn test_shape_iter() { + let dims = [2, 3, 4, 5]; + let shape = Shape::new(dims); + + for (d, sd) in dims.iter().zip(shape.iter()) { + assert_eq!(d, sd); + } + } + + #[test] + fn test_shape_iter_mut() { + let mut shape = Shape::new([2, 3, 4, 5]); + + for d in shape.iter_mut() { + *d += 1; + } + + assert_eq!(shape.as_slice(), &[3, 4, 5, 6]); + } + + #[test] + fn test_shape_as_slice() { + let dims = [2, 3, 4, 5]; + let shape = Shape::new(dims); + + assert_eq!(shape.as_slice(), dims.as_slice()); + + // Deref coercion + let shape_slice: &[usize] = &shape; + assert_eq!(shape_slice, *&[2, 3, 4, 5]); + } + + #[test] + fn test_shape_as_mut_slice() { + let mut dims = [2, 3, 4, 5]; + let mut shape = Shape::new(dims); + + let shape_mut = shape.as_mut_slice(); + assert_eq!(shape_mut, dims.as_mut_slice()); + shape_mut[1] = 6; + + assert_eq!(shape_mut, &[2, 6, 4, 5]); + + let mut shape = Shape::new(dims); + let shape = &mut shape[..]; + shape[1] = 6; + + assert_eq!(shape, shape_mut) + } + + #[test] + fn test_shape_flatten() { + let shape = Shape::new([2, 3, 4, 5]); + assert_eq!(shape.num_elements(), 120); + + let shape = shape.flatten(); + assert_eq!(shape.num_elements(), 120); + assert_eq!(shape.as_slice(), &[120]); + } + + #[test] + fn test_ravel() { + let shape = Shape::new([2, 3, 4, 5]); + + assert_eq!(shape.ravel_index(&[0, 0, 0, 0]), 0); + assert_eq!( + shape.ravel_index(&[1, 2, 3, 4]), + 1 * (3 * 4 * 5) + 2 * (4 * 5) + 3 * 5 + 4 + ); + } + + #[test] + fn test_shape_insert_remove_push() { + let dims = [2, 3, 4, 5]; + let mut shape = Shape::new(dims); + let size = 6; + shape.insert(1, size); + + assert_eq!(shape, Shape::new([2, 6, 3, 4, 5])); + + let removed = shape.remove(1); + assert_eq!(removed, size); + assert_eq!(shape, Shape::new(dims)); + + shape.push(6); + assert_eq!(shape, Shape::new([2, 3, 4, 5, 6])); + } + + #[test] + fn test_shape_swap_permute() { + let dims = [2, 3, 4, 5]; + let shape = Shape::new(dims); + let shape = shape.swapped(1, 2).unwrap(); + + assert_eq!(shape.as_slice(), &[2, 4, 3, 5]); + + let shape = shape.permuted(&[0, 2, 1, 3]).unwrap(); + assert_eq!(shape, Shape::new(dims)); + } + + #[test] + #[should_panic] + fn test_shape_swap_out_of_bounds() { + let shape = Shape::new([2, 3, 4, 5]); + + shape.swapped(0, 4).unwrap(); + } + + #[test] + #[should_panic] + fn test_shape_permute_incomplete() { + let shape = Shape::new([2, 3, 4, 5]); + + shape.permuted(&[0, 2, 1]).unwrap(); + } + + #[test] + fn test_shape_repeat() { + let shape = Shape::new([2, 3, 4, 5]); + + let out = shape.repeat(2, 3).unwrap(); + assert_eq!(out, Shape::new([2, 3, 12, 5])); + } + + #[test] + fn test_shape_repeat_invalid() { + let shape = Shape::new([2, 3, 4, 5]); + + let out = shape.repeat(5, 3); + assert_eq!(out, Err(MetadataError::OutOfBounds { dim: 5, rank: 4 })); + } + + #[test] + fn test_shape_reduce() { + let shape = Shape::new([2, 3, 4, 5]); + + let out = shape.reduce(2).unwrap(); + assert_eq!(out, Shape::new([2, 3, 1, 5])); + } + + #[test] + fn test_shape_reduce_invalid() { + let shape = Shape::new([2, 3, 4, 5]); + + let out = shape.reduce(5); + assert_eq!(out, Err(MetadataError::OutOfBounds { dim: 5, rank: 4 })); + } + + #[test] + fn test_shape_broadcast_binary() { + let lhs = Shape::new([1, 1, 2, 4]); + let rhs = Shape::new([7, 6, 2, 1]); + + let out = lhs.broadcast(&rhs).unwrap(); + assert_eq!(out, Shape::new([7, 6, 2, 4])); + } + + #[test] + fn test_shape_broadcast_rank_mismatch() { + let lhs = Shape::new([1, 2, 4]); + let rhs = Shape::new([7, 6, 2, 4]); + + let out = lhs.broadcast(&rhs); + assert_eq!(out, Err(MetadataError::RankMismatch { left: 3, right: 4 })); + } + + #[test] + fn test_shape_broadcast_incompatible_dims() { + let lhs = Shape::new([1, 2, 2, 4]); + let rhs = Shape::new([7, 6, 2, 1]); + + let out = lhs.broadcast(&rhs); + assert_eq!( + out, + Err(MetadataError::IncompatibleDims { + left: 2, + right: 6, + dim: 1 + }) + ); + } + + #[test] + fn test_shape_broadcast_many() { + let s1 = Shape::new([1, 1, 2, 4]); + let s2 = Shape::new([7, 1, 2, 1]); + let s3 = Shape::new([7, 6, 1, 1]); + + let out = Shape::broadcast_many([&s1, &s2, &s3]).unwrap(); + assert_eq!(out, Shape::new([7, 6, 2, 4])); + } + + #[test] + fn test_shape_broadcast_many_rank_mismatch() { + let s1 = Shape::new([1, 1, 2, 4]); + let s2 = Shape::new([7, 1, 2, 1]); + let s3 = Shape::new([1, 6, 1]); + + let out = Shape::broadcast_many([&s1, &s2, &s3]); + assert_eq!(out, Err(MetadataError::RankMismatch { left: 4, right: 3 })); + } + + #[test] + fn test_shape_broadcast_many_incompatible_dims() { + let s1 = Shape::new([1, 1, 2, 4]); + let s2 = Shape::new([7, 1, 2, 1]); + let s3 = Shape::new([4, 6, 1, 1]); + + let out = Shape::broadcast_many([&s1, &s2, &s3]); + assert_eq!( + out, + Err(MetadataError::IncompatibleDims { + left: 7, + right: 4, + dim: 0 + }) + ); + } + + #[test] + fn test_shape_broadcast_many_empty() { + let out = Shape::broadcast_many(&[]); + assert_eq!(out, Err(MetadataError::empty())); + } + + #[test] + fn test_shape_matmul_2d() { + let lhs = Shape::new([2, 4]); + let rhs = Shape::new([4, 2]); + let out = calculate_matmul_output(&lhs, &rhs).unwrap(); + assert_eq!(out, Shape::new([2, 2])); + } + + #[test] + fn test_shape_matmul_4d_broadcasted() { + let lhs = Shape::new([1, 3, 2, 4]); + let rhs = Shape::new([2, 1, 4, 2]); + let out = calculate_matmul_output(&lhs, &rhs).unwrap(); + assert_eq!(out, Shape::new([2, 3, 2, 2])); + } + + #[test] + fn test_shape_matmul_invalid_rank() { + let lhs = Shape::new([3, 2, 4]); + let rhs = Shape::new([2, 1, 4, 2]); + let out = calculate_matmul_output(&lhs, &rhs); + assert_eq!(out, Err(MetadataError::RankMismatch { left: 3, right: 4 })); + } + + #[test] + fn test_shape_matmul_invalid_shape() { + let lhs = Shape::new([1, 3, 2, 4]); + let rhs = Shape::new([2, 1, 3, 2]); + let out = calculate_matmul_output(&lhs, &rhs); + assert_eq!( + out, + Err(MetadataError::IncompatibleShapes { + left: lhs, + right: rhs + }) + ); + } + + #[test] + fn test_shape_matmul_invalid_broadcast() { + let lhs = Shape::new([1, 3, 2, 4]); + let rhs = Shape::new([2, 2, 4, 2]); + let out = calculate_matmul_output(&lhs, &rhs); + assert_eq!( + out, + Err(MetadataError::IncompatibleDims { + left: 3, + right: 2, + dim: 1 + }) + ); + } + + #[test] + fn test_shape_cat() { + let s1 = Shape::new([2, 3, 4, 5]); + let s2 = Shape::new([1, 3, 4, 5]); + let s3 = Shape::new([4, 3, 4, 5]); + + let out = Shape::cat(&[s1, s2, s3], 0).unwrap(); + assert_eq!(out, Shape::new([7, 3, 4, 5])); + + let s1 = Shape::new([2, 3, 4, 5]); + let s2 = Shape::new([2, 3, 2, 5]); + let s3 = Shape::new([2, 3, 1, 5]); + + let out = Shape::cat(&[s1, s2, s3], 2).unwrap(); + assert_eq!(out, Shape::new([2, 3, 7, 5])); + } + + #[test] + fn test_shape_cat_empty() { + let out = Shape::cat(&[], 0); + assert_eq!(out, Err(MetadataError::empty())); + } + + #[test] + fn test_shape_cat_dim_out_of_bounds() { + let s1 = Shape::new([2, 3, 4, 5]); + let s2 = Shape::new([2, 3, 4, 5]); + let out = Shape::cat(&[s1, s2], 4); + assert_eq!(out, Err(MetadataError::OutOfBounds { dim: 4, rank: 4 })); + } + + #[test] + fn test_shape_cat_rank_mismatch() { + let s1 = Shape::new([2, 3, 4, 5]); + let s2 = Shape::new([2, 3, 4, 5, 6]); + let out = Shape::cat(&[s1, s2], 0); + assert_eq!(out, Err(MetadataError::RankMismatch { left: 4, right: 5 })); + } + + #[test] + fn test_shape_cat_incompatible_shapes() { + let s1 = Shape::new([2, 3, 4, 5]); + let s2 = Shape::new([1, 3, 4, 5]); + let out = Shape::cat(&[s1.clone(), s2.clone()], 1); + + assert_eq!( + out, + Err(MetadataError::IncompatibleShapes { + left: s1, + right: s2 + }) + ); + } + + #[test] + fn test_shape_expand() { + let shape = Shape::new([1, 3, 1]); + let expanded = Shape::new([2, 3, 4]); + let out = shape.expand(expanded.clone()).unwrap(); + assert_eq!(out, expanded); + } + + #[test] + fn test_shape_expand_higher_rank() { + let shape = Shape::new([1, 4]); + let expanded = Shape::new([2, 3, 4]); + let out = shape.expand(expanded.clone()).unwrap(); + assert_eq!(out, expanded); + } + + #[test] + fn test_shape_expand_invalid_rank() { + let shape = Shape::new([1, 3, 1]); + let expanded = Shape::new([3, 4]); + let out = shape.expand(expanded); + assert_eq!(out, Err(MetadataError::RankMismatch { left: 3, right: 2 })); + } + + #[test] + fn test_shape_expand_incompatible_dims() { + let shape = Shape::new([1, 3, 2]); + let expanded = Shape::new([2, 3, 4]); + let out = shape.expand(expanded); + assert_eq!( + out, + Err(MetadataError::IncompatibleDims { + left: 2, + right: 4, + dim: 2 + }) + ); + } + + #[test] + fn test_shape_reshape() { + let shape = Shape::new([2, 3, 4, 5]); + let reshaped = Shape::new([1, 2, 12, 5]); + let out = shape.reshape(reshaped.clone()).unwrap(); + assert_eq!(out, reshaped); + } + + #[test] + fn test_shape_reshape_invalid() { + let shape = Shape::new([2, 3, 4, 5]); + let reshaped = Shape::new([2, 2, 12, 5]); + let out = shape.reshape(reshaped.clone()); + assert_eq!( + out, + Err(MetadataError::Invalid { + reason: "The given shape doesn't have the same number of elements as the current shape. Current shape: [2, 3, 4, 5], target shape: [2, 2, 12, 5].".into(), + }) + ); + } + + #[test] + fn test_shape_reshape_invalid_inferred() { + let shape = Shape::new([2, 4]); + let out = shape.reshape([-1, 3]); + assert_eq!( + out, + Err(MetadataError::Invalid { + reason: "Cannot infer a valid target shape. Current shape: [2, 4], target dimensions: [-1, 3].".into(), + }) + ); + } + + #[test] + fn test_flatten_dims() { + let shape = Shape::new([2, 3, 4, 5]); + let flattened = shape.flatten_dims(-2, 3); + assert_eq!(flattened, Shape::new([2, 3, 20])); + } +} diff --git a/crates/cubecl-zspace/src/strides.rs b/crates/cubecl-zspace/src/strides.rs new file mode 100644 index 000000000..d1fd7e5e6 --- /dev/null +++ b/crates/cubecl-zspace/src/strides.rs @@ -0,0 +1,122 @@ +use core::ops::{Deref, DerefMut}; + +use serde::{Deserialize, Serialize}; +use smallvec::SmallVec; + +use crate::{INLINE_DIMS, MetadataError, indexing::AsSize}; + +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Hash)] +pub struct Strides { + dims: SmallVec<[usize; INLINE_DIMS]>, +} + +impl Strides { + pub fn new(dims: &[usize]) -> Self { + // For backward compat + Self { + dims: SmallVec::from_slice(dims), + } + } + + pub fn new_raw(dims: SmallVec<[usize; INLINE_DIMS]>) -> Self { + Self { dims } + } + + pub fn rank(&self) -> usize { + self.dims.len() + } + + /// Insert a dimension of `stride` at position `index`. + pub fn insert(&mut self, index: usize, stride: usize) { + self.dims.insert(index, stride); + } + + /// Remove and return the dimension at position `index` from the strides. + pub fn remove(&mut self, index: usize) -> usize { + self.dims.remove(index) + } + + /// Appends a dimension of `stride` to the back of the strides. + pub fn push(&mut self, stride: usize) { + self.dims.push(stride) + } + + /// Extend the strides with the content of another shape or iterator. + pub fn extend(&mut self, iter: impl IntoIterator) { + self.dims.extend(iter) + } + + /// Reorder the strides dimensions according to the permutation of `axes`. + pub fn permute(&mut self, axes: &[usize]) -> Result<(), MetadataError> { + if axes.len() != self.rank() { + return Err(MetadataError::RankMismatch { + left: self.rank(), + right: axes.len(), + }); + } + debug_assert!(axes.iter().all(|i| i < &self.rank())); + + self.dims = axes.iter().map(|&i| self.dims[i]).collect(); + Ok(()) + } + + /// Reorder the strides dimensions according to the permutation of `axes`. + pub fn permuted(mut self, axes: &[usize]) -> Result { + self.permute(axes)?; + Ok(self) + } +} + +impl Deref for Strides { + type Target = [usize]; + + fn deref(&self) -> &Self::Target { + &self.dims + } +} + +impl DerefMut for Strides { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.dims + } +} + +#[macro_export] +macro_rules! strides { + (@one $x:expr) => (1usize); + () => ( + $crate::Strides::new_raw($crate::SmallVec::new()) + ); + ($elem:expr; $n:expr) => ({ + $crate::Strides::new_raw($crate::smallvec!($elem; $n)) + }); + ($($x:expr),+$(,)?) => ({ + $crate::Strides::new_raw($crate::smallvec!($($x),*)) + }); +} + +impl From for Strides +where + T: IntoIterator, + I: AsSize, +{ + fn from(dims: T) -> Self { + Strides { + dims: dims.into_iter().map(|d| d.as_size()).collect(), + } + } +} + +impl From<&Strides> for Strides { + fn from(value: &Strides) -> Self { + value.clone() + } +} + +impl FromIterator for Strides { + fn from_iter>(iter: T) -> Self { + Strides { + dims: iter.into_iter().map(|it| it.as_size()).collect(), + } + } +} diff --git a/crates/cubecl-zspace/src/striding/layout_builders.rs b/crates/cubecl-zspace/src/striding/layout_builders.rs index 85bdc6372..bf045fdab 100644 --- a/crates/cubecl-zspace/src/striding/layout_builders.rs +++ b/crates/cubecl-zspace/src/striding/layout_builders.rs @@ -1,6 +1,6 @@ //! # Stride Layout Builders -use alloc::vec; +use crate::{Strides, strides}; /// Construct row-major contiguous strides for a shape. /// @@ -11,13 +11,13 @@ use alloc::vec; /// - ``for i in 0..rank - 1 { strides[i] == strides[i + 1] * shape[i + 1] }`` /// /// If ``rank == 0``, this will return ``vec![]``. -pub fn row_major_contiguous_strides(shape: S) -> Vec +pub fn row_major_contiguous_strides(shape: S) -> Strides where S: AsRef<[usize]>, { let shape = shape.as_ref(); let rank = shape.len(); - let mut strides = vec![1; rank]; + let mut strides = strides![1; rank]; if rank > 1 { for i in (0..rank - 1).rev() { strides[i] = strides[i + 1] * shape[i + 1]; @@ -32,7 +32,7 @@ mod tests { #[test] fn test_row_major_contiguous_strides() { - assert_eq!(row_major_contiguous_strides([]), vec![]); - assert_eq!(row_major_contiguous_strides([1, 2, 3]), vec![6, 3, 1]); + assert_eq!(row_major_contiguous_strides([]), strides![]); + assert_eq!(row_major_contiguous_strides([1, 2, 3]), strides![6, 3, 1]); } } diff --git a/crates/cubecl-zspace/src/striding/layout_validation.rs b/crates/cubecl-zspace/src/striding/layout_validation.rs index bf5a46af2..0fe7a3a4f 100644 --- a/crates/cubecl-zspace/src/striding/layout_validation.rs +++ b/crates/cubecl-zspace/src/striding/layout_validation.rs @@ -184,6 +184,8 @@ where #[cfg(test)] mod tests { + use crate::{shape, strides}; + use super::*; #[test] @@ -194,8 +196,8 @@ mod tests { &try_check_matching_ranks([1, 2], [1, 2, 3]), &Err(StrideError::MalformedRanks { record: StrideRecord { - shape: vec![1, 2], - strides: vec![1, 2, 3] + shape: shape![1, 2], + strides: strides![1, 2, 3] } }) ); @@ -214,8 +216,8 @@ mod tests { Err(StrideError::UnsupportedRank { rank: 0, record: StrideRecord { - shape: vec![], - strides: vec![] + shape: shape![], + strides: strides![] } }) ); @@ -226,8 +228,8 @@ mod tests { Err(StrideError::Invalid { message: "strides are not contiguous in row major order".to_string(), record: StrideRecord { - shape: vec![2, 2], - strides: vec![3, 1] + shape: shape![2, 2], + strides: strides![3, 1] } }) ); @@ -238,8 +240,8 @@ mod tests { Err(StrideError::Invalid { message: "strides are not contiguous in row major order".to_string(), record: StrideRecord { - shape: vec![1, 2], - strides: vec![1, 2] + shape: shape![1, 2], + strides: strides![1, 2] } }) );