diff --git a/ChangeLog b/ChangeLog index 4ea98fedfd..14dea92439 100644 --- a/ChangeLog +++ b/ChangeLog @@ -16,6 +16,11 @@ Release date: TBA Refs PyCQA/pylint#2567 +* Treat ``typing.NewType()`` values as normal subclasses. + + Closes PyCQA/pylint#2296 + Closes PyCQA/pylint#3162 + What's New in astroid 2.12.13? ============================== Release date: 2022-11-19 diff --git a/astroid/brain/brain_typing.py b/astroid/brain/brain_typing.py index b34b8bec50..f6111dc5fb 100644 --- a/astroid/brain/brain_typing.py +++ b/astroid/brain/brain_typing.py @@ -10,10 +10,11 @@ from collections.abc import Iterator from functools import partial -from astroid import context, extract_node, inference_tip +from astroid import context, extract_node, inference_tip, nodes from astroid.builder import _extract_single_node from astroid.const import PY38_PLUS, PY39_PLUS from astroid.exceptions import ( + AstroidImportError, AttributeInferenceError, InferenceError, UseInferenceDefault, @@ -35,8 +36,6 @@ from astroid.util import Uninferable TYPING_NAMEDTUPLE_BASENAMES = {"NamedTuple", "typing.NamedTuple"} -TYPING_TYPEVARS = {"TypeVar", "NewType"} -TYPING_TYPEVARS_QUALIFIED = {"typing.TypeVar", "typing.NewType"} TYPING_TYPE_TEMPLATE = """ class Meta(type): def __getitem__(self, item): @@ -49,6 +48,13 @@ def __args__(self): class {0}(metaclass=Meta): pass """ +# PEP484 suggests NewType is equivalent to this for typing purposes +# https://www.python.org/dev/peps/pep-0484/#newtype-helper-function +TYPING_NEWTYPE_TEMPLATE = """ +class {derived}({base}): + def __init__(self, val: {base}) -> None: + ... +""" TYPING_MEMBERS = set(getattr(typing, "__all__", [])) TYPING_ALIAS = frozenset( @@ -103,24 +109,33 @@ def __class_getitem__(cls, item): """ -def looks_like_typing_typevar_or_newtype(node): +def looks_like_typing_typevar(node: nodes.Call) -> bool: func = node.func if isinstance(func, Attribute): - return func.attrname in TYPING_TYPEVARS + return func.attrname == "TypeVar" if isinstance(func, Name): - return func.name in TYPING_TYPEVARS + return func.name == "TypeVar" return False -def infer_typing_typevar_or_newtype(node, context_itton=None): - """Infer a typing.TypeVar(...) or typing.NewType(...) call""" +def looks_like_typing_newtype(node: nodes.Call) -> bool: + func = node.func + if isinstance(func, Attribute): + return func.attrname == "NewType" + if isinstance(func, Name): + return func.name == "NewType" + return False + + +def infer_typing_typevar( + node: nodes.Call, ctx: context.InferenceContext | None = None +) -> Iterator[nodes.ClassDef]: + """Infer a typing.TypeVar(...) call""" try: - func = next(node.func.infer(context=context_itton)) + next(node.func.infer(context=ctx)) except (InferenceError, StopIteration) as exc: raise UseInferenceDefault from exc - if func.qname() not in TYPING_TYPEVARS_QUALIFIED: - raise UseInferenceDefault if not node.args: raise UseInferenceDefault # Cannot infer from a dynamic class name (f-string) @@ -129,7 +144,135 @@ def infer_typing_typevar_or_newtype(node, context_itton=None): typename = node.args[0].as_string().strip("'") node = extract_node(TYPING_TYPE_TEMPLATE.format(typename)) - return node.infer(context=context_itton) + return node.infer(context=ctx) + + +def infer_typing_newtype( + node: nodes.Call, ctx: context.InferenceContext | None = None +) -> Iterator[nodes.ClassDef]: + """Infer a typing.NewType(...) call""" + try: + next(node.func.infer(context=ctx)) + except (InferenceError, StopIteration) as exc: + raise UseInferenceDefault from exc + + if len(node.args) != 2: + raise UseInferenceDefault + + # Cannot infer from a dynamic class name (f-string) + if isinstance(node.args[0], JoinedStr) or isinstance(node.args[1], JoinedStr): + raise UseInferenceDefault + + derived, base = node.args + derived_name = derived.as_string().strip("'") + base_name = base.as_string().strip("'") + + new_node: ClassDef = extract_node( + TYPING_NEWTYPE_TEMPLATE.format(derived=derived_name, base=base_name) + ) + new_node.parent = node.parent + + new_bases: list[NodeNG] = [] + + if not isinstance(base, nodes.Const): + # Base type arg is a normal reference, so no need to do special lookups + new_bases = [base] + elif isinstance(base, nodes.Const) and isinstance(base.value, str): + # If the base type is given as a string (e.g. for a forward reference), + # make a naive attempt to find the corresponding node. + _, resolved_base = node.frame().lookup(base_name) + if resolved_base: + base_node = resolved_base[0] + + # If the value is from an "import from" statement, follow the import chain + if isinstance(base_node, nodes.ImportFrom): + ctx = ctx.clone() if ctx else context.InferenceContext() + ctx.lookupname = base_name + base_node = next(base_node.infer(context=ctx)) + + new_bases = [base_node] + elif "." in base.value: + possible_base = _try_find_imported_object_from_str(node, base.value, ctx) + if possible_base: + new_bases = [possible_base] + + if new_bases: + new_node.postinit( + bases=new_bases, body=new_node.body, decorators=new_node.decorators + ) + + return new_node.infer(context=ctx) + + +def _try_find_imported_object_from_str( + node: nodes.Call, + name: str, + ctx: context.InferenceContext | None, +) -> nodes.NodeNG | None: + for statement_mod_name, _ in _possible_module_object_splits(name): + # Find import statements that may pull in the appropriate modules + # The name used to find this statement may not correspond to the name of the module actually being imported + # For example, "import email.charset" is found by lookup("email") + _, resolved_bases = node.frame().lookup(statement_mod_name) + if not resolved_bases: + continue + + resolved_base = resolved_bases[0] + if isinstance(resolved_base, nodes.Import): + # Extract the names of the module as they are accessed from actual code + scope_names = {(alias or name) for (name, alias) in resolved_base.names} + aliases = {alias: name for (name, alias) in resolved_base.names if alias} + + # Find potential mod_name, obj_name splits that work with the available names + # for the module in this scope + import_targets = [ + (mod_name, obj_name) + for (mod_name, obj_name) in _possible_module_object_splits(name) + if mod_name in scope_names + ] + if not import_targets: + continue + + import_target, name_in_mod = import_targets[0] + import_target = aliases.get(import_target, import_target) + + # Try to import the module and find the object in it + try: + resolved_mod: nodes.Module = resolved_base.do_import_module( + import_target + ) + except AstroidImportError: + # If the module doesn't actually exist, try the next option + continue + + # Try to find the appropriate ClassDef or other such node in the target module + _, object_results_in_mod = resolved_mod.lookup(name_in_mod) + if not object_results_in_mod: + continue + + base_node = object_results_in_mod[0] + + # If the value is from an "import from" statement, follow the import chain + if isinstance(base_node, nodes.ImportFrom): + ctx = ctx.clone() if ctx else context.InferenceContext() + ctx.lookupname = name_in_mod + base_node = next(base_node.infer(context=ctx)) + + return base_node + + return None + + +def _possible_module_object_splits( + dot_str: str, +) -> Iterator[tuple[str, str]]: + components = dot_str.split(".") + popped = [] + + while components: + popped.append(components.pop()) + + yield ".".join(components), ".".join(reversed(popped)) def _looks_like_typing_subscript(node): @@ -404,8 +547,13 @@ def infer_typing_cast( AstroidManager().register_transform( Call, - inference_tip(infer_typing_typevar_or_newtype), - looks_like_typing_typevar_or_newtype, + inference_tip(infer_typing_typevar), + looks_like_typing_typevar, +) +AstroidManager().register_transform( + Call, + inference_tip(infer_typing_newtype), + looks_like_typing_newtype, ) AstroidManager().register_transform( Subscript, inference_tip(infer_typing_attr), _looks_like_typing_subscript diff --git a/tests/unittest_brain.py b/tests/unittest_brain.py index 3789c3bf41..8466076966 100644 --- a/tests/unittest_brain.py +++ b/tests/unittest_brain.py @@ -1718,6 +1718,26 @@ def test_typing_types(self) -> None: inferred = next(node.infer()) self.assertIsInstance(inferred, nodes.ClassDef, node.as_string()) + def test_typing_typevar_bad_args(self) -> None: + ast_nodes = builder.extract_node( + """ + from typing import TypeVar + + T = TypeVar() + T #@ + + U = TypeVar(f"U") + U #@ + """ + ) + assert isinstance(ast_nodes, list) + + no_args_node = ast_nodes[0] + assert list(no_args_node.infer()) == [util.Uninferable] + + fstr_node = ast_nodes[1] + assert list(fstr_node.infer()) == [util.Uninferable] + def test_typing_type_without_tip(self): """Regression test for https://github.com/PyCQA/pylint/issues/5770""" node = builder.extract_node( @@ -1729,7 +1749,337 @@ def make_new_type(t): """ ) with self.assertRaises(UseInferenceDefault): - astroid.brain.brain_typing.infer_typing_typevar_or_newtype(node.value) + astroid.brain.brain_typing.infer_typing_newtype(node.value) + + def test_typing_newtype_attrs(self) -> None: + ast_nodes = builder.extract_node( + """ + from typing import NewType + import decimal + from decimal import Decimal + + NewType("Foo", str) #@ + NewType("Bar", "int") #@ + NewType("Baz", Decimal) #@ + NewType("Qux", decimal.Decimal) #@ + """ + ) + assert isinstance(ast_nodes, list) + + # Base type given by reference + foo_node = ast_nodes[0] + + # Should be unambiguous + foo_inferred_all = list(foo_node.infer()) + assert len(foo_inferred_all) == 1 + + foo_inferred = foo_inferred_all[0] + assert isinstance(foo_inferred, astroid.ClassDef) + + # Check base type method is inferred by accessing one of its methods + foo_base_class_method = foo_inferred.getattr("endswith")[0] + assert isinstance(foo_base_class_method, astroid.FunctionDef) + assert foo_base_class_method.qname() == "builtins.str.endswith" + + # Base type given by string (i.e. "int") + bar_node = ast_nodes[1] + bar_inferred_all = list(bar_node.infer()) + assert len(bar_inferred_all) == 1 + bar_inferred = bar_inferred_all[0] + assert isinstance(bar_inferred, astroid.ClassDef) + + bar_base_class_method = bar_inferred.getattr("bit_length")[0] + assert isinstance(bar_base_class_method, astroid.FunctionDef) + assert bar_base_class_method.qname() == "builtins.int.bit_length" + + # Decimal may be reexported from an implementation-defined module. For + # example, in CPython 3.10 this is _decimal, but in PyPy 7.3 it's + # _pydecimal. So the expected qname needs to be grabbed dynamically. + decimal_quant_node = builder.extract_node( + """ + from decimal import Decimal + Decimal.quantize #@ + """ + ) + assert isinstance(decimal_quant_node, nodes.NodeNG) + + # Just grab the first result, since infer() may return values for both + # _decimal and _pydecimal + decimal_quant_qname = next(decimal_quant_node.infer()).qname() + + # Base type is from an "import from" + baz_node = ast_nodes[2] + baz_inferred_all = list(baz_node.infer()) + assert len(baz_inferred_all) == 1 + baz_inferred = baz_inferred_all[0] + assert isinstance(baz_inferred, astroid.ClassDef) + + baz_base_class_method = baz_inferred.getattr("quantize")[0] + assert isinstance(baz_base_class_method, astroid.FunctionDef) + assert decimal_quant_qname == baz_base_class_method.qname() + + # Base type is from an import + qux_node = ast_nodes[3] + qux_inferred_all = list(qux_node.infer()) + qux_inferred = qux_inferred_all[0] + assert isinstance(qux_inferred, astroid.ClassDef) + + qux_base_class_method = qux_inferred.getattr("quantize")[0] + assert isinstance(qux_base_class_method, astroid.FunctionDef) + assert decimal_quant_qname == qux_base_class_method.qname() + + def test_typing_newtype_bad_args(self) -> None: + ast_nodes = builder.extract_node( + """ + from typing import NewType + + NoArgs = NewType() + NoArgs #@ + + OneArg = NewType("OneArg") + OneArg #@ + + ThreeArgs = NewType("ThreeArgs", int, str) + ThreeArgs #@ + + DynamicArg = NewType(f"DynamicArg", int) + DynamicArg #@ + + DynamicBase = NewType("DynamicBase", f"int") + DynamicBase #@ + """ + ) + assert isinstance(ast_nodes, list) + + node: nodes.NodeNG + for node in ast_nodes: + assert list(node.infer()) == [util.Uninferable] + + def test_typing_newtype_user_defined(self) -> None: + ast_nodes = builder.extract_node( + """ + from typing import NewType + + class A: + def __init__(self, value: int): + self.value = value + + a = A(5) + a #@ + + B = NewType("B", A) + b = B(5) + b #@ + """ + ) + assert isinstance(ast_nodes, list) + + for node in ast_nodes: + self._verify_node_has_expected_attr(node) + + def test_typing_newtype_forward_reference(self) -> None: + # Similar to the test above, but using a forward reference for "A" + ast_nodes = builder.extract_node( + """ + from typing import NewType + + B = NewType("B", "A") + + class A: + def __init__(self, value: int): + self.value = value + + a = A(5) + a #@ + + b = B(5) + b #@ + """ + ) + assert isinstance(ast_nodes, list) + + for node in ast_nodes: + self._verify_node_has_expected_attr(node) + + def _verify_node_has_expected_attr(self, node: nodes.NodeNG) -> None: + inferred_all = list(node.infer()) + assert len(inferred_all) == 1 + inferred = inferred_all[0] + assert isinstance(inferred, astroid.Instance) + + # Should be able to infer that the "value" attr is present on both types + val = inferred.getattr("value")[0] + assert isinstance(val, astroid.AssignAttr) + + # Sanity check: nonexistent attr is not inferred + with self.assertRaises(AttributeInferenceError): + inferred.getattr("bad_attr") + + def test_typing_newtype_forward_reference_imported(self) -> None: + all_ast_nodes = builder.extract_node( + """ + from typing import NewType + + A = NewType("A", "decimal.Decimal") + B = NewType("B", "decimal_mod_alias.Decimal") + C = NewType("C", "Decimal") + D = NewType("D", "DecimalAlias") + + import decimal + import decimal as decimal_mod_alias + from decimal import Decimal + from decimal import Decimal as DecimalAlias + + Decimal #@ + + a = A(decimal.Decimal(2)) + a #@ + b = B(decimal_mod_alias.Decimal(2)) + b #@ + c = C(Decimal(2)) + c #@ + d = D(DecimalAlias(2)) + d #@ + """ + ) + assert isinstance(all_ast_nodes, list) + + real_dec, *ast_nodes = all_ast_nodes + + real_quantize = next(real_dec.infer()).getattr("quantize") + + for node in ast_nodes: + all_inferred = list(node.infer()) + assert len(all_inferred) == 1 + inferred = all_inferred[0] + assert isinstance(inferred, astroid.Instance) + + assert inferred.getattr("quantize") == real_quantize + + def test_typing_newtype_forward_ref_bad_base(self) -> None: + ast_nodes = builder.extract_node( + """ + from typing import NewType + + A = NewType("A", "DoesntExist") + + a = A() + a #@ + + # Valid name, but not actually imported + B = NewType("B", "decimal.Decimal") + + b = B() + b #@ + + # AST works out, but can't import the module + import not_a_real_module + + C = NewType("C", "not_a_real_module.SomeClass") + c = C() + c #@ + + # Real module, fake base class name + import email.charset + + D = NewType("D", "email.charset.BadClassRef") + d = D() + d #@ + + # Real module, but aliased differently than used + import email.header as header_mod + + E = NewType("E", "email.header.Header") + e = E(header_mod.Header()) + e #@ + """ + ) + assert isinstance(ast_nodes, list) + + for ast_node in ast_nodes: + inferred = next(ast_node.infer()) + + with self.assertRaises(astroid.AttributeInferenceError): + inferred.getattr("value") + + def test_typing_newtype_forward_ref_nested_module(self) -> None: + ast_nodes = builder.extract_node( + """ + from typing import NewType + + A = NewType("A", "email.charset.Charset") + B = NewType("B", "charset.Charset") + + # header is unused in both cases, but verifies that module name is properly checked + import email.header, email.charset + from email import header, charset + + real = charset.Charset() + real #@ + + a = A(email.charset.Charset()) + a #@ + + b = B(charset.Charset()) + """ + ) + assert isinstance(ast_nodes, list) + + real, *newtypes = ast_nodes + + real_inferred_all = list(real.infer()) + assert len(real_inferred_all) == 1 + real_inferred = real_inferred_all[0] + + real_method = real_inferred.getattr("get_body_encoding") + + for newtype_node in newtypes: + newtype_inferred_all = list(newtype_node.infer()) + assert len(newtype_inferred_all) == 1 + newtype_inferred = newtype_inferred_all[0] + + newtype_method = newtype_inferred.getattr("get_body_encoding") + + assert real_method == newtype_method + + def test_typing_newtype_forward_ref_nested_class(self) -> None: + ast_nodes = builder.extract_node( + """ + from typing import NewType + + A = NewType("A", "SomeClass.Nested") + + class SomeClass: + class Nested: + def method(self) -> None: + pass + + real = SomeClass.Nested() + real #@ + + a = A(SomeClass.Nested()) + a #@ + """ + ) + assert isinstance(ast_nodes, list) + + real, newtype = ast_nodes + + real_all_inferred = list(real.infer()) + assert len(real_all_inferred) == 1 + real_inferred = real_all_inferred[0] + real_method = real_inferred.getattr("method") + + newtype_all_inferred = list(newtype.infer()) + assert len(newtype_all_inferred) == 1 + newtype_inferred = newtype_all_inferred[0] + + # This could theoretically work, but for now just here to check that + # the "forward-declared module" inference doesn't totally break things + with self.assertRaises(astroid.AttributeInferenceError): + newtype_method = newtype_inferred.getattr("method") + + assert real_method == newtype_method def test_namedtuple_nested_class(self): result = builder.extract_node(