Skip to content

Commit 56bfbb5

Browse files
committed
add api
1 parent 7175536 commit 56bfbb5

File tree

2 files changed

+163
-28
lines changed

2 files changed

+163
-28
lines changed

unified-runtime/source/loader/layers/sanitizer/asan/asan_ddi.cpp

+81-14
Original file line numberDiff line numberDiff line change
@@ -513,22 +513,14 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueKernelLaunch(
513513

514514
UR_CALL(getAsanInterceptor()->preLaunchKernel(hKernel, hQueue, LaunchInfo));
515515

516-
ur_event_handle_t hEvent{};
517-
ur_result_t result =
518-
pfnKernelLaunch(hQueue, hKernel, workDim, pGlobalWorkOffset,
519-
pGlobalWorkSize, LaunchInfo.LocalWorkSize.data(),
520-
numEventsInWaitList, phEventWaitList, &hEvent);
521-
522-
if (result == UR_RESULT_SUCCESS) {
523-
UR_CALL(
524-
getAsanInterceptor()->postLaunchKernel(hKernel, hQueue, LaunchInfo));
525-
}
516+
UR_CALL(getContext()->urDdiTable.Enqueue.pfnKernelLaunch(
517+
hQueue, hKernel, workDim, pGlobalWorkOffset, pGlobalWorkSize,
518+
LaunchInfo.LocalWorkSize.data(), numEventsInWaitList, phEventWaitList,
519+
phEvent));
526520

527-
if (phEvent) {
528-
*phEvent = hEvent;
529-
}
521+
UR_CALL(getAsanInterceptor()->postLaunchKernel(hKernel, hQueue, LaunchInfo));
530522

531-
return result;
523+
return UR_RESULT_SUCCESS;
532524
}
533525

534526
///////////////////////////////////////////////////////////////////////////////
@@ -1410,6 +1402,57 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemUnmap(
14101402
return UR_RESULT_SUCCESS;
14111403
}
14121404

1405+
ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
1406+
/// [in] handle of the queue object
1407+
ur_queue_handle_t hQueue,
1408+
/// [in] handle of the kernel object
1409+
ur_kernel_handle_t hKernel,
1410+
/// [in] number of dimensions, from 1 to 3, to specify the global and
1411+
/// work-group work-items
1412+
uint32_t workDim,
1413+
/// [in] pointer to an array of workDim unsigned values that specify the
1414+
/// offset used to calculate the global ID of a work-item
1415+
const size_t *pGlobalWorkOffset,
1416+
/// [in] pointer to an array of workDim unsigned values that specify the
1417+
/// number of global work-items in workDim that will execute the kernel
1418+
/// function
1419+
const size_t *pGlobalWorkSize,
1420+
/// [in][optional] pointer to an array of workDim unsigned values that
1421+
/// specify the number of local work-items forming a work-group that will
1422+
/// execute the kernel function.
1423+
/// If nullptr, the runtime implementation will choose the work-group size.
1424+
const size_t *pLocalWorkSize,
1425+
/// [in] size of the event wait list
1426+
uint32_t numEventsInWaitList,
1427+
/// [in][optional][range(0, numEventsInWaitList)] pointer to a list of
1428+
/// events that must be complete before the kernel execution.
1429+
/// If nullptr, the numEventsInWaitList must be 0, indicating that no wait
1430+
/// event.
1431+
const ur_event_handle_t *phEventWaitList,
1432+
/// [out][optional][alloc] return an event object that identifies this
1433+
/// particular kernel execution instance. If phEventWaitList and phEvent
1434+
/// are not NULL, phEvent must not refer to an element of the
1435+
/// phEventWaitList array.
1436+
ur_event_handle_t *phEvent) {
1437+
1438+
getContext()->logger.debug("==== urEnqueueCooperativeKernelLaunchExp");
1439+
1440+
LaunchInfo LaunchInfo(GetContext(hQueue), GetDevice(hQueue), pGlobalWorkSize,
1441+
pLocalWorkSize, pGlobalWorkOffset, workDim);
1442+
UR_CALL(LaunchInfo.Data.syncToDevice(hQueue));
1443+
1444+
UR_CALL(getAsanInterceptor()->preLaunchKernel(hKernel, hQueue, LaunchInfo));
1445+
1446+
UR_CALL(getContext()->urDdiTable.EnqueueExp.pfnCooperativeKernelLaunchExp(
1447+
hQueue, hKernel, workDim, pGlobalWorkOffset, pGlobalWorkSize,
1448+
LaunchInfo.LocalWorkSize.data(), numEventsInWaitList, phEventWaitList,
1449+
phEvent));
1450+
1451+
UR_CALL(getAsanInterceptor()->postLaunchKernel(hKernel, hQueue, LaunchInfo));
1452+
1453+
return UR_RESULT_SUCCESS;
1454+
}
1455+
14131456
///////////////////////////////////////////////////////////////////////////////
14141457
/// @brief Intercept function for urKernelRetain
14151458
__urdlllocal ur_result_t UR_APICALL urKernelRetain(
@@ -1952,6 +1995,25 @@ __urdlllocal ur_result_t UR_APICALL urGetDeviceProcAddrTable(
19521995
return result;
19531996
}
19541997

1998+
///////////////////////////////////////////////////////////////////////////////
1999+
/// @brief Exported function for filling application's EnqueueExp table
2000+
/// with current process' addresses
2001+
///
2002+
/// @returns
2003+
/// - ::UR_RESULT_SUCCESS
2004+
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
2005+
__urdlllocal ur_result_t UR_APICALL urGetEnqueueExpProcAddrTable(
2006+
/// [in,out] pointer to table of DDI function pointers
2007+
ur_enqueue_exp_dditable_t *pDdiTable) {
2008+
if (nullptr == pDdiTable) {
2009+
return UR_RESULT_ERROR_INVALID_NULL_POINTER;
2010+
}
2011+
2012+
pDdiTable->pfnCooperativeKernelLaunchExp =
2013+
ur_sanitizer_layer::asan::urEnqueueCooperativeKernelLaunchExp;
2014+
return UR_RESULT_SUCCESS;
2015+
}
2016+
19552017
template <class A, class B> struct NotSupportedApi;
19562018

19572019
template <class MsgType, class R, class... A>
@@ -2147,6 +2209,11 @@ ur_result_t initAsanDDITable(ur_dditable_t *dditable) {
21472209
UR_API_VERSION_CURRENT, &dditable->VirtualMem);
21482210
}
21492211

2212+
if (UR_RESULT_SUCCESS == result) {
2213+
result = ur_sanitizer_layer::asan::urGetEnqueueExpProcAddrTable(
2214+
&dditable->EnqueueExp);
2215+
}
2216+
21502217
if (result != UR_RESULT_SUCCESS) {
21512218
getContext()->logger.error("Initialize ASAN DDI table failed: {}", result);
21522219
}

unified-runtime/source/loader/layers/sanitizer/msan/msan_ddi.cpp

+82-14
Original file line numberDiff line numberDiff line change
@@ -431,22 +431,14 @@ ur_result_t urEnqueueKernelLaunch(
431431

432432
UR_CALL(getMsanInterceptor()->preLaunchKernel(hKernel, hQueue, LaunchInfo));
433433

434-
ur_event_handle_t hEvent{};
435-
ur_result_t result =
436-
pfnKernelLaunch(hQueue, hKernel, workDim, pGlobalWorkOffset,
437-
pGlobalWorkSize, LaunchInfo.LocalWorkSize.data(),
438-
numEventsInWaitList, phEventWaitList, &hEvent);
434+
UR_CALL(getContext()->urDdiTable.Enqueue.pfnKernelLaunch(
435+
hQueue, hKernel, workDim, pGlobalWorkOffset, pGlobalWorkSize,
436+
LaunchInfo.LocalWorkSize.data(), numEventsInWaitList, phEventWaitList,
437+
phEvent));
439438

440-
if (result == UR_RESULT_SUCCESS) {
441-
UR_CALL(
442-
getMsanInterceptor()->postLaunchKernel(hKernel, hQueue, LaunchInfo));
443-
}
444-
445-
if (phEvent) {
446-
*phEvent = hEvent;
447-
}
439+
UR_CALL(getMsanInterceptor()->postLaunchKernel(hKernel, hQueue, LaunchInfo));
448440

449-
return result;
441+
return UR_RESULT_SUCCESS;
450442
}
451443

452444
///////////////////////////////////////////////////////////////////////////////
@@ -1314,6 +1306,58 @@ ur_result_t urEnqueueMemUnmap(
13141306
return UR_RESULT_SUCCESS;
13151307
}
13161308

1309+
ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
1310+
/// [in] handle of the queue object
1311+
ur_queue_handle_t hQueue,
1312+
/// [in] handle of the kernel object
1313+
ur_kernel_handle_t hKernel,
1314+
/// [in] number of dimensions, from 1 to 3, to specify the global and
1315+
/// work-group work-items
1316+
uint32_t workDim,
1317+
/// [in] pointer to an array of workDim unsigned values that specify the
1318+
/// offset used to calculate the global ID of a work-item
1319+
const size_t *pGlobalWorkOffset,
1320+
/// [in] pointer to an array of workDim unsigned values that specify the
1321+
/// number of global work-items in workDim that will execute the kernel
1322+
/// function
1323+
const size_t *pGlobalWorkSize,
1324+
/// [in][optional] pointer to an array of workDim unsigned values that
1325+
/// specify the number of local work-items forming a work-group that will
1326+
/// execute the kernel function.
1327+
/// If nullptr, the runtime implementation will choose the work-group size.
1328+
const size_t *pLocalWorkSize,
1329+
/// [in] size of the event wait list
1330+
uint32_t numEventsInWaitList,
1331+
/// [in][optional][range(0, numEventsInWaitList)] pointer to a list of
1332+
/// events that must be complete before the kernel execution.
1333+
/// If nullptr, the numEventsInWaitList must be 0, indicating that no wait
1334+
/// event.
1335+
const ur_event_handle_t *phEventWaitList,
1336+
/// [out][optional][alloc] return an event object that identifies this
1337+
/// particular kernel execution instance. If phEventWaitList and phEvent
1338+
/// are not NULL, phEvent must not refer to an element of the
1339+
/// phEventWaitList array.
1340+
ur_event_handle_t *phEvent) {
1341+
1342+
getContext()->logger.debug("==== urEnqueueCooperativeKernelLaunchExp");
1343+
1344+
USMLaunchInfo LaunchInfo(GetContext(hQueue), GetDevice(hQueue),
1345+
pGlobalWorkSize, pLocalWorkSize, pGlobalWorkOffset,
1346+
workDim);
1347+
UR_CALL(LaunchInfo.initialize());
1348+
1349+
UR_CALL(getMsanInterceptor()->preLaunchKernel(hKernel, hQueue, LaunchInfo));
1350+
1351+
UR_CALL(getContext()->urDdiTable.EnqueueExp.pfnCooperativeKernelLaunchExp(
1352+
hQueue, hKernel, workDim, pGlobalWorkOffset, pGlobalWorkSize,
1353+
LaunchInfo.LocalWorkSize.data(), numEventsInWaitList, phEventWaitList,
1354+
phEvent));
1355+
1356+
UR_CALL(getMsanInterceptor()->postLaunchKernel(hKernel, hQueue, LaunchInfo));
1357+
1358+
return UR_RESULT_SUCCESS;
1359+
}
1360+
13171361
///////////////////////////////////////////////////////////////////////////////
13181362
/// @brief Intercept function for urKernelRetain
13191363
ur_result_t urKernelRetain(
@@ -1891,6 +1935,25 @@ ur_result_t urCheckVersion(ur_api_version_t version) {
18911935
return UR_RESULT_SUCCESS;
18921936
}
18931937

1938+
///////////////////////////////////////////////////////////////////////////////
1939+
/// @brief Exported function for filling application's EnqueueExp table
1940+
/// with current process' addresses
1941+
///
1942+
/// @returns
1943+
/// - ::UR_RESULT_SUCCESS
1944+
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
1945+
__urdlllocal ur_result_t UR_APICALL urGetEnqueueExpProcAddrTable(
1946+
/// [in,out] pointer to table of DDI function pointers
1947+
ur_enqueue_exp_dditable_t *pDdiTable) {
1948+
if (nullptr == pDdiTable) {
1949+
return UR_RESULT_ERROR_INVALID_NULL_POINTER;
1950+
}
1951+
1952+
pDdiTable->pfnCooperativeKernelLaunchExp =
1953+
ur_sanitizer_layer::msan::urEnqueueCooperativeKernelLaunchExp;
1954+
return UR_RESULT_SUCCESS;
1955+
}
1956+
18941957
} // namespace msan
18951958

18961959
ur_result_t initMsanDDITable(ur_dditable_t *dditable) {
@@ -1945,6 +2008,11 @@ ur_result_t initMsanDDITable(ur_dditable_t *dditable) {
19452008
result = ur_sanitizer_layer::msan::urGetUSMProcAddrTable(&dditable->USM);
19462009
}
19472010

2011+
if (UR_RESULT_SUCCESS == result) {
2012+
result = ur_sanitizer_layer::msan::urGetEnqueueExpProcAddrTable(
2013+
&dditable->EnqueueExp);
2014+
}
2015+
19482016
if (result != UR_RESULT_SUCCESS) {
19492017
getContext()->logger.error("Initialize MSAN DDI table failed: {}", result);
19502018
}

0 commit comments

Comments
 (0)