diff --git a/unified-runtime/source/adapters/cuda/queue.cpp b/unified-runtime/source/adapters/cuda/queue.cpp index a6dc80bd9ce3c..3708b08c72397 100644 --- a/unified-runtime/source/adapters/cuda/queue.cpp +++ b/unified-runtime/source/adapters/cuda/queue.cpp @@ -16,32 +16,37 @@ #include #include -void ur_queue_handle_t_::computeStreamWaitForBarrierIfNeeded(CUstream Stream, - uint32_t StreamI) { +template <> +void cuda_stream_queue::computeStreamWaitForBarrierIfNeeded(CUstream Stream, + uint32_t StreamI) { if (BarrierEvent && !ComputeAppliedBarrier[StreamI]) { UR_CHECK_ERROR(cuStreamWaitEvent(Stream, BarrierEvent, 0)); ComputeAppliedBarrier[StreamI] = true; } } -void ur_queue_handle_t_::transferStreamWaitForBarrierIfNeeded( - CUstream Stream, uint32_t StreamI) { +template <> +void cuda_stream_queue::transferStreamWaitForBarrierIfNeeded(CUstream Stream, + uint32_t StreamI) { if (BarrierEvent && !TransferAppliedBarrier[StreamI]) { UR_CHECK_ERROR(cuStreamWaitEvent(Stream, BarrierEvent, 0)); TransferAppliedBarrier[StreamI] = true; } } -ur_queue_handle_t ur_queue_handle_t_::getEventQueue(const ur_event_handle_t e) { +template <> +ur_queue_handle_t cuda_stream_queue::getEventQueue(const ur_event_handle_t e) { return e->getQueue(); } +template <> uint32_t -ur_queue_handle_t_::getEventComputeStreamToken(const ur_event_handle_t e) { +cuda_stream_queue::getEventComputeStreamToken(const ur_event_handle_t e) { return e->getComputeStreamToken(); } -CUstream ur_queue_handle_t_::getEventStream(const ur_event_handle_t e) { +template <> +CUstream cuda_stream_queue::getEventStream(const ur_event_handle_t e) { return e->getStream(); } @@ -87,7 +92,7 @@ urQueueCreate(ur_context_handle_t hContext, ur_device_handle_t hDevice, } Queue = std::unique_ptr(new ur_queue_handle_t_{ - IsOutOfOrder, hContext, hDevice, Flags, URFlags, Priority}); + {IsOutOfOrder, hContext, hDevice, Flags, URFlags, Priority}}); *phQueue = Queue.release(); @@ -203,8 +208,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urQueueCreateWithNativeHandle( pProperties ? pProperties->isNativeHandleOwned : false; // Create queue from a native stream - *phQueue = new ur_queue_handle_t_{CuStream, hContext, hDevice, - CuFlags, Flags, isNativeHandleOwned}; + *phQueue = new ur_queue_handle_t_{ + {CuStream, hContext, hDevice, CuFlags, Flags, isNativeHandleOwned}}; return UR_RESULT_SUCCESS; } diff --git a/unified-runtime/source/adapters/cuda/queue.hpp b/unified-runtime/source/adapters/cuda/queue.hpp index f6369996b86af..636206c1159ab 100644 --- a/unified-runtime/source/adapters/cuda/queue.hpp +++ b/unified-runtime/source/adapters/cuda/queue.hpp @@ -19,38 +19,25 @@ #include -/// UR queue mapping on to CUstream objects. -/// -struct ur_queue_handle_t_ : stream_queue_t { - using stream_queue_t::stream_queue_t; - - CUevent BarrierEvent = nullptr; - CUevent BarrierTmpEvent = nullptr; - - void computeStreamWaitForBarrierIfNeeded(CUstream Strean, - uint32_t StreamI) override; - void transferStreamWaitForBarrierIfNeeded(CUstream Stream, - uint32_t StreamI) override; - ur_queue_handle_t getEventQueue(const ur_event_handle_t) override; - uint32_t getEventComputeStreamToken(const ur_event_handle_t) override; - CUstream getEventStream(const ur_event_handle_t) override; - - // Function which creates the profiling stream. Called only from makeNative - // event when profiling is required. - void createHostSubmitTimeStream() { - static std::once_flag HostSubmitTimeStreamFlag; - std::call_once(HostSubmitTimeStreamFlag, [&]() { - UR_CHECK_ERROR(cuStreamCreateWithPriority(&HostSubmitTimeStream, - CU_STREAM_NON_BLOCKING, 0)); - }); - } - - void createStreamWithPriority(CUstream *Stream, unsigned int Flags, - int Priority) override { - UR_CHECK_ERROR(cuStreamCreateWithPriority(Stream, Flags, Priority)); - } -}; +using cuda_stream_queue = stream_queue_t; +struct ur_queue_handle_t_ : public cuda_stream_queue {}; + +// Function which creates the profiling stream. Called only from makeNative +// event when profiling is required. +template <> inline void cuda_stream_queue::createHostSubmitTimeStream() { + static std::once_flag HostSubmitTimeStreamFlag; + std::call_once(HostSubmitTimeStreamFlag, [&]() { + UR_CHECK_ERROR(cuStreamCreateWithPriority(&HostSubmitTimeStream, + CU_STREAM_NON_BLOCKING, 0)); + }); +} + +template <> +inline void cuda_stream_queue::createStreamWithPriority(CUstream *Stream, + unsigned int Flags, + int Priority) { + UR_CHECK_ERROR(cuStreamCreateWithPriority(Stream, Flags, Priority)); +} // RAII object to make hQueue stream getter methods all return the same stream // within the lifetime of this object. diff --git a/unified-runtime/source/adapters/hip/queue.cpp b/unified-runtime/source/adapters/hip/queue.cpp index 0824a631b90dc..fcda95f1f587f 100644 --- a/unified-runtime/source/adapters/hip/queue.cpp +++ b/unified-runtime/source/adapters/hip/queue.cpp @@ -12,32 +12,37 @@ #include "context.hpp" #include "event.hpp" -void ur_queue_handle_t_::computeStreamWaitForBarrierIfNeeded( - hipStream_t Stream, uint32_t Stream_i) { +template <> +void hip_stream_queue::computeStreamWaitForBarrierIfNeeded(hipStream_t Stream, + uint32_t Stream_i) { if (BarrierEvent && !ComputeAppliedBarrier[Stream_i]) { UR_CHECK_ERROR(hipStreamWaitEvent(Stream, BarrierEvent, 0)); ComputeAppliedBarrier[Stream_i] = true; } } -void ur_queue_handle_t_::transferStreamWaitForBarrierIfNeeded( - hipStream_t Stream, uint32_t Stream_i) { +template <> +void hip_stream_queue::transferStreamWaitForBarrierIfNeeded(hipStream_t Stream, + uint32_t Stream_i) { if (BarrierEvent && !TransferAppliedBarrier[Stream_i]) { UR_CHECK_ERROR(hipStreamWaitEvent(Stream, BarrierEvent, 0)); TransferAppliedBarrier[Stream_i] = true; } } -ur_queue_handle_t ur_queue_handle_t_::getEventQueue(const ur_event_handle_t e) { +template <> +ur_queue_handle_t hip_stream_queue::getEventQueue(const ur_event_handle_t e) { return e->getQueue(); } +template <> uint32_t -ur_queue_handle_t_::getEventComputeStreamToken(const ur_event_handle_t e) { +hip_stream_queue::getEventComputeStreamToken(const ur_event_handle_t e) { return e->getComputeStreamToken(); } -hipStream_t ur_queue_handle_t_::getEventStream(const ur_event_handle_t e) { +template <> +hipStream_t hip_stream_queue::getEventStream(const ur_event_handle_t e) { return e->getStream(); } @@ -76,7 +81,7 @@ urQueueCreate(ur_context_handle_t hContext, ur_device_handle_t hDevice, : false; QueueImpl = std::unique_ptr(new ur_queue_handle_t_{ - IsOutOfOrder, hContext, hDevice, Flags, URFlags, Priority}); + {IsOutOfOrder, hContext, hDevice, Flags, URFlags, Priority}}); *phQueue = QueueImpl.release(); @@ -238,8 +243,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urQueueCreateWithNativeHandle( // Create queue and set num_compute_streams to 1, as computeHIPStreams has // valid stream - *phQueue = new ur_queue_handle_t_{HIPStream, hContext, hDevice, - HIPFlags, Flags, isNativeHandleOwned}; + *phQueue = new ur_queue_handle_t_{ + {HIPStream, hContext, hDevice, HIPFlags, Flags, isNativeHandleOwned}}; return UR_RESULT_SUCCESS; } diff --git a/unified-runtime/source/adapters/hip/queue.hpp b/unified-runtime/source/adapters/hip/queue.hpp index b06f681953962..6c3f2c8f0ef8e 100644 --- a/unified-runtime/source/adapters/hip/queue.hpp +++ b/unified-runtime/source/adapters/hip/queue.hpp @@ -16,38 +16,25 @@ #include -/// UR queue mapping on to hipStream_t objects. -/// -struct ur_queue_handle_t_ : stream_queue_t { - using stream_queue_t::stream_queue_t; - - hipEvent_t BarrierEvent = nullptr; - hipEvent_t BarrierTmpEvent = nullptr; - - void computeStreamWaitForBarrierIfNeeded(hipStream_t Strean, - uint32_t StreamI) override; - void transferStreamWaitForBarrierIfNeeded(hipStream_t Stream, - uint32_t StreamI) override; - ur_queue_handle_t getEventQueue(const ur_event_handle_t) override; - uint32_t getEventComputeStreamToken(const ur_event_handle_t) override; - hipStream_t getEventStream(const ur_event_handle_t) override; - - // Function which creates the profiling stream. Called only from makeNative - // event when profiling is required. - void createHostSubmitTimeStream() { - static std::once_flag HostSubmitTimeStreamFlag; - std::call_once(HostSubmitTimeStreamFlag, [&]() { - UR_CHECK_ERROR(hipStreamCreateWithFlags(&HostSubmitTimeStream, - hipStreamNonBlocking)); - }); - } - - void createStreamWithPriority(hipStream_t *Stream, unsigned int Flags, - int Priority) override { - UR_CHECK_ERROR(hipStreamCreateWithPriority(Stream, Flags, Priority)); - } -}; +using hip_stream_queue = stream_queue_t; +struct ur_queue_handle_t_ : public hip_stream_queue {}; + +template <> +inline void hip_stream_queue::createStreamWithPriority(hipStream_t *Stream, + unsigned int Flags, + int Priority) { + UR_CHECK_ERROR(hipStreamCreateWithPriority(Stream, Flags, Priority)); +} + +// Function which creates the profiling stream. Called only from makeNative +// event when profiling is required. +template <> inline void hip_stream_queue::createHostSubmitTimeStream() { + static std::once_flag HostSubmitTimeStreamFlag; + std::call_once(HostSubmitTimeStreamFlag, [&]() { + UR_CHECK_ERROR( + hipStreamCreateWithFlags(&HostSubmitTimeStream, hipStreamNonBlocking)); + }); +} // RAII object to make hQueue stream getter methods all return the same stream // within the lifetime of this object. diff --git a/unified-runtime/source/common/cuda-hip/stream_queue.hpp b/unified-runtime/source/common/cuda-hip/stream_queue.hpp index 7e5024c97156b..a0be37f3d8c4e 100644 --- a/unified-runtime/source/common/cuda-hip/stream_queue.hpp +++ b/unified-runtime/source/common/cuda-hip/stream_queue.hpp @@ -21,7 +21,8 @@ using ur_stream_guard = std::unique_lock; /// backend 'stream' objects. /// /// This class is specifically designed for the CUDA and HIP adapters. -template struct stream_queue_t { +template +struct stream_queue_t { using native_type = ST; static constexpr int DefaultNumComputeStreams = CS; static constexpr int DefaultNumTransferStreams = TS; @@ -61,6 +62,8 @@ template struct stream_queue_t { std::mutex TransferStreamMutex; std::mutex BarrierMutex; bool HasOwnership; + BarrierEventT BarrierEvent = nullptr; + BarrierEventT BarrierTmpEvent = nullptr; stream_queue_t(bool IsOutOfOrder, ur_context_handle_t_ *Context, ur_device_handle_t_ *Device, unsigned int Flags, @@ -88,17 +91,18 @@ template struct stream_queue_t { urContextRetain(Context); } - virtual ~stream_queue_t() { urContextRelease(Context); } + ~stream_queue_t() { urContextRelease(Context); } - virtual void computeStreamWaitForBarrierIfNeeded(native_type Strean, - uint32_t StreamI) = 0; - virtual void transferStreamWaitForBarrierIfNeeded(native_type Stream, - uint32_t StreamI) = 0; - virtual void createStreamWithPriority(native_type *Stream, unsigned int Flags, - int Priority) = 0; - virtual ur_queue_handle_t getEventQueue(const ur_event_handle_t) = 0; - virtual uint32_t getEventComputeStreamToken(const ur_event_handle_t) = 0; - virtual native_type getEventStream(const ur_event_handle_t) = 0; + void computeStreamWaitForBarrierIfNeeded(native_type Strean, + uint32_t StreamI); + void transferStreamWaitForBarrierIfNeeded(native_type Stream, + uint32_t StreamI); + void createStreamWithPriority(native_type *Stream, unsigned int Flags, + int Priority); + ur_queue_handle_t getEventQueue(const ur_event_handle_t); + uint32_t getEventComputeStreamToken(const ur_event_handle_t); + native_type getEventStream(const ur_event_handle_t); + void createHostSubmitTimeStream(); // get_next_compute/transfer_stream() functions return streams from // appropriate pools in round-robin fashion