Skip to content

Commit df3febb

Browse files
committed
[UR] Track handles from enqueue functions in validation layer
On handle release, the validation layer reports event handles obtained via enqueue functions as invalid. This patch should fix this behavior.
1 parent c7dbbef commit df3febb

File tree

5 files changed

+295
-12
lines changed

5 files changed

+295
-12
lines changed

Diff for: unified-runtime/scripts/templates/helper.py

+17-6
Original file line numberDiff line numberDiff line change
@@ -1890,18 +1890,22 @@ def _get_create_get_retain_release_functions(specs, namespace, tags):
18901890
funcs.append(make_func_name(namespace, tags, obj))
18911891

18921892
create_suffixes = r"(Create[A-Za-z]*){1}$"
1893+
enqueue_prefixes = r"(Enqueue){1}"
18931894
get_suffixes = r"(Get){1}$"
18941895
retain_suffixes = r"(Retain){1}$"
18951896
release_suffixes = r"(Release){1}$"
18961897
common_prefix = r"^" + namespace
18971898

18981899
create_exp = common_prefix + r"[A-Za-z]+" + create_suffixes
1900+
enqueue_exp = common_prefix + enqueue_prefixes + r"[A-Za-z]+$"
18991901
get_exp = common_prefix + r"[A-Za-z]+" + get_suffixes
19001902
retain_exp = common_prefix + r"[A-Za-z]+" + retain_suffixes
19011903
release_exp = common_prefix + r"[A-Za-z]+" + release_suffixes
19021904

19031905
create_funcs, get_funcs, retain_funcs, release_funcs = (
1904-
list(filter(lambda f: re.match(create_exp, f), funcs)),
1906+
list(
1907+
filter(lambda f: re.match(create_exp, f) or re.match(enqueue_exp, f), funcs)
1908+
),
19051909
list(filter(lambda f: re.match(get_exp, f), funcs)),
19061910
list(filter(lambda f: re.match(retain_exp, f), funcs)),
19071911
list(filter(lambda f: re.match(release_exp, f), funcs)),
@@ -1934,10 +1938,17 @@ def get_handle_create_get_retain_release_functions(specs, namespace, tags):
19341938
continue
19351939

19361940
class_type = subt(namespace, tags, h["class"])
1937-
create_funcs = list(filter(lambda f: class_type in f, funcs["create"]))
1938-
get_funcs = list(filter(lambda f: class_type in f, funcs["get"]))
1939-
retain_funcs = list(filter(lambda f: class_type in f, funcs["retain"]))
1940-
release_funcs = list(filter(lambda f: class_type in f, funcs["release"]))
1941+
1942+
prefixes = [class_type]
1943+
if class_type == namespace + "Event":
1944+
prefixes.append(namespace + "Enqueue")
1945+
# Functions prefixed with $xEnqueue are also 'create' functions for event handles
1946+
1947+
has_prefix = lambda f: any(p in f for p in prefixes)
1948+
create_funcs = list(filter(has_prefix, funcs["create"]))
1949+
get_funcs = list(filter(has_prefix, funcs["get"]))
1950+
retain_funcs = list(filter(has_prefix, funcs["retain"]))
1951+
release_funcs = list(filter(has_prefix, funcs["release"]))
19411952

19421953
record = {}
19431954
record["handle"] = subt(namespace, tags, h["name"])
@@ -1953,7 +1964,7 @@ def get_handle_create_get_retain_release_functions(specs, namespace, tags):
19531964

19541965
"""
19551966
Public:
1956-
returns a list of objects representing functions that accept $x_queue_handle_t as a first param
1967+
returns a list of objects representing functions that accept $x_queue_handle_t as a first param
19571968
"""
19581969

19591970

Diff for: unified-runtime/scripts/templates/valddi.cpp.mako

+8-3
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ from templates import helper as th
1111
handle_create_get_retain_release_funcs=th.get_handle_create_get_retain_release_functions(specs, n, tags)
1212
%>/*
1313
*
14-
* Copyright (C) 2023-2024 Intel Corporation
14+
* Copyright (C) 2023-2025 Intel Corporation
1515
*
1616
* Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM
1717
* Exceptions.
@@ -117,9 +117,14 @@ namespace ur_validation_layer
117117
<%
118118
tp_handle_funcs = next((hf for hf in handle_create_get_retain_release_funcs if th.subt(n, tags, tp['type']) in [hf['handle'], hf['handle'] + "*"]), None)
119119
is_handle_to_adapter = ("_adapter_handle_t" in tp['type'])
120+
is_handle_to_event = ("_event_handle_t" in tp['type'])
120121
%>
121122
%if func_name in tp_handle_funcs['create']:
122-
if( getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS )
123+
if( getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS
124+
%if is_handle_to_event:
125+
&& ${tp['name']}
126+
%endif
127+
)
123128
{
124129
getContext()->refCountContext->createRefCount(*${tp['name']});
125130
}
@@ -236,7 +241,7 @@ namespace ur_validation_layer
236241
if (enableLeakChecking) {
237242
getContext()->refCountContext->logInvalidReferences();
238243
}
239-
244+
240245
return ${X}_RESULT_SUCCESS;
241246
}
242247

0 commit comments

Comments
 (0)