@@ -150,6 +150,8 @@ ur_result_t ur_queue_immediate_in_order_t::queueGetNativeHandle(
150
150
ur_result_t ur_queue_immediate_in_order_t::queueFinish () {
151
151
TRACK_SCOPE_LATENCY (" ur_queue_immediate_in_order_t::queueFinish" );
152
152
153
+ hContext->getAsyncPool ()->cleanupPoolsForQueue (this );
154
+
153
155
auto commandListLocked = commandListManager.lock ();
154
156
// TODO: use zeEventHostSynchronize instead?
155
157
TRACK_SCOPE_LATENCY (
@@ -703,31 +705,149 @@ ur_result_t ur_queue_immediate_in_order_t::enqueueWriteHostPipe(
703
705
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
704
706
}
705
707
708
+ ur_result_t ur_queue_immediate_in_order_t::enqueueUSMAllocHelper (
709
+ ur_usm_pool_handle_t pPool, const size_t size,
710
+ const ur_exp_async_usm_alloc_properties_t *, uint32_t numEventsInWaitList,
711
+ const ur_event_handle_t *phEventWaitList, void **ppMem,
712
+ ur_event_handle_t *phEvent, ur_usm_type_t type) {
713
+ auto commandListLocked = commandListManager.lock ();
714
+
715
+ if (!pPool) {
716
+ pPool = hContext->getAsyncPool ();
717
+ }
718
+
719
+ auto device = (type == UR_USM_TYPE_HOST) ? nullptr : hDevice;
720
+ auto waitListView =
721
+ getWaitListView (commandListLocked, phEventWaitList, numEventsInWaitList);
722
+
723
+ auto asyncAlloc =
724
+ pPool->allocateEnqueued (hContext, this , device, nullptr , type, size);
725
+ if (!asyncAlloc) {
726
+ auto Ret = pPool->allocate (hContext, device, nullptr , type, size, ppMem);
727
+ if (Ret) {
728
+ return Ret;
729
+ }
730
+ } else {
731
+ ur_event_handle_t originAllocEvent;
732
+ std::tie (*ppMem, originAllocEvent) = *asyncAlloc;
733
+ if (originAllocEvent) {
734
+ waitListView = getWaitListView (commandListLocked, phEventWaitList,
735
+ numEventsInWaitList, originAllocEvent);
736
+ }
737
+ }
738
+
739
+ ur_command_t commandType = UR_COMMAND_FORCE_UINT32;
740
+ switch (type) {
741
+ case UR_USM_TYPE_HOST:
742
+ commandType = UR_COMMAND_ENQUEUE_USM_HOST_ALLOC_EXP;
743
+ break ;
744
+ case UR_USM_TYPE_DEVICE:
745
+ commandType = UR_COMMAND_ENQUEUE_USM_DEVICE_ALLOC_EXP;
746
+ break ;
747
+ case UR_USM_TYPE_SHARED:
748
+ commandType = UR_COMMAND_ENQUEUE_USM_SHARED_ALLOC_EXP;
749
+ break ;
750
+ default :
751
+ logger::error (" enqueueUSMAllocHelper: unsupported USM type" );
752
+ throw UR_RESULT_ERROR_UNKNOWN;
753
+ }
754
+
755
+ auto zeSignalEvent = getSignalEvent (commandListLocked, phEvent, commandType);
756
+ auto [pWaitEvents, numWaitEvents] = waitListView;
757
+ if (numWaitEvents > 0 ) {
758
+ ZE2UR_CALL (
759
+ zeCommandListAppendWaitOnEvents,
760
+ (commandListLocked->getZeCommandList (), numWaitEvents, pWaitEvents));
761
+ }
762
+ if (zeSignalEvent) {
763
+ ZE2UR_CALL (zeCommandListAppendSignalEvent,
764
+ (commandListLocked->getZeCommandList (), zeSignalEvent));
765
+ }
766
+
767
+ return UR_RESULT_SUCCESS;
768
+ }
769
+
706
770
ur_result_t ur_queue_immediate_in_order_t::enqueueUSMDeviceAllocExp (
707
- ur_usm_pool_handle_t , const size_t ,
708
- const ur_exp_async_usm_alloc_properties_t *, uint32_t ,
709
- const ur_event_handle_t *, void **, ur_event_handle_t *) {
710
- return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
771
+ ur_usm_pool_handle_t pPool, const size_t size,
772
+ const ur_exp_async_usm_alloc_properties_t *pProperties,
773
+ uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
774
+ void **ppMem, ur_event_handle_t *phEvent) {
775
+ TRACK_SCOPE_LATENCY (
776
+ " ur_queue_immediate_in_order_t::enqueueUSMDeviceAllocExp" );
777
+
778
+ return enqueueUSMAllocHelper (pPool, size, pProperties, numEventsInWaitList,
779
+ phEventWaitList, ppMem, phEvent,
780
+ UR_USM_TYPE_DEVICE);
711
781
}
712
782
713
783
ur_result_t ur_queue_immediate_in_order_t::enqueueUSMSharedAllocExp (
714
- ur_usm_pool_handle_t , const size_t ,
715
- const ur_exp_async_usm_alloc_properties_t *, uint32_t ,
716
- const ur_event_handle_t *, void **, ur_event_handle_t *) {
717
- return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
784
+ ur_usm_pool_handle_t pPool, const size_t size,
785
+ const ur_exp_async_usm_alloc_properties_t *pProperties,
786
+ uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
787
+ void **ppMem, ur_event_handle_t *phEvent) {
788
+ TRACK_SCOPE_LATENCY (
789
+ " ur_queue_immediate_in_order_t::enqueueUSMSharedAllocExp" );
790
+
791
+ return enqueueUSMAllocHelper (pPool, size, pProperties, numEventsInWaitList,
792
+ phEventWaitList, ppMem, phEvent,
793
+ UR_USM_TYPE_SHARED);
718
794
}
719
795
720
796
ur_result_t ur_queue_immediate_in_order_t::enqueueUSMHostAllocExp (
721
- ur_usm_pool_handle_t , const size_t ,
722
- const ur_exp_async_usm_alloc_properties_t *, uint32_t ,
723
- const ur_event_handle_t *, void **, ur_event_handle_t *) {
724
- return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
797
+ ur_usm_pool_handle_t pPool, const size_t size,
798
+ const ur_exp_async_usm_alloc_properties_t *pProperties,
799
+ uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
800
+ void **ppMem, ur_event_handle_t *phEvent) {
801
+ TRACK_SCOPE_LATENCY (" ur_queue_immediate_in_order_t::enqueueUSMHostAllocExp" );
802
+
803
+ return enqueueUSMAllocHelper (pPool, size, pProperties, numEventsInWaitList,
804
+ phEventWaitList, ppMem, phEvent,
805
+ UR_USM_TYPE_HOST);
725
806
}
726
807
727
808
ur_result_t ur_queue_immediate_in_order_t::enqueueUSMFreeExp (
728
- ur_usm_pool_handle_t , void *, uint32_t , const ur_event_handle_t *,
729
- ur_event_handle_t *) {
730
- return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
809
+ ur_usm_pool_handle_t , void *pMem, uint32_t numEventsInWaitList,
810
+ const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
811
+ TRACK_SCOPE_LATENCY (" ur_queue_immediate_in_order_t::enqueueUSMFreeExp" );
812
+ auto commandListLocked = commandListManager.lock ();
813
+
814
+ auto zeSignalEvent = getSignalEvent (commandListLocked, phEvent,
815
+ UR_COMMAND_ENQUEUE_USM_FREE_EXP);
816
+ auto [pWaitEvents, numWaitEvents] =
817
+ getWaitListView (commandListLocked, phEventWaitList, numEventsInWaitList);
818
+
819
+ umf_memory_pool_handle_t hPool = umfPoolByPtr (pMem);
820
+ if (!hPool) {
821
+ return UR_RESULT_ERROR_INVALID_MEM_OBJECT;
822
+ }
823
+
824
+ UsmPool *usmPool = nullptr ;
825
+ auto ret = umfPoolGetTag (hPool, (void **)&usmPool);
826
+ if (ret != UMF_RESULT_SUCCESS || !usmPool) {
827
+ // This should never happen
828
+ return UR_RESULT_ERROR_UNKNOWN;
829
+ }
830
+
831
+ size_t size = umfPoolMallocUsableSize (hPool, pMem);
832
+ ur_event_handle_t poolEvent = nullptr ;
833
+ if (phEvent) {
834
+ poolEvent = *phEvent;
835
+ poolEvent->RefCount .increment ();
836
+ }
837
+ usmPool->asyncPool .insert (pMem, size, poolEvent, this );
838
+
839
+ if (numWaitEvents > 0 ) {
840
+ ZE2UR_CALL (
841
+ zeCommandListAppendWaitOnEvents,
842
+ (commandListLocked->getZeCommandList (), numWaitEvents, pWaitEvents));
843
+ }
844
+
845
+ if (zeSignalEvent) {
846
+ ZE2UR_CALL (zeCommandListAppendSignalEvent,
847
+ (commandListLocked->getZeCommandList (), zeSignalEvent));
848
+ }
849
+
850
+ return UR_RESULT_SUCCESS;
731
851
}
732
852
733
853
ur_result_t ur_queue_immediate_in_order_t::bindlessImagesImageCopyExp (
@@ -866,9 +986,9 @@ ur_result_t ur_queue_immediate_in_order_t::enqueueGenericCommandListsExp(
866
986
" ur_queue_immediate_in_order_t::enqueueGenericCommandListsExp" );
867
987
868
988
auto commandListLocked = commandListManager.lock ();
989
+
869
990
auto zeSignalEvent =
870
991
getSignalEvent (commandListLocked, phEvent, callerCommand);
871
-
872
992
auto [pWaitEvents, numWaitEvents] =
873
993
getWaitListView (commandListLocked, phEventWaitList, numEventsInWaitList,
874
994
additionalWaitEvent);
0 commit comments