@@ -150,6 +150,8 @@ ur_result_t ur_queue_immediate_in_order_t::queueGetNativeHandle(
150
150
ur_result_t ur_queue_immediate_in_order_t::queueFinish () {
151
151
TRACK_SCOPE_LATENCY (" ur_queue_immediate_in_order_t::queueFinish" );
152
152
153
+ hContext->getAsyncPool ()->cleanupPoolsForQueue (this );
154
+
153
155
auto commandListLocked = commandListManager.lock ();
154
156
// TODO: use zeEventHostSynchronize instead?
155
157
TRACK_SCOPE_LATENCY (
@@ -703,31 +705,142 @@ ur_result_t ur_queue_immediate_in_order_t::enqueueWriteHostPipe(
703
705
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
704
706
}
705
707
708
+ ur_result_t ur_queue_immediate_in_order_t::enqueueUSMAllocHelper (
709
+ ur_usm_pool_handle_t pPool, const size_t size,
710
+ const ur_exp_async_usm_alloc_properties_t *, uint32_t numEventsInWaitList,
711
+ const ur_event_handle_t *phEventWaitList, void **ppMem,
712
+ ur_event_handle_t *phEvent, ur_usm_type_t type) {
713
+ auto commandListLocked = commandListManager.lock ();
714
+
715
+ if (!pPool) {
716
+ pPool = hContext->getAsyncPool ();
717
+ }
718
+
719
+ auto device = (type == UR_USM_TYPE_HOST) ? nullptr : hDevice;
720
+ auto waitListView =
721
+ getWaitListView (commandListLocked, phEventWaitList, numEventsInWaitList);
722
+
723
+ auto asyncAlloc =
724
+ pPool->allocateEnqueued (hContext, this , device, nullptr , type, size);
725
+ if (!asyncAlloc) {
726
+ auto Ret = pPool->allocate (hContext, device, nullptr , type, size, ppMem);
727
+ if (Ret) {
728
+ return Ret;
729
+ }
730
+ } else {
731
+ ur_event_handle_t originAllocEvent;
732
+ std::tie (*ppMem, originAllocEvent) = *asyncAlloc;
733
+ waitListView = getWaitListView (commandListLocked, phEventWaitList,
734
+ numEventsInWaitList, originAllocEvent);
735
+ }
736
+
737
+ ur_command_t commandType = UR_COMMAND_FORCE_UINT32;
738
+ switch (type) {
739
+ case UR_USM_TYPE_HOST:
740
+ commandType = UR_COMMAND_ENQUEUE_USM_HOST_ALLOC_EXP;
741
+ break ;
742
+ case UR_USM_TYPE_DEVICE:
743
+ commandType = UR_COMMAND_ENQUEUE_USM_DEVICE_ALLOC_EXP;
744
+ break ;
745
+ case UR_USM_TYPE_SHARED:
746
+ commandType = UR_COMMAND_ENQUEUE_USM_SHARED_ALLOC_EXP;
747
+ break ;
748
+ default :
749
+ logger::error (" enqueueUSMAllocHelper: unsupported USM type" );
750
+ throw UR_RESULT_ERROR_UNKNOWN;
751
+ }
752
+
753
+ auto zeSignalEvent = getSignalEvent (commandListLocked, phEvent, commandType);
754
+ auto [pWaitEvents, numWaitEvents] = waitListView;
755
+ if (numWaitEvents > 0 ) {
756
+ ZE2UR_CALL (
757
+ zeCommandListAppendWaitOnEvents,
758
+ (commandListLocked->getZeCommandList (), numWaitEvents, pWaitEvents));
759
+ }
760
+ if (zeSignalEvent) {
761
+ ZE2UR_CALL (zeCommandListAppendSignalEvent,
762
+ (commandListLocked->getZeCommandList (), zeSignalEvent));
763
+ }
764
+
765
+ return UR_RESULT_SUCCESS;
766
+ }
767
+
706
768
ur_result_t ur_queue_immediate_in_order_t::enqueueUSMDeviceAllocExp (
707
- ur_usm_pool_handle_t , const size_t ,
708
- const ur_exp_async_usm_alloc_properties_t *, uint32_t ,
709
- const ur_event_handle_t *, void **, ur_event_handle_t *) {
710
- return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
769
+ ur_usm_pool_handle_t pPool, const size_t size,
770
+ const ur_exp_async_usm_alloc_properties_t *pProperties,
771
+ uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
772
+ void **ppMem, ur_event_handle_t *phEvent) {
773
+ TRACK_SCOPE_LATENCY (
774
+ " ur_queue_immediate_in_order_t::enqueueUSMDeviceAllocExp" );
775
+
776
+ return enqueueUSMAllocHelper (pPool, size, pProperties, numEventsInWaitList,
777
+ phEventWaitList, ppMem, phEvent,
778
+ UR_USM_TYPE_DEVICE);
711
779
}
712
780
713
781
ur_result_t ur_queue_immediate_in_order_t::enqueueUSMSharedAllocExp (
714
- ur_usm_pool_handle_t , const size_t ,
715
- const ur_exp_async_usm_alloc_properties_t *, uint32_t ,
716
- const ur_event_handle_t *, void **, ur_event_handle_t *) {
717
- return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
782
+ ur_usm_pool_handle_t pPool, const size_t size,
783
+ const ur_exp_async_usm_alloc_properties_t *pProperties,
784
+ uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
785
+ void **ppMem, ur_event_handle_t *phEvent) {
786
+ TRACK_SCOPE_LATENCY (
787
+ " ur_queue_immediate_in_order_t::enqueueUSMSharedAllocExp" );
788
+
789
+ return enqueueUSMAllocHelper (pPool, size, pProperties, numEventsInWaitList,
790
+ phEventWaitList, ppMem, phEvent,
791
+ UR_USM_TYPE_SHARED);
718
792
}
719
793
720
794
ur_result_t ur_queue_immediate_in_order_t::enqueueUSMHostAllocExp (
721
- ur_usm_pool_handle_t , const size_t ,
722
- const ur_exp_async_usm_alloc_properties_t *, uint32_t ,
723
- const ur_event_handle_t *, void **, ur_event_handle_t *) {
724
- return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
795
+ ur_usm_pool_handle_t pPool, const size_t size,
796
+ const ur_exp_async_usm_alloc_properties_t *pProperties,
797
+ uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
798
+ void **ppMem, ur_event_handle_t *phEvent) {
799
+ TRACK_SCOPE_LATENCY (" ur_queue_immediate_in_order_t::enqueueUSMHostAllocExp" );
800
+
801
+ return enqueueUSMAllocHelper (pPool, size, pProperties, numEventsInWaitList,
802
+ phEventWaitList, ppMem, phEvent,
803
+ UR_USM_TYPE_HOST);
725
804
}
726
805
727
806
ur_result_t ur_queue_immediate_in_order_t::enqueueUSMFreeExp (
728
- ur_usm_pool_handle_t , void *, uint32_t , const ur_event_handle_t *,
729
- ur_event_handle_t *) {
730
- return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
807
+ ur_usm_pool_handle_t pPool, void *pMem, uint32_t numEventsInWaitList,
808
+ const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
809
+ TRACK_SCOPE_LATENCY (" ur_queue_immediate_in_order_t::enqueueUSMFreeExp" );
810
+ auto commandListLocked = commandListManager.lock ();
811
+
812
+ auto zeSignalEvent = getSignalEvent (commandListLocked, phEvent,
813
+ UR_COMMAND_ENQUEUE_USM_FREE_EXP);
814
+ auto [pWaitEvents, numWaitEvents] =
815
+ getWaitListView (commandListLocked, phEventWaitList, numEventsInWaitList);
816
+
817
+ umf_memory_pool_handle_t hPool = umfPoolByPtr (pMem);
818
+ if (!hPool) {
819
+ return UR_RESULT_ERROR_INVALID_MEM_OBJECT
820
+ }
821
+
822
+ UsmPool *usmPool = nullptr ;
823
+ auto ret = umfPoolGetTag (hPool, (void **)&usmPool);
824
+ if (ret != UR_RESULT_SUCCESS || !usmPool) {
825
+ // This should never happen
826
+ return UR_RESULT_ERROR_UNKNOWN;
827
+ }
828
+
829
+ size_t size = umfPoolMallocUsableSize (hPool, pMem);
830
+ usmPool->asyncPool .insert (pMem, size, *phEvent, this );
831
+
832
+ if (numWaitEvents > 0 ) {
833
+ ZE2UR_CALL (
834
+ zeCommandListAppendWaitOnEvents,
835
+ (commandListLocked->getZeCommandList (), numWaitEvents, pWaitEvents));
836
+ }
837
+
838
+ if (zeSignalEvent) {
839
+ ZE2UR_CALL (zeCommandListAppendSignalEvent,
840
+ (commandListLocked->getZeCommandList (), zeSignalEvent));
841
+ }
842
+
843
+ return UR_RESULT_SUCCESS;
731
844
}
732
845
733
846
ur_result_t ur_queue_immediate_in_order_t::bindlessImagesImageCopyExp (
@@ -866,9 +979,9 @@ ur_result_t ur_queue_immediate_in_order_t::enqueueGenericCommandListsExp(
866
979
" ur_queue_immediate_in_order_t::enqueueGenericCommandListsExp" );
867
980
868
981
auto commandListLocked = commandListManager.lock ();
982
+
869
983
auto zeSignalEvent =
870
984
getSignalEvent (commandListLocked, phEvent, callerCommand);
871
-
872
985
auto [pWaitEvents, numWaitEvents] =
873
986
getWaitListView (commandListLocked, phEventWaitList, numEventsInWaitList,
874
987
additionalWaitEvent);
0 commit comments