Skip to content

Commit

Permalink
Merge pull request #363 from ecmwf-ifs/naml-omni-fix-range-indexing
Browse files Browse the repository at this point in the history
OMNI: Fix dimension range-indexing in frontend
  • Loading branch information
reuterbal authored Aug 26, 2024
2 parents 9a76c21 + 29a4869 commit 68cb274
Show file tree
Hide file tree
Showing 24 changed files with 240 additions and 293 deletions.
13 changes: 4 additions & 9 deletions loki/expression/tests/test_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,11 +473,11 @@ def test_index_ranges(frontend):
# OMNI will insert implicit lower=1 into shape declarations,
# we simply have to live with it... :(
assert str(vmap['v4']) == 'v4(dim)' or str(vmap['v4']) == 'v4(1:dim)'
assert str(vmap['v5']) == 'v5(1:dim)'
assert str(vmap['v5']) == 'v5(1:dim)' or str(vmap['v5']) == 'v5(dim)'

vmap_body = {v.name: v for v in FindVariables().visit(routine.body)}
assert str(vmap_body['v1']) == 'v1(::2)'
assert str(vmap_body['v2']) == 'v2(1:dim)'
assert str(vmap_body['v2']) == 'v2(dim)' or str(vmap_body['v2']) == 'v2(1:dim)'
assert str(vmap_body['v3']) == 'v3(0:4:2)'
assert str(vmap_body['v5']) == 'v5(:)'

Expand Down Expand Up @@ -1853,13 +1853,8 @@ def to_str(_parsed):
assert isinstance(parsed, sym.Array)
assert all(isinstance(_parsed, sym.Scalar) for _parsed in parsed.dimensions)
assert all(_parsed.scope == routine for _parsed in parsed.dimensions)
if frontend == OMNI:
assert all(isinstance(_parsed, sym.RangeIndex) for _parsed in parsed.shape)
assert all(isinstance(_parsed.upper, sym.Scalar) for _parsed in parsed.shape)
assert all(_parsed.upper.scope == routine for _parsed in parsed.shape)
else:
assert all(isinstance(_parsed, sym.Scalar) for _parsed in parsed.shape)
assert all(_parsed.scope == routine for _parsed in parsed.shape)
assert all(isinstance(_parsed, sym.Scalar) for _parsed in parsed.shape)
assert all(_parsed.scope == routine for _parsed in parsed.shape)
assert to_str(parsed) == 'arr(i1,i2,i3)'

parsed = parse_expr(convert_to_case('my_func(i1)', mode=case), scope=routine)
Expand Down
9 changes: 3 additions & 6 deletions loki/frontend/tests/test_frontends.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,12 +212,9 @@ def test_associates(tmp_path, frontend):
module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
routine = module['associates']
variables = FindVariables().visit(routine.body)
if frontend == OMNI:
assert all(v.shape == ('1:3',)
for v in variables if v.name in ['vector', 'vector2'])
else:
assert all(v.shape == ('3',)
for v in variables if v.name in ['vector', 'vector2'])
assert all(
v.shape == ('3',) for v in variables if v.name in ['vector', 'vector2']
)

for assoc in FindNodes(ir.Associate).visit(routine.body):
for var in FindVariables().visit(assoc.body):
Expand Down
85 changes: 85 additions & 0 deletions loki/frontend/tests/test_omni.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
Specific test battery for the OMNI parser frontend.
"""

import pytest

from loki import Module, Subroutine
from loki.expression import FindVariables
from loki.frontend import OMNI, HAVE_OMNI
from loki.ir import nodes as ir, FindNodes


@pytest.mark.skipif(not HAVE_OMNI, reason='Test tequires OMNI frontend.')
def test_derived_type_definitions(tmp_path):
""" Test correct parsing of derived type declarations. """
fcode = """
module omni_derived_type_mod
type explicit
real(kind=8) :: scalar, vector(3), matrix(3, 3)
end type explicit
type deferred
real(kind=8), allocatable :: scalar, vector(:), matrix(:, :)
end type deferred
type ranged
real(kind=8) :: scalar, vector(1:3), matrix(0:3, 0:3)
end type ranged
end module omni_derived_type_mod
"""
# Parse the source and validate the IR
module = Module.from_source(fcode, frontend=OMNI, xmods=[tmp_path])

assert len(module.typedefs) == 3
explicit_symbols = FindVariables(unique=False).visit(module['explicit'].body)
assert explicit_symbols == ('scalar', 'vector(3)', 'matrix(3, 3)')

deferred_symbols = FindVariables(unique=False).visit(module['deferred'].body)
assert deferred_symbols == ('scalar', 'vector(:)', 'matrix(:, :)')

ranged_symbols = FindVariables(unique=False).visit(module['ranged'].body)
assert ranged_symbols == ('scalar', 'vector(3)', 'matrix(0:3, 0:3)')


@pytest.mark.skipif(not HAVE_OMNI, reason='Test tequires OMNI frontend.')
def test_array_dimensions(tmp_path):
""" Test correct parsing of derived type declarations. """
fcode = """
subroutine omni_array_indexing(n, a, b)
integer, intent(in) :: n
real(kind=8), intent(inout) :: a(3), b(n)
real(kind=8) :: c(n, n)
real(kind=8) :: d(1:n, 0:n)
a(:) = 11.
b(1:n) = 42.
c(2:n, 0:n) = 66.
d(:, 0:n) = 68.
end subroutine omni_array_indexing
"""
# Parse the source and validate the IR
routine = Subroutine.from_source(fcode, frontend=OMNI, xmods=[tmp_path])

# OMNI separate declarations per variable
decls = FindNodes(ir.VariableDeclaration).visit(routine.spec)
assert len(decls) == 5
assert decls[0].symbols == ('n',)
assert decls[1].symbols == ('a(3)',)
assert decls[2].symbols == ('b(n)',)
assert decls[3].symbols == ('c(n, n)',)
assert decls[4].symbols == ('d(n, 0:n)',)

assigns = FindNodes(ir.Assignment).visit(routine.body)
assert len(assigns) == 4
assert assigns[0].lhs == 'a(:)'
assert assigns[1].lhs == 'b(1:n)'
assert assigns[2].lhs == 'c(2:n, 0:n)'
assert assigns[3].lhs == 'd(:, 0:n)'
36 changes: 36 additions & 0 deletions loki/frontend/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from codetiming import Timer
from more_itertools import split_after

from loki.expression import (
symbols as sym, SubstituteExpressionsMapper, ExpressionRetriever
)
from loki.ir import (
NestedTransformer, FindNodes, PatternFinder, Transformer,
Assignment, Comment, CommentBlock, VariableDeclaration,
Expand Down Expand Up @@ -267,6 +270,35 @@ def visit_tuple(self, o, **kwargs):
return tuple(i for i in visited if i is not None and as_tuple(i))


class RangeIndexTransformer(Transformer):
"""
:any:`Transformer` that replaces ``arr(1:n)`` notations with
``arr(n)`` in :any:`VariableDeclaration`.
"""

retriever = ExpressionRetriever(lambda e: isinstance(e, (sym.Array)))

@staticmethod
def is_one_index(dim):
return isinstance(dim, sym.RangeIndex) and dim.lower == 1 and dim.step is None

def visit_VariableDeclaration(self, o, **kwargs): # pylint: disable=unused-argument
"""
Gets all :any:`Array` symbols and adjusts dimension and shape.
"""
vmap = {}
for v in self.retriever.retrieve(o.symbols):
dimensions = tuple(d.upper if self.is_one_index(d) else d for d in v.dimensions)
_type = v.type
if _type.shape:
shape = tuple(d.upper if self.is_one_index(d) else d for d in _type.shape)
_type = _type.clone(shape=shape)
vmap[v] = v.clone(dimensions=dimensions, type=_type)

mapper = SubstituteExpressionsMapper(vmap, invalidate_source=self.invalidate_source)
return o.clone(symbols=mapper(o.symbols, recurse_to_declaration_attributes=True))


@Timer(logger=perf, text=lambda s: f'[Loki::Frontend] Executed sanitize_ir in {s:.2f}s')
def sanitize_ir(_ir, frontend, pp_registry=None, pp_info=None):
"""
Expand Down Expand Up @@ -303,6 +335,10 @@ def sanitize_ir(_ir, frontend, pp_registry=None, pp_info=None):
_ir = InlineCommentTransformer(inplace=True, invalidate_source=False).visit(_ir)
_ir = ClusterCommentTransformer(inplace=True, invalidate_source=False).visit(_ir)

if frontend == OMNI:
# Revert OMNI's array dimension expansion from `a(n)` => `arr(1:n)`
_ir = RangeIndexTransformer(invalidate_source=False).visit(_ir)

if frontend in (OMNI, OFP):
_ir = inline_labels(_ir)

Expand Down
8 changes: 2 additions & 6 deletions loki/tests/test_derived_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1173,12 +1173,8 @@ def test_derived_type_rescope_symbols_shadowed(tmp_path, shadowed_typedef_symbol
assert istate in ('istate(nmaxstreams)', 'istate(1:nmaxstreams)')
assert istate.scope is rng_type

if frontend == OMNI:
assert istate.dimensions[0] == '1:nmaxstreams'
assert istate.dimensions[0].stop.scope
else:
assert istate.dimensions[0] == 'nmaxstreams'
assert istate.dimensions[0].scope
assert istate.dimensions[0] == 'nmaxstreams'
assert istate.dimensions[0].scope

# FIXME: Use of NMaxStreams from parent scope is in the wrong scope (LOKI-52)
#assert istate.dimensions[0].scope is module
Expand Down
6 changes: 3 additions & 3 deletions loki/tests/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def test_module_external_typedefs_subroutine(frontend, tmp_path):
pt_ext = routine.variables[0]

# OMNI resolves explicit shape parameters in the frontend parser
exptected_array_shape = '(1:2, 1:3)' if frontend == OMNI else '(x, y)'
exptected_array_shape = '(2, 3)' if frontend == OMNI else '(x, y)'

# Check that the `array` variable in the `ext` type is found and
# has correct type and shape info
Expand Down Expand Up @@ -187,7 +187,7 @@ def test_module_external_typedefs_type(frontend, tmp_path):
assert isinstance(module['other_routine'].symbol_attrs['pt'].dtype.typedef, TypeDef)

# OMNI resolves explicit shape parameters in the frontend parser
exptected_array_shape = '(1:2, 1:3)' if frontend == OMNI else '(x, y)'
exptected_array_shape = '(2, 3)' if frontend == OMNI else '(x, y)'

# Check that the `array` variable in the `ext` type is found and
# has correct type and shape info
Expand Down Expand Up @@ -234,7 +234,7 @@ def test_module_nested_types(frontend, tmp_path):
end module type_mod
"""
# OMNI resolves explicit shape parameters in the frontend parser
exptected_array_shape = '(1:2, 1:3)' if frontend == OMNI else '(x, y)'
exptected_array_shape = '(2, 3)' if frontend == OMNI else '(x, y)'

module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
parent = module.typedef_map['parent_type']
Expand Down
Loading

0 comments on commit 68cb274

Please sign in to comment.