Skip to content

Commit b7a9f44

Browse files
committed
fix disjointness checks with type-of @final classes
1 parent 38334c8 commit b7a9f44

File tree

5 files changed

+275
-108
lines changed

5 files changed

+275
-108
lines changed

crates/ty_python_semantic/resources/mdtest/narrow/type.md

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,7 @@ def f(x: A[int] | B):
9292
reveal_type(x) # revealed: A[int] | B
9393

9494
if type(x) is A:
95-
# TODO: this should be `A[int]`, but `A[int] | B` would be better than `Never`
96-
reveal_type(x) # revealed: Never
95+
reveal_type(x) # revealed: A[int]
9796
else:
9897
reveal_type(x) # revealed: A[int] | B
9998

@@ -111,8 +110,7 @@ def f(x: A[int] | B):
111110
if type(x) is not A:
112111
reveal_type(x) # revealed: A[int] | B
113112
else:
114-
# TODO: this should be `A[int]`, but `A[int] | B` would be better than `Never`
115-
reveal_type(x) # revealed: Never
113+
reveal_type(x) # revealed: A[int]
116114

117115
if type(x) is not B:
118116
reveal_type(x) # revealed: A[int] | B
@@ -217,8 +215,7 @@ class B: ...
217215

218216
def _[T](x: A | B):
219217
if type(x) is A[str]:
220-
# TODO: `type()` never returns a generic alias, so `type(x)` cannot be `A[str]`
221-
reveal_type(x) # revealed: A[int] | B
218+
reveal_type(x) # revealed: Never
222219
else:
223220
reveal_type(x) # revealed: A[int] | B
224221
```

crates/ty_python_semantic/resources/mdtest/type_of/basic.md

Lines changed: 168 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ same also applies to enum classes with members, which are implicitly final:
215215

216216
```toml
217217
[environment]
218-
python-version = "3.10"
218+
python-version = "3.12"
219219
```
220220

221221
```py
@@ -235,3 +235,170 @@ def _(x: type[Foo], y: type[EllipsisType], z: type[Answer]):
235235
reveal_type(y) # revealed: <class 'EllipsisType'>
236236
reveal_type(z) # revealed: <class 'Answer'>
237237
```
238+
239+
## Subtyping `@final` classes
240+
241+
```toml
242+
[environment]
243+
python-version = "3.12"
244+
```
245+
246+
```py
247+
from typing import final, Any
248+
from ty_extensions import is_assignable_to, is_subtype_of, is_disjoint_from, static_assert
249+
250+
class Biv[T]: ...
251+
252+
class Cov[T]:
253+
def pop(self) -> T:
254+
raise NotImplementedError
255+
256+
class Contra[T]:
257+
def push(self, value: T) -> None:
258+
pass
259+
260+
class Inv[T]:
261+
x: T
262+
263+
@final
264+
class BivSub[T](Biv[T]): ...
265+
266+
@final
267+
class CovSub[T](Cov[T]): ...
268+
269+
@final
270+
class ContraSub[T](Contra[T]): ...
271+
272+
@final
273+
class InvSub[T](Inv[T]): ...
274+
275+
def _[T, U]():
276+
static_assert(is_subtype_of(type[BivSub[T]], type[BivSub[U]]))
277+
static_assert(not is_disjoint_from(type[BivSub[U]], type[BivSub[T]]))
278+
279+
static_assert(not is_subtype_of(type[CovSub[T]], type[CovSub[U]]))
280+
static_assert(not is_disjoint_from(type[CovSub[U]], type[CovSub[T]]))
281+
282+
static_assert(not is_subtype_of(type[ContraSub[T]], type[ContraSub[U]]))
283+
static_assert(not is_disjoint_from(type[ContraSub[U]], type[ContraSub[T]]))
284+
285+
static_assert(not is_subtype_of(type[InvSub[T]], type[InvSub[U]]))
286+
static_assert(not is_disjoint_from(type[InvSub[U]], type[InvSub[T]]))
287+
288+
def _():
289+
static_assert(is_subtype_of(type[BivSub[bool]], type[BivSub[int]]))
290+
static_assert(is_subtype_of(type[BivSub[int]], type[BivSub[bool]]))
291+
static_assert(is_disjoint_from(type[BivSub[int]], type[BivSub[str]]))
292+
static_assert(not is_disjoint_from(type[BivSub[bool]], type[BivSub[int]]))
293+
294+
static_assert(is_subtype_of(type[CovSub[bool]], type[CovSub[int]]))
295+
static_assert(not is_subtype_of(type[CovSub[int]], type[CovSub[bool]]))
296+
static_assert(is_disjoint_from(type[CovSub[int]], type[CovSub[str]]))
297+
static_assert(not is_disjoint_from(type[CovSub[bool]], type[CovSub[int]]))
298+
299+
static_assert(not is_subtype_of(type[ContraSub[bool]], type[ContraSub[int]]))
300+
static_assert(is_subtype_of(type[ContraSub[int]], type[ContraSub[bool]]))
301+
static_assert(is_disjoint_from(type[ContraSub[int]], type[ContraSub[str]]))
302+
static_assert(not is_disjoint_from(type[ContraSub[bool]], type[ContraSub[int]]))
303+
304+
static_assert(not is_subtype_of(type[InvSub[bool]], type[InvSub[int]]))
305+
static_assert(not is_subtype_of(type[InvSub[int]], type[InvSub[bool]]))
306+
static_assert(is_disjoint_from(type[InvSub[int]], type[InvSub[str]]))
307+
static_assert(not is_disjoint_from(type[InvSub[bool]], type[InvSub[int]]))
308+
309+
def _[T]():
310+
static_assert(is_subtype_of(type[BivSub[T]], type[BivSub[Any]]))
311+
static_assert(is_subtype_of(type[BivSub[Any]], type[BivSub[T]]))
312+
static_assert(is_assignable_to(type[BivSub[T]], type[BivSub[Any]]))
313+
static_assert(is_assignable_to(type[BivSub[Any]], type[BivSub[T]]))
314+
static_assert(not is_disjoint_from(type[BivSub[T]], type[BivSub[Any]]))
315+
316+
static_assert(not is_subtype_of(type[CovSub[T]], type[CovSub[Any]]))
317+
static_assert(not is_subtype_of(type[CovSub[Any]], type[CovSub[T]]))
318+
static_assert(is_assignable_to(type[CovSub[T]], type[CovSub[Any]]))
319+
static_assert(is_assignable_to(type[CovSub[Any]], type[CovSub[T]]))
320+
static_assert(not is_disjoint_from(type[CovSub[T]], type[CovSub[Any]]))
321+
322+
static_assert(not is_subtype_of(type[ContraSub[T]], type[ContraSub[Any]]))
323+
static_assert(not is_subtype_of(type[ContraSub[Any]], type[ContraSub[T]]))
324+
static_assert(is_assignable_to(type[ContraSub[T]], type[ContraSub[Any]]))
325+
static_assert(is_assignable_to(type[ContraSub[Any]], type[ContraSub[T]]))
326+
static_assert(not is_disjoint_from(type[ContraSub[T]], type[ContraSub[Any]]))
327+
328+
static_assert(not is_subtype_of(type[InvSub[T]], type[InvSub[Any]]))
329+
static_assert(not is_subtype_of(type[InvSub[Any]], type[InvSub[T]]))
330+
static_assert(is_assignable_to(type[InvSub[T]], type[InvSub[Any]]))
331+
static_assert(is_assignable_to(type[InvSub[Any]], type[InvSub[T]]))
332+
static_assert(not is_disjoint_from(type[InvSub[T]], type[InvSub[Any]]))
333+
334+
def _[T, U]():
335+
static_assert(is_subtype_of(type[BivSub[T]], type[Biv[T]]))
336+
static_assert(not is_subtype_of(type[Biv[T]], type[BivSub[T]]))
337+
static_assert(not is_disjoint_from(type[BivSub[T]], type[Biv[T]]))
338+
static_assert(not is_disjoint_from(type[BivSub[U]], type[Biv[T]]))
339+
static_assert(not is_disjoint_from(type[BivSub[U]], type[Biv[U]]))
340+
341+
static_assert(is_subtype_of(type[CovSub[T]], type[Cov[T]]))
342+
static_assert(not is_subtype_of(type[Cov[T]], type[CovSub[T]]))
343+
static_assert(not is_disjoint_from(type[CovSub[T]], type[Cov[T]]))
344+
static_assert(not is_disjoint_from(type[CovSub[U]], type[Cov[T]]))
345+
static_assert(not is_disjoint_from(type[CovSub[U]], type[Cov[U]]))
346+
347+
static_assert(is_subtype_of(type[ContraSub[T]], type[Contra[T]]))
348+
static_assert(not is_subtype_of(type[Contra[T]], type[ContraSub[T]]))
349+
static_assert(not is_disjoint_from(type[ContraSub[T]], type[Contra[T]]))
350+
static_assert(not is_disjoint_from(type[ContraSub[U]], type[Contra[T]]))
351+
static_assert(not is_disjoint_from(type[ContraSub[U]], type[Contra[U]]))
352+
353+
static_assert(is_subtype_of(type[InvSub[T]], type[Inv[T]]))
354+
static_assert(not is_subtype_of(type[Inv[T]], type[InvSub[T]]))
355+
static_assert(not is_disjoint_from(type[InvSub[T]], type[Inv[T]]))
356+
static_assert(not is_disjoint_from(type[InvSub[U]], type[Inv[T]]))
357+
static_assert(not is_disjoint_from(type[InvSub[U]], type[Inv[U]]))
358+
359+
def _():
360+
static_assert(is_subtype_of(type[BivSub[bool]], type[Biv[int]]))
361+
static_assert(is_subtype_of(type[BivSub[int]], type[Biv[bool]]))
362+
static_assert(not is_disjoint_from(type[BivSub[bool]], type[Biv[int]]))
363+
static_assert(not is_disjoint_from(type[BivSub[int]], type[Biv[bool]]))
364+
365+
static_assert(is_subtype_of(type[CovSub[bool]], type[Cov[int]]))
366+
static_assert(not is_subtype_of(type[CovSub[int]], type[Cov[bool]]))
367+
static_assert(not is_disjoint_from(type[CovSub[bool]], type[Cov[int]]))
368+
static_assert(not is_disjoint_from(type[CovSub[int]], type[Cov[bool]]))
369+
370+
static_assert(not is_subtype_of(type[ContraSub[bool]], type[Contra[int]]))
371+
static_assert(is_subtype_of(type[ContraSub[int]], type[Contra[bool]]))
372+
static_assert(not is_disjoint_from(type[ContraSub[int]], type[Contra[bool]]))
373+
static_assert(not is_disjoint_from(type[ContraSub[bool]], type[Contra[int]]))
374+
375+
static_assert(not is_subtype_of(type[InvSub[bool]], type[Inv[int]]))
376+
static_assert(not is_subtype_of(type[InvSub[int]], type[Inv[bool]]))
377+
static_assert(not is_disjoint_from(type[InvSub[bool]], type[Inv[int]]))
378+
static_assert(not is_disjoint_from(type[InvSub[int]], type[Inv[bool]]))
379+
380+
def _[T]():
381+
static_assert(is_subtype_of(type[BivSub[T]], type[Biv[Any]]))
382+
static_assert(is_subtype_of(type[BivSub[Any]], type[Biv[T]]))
383+
static_assert(is_assignable_to(type[BivSub[T]], type[Biv[Any]]))
384+
static_assert(is_assignable_to(type[BivSub[Any]], type[Biv[T]]))
385+
static_assert(not is_disjoint_from(type[BivSub[T]], type[Biv[Any]]))
386+
387+
static_assert(not is_subtype_of(type[CovSub[T]], type[Cov[Any]]))
388+
static_assert(not is_subtype_of(type[CovSub[Any]], type[Cov[T]]))
389+
static_assert(is_assignable_to(type[CovSub[T]], type[Cov[Any]]))
390+
static_assert(is_assignable_to(type[CovSub[Any]], type[Cov[T]]))
391+
static_assert(not is_disjoint_from(type[CovSub[T]], type[Cov[Any]]))
392+
393+
static_assert(not is_subtype_of(type[ContraSub[T]], type[Contra[Any]]))
394+
static_assert(not is_subtype_of(type[ContraSub[Any]], type[Contra[T]]))
395+
static_assert(is_assignable_to(type[ContraSub[T]], type[Contra[Any]]))
396+
static_assert(is_assignable_to(type[ContraSub[Any]], type[Contra[T]]))
397+
static_assert(not is_disjoint_from(type[ContraSub[T]], type[Contra[Any]]))
398+
399+
static_assert(not is_subtype_of(type[InvSub[T]], type[Inv[Any]]))
400+
static_assert(not is_subtype_of(type[InvSub[Any]], type[Inv[T]]))
401+
static_assert(is_assignable_to(type[InvSub[T]], type[Inv[Any]]))
402+
static_assert(is_assignable_to(type[InvSub[Any]], type[Inv[T]]))
403+
static_assert(not is_disjoint_from(type[InvSub[T]], type[Inv[Any]]))
404+
```

crates/ty_python_semantic/src/types.rs

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2724,6 +2724,17 @@ impl<'db> Type<'db> {
27242724
)
27252725
}
27262726

2727+
(Type::GenericAlias(self_alias), Type::GenericAlias(target_alias)) => {
2728+
ClassType::from(self_alias).has_relation_to_impl(
2729+
db,
2730+
ClassType::from(target_alias),
2731+
inferable,
2732+
relation,
2733+
relation_visitor,
2734+
disjointness_visitor,
2735+
)
2736+
}
2737+
27272738
// `Literal[str]` is a subtype of `type` because the `str` class object is an instance of its metaclass `type`.
27282739
// `Literal[abc.ABC]` is a subtype of `abc.ABCMeta` because the `abc.ABC` class object
27292740
// is an instance of its metaclass `abc.ABCMeta`.
@@ -3307,7 +3318,6 @@ impl<'db> Type<'db> {
33073318
| Type::WrapperDescriptor(..)
33083319
| Type::ModuleLiteral(..)
33093320
| Type::ClassLiteral(..)
3310-
| Type::GenericAlias(..)
33113321
| Type::SpecialForm(..)
33123322
| Type::KnownInstance(..)),
33133323
right @ (Type::BooleanLiteral(..)
@@ -3321,7 +3331,6 @@ impl<'db> Type<'db> {
33213331
| Type::WrapperDescriptor(..)
33223332
| Type::ModuleLiteral(..)
33233333
| Type::ClassLiteral(..)
3324-
| Type::GenericAlias(..)
33253334
| Type::SpecialForm(..)
33263335
| Type::KnownInstance(..)),
33273336
) => ConstraintSet::from(left != right),
@@ -3504,13 +3513,25 @@ impl<'db> Type<'db> {
35043513
ConstraintSet::from(true)
35053514
}
35063515

3516+
(Type::GenericAlias(left_alias), Type::GenericAlias(right_alias)) => {
3517+
ConstraintSet::from(left_alias.origin(db) != right_alias.origin(db)).or(db, || {
3518+
left_alias.specialization(db).is_disjoint_from_impl(
3519+
db,
3520+
right_alias.specialization(db),
3521+
inferable,
3522+
disjointness_visitor,
3523+
relation_visitor,
3524+
)
3525+
})
3526+
}
3527+
35073528
(Type::SubclassOf(subclass_of_ty), Type::ClassLiteral(class_b))
35083529
| (Type::ClassLiteral(class_b), Type::SubclassOf(subclass_of_ty)) => {
35093530
match subclass_of_ty.subclass_of() {
35103531
SubclassOfInner::Dynamic(_) => ConstraintSet::from(false),
3511-
SubclassOfInner::Class(class_a) => {
3512-
class_b.when_subclass_of(db, None, class_a).negate(db)
3513-
}
3532+
SubclassOfInner::Class(class_a) => ConstraintSet::from(
3533+
!class_a.could_exist_in_mro_of(db, ClassType::NonGeneric(class_b)),
3534+
),
35143535
SubclassOfInner::TypeVar(_) => unreachable!(),
35153536
}
35163537
}
@@ -3519,9 +3540,9 @@ impl<'db> Type<'db> {
35193540
| (Type::GenericAlias(alias_b), Type::SubclassOf(subclass_of_ty)) => {
35203541
match subclass_of_ty.subclass_of() {
35213542
SubclassOfInner::Dynamic(_) => ConstraintSet::from(false),
3522-
SubclassOfInner::Class(class_a) => ClassType::from(alias_b)
3523-
.when_subclass_of(db, class_a, inferable)
3524-
.negate(db),
3543+
SubclassOfInner::Class(class_a) => ConstraintSet::from(
3544+
!class_a.could_exist_in_mro_of(db, ClassType::Generic(alias_b)),
3545+
),
35253546
SubclassOfInner::TypeVar(_) => unreachable!(),
35263547
}
35273548
}
@@ -3815,6 +3836,8 @@ impl<'db> Type<'db> {
38153836
relation_visitor,
38163837
)
38173838
}
3839+
3840+
(Type::GenericAlias(_), _) | (_, Type::GenericAlias(_)) => ConstraintSet::from(true),
38183841
}
38193842
}
38203843

crates/ty_python_semantic/src/types/class.rs

Lines changed: 29 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,30 @@ impl<'db> ClassType<'db> {
706706
.find_map(|base| base.as_disjoint_base(db))
707707
}
708708

709+
/// Return `true` if this class could exist in the MRO of `other`.
710+
pub(super) fn could_exist_in_mro_of(self, db: &'db dyn Db, other: Self) -> bool {
711+
other.iter_mro(db)
712+
.filter_map(ClassBase::into_class)
713+
.any(|class| match (self, class) {
714+
(ClassType::NonGeneric(this_class), ClassType::NonGeneric(other_class)) => {
715+
this_class == other_class
716+
}
717+
(ClassType::Generic(this_alias), ClassType::Generic(other_alias)) => {
718+
this_alias.origin(db) == other_alias.origin(db)
719+
&& this_alias
720+
.specialization(db)
721+
.is_disjoint_from(
722+
db,
723+
other_alias.specialization(db),
724+
InferableTypeVars::None,
725+
)
726+
.is_never_satisfied(db)
727+
}
728+
(ClassType::NonGeneric(_), ClassType::Generic(_))
729+
| (ClassType::Generic(_), ClassType::NonGeneric(_)) => false,
730+
})
731+
}
732+
709733
/// Return `true` if this class could coexist in an MRO with `other`.
710734
///
711735
/// For two given classes `A` and `B`, it is often possible to say for sure
@@ -716,34 +740,12 @@ impl<'db> ClassType<'db> {
716740
return true;
717741
}
718742

719-
if self.is_final(db) || other.is_final(db) {
720-
let (this, other) = if self.is_final(db) {
721-
(self, other)
722-
} else {
723-
(other, self)
724-
};
743+
if self.is_final(db) {
744+
return other.could_exist_in_mro_of(db, self);
745+
}
725746

726-
return this
727-
.iter_mro(db)
728-
.filter_map(ClassBase::into_class)
729-
.any(|class| match (class, other) {
730-
(ClassType::NonGeneric(this_class), ClassType::NonGeneric(other_class)) => {
731-
this_class == other_class
732-
}
733-
(ClassType::Generic(this_alias), ClassType::Generic(other_alias)) => {
734-
this_alias.origin(db) == other_alias.origin(db)
735-
&& this_alias
736-
.specialization(db)
737-
.is_disjoint_from(
738-
db,
739-
other_alias.specialization(db),
740-
InferableTypeVars::None,
741-
)
742-
.is_never_satisfied(db)
743-
}
744-
(ClassType::NonGeneric(_), ClassType::Generic(_))
745-
| (ClassType::Generic(_), ClassType::NonGeneric(_)) => false,
746-
});
747+
if other.is_final(db) {
748+
return self.could_exist_in_mro_of(db, other);
747749
}
748750

749751
// Two disjoint bases can only coexist in an MRO if one is a subclass of the other.
@@ -1859,15 +1861,6 @@ impl<'db> ClassLiteral<'db> {
18591861
.contains(&ClassBase::Class(other))
18601862
}
18611863

1862-
pub(super) fn when_subclass_of(
1863-
self,
1864-
db: &'db dyn Db,
1865-
specialization: Option<Specialization<'db>>,
1866-
other: ClassType<'db>,
1867-
) -> ConstraintSet<'db> {
1868-
ConstraintSet::from(self.is_subclass_of(db, specialization, other))
1869-
}
1870-
18711864
/// Return `true` if this class constitutes a typed dict specification (inherits from
18721865
/// `typing.TypedDict`, either directly or indirectly).
18731866
#[salsa::tracked(cycle_initial=is_typed_dict_cycle_initial,

0 commit comments

Comments
 (0)