Skip to content

Design proposal: A flexible parameter and guarantees framework for (segmented) algorithms #7495

@elstehle

Description

@elstehle

This issue gathers thoughts, use cases, considerations, and design options to serve as a foundation for discussing a parameter framework for (segmented) algorithms.

The parameter framework has two main purposes:_

  1. allowing users to mix segment-specific, uniform, and compile-time static parameters, and
  2. providing a mechanism for users to provide parameter constraints as guarantees (at compile-time and at runtime) so the dispatch layer can safely specialize kernels and execute only relevant code paths.

Motivation & Use Cases

Example: DeviceBatchedTopK
DeviceBatchedTopK::Select(
    d_temp_storage,
    temp_storage_bytes,
    d_key_segments_it,
    d_key_segments_out_it,
    d_value_segments_it,
    d_value_segments_out_it,
    segment_sizes,          // static, uniform, or segment-specific
    k,                      // static, uniform, or segment-specific
    select_directions,      // static, uniform, or segment-specific
    num_segments,           // host-accessible or device-accessible-only
    total_num_items)        // needed as guarantee on the upper bound of the total number of items for buffer allocation

Segment-specific vs. uniform parameters vs. compile-time static parameters

Segmented top-k depends on segment_sizes, k, and select_directions. Each can be:

  • Static (compile-time constant).
  • Uniform (same for all segments).
  • Segment-specific (per-segment values).

Covering every combination through its own overload would lead to combinatorial explosion of function signatures. We need a representation that supports all combinations without exploding the number of overloads.

Constraining parameter values at compile-time and runtime

Many applications, such as LLMs (e.g., Mixture-of-Experts (MoE) and Sparse Attention models) have narrow constraints on parameter values, such as k (e.g., [1, 8]) or segment sizes (e.g., [512, 4096]).
We need a mechanism for users to communicate these constraints to the algorithm. This information is essential for achieving peak performance and minimizing kernel code paths that are compiled as well as the number of kernels.

Users may want to provide constraints on parameter values at compile-time and runtime. Moreover, they may want to provide runtime constraints as device-accessible-only values:

  • At compile-time: users may want to provide static lower and upper bounds for values of a parameter.
  • At runtime (host-accessible): users may want to further narrow down the range of values of a parameter by providing a runtime host-accessible lower and upper bounds for values of a parameter.
  • At runtime (device-accessible-only): users may want to further narrow down the range of values of a parameter by providing a runtime device-accessible-only lower and upper bounds for values of a parameter.

A few examples of optimizations in the context of top-k that can be achieved by having this information available at compile-time and runtime:

  • Static lower and upper bounds:
    • segment_sizes: Avoid compiling and running load-balancing mechanisms if all segments are small enough to fit into shared memory
    • segment_sizes: Avoid padding logic if segment size equals items_per_block.
    • k: Avoid compiling and running algorithm code paths that are not needed if k is small. E.g., we can use a reduce-based algorithm if k is, say, [1, 8].
    • num_items (DeviceRadixSort): May allow us to avoid compiling both, the single-tile and regular sorting kernels
  • Runtime host-accessible lower and upper bounds:
    • segment_sizes: Avoid allocating an auxiliary buffer for double-buffering if all segments are small enough to fit into shared memory
  • Runtime device-accessible-only lower and upper bounds:
    • segment_sizes: Avoid running load-balancing logic to partition segments based on their size if each segment can be efficiently processed by a single thread block or single warp.
    • segment_sizes: Avoid padding out-of-bounds items if segment size equals items_per_block.

Allow users to enumerate a discrete set of supported options

For parameters that only support a limited set of discrete options, such as the order (ascending, descending) that are supported by a sorting algorithm, users may want to specify which options they want the algorithm to support. This applies to single-problem algorithms as well as segmented and batched algorithms.

Accept parameters as device-accessible-only value

For single-problem algorithms, users may want to provide algorithm parameters as device-accessible-only values to support back-to-back kernel launches and avoid superfluous host-device data transfers (e.g., num_items or initial_value for scan).

Summary

Segmented & batched algorithms:

  • Support any combination of segment-specific parameters and parameters that are uniform across all segments.

Applicable to single-problem and segmented/batched algorithms:

  • Allow users to provide guarantees on parameter values: Users may want to provide guarantees on parameter values, such as the range of values that are supported by a parameter, both at compile-time and at runtime.
  • Allow users to enumerate a discrete set of supported options: Users may want to specify a discrete set of options for a parameter, such as the order (ascending, descending) that are supported by a sorting algorithm.
  • Accept parameters as device-accessible-only value: Users may want to provide algorithm parameters as device-accessible-only values to support back-to-back kernel launches and avoid superfluous host-device data transfers (e.g., num_items or initial_value for scan).

Open Questions

  • Do we want to allow users to provide constraints as device-accessible-only values? E.g., maybe it's computed by a prior reduction.
  • Do we want to allow users to provide their custom parameter types by specializing one of our parameter types?
  • Would we ever want to support segment-specific algorithm requirements (e.g., determinism)

Considerations

  • We may want to provide the constraints/guarantees down to block- and warp-level primitives.

Our Options

We have two fundamentally different options:

  • Should constraints be provided through annotated wrappers or
  • Should constraints be part of a guarantees environment

Option 1: Annotating parameters with wrappers on the parameter itself:

Example: DeviceBatchedTopK
using cub::batched_topk;
const auto [min_segment_size, max_segment_size] = thrust::minmax_element(...);
DeviceBatchedTopK::Select(
    ...,
    d_keys_in,
    d_keys_out,
    segment_size_per_segment<static_min_seg_size, static_max_seg_size>{segment_size, min_segment_size, max_segment_size},
    k_uniform<1, static_max_k>{k},
    select_direction_uniform{direction},
    num_segments_uniform<1, 256>{num_segments},
    total_num_items_guarantee{num_segments * segment_size});
Example: What the algorithm-specific parameter wrappers look like
// ------------ SELECTION DIRECTION PARAMETER TYPES ------------

// Selection direction known at compile time, same value applies to all segments
template <detail::topk::select SelectDirection>
using select_direction_static = params::uniform_discrete_param<detail::topk::select, SelectDirection>;

// Selection direction is a runtime value, same value applies to all segments
using select_direction_uniform =
  params::uniform_discrete_param<detail::topk::select, detail::topk::select::max, detail::topk::select::min>;

// Per-segment selection direction via iterator
template <typename SelectionDirectionIt, detail::topk::select... SelectDirectionOptions>
using select_direction_per_segment =
  params::per_segment_discrete_param<SelectionDirectionIt, detail::topk::select, SelectDirectionOptions...>;

// ------------ SEGMENT SIZE PARAMETER TYPES ------------

// Segment size known at compile time, same value applies to all segments
template <::cuda::std::int64_t SegmentSize>
using segment_size_static = params::static_constant_param<::cuda::std::int64_t, SegmentSize>;

// Segment size is a runtime value, same value applies to all segments
template <::cuda::std::int64_t MinSegmentSize = 0,
          ::cuda::std::int64_t MaxSegmentSize = ::cuda::std::numeric_limits<::cuda::std::int64_t>::max()>
using segment_size_uniform = params::uniform_param<::cuda::std::int64_t, MinSegmentSize, MaxSegmentSize>;

// Segment size via iterator
template <typename SegmentSizesItT,
          ::cuda::std::int64_t MinSegmentSize = 1,
          ::cuda::std::int64_t MaxSegmentSize = ::cuda::std::numeric_limits<::cuda::std::int64_t>::max()>
using segment_size_per_segment =
  params::per_segment_param<SegmentSizesItT, ::cuda::std::int64_t, MinSegmentSize, MaxSegmentSize>;

// ------------ K PARAMETER TYPES ------------

// K known at compile time, same value applies to all segments
template <::cuda::std::int64_t K>
using k_static = params::static_constant_param<::cuda::std::int64_t, K>;

// K is a runtime value, same value applies to all segments
template <::cuda::std::int64_t MinK = 1,
          ::cuda::std::int64_t MaxK = ::cuda::std::numeric_limits<::cuda::std::int64_t>::max()>
using k_uniform = params::uniform_param<::cuda::std::int64_t, MinK, MaxK>;

// K via iterator
template <typename KItT,
          ::cuda::std::int64_t MinK = 1,
          ::cuda::std::int64_t MaxK = ::cuda::std::numeric_limits<::cuda::std::int64_t>::max()>
using k_per_segment = params::per_segment_param<KItT, ::cuda::std::int64_t, MinK, MaxK>;

// ------------ TOTAL NUMBER OF SEGMENTS ------------
// Number of segments known at compile time
template <::cuda::std::int64_t StaticNumSegments>
using num_segments_static = params::static_constant_param<::cuda::std::int64_t, StaticNumSegments>;

// Number of segments is a runtime value
template <::cuda::std::int64_t MinNumSegments = 1,
          ::cuda::std::int64_t MaxNumSegments = ::cuda::std::numeric_limits<::cuda::std::int64_t>::max()>
using num_segments_uniform = params::uniform_param<::cuda::std::int64_t, MinNumSegments, MaxNumSegments>;

// Number of segments via iterator
template <typename NumSegmentsItT,
          ::cuda::std::int64_t MinNumSegments = 1,
          ::cuda::std::int64_t MaxNumSegments = ::cuda::std::numeric_limits<::cuda::std::int64_t>::max()>
using num_segments_per_segment =
  params::per_segment_param<NumSegmentsItT, ::cuda::std::int64_t, MinNumSegments, MaxNumSegments>;

// ------------ TOTAL NUMBER OF ITEMS PARAMETER TYPES ------------

// Number of items guarantee
template <::cuda::std::int64_t MinNumItemsT = 1,
          ::cuda::std::int64_t MaxNumItems  = ::cuda::std::numeric_limits<::cuda::std::int64_t>::max()>
struct total_num_items_guarantee
{
  static constexpr ::cuda::std::int64_t static_min_num_items = MinNumItemsT;
  static constexpr ::cuda::std::int64_t static_max_num_items = MaxNumItems;

  ::cuda::std::int64_t min_num_items = MinNumItemsT;
  ::cuda::std::int64_t max_num_items = MaxNumItems;

  // Create default ctor, 1 param ctor taking min, 2 param ctor taking min/max
  total_num_items_guarantee() = default;

  _CCCL_HOST_DEVICE total_num_items_guarantee(::cuda::std::int64_t num_items)
      : min_num_items(num_items)
      , max_num_items(num_items)
  {}

  _CCCL_HOST_DEVICE total_num_items_guarantee(::cuda::std::int64_t min_items, ::cuda::std::int64_t max_items)
      : min_num_items(min_items)
      , max_num_items(max_items)
  {}
};
Example: What the underlying parameter wrappers look like
// -----------------------------------------------------------------------------
// Parameter Mixins and Helpers
// -----------------------------------------------------------------------------

// Allows providing constraints on parameter values at compile-time
template <typename T, T Min = ::cuda::std::numeric_limits<T>::lowest(), T Max = ::cuda::std::numeric_limits<T>::max()>
struct static_bounds_mixin
{
  static_assert(Min <= Max, "Min must be <= Max");

  // Compile-time bounds
  static constexpr T static_min_value = Min;
  static constexpr T static_max_value = Max;

  // Indicates that there's only one possible value
  static constexpr bool is_exact = (Min == Max);
};

// -----------------------------------------------------------------------------
// 1. Compile-time constant parameter
// -----------------------------------------------------------------------------

// A compile-time constant
template <typename T, T Value>
struct static_constant_param : public static_bounds_mixin<T, Value, Value>
{
  using value_type = T;

  template <typename SegmentIndexT>
  _CCCL_HOST_DEVICE constexpr auto get_param([[maybe_unused]] SegmentIndexT segment_id) const
  {
    static_assert(static_bounds_mixin<T, Value, Value>::is_exact, "Static parameter must have exact value");
    return static_bounds_mixin<T, Value, Value>::static_min_value;
  }
};

// -----------------------------------------------------------------------------
// 2. Uniform parameter
// -----------------------------------------------------------------------------
// Added default template args so CTAD can deduce T and default Min/Max
template <typename T, T Min = ::cuda::std::numeric_limits<T>::lowest(), T Max = ::cuda::std::numeric_limits<T>::max()>
struct uniform_param : public static_bounds_mixin<T, Min, Max>
{
  using value_type = T;

  T value;

  _CCCL_HOST_DEVICE constexpr uniform_param(T v)
      : value(v)
  {}

  uniform_param() = default;

  template <typename SegmentIndexT>
  _CCCL_HOST_DEVICE constexpr auto get_param([[maybe_unused]] SegmentIndexT segment_id) const
  {
    return value;
  }
};

template <typename T>
uniform_param(T) -> uniform_param<T>;

// -----------------------------------------------------------------------------
// 3. Per-Segment parameter
// -----------------------------------------------------------------------------
// Added defaults for T, Min, and Max based on the Iterator's value_type
template <typename IteratorT,
          typename T = typename ::cuda::std::iterator_traits<IteratorT>::value_type,
          T Min      = ::cuda::std::numeric_limits<T>::lowest(),
          T Max      = ::cuda::std::numeric_limits<T>::max()>
struct per_segment_param : public static_bounds_mixin<T, Min, Max>
{
  using iterator_type = IteratorT;
  using value_type    = T;

  IteratorT iterator;
  T min_value = Min;
  T max_value = Max;

  _CCCL_HOST_DEVICE constexpr per_segment_param(IteratorT iter, T min_v = Min, T max_v = Max)
      : iterator(iter)
      , min_value(min_v)
      , max_value(max_v)
  {}

  per_segment_param() = default;

  template <typename SegmentIndexT>
  _CCCL_HOST_DEVICE constexpr auto get_param(SegmentIndexT segment_id) const
  {
    return iterator[segment_id];
  }
};

Option 2: Using environment-based guarantees:

Example: DeviceBatchedTopK using guarantees
auto static_guarantee_env           = cuda::execution::guarantee(
    cuda::execution::min_segment_size<static_min_seg_size>{}, 
    cuda::execution::max_segment_size<max_segment_size>{}, 
    cuda::execution::max_k<static_max_k>,
    cuda::execution::min_num_segments<1>{},
    cuda::execution::max_num_segments<256>{});
DeviceBatchedTopK::Select(
    ...,
    segment_sizes,
    cuda::make_constant_iterator(k),
    cuda::make_constant_iterator(direction),
    num_segments,
    static_guarantee_env);

Rationale in favor of having guarantees as part of the environment:

  • We might want to be able to pass guarantees to the block- and warp-level algorithms.

Rationale in favor of annotating parameters with wrappers on the parameter itself:

  • Parameter constraints are more of a property of the parameter itself than of the algorithm.
  • Having parameter constraints as part of the parameter itself is less verbose, it keeps per-parameter information close to the parameter itself.
  • Some constraints simply don't make sense for certain parameter types. E.g.,
    • a static parameter doesn't need further static or runtime constraints.
    • a uniform parameter doesn't need runtime constraints.

References

Existing CUB segmented and batched algorithms:

cub::DeviceSegmentedRadixSort

  • order: ascending or descending
  • segment size: uniform or segment-specific
  • begin_bit: uniform or segment-specific
  • end_bit: uniform or segment-specific
  • number of segments: host-accessible or device-accessible
  • decomposer type(?): uniform or segment-specific
  • total number of items: need some upper bound to allocate auxiliary buffer for double-buffering

cub::DeviceSegmentedSort

  • order: ascending or descending
  • segment size: uniform or segment-specific
  • begin_bit: uniform or segment-specific
  • end_bit: uniform or segment-specific
  • number of segments: host-accessible or device-accessible
  • decomposer type(?): uniform or segment-specific
  • total number of items: need some upper bound to allocate auxiliary buffer for double-buffering

cub::DeviceSegmentedReduce

  • reduction_op: uniform or segment-specific
  • segment size: uniform or segment-specific
  • number of segments: host-accessible or device-accessible
  • total number of items: e.g., to help allocate load-balancing queue, that writes N/CTA_SIZE items

cub::DeviceSegmentedScan

  • scan_op: uniform or segment-specific
  • segment size: uniform or segment-specific
  • number of segments: host-accessible or device-accessible
  • inclusive or exclusive scan
  • initial value: uniform or segment-specific
  • total number of items: e.g., to help allocate load-balancing queue, that writes N/CTA_SIZE items

cub::DeviceCopy::Batched & cub::DeviceMemcpy::Batched

  • segment size: uniform or segment-specific
  • number of segments: host-accessible or device-accessible
  • total number of items/bytes: e.g., to help allocate load-balancing queue, that writes N/CTA_SIZE items

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    Status

    Todo

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions