Skip to content

Commit f4e3b04

Browse files
committed
Add update_capture_dependencies flags
1 parent 66d7e36 commit f4e3b04

File tree

4 files changed

+21
-6
lines changed

4 files changed

+21
-6
lines changed

doc/driver.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -621,6 +621,13 @@ Constants
621621
.. attribute:: ACTIVE
622622
.. attribute:: INVALIDATED
623623

624+
.. class:: update_capture_dependencies_flags
625+
626+
CUDA 11.3 and newer.
627+
628+
.. attribute:: ADD_CAPTURE_DEPENDENCIES
629+
.. attribute:: SET_CAPTURE_DEPENDENCIES
630+
624631

625632
Graphics-related constants
626633
^^^^^^^^^^^^^^^^^^^^^^^^^^

examples/demo_graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@
3434
func_plus(a_gpu, numpy.int32(2), block=(4, 4, 1), stream=stream_1)
3535
_, _, graph, deps = stream_1.get_capture_info_v2()
3636
first_node = graph.add_kernel_node(b_gpu, numpy.int32(3), block=(4, 4, 1), func=func_plus, dependencies=deps)
37-
stream_1.update_capture_dependencies([first_node], 1)
37+
stream_1.update_capture_dependencies([first_node], cuda.update_capture_dependencies_flags.SET_CAPTURE_DEPENDENCIES)
3838

3939
_, _, graph, deps = stream_1.get_capture_info_v2()
4040
second_node = graph.add_kernel_node(a_gpu, b_gpu, block=(4, 4, 1), func=func_times, dependencies=deps)
41-
stream_1.update_capture_dependencies([second_node], 1)
41+
stream_1.update_capture_dependencies([second_node], cuda.update_capture_dependencies_flags.SET_CAPTURE_DEPENDENCIES)
4242
cuda.memcpy_dtoh_async(result, a_gpu, stream_1)
4343

4444
graph = stream_1.end_capture()

src/wrapper/wrap_cudadrv.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1277,6 +1277,12 @@ BOOST_PYTHON_MODULE(_driver)
12771277
.value("ACTIVE", CU_STREAM_CAPTURE_STATUS_ACTIVE)
12781278
.value("INVALIDATED", CU_STREAM_CAPTURE_STATUS_INVALIDATED)
12791279
;
1280+
#endif
1281+
#if CUDAPP_CUDA_VERSION >= 11030
1282+
py::enum_<CUstreamUpdateCaptureDependencies_flags>("update_capture_dependencies_flags")
1283+
.value("ADD_CAPTURE_DEPENDENCIES", CU_STREAM_ADD_CAPTURE_DEPENDENCIES)
1284+
.value("SET_CAPTURE_DEPENDENCIES", CU_STREAM_SET_CAPTURE_DEPENDENCIES)
1285+
;
12801286
#endif
12811287
{
12821288
typedef stream cl;
@@ -1294,7 +1300,9 @@ BOOST_PYTHON_MODULE(_driver)
12941300
py::return_value_policy<py::manage_new_object>())
12951301
.def("get_capture_info_v2", &cl::get_capture_info_v2)
12961302
#if CUDAPP_CUDA_VERSION >= 11030
1297-
.def("update_capture_dependencies", &cl::update_capture_dependencies)
1303+
.def("update_capture_dependencies", &cl::update_capture_dependencies,
1304+
(py::arg("dependencies"),
1305+
py::arg("flags") = CU_STREAM_ADD_CAPTURE_DEPENDENCIES))
12981306
#endif
12991307
#endif
13001308
.add_property("handle", &cl::handle_int)

test/test_graph.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def test_dynamic_params(self):
6161
assert stat == drv.capture_status.ACTIVE, "Capture should be active"
6262
assert len(deps) == 0, "Nothing on deps"
6363
newnode = x_graph.add_kernel_node(a_gpu, numpy.int32(3), block=(4, 4, 1), func=func_plus, dependencies=deps)
64-
stream_1.update_capture_dependencies([newnode], 1)
64+
stream_1.update_capture_dependencies([newnode], cuda.update_capture_dependencies_flags.SET_CAPTURE_DEPENDENCIES)
6565
drv.memcpy_dtoh_async(result, a_gpu, stream_1) # Capture a copy as well.
6666
graph = stream_1.end_capture()
6767
assert graph == x_graph, "Should be the same"
@@ -110,11 +110,11 @@ def test_many_dynamic_params(self):
110110
assert stat == drv.capture_status.ACTIVE, "Capture should be active"
111111
assert len(deps) == 0, "Nothing on deps"
112112
newnode = x_graph.add_kernel_node(a_gpu, numpy.int32(3), block=(4, 4, 1), func=func_plus, dependencies=deps)
113-
stream_1.update_capture_dependencies([newnode], 1)
113+
stream_1.update_capture_dependencies([newnode], cuda.update_capture_dependencies_flags.SET_CAPTURE_DEPENDENCIES)
114114
_, _, x_graph, deps = stream_1.get_capture_info_v2()
115115
assert deps == [newnode], "Call to update_capture_dependencies should set newnode as the only dep"
116116
newnode2 = x_graph.add_kernel_node(b_gpu, numpy.int32(3), block=(4, 4, 1), func=func_plus, dependencies=deps)
117-
stream_1.update_capture_dependencies([newnode2], 1)
117+
stream_1.update_capture_dependencies([newnode2], cuda.update_capture_dependencies_flags.SET_CAPTURE_DEPENDENCIES)
118118

119119
# Static capture
120120
func_times(a_gpu, b_gpu, block=(4, 4, 1), stream=stream_1)

0 commit comments

Comments
 (0)