Skip to content

Commit 244b1e4

Browse files
committed
Handle more subject expressions
1 parent c132ec4 commit 244b1e4

File tree

3 files changed

+94
-17
lines changed

3 files changed

+94
-17
lines changed

mypy/checkpattern.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,15 @@
1616
from mypy.messages import MessageBuilder
1717
from mypy.nodes import (
1818
ARG_POS,
19+
AssignmentExpr,
1920
Context,
2021
Expression,
2122
IndexExpr,
2223
IntExpr,
24+
ListExpr,
2325
MemberExpr,
2426
NameExpr,
27+
TupleExpr,
2528
TypeAlias,
2629
TypeInfo,
2730
UnaryExpr,
@@ -194,7 +197,8 @@ def visit_or_pattern(self, o: OrPattern) -> PatternType:
194197
capture_expr_keys: set[Key | None] = set()
195198
# Collect captures from the first subpattern
196199
for expr, typ in pattern_types[0].captures.items():
197-
node = get_var(expr)
200+
if (node := get_var(expr)) is None:
201+
continue
198202
key = literal_hash(expr)
199203
capture_types[node][key].append((expr, typ))
200204
if isinstance(expr, NameExpr):
@@ -209,7 +213,8 @@ def visit_or_pattern(self, o: OrPattern) -> PatternType:
209213
# Only fail for directly captured names (with NameExpr)
210214
self.msg.fail(message_registry.OR_PATTERN_ALTERNATIVE_NAMES, o.patterns[i])
211215
for expr, typ in pattern_type.captures.items():
212-
node = get_var(expr)
216+
if (node := get_var(expr)) is None:
217+
continue
213218
key = literal_hash(expr)
214219
capture_types[node][key].append((expr, typ))
215220

@@ -312,15 +317,26 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
312317
inner_types, star_position, required_patterns
313318
)
314319
current_subjects: list[list[Expression]] = [[] for _ in range(len(contracted_inner_types))]
320+
end_pos = len(contracted_inner_types) if star_position is None else star_position
315321
for subject in self.subject_context[-1]:
316-
# Support x[0], x[1], ... lookup until wildcard
317-
end_pos = len(contracted_inner_types) if star_position is None else star_position
318-
for i in range(end_pos):
319-
current_subjects[i].append(IndexExpr(subject, IntExpr(i)))
320-
# For everything after wildcard use x[-2], x[-1]
321-
for i in range((star_position or -1) + 1, len(contracted_inner_types)):
322-
offset = len(contracted_inner_types) - i
323-
current_subjects[i].append(IndexExpr(subject, UnaryExpr("-", IntExpr(offset))))
322+
if isinstance(subject, (ListExpr, TupleExpr)):
323+
# For list and tuple expressions, lookup expression in items
324+
for i in range(end_pos):
325+
if i < len(subject.items):
326+
current_subjects[i].append(subject.items[i])
327+
if star_position is not None:
328+
for i in range(star_position + 1, len(contracted_inner_types)):
329+
offset = len(contracted_inner_types) - i
330+
if offset <= len(subject.items):
331+
current_subjects[i].append(subject.items[-offset])
332+
else:
333+
# Support x[0], x[1], ... lookup until wildcard
334+
for i in range(end_pos):
335+
current_subjects[i].append(IndexExpr(subject, IntExpr(i)))
336+
# For everything after wildcard use x[-2], x[-1]
337+
for i in range((star_position or -1) + 1, len(contracted_inner_types)):
338+
offset = len(contracted_inner_types) - i
339+
current_subjects[i].append(IndexExpr(subject, UnaryExpr("-", IntExpr(offset))))
324340
for p, t, s in zip(o.patterns, contracted_inner_types, current_subjects):
325341
pattern_type = self.accept(p, t, s)
326342
typ, rest, type_map = pattern_type
@@ -794,7 +810,8 @@ def update_type_map(
794810
already_captured = {literal_hash(expr) for expr in original_type_map}
795811
for expr, typ in extra_type_map.items():
796812
if literal_hash(expr) in already_captured:
797-
node = get_var(expr)
813+
if (node := get_var(expr)) is None:
814+
continue
798815
self.msg.fail(
799816
message_registry.MULTIPLE_ASSIGNMENTS_IN_PATTERN.format(node.name), expr
800817
)
@@ -849,7 +866,7 @@ def get_match_arg_names(typ: TupleType) -> list[str | None]:
849866
return args
850867

851868

852-
def get_var(expr: Expression) -> Var:
869+
def get_var(expr: Expression) -> Var | None:
853870
"""
854871
Warning: this in only true for expressions captured by a match statement.
855872
Don't call it from anywhere else
@@ -858,7 +875,10 @@ def get_var(expr: Expression) -> Var:
858875
return get_var(expr.expr)
859876
if isinstance(expr, IndexExpr):
860877
return get_var(expr.base)
861-
assert isinstance(expr, NameExpr), expr
878+
if isinstance(expr, AssignmentExpr):
879+
return get_var(expr.target)
880+
if not isinstance(expr, NameExpr):
881+
return None
862882
node = expr.node
863883
assert isinstance(node, Var), node
864884
return node

mypy/fastparse.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,7 @@ def __init__(
398398
# 'C' for class, 'D' for function signature, 'F' for function, 'L' for lambda
399399
self.class_and_function_stack: list[Literal["C", "D", "F", "L"]] = []
400400
self.imports: list[ImportBase] = []
401+
self.match_stmt_subject = False
401402

402403
self.options = options
403404
self.is_stub = is_stub
@@ -1760,7 +1761,7 @@ def visit_Name(self, n: Name) -> NameExpr:
17601761
# List(expr* elts, expr_context ctx)
17611762
def visit_List(self, n: ast3.List) -> ListExpr | TupleExpr:
17621763
expr_list: list[Expression] = [self.visit(e) for e in n.elts]
1763-
if isinstance(n.ctx, ast3.Store):
1764+
if isinstance(n.ctx, ast3.Store) or self.match_stmt_subject:
17641765
# [x, y] = z and (x, y) = z means exactly the same thing
17651766
e: ListExpr | TupleExpr = TupleExpr(expr_list)
17661767
else:
@@ -1793,8 +1794,11 @@ def visit_Index(self, n: Index) -> Node:
17931794

17941795
# Match(expr subject, match_case* cases) # python 3.10 and later
17951796
def visit_Match(self, n: Match) -> MatchStmt:
1797+
self.match_stmt_subject = True
1798+
subject = self.visit(n.subject)
1799+
self.match_stmt_subject = False
17961800
node = MatchStmt(
1797-
self.visit(n.subject),
1801+
subject,
17981802
[self.visit(c.pattern) for c in n.cases],
17991803
[self.visit(c.guard) for c in n.cases],
18001804
[self.as_required_block(c.body) for c in n.cases],

test-data/unit/check-python310.test

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,7 @@ match x:
318318
pass
319319

320320
[case testMatchSequencePatternWithInvalidClassPattern]
321+
# flags: --warn-unreachable
321322
class Example:
322323
__match_args__ = ("value",)
323324
def __init__(self, value: str) -> None:
@@ -327,8 +328,10 @@ SubClass: type[Example]
327328

328329
match [SubClass("a"), SubClass("b")]:
329330
case [SubClass(value), *rest]: # E: Expected type in class pattern; found "type[__main__.Example]"
330-
reveal_type(value) # E: Cannot determine type of "value" \
331-
# N: Revealed type is "Any"
331+
reveal_type(value) # E: Statement is unreachable
332+
reveal_type(rest)
333+
case [Example(value), *rest]:
334+
reveal_type(value) # N: Revealed type is "builtins.str"
332335
reveal_type(rest) # N: Revealed type is "builtins.list[__main__.Example]"
333336
[builtins fixtures/tuple.pyi]
334337

@@ -3048,3 +3051,53 @@ match m:
30483051
reveal_type(m.a) # N: Revealed type is "Any"
30493052
reveal_type(m.a[0]) # N: Revealed type is "__main__.A"
30503053
reveal_type(m.a[0].a) # N: Revealed type is "Union[Literal['Hello'], Literal['World']]"
3054+
3055+
[case testMatchSubjectExpression]
3056+
# flags: --warn-unreachable
3057+
m: object
3058+
n: object
3059+
o: object
3060+
def func(): ...
3061+
3062+
match (m, n, o):
3063+
case [1, 2, 3] | [2, 3, 4]:
3064+
reveal_type(m) # N: Revealed type is "Union[Literal[1], Literal[2]]"
3065+
reveal_type(n) # N: Revealed type is "Union[Literal[2], Literal[3]]"
3066+
reveal_type(o) # N: Revealed type is "Union[Literal[3], Literal[4]]"
3067+
case [1, 2, 3, 4] | [2, 3, 4, 5]:
3068+
# No match -> don't crash
3069+
reveal_type(m) # E: Statement is unreachable
3070+
case [1, *_, 3] | [2, *_, 4]:
3071+
reveal_type(m) # N: Revealed type is "Union[Literal[1], Literal[2]]"
3072+
reveal_type(n) # N: Revealed type is "builtins.object"
3073+
reveal_type(o) # N: Revealed type is "Union[Literal[3], Literal[4]]"
3074+
case [1, *_, 3, 4, 5] | [2, *_, 3, 4, 5]:
3075+
# No match -> don't crash
3076+
reveal_type(m) # E: Statement is unreachable
3077+
3078+
match [m, n, o]:
3079+
case [1, 2, 3] | [2, 3, 4]:
3080+
reveal_type(m) # N: Revealed type is "Union[Literal[1], Literal[2]]"
3081+
reveal_type(n) # N: Revealed type is "Union[Literal[2], Literal[3]]"
3082+
reveal_type(o) # N: Revealed type is "Union[Literal[3], Literal[4]]"
3083+
case [1, 2, 3, 4] | [2, 3, 4, 5]:
3084+
# No match -> don't crash
3085+
reveal_type(m) # E: Statement is unreachable
3086+
case [1, *_, 3] | [2, *_, 4]:
3087+
reveal_type(m) # N: Revealed type is "Union[Literal[1], Literal[2]]"
3088+
reveal_type(n) # N: Revealed type is "builtins.object"
3089+
reveal_type(o) # N: Revealed type is "Union[Literal[3], Literal[4]]"
3090+
case [1, *_, 3, 4, 5] | [2, *_, 3, 4, 5]:
3091+
# No match -> don't crash
3092+
reveal_type(m) # E: Statement is unreachable
3093+
3094+
match a := m:
3095+
case [1, 2] | [3, 4]:
3096+
reveal_type(a) # N: Revealed type is "typing.Sequence[builtins.int]"
3097+
reveal_type(a[0]) # N: Revealed type is "Union[Literal[1], Literal[3]]"
3098+
3099+
match func():
3100+
# Don't crash for subject expressions which can't be narrowed
3101+
case [1, 2] | [3, 4]:
3102+
...
3103+
[builtins fixtures/tuple.pyi]

0 commit comments

Comments
 (0)