Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@ variadics_please = "1"
dashmap = "6.1.0"
foldhash = { version = "0.2", default-features = false }
hashbrown = "0.16"
smallvec = { version = "1", features = ["union", "const_generics"] }
smallvec = { version = "1", features = [
"union",
"const_generics",
"const_new",
] }
spin = { version = "0.10.0", features = ["mutex", "spin_mutex"] }
xxhash-rust = { version = "0.8", default-features = false }

Expand Down
4 changes: 2 additions & 2 deletions crates/cubecl-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ tracing = ["dep:tracing", "cubecl-runtime/tracing", "cubecl-ir/tracing"]
[dependencies]
# Runtime Deps
log = { workspace = true }
tracing = { workspace = true, features = ["attributes"], optional = true }
test-log = { workspace = true, optional = true }
tracing = { workspace = true, features = ["attributes"], optional = true }

cubecl-ir = { path = "../cubecl-ir", version = "=0.10.0-pre.1", default-features = false, features = [
"serde",
Expand Down Expand Up @@ -63,6 +63,6 @@ tempfile = { version = "3.20", optional = true }
variadics_please = { workspace = true }

[dev-dependencies]
test-log = { workspace = true, features = ["trace"] }
pretty_assertions = { workspace = true }
test-log = { workspace = true, features = ["trace"] }
trybuild = "1"
3 changes: 2 additions & 1 deletion crates/cubecl-core/src/frontend/container/sequence/launch.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::{cell::RefCell, rc::Rc};

use cubecl_runtime::runtime::Runtime;
use cubecl_zspace::SmallVec;

use crate::{
compute::KernelBuilder,
Expand Down Expand Up @@ -29,7 +30,7 @@ impl<'a, R: Runtime, T: LaunchArg> SequenceArg<'a, R, T> {
}

pub struct SequenceCompilationArg<C: LaunchArg> {
pub values: Vec<C::CompilationArg>,
pub values: SmallVec<[C::CompilationArg; 5]>,
}

impl<C: LaunchArg> CompilationArg for SequenceCompilationArg<C> {}
Expand Down
9 changes: 4 additions & 5 deletions crates/cubecl-core/src/frontend/container/tensor/tensormap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::ir::ExpandElement;
use crate::{prelude::*, unexpanded};
use cubecl_ir::{LineSize, StorageType, Type};
use cubecl_runtime::server::TensorMapMeta;
use cubecl_zspace::{Strides, metadata::Metadata, strides};
use paste::paste;
use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -74,10 +75,8 @@ impl<'a, R: Runtime, K: TensorMapKind> TensorMapArg<'a, R, K> {
Self {
metadata: TensorMapMeta {
format: K::as_format(args),
rank,
shape: handle.shape.to_vec(),
strides: handle.strides.to_vec(),
elem_stride: vec![1; rank],
metadata: Metadata::new(handle.shape, handle.strides),
elem_stride: strides![1; rank],
interleave: TensorMapInterleave::None,
swizzle: TensorMapSwizzle::None,
prefetch: TensorMapPrefetch::None,
Expand All @@ -89,7 +88,7 @@ impl<'a, R: Runtime, K: TensorMapKind> TensorMapArg<'a, R, K> {
}
}

pub fn with_elem_stride(mut self, elem_stride: Vec<usize>) -> Self {
pub fn with_elem_stride(mut self, elem_stride: Strides) -> Self {
self.metadata.elem_stride = elem_stride;
self
}
Expand Down
4 changes: 2 additions & 2 deletions crates/cubecl-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,12 @@ pub enum LineSizeError {
/// is divisible by the vectorization.
/// The last condition ensure that the current axis is contiguous within the next stride.
pub fn tensor_line_size_parallel(
supported_line_sizes: impl Iterator<Item = LineSize>,
optimized_line_sizes: impl Iterator<Item = LineSize>,
shape: &[usize],
strides: &[usize],
axis: usize,
) -> LineSize {
try_tensor_line_size_parallel(supported_line_sizes, shape, strides, axis).unwrap_or(1)
try_tensor_line_size_parallel(optimized_line_sizes, shape, strides, axis).unwrap_or(1)
}

/// Like `try_tensor_line_size_parallel` but does not assume 1 is supported
Expand Down
8 changes: 4 additions & 4 deletions crates/cubecl-core/src/runtime_tests/line.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ pub fn kernel_line_index<F: Float>(output: &mut Array<F>, #[comptime] line_size:

#[allow(clippy::needless_range_loop)]
pub fn test_line_index<R: Runtime, F: Float + CubeElement>(client: ComputeClient<R>) {
for line_size in client.io_optimized_line_sizes(&F::as_type_native().unwrap()) {
for line_size in client.io_optimized_line_sizes(size_of::<F>()) {
if line_size < 4 {
continue;
}
Expand Down Expand Up @@ -50,7 +50,7 @@ pub fn kernel_line_index_assign<F: Float>(output: &mut Array<Line<F>>) {
}

pub fn test_line_index_assign<R: Runtime, F: Float + CubeElement>(client: ComputeClient<R>) {
for line_size in client.io_optimized_line_sizes(&F::as_type_native().unwrap()) {
for line_size in client.io_optimized_line_sizes(size_of::<F>()) {
let handle = client.create_from_slice(F::as_bytes(&vec![F::new(0.0); line_size]));
unsafe {
kernel_line_index_assign::launch_unchecked::<F, R>(
Expand Down Expand Up @@ -88,7 +88,7 @@ pub fn kernel_line_loop_unroll<F: Float>(
}

pub fn test_line_loop_unroll<R: Runtime, F: Float + CubeElement>(client: ComputeClient<R>) {
for line_size in client.io_optimized_line_sizes(&F::as_type_native_unchecked()) {
for line_size in client.io_optimized_line_sizes(size_of::<F>()) {
let handle = client.create_from_slice(F::as_bytes(&vec![F::new(0.0); line_size]));
unsafe {
kernel_line_loop_unroll::launch_unchecked::<F, R>(
Expand Down Expand Up @@ -171,7 +171,7 @@ pub fn kernel_shared_memory<F: Float>(output: &mut Array<Line<F>>) {
}

pub fn test_shared_memory<R: Runtime, F: Float + CubeElement>(client: ComputeClient<R>) {
for line_size in client.io_optimized_line_sizes(&F::as_type_native().unwrap()) {
for line_size in client.io_optimized_line_sizes(size_of::<F>()) {
let output = client.create_from_slice(F::as_bytes(&vec![F::new(0.0); line_size]));
unsafe {
kernel_shared_memory::launch_unchecked::<F, R>(
Expand Down
2 changes: 1 addition & 1 deletion crates/cubecl-core/src/runtime_tests/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pub fn test_tensor_coordinate<R: Runtime>(client: ComputeClient<R>) {
let output_size = shape.len() * input_size;

// The result is independent of the line size
for &line_size in R::supported_line_sizes() {
for line_size in client.io_optimized_line_sizes(size_of::<f32>()) {
let output = client.empty(core::mem::size_of::<u32>() * output_size);
unsafe {
tensor_coordinate::launch(
Expand Down
9 changes: 5 additions & 4 deletions crates/cubecl-core/src/runtime_tests/tensormap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use cubecl_runtime::{
server::{Allocation, ComputeServer, CopyDescriptor},
storage::ComputeStorage,
};
use cubecl_zspace::shape;
use std::fmt::Debug;

#[cube(launch)]
Expand Down Expand Up @@ -129,7 +130,7 @@ where
CubeDim::new_2d(32, 16),
TensorMapArg::new(
TiledArgs {
tile_size: vec![16, 32],
tile_size: shape![16, 32],
},
input,
F::as_type_native_unchecked(),
Expand Down Expand Up @@ -173,7 +174,7 @@ where
unsafe { ArrayArg::from_raw_parts::<F>(&handle, 32 * 16, 1) },
TensorMapArg::new(
TiledArgs {
tile_size: vec![16, 32],
tile_size: shape![16, 32],
},
unsafe { TensorArg::from_raw_parts::<F>(&out.handle, &out.strides, &[64, 64], 1) },
F::as_type_native_unchecked(),
Expand Down Expand Up @@ -316,14 +317,14 @@ where
input_1,
TensorMapArg::new(
TiledArgs {
tile_size: vec![16, 16],
tile_size: shape![16, 16],
},
output_1,
F::as_type_native_unchecked(),
),
TensorMapArg::new(
TiledArgs {
tile_size: vec![16, 32],
tile_size: shape![16, 32],
},
input_2,
F::as_type_native_unchecked(),
Expand Down
5 changes: 3 additions & 2 deletions crates/cubecl-cpu/src/compute/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use cubecl_core::{
ExecutionError, IoError, LaunchError, ProfileError, ProfilingToken, ServerCommunication,
ServerUtilities,
},
zspace::{Strides, strides},
};
use cubecl_runtime::{
compiler::CubeTask,
Expand Down Expand Up @@ -335,9 +336,9 @@ impl ServerCommunication for CpuServer {
const SERVER_COMM_ENABLED: bool = false;
}

pub(crate) fn contiguous_strides(shape: &[usize]) -> Vec<usize> {
pub(crate) fn contiguous_strides(shape: &[usize]) -> Strides {
let rank = shape.len();
let mut strides = vec![1; rank];
let mut strides = strides![1; rank];
for i in (0..rank - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
Expand Down
5 changes: 1 addition & 4 deletions crates/cubecl-cpu/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ impl DeviceState for CpuServer {
num_streaming_multiprocessors: None,
num_tensor_cores: None,
min_tensor_cores_dim: None,
max_line_size: LineSize::MAX,
};

const ALIGNMENT: u64 = 4;
Expand Down Expand Up @@ -95,10 +96,6 @@ impl Runtime for CpuRuntime {
"cpu"
}

fn supported_line_sizes() -> &'static [LineSize] {
&[128, 64, 32, 16, 8, 4, 2, 1]
}

fn max_cube_count() -> (u32, u32, u32) {
(u32::MAX, u32::MAX, u32::MAX)
}
Expand Down
53 changes: 34 additions & 19 deletions crates/cubecl-cuda/src/compute/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use cubecl_core::{
ExecutionError, IoError, LaunchError, ProfileError, ProfilingToken, ServerCommunication,
ServerUtilities, TensorMapBinding, TensorMapMeta,
},
zspace::{Shape, Strides, strides},
};
use cubecl_runtime::{
compiler::CubeTask,
Expand Down Expand Up @@ -109,7 +110,7 @@ impl ComputeServer for CudaServer {
let pitch = width_bytes.next_multiple_of(pitch_align);
let size = height * pitch;
total_size += size.next_multiple_of(self.mem_alignment);
let mut stride = vec![1; rank];
let mut stride = strides![1; rank];
if rank > 1 {
stride[rank - 2] = pitch / descriptor.elem_size;
}
Expand Down Expand Up @@ -245,9 +246,16 @@ impl ComputeServer for CudaServer {

let mut map_ptr = MaybeUninit::zeroed();

let shape: Vec<_> = map.shape.iter().rev().map(|s| *s as u64).collect();
let shape: Vec<_> = map
.metadata
.shape()
.iter()
.rev()
.map(|s| *s as u64)
.collect();
let strides: Vec<_> = map
.strides
.metadata
.strides()
.iter()
.rev()
.skip(1)
Expand All @@ -257,12 +265,13 @@ impl ComputeServer for CudaServer {

match &map.format {
TensorMapFormat::Tiled(TiledArgs { tile_size }) => unsafe {
let tile_size: Vec<_> = tile_size.iter().rev().copied().collect();
let tile_size: Vec<_> =
tile_size.iter().rev().copied().map(|s| s as u32).collect();

cuTensorMapEncodeTiled(
map_ptr.as_mut_ptr(),
elem_to_tensor_map_type(map.storage_ty),
map.rank as u32,
map.metadata.rank() as u32,
device_ptr,
shape.as_ptr(),
strides.as_ptr(),
Expand Down Expand Up @@ -296,7 +305,7 @@ impl ComputeServer for CudaServer {
cuTensorMapEncodeIm2col(
map_ptr.as_mut_ptr(),
elem_to_tensor_map_type(map.storage_ty),
map.rank as u32,
map.metadata.rank() as u32,
device_ptr,
shape.as_ptr(),
strides.as_ptr(),
Expand Down Expand Up @@ -339,7 +348,7 @@ impl ComputeServer for CudaServer {
cuTensorMapEncodeIm2colWide(
map_ptr.as_mut_ptr(),
elem_to_tensor_map_type(map.storage_ty),
map.rank as u32,
map.metadata.rank() as u32,
device_ptr,
shape.as_ptr(),
strides.as_ptr(),
Expand Down Expand Up @@ -532,7 +541,7 @@ impl CudaServer {
stream_id_src: StreamId,
stream_id_dst: StreamId,
) -> Result<Allocation, IoError> {
let strides = src.strides.to_vec();
let strides = src.strides.into();
let binding = src.binding.clone();

let context_src = server_src.ctx.context;
Expand Down Expand Up @@ -590,8 +599,8 @@ impl CudaServer {
stream_id_src: StreamId,
stream_id_dst: StreamId,
) -> Result<Allocation, IoError> {
let shape = src.shape.to_vec();
let strides = src.strides.to_vec();
let shape: Shape = src.shape.into();
let strides: Strides = src.strides.into();
let elem_size = src.elem_size;
let binding = src.binding.clone();
let num_bytes = shape.iter().product::<usize>() * elem_size;
Expand Down Expand Up @@ -826,9 +835,12 @@ fn check_tma_generic(
}

// tensorRank invariants
launch_check!((1..=5).contains(&map.rank), "Rank must be between 1 and 5")?;
launch_check!(
matches!(map.interleave, TensorMapInterleave::None) || map.rank >= 3,
(1..=5).contains(&map.metadata.rank()),
"Rank must be between 1 and 5"
)?;
launch_check!(
matches!(map.interleave, TensorMapInterleave::None) || map.metadata.rank() >= 3,
"When interleave is enabled, rank must be >= 3"
)?;

Expand Down Expand Up @@ -881,7 +893,10 @@ fn check_tma_generic(
}

fn check_tma_tiled(map: &TensorMapMeta, tile_size: &[u32]) -> Result<(), LaunchError> {
launch_check!(tile_size.len() == map.rank, "Tile shape should match rank")?;
launch_check!(
tile_size.len() == map.metadata.rank(),
"Tile shape should match rank"
)?;
launch_check!(
tile_size.iter().all(|it| *it > 0 && *it <= 256),
"Tile shape must be non-zero and <= 256"
Expand Down Expand Up @@ -920,20 +935,20 @@ fn check_tma_im2col(
pixels_per_column: u32,
) -> Result<(), LaunchError> {
launch_check!(
lower_corner.len() == map.rank - 2,
lower_corner.len() == map.metadata.rank() - 2,
"Lower corner must be rank - 2 elements"
)?;
launch_check!(
upper_corner.len() == map.rank - 2,
upper_corner.len() == map.metadata.rank() - 2,
"Upper corner must be rank - 2 elements"
)?;

launch_check!(
map.rank >= 3 && map.rank <= 5,
map.metadata.rank() >= 3 && map.metadata.rank() <= 5,
"im2col requires rank to be between 3 and 5"
)?;

let (range_lower, range_upper) = match map.rank {
let (range_lower, range_upper) = match map.metadata.rank() {
3 => (-32768, 32767),
4 => (-128, 127),
5 => (-16, 15),
Expand All @@ -944,14 +959,14 @@ fn check_tma_im2col(
.iter()
.all(|it| *it >= range_lower && *it <= range_upper),
"Lower corner must be in range [{range_lower}, {range_upper}] for {}D im2col",
map.rank
map.metadata.rank()
)?;
launch_check!(
upper_corner
.iter()
.all(|it| *it >= range_lower && *it <= range_upper),
"Upper corner must be in range [{range_lower}, {range_upper}] for {}D im2col",
map.rank
map.metadata.rank()
)?;

launch_check!(
Expand Down
5 changes: 1 addition & 4 deletions crates/cubecl-cuda/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ impl DeviceState for CudaServer {
Some(8)
},
num_cpu_cores: None,
max_line_size: LineSize::MAX,
}
};

Expand Down Expand Up @@ -320,10 +321,6 @@ impl Runtime for CudaRuntime {
true
}

fn supported_line_sizes() -> &'static [LineSize] {
&[16, 8, 4, 2, 1]
}

fn max_cube_count() -> (u32, u32, u32) {
(i32::MAX as u32, u16::MAX as u32, u16::MAX as u32)
}
Expand Down
Loading
Loading