Skip to content

Commit 83134fb

Browse files
authored
[ty] Handle nested types when creating specializations from constraint sets (#21530)
#21414 added the ability to create a specialization from a constraint set. It handled mutually constrained typevars just fine, e.g. given `T ≤ int ∧ U = T` we can infer `T = int, U = int`. But it didn't handle _nested_ constraints correctly, e.g. `T ≤ int ∧ U = list[T]`. Now we do! This requires doing a fixed-point "apply the specialization to itself" step to propagate the assignments of any nested typevars, and then a cycle detection check to make sure we don't have an infinite expansion in the specialization. This gets at an interesting nuance in our constraint set structure that @sharkdp has asked about before. Constraint sets are BDDs, and each internal node represents an _individual constraint_, of the form `lower ≤ T ≤ upper`. `lower` and `upper` are allowed to be other typevars, but only if they appear "later" in the arbitary ordering that we establish over typevars. The main purpose of this is to avoid infinite expansion for mutually constrained typevars. However, that restriction doesn't help us here, because only applies when `lower` and `upper` _are_ typevars, not when they _contain_ typevars. That distinction is important, since it means the restriction does not affect our expressiveness: we can always rewrite `Never ≤ T ≤ U` (a constraint on `T`) into `T ≤ U ≤ object` (a constraint on `U`). The same is not true of `Never ≤ T ≤ list[U]` — there is no "inverse" of `list` that we could apply to both sides to transform this into a constraint on a bare `U`.
1 parent 0d47334 commit 83134fb

File tree

3 files changed

+179
-5
lines changed

3 files changed

+179
-5
lines changed

crates/ty_python_semantic/resources/mdtest/generics/specialize_constrained.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,3 +303,33 @@ def mutually_bound[T: Base, U]():
303303
# revealed: ty_extensions.Specialization[T@mutually_bound = Base, U@mutually_bound = Sub]
304304
reveal_type(generic_context(mutually_bound).specialize_constrained(ConstraintSet.range(Never, U, Sub) & ConstraintSet.range(Never, U, T)))
305305
```
306+
307+
## Nested typevars
308+
309+
A typevar's constraint can _mention_ another typevar without _constraining_ it. In this example, `U`
310+
must be specialized to `list[T]`, but it cannot affect what `T` is specialized to.
311+
312+
```py
313+
from typing import Never
314+
from ty_extensions import ConstraintSet, generic_context
315+
316+
def mentions[T, U]():
317+
constraints = ConstraintSet.range(Never, T, int) & ConstraintSet.range(list[T], U, list[T])
318+
# revealed: ty_extensions.ConstraintSet[((T@mentions ≤ int) ∧ (U@mentions = list[T@mentions]))]
319+
reveal_type(constraints)
320+
# revealed: ty_extensions.Specialization[T@mentions = int, U@mentions = list[int]]
321+
reveal_type(generic_context(mentions).specialize_constrained(constraints))
322+
```
323+
324+
If the constraint set contains mutually recursive bounds, specialization inference will not
325+
converge. This test ensures that our cycle detection prevents an endless loop or stack overflow in
326+
this case.
327+
328+
```py
329+
def divergent[T, U]():
330+
constraints = ConstraintSet.range(list[U], T, list[U]) & ConstraintSet.range(list[T], U, list[T])
331+
# revealed: ty_extensions.ConstraintSet[((T@divergent = list[U@divergent]) ∧ (U@divergent = list[T@divergent]))]
332+
reveal_type(constraints)
333+
# revealed: None
334+
reveal_type(generic_context(divergent).specialize_constrained(constraints))
335+
```

crates/ty_python_semantic/src/types/constraints.rs

Lines changed: 104 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
//!
5454
//! [bdd]: https://en.wikipedia.org/wiki/Binary_decision_diagram
5555
56+
use std::cell::RefCell;
5657
use std::cmp::Ordering;
5758
use std::fmt::Display;
5859
use std::ops::Range;
@@ -62,9 +63,10 @@ use rustc_hash::{FxHashMap, FxHashSet};
6263
use salsa::plumbing::AsId;
6364

6465
use crate::types::generics::{GenericContext, InferableTypeVars, Specialization};
66+
use crate::types::visitor::{TypeCollector, TypeVisitor, walk_type_with_recursion_guard};
6567
use crate::types::{
6668
BoundTypeVarIdentity, BoundTypeVarInstance, IntersectionType, Type, TypeRelation,
67-
TypeVarBoundOrConstraints, UnionType,
69+
TypeVarBoundOrConstraints, UnionType, walk_bound_type_var_type,
6870
};
6971
use crate::{Db, FxOrderSet};
7072

@@ -213,6 +215,100 @@ impl<'db> ConstraintSet<'db> {
213215
self.node.is_always_satisfied(db)
214216
}
215217

218+
/// Returns whether this constraint set contains any cycles between typevars. If it does, then
219+
/// we cannot create a specialization from this constraint set.
220+
///
221+
/// We have restrictions in place that ensure that there are no cycles in the _lower and upper
222+
/// bounds_ of each constraint, but it's still possible for a constraint to _mention_ another
223+
/// typevar without _constraining_ it. For instance, `(T ≤ int) ∧ (U ≤ list[T])` is a valid
224+
/// constraint set, which we can create a specialization from (`T = int, U = list[int]`). But
225+
/// `(T ≤ list[U]) ∧ (U ≤ list[T])` does not violate our lower/upper bounds restrictions, since
226+
/// neither bound _is_ a typevar. And it's not something we can create a specialization from,
227+
/// since we would endlessly substitute until we stack overflow.
228+
pub(crate) fn is_cyclic(self, db: &'db dyn Db) -> bool {
229+
#[derive(Default)]
230+
struct CollectReachability<'db> {
231+
reachable_typevars: RefCell<FxHashSet<BoundTypeVarIdentity<'db>>>,
232+
recursion_guard: TypeCollector<'db>,
233+
}
234+
235+
impl<'db> TypeVisitor<'db> for CollectReachability<'db> {
236+
fn should_visit_lazy_type_attributes(&self) -> bool {
237+
true
238+
}
239+
240+
fn visit_bound_type_var_type(
241+
&self,
242+
db: &'db dyn Db,
243+
bound_typevar: BoundTypeVarInstance<'db>,
244+
) {
245+
self.reachable_typevars
246+
.borrow_mut()
247+
.insert(bound_typevar.identity(db));
248+
walk_bound_type_var_type(db, bound_typevar, self);
249+
}
250+
251+
fn visit_type(&self, db: &'db dyn Db, ty: Type<'db>) {
252+
walk_type_with_recursion_guard(db, ty, self, &self.recursion_guard);
253+
}
254+
}
255+
256+
fn visit_dfs<'db>(
257+
reachable_typevars: &mut FxHashMap<
258+
BoundTypeVarIdentity<'db>,
259+
FxHashSet<BoundTypeVarIdentity<'db>>,
260+
>,
261+
discovered: &mut FxHashSet<BoundTypeVarIdentity<'db>>,
262+
bound_typevar: BoundTypeVarIdentity<'db>,
263+
) -> bool {
264+
discovered.insert(bound_typevar);
265+
let outgoing = reachable_typevars
266+
.remove(&bound_typevar)
267+
.expect("should not visit typevar twice in DFS");
268+
for outgoing in outgoing {
269+
if discovered.contains(&outgoing) {
270+
return true;
271+
}
272+
if reachable_typevars.contains_key(&outgoing) {
273+
if visit_dfs(reachable_typevars, discovered, outgoing) {
274+
return true;
275+
}
276+
}
277+
}
278+
discovered.remove(&bound_typevar);
279+
false
280+
}
281+
282+
// First find all of the typevars that each constraint directly mentions.
283+
let mut reachable_typevars: FxHashMap<
284+
BoundTypeVarIdentity<'db>,
285+
FxHashSet<BoundTypeVarIdentity<'db>>,
286+
> = FxHashMap::default();
287+
self.node.for_each_constraint(db, &mut |constraint| {
288+
let visitor = CollectReachability::default();
289+
visitor.visit_type(db, constraint.lower(db));
290+
visitor.visit_type(db, constraint.upper(db));
291+
reachable_typevars
292+
.entry(constraint.typevar(db).identity(db))
293+
.or_default()
294+
.extend(visitor.reachable_typevars.into_inner());
295+
});
296+
297+
// Then perform a depth-first search to see if there are any cycles.
298+
let mut discovered: FxHashSet<BoundTypeVarIdentity<'db>> = FxHashSet::default();
299+
while let Some(bound_typevar) = reachable_typevars.keys().copied().next() {
300+
if !discovered.contains(&bound_typevar) {
301+
let cycle_found =
302+
visit_dfs(&mut reachable_typevars, &mut discovered, bound_typevar);
303+
if cycle_found {
304+
return true;
305+
}
306+
}
307+
}
308+
309+
false
310+
}
311+
216312
/// Returns the constraints under which `lhs` is a subtype of `rhs`, assuming that the
217313
/// constraints in this constraint set hold. Panics if neither of the types being compared are
218314
/// a typevar. (That case is handled by `Type::has_relation_to`.)
@@ -2964,6 +3060,12 @@ impl<'db> GenericContext<'db> {
29643060
db: &'db dyn Db,
29653061
constraints: ConstraintSet<'db>,
29663062
) -> Result<Specialization<'db>, ()> {
3063+
// If the constraint set is cyclic, don't even try to construct a specialization.
3064+
if constraints.is_cyclic(db) {
3065+
// TODO: Better error
3066+
return Err(());
3067+
}
3068+
29673069
// First we intersect with the valid specializations of all of the typevars. We need all of
29683070
// valid specializations to hold simultaneously, so we do this once before abstracting over
29693071
// each typevar.
@@ -3020,7 +3122,7 @@ impl<'db> GenericContext<'db> {
30203122
types[i] = least_upper_bound;
30213123
}
30223124

3023-
Ok(self.specialize(db, types.into_boxed_slice()))
3125+
Ok(self.specialize_recursive(db, types.into_boxed_slice()))
30243126
}
30253127
}
30263128

crates/ty_python_semantic/src/types/generics.rs

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -500,9 +500,16 @@ impl<'db> GenericContext<'db> {
500500
}
501501

502502
/// Creates a specialization of this generic context. Panics if the length of `types` does not
503-
/// match the number of typevars in the generic context. You must provide a specific type for
504-
/// each typevar; no defaults are used. (Use [`specialize_partial`](Self::specialize_partial)
505-
/// if you might not have types for every typevar.)
503+
/// match the number of typevars in the generic context.
504+
///
505+
/// You must provide a specific type for each typevar; no defaults are used. (Use
506+
/// [`specialize_partial`](Self::specialize_partial) if you might not have types for every
507+
/// typevar.)
508+
///
509+
/// The types you provide should not mention any of the typevars in this generic context;
510+
/// otherwise, you will be left with a partial specialization. (Use
511+
/// [`specialize_recursive`](Self::specialize_recursive) if your types might mention typevars
512+
/// in this generic context.)
506513
pub(crate) fn specialize(
507514
self,
508515
db: &'db dyn Db,
@@ -512,6 +519,41 @@ impl<'db> GenericContext<'db> {
512519
Specialization::new(db, self, types, None, None)
513520
}
514521

522+
/// Creates a specialization of this generic context. Panics if the length of `types` does not
523+
/// match the number of typevars in the generic context.
524+
///
525+
/// You are allowed to provide types that mention the typevars in this generic context.
526+
pub(crate) fn specialize_recursive(
527+
self,
528+
db: &'db dyn Db,
529+
mut types: Box<[Type<'db>]>,
530+
) -> Specialization<'db> {
531+
let len = types.len();
532+
assert!(self.len(db) == len);
533+
loop {
534+
let mut any_changed = false;
535+
for i in 0..len {
536+
let partial = PartialSpecialization {
537+
generic_context: self,
538+
types: &types,
539+
};
540+
let updated = types[i].apply_type_mapping(
541+
db,
542+
&TypeMapping::PartialSpecialization(partial),
543+
TypeContext::default(),
544+
);
545+
if updated != types[i] {
546+
types[i] = updated;
547+
any_changed = true;
548+
}
549+
}
550+
551+
if !any_changed {
552+
return Specialization::new(db, self, types, None, None);
553+
}
554+
}
555+
}
556+
515557
/// Creates a specialization of this generic context for the `tuple` class.
516558
pub(crate) fn specialize_tuple(
517559
self,

0 commit comments

Comments
 (0)