Skip to content

Commit

Permalink
Refactor and improve 'LowerConstantArrayIndices' (which also allows n…
Browse files Browse the repository at this point in the history
…ow multiple constant array indices for one array argument)
  • Loading branch information
MichaelSt98 committed Aug 7, 2024
1 parent 1b71545 commit 6e779ee
Show file tree
Hide file tree
Showing 2 changed files with 285 additions and 58 deletions.
161 changes: 112 additions & 49 deletions loki/transformations/array_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,7 @@ class LowerConstantArrayIndices(Transformation):
# This trafo only operates on procedures
item_filter = (ProcedureItem,)

def __init__(self, recurse_to_kernels=True, inline_external_only=False):
def __init__(self, recurse_to_kernels=True, inline_external_only=True):
self.recurse_to_kernels = recurse_to_kernels
self.inline_external_only = inline_external_only

Expand Down Expand Up @@ -723,57 +723,120 @@ def transform_subroutine(self, routine, **kwargs):
self.process(routine, targets)

def process(self, routine, targets):
"""
Process the driver and possibly kernels
"""
dispatched_routines = ()
offset_map = {}
for call in FindNodes(ir.CallStatement).visit(routine.body):
if str(call.name).lower() not in targets:
continue
# make array dimensions explicit
# skip already dispatched routines but still update the call signature
if call.routine in dispatched_routines:
self.update_call_signature(call)
continue
# explicit array dimensions for the callee
self.explicit_dimensions(call.routine)
routine_vmap = {}
updated_call_args = {}
introduce_index = {}
introduce_offset = {}
# collect and iterate over relevant arguments (being arrays)
relevant_arguments = [arg for arg in call.arg_iter() if isinstance(arg[1], sym.Array)]
for routine_arg, call_arg in relevant_arguments:
# check for constant dimensions
insert_dims = tuple((i, dim) for i, dim in enumerate(call_arg.dimensions) if self.is_constant_dim(dim))
if insert_dims:
for insert_dim in insert_dims:
new_dims = list(call_arg.dimensions)
new_dims[insert_dim[0]] = sym.RangeIndex((None, None))
updated_call_args[call_arg.name] = call_arg.clone(dimensions=as_tuple(new_dims))
if call.routine not in dispatched_routines:
new_dims = list(routine_arg.dimensions)
# dimension is a constant RangeIndex, e.g., '1:3'
if isinstance(insert_dim[1], sym.RangeIndex):
introduce_offset[routine_arg.name] = (insert_dim[0], insert_dim[1].children[0])
new_dims[insert_dim[0]] = call_arg.shape[insert_dim[0]]
# dimension is a constant literal, e.g., '1'
else:
introduce_index[routine_arg.name] = (insert_dim[0], insert_dim[1])
new_dims.insert(insert_dim[0], call_arg.shape[insert_dim[0]])
routine_arg_new_type = routine_arg.type.clone(shape=as_tuple(new_dims))
routine_vmap[routine_arg] = routine_arg.clone(type=routine_arg_new_type,
dimensions=as_tuple(new_dims))
# apply changes to call.routine (if this routine has not yet been processed)
if call.routine not in dispatched_routines:
call.routine.spec = SubstituteExpressions(routine_vmap).visit(call.routine.spec)
vmap = {}
for var in FindVariables(unique=False).visit(call.routine.body):
if var.name in introduce_index and var.dimensions is not None and var.dimensions:
var_dim = list(var.dimensions)
var_dim.insert(introduce_index[var.name][0], introduce_index[var.name][1])
vmap[var] = var.clone(dimensions=as_tuple(var_dim))
if var.name in introduce_offset and var.dimensions is not None and var.dimensions:
var_dim = list(var.dimensions)
var_dim[introduce_offset[var.name][0]] += introduce_offset[var.name][1] - 1
vmap[var] = var.clone(dimensions=as_tuple(var_dim))
call.routine.body = SubstituteExpressions(vmap).visit(call.routine.body)
# update the call itself
new_args = tuple(updated_call_args[arg.name] if not isinstance(arg, sym._Literal)\
and arg.name in updated_call_args else arg for arg in call.arguments)
new_kwargs = ((kwarg[0], updated_call_args[kwarg[1].name]) if not isinstance(kwarg[1], sym._Literal)\
and kwarg[1].name in updated_call_args else kwarg for kwarg in call.kwarguments)
call._update(arguments=new_args, kwarguments=new_kwargs)
dispatched_routines += (call.routine,)
# create the offset map and apply to call and callee
offset_map[call.routine.name.lower()] = self.create_offset_map(call)
self.process_callee(call.routine, offset_map[call.routine.name.lower()])
self.update_call_signature(call)

def update_call_signature(self, call):
"""
Replace constant indices for call arguments being arrays with ':' and update the call.
"""
new_args = [arg.clone(dimensions=\
tuple(sym.RangeIndex((None, None)) if self.is_constant_dim(d) else d for d in arg.dimensions))\
if isinstance(arg, sym.Array) else arg for arg in call.arguments]
new_kwargs = [(kw[0], kw[1].clone(dimensions=\
tuple(sym.RangeIndex((None, None)) if self.is_constant_dim(d) else d for d in kw[1].dimensions)))\
if isinstance(kw[1], sym.Array) else kw for kw in call.kwarguments]
call._update(arguments=as_tuple(new_args), kwarguments=as_tuple(new_kwargs))

def create_offset_map(self, call):
"""
Create map/dictionary for arguments with constant array indices.
For, e.g.,
integer :: arg(len1, len2, len3, len4)
call kernel(..., arg(:, 2, 4:6, i), ...)
offset_map[arg] = {
0: (0, None, None), # same index as before, no offset
1: (None, 1, len2), # New index, offset 1, size of the dimension is len2
2: (1, 4, len3), # Used to be position 1, offset 4, size of the dimension is len3
3: (-1, None, None), # disregard as this is neither constant nor passed to callee
}
"""
offset_map = {}
for routine_arg, call_arg in call.arg_iter():
if not isinstance(routine_arg, sym.Array):
continue
offset_map[routine_arg.name] = {}
current_index = 0
for i, dim in enumerate(call_arg.dimensions):
if self.is_constant_dim(dim):
if isinstance(dim, sym.RangeIndex):
# constant array index is e.g. '1:3' or '5:10'
offset_map[routine_arg.name][i] = (current_index, dim.children[0], call_arg.shape[i])
else:
# constant array index is e.g., '1' or '42'
offset_map[routine_arg.name][i] = (None, dim, call_arg.shape[i])
current_index -= 1
else:
if not isinstance(dim, sym.RangeIndex):
# non constant array index is a variable e.g. 'jl'
offset_map[routine_arg.name][i] = (-1, None, None)
current_index -= 1
else:
# non constant array index is ':'
offset_map[routine_arg.name][i] = (current_index, None, None)
current_index += 1
return offset_map

def process_callee(self, routine, offset_map):
"""
Process/adapt the callee according to information in `offset_map`.
Adapt the variable declarations and usage/indexing.
"""
# adapt variable declarations, thus adapt the dimension and shape of the corresponding arguments
vmap = {}
variable_map = routine.variable_map
for var_name in offset_map:
var = variable_map[var_name]
new_dims = ()
for i in range(max(k for k, v in offset_map[var.name].items() if v != 0) + 1):
original_index = offset_map[var_name][i][0]
offset = offset_map[var_name][i][1]
size = offset_map[var_name][i][2]
if not (original_index is None or 0 <= original_index < len(var.dimensions)):
continue
if offset is not None:
new_dims += (size,)
else:
new_dims += (var.shape[original_index],)
vmap[var] = var.clone(dimensions=new_dims, type=var.type.clone(shape=new_dims))
routine.spec = SubstituteExpressions(vmap).visit(routine.spec)
# adapt the variable usage, thus the indexing/dimension
vmap = {}
for var in FindVariables(unique=False).visit(routine.body):
if var.name in offset_map and var.dimensions is not None and var.dimensions:
new_dims = ()
for i in range(max(k for k, v in offset_map[var.name].items() if v != 0) + 1):
original_index = offset_map[var.name][i][0]
offset = offset_map[var.name][i][1]
if not (original_index is None or 0 <= original_index < len(var.dimensions)):
continue
if offset is not None:
if original_index is None:
new_dims += (offset,)
else:
new_dims += (var.dimensions[original_index] + offset - 1,)
else:
new_dims += (var.dimensions[original_index],)
vmap[var] = var.clone(dimensions=new_dims)
routine.body = SubstituteExpressions(vmap).visit(routine.body)
Loading

0 comments on commit 6e779ee

Please sign in to comment.