Skip to content

Commit

Permalink
Utility to remove duplicate arguments for calls and callees
Browse files Browse the repository at this point in the history
  • Loading branch information
MichaelSt98 committed Aug 28, 2024
1 parent 1f87624 commit 40fda0b
Show file tree
Hide file tree
Showing 2 changed files with 295 additions and 5 deletions.
134 changes: 132 additions & 2 deletions loki/transformations/tests/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@
symbols as sym, FindVariables, FindInlineCalls, SubstituteExpressions
)
from loki.frontend import available_frontends, OMNI
from loki.ir import nodes as ir, FindNodes, pragmas_attached
from loki.ir import nodes as ir, FindNodes, pragmas_attached, CallStatement
from loki.types import BasicType

from loki.transformations.utilities import (
single_variable_declaration, recursive_expression_map_update,
convert_to_lower_case, replace_intrinsics, rename_variables,
get_integer_variable, get_loop_bounds, is_driver_loop,
find_driver_loops, get_local_arrays, check_routine_pragmas
find_driver_loops, get_local_arrays, check_routine_pragmas,
RemoveDuplicateArgs
)


Expand Down Expand Up @@ -546,3 +547,132 @@ def test_transform_utilites_check_routine_pragmas(frontend, tmp_path):
assert check_routine_pragmas(module['test_acc_seq'], directive=None)
assert check_routine_pragmas(module['test_loki_seq'], directive=None)
assert check_routine_pragmas(module['test_acc_vec'], directive='openacc')

@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('pass_as_kwarg', (True, False))
@pytest.mark.parametrize('recurse_to_kernels', (True, False))
@pytest.mark.parametrize('rename_common', (True, False))
def test_utilities_remove_duplicate_args(tmp_path, frontend, pass_as_kwarg, recurse_to_kernels, rename_common):
"""
Test lowering constant array indices
"""
fcode_driver = f"""
subroutine driver(nlon,nlev,nb,var)
use kernel_mod, only: kernel
implicit none
integer, intent(in) :: nlon,nlev,nb
real, intent(inout) :: var(nlon,nlev,5,nb)
integer :: ibl
integer :: offset
integer :: some_val
integer :: loop_start, loop_end
loop_start = 2
loop_end = nb
some_val = 0
offset = 1
!$omp test
do ibl=loop_start, loop_end
call kernel(nlon,nlev,{'var1=' if pass_as_kwarg else ''}var(:,:,1,ibl), {'var2=' if pass_as_kwarg else ''}var(:,:,1,ibl), {'another_var=' if pass_as_kwarg else ''}var(:,:,2:5,ibl), {'icend=' if pass_as_kwarg else ''}offset, {'lstart=' if pass_as_kwarg else ''}loop_start, {'lend=' if pass_as_kwarg else ''}loop_end, {'kend=' if pass_as_kwarg else ''}nlev)
call kernel(nlon,nlev,{'var1=' if pass_as_kwarg else ''}var(:,:,1,ibl), {'var2=' if pass_as_kwarg else ''}var(:,:,1,ibl), {'another_var=' if pass_as_kwarg else ''}var(:,:,2:5,ibl), {'icend=' if pass_as_kwarg else ''}offset, {'lstart=' if pass_as_kwarg else ''}loop_start, {'lend=' if pass_as_kwarg else ''}loop_end, {'kend=' if pass_as_kwarg else ''}nlev)
enddo
end subroutine driver
"""

fcode_kernel = """
module kernel_mod
implicit none
contains
subroutine kernel(nlon,nlev,var1,var2,another_var,icend,lstart,lend,kend)
use compute_mod, only: compute
implicit none
integer, intent(in) :: nlon,nlev,icend,lstart,lend,kend
real, intent(inout) :: var1(nlon,nlev)
real, intent(inout) :: var2(nlon,nlev)
real, intent(inout) :: another_var(nlon,nlev,4)
integer :: jk, jl, jt
var1(:,:) = 0.
do jk = 1,kend
do jl = 1, nlon
var1(jl, jk) = 0.
var2(jl, jk) = 1.0
do jt= 1,4
another_var(jl, jk, jt) = 0.0
end do
end do
end do
call compute(nlon,nlev,var1, var2)
call compute(nlon,nlev,var1, var2)
end subroutine kernel
end module kernel_mod
"""

fcode_nested_kernel = """
module compute_mod
implicit none
contains
subroutine compute(nlon,nlev,b_var,a_var)
implicit none
integer, intent(in) :: nlon,nlev
real, intent(inout) :: b_var(nlon,nlev)
real, intent(inout) :: a_var(nlon,nlev)
b_var(:,:) = 0.
a_var(:,:) = 1.0
end subroutine compute
end module compute_mod
"""

nested_kernel_mod = Module.from_source(fcode_nested_kernel, frontend=frontend, xmods=[tmp_path])
kernel_mod = Module.from_source(fcode_kernel, frontend=frontend, definitions=nested_kernel_mod, xmods=[tmp_path])
driver = Subroutine.from_source(fcode_driver, frontend=frontend, definitions=kernel_mod, xmods=[tmp_path])

kwargs = {'recurse_to_kernels': recurse_to_kernels, 'rename_common': rename_common}

RemoveDuplicateArgs(**kwargs).apply(driver, role='driver', targets=('kernel',))
RemoveDuplicateArgs(**kwargs).apply(kernel_mod['kernel'], role='kernel', targets=('compute',))
RemoveDuplicateArgs(**kwargs).apply(nested_kernel_mod['compute'], role='kernel')

# driver
kernel_var_name = 'var' if rename_common else 'var1'
kernel_calls = FindNodes(CallStatement).visit(driver.body)
for kernel_call in kernel_calls:
if pass_as_kwarg:
# print(f"kernel_call.kwarguments: {list(kernel_call.kwarguments)}")
assert (kernel_var_name, 'var(:, :, 1, ibl)') in kernel_call.kwarguments
assert ('var2', 'var(:, :, 1, ibl)') not in kernel_call.kwarguments
arg1 = kernel_call.kwarguments[0][1]
arg2 = kernel_call.kwarguments[1][1]
else:
assert 'var(:, :, 1, ibl)' in kernel_call.arguments
assert 'var2(:, :, 1, ibl)' not in kernel_call.arguments
arg1 = kernel_call.arguments[2]
arg2 = kernel_call.arguments[3]
# print(f"arg1: {arg1} | arg2: {arg2}")
assert arg1.dimensions == (':', ':', '1', 'ibl')
assert arg2.dimensions == (':', ':', '2:5', 'ibl')
# kernel
kernel_vars = kernel_mod['kernel'].variable_map
kernel_args = [arg.name.lower() for arg in kernel_mod['kernel'].arguments]
assert kernel_var_name in kernel_args
assert 'var2' not in kernel_args
assert 'var2' not in kernel_vars
assert kernel_vars[kernel_var_name].shape == ('nlon', 'nlev')
assert kernel_vars['another_var'].dimensions == ('nlon', 'nlev', 4)
compute_calls = FindNodes(CallStatement).visit(kernel_mod['kernel'].body)
for compute_call in compute_calls:
assert kernel_var_name in compute_call.arguments
assert 'var2' not in compute_call.arguments
# nested_kernel
nested_kernel = nested_kernel_mod['compute']
nested_kernel_vars = nested_kernel.variable_map
nested_kernel_args = [arg.name.lower() for arg in nested_kernel.arguments]
nested_kernel_var_name = 'var' if rename_common else 'b_var'
if recurse_to_kernels:
assert nested_kernel_var_name in nested_kernel_args
assert 'a_var' not in nested_kernel_args
assert nested_kernel_var_name in nested_kernel_vars
assert 'a_var' not in nested_kernel_vars
else:
assert 'b_var' in nested_kernel_args
assert 'a_var' in nested_kernel_args
assert 'b_var' in nested_kernel_vars
assert 'a_var' in nested_kernel_vars
166 changes: 163 additions & 3 deletions loki/transformations/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,24 @@
Collection of utility routines to deal with general language conversion.
"""

import os
import platform
from collections import defaultdict
import itertools as it
from pymbolic.primitives import Expression
from loki.batch import Transformation, ProcedureItem
from loki.expression import (
symbols as sym, FindVariables, FindInlineCalls, FindLiterals,
SubstituteExpressions, SubstituteExpressionsMapper, ExpressionFinder,
ExpressionRetriever, TypedSymbol, MetaSymbol
)
from loki.ir import (
nodes as ir, Import, TypeDef, VariableDeclaration,
StatementFunction, Transformer, FindNodes
StatementFunction, Transformer, FindNodes, CallStatement
)
from loki.module import Module
from loki.subroutine import Subroutine
from loki.tools import CaseInsensitiveDict, as_tuple
from loki.tools import CaseInsensitiveDict, as_tuple, flatten
from loki.types import SymbolAttributes, BasicType, DerivedType, ProcedureType


Expand All @@ -32,10 +35,167 @@
'sanitise_imports', 'replace_selected_kind',
'single_variable_declaration', 'recursive_expression_map_update',
'get_integer_variable', 'get_loop_bounds', 'find_driver_loops',
'get_local_arrays', 'check_routine_pragmas'
'get_local_arrays', 'check_routine_pragmas', 'remove_duplicate_args',
'modify_variable_declarations', 'RemoveDuplicateArgs',
]


class RemoveDuplicateArgs(Transformation):
"""
Transformation to remove duplicate arguments for both caller
and callee.
.. warning::
this won't work properly for multiple calls to the same routine
with differing duplicate arguments
Parameters
----------
recurse_to_kernels : bool, optional
Remove duplicate arguments only at the driver level or recurse to
(nested) kernels (Default: `True`).
rename_common : bool, optional
Try to rename duplicate arguments by finding a common pattern
in those names (Default: `False`).
"""

# This trafo only operates on procedures
item_filter = (ProcedureItem,)

def __init__(self, recurse_to_kernels=True, rename_common=False):
self.recurse_to_kernels = recurse_to_kernels
self.rename_common = rename_common

def transform_subroutine(self, routine, **kwargs):
role = kwargs['role']
if role == 'driver' or self.recurse_to_kernels:
remove_duplicate_args(routine, rename_common=self.rename_common)

def remove_duplicate_args(routine, rename_common=False):
"""
Utility to remove duplicate arguments for both caller
and callee.
.. warning::
this won't work properly for multiple calls to the same routine
with differing duplicate arguments
Parameters
----------
routine : :any:`Subroutine`
The subroutine to be transformed.
rename_common : bool, optional
Try to rename duplicate arguments by finding a common pattern
in those names (Default: `False`).
"""

def remove_duplicate_args_call(call):
arg_map = {}
for routine_arg, call_arg in call.arg_iter():
arg_map.setdefault(call_arg, []).append(routine_arg)
# filter duplicate kwargs (comparing to the other kwarguments)
_new_kwargs = as_tuple(list(_)[0] for g, _ in it.groupby(call.kwarguments, key=lambda x: x[1]))
# filter duplicate kwargs (comparing to the arguments)
new_kwargs = tuple(kwarg for kwarg in _new_kwargs if kwarg[1] not in call.arguments)
# (filter duplicate arguments and) update call
call._update(arguments=as_tuple(dict.fromkeys(call.arguments)), kwarguments=new_kwargs)
return arg_map

def modify_callee(callee, callee_arg_map):
combine = [routine_args for call_arg, routine_args in callee_arg_map.items() if len(routine_args) > 1]
matches = [os.path.commonprefix([str(elem.name) for elem in args]).rstrip('_')
or os.path.commonprefix([str(elem.name)[::-1] for elem in args]).rstrip('_')[::-1] for args in combine]
rename_common_map = {combine[i][0].name: matches[i] for i in range(len(combine)) if matches[i] != ''}\
if rename_common else {}
redundant = flatten([routine_args[1:] for routine_args in combine])
combine_map = {routine_args[0]: as_tuple(routine_args[1:]) for routine_args in combine}
arg_map = {arg.name: rename_common_map[common_arg.name]
if common_arg.name in rename_common_map and rename_common_map[common_arg.name] is not None
else common_arg.name
for common_arg, redundant_args in combine_map.items() for arg in redundant_args}
# remove duplicates from callee.arguments
new_routine_args = tuple(arg for arg in callee.arguments if arg not in redundant)
# rename if common name is possible
new_routine_args = as_tuple(arg.clone(name=rename_common_map[arg.name])
if arg.name in rename_common_map else arg for arg in new_routine_args)
callee.arguments = new_routine_args

# rename usage/occurences in callee.body
var_map = {}
variables = FindVariables(unique=False).visit(callee.body)
var_map = {var: var.clone(name=arg_map[var.name]) for var in variables if var.name in arg_map}
var_map.update({var: var.clone(name=rename_common_map[var.name]) for var in variables
if var.name in rename_common_map})
callee.body = SubstituteExpressions(var_map).visit(callee.body)
# modify the variable declarations, thus remove redundant variable declarations and possibly rename
modify_variable_declarations(callee, remove_symbols=redundant, rename_symbols=rename_common_map)
# store the information for possibly later renaming kwarguments on caller side
return rename_common_map

def rename_kwarguments(relevant_calls, rename_common_map_routine):
for call in relevant_calls:
kwarguments = call.kwarguments
if kwarguments:
call_name = str(call.routine.name).lower()
new_kwargs = as_tuple((rename_common_map_routine[call_name][kw[0]], kw[1])
if kw[0] in rename_common_map_routine[call_name] else kw for kw in kwarguments)
call._update(kwarguments=new_kwargs)

calls = FindNodes(CallStatement).visit(routine.body)
call_arg_map = {}
relevant_calls = []
# adapt call statements (and remove duplicate args/kwargs)
for call in calls:
if call.routine is BasicType.DEFERRED:
continue
call_arg_map[call.routine] = remove_duplicate_args_call(call)
relevant_calls.append(call)
rename_common_map_routine = {}
# modify/adapt callees
for callee, callee_arg_map in call_arg_map.items():
rename_common_map_routine[str(callee.name).lower()] = modify_callee(callee, callee_arg_map)
# handle possibly renamed kwarguments on caller side
if rename_common:
rename_kwarguments(relevant_calls, rename_common_map_routine)


def modify_variable_declarations(routine, remove_symbols=(), rename_symbols=None):
"""
Utility to modify variable declarations by either removing symbols or renaming
symbols.
.. note::
This utility only works on the variable declarations itself and
won't modify variable/symbol usages elsewhere!
Parameters
----------
routine : :any:`Subroutine`
The subroutine to be transformed.
remove_symbols : list, tuple
List of symbols for which their declaration should be removed.
rename_symbols : dict
Dict/Map of symbols for which their declaration should be renamed.
"""
rename_symbols = rename_symbols if rename_symbols is not None else {}
var_decls = FindNodes(VariableDeclaration).visit(routine.spec)
remove_symbol_names = [var.name.lower() for var in remove_symbols]
decl_map = {}
already_declared = ()
for decl in var_decls:
symbols = [symbol for symbol in decl.symbols if symbol.name.lower() not in remove_symbol_names]
symbols = [symbol.clone(name=rename_symbols[symbol.name])
if symbol.name in rename_symbols else symbol for symbol in symbols]
symbols = [symbol for symbol in symbols if not symbol.name.lower() in already_declared]
already_declared += tuple(symbol.name.lower() for symbol in symbols)
if symbols and symbols != decl.symbols:
decl_map[decl] = decl.clone(symbols=as_tuple(symbols))
else:
if not symbols:
decl_map[decl] = None
routine.spec = Transformer(decl_map).visit(routine.spec)


def single_variable_declaration(routine, variables=None, group_by_shape=False):
"""
Modify/extend variable declarations to
Expand Down

0 comments on commit 40fda0b

Please sign in to comment.