Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/cubecl-cuda/src/compute/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,7 @@ impl ComputeServer for CudaServer {
Box::pin(self.sync_stream_async())
}


fn start_profile(&mut self) -> ProfilingToken {
// Wait for current work to be done.
self.ctx.sync();
Expand Down
34 changes: 30 additions & 4 deletions crates/cubecl-cuda/src/compute/sync/fence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,9 @@ impl Fence {

/// Wait for the [Fence] to be reached, ensuring that all previous tasks enqueued to the
/// [stream](CUstream_st) are completed.
pub fn wait_sync(self) {
pub fn wait_sync(&self) {
unsafe {
cudarc::driver::result::event::synchronize(self.event).unwrap();
cudarc::driver::result::event::destroy(self.event).unwrap();
}
}

Expand All @@ -50,15 +49,42 @@ impl Fence {
/// # Notes
///
/// The [stream](CUevent_st) must be initialized.
pub fn wait_async(self, stream: *mut CUstream_st) {
pub fn wait_async(&self, stream: *mut CUstream_st) {
unsafe {
cudarc::driver::result::stream::wait_event(
stream,
self.event,
CUevent_wait_flags::CU_EVENT_WAIT_DEFAULT,
)
.unwrap();
cudarc::driver::result::event::destroy(self.event).unwrap();
}
}
}

impl Drop for Fence {
fn drop(&mut self) {
if !self.event.is_null() {
unsafe {
let _ = cudarc::driver::result::event::destroy(self.event);
self.event = core::ptr::null_mut();
}
}
}
}
Comment on lines +64 to +73
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fence is destroyed right after we wait on it. The methods wait_async and wait_sync take ownership, so you can't wait more than once on an event. I don't mind removing the destroy call in both functions, but we should also change the method signatures to take &self instead, so that we can wait multiple times on the same event.


#[cfg(test)]
mod tests {
use super::*;

// Compile-time check: ensure fence wait methods take &self and can be referenced twice.
// This does not execute any CUDA calls; it only validates signatures.
#[test]
fn fence_wait_methods_are_ref_and_multiwait() {
let _sync: fn(&Fence) = Fence::wait_sync;
let _async: fn(&Fence, *mut CUstream_st) = Fence::wait_async;
// Taking the function pointers twice ensures no consumption semantics are required.
let _sync2: fn(&Fence) = Fence::wait_sync;
let _async2: fn(&Fence, *mut CUstream_st) = Fence::wait_async;
let _ = (_sync, _async, _sync2, _async2);
}
}
29 changes: 26 additions & 3 deletions crates/cubecl-hip/src/compute/fence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,42 @@ impl Fence {
/// # Notes
///
/// The [stream](hipStream_t) must be initialized.
pub fn wait(self) {
pub fn wait(&self) {
unsafe {
let status = cubecl_hip_sys::hipStreamWaitEvent(self.stream, self.event, 0);
assert_eq!(
status, HIP_SUCCESS,
"Should successfully wait for stream event"
);
let status = cubecl_hip_sys::hipEventDestroy(self.event);
assert_eq!(status, HIP_SUCCESS, "Should destrdestroy the stream eventt");
}
}
}

impl Drop for Fence {
fn drop(&mut self) {
if !self.event.is_null() {
unsafe {
// Best-effort destroy; ignore errors in Drop.
let _ = cubecl_hip_sys::hipEventDestroy(self.event);
self.event = std::ptr::null_mut();
}
}
}
}

#[cfg(test)]
mod tests {
use super::*;

// Compile-time check: ensure fence wait method takes &self and can be referenced twice.
#[test]
fn fence_wait_method_is_ref_and_multiwait() {
let _wait: fn(&Fence) = Fence::wait;
let _wait2: fn(&Fence) = Fence::wait;
let _ = (_wait, _wait2);
}
}

/// A stream synchronization point that blocks until all previously enqueued work in the stream
/// has completed.
///
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-hip/src/compute/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ impl ComputeServer for HipServer {
Box::pin(self.sync_stream_async())
}


fn start_profile(&mut self) -> ProfilingToken {
cubecl_common::future::block_on(self.sync());
self.ctx.timestamps.start()
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-runtime/src/channel/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ pub trait ComputeChannel<Server: ComputeServer>: Clone + core::fmt::Debug + Send
/// Wait for the completion of every task in the server.
fn sync(&self) -> DynFut<()>;


/// Given a resource handle, return the storage resource.
fn get_resource(
&self,
Expand Down
2 changes: 2 additions & 0 deletions crates/cubecl-runtime/src/channel/cell.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ where
server.sync()
}



fn get_resource(
&self,
binding: Binding,
Expand Down
7 changes: 7 additions & 0 deletions crates/cubecl-runtime/src/channel/mpsc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ where
),
Flush,
Sync(Callback<()>),
/// Create a completion fence for all work submitted up to this message.
/// The server must not block the message loop while waiting for completion;
/// instead it should arrange completion asynchronously and signal the callback.
// Removed: WorkDone
MemoryUsage(Callback<MemoryUsage>),
MemoryCleanup,
AllocationMode(MemoryAllocationMode),
Expand Down Expand Up @@ -168,6 +172,7 @@ where
server.sync().await;
callback.send(()).await.unwrap();
}

Message::Flush => {
server.flush();
}
Expand Down Expand Up @@ -325,6 +330,8 @@ where
})
}



fn memory_usage(&self) -> crate::memory_management::MemoryUsage {
let (callback, response) = async_channel::unbounded();
self.state
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-runtime/src/channel/mutex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ where
server.sync()
}


fn get_resource(
&self,
binding: Binding,
Expand Down
50 changes: 50 additions & 0 deletions crates/cubecl-runtime/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use crate::{
},
storage::{BindingResource, ComputeStorage},
};
use alloc::boxed::Box;
use alloc::format;
use alloc::sync::Arc;
use alloc::vec;
Expand Down Expand Up @@ -199,6 +200,55 @@ where
self.channel.get_resource(binding)
}

/// Returns a future that resolves when all work submitted up to this call completes.
///
/// Note: This forwards to the underlying channel `sync()`.
pub fn fence(&self) -> cubecl_common::future::DynFut<()> {
self.channel.sync()
}

/// Execute a kernel asynchronously. The returned future resolves when
/// the submitted work has completed on the device.
///
/// # Safety
///
/// This forwards to the underlying channel `execute`, which is unsafe. The caller must
/// guarantee the same safety invariants as [`ComputeChannel::execute`], including but not
/// limited to ensuring that the kernel does not perform out‑of‑bounds accesses when running
/// with the provided bindings and execution mode.
pub unsafe fn execute_async(
&self,
kernel: Server::Kernel,
count: CubeCount,
bindings: Bindings,
mode: ExecutionMode,
) -> cubecl_common::future::DynFut<()> {
unsafe {
self.channel
.execute(kernel, count, bindings, mode, self.state.logger.clone());
}
self.channel.sync()
}

/// Asynchronously write bytes to buffers; returns a future that resolves when
/// the writes are visible on the device (after the device completes enqueued work).
pub fn write_async(
&self,
descriptors: Vec<(CopyDescriptor<'_>, &[u8])>,
) -> cubecl_common::future::DynFut<Result<(), IoError>> {
let res = self.channel.write(descriptors);
match res {
Ok(()) => {
let fut = self.channel.sync();
Box::pin(async move {
fut.await;
Ok(())
})
}
Err(e) => Box::pin(async move { Err(e) }),
}
}

fn do_create(
&self,
descriptors: Vec<AllocationDescriptor<'_>>,
Expand Down
2 changes: 2 additions & 0 deletions crates/cubecl-runtime/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ where
/// Wait for the completion of every task in the server.
fn sync(&mut self) -> DynFut<()>;

// Note: Completion fences are exposed via `sync()` implementations.

/// Given a resource handle, returns the storage resource.
fn get_resource(
&mut self,
Expand Down
21 changes: 21 additions & 0 deletions crates/cubecl-runtime/tests/integration_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ mod dummy;

use crate::dummy::{DummyDevice, DummyElementwiseAddition, test_client};

use cubecl_common::future::block_on;
use cubecl_runtime::server::Bindings;
use cubecl_runtime::server::CubeCount;
use cubecl_runtime::{local_tuner, tune::LocalTuner};
Expand Down Expand Up @@ -46,6 +47,26 @@ fn execute_elementwise_addition() {
assert_eq!(obtained_resource, Vec::from([4, 5, 6]))
}

#[test]
fn fence_completes_submitted_work() {
let client = test_client(&DummyDevice);
let lhs = client.create(&[1, 2, 3]);
let rhs = client.create(&[4, 5, 6]);
let out = client.empty(3);

client.execute(
KernelTask::new(DummyElementwiseAddition),
CubeCount::Static(1, 1, 1),
Bindings::new().with_buffers(vec![lhs.binding(), rhs.binding(), out.clone().binding()]),
);

// Wait for completion of all work submitted up to here.
block_on(client.fence());

let obtained = client.read_one(out);
assert_eq!(obtained, Vec::from([5, 7, 9]));
}

#[test]
#[cfg(feature = "std")]
fn autotune_basic_addition_execution() {
Expand Down
2 changes: 1 addition & 1 deletion crates/cubecl-wgpu/src/compiler/wgsl/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ impl Display for Extension {
SAFE_TANH_PRIMITIVE,
&[VectorIdent {
name: "x",
item: item.clone(),
item: *item,
}],
*item,
),
Expand Down
2 changes: 1 addition & 1 deletion crates/cubecl-wgpu/src/compute/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,11 @@ impl ComputeServer for WgpuServer {
self.stream.flush();
}

/// Returns the total time of GPU work this sync completes.
fn sync(&mut self) -> DynFut<()> {
self.stream.sync()
}


fn start_profile(&mut self) -> ProfilingToken {
self.stream.start_profile()
}
Expand Down
Loading
Loading