diff --git a/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md b/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md index 36f53afe4d7fed..6c21913b16cb9d 100644 --- a/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md +++ b/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md @@ -542,8 +542,7 @@ e: list[Any] | None = [1] reveal_type(e) # revealed: list[Any] f: list[Any] | None = f2(1) -# TODO: Better constraint solver. -reveal_type(f) # revealed: list[int] | None +reveal_type(f) # revealed: list[Any] | None g: list[Any] | dict[Any, Any] = f3(1) # TODO: Better constraint solver. @@ -600,6 +599,48 @@ reveal_type(x7) # revealed: Contravariant[Any] reveal_type(x8) # revealed: Invariant[Any] ``` +## Declared type preference sees through subtyping + +```toml +[environment] +python-version = "3.12" +``` + +```py +from typing import Any, Iterable, Literal, MutableSequence, Sequence + +x1: Sequence[Any] = [1, 2, 3] +reveal_type(x1) # revealed: list[Any] + +x2: MutableSequence[Any] = [1, 2, 3] +reveal_type(x2) # revealed: list[Any] + +x3: Iterable[Any] = [1, 2, 3] +reveal_type(x3) # revealed: list[Any] + +class X[T]: + value: T + + def __init__(self, value: T): ... + +class A[T](X[T]): ... + +def a[T](value: T) -> A[T]: + return A(value) + +x4: A[object] = A(1) +reveal_type(x4) # revealed: A[object] + +x5: X[object] = A(1) +reveal_type(x5) # revealed: A[object] + +x6: X[object] | None = A(1) +reveal_type(x6) # revealed: A[object] + +x7: X[object] | None = a(1) +reveal_type(x7) # revealed: A[object] +``` + ## Narrow generic unions ```toml diff --git a/crates/ty_python_semantic/resources/mdtest/dataclasses/fields.md b/crates/ty_python_semantic/resources/mdtest/dataclasses/fields.md index 28a69081e57cf0..1727439778250d 100644 --- a/crates/ty_python_semantic/resources/mdtest/dataclasses/fields.md +++ b/crates/ty_python_semantic/resources/mdtest/dataclasses/fields.md @@ -37,7 +37,7 @@ class Data: content: list[int] = field(default_factory=list) timestamp: datetime = field(default_factory=datetime.now, init=False) -# revealed: (self: Data, content: list[int] = list[int]) -> None +# revealed: (self: Data, content: list[int] = Unknown) -> None reveal_type(Data.__init__) data = Data([1, 2, 3]) diff --git a/crates/ty_python_semantic/resources/mdtest/literal_promotion.md b/crates/ty_python_semantic/resources/mdtest/literal_promotion.md index eb79c44b6c2fc2..f51cc9eaaddfb8 100644 --- a/crates/ty_python_semantic/resources/mdtest/literal_promotion.md +++ b/crates/ty_python_semantic/resources/mdtest/literal_promotion.md @@ -341,3 +341,52 @@ reveal_type(x21) # revealed: X[Literal[1]] x22: X[Literal[1]] | None = x(1) reveal_type(x22) # revealed: X[Literal[1]] ``` + +## Literal annotations see through subtyping + +```py +from typing import Any, Iterable, Literal, MutableSequence, Sequence + +x1: Sequence[Literal[1, 2, 3]] = [1, 2, 3] +reveal_type(x1) # revealed: list[Literal[1, 2, 3]] + +x2: MutableSequence[Literal[1, 2, 3]] = [1, 2, 3] +reveal_type(x2) # revealed: list[Literal[1, 2, 3]] + +x3: Iterable[Literal[1, 2, 3]] = [1, 2, 3] +reveal_type(x3) # revealed: list[Literal[1, 2, 3]] + +class Sup1[T]: + value: T + +class Sub1[T](Sup1[T]): ... + +def sub1[T](value: T) -> Sub1[T]: + return Sub1() + +x4: Sub1[Literal[1]] = sub1(1) +reveal_type(x4) # revealed: Sub1[Literal[1]] + +x5: Sup1[Literal[1]] = sub1(1) +reveal_type(x5) # revealed: Sub1[Literal[1]] + +x6: Sup1[Literal[1]] | None = sub1(1) +reveal_type(x6) # revealed: Sub1[Literal[1]] + +x7: Sup1[Literal[1]] | None = sub1(1) +reveal_type(x7) # revealed: Sub1[Literal[1]] + +class Sup2[T, U]: + value: tuple[T, U] + +class Sub2[T, U](Sup2[T, Any], Sup2[Any, U]): ... + +def sub2[T, U](x: T, y: U) -> Sub2[T, U]: + return Sub2() + +x8 = sub2(1, 2) +reveal_type(x8) # revealed: Sub2[int, int] + +x9: Sup2[Literal[1], Literal[2]] = sub2(1, 2) +reveal_type(x9) # revealed: Sub2[Literal[1], Literal[2]] +``` diff --git a/crates/ty_python_semantic/resources/mdtest/type_compendium/tuple.md b/crates/ty_python_semantic/resources/mdtest/type_compendium/tuple.md index e323d25a17beab..e2c45cc7f1c44e 100644 --- a/crates/ty_python_semantic/resources/mdtest/type_compendium/tuple.md +++ b/crates/ty_python_semantic/resources/mdtest/type_compendium/tuple.md @@ -57,6 +57,9 @@ reveal_type(tuple((1, 2))) # revealed: tuple[Literal[1], Literal[2]] reveal_type(tuple([1])) # revealed: tuple[Unknown | int, ...] +x1: tuple[int, ...] = tuple([1]) +reveal_type(x1) # revealed: tuple[int, ...] + # error: [invalid-argument-type] reveal_type(tuple[int]([1])) # revealed: tuple[int] diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 7c65a4715e8d0a..14d5a9248722e0 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -2,6 +2,7 @@ use compact_str::{CompactString, ToCompactString}; use infer::nearest_enclosing_class; use itertools::{Either, Itertools}; use ruff_diagnostics::{Edit, Fix}; +use rustc_hash::FxHashSet; use std::borrow::Cow; use std::time::Duration; @@ -1000,7 +1001,10 @@ impl<'db> Type<'db> { } /// If this type is a class instance, returns its specialization. - pub(crate) fn class_specialization(self, db: &'db dyn Db) -> Option> { + pub(crate) fn class_specialization( + self, + db: &'db dyn Db, + ) -> Option<(ClassType<'db>, Specialization<'db>)> { self.specialization_of_optional(db, None) } @@ -1011,15 +1015,17 @@ impl<'db> Type<'db> { expected_class: ClassLiteral<'_>, ) -> Option> { self.specialization_of_optional(db, Some(expected_class)) + .map(|(_, specialization)| specialization) } fn specialization_of_optional( self, db: &'db dyn Db, expected_class: Option>, - ) -> Option> { + ) -> Option<(ClassType<'db>, Specialization<'db>)> { let class_type = match self { Type::NominalInstance(instance) => instance, + Type::ProtocolInstance(instance) => instance.to_nominal_instance()?, Type::TypeAlias(alias) => alias.value_type(db).as_nominal_instance()?, _ => return None, } @@ -1030,7 +1036,64 @@ impl<'db> Type<'db> { return None; } - specialization + Some((class_type, specialization?)) + } + + /// Given a type variable `T` from the generic context of a class `C`: + /// - If `self` is a specialized instance of `C`, returns the type assigned to `T` on `self`. + /// - If `self` is a specialized instance of some class `A`, and `C` is a subclass of `A` + /// such that the type variable `U` on `A` is specialized to `T`, returns the type + /// assigned to `U` on `self`. + pub(crate) fn find_type_var_from( + self, + db: &'db dyn Db, + bound_typevar: BoundTypeVarInstance<'db>, + class: ClassLiteral<'db>, + ) -> Option> { + self.find_type_var_from_impl(db, bound_typevar, class, &mut FxHashSet::default()) + } + + pub(crate) fn find_type_var_from_impl( + self, + db: &'db dyn Db, + bound_typevar: BoundTypeVarInstance<'db>, + class: ClassLiteral<'db>, + visited: &mut FxHashSet<(ClassLiteral<'db>, BoundTypeVarIdentity<'db>)>, + ) -> Option> { + if let Some(specialization) = self.specialization_of(db, class) { + return specialization.get(db, bound_typevar); + } + + // TODO: We should use the constraint solver here to determine the type mappings for more + // complex subtyping relationships, e.g., `type[C[T]]` to `Callable[..., T]`, or unions + // containing multiple generic elements. + for base in class.iter_mro(db, None) { + let Some((origin, Some(specialization))) = + base.into_class().map(|class| class.class_literal(db)) + else { + continue; + }; + + for (base_typevar, base_ty) in specialization + .generic_context(db) + .variables(db) + .zip(specialization.types(db)) + { + if *base_ty == Type::TypeVar(bound_typevar) { + if !visited.insert((origin, base_typevar.identity(db))) { + return None; + } + + if let Some(ty) = + self.find_type_var_from_impl(db, base_typevar, origin, visited) + { + return Some(ty); + } + } + } + } + + None } /// Returns the top materialization (or upper bound materialization) of this type, which is the @@ -3852,20 +3915,20 @@ impl<'db> Type<'db> { return; }; - let tcx_specialization = tcx.annotation.and_then(|tcx| { - tcx.filter_union(db, |ty| ty.specialization_of(db, class_literal).is_some()) - .specialization_of(db, class_literal) - }); - - for (typevar, ty) in specialization + for (type_var, ty) in specialization .generic_context(db) .variables(db) .zip(specialization.types(db)) { - let variance = typevar.variance_with_polarity(db, polarity); - let tcx = TypeContext::new(tcx_specialization.and_then(|spec| spec.get(db, typevar))); + let variance = type_var.variance_with_polarity(db, polarity); + let tcx = tcx.and_then(|tcx| { + tcx.filter_union(db, |ty| { + ty.find_type_var_from(db, type_var, class_literal).is_some() + }) + .find_type_var_from(db, type_var, class_literal) + }); - f(typevar, *ty, variance, tcx); + f(type_var, *ty, variance, tcx); visitor.visit(*ty, || { ty.visit_specialization_impl(db, tcx, variance, f, visitor); @@ -6059,30 +6122,35 @@ impl<'db> Type<'db> { } Some(KnownClass::Tuple) => { - let object = Type::object(); + let element_ty = + BoundTypeVarInstance::synthetic(db, "T", TypeVarVariance::Covariant); // ```py - // class tuple: + // class tuple(Sequence[_T_co]): // @overload // def __new__(cls) -> tuple[()]: ... // @overload - // def __new__(cls, iterable: Iterable[object]) -> tuple[object, ...]: ... + // def __new__(cls, iterable: Iterable[_T_co]) -> tuple[_T_co, ...]: ... // ``` CallableBinding::from_overloads( self, [ Signature::new(Parameters::empty(), Some(Type::empty_tuple(db))), - Signature::new( + Signature::new_generic( + Some(GenericContext::from_typevar_instances(db, [element_ty])), Parameters::new( db, [Parameter::positional_only(Some(Name::new_static( "iterable", ))) .with_annotated_type( - KnownClass::Iterable.to_specialized_instance(db, [object]), + KnownClass::Iterable.to_specialized_instance( + db, + [Type::TypeVar(element_ty)], + ), )], ), - Some(Type::homogeneous_tuple(db, object)), + Some(Type::homogeneous_tuple(db, Type::TypeVar(element_ty))), ), ], ) diff --git a/crates/ty_python_semantic/src/types/bound_super.rs b/crates/ty_python_semantic/src/types/bound_super.rs index 04cd24e40e6a15..bd8ac8a6a4f59b 100644 --- a/crates/ty_python_semantic/src/types/bound_super.rs +++ b/crates/ty_python_semantic/src/types/bound_super.rs @@ -321,7 +321,7 @@ impl<'db> BoundSuperType<'db> { Type::NominalInstance(instance) => SuperOwnerKind::Instance(instance), Type::ProtocolInstance(protocol) => { - if let Some(nominal_instance) = protocol.as_nominal_type() { + if let Some(nominal_instance) = protocol.to_nominal_instance() { SuperOwnerKind::Instance(nominal_instance) } else { return Err(BoundSuperError::AbstractOwnerType { diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index 2093f2f377b963..bf7be4917a51d1 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -2818,10 +2818,32 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { // Prefer the declared type of generic classes. let preferred_type_mappings = return_with_tcx.and_then(|(return_ty, tcx)| { - tcx.filter_union(self.db, |ty| ty.class_specialization(self.db).is_some()) - .class_specialization(self.db)?; + let tcx = tcx.filter_union(self.db, |ty| ty.class_specialization(self.db).is_some()); + let return_ty = + return_ty.filter_union(self.db, |ty| ty.class_specialization(self.db).is_some()); + + // TODO: We should use the constraint solver here to determine the type mappings for more + // complex subtyping relationships, e.g., `type[C[T]]` to `Callable[..., T]`, or unions + // containing multiple generic elements. + if let Some((class_literal, _)) = return_ty.class_specialization(self.db) + && let Some(generic_alias) = class_literal.into_generic_alias() + { + let specialization = generic_alias.specialization(self.db); + for (class_type_var, return_ty) in specialization + .generic_context(self.db) + .variables(self.db) + .zip(specialization.types(self.db)) + { + if let Some(ty) = tcx.find_type_var_from( + self.db, + class_type_var, + generic_alias.origin(self.db), + ) { + builder.infer(*return_ty, ty).ok()?; + } + } + } - builder.infer(return_ty, tcx).ok()?; Some(builder.type_mappings().clone()) }); diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index b9adc93eb28ff1..0dd8803baad77a 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -381,6 +381,12 @@ impl<'db> TypeContext<'db> { } } + pub(crate) fn and_then(self, f: impl FnOnce(Type<'db>) -> Option>) -> Self { + Self { + annotation: self.annotation.and_then(f), + } + } + pub(crate) fn is_typealias(&self) -> bool { self.annotation .is_some_and(|ty| ty.is_typealias_special_form()) diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index d3a65e88b20558..a5734b0be1d1f6 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -7497,41 +7497,24 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { annotation.filter_disjoint_elements(self.db(), collection_ty, inferable) }); - // Extract the annotated type of `T`, if provided. - let annotated_elt_tys = tcx - .known_specialization(self.db(), collection_class) - .map(|specialization| specialization.types(self.db())); - // Create a set of constraints to infer a precise type for `T`. let mut builder = SpecializationBuilder::new(self.db(), inferable); - match annotated_elt_tys { - // The annotated type acts as a constraint for `T`. - // - // Note that we infer the annotated type _before_ the elements, to more closely match the - // order of any unions as written in the type annotation. - Some(annotated_elt_tys) => { - for (elt_ty, annotated_elt_ty) in iter::zip(elt_tys.clone(), annotated_elt_tys) { - builder - .infer(Type::TypeVar(elt_ty), *annotated_elt_ty) - .ok()?; - } - } + for elt_ty in elt_tys.clone() { + let elt_tcx = tcx + .annotation + // The annotated type acts as a constraint for `T`. + // + // Note that we infer the annotated type _before_ the elements, to more closely match the + // order of any unions as written in the type annotation. + .and_then(|tcx| tcx.find_type_var_from(self.db(), elt_ty, class_literal)) + // If a valid type annotation was not provided, avoid restricting the type of the collection + // by unioning the inferred type with `Unknown`. + .unwrap_or(Type::unknown()); - // If a valid type annotation was not provided, avoid restricting the type of the collection - // by unioning the inferred type with `Unknown`. - None => { - for elt_ty in elt_tys.clone() { - builder.infer(Type::TypeVar(elt_ty), Type::unknown()).ok()?; - } - } + builder.infer(Type::TypeVar(elt_ty), elt_tcx).ok()?; } - let elt_tcxs = match annotated_elt_tys { - None => Either::Left(iter::repeat(TypeContext::default())), - Some(tys) => Either::Right(tys.iter().map(|ty| TypeContext::new(Some(*ty)))), - }; - for elts in elts { // An unpacking expression for a dictionary. if let &[None, Some(value)] = elts.as_slice() { @@ -7554,10 +7537,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } // The inferred type of each element acts as an additional constraint on `T`. - for (elt, elt_ty, elt_tcx) in itertools::izip!(elts, elt_tys.clone(), elt_tcxs.clone()) - { + for (elt, elt_ty) in iter::zip(elts, elt_tys.clone()) { let Some(elt) = elt else { continue }; + let elt_tcx = + tcx.and_then(|tcx| tcx.find_type_var_from(self.db(), elt_ty, class_literal)); let inferred_elt_ty = infer_elt_expression(self, elt, elt_tcx); // Simplify the inference based on the declared type of the element. diff --git a/crates/ty_python_semantic/src/types/instance.rs b/crates/ty_python_semantic/src/types/instance.rs index fb53f10ef4d509..f6e198d591465a 100644 --- a/crates/ty_python_semantic/src/types/instance.rs +++ b/crates/ty_python_semantic/src/types/instance.rs @@ -165,7 +165,7 @@ impl<'db> Type<'db> { // This matches the behaviour of other type checkers, and is required for us to // recognise `str` as a subtype of `Container[str]`. structurally_satisfied.or(db, || { - let Some(nominal_instance) = protocol.as_nominal_type() else { + let Some(nominal_instance) = protocol.to_nominal_instance() else { return ConstraintSet::from(false); }; @@ -175,7 +175,7 @@ impl<'db> Type<'db> { // `Q`'s members in a Liskov-incompatible way. let type_to_test = self .as_protocol_instance() - .and_then(ProtocolInstanceType::as_nominal_type) + .and_then(ProtocolInstanceType::to_nominal_instance) .map(Type::NominalInstance) .unwrap_or(self); @@ -650,7 +650,7 @@ impl<'db> ProtocolInstanceType<'db> { /// If this is a synthesized protocol that does not correspond to a class definition /// in source code, return `None`. These are "pure" abstract types, that cannot be /// treated in a nominal way. - pub(super) fn as_nominal_type(self) -> Option> { + pub(super) fn to_nominal_instance(self) -> Option> { match self.inner { Protocol::FromClass(class) => { Some(NominalInstanceType(NominalInstanceInner::NonTuple(*class)))