Skip to content

Commit 200ca34

Browse files
authored
[SYCL] Fix shortcut functions to get kernel info for specific device (#20435)
Currently in those shortcut functions we create a kernel bundle for entire context which is expensive if we have multiple devices in the context, so we need to create the bundle only for the provided device.
1 parent b111993 commit 200ca34

File tree

3 files changed

+78
-3
lines changed

3 files changed

+78
-3
lines changed

sycl/include/sycl/ext/oneapi/get_kernel_info.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ template <typename KernelName, typename Param>
4040
typename sycl::detail::is_kernel_device_specific_info_desc<Param>::return_type
4141
get_kernel_info(const context &Ctx, const device &Dev) {
4242
auto Bundle =
43-
sycl::get_kernel_bundle<KernelName, sycl::bundle_state::executable>(Ctx);
43+
sycl::get_kernel_bundle<KernelName, sycl::bundle_state::executable>(
44+
Ctx, {Dev});
4445
return Bundle.template get_kernel<KernelName>().template get_info<Param>(Dev);
4546
}
4647

@@ -49,7 +50,7 @@ typename sycl::detail::is_kernel_device_specific_info_desc<Param>::return_type
4950
get_kernel_info(const queue &Q) {
5051
auto Bundle =
5152
sycl::get_kernel_bundle<KernelName, sycl::bundle_state::executable>(
52-
Q.get_context());
53+
Q.get_context(), {Q.get_device()});
5354
return Bundle.template get_kernel<KernelName>().template get_info<Param>(
5455
Q.get_device());
5556
}
@@ -73,7 +74,7 @@ std::enable_if_t<ext::oneapi::experimental::is_kernel_v<Func>,
7374
Param>::return_type>
7475
get_kernel_info(const context &ctxt, const device &dev) {
7576
auto Bundle = sycl::ext::oneapi::experimental::get_kernel_bundle<
76-
Func, sycl::bundle_state::executable>(ctxt);
77+
Func, sycl::bundle_state::executable>(ctxt, {dev});
7778
return Bundle.template ext_oneapi_get_kernel<Func>().template get_info<Param>(
7879
dev);
7980
}

sycl/unittests/kernel-and-program/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ add_sycl_unittest(KernelAndProgramTests OBJECT
99
OutOfResources.cpp
1010
InMemCacheEviction.cpp
1111
KernelArgs.cpp
12+
KernelInfoShortcuts.cpp
1213
)
1314
target_compile_definitions(KernelAndProgramTests_non_preview PRIVATE __SYCL_INTERNAL_API)
1415
target_compile_definitions(KernelAndProgramTests_preview PRIVATE __SYCL_INTERNAL_API __INTEL_PREVIEW_BREAKING_CHANGES)
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
//==-------------------------- KernelInfoShortcuts.cpp -------------------==//
2+
//
3+
// Unit test to ensure get_kernel_info for a device queries/uses kernel bundle
4+
// for that specific device only and doesn't trigger builds for all devices.
5+
//
6+
7+
#include <helpers/MockDeviceImage.hpp>
8+
#include <helpers/MockKernelInfo.hpp>
9+
#include <helpers/ScopedEnvVar.hpp>
10+
#include <helpers/UrMock.hpp>
11+
#include <sycl/sycl.hpp>
12+
13+
#include <gtest/gtest.h>
14+
15+
using namespace sycl;
16+
using namespace sycl::unittest;
17+
18+
class ShortcutKernelInfoTestKernel;
19+
MOCK_INTEGRATION_HEADER(ShortcutKernelInfoTestKernel)
20+
21+
static int ProgramBuildCounter = 0;
22+
static ur_result_t redefinedurProgramBuild(void *pParams) {
23+
++ProgramBuildCounter;
24+
return UR_RESULT_SUCCESS;
25+
}
26+
27+
static ur_result_t redefinedDeviceGet(void *pParams) {
28+
auto params = *static_cast<ur_device_get_params_t *>(pParams);
29+
if (*params.ppNumDevices) {
30+
**params.ppNumDevices = 2; // two devices total
31+
return UR_RESULT_SUCCESS;
32+
}
33+
if (*params.pphDevices) {
34+
// provide two mock device handles
35+
(*params.pphDevices)[0] = reinterpret_cast<ur_device_handle_t>(0x1);
36+
(*params.pphDevices)[1] = reinterpret_cast<ur_device_handle_t>(0x2);
37+
}
38+
return UR_RESULT_SUCCESS;
39+
}
40+
41+
ur_result_t redefinedurKernelGetGroupInfo(void *pParams) {
42+
return UR_RESULT_SUCCESS;
43+
}
44+
45+
TEST(ShortcutKernelInfo, QueryInfoForSingleDevice) {
46+
unittest::UrMock<> Mock;
47+
static sycl::unittest::MockDeviceImage DevImage =
48+
sycl::unittest::generateDefaultImage({"ShortcutKernelInfoTestKernel"});
49+
static sycl::unittest::MockDeviceImageArray<1> DevImageArray = {&DevImage};
50+
51+
mock::getCallbacks().set_replace_callback("urDeviceGet", &redefinedDeviceGet);
52+
mock::getCallbacks().set_replace_callback("urProgramBuildExp",
53+
&redefinedurProgramBuild);
54+
mock::getCallbacks().set_replace_callback("urKernelGetGroupInfo",
55+
&redefinedurKernelGetGroupInfo);
56+
57+
platform Plt = platform();
58+
std::vector<device> Devs = Plt.get_devices();
59+
ASSERT_GE(Devs.size(), 2u) << "Test requires at least 2 devices";
60+
context Ctx = context(Devs);
61+
queue Queue = queue(Ctx, Devs[0]);
62+
63+
// Query kernel info for the first device only
64+
ProgramBuildCounter = 0;
65+
sycl::ext::oneapi::get_kernel_info<
66+
ShortcutKernelInfoTestKernel,
67+
sycl::info::kernel_device_specific::work_group_size>(Ctx, Devs[0]);
68+
sycl::ext::oneapi::get_kernel_info<
69+
ShortcutKernelInfoTestKernel,
70+
sycl::info::kernel_device_specific::work_group_size>(Queue);
71+
72+
EXPECT_EQ(ProgramBuildCounter, 1);
73+
}

0 commit comments

Comments
 (0)