diff --git a/cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx b/cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx index edd2ab2c5..074724b7e 100644 --- a/cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx +++ b/cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx @@ -212,7 +212,13 @@ cdef class ParamHolder: for i, arg in enumerate(kernel_args): if isinstance(arg, Buffer): # we need the address of where the actual buffer address is stored - self.data_addresses[i] = (arg.handle.getPtr()) + if isinstance(arg.handle, int): + # see note below on handling int arguments + prepare_arg[intptr_t](self.data, self.data_addresses, arg.handle, i) + continue + else: + # it's a CUdeviceptr: + self.data_addresses[i] = (arg.handle.getPtr()) continue elif isinstance(arg, int): # Here's the dilemma: We want to have a fast path to pass in Python diff --git a/cuda_core/examples/memory_ops.py b/cuda_core/examples/memory_ops.py new file mode 100644 index 000000000..ceff29dd3 --- /dev/null +++ b/cuda_core/examples/memory_ops.py @@ -0,0 +1,137 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# +# SPDX-License-Identifier: Apache-2.0 + +# ################################################################################ +# +# This demo illustrates: +# +# 1. How to use different memory resources to allocate and manage memory +# 2. How to copy data between different memory types +# 3. How to use DLPack to interoperate with other libraries +# +# ################################################################################ + +import sys + +import cupy as cp +import numpy as np + +from cuda.core.experimental import ( + Device, + LaunchConfig, + LegacyPinnedMemoryResource, + Program, + ProgramOptions, + launch, +) + +if np.__version__ < "2.1.0": + print("This example requires NumPy 2.1.0 or later", file=sys.stderr) + sys.exit(0) + +# Kernel for memory operations +code = """ +extern "C" +__global__ void memory_ops(float* device_data, + float* pinned_data, + size_t N) { + const unsigned int tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid < N) { + // Access device memory + device_data[tid] = device_data[tid] + 1.0f; + + // Access pinned memory (zero-copy from GPU) + pinned_data[tid] = pinned_data[tid] * 3.0f; + } +} +""" + +dev = Device() +dev.set_current() +stream = dev.create_stream() +# tell CuPy to use our stream as the current stream: +cp.cuda.ExternalStream(int(stream.handle)).use() + +# Compile kernel +arch = "".join(f"{i}" for i in dev.compute_capability) +program_options = ProgramOptions(std="c++17", arch=f"sm_{arch}") +prog = Program(code, code_type="c++", options=program_options) +mod = prog.compile("cubin") +kernel = mod.get_kernel("memory_ops") + +# Create different memory resources +device_mr = dev.memory_resource +pinned_mr = LegacyPinnedMemoryResource() + +# Allocate different types of memory +size = 1024 +dtype = cp.float32 +element_size = dtype().itemsize +total_size = size * element_size + +# 1. Device Memory (GPU-only) +device_buffer = device_mr.allocate(total_size, stream=stream) +device_array = cp.from_dlpack(device_buffer).view(dtype=dtype) + +# 2. Pinned Memory (CPU memory, GPU accessible) +pinned_buffer = pinned_mr.allocate(total_size, stream=stream) +pinned_array = np.from_dlpack(pinned_buffer).view(dtype=dtype) + +# Initialize data +rng = cp.random.default_rng() +device_array[:] = rng.random(size, dtype=dtype) +pinned_array[:] = rng.random(size, dtype=dtype).get() + +# Store original values for verification +device_original = device_array.copy() +pinned_original = pinned_array.copy() + +# Sync before kernel launch +stream.sync() + +# Launch kernel +block = 256 +grid = (size + block - 1) // block +config = LaunchConfig(grid=grid, block=block) + +launch(stream, config, kernel, device_buffer, pinned_buffer, cp.uint64(size)) +stream.sync() + +# Verify kernel operations +assert cp.allclose(device_array, device_original + 1.0), "Device memory operation failed" +assert cp.allclose(pinned_array, pinned_original * 3.0), "Pinned memory operation failed" + +# Copy data between different memory types +print("\nCopying data between memory types...") + +# Copy from device to pinned memory +device_buffer.copy_to(pinned_buffer, stream=stream) +stream.sync() + +# Verify the copy operation +assert cp.allclose(pinned_array, device_array), "Device to pinned copy failed" + +# Create a new device buffer and copy from pinned +new_device_buffer = device_mr.allocate(total_size, stream=stream) +new_device_array = cp.from_dlpack(new_device_buffer).view(dtype=dtype) + +pinned_buffer.copy_to(new_device_buffer, stream=stream) +stream.sync() + +# Verify the copy operation +assert cp.allclose(new_device_array, pinned_array), "Pinned to device copy failed" + +# Clean up +device_buffer.close(stream) +pinned_buffer.close(stream) +new_device_buffer.close(stream) +stream.close() +cp.cuda.Stream.null.use() # reset CuPy's current stream to the null stream + +# Verify buffers are properly closed +assert device_buffer.handle == 0, "Device buffer should be closed" +assert pinned_buffer.handle == 0, "Pinned buffer should be closed" +assert new_device_buffer.handle == 0, "New device buffer should be closed" + +print("Memory management example completed!") diff --git a/cuda_core/tests/test_launcher.py b/cuda_core/tests/test_launcher.py index 3a02065de..a6648d8a4 100644 --- a/cuda_core/tests/test_launcher.py +++ b/cuda_core/tests/test_launcher.py @@ -5,11 +5,21 @@ import os import pathlib +import cupy as cp import numpy as np import pytest from conftest import skipif_need_cuda_headers -from cuda.core.experimental import Device, LaunchConfig, LegacyPinnedMemoryResource, Program, ProgramOptions, launch +from cuda.core.experimental import ( + Device, + DeviceMemoryResource, + LaunchConfig, + LegacyPinnedMemoryResource, + Program, + ProgramOptions, + launch, +) +from cuda.core.experimental._memory import _SynchronousMemoryResource def test_launch_config_init(init_cuda): @@ -197,3 +207,102 @@ def test_cooperative_launch(): config = LaunchConfig(grid=1, block=1, cooperative_launch=True) launch(s, config, ker) s.sync() + + +@pytest.mark.parametrize( + "memory_resource_class", + [ + "device_memory_resource", # kludgy, but can go away after #726 is resolved + pytest.param( + LegacyPinnedMemoryResource, + marks=pytest.mark.skipif( + tuple(int(i) for i in np.__version__.split(".")[:3]) < (2, 2, 5), + reason="need numpy 2.2.5+, numpy GH #28632", + ), + ), + ], +) +def test_launch_with_buffers_allocated_by_memory_resource(init_cuda, memory_resource_class): + """Test that kernels can access memory allocated by memory resources.""" + dev = Device() + dev.set_current() + stream = dev.create_stream() + # tell CuPy to use our stream as the current stream: + cp.cuda.ExternalStream(int(stream.handle)).use() + + # Kernel that operates on memory + code = """ + extern "C" + __global__ void memory_ops(float* data, size_t N) { + const unsigned int tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid < N) { + // Access memory (device or pinned) + data[tid] = data[tid] * 3.0f; + } + } + """ + + # Compile kernel + arch = "".join(f"{i}" for i in dev.compute_capability) + program_options = ProgramOptions(std="c++17", arch=f"sm_{arch}") + prog = Program(code, code_type="c++", options=program_options) + mod = prog.compile("cubin") + kernel = mod.get_kernel("memory_ops") + + # Create memory resource + if memory_resource_class == "device_memory_resource": + if dev.properties.memory_pools_supported: + mr = DeviceMemoryResource(dev.device_id) + else: + mr = _SynchronousMemoryResource(dev.device_id) + else: + mr = memory_resource_class() + + # Allocate memory + size = 1024 + dtype = np.float32 + element_size = dtype().itemsize + total_size = size * element_size + + buffer = mr.allocate(total_size, stream=stream) + + # Create array view based on memory type + if mr.is_host_accessible: + # For pinned memory, use numpy + array = np.from_dlpack(buffer).view(dtype=dtype) + else: + array = cp.from_dlpack(buffer).view(dtype=dtype) + + # Initialize data with random values + if mr.is_host_accessible: + rng = np.random.default_rng() + array[:] = rng.random(size, dtype=dtype) + else: + rng = cp.random.default_rng() + array[:] = rng.random(size, dtype=dtype) + + # Store original values for verification + original = array.copy() + + # Sync before kernel launch + stream.sync() + + # Launch kernel + block = 256 + grid = (size + block - 1) // block + config = LaunchConfig(grid=grid, block=block) + + launch(stream, config, kernel, buffer, np.uint64(size)) + stream.sync() + + # Verify kernel operations + assert cp.allclose(array, original * 3.0), f"{memory_resource_class.__name__} operation failed" + + # Clean up + buffer.close(stream) + stream.close() + + cp.cuda.Stream.null.use() # reset CuPy's current stream to the null stream + + # Verify buffer is properly closed + assert buffer.handle == 0, f"{memory_resource_class.__name__} buffer should be closed"