Skip to content

Commit

Permalink
Introduce new transformation 'LowerConstantArrayIndices' to allow to …
Browse files Browse the repository at this point in the history
…pass/lower constant array indices to kernel(s) calls
  • Loading branch information
MichaelSt98 committed Jul 22, 2024
1 parent a39336b commit 3751695
Show file tree
Hide file tree
Showing 3 changed files with 322 additions and 5 deletions.
180 changes: 177 additions & 3 deletions loki/transformations/array_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,28 @@
from itertools import count
import operator as op

from loki.batch import Transformation, ProcedureItem
from loki.logging import info
from loki.analyse import dataflow_analysis_attached
from loki.expression import (
symbols as sym, simplify, symbolic_op, FindVariables, SubstituteExpressions
symbols as sym, simplify, symbolic_op, FindVariables, SubstituteExpressions,
is_constant
)
from loki.ir import (
Assignment, Loop, VariableDeclaration, FindNodes, Transformer
Assignment, Loop, VariableDeclaration, FindNodes, Transformer, nodes as ir
)
from loki.tools import as_tuple, CaseInsensitiveDict
from loki.types import SymbolAttributes, BasicType
from loki.transformations.inline import inline_constant_parameters


__all__ = [
'shift_to_zero_indexing', 'invert_array_indices',
'resolve_vector_notation', 'normalize_range_indexing',
'promote_variables', 'promote_nonmatching_variables',
'promotion_dimensions_from_loop_nest', 'demote_variables',
'flatten_arrays', 'normalize_array_shape_and_access'
'flatten_arrays', 'normalize_array_shape_and_access',
'LowerConstantArrayIndices'
]


Expand Down Expand Up @@ -603,3 +607,173 @@ def new_dims(dim, shape):
routine.variables = [v.clone(dimensions=as_tuple(sym.Product(v.shape)),
type=v.type.clone(shape=as_tuple(sym.Product(v.shape))))
if isinstance(v, sym.Array) else v for v in routine.variables]

class LowerConstantArrayIndices(Transformation):
"""
A transformation to pass/lower constant array indices down the call tree.
For example, the following code:
.. code-block:: fortran
subroutine driver(...)
real, intent(inout) :: var(nlon,nlev,5,nb)
do ibl=1,10
call kernel(var(:, :, 1, ibl), var(:, :, 2:5, ibl))
end do
end subroutine driver
subroutine kernel(var1, var2)
real, intent(inout) :: var1(nlon, nlev)
real, intent(inout) :: var2(nlon, nlev, 4)
var1(:, :) = ...
do jk=1,nlev
do jl=1,nlon
var1(jl, jk) = ...
do jt=1,4
var2(jl, jk, jt) = ...
enddo
enddo
enddo
end subroutine kernel
is transformed to:
.. code-block:: fortran
subroutine driver(...)
real, intent(inout) :: var(nlon,nlev,5,nb)
do ibl=1,10
call kernel(var(:, :, :, ibl), var(:, :, :, ibl))
end do
end subroutine driver
subroutine kernel(var1, var2)
real, intent(inout) :: var1(nlon, nlev, 5)
real, intent(inout) :: var2(nlon, nlev, 5)
var1(:, :, 1) = ...
do jk=1,nlev
do jl=1,nlon
var1(jl, jk, 1) = ...
do jt=1,4
var2(jl, jk, jt + 2 + -1) = ...
enddo
enddo
enddo
end subroutine kernel
Parameters
----------
recurse_to_kernels: bool
Recurse to kernels, thus lower constant array indices below the driver level for nested
kernel calls (default: `True`).
inline_external_only: bool
Inline only external constant expressions or all of them (default: `False`)
"""

# This trafo only operates on procedures
item_filter = (ProcedureItem,)

def __init__(self, recurse_to_kernels=True, inline_external_only=False):
self.recurse_to_kernels = recurse_to_kernels
self.inline_external_only = inline_external_only

@staticmethod
def explicit_dimensions(routine):
"""
Make dimensions of arrays explicit within :any:`Subroutine` ``routine``.
E.g., convert two-dimensional array ``arr2d`` to ``arr2d(:,:)`` or
``arr3d`` to ``arr3d(:,:,:)``.
Parameters
----------
routine: :any:`Subroutine`
The subroutine to check
"""
arrays = [var for var in FindVariables(unique=False).visit(routine.body) if isinstance(var, sym.Array)]
array_map = {}
for array in arrays:
if not array.dimensions:
new_dimensions = (sym.RangeIndex((None, None)),) * len(array.shape)
array_map[array] = array.clone(dimensions=new_dimensions)
routine.body = SubstituteExpressions(array_map).visit(routine.body)

@staticmethod
def is_constant_dim(dim):
"""
Check whether dimension dim is constant, thus, either a constant
value or a constant range index.
Parameters
----------
dim: :py:class:`pymbolic.primitives.Expression`
"""
if is_constant(dim):
return True
if isinstance(dim, sym.RangeIndex)\
and all(child is not None and is_constant(child) for child in dim.children[:-1]):
return True
return False

def transform_subroutine(self, routine, **kwargs):
role = kwargs['role']
targets = tuple(str(t).lower() for t in as_tuple(kwargs.get('targets', None)))
if role == 'driver' or self.recurse_to_kernels:
inline_constant_parameters(routine, external_only=self.inline_external_only)
self.process(routine, targets)

def process(self, routine, targets):
dispatched_routines = ()
for call in FindNodes(ir.CallStatement).visit(routine.body):
if str(call.name).lower() not in targets:
continue
# make array dimensions explicit
self.explicit_dimensions(call.routine)
routine_vmap = {}
updated_call_args = {}
introduce_index = {}
introduce_offset = {}
# collect and iterate over relevant arguments (being arrays)
relevant_arguments = [arg for arg in call.arg_iter() if isinstance(arg[1], sym.Array)]
for routine_arg, call_arg in relevant_arguments:
# check for constant dimensions
insert_dims = tuple((i, dim) for i, dim in enumerate(call_arg.dimensions) if self.is_constant_dim(dim))
if insert_dims:
for insert_dim in insert_dims:
new_dims = list(call_arg.dimensions)
new_dims[insert_dim[0]] = sym.RangeIndex((None, None))
updated_call_args[call_arg.name] = call_arg.clone(dimensions=as_tuple(new_dims))
if call.routine not in dispatched_routines:
new_dims = list(routine_arg.dimensions)
# dimension is a constant RangeIndex, e.g., '1:3'
if isinstance(insert_dim[1], sym.RangeIndex):
introduce_offset[routine_arg.name] = (insert_dim[0], insert_dim[1].children[0])
new_dims[insert_dim[0]] = call_arg.shape[insert_dim[0]]
# dimension is a constant literal, e.g., '1'
else:
introduce_index[routine_arg.name] = (insert_dim[0], insert_dim[1])
new_dims.insert(insert_dim[0], call_arg.shape[insert_dim[0]])
routine_arg_new_type = routine_arg.type.clone(shape=as_tuple(new_dims))
routine_vmap[routine_arg] = routine_arg.clone(type=routine_arg_new_type,
dimensions=as_tuple(new_dims))
# apply changes to call.routine (if this routine has not yet been processed)
if call.routine not in dispatched_routines:
call.routine.spec = SubstituteExpressions(routine_vmap).visit(call.routine.spec)
vmap = {}
for var in FindVariables(unique=False).visit(call.routine.body):
if var.name in introduce_index and var.dimensions is not None and var.dimensions:
var_dim = list(var.dimensions)
var_dim.insert(introduce_index[var.name][0], introduce_index[var.name][1])
vmap[var] = var.clone(dimensions=as_tuple(var_dim))
if var.name in introduce_offset and var.dimensions is not None and var.dimensions:
var_dim = list(var.dimensions)
var_dim[introduce_offset[var.name][0]] += introduce_offset[var.name][1] - 1
vmap[var] = var.clone(dimensions=as_tuple(var_dim))
call.routine.body = SubstituteExpressions(vmap).visit(call.routine.body)
# update the call itself
new_args = tuple(updated_call_args[arg.name] if not isinstance(arg, sym._Literal)\
and arg.name in updated_call_args else arg for arg in call.arguments)
new_kwargs = ((kwarg[0], updated_call_args[kwarg[1].name]) if not isinstance(kwarg[1], sym._Literal)\
and kwarg[1].name in updated_call_args else kwarg for kwarg in call.kwarguments)
call._update(arguments=new_args, kwarguments=new_kwargs)
dispatched_routines += (call.routine,)
140 changes: 139 additions & 1 deletion loki/transformations/tests/test_array_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
from loki import Module, Subroutine, fgen
from loki.build import jit_compile, jit_compile_lib, clean_test, Builder, Obj
from loki.expression import symbols as sym, FindVariables
from loki.frontend import available_frontends
from loki.frontend import available_frontends, OMNI
from loki.ir import FindNodes, CallStatement

from loki.transformations.array_indexing import (
promote_variables, demote_variables, normalize_range_indexing,
invert_array_indices, flatten_arrays,
normalize_array_shape_and_access, shift_to_zero_indexing,
LowerConstantArrayIndices
)
from loki.transformations.transpile import FortranCTransformation

Expand Down Expand Up @@ -733,3 +734,140 @@ def validate_routine(routine):
assert (b_flattened == b_ref.flatten(order='F')).all()

builder.clean()

@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('recurse_to_kernels', (False, True))
@pytest.mark.parametrize('inline_external_only', (False, True))
def test_lower_constant_array_indices(frontend, recurse_to_kernels, inline_external_only):

fcode_driver = """
subroutine driver(nlon,nlev,nb,var)
use kernel_mod, only: kernel
implicit none
integer, parameter :: param_1 = 1
integer, parameter :: param_2 = 2
integer, parameter :: param_3 = 5
integer, intent(in) :: nlon,nlev,nb
real, intent(inout) :: var(nlon,nlev,param_3,nb)
integer :: ibl
integer :: offset
integer :: some_val
integer :: loop_start, loop_end
loop_start = 2
loop_end = nb
some_val = 0
offset = 1
!$omp test
do ibl=loop_start, loop_end
call kernel(nlon,nlev,var(:,:,param_1,ibl), var(:,:,param_2:param_3,ibl), offset, loop_start, loop_end)
call kernel(nlon,nlev,var(:,:,param_1,ibl), var(:,:,param_2:param_3,ibl), offset, loop_start, loop_end)
enddo
end subroutine driver
"""

fcode_kernel = """
module kernel_mod
implicit none
contains
subroutine kernel(nlon,nlev,var,another_var,icend,lstart,lend)
use compute_mod, only: compute
implicit none
integer, intent(in) :: nlon,nlev,icend,lstart,lend
real, intent(inout) :: var(nlon,nlev)
real, intent(inout) :: another_var(nlon,nlev,4)
integer :: jk, jl, jt
var(:,:) = 0.
do jk = 1,nlev
do jl = 1, nlon
var(jl, jk) = 0.
do jt= 1,4
another_var(jl, jk, jt) = 0.0
end do
end do
end do
call compute(nlon,nlev,var)
call compute(nlon,nlev,var)
end subroutine kernel
end module kernel_mod
"""

fcode_nested_kernel = """
module compute_mod
implicit none
contains
subroutine compute(nlon,nlev,var)
implicit none
integer, intent(in) :: nlon,nlev
real, intent(inout) :: var(nlon,nlev)
var(:,:) = 0.
end subroutine compute
end module compute_mod
"""

# recurse_to_kernels = True
nested_kernel_mod = Module.from_source(fcode_nested_kernel, frontend=frontend)
kernel_mod = Module.from_source(fcode_kernel, frontend=frontend, definitions=nested_kernel_mod)
driver = Subroutine.from_source(fcode_driver, frontend=frontend, definitions=kernel_mod)

kwargs = {'recurse_to_kernels': recurse_to_kernels, 'inline_external_only': inline_external_only}
LowerConstantArrayIndices(**kwargs).apply(driver, role='driver', targets=('kernel',))
LowerConstantArrayIndices(**kwargs).apply(kernel_mod['kernel'], role='kernel', targets=('compute',))
LowerConstantArrayIndices(**kwargs).apply(nested_kernel_mod['compute'], role='kernel')

# driver
kernel_calls = FindNodes(CallStatement).visit(driver.body)
for kernel_call in kernel_calls:
if inline_external_only and frontend != OMNI:
assert kernel_call.arguments[2].dimensions == (':', ':', 'param_1', 'ibl')
assert kernel_call.arguments[3].dimensions == (':', ':', 'param_2:param_3', 'ibl')
else:
assert kernel_call.arguments[2].dimensions == (':', ':', ':', 'ibl')
assert kernel_call.arguments[3].dimensions == (':', ':', ':', 'ibl')
# kernel
kernel_vars = kernel_mod['kernel'].variable_map
if inline_external_only and frontend != OMNI:
assert kernel_vars['var'].shape == ('nlon', 'nlev')
assert kernel_vars['var'].dimensions == ('nlon', 'nlev')
assert kernel_vars['another_var'].shape == ('nlon', 'nlev', 4)
assert kernel_vars['another_var'].dimensions == ('nlon', 'nlev', 4)
else:
assert kernel_vars['var'].shape == ('nlon', 'nlev', 5)
assert kernel_vars['var'].dimensions == ('nlon', 'nlev', 5)
assert kernel_vars['another_var'].shape == ('nlon', 'nlev', 5)
assert kernel_vars['another_var'].dimensions == ('nlon', 'nlev', 5)
if inline_external_only and frontend != OMNI:
for var in FindVariables().visit(kernel_mod['kernel'].body):
if var.name.lower() == 'var' and not any(isinstance(dim, sym.RangeIndex) for dim in var.dimensions):
assert var.dimensions == ('jl', 'jk')
if var.name.lower() == 'another_var' and not any(isinstance(dim, sym.RangeIndex) for dim in var.dimensions):
assert tuple(str(dim) for dim in var.dimensions) == ('jl', 'jk', 'jt')
else:
for var in FindVariables().visit(kernel_mod['kernel'].body):
if var.name.lower() == 'var' and not any(isinstance(dim, sym.RangeIndex) for dim in var.dimensions):
assert var.dimensions == ('jl', 'jk', 1)
if var.name.lower() == 'another_var' and not any(isinstance(dim, sym.RangeIndex) for dim in var.dimensions):
assert tuple(str(dim) for dim in var.dimensions) == ('jl', 'jk', 'jt + 2 + -1')
compute_calls = FindNodes(CallStatement).visit(kernel_mod['kernel'].body)
for compute_call in compute_calls:
for arg in compute_call.arguments:
if arg.name.lower() == 'var':
if inline_external_only and frontend != OMNI:
assert arg.dimensions == (':', ':')
elif recurse_to_kernels:
assert arg.dimensions == (':', ':', ':')
else:
assert arg.dimensions == (':', ':', '1')
# nested kernel
nested_kernel_var = nested_kernel_mod['compute'].variable_map['var']
if recurse_to_kernels and (not inline_external_only or frontend == OMNI):
assert nested_kernel_var.shape == ('nlon', 'nlev', 5)
assert nested_kernel_var.dimensions == ('nlon', 'nlev', 5)
for var in FindVariables().visit(nested_kernel_mod['compute'].body):
if var.name.lower() == 'var':
assert var.dimensions == (':', ':', 1)
else:
assert nested_kernel_var.shape == ('nlon', 'nlev')
assert nested_kernel_var.dimensions == ('nlon', 'nlev')
for var in FindVariables().visit(nested_kernel_mod['compute'].body):
if var.name.lower() == 'var':
assert var.dimensions == (':', ':')
7 changes: 6 additions & 1 deletion scripts/loki_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
from loki.transformations.argument_shape import (
ArgumentArrayShapeAnalysis, ExplicitArgumentArrayShapeTransformation
)
from loki.transformations.array_indexing import normalize_range_indexing
from loki.transformations.array_indexing import (
normalize_range_indexing, LowerConstantArrayIndices
)
from loki.transformations.build_system import (
DependencyTransformation, ModuleWrapTransformation, FileWriteTransformation
)
Expand Down Expand Up @@ -324,6 +326,9 @@ def transform_subroutine(self, routine, **kwargs):
)
scheduler.process( pipeline )

if 'cuf' in mode:
scheduler.process( LowerConstantArrayIndices() )

if mode in ['cuf-parametrise', 'cuf-hoist', 'cuf-dynamic']:
# These transformations requires complex constructor arguments,
# so we use the file-based transformation configuration.
Expand Down

0 comments on commit 3751695

Please sign in to comment.