Skip to content

Commit 51d8ef8

Browse files
committed
simple subtyping for bidirectional inference
1 parent edc6ed5 commit 51d8ef8

File tree

10 files changed

+189
-54
lines changed

10 files changed

+189
-54
lines changed

crates/ty_python_semantic/resources/mdtest/assignment/annotations.md

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -542,8 +542,7 @@ e: list[Any] | None = [1]
542542
reveal_type(e) # revealed: list[Any]
543543

544544
f: list[Any] | None = f2(1)
545-
# TODO: Better constraint solver.
546-
reveal_type(f) # revealed: list[int] | None
545+
reveal_type(f) # revealed: list[Any] | None
547546

548547
g: list[Any] | dict[Any, Any] = f3(1)
549548
# TODO: Better constraint solver.
@@ -600,6 +599,48 @@ reveal_type(x7) # revealed: Contravariant[Any]
600599
reveal_type(x8) # revealed: Invariant[Any]
601600
```
602601

602+
## Declared type preference sees through subtyping
603+
604+
```toml
605+
[environment]
606+
python-version = "3.12"
607+
```
608+
609+
```py
610+
from typing import Any, Iterable, Literal, MutableSequence, Sequence
611+
612+
x1: Sequence[Any] = [1, 2, 3]
613+
reveal_type(x1) # revealed: list[Any]
614+
615+
x2: MutableSequence[Any] = [1, 2, 3]
616+
reveal_type(x2) # revealed: list[Any]
617+
618+
x3: Iterable[Any] = [1, 2, 3]
619+
reveal_type(x3) # revealed: list[Any]
620+
621+
class X[T]:
622+
value: T
623+
624+
def __init__(self, value: T): ...
625+
626+
class A[T](X[T]): ...
627+
628+
def a[T](value: T) -> A[T]:
629+
return A(value)
630+
631+
x4: A[object] = A(1)
632+
reveal_type(x4) # revealed: A[object]
633+
634+
x5: X[object] = A(1)
635+
reveal_type(x5) # revealed: A[object]
636+
637+
x6: X[object] | None = A(1)
638+
reveal_type(x6) # revealed: A[object]
639+
640+
x7: X[object] | None = a(1)
641+
reveal_type(x7) # revealed: A[object]
642+
```
643+
603644
## Narrow generic unions
604645

605646
```toml

crates/ty_python_semantic/resources/mdtest/dataclasses/fields.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class Data:
3737
content: list[int] = field(default_factory=list)
3838
timestamp: datetime = field(default_factory=datetime.now, init=False)
3939

40-
# revealed: (self: Data, content: list[int] = list[int]) -> None
40+
# revealed: (self: Data, content: list[int] = Unknown) -> None
4141
reveal_type(Data.__init__)
4242

4343
data = Data([1, 2, 3])

crates/ty_python_semantic/resources/mdtest/literal_promotion.md

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,3 +341,40 @@ reveal_type(x21) # revealed: X[Literal[1]]
341341
x22: X[Literal[1]] | None = x(1)
342342
reveal_type(x22) # revealed: X[Literal[1]]
343343
```
344+
345+
## Literal annotations see through subtyping
346+
347+
```py
348+
from typing import Iterable, Literal, MutableSequence, Sequence
349+
350+
x1: Sequence[Literal[1, 2, 3]] = [1, 2, 3]
351+
reveal_type(x1) # revealed: list[Literal[1, 2, 3]]
352+
353+
x2: MutableSequence[Literal[1, 2, 3]] = [1, 2, 3]
354+
reveal_type(x2) # revealed: list[Literal[1, 2, 3]]
355+
356+
x3: Iterable[Literal[1, 2, 3]] = [1, 2, 3]
357+
reveal_type(x3) # revealed: list[Literal[1, 2, 3]]
358+
359+
class X[T]:
360+
value: T
361+
362+
def __init__(self, value: T): ...
363+
364+
class A[T](X[T]): ...
365+
366+
def a[T](value: T) -> A[T]:
367+
return A(value)
368+
369+
x4: A[Literal[1]] = A(1)
370+
reveal_type(x4) # revealed: A[Literal[1]]
371+
372+
x5: X[Literal[1]] = A(1)
373+
reveal_type(x5) # revealed: A[Literal[1]]
374+
375+
x6: X[Literal[1]] | None = A(1)
376+
reveal_type(x6) # revealed: A[Literal[1]]
377+
378+
x7: X[Literal[1]] | None = a(1)
379+
reveal_type(x7) # revealed: A[Literal[1]]
380+
```

crates/ty_python_semantic/resources/mdtest/type_compendium/tuple.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def f(x: Iterable[int], y: list[str], z: Never, aa: list[Never], bb: LiskovUncom
5555

5656
reveal_type(tuple((1, 2))) # revealed: tuple[Literal[1], Literal[2]]
5757

58-
reveal_type(tuple([1])) # revealed: tuple[Unknown | int, ...]
58+
reveal_type(tuple([1])) # revealed: tuple[object, ...]
5959

6060
# error: [invalid-argument-type]
6161
reveal_type(tuple[int]([1])) # revealed: tuple[int]

crates/ty_python_semantic/src/types.rs

Lines changed: 57 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -999,7 +999,10 @@ impl<'db> Type<'db> {
999999
}
10001000

10011001
/// If this type is a class instance, returns its specialization.
1002-
pub(crate) fn class_specialization(self, db: &'db dyn Db) -> Option<Specialization<'db>> {
1002+
pub(crate) fn class_specialization(
1003+
self,
1004+
db: &'db dyn Db,
1005+
) -> Option<(ClassType<'db>, Specialization<'db>)> {
10031006
self.specialization_of_optional(db, None)
10041007
}
10051008

@@ -1010,15 +1013,17 @@ impl<'db> Type<'db> {
10101013
expected_class: ClassLiteral<'_>,
10111014
) -> Option<Specialization<'db>> {
10121015
self.specialization_of_optional(db, Some(expected_class))
1016+
.map(|(_, specialization)| specialization)
10131017
}
10141018

10151019
fn specialization_of_optional(
10161020
self,
10171021
db: &'db dyn Db,
10181022
expected_class: Option<ClassLiteral<'_>>,
1019-
) -> Option<Specialization<'db>> {
1023+
) -> Option<(ClassType<'db>, Specialization<'db>)> {
10201024
let class_type = match self {
10211025
Type::NominalInstance(instance) => instance,
1026+
Type::ProtocolInstance(instance) => instance.to_nominal_instance()?,
10221027
Type::TypeAlias(alias) => alias.value_type(db).as_nominal_instance()?,
10231028
_ => return None,
10241029
}
@@ -1029,7 +1034,47 @@ impl<'db> Type<'db> {
10291034
return None;
10301035
}
10311036

1032-
specialization
1037+
Some((class_type, specialization?))
1038+
}
1039+
1040+
/// Given a type variable `T` from the generic context of a class `C`:
1041+
/// - If `self` is a specialized instance of `C`, returns the type assigned to `T` on `self`.
1042+
/// - If `self` is a specialized instance of some class `A`, and `C` is a subclass of `A`
1043+
/// such that the type variable `U` on `A` is specialized to `T`, returns the type
1044+
/// assigned to `U` on `self`.
1045+
pub(crate) fn find_type_var_from(
1046+
self,
1047+
db: &'db dyn Db,
1048+
bound_typevar: BoundTypeVarInstance<'db>,
1049+
class: ClassLiteral<'db>,
1050+
) -> Option<Type<'db>> {
1051+
if let Some(specialization) = self.specialization_of(db, class) {
1052+
return specialization.get(db, bound_typevar);
1053+
}
1054+
1055+
// TODO: We should use the constraint solver here to determine the type mappings for more
1056+
// complex subtyping relationships, e.g., `type[C[T]]` to `Callable[..., T]`, or unions
1057+
// containing multiple generic elements.
1058+
for base in class.iter_mro(db, None) {
1059+
let Some(ClassType::Generic(class)) = base.into_class() else {
1060+
continue;
1061+
};
1062+
1063+
for (base_typevar, base_ty) in class
1064+
.specialization(db)
1065+
.generic_context(db)
1066+
.variables(db)
1067+
.zip(class.specialization(db).types(db))
1068+
{
1069+
if *base_ty == Type::TypeVar(bound_typevar) {
1070+
if let Some(ty) = self.find_type_var_from(db, base_typevar, class.origin(db)) {
1071+
return Some(ty);
1072+
}
1073+
}
1074+
}
1075+
}
1076+
1077+
None
10331078
}
10341079

10351080
/// Returns the top materialization (or upper bound materialization) of this type, which is the
@@ -3842,20 +3887,20 @@ impl<'db> Type<'db> {
38423887
return;
38433888
};
38443889

3845-
let tcx_specialization = tcx.annotation.and_then(|tcx| {
3846-
tcx.filter_union(db, |ty| ty.specialization_of(db, class_literal).is_some())
3847-
.specialization_of(db, class_literal)
3848-
});
3849-
3850-
for (typevar, ty) in specialization
3890+
for (type_var, ty) in specialization
38513891
.generic_context(db)
38523892
.variables(db)
38533893
.zip(specialization.types(db))
38543894
{
3855-
let variance = typevar.variance_with_polarity(db, polarity);
3856-
let tcx = TypeContext::new(tcx_specialization.and_then(|spec| spec.get(db, typevar)));
3895+
let variance = type_var.variance_with_polarity(db, polarity);
3896+
let tcx = tcx.and_then(|tcx| {
3897+
tcx.filter_union(db, |ty| {
3898+
ty.find_type_var_from(db, type_var, class_literal).is_some()
3899+
})
3900+
.find_type_var_from(db, type_var, class_literal)
3901+
});
38573902

3858-
f(typevar, *ty, variance, tcx);
3903+
f(type_var, *ty, variance, tcx);
38593904

38603905
visitor.visit(*ty, || {
38613906
ty.visit_specialization_impl(db, tcx, variance, f, visitor);

crates/ty_python_semantic/src/types/bound_super.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ impl<'db> BoundSuperType<'db> {
321321
Type::NominalInstance(instance) => SuperOwnerKind::Instance(instance),
322322

323323
Type::ProtocolInstance(protocol) => {
324-
if let Some(nominal_instance) = protocol.as_nominal_type() {
324+
if let Some(nominal_instance) = protocol.to_nominal_instance() {
325325
SuperOwnerKind::Instance(nominal_instance)
326326
} else {
327327
return Err(BoundSuperError::AbstractOwnerType {

crates/ty_python_semantic/src/types/call/bind.rs

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2818,10 +2818,32 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
28182818

28192819
// Prefer the declared type of generic classes.
28202820
let preferred_type_mappings = return_with_tcx.and_then(|(return_ty, tcx)| {
2821-
tcx.filter_union(self.db, |ty| ty.class_specialization(self.db).is_some())
2822-
.class_specialization(self.db)?;
2821+
let tcx = tcx.filter_union(self.db, |ty| ty.class_specialization(self.db).is_some());
2822+
let return_ty =
2823+
return_ty.filter_union(self.db, |ty| ty.class_specialization(self.db).is_some());
2824+
2825+
// TODO: We should use the constraint solver here to determine the type mappings for more
2826+
// complex subtyping relationships, e.g., `type[C[T]]` to `Callable[..., T]`, or unions
2827+
// containing multiple generic elements.
2828+
if let Some((class_literal, _)) = return_ty.class_specialization(self.db)
2829+
&& let Some(generic_alias) = class_literal.into_generic_alias()
2830+
{
2831+
let specialization = generic_alias.specialization(self.db);
2832+
for (class_type_var, return_ty) in specialization
2833+
.generic_context(self.db)
2834+
.variables(self.db)
2835+
.zip(specialization.types(self.db))
2836+
{
2837+
if let Some(ty) = tcx.find_type_var_from(
2838+
self.db,
2839+
class_type_var,
2840+
generic_alias.origin(self.db),
2841+
) {
2842+
builder.infer(*return_ty, ty).ok()?;
2843+
}
2844+
}
2845+
}
28232846

2824-
builder.infer(return_ty, tcx).ok()?;
28252847
Some(builder.type_mappings().clone())
28262848
});
28272849

crates/ty_python_semantic/src/types/infer.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,12 @@ impl<'db> TypeContext<'db> {
380380
}
381381
}
382382

383+
pub(crate) fn and_then(self, f: impl FnOnce(Type<'db>) -> Option<Type<'db>>) -> Self {
384+
Self {
385+
annotation: self.annotation.and_then(f),
386+
}
387+
}
388+
383389
pub(crate) fn is_typealias(&self) -> bool {
384390
self.annotation
385391
.is_some_and(|ty| ty.is_typealias_special_form())

crates/ty_python_semantic/src/types/infer/builder.rs

Lines changed: 15 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7497,41 +7497,24 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
74977497
annotation.filter_disjoint_elements(self.db(), collection_ty, inferable)
74987498
});
74997499

7500-
// Extract the annotated type of `T`, if provided.
7501-
let annotated_elt_tys = tcx
7502-
.known_specialization(self.db(), collection_class)
7503-
.map(|specialization| specialization.types(self.db()));
7504-
75057500
// Create a set of constraints to infer a precise type for `T`.
75067501
let mut builder = SpecializationBuilder::new(self.db(), inferable);
75077502

7508-
match annotated_elt_tys {
7509-
// The annotated type acts as a constraint for `T`.
7510-
//
7511-
// Note that we infer the annotated type _before_ the elements, to more closely match the
7512-
// order of any unions as written in the type annotation.
7513-
Some(annotated_elt_tys) => {
7514-
for (elt_ty, annotated_elt_ty) in iter::zip(elt_tys.clone(), annotated_elt_tys) {
7515-
builder
7516-
.infer(Type::TypeVar(elt_ty), *annotated_elt_ty)
7517-
.ok()?;
7518-
}
7519-
}
7503+
for elt_ty in elt_tys.clone() {
7504+
let elt_tcx = tcx
7505+
.annotation
7506+
// The annotated type acts as a constraint for `T`.
7507+
//
7508+
// Note that we infer the annotated type _before_ the elements, to more closely match the
7509+
// order of any unions as written in the type annotation.
7510+
.and_then(|tcx| tcx.find_type_var_from(self.db(), elt_ty, class_literal))
7511+
// If a valid type annotation was not provided, avoid restricting the type of the collection
7512+
// by unioning the inferred type with `Unknown`.
7513+
.unwrap_or(Type::unknown());
75207514

7521-
// If a valid type annotation was not provided, avoid restricting the type of the collection
7522-
// by unioning the inferred type with `Unknown`.
7523-
None => {
7524-
for elt_ty in elt_tys.clone() {
7525-
builder.infer(Type::TypeVar(elt_ty), Type::unknown()).ok()?;
7526-
}
7527-
}
7515+
builder.infer(Type::TypeVar(elt_ty), elt_tcx).ok()?;
75287516
}
75297517

7530-
let elt_tcxs = match annotated_elt_tys {
7531-
None => Either::Left(iter::repeat(TypeContext::default())),
7532-
Some(tys) => Either::Right(tys.iter().map(|ty| TypeContext::new(Some(*ty)))),
7533-
};
7534-
75357518
for elts in elts {
75367519
// An unpacking expression for a dictionary.
75377520
if let &[None, Some(value)] = elts.as_slice() {
@@ -7554,10 +7537,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
75547537
}
75557538

75567539
// The inferred type of each element acts as an additional constraint on `T`.
7557-
for (elt, elt_ty, elt_tcx) in itertools::izip!(elts, elt_tys.clone(), elt_tcxs.clone())
7558-
{
7540+
for (elt, elt_ty) in iter::zip(elts, elt_tys.clone()) {
75597541
let Some(elt) = elt else { continue };
75607542

7543+
let elt_tcx =
7544+
tcx.and_then(|tcx| tcx.find_type_var_from(self.db(), elt_ty, class_literal));
75617545
let inferred_elt_ty = infer_elt_expression(self, elt, elt_tcx);
75627546

75637547
// Simplify the inference based on the declared type of the element.

crates/ty_python_semantic/src/types/instance.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ impl<'db> Type<'db> {
165165
// This matches the behaviour of other type checkers, and is required for us to
166166
// recognise `str` as a subtype of `Container[str]`.
167167
structurally_satisfied.or(db, || {
168-
let Some(nominal_instance) = protocol.as_nominal_type() else {
168+
let Some(nominal_instance) = protocol.to_nominal_instance() else {
169169
return ConstraintSet::from(false);
170170
};
171171

@@ -175,7 +175,7 @@ impl<'db> Type<'db> {
175175
// `Q`'s members in a Liskov-incompatible way.
176176
let type_to_test = self
177177
.as_protocol_instance()
178-
.and_then(ProtocolInstanceType::as_nominal_type)
178+
.and_then(ProtocolInstanceType::to_nominal_instance)
179179
.map(Type::NominalInstance)
180180
.unwrap_or(self);
181181

@@ -650,7 +650,7 @@ impl<'db> ProtocolInstanceType<'db> {
650650
/// If this is a synthesized protocol that does not correspond to a class definition
651651
/// in source code, return `None`. These are "pure" abstract types, that cannot be
652652
/// treated in a nominal way.
653-
pub(super) fn as_nominal_type(self) -> Option<NominalInstanceType<'db>> {
653+
pub(super) fn to_nominal_instance(self) -> Option<NominalInstanceType<'db>> {
654654
match self.inner {
655655
Protocol::FromClass(class) => {
656656
Some(NominalInstanceType(NominalInstanceInner::NonTuple(*class)))

0 commit comments

Comments
 (0)