From b44810fdc45c5e89e51be518bd1cc2230d7d19b3 Mon Sep 17 00:00:00 2001 From: Michael Staneker Date: Mon, 22 Jul 2024 09:22:05 +0000 Subject: [PATCH 1/3] Introduce new transformation 'LowerConstantArrayIndices' to allow to pass/lower constant array indices to kernel(s) calls --- loki/transformations/array_indexing.py | 180 +++++++++++++++++- .../tests/test_array_indexing.py | 141 +++++++++++++- scripts/loki_transform.py | 7 +- 3 files changed, 323 insertions(+), 5 deletions(-) diff --git a/loki/transformations/array_indexing.py b/loki/transformations/array_indexing.py index 3feeb4520..46781769f 100644 --- a/loki/transformations/array_indexing.py +++ b/loki/transformations/array_indexing.py @@ -13,16 +13,19 @@ from itertools import count import operator as op +from loki.batch import Transformation, ProcedureItem from loki.logging import info from loki.analyse import dataflow_analysis_attached from loki.expression import ( - symbols as sym, simplify, symbolic_op, FindVariables, SubstituteExpressions + symbols as sym, simplify, symbolic_op, FindVariables, SubstituteExpressions, + is_constant ) from loki.ir import ( - Assignment, Loop, VariableDeclaration, FindNodes, Transformer + Assignment, Loop, VariableDeclaration, FindNodes, Transformer, nodes as ir ) from loki.tools import as_tuple, CaseInsensitiveDict from loki.types import SymbolAttributes, BasicType +from loki.transformations.inline import inline_constant_parameters __all__ = [ @@ -30,7 +33,8 @@ 'resolve_vector_notation', 'normalize_range_indexing', 'promote_variables', 'promote_nonmatching_variables', 'promotion_dimensions_from_loop_nest', 'demote_variables', - 'flatten_arrays', 'normalize_array_shape_and_access' + 'flatten_arrays', 'normalize_array_shape_and_access', + 'LowerConstantArrayIndices' ] @@ -603,3 +607,173 @@ def new_dims(dim, shape): routine.variables = [v.clone(dimensions=as_tuple(sym.Product(v.shape)), type=v.type.clone(shape=as_tuple(sym.Product(v.shape)))) if isinstance(v, sym.Array) else v for v in routine.variables] + +class LowerConstantArrayIndices(Transformation): + """ + A transformation to pass/lower constant array indices down the call tree. + + For example, the following code: + + .. code-block:: fortran + + subroutine driver(...) + real, intent(inout) :: var(nlon,nlev,5,nb) + do ibl=1,10 + call kernel(var(:, :, 1, ibl), var(:, :, 2:5, ibl)) + end do + end subroutine driver + + subroutine kernel(var1, var2) + real, intent(inout) :: var1(nlon, nlev) + real, intent(inout) :: var2(nlon, nlev, 4) + var1(:, :) = ... + do jk=1,nlev + do jl=1,nlon + var1(jl, jk) = ... + do jt=1,4 + var2(jl, jk, jt) = ... + enddo + enddo + enddo + end subroutine kernel + + is transformed to: + + .. code-block:: fortran + + subroutine driver(...) + real, intent(inout) :: var(nlon,nlev,5,nb) + do ibl=1,10 + call kernel(var(:, :, :, ibl), var(:, :, :, ibl)) + end do + end subroutine driver + + subroutine kernel(var1, var2) + real, intent(inout) :: var1(nlon, nlev, 5) + real, intent(inout) :: var2(nlon, nlev, 5) + var1(:, :, 1) = ... + do jk=1,nlev + do jl=1,nlon + var1(jl, jk, 1) = ... + do jt=1,4 + var2(jl, jk, jt + 2 + -1) = ... + enddo + enddo + enddo + end subroutine kernel + + Parameters + ---------- + recurse_to_kernels: bool + Recurse to kernels, thus lower constant array indices below the driver level for nested + kernel calls (default: `True`). + inline_external_only: bool + Inline only external constant expressions or all of them (default: `False`) + """ + + # This trafo only operates on procedures + item_filter = (ProcedureItem,) + + def __init__(self, recurse_to_kernels=True, inline_external_only=False): + self.recurse_to_kernels = recurse_to_kernels + self.inline_external_only = inline_external_only + + @staticmethod + def explicit_dimensions(routine): + """ + Make dimensions of arrays explicit within :any:`Subroutine` ``routine``. + E.g., convert two-dimensional array ``arr2d`` to ``arr2d(:,:)`` or + ``arr3d`` to ``arr3d(:,:,:)``. + + Parameters + ---------- + routine: :any:`Subroutine` + The subroutine to check + """ + arrays = [var for var in FindVariables(unique=False).visit(routine.body) if isinstance(var, sym.Array)] + array_map = {} + for array in arrays: + if not array.dimensions: + new_dimensions = (sym.RangeIndex((None, None)),) * len(array.shape) + array_map[array] = array.clone(dimensions=new_dimensions) + routine.body = SubstituteExpressions(array_map).visit(routine.body) + + @staticmethod + def is_constant_dim(dim): + """ + Check whether dimension dim is constant, thus, either a constant + value or a constant range index. + + Parameters + ---------- + dim: :py:class:`pymbolic.primitives.Expression` + """ + if is_constant(dim): + return True + if isinstance(dim, sym.RangeIndex)\ + and all(child is not None and is_constant(child) for child in dim.children[:-1]): + return True + return False + + def transform_subroutine(self, routine, **kwargs): + role = kwargs['role'] + targets = tuple(str(t).lower() for t in as_tuple(kwargs.get('targets', None))) + if role == 'driver' or self.recurse_to_kernels: + inline_constant_parameters(routine, external_only=self.inline_external_only) + self.process(routine, targets) + + def process(self, routine, targets): + dispatched_routines = () + for call in FindNodes(ir.CallStatement).visit(routine.body): + if str(call.name).lower() not in targets: + continue + # make array dimensions explicit + self.explicit_dimensions(call.routine) + routine_vmap = {} + updated_call_args = {} + introduce_index = {} + introduce_offset = {} + # collect and iterate over relevant arguments (being arrays) + relevant_arguments = [arg for arg in call.arg_iter() if isinstance(arg[1], sym.Array)] + for routine_arg, call_arg in relevant_arguments: + # check for constant dimensions + insert_dims = tuple((i, dim) for i, dim in enumerate(call_arg.dimensions) if self.is_constant_dim(dim)) + if insert_dims: + for insert_dim in insert_dims: + new_dims = list(call_arg.dimensions) + new_dims[insert_dim[0]] = sym.RangeIndex((None, None)) + updated_call_args[call_arg.name] = call_arg.clone(dimensions=as_tuple(new_dims)) + if call.routine not in dispatched_routines: + new_dims = list(routine_arg.dimensions) + # dimension is a constant RangeIndex, e.g., '1:3' + if isinstance(insert_dim[1], sym.RangeIndex): + introduce_offset[routine_arg.name] = (insert_dim[0], insert_dim[1].children[0]) + new_dims[insert_dim[0]] = call_arg.shape[insert_dim[0]] + # dimension is a constant literal, e.g., '1' + else: + introduce_index[routine_arg.name] = (insert_dim[0], insert_dim[1]) + new_dims.insert(insert_dim[0], call_arg.shape[insert_dim[0]]) + routine_arg_new_type = routine_arg.type.clone(shape=as_tuple(new_dims)) + routine_vmap[routine_arg] = routine_arg.clone(type=routine_arg_new_type, + dimensions=as_tuple(new_dims)) + # apply changes to call.routine (if this routine has not yet been processed) + if call.routine not in dispatched_routines: + call.routine.spec = SubstituteExpressions(routine_vmap).visit(call.routine.spec) + vmap = {} + for var in FindVariables(unique=False).visit(call.routine.body): + if var.name in introduce_index and var.dimensions is not None and var.dimensions: + var_dim = list(var.dimensions) + var_dim.insert(introduce_index[var.name][0], introduce_index[var.name][1]) + vmap[var] = var.clone(dimensions=as_tuple(var_dim)) + if var.name in introduce_offset and var.dimensions is not None and var.dimensions: + var_dim = list(var.dimensions) + var_dim[introduce_offset[var.name][0]] += introduce_offset[var.name][1] - 1 + vmap[var] = var.clone(dimensions=as_tuple(var_dim)) + call.routine.body = SubstituteExpressions(vmap).visit(call.routine.body) + # update the call itself + new_args = tuple(updated_call_args[arg.name] if not isinstance(arg, sym._Literal)\ + and arg.name in updated_call_args else arg for arg in call.arguments) + new_kwargs = ((kwarg[0], updated_call_args[kwarg[1].name]) if not isinstance(kwarg[1], sym._Literal)\ + and kwarg[1].name in updated_call_args else kwarg for kwarg in call.kwarguments) + call._update(arguments=new_args, kwarguments=new_kwargs) + dispatched_routines += (call.routine,) diff --git a/loki/transformations/tests/test_array_indexing.py b/loki/transformations/tests/test_array_indexing.py index 520fe6fe6..1e022c4f7 100644 --- a/loki/transformations/tests/test_array_indexing.py +++ b/loki/transformations/tests/test_array_indexing.py @@ -19,7 +19,7 @@ promote_variables, demote_variables, normalize_range_indexing, invert_array_indices, flatten_arrays, normalize_array_shape_and_access, shift_to_zero_indexing, - resolve_vector_notation + resolve_vector_notation, LowerConstantArrayIndices ) from loki.transformations.transpile import FortranCTransformation @@ -721,6 +721,145 @@ def validate_routine(routine): builder.clean() + +@pytest.mark.parametrize('frontend', available_frontends()) +@pytest.mark.parametrize('recurse_to_kernels', (False, True)) +@pytest.mark.parametrize('inline_external_only', (False, True)) +def test_lower_constant_array_indices(frontend, recurse_to_kernels, inline_external_only): + + fcode_driver = """ +subroutine driver(nlon,nlev,nb,var) + use kernel_mod, only: kernel + implicit none + integer, parameter :: param_1 = 1 + integer, parameter :: param_2 = 2 + integer, parameter :: param_3 = 5 + integer, intent(in) :: nlon,nlev,nb + real, intent(inout) :: var(nlon,nlev,param_3,nb) + integer :: ibl + integer :: offset + integer :: some_val + integer :: loop_start, loop_end + loop_start = 2 + loop_end = nb + some_val = 0 + offset = 1 + !$omp test + do ibl=loop_start, loop_end + call kernel(nlon,nlev,var(:,:,param_1,ibl), var(:,:,param_2:param_3,ibl), offset, loop_start, loop_end) + call kernel(nlon,nlev,var(:,:,param_1,ibl), var(:,:,param_2:param_3,ibl), offset, loop_start, loop_end) + enddo +end subroutine driver +""" + + fcode_kernel = """ +module kernel_mod +implicit none +contains +subroutine kernel(nlon,nlev,var,another_var,icend,lstart,lend) + use compute_mod, only: compute + implicit none + integer, intent(in) :: nlon,nlev,icend,lstart,lend + real, intent(inout) :: var(nlon,nlev) + real, intent(inout) :: another_var(nlon,nlev,4) + integer :: jk, jl, jt + var(:,:) = 0. + do jk = 1,nlev + do jl = 1, nlon + var(jl, jk) = 0. + do jt= 1,4 + another_var(jl, jk, jt) = 0.0 + end do + end do + end do + call compute(nlon,nlev,var) + call compute(nlon,nlev,var) +end subroutine kernel +end module kernel_mod +""" + + fcode_nested_kernel = """ +module compute_mod +implicit none +contains +subroutine compute(nlon,nlev,var) + implicit none + integer, intent(in) :: nlon,nlev + real, intent(inout) :: var(nlon,nlev) + var(:,:) = 0. +end subroutine compute +end module compute_mod +""" + + # recurse_to_kernels = True + nested_kernel_mod = Module.from_source(fcode_nested_kernel, frontend=frontend) + kernel_mod = Module.from_source(fcode_kernel, frontend=frontend, definitions=nested_kernel_mod) + driver = Subroutine.from_source(fcode_driver, frontend=frontend, definitions=kernel_mod) + + kwargs = {'recurse_to_kernels': recurse_to_kernels, 'inline_external_only': inline_external_only} + LowerConstantArrayIndices(**kwargs).apply(driver, role='driver', targets=('kernel',)) + LowerConstantArrayIndices(**kwargs).apply(kernel_mod['kernel'], role='kernel', targets=('compute',)) + LowerConstantArrayIndices(**kwargs).apply(nested_kernel_mod['compute'], role='kernel') + + # driver + kernel_calls = FindNodes(CallStatement).visit(driver.body) + for kernel_call in kernel_calls: + if inline_external_only and frontend != OMNI: + assert kernel_call.arguments[2].dimensions == (':', ':', 'param_1', 'ibl') + assert kernel_call.arguments[3].dimensions == (':', ':', 'param_2:param_3', 'ibl') + else: + assert kernel_call.arguments[2].dimensions == (':', ':', ':', 'ibl') + assert kernel_call.arguments[3].dimensions == (':', ':', ':', 'ibl') + # kernel + kernel_vars = kernel_mod['kernel'].variable_map + if inline_external_only and frontend != OMNI: + assert kernel_vars['var'].shape == ('nlon', 'nlev') + assert kernel_vars['var'].dimensions == ('nlon', 'nlev') + assert kernel_vars['another_var'].shape == ('nlon', 'nlev', 4) + assert kernel_vars['another_var'].dimensions == ('nlon', 'nlev', 4) + else: + assert kernel_vars['var'].shape == ('nlon', 'nlev', 5) + assert kernel_vars['var'].dimensions == ('nlon', 'nlev', 5) + assert kernel_vars['another_var'].shape == ('nlon', 'nlev', 5) + assert kernel_vars['another_var'].dimensions == ('nlon', 'nlev', 5) + if inline_external_only and frontend != OMNI: + for var in FindVariables().visit(kernel_mod['kernel'].body): + if var.name.lower() == 'var' and not any(isinstance(dim, sym.RangeIndex) for dim in var.dimensions): + assert var.dimensions == ('jl', 'jk') + if var.name.lower() == 'another_var' and not any(isinstance(dim, sym.RangeIndex) for dim in var.dimensions): + assert tuple(str(dim) for dim in var.dimensions) == ('jl', 'jk', 'jt') + else: + for var in FindVariables().visit(kernel_mod['kernel'].body): + if var.name.lower() == 'var' and not any(isinstance(dim, sym.RangeIndex) for dim in var.dimensions): + assert var.dimensions == ('jl', 'jk', 1) + if var.name.lower() == 'another_var' and not any(isinstance(dim, sym.RangeIndex) for dim in var.dimensions): + assert tuple(str(dim) for dim in var.dimensions) == ('jl', 'jk', 'jt + 2 + -1') + compute_calls = FindNodes(CallStatement).visit(kernel_mod['kernel'].body) + for compute_call in compute_calls: + for arg in compute_call.arguments: + if arg.name.lower() == 'var': + if inline_external_only and frontend != OMNI: + assert arg.dimensions == (':', ':') + elif recurse_to_kernels: + assert arg.dimensions == (':', ':', ':') + else: + assert arg.dimensions == (':', ':', '1') + # nested kernel + nested_kernel_var = nested_kernel_mod['compute'].variable_map['var'] + if recurse_to_kernels and (not inline_external_only or frontend == OMNI): + assert nested_kernel_var.shape == ('nlon', 'nlev', 5) + assert nested_kernel_var.dimensions == ('nlon', 'nlev', 5) + for var in FindVariables().visit(nested_kernel_mod['compute'].body): + if var.name.lower() == 'var': + assert var.dimensions == (':', ':', 1) + else: + assert nested_kernel_var.shape == ('nlon', 'nlev') + assert nested_kernel_var.dimensions == ('nlon', 'nlev') + for var in FindVariables().visit(nested_kernel_mod['compute'].body): + if var.name.lower() == 'var': + assert var.dimensions == (':', ':') + + @pytest.mark.parametrize('frontend', available_frontends()) def test_transform_promote_resolve_vector_notation(tmp_path, frontend): """ diff --git a/scripts/loki_transform.py b/scripts/loki_transform.py index b7b0bbc01..6b5e8c02d 100644 --- a/scripts/loki_transform.py +++ b/scripts/loki_transform.py @@ -25,7 +25,9 @@ from loki.transformations.argument_shape import ( ArgumentArrayShapeAnalysis, ExplicitArgumentArrayShapeTransformation ) -from loki.transformations.array_indexing import normalize_range_indexing +from loki.transformations.array_indexing import ( + normalize_range_indexing, LowerConstantArrayIndices +) from loki.transformations.build_system import ( DependencyTransformation, ModuleWrapTransformation, FileWriteTransformation ) @@ -319,6 +321,9 @@ def transform_subroutine(self, routine, **kwargs): ) scheduler.process( pipeline ) + if 'cuf' in mode: + scheduler.process( LowerConstantArrayIndices() ) + if mode in ['cuf-parametrise', 'cuf-hoist', 'cuf-dynamic']: # These transformations requires complex constructor arguments, # so we use the file-based transformation configuration. From 84b8f5b05593c12a36ef9bd44f67f09aae1b941d Mon Sep 17 00:00:00 2001 From: Michael Staneker Date: Mon, 22 Jul 2024 12:57:48 +0000 Subject: [PATCH 2/3] use xmods=[tmp_path] for test --- loki/transformations/tests/test_array_indexing.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/loki/transformations/tests/test_array_indexing.py b/loki/transformations/tests/test_array_indexing.py index 1e022c4f7..a7ee33250 100644 --- a/loki/transformations/tests/test_array_indexing.py +++ b/loki/transformations/tests/test_array_indexing.py @@ -725,7 +725,7 @@ def validate_routine(routine): @pytest.mark.parametrize('frontend', available_frontends()) @pytest.mark.parametrize('recurse_to_kernels', (False, True)) @pytest.mark.parametrize('inline_external_only', (False, True)) -def test_lower_constant_array_indices(frontend, recurse_to_kernels, inline_external_only): +def test_lower_constant_array_indices(tmp_path, frontend, recurse_to_kernels, inline_external_only): fcode_driver = """ subroutine driver(nlon,nlev,nb,var) @@ -791,10 +791,9 @@ def test_lower_constant_array_indices(frontend, recurse_to_kernels, inline_exter end module compute_mod """ - # recurse_to_kernels = True - nested_kernel_mod = Module.from_source(fcode_nested_kernel, frontend=frontend) - kernel_mod = Module.from_source(fcode_kernel, frontend=frontend, definitions=nested_kernel_mod) - driver = Subroutine.from_source(fcode_driver, frontend=frontend, definitions=kernel_mod) + nested_kernel_mod = Module.from_source(fcode_nested_kernel, frontend=frontend, xmods=[tmp_path]) + kernel_mod = Module.from_source(fcode_kernel, frontend=frontend, definitions=nested_kernel_mod, xmods=[tmp_path]) + driver = Subroutine.from_source(fcode_driver, frontend=frontend, definitions=kernel_mod, xmods=[tmp_path]) kwargs = {'recurse_to_kernels': recurse_to_kernels, 'inline_external_only': inline_external_only} LowerConstantArrayIndices(**kwargs).apply(driver, role='driver', targets=('kernel',)) From 80a1a81fab74ae581e79aee08d7553f099e9271a Mon Sep 17 00:00:00 2001 From: Michael Staneker Date: Wed, 7 Aug 2024 09:33:44 +0000 Subject: [PATCH 3/3] Refactor and improve 'LowerConstantArrayIndices' (which also allows now multiple constant array indices for one array argument) --- loki/transformations/array_indexing.py | 161 ++++++++++----- .../tests/test_array_indexing.py | 183 +++++++++++++++++- 2 files changed, 286 insertions(+), 58 deletions(-) diff --git a/loki/transformations/array_indexing.py b/loki/transformations/array_indexing.py index 46781769f..02daa3adc 100644 --- a/loki/transformations/array_indexing.py +++ b/loki/transformations/array_indexing.py @@ -674,7 +674,7 @@ class LowerConstantArrayIndices(Transformation): # This trafo only operates on procedures item_filter = (ProcedureItem,) - def __init__(self, recurse_to_kernels=True, inline_external_only=False): + def __init__(self, recurse_to_kernels=True, inline_external_only=True): self.recurse_to_kernels = recurse_to_kernels self.inline_external_only = inline_external_only @@ -723,57 +723,120 @@ def transform_subroutine(self, routine, **kwargs): self.process(routine, targets) def process(self, routine, targets): + """ + Process the driver and possibly kernels + """ dispatched_routines = () + offset_map = {} for call in FindNodes(ir.CallStatement).visit(routine.body): if str(call.name).lower() not in targets: continue - # make array dimensions explicit + # skip already dispatched routines but still update the call signature + if call.routine in dispatched_routines: + self.update_call_signature(call) + continue + # explicit array dimensions for the callee self.explicit_dimensions(call.routine) - routine_vmap = {} - updated_call_args = {} - introduce_index = {} - introduce_offset = {} - # collect and iterate over relevant arguments (being arrays) - relevant_arguments = [arg for arg in call.arg_iter() if isinstance(arg[1], sym.Array)] - for routine_arg, call_arg in relevant_arguments: - # check for constant dimensions - insert_dims = tuple((i, dim) for i, dim in enumerate(call_arg.dimensions) if self.is_constant_dim(dim)) - if insert_dims: - for insert_dim in insert_dims: - new_dims = list(call_arg.dimensions) - new_dims[insert_dim[0]] = sym.RangeIndex((None, None)) - updated_call_args[call_arg.name] = call_arg.clone(dimensions=as_tuple(new_dims)) - if call.routine not in dispatched_routines: - new_dims = list(routine_arg.dimensions) - # dimension is a constant RangeIndex, e.g., '1:3' - if isinstance(insert_dim[1], sym.RangeIndex): - introduce_offset[routine_arg.name] = (insert_dim[0], insert_dim[1].children[0]) - new_dims[insert_dim[0]] = call_arg.shape[insert_dim[0]] - # dimension is a constant literal, e.g., '1' - else: - introduce_index[routine_arg.name] = (insert_dim[0], insert_dim[1]) - new_dims.insert(insert_dim[0], call_arg.shape[insert_dim[0]]) - routine_arg_new_type = routine_arg.type.clone(shape=as_tuple(new_dims)) - routine_vmap[routine_arg] = routine_arg.clone(type=routine_arg_new_type, - dimensions=as_tuple(new_dims)) - # apply changes to call.routine (if this routine has not yet been processed) - if call.routine not in dispatched_routines: - call.routine.spec = SubstituteExpressions(routine_vmap).visit(call.routine.spec) - vmap = {} - for var in FindVariables(unique=False).visit(call.routine.body): - if var.name in introduce_index and var.dimensions is not None and var.dimensions: - var_dim = list(var.dimensions) - var_dim.insert(introduce_index[var.name][0], introduce_index[var.name][1]) - vmap[var] = var.clone(dimensions=as_tuple(var_dim)) - if var.name in introduce_offset and var.dimensions is not None and var.dimensions: - var_dim = list(var.dimensions) - var_dim[introduce_offset[var.name][0]] += introduce_offset[var.name][1] - 1 - vmap[var] = var.clone(dimensions=as_tuple(var_dim)) - call.routine.body = SubstituteExpressions(vmap).visit(call.routine.body) - # update the call itself - new_args = tuple(updated_call_args[arg.name] if not isinstance(arg, sym._Literal)\ - and arg.name in updated_call_args else arg for arg in call.arguments) - new_kwargs = ((kwarg[0], updated_call_args[kwarg[1].name]) if not isinstance(kwarg[1], sym._Literal)\ - and kwarg[1].name in updated_call_args else kwarg for kwarg in call.kwarguments) - call._update(arguments=new_args, kwarguments=new_kwargs) dispatched_routines += (call.routine,) + # create the offset map and apply to call and callee + offset_map[call.routine.name.lower()] = self.create_offset_map(call) + self.process_callee(call.routine, offset_map[call.routine.name.lower()]) + self.update_call_signature(call) + + def update_call_signature(self, call): + """ + Replace constant indices for call arguments being arrays with ':' and update the call. + """ + new_args = [arg.clone(dimensions=\ + tuple(sym.RangeIndex((None, None)) if self.is_constant_dim(d) else d for d in arg.dimensions))\ + if isinstance(arg, sym.Array) else arg for arg in call.arguments] + new_kwargs = [(kw[0], kw[1].clone(dimensions=\ + tuple(sym.RangeIndex((None, None)) if self.is_constant_dim(d) else d for d in kw[1].dimensions)))\ + if isinstance(kw[1], sym.Array) else kw for kw in call.kwarguments] + call._update(arguments=as_tuple(new_args), kwarguments=as_tuple(new_kwargs)) + + def create_offset_map(self, call): + """ + Create map/dictionary for arguments with constant array indices. + + For, e.g., + + integer :: arg(len1, len2, len3, len4) + call kernel(..., arg(:, 2, 4:6, i), ...) + + offset_map[arg] = { + 0: (0, None, None), # same index as before, no offset + 1: (None, 1, len2), # New index, offset 1, size of the dimension is len2 + 2: (1, 4, len3), # Used to be position 1, offset 4, size of the dimension is len3 + 3: (-1, None, None), # disregard as this is neither constant nor passed to callee + } + """ + offset_map = {} + for routine_arg, call_arg in call.arg_iter(): + if not isinstance(routine_arg, sym.Array): + continue + offset_map[routine_arg.name] = {} + current_index = 0 + for i, dim in enumerate(call_arg.dimensions): + if self.is_constant_dim(dim): + if isinstance(dim, sym.RangeIndex): + # constant array index is e.g. '1:3' or '5:10' + offset_map[routine_arg.name][i] = (current_index, dim.children[0], call_arg.shape[i]) + else: + # constant array index is e.g., '1' or '42' + offset_map[routine_arg.name][i] = (None, dim, call_arg.shape[i]) + current_index -= 1 + else: + if not isinstance(dim, sym.RangeIndex): + # non constant array index is a variable e.g. 'jl' + offset_map[routine_arg.name][i] = (-1, None, None) + current_index -= 1 + else: + # non constant array index is ':' + offset_map[routine_arg.name][i] = (current_index, None, None) + current_index += 1 + return offset_map + + def process_callee(self, routine, offset_map): + """ + Process/adapt the callee according to information in `offset_map`. + + Adapt the variable declarations and usage/indexing. + """ + # adapt variable declarations, thus adapt the dimension and shape of the corresponding arguments + vmap = {} + variable_map = routine.variable_map + for var_name in offset_map: + var = variable_map[var_name] + new_dims = () + for i in range(max(k for k, v in offset_map[var.name].items() if v != 0) + 1): + original_index = offset_map[var_name][i][0] + offset = offset_map[var_name][i][1] + size = offset_map[var_name][i][2] + if not (original_index is None or 0 <= original_index < len(var.dimensions)): + continue + if offset is not None: + new_dims += (size,) + else: + new_dims += (var.shape[original_index],) + vmap[var] = var.clone(dimensions=new_dims, type=var.type.clone(shape=new_dims)) + routine.spec = SubstituteExpressions(vmap).visit(routine.spec) + # adapt the variable usage, thus the indexing/dimension + vmap = {} + for var in FindVariables(unique=False).visit(routine.body): + if var.name in offset_map and var.dimensions is not None and var.dimensions: + new_dims = () + for i in range(max(k for k, v in offset_map[var.name].items() if v != 0) + 1): + original_index = offset_map[var.name][i][0] + offset = offset_map[var.name][i][1] + if not (original_index is None or 0 <= original_index < len(var.dimensions)): + continue + if offset is not None: + if original_index is None: + new_dims += (offset,) + else: + new_dims += (var.dimensions[original_index] + offset - 1,) + else: + new_dims += (var.dimensions[original_index],) + vmap[var] = var.clone(dimensions=new_dims) + routine.body = SubstituteExpressions(vmap).visit(routine.body) diff --git a/loki/transformations/tests/test_array_indexing.py b/loki/transformations/tests/test_array_indexing.py index a7ee33250..42112e7ad 100644 --- a/loki/transformations/tests/test_array_indexing.py +++ b/loki/transformations/tests/test_array_indexing.py @@ -725,9 +725,12 @@ def validate_routine(routine): @pytest.mark.parametrize('frontend', available_frontends()) @pytest.mark.parametrize('recurse_to_kernels', (False, True)) @pytest.mark.parametrize('inline_external_only', (False, True)) -def test_lower_constant_array_indices(tmp_path, frontend, recurse_to_kernels, inline_external_only): - - fcode_driver = """ +@pytest.mark.parametrize('pass_as_kwarg', (False, True,)) +def test_lower_constant_array_indices(tmp_path, frontend, recurse_to_kernels, inline_external_only, pass_as_kwarg): + """ + Test lowering constant array indices + """ + fcode_driver = f""" subroutine driver(nlon,nlev,nb,var) use kernel_mod, only: kernel implicit none @@ -746,8 +749,9 @@ def test_lower_constant_array_indices(tmp_path, frontend, recurse_to_kernels, in offset = 1 !$omp test do ibl=loop_start, loop_end - call kernel(nlon,nlev,var(:,:,param_1,ibl), var(:,:,param_2:param_3,ibl), offset, loop_start, loop_end) - call kernel(nlon,nlev,var(:,:,param_1,ibl), var(:,:,param_2:param_3,ibl), offset, loop_start, loop_end) + call kernel(nlon,nlev,{'var=' if pass_as_kwarg else ''}var(:,:,param_1,ibl), {'another_var=' if pass_as_kwarg else ''}var(:,:,param_2:param_3,ibl), {'icend=' if pass_as_kwarg else ''}offset, {'lstart=' if pass_as_kwarg else ''}loop_start, {'lend=' if pass_as_kwarg else ''}loop_end) + call kernel(nlon,nlev,{'var=' if pass_as_kwarg else ''}var(:,:,param_1,ibl), {'another_var=' if pass_as_kwarg else ''}var(:,:,param_2:param_3,ibl), {'icend=' if pass_as_kwarg else ''}offset, {'lstart=' if pass_as_kwarg else ''}loop_start, {'lend=' if pass_as_kwarg else ''}loop_end) + ! call kernel(nlon,nlev,var(:,:,param_1,ibl), var(:,:,param_2:param_3,ibl), offset, loop_start, loop_end) enddo end subroutine driver """ @@ -803,12 +807,18 @@ def test_lower_constant_array_indices(tmp_path, frontend, recurse_to_kernels, in # driver kernel_calls = FindNodes(CallStatement).visit(driver.body) for kernel_call in kernel_calls: + if pass_as_kwarg: + arg1 = kernel_call.kwarguments[0][1] + arg2 = kernel_call.kwarguments[1][1] + else: + arg1 = kernel_call.arguments[2] + arg2 = kernel_call.arguments[3] if inline_external_only and frontend != OMNI: - assert kernel_call.arguments[2].dimensions == (':', ':', 'param_1', 'ibl') - assert kernel_call.arguments[3].dimensions == (':', ':', 'param_2:param_3', 'ibl') + assert arg1.dimensions == (':', ':', 'param_1', 'ibl') + assert arg2.dimensions == (':', ':', 'param_2:param_3', 'ibl') else: - assert kernel_call.arguments[2].dimensions == (':', ':', ':', 'ibl') - assert kernel_call.arguments[3].dimensions == (':', ':', ':', 'ibl') + assert arg1.dimensions == (':', ':', ':', 'ibl') + assert arg2.dimensions == (':', ':', ':', 'ibl') # kernel kernel_vars = kernel_mod['kernel'].variable_map if inline_external_only and frontend != OMNI: @@ -859,6 +869,161 @@ def test_lower_constant_array_indices(tmp_path, frontend, recurse_to_kernels, in assert var.dimensions == (':', ':') +@pytest.mark.parametrize('frontend', available_frontends()) +@pytest.mark.parametrize('recurse_to_kernels', (False, True,)) +@pytest.mark.parametrize('inline_external_only', (False, True,)) +def test_lower_constant_array_indices_academic(tmp_path, frontend, recurse_to_kernels, inline_external_only): + """ + Test lowering constant array indices for a valid but somewhat academic example ... + + The transformation is capable to handle that, but let's just hope we'll never see + something like that out there in the wild ... + """ + fcode_driver = """ +subroutine driver(nlon,nlev,nb,var) + use kernel_mod, only: kernel + implicit none + integer, parameter :: param_1 = 1 + integer, parameter :: param_2 = 2 + integer, parameter :: param_3 = 5 + integer, intent(in) :: nlon,nlev,nb + real, intent(inout) :: var(nlon,4,3,nlev,param_3,nb) + ! real, intent(inout) :: var(nlon,3,nlev,param_3,nb) + integer :: ibl, j + integer :: offset + integer :: some_val + integer :: loop_start, loop_end + loop_start = 2 + loop_end = nb + some_val = 0 + offset = 1 + !$omp test + do ibl=loop_start, loop_end + do j=1,4 + call kernel(nlon,nlev,var(:,j,1,:,param_1,ibl), var(:,j,2:3,:,param_2:param_3,ibl), offset, loop_start, loop_end) + call kernel(nlon,nlev,var(:,j,1,:,param_1,ibl), var(:,j,2:3,:,param_2:param_3,ibl), offset, loop_start, loop_end) + end do + enddo +end subroutine driver +""" + + fcode_kernel = """ +module kernel_mod +implicit none +contains +subroutine kernel(nlon,nlev,var,another_var,icend,lstart,lend) + use compute_mod, only: compute + implicit none + integer, intent(in) :: nlon,nlev,icend,lstart,lend + real, intent(inout) :: var(nlon,nlev) + real, intent(inout) :: another_var(nlon,2,nlev,4) + integer :: jk, jl, jt + var(:,:) = 0. + do jk = 1,nlev + do jl = 1, nlon + var(jl, jk) = 0. + do jt= 1,4 + another_var(jl, 1, jk, jt) = 0.0 + end do + end do + end do + call compute(nlon,nlev,var) + call compute(nlon,nlev,var) +end subroutine kernel +end module kernel_mod +""" + + fcode_nested_kernel = """ +module compute_mod +implicit none +contains +subroutine compute(nlon,nlev,var) + implicit none + integer, intent(in) :: nlon,nlev + real, intent(inout) :: var(nlon,nlev) + var(:,:) = 0. +end subroutine compute +end module compute_mod +""" + + nested_kernel_mod = Module.from_source(fcode_nested_kernel, frontend=frontend, xmods=[tmp_path]) + kernel_mod = Module.from_source(fcode_kernel, frontend=frontend, definitions=nested_kernel_mod, xmods=[tmp_path]) + driver = Subroutine.from_source(fcode_driver, frontend=frontend, definitions=kernel_mod, xmods=[tmp_path]) + + kwargs = {'recurse_to_kernels': recurse_to_kernels, 'inline_external_only': inline_external_only} + LowerConstantArrayIndices(**kwargs).apply(driver, role='driver', targets=('kernel',)) + LowerConstantArrayIndices(**kwargs).apply(kernel_mod['kernel'], role='kernel', targets=('compute',)) + LowerConstantArrayIndices(**kwargs).apply(nested_kernel_mod['compute'], role='kernel') + + # driver + kernel_calls = FindNodes(CallStatement).visit(driver.body) + for kernel_call in kernel_calls: + if inline_external_only and frontend != OMNI: + assert kernel_call.arguments[2].dimensions == (':', 'j', ':', ':', 'param_1', 'ibl') + assert kernel_call.arguments[3].dimensions == (':', 'j', ':', ':', 'param_2:param_3', 'ibl') + else: + assert kernel_call.arguments[2].dimensions == (':', 'j', ':', ':', ':', 'ibl') + assert kernel_call.arguments[3].dimensions == (':', 'j', ':', ':', ':', 'ibl') + # kernel + kernel_vars = kernel_mod['kernel'].variable_map + if inline_external_only and frontend != OMNI: + assert kernel_vars['var'].shape == ('nlon', 3, 'nlev') + assert kernel_vars['var'].dimensions == ('nlon', 3, 'nlev') + assert kernel_vars['another_var'].shape == ('nlon', 3, 'nlev', 4) + assert kernel_vars['another_var'].dimensions == ('nlon', 3, 'nlev', 4) + else: + assert kernel_vars['var'].shape == ('nlon', '3', 'nlev', 5) + assert kernel_vars['var'].dimensions == ('nlon', '3', 'nlev', 5) + assert kernel_vars['another_var'].shape == ('nlon', '3', 'nlev', 5) + assert kernel_vars['another_var'].dimensions == ('nlon', '3', 'nlev', 5) + if inline_external_only and frontend != OMNI: + for var in FindVariables().visit(kernel_mod['kernel'].body): + if var.name.lower() == 'var' and not any(isinstance(dim, sym.RangeIndex) for dim in var.dimensions): + assert var.dimensions == ('jl', 1, 'jk') + if var.name.lower() == 'another_var' and not any(isinstance(dim, sym.RangeIndex) for dim in var.dimensions): + assert tuple(str(dim) for dim in var.dimensions) == ('jl', '1 + 2 + -1', 'jk', 'jt') + else: + for var in FindVariables().visit(kernel_mod['kernel'].body): + if var.name.lower() == 'var' and not any(isinstance(dim, sym.RangeIndex) for dim in var.dimensions): + assert var.dimensions == ('jl', 1, 'jk', 1) + if var.name.lower() == 'another_var' and not any(isinstance(dim, sym.RangeIndex) for dim in var.dimensions): + assert tuple(str(dim) for dim in var.dimensions) == ('jl', '1 + 2 + -1', 'jk', 'jt + 2 + -1') + compute_calls = FindNodes(CallStatement).visit(kernel_mod['kernel'].body) + for compute_call in compute_calls: + for arg in compute_call.arguments: + if arg.name.lower() == 'var': + if inline_external_only and frontend != OMNI: + if recurse_to_kernels: + assert arg.dimensions == (':', ':', ':') + else: + assert arg.dimensions == (':', 1, ':') + elif recurse_to_kernels: + assert arg.dimensions == (':', ':', ':', ':') + else: + assert arg.dimensions == (':', 1, ':', '1') + # nested kernel + nested_kernel_var = nested_kernel_mod['compute'].variable_map['var'] + if recurse_to_kernels and (not inline_external_only or frontend == OMNI): + assert nested_kernel_var.shape == ('nlon', 3, 'nlev', 5) + assert nested_kernel_var.dimensions == ('nlon', 3, 'nlev', 5) + for var in FindVariables().visit(nested_kernel_mod['compute'].body): + if var.name.lower() == 'var': + assert var.dimensions == (':', 1, ':', 1) + else: + if recurse_to_kernels: + assert nested_kernel_var.shape == ('nlon', 3, 'nlev') + assert nested_kernel_var.dimensions == ('nlon', 3, 'nlev') + else: + assert nested_kernel_var.shape == ('nlon', 'nlev') + assert nested_kernel_var.dimensions == ('nlon', 'nlev') + for var in FindVariables().visit(nested_kernel_mod['compute'].body): + if var.name.lower() == 'var': + if recurse_to_kernels: + assert var.dimensions == (':', 1, ':') + else: + assert var.dimensions == (':', ':') + + @pytest.mark.parametrize('frontend', available_frontends()) def test_transform_promote_resolve_vector_notation(tmp_path, frontend): """