diff --git a/loki/transformations/tests/test_utilities.py b/loki/transformations/tests/test_utilities.py index bd2af3b73..56e4259c1 100644 --- a/loki/transformations/tests/test_utilities.py +++ b/loki/transformations/tests/test_utilities.py @@ -12,14 +12,15 @@ symbols as sym, FindVariables, FindInlineCalls, SubstituteExpressions ) from loki.frontend import available_frontends, OMNI -from loki.ir import nodes as ir, FindNodes, pragmas_attached +from loki.ir import nodes as ir, FindNodes, pragmas_attached, CallStatement from loki.types import BasicType from loki.transformations.utilities import ( single_variable_declaration, recursive_expression_map_update, convert_to_lower_case, replace_intrinsics, rename_variables, get_integer_variable, get_loop_bounds, is_driver_loop, - find_driver_loops, get_local_arrays, check_routine_pragmas + find_driver_loops, get_local_arrays, check_routine_pragmas, + RemoveDuplicateArgs ) @@ -546,3 +547,132 @@ def test_transform_utilites_check_routine_pragmas(frontend, tmp_path): assert check_routine_pragmas(module['test_acc_seq'], directive=None) assert check_routine_pragmas(module['test_loki_seq'], directive=None) assert check_routine_pragmas(module['test_acc_vec'], directive='openacc') + +@pytest.mark.parametrize('frontend', available_frontends()) +@pytest.mark.parametrize('pass_as_kwarg', (True, False)) +@pytest.mark.parametrize('recurse_to_kernels', (True, False)) +@pytest.mark.parametrize('rename_common', (True, False)) +def test_utilities_remove_duplicate_args(tmp_path, frontend, pass_as_kwarg, recurse_to_kernels, rename_common): + """ + Test lowering constant array indices + """ + fcode_driver = f""" +subroutine driver(nlon,nlev,nb,var) + use kernel_mod, only: kernel + implicit none + integer, intent(in) :: nlon,nlev,nb + real, intent(inout) :: var(nlon,nlev,5,nb) + integer :: ibl + integer :: offset + integer :: some_val + integer :: loop_start, loop_end + loop_start = 2 + loop_end = nb + some_val = 0 + offset = 1 + !$omp test + do ibl=loop_start, loop_end + call kernel(nlon,nlev,{'var1=' if pass_as_kwarg else ''}var(:,:,1,ibl), {'var2=' if pass_as_kwarg else ''}var(:,:,1,ibl), {'another_var=' if pass_as_kwarg else ''}var(:,:,2:5,ibl), {'icend=' if pass_as_kwarg else ''}offset, {'lstart=' if pass_as_kwarg else ''}loop_start, {'lend=' if pass_as_kwarg else ''}loop_end, {'kend=' if pass_as_kwarg else ''}nlev) + call kernel(nlon,nlev,{'var1=' if pass_as_kwarg else ''}var(:,:,1,ibl), {'var2=' if pass_as_kwarg else ''}var(:,:,1,ibl), {'another_var=' if pass_as_kwarg else ''}var(:,:,2:5,ibl), {'icend=' if pass_as_kwarg else ''}offset, {'lstart=' if pass_as_kwarg else ''}loop_start, {'lend=' if pass_as_kwarg else ''}loop_end, {'kend=' if pass_as_kwarg else ''}nlev) + enddo +end subroutine driver +""" + + fcode_kernel = """ +module kernel_mod +implicit none +contains +subroutine kernel(nlon,nlev,var1,var2,another_var,icend,lstart,lend,kend) + use compute_mod, only: compute + implicit none + integer, intent(in) :: nlon,nlev,icend,lstart,lend,kend + real, intent(inout) :: var1(nlon,nlev) + real, intent(inout) :: var2(nlon,nlev) + real, intent(inout) :: another_var(nlon,nlev,4) + integer :: jk, jl, jt + var1(:,:) = 0. + do jk = 1,kend + do jl = 1, nlon + var1(jl, jk) = 0. + var2(jl, jk) = 1.0 + do jt= 1,4 + another_var(jl, jk, jt) = 0.0 + end do + end do + end do + call compute(nlon,nlev,var1, var2) + call compute(nlon,nlev,var1, var2) +end subroutine kernel +end module kernel_mod +""" + + fcode_nested_kernel = """ +module compute_mod +implicit none +contains +subroutine compute(nlon,nlev,b_var,a_var) + implicit none + integer, intent(in) :: nlon,nlev + real, intent(inout) :: b_var(nlon,nlev) + real, intent(inout) :: a_var(nlon,nlev) + b_var(:,:) = 0. + a_var(:,:) = 1.0 +end subroutine compute +end module compute_mod +""" + + nested_kernel_mod = Module.from_source(fcode_nested_kernel, frontend=frontend, xmods=[tmp_path]) + kernel_mod = Module.from_source(fcode_kernel, frontend=frontend, definitions=nested_kernel_mod, xmods=[tmp_path]) + driver = Subroutine.from_source(fcode_driver, frontend=frontend, definitions=kernel_mod, xmods=[tmp_path]) + + kwargs = {'recurse_to_kernels': recurse_to_kernels, 'rename_common': rename_common} + + RemoveDuplicateArgs(**kwargs).apply(driver, role='driver', targets=('kernel',)) + RemoveDuplicateArgs(**kwargs).apply(kernel_mod['kernel'], role='kernel', targets=('compute',)) + RemoveDuplicateArgs(**kwargs).apply(nested_kernel_mod['compute'], role='kernel') + + # driver + kernel_var_name = 'var' if rename_common else 'var1' + kernel_calls = FindNodes(CallStatement).visit(driver.body) + for kernel_call in kernel_calls: + if pass_as_kwarg: + # print(f"kernel_call.kwarguments: {list(kernel_call.kwarguments)}") + assert (kernel_var_name, 'var(:, :, 1, ibl)') in kernel_call.kwarguments + assert ('var2', 'var(:, :, 1, ibl)') not in kernel_call.kwarguments + arg1 = kernel_call.kwarguments[0][1] + arg2 = kernel_call.kwarguments[1][1] + else: + assert 'var(:, :, 1, ibl)' in kernel_call.arguments + assert 'var2(:, :, 1, ibl)' not in kernel_call.arguments + arg1 = kernel_call.arguments[2] + arg2 = kernel_call.arguments[3] + # print(f"arg1: {arg1} | arg2: {arg2}") + assert arg1.dimensions == (':', ':', '1', 'ibl') + assert arg2.dimensions == (':', ':', '2:5', 'ibl') + # kernel + kernel_vars = kernel_mod['kernel'].variable_map + kernel_args = [arg.name.lower() for arg in kernel_mod['kernel'].arguments] + assert kernel_var_name in kernel_args + assert 'var2' not in kernel_args + assert 'var2' not in kernel_vars + assert kernel_vars[kernel_var_name].shape == ('nlon', 'nlev') + assert kernel_vars['another_var'].dimensions == ('nlon', 'nlev', 4) + compute_calls = FindNodes(CallStatement).visit(kernel_mod['kernel'].body) + for compute_call in compute_calls: + assert kernel_var_name in compute_call.arguments + assert 'var2' not in compute_call.arguments + # nested_kernel + nested_kernel = nested_kernel_mod['compute'] + nested_kernel_vars = nested_kernel.variable_map + nested_kernel_args = [arg.name.lower() for arg in nested_kernel.arguments] + nested_kernel_var_name = 'var' if rename_common else 'b_var' + if recurse_to_kernels: + assert nested_kernel_var_name in nested_kernel_args + assert 'a_var' not in nested_kernel_args + assert nested_kernel_var_name in nested_kernel_vars + assert 'a_var' not in nested_kernel_vars + else: + assert 'b_var' in nested_kernel_args + assert 'a_var' in nested_kernel_args + assert 'b_var' in nested_kernel_vars + assert 'a_var' in nested_kernel_vars diff --git a/loki/transformations/utilities.py b/loki/transformations/utilities.py index 5eb277324..0d460b16c 100644 --- a/loki/transformations/utilities.py +++ b/loki/transformations/utilities.py @@ -9,9 +9,12 @@ Collection of utility routines to deal with general language conversion. """ +import os import platform from collections import defaultdict +import itertools as it from pymbolic.primitives import Expression +from loki.batch import Transformation, ProcedureItem from loki.expression import ( symbols as sym, FindVariables, FindInlineCalls, FindLiterals, SubstituteExpressions, SubstituteExpressionsMapper, ExpressionFinder, @@ -19,11 +22,11 @@ ) from loki.ir import ( nodes as ir, Import, TypeDef, VariableDeclaration, - StatementFunction, Transformer, FindNodes + StatementFunction, Transformer, FindNodes, CallStatement ) from loki.module import Module from loki.subroutine import Subroutine -from loki.tools import CaseInsensitiveDict, as_tuple +from loki.tools import CaseInsensitiveDict, as_tuple, flatten from loki.types import SymbolAttributes, BasicType, DerivedType, ProcedureType @@ -32,10 +35,167 @@ 'sanitise_imports', 'replace_selected_kind', 'single_variable_declaration', 'recursive_expression_map_update', 'get_integer_variable', 'get_loop_bounds', 'find_driver_loops', - 'get_local_arrays', 'check_routine_pragmas' + 'get_local_arrays', 'check_routine_pragmas', 'remove_duplicate_args', + 'modify_variable_declarations', 'RemoveDuplicateArgs', ] +class RemoveDuplicateArgs(Transformation): + """ + Transformation to remove duplicate arguments for both caller + and callee. + + .. warning:: + this won't work properly for multiple calls to the same routine + with differing duplicate arguments + + Parameters + ---------- + recurse_to_kernels : bool, optional + Remove duplicate arguments only at the driver level or recurse to + (nested) kernels (Default: `True`). + rename_common : bool, optional + Try to rename duplicate arguments by finding a common pattern + in those names (Default: `False`). + """ + + # This trafo only operates on procedures + item_filter = (ProcedureItem,) + + def __init__(self, recurse_to_kernels=True, rename_common=False): + self.recurse_to_kernels = recurse_to_kernels + self.rename_common = rename_common + + def transform_subroutine(self, routine, **kwargs): + role = kwargs['role'] + if role == 'driver' or self.recurse_to_kernels: + remove_duplicate_args(routine, rename_common=self.rename_common) + +def remove_duplicate_args(routine, rename_common=False): + """ + Utility to remove duplicate arguments for both caller + and callee. + + .. warning:: + this won't work properly for multiple calls to the same routine + with differing duplicate arguments + + Parameters + ---------- + routine : :any:`Subroutine` + The subroutine to be transformed. + rename_common : bool, optional + Try to rename duplicate arguments by finding a common pattern + in those names (Default: `False`). + """ + + def remove_duplicate_args_call(call): + arg_map = {} + for routine_arg, call_arg in call.arg_iter(): + arg_map.setdefault(call_arg, []).append(routine_arg) + # filter duplicate kwargs (comparing to the other kwarguments) + _new_kwargs = as_tuple(list(_)[0] for g, _ in it.groupby(call.kwarguments, key=lambda x: x[1])) + # filter duplicate kwargs (comparing to the arguments) + new_kwargs = tuple(kwarg for kwarg in _new_kwargs if kwarg[1] not in call.arguments) + # (filter duplicate arguments and) update call + call._update(arguments=as_tuple(dict.fromkeys(call.arguments)), kwarguments=new_kwargs) + return arg_map + + def modify_callee(callee, callee_arg_map): + combine = [routine_args for call_arg, routine_args in callee_arg_map.items() if len(routine_args) > 1] + matches = [os.path.commonprefix([str(elem.name) for elem in args]).rstrip('_') + or os.path.commonprefix([str(elem.name)[::-1] for elem in args]).rstrip('_')[::-1] for args in combine] + rename_common_map = {combine[i][0].name: matches[i] for i in range(len(combine)) if matches[i] != ''}\ + if rename_common else {} + redundant = flatten([routine_args[1:] for routine_args in combine]) + combine_map = {routine_args[0]: as_tuple(routine_args[1:]) for routine_args in combine} + arg_map = {arg.name: rename_common_map[common_arg.name] + if common_arg.name in rename_common_map and rename_common_map[common_arg.name] is not None + else common_arg.name + for common_arg, redundant_args in combine_map.items() for arg in redundant_args} + # remove duplicates from callee.arguments + new_routine_args = tuple(arg for arg in callee.arguments if arg not in redundant) + # rename if common name is possible + new_routine_args = as_tuple(arg.clone(name=rename_common_map[arg.name]) + if arg.name in rename_common_map else arg for arg in new_routine_args) + callee.arguments = new_routine_args + + # rename usage/occurences in callee.body + var_map = {} + variables = FindVariables(unique=False).visit(callee.body) + var_map = {var: var.clone(name=arg_map[var.name]) for var in variables if var.name in arg_map} + var_map.update({var: var.clone(name=rename_common_map[var.name]) for var in variables + if var.name in rename_common_map}) + callee.body = SubstituteExpressions(var_map).visit(callee.body) + # modify the variable declarations, thus remove redundant variable declarations and possibly rename + modify_variable_declarations(callee, remove_symbols=redundant, rename_symbols=rename_common_map) + # store the information for possibly later renaming kwarguments on caller side + return rename_common_map + + def rename_kwarguments(relevant_calls, rename_common_map_routine): + for call in relevant_calls: + kwarguments = call.kwarguments + if kwarguments: + call_name = str(call.routine.name).lower() + new_kwargs = as_tuple((rename_common_map_routine[call_name][kw[0]], kw[1]) + if kw[0] in rename_common_map_routine[call_name] else kw for kw in kwarguments) + call._update(kwarguments=new_kwargs) + + calls = FindNodes(CallStatement).visit(routine.body) + call_arg_map = {} + relevant_calls = [] + # adapt call statements (and remove duplicate args/kwargs) + for call in calls: + if call.routine is BasicType.DEFERRED: + continue + call_arg_map[call.routine] = remove_duplicate_args_call(call) + relevant_calls.append(call) + rename_common_map_routine = {} + # modify/adapt callees + for callee, callee_arg_map in call_arg_map.items(): + rename_common_map_routine[str(callee.name).lower()] = modify_callee(callee, callee_arg_map) + # handle possibly renamed kwarguments on caller side + if rename_common: + rename_kwarguments(relevant_calls, rename_common_map_routine) + + +def modify_variable_declarations(routine, remove_symbols=(), rename_symbols=None): + """ + Utility to modify variable declarations by either removing symbols or renaming + symbols. + + .. note:: + This utility only works on the variable declarations itself and + won't modify variable/symbol usages elsewhere! + + Parameters + ---------- + routine : :any:`Subroutine` + The subroutine to be transformed. + remove_symbols : list, tuple + List of symbols for which their declaration should be removed. + rename_symbols : dict + Dict/Map of symbols for which their declaration should be renamed. + """ + rename_symbols = rename_symbols if rename_symbols is not None else {} + var_decls = FindNodes(VariableDeclaration).visit(routine.spec) + remove_symbol_names = [var.name.lower() for var in remove_symbols] + decl_map = {} + already_declared = () + for decl in var_decls: + symbols = [symbol for symbol in decl.symbols if symbol.name.lower() not in remove_symbol_names] + symbols = [symbol.clone(name=rename_symbols[symbol.name]) + if symbol.name in rename_symbols else symbol for symbol in symbols] + symbols = [symbol for symbol in symbols if not symbol.name.lower() in already_declared] + already_declared += tuple(symbol.name.lower() for symbol in symbols) + if symbols and symbols != decl.symbols: + decl_map[decl] = decl.clone(symbols=as_tuple(symbols)) + else: + if not symbols: + decl_map[decl] = None + routine.spec = Transformer(decl_map).visit(routine.spec) + + def single_variable_declaration(routine, variables=None, group_by_shape=False): """ Modify/extend variable declarations to