Skip to content

Commit 4af7c3e

Browse files
authored
[slimtensor] Introduce CUDA guard to aoti/slim/cuda (#16724)
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #16565 * #16551 * #16469 * #16457 * #16455 * #16454 * #16453 * #16452 * #16451 * #16450 * #16449 * #16448 * #16447 * #16446 * __->__ #16724 Copy CUDAGuard and CUDAStreamGuard from cuda/runtime/ to aoti/slim/cuda/ to support slimtensor requirement while get rid of potential circular dependency: - cuda_backend/main_functionalities -> aoti/slimtensor -> cuda_backend/cuda_guard This change: - copy guard.h, guard.cpp and test files from backend/cuda_backend to backend/aoti/slim/cuda/ Differential Revision: [D91056808](https://our.internmc.facebook.com/intern/diff/D91056808/)
1 parent 2b841eb commit 4af7c3e

File tree

11 files changed

+851
-10
lines changed

11 files changed

+851
-10
lines changed

backends/aoti/slim/c10/cuda/Exception.h

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,26 +8,55 @@
88

99
#pragma once
1010

11-
#ifdef CUDA_AVAILABLE
12-
1311
#include <cuda.h>
1412
#include <cuda_runtime.h>
1513

1614
#include <executorch/backends/aoti/slim/c10/macros/Macros.h>
15+
#include <executorch/runtime/core/error.h>
1716
#include <executorch/runtime/platform/assert.h>
1817
#include <executorch/runtime/platform/log.h>
1918

2019
/// Checks a CUDA expression and aborts on error.
2120
/// @param EXPR The CUDA expression to check.
22-
#define ET_CUDA_CHECK(EXPR) \
23-
do { \
24-
const cudaError_t __err = EXPR; \
25-
ET_CHECK_MSG( \
26-
__err == cudaSuccess, "CUDA error: %s", cudaGetErrorString(__err)); \
21+
#ifndef ET_CUDA_CHECK
22+
#define ET_CUDA_CHECK(EXPR) \
23+
do { \
24+
const cudaError_t __err = EXPR; \
25+
if (__err == cudaSuccess) { \
26+
break; \
27+
} \
28+
ET_LOG( \
29+
Error, \
30+
"%s:%d CUDA error: %s", \
31+
__FILE__, \
32+
__LINE__, \
33+
cudaGetErrorString(__err)); \
34+
ET_CHECK_MSG(false, "CUDA error: %s", cudaGetErrorString(__err)); \
2735
} while (0)
36+
#endif
37+
38+
/// Checks a CUDA expression and returns Error::Internal on failure.
39+
/// @param EXPR The CUDA expression to check.
40+
#ifndef ET_CUDA_CHECK_OR_RETURN_ERROR
41+
#define ET_CUDA_CHECK_OR_RETURN_ERROR(EXPR) \
42+
do { \
43+
const cudaError_t __err = EXPR; \
44+
if (__err == cudaSuccess) { \
45+
break; \
46+
} \
47+
ET_LOG( \
48+
Error, \
49+
"%s:%d CUDA error: %s", \
50+
__FILE__, \
51+
__LINE__, \
52+
cudaGetErrorString(__err)); \
53+
return ::executorch::runtime::Error::Internal; \
54+
} while (0)
55+
#endif
2856

2957
/// Checks a CUDA expression and logs a warning on error (non-fatal).
3058
/// @param EXPR The CUDA expression to check.
59+
#ifndef ET_CUDA_LOG_WARN
3160
#define ET_CUDA_LOG_WARN(EXPR) \
3261
do { \
3362
const cudaError_t __err = EXPR; \
@@ -36,5 +65,17 @@
3665
ET_LOG(Error, "CUDA warning: %s", cudaGetErrorString(__err)); \
3766
} \
3867
} while (0)
68+
#endif
69+
70+
/// Kernel launch check macro (with return) - checks cudaGetLastError after
71+
/// kernel launch.
72+
#ifndef ET_CUDA_KERNEL_LAUNCH_CHECK_OR_RETURN_ERROR
73+
#define ET_CUDA_KERNEL_LAUNCH_CHECK_OR_RETURN_ERROR() \
74+
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaGetLastError())
75+
#endif
3976

40-
#endif // CUDA_AVAILABLE
77+
/// Kernel launch check macro (without return) - checks cudaGetLastError after
78+
/// kernel launch.
79+
#ifndef ET_CUDA_KERNEL_LAUNCH_CHECK
80+
#define ET_CUDA_KERNEL_LAUNCH_CHECK() ET_CUDA_CHECK(cudaGetLastError())
81+
#endif

backends/aoti/slim/core/Storage.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
#ifdef CUDA_AVAILABLE
1414
#include <executorch/backends/aoti/slim/c10/cuda/Exception.h>
15-
#include <executorch/backends/cuda/runtime/guard.h>
15+
#include <executorch/backends/aoti/slim/cuda/guard.h>
1616
#endif
1717

1818
#include <executorch/backends/aoti/slim/c10/core/Device.h>

backends/aoti/slim/core/targets.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def define_common_targets():
1818
"//executorch/backends/aoti/slim/util:size_util",
1919
"//executorch/runtime/platform:platform",
2020
"//executorch/backends/aoti/slim/c10/cuda:exception",
21-
"//executorch/backends/cuda/runtime:guard",
21+
"//executorch/backends/aoti/slim/cuda:guard",
2222
],
2323
)
2424

backends/aoti/slim/cuda/TARGETS

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
load(":targets.bzl", "define_common_targets")
3+
4+
oncall("executorch")
5+
6+
define_common_targets()

backends/aoti/slim/cuda/guard.cpp

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/aoti/slim/cuda/guard.h>
10+
#include <executorch/runtime/platform/log.h>
11+
#include <limits>
12+
#include <unordered_map>
13+
14+
namespace executorch::backends::cuda {
15+
16+
namespace {
17+
// Thread-local stream storage (private to this file)
18+
thread_local std::unordered_map<DeviceIndex, cudaStream_t> current_streams_;
19+
} // namespace
20+
21+
Error setCurrentCUDAStream(cudaStream_t stream, DeviceIndex device_index) {
22+
if (device_index == -1) {
23+
// Get current device if not specified
24+
// CUDA API returns int, explicit cast to DeviceIndex (int8_t) following
25+
// ATen
26+
int tmp_device = -1;
27+
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaGetDevice(&tmp_device));
28+
device_index = static_cast<DeviceIndex>(tmp_device);
29+
}
30+
31+
current_streams_[device_index] = stream;
32+
return Error::Ok;
33+
}
34+
35+
Result<cudaStream_t> getCurrentCUDAStream(DeviceIndex device_index) {
36+
if (device_index == -1) {
37+
// CUDA API returns int, explicit cast to DeviceIndex (int8_t) following
38+
// ATen
39+
int tmp_device = -1;
40+
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaGetDevice(&tmp_device));
41+
device_index = static_cast<DeviceIndex>(tmp_device);
42+
}
43+
44+
auto it = current_streams_.find(device_index);
45+
if (it != current_streams_.end()) {
46+
return it->second;
47+
}
48+
49+
cudaStream_t stream;
50+
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaStreamCreate(&stream));
51+
setCurrentCUDAStream(stream, device_index);
52+
return stream;
53+
}
54+
55+
CUDAGuard::CUDAGuard(CUDAGuard&& other) noexcept
56+
: original_device_index_(other.original_device_index_),
57+
current_device_index_(other.current_device_index_) {
58+
// Mark the moved-from object as "already restored" so its destructor doesn't
59+
// try to restore the device
60+
other.original_device_index_ = other.current_device_index_;
61+
}
62+
63+
CUDAGuard::~CUDAGuard() {
64+
if (original_device_index_ != current_device_index_) {
65+
// DeviceIndex (int8_t) implicitly widens to int for cudaSetDevice
66+
cudaError_t err = cudaSetDevice(original_device_index_);
67+
if (err != cudaSuccess) {
68+
ET_LOG(
69+
Error,
70+
"~CUDAGuard: Failed to restore device to %d: %s",
71+
static_cast<int>(original_device_index_),
72+
cudaGetErrorString(err));
73+
}
74+
}
75+
}
76+
77+
Error CUDAGuard::set_index(DeviceIndex device_index) {
78+
// CUDA API returns int, explicit cast to DeviceIndex (int8_t) following ATen
79+
int tmp_device = -1;
80+
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaGetDevice(&tmp_device));
81+
82+
original_device_index_ = static_cast<DeviceIndex>(tmp_device);
83+
current_device_index_ = device_index;
84+
85+
if (current_device_index_ != original_device_index_) {
86+
// DeviceIndex (int8_t) implicitly widens to int for cudaSetDevice
87+
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaSetDevice(current_device_index_));
88+
}
89+
90+
return Error::Ok;
91+
}
92+
93+
Result<CUDAGuard> CUDAGuard::create(DeviceIndex device_index) {
94+
CUDAGuard guard; // Fixed: Removed () to create a variable, not a function
95+
ET_CHECK_OK_OR_RETURN_ERROR(guard.set_index(device_index));
96+
return guard;
97+
}
98+
99+
CUDAStreamGuard::CUDAStreamGuard(CUDAStreamGuard&& other) noexcept
100+
: device_guard_(std::move(other.device_guard_)),
101+
original_stream_(other.original_stream_),
102+
current_stream_(other.current_stream_),
103+
device_index_(other.device_index_) {
104+
// Mark the moved-from object as "already restored" so its destructor doesn't
105+
// try to restore the stream
106+
other.original_stream_ = other.current_stream_;
107+
}
108+
109+
CUDAStreamGuard::~CUDAStreamGuard() {
110+
// Restore the original stream unless this object was moved-from.
111+
// After a move, original_stream_ == current_stream_, which indicates
112+
// the moved-from object should not restore.
113+
// Note: nullptr is a valid stream value (represents the default stream),
114+
// so we must restore even if original_stream_ is nullptr.
115+
if (original_stream_ != current_stream_) {
116+
Error err = setCurrentCUDAStream(original_stream_, device_index_);
117+
if (err != Error::Ok) {
118+
ET_LOG(
119+
Error,
120+
"~CUDAStreamGuard: Failed to restore stream for device %d",
121+
static_cast<int>(device_index_));
122+
}
123+
}
124+
}
125+
126+
Error CUDAStreamGuard::set_stream(
127+
cudaStream_t stream,
128+
DeviceIndex device_index) {
129+
auto result = getCurrentCUDAStream(device_index);
130+
if (!result.ok()) {
131+
ET_LOG(
132+
Error,
133+
"Failed to get current stream for device %d",
134+
static_cast<int>(device_index));
135+
return result.error();
136+
}
137+
138+
original_stream_ = result.get();
139+
current_stream_ = stream;
140+
device_index_ = device_index;
141+
142+
ET_CHECK_OK_OR_RETURN_ERROR(setCurrentCUDAStream(stream, device_index));
143+
144+
return Error::Ok;
145+
}
146+
147+
Result<CUDAStreamGuard> CUDAStreamGuard::create(
148+
cudaStream_t stream,
149+
DeviceIndex device_index) {
150+
auto guard_result = CUDAGuard::create(device_index);
151+
ET_CHECK_OK_OR_RETURN_ERROR(guard_result.error());
152+
153+
CUDAStreamGuard stream_guard(std::move(guard_result.get()));
154+
ET_CHECK_OK_OR_RETURN_ERROR(stream_guard.set_stream(stream, device_index));
155+
156+
return stream_guard;
157+
}
158+
159+
} // namespace executorch::backends::cuda

0 commit comments

Comments
 (0)