Skip to content

Commit

Permalink
#1664: Add tests for ACCKernelDirective with async
Browse files Browse the repository at this point in the history
  • Loading branch information
svalat committed Mar 14, 2023
1 parent 47654db commit 9da051d
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 1 deletion.
27 changes: 27 additions & 0 deletions src/psyclone/tests/psyir/nodes/acc_directives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,33 @@ def test_acckernelsdirective_gencode(default_present):
" END DO\n"
" !$acc end kernels\n" in code)

# (1/1) Method gen_code
@pytest.mark.parametrize("async_queue", [False, 1, Signature('stream1')])
def test_acckernelsdirective_gencode(async_queue):
'''Check that the gen_code method in the ACCKernelsDirective class
generates the expected code. Use the dynamo0.3 API.
'''
_, info = parse(os.path.join(BASE_PATH, "1_single_invoke.f90"))
psy = PSyFactory(distributed_memory=False).create(info)
sched = psy.invokes.get('invoke_0_testkern_type').schedule

trans = ACCKernelsTrans()
trans.apply(sched, {"async_queue": async_queue})

code = str(psy.gen)
string = ""
if async_queue:
if isinstance(async_queue, int):
string = " async(1)"
elif isinstance(async_queue, Signature):
string = " async(stream1)"
assert (
f" !$acc kernels{string}\n"
f" DO cell=loop0_start,loop0_stop\n" in code)
assert (
" END DO\n"
" !$acc end kernels\n" in code)

def test_acckerneldirective_equality():
''' Test the __eq__ method of ACCKernelsDirective node. '''
Expand Down
17 changes: 16 additions & 1 deletion src/psyclone/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2518,11 +2518,12 @@ def apply(self, node, options=None):
if not options:
options = {}
default_present = options.get("default_present", False)
async_queue = options.get("async_queue", False)

# Create a directive containing the nodes in node_list and insert it.
directive = ACCKernelsDirective(
parent=parent, children=[node.detach() for node in node_list],
default_present=default_present)
default_present=default_present, async_queue=async_queue)

parent.children.insert(start_index, directive)

Expand Down Expand Up @@ -2568,6 +2569,20 @@ def validate(self, nodes, options):
"A kernels transformation must enclose at least one loop or "
"array range but none were found.")

# do not has mixed async
async_queue = None
if options != None:
async_queue = options.get('async_queue', False)
if async_queue != False:
directive_cls = (ACCKernelsDirective, ACCParallelDirective)
for directive in sched.walk(directive_cls):
if directive.async_queue != False and directive.async_queue != async_queue:
raise TransformationError(f"Tried to apply async() while another one is used internally \
with different queue ({async_queue} != {directive.async_queue}) !")
directive = sched.ancestor(directive_cls)
if directive and directive.async_queue != False and directive.async_queue != async_queue:
raise TransformationError(f"Tried to apply async() while another one is used in ancestor \
with different queue ({async_queue} != {directive.async_queue}) !")

class ACCDataTrans(RegionTrans):
'''
Expand Down

0 comments on commit 9da051d

Please sign in to comment.