From d7c581eb5ba93a40009b466bfd5c127dc0a48152 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Mon, 5 Jan 2026 11:02:22 -0800 Subject: [PATCH] [slimtensor] Add CUDA Storage with DeviceTraits and memory allocation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This diff adds CUDA storage infrastructure to SlimTensor, enabling GPU memory allocation and management. **Key changes:** 1. **`cuda/Guard.h`** - CUDAGuard RAII class: - Saves current CUDA device on construction, restores on destruction - Exception-safe device context switching - Constructors accept device index or Device object 2. **`core/Storage.h`** - Extended for CUDA support: - Added `DeviceTraits` specialization with: - `allocate()` - Uses cudaMalloc with CUDAGuard for device selection - `free()` - Uses cudaFree with warning on error - `memcpy()` - Supports Host↔Device and Device↔Device copies - Added `DEFAULT_CUDA_DEVICE` constant - Updated `MaybeOwningStorage` constructor to handle CUDA devices - Stub implementation when `CUDA_AVAILABLE` is not defined (throws error) Differential Revision: [D89826553](https://our.internmc.facebook.com/intern/diff/D89826553/) [ghstack-poisoned] --- backends/aoti/slim/c10/cuda/Exception.h | 40 +++ backends/aoti/slim/c10/cuda/TARGETS | 6 + backends/aoti/slim/c10/cuda/targets.bzl | 16 + backends/aoti/slim/core/Storage.h | 115 +++++- backends/aoti/slim/core/targets.bzl | 4 +- backends/aoti/slim/core/test/targets.bzl | 37 +- backends/aoti/slim/core/test/test_storage.cpp | 331 ++++++++++++++---- backends/aoti/slim/cuda/Guard.h | 82 +++++ backends/aoti/slim/cuda/TARGETS | 6 + backends/aoti/slim/cuda/targets.bzl | 16 + 10 files changed, 575 insertions(+), 78 deletions(-) create mode 100644 backends/aoti/slim/c10/cuda/Exception.h create mode 100644 backends/aoti/slim/c10/cuda/TARGETS create mode 100644 backends/aoti/slim/c10/cuda/targets.bzl create mode 100644 backends/aoti/slim/cuda/Guard.h create mode 100644 backends/aoti/slim/cuda/TARGETS create mode 100644 backends/aoti/slim/cuda/targets.bzl diff --git a/backends/aoti/slim/c10/cuda/Exception.h b/backends/aoti/slim/c10/cuda/Exception.h new file mode 100644 index 00000000000..33d8414e661 --- /dev/null +++ b/backends/aoti/slim/c10/cuda/Exception.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#ifdef CUDA_AVAILABLE + +#include +#include + +#include +#include +#include + +/// Checks a CUDA expression and aborts on error. +/// @param EXPR The CUDA expression to check. +#define ET_CUDA_CHECK(EXPR) \ + do { \ + const cudaError_t __err = EXPR; \ + ET_CHECK_MSG( \ + __err == cudaSuccess, "CUDA error: %s", cudaGetErrorString(__err)); \ + } while (0) + +/// Checks a CUDA expression and logs a warning on error (non-fatal). +/// @param EXPR The CUDA expression to check. +#define ET_CUDA_LOG_WARN(EXPR) \ + do { \ + const cudaError_t __err = EXPR; \ + if (SLIMTENSOR_UNLIKELY(__err != cudaSuccess)) { \ + [[maybe_unused]] auto error_unused = cudaGetLastError(); \ + ET_LOG(Error, "CUDA warning: %s", cudaGetErrorString(__err)); \ + } \ + } while (0) + +#endif // CUDA_AVAILABLE diff --git a/backends/aoti/slim/c10/cuda/TARGETS b/backends/aoti/slim/c10/cuda/TARGETS new file mode 100644 index 00000000000..08e83a5f3c4 --- /dev/null +++ b/backends/aoti/slim/c10/cuda/TARGETS @@ -0,0 +1,6 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") +load(":targets.bzl", "define_common_targets") + +oncall("executorch") + +define_common_targets() diff --git a/backends/aoti/slim/c10/cuda/targets.bzl b/backends/aoti/slim/c10/cuda/targets.bzl new file mode 100644 index 00000000000..1d44bd1f032 --- /dev/null +++ b/backends/aoti/slim/c10/cuda/targets.bzl @@ -0,0 +1,16 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +def define_common_targets(): + """Define targets for SlimTensor CUDA exception handling module.""" + + runtime.cxx_library( + name = "exception", + exported_headers = [ + "Exception.h", + ], + visibility = ["@EXECUTORCH_CLIENTS"], + exported_deps = [ + "//executorch/backends/aoti/slim/c10/macros:macros", + "//executorch/runtime/platform:platform", + ], + ) diff --git a/backends/aoti/slim/core/Storage.h b/backends/aoti/slim/core/Storage.h index d122e86c1d4..121031a6d59 100644 --- a/backends/aoti/slim/core/Storage.h +++ b/backends/aoti/slim/core/Storage.h @@ -10,12 +10,18 @@ #include +#ifdef CUDA_AVAILABLE +#include +#include +#endif + #include #include #include #include #include #include +#include namespace executorch::backends::aoti::slim { @@ -30,6 +36,10 @@ inline void noop(void*) {} /// Default CPU device constant. inline const c10::Device CPU_DEVICE = c10::Device(c10::DeviceType::CPU, 0); +/// Default CUDA device constant. +inline const c10::Device DEFAULT_CUDA_DEVICE = + c10::Device(c10::DeviceType::CUDA, 0); + /// DeviceTraits template for device-specific operations. /// Device-specific implementations provide allocate(), free(), and memcpy(). template @@ -74,6 +84,93 @@ struct DeviceTraits { } }; +#ifdef CUDA_AVAILABLE +/// CUDA specialization of DeviceTraits. +/// Provides CUDA memory allocation and copy operations using +/// cudaMalloc/cudaFree. +template <> +struct DeviceTraits { + /// Allocates CUDA device memory. + /// @param nbytes Number of bytes to allocate. + /// @param device The target CUDA device. + /// @return Pointer to allocated device memory. + static void* allocate(size_t nbytes, const c10::Device& device) { + cuda::CUDAGuard guard(device); + void* data = nullptr; + ET_CUDA_CHECK(cudaMalloc(&data, nbytes)); + return data; + } + + /// Frees CUDA device memory. + /// @param ptr Pointer to device memory to free. + static void free(void* ptr) { + ET_CUDA_LOG_WARN(cudaFree(ptr)); + } + + /// Copies memory between CPU and CUDA or CUDA and CUDA. + /// @param dst Destination pointer. + /// @param src Source pointer. + /// @param nbytes Number of bytes to copy. + /// @param dst_device Destination device. + /// @param src_device Source device. + static void memcpy( + void* dst, + const void* src, + size_t nbytes, + const c10::Device& dst_device, + const c10::Device& src_device) { + cudaMemcpyKind direction = cudaMemcpyDeviceToDevice; + c10::Device cuda_device = dst_device; + + if (src_device.is_cpu()) { + direction = cudaMemcpyHostToDevice; + } else if (dst_device.is_cpu()) { + direction = cudaMemcpyDeviceToHost; + cuda_device = src_device; + } else { + ET_CHECK_MSG( + src_device.index() == dst_device.index(), + "CUDA memcpy across different device indices not supported: %d != %d", + static_cast(src_device.index()), + static_cast(dst_device.index())); + } + + cuda::CUDAGuard guard(cuda_device); + ET_CUDA_CHECK(cudaMemcpy(dst, src, nbytes, direction)); + } +}; +#else +/// CUDA stub when CUDA_AVAILABLE is not defined. +/// All operations abort with an error message. +template <> +struct DeviceTraits { + static void* allocate(size_t nbytes, const c10::Device& device) { + (void)nbytes; + (void)device; + ET_CHECK_MSG(false, "Build with CUDA_AVAILABLE=1 to enable CUDA support"); + } + + static void free(void* ptr) { + (void)ptr; + ET_LOG(Error, "Build with CUDA_AVAILABLE=1 to enable CUDA support"); + } + + static void memcpy( + void* dst, + const void* src, + size_t nbytes, + const c10::Device& dst_device, + const c10::Device& src_device) { + (void)dst; + (void)src; + (void)nbytes; + (void)dst_device; + (void)src_device; + ET_CHECK_MSG(false, "Build with CUDA_AVAILABLE=1 to enable CUDA support"); + } +}; +#endif // CUDA_AVAILABLE + /** * MaybeOwningStorage - A storage class that manages tensor data memory. * @@ -93,17 +190,19 @@ struct DeviceTraits { class MaybeOwningStorage { public: /// Constructs owning storage with allocated memory. - /// @param device The device for storage (must be CPU). + /// @param device The device for storage (CPU or CUDA). /// @param nbytes Number of bytes to allocate. MaybeOwningStorage(const c10::Device& device, size_t nbytes) : device_(device), capacity_(nbytes), is_owning_(true) { - ET_CHECK_MSG( - device.is_cpu(), - "Only CPU device is currently supported, got: %s", - device.str().c_str()); - - data_ = DeviceTraits::allocate(nbytes, device); - deleter_ = DeviceTraits::free; + if (device.is_cpu()) { + data_ = DeviceTraits::allocate(nbytes, device); + deleter_ = DeviceTraits::free; + } else if (device.is_cuda()) { + data_ = DeviceTraits::allocate(nbytes, device); + deleter_ = DeviceTraits::free; + } else { + ET_CHECK_MSG(false, "Unsupported device type: %s", device.str().c_str()); + } } /// Default constructor is deleted - storage must have a device. diff --git a/backends/aoti/slim/core/targets.bzl b/backends/aoti/slim/core/targets.bzl index 2056b8c6866..d0ee397c112 100644 --- a/backends/aoti/slim/core/targets.bzl +++ b/backends/aoti/slim/core/targets.bzl @@ -17,10 +17,12 @@ def define_common_targets(): "//executorch/backends/aoti/slim/util:shared_ptr", "//executorch/backends/aoti/slim/util:size_util", "//executorch/runtime/platform:platform", + "//executorch/backends/aoti/slim/c10/cuda:exception", + "//executorch/backends/aoti/slim/cuda:guard", ], ) - # Header-only library for SlimTensor + # Header-only library for SlimTensor (CPU-only for now) runtime.cxx_library( name = "slimtensor", headers = [ diff --git a/backends/aoti/slim/core/test/targets.bzl b/backends/aoti/slim/core/test/targets.bzl index c7debd46836..3a7e99dd37c 100644 --- a/backends/aoti/slim/core/test/targets.bzl +++ b/backends/aoti/slim/core/test/targets.bzl @@ -1,17 +1,36 @@ +load("@fbcode_macros//build_defs/lib:re_test_utils.bzl", "re_test_utils") load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") +def get_backend_mode(): + """Get the supported backend mode of slimtensor.""" + return ["cuda", "cpu"] + def define_common_targets(): """Define test targets for SlimTensor core module.""" - runtime.cxx_test( - name = "test_storage", - srcs = [ - "test_storage.cpp", - ], - deps = [ - "//executorch/backends/aoti/slim/core:storage", - ], - ) + # GPU storage test with CUDA support + for backend_mode in get_backend_mode(): + backend_suffix = "_" + backend_mode if backend_mode == "cuda" else "" + + backend_kwargs = { + "external_deps": [("cuda", None, "cuda-lazy")], + "preprocessor_flags": ["-DCUDA_AVAILABLE=1"], + "keep_gpu_sections": True, + "remote_execution": re_test_utils.remote_execution( + platform = "gpu-remote-execution", + ), + } if backend_mode == "cuda" else {} + + runtime.cxx_test( + name = "test_storage" + backend_suffix, + srcs = [ + "test_storage.cpp", + ], + deps = [ + "//executorch/backends/aoti/slim/core:storage", + ], + **backend_kwargs + ) runtime.cxx_test( name = "test_slimtensor_basic", diff --git a/backends/aoti/slim/core/test/test_storage.cpp b/backends/aoti/slim/core/test/test_storage.cpp index bf92eb46a72..8a5e78ba058 100644 --- a/backends/aoti/slim/core/test/test_storage.cpp +++ b/backends/aoti/slim/core/test/test_storage.cpp @@ -10,8 +10,29 @@ #include +#ifdef CUDA_AVAILABLE +#include +#endif + namespace executorch::backends::aoti::slim { +// ============================================================================= +// Test Device Helpers +// ============================================================================= + +inline std::vector getTestDevices() { + std::vector devices = {CPU_DEVICE}; +#ifdef CUDA_AVAILABLE + devices.push_back(DEFAULT_CUDA_DEVICE); +#endif + return devices; +} + +inline std::string deviceToString( + const testing::TestParamInfo& info) { + return info.param.is_cpu() ? "CPU" : "CUDA"; +} + // ============================================================================= // DeviceTraits Tests // ============================================================================= @@ -52,48 +73,39 @@ TEST(DeviceTraitsCPUTest, MemcpyCPUToCPU) { } // ============================================================================= -// MaybeOwningStorage Tests - Owning Mode +// MaybeOwningStorage Parameterized Tests (CPU and CUDA) // ============================================================================= -TEST(MaybeOwningStorageTest, ConstructOwning) { +class MaybeOwningStorageParamTest : public testing::TestWithParam { + protected: + c10::Device device() const { + return GetParam(); + } +}; + +TEST_P(MaybeOwningStorageParamTest, ConstructOwning) { constexpr size_t kNbytes = 512; - MaybeOwningStorage storage(CPU_DEVICE, kNbytes); + MaybeOwningStorage storage(device(), kNbytes); EXPECT_NE(storage.data(), nullptr); EXPECT_EQ(storage.nbytes(), kNbytes); - EXPECT_TRUE(storage.device().is_cpu()); + EXPECT_EQ(storage.device().type(), device().type()); EXPECT_TRUE(storage.is_owning()); EXPECT_TRUE(storage.is_resizable()); } -TEST(MaybeOwningStorageTest, ConstructOwningZeroBytes) { - MaybeOwningStorage storage(CPU_DEVICE, 0); +TEST_P(MaybeOwningStorageParamTest, ConstructOwningZeroBytes) { + MaybeOwningStorage storage(device(), 0); EXPECT_EQ(storage.data(), nullptr); EXPECT_EQ(storage.nbytes(), 0); - EXPECT_TRUE(storage.device().is_cpu()); + EXPECT_EQ(storage.device().type(), device().type()); EXPECT_TRUE(storage.is_owning()); } -TEST(MaybeOwningStorageTest, DataPersistence) { - constexpr size_t kNumFloats = 64; - constexpr size_t kNbytes = kNumFloats * sizeof(float); - MaybeOwningStorage storage(CPU_DEVICE, kNbytes); - - float* data = static_cast(storage.data()); - for (size_t i = 0; i < kNumFloats; ++i) { - data[i] = static_cast(i) * 2.0f; - } - - float* read_data = static_cast(storage.data()); - for (size_t i = 0; i < kNumFloats; ++i) { - EXPECT_FLOAT_EQ(read_data[i], static_cast(i) * 2.0f); - } -} - -TEST(MaybeOwningStorageTest, MoveConstruct) { +TEST_P(MaybeOwningStorageParamTest, MoveConstruct) { constexpr size_t kNbytes = 256; - MaybeOwningStorage original(CPU_DEVICE, kNbytes); + MaybeOwningStorage original(device(), kNbytes); void* original_data = original.data(); MaybeOwningStorage moved(std::move(original)); @@ -101,17 +113,18 @@ TEST(MaybeOwningStorageTest, MoveConstruct) { EXPECT_EQ(moved.data(), original_data); EXPECT_EQ(moved.nbytes(), kNbytes); EXPECT_TRUE(moved.is_owning()); + EXPECT_EQ(moved.device().type(), device().type()); EXPECT_EQ(original.data(), nullptr); EXPECT_EQ(original.nbytes(), 0); EXPECT_FALSE(original.is_owning()); } -TEST(MaybeOwningStorageTest, MoveAssign) { +TEST_P(MaybeOwningStorageParamTest, MoveAssign) { constexpr size_t kNbytes1 = 256; constexpr size_t kNbytes2 = 512; - MaybeOwningStorage storage1(CPU_DEVICE, kNbytes1); - MaybeOwningStorage storage2(CPU_DEVICE, kNbytes2); + MaybeOwningStorage storage1(device(), kNbytes1); + MaybeOwningStorage storage2(device(), kNbytes2); void* storage2_data = storage2.data(); storage1 = std::move(storage2); @@ -125,7 +138,33 @@ TEST(MaybeOwningStorageTest, MoveAssign) { EXPECT_FALSE(storage2.is_owning()); } -TEST(MaybeOwningStorageTest, Clone) { +INSTANTIATE_TEST_SUITE_P( + DeviceTests, + MaybeOwningStorageParamTest, + testing::ValuesIn(getTestDevices()), + deviceToString); + +// ============================================================================= +// MaybeOwningStorage CPU-Only Tests (require direct data access) +// ============================================================================= + +TEST(MaybeOwningStorageCPUTest, DataPersistence) { + constexpr size_t kNumFloats = 64; + constexpr size_t kNbytes = kNumFloats * sizeof(float); + MaybeOwningStorage storage(CPU_DEVICE, kNbytes); + + float* data = static_cast(storage.data()); + for (size_t i = 0; i < kNumFloats; ++i) { + data[i] = static_cast(i) * 2.0f; + } + + float* read_data = static_cast(storage.data()); + for (size_t i = 0; i < kNumFloats; ++i) { + EXPECT_FLOAT_EQ(read_data[i], static_cast(i) * 2.0f); + } +} + +TEST(MaybeOwningStorageCPUTest, Clone) { constexpr size_t kNumFloats = 32; constexpr size_t kNbytes = kNumFloats * sizeof(float); MaybeOwningStorage original(CPU_DEVICE, kNbytes); @@ -150,7 +189,7 @@ TEST(MaybeOwningStorageTest, Clone) { EXPECT_FLOAT_EQ(cloned_data[0], 0.0f); } -TEST(MaybeOwningStorageTest, CopyFunction) { +TEST(MaybeOwningStorageCPUTest, CopyFunction) { constexpr size_t kNumFloats = 16; constexpr size_t kNbytes = kNumFloats * sizeof(float); MaybeOwningStorage src_storage(CPU_DEVICE, kNbytes); @@ -171,26 +210,33 @@ TEST(MaybeOwningStorageTest, CopyFunction) { } // ============================================================================= -// Storage (SharedPtr) Tests +// Storage (SharedPtr) Parameterized Tests // ============================================================================= -TEST(StorageSharedPtrTest, BasicUsage) { +class StorageSharedPtrParamTest : public testing::TestWithParam { + protected: + c10::Device device() const { + return GetParam(); + } +}; + +TEST_P(StorageSharedPtrParamTest, BasicUsage) { constexpr size_t kNbytes = 128; - Storage storage(new MaybeOwningStorage(CPU_DEVICE, kNbytes)); + Storage storage(new MaybeOwningStorage(device(), kNbytes)); EXPECT_NE(storage.get(), nullptr); EXPECT_NE(storage->data(), nullptr); EXPECT_EQ(storage->nbytes(), kNbytes); - EXPECT_TRUE(storage->device().is_cpu()); + EXPECT_EQ(storage->device().type(), device().type()); EXPECT_EQ(storage.use_count(), 1); } -TEST(StorageSharedPtrTest, SharedOwnership) { +TEST_P(StorageSharedPtrParamTest, SharedOwnership) { constexpr size_t kNbytes = 128; - Storage storage1(new MaybeOwningStorage(CPU_DEVICE, kNbytes)); + Storage storage1(new MaybeOwningStorage(device(), kNbytes)); void* data_ptr = storage1->data(); - const Storage& storage2 = storage1; + Storage storage2 = storage1; EXPECT_EQ(storage1.use_count(), 2); EXPECT_EQ(storage2.use_count(), 2); @@ -198,7 +244,52 @@ TEST(StorageSharedPtrTest, SharedOwnership) { EXPECT_EQ(storage2->data(), data_ptr); } -TEST(StorageSharedPtrTest, SharedOwnershipModification) { +TEST_P(StorageSharedPtrParamTest, ReferenceCountDecrement) { + constexpr size_t kNbytes = 64; + Storage storage1(new MaybeOwningStorage(device(), kNbytes)); + EXPECT_EQ(storage1.use_count(), 1); + + { + Storage storage2 = storage1; + EXPECT_EQ(storage1.use_count(), 2); + } + + EXPECT_EQ(storage1.use_count(), 1); +} + +TEST_P(StorageSharedPtrParamTest, MoveSemantics) { + constexpr size_t kNbytes = 64; + Storage storage1(new MaybeOwningStorage(device(), kNbytes)); + void* data_ptr = storage1->data(); + + Storage storage2 = std::move(storage1); + + EXPECT_EQ(storage1.get(), nullptr); + EXPECT_EQ(storage2->data(), data_ptr); + EXPECT_EQ(storage2.use_count(), 1); +} + +TEST_P(StorageSharedPtrParamTest, MakeShared) { + constexpr size_t kNbytes = 256; + Storage storage = make_shared(device(), kNbytes); + + EXPECT_NE(storage.get(), nullptr); + EXPECT_NE(storage->data(), nullptr); + EXPECT_EQ(storage->nbytes(), kNbytes); + EXPECT_EQ(storage.use_count(), 1); +} + +INSTANTIATE_TEST_SUITE_P( + DeviceTests, + StorageSharedPtrParamTest, + testing::ValuesIn(getTestDevices()), + deviceToString); + +// ============================================================================= +// Storage CPU-Only Tests (require direct data access) +// ============================================================================= + +TEST(StorageSharedPtrCPUTest, SharedOwnershipModification) { constexpr size_t kNumFloats = 8; constexpr size_t kNbytes = kNumFloats * sizeof(float); Storage storage1(new MaybeOwningStorage(CPU_DEVICE, kNbytes)); @@ -208,7 +299,7 @@ TEST(StorageSharedPtrTest, SharedOwnershipModification) { data[i] = 0.0f; } - const Storage& storage2 = storage1; + Storage storage2 = storage1; float* data2 = static_cast(storage2->data()); for (size_t i = 0; i < kNumFloats; ++i) { @@ -221,36 +312,156 @@ TEST(StorageSharedPtrTest, SharedOwnershipModification) { } } -TEST(StorageSharedPtrTest, ReferenceCountDecrement) { - constexpr size_t kNbytes = 64; - Storage storage1(new MaybeOwningStorage(CPU_DEVICE, kNbytes)); - EXPECT_EQ(storage1.use_count(), 1); +#ifdef CUDA_AVAILABLE - { EXPECT_EQ(storage1.use_count(), 2); } +// ============================================================================= +// DeviceTraits Tests +// ============================================================================= - EXPECT_EQ(storage1.use_count(), 1); +TEST(DeviceTraitsCUDATest, AllocateAndFree) { + constexpr size_t kSize = 1024; + void* ptr = + DeviceTraits::allocate(kSize, DEFAULT_CUDA_DEVICE); + ASSERT_NE(ptr, nullptr); + + DeviceTraits::free(ptr); } -TEST(StorageSharedPtrTest, MoveSemantics) { - constexpr size_t kNbytes = 64; - Storage storage1(new MaybeOwningStorage(CPU_DEVICE, kNbytes)); - void* data_ptr = storage1->data(); +TEST(DeviceTraitsCUDATest, AllocateZeroBytes) { + void* ptr = + DeviceTraits::allocate(0, DEFAULT_CUDA_DEVICE); + DeviceTraits::free(ptr); +} - Storage storage2 = std::move(storage1); +TEST(DeviceTraitsCUDATest, MemcpyCPUToCUDA) { + constexpr size_t kSize = 256; + float* cpu_src = static_cast( + DeviceTraits::allocate(kSize * sizeof(float))); + float* cuda_dst = + static_cast(DeviceTraits::allocate( + kSize * sizeof(float), DEFAULT_CUDA_DEVICE)); + float* cpu_verify = static_cast( + DeviceTraits::allocate(kSize * sizeof(float))); - EXPECT_EQ(storage1.get(), nullptr); - EXPECT_EQ(storage2->data(), data_ptr); - EXPECT_EQ(storage2.use_count(), 1); + for (size_t i = 0; i < kSize; ++i) { + cpu_src[i] = static_cast(i) * 2.5f; + } + + // Copy CPU -> CUDA + DeviceTraits::memcpy( + cuda_dst, + cpu_src, + kSize * sizeof(float), + DEFAULT_CUDA_DEVICE, + CPU_DEVICE); + + // Copy CUDA -> CPU to verify + DeviceTraits::memcpy( + cpu_verify, + cuda_dst, + kSize * sizeof(float), + CPU_DEVICE, + DEFAULT_CUDA_DEVICE); + + for (size_t i = 0; i < kSize; ++i) { + EXPECT_FLOAT_EQ(cpu_verify[i], static_cast(i) * 2.5f); + } + + DeviceTraits::free(cpu_src); + DeviceTraits::free(cuda_dst); + DeviceTraits::free(cpu_verify); } -TEST(StorageSharedPtrTest, MakeShared) { - constexpr size_t kNbytes = 256; - Storage storage = make_shared(CPU_DEVICE, kNbytes); +TEST(DeviceTraitsCUDATest, MemcpyCUDAToCPU) { + constexpr size_t kSize = 128; + float* cpu_src = static_cast( + DeviceTraits::allocate(kSize * sizeof(float))); + float* cuda_mem = + static_cast(DeviceTraits::allocate( + kSize * sizeof(float), DEFAULT_CUDA_DEVICE)); + float* cpu_dst = static_cast( + DeviceTraits::allocate(kSize * sizeof(float))); - EXPECT_NE(storage.get(), nullptr); - EXPECT_NE(storage->data(), nullptr); - EXPECT_EQ(storage->nbytes(), kNbytes); - EXPECT_EQ(storage.use_count(), 1); + for (size_t i = 0; i < kSize; ++i) { + cpu_src[i] = static_cast(i) + 100.0f; + } + + // Copy CPU -> CUDA + DeviceTraits::memcpy( + cuda_mem, + cpu_src, + kSize * sizeof(float), + DEFAULT_CUDA_DEVICE, + CPU_DEVICE); + + // Copy CUDA -> CPU + DeviceTraits::memcpy( + cpu_dst, + cuda_mem, + kSize * sizeof(float), + CPU_DEVICE, + DEFAULT_CUDA_DEVICE); + + for (size_t i = 0; i < kSize; ++i) { + EXPECT_FLOAT_EQ(cpu_dst[i], static_cast(i) + 100.0f); + } + + DeviceTraits::free(cpu_src); + DeviceTraits::free(cuda_mem); + DeviceTraits::free(cpu_dst); +} + +TEST(DeviceTraitsCUDATest, MemcpyCUDAToCUDA) { + constexpr size_t kSize = 64; + float* cpu_src = static_cast( + DeviceTraits::allocate(kSize * sizeof(float))); + float* cuda_src = + static_cast(DeviceTraits::allocate( + kSize * sizeof(float), DEFAULT_CUDA_DEVICE)); + float* cuda_dst = + static_cast(DeviceTraits::allocate( + kSize * sizeof(float), DEFAULT_CUDA_DEVICE)); + float* cpu_verify = static_cast( + DeviceTraits::allocate(kSize * sizeof(float))); + + for (size_t i = 0; i < kSize; ++i) { + cpu_src[i] = static_cast(i) * 3.0f; + } + + // Copy CPU -> CUDA src + DeviceTraits::memcpy( + cuda_src, + cpu_src, + kSize * sizeof(float), + DEFAULT_CUDA_DEVICE, + CPU_DEVICE); + + // Copy CUDA src -> CUDA dst + DeviceTraits::memcpy( + cuda_dst, + cuda_src, + kSize * sizeof(float), + DEFAULT_CUDA_DEVICE, + DEFAULT_CUDA_DEVICE); + + // Copy CUDA dst -> CPU to verify + DeviceTraits::memcpy( + cpu_verify, + cuda_dst, + kSize * sizeof(float), + CPU_DEVICE, + DEFAULT_CUDA_DEVICE); + + for (size_t i = 0; i < kSize; ++i) { + EXPECT_FLOAT_EQ(cpu_verify[i], static_cast(i) * 3.0f); + } + + DeviceTraits::free(cpu_src); + DeviceTraits::free(cuda_src); + DeviceTraits::free(cuda_dst); + DeviceTraits::free(cpu_verify); } +#endif // CUDA_AVAILABLE + } // namespace executorch::backends::aoti::slim diff --git a/backends/aoti/slim/cuda/Guard.h b/backends/aoti/slim/cuda/Guard.h new file mode 100644 index 00000000000..f553e0dcc4b --- /dev/null +++ b/backends/aoti/slim/cuda/Guard.h @@ -0,0 +1,82 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#ifdef CUDA_AVAILABLE + +#include +#include + +#include +#include + +namespace executorch::backends::aoti::slim::cuda { + +/** + * CUDAGuard - RAII class that sets the current CUDA device. + * + * This class saves the current CUDA device on construction and restores it + * on destruction, providing exception-safe device switching. + * + * Thread Safety: NOT THREAD-SAFE + * - Must only be used within a single thread + */ +struct CUDAGuard { + /// No default constructor - device must be specified. + CUDAGuard() = delete; + + /// Sets the current CUDA device to the specified device index. + /// @param device_index The CUDA device index to switch to. + explicit CUDAGuard(c10::DeviceIndex device_index) { + set_index(device_index); + } + + /// Sets the current CUDA device to the specified device. + /// @param device The CUDA device to switch to. Must be a CUDA device. + explicit CUDAGuard(c10::Device device) { + ET_CHECK_MSG(device.is_cuda(), "Expected a CUDA device for CUDAGuard"); + set_index(device.index()); + } + + // Copy is not allowed + CUDAGuard(const CUDAGuard&) = delete; + CUDAGuard& operator=(const CUDAGuard&) = delete; + + // Move is not allowed + CUDAGuard(CUDAGuard&& other) = delete; + CUDAGuard& operator=(CUDAGuard&& other) = delete; + + /// Restores the original CUDA device on destruction. + ~CUDAGuard() { + if (original_device_index_ != current_device_index_) { + ET_CUDA_LOG_WARN(cudaSetDevice(original_device_index_)); + } + } + + /// Sets the CUDA device to the given device index. + /// @param device_index The device index to switch to. + void set_index(c10::DeviceIndex device_index) { + int orig_index = -1; + ET_CUDA_CHECK(cudaGetDevice(&orig_index)); + + original_device_index_ = orig_index; + current_device_index_ = device_index; + if (current_device_index_ != original_device_index_) { + ET_CUDA_CHECK(cudaSetDevice(current_device_index_)); + } + } + + private: + c10::DeviceIndex original_device_index_; + c10::DeviceIndex current_device_index_; +}; + +} // namespace executorch::backends::aoti::slim::cuda + +#endif // CUDA_AVAILABLE diff --git a/backends/aoti/slim/cuda/TARGETS b/backends/aoti/slim/cuda/TARGETS new file mode 100644 index 00000000000..08e83a5f3c4 --- /dev/null +++ b/backends/aoti/slim/cuda/TARGETS @@ -0,0 +1,6 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") +load(":targets.bzl", "define_common_targets") + +oncall("executorch") + +define_common_targets() diff --git a/backends/aoti/slim/cuda/targets.bzl b/backends/aoti/slim/cuda/targets.bzl new file mode 100644 index 00000000000..059d27034ae --- /dev/null +++ b/backends/aoti/slim/cuda/targets.bzl @@ -0,0 +1,16 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +def define_common_targets(): + """Define targets for SlimTensor CUDA guard module.""" + + runtime.cxx_library( + name = "guard", + exported_headers = [ + "Guard.h", + ], + visibility = ["@EXECUTORCH_CLIENTS"], + exported_deps = [ + "//executorch/backends/aoti/slim/c10/core:device", + "//executorch/backends/aoti/slim/c10/cuda:exception", + ], + )