Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 22 additions & 74 deletions docs/libcudacxx/extended_api/work_stealing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,31 @@ Defined in header ``<cuda/work_stealing>`` if the CUDA compiler supports at leas

namespace cuda {

template <int ThreadBlockRank = 3, typename UnaryFunction = ..unspecified..>
template <int ThreadBlockRank = 3, invocable<dim3> UnaryFunction = ..unspecified..>
__device__ void for_each_canceled_block(UnaryFunction uf);

template <int ThreadBlockRank = 3, invocable<dim3> UnaryFunction = ..unspecified..>
__device__ void for_each_canceled_cluster(UnaryFunction uf);

} // namespace cuda

**Note**: On devices with compute capability 10.0 or higher, this function may leverage hardware acceleration.

This API is primarily intended for implementing work-stealing at the thread-block level.
These APIs are primarily intended for implementing work-stealing at the thread-block or cluster level.

Compared to alternative work distribution techniques, such as `grid-stride loops <https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/>`__, which distribute work statically, or dynamic work distribution methods relying on global memory concurrency, these API offer several advantages:

Compared to alternative work distribution techniques, such as `grid-stride loops <https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/>`__, which distribute work statically, or dynamic work distribution methods relying on global memory concurrency, this API offers several advantages:
- **Dynamic work-stealing**, i.e., thread blocks that complete their tasks sooner may take on additional work from slower thread blocks.
- **GPU Work Scheduler cooperation**, e.g., to respect work priorities and improve load balancing.
- **Lower latency**, e.g., when compared to global memory atomics.

- It enables dynamic work-stealing: thread blocks that complete their tasks sooner can take on additional work from slower thread blocks.
- It may cooperate with the GPU work scheduler to respect work priorities and improve load balancing.
- It may reduce work-stealing latency compared to global memory atomics.
For better performance, extract the shared prologue and epilogue from the work to be performed to reuse them across iterations:

For better performance, extract the shared thread-block prologue and epilogue outside the lambda and reuse them across thread-block iterations:
- Prologue: Initialization code and data common to all thread blocks or clusters, such as ``__shared__`` memory allocation and initialization.
- Epilogue: Finalization code common to all thread blocks or clusters, such as writing shared memory back to global memory.

- Prologue: Thread-block initialization code and data common to all thread blocks, such as ``__shared__`` memory allocation and initialization.
- Epilogue: Epilogue: Thread-block finalization code common to all thread blocks, such as writing shared memory back to global memory..
The ``for_each_canceled_cluster`` API may be used with thread-block clusters of any size, including one.
The ```for_each_canceled_block`` API is optimized for and requires thread-block clusters of size one.

**Mandates**:

Expand All @@ -37,79 +42,22 @@ For better performance, extract the shared thread-block prologue and epilogue ou

**Preconditions**:

- All threads within a thread block shall call ``for_each_canceled_block`` **exactly once**.
- ``for_each_canceled_block`` shall only be called from grids with **exactly** one thread block per cluster.
- All threads within a thread-block cluster shall call either ``for_each_canceled_block`` or ``for_each_canceled_cluster``, and do so **exactly once**.

**Effects**:

- Invokes ``uf`` with ``blockIdx`` and then repeatedly attempts to cancel the launch of another thread block within the current grid:
- Invokes ``uf`` with ``blockIdx`` and then repeatedly attempts to cancel the launch of another thread block or cluster within the current grid:

- If successful: invokes ``uf`` with the canceled thread block's ``blockIdx`` and repeats.
- Otherwise, the function returns; it failed to cancel the launch of another thread block.
- Otherwise, the function returns; it failed to cancel the launch of another thread block or cluster.

**Remarks**: ``for_each_canceled_cluster` guarantees that the relative position within a cluster of the thread block index ``uf`` is invoked with is always the same.

Example
-------

This example demonstrates work-stealing at thread-block granularity using this API.

.. code:: cuda

// Before:

#include <cuda/math>
#include <cuda/functional>
__global__ void vec_add(int* a, int* b, int* c, int n) {
// Extract common prologue outside the lambda, e.g.,
// - __shared__ or global (malloc) memory allocation
// - common initialization code
// - etc.

cuda::for_each_canceled_block<1>([=](dim3 block_idx) {
// block_idx may be different than the built-in blockIdx variable, that is:
// assert(block_idx == blockIdx); // may fail!
// so we need to use "block_idx" consistently inside for_each_canceled:
int idx = threadIdx.x + block_idx.x * blockDim.x;
if (idx < n) {
c[idx] += a[idx] + b[idx];
}
});
// Note: Calling for_each_canceled_block<1> again from this
// thread block exhibits undefined behavior.

// Extract common epilogue outside the lambda, e.g.,
// - write back shared memory to global memory
// - external synchronization
// - global memory deallocation (free)
// - etc.
}

int main() {
int N = 10000;
int *a, *b, *c;
cudaMallocManaged(&a, N * sizeof(int));
cudaMallocManaged(&b, N * sizeof(int));
cudaMallocManaged(&c, N * sizeof(int));
for (int i = 0; i < N; ++i) {
a[i] = i;
b[i] = 1;
c[i] = 0;
}

const int threads_per_block = 256;
const int blocks_per_grid = cuda::ceil_div(N, threads_per_block);

vec_add<<<blocks_per_grid, threads_per_block>>>(a, b, c, N);
cudaDeviceSynchronize();

bool success = true;
for (int i = 0; i < N; ++i) {
if (c[i] != (1 + i)) {
std::cerr << "ERROR " << i << ", " << c[i] << std::endl;
success = false;
}
}
cudaFree(a);
cudaFree(b);
cudaFree(c);

return success? 0 : 1;
}
.. literalinclude:: ../libcudacxx/examples/work_stealing.cu
:language: c++
121 changes: 121 additions & 0 deletions libcudacxx/examples/work_stealing.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
#include <cuda/atomic>
#include <cuda/math>
#include <cuda/work_stealing>

#include <iostream>

#include <cooperative_groups/reduce.h>
namespace cg = cooperative_groups

// Before: process one scalar addition per thread block:
// __global__ void vec_add_reduce(int* s, int* a, int* b, int* c, int n) {
// int idx = threadIdx.x + blockIdx.x * blockDim.x;
// int thread_sum = 0;
// if (idx < n) {
// int sum = a[idx] + b[idx];
// c[idx] += sum;
// thread_sum += sum;
// }
// auto block = cg::this_thread_block();
// auto tile = cg::tiled_partition<32>(block);
// cg::reduce_update_async(
// cg::tiled_partition<32>(cg::this_thread_block()), cuda::atomic_ref{*s},
// thread_sum, cg::plus<int>{}
// );
// }

// After: process many scalar additions per thread block:
__global__ void
vec_add_reduce(int* s, int* a, int* b, int* c, int n)
{
// Extract common prologue outside the lambda, e.g.,
// - __shared__ or global (malloc) memory allocation
// - common initialization code
// - etc.
// Here we extract the sum to continue accumulating locally across block indices:
int thread_sum = 0;

cuda::for_each_canceled_block<1>([&](dim3 block_idx) {
// block_idx may be different than the built-in blockIdx variable, that is:
// assert(block_idx == blockIdx); // may fail!
// so we need to use "block_idx" consistently inside for_each_canceled:
int idx = threadIdx.x + block_idx.x * blockDim.x;
if (idx < n)
{
int sum = a[idx] + b[idx];
c[idx] += sum;
thread_sum += sum;
}
});
// Note: Calling for_each_canceled_block again or calling for_each_canceled_cluster from this
// thread block exhibits undefined behavior.

// Extract common epilogue outside the lambda, e.g.,
// - write back shared memory to global memory
// - external synchronization
// - global memory deallocation (free)
// - etc.
// Here we extract that the per thread-block tile reduction into the global memory location:
auto block = cg::this_thread_block();
auto tile = cg::tiled_partition<32>(block);
cg::reduce_update_async(
cg::tiled_partition<32>(cg::this_thread_block()), cuda::atomic_ref{*s}, thread_sum, cg::plus<int>{});
} // namespace cg=cooperative_groups__global__void vec_add_reduce(int*s,int*a,int*b,int*c,intn)

int main()
{
int N = 10000;
int *sum, *a, *b, *c;
cudaMallocManaged(&sum, sizeof(int));
cudaMallocManaged(&a, N * sizeof(int));
cudaMallocManaged(&b, N * sizeof(int));
cudaMallocManaged(&c, N * sizeof(int));
*sum = 0;
for (int i = 0; i < N; ++i)
{
a[i] = i;
b[i] = 1;
c[i] = 0;
}

const int threads_per_block = 256;
const int blocks_per_grid = cuda::ceil_div(N, threads_per_block);

vec_add_reduce<<<blocks_per_grid, threads_per_block>>>(sum, a, b, c, N);

bool success = true;
if (cudaGetLastError() != cudaSuccess)
{
std::cerr << "LAUNCH ERROR" << std::endl;
success = false;
}
if (cudaDeviceSynchronize() != cudaSuccess)
{
std::cerr << "SYNC ERRROR" << std::endl;
success = false;
}

int should = 0;
for (int i = 0; i < N; ++i)
{
should += c[i];
if (c[i] != (1 + i))
{
std::cerr << "ERROR " << i << ": " << c[i] << " != " << (1 + i) << std::endl;
success = false;
}
}

if (*sum != should)
{
std::cerr << "SUM ERROR " << *sum << " != " << should << std::endl;
success = false;
}

cudaFree(sum);
cudaFree(a);
cudaFree(b);
cudaFree(c);

return success ? 0 : 1;
}
Loading
Loading