diff --git a/lint_rules/lint_rules/debug_rules.py b/lint_rules/lint_rules/debug_rules.py index 5061da53c..e30c3c2c2 100644 --- a/lint_rules/lint_rules/debug_rules.py +++ b/lint_rules/lint_rules/debug_rules.py @@ -10,7 +10,7 @@ FindNodes, CallStatement, Assignment, Scalar, RangeIndex, resolve_associates, simplify, Sum, Product, IntLiteral, as_tuple, SubstituteExpressions, Array, symbolic_op, StringLiteral, is_constant, LogicLiteral, VariableDeclaration, flatten, - FindInlineCalls, Conditional, FindExpressions, Comparison + FindInlineCalls, Conditional, FindExpressions, Comparison, single_variable_declaration ) from loki.lint import GenericRule, RuleType @@ -284,11 +284,22 @@ def fix_subroutine(cls, subroutine, rule_report, config): vtype = arg.type.clone(shape=new_shape, scope=subroutine) new_vars += as_tuple(arg.clone(type=vtype, dimensions=new_shape, scope=subroutine)) - #TODO: add 'VariableDeclaration.symbols' should be of type 'Variable' rather than 'Expression' + # simplify variable declarations + single_variable_declaration(subroutine) + + #TODO: 'VariableDeclaration.symbols' should be of type 'Variable' rather than 'Expression' # to enable case-insensitive search here new_var_names = [v.name.lower() for v in new_vars] - subroutine.variables = [var for var in subroutine.variables if not var.name.lower() in new_var_names] - subroutine.variables += new_vars + + routine = subroutine.clone() + routine.variables = [var for var in routine.variables if not var.name.lower() in new_var_names] + routine.variables += new_vars + + old_decls = as_tuple([decl for decl in FindNodes(VariableDeclaration).visit(subroutine.spec) + if decl.symbols[0].name.lower() in new_var_names]) + new_decls = as_tuple([decl for decl in FindNodes(VariableDeclaration).visit(routine.spec) + if decl.symbols[0].name.lower() in new_var_names]) + node_map.update({old_decls: new_decls}) return node_map diff --git a/lint_rules/tests/test_debug_rules.py b/lint_rules/tests/test_debug_rules.py index 08a8e3036..a5697c541 100644 --- a/lint_rules/tests/test_debug_rules.py +++ b/lint_rules/tests/test_debug_rules.py @@ -11,7 +11,7 @@ import pytest from conftest import run_linter, available_frontends -from loki import Sourcefile, FindInlineCalls +from loki import Sourcefile, FindInlineCalls, FindNodes, VariableDeclaration from loki.lint import DefaultHandler @@ -229,4 +229,12 @@ def test_dynamic_ubound_checks(rules, frontend): assert all(s.name == d for s, d in zip(routine.variable_map['var0'].shape, shape)) assert all(s.name == d for s, d in zip(routine.variable_map['var2'].shape, shape)) + arg_names = ['klon', 'klev', 'nblk', 'var0', 'var1', 'var2'] + assert [arg.name.lower() for arg in routine.arguments] == arg_names + + # check that variable declarations have not been duplicated + symbols = [s.name.lower() for decl in FindNodes(VariableDeclaration).visit(routine.spec) for s in decl.symbols] + assert len(symbols) == 6 + assert set(symbols) == {'klon', 'klev', 'nblk', 'var0', 'var1', 'var2'} + os.remove(kernel.path)