From 4f089675b1cff92b23a7fbc35c0a8af98da59d0d Mon Sep 17 00:00:00 2001 From: Michael Staneker Date: Thu, 22 Aug 2024 12:03:39 +0000 Subject: [PATCH] fix for 'resolve_vector_notation' utility that occurred whenever two dimensions had the same size --- loki/transformations/array_indexing.py | 6 +-- .../tests/test_array_indexing.py | 53 ++++++++++++++++++- 2 files changed, 54 insertions(+), 5 deletions(-) diff --git a/loki/transformations/array_indexing.py b/loki/transformations/array_indexing.py index 063c52650..3feeb4520 100644 --- a/loki/transformations/array_indexing.py +++ b/loki/transformations/array_indexing.py @@ -119,15 +119,15 @@ def resolve_vector_notation(routine): # Create new index variable vtype = SymbolAttributes(BasicType.INTEGER) ivar = sym.Variable(name=f'{ivar_basename}_{i}', type=vtype, scope=routine) - shape_index_map[s] = ivar + shape_index_map[(i, s)] = ivar index_range_map[ivar] = s if ivar not in vdims: vdims.append(ivar) # Add index variable to range replacement - new_dims = as_tuple(shape_index_map.get(s, d) - for d, s in zip(v.dimensions, as_tuple(v.shape))) + new_dims = as_tuple(shape_index_map.get((i, s), d) + for i, d, s in zip(count(), v.dimensions, as_tuple(v.shape))) vmap[v] = v.clone(dimensions=new_dims) index_vars.update(list(vdims)) diff --git a/loki/transformations/tests/test_array_indexing.py b/loki/transformations/tests/test_array_indexing.py index d544daf2e..520fe6fe6 100644 --- a/loki/transformations/tests/test_array_indexing.py +++ b/loki/transformations/tests/test_array_indexing.py @@ -12,13 +12,14 @@ from loki import Module, Subroutine, fgen from loki.build import jit_compile, jit_compile_lib, clean_test, Builder, Obj from loki.expression import symbols as sym, FindVariables -from loki.frontend import available_frontends -from loki.ir import FindNodes, CallStatement +from loki.frontend import available_frontends, OMNI +from loki.ir import FindNodes, CallStatement, Loop from loki.transformations.array_indexing import ( promote_variables, demote_variables, normalize_range_indexing, invert_array_indices, flatten_arrays, normalize_array_shape_and_access, shift_to_zero_indexing, + resolve_vector_notation ) from loki.transformations.transpile import FortranCTransformation @@ -719,3 +720,51 @@ def validate_routine(routine): assert (b_flattened == b_ref.flatten(order='F')).all() builder.clean() + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_transform_promote_resolve_vector_notation(tmp_path, frontend): + """ + Apply and test resolve vector notation utility. + """ + fcode = """ +subroutine transform_resolve_vector_notation(ret1, ret2) + implicit none + integer, parameter :: param1 = 3 + integer, parameter :: param2 = 5 + integer, intent(out) :: ret1(param1, param1), ret2(param1, param2) + integer :: tmp, jk + + ret1(:, :) = 11 + ret2(:, :) = 42 + +end subroutine transform_resolve_vector_notation + """.strip() + routine = Subroutine.from_source(fcode, frontend=frontend) + resolve_vector_notation(routine) + + loops = FindNodes(Loop).visit(routine.body) + arrays = [var for var in FindVariables(unique=False).visit(routine.body) if isinstance(var, sym.Array)] + + assert len(loops) == 4 + assert loops[0].variable == 'i_ret1_1' + assert loops[0].bounds.children == (1, 'param1', 1) if frontend != OMNI else (1, 3, 1) + assert loops[1].variable == 'i_ret1_0' + assert loops[1].bounds.children == (1, 'param1', 1) if frontend != OMNI else (1, 3, 1) + assert loops[2].variable == 'i_ret2_1' + assert loops[2].bounds.children == (1, 'param2', 1) if frontend != OMNI else (1, 5, 1) + assert loops[3].variable == 'i_ret2_0' + assert loops[3].bounds.children == (1, 'param1', 1) if frontend != OMNI else (1, 3, 1) + + assert len(arrays) == 2 + assert arrays[0].dimensions == ('i_ret1_0', 'i_ret1_1') + assert arrays[1].dimensions == ('i_ret2_0', 'i_ret2_1') + + ret1 = np.zeros(shape=(3, 3), order='F', dtype=np.int32) + ret2 = np.zeros(shape=(3, 5), order='F', dtype=np.int32) + + filepath = tmp_path/(f'{routine.name}_{frontend}.f90') + function = jit_compile(routine, filepath=filepath, objname=routine.name) + function(ret1, ret2) + + assert np.all(ret1 == 11) + assert np.all(ret2 == 42)