diff --git a/loki/transformations/array_indexing.py b/loki/transformations/array_indexing.py index 3feeb4520..02daa3adc 100644 --- a/loki/transformations/array_indexing.py +++ b/loki/transformations/array_indexing.py @@ -13,16 +13,19 @@ from itertools import count import operator as op +from loki.batch import Transformation, ProcedureItem from loki.logging import info from loki.analyse import dataflow_analysis_attached from loki.expression import ( - symbols as sym, simplify, symbolic_op, FindVariables, SubstituteExpressions + symbols as sym, simplify, symbolic_op, FindVariables, SubstituteExpressions, + is_constant ) from loki.ir import ( - Assignment, Loop, VariableDeclaration, FindNodes, Transformer + Assignment, Loop, VariableDeclaration, FindNodes, Transformer, nodes as ir ) from loki.tools import as_tuple, CaseInsensitiveDict from loki.types import SymbolAttributes, BasicType +from loki.transformations.inline import inline_constant_parameters __all__ = [ @@ -30,7 +33,8 @@ 'resolve_vector_notation', 'normalize_range_indexing', 'promote_variables', 'promote_nonmatching_variables', 'promotion_dimensions_from_loop_nest', 'demote_variables', - 'flatten_arrays', 'normalize_array_shape_and_access' + 'flatten_arrays', 'normalize_array_shape_and_access', + 'LowerConstantArrayIndices' ] @@ -603,3 +607,236 @@ def new_dims(dim, shape): routine.variables = [v.clone(dimensions=as_tuple(sym.Product(v.shape)), type=v.type.clone(shape=as_tuple(sym.Product(v.shape)))) if isinstance(v, sym.Array) else v for v in routine.variables] + +class LowerConstantArrayIndices(Transformation): + """ + A transformation to pass/lower constant array indices down the call tree. + + For example, the following code: + + .. code-block:: fortran + + subroutine driver(...) + real, intent(inout) :: var(nlon,nlev,5,nb) + do ibl=1,10 + call kernel(var(:, :, 1, ibl), var(:, :, 2:5, ibl)) + end do + end subroutine driver + + subroutine kernel(var1, var2) + real, intent(inout) :: var1(nlon, nlev) + real, intent(inout) :: var2(nlon, nlev, 4) + var1(:, :) = ... + do jk=1,nlev + do jl=1,nlon + var1(jl, jk) = ... + do jt=1,4 + var2(jl, jk, jt) = ... + enddo + enddo + enddo + end subroutine kernel + + is transformed to: + + .. code-block:: fortran + + subroutine driver(...) + real, intent(inout) :: var(nlon,nlev,5,nb) + do ibl=1,10 + call kernel(var(:, :, :, ibl), var(:, :, :, ibl)) + end do + end subroutine driver + + subroutine kernel(var1, var2) + real, intent(inout) :: var1(nlon, nlev, 5) + real, intent(inout) :: var2(nlon, nlev, 5) + var1(:, :, 1) = ... + do jk=1,nlev + do jl=1,nlon + var1(jl, jk, 1) = ... + do jt=1,4 + var2(jl, jk, jt + 2 + -1) = ... + enddo + enddo + enddo + end subroutine kernel + + Parameters + ---------- + recurse_to_kernels: bool + Recurse to kernels, thus lower constant array indices below the driver level for nested + kernel calls (default: `True`). + inline_external_only: bool + Inline only external constant expressions or all of them (default: `False`) + """ + + # This trafo only operates on procedures + item_filter = (ProcedureItem,) + + def __init__(self, recurse_to_kernels=True, inline_external_only=True): + self.recurse_to_kernels = recurse_to_kernels + self.inline_external_only = inline_external_only + + @staticmethod + def explicit_dimensions(routine): + """ + Make dimensions of arrays explicit within :any:`Subroutine` ``routine``. + E.g., convert two-dimensional array ``arr2d`` to ``arr2d(:,:)`` or + ``arr3d`` to ``arr3d(:,:,:)``. + + Parameters + ---------- + routine: :any:`Subroutine` + The subroutine to check + """ + arrays = [var for var in FindVariables(unique=False).visit(routine.body) if isinstance(var, sym.Array)] + array_map = {} + for array in arrays: + if not array.dimensions: + new_dimensions = (sym.RangeIndex((None, None)),) * len(array.shape) + array_map[array] = array.clone(dimensions=new_dimensions) + routine.body = SubstituteExpressions(array_map).visit(routine.body) + + @staticmethod + def is_constant_dim(dim): + """ + Check whether dimension dim is constant, thus, either a constant + value or a constant range index. + + Parameters + ---------- + dim: :py:class:`pymbolic.primitives.Expression` + """ + if is_constant(dim): + return True + if isinstance(dim, sym.RangeIndex)\ + and all(child is not None and is_constant(child) for child in dim.children[:-1]): + return True + return False + + def transform_subroutine(self, routine, **kwargs): + role = kwargs['role'] + targets = tuple(str(t).lower() for t in as_tuple(kwargs.get('targets', None))) + if role == 'driver' or self.recurse_to_kernels: + inline_constant_parameters(routine, external_only=self.inline_external_only) + self.process(routine, targets) + + def process(self, routine, targets): + """ + Process the driver and possibly kernels + """ + dispatched_routines = () + offset_map = {} + for call in FindNodes(ir.CallStatement).visit(routine.body): + if str(call.name).lower() not in targets: + continue + # skip already dispatched routines but still update the call signature + if call.routine in dispatched_routines: + self.update_call_signature(call) + continue + # explicit array dimensions for the callee + self.explicit_dimensions(call.routine) + dispatched_routines += (call.routine,) + # create the offset map and apply to call and callee + offset_map[call.routine.name.lower()] = self.create_offset_map(call) + self.process_callee(call.routine, offset_map[call.routine.name.lower()]) + self.update_call_signature(call) + + def update_call_signature(self, call): + """ + Replace constant indices for call arguments being arrays with ':' and update the call. + """ + new_args = [arg.clone(dimensions=\ + tuple(sym.RangeIndex((None, None)) if self.is_constant_dim(d) else d for d in arg.dimensions))\ + if isinstance(arg, sym.Array) else arg for arg in call.arguments] + new_kwargs = [(kw[0], kw[1].clone(dimensions=\ + tuple(sym.RangeIndex((None, None)) if self.is_constant_dim(d) else d for d in kw[1].dimensions)))\ + if isinstance(kw[1], sym.Array) else kw for kw in call.kwarguments] + call._update(arguments=as_tuple(new_args), kwarguments=as_tuple(new_kwargs)) + + def create_offset_map(self, call): + """ + Create map/dictionary for arguments with constant array indices. + + For, e.g., + + integer :: arg(len1, len2, len3, len4) + call kernel(..., arg(:, 2, 4:6, i), ...) + + offset_map[arg] = { + 0: (0, None, None), # same index as before, no offset + 1: (None, 1, len2), # New index, offset 1, size of the dimension is len2 + 2: (1, 4, len3), # Used to be position 1, offset 4, size of the dimension is len3 + 3: (-1, None, None), # disregard as this is neither constant nor passed to callee + } + """ + offset_map = {} + for routine_arg, call_arg in call.arg_iter(): + if not isinstance(routine_arg, sym.Array): + continue + offset_map[routine_arg.name] = {} + current_index = 0 + for i, dim in enumerate(call_arg.dimensions): + if self.is_constant_dim(dim): + if isinstance(dim, sym.RangeIndex): + # constant array index is e.g. '1:3' or '5:10' + offset_map[routine_arg.name][i] = (current_index, dim.children[0], call_arg.shape[i]) + else: + # constant array index is e.g., '1' or '42' + offset_map[routine_arg.name][i] = (None, dim, call_arg.shape[i]) + current_index -= 1 + else: + if not isinstance(dim, sym.RangeIndex): + # non constant array index is a variable e.g. 'jl' + offset_map[routine_arg.name][i] = (-1, None, None) + current_index -= 1 + else: + # non constant array index is ':' + offset_map[routine_arg.name][i] = (current_index, None, None) + current_index += 1 + return offset_map + + def process_callee(self, routine, offset_map): + """ + Process/adapt the callee according to information in `offset_map`. + + Adapt the variable declarations and usage/indexing. + """ + # adapt variable declarations, thus adapt the dimension and shape of the corresponding arguments + vmap = {} + variable_map = routine.variable_map + for var_name in offset_map: + var = variable_map[var_name] + new_dims = () + for i in range(max(k for k, v in offset_map[var.name].items() if v != 0) + 1): + original_index = offset_map[var_name][i][0] + offset = offset_map[var_name][i][1] + size = offset_map[var_name][i][2] + if not (original_index is None or 0 <= original_index < len(var.dimensions)): + continue + if offset is not None: + new_dims += (size,) + else: + new_dims += (var.shape[original_index],) + vmap[var] = var.clone(dimensions=new_dims, type=var.type.clone(shape=new_dims)) + routine.spec = SubstituteExpressions(vmap).visit(routine.spec) + # adapt the variable usage, thus the indexing/dimension + vmap = {} + for var in FindVariables(unique=False).visit(routine.body): + if var.name in offset_map and var.dimensions is not None and var.dimensions: + new_dims = () + for i in range(max(k for k, v in offset_map[var.name].items() if v != 0) + 1): + original_index = offset_map[var.name][i][0] + offset = offset_map[var.name][i][1] + if not (original_index is None or 0 <= original_index < len(var.dimensions)): + continue + if offset is not None: + if original_index is None: + new_dims += (offset,) + else: + new_dims += (var.dimensions[original_index] + offset - 1,) + else: + new_dims += (var.dimensions[original_index],) + vmap[var] = var.clone(dimensions=new_dims) + routine.body = SubstituteExpressions(vmap).visit(routine.body) diff --git a/loki/transformations/tests/test_array_indexing.py b/loki/transformations/tests/test_array_indexing.py index 520fe6fe6..42112e7ad 100644 --- a/loki/transformations/tests/test_array_indexing.py +++ b/loki/transformations/tests/test_array_indexing.py @@ -19,7 +19,7 @@ promote_variables, demote_variables, normalize_range_indexing, invert_array_indices, flatten_arrays, normalize_array_shape_and_access, shift_to_zero_indexing, - resolve_vector_notation + resolve_vector_notation, LowerConstantArrayIndices ) from loki.transformations.transpile import FortranCTransformation @@ -721,6 +721,309 @@ def validate_routine(routine): builder.clean() + +@pytest.mark.parametrize('frontend', available_frontends()) +@pytest.mark.parametrize('recurse_to_kernels', (False, True)) +@pytest.mark.parametrize('inline_external_only', (False, True)) +@pytest.mark.parametrize('pass_as_kwarg', (False, True,)) +def test_lower_constant_array_indices(tmp_path, frontend, recurse_to_kernels, inline_external_only, pass_as_kwarg): + """ + Test lowering constant array indices + """ + fcode_driver = f""" +subroutine driver(nlon,nlev,nb,var) + use kernel_mod, only: kernel + implicit none + integer, parameter :: param_1 = 1 + integer, parameter :: param_2 = 2 + integer, parameter :: param_3 = 5 + integer, intent(in) :: nlon,nlev,nb + real, intent(inout) :: var(nlon,nlev,param_3,nb) + integer :: ibl + integer :: offset + integer :: some_val + integer :: loop_start, loop_end + loop_start = 2 + loop_end = nb + some_val = 0 + offset = 1 + !$omp test + do ibl=loop_start, loop_end + call kernel(nlon,nlev,{'var=' if pass_as_kwarg else ''}var(:,:,param_1,ibl), {'another_var=' if pass_as_kwarg else ''}var(:,:,param_2:param_3,ibl), {'icend=' if pass_as_kwarg else ''}offset, {'lstart=' if pass_as_kwarg else ''}loop_start, {'lend=' if pass_as_kwarg else ''}loop_end) + call kernel(nlon,nlev,{'var=' if pass_as_kwarg else ''}var(:,:,param_1,ibl), {'another_var=' if pass_as_kwarg else ''}var(:,:,param_2:param_3,ibl), {'icend=' if pass_as_kwarg else ''}offset, {'lstart=' if pass_as_kwarg else ''}loop_start, {'lend=' if pass_as_kwarg else ''}loop_end) + ! call kernel(nlon,nlev,var(:,:,param_1,ibl), var(:,:,param_2:param_3,ibl), offset, loop_start, loop_end) + enddo +end subroutine driver +""" + + fcode_kernel = """ +module kernel_mod +implicit none +contains +subroutine kernel(nlon,nlev,var,another_var,icend,lstart,lend) + use compute_mod, only: compute + implicit none + integer, intent(in) :: nlon,nlev,icend,lstart,lend + real, intent(inout) :: var(nlon,nlev) + real, intent(inout) :: another_var(nlon,nlev,4) + integer :: jk, jl, jt + var(:,:) = 0. + do jk = 1,nlev + do jl = 1, nlon + var(jl, jk) = 0. + do jt= 1,4 + another_var(jl, jk, jt) = 0.0 + end do + end do + end do + call compute(nlon,nlev,var) + call compute(nlon,nlev,var) +end subroutine kernel +end module kernel_mod +""" + + fcode_nested_kernel = """ +module compute_mod +implicit none +contains +subroutine compute(nlon,nlev,var) + implicit none + integer, intent(in) :: nlon,nlev + real, intent(inout) :: var(nlon,nlev) + var(:,:) = 0. +end subroutine compute +end module compute_mod +""" + + nested_kernel_mod = Module.from_source(fcode_nested_kernel, frontend=frontend, xmods=[tmp_path]) + kernel_mod = Module.from_source(fcode_kernel, frontend=frontend, definitions=nested_kernel_mod, xmods=[tmp_path]) + driver = Subroutine.from_source(fcode_driver, frontend=frontend, definitions=kernel_mod, xmods=[tmp_path]) + + kwargs = {'recurse_to_kernels': recurse_to_kernels, 'inline_external_only': inline_external_only} + LowerConstantArrayIndices(**kwargs).apply(driver, role='driver', targets=('kernel',)) + LowerConstantArrayIndices(**kwargs).apply(kernel_mod['kernel'], role='kernel', targets=('compute',)) + LowerConstantArrayIndices(**kwargs).apply(nested_kernel_mod['compute'], role='kernel') + + # driver + kernel_calls = FindNodes(CallStatement).visit(driver.body) + for kernel_call in kernel_calls: + if pass_as_kwarg: + arg1 = kernel_call.kwarguments[0][1] + arg2 = kernel_call.kwarguments[1][1] + else: + arg1 = kernel_call.arguments[2] + arg2 = kernel_call.arguments[3] + if inline_external_only and frontend != OMNI: + assert arg1.dimensions == (':', ':', 'param_1', 'ibl') + assert arg2.dimensions == (':', ':', 'param_2:param_3', 'ibl') + else: + assert arg1.dimensions == (':', ':', ':', 'ibl') + assert arg2.dimensions == (':', ':', ':', 'ibl') + # kernel + kernel_vars = kernel_mod['kernel'].variable_map + if inline_external_only and frontend != OMNI: + assert kernel_vars['var'].shape == ('nlon', 'nlev') + assert kernel_vars['var'].dimensions == ('nlon', 'nlev') + assert kernel_vars['another_var'].shape == ('nlon', 'nlev', 4) + assert kernel_vars['another_var'].dimensions == ('nlon', 'nlev', 4) + else: + assert kernel_vars['var'].shape == ('nlon', 'nlev', 5) + assert kernel_vars['var'].dimensions == ('nlon', 'nlev', 5) + assert kernel_vars['another_var'].shape == ('nlon', 'nlev', 5) + assert kernel_vars['another_var'].dimensions == ('nlon', 'nlev', 5) + if inline_external_only and frontend != OMNI: + for var in FindVariables().visit(kernel_mod['kernel'].body): + if var.name.lower() == 'var' and not any(isinstance(dim, sym.RangeIndex) for dim in var.dimensions): + assert var.dimensions == ('jl', 'jk') + if var.name.lower() == 'another_var' and not any(isinstance(dim, sym.RangeIndex) for dim in var.dimensions): + assert tuple(str(dim) for dim in var.dimensions) == ('jl', 'jk', 'jt') + else: + for var in FindVariables().visit(kernel_mod['kernel'].body): + if var.name.lower() == 'var' and not any(isinstance(dim, sym.RangeIndex) for dim in var.dimensions): + assert var.dimensions == ('jl', 'jk', 1) + if var.name.lower() == 'another_var' and not any(isinstance(dim, sym.RangeIndex) for dim in var.dimensions): + assert tuple(str(dim) for dim in var.dimensions) == ('jl', 'jk', 'jt + 2 + -1') + compute_calls = FindNodes(CallStatement).visit(kernel_mod['kernel'].body) + for compute_call in compute_calls: + for arg in compute_call.arguments: + if arg.name.lower() == 'var': + if inline_external_only and frontend != OMNI: + assert arg.dimensions == (':', ':') + elif recurse_to_kernels: + assert arg.dimensions == (':', ':', ':') + else: + assert arg.dimensions == (':', ':', '1') + # nested kernel + nested_kernel_var = nested_kernel_mod['compute'].variable_map['var'] + if recurse_to_kernels and (not inline_external_only or frontend == OMNI): + assert nested_kernel_var.shape == ('nlon', 'nlev', 5) + assert nested_kernel_var.dimensions == ('nlon', 'nlev', 5) + for var in FindVariables().visit(nested_kernel_mod['compute'].body): + if var.name.lower() == 'var': + assert var.dimensions == (':', ':', 1) + else: + assert nested_kernel_var.shape == ('nlon', 'nlev') + assert nested_kernel_var.dimensions == ('nlon', 'nlev') + for var in FindVariables().visit(nested_kernel_mod['compute'].body): + if var.name.lower() == 'var': + assert var.dimensions == (':', ':') + + +@pytest.mark.parametrize('frontend', available_frontends()) +@pytest.mark.parametrize('recurse_to_kernels', (False, True,)) +@pytest.mark.parametrize('inline_external_only', (False, True,)) +def test_lower_constant_array_indices_academic(tmp_path, frontend, recurse_to_kernels, inline_external_only): + """ + Test lowering constant array indices for a valid but somewhat academic example ... + + The transformation is capable to handle that, but let's just hope we'll never see + something like that out there in the wild ... + """ + fcode_driver = """ +subroutine driver(nlon,nlev,nb,var) + use kernel_mod, only: kernel + implicit none + integer, parameter :: param_1 = 1 + integer, parameter :: param_2 = 2 + integer, parameter :: param_3 = 5 + integer, intent(in) :: nlon,nlev,nb + real, intent(inout) :: var(nlon,4,3,nlev,param_3,nb) + ! real, intent(inout) :: var(nlon,3,nlev,param_3,nb) + integer :: ibl, j + integer :: offset + integer :: some_val + integer :: loop_start, loop_end + loop_start = 2 + loop_end = nb + some_val = 0 + offset = 1 + !$omp test + do ibl=loop_start, loop_end + do j=1,4 + call kernel(nlon,nlev,var(:,j,1,:,param_1,ibl), var(:,j,2:3,:,param_2:param_3,ibl), offset, loop_start, loop_end) + call kernel(nlon,nlev,var(:,j,1,:,param_1,ibl), var(:,j,2:3,:,param_2:param_3,ibl), offset, loop_start, loop_end) + end do + enddo +end subroutine driver +""" + + fcode_kernel = """ +module kernel_mod +implicit none +contains +subroutine kernel(nlon,nlev,var,another_var,icend,lstart,lend) + use compute_mod, only: compute + implicit none + integer, intent(in) :: nlon,nlev,icend,lstart,lend + real, intent(inout) :: var(nlon,nlev) + real, intent(inout) :: another_var(nlon,2,nlev,4) + integer :: jk, jl, jt + var(:,:) = 0. + do jk = 1,nlev + do jl = 1, nlon + var(jl, jk) = 0. + do jt= 1,4 + another_var(jl, 1, jk, jt) = 0.0 + end do + end do + end do + call compute(nlon,nlev,var) + call compute(nlon,nlev,var) +end subroutine kernel +end module kernel_mod +""" + + fcode_nested_kernel = """ +module compute_mod +implicit none +contains +subroutine compute(nlon,nlev,var) + implicit none + integer, intent(in) :: nlon,nlev + real, intent(inout) :: var(nlon,nlev) + var(:,:) = 0. +end subroutine compute +end module compute_mod +""" + + nested_kernel_mod = Module.from_source(fcode_nested_kernel, frontend=frontend, xmods=[tmp_path]) + kernel_mod = Module.from_source(fcode_kernel, frontend=frontend, definitions=nested_kernel_mod, xmods=[tmp_path]) + driver = Subroutine.from_source(fcode_driver, frontend=frontend, definitions=kernel_mod, xmods=[tmp_path]) + + kwargs = {'recurse_to_kernels': recurse_to_kernels, 'inline_external_only': inline_external_only} + LowerConstantArrayIndices(**kwargs).apply(driver, role='driver', targets=('kernel',)) + LowerConstantArrayIndices(**kwargs).apply(kernel_mod['kernel'], role='kernel', targets=('compute',)) + LowerConstantArrayIndices(**kwargs).apply(nested_kernel_mod['compute'], role='kernel') + + # driver + kernel_calls = FindNodes(CallStatement).visit(driver.body) + for kernel_call in kernel_calls: + if inline_external_only and frontend != OMNI: + assert kernel_call.arguments[2].dimensions == (':', 'j', ':', ':', 'param_1', 'ibl') + assert kernel_call.arguments[3].dimensions == (':', 'j', ':', ':', 'param_2:param_3', 'ibl') + else: + assert kernel_call.arguments[2].dimensions == (':', 'j', ':', ':', ':', 'ibl') + assert kernel_call.arguments[3].dimensions == (':', 'j', ':', ':', ':', 'ibl') + # kernel + kernel_vars = kernel_mod['kernel'].variable_map + if inline_external_only and frontend != OMNI: + assert kernel_vars['var'].shape == ('nlon', 3, 'nlev') + assert kernel_vars['var'].dimensions == ('nlon', 3, 'nlev') + assert kernel_vars['another_var'].shape == ('nlon', 3, 'nlev', 4) + assert kernel_vars['another_var'].dimensions == ('nlon', 3, 'nlev', 4) + else: + assert kernel_vars['var'].shape == ('nlon', '3', 'nlev', 5) + assert kernel_vars['var'].dimensions == ('nlon', '3', 'nlev', 5) + assert kernel_vars['another_var'].shape == ('nlon', '3', 'nlev', 5) + assert kernel_vars['another_var'].dimensions == ('nlon', '3', 'nlev', 5) + if inline_external_only and frontend != OMNI: + for var in FindVariables().visit(kernel_mod['kernel'].body): + if var.name.lower() == 'var' and not any(isinstance(dim, sym.RangeIndex) for dim in var.dimensions): + assert var.dimensions == ('jl', 1, 'jk') + if var.name.lower() == 'another_var' and not any(isinstance(dim, sym.RangeIndex) for dim in var.dimensions): + assert tuple(str(dim) for dim in var.dimensions) == ('jl', '1 + 2 + -1', 'jk', 'jt') + else: + for var in FindVariables().visit(kernel_mod['kernel'].body): + if var.name.lower() == 'var' and not any(isinstance(dim, sym.RangeIndex) for dim in var.dimensions): + assert var.dimensions == ('jl', 1, 'jk', 1) + if var.name.lower() == 'another_var' and not any(isinstance(dim, sym.RangeIndex) for dim in var.dimensions): + assert tuple(str(dim) for dim in var.dimensions) == ('jl', '1 + 2 + -1', 'jk', 'jt + 2 + -1') + compute_calls = FindNodes(CallStatement).visit(kernel_mod['kernel'].body) + for compute_call in compute_calls: + for arg in compute_call.arguments: + if arg.name.lower() == 'var': + if inline_external_only and frontend != OMNI: + if recurse_to_kernels: + assert arg.dimensions == (':', ':', ':') + else: + assert arg.dimensions == (':', 1, ':') + elif recurse_to_kernels: + assert arg.dimensions == (':', ':', ':', ':') + else: + assert arg.dimensions == (':', 1, ':', '1') + # nested kernel + nested_kernel_var = nested_kernel_mod['compute'].variable_map['var'] + if recurse_to_kernels and (not inline_external_only or frontend == OMNI): + assert nested_kernel_var.shape == ('nlon', 3, 'nlev', 5) + assert nested_kernel_var.dimensions == ('nlon', 3, 'nlev', 5) + for var in FindVariables().visit(nested_kernel_mod['compute'].body): + if var.name.lower() == 'var': + assert var.dimensions == (':', 1, ':', 1) + else: + if recurse_to_kernels: + assert nested_kernel_var.shape == ('nlon', 3, 'nlev') + assert nested_kernel_var.dimensions == ('nlon', 3, 'nlev') + else: + assert nested_kernel_var.shape == ('nlon', 'nlev') + assert nested_kernel_var.dimensions == ('nlon', 'nlev') + for var in FindVariables().visit(nested_kernel_mod['compute'].body): + if var.name.lower() == 'var': + if recurse_to_kernels: + assert var.dimensions == (':', 1, ':') + else: + assert var.dimensions == (':', ':') + + @pytest.mark.parametrize('frontend', available_frontends()) def test_transform_promote_resolve_vector_notation(tmp_path, frontend): """ diff --git a/scripts/loki_transform.py b/scripts/loki_transform.py index b7b0bbc01..6b5e8c02d 100644 --- a/scripts/loki_transform.py +++ b/scripts/loki_transform.py @@ -25,7 +25,9 @@ from loki.transformations.argument_shape import ( ArgumentArrayShapeAnalysis, ExplicitArgumentArrayShapeTransformation ) -from loki.transformations.array_indexing import normalize_range_indexing +from loki.transformations.array_indexing import ( + normalize_range_indexing, LowerConstantArrayIndices +) from loki.transformations.build_system import ( DependencyTransformation, ModuleWrapTransformation, FileWriteTransformation ) @@ -319,6 +321,9 @@ def transform_subroutine(self, routine, **kwargs): ) scheduler.process( pipeline ) + if 'cuf' in mode: + scheduler.process( LowerConstantArrayIndices() ) + if mode in ['cuf-parametrise', 'cuf-hoist', 'cuf-dynamic']: # These transformations requires complex constructor arguments, # so we use the file-based transformation configuration.