Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -303,3 +303,33 @@ def mutually_bound[T: Base, U]():
# revealed: ty_extensions.Specialization[T@mutually_bound = Base, U@mutually_bound = Sub]
reveal_type(generic_context(mutually_bound).specialize_constrained(ConstraintSet.range(Never, U, Sub) & ConstraintSet.range(Never, U, T)))
```

## Nested typevars

A typevar's constraint can _mention_ another typevar without _constraining_ it. In this example, `U`
must be specialized to `list[T]`, but it cannot affect what `T` is specialized to.

```py
from typing import Never
from ty_extensions import ConstraintSet, generic_context

def mentions[T, U]():
constraints = ConstraintSet.range(Never, T, int) & ConstraintSet.range(list[T], U, list[T])
# revealed: ty_extensions.ConstraintSet[((T@mentions ≤ int) ∧ (U@mentions = list[T@mentions]))]
reveal_type(constraints)
# revealed: ty_extensions.Specialization[T@mentions = int, U@mentions = list[int]]
reveal_type(generic_context(mentions).specialize_constrained(constraints))
```

If the constraint set contains mutually recursive bounds, specialization inference will not
converge. This test ensures that our cycle detection prevents an endless loop or stack overflow in
this case.

```py
def divergent[T, U]():
constraints = ConstraintSet.range(list[U], T, list[U]) & ConstraintSet.range(list[T], U, list[T])
# revealed: ty_extensions.ConstraintSet[((T@divergent = list[U@divergent]) ∧ (U@divergent = list[T@divergent]))]
reveal_type(constraints)
# revealed: None
reveal_type(generic_context(divergent).specialize_constrained(constraints))
```
110 changes: 108 additions & 2 deletions crates/ty_python_semantic/src/types/constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
//!
//! [bdd]: https://en.wikipedia.org/wiki/Binary_decision_diagram

use std::cell::RefCell;
use std::cmp::Ordering;
use std::fmt::Display;
use std::ops::Range;
Expand All @@ -62,9 +63,10 @@ use rustc_hash::{FxHashMap, FxHashSet};
use salsa::plumbing::AsId;

use crate::types::generics::{GenericContext, InferableTypeVars, Specialization};
use crate::types::visitor::{TypeCollector, TypeVisitor, walk_type_with_recursion_guard};
use crate::types::{
BoundTypeVarIdentity, BoundTypeVarInstance, IntersectionType, Type, TypeRelation,
TypeVarBoundOrConstraints, UnionType,
TypeVarBoundOrConstraints, UnionType, walk_bound_type_var_type,
};
use crate::{Db, FxOrderSet};

Expand Down Expand Up @@ -213,6 +215,104 @@ impl<'db> ConstraintSet<'db> {
self.node.is_always_satisfied(db)
}

/// Returns whether this constraint set contains any cycles between typevars. If it does, then
/// we cannot create a specialization from this constraint set.
///
/// We have restrictions in place that ensure that there are no cycles in the _lower and upper
/// bounds_ of each constraint, but it's still possible for a constraint to _mention_ another
/// typevar without _constraining_ it. For instance, `(T ≤ int) ∧ (U ≤ list[T])` is a valid
/// constraint set, which we can create a specialization from (`T = int, U = list[int]`). But
/// `(T ≤ list[U]) ∧ (U ≤ list[T])` does not violate our lower/upper bounds restrictions, since
/// neither bound _is_ a typevar. And it's not something we can create a specialization from,
/// since we would endlessly substitute until we stack overflow.
pub(crate) fn is_cyclic(self, db: &'db dyn Db) -> bool {
#[derive(Default)]
struct CollectReachability<'db> {
reachable_typevars: RefCell<FxHashSet<BoundTypeVarIdentity<'db>>>,
recursion_guard: TypeCollector<'db>,
}

impl<'db> TypeVisitor<'db> for CollectReachability<'db> {
fn should_visit_lazy_type_attributes(&self) -> bool {
true
}

fn visit_bound_type_var_type(
&self,
db: &'db dyn Db,
bound_typevar: BoundTypeVarInstance<'db>,
) {
self.reachable_typevars
.borrow_mut()
.insert(bound_typevar.identity(db));
walk_bound_type_var_type(db, bound_typevar, self);
}

fn visit_type(&self, db: &'db dyn Db, ty: Type<'db>) {
walk_type_with_recursion_guard(db, ty, self, &self.recursion_guard);
}
}

fn visit_dfs<'db>(
reachable_typevars: &FxHashMap<
BoundTypeVarIdentity<'db>,
FxHashSet<BoundTypeVarIdentity<'db>>,
>,
discovered: &mut FxHashSet<BoundTypeVarIdentity<'db>>,
finished: &mut FxHashSet<BoundTypeVarIdentity<'db>>,
bound_typevar: BoundTypeVarIdentity<'db>,
) -> bool {
discovered.insert(bound_typevar);
for outgoing in reachable_typevars.get(&bound_typevar).into_iter().flatten() {
if discovered.contains(outgoing) {
return true;
}
if !finished.contains(outgoing) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could finished be eliminated in favor of removing keys from reachable_typevars when finished checking them?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, yes! Have to jump through some borrow checker hoops but nothing too bad

if visit_dfs(reachable_typevars, discovered, finished, *outgoing) {
return true;
}
}
}
discovered.remove(&bound_typevar);
finished.insert(bound_typevar);
false
}

// First find all of the typevars that each constraint directly mentions.
let mut reachable_typevars: FxHashMap<
BoundTypeVarIdentity<'db>,
FxHashSet<BoundTypeVarIdentity<'db>>,
> = FxHashMap::default();
self.node.for_each_constraint(db, &mut |constraint| {
let visitor = CollectReachability::default();
visitor.visit_type(db, constraint.lower(db));
visitor.visit_type(db, constraint.upper(db));
reachable_typevars
.entry(constraint.typevar(db).identity(db))
.or_default()
.extend(visitor.reachable_typevars.into_inner());
});

// Then perform a depth-first search to see if there are any cycles.
let mut discovered: FxHashSet<BoundTypeVarIdentity<'db>> = FxHashSet::default();
let mut finished: FxHashSet<BoundTypeVarIdentity<'db>> = FxHashSet::default();
for bound_typevar in reachable_typevars.keys() {
if !discovered.contains(bound_typevar) && !finished.contains(bound_typevar) {
let cycle_found = visit_dfs(
&reachable_typevars,
&mut discovered,
&mut finished,
*bound_typevar,
);
if cycle_found {
return true;
}
}
}

false
}

/// Returns the constraints under which `lhs` is a subtype of `rhs`, assuming that the
/// constraints in this constraint set hold. Panics if neither of the types being compared are
/// a typevar. (That case is handled by `Type::has_relation_to`.)
Expand Down Expand Up @@ -2964,6 +3064,12 @@ impl<'db> GenericContext<'db> {
db: &'db dyn Db,
constraints: ConstraintSet<'db>,
) -> Result<Specialization<'db>, ()> {
// If the constraint set is cyclic, don't even try to construct a specialization.
if constraints.is_cyclic(db) {
// TODO: Better error
return Err(());
}

// First we intersect with the valid specializations of all of the typevars. We need all of
// valid specializations to hold simultaneously, so we do this once before abstracting over
// each typevar.
Expand Down Expand Up @@ -3020,7 +3126,7 @@ impl<'db> GenericContext<'db> {
types[i] = least_upper_bound;
}

Ok(self.specialize(db, types.into_boxed_slice()))
Ok(self.specialize_recursive(db, types.into_boxed_slice()))
}
}

Expand Down
48 changes: 45 additions & 3 deletions crates/ty_python_semantic/src/types/generics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -500,9 +500,16 @@ impl<'db> GenericContext<'db> {
}

/// Creates a specialization of this generic context. Panics if the length of `types` does not
/// match the number of typevars in the generic context. You must provide a specific type for
/// each typevar; no defaults are used. (Use [`specialize_partial`](Self::specialize_partial)
/// if you might not have types for every typevar.)
/// match the number of typevars in the generic context.
///
/// You must provide a specific type for each typevar; no defaults are used. (Use
/// [`specialize_partial`](Self::specialize_partial) if you might not have types for every
/// typevar.)
///
/// The types you provide should not mention any of the typevars in this generic context;
/// otherwise, you will be left with a partial specialization. (Use
/// [`specialize_recursive`](Self::specialize_recursive) if your types might mention typevars
/// in this generic context.)
pub(crate) fn specialize(
self,
db: &'db dyn Db,
Expand All @@ -512,6 +519,41 @@ impl<'db> GenericContext<'db> {
Specialization::new(db, self, types, None, None)
}

/// Creates a specialization of this generic context. Panics if the length of `types` does not
/// match the number of typevars in the generic context.
///
/// You are allowed to provide types that mention the typevars in this generic context.
pub(crate) fn specialize_recursive(
self,
db: &'db dyn Db,
mut types: Box<[Type<'db>]>,
) -> Specialization<'db> {
let len = types.len();
assert!(self.len(db) == len);
loop {
let mut any_changed = false;
for i in 0..len {
let partial = PartialSpecialization {
generic_context: self,
types: &types,
};
let updated = types[i].apply_type_mapping(
db,
&TypeMapping::PartialSpecialization(partial),
TypeContext::default(),
);
if updated != types[i] {
types[i] = updated;
any_changed = true;
}
}

if !any_changed {
return Specialization::new(db, self, types, None, None);
}
}
}

/// Creates a specialization of this generic context for the `tuple` class.
pub(crate) fn specialize_tuple(
self,
Expand Down
Loading