Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions unified-runtime/source/adapters/level_zero/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
//===----------------------------------------------------------------------===//

#include "common.hpp"
#include "external/driver_experimental/zex_graph.h"
#include "logger/ur_logger.hpp"
#include "usm.hpp"

Expand Down Expand Up @@ -350,6 +351,12 @@ getZexStructureType<ze_intel_xe_device_exp_properties_t>() {
return ZE_STRUCTURE_TYPE_INTEL_XE_DEVICE_EXP_PROPERTIES;
}

template <>
ze_structure_type_ext_t
getZexStructureType<ze_record_replay_graph_exp_properties_t>() {
return ZE_STRUCTURE_TYPE_RECORD_REPLAY_GRAPH_EXP_PROPERTIES;
}

// Global variables for ZER_EXT_RESULT_ADAPTER_SPECIFIC_ERROR
thread_local int32_t ErrorMessageCode = 0;
thread_local char ErrorMessage[MaxMessageSize]{};
Expand Down
20 changes: 19 additions & 1 deletion unified-runtime/source/adapters/level_zero/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1575,7 +1575,25 @@ ur_result_t urDeviceGetInfo(
return ReturnValue(static_cast<ur_bool_t>(Device->isIntegrated() != 0));
case UR_DEVICE_INFO_GRAPH_RECORD_AND_REPLAY_SUPPORT_EXP:
#ifdef UR_ADAPTER_LEVEL_ZERO_V2
return ReturnValue(Device->Platform->ZeGraphExt.Supported);
{
if (!Device->Platform->ZeGraphExt.Supported) {
return ReturnValue(false);
}

ze_record_replay_graph_exp_properties_t GraphProperties{};
GraphProperties.stype =
ZE_STRUCTURE_TYPE_RECORD_REPLAY_GRAPH_EXP_PROPERTIES;
GraphProperties.pNext = nullptr;
ZeStruct<ze_device_properties_t> DeviceProperties;
DeviceProperties.pNext = &GraphProperties;
ZE2UR_CALL(zeDeviceGetProperties, (ZeDevice, &DeviceProperties));

constexpr ze_record_replay_graph_exp_flags_t GraphModeMask =
ZE_RECORD_REPLAY_GRAPH_EXP_FLAG_IMMUTABLE_GRAPH |
ZE_RECORD_REPLAY_GRAPH_EXP_FLAG_MUTABLE_GRAPH;
return ReturnValue(static_cast<ur_bool_t>(
(GraphProperties.graphFlags & GraphModeMask) != 0));
}
#else
return ReturnValue(false);
#endif
Expand Down
103 changes: 57 additions & 46 deletions unified-runtime/source/adapters/level_zero/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,12 @@ ur_result_t ur_platform_handle_t_::initialize() {
zeDriverExtensionMap[extension.name] = extension.version;
}

const auto GraphExtension =
zeDriverExtensionMap.find(ZE_RECORD_REPLAY_GRAPH_EXP_NAME);
const bool ZeGraphExtensionSupported =
GraphExtension != zeDriverExtensionMap.end() &&
GraphExtension->second >= ZE_RECORD_REPLAY_GRAPH_EXP_VERSION_1_0;

ZE2UR_CALL(zelLoaderTranslateHandle, (ZEL_HANDLE_DRIVER, ZeDriver,
(void **)&ZeDriverHandleExpTranslated));

Expand Down Expand Up @@ -561,53 +567,58 @@ ur_result_t ur_platform_handle_t_::initialize() {
ZeMemGetPitchFor2dImageExt.Supported =
ZeMemGetPitchFor2dImageExt.zeMemGetPitchFor2dImage != nullptr;

// Populate Graph Extension structure. Mandatory graph functions.
std::unordered_map<std::string, void **> ZeGraphFuncNameToAddrMap = {
{"zeGraphCreateExp",
reinterpret_cast<void **>(&ZeGraphExt.zeGraphCreateExp)},
{"zeCommandListBeginGraphCaptureExp",
reinterpret_cast<void **>(
&ZeGraphExt.zeCommandListBeginGraphCaptureExp)},
{"zeCommandListBeginCaptureIntoGraphExp",
reinterpret_cast<void **>(
&ZeGraphExt.zeCommandListBeginCaptureIntoGraphExp)},
{"zeCommandListEndGraphCaptureExp",
reinterpret_cast<void **>(&ZeGraphExt.zeCommandListEndGraphCaptureExp)},
{"zeCommandListInstantiateGraphExp",
reinterpret_cast<void **>(&ZeGraphExt.zeCommandListInstantiateGraphExp)},
{"zeCommandListAppendGraphExp",
reinterpret_cast<void **>(&ZeGraphExt.zeCommandListAppendGraphExp)},
{"zeGraphDestroyExp",
reinterpret_cast<void **>(&ZeGraphExt.zeGraphDestroyExp)},
{"zeExecutableGraphDestroyExp",
reinterpret_cast<void **>(&ZeGraphExt.zeExecutableGraphDestroyExp)},
{"zeCommandListIsGraphCaptureEnabledExp",
reinterpret_cast<void **>(
&ZeGraphExt.zeCommandListIsGraphCaptureEnabledExp)},
{"zeGraphIsEmptyExp",
reinterpret_cast<void **>(&ZeGraphExt.zeGraphIsEmptyExp)},
{"zeGraphDumpContentsExp",
reinterpret_cast<void **>(&ZeGraphExt.zeGraphDumpContentsExp)},
};

ZeGraphExt.Supported = true;
for (auto &[funcName, funcAddr] : ZeGraphFuncNameToAddrMap) {
ZE_CALL_NOCHECK(zeDriverGetExtensionFunctionAddress,
(ZeDriver, funcName.c_str(), funcAddr));
ZeGraphExt.Supported &= (*funcAddr != nullptr);
}

// Optional graph functions. If the function is not supported due to driver
// version, then still mark graphs as supported and only return unsupported
// code in affected function.
std::unordered_map<std::string, void **> ZeGraphOptionalFuncNameToAddrMap = {
{"zeCommandListGetGraphExp",
reinterpret_cast<void **>(&ZeGraphExt.zeCommandListGetGraphExp)},
};
if (ZeGraphExtensionSupported) {
// Populate Graph Extension structure. Mandatory graph functions.
std::unordered_map<std::string, void **> ZeGraphFuncNameToAddrMap = {
{"zeGraphCreateExp",
reinterpret_cast<void **>(&ZeGraphExt.zeGraphCreateExp)},
{"zeCommandListBeginGraphCaptureExp",
reinterpret_cast<void **>(
&ZeGraphExt.zeCommandListBeginGraphCaptureExp)},
{"zeCommandListBeginCaptureIntoGraphExp",
reinterpret_cast<void **>(
&ZeGraphExt.zeCommandListBeginCaptureIntoGraphExp)},
{"zeCommandListEndGraphCaptureExp",
reinterpret_cast<void **>(
&ZeGraphExt.zeCommandListEndGraphCaptureExp)},
{"zeCommandListInstantiateGraphExp",
reinterpret_cast<void **>(
&ZeGraphExt.zeCommandListInstantiateGraphExp)},
{"zeCommandListAppendGraphExp",
reinterpret_cast<void **>(&ZeGraphExt.zeCommandListAppendGraphExp)},
{"zeGraphDestroyExp",
reinterpret_cast<void **>(&ZeGraphExt.zeGraphDestroyExp)},
{"zeExecutableGraphDestroyExp",
reinterpret_cast<void **>(&ZeGraphExt.zeExecutableGraphDestroyExp)},
{"zeCommandListIsGraphCaptureEnabledExp",
reinterpret_cast<void **>(
&ZeGraphExt.zeCommandListIsGraphCaptureEnabledExp)},
{"zeGraphIsEmptyExp",
reinterpret_cast<void **>(&ZeGraphExt.zeGraphIsEmptyExp)},
{"zeGraphDumpContentsExp",
reinterpret_cast<void **>(&ZeGraphExt.zeGraphDumpContentsExp)},
};

ZeGraphExt.Supported = true;
for (auto &[funcName, funcAddr] : ZeGraphFuncNameToAddrMap) {
ZE_CALL_NOCHECK(zeDriverGetExtensionFunctionAddress,
(ZeDriver, funcName.c_str(), funcAddr));
ZeGraphExt.Supported &= (*funcAddr != nullptr);
}

for (auto &[funcName, funcAddr] : ZeGraphOptionalFuncNameToAddrMap) {
ZE_CALL_NOCHECK(zeDriverGetExtensionFunctionAddress,
(ZeDriver, funcName.c_str(), funcAddr));
// Optional graph functions. If the function is not supported due to driver
// version, then still mark graphs as supported and only return unsupported
// code in affected function.
std::unordered_map<std::string, void **> ZeGraphOptionalFuncNameToAddrMap =
{
{"zeCommandListGetGraphExp",
reinterpret_cast<void **>(&ZeGraphExt.zeCommandListGetGraphExp)},
};

for (auto &[funcName, funcAddr] : ZeGraphOptionalFuncNameToAddrMap) {
ZE_CALL_NOCHECK(zeDriverGetExtensionFunctionAddress,
(ZeDriver, funcName.c_str(), funcAddr));
}
}

if (this->isDriverVersionNewerOrSimilar(1, 14, 36035)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1320,7 +1320,7 @@ ur_result_t ur_command_list_manager::appendKernelLaunchWithArgsExp(
} else {
// We cannot pass cooperativeKernelLaunchRequested to
// appendKernelLaunchWithArgsExpOld() because appendKernelLaunch() must
// check it on its own since it is called also from enqueueKernelLaunch().
// check it on its own since it is called from other kernel launch paths.
return appendKernelLaunchWithArgsExpOld(
hKernel, workDim, pGlobalWorkOffset, pGlobalWorkSize, pLocalWorkSize,
numArgs, pArgs, launchPropList, waitListView, phEvent);
Expand Down Expand Up @@ -1401,7 +1401,7 @@ ur_command_list_manager::appendGraph(ur_exp_executable_graph_handle_t hGraph,
return UR_RESULT_SUCCESS;
}

ur_result_t ur_command_list_manager::isGraphCaptureActive(bool *pResult) {
ur_result_t ur_command_list_manager::queryGraphCaptureActive(bool *pResult) {
if (!checkGraphExtensionSupport(hContext.get())) {
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ struct ur_command_list_manager {
ur_result_t beginGraphCapture();
ur_result_t beginCaptureIntoGraph(ur_exp_graph_handle_t hGraph);
ur_result_t endGraphCapture(ur_exp_graph_handle_t *phGraph);
ur_result_t isGraphCaptureActive(bool *pResult);
ur_result_t queryGraphCaptureActive(bool *pResult);
ur_result_t getGraph(ur_exp_graph_handle_t *phGraph);

v2::raii::command_list_unique_handle &&releaseCommandList();
Expand Down
Loading
Loading