Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[UR] Track handles from enqueue functions in validation layer #17855

Open
wants to merge 1 commit into
base: sycl
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions unified-runtime/scripts/templates/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1890,18 +1890,22 @@ def _get_create_get_retain_release_functions(specs, namespace, tags):
funcs.append(make_func_name(namespace, tags, obj))

create_suffixes = r"(Create[A-Za-z]*){1}$"
enqueue_prefixes = r"(Enqueue){1}"
get_suffixes = r"(Get){1}$"
retain_suffixes = r"(Retain){1}$"
release_suffixes = r"(Release){1}$"
common_prefix = r"^" + namespace

create_exp = common_prefix + r"[A-Za-z]+" + create_suffixes
enqueue_exp = common_prefix + enqueue_prefixes + r"[A-Za-z]+$"
get_exp = common_prefix + r"[A-Za-z]+" + get_suffixes
retain_exp = common_prefix + r"[A-Za-z]+" + retain_suffixes
release_exp = common_prefix + r"[A-Za-z]+" + release_suffixes

create_funcs, get_funcs, retain_funcs, release_funcs = (
list(filter(lambda f: re.match(create_exp, f), funcs)),
list(
filter(lambda f: re.match(create_exp, f) or re.match(enqueue_exp, f), funcs)
),
list(filter(lambda f: re.match(get_exp, f), funcs)),
list(filter(lambda f: re.match(retain_exp, f), funcs)),
list(filter(lambda f: re.match(release_exp, f), funcs)),
Expand Down Expand Up @@ -1934,10 +1938,17 @@ def get_handle_create_get_retain_release_functions(specs, namespace, tags):
continue

class_type = subt(namespace, tags, h["class"])
create_funcs = list(filter(lambda f: class_type in f, funcs["create"]))
get_funcs = list(filter(lambda f: class_type in f, funcs["get"]))
retain_funcs = list(filter(lambda f: class_type in f, funcs["retain"]))
release_funcs = list(filter(lambda f: class_type in f, funcs["release"]))

prefixes = [class_type]
if class_type == namespace + "Event":
prefixes.append(namespace + "Enqueue")
# Functions prefixed with $xEnqueue are also 'create' functions for event handles

has_prefix = lambda f: any(p in f for p in prefixes)
create_funcs = list(filter(has_prefix, funcs["create"]))
get_funcs = list(filter(has_prefix, funcs["get"]))
retain_funcs = list(filter(has_prefix, funcs["retain"]))
release_funcs = list(filter(has_prefix, funcs["release"]))

record = {}
record["handle"] = subt(namespace, tags, h["name"])
Expand All @@ -1953,7 +1964,7 @@ def get_handle_create_get_retain_release_functions(specs, namespace, tags):

"""
Public:
returns a list of objects representing functions that accept $x_queue_handle_t as a first param
returns a list of objects representing functions that accept $x_queue_handle_t as a first param
"""


Expand Down
11 changes: 8 additions & 3 deletions unified-runtime/scripts/templates/valddi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ from templates import helper as th
handle_create_get_retain_release_funcs=th.get_handle_create_get_retain_release_functions(specs, n, tags)
%>/*
*
* Copyright (C) 2023-2024 Intel Corporation
* Copyright (C) 2023-2025 Intel Corporation
*
* Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM
* Exceptions.
Expand Down Expand Up @@ -117,9 +117,14 @@ namespace ur_validation_layer
<%
tp_handle_funcs = next((hf for hf in handle_create_get_retain_release_funcs if th.subt(n, tags, tp['type']) in [hf['handle'], hf['handle'] + "*"]), None)
is_handle_to_adapter = ("_adapter_handle_t" in tp['type'])
is_handle_to_event = ("_event_handle_t" in tp['type'])
%>
%if func_name in tp_handle_funcs['create']:
if( getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS )
if( getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS
%if is_handle_to_event:
&& ${tp['name']}
%endif
)
{
getContext()->refCountContext->createRefCount(*${tp['name']});
}
Expand Down Expand Up @@ -236,7 +241,7 @@ namespace ur_validation_layer
if (enableLeakChecking) {
getContext()->refCountContext->logInvalidReferences();
}

return ${X}_RESULT_SUCCESS;
}

Expand Down
Loading
Loading