Skip to content

Commit e6d994e

Browse files
authored
Add cuda.compute APIs for upper_bound and lower_bound (#7250)
* Python bindings to binary search * Add Python API for binary search * Add tests, examples and benchmarks for binary search. * Docs clarification * Update binary_search to use new registered caching mechanism * Update binary_search to use new registered caching mechanism * Change to private method * Actually, we don't need to work with the same pointers between __init__ and __call__ * Remove our hack to return indices --------- Co-authored-by: Ashwin Srinath <shwina@users.noreply.github.com>
1 parent eebd133 commit e6d994e

File tree

12 files changed

+708
-0
lines changed

12 files changed

+708
-0
lines changed
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
import cupy as cp
5+
import numpy as np
6+
import pytest
7+
8+
import cuda.compute
9+
10+
11+
def lower_bound_run(d_data, d_values, d_out, build_only):
12+
searcher = cuda.compute.make_lower_bound(d_data, d_values, d_out)
13+
if not build_only:
14+
searcher(d_data, d_values, d_out, len(d_data), len(d_values))
15+
cp.cuda.runtime.deviceSynchronize()
16+
17+
18+
def upper_bound_run(d_data, d_values, d_out, build_only):
19+
searcher = cuda.compute.make_upper_bound(d_data, d_values, d_out)
20+
if not build_only:
21+
searcher(d_data, d_values, d_out, len(d_data), len(d_values))
22+
cp.cuda.runtime.deviceSynchronize()
23+
24+
25+
@pytest.mark.parametrize("bench_fixture", ["compile_benchmark", "benchmark"])
26+
def bench_lower_bound(bench_fixture, request, size):
27+
actual_size = 100 if bench_fixture == "compile_benchmark" else size
28+
d_data = cp.sort(cp.random.randint(0, 1000, actual_size, dtype=np.int32))
29+
d_values = cp.random.randint(0, 1000, actual_size, dtype=np.int32)
30+
d_out = cp.empty_like(d_values, dtype=np.uintp)
31+
32+
def run():
33+
lower_bound_run(
34+
d_data, d_values, d_out, build_only=(bench_fixture == "compile_benchmark")
35+
)
36+
37+
fixture = request.getfixturevalue(bench_fixture)
38+
fixture(run)
39+
40+
41+
@pytest.mark.parametrize("bench_fixture", ["compile_benchmark", "benchmark"])
42+
def bench_upper_bound(bench_fixture, request, size):
43+
actual_size = 100 if bench_fixture == "compile_benchmark" else size
44+
d_data = cp.sort(cp.random.randint(0, 1000, actual_size, dtype=np.int32))
45+
d_values = cp.random.randint(0, 1000, actual_size, dtype=np.int32)
46+
d_out = cp.empty_like(d_values, dtype=np.uintp)
47+
48+
def run():
49+
upper_bound_run(
50+
d_data, d_values, d_out, build_only=(bench_fixture == "compile_benchmark")
51+
)
52+
53+
fixture = request.getfixturevalue(bench_fixture)
54+
fixture(run)

python/cuda_cccl/cuda/compute/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@ def __getattr__(name):
2121
exclusive_scan,
2222
histogram_even,
2323
inclusive_scan,
24+
lower_bound,
2425
make_binary_transform,
2526
make_exclusive_scan,
2627
make_histogram_even,
2728
make_inclusive_scan,
29+
make_lower_bound,
2830
make_merge_sort,
2931
make_radix_sort,
3032
make_reduce_into,
@@ -34,6 +36,7 @@ def __getattr__(name):
3436
make_three_way_partition,
3537
make_unary_transform,
3638
make_unique_by_key,
39+
make_upper_bound,
3740
merge_sort,
3841
radix_sort,
3942
reduce_into,
@@ -43,6 +46,7 @@ def __getattr__(name):
4346
three_way_partition,
4447
unary_transform,
4548
unique_by_key,
49+
upper_bound,
4650
)
4751
from .determinism import Determinism
4852
from .iterators import (
@@ -72,11 +76,13 @@ def __getattr__(name):
7276
"gpu_struct",
7377
"histogram_even",
7478
"inclusive_scan",
79+
"lower_bound",
7580
"make_binary_transform",
7681
"make_exclusive_scan",
7782
"make_select",
7883
"make_histogram_even",
7984
"make_inclusive_scan",
85+
"make_lower_bound",
8086
"make_merge_sort",
8187
"make_radix_sort",
8288
"make_reduce_into",
@@ -85,6 +91,7 @@ def __getattr__(name):
8591
"make_three_way_partition",
8692
"make_unary_transform",
8793
"make_unique_by_key",
94+
"make_upper_bound",
8895
"merge_sort",
8996
"OpKind",
9097
"Determinism",
@@ -101,5 +108,6 @@ def __getattr__(name):
101108
"three_way_partition",
102109
"unary_transform",
103110
"unique_by_key",
111+
"upper_bound",
104112
"ZipIterator",
105113
]

python/cuda_cccl/cuda/compute/_bindings.pyi

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,11 @@ class Determinism(IntEnum):
6969
RUN_TO_RUN = ...
7070
GPU_TO_GPU = ...
7171

72+
class BinarySearchMode(IntEnum):
73+
_value_: int
74+
LOWER_BOUND = ...
75+
UPPER_BOUND = ...
76+
7277
class Op:
7378
def __init__(
7479
self,
@@ -465,6 +470,31 @@ class DeviceHistogramBuildResult:
465470
stream,
466471
) -> None: ...
467472

473+
# -------------------
474+
# DeviceBinarySearch
475+
# -------------------
476+
477+
class DeviceBinarySearchBuildResult:
478+
def __init__(
479+
self,
480+
mode: BinarySearchMode,
481+
d_data: Iterator,
482+
d_values: Iterator,
483+
d_out: Iterator,
484+
comparison_op: Op,
485+
info: CommonData,
486+
): ...
487+
def compute(
488+
self,
489+
d_data: Iterator,
490+
num_items: int,
491+
d_values: Iterator,
492+
num_values: int,
493+
d_out: Iterator,
494+
comparison_op: Op,
495+
stream,
496+
) -> None: ...
497+
468498
# -----------------
469499
# DeviceSegmentedSort
470500
# -----------------

python/cuda_cccl/cuda/compute/_bindings_impl.pyx

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,10 @@ cdef extern from "cccl/c/types.h":
130130
RUN_TO_RUN "CCCL_RUN_TO_RUN"
131131
GPU_TO_GPU "CCCL_GPU_TO_GPU"
132132

133+
cpdef enum cccl_binary_search_mode_t:
134+
LOWER_BOUND "CCCL_BINARY_SEARCH_LOWER_BOUND"
135+
UPPER_BOUND "CCCL_BINARY_SEARCH_UPPER_BOUND"
136+
133137
cdef void arg_type_check(
134138
str arg_name,
135139
object expected_type,
@@ -147,6 +151,7 @@ IteratorKind = cccl_iterator_kind_t
147151
SortOrder = cccl_sort_order_t
148152
InitKind = cccl_init_kind_t
149153
Determinism = cccl_determinism_t
154+
BinarySearchMode = cccl_binary_search_mode_t
150155

151156
cdef void _validate_alignment(int alignment) except *:
152157
"""
@@ -2188,6 +2193,127 @@ cdef class DeviceHistogramBuildResult:
21882193
)
21892194

21902195

2196+
# -------------------
2197+
# DeviceBinarySearch
2198+
# -------------------
2199+
cdef extern from "cccl/c/binary_search.h":
2200+
cdef struct cccl_device_binary_search_build_result_t 'cccl_device_binary_search_build_result_t':
2201+
int cc
2202+
void* cubin
2203+
size_t cubin_size
2204+
CUlibrary library
2205+
CUkernel kernel
2206+
2207+
cdef CUresult cccl_device_binary_search_build(
2208+
cccl_device_binary_search_build_result_t*,
2209+
cccl_binary_search_mode_t,
2210+
cccl_iterator_t,
2211+
cccl_iterator_t,
2212+
cccl_iterator_t,
2213+
cccl_op_t,
2214+
int, int, const char*, const char*, const char*, const char*
2215+
) nogil
2216+
2217+
cdef CUresult cccl_device_binary_search(
2218+
cccl_device_binary_search_build_result_t,
2219+
cccl_iterator_t,
2220+
uint64_t,
2221+
cccl_iterator_t,
2222+
uint64_t,
2223+
cccl_iterator_t,
2224+
cccl_op_t,
2225+
CUstream
2226+
) nogil
2227+
2228+
cdef CUresult cccl_device_binary_search_cleanup(
2229+
cccl_device_binary_search_build_result_t *build_ptr
2230+
) nogil
2231+
2232+
2233+
cdef class DeviceBinarySearchBuildResult:
2234+
cdef cccl_device_binary_search_build_result_t build_data
2235+
2236+
def __dealloc__(DeviceBinarySearchBuildResult self):
2237+
cdef CUresult status = -1
2238+
with nogil:
2239+
status = cccl_device_binary_search_cleanup(&self.build_data)
2240+
if (status != 0):
2241+
print(f"Return code {status} encountered during binary_search result cleanup")
2242+
2243+
def __cinit__(
2244+
DeviceBinarySearchBuildResult self,
2245+
cccl_binary_search_mode_t mode,
2246+
Iterator d_data,
2247+
Iterator d_values,
2248+
Iterator d_out,
2249+
Op op,
2250+
CommonData common_data
2251+
):
2252+
cdef CUresult status = -1
2253+
cdef int cc_major = common_data.get_cc_major()
2254+
cdef int cc_minor = common_data.get_cc_minor()
2255+
cdef const char *cub_path = common_data.cub_path_get_c_str()
2256+
cdef const char *thrust_path = common_data.thrust_path_get_c_str()
2257+
cdef const char *libcudacxx_path = common_data.libcudacxx_path_get_c_str()
2258+
cdef const char *ctk_path = common_data.ctk_path_get_c_str()
2259+
2260+
memset(&self.build_data, 0, sizeof(cccl_device_binary_search_build_result_t))
2261+
with nogil:
2262+
status = cccl_device_binary_search_build(
2263+
&self.build_data,
2264+
mode,
2265+
d_data.iter_data,
2266+
d_values.iter_data,
2267+
d_out.iter_data,
2268+
op.op_data,
2269+
cc_major,
2270+
cc_minor,
2271+
cub_path,
2272+
thrust_path,
2273+
libcudacxx_path,
2274+
ctk_path,
2275+
)
2276+
if status != 0:
2277+
raise RuntimeError(
2278+
f"Failed building binary_search, error code: {status}"
2279+
)
2280+
2281+
cpdef void compute(
2282+
DeviceBinarySearchBuildResult self,
2283+
Iterator d_data,
2284+
size_t num_items,
2285+
Iterator d_values,
2286+
size_t num_values,
2287+
Iterator d_out,
2288+
Op op,
2289+
stream
2290+
):
2291+
cdef CUresult status = -1
2292+
cdef CUstream c_stream = <CUstream><uintptr_t>(stream) if stream else NULL
2293+
2294+
with nogil:
2295+
status = cccl_device_binary_search(
2296+
self.build_data,
2297+
d_data.iter_data,
2298+
<uint64_t>num_items,
2299+
d_values.iter_data,
2300+
<uint64_t>num_values,
2301+
d_out.iter_data,
2302+
op.op_data,
2303+
c_stream
2304+
)
2305+
if status != 0:
2306+
raise RuntimeError(
2307+
f"Failed executing binary_search, error code: {status}"
2308+
)
2309+
2310+
def _get_cubin(self):
2311+
return PyBytes_FromStringAndSize(
2312+
<const char*>self.build_data.cubin,
2313+
self.build_data.cubin_size
2314+
)
2315+
2316+
21912317
# ----------------------------------
21922318
# DeviceThreeWayPartitionBuildResult
21932319
# ----------------------------------

python/cuda_cccl/cuda/compute/algorithms/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
#
44
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
55

6+
from ._binary_search import lower_bound as lower_bound
7+
from ._binary_search import make_lower_bound as make_lower_bound
8+
from ._binary_search import make_upper_bound as make_upper_bound
9+
from ._binary_search import upper_bound as upper_bound
610
from ._histogram import histogram_even as histogram_even
711
from ._histogram import make_histogram_even as make_histogram_even
812
from ._reduce import make_reduce_into as make_reduce_into
@@ -33,6 +37,10 @@
3337
__all__ = [
3438
"reduce_into",
3539
"make_reduce_into",
40+
"lower_bound",
41+
"make_lower_bound",
42+
"upper_bound",
43+
"make_upper_bound",
3644
"inclusive_scan",
3745
"make_inclusive_scan",
3846
"exclusive_scan",

0 commit comments

Comments
 (0)