Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New transformation 'LowerConstantArrayIndices' to allow to … #348

Merged
merged 3 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 240 additions & 3 deletions loki/transformations/array_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,28 @@
from itertools import count
import operator as op

from loki.batch import Transformation, ProcedureItem
from loki.logging import info
from loki.analyse import dataflow_analysis_attached
from loki.expression import (
symbols as sym, simplify, symbolic_op, FindVariables, SubstituteExpressions
symbols as sym, simplify, symbolic_op, FindVariables, SubstituteExpressions,
is_constant
)
from loki.ir import (
Assignment, Loop, VariableDeclaration, FindNodes, Transformer
Assignment, Loop, VariableDeclaration, FindNodes, Transformer, nodes as ir
)
from loki.tools import as_tuple, CaseInsensitiveDict
from loki.types import SymbolAttributes, BasicType
from loki.transformations.inline import inline_constant_parameters


__all__ = [
'shift_to_zero_indexing', 'invert_array_indices',
'resolve_vector_notation', 'normalize_range_indexing',
'promote_variables', 'promote_nonmatching_variables',
'promotion_dimensions_from_loop_nest', 'demote_variables',
'flatten_arrays', 'normalize_array_shape_and_access'
'flatten_arrays', 'normalize_array_shape_and_access',
'LowerConstantArrayIndices'
]


Expand Down Expand Up @@ -603,3 +607,236 @@ def new_dims(dim, shape):
routine.variables = [v.clone(dimensions=as_tuple(sym.Product(v.shape)),
type=v.type.clone(shape=as_tuple(sym.Product(v.shape))))
if isinstance(v, sym.Array) else v for v in routine.variables]

class LowerConstantArrayIndices(Transformation):
"""
A transformation to pass/lower constant array indices down the call tree.

For example, the following code:

.. code-block:: fortran

subroutine driver(...)
real, intent(inout) :: var(nlon,nlev,5,nb)
do ibl=1,10
call kernel(var(:, :, 1, ibl), var(:, :, 2:5, ibl))
end do
end subroutine driver

subroutine kernel(var1, var2)
real, intent(inout) :: var1(nlon, nlev)
real, intent(inout) :: var2(nlon, nlev, 4)
var1(:, :) = ...
do jk=1,nlev
do jl=1,nlon
var1(jl, jk) = ...
do jt=1,4
var2(jl, jk, jt) = ...
enddo
enddo
enddo
end subroutine kernel

is transformed to:

.. code-block:: fortran

subroutine driver(...)
real, intent(inout) :: var(nlon,nlev,5,nb)
do ibl=1,10
call kernel(var(:, :, :, ibl), var(:, :, :, ibl))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have potential issues with aliasing here, or rather compilers not being able to figure out that there isn't any? I'm wondering wether we could avoid passing the same array multiple times and instead rename the argument then on kernel side?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, that is possible. However, this makes the transformation more invasive and the question arises whether we want to do that within this transformation or introduce a separate utility/transformation to do that. What do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, good point, that could indeed be a separate transformation pass! In that case, disregard this here.

end do
end subroutine driver

subroutine kernel(var1, var2)
real, intent(inout) :: var1(nlon, nlev, 5)
real, intent(inout) :: var2(nlon, nlev, 5)
var1(:, :, 1) = ...
do jk=1,nlev
do jl=1,nlon
var1(jl, jk, 1) = ...
do jt=1,4
var2(jl, jk, jt + 2 + -1) = ...
enddo
enddo
enddo
end subroutine kernel

Parameters
----------
recurse_to_kernels: bool
Recurse to kernels, thus lower constant array indices below the driver level for nested
kernel calls (default: `True`).
inline_external_only: bool
Inline only external constant expressions or all of them (default: `False`)
"""

# This trafo only operates on procedures
item_filter = (ProcedureItem,)

def __init__(self, recurse_to_kernels=True, inline_external_only=True):
self.recurse_to_kernels = recurse_to_kernels
self.inline_external_only = inline_external_only

@staticmethod
def explicit_dimensions(routine):
"""
Make dimensions of arrays explicit within :any:`Subroutine` ``routine``.
E.g., convert two-dimensional array ``arr2d`` to ``arr2d(:,:)`` or
``arr3d`` to ``arr3d(:,:,:)``.

Parameters
----------
routine: :any:`Subroutine`
The subroutine to check
"""
arrays = [var for var in FindVariables(unique=False).visit(routine.body) if isinstance(var, sym.Array)]
array_map = {}
for array in arrays:
if not array.dimensions:
new_dimensions = (sym.RangeIndex((None, None)),) * len(array.shape)
array_map[array] = array.clone(dimensions=new_dimensions)
routine.body = SubstituteExpressions(array_map).visit(routine.body)

@staticmethod
def is_constant_dim(dim):
"""
Check whether dimension dim is constant, thus, either a constant
value or a constant range index.

Parameters
----------
dim: :py:class:`pymbolic.primitives.Expression`
"""
if is_constant(dim):
return True
if isinstance(dim, sym.RangeIndex)\
and all(child is not None and is_constant(child) for child in dim.children[:-1]):
return True
return False

def transform_subroutine(self, routine, **kwargs):
role = kwargs['role']
targets = tuple(str(t).lower() for t in as_tuple(kwargs.get('targets', None)))
if role == 'driver' or self.recurse_to_kernels:
inline_constant_parameters(routine, external_only=self.inline_external_only)
self.process(routine, targets)

def process(self, routine, targets):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is pretty lengthy code and hard to digest. Can we break this down a bit, e.g., like this?

def process_caller():
    dispatched_routines = set()
    for call in ...:
        for routine_arg, call_arg in ...:
            # Create argument mappings
            self.update_caller_arg(updated_call_args, ...)
            if call.routine.name not in dispatched_routines:
                self.update_callee_arg(routine_vmap, ...)

        # Apply mapping to callee
        if call.routine.name not in dispatched_routines:
            dispatched_routines.add(call.routine.name)
            self.process_callee(call.routine, routine_vmap)

        # Update the call
        self.update_call(...)

"""
Process the driver and possibly kernels
"""
dispatched_routines = ()
offset_map = {}
for call in FindNodes(ir.CallStatement).visit(routine.body):
if str(call.name).lower() not in targets:
continue
# skip already dispatched routines but still update the call signature
if call.routine in dispatched_routines:
self.update_call_signature(call)
continue
# explicit array dimensions for the callee
self.explicit_dimensions(call.routine)
reuterbal marked this conversation as resolved.
Show resolved Hide resolved
dispatched_routines += (call.routine,)
# create the offset map and apply to call and callee
offset_map[call.routine.name.lower()] = self.create_offset_map(call)
self.process_callee(call.routine, offset_map[call.routine.name.lower()])
self.update_call_signature(call)

def update_call_signature(self, call):
"""
Replace constant indices for call arguments being arrays with ':' and update the call.
"""
new_args = [arg.clone(dimensions=\
tuple(sym.RangeIndex((None, None)) if self.is_constant_dim(d) else d for d in arg.dimensions))\
if isinstance(arg, sym.Array) else arg for arg in call.arguments]
new_kwargs = [(kw[0], kw[1].clone(dimensions=\
tuple(sym.RangeIndex((None, None)) if self.is_constant_dim(d) else d for d in kw[1].dimensions)))\
if isinstance(kw[1], sym.Array) else kw for kw in call.kwarguments]
call._update(arguments=as_tuple(new_args), kwarguments=as_tuple(new_kwargs))

def create_offset_map(self, call):
"""
Create map/dictionary for arguments with constant array indices.

For, e.g.,

integer :: arg(len1, len2, len3, len4)
call kernel(..., arg(:, 2, 4:6, i), ...)

offset_map[arg] = {
0: (0, None, None), # same index as before, no offset
1: (None, 1, len2), # New index, offset 1, size of the dimension is len2
2: (1, 4, len3), # Used to be position 1, offset 4, size of the dimension is len3
3: (-1, None, None), # disregard as this is neither constant nor passed to callee
}
"""
offset_map = {}
for routine_arg, call_arg in call.arg_iter():
if not isinstance(routine_arg, sym.Array):
continue
offset_map[routine_arg.name] = {}
current_index = 0
for i, dim in enumerate(call_arg.dimensions):
if self.is_constant_dim(dim):
if isinstance(dim, sym.RangeIndex):
# constant array index is e.g. '1:3' or '5:10'
offset_map[routine_arg.name][i] = (current_index, dim.children[0], call_arg.shape[i])
else:
# constant array index is e.g., '1' or '42'
offset_map[routine_arg.name][i] = (None, dim, call_arg.shape[i])
current_index -= 1
else:
if not isinstance(dim, sym.RangeIndex):
# non constant array index is a variable e.g. 'jl'
offset_map[routine_arg.name][i] = (-1, None, None)
current_index -= 1
else:
# non constant array index is ':'
offset_map[routine_arg.name][i] = (current_index, None, None)
current_index += 1
return offset_map

def process_callee(self, routine, offset_map):
"""
Process/adapt the callee according to information in `offset_map`.

Adapt the variable declarations and usage/indexing.
"""
# adapt variable declarations, thus adapt the dimension and shape of the corresponding arguments
vmap = {}
variable_map = routine.variable_map
for var_name in offset_map:
var = variable_map[var_name]
new_dims = ()
for i in range(max(k for k, v in offset_map[var.name].items() if v != 0) + 1):
original_index = offset_map[var_name][i][0]
offset = offset_map[var_name][i][1]
size = offset_map[var_name][i][2]
if not (original_index is None or 0 <= original_index < len(var.dimensions)):
continue
if offset is not None:
new_dims += (size,)
else:
new_dims += (var.shape[original_index],)
vmap[var] = var.clone(dimensions=new_dims, type=var.type.clone(shape=new_dims))
routine.spec = SubstituteExpressions(vmap).visit(routine.spec)
# adapt the variable usage, thus the indexing/dimension
vmap = {}
for var in FindVariables(unique=False).visit(routine.body):
if var.name in offset_map and var.dimensions is not None and var.dimensions:
new_dims = ()
for i in range(max(k for k, v in offset_map[var.name].items() if v != 0) + 1):
original_index = offset_map[var.name][i][0]
offset = offset_map[var.name][i][1]
if not (original_index is None or 0 <= original_index < len(var.dimensions)):
continue
if offset is not None:
if original_index is None:
new_dims += (offset,)
else:
new_dims += (var.dimensions[original_index] + offset - 1,)
else:
new_dims += (var.dimensions[original_index],)
vmap[var] = var.clone(dimensions=new_dims)
routine.body = SubstituteExpressions(vmap).visit(routine.body)
Loading
Loading