Skip to content
Draft

Sort #77

Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
456d1a9
WIP baseline (single threaded)
ArthurBrussee Jan 29, 2026
646f997
Partially parallel
ArthurBrussee Jan 29, 2026
9dab244
WIP
ArthurBrussee Jan 29, 2026
09fa433
Proper paralellization and f32/i32 support
ArthurBrussee Jan 29, 2026
7f9f843
Add benches
ArthurBrussee Jan 29, 2026
97d346d
Cleanup
ArthurBrussee Jan 29, 2026
fab22c4
Small speedup
ArthurBrussee Jan 29, 2026
945d6a3
Coalesced shared mem
ArthurBrussee Jan 29, 2026
17e60f6
Speedup
ArthurBrussee Jan 30, 2026
28059b6
Cached digit
ArthurBrussee Jan 30, 2026
ed9b3f9
Misc speedups
ArthurBrussee Jan 30, 2026
610f3b9
Better benchmark
ArthurBrussee Jan 30, 2026
e900584
Speedups
ArthurBrussee Jan 30, 2026
cb93627
Moar faster
ArthurBrussee Jan 30, 2026
934c017
~100GB/s
ArthurBrussee Jan 30, 2026
86deb42
Dont cache digit after all
ArthurBrussee Jan 30, 2026
5d9be94
Improve speed, measure batched perf
ArthurBrussee Jan 30, 2026
6869e87
Use plane_id and local CubeCL for now
ArthurBrussee Jan 30, 2026
02c4961
Cleanup
ArthurBrussee Jan 31, 2026
9dc514a
Cleanup, cleanup tests, support more types
ArthurBrussee Feb 2, 2026
df4c188
Cleanup bench, just measure keys/s, fix CUDA crash, add test for OOB
ArthurBrussee Feb 2, 2026
23c3f25
Fixes for CUDA (shared mem init), cleanup benches
ArthurBrussee Feb 2, 2026
9accdc5
Remove a test
ArthurBrussee Feb 2, 2026
584f9b3
Cleanup
ArthurBrussee Feb 2, 2026
2446861
Refactor cubek-sort to support implicit indices
ArthurBrussee Feb 3, 2026
bcbf1a8
Small cleanups in kernel
ArthurBrussee Feb 3, 2026
89d4423
Simplify scan kernel a bit
ArthurBrussee Feb 3, 2026
3c1dd94
Skip mem init in scan kernel
ArthurBrussee Feb 3, 2026
32be390
Remove unneeded write to g_scane
ArthurBrussee Feb 3, 2026
c0ab7e2
Cleanuo & remove some unneeded args
ArthurBrussee Feb 11, 2026
d003a48
Add support for different value sizes, cleanup
ArthurBrussee Feb 13, 2026
34ae8bc
Some more cleanup of the kernels
ArthurBrussee Feb 13, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@ version = "0.1.0"

[workspace.dependencies]
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
cubecl-std = { path = "../cubecl/crates/cubecl-std", default-features = false }

### For the main cubek branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", rev = "a3987c60e410d77b6b38e99d1c2a6bc0979d38ba", default-features = false }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", rev = "a3987c60e410d77b6b38e99d1c2a6bc0979d38ba", default-features = false }
# cubecl = { git = "https://github.com/tracel-ai/cubecl", rev = "a3987c60e410d77b6b38e99d1c2a6bc0979d38ba", default-features = false }
# cubecl-common = { git = "https://github.com/tracel-ai/cubecl", rev = "a3987c60e410d77b6b38e99d1c2a6bc0979d38ba", default-features = false }

### For releases ###
# cubecl = { version = "=0.9.0", default-features = false }
Expand Down
8 changes: 8 additions & 0 deletions benchmarks/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ repository = "https://github.com/tracel-ai/cubek"
version.workspace = true

[features]
default = []
cuda = ["cubecl/cuda"]

[dependencies]
cubecl = { workspace = true, features = ["test-runtime", "stdlib"] }
Expand All @@ -23,8 +25,10 @@ cubek = { path = "../crates/cubek", version = "=0.1.0", default-features = false
"matmul",
"convolution",
"attention",
"sort",
] }
half = { workspace = true }
rand = { workspace = true }

[[bench]]
harness = false
Expand Down Expand Up @@ -53,3 +57,7 @@ name = "attention"
[[bench]]
harness = false
name = "contiguous"

[[bench]]
harness = false
name = "sort"
151 changes: 151 additions & 0 deletions benchmarks/benches/sort.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
//! Benchmarks for cubek-sort.

use cubecl::{
benchmark::{Benchmark, BenchmarkComputations, TimingMethod},
future,
prelude::*,
server::Handle,
};
use cubek::sort::{SortKey, SortOrder, SortValues, sort};
use std::marker::PhantomData;
use std::time::Duration;

const NUM_SAMPLES: usize = 100;

struct SortBench<R: Runtime, K> {
num_items: usize,
client: ComputeClient<R>,
_key: PhantomData<K>,
}

impl<R: Runtime, K> Benchmark for SortBench<R, K>
where
K: SortKey + CubeElement + Numeric,
K::Radix: SortKey<Radix = K::Radix>,
{
type Input = Handle;
type Output = Handle;

fn prepare(&self) -> Self::Input {
// Sequential reversed data, wrapping for smaller types
let elem_bytes = std::mem::size_of::<K>();
let max_val = if elem_bytes >= 8 {
u64::MAX
} else {
(1u64 << (elem_bytes * 8)) - 1
};
let data: Vec<K> = (0..self.num_items as u64)
.rev()
.map(|i| K::from_int((i % (max_val + 1)) as i64))
.collect();
self.client.create_from_slice(K::as_bytes(&data))
}

fn name(&self) -> String {
format!(
"sort-{}-{}-{}",
std::any::type_name::<K>().split("::").last().unwrap_or("?"),
format_size(self.num_items),
R::name(&self.client),
)
.to_lowercase()
}

fn sync(&self) {
future::block_on(self.client.sync()).expect("sync failed")
}

fn execute(&self, input: Self::Input) -> Result<Self::Output, String> {
let shape = [self.num_items];
let strides = [1];

let input_ref = unsafe {
TensorHandleRef::from_raw_parts(&input, &strides, &shape, std::mem::size_of::<K>())
};

let output = sort::<R, K>(
&self.client,
input_ref,
SortValues::None,
self.num_items,
SortOrder::Ascending,
)
.map_err(|e| format!("Sort failed: {:?}", e))?;

Ok(output.keys)
}

fn num_samples(&self) -> usize {
NUM_SAMPLES
}
}

fn format_size(n: usize) -> String {
if n >= 1_000_000 {
format!("{}m", n / 1_000_000)
} else if n >= 1_000 {
format!("{}k", n / 1_000)
} else {
format!("{}", n)
}
}

fn run<R: Runtime, K>(client: &ComputeClient<R>, sizes: &[usize])
where
K: SortKey + CubeElement + Numeric,
K::Radix: SortKey<Radix = K::Radix>,
{
let key_bits = std::mem::size_of::<K>() * 8;
println!("--- {key_bits}-bit keys ---");

for &size in sizes {
let bench = SortBench::<R, K> {
num_items: size,
client: client.clone(),
_key: PhantomData,
};

match bench.run(TimingMethod::System) {
Ok(bench_durations) => {
let durations = &bench_durations.durations;
let computed = BenchmarkComputations::new(&bench_durations);

let total_time: Duration = durations.iter().sum();
let mean: Duration = total_time / durations.len() as u32;

let duration_sec = total_time.as_secs_f64();
let keys_per_sec = (size as f64) / duration_sec * (NUM_SAMPLES as f64);

println!(
"{:>4}M: {:.2E} keys/sec (mean={:.2}ms min={:.2}ms max={:.2}ms)",
size / 1_000_000,
keys_per_sec,
mean.as_secs_f64() * 1000.0,
computed.min.as_secs_f64() * 1000.0,
computed.max.as_secs_f64() * 1000.0,
);
}
Err(e) => {
println!("{:>4}M: Failed - {}", size / 1_000_000, e);
}
}
}
println!();
}

fn main() {
use cubecl::ir::{ElemType, UIntKind};

let client = cubecl::TestRuntime::client(&Default::default());

let sizes: Vec<usize> = vec![
1 << 24, // 16M
1 << 25, // 32M
1 << 26, // 64M
];

run::<cubecl::TestRuntime, u32>(&client, &sizes);
if client.properties().features.supports_type(ElemType::UInt(UIntKind::U16)) {
run::<cubecl::TestRuntime, u16>(&client, &sizes);
}
}
32 changes: 32 additions & 0 deletions crates/cubek-sort/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
[package]
authors = [
"nathanielsimard <nathaniel.simard.42@gmail.com>",
"arthurbrussee <arthur.brussee@gmail.com>",
]
categories = ["science", "mathematics", "algorithms"]
description = "CubeK: Sorting Kernels"
edition.workspace = true
keywords = []
license.workspace = true
name = "cubek-sort"
readme.workspace = true
repository = "https://github.com/tracel-ai/cubek/tree/main/crates/cubek-sort"
version.workspace = true

[features]
default = ["std", "cubecl/default"]
std = ["cubecl/std", "thiserror/std"]

[dependencies]
cubecl = { workspace = true }
cubecl-std = { workspace = true }
half = { workspace = true }
thiserror = { workspace = true }

num-traits = "0.2.19"
pretty_assertions = { workspace = true, optional = true }

[dev-dependencies]
cubecl = { workspace = true, features = ["test-runtime"] }
cubecl-common = { workspace = true }
rand = { workspace = true }
20 changes: 20 additions & 0 deletions crates/cubek-sort/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# CubeK Sort

Implements a portal radix sort that is hardware agnostic, and supports key/value pairs.

## Implementation

The implementation is based on the device radix sort from [b0nes164](https://github.com/b0nes164/GPUSorting). While the single pass decoupled look back would be faster, it requires forward progress guarantees, which is not guaranteed by all runtimes. This could be added in future versions as variant of the kernels.

## Other features

- Radix sorting is a _stable_ sorting algorithm, which means values with the same key are preserved in their original order.
- Supports sorting of key/value pairs.
- Supports sorting floating point values as well as integers.

## Resources

https://gpuopen.com/learn/boosting_gpu_radix_sort/
https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back
https://linebender.org/wiki/gpu/sorting/
https://nvidia.github.io/cccl/cub/api/structcub_1_1DeviceRadixSort.html
56 changes: 56 additions & 0 deletions crates/cubek-sort/src/components/config.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
pub const RADIX_BITS: usize = 8;
pub const NUM_BUCKETS: usize = 1 << RADIX_BITS;

#[derive(Clone, Debug)]
pub struct SortStrategy {
pub items_per_thread: u32,
pub threads_per_block: u32,
}

impl Default for SortStrategy {
fn default() -> Self {
// 512 threads × 16 items = 8192 items per block
// Larger blocks reduce scan overhead at the cost of occupancy
Self {
items_per_thread: 16,
threads_per_block: 512,
}
}
}

// The current SortStrategy is a very simple heuristic. Instead we can probably use autotuning
// in Burn to figure out settings. Eg. we could choose a low, mid and high blocksize that we tune.
// Alternatively we figure out a more principled way to set these.
impl SortStrategy {
/// Create a strategy optimized for keys-only sorting at the given input size.
pub fn for_keys(num_items: usize) -> Self {
if num_items < 4_000_000 {
// Small inputs: 256 threads × 15 items = 3840 items/block
// Smaller blocks improve GPU occupancy
Self {
items_per_thread: 15,
threads_per_block: 256,
}
} else {
// Large inputs: use default (8192 items/block)
Self::default()
}
}

/// Create a strategy optimized for key-value pair sorting at the given input size.
pub fn for_pairs(_num_items: usize) -> Self {
Self::default()
}

pub fn items_per_block(&self) -> u32 {
self.items_per_thread * self.threads_per_block
}

pub fn num_blocks(&self, num_items: usize) -> u32 {
num_items.div_ceil(self.items_per_block() as usize) as u32
}

pub fn num_planes(&self, plane_dim: u32) -> u32 {
self.threads_per_block.div_ceil(plane_dim)
}
}
Loading