-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New transformation 'LowerConstantArrayIndices' to allow to … #348
Merged
+550
−5
Merged
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,24 +13,28 @@ | |
from itertools import count | ||
import operator as op | ||
|
||
from loki.batch import Transformation, ProcedureItem | ||
from loki.logging import info | ||
from loki.analyse import dataflow_analysis_attached | ||
from loki.expression import ( | ||
symbols as sym, simplify, symbolic_op, FindVariables, SubstituteExpressions | ||
symbols as sym, simplify, symbolic_op, FindVariables, SubstituteExpressions, | ||
is_constant | ||
) | ||
from loki.ir import ( | ||
Assignment, Loop, VariableDeclaration, FindNodes, Transformer | ||
Assignment, Loop, VariableDeclaration, FindNodes, Transformer, nodes as ir | ||
) | ||
from loki.tools import as_tuple, CaseInsensitiveDict | ||
from loki.types import SymbolAttributes, BasicType | ||
from loki.transformations.inline import inline_constant_parameters | ||
|
||
|
||
__all__ = [ | ||
'shift_to_zero_indexing', 'invert_array_indices', | ||
'resolve_vector_notation', 'normalize_range_indexing', | ||
'promote_variables', 'promote_nonmatching_variables', | ||
'promotion_dimensions_from_loop_nest', 'demote_variables', | ||
'flatten_arrays', 'normalize_array_shape_and_access' | ||
'flatten_arrays', 'normalize_array_shape_and_access', | ||
'LowerConstantArrayIndices' | ||
] | ||
|
||
|
||
|
@@ -603,3 +607,236 @@ def new_dims(dim, shape): | |
routine.variables = [v.clone(dimensions=as_tuple(sym.Product(v.shape)), | ||
type=v.type.clone(shape=as_tuple(sym.Product(v.shape)))) | ||
if isinstance(v, sym.Array) else v for v in routine.variables] | ||
|
||
class LowerConstantArrayIndices(Transformation): | ||
""" | ||
A transformation to pass/lower constant array indices down the call tree. | ||
|
||
For example, the following code: | ||
|
||
.. code-block:: fortran | ||
|
||
subroutine driver(...) | ||
real, intent(inout) :: var(nlon,nlev,5,nb) | ||
do ibl=1,10 | ||
call kernel(var(:, :, 1, ibl), var(:, :, 2:5, ibl)) | ||
end do | ||
end subroutine driver | ||
|
||
subroutine kernel(var1, var2) | ||
real, intent(inout) :: var1(nlon, nlev) | ||
real, intent(inout) :: var2(nlon, nlev, 4) | ||
var1(:, :) = ... | ||
do jk=1,nlev | ||
do jl=1,nlon | ||
var1(jl, jk) = ... | ||
do jt=1,4 | ||
var2(jl, jk, jt) = ... | ||
enddo | ||
enddo | ||
enddo | ||
end subroutine kernel | ||
|
||
is transformed to: | ||
|
||
.. code-block:: fortran | ||
|
||
subroutine driver(...) | ||
real, intent(inout) :: var(nlon,nlev,5,nb) | ||
do ibl=1,10 | ||
call kernel(var(:, :, :, ibl), var(:, :, :, ibl)) | ||
end do | ||
end subroutine driver | ||
|
||
subroutine kernel(var1, var2) | ||
real, intent(inout) :: var1(nlon, nlev, 5) | ||
real, intent(inout) :: var2(nlon, nlev, 5) | ||
var1(:, :, 1) = ... | ||
do jk=1,nlev | ||
do jl=1,nlon | ||
var1(jl, jk, 1) = ... | ||
do jt=1,4 | ||
var2(jl, jk, jt + 2 + -1) = ... | ||
enddo | ||
enddo | ||
enddo | ||
end subroutine kernel | ||
|
||
Parameters | ||
---------- | ||
recurse_to_kernels: bool | ||
Recurse to kernels, thus lower constant array indices below the driver level for nested | ||
kernel calls (default: `True`). | ||
inline_external_only: bool | ||
Inline only external constant expressions or all of them (default: `False`) | ||
""" | ||
|
||
# This trafo only operates on procedures | ||
item_filter = (ProcedureItem,) | ||
|
||
def __init__(self, recurse_to_kernels=True, inline_external_only=True): | ||
self.recurse_to_kernels = recurse_to_kernels | ||
self.inline_external_only = inline_external_only | ||
|
||
@staticmethod | ||
def explicit_dimensions(routine): | ||
""" | ||
Make dimensions of arrays explicit within :any:`Subroutine` ``routine``. | ||
E.g., convert two-dimensional array ``arr2d`` to ``arr2d(:,:)`` or | ||
``arr3d`` to ``arr3d(:,:,:)``. | ||
|
||
Parameters | ||
---------- | ||
routine: :any:`Subroutine` | ||
The subroutine to check | ||
""" | ||
arrays = [var for var in FindVariables(unique=False).visit(routine.body) if isinstance(var, sym.Array)] | ||
array_map = {} | ||
for array in arrays: | ||
if not array.dimensions: | ||
new_dimensions = (sym.RangeIndex((None, None)),) * len(array.shape) | ||
array_map[array] = array.clone(dimensions=new_dimensions) | ||
routine.body = SubstituteExpressions(array_map).visit(routine.body) | ||
|
||
@staticmethod | ||
def is_constant_dim(dim): | ||
""" | ||
Check whether dimension dim is constant, thus, either a constant | ||
value or a constant range index. | ||
|
||
Parameters | ||
---------- | ||
dim: :py:class:`pymbolic.primitives.Expression` | ||
""" | ||
if is_constant(dim): | ||
return True | ||
if isinstance(dim, sym.RangeIndex)\ | ||
and all(child is not None and is_constant(child) for child in dim.children[:-1]): | ||
return True | ||
return False | ||
|
||
def transform_subroutine(self, routine, **kwargs): | ||
role = kwargs['role'] | ||
targets = tuple(str(t).lower() for t in as_tuple(kwargs.get('targets', None))) | ||
if role == 'driver' or self.recurse_to_kernels: | ||
inline_constant_parameters(routine, external_only=self.inline_external_only) | ||
self.process(routine, targets) | ||
|
||
def process(self, routine, targets): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is pretty lengthy code and hard to digest. Can we break this down a bit, e.g., like this?
|
||
""" | ||
Process the driver and possibly kernels | ||
""" | ||
dispatched_routines = () | ||
offset_map = {} | ||
for call in FindNodes(ir.CallStatement).visit(routine.body): | ||
if str(call.name).lower() not in targets: | ||
continue | ||
# skip already dispatched routines but still update the call signature | ||
if call.routine in dispatched_routines: | ||
self.update_call_signature(call) | ||
continue | ||
# explicit array dimensions for the callee | ||
self.explicit_dimensions(call.routine) | ||
reuterbal marked this conversation as resolved.
Show resolved
Hide resolved
|
||
dispatched_routines += (call.routine,) | ||
# create the offset map and apply to call and callee | ||
offset_map[call.routine.name.lower()] = self.create_offset_map(call) | ||
self.process_callee(call.routine, offset_map[call.routine.name.lower()]) | ||
self.update_call_signature(call) | ||
|
||
def update_call_signature(self, call): | ||
""" | ||
Replace constant indices for call arguments being arrays with ':' and update the call. | ||
""" | ||
new_args = [arg.clone(dimensions=\ | ||
tuple(sym.RangeIndex((None, None)) if self.is_constant_dim(d) else d for d in arg.dimensions))\ | ||
if isinstance(arg, sym.Array) else arg for arg in call.arguments] | ||
new_kwargs = [(kw[0], kw[1].clone(dimensions=\ | ||
tuple(sym.RangeIndex((None, None)) if self.is_constant_dim(d) else d for d in kw[1].dimensions)))\ | ||
if isinstance(kw[1], sym.Array) else kw for kw in call.kwarguments] | ||
call._update(arguments=as_tuple(new_args), kwarguments=as_tuple(new_kwargs)) | ||
|
||
def create_offset_map(self, call): | ||
""" | ||
Create map/dictionary for arguments with constant array indices. | ||
|
||
For, e.g., | ||
|
||
integer :: arg(len1, len2, len3, len4) | ||
call kernel(..., arg(:, 2, 4:6, i), ...) | ||
|
||
offset_map[arg] = { | ||
0: (0, None, None), # same index as before, no offset | ||
1: (None, 1, len2), # New index, offset 1, size of the dimension is len2 | ||
2: (1, 4, len3), # Used to be position 1, offset 4, size of the dimension is len3 | ||
3: (-1, None, None), # disregard as this is neither constant nor passed to callee | ||
} | ||
""" | ||
offset_map = {} | ||
for routine_arg, call_arg in call.arg_iter(): | ||
if not isinstance(routine_arg, sym.Array): | ||
continue | ||
offset_map[routine_arg.name] = {} | ||
current_index = 0 | ||
for i, dim in enumerate(call_arg.dimensions): | ||
if self.is_constant_dim(dim): | ||
if isinstance(dim, sym.RangeIndex): | ||
# constant array index is e.g. '1:3' or '5:10' | ||
offset_map[routine_arg.name][i] = (current_index, dim.children[0], call_arg.shape[i]) | ||
else: | ||
# constant array index is e.g., '1' or '42' | ||
offset_map[routine_arg.name][i] = (None, dim, call_arg.shape[i]) | ||
current_index -= 1 | ||
else: | ||
if not isinstance(dim, sym.RangeIndex): | ||
# non constant array index is a variable e.g. 'jl' | ||
offset_map[routine_arg.name][i] = (-1, None, None) | ||
current_index -= 1 | ||
else: | ||
# non constant array index is ':' | ||
offset_map[routine_arg.name][i] = (current_index, None, None) | ||
current_index += 1 | ||
return offset_map | ||
|
||
def process_callee(self, routine, offset_map): | ||
""" | ||
Process/adapt the callee according to information in `offset_map`. | ||
|
||
Adapt the variable declarations and usage/indexing. | ||
""" | ||
# adapt variable declarations, thus adapt the dimension and shape of the corresponding arguments | ||
vmap = {} | ||
variable_map = routine.variable_map | ||
for var_name in offset_map: | ||
var = variable_map[var_name] | ||
new_dims = () | ||
for i in range(max(k for k, v in offset_map[var.name].items() if v != 0) + 1): | ||
original_index = offset_map[var_name][i][0] | ||
offset = offset_map[var_name][i][1] | ||
size = offset_map[var_name][i][2] | ||
if not (original_index is None or 0 <= original_index < len(var.dimensions)): | ||
continue | ||
if offset is not None: | ||
new_dims += (size,) | ||
else: | ||
new_dims += (var.shape[original_index],) | ||
vmap[var] = var.clone(dimensions=new_dims, type=var.type.clone(shape=new_dims)) | ||
routine.spec = SubstituteExpressions(vmap).visit(routine.spec) | ||
# adapt the variable usage, thus the indexing/dimension | ||
vmap = {} | ||
for var in FindVariables(unique=False).visit(routine.body): | ||
if var.name in offset_map and var.dimensions is not None and var.dimensions: | ||
new_dims = () | ||
for i in range(max(k for k, v in offset_map[var.name].items() if v != 0) + 1): | ||
original_index = offset_map[var.name][i][0] | ||
offset = offset_map[var.name][i][1] | ||
if not (original_index is None or 0 <= original_index < len(var.dimensions)): | ||
continue | ||
if offset is not None: | ||
if original_index is None: | ||
new_dims += (offset,) | ||
else: | ||
new_dims += (var.dimensions[original_index] + offset - 1,) | ||
else: | ||
new_dims += (var.dimensions[original_index],) | ||
vmap[var] = var.clone(dimensions=new_dims) | ||
routine.body = SubstituteExpressions(vmap).visit(routine.body) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have potential issues with aliasing here, or rather compilers not being able to figure out that there isn't any? I'm wondering wether we could avoid passing the same array multiple times and instead rename the argument then on kernel side?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, that is possible. However, this makes the transformation more invasive and the question arises whether we want to do that within this transformation or introduce a separate utility/transformation to do that. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, good point, that could indeed be a separate transformation pass! In that case, disregard this here.