Skip to content

Commit

Permalink
Merge pull request #361 from ecmwf-ifs/nams-resolve-vector-notation-fix
Browse files Browse the repository at this point in the history
fix for 'resolve_vector_notation' utility
  • Loading branch information
reuterbal committed Aug 23, 2024
2 parents cb2d5d5 + 4f08967 commit 49795ea
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 5 deletions.
6 changes: 3 additions & 3 deletions loki/transformations/array_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,15 @@ def resolve_vector_notation(routine):
# Create new index variable
vtype = SymbolAttributes(BasicType.INTEGER)
ivar = sym.Variable(name=f'{ivar_basename}_{i}', type=vtype, scope=routine)
shape_index_map[s] = ivar
shape_index_map[(i, s)] = ivar
index_range_map[ivar] = s

if ivar not in vdims:
vdims.append(ivar)

# Add index variable to range replacement
new_dims = as_tuple(shape_index_map.get(s, d)
for d, s in zip(v.dimensions, as_tuple(v.shape)))
new_dims = as_tuple(shape_index_map.get((i, s), d)
for i, d, s in zip(count(), v.dimensions, as_tuple(v.shape)))
vmap[v] = v.clone(dimensions=new_dims)

index_vars.update(list(vdims))
Expand Down
53 changes: 51 additions & 2 deletions loki/transformations/tests/test_array_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
from loki import Module, Subroutine, fgen
from loki.build import jit_compile, jit_compile_lib, clean_test, Builder, Obj
from loki.expression import symbols as sym, FindVariables
from loki.frontend import available_frontends
from loki.ir import FindNodes, CallStatement
from loki.frontend import available_frontends, OMNI
from loki.ir import FindNodes, CallStatement, Loop

from loki.transformations.array_indexing import (
promote_variables, demote_variables, normalize_range_indexing,
invert_array_indices, flatten_arrays,
normalize_array_shape_and_access, shift_to_zero_indexing,
resolve_vector_notation
)
from loki.transformations.transpile import FortranCTransformation

Expand Down Expand Up @@ -719,3 +720,51 @@ def validate_routine(routine):
assert (b_flattened == b_ref.flatten(order='F')).all()

builder.clean()

@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_promote_resolve_vector_notation(tmp_path, frontend):
"""
Apply and test resolve vector notation utility.
"""
fcode = """
subroutine transform_resolve_vector_notation(ret1, ret2)
implicit none
integer, parameter :: param1 = 3
integer, parameter :: param2 = 5
integer, intent(out) :: ret1(param1, param1), ret2(param1, param2)
integer :: tmp, jk
ret1(:, :) = 11
ret2(:, :) = 42
end subroutine transform_resolve_vector_notation
""".strip()
routine = Subroutine.from_source(fcode, frontend=frontend)
resolve_vector_notation(routine)

loops = FindNodes(Loop).visit(routine.body)
arrays = [var for var in FindVariables(unique=False).visit(routine.body) if isinstance(var, sym.Array)]

assert len(loops) == 4
assert loops[0].variable == 'i_ret1_1'
assert loops[0].bounds.children == (1, 'param1', 1) if frontend != OMNI else (1, 3, 1)
assert loops[1].variable == 'i_ret1_0'
assert loops[1].bounds.children == (1, 'param1', 1) if frontend != OMNI else (1, 3, 1)
assert loops[2].variable == 'i_ret2_1'
assert loops[2].bounds.children == (1, 'param2', 1) if frontend != OMNI else (1, 5, 1)
assert loops[3].variable == 'i_ret2_0'
assert loops[3].bounds.children == (1, 'param1', 1) if frontend != OMNI else (1, 3, 1)

assert len(arrays) == 2
assert arrays[0].dimensions == ('i_ret1_0', 'i_ret1_1')
assert arrays[1].dimensions == ('i_ret2_0', 'i_ret2_1')

ret1 = np.zeros(shape=(3, 3), order='F', dtype=np.int32)
ret2 = np.zeros(shape=(3, 5), order='F', dtype=np.int32)

filepath = tmp_path/(f'{routine.name}_{frontend}.f90')
function = jit_compile(routine, filepath=filepath, objname=routine.name)
function(ret1, ret2)

assert np.all(ret1 == 11)
assert np.all(ret2 == 42)

0 comments on commit 49795ea

Please sign in to comment.