16
16
from mypy .messages import MessageBuilder
17
17
from mypy .nodes import (
18
18
ARG_POS ,
19
+ AssignmentExpr ,
19
20
Context ,
20
21
Expression ,
21
22
IndexExpr ,
22
23
IntExpr ,
24
+ ListExpr ,
23
25
MemberExpr ,
24
26
NameExpr ,
27
+ TupleExpr ,
25
28
TypeAlias ,
26
29
TypeInfo ,
27
30
UnaryExpr ,
@@ -194,7 +197,8 @@ def visit_or_pattern(self, o: OrPattern) -> PatternType:
194
197
capture_expr_keys : set [Key | None ] = set ()
195
198
# Collect captures from the first subpattern
196
199
for expr , typ in pattern_types [0 ].captures .items ():
197
- node = get_var (expr )
200
+ if (node := get_var (expr )) is None :
201
+ continue
198
202
key = literal_hash (expr )
199
203
capture_types [node ][key ].append ((expr , typ ))
200
204
if isinstance (expr , NameExpr ):
@@ -209,7 +213,8 @@ def visit_or_pattern(self, o: OrPattern) -> PatternType:
209
213
# Only fail for directly captured names (with NameExpr)
210
214
self .msg .fail (message_registry .OR_PATTERN_ALTERNATIVE_NAMES , o .patterns [i ])
211
215
for expr , typ in pattern_type .captures .items ():
212
- node = get_var (expr )
216
+ if (node := get_var (expr )) is None :
217
+ continue
213
218
key = literal_hash (expr )
214
219
capture_types [node ][key ].append ((expr , typ ))
215
220
@@ -312,15 +317,26 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
312
317
inner_types , star_position , required_patterns
313
318
)
314
319
current_subjects : list [list [Expression ]] = [[] for _ in range (len (contracted_inner_types ))]
320
+ end_pos = len (contracted_inner_types ) if star_position is None else star_position
315
321
for subject in self .subject_context [- 1 ]:
316
- # Support x[0], x[1], ... lookup until wildcard
317
- end_pos = len (contracted_inner_types ) if star_position is None else star_position
318
- for i in range (end_pos ):
319
- current_subjects [i ].append (IndexExpr (subject , IntExpr (i )))
320
- # For everything after wildcard use x[-2], x[-1]
321
- for i in range ((star_position or - 1 ) + 1 , len (contracted_inner_types )):
322
- offset = len (contracted_inner_types ) - i
323
- current_subjects [i ].append (IndexExpr (subject , UnaryExpr ("-" , IntExpr (offset ))))
322
+ if isinstance (subject , (ListExpr , TupleExpr )):
323
+ # For list and tuple expressions, lookup expression in items
324
+ for i in range (end_pos ):
325
+ if i < len (subject .items ):
326
+ current_subjects [i ].append (subject .items [i ])
327
+ if star_position is not None :
328
+ for i in range (star_position + 1 , len (contracted_inner_types )):
329
+ offset = len (contracted_inner_types ) - i
330
+ if offset <= len (subject .items ):
331
+ current_subjects [i ].append (subject .items [- offset ])
332
+ else :
333
+ # Support x[0], x[1], ... lookup until wildcard
334
+ for i in range (end_pos ):
335
+ current_subjects [i ].append (IndexExpr (subject , IntExpr (i )))
336
+ # For everything after wildcard use x[-2], x[-1]
337
+ for i in range ((star_position or - 1 ) + 1 , len (contracted_inner_types )):
338
+ offset = len (contracted_inner_types ) - i
339
+ current_subjects [i ].append (IndexExpr (subject , UnaryExpr ("-" , IntExpr (offset ))))
324
340
for p , t , s in zip (o .patterns , contracted_inner_types , current_subjects ):
325
341
pattern_type = self .accept (p , t , s )
326
342
typ , rest , type_map = pattern_type
@@ -794,7 +810,8 @@ def update_type_map(
794
810
already_captured = {literal_hash (expr ) for expr in original_type_map }
795
811
for expr , typ in extra_type_map .items ():
796
812
if literal_hash (expr ) in already_captured :
797
- node = get_var (expr )
813
+ if (node := get_var (expr )) is None :
814
+ continue
798
815
self .msg .fail (
799
816
message_registry .MULTIPLE_ASSIGNMENTS_IN_PATTERN .format (node .name ), expr
800
817
)
@@ -849,7 +866,7 @@ def get_match_arg_names(typ: TupleType) -> list[str | None]:
849
866
return args
850
867
851
868
852
- def get_var (expr : Expression ) -> Var :
869
+ def get_var (expr : Expression ) -> Var | None :
853
870
"""
854
871
Warning: this in only true for expressions captured by a match statement.
855
872
Don't call it from anywhere else
@@ -858,7 +875,10 @@ def get_var(expr: Expression) -> Var:
858
875
return get_var (expr .expr )
859
876
if isinstance (expr , IndexExpr ):
860
877
return get_var (expr .base )
861
- assert isinstance (expr , NameExpr ), expr
878
+ if isinstance (expr , AssignmentExpr ):
879
+ return get_var (expr .target )
880
+ if not isinstance (expr , NameExpr ):
881
+ return None
862
882
node = expr .node
863
883
assert isinstance (node , Var ), node
864
884
return node
0 commit comments