diff --git a/transformations/tests/sources/projParallelRoutineDispatch/dispatch_routine.F90 b/transformations/tests/sources/projParallelRoutineDispatch/dispatch_routine.F90 index f4e6dfa1a..5b3fbe641 100644 --- a/transformations/tests/sources/projParallelRoutineDispatch/dispatch_routine.F90 +++ b/transformations/tests/sources/projParallelRoutineDispatch/dispatch_routine.F90 @@ -40,14 +40,16 @@ SUBROUTINE DISPATCH_ROUTINE(YDGEOMETRY, YDCPG_BNDS, YDCPG_OPTS, & INSTEP_DEB=1 INSTEP_FIN=1 -!$ACDC PARALLEL,TARGET=OpenMP/OpenMPSingleColumn/OpenACCSingleColumn,NAME=CPPHINP { +!!!$ACDC PARALLEL,TARGET=OpenMP/OpenMPSingleColumn/OpenACCSingleColumn,NAME=CPPHINP { +!$loki parallel PARALLEL,TARGET=OpenMP/OpenMPSingleColumn/OpenACCSingleColumn,NAME=CPPHINP CALL CPPHINP(YDGEOMETRY, YDMODEL, YDCPG_BNDS%KIDIA, YDCPG_BNDS%KFDIA, YDVARS%GEOMETRY%GEMU%T0, & & YDVARS%GEOMETRY%GELAM%T0, YDVARS%U%T0, YDVARS%V%T0, YDVARS%Q%T0, YDVARS%Q%DL, YDVARS%Q%DM, YDVARS%CVGQ%DL, & & YDVARS%CVGQ%DM, YDCPG_PHY0%XYB%RDELP, YDCPG_DYN0%CTY%EVEL, YDVARS%CVGQ%T0, ZRDG_MU0, ZRDG_MU0LU, ZRDG_MU0M, & & ZRDG_MU0N, ZRDG_CVGQ, YDMF_PHYS_SURF%GSD_VF%PZ0F) -!$ACDC } +!$loki end parallel +!!!$ACDC } IF (LHOOK) CALL DR_HOOK('DISPATCH_ROUTINE',1,ZHOOK_HANDLE) diff --git a/transformations/tests/test_parallel_routine_dispatch.py b/transformations/tests/test_parallel_routine_dispatch.py index 2161fbe7b..e4afc3739 100644 --- a/transformations/tests/test_parallel_routine_dispatch.py +++ b/transformations/tests/test_parallel_routine_dispatch.py @@ -11,7 +11,7 @@ import pytest from loki.frontend import available_frontends, OMNI -from loki import Sourcefile +from loki import Sourcefile, FindNodes, CallStatement from transformations.parallel_routine_dispatch import ParallelRoutineDispatchTransformation @@ -22,10 +22,19 @@ def fixture_here(): @pytest.mark.parametrize('frontend', available_frontends(skip=[OMNI])) -def test_parallel_routine_dispatch_parallel_regions(here, frontend): +def test_parallel_routine_dispatch_dr_hook(here, frontend): source = Sourcefile.from_file(here/'sources/projParallelRoutineDispatch/dispatch_routine.F90', frontend=frontend) + routine = source['dispatch_routine'] + + calls = FindNodes(CallStatement).visit(routine.body) + assert len(calls) == 3 + transformation = ParallelRoutineDispatchTransformation() transformation.apply(source['dispatch_routine']) - assert transformation.dummy_return_value == ['dispatch_routine'] + calls = FindNodes(CallStatement).visit(routine.body) + assert len(calls) == 5 + assert [str(call.name).lower() for call in calls] == [ + 'dr_hook', 'dr_hook', 'cpphinp', 'dr_hook', 'dr_hook' + ] diff --git a/transformations/transformations/parallel_routine_dispatch.py b/transformations/transformations/parallel_routine_dispatch.py index 064a71e55..2639dfd22 100644 --- a/transformations/transformations/parallel_routine_dispatch.py +++ b/transformations/transformations/parallel_routine_dispatch.py @@ -5,6 +5,11 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +from loki.expression import symbols as sym +from loki.ir import ( + is_loki_pragma, get_pragma_parameters, pragma_regions_attached, + FindNodes, nodes as ir +) from loki.transform import Transformation __all__ = ['ParallelRoutineDispatchTransformation'] @@ -12,8 +17,42 @@ class ParallelRoutineDispatchTransformation(Transformation): - def __init__(self): - self.dummy_return_value = [] - def transform_subroutine(self, routine, **kwargs): - self.dummy_return_value += [routine.name.lower()] + with pragma_regions_attached(routine): + for region in FindNodes(ir.PragmaRegion).visit(routine.body): + if is_loki_pragma(region.pragma): + self.process_parallel_region(routine, region) + + def process_parallel_region(self, routine, region): + pragma_content = region.pragma.content.split(maxsplit=1) + pragma_content = [entry.split('=', maxsplit=1) for entry in pragma_content[1].split(',')] + pragma_attrs = { + entry[0].lower(): entry[1] if len(entry) == 2 else None + for entry in pragma_content + } + if 'parallel' not in pragma_attrs: + return + + dr_hook_calls = self.create_dr_hook_calls( + routine, pragma_attrs['name'], + sym.Variable(name='ZHOOK_HANDLE_FIELD_API', scope=routine) + ) + + region.prepend(dr_hook_calls[0]) + region.append(dr_hook_calls[1]) + + @staticmethod + def create_dr_hook_calls(scope, cdname, pkey): + dr_hook_calls = [] + for kswitch in (0, 1): + call_stmt = ir.CallStatement( + name=sym.Variable(name='DR_HOOK', scope=scope), + arguments=(sym.StringLiteral(cdname), sym.IntLiteral(kswitch), pkey) + ) + dr_hook_calls += [ + ir.Conditional( + condition=sym.Variable(name='LHOOK', scope=scope), + inline=True, body=(call_stmt,) + ) + ] + return dr_hook_calls