Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

extending depdendency trafo #436

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions loki/transformations/build_system/dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ def transform_module(self, module, **kwargs):
if self.replace_ignore_items and (item := kwargs.get('item')):
targets += tuple(str(i).lower() for i in item.ignore)
self.rename_imports(module, imports=module.imports, targets=targets)
active_nodes = None
if self.remove_inactive_items and not kwargs.get('items') is None:
active_nodes = [item.scope_ir.name.lower() for item in kwargs['items']]
self.rename_access_specs(module, targets=targets, active_nodes=active_nodes)

def transform_subroutine(self, routine, **kwargs):
"""
Expand Down Expand Up @@ -329,6 +333,34 @@ def rename_imports(self, source, imports, targets=None):
if import_map:
source.spec = Transformer(import_map).visit(source.spec)

def rename_access_specs(self, module, targets=None, active_nodes=None):
"""
Update/rename access specifiers.

Parameters
----------
module : :any:`Module`
The IR object to transform
targets : list of str
Optional list of subroutine names for which to modify access specs
active_nodes : list of str
Optional list of active nodes
"""
if module.public_access_spec:
if active_nodes is not None:
new_access_spec = tuple(elem for elem in module.public_access_spec if elem in active_nodes)
else:
new_access_spec = module.public_access_spec
module.public_access_spec = tuple(f'{elem}{self.suffix}' if not targets or elem in targets else
elem for elem in new_access_spec)
if module.private_access_spec:
if active_nodes is not None:
new_access_spec = tuple(elem for elem in module.private_access_spec if elem in active_nodes)
else:
new_access_spec = module.public_access_spec
module.private_access_spec = tuple(f'{elem}{self.suffix}' if not targets or elem in targets
else elem for elem in new_access_spec)

def rename_interfaces(self, intfs, targets=None):
"""
Update explicit interfaces to actively transformed subroutines.
Expand Down
112 changes: 112 additions & 0 deletions loki/transformations/build_system/tests/test_dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,118 @@ def test_dependency_transformation_globalvar_imports(frontend, use_scheduler, tm
assert 'some_const' in [str(s) for s in driver['driver'].spec.body[1].symbols]


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('use_scheduler', [False, True])
def test_dependency_transformation_access_specs(frontend, use_scheduler, tmp_path, config):
"""
Test that global variable imports are not renamed as a
call statement would be.
"""

kernel_fcode = """
MODULE kernel_mod

INTEGER, PUBLIC :: some_const

PRIVATE
PUBLIC kernel, kernel_2, unused_kernel
CONTAINS
SUBROUTINE kernel(a, b, c)
IMPLICIT NONE
INTEGER, INTENT(INOUT) :: a, b, c

call kernel_2(a, b)
call kernel_3(c)
END SUBROUTINE kernel
SUBROUTINE kernel_2(a, b)
IMPLICIT NONE
INTEGER, INTENT(INOUT) :: a, b

a = 1
b = 2
END SUBROUTINE kernel_2
SUBROUTINE kernel_3(a)
IMPLICIT NONE
INTEGER, INTENT(INOUT) :: a

a = 3
END SUBROUTINE kernel_3
SUBROUTINE unused_kernel(a)
IMPLICIT NONE
INTEGER, INTENT(INOUT) :: a

a = 3
END SUBROUTINE unused_kernel
END MODULE kernel_mod
""".strip()

driver_fcode = """
SUBROUTINE driver(a, b, c)
USE kernel_mod, only: kernel
USE kernel_mod, only: some_const
IMPLICIT NONE
INTEGER, INTENT(INOUT) :: a, b, c

CALL kernel(a, b ,c)
END SUBROUTINE driver
""".strip()

transformation = DependencyTransformation(suffix='_test', module_suffix='_mod')
if use_scheduler:
(tmp_path/'kernel_mod.F90').write_text(kernel_fcode)
(tmp_path/'driver.F90').write_text(driver_fcode)
scheduler = Scheduler(
paths=[tmp_path], config=SchedulerConfig.from_dict(config), frontend=frontend, xmods=[tmp_path]
)
scheduler.process(transformation)

# Check that both, old and new module exist now in the scheduler graph
assert 'kernel_test_mod#kernel_test' in scheduler.items # for the subroutine
assert 'kernel_mod' in scheduler.items # for the global variable

kernel = scheduler['kernel_test_mod#kernel_test'].source
driver = scheduler['#driver'].source

# Check that the not-renamed module is indeed the original one
scheduler.item_factory.item_cache[str(tmp_path/'kernel_mod.F90')].source.make_complete(
frontend=frontend, xmods=[tmp_path]
)
assert (
Sourcefile.from_source(kernel_fcode, frontend=frontend, xmods=[tmp_path]).to_fortran() ==
scheduler.item_factory.item_cache[str(tmp_path/'kernel_mod.F90')].source.to_fortran()
)

else:
kernel = Sourcefile.from_source(kernel_fcode, frontend=frontend, xmods=[tmp_path])
driver = Sourcefile.from_source(driver_fcode, frontend=frontend, xmods=[tmp_path])

kernel.apply(transformation, role='kernel')
driver['driver'].apply(transformation, role='driver', targets=('kernel', 'kernel_mod'))

# Check that the global variable declaration remains unchanged
assert kernel.modules[0].variables[0].name == 'some_const'

# Check that calls and matching import have been diverted to the re-generated routine
calls = FindNodes(CallStatement).visit(driver['driver'].body)
assert len(calls) == 1
assert calls[0].name == 'kernel_test'
imports = FindNodes(Import).visit(driver['driver'].spec)
assert len(imports) == 2
assert isinstance(imports[0], Import)
assert driver['driver'].spec.body[0].module == 'kernel_test_mod'
assert 'kernel_test' in [str(s) for s in driver['driver'].spec.body[0].symbols]

# Check that global variable import remains unchanged
assert isinstance(imports[1], Import)
assert driver['driver'].spec.body[1].module == 'kernel_mod'
assert 'some_const' in [str(s) for s in driver['driver'].spec.body[1].symbols]

if use_scheduler:
assert kernel.modules[0].public_access_spec == ('kernel_test', 'kernel_2_test')
else:
assert kernel.modules[0].public_access_spec == ('kernel_test', 'kernel_2_test', 'unused_kernel_test')


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('use_scheduler', [False, True])
def test_dependency_transformation_globalvar_imports_driver_mod(frontend, use_scheduler, tmp_path, config):
Expand Down
Loading