Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions crates/ty_python_semantic/resources/mdtest/call/union.md
Original file line number Diff line number Diff line change
Expand Up @@ -227,17 +227,22 @@ def _(literals_2: Literal[0, 1], b: bool, flag: bool):
literals_16 = 4 * literals_4 + literals_4 # Literal[0, 1, .., 15]
literals_64 = 4 * literals_16 + literals_4 # Literal[0, 1, .., 63]
literals_128 = 2 * literals_64 + literals_2 # Literal[0, 1, .., 127]
literals_256 = 2 * literals_128 + literals_2 # Literal[0, 1, .., 255]

# Going beyond the MAX_UNION_LITERALS limit (currently 200):
literals_256 = 16 * literals_16 + literals_16
reveal_type(literals_256) # revealed: int
# Going beyond the MAX_UNION_LITERALS limit (currently 512):
literals_512 = 2 * literals_256 + literals_2 # Literal[0, 1, .., 511]
reveal_type(literals_512 if flag else 512) # revealed: int

# Going beyond the limit when another type is already part of the union
bool_and_literals_128 = b if flag else literals_128 # bool | Literal[0, 1, ..., 127]
literals_128_shifted = literals_128 + 128 # Literal[128, 129, ..., 255]
literals_256_shifted = literals_256 + 256 # Literal[256, 257, ..., 511]

# Now union the two:
reveal_type(bool_and_literals_128 if flag else literals_128_shifted) # revealed: int
two = bool_and_literals_128 if flag else literals_128_shifted
# revealed: bool | Literal[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255]
reveal_type(two)
reveal_type(two if flag else literals_256_shifted) # revealed: int
```

## Simplifying gradually-equivalent types
Expand Down
45 changes: 39 additions & 6 deletions crates/ty_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ use crate::semantic_index::scope::ScopeId;
use crate::semantic_index::{imported_modules, place_table, semantic_index};
use crate::suppression::check_suppressions;
use crate::types::bound_super::BoundSuperType;
use crate::types::builder::RecursivelyDefined;
use crate::types::call::{Binding, Bindings, CallArguments, CallableBinding};
pub(crate) use crate::types::class_base::ClassBase;
use crate::types::constraints::{
Expand Down Expand Up @@ -9514,6 +9515,7 @@ impl<'db> TypeVarInstance<'db> {
.skip(1)
.map(|arg| definition_expression_type(db, definition, arg))
.collect::<Box<_>>(),
RecursivelyDefined::No,
)
}
_ => return None,
Expand Down Expand Up @@ -9944,6 +9946,7 @@ impl<'db> TypeVarBoundOrConstraints<'db> {
.iter()
.map(|ty| ty.normalized_impl(db, visitor))
.collect::<Box<_>>(),
constraints.recursively_defined(db),
))
}
}
Expand All @@ -9967,6 +9970,7 @@ impl<'db> TypeVarBoundOrConstraints<'db> {
.iter()
.map(|ty| ty.materialize(db, materialization_kind, visitor))
.collect::<Box<_>>(),
RecursivelyDefined::No,
))
}
}
Expand Down Expand Up @@ -12922,6 +12926,9 @@ pub struct UnionType<'db> {
/// The union type includes values in any of these types.
#[returns(deref)]
pub elements: Box<[Type<'db>]>,
/// Whether the value pointed to by this type is recursively defined.
/// If `Yes`, union literal widening is performed early.
recursively_defined: RecursivelyDefined,
}

pub(crate) fn walk_union<'db, V: visitor::TypeVisitor<'db> + ?Sized>(
Expand Down Expand Up @@ -13006,7 +13013,14 @@ impl<'db> UnionType<'db> {
db: &'db dyn Db,
transform_fn: impl FnMut(&Type<'db>) -> Type<'db>,
) -> Type<'db> {
Self::from_elements(db, self.elements(db).iter().map(transform_fn))
self.elements(db)
.iter()
.map(transform_fn)
.fold(UnionBuilder::new(db), |builder, element| {
builder.add(element)
})
.recursively_defined(self.recursively_defined(db))
.build()
}

/// A fallible version of [`UnionType::map`].
Expand All @@ -13021,7 +13035,12 @@ impl<'db> UnionType<'db> {
db: &'db dyn Db,
transform_fn: impl FnMut(&Type<'db>) -> Option<Type<'db>>,
) -> Option<Type<'db>> {
Self::try_from_elements(db, self.elements(db).iter().map(transform_fn))
let mut builder = UnionBuilder::new(db);
for element in self.elements(db).iter().map(transform_fn) {
builder = builder.add(element?);
}
builder = builder.recursively_defined(self.recursively_defined(db));
Some(builder.build())
}

pub(crate) fn to_instance(self, db: &'db dyn Db) -> Option<Type<'db>> {
Expand All @@ -13033,7 +13052,14 @@ impl<'db> UnionType<'db> {
db: &'db dyn Db,
mut f: impl FnMut(&Type<'db>) -> bool,
) -> Type<'db> {
Self::from_elements(db, self.elements(db).iter().filter(|ty| f(ty)))
self.elements(db)
.iter()
.filter(|ty| f(ty))
.fold(UnionBuilder::new(db), |builder, element| {
builder.add(*element)
})
.recursively_defined(self.recursively_defined(db))
.build()
}

pub(crate) fn map_with_boundness(
Expand Down Expand Up @@ -13068,7 +13094,9 @@ impl<'db> UnionType<'db> {
Place::Undefined
} else {
Place::Defined(
builder.build(),
builder
.recursively_defined(self.recursively_defined(db))
.build(),
origin,
if possibly_unbound {
Definedness::PossiblyUndefined
Expand Down Expand Up @@ -13116,7 +13144,9 @@ impl<'db> UnionType<'db> {
Place::Undefined
} else {
Place::Defined(
builder.build(),
builder
.recursively_defined(self.recursively_defined(db))
.build(),
origin,
if possibly_unbound {
Definedness::PossiblyUndefined
Expand Down Expand Up @@ -13151,6 +13181,7 @@ impl<'db> UnionType<'db> {
.unpack_aliases(true),
UnionBuilder::add,
)
.recursively_defined(self.recursively_defined(db))
.build()
}

Expand All @@ -13164,7 +13195,8 @@ impl<'db> UnionType<'db> {
let mut builder = UnionBuilder::new(db)
.order_elements(false)
.unpack_aliases(false)
.cycle_recovery(true);
.cycle_recovery(true)
.recursively_defined(self.recursively_defined(db));
let mut empty = true;
for ty in self.elements(db) {
if nested {
Expand All @@ -13179,6 +13211,7 @@ impl<'db> UnionType<'db> {
// `Divergent` in a union type does not mean true divergence, so we skip it if not nested.
// e.g. T | Divergent == T | (T | (T | (T | ...))) == T
if ty == &div {
builder = builder.recursively_defined(RecursivelyDefined::Yes);
continue;
}
builder = builder.add(
Expand Down
89 changes: 80 additions & 9 deletions crates/ty_python_semantic/src/types/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,12 +202,30 @@ enum ReduceResult<'db> {
Type(Type<'db>),
}

// TODO increase this once we extend `UnionElement` throughout all union/intersection
// representations, so that we can make large unions of literals fast in all operations.
//
// For now (until we solve https://github.com/astral-sh/ty/issues/957), keep this number
// below 200, which is the salsa fixpoint iteration limit.
const MAX_UNION_LITERALS: usize = 190;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, get_size2::GetSize)]
pub enum RecursivelyDefined {
Yes,
No,
}

impl RecursivelyDefined {
const fn is_yes(self) -> bool {
matches!(self, RecursivelyDefined::Yes)
}

const fn or(self, other: RecursivelyDefined) -> RecursivelyDefined {
match (self, other) {
(RecursivelyDefined::Yes, _) | (_, RecursivelyDefined::Yes) => RecursivelyDefined::Yes,
_ => RecursivelyDefined::No,
}
}
}

/// If the value ​​is defined recursively, widening is performed from fewer literal elements, resulting in faster convergence of the fixed-point iteration.
const MAX_RECURSIVE_UNION_LITERALS: usize = 10;
/// If the value ​​is defined non-recursively, the fixed-point iteration will converge in one go,
/// so in principle we can have as many literal elements as we want, but to avoid unintended huge computational loads, we limit it to 256.
const MAX_NON_RECURSIVE_UNION_LITERALS: usize = 256;

pub(crate) struct UnionBuilder<'db> {
elements: Vec<UnionElement<'db>>,
Expand All @@ -217,6 +235,7 @@ pub(crate) struct UnionBuilder<'db> {
// This is enabled when joining types in a `cycle_recovery` function.
// Since a cycle cannot be created within a `cycle_recovery` function, execution of `is_redundant_with` is skipped.
cycle_recovery: bool,
recursively_defined: RecursivelyDefined,
}

impl<'db> UnionBuilder<'db> {
Expand All @@ -227,6 +246,7 @@ impl<'db> UnionBuilder<'db> {
unpack_aliases: true,
order_elements: false,
cycle_recovery: false,
recursively_defined: RecursivelyDefined::No,
}
}

Expand All @@ -248,6 +268,11 @@ impl<'db> UnionBuilder<'db> {
self
}

pub(crate) fn recursively_defined(mut self, val: RecursivelyDefined) -> Self {
self.recursively_defined = val;
self
}

pub(crate) fn is_empty(&self) -> bool {
self.elements.is_empty()
}
Expand All @@ -258,6 +283,27 @@ impl<'db> UnionBuilder<'db> {
self.elements.push(UnionElement::Type(Type::object()));
}

fn widen_literal_types(&mut self, seen_aliases: &mut Vec<Type<'db>>) {
let mut replace_with = vec![];
for elem in &self.elements {
match elem {
UnionElement::IntLiterals(_) => {
replace_with.push(KnownClass::Int.to_instance(self.db));
}
UnionElement::StringLiterals(_) => {
replace_with.push(KnownClass::Str.to_instance(self.db));
}
UnionElement::BytesLiterals(_) => {
replace_with.push(KnownClass::Bytes.to_instance(self.db));
}
UnionElement::Type(_) => {}
}
}
for ty in replace_with {
self.add_in_place_impl(ty, seen_aliases);
}
}

/// Adds a type to this union.
pub(crate) fn add(mut self, ty: Type<'db>) -> Self {
self.add_in_place(ty);
Expand All @@ -270,13 +316,36 @@ impl<'db> UnionBuilder<'db> {
}

pub(crate) fn add_in_place_impl(&mut self, ty: Type<'db>, seen_aliases: &mut Vec<Type<'db>>) {
let cycle_recovery = self.cycle_recovery;
let should_widen = |literals, recursively_defined: RecursivelyDefined| {
if recursively_defined.is_yes() && cycle_recovery {
literals >= MAX_RECURSIVE_UNION_LITERALS
} else {
literals >= MAX_NON_RECURSIVE_UNION_LITERALS
}
};

match ty {
Type::Union(union) => {
let new_elements = union.elements(self.db);
self.elements.reserve(new_elements.len());
for element in new_elements {
self.add_in_place_impl(*element, seen_aliases);
}
self.recursively_defined = self
.recursively_defined
.or(union.recursively_defined(self.db));
if self.cycle_recovery && self.recursively_defined.is_yes() {
let literals = self.elements.iter().fold(0, |acc, elem| match elem {
UnionElement::IntLiterals(literals) => acc + literals.len(),
UnionElement::StringLiterals(literals) => acc + literals.len(),
UnionElement::BytesLiterals(literals) => acc + literals.len(),
UnionElement::Type(_) => acc,
});
if should_widen(literals, self.recursively_defined) {
self.widen_literal_types(seen_aliases);
}
}
}
// Adding `Never` to a union is a no-op.
Type::Never => {}
Expand All @@ -300,7 +369,7 @@ impl<'db> UnionBuilder<'db> {
for (index, element) in self.elements.iter_mut().enumerate() {
match element {
UnionElement::StringLiterals(literals) => {
if literals.len() >= MAX_UNION_LITERALS {
if should_widen(literals.len(), self.recursively_defined) {
let replace_with = KnownClass::Str.to_instance(self.db);
self.add_in_place_impl(replace_with, seen_aliases);
return;
Expand Down Expand Up @@ -345,7 +414,7 @@ impl<'db> UnionBuilder<'db> {
for (index, element) in self.elements.iter_mut().enumerate() {
match element {
UnionElement::BytesLiterals(literals) => {
if literals.len() >= MAX_UNION_LITERALS {
if should_widen(literals.len(), self.recursively_defined) {
let replace_with = KnownClass::Bytes.to_instance(self.db);
self.add_in_place_impl(replace_with, seen_aliases);
return;
Expand Down Expand Up @@ -390,7 +459,7 @@ impl<'db> UnionBuilder<'db> {
for (index, element) in self.elements.iter_mut().enumerate() {
match element {
UnionElement::IntLiterals(literals) => {
if literals.len() >= MAX_UNION_LITERALS {
if should_widen(literals.len(), self.recursively_defined) {
let replace_with = KnownClass::Int.to_instance(self.db);
self.add_in_place_impl(replace_with, seen_aliases);
return;
Expand Down Expand Up @@ -585,6 +654,7 @@ impl<'db> UnionBuilder<'db> {
_ => Some(Type::Union(UnionType::new(
self.db,
types.into_boxed_slice(),
self.recursively_defined,
))),
}
}
Expand Down Expand Up @@ -696,6 +766,7 @@ impl<'db> IntersectionBuilder<'db> {
enum_member_literals(db, instance.class_literal(db), None)
.expect("Calling `enum_member_literals` on an enum class")
.collect::<Box<[_]>>(),
RecursivelyDefined::No,
)),
seen_aliases,
)
Expand Down
2 changes: 2 additions & 0 deletions crates/ty_python_semantic/src/types/infer/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ use crate::semantic_index::{
ApplicableConstraints, EnclosingSnapshotResult, SemanticIndex, place_table,
};
use crate::subscript::{PyIndex, PySlice};
use crate::types::builder::RecursivelyDefined;
use crate::types::call::bind::{CallableDescription, MatchingOverloadIndex};
use crate::types::call::{Binding, Bindings, CallArguments, CallError, CallErrorKind};
use crate::types::class::{CodeGeneratorKind, FieldKind, MetaclassErrorKind, MethodDecorator};
Expand Down Expand Up @@ -3269,6 +3270,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
elts.iter()
.map(|expr| self.infer_type_expression(expr))
.collect::<Box<[_]>>(),
RecursivelyDefined::No,
));
self.store_expression_type(expr, ty);
}
Expand Down
8 changes: 6 additions & 2 deletions crates/ty_python_semantic/src/types/subclass_of.rs
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ impl<'db> SubclassOfInner<'db> {
)
}
Some(TypeVarBoundOrConstraints::Constraints(constraints)) => {
let constraints = constraints
let constraints_types = constraints
.elements(db)
.iter()
.map(|constraint| {
Expand All @@ -426,7 +426,11 @@ impl<'db> SubclassOfInner<'db> {
})
.collect::<Box<_>>();

TypeVarBoundOrConstraints::Constraints(UnionType::new(db, constraints))
TypeVarBoundOrConstraints::Constraints(UnionType::new(
db,
constraints_types,
constraints.recursively_defined(db),
))
}
})
});
Expand Down
3 changes: 2 additions & 1 deletion crates/ty_python_semantic/src/types/tuple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use itertools::{Either, EitherOrBoth, Itertools};

use crate::semantic_index::definition::Definition;
use crate::subscript::{Nth, OutOfBoundsError, PyIndex, PySlice, StepSizeZeroError};
use crate::types::builder::RecursivelyDefined;
use crate::types::class::{ClassType, KnownClass};
use crate::types::constraints::{ConstraintSet, IteratorConstraintsExtension};
use crate::types::generics::InferableTypeVars;
Expand Down Expand Up @@ -1462,7 +1463,7 @@ impl<'db> Tuple<Type<'db>> {
// those techniques ensure that union elements are deduplicated and unions are eagerly simplified
// into other types where necessary. Here, however, we know that there are no duplicates
// in this union, so it's probably more efficient to use `UnionType::new()` directly.
Type::Union(UnionType::new(db, elements))
Type::Union(UnionType::new(db, elements, RecursivelyDefined::No))
};

TupleSpec::heterogeneous([
Expand Down
Loading