From faab0f841df9f5bb9eba060ba88e3a26dbbb568a Mon Sep 17 00:00:00 2001 From: "Terence D. Honles" Date: Tue, 23 Mar 2021 13:34:22 -0700 Subject: [PATCH] track `if typing.TYPE_CHECKING` to warn about non runtime bindings When importing or defining values in ``if typing.TYPE_CHECKING`` blocks the bound names will not be available at runtime and may cause errors when used in the following way:: import typing if typing.TYPE_CHECKING: from module import Type # some slow import or circular reference def method(value) -> Type: # the import is needed by the type checker assert isinstance(value, Type) # this is a runtime error This change allows pyflakes to track what names are bound for runtime use, and allows it to warn when a non runtime name is used in a runtime context. --- pyflakes/checker.py | 142 ++++++++++++++++++------- pyflakes/messages.py | 8 ++ pyflakes/test/test_type_annotations.py | 66 ++++++++++++ 3 files changed, 177 insertions(+), 39 deletions(-) diff --git a/pyflakes/checker.py b/pyflakes/checker.py index 629dacf0..1ba48482 100644 --- a/pyflakes/checker.py +++ b/pyflakes/checker.py @@ -215,10 +215,11 @@ class Binding: the node that this binding was last used. """ - def __init__(self, name, source): + def __init__(self, name, source, *, runtime=True): self.name = name self.source = source self.used = False + self.runtime = runtime def __str__(self): return self.name @@ -249,8 +250,8 @@ def redefines(self, other): class Builtin(Definition): """A definition created for all Python builtins.""" - def __init__(self, name): - super().__init__(name, None) + def __init__(self, name, *, runtime=True): + super().__init__(name, None, runtime=runtime) def __repr__(self): return '<{} object {!r} at 0x{:x}>'.format( @@ -294,10 +295,10 @@ class Importation(Definition): @type fullName: C{str} """ - def __init__(self, name, source, full_name=None): + def __init__(self, name, source, full_name=None, *, runtime=True): self.fullName = full_name or name self.redefined = [] - super().__init__(name, source) + super().__init__(name, source, runtime=runtime) def redefines(self, other): if isinstance(other, SubmoduleImportation): @@ -342,11 +343,11 @@ class SubmoduleImportation(Importation): name is also the same, to avoid false positives. """ - def __init__(self, name, source): + def __init__(self, name, source, *, runtime=True): # A dot should only appear in the name when it is a submodule import assert '.' in name and (not source or isinstance(source, ast.Import)) package_name = name.split('.')[0] - super().__init__(package_name, source) + super().__init__(package_name, source, runtime=runtime) self.fullName = name def redefines(self, other): @@ -364,7 +365,9 @@ def source_statement(self): class ImportationFrom(Importation): - def __init__(self, name, source, module, real_name=None): + def __init__( + self, name, source, module, real_name=None, *, runtime=True + ): self.module = module self.real_name = real_name or name @@ -373,7 +376,7 @@ def __init__(self, name, source, module, real_name=None): else: full_name = module + '.' + self.real_name - super().__init__(name, source, full_name) + super().__init__(name, source, full_name, runtime=runtime) def __str__(self): """Return import full name with alias.""" @@ -393,8 +396,8 @@ def source_statement(self): class StarImportation(Importation): """A binding created by a 'from x import *' statement.""" - def __init__(self, name, source): - super().__init__('*', source) + def __init__(self, name, source, *, runtime=True): + super().__init__('*', source, runtime=runtime) # Each star importation needs a unique name, and # may not be the module name otherwise it will be deemed imported self.name = name + '.*' @@ -483,7 +486,7 @@ class ExportBinding(Binding): C{__all__} will not have an unused import warning reported for them. """ - def __init__(self, name, source, scope): + def __init__(self, name, source, scope, *, runtime=True): if '__all__' in scope and isinstance(source, ast.AugAssign): self.names = list(scope['__all__'].names) else: @@ -514,7 +517,7 @@ def _add_to_names(container): # If not list concatenation else: break - super().__init__(name, source) + super().__init__(name, source, runtime=runtime) class Scope(dict): @@ -722,10 +725,6 @@ class Checker: ast.DictComp: GeneratorScope, } - nodeDepth = 0 - offset = None - _in_annotation = AnnotationState.NONE - builtIns = set(builtin_vars).union(_MAGIC_GLOBALS) _customBuiltIns = os.environ.get('PYFLAKES_BUILTINS') if _customBuiltIns: @@ -734,6 +733,10 @@ class Checker: def __init__(self, tree, filename='(none)', builtins=None, withDoctest='PYFLAKES_DOCTEST' in os.environ, file_tokens=()): + self.nodeDepth = 0 + self.offset = None + self._in_annotation = AnnotationState.NONE + self._in_type_check_guard = False self._nodeHandlers = {} self._deferred = collections.deque() self.deadScopes = [] @@ -1000,9 +1003,11 @@ def addBinding(self, node, value): # then assume the rebound name is used as a global or within a loop value.used = self.scope[value.name].used - # don't treat annotations as assignments if there is an existing value - # in scope - if value.name not in self.scope or not isinstance(value, Annotation): + # always allow the first assignment or if not already a runtime value, + # but do not shadow an existing assignment with an annotation or non + # runtime value. + if (not existing or not existing.runtime + or (not isinstance(value, Annotation) and value.runtime)): if isinstance(value, NamedExprAssignment): # PEP 572: use scope in which outermost generator is defined scope = next( @@ -1080,20 +1085,28 @@ def handleNodeLoad(self, node, parent): self.report(messages.InvalidPrintSyntax, node) try: - scope[name].used = (self.scope, node) + binding = scope[name] + except KeyError: + pass + else: + # check if the binding is used in the wrong context + if (not binding.runtime + and not (self._in_type_check_guard or self._in_annotation)): + self.report(messages.TypeCheckingOnly, node, name) + return + + # mark the binding as used + binding.used = (self.scope, node) # if the name of SubImportation is same as # alias of other Importation and the alias # is used, SubImportation also should be marked as used. - n = scope[name] - if isinstance(n, Importation) and n._has_alias(): + if isinstance(binding, Importation) and binding._has_alias(): try: - scope[n.fullName].used = (self.scope, node) + scope[binding.fullName].used = (self.scope, node) except KeyError: pass - except KeyError: - pass - else: + return importStarred = importStarred or scope.importStarred @@ -1150,12 +1163,13 @@ def handleNodeStore(self, node): break parent_stmt = self.getParent(node) + runtime = not self._in_type_check_guard if isinstance(parent_stmt, ast.AnnAssign) and parent_stmt.value is None: binding = Annotation(name, node) elif isinstance(parent_stmt, (FOR_TYPES, ast.comprehension)) or ( parent_stmt != node._pyflakes_parent and not self.isLiteralTupleUnpacking(parent_stmt)): - binding = Binding(name, node) + binding = Binding(name, node, runtime=runtime) elif ( name == '__all__' and isinstance(self.scope, ModuleScope) and @@ -1164,11 +1178,13 @@ def handleNodeStore(self, node): (ast.Assign, ast.AugAssign, ast.AnnAssign) ) ): - binding = ExportBinding(name, node._pyflakes_parent, self.scope) + binding = ExportBinding( + name, node._pyflakes_parent, self.scope, runtime=runtime + ) elif isinstance(parent_stmt, ast.NamedExpr): - binding = NamedExprAssignment(name, node) + binding = NamedExprAssignment(name, node, runtime=runtime) else: - binding = Assignment(name, node) + binding = Assignment(name, node, runtime=runtime) self.addBinding(node, binding) def handleNodeDelete(self, node): @@ -1832,7 +1848,38 @@ def DICT(self, node): def IF(self, node): if isinstance(node.test, ast.Tuple) and node.test.elts != []: self.report(messages.IfTuple, node) - self.handleChildren(node) + + # check for typing.TYPE_CHECKING, and if so handle each node specifically + if_type_checking = _is_typing(node.test, 'TYPE_CHECKING', self.scopeStack) + if if_type_checking or ( + # check for else TYPE_CHECKING + isinstance(node.test, ast.UnaryOp) + and isinstance(node.test.op, ast.Not) + and _is_typing(node.test.operand, 'TYPE_CHECKING', self.scopeStack) + ): + self.handleNode(node.test, node) + try: + _in_type_check_guard = self._in_type_check_guard + + # update the current TYPE_CHECKING state and handle the if-node(s) + self._in_type_check_guard = if_type_checking + if isinstance(node.body, list): + for child in node.body: + self.handleNode(child, node) + else: + self.handleNode(node.body, node) + + # update the current TYPE_CHECKING state and handle the else-node(s) + self._in_type_check_guard = not if_type_checking or _in_type_check_guard + if isinstance(node.orelse, list): + for child in node.orelse: + self.handleNode(child, node) + else: + self.handleNode(node.orelse, node) + finally: + self._in_type_check_guard = _in_type_check_guard + else: + self.handleChildren(node) IFEXP = IF @@ -1943,7 +1990,12 @@ def FUNCTIONDEF(self, node): with self._type_param_scope(node): self.LAMBDA(node) - self.addBinding(node, FunctionDefinition(node.name, node)) + self.addBinding( + node, + FunctionDefinition( + node.name, node, runtime=not self._in_type_check_guard + ), + ) # doctest does not process doctest within a doctest, # or in nested functions. if (self.withDoctest and @@ -2028,7 +2080,12 @@ def CLASSDEF(self, node): for stmt in node.body: self.handleNode(stmt, node) - self.addBinding(node, ClassDefinition(node.name, node)) + self.addBinding( + node, + ClassDefinition( + node.name, node, runtime=not self._in_type_check_guard + ), + ) def AUGASSIGN(self, node): self.handleNodeLoad(node.target, node) @@ -2061,12 +2118,17 @@ def TUPLE(self, node): LIST = TUPLE def IMPORT(self, node): + runtime = not self._in_type_check_guard for alias in node.names: if '.' in alias.name and not alias.asname: - importation = SubmoduleImportation(alias.name, node) + importation = SubmoduleImportation( + alias.name, node, runtime=runtime + ) else: name = alias.asname or alias.name - importation = Importation(name, node, alias.name) + importation = Importation( + name, node, alias.name, runtime=runtime + ) self.addBinding(node, importation) def IMPORTFROM(self, node): @@ -2078,6 +2140,7 @@ def IMPORTFROM(self, node): module = ('.' * node.level) + (node.module or '') + runtime = not self._in_type_check_guard for alias in node.names: name = alias.asname or alias.name if node.module == '__future__': @@ -2095,10 +2158,11 @@ def IMPORTFROM(self, node): self.scope.importStarred = True self.report(messages.ImportStarUsed, node, module) - importation = StarImportation(module, node) + importation = StarImportation(module, node, runtime=runtime) else: - importation = ImportationFrom(name, node, - module, alias.name) + importation = ImportationFrom( + name, node, module, alias.name, runtime=runtime + ) self.addBinding(node, importation) def TRY(self, node): diff --git a/pyflakes/messages.py b/pyflakes/messages.py index 405dc72f..df9fda09 100644 --- a/pyflakes/messages.py +++ b/pyflakes/messages.py @@ -65,6 +65,14 @@ def __init__(self, filename, loc, name, from_list): self.message_args = (name, from_list) +class TypeCheckingOnly(Message): + message = 'name only defined for TYPE_CHECKIN: %r' + + def __init__(self, filename, loc, name): + Message.__init__(self, filename, loc) + self.message_args = (name,) + + class UndefinedName(Message): message = 'undefined name %r' diff --git a/pyflakes/test/test_type_annotations.py b/pyflakes/test/test_type_annotations.py index f4f8ded5..73e049e9 100644 --- a/pyflakes/test/test_type_annotations.py +++ b/pyflakes/test/test_type_annotations.py @@ -664,6 +664,55 @@ def f() -> T: pass """) + def test_typing_guard_import(self): + # T is imported for runtime use + self.flakes(""" + from typing import TYPE_CHECKING + + if TYPE_CHECKING: + from t import T + + def f(x) -> T: + from t import T + + assert isinstance(x, T) + return x + """) + # T is defined at runtime in one side of the if/else block + self.flakes(""" + from typing import TYPE_CHECKING, Union + + if TYPE_CHECKING: + from t import T + else: + T = object + + if not TYPE_CHECKING: + U = object + else: + from t import U + + def f(x) -> Union[T, U]: + assert isinstance(x, (T, U)) + return x + """) + + def test_typing_guard_import_runtime_error(self): + # T and U are not bound for runtime use + self.flakes(""" + from typing import TYPE_CHECKING, Union + + if TYPE_CHECKING: + from t import T + + class U: + pass + + def f(x) -> Union[T, U]: + assert isinstance(x, (T, U)) + return x + """, m.TypeCheckingOnly, m.TypeCheckingOnly) + def test_typing_guard_for_protocol(self): self.flakes(""" from typing import TYPE_CHECKING @@ -678,6 +727,23 @@ def f() -> int: pass """) + def test_typing_guard_with_elif_branch(self): + # This test will not raise an error even though Protocol is not + # defined outside TYPE_CHECKING because Pyflakes does not do case + # analysis. + self.flakes(""" + from typing import TYPE_CHECKING + if TYPE_CHECKING: + from typing import Protocol + elif False: + Protocol = object + else: + pass + class C(Protocol): + def f(): # type: () -> int + pass + """) + def test_typednames_correct_forward_ref(self): self.flakes(""" from typing import TypedDict, List, NamedTuple