Skip to content

Commit

Permalink
DynamicUboundCheckRule: fixer method now preserves subroutine arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
awnawab committed Aug 14, 2023
1 parent 08adb4e commit a357911
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 5 deletions.
19 changes: 15 additions & 4 deletions lint_rules/lint_rules/debug_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
FindNodes, CallStatement, Assignment, Scalar, RangeIndex, resolve_associates,
simplify, Sum, Product, IntLiteral, as_tuple, SubstituteExpressions, Array,
symbolic_op, StringLiteral, is_constant, LogicLiteral, VariableDeclaration, flatten,
FindInlineCalls, Conditional, FindExpressions, Comparison
FindInlineCalls, Conditional, FindExpressions, Comparison, single_variable_declaration
)
from loki.lint import GenericRule, RuleType

Expand Down Expand Up @@ -284,11 +284,22 @@ def fix_subroutine(cls, subroutine, rule_report, config):
vtype = arg.type.clone(shape=new_shape, scope=subroutine)
new_vars += as_tuple(arg.clone(type=vtype, dimensions=new_shape, scope=subroutine))

#TODO: add 'VariableDeclaration.symbols' should be of type 'Variable' rather than 'Expression'
# simplify variable declarations
single_variable_declaration(subroutine)

#TODO: 'VariableDeclaration.symbols' should be of type 'Variable' rather than 'Expression'
# to enable case-insensitive search here
new_var_names = [v.name.lower() for v in new_vars]
subroutine.variables = [var for var in subroutine.variables if not var.name.lower() in new_var_names]
subroutine.variables += new_vars

routine = subroutine.clone()
routine.variables = [var for var in routine.variables if not var.name.lower() in new_var_names]
routine.variables += new_vars

old_decls = as_tuple([decl for decl in FindNodes(VariableDeclaration).visit(subroutine.spec)
if decl.symbols[0].name.lower() in new_var_names])
new_decls = as_tuple([decl for decl in FindNodes(VariableDeclaration).visit(routine.spec)
if decl.symbols[0].name.lower() in new_var_names])
node_map.update({old_decls: new_decls})

return node_map

Expand Down
10 changes: 9 additions & 1 deletion lint_rules/tests/test_debug_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import pytest

from conftest import run_linter, available_frontends
from loki import Sourcefile, FindInlineCalls
from loki import Sourcefile, FindInlineCalls, FindNodes, VariableDeclaration
from loki.lint import DefaultHandler


Expand Down Expand Up @@ -229,4 +229,12 @@ def test_dynamic_ubound_checks(rules, frontend):
assert all(s.name == d for s, d in zip(routine.variable_map['var0'].shape, shape))
assert all(s.name == d for s, d in zip(routine.variable_map['var2'].shape, shape))

arg_names = ['klon', 'klev', 'nblk', 'var0', 'var1', 'var2']
assert [arg.name.lower() for arg in routine.arguments] == arg_names

# check that variable declarations have not been duplicated
symbols = [s.name.lower() for decl in FindNodes(VariableDeclaration).visit(routine.spec) for s in decl.symbols]
assert len(symbols) == 6
assert set(symbols) == {'klon', 'klev', 'nblk', 'var0', 'var1', 'var2'}

os.remove(kernel.path)

0 comments on commit a357911

Please sign in to comment.