7
7
# is strictly prohibited.
8
8
9
9
import os
10
+ import pathlib
10
11
import time
11
12
13
+ import numpy as np
12
14
import pytest
13
15
14
16
import cuda .core .experimental
15
17
from cuda .core .experimental import Device , EventOptions , LaunchConfig , Program , ProgramOptions , launch
18
+ from cuda .core .experimental ._memory import _DefaultPinnedMemorySource
16
19
17
20
18
21
def test_event_init_disabled ():
@@ -113,27 +116,44 @@ def test_error_timing_recorded():
113
116
event3 - event2
114
117
115
118
119
+ # TODO: improve this once path finder can find headers
120
+ @pytest .mark .skipif (os .environ .get ("CUDA_PATH" ) is None , reason = "need libcu++ header" )
116
121
def test_error_timing_incomplete ():
117
122
device = Device ()
118
123
device .set_current ()
119
124
120
- # This kernel is designed to not complete
125
+ # This kernel is designed to busy loop until a signal is received
121
126
code = """
127
+ #include <cuda/atomic>
128
+
122
129
extern "C"
123
- __global__ void wait() {
124
- while (1 > 0) {
130
+ __global__ void wait(int* val) {
131
+ cuda::atomic_ref<int, cuda::thread_scope_system> signal{*val};
132
+ while (true) {
133
+ if (signal.load(cuda::memory_order_relaxed)) {
134
+ break;
135
+ }
125
136
}
126
137
}
127
138
"""
128
139
129
140
arch = "" .join (f"{ i } " for i in device .compute_capability )
130
- program_options = ProgramOptions (std = "c++11" , arch = f"sm_{ arch } " )
141
+ program_options = ProgramOptions (
142
+ std = "c++17" ,
143
+ arch = f"sm_{ arch } " ,
144
+ include_path = str (pathlib .Path (os .environ ["CUDA_PATH" ]) / pathlib .Path ("include" )),
145
+ )
131
146
prog = Program (code , code_type = "c++" , options = program_options )
132
147
mod = prog .compile (target_type = "cubin" )
133
148
ker = mod .get_kernel ("wait" )
134
149
150
+ mr = _DefaultPinnedMemorySource ()
151
+ b = mr .allocate (4 )
152
+ arr = np .from_dlpack (b ).view (np .int32 )
153
+ arr [0 ] = 0
154
+
135
155
config = LaunchConfig (grid = 1 , block = 1 )
136
- ker_args = ()
156
+ ker_args = (arr . ctypes . data , )
137
157
138
158
enabled = EventOptions (enable_timing = True )
139
159
stream = device .create_stream ()
@@ -145,3 +165,7 @@ def test_error_timing_incomplete():
145
165
# event3 will never complete because the stream is waiting on wait() to complete
146
166
with pytest .raises (RuntimeError , match = "^One or both events have not completed." ):
147
167
event3 - event1
168
+
169
+ arr [0 ] = 1
170
+ event3 .sync ()
171
+ event3 - event1 # this should work
0 commit comments