diff --git a/crates/cubecl-cuda/Cargo.toml b/crates/cubecl-cuda/Cargo.toml index 65e10885e1..cca4beeaec 100644 --- a/crates/cubecl-cuda/Cargo.toml +++ b/crates/cubecl-cuda/Cargo.toml @@ -78,6 +78,9 @@ matmul_tests_tma = ["cubecl-matmul/matmul_tests_tma"] matmul_tests_unit = ["cubecl-matmul/matmul_tests_unit"] matmul_tests_vecmat = ["cubecl-matmul/matmul_tests_vecmat"] +cu-nccl = ["cubecl-runtime/plugin", "cudarc/nccl"] +plugin = ["cubecl-runtime/plugin"] + [dependencies] cubecl-common = { path = "../cubecl-common", version = "=0.9.0-pre.3", default-features = false, features = [ "cache", diff --git a/crates/cubecl-cuda/src/compute/mod.rs b/crates/cubecl-cuda/src/compute/mod.rs index 0b53fb7a58..9af69214f5 100644 --- a/crates/cubecl-cuda/src/compute/mod.rs +++ b/crates/cubecl-cuda/src/compute/mod.rs @@ -5,6 +5,9 @@ pub(crate) mod storage; pub(crate) mod stream; pub(crate) mod sync; +#[cfg(feature = "cu-nccl")] +pub mod nccl; + mod server; pub use server::*; diff --git a/crates/cubecl-cuda/src/compute/nccl.rs b/crates/cubecl-cuda/src/compute/nccl.rs new file mode 100644 index 0000000000..5b88adf3bc --- /dev/null +++ b/crates/cubecl-cuda/src/compute/nccl.rs @@ -0,0 +1,790 @@ +#![allow(dead_code)] +#![allow(missing_docs)] +use std::mem::MaybeUninit; + +use cubecl_core::server::Binding; +use cubecl_runtime::plugin::{Plugin, PluginError, PluginType}; +use cudarc::nccl::sys::{ + ncclComm_t, ncclDataType_t, ncclRedOp_t, ncclScalarResidence_t, ncclUniqueId, +}; + +#[derive(Debug)] +pub struct NcclComm(*mut cudarc::nccl::sys::ncclComm); + +unsafe impl Send for NcclComm {} +unsafe impl Sync for NcclComm {} + +impl NcclComm { + pub fn as_ptr(&self) -> ncclComm_t { + self.0 + } +} + +impl Drop for NcclComm { + fn drop(&mut self) { + unsafe { + if !self.0.is_null() { + cudarc::nccl::result::comm_destroy(self.0).ok(); + } + } + } +} + +pub struct NcclExtension; + +/// The types are defined which are used by client and server end. +impl Plugin for NcclExtension { + type ClientHandle = NcclClientHandle; + type ServerHandle = NcclServerHandle; + type InitType = NcclInit; + type ReturnVal = (); + type Fns = fn(Self::ServerHandle) -> Result<(), PluginError>; + const EXTENSION_NAME: &'static str = "cuda_nccl"; +} + +/// Info needed to initialize a NcclComm. +#[derive(Debug, Clone)] +pub struct NcclInit { + /// Needs to be id = Device::device_count_total() - 1 + pub id: i32, + /// Device::device_count_total() + pub dev_count: i32, + /// cudarc::nccl::result::get_uniqueid().unwrap() + pub uid: ncclUniqueId, +} + +/// Here the `PluginType` is used to initialize `Nccl`. +/// `NcclComm` is used as the `Insert` type and gets injected +/// into `CudaServer` when called through `ComputeClient`'s new `plugin_init()` +impl PluginType for NcclInit { + type Insert = NcclComm; + + /// Function used to generate the Insert type + fn init(self) -> Self::Insert { + let mut comm = MaybeUninit::uninit(); + let comm = unsafe { + cudarc::nccl::result::comm_init_rank( + comm.as_mut_ptr(), + self.dev_count, + self.uid, + self.id, + ) + .unwrap(); + comm.assume_init() + }; + NcclComm(comm) + } +} + +#[derive(new, Debug, Clone)] +pub struct NcclClientHandle { + /// For example with broadcast is no input needed. + /// Thus resulting in the use of `Option`. + pub input: Option, + /// Also `Option` for send and receive. + pub output: Option, + /// Device::device_count_total() + pub device_count: usize, + /// cudarc::nccl::result::get_uniqueid().unwrap() + pub nccl_type: ncclDataType_t, +} + +unsafe impl Send for NcclServerHandle {} + +/// This struct will be constructed by `CudaServer`, +/// when `client.plugin_fn(c: NcclClientHandle)` +pub struct NcclServerHandle { + pub input: Option<*const ::core::ffi::c_void>, + pub output: Option<*mut std::ffi::c_void>, + pub dev_count: usize, + pub ty: ncclDataType_t, + pub comm: ncclComm_t, + pub stream: cudarc::driver::sys::CUstream, +} + +impl NcclServerHandle { + pub fn all_reduce(self, op: ReduceOp) -> Result<(), PluginError> { + unsafe { + cudarc::nccl::result::all_reduce( + self.input + .ok_or_else(|| PluginError::InvalidHandle("Input required".into()))?, + self.output + .ok_or_else(|| PluginError::InvalidHandle("Output required".into()))?, + self.dev_count, + self.ty, + op.convert(), + self.comm, + self.stream as *mut cudarc::nccl::sys::CUstream_st, + ) + .map(|_| ()) + .map_err(|e| PluginError::ExecutionFailed(format!("all_reduce: {:?}", e))) + } + } + + pub fn broadcast(self, root: i32) -> Result<(), PluginError> { + unsafe { + cudarc::nccl::result::broadcast( + self.input + .ok_or_else(|| PluginError::InvalidHandle("Input required".into()))?, + self.output + .ok_or_else(|| PluginError::InvalidHandle("Output required".into()))?, + self.dev_count, + self.ty, + root, + self.comm, + self.stream as *mut cudarc::nccl::sys::CUstream_st, + ) + .map(|_| ()) + .map_err(|e| PluginError::ExecutionFailed(format!("broadcast: {:?}", e))) + } + } + + pub fn reduce(self, op: ReduceOp, root: i32) -> Result<(), PluginError> { + unsafe { + cudarc::nccl::result::reduce( + self.input + .ok_or_else(|| PluginError::InvalidHandle("Input required".into()))?, + self.output + .ok_or_else(|| PluginError::InvalidHandle("Output required".into()))?, + self.dev_count, + self.ty, + op.convert(), + root, + self.comm, + self.stream as *mut cudarc::nccl::sys::CUstream_st, + ) + .map(|_| ()) + .map_err(|e| PluginError::ExecutionFailed(format!("reduce: {:?}", e))) + } + } + + pub fn reduce_scatter(self, op: ReduceOp) -> Result<(), PluginError> { + unsafe { + cudarc::nccl::result::reduce_scatter( + self.input + .ok_or_else(|| PluginError::InvalidHandle("Input required".into()))?, + self.output + .ok_or_else(|| PluginError::InvalidHandle("Output required".into()))?, + self.dev_count, + self.ty, + op.convert(), + self.comm, + self.stream as *mut cudarc::nccl::sys::CUstream_st, + ) + .map(|_| ()) + .map_err(|e| PluginError::ExecutionFailed(format!("reduce_scatter: {:?}", e))) + } + } + + pub fn all_gather(self) -> Result<(), PluginError> { + unsafe { + cudarc::nccl::result::all_gather( + self.input + .ok_or_else(|| PluginError::InvalidHandle("Input required".into()))?, + self.output + .ok_or_else(|| PluginError::InvalidHandle("Output required".into()))?, + self.dev_count, + self.ty, + self.comm, + self.stream as *mut cudarc::nccl::sys::CUstream_st, + ) + .map(|_| ()) + .map_err(|e| PluginError::ExecutionFailed(format!("all_gather: {:?}", e))) + } + } + + pub fn send(self, peer: i32) -> Result<(), PluginError> { + unsafe { + cudarc::nccl::result::send( + self.input + .ok_or_else(|| PluginError::InvalidHandle("Input required".into()))?, + self.dev_count, + self.ty, + peer, + self.comm, + self.stream as *mut cudarc::nccl::sys::CUstream_st, + ) + .map(|_| ()) + .map_err(|e| PluginError::ExecutionFailed(format!("send: {:?}", e))) + } + } + + pub fn recv(self, peer: i32) -> Result<(), PluginError> { + unsafe { + cudarc::nccl::result::recv( + self.output + .ok_or_else(|| PluginError::InvalidHandle("Output required".into()))?, + self.dev_count, + self.ty, + peer, + self.comm, + self.stream as *mut cudarc::nccl::sys::CUstream_st, + ) + .map(|_| ()) + .map_err(|e| PluginError::ExecutionFailed(format!("recv: {:?}", e))) + } + } + + pub fn create_custom_pre_mul_sum( + scalar: *mut ::core::ffi::c_void, + datatype: ncclDataType_t, + residence: ncclScalarResidence_t, + comm: ncclComm_t, + ) -> Result { + unsafe { + let mut op = MaybeUninit::uninit(); + cudarc::nccl::result::reduce_op_create_pre_mul_sum( + op.as_mut_ptr(), + scalar, + datatype, + residence, + comm, + ) + .map(|_| op.assume_init()) + .map_err(|e| { + PluginError::ExecutionFailed(format!("create_custom_pre_mul_sum: {:?}", e)) + }) + } + } + + pub fn destroy_custom_op(op: ncclRedOp_t, comm: ncclComm_t) -> Result<(), PluginError> { + unsafe { + cudarc::nccl::result::reduce_op_destroy(op, comm) + .map(|_| ()) + .map_err(|e| PluginError::ExecutionFailed(format!("destroy_custom_op: {:?}", e))) + } + } +} + +#[derive(Debug, Clone, Copy)] +/// Reduce operation type. +pub enum ReduceOp { + Sum, + Prod, + Max, + Min, + Avg, + Custom(ncclRedOp_t), +} + +impl ReduceOp { + fn convert(&self) -> ncclRedOp_t { + match self { + ReduceOp::Sum => cudarc::nccl::sys::ncclRedOp_t::ncclSum, + ReduceOp::Prod => cudarc::nccl::sys::ncclRedOp_t::ncclProd, + ReduceOp::Max => cudarc::nccl::sys::ncclRedOp_t::ncclMax, + ReduceOp::Min => cudarc::nccl::sys::ncclRedOp_t::ncclMin, + ReduceOp::Avg => cudarc::nccl::sys::ncclRedOp_t::ncclAvg, + ReduceOp::Custom(op) => *op, + } + } +} + +#[derive(Debug, Clone, Copy)] +enum ScalarResidence { + Device, + Host, +} + +impl ScalarResidence { + fn to_nccl(&self) -> ncclScalarResidence_t { + match self { + ScalarResidence::Device => cudarc::nccl::sys::ncclScalarResidence_t::ncclScalarDevice, + ScalarResidence::Host => { + cudarc::nccl::sys::ncclScalarResidence_t::ncclScalarHostImmediate + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::compute::CudaServer; + use crate::{CudaDevice, CudaRuntime}; + use cubecl_common::device::Device; + use cubecl_core::prelude::*; + use cubecl_runtime::client::ComputeClient; + use std::sync::Mutex; + + static NCCL_TEST_LOCK: Mutex<()> = Mutex::new(()); + + fn has_multi_gpu() -> bool { + let dev_count = CudaDevice::device_count_total(); + println!("{}", dev_count); + dev_count >= 2 + } + + fn init_nccl_pair() -> Option<( + ComputeClient, + ComputeClient, + cudarc::nccl::sys::ncclUniqueId, + )> { + if !has_multi_gpu() { + println!("Skipping test: requires at least 2 GPUs"); + return None; + } + + let device0 = CudaDevice::new(0); + let device1 = CudaDevice::new(1); + + let client0 = CudaRuntime::client(&device0); + let client1 = CudaRuntime::client(&device1); + + let uid = cudarc::nccl::result::get_uniqueid().unwrap(); + + Some((client0, client1, uid)) + } + + fn setup_nccl_communicators( + client0: &ComputeClient, + client1: &ComputeClient, + uid: cudarc::nccl::sys::ncclUniqueId, + ) { + let init0 = NcclInit { + id: 0, + dev_count: 2, + uid, + }; + let init1 = NcclInit { + id: 1, + dev_count: 2, + uid, + }; + + std::thread::scope(|s| { + let h0 = s.spawn(|| { + client0 + .plugin_init::(init0) + .expect("Failed to initialize NCCL on device 0"); + }); + + let h1 = s.spawn(|| { + client1 + .plugin_init::(init1) + .expect("Failed to initialize NCCL on device 1"); + }); + + h0.join().unwrap(); + h1.join().unwrap(); + }); + } + + fn run_nccl_op( + client0: &ComputeClient, + client1: &ComputeClient, + uid: cudarc::nccl::sys::ncclUniqueId, + handle0: NcclClientHandle, + handle1: NcclClientHandle, + op: fn(NcclServerHandle) -> Result<(), PluginError>, + ) { + let init0 = NcclInit { + id: 0, + dev_count: 2, + uid, + }; + let init1 = NcclInit { + id: 1, + dev_count: 2, + uid, + }; + + std::thread::scope(|s| { + let h0 = s.spawn(move || { + client0 + .plugin_init::(init0) + .expect("Failed to initialize NCCL on device 0"); + client0 + .plugin_fn::(handle0, op) + .expect("NCCL operation on GPU 0 failed"); + }); + + let h1 = s.spawn(move || { + client1 + .plugin_init::(init1) + .expect("Failed to initialize NCCL on device 1"); + client1 + .plugin_fn::(handle1, op) + .expect("NCCL operation on GPU 1 failed"); + }); + + h0.join().unwrap(); + h1.join().unwrap(); + }); + } + + #[test] + fn test_nccl_all_reduce_sum_f32() { + let _lock = NCCL_TEST_LOCK.lock().unwrap(); + let Some((client0, client1, uid)) = init_nccl_pair() else { + return; + }; + + let data0: Vec = vec![1.0, 2.0, 3.0, 4.0]; + let data1: Vec = vec![5.0, 6.0, 7.0, 8.0]; + + let expected: Vec = vec![6.0, 8.0, 10.0, 12.0]; + + let input0 = client0.create_from_slice(bytemuck::cast_slice(&data0)); + let output0 = client0.empty(data0.len() * std::mem::size_of::()); + + let input1 = client1.create_from_slice(bytemuck::cast_slice(&data1)); + let output1 = client1.empty(data1.len() * std::mem::size_of::()); + + let handle0 = NcclClientHandle::new( + Some(input0.clone().binding()), + Some(output0.clone().binding()), + data0.len(), + ncclDataType_t::ncclFloat32, + ); + + let handle1 = NcclClientHandle::new( + Some(input1.clone().binding()), + Some(output1.clone().binding()), + data1.len(), + ncclDataType_t::ncclFloat32, + ); + + let all_reduce_sum = |handle: NcclServerHandle| -> Result<(), PluginError> { + handle.all_reduce(ReduceOp::Sum) + }; + + run_nccl_op(&client0, &client1, uid, handle0, handle1, all_reduce_sum); + + cubecl_common::reader::read_sync(client0.sync()); + cubecl_common::reader::read_sync(client1.sync()); + + let result0_bytes = client0.read(vec![output0]); + let result1_bytes = client1.read(vec![output1]); + + let result0: Vec = bytemuck::cast_slice(&result0_bytes[0]).to_vec(); + let result1: Vec = bytemuck::cast_slice(&result1_bytes[0]).to_vec(); + + assert_eq!(result0, expected, "GPU 0 result mismatch"); + assert_eq!(result1, expected, "GPU 1 result mismatch"); + } + + #[test] + fn test_nccl_broadcast() { + let _lock = NCCL_TEST_LOCK.lock().unwrap(); + let Some((client0, client1, uid)) = init_nccl_pair() else { + return; + }; + + let data0: Vec = vec![10.0, 20.0, 30.0, 40.0]; + let data1: Vec = vec![0.0, 0.0, 0.0, 0.0]; + + let input0 = client0.create_from_slice(bytemuck::cast_slice(&data0)); + let output0 = client0.empty(data0.len() * std::mem::size_of::()); + + let input1 = client1.create_from_slice(bytemuck::cast_slice(&data1)); + let output1 = client1.empty(data1.len() * std::mem::size_of::()); + + let handle0 = NcclClientHandle::new( + Some(input0.clone().binding()), + Some(output0.clone().binding()), + data0.len(), + ncclDataType_t::ncclFloat32, + ); + + let handle1 = NcclClientHandle::new( + Some(input1.clone().binding()), + Some(output1.clone().binding()), + data1.len(), + ncclDataType_t::ncclFloat32, + ); + + let broadcast_root_0 = + |handle: NcclServerHandle| -> Result<(), PluginError> { handle.broadcast(0) }; + + run_nccl_op(&client0, &client1, uid, handle0, handle1, broadcast_root_0); + + cubecl_common::reader::read_sync(client0.sync()); + cubecl_common::reader::read_sync(client1.sync()); + + let result0_bytes = client0.read(vec![output0]); + let result1_bytes = client1.read(vec![output1]); + + let result0: Vec = bytemuck::cast_slice(&result0_bytes[0]).to_vec(); + let result1: Vec = bytemuck::cast_slice(&result1_bytes[0]).to_vec(); + + assert_eq!(result0, data0, "GPU 0 result mismatch"); + assert_eq!(result1, data0, "GPU 1 result mismatch"); + } + + #[test] + fn test_nccl_reduce() { + let _lock = NCCL_TEST_LOCK.lock().unwrap(); + let Some((client0, client1, uid)) = init_nccl_pair() else { + return; + }; + + let data0: Vec = vec![1.0, 2.0, 3.0, 4.0]; + let data1: Vec = vec![5.0, 6.0, 7.0, 8.0]; + + let expected: Vec = vec![6.0, 8.0, 10.0, 12.0]; + + let input0 = client0.create_from_slice(bytemuck::cast_slice(&data0)); + let output0 = client0.empty(data0.len() * std::mem::size_of::()); + + let input1 = client1.create_from_slice(bytemuck::cast_slice(&data1)); + let output1 = client1.empty(data1.len() * std::mem::size_of::()); + + let handle0 = NcclClientHandle::new( + Some(input0.clone().binding()), + Some(output0.clone().binding()), + data0.len(), + ncclDataType_t::ncclFloat32, + ); + + let handle1 = NcclClientHandle::new( + Some(input1.clone().binding()), + Some(output1.clone().binding()), + data1.len(), + ncclDataType_t::ncclFloat32, + ); + + let reduce_sum_root_0 = |handle: NcclServerHandle| -> Result<(), PluginError> { + handle.reduce(ReduceOp::Sum, 0) + }; + + run_nccl_op(&client0, &client1, uid, handle0, handle1, reduce_sum_root_0); + + cubecl_common::reader::read_sync(client0.sync()); + cubecl_common::reader::read_sync(client1.sync()); + + let result0_bytes = client0.read(vec![output0]); + let result0: Vec = bytemuck::cast_slice(&result0_bytes[0]).to_vec(); + + assert_eq!(result0, expected, "GPU 0 (root) result mismatch"); + } + + #[test] + fn test_nccl_reduce_scatter() { + let _lock = NCCL_TEST_LOCK.lock().unwrap(); + let Some((client0, client1, uid)) = init_nccl_pair() else { + return; + }; + + let data0: Vec = vec![1.0, 2.0, 3.0, 4.0]; + let data1: Vec = vec![5.0, 6.0, 7.0, 8.0]; + + let expected0: Vec = vec![6.0, 8.0]; + let expected1: Vec = vec![10.0, 12.0]; + + let input0 = client0.create_from_slice(bytemuck::cast_slice(&data0)); + let output0 = client0.empty(expected0.len() * std::mem::size_of::()); + + let input1 = client1.create_from_slice(bytemuck::cast_slice(&data1)); + let output1 = client1.empty(expected1.len() * std::mem::size_of::()); + + let handle0 = NcclClientHandle::new( + Some(input0.clone().binding()), + Some(output0.clone().binding()), + expected0.len(), + ncclDataType_t::ncclFloat32, + ); + + let handle1 = NcclClientHandle::new( + Some(input1.clone().binding()), + Some(output1.clone().binding()), + expected1.len(), + ncclDataType_t::ncclFloat32, + ); + + let reduce_scatter_sum = |handle: NcclServerHandle| -> Result<(), PluginError> { + handle.reduce_scatter(ReduceOp::Sum) + }; + + run_nccl_op( + &client0, + &client1, + uid, + handle0, + handle1, + reduce_scatter_sum, + ); + + cubecl_common::reader::read_sync(client0.sync()); + cubecl_common::reader::read_sync(client1.sync()); + + let result0_bytes = client0.read(vec![output0]); + let result1_bytes = client1.read(vec![output1]); + + let result0: Vec = bytemuck::cast_slice(&result0_bytes[0]).to_vec(); + let result1: Vec = bytemuck::cast_slice(&result1_bytes[0]).to_vec(); + + assert_eq!(result0, expected0, "GPU 0 result mismatch"); + assert_eq!(result1, expected1, "GPU 1 result mismatch"); + } + + #[test] + fn test_nccl_all_gather() { + let _lock = NCCL_TEST_LOCK.lock().unwrap(); + let Some((client0, client1, uid)) = init_nccl_pair() else { + return; + }; + + let data0: Vec = vec![1.0, 2.0]; + let data1: Vec = vec![3.0, 4.0]; + + let expected: Vec = vec![1.0, 2.0, 3.0, 4.0]; + + let input0 = client0.create_from_slice(bytemuck::cast_slice(&data0)); + let output0 = client0.empty(expected.len() * std::mem::size_of::()); + + let input1 = client1.create_from_slice(bytemuck::cast_slice(&data1)); + let output1 = client1.empty(expected.len() * std::mem::size_of::()); + + let handle0 = NcclClientHandle::new( + Some(input0.clone().binding()), + Some(output0.clone().binding()), + data0.len(), + ncclDataType_t::ncclFloat32, + ); + + let handle1 = NcclClientHandle::new( + Some(input1.clone().binding()), + Some(output1.clone().binding()), + data1.len(), + ncclDataType_t::ncclFloat32, + ); + + let all_gather = + |handle: NcclServerHandle| -> Result<(), PluginError> { handle.all_gather() }; + + run_nccl_op(&client0, &client1, uid, handle0, handle1, all_gather); + + cubecl_common::reader::read_sync(client0.sync()); + cubecl_common::reader::read_sync(client1.sync()); + + let result0_bytes = client0.read(vec![output0]); + let result1_bytes = client1.read(vec![output1]); + + let result0: Vec = bytemuck::cast_slice(&result0_bytes[0]).to_vec(); + let result1: Vec = bytemuck::cast_slice(&result1_bytes[0]).to_vec(); + + assert_eq!(result0, expected, "GPU 0 result mismatch"); + assert_eq!(result1, expected, "GPU 1 result mismatch"); + } + + #[test] + fn test_nccl_send_recv() { + let _lock = NCCL_TEST_LOCK.lock().unwrap(); + let Some((client0, client1, uid)) = init_nccl_pair() else { + return; + }; + + let data0: Vec = vec![10.0, 20.0, 30.0, 40.0]; + + let input0 = client0.create_from_slice(bytemuck::cast_slice(&data0)); + let output1 = client1.empty(data0.len() * std::mem::size_of::()); + + let handle0 = NcclClientHandle::new( + Some(input0.clone().binding()), + None, + data0.len(), + ncclDataType_t::ncclFloat32, + ); + + let handle1 = NcclClientHandle::new( + None, + Some(output1.clone().binding()), + data0.len(), + ncclDataType_t::ncclFloat32, + ); + + let init0 = NcclInit { + id: 0, + dev_count: 2, + uid, + }; + let init1 = NcclInit { + id: 1, + dev_count: 2, + uid, + }; + + std::thread::scope(|s| { + let h0 = s.spawn(|| { + client0 + .plugin_init::(init0) + .expect("Failed to initialize NCCL on device 0"); + client0 + .plugin_fn::(handle0, |handle: NcclServerHandle| handle.send(1)) + .expect("Send on GPU 0 failed"); + }); + + let h1 = s.spawn(|| { + client1 + .plugin_init::(init1) + .expect("Failed to initialize NCCL on device 1"); + client1 + .plugin_fn::(handle1, |handle: NcclServerHandle| handle.recv(0)) + .expect("Recv on GPU 1 failed"); + }); + + h0.join().unwrap(); + h1.join().unwrap(); + }); + + cubecl_common::reader::read_sync(client0.sync()); + cubecl_common::reader::read_sync(client1.sync()); + + let result1_bytes = client1.read(vec![output1]); + let result1: Vec = bytemuck::cast_slice(&result1_bytes[0]).to_vec(); + + assert_eq!(result1, data0, "GPU 1 received data mismatch"); + } + + #[test] + fn test_nccl_all_reduce_with_different_ops() { + let _lock = NCCL_TEST_LOCK.lock().unwrap(); + let Some((client0, client1, uid)) = init_nccl_pair() else { + return; + }; + + let data0: Vec = vec![2.0, 4.0, 6.0, 8.0]; + let data1: Vec = vec![1.0, 3.0, 5.0, 7.0]; + + let expected_max: Vec = vec![2.0, 4.0, 6.0, 8.0]; + + let input0 = client0.create_from_slice(bytemuck::cast_slice(&data0)); + let output0 = client0.empty(data0.len() * std::mem::size_of::()); + + let input1 = client1.create_from_slice(bytemuck::cast_slice(&data1)); + let output1 = client1.empty(data1.len() * std::mem::size_of::()); + + let handle0 = NcclClientHandle::new( + Some(input0.clone().binding()), + Some(output0.clone().binding()), + data0.len(), + ncclDataType_t::ncclFloat32, + ); + + let handle1 = NcclClientHandle::new( + Some(input1.clone().binding()), + Some(output1.clone().binding()), + data1.len(), + ncclDataType_t::ncclFloat32, + ); + + let all_reduce_max = |handle: NcclServerHandle| -> Result<(), PluginError> { + handle.all_reduce(ReduceOp::Max) + }; + + run_nccl_op(&client0, &client1, uid, handle0, handle1, all_reduce_max); + + cubecl_common::reader::read_sync(client0.sync()); + cubecl_common::reader::read_sync(client1.sync()); + + let result0_bytes = client0.read(vec![output0]); + let result1_bytes = client1.read(vec![output1]); + + let result0: Vec = bytemuck::cast_slice(&result0_bytes[0]).to_vec(); + let result1: Vec = bytemuck::cast_slice(&result1_bytes[0]).to_vec(); + + assert_eq!(result0, expected_max, "GPU 0 result mismatch for Max op"); + assert_eq!(result1, expected_max, "GPU 1 result mismatch for Max op"); + } +} diff --git a/crates/cubecl-cuda/src/compute/server.rs b/crates/cubecl-cuda/src/compute/server.rs index db33ad16db..a0d5b0e34c 100644 --- a/crates/cubecl-cuda/src/compute/server.rs +++ b/crates/cubecl-cuda/src/compute/server.rs @@ -28,6 +28,8 @@ use cubecl_runtime::config::GlobalConfig; use cubecl_runtime::logging::ServerLogger; use cubecl_runtime::memory_management::{MemoryAllocationMode, offset_handles}; use cubecl_runtime::memory_management::{MemoryDeviceProperties, MemoryUsage}; +#[cfg(feature = "cu-nccl")] +use cubecl_runtime::plugin::{Plugin, PluginType, SupportsPlugin}; use cubecl_runtime::server::{self, ComputeServer}; use cubecl_runtime::storage::BindingResource; use cubecl_runtime::stream::MultiStream; @@ -39,6 +41,11 @@ use cudarc::driver::sys::{ use std::ffi::c_void; use std::mem::MaybeUninit; use std::sync::Arc; +#[cfg(any(feature = "plugin", feature = "cu-nccl"))] +use std::{ + any::{Any, TypeId}, + collections::HashMap, +}; pub(crate) const MB: usize = 1024 * 1024; @@ -49,6 +56,8 @@ pub struct CudaServer { peer_activated: bool, mem_alignment: usize, utilities: Arc>, + #[cfg(any(feature = "plugin", feature = "cu-nccl"))] + plugins: HashMap>, } unsafe impl Send for CudaServer {} @@ -509,6 +518,8 @@ impl CudaServer { max_streams, ), utilities: Arc::new(utilities), + #[cfg(any(feature = "plugin", feature = "cu-nccl"))] + plugins: HashMap::new(), } } @@ -919,3 +930,105 @@ fn enable_one_way_peer_access(ctx_src: CUcontext) -> Result<(), CUresult> { } } } + +#[cfg(feature = "cu-nccl")] +impl SupportsPlugin for CudaServer { + /// Initializes the `Plugin` on the `CudaServer`. + /// Takes the `InitType` defined by the `Plugin` and stores the initialized + /// `Insert` type in the server's plugin storage. + /// This function syncs the command stream before initialization to ensure + /// all previous operations are completed. + fn init_type( + &mut self, + plugin_type: ::InitType, + stream: StreamId, + ) -> Result<(), cubecl_runtime::plugin::PluginError> { + { + let mut command = self.command(stream, vec![].iter()); + cubecl_core::future::block_on(command.sync()); + } + let comm = plugin_type.init(); + let type_id = TypeId::of::(); + self.plugins.insert(type_id, Box::new(comm)); + Ok(()) + } + + /// Executes a `Plugin` function on the `CudaServer`. + /// Takes the `ClientHandle` provided by the client and converts it to a `ServerHandle` + /// by resolving the bindings to GPU resources and retrieving the stored plugin state. + /// The function then syncs the command stream and executes the provided operation + /// with the `ServerHandle`. + fn plugin_fn( + &mut self, + client_handle: crate::compute::nccl::NcclClientHandle, + stream_id: StreamId, + op: ::Fns, + ) -> Result<(), cubecl_runtime::plugin::PluginError> { + use crate::compute::nccl::NcclServerHandle; + + let comm_ptr = { + let type_id = TypeId::of::(); + + let comm = self + .plugins + .get(&type_id) + .ok_or_else(|| { + cubecl_runtime::plugin::PluginError::NotInitialized("cuda_nccl".into()) + })? + .downcast_ref::() + .ok_or_else(|| { + cubecl_runtime::plugin::PluginError::InvalidHandle( + "Invalid NCCL communicator type".into(), + ) + })?; + comm.as_ptr() + }; + + let mut bindings: Vec = Vec::new(); + + if let Some(ref binding) = client_handle.input { + bindings.push(binding.clone()); + } + if let Some(ref binding) = client_handle.output { + bindings.push(binding.clone()); + } + let mut command = self.command(stream_id, bindings.iter()); + + let input = if let Some(binding) = client_handle.input { + let resource = command.resource(binding).map_err(|e| { + cubecl_runtime::plugin::PluginError::InvalidHandle(format!( + "Input resource not found: {:?}", + e + )) + })?; + Some(resource.ptr as *const ::core::ffi::c_void) + } else { + None + }; + + let output = if let Some(binding) = client_handle.output { + let resource = command.resource(binding).map_err(|e| { + cubecl_runtime::plugin::PluginError::InvalidHandle(format!( + "Output resource not found: {:?}", + e + )) + })?; + Some(resource.ptr as *mut ::core::ffi::c_void) + } else { + None + }; + + let stream_sys = command.streams.current().sys; + cubecl_common::future::block_on(command.sync()); + + let server_handle = NcclServerHandle { + input, + output, + dev_count: client_handle.device_count, + ty: client_handle.nccl_type, + comm: comm_ptr, + stream: stream_sys, + }; + op(server_handle) + } +} diff --git a/crates/cubecl-cuda/src/lib.rs b/crates/cubecl-cuda/src/lib.rs index f93c46b1c8..79c6cd8a11 100644 --- a/crates/cubecl-cuda/src/lib.rs +++ b/crates/cubecl-cuda/src/lib.rs @@ -9,6 +9,9 @@ mod runtime; pub use device::*; pub use runtime::*; +#[cfg(feature = "cu-nccl")] +pub use compute::nccl::*; + #[cfg(feature = "ptx-wmma")] pub(crate) type WmmaCompiler = cubecl_cpp::cuda::mma::PtxWmmaCompiler; diff --git a/crates/cubecl-runtime/Cargo.toml b/crates/cubecl-runtime/Cargo.toml index d8189e0d75..05acc616ca 100644 --- a/crates/cubecl-runtime/Cargo.toml +++ b/crates/cubecl-runtime/Cargo.toml @@ -27,6 +27,7 @@ exclusive-memory-only = [] profile-tracy = ["dep:tracy-client"] std = ["cubecl-common/std", "toml", "dirs"] storage-bytes = [] +plugin = [] [dependencies] async-channel = { workspace = true } # Assume std diff --git a/crates/cubecl-runtime/src/client.rs b/crates/cubecl-runtime/src/client.rs index fd8627888e..f075595786 100644 --- a/crates/cubecl-runtime/src/client.rs +++ b/crates/cubecl-runtime/src/client.rs @@ -10,6 +10,9 @@ use crate::{ }, storage::{BindingResource, ComputeStorage}, }; + +#[cfg(feature = "plugin")] +use crate::plugin::{Plugin, PluginError, SupportsPlugin}; use alloc::format; use alloc::sync::Arc; use alloc::vec; @@ -783,3 +786,31 @@ where alloc } } + +#[cfg(feature = "plugin")] +impl ComputeClient { + /// This function is used to initialize a type on the `Server` + pub fn plugin_init(&self, plugin_type: S::InitType) -> Result<(), PluginError> + where + Server: SupportsPlugin, + { + let mut state = self.context.lock(); + let stream_id = self.stream_id(); + state.init_type(plugin_type, stream_id)?; + Ok(()) + } + + /// Execute an extension function on the `Server`. + pub fn plugin_fn( + &self, + client_handle: S::ClientHandle, + op: S::Fns, + ) -> Result + where + Server: SupportsPlugin, + { + let stream_id = self.stream_id(); + let mut state = self.context.lock(); + state.plugin_fn::(client_handle, stream_id, op) + } +} diff --git a/crates/cubecl-runtime/src/lib.rs b/crates/cubecl-runtime/src/lib.rs index 4c2d4dfe86..b3a0ef484e 100644 --- a/crates/cubecl-runtime/src/lib.rs +++ b/crates/cubecl-runtime/src/lib.rs @@ -50,3 +50,8 @@ pub mod tma; /// Simple system profiling using timestamps. pub mod timestamp_profiler; + +/// Plugin traits used with the ComputeClient extention function to add additional types and fns to +/// a Backend +#[cfg(feature = "plugin")] +pub mod plugin; diff --git a/crates/cubecl-runtime/src/plugin.rs b/crates/cubecl-runtime/src/plugin.rs new file mode 100644 index 0000000000..d0eb27ad5c --- /dev/null +++ b/crates/cubecl-runtime/src/plugin.rs @@ -0,0 +1,62 @@ +#![allow(missing_docs)] +//! This is a set of traits which need to be implemented when adding a `Plugin` to a backend. +//! The `ComputeClient` got two additional functions to execute `Plugin`'s. +//! One for type initialisation and one for passing a function to the `ComputeServer`. +//! These traits need to be implemented carefully and with synchronization and cleanup in mind, +//! especially when considering to handle a `Plugin` over a continuous runtime. + +use cubecl_common::stream_id::StreamId; +use crate::server::ComputeServer; + +/// This trait defines all types and the function trait for a `Plugin`. +pub trait Plugin: Send + Sync + 'static { + /// This is a structure of some kind + /// representing all data the client layer needs to provide + /// for executing functions of the `Plugin` on a `ComputeServer`. + type ClientHandle: Send; + /// This is a type which needs to be build by a `ComputeServer`. + type ServerHandle: Send; + /// A `Plugin` wants to add additional functions to a `ComputeServer`. + /// Since we know how the `ServerHandle` will look the `Plugin` will + /// be able to define additional functions which can be executed over the + /// `ServerHandle` by a `ComputeServer` + type Fns: FnOnce(Self::ServerHandle) -> Result; + /// A `Plugin` might need an additional type which needs to be + /// initialised after the `ComputeServer` is loaded. + /// This type can be used to define initialisation parameters. + type InitType: PluginType; + /// For the case the `Plugin` needs a return type. + type ReturnVal; + /// Extension name for complex runtime operations. + const EXTENSION_NAME: &'static str; +} + +/// Type we want for initialisation. +pub trait PluginType { + type Insert: Send + Sync; + + fn init(self) -> Self::Insert; +} + +/// These are the functions used by the `ComputeClient` implemented over `ComputeServer`. +pub trait SupportsPlugin: ComputeServer { + /// The `ComputeServer` should be able to initialise the `Plugin`. + fn init_type(&mut self, plugin_type: SP::InitType, stream: StreamId) -> Result<(), PluginError>; + + /// And should be able to take the `Plugin`'s fns plus `ClientHandle` and + /// `StreamId` to build the `ServerHandle` and execute the function. + fn plugin_fn( + &mut self, + client_handle: SP::ClientHandle, + stream_id: StreamId, + op: SP::Fns, + ) -> Result; +} + +#[derive(Debug)] +pub enum PluginError { + NotSupported(&'static str), + NotInitialized(String), + ExecutionFailed(String), + InvalidHandle(String), +}