@@ -94,7 +94,9 @@ use crate::types::mro::MroErrorKind;
9494use crate::types::newtype::NewType;
9595use crate::types::signatures::Signature;
9696use crate::types::subclass_of::SubclassOfInner;
97- use crate::types::tuple::{Tuple, TupleLength, TupleSpec, TupleType};
97+ use crate::types::tuple::{
98+ Tuple, TupleLength, TupleSpec, TupleSpecBuilder, TupleType, VariableLengthTuple,
99+ };
98100use crate::types::typed_dict::{
99101 TypedDictAssignmentKind, validate_typed_dict_constructor, validate_typed_dict_dict_literal,
100102 validate_typed_dict_key_assignment,
@@ -6926,7 +6928,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
69266928 ast::Expr::If(if_expression) => self.infer_if_expression(if_expression, tcx),
69276929 ast::Expr::Lambda(lambda_expression) => self.infer_lambda_expression(lambda_expression),
69286930 ast::Expr::Call(call_expression) => self.infer_call_expression(call_expression, tcx),
6929- ast::Expr::Starred(starred) => self.infer_starred_expression(starred),
6931+ ast::Expr::Starred(starred) => self.infer_starred_expression(starred, tcx ),
69306932 ast::Expr::Yield(yield_expression) => self.infer_yield_expression(yield_expression),
69316933 ast::Expr::YieldFrom(yield_from) => self.infer_yield_from_expression(yield_from),
69326934 ast::Expr::Await(await_expression) => self.infer_await_expression(await_expression),
@@ -7151,25 +7153,66 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
71517153 )
71527154 });
71537155
7156+ let mut is_homogeneous_tuple_annotation = false;
7157+
71547158 let annotated_tuple = tcx
71557159 .known_specialization(self.db(), KnownClass::Tuple)
71567160 .and_then(|specialization| {
7157- specialization
7161+ let spec = specialization
71587162 .tuple(self.db())
7159- .expect("the specialization of `KnownClass::Tuple` must have a tuple spec")
7160- .resize(self.db(), TupleLength::Fixed(elts.len()))
7161- .ok()
7163+ .expect("the specialization of `KnownClass::Tuple` must have a tuple spec");
7164+
7165+ if matches!(
7166+ spec,
7167+ Tuple::Variable(VariableLengthTuple { prefix, variable: _, suffix})
7168+ if prefix.is_empty() && suffix.is_empty()
7169+ ) {
7170+ is_homogeneous_tuple_annotation = true;
7171+ }
7172+
7173+ spec.resize(self.db(), TupleLength::Fixed(elts.len())).ok()
71627174 });
71637175
71647176 let mut annotated_elt_tys = annotated_tuple.as_ref().map(Tuple::all_elements);
71657177
71667178 let db = self.db();
7167- let element_types = elts.iter().map(|element| {
7168- let annotated_elt_ty = annotated_elt_tys.as_mut().and_then(Iterator::next).copied();
7169- self.infer_expression(element, TypeContext::new(annotated_elt_ty))
7170- });
71717179
7172- Type::heterogeneous_tuple(db, element_types)
7180+ let can_use_type_context =
7181+ is_homogeneous_tuple_annotation || elts.iter().all(|elt| !elt.is_starred_expr());
7182+
7183+ let mut infer_element = |elt: &ast::Expr| {
7184+ if can_use_type_context {
7185+ let annotated_elt_ty = annotated_elt_tys.as_mut().and_then(Iterator::next).copied();
7186+ let context = if let ast::Expr::Starred(starred) = elt {
7187+ annotated_elt_ty
7188+ .map(|expected_element_type| {
7189+ TypeContext::for_starred_expression(db, expected_element_type, starred)
7190+ })
7191+ .unwrap_or_default()
7192+ } else {
7193+ TypeContext::new(annotated_elt_ty)
7194+ };
7195+ self.infer_expression(elt, context)
7196+ } else {
7197+ self.infer_expression(elt, TypeContext::default())
7198+ }
7199+ };
7200+
7201+ let mut builder = TupleSpecBuilder::with_capacity(elts.len());
7202+
7203+ for element in elts {
7204+ if element.is_starred_expr() {
7205+ let element_type = infer_element(element);
7206+ // Fine to use `iterate` rather than `try_iterate` here:
7207+ // errors from iterating over something not iterable will have been
7208+ // emitted in the `infer_element` call above.
7209+ builder = builder.concat(db, &element_type.iterate(db));
7210+ } else {
7211+ builder.push(infer_element(element).fallback_to_divergent(db));
7212+ }
7213+ }
7214+
7215+ Type::tuple(TupleType::new(db, &builder.build()))
71737216 }
71747217
71757218 fn infer_list_expression(&mut self, list: &ast::ExprList, tcx: TypeContext<'db>) -> Type<'db> {
@@ -7326,7 +7369,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
73267369
73277370 let inferable = generic_context.inferable_typevars(self.db());
73287371
7329- // Remove any union elements of that are unrelated to the collection type.
7372+ // Remove any union elements of the annotation that are unrelated to the collection type.
73307373 //
73317374 // For example, we only want the `list[int]` from `annotation: list[int] | None` if
73327375 // `collection_ty` is `list`.
@@ -7366,8 +7409,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
73667409 }
73677410
73687411 let elt_tcxs = match annotated_elt_tys {
7369- None => Either::Left(iter::repeat(TypeContext::default() )),
7370- Some(tys) => Either::Right(tys.iter().map(|ty| TypeContext::new( Some(*ty)) )),
7412+ None => Either::Left(iter::repeat(None )),
7413+ Some(tys) => Either::Right(tys.iter().copied(). map(Some)),
73717414 };
73727415
73737416 for elts in elts {
@@ -7396,6 +7439,14 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
73967439 {
73977440 let Some(elt) = elt else { continue };
73987441
7442+ let elt_tcx = if let ast::Expr::Starred(starred) = elt {
7443+ elt_tcx
7444+ .map(|ty| TypeContext::for_starred_expression(self.db(), ty, starred))
7445+ .unwrap_or_default()
7446+ } else {
7447+ TypeContext::new(elt_tcx)
7448+ };
7449+
73997450 let inferred_elt_ty = infer_elt_expression(self, elt, elt_tcx);
74007451
74017452 // Simplify the inference based on the declared type of the element.
@@ -7409,7 +7460,18 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
74097460 // unions for large nested list literals, which the constraint solver struggles with.
74107461 let inferred_elt_ty = inferred_elt_ty.promote_literals(self.db(), elt_tcx);
74117462
7412- builder.infer(Type::TypeVar(elt_ty), inferred_elt_ty).ok()?;
7463+ builder
7464+ .infer(
7465+ Type::TypeVar(elt_ty),
7466+ if elt.is_starred_expr() {
7467+ inferred_elt_ty
7468+ .iterate(self.db())
7469+ .homogeneous_element_type(self.db())
7470+ } else {
7471+ inferred_elt_ty
7472+ },
7473+ )
7474+ .ok()?;
74137475 }
74147476 }
74157477
@@ -8204,25 +8266,28 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
82048266 }
82058267 }
82068268
8207- fn infer_starred_expression(&mut self, starred: &ast::ExprStarred) -> Type<'db> {
8269+ fn infer_starred_expression(
8270+ &mut self,
8271+ starred: &ast::ExprStarred,
8272+ tcx: TypeContext<'db>,
8273+ ) -> Type<'db> {
82088274 let ast::ExprStarred {
82098275 range: _,
82108276 node_index: _,
82118277 value,
82128278 ctx: _,
82138279 } = starred;
82148280
8215- let iterable_type = self.infer_expression(value, TypeContext::default());
8281+ let db = self.db();
8282+ let iterable_type = self.infer_expression(value, tcx);
8283+
82168284 iterable_type
8217- .try_iterate(self.db() )
8218- .map(|tuple| tuple.homogeneous_element_type(self.db( )))
8285+ .try_iterate(db )
8286+ .map(|spec| Type:: tuple(TupleType::new(db, &spec )))
82198287 .unwrap_or_else(|err| {
82208288 err.report_diagnostic(&self.context, iterable_type, value.as_ref().into());
8221- err.fallback_element_type(self.db())
8222- });
8223-
8224- // TODO
8225- todo_type!("starred expression")
8289+ Type::homogeneous_tuple(db, err.fallback_element_type(db))
8290+ })
82268291 }
82278292
82288293 fn infer_yield_expression(&mut self, yield_expression: &ast::ExprYield) -> Type<'db> {
0 commit comments