Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -542,8 +542,7 @@ e: list[Any] | None = [1]
reveal_type(e) # revealed: list[Any]

f: list[Any] | None = f2(1)
# TODO: Better constraint solver.
reveal_type(f) # revealed: list[int] | None
reveal_type(f) # revealed: list[Any] | None

g: list[Any] | dict[Any, Any] = f3(1)
# TODO: Better constraint solver.
Expand Down Expand Up @@ -600,6 +599,48 @@ reveal_type(x7) # revealed: Contravariant[Any]
reveal_type(x8) # revealed: Invariant[Any]
```

## Declared type preference sees through subtyping

```toml
[environment]
python-version = "3.12"
```

```py
from typing import Any, Iterable, Literal, MutableSequence, Sequence

x1: Sequence[Any] = [1, 2, 3]
reveal_type(x1) # revealed: list[Any]

x2: MutableSequence[Any] = [1, 2, 3]
reveal_type(x2) # revealed: list[Any]

x3: Iterable[Any] = [1, 2, 3]
reveal_type(x3) # revealed: list[Any]

class X[T]:
value: T

def __init__(self, value: T): ...

class A[T](X[T]): ...

def a[T](value: T) -> A[T]:
return A(value)

x4: A[object] = A(1)
reveal_type(x4) # revealed: A[object]

x5: X[object] = A(1)
reveal_type(x5) # revealed: A[object]

x6: X[object] | None = A(1)
reveal_type(x6) # revealed: A[object]

x7: X[object] | None = a(1)
reveal_type(x7) # revealed: A[object]
```

## Narrow generic unions

```toml
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class Data:
content: list[int] = field(default_factory=list)
timestamp: datetime = field(default_factory=datetime.now, init=False)

# revealed: (self: Data, content: list[int] = list[int]) -> None
# revealed: (self: Data, content: list[int] = Unknown) -> None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any idea what's going on here?

reveal_type(Data.__init__)

data = Data([1, 2, 3])
Expand Down
49 changes: 49 additions & 0 deletions crates/ty_python_semantic/resources/mdtest/literal_promotion.md
Original file line number Diff line number Diff line change
Expand Up @@ -341,3 +341,52 @@ reveal_type(x21) # revealed: X[Literal[1]]
x22: X[Literal[1]] | None = x(1)
reveal_type(x22) # revealed: X[Literal[1]]
```

## Literal annotations see through subtyping

```py
from typing import Any, Iterable, Literal, MutableSequence, Sequence

x1: Sequence[Literal[1, 2, 3]] = [1, 2, 3]
reveal_type(x1) # revealed: list[Literal[1, 2, 3]]

x2: MutableSequence[Literal[1, 2, 3]] = [1, 2, 3]
reveal_type(x2) # revealed: list[Literal[1, 2, 3]]

x3: Iterable[Literal[1, 2, 3]] = [1, 2, 3]
reveal_type(x3) # revealed: list[Literal[1, 2, 3]]

class Sup1[T]:
value: T

class Sub1[T](Sup1[T]): ...

def sub1[T](value: T) -> Sub1[T]:
return Sub1()

x4: Sub1[Literal[1]] = sub1(1)
reveal_type(x4) # revealed: Sub1[Literal[1]]

x5: Sup1[Literal[1]] = sub1(1)
reveal_type(x5) # revealed: Sub1[Literal[1]]

x6: Sup1[Literal[1]] | None = sub1(1)
reveal_type(x6) # revealed: Sub1[Literal[1]]

x7: Sup1[Literal[1]] | None = sub1(1)
reveal_type(x7) # revealed: Sub1[Literal[1]]

class Sup2[T, U]:
value: tuple[T, U]

class Sub2[T, U](Sup2[T, Any], Sup2[Any, U]): ...

def sub2[T, U](x: T, y: U) -> Sub2[T, U]:
return Sub2()

x8 = sub2(1, 2)
reveal_type(x8) # revealed: Sub2[int, int]

x9: Sup2[Literal[1], Literal[2]] = sub2(1, 2)
reveal_type(x9) # revealed: Sub2[Literal[1], Literal[2]]
```
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ reveal_type(tuple((1, 2))) # revealed: tuple[Literal[1], Literal[2]]

reveal_type(tuple([1])) # revealed: tuple[Unknown | int, ...]

x1: tuple[int, ...] = tuple([1])
reveal_type(x1) # revealed: tuple[int, ...]

# error: [invalid-argument-type]
reveal_type(tuple[int]([1])) # revealed: tuple[int]

Expand Down
104 changes: 86 additions & 18 deletions crates/ty_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use compact_str::{CompactString, ToCompactString};
use infer::nearest_enclosing_class;
use itertools::{Either, Itertools};
use ruff_diagnostics::{Edit, Fix};
use rustc_hash::FxHashSet;

use std::borrow::Cow;
use std::time::Duration;
Expand Down Expand Up @@ -1000,7 +1001,10 @@ impl<'db> Type<'db> {
}

/// If this type is a class instance, returns its specialization.
pub(crate) fn class_specialization(self, db: &'db dyn Db) -> Option<Specialization<'db>> {
pub(crate) fn class_specialization(
self,
db: &'db dyn Db,
) -> Option<(ClassType<'db>, Specialization<'db>)> {
self.specialization_of_optional(db, None)
}

Expand All @@ -1011,15 +1015,17 @@ impl<'db> Type<'db> {
expected_class: ClassLiteral<'_>,
) -> Option<Specialization<'db>> {
self.specialization_of_optional(db, Some(expected_class))
.map(|(_, specialization)| specialization)
}

fn specialization_of_optional(
self,
db: &'db dyn Db,
expected_class: Option<ClassLiteral<'_>>,
) -> Option<Specialization<'db>> {
) -> Option<(ClassType<'db>, Specialization<'db>)> {
let class_type = match self {
Type::NominalInstance(instance) => instance,
Type::ProtocolInstance(instance) => instance.to_nominal_instance()?,
Type::TypeAlias(alias) => alias.value_type(db).as_nominal_instance()?,
_ => return None,
}
Expand All @@ -1030,7 +1036,64 @@ impl<'db> Type<'db> {
return None;
}

specialization
Some((class_type, specialization?))
}

/// Given a type variable `T` from the generic context of a class `C`:
/// - If `self` is a specialized instance of `C`, returns the type assigned to `T` on `self`.
/// - If `self` is a specialized instance of some class `A`, and `C` is a subclass of `A`
/// such that the type variable `U` on `A` is specialized to `T`, returns the type
/// assigned to `U` on `self`.
pub(crate) fn find_type_var_from(
self,
db: &'db dyn Db,
bound_typevar: BoundTypeVarInstance<'db>,
class: ClassLiteral<'db>,
) -> Option<Type<'db>> {
self.find_type_var_from_impl(db, bound_typevar, class, &mut FxHashSet::default())
}

pub(crate) fn find_type_var_from_impl(
self,
db: &'db dyn Db,
bound_typevar: BoundTypeVarInstance<'db>,
class: ClassLiteral<'db>,
visited: &mut FxHashSet<(ClassLiteral<'db>, BoundTypeVarIdentity<'db>)>,
) -> Option<Type<'db>> {
if let Some(specialization) = self.specialization_of(db, class) {
return specialization.get(db, bound_typevar);
}

// TODO: We should use the constraint solver here to determine the type mappings for more
// complex subtyping relationships, e.g., `type[C[T]]` to `Callable[..., T]`, or unions
// containing multiple generic elements.
Comment on lines +1067 to +1069
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could possibly also mention protocols here, since although this PR tackles simple cases where a nominal type actually subclasses a protocol, there are also lots of cases where a nominal type is a subtype of a protocol without actually subclassing that protocol

for base in class.iter_mro(db, None) {
let Some((origin, Some(specialization))) =
base.into_class().map(|class| class.class_literal(db))
else {
continue;
};

for (base_typevar, base_ty) in specialization
.generic_context(db)
.variables(db)
.zip(specialization.types(db))
{
if *base_ty == Type::TypeVar(bound_typevar) {
if !visited.insert((origin, base_typevar.identity(db))) {
return None;
}

if let Some(ty) =
self.find_type_var_from_impl(db, base_typevar, origin, visited)
{
return Some(ty);
}
}
}
Comment on lines +1077 to +1093
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do a similar loop through the MRO in the (old) constraint solver which doesn't seem to need to do the FxHashSet memoization thing -- could you also take a similar approach here? Might that be more performant?

// Extract formal_alias if this is a generic class
let formal_alias = match formal {
Type::NominalInstance(formal_nominal) => {
formal_nominal.class(self.db).into_generic_alias()
}
// TODO: This will only handle classes that explicit implement a generic protocol
// by listing it as a base class. To handle classes that implicitly implement a
// generic protocol, we will need to check the types of the protocol members to be
// able to infer the specialization of the protocol that the class implements.
Type::ProtocolInstance(ProtocolInstanceType {
inner: Protocol::FromClass(class),
..
}) => class.into_generic_alias(),
_ => None,
};
if let Some(formal_alias) = formal_alias {
let formal_origin = formal_alias.origin(self.db);
for base in actual_nominal.class(self.db).iter_mro(self.db) {
let ClassBase::Class(ClassType::Generic(base_alias)) = base else {
continue;
};
if formal_origin != base_alias.origin(self.db) {
continue;
}
let generic_context = formal_alias
.specialization(self.db)
.generic_context(self.db)
.variables(self.db);
let formal_specialization =
formal_alias.specialization(self.db).types(self.db);
let base_specialization = base_alias.specialization(self.db).types(self.db);
for (typevar, formal_ty, base_ty) in itertools::izip!(
generic_context,
formal_specialization,
base_specialization
) {
let variance = typevar.variance_with_polarity(self.db, polarity);
self.infer_map_impl(*formal_ty, *base_ty, variance, &mut f)?;
}
return Ok(());
}
}
}

}

None
}

/// Returns the top materialization (or upper bound materialization) of this type, which is the
Expand Down Expand Up @@ -3852,20 +3915,20 @@ impl<'db> Type<'db> {
return;
};

let tcx_specialization = tcx.annotation.and_then(|tcx| {
tcx.filter_union(db, |ty| ty.specialization_of(db, class_literal).is_some())
.specialization_of(db, class_literal)
});

for (typevar, ty) in specialization
for (type_var, ty) in specialization
.generic_context(db)
.variables(db)
.zip(specialization.types(db))
{
let variance = typevar.variance_with_polarity(db, polarity);
let tcx = TypeContext::new(tcx_specialization.and_then(|spec| spec.get(db, typevar)));
let variance = type_var.variance_with_polarity(db, polarity);
let tcx = tcx.and_then(|tcx| {
tcx.filter_union(db, |ty| {
ty.find_type_var_from(db, type_var, class_literal).is_some()
})
.find_type_var_from(db, type_var, class_literal)
});

f(typevar, *ty, variance, tcx);
f(type_var, *ty, variance, tcx);

visitor.visit(*ty, || {
ty.visit_specialization_impl(db, tcx, variance, f, visitor);
Expand Down Expand Up @@ -6059,30 +6122,35 @@ impl<'db> Type<'db> {
}

Some(KnownClass::Tuple) => {
let object = Type::object();
let element_ty =
BoundTypeVarInstance::synthetic(db, "T", TypeVarVariance::Covariant);

// ```py
// class tuple:
// class tuple(Sequence[_T_co]):
// @overload
// def __new__(cls) -> tuple[()]: ...
// @overload
// def __new__(cls, iterable: Iterable[object]) -> tuple[object, ...]: ...
// def __new__(cls, iterable: Iterable[_T_co]) -> tuple[_T_co, ...]: ...
// ```
Comment on lines 6128 to 6134
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops, sorry... I guess I baked this bug in a long time ago for you to find now :-)

CallableBinding::from_overloads(
self,
[
Signature::new(Parameters::empty(), Some(Type::empty_tuple(db))),
Signature::new(
Signature::new_generic(
Some(GenericContext::from_typevar_instances(db, [element_ty])),
Parameters::new(
db,
[Parameter::positional_only(Some(Name::new_static(
"iterable",
)))
.with_annotated_type(
KnownClass::Iterable.to_specialized_instance(db, [object]),
KnownClass::Iterable.to_specialized_instance(
db,
[Type::TypeVar(element_ty)],
),
)],
),
Some(Type::homogeneous_tuple(db, object)),
Some(Type::homogeneous_tuple(db, Type::TypeVar(element_ty))),
),
],
)
Expand Down
2 changes: 1 addition & 1 deletion crates/ty_python_semantic/src/types/bound_super.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ impl<'db> BoundSuperType<'db> {
Type::NominalInstance(instance) => SuperOwnerKind::Instance(instance),

Type::ProtocolInstance(protocol) => {
if let Some(nominal_instance) = protocol.as_nominal_type() {
if let Some(nominal_instance) = protocol.to_nominal_instance() {
SuperOwnerKind::Instance(nominal_instance)
} else {
return Err(BoundSuperError::AbstractOwnerType {
Expand Down
28 changes: 25 additions & 3 deletions crates/ty_python_semantic/src/types/call/bind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2818,10 +2818,32 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {

// Prefer the declared type of generic classes.
let preferred_type_mappings = return_with_tcx.and_then(|(return_ty, tcx)| {
tcx.filter_union(self.db, |ty| ty.class_specialization(self.db).is_some())
.class_specialization(self.db)?;
let tcx = tcx.filter_union(self.db, |ty| ty.class_specialization(self.db).is_some());
let return_ty =
return_ty.filter_union(self.db, |ty| ty.class_specialization(self.db).is_some());

// TODO: We should use the constraint solver here to determine the type mappings for more
// complex subtyping relationships, e.g., `type[C[T]]` to `Callable[..., T]`, or unions
// containing multiple generic elements.
if let Some((class_literal, _)) = return_ty.class_specialization(self.db)
&& let Some(generic_alias) = class_literal.into_generic_alias()
{
let specialization = generic_alias.specialization(self.db);
for (class_type_var, return_ty) in specialization
.generic_context(self.db)
.variables(self.db)
.zip(specialization.types(self.db))
{
if let Some(ty) = tcx.find_type_var_from(
self.db,
class_type_var,
generic_alias.origin(self.db),
) {
builder.infer(*return_ty, ty).ok()?;
}
}
}

builder.infer(return_ty, tcx).ok()?;
Some(builder.type_mappings().clone())
});

Expand Down
6 changes: 6 additions & 0 deletions crates/ty_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,12 @@ impl<'db> TypeContext<'db> {
}
}

pub(crate) fn and_then(self, f: impl FnOnce(Type<'db>) -> Option<Type<'db>>) -> Self {
Self {
annotation: self.annotation.and_then(f),
}
}

pub(crate) fn is_typealias(&self) -> bool {
self.annotation
.is_some_and(|ty| ty.is_typealias_special_form())
Expand Down
Loading
Loading