Skip to content

Commit e4a32ba

Browse files
dcreagersharkdp
andauthored
[ty] Constraint sets compare generic callables correctly (#21392)
Constraint sets can now track subtyping/assignability/etc of generic callables correctly. For instance: ```py def identity[T](t: T) -> T: return t constraints = ConstraintSet.always() static_assert(constraints.implies_subtype_of(TypeOf[identity], Callable[[int], int])) static_assert(constraints.implies_subtype_of(TypeOf[identity], Callable[[str], str])) ``` A generic callable can be considered an intersection of all of its possible specializations, and an assignability check with an intersection as the lhs side succeeds of _any_ of the intersected types satisfies the check. Put another way, if someone expects to receive any function with a signature of `(int) -> int`, we can give them `identity`. Note that the corresponding check using `is_subtype_of` directly does not yet work, since #20093 has not yet hooked up the core typing relationship logic to use constraint sets: ```py # These currently fail static_assert(is_subtype_of(TypeOf[identity], Callable[[int], int])) static_assert(is_subtype_of(TypeOf[identity], Callable[[str], str])) ``` To do this, we add a new _existential quantification_ operation on constraint sets. This takes in a list of typevars and _removes_ those typevars from the constraint set. Conceptually, we return a new constraint set that evaluates to `true` when there was _any_ assignment of the removed typevars that caused the old constraint set to evaluate to `true`. When comparing a generic constraint set, we add its typevars to the `inferable` set, and figure out whatever constraints would allow any specialization to satisfy the check. We then use the new existential quantification operator to remove those new typevars, since the caller doesn't (and shouldn't) know anything about them. --------- Co-authored-by: David Peter <[email protected]>
1 parent ac2d07e commit e4a32ba

File tree

8 files changed

+388
-52
lines changed

8 files changed

+388
-52
lines changed

crates/ty_python_semantic/resources/mdtest/protocols.md

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2099,18 +2099,14 @@ static_assert(is_equivalent_to(LegacyFunctionScoped, NewStyleFunctionScoped)) #
20992099

21002100
static_assert(is_assignable_to(NominalNewStyle, NewStyleFunctionScoped))
21012101
static_assert(is_assignable_to(NominalNewStyle, LegacyFunctionScoped))
2102-
# TODO: should pass
2103-
static_assert(is_subtype_of(NominalNewStyle, NewStyleFunctionScoped)) # error: [static-assert-error]
2104-
# TODO: should pass
2105-
static_assert(is_subtype_of(NominalNewStyle, LegacyFunctionScoped)) # error: [static-assert-error]
2102+
static_assert(is_subtype_of(NominalNewStyle, NewStyleFunctionScoped))
2103+
static_assert(is_subtype_of(NominalNewStyle, LegacyFunctionScoped))
21062104
static_assert(not is_assignable_to(NominalNewStyle, UsesSelf))
21072105

21082106
static_assert(is_assignable_to(NominalLegacy, NewStyleFunctionScoped))
21092107
static_assert(is_assignable_to(NominalLegacy, LegacyFunctionScoped))
2110-
# TODO: should pass
2111-
static_assert(is_subtype_of(NominalLegacy, NewStyleFunctionScoped)) # error: [static-assert-error]
2112-
# TODO: should pass
2113-
static_assert(is_subtype_of(NominalLegacy, LegacyFunctionScoped)) # error: [static-assert-error]
2108+
static_assert(is_subtype_of(NominalLegacy, NewStyleFunctionScoped))
2109+
static_assert(is_subtype_of(NominalLegacy, LegacyFunctionScoped))
21142110
static_assert(not is_assignable_to(NominalLegacy, UsesSelf))
21152111

21162112
static_assert(not is_assignable_to(NominalWithSelf, NewStyleFunctionScoped))

crates/ty_python_semantic/resources/mdtest/type_properties/implies_subtype_of.md

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,4 +349,101 @@ def mutually_constrained[T, U]():
349349
static_assert(not given_int.implies_subtype_of(Invariant[str], Invariant[T]))
350350
```
351351

352+
## Generic callables
353+
354+
A generic callable can be considered equivalent to an intersection of all of its possible
355+
specializations. That means that a generic callable is a subtype of any particular specialization.
356+
(If someone expects a function that works with a particular specialization, it's fine to hand them
357+
the generic callable.)
358+
359+
```py
360+
from typing import Callable
361+
from ty_extensions import CallableTypeOf, ConstraintSet, TypeOf, is_subtype_of, static_assert
362+
363+
def identity[T](t: T) -> T:
364+
return t
365+
366+
type GenericIdentity[T] = Callable[[T], T]
367+
368+
constraints = ConstraintSet.always()
369+
370+
static_assert(constraints.implies_subtype_of(TypeOf[identity], Callable[[int], int]))
371+
static_assert(constraints.implies_subtype_of(TypeOf[identity], Callable[[str], str]))
372+
static_assert(not constraints.implies_subtype_of(TypeOf[identity], Callable[[str], int]))
373+
374+
static_assert(constraints.implies_subtype_of(CallableTypeOf[identity], Callable[[int], int]))
375+
static_assert(constraints.implies_subtype_of(CallableTypeOf[identity], Callable[[str], str]))
376+
static_assert(not constraints.implies_subtype_of(CallableTypeOf[identity], Callable[[str], int]))
377+
378+
static_assert(constraints.implies_subtype_of(TypeOf[identity], GenericIdentity[int]))
379+
static_assert(constraints.implies_subtype_of(TypeOf[identity], GenericIdentity[str]))
380+
# This gives us the default specialization, GenericIdentity[Unknown], which does
381+
# not participate in subtyping.
382+
static_assert(not constraints.implies_subtype_of(TypeOf[identity], GenericIdentity))
383+
```
384+
385+
The reverse is not true — if someone expects a generic function that can be called with any
386+
specialization, we cannot hand them a function that only works with one specialization.
387+
388+
```py
389+
static_assert(not constraints.implies_subtype_of(Callable[[int], int], TypeOf[identity]))
390+
static_assert(not constraints.implies_subtype_of(Callable[[str], str], TypeOf[identity]))
391+
static_assert(not constraints.implies_subtype_of(Callable[[str], int], TypeOf[identity]))
392+
393+
static_assert(not constraints.implies_subtype_of(Callable[[int], int], CallableTypeOf[identity]))
394+
static_assert(not constraints.implies_subtype_of(Callable[[str], str], CallableTypeOf[identity]))
395+
static_assert(not constraints.implies_subtype_of(Callable[[str], int], CallableTypeOf[identity]))
396+
397+
static_assert(not constraints.implies_subtype_of(GenericIdentity[int], TypeOf[identity]))
398+
static_assert(not constraints.implies_subtype_of(GenericIdentity[str], TypeOf[identity]))
399+
# This gives us the default specialization, GenericIdentity[Unknown], which does
400+
# not participate in subtyping.
401+
static_assert(not constraints.implies_subtype_of(GenericIdentity, TypeOf[identity]))
402+
```
403+
404+
Unrelated typevars in the constraint set do not affect whether the subtyping check succeeds or
405+
fails.
406+
407+
```py
408+
def unrelated[T]():
409+
# Note that even though this typevar is also named T, it is not the same typevar as T@identity!
410+
constraints = ConstraintSet.range(bool, T, int)
411+
412+
static_assert(constraints.implies_subtype_of(TypeOf[identity], Callable[[int], int]))
413+
static_assert(constraints.implies_subtype_of(TypeOf[identity], Callable[[str], str]))
414+
static_assert(not constraints.implies_subtype_of(TypeOf[identity], Callable[[str], int]))
415+
static_assert(constraints.implies_subtype_of(TypeOf[identity], GenericIdentity[int]))
416+
static_assert(constraints.implies_subtype_of(TypeOf[identity], GenericIdentity[str]))
417+
418+
static_assert(not constraints.implies_subtype_of(Callable[[int], int], TypeOf[identity]))
419+
static_assert(not constraints.implies_subtype_of(Callable[[str], str], TypeOf[identity]))
420+
static_assert(not constraints.implies_subtype_of(Callable[[str], int], TypeOf[identity]))
421+
static_assert(not constraints.implies_subtype_of(GenericIdentity[int], TypeOf[identity]))
422+
static_assert(not constraints.implies_subtype_of(GenericIdentity[str], TypeOf[identity]))
423+
```
424+
425+
The generic callable's typevar _also_ does not affect whether the subtyping check succeeds or fails!
426+
427+
```py
428+
def identity2[T](t: T) -> T:
429+
# This constraint set refers to the same typevar as the generic function types below!
430+
constraints = ConstraintSet.range(bool, T, int)
431+
432+
static_assert(constraints.implies_subtype_of(TypeOf[identity2], Callable[[int], int]))
433+
static_assert(constraints.implies_subtype_of(TypeOf[identity2], Callable[[str], str]))
434+
# TODO: no error
435+
# error: [static-assert-error]
436+
static_assert(not constraints.implies_subtype_of(TypeOf[identity2], Callable[[str], int]))
437+
static_assert(constraints.implies_subtype_of(TypeOf[identity2], GenericIdentity[int]))
438+
static_assert(constraints.implies_subtype_of(TypeOf[identity2], GenericIdentity[str]))
439+
440+
static_assert(not constraints.implies_subtype_of(Callable[[int], int], TypeOf[identity2]))
441+
static_assert(not constraints.implies_subtype_of(Callable[[str], str], TypeOf[identity2]))
442+
static_assert(not constraints.implies_subtype_of(Callable[[str], int], TypeOf[identity2]))
443+
static_assert(not constraints.implies_subtype_of(GenericIdentity[int], TypeOf[identity2]))
444+
static_assert(not constraints.implies_subtype_of(GenericIdentity[str], TypeOf[identity2]))
445+
446+
return t
447+
```
448+
352449
[subtyping]: https://typing.python.org/en/latest/spec/concepts.html#subtype-supertype-and-type-equivalence

crates/ty_python_semantic/resources/mdtest/type_properties/is_assignable_to.md

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# Assignable-to relation
22

3+
```toml
4+
[environment]
5+
python-version = "3.12"
6+
```
7+
38
The `is_assignable_to(S, T)` relation below checks if type `S` is assignable to type `T` (target).
49
This allows us to check if a type `S` can be used in a context where a type `T` is expected
510
(function arguments, variable assignments). See the [typing documentation] for a precise definition
@@ -1227,6 +1232,46 @@ from ty_extensions import static_assert, is_assignable_to
12271232
static_assert(is_assignable_to(type, Callable[..., Any]))
12281233
```
12291234

1235+
### Generic callables
1236+
1237+
A generic callable can be considered equivalent to an intersection of all of its possible
1238+
specializations. That means that a generic callable is assignable to any particular specialization.
1239+
(If someone expects a function that works with a particular specialization, it's fine to hand them
1240+
the generic callable.)
1241+
1242+
```py
1243+
from typing import Callable
1244+
from ty_extensions import CallableTypeOf, TypeOf, is_assignable_to, static_assert
1245+
1246+
def identity[T](t: T) -> T:
1247+
return t
1248+
1249+
static_assert(is_assignable_to(TypeOf[identity], Callable[[int], int]))
1250+
static_assert(is_assignable_to(TypeOf[identity], Callable[[str], str]))
1251+
# TODO: no error
1252+
# error: [static-assert-error]
1253+
static_assert(not is_assignable_to(TypeOf[identity], Callable[[str], int]))
1254+
1255+
static_assert(is_assignable_to(CallableTypeOf[identity], Callable[[int], int]))
1256+
static_assert(is_assignable_to(CallableTypeOf[identity], Callable[[str], str]))
1257+
# TODO: no error
1258+
# error: [static-assert-error]
1259+
static_assert(not is_assignable_to(CallableTypeOf[identity], Callable[[str], int]))
1260+
```
1261+
1262+
The reverse is not true — if someone expects a generic function that can be called with any
1263+
specialization, we cannot hand them a function that only works with one specialization.
1264+
1265+
```py
1266+
static_assert(not is_assignable_to(Callable[[int], int], TypeOf[identity]))
1267+
static_assert(not is_assignable_to(Callable[[str], str], TypeOf[identity]))
1268+
static_assert(not is_assignable_to(Callable[[str], int], TypeOf[identity]))
1269+
1270+
static_assert(not is_assignable_to(Callable[[int], int], CallableTypeOf[identity]))
1271+
static_assert(not is_assignable_to(Callable[[str], str], CallableTypeOf[identity]))
1272+
static_assert(not is_assignable_to(Callable[[str], int], CallableTypeOf[identity]))
1273+
```
1274+
12301275
## Generics
12311276

12321277
### Assignability of generic types parameterized by gradual types

crates/ty_python_semantic/resources/mdtest/type_properties/is_subtype_of.md

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2207,6 +2207,54 @@ static_assert(is_subtype_of(CallableTypeOf[overload_ab], CallableTypeOf[overload
22072207
static_assert(is_subtype_of(CallableTypeOf[overload_ba], CallableTypeOf[overload_ab]))
22082208
```
22092209

2210+
### Generic callables
2211+
2212+
A generic callable can be considered equivalent to an intersection of all of its possible
2213+
specializations. That means that a generic callable is a subtype of any particular specialization.
2214+
(If someone expects a function that works with a particular specialization, it's fine to hand them
2215+
the generic callable.)
2216+
2217+
```py
2218+
from typing import Callable
2219+
from ty_extensions import CallableTypeOf, TypeOf, is_subtype_of, static_assert
2220+
2221+
def identity[T](t: T) -> T:
2222+
return t
2223+
2224+
# TODO: Confusingly, these are not the same results as the corresponding checks in
2225+
# is_assignable_to.md, even though all of these types are fully static. We have some heuristics that
2226+
# currently conflict with each other, that we are in the process of removing with the constraint set
2227+
# work.
2228+
# TODO: no error
2229+
# error: [static-assert-error]
2230+
static_assert(is_subtype_of(TypeOf[identity], Callable[[int], int]))
2231+
# TODO: no error
2232+
# error: [static-assert-error]
2233+
static_assert(is_subtype_of(TypeOf[identity], Callable[[str], str]))
2234+
static_assert(not is_subtype_of(TypeOf[identity], Callable[[str], int]))
2235+
2236+
# TODO: no error
2237+
# error: [static-assert-error]
2238+
static_assert(is_subtype_of(CallableTypeOf[identity], Callable[[int], int]))
2239+
# TODO: no error
2240+
# error: [static-assert-error]
2241+
static_assert(is_subtype_of(CallableTypeOf[identity], Callable[[str], str]))
2242+
static_assert(not is_subtype_of(CallableTypeOf[identity], Callable[[str], int]))
2243+
```
2244+
2245+
The reverse is not true — if someone expects a generic function that can be called with any
2246+
specialization, we cannot hand them a function that only works with one specialization.
2247+
2248+
```py
2249+
static_assert(not is_subtype_of(Callable[[int], int], TypeOf[identity]))
2250+
static_assert(not is_subtype_of(Callable[[str], str], TypeOf[identity]))
2251+
static_assert(not is_subtype_of(Callable[[str], int], TypeOf[identity]))
2252+
2253+
static_assert(not is_subtype_of(Callable[[int], int], CallableTypeOf[identity]))
2254+
static_assert(not is_subtype_of(Callable[[str], str], CallableTypeOf[identity]))
2255+
static_assert(not is_subtype_of(Callable[[str], int], CallableTypeOf[identity]))
2256+
```
2257+
22102258
[gradual form]: https://typing.python.org/en/latest/spec/glossary.html#term-gradual-form
22112259
[gradual tuple]: https://typing.python.org/en/latest/spec/tuples.html#tuple-type-form
22122260
[special case for float and complex]: https://typing.python.org/en/latest/spec/special-types.html#special-cases-for-float-and-complex

crates/ty_python_semantic/src/types.rs

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1944,9 +1944,7 @@ impl<'db> Type<'db> {
19441944
})
19451945
}
19461946

1947-
(Type::TypeVar(bound_typevar), _)
1948-
if bound_typevar.is_inferable(db, inferable) && relation.is_assignability() =>
1949-
{
1947+
(Type::TypeVar(bound_typevar), _) if bound_typevar.is_inferable(db, inferable) => {
19501948
// The implicit lower bound of a typevar is `Never`, which means
19511949
// that it is always assignable to any other type.
19521950

@@ -2086,9 +2084,12 @@ impl<'db> Type<'db> {
20862084
}
20872085

20882086
// TODO: Infer specializations here
2089-
(Type::TypeVar(bound_typevar), _) | (_, Type::TypeVar(bound_typevar))
2090-
if bound_typevar.is_inferable(db, inferable) =>
2091-
{
2087+
(_, Type::TypeVar(bound_typevar)) if bound_typevar.is_inferable(db, inferable) => {
2088+
ConstraintSet::from(false)
2089+
}
2090+
(Type::TypeVar(bound_typevar), _) => {
2091+
// All inferable cases should have been handled above
2092+
assert!(!bound_typevar.is_inferable(db, inferable));
20922093
ConstraintSet::from(false)
20932094
}
20942095

@@ -2542,13 +2543,8 @@ impl<'db> Type<'db> {
25422543
disjointness_visitor,
25432544
),
25442545

2545-
// Other than the special cases enumerated above, nominal-instance types,
2546-
// newtype-instance types, and typevars are never subtypes of any other variants
2547-
(Type::TypeVar(bound_typevar), _) => {
2548-
// All inferable cases should have been handled above
2549-
assert!(!bound_typevar.is_inferable(db, inferable));
2550-
ConstraintSet::from(false)
2551-
}
2546+
// Other than the special cases enumerated above, nominal-instance types, and
2547+
// newtype-instance types are never subtypes of any other variants
25522548
(Type::NominalInstance(_), _) => ConstraintSet::from(false),
25532549
(Type::NewTypeInstance(_), _) => ConstraintSet::from(false),
25542550
}

0 commit comments

Comments
 (0)