Skip to content
32 changes: 24 additions & 8 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@
find_member,
infer_class_variances,
is_callable_compatible,
is_enum_value_pair,
is_equivalent,
is_more_precise,
is_proper_subtype,
Expand Down Expand Up @@ -6567,6 +6568,7 @@ def equality_type_narrowing_helper(
if operator in {"is", "is not"}:
is_valid_target: Callable[[Type], bool] = is_singleton_type
coerce_only_in_literal_context = False
no_custom_eq = True
should_narrow_by_identity = True
else:

Expand All @@ -6582,21 +6584,31 @@ def has_no_custom_eq_checks(t: Type) -> bool:
coerce_only_in_literal_context = True

expr_types = [operand_types[i] for i in expr_indices]
should_narrow_by_identity = all(
map(has_no_custom_eq_checks, expr_types)
) and not is_ambiguous_mix_of_enums(expr_types)
no_custom_eq = all(map(has_no_custom_eq_checks, expr_types))
should_narrow_by_identity = not is_ambiguous_mix_of_enums(expr_types)

if_map: TypeMap = {}
else_map: TypeMap = {}
if should_narrow_by_identity:
if_map, else_map = self.refine_identity_comparison_expression(
if no_custom_eq:
# Try to narrow the types or at least identify unreachable blocks.
# If there's some mix of enums and values, we do not want to narrow enums
# to literals, but still want to detect unreachable branches.
if_map_optimistic, else_map_optimistic = self.refine_identity_comparison_expression(
operands,
operand_types,
expr_indices,
narrowable_operand_index_to_hash.keys(),
is_valid_target,
coerce_only_in_literal_context,
)
if should_narrow_by_identity:
if_map = if_map_optimistic
else_map = else_map_optimistic
else:
if if_map_optimistic is None:
if_map = None
if else_map_optimistic is None:
else_map = None

if if_map == {} and else_map == {}:
if_map, else_map = self.refine_away_none_in_comparison(
Expand Down Expand Up @@ -6844,13 +6856,16 @@ def should_coerce_inner(typ: Type) -> bool:
expr_type = coerce_to_literal(expr_type)
if not is_valid_target(get_proper_type(expr_type)):
continue
if target and not is_same_type(target, expr_type):
if (
target is not None
and not is_same_type(target, expr_type)
and not is_enum_value_pair(target, expr_type)
):
# We have multiple disjoint target types. So the 'if' branch
# must be unreachable.
return None, {}
target = expr_type
possible_target_indices.append(i)

# There's nothing we can currently infer if none of the operands are valid targets,
# so we end early and infer nothing.
if target is None:
Expand Down Expand Up @@ -9193,7 +9208,8 @@ def _ambiguous_enum_variants(types: list[Type]) -> set[str]:
if t.last_known_value:
result.update(_ambiguous_enum_variants([t.last_known_value]))
elif t.type.is_enum and any(
base.fullname in ("enum.IntEnum", "enum.StrEnum") for base in t.type.mro
base.fullname in ("enum.IntEnum", "enum.StrEnum", "builtins.str", "builtins.int")
for base in t.type.mro
):
result.add(t.type.fullname)
elif not t.type.is_enum:
Expand Down
10 changes: 9 additions & 1 deletion mypy/meet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
are_parameters_compatible,
find_member,
is_callable_compatible,
is_enum_value_pair,
is_equivalent,
is_proper_subtype,
is_same_type,
Expand Down Expand Up @@ -547,9 +548,16 @@ def _type_object_overlap(left: Type, right: Type) -> bool:
right = right.fallback

if isinstance(left, LiteralType) and isinstance(right, LiteralType):
if left.value == right.value:
if (
left.value == right.value
and left.fallback.type.is_enum == right.fallback.type.is_enum
or is_enum_value_pair(left, right)
):
# If values are the same, we still need to check if fallbacks are overlapping,
# this is done below.
# Enums are more interesting:
# * if both sides are enums, they should have same values
# * if exactly one of them is a enum, fallback compatibibility is enough
left = left.fallback
right = right.fallback
else:
Expand Down
30 changes: 30 additions & 0 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from mypy.options import Options
from mypy.state import state
from mypy.types import (
ELLIPSIS_TYPE_NAMES,
MYPYC_NATIVE_INT_NAMES,
TUPLE_LIKE_INSTANCE_NAMES,
TYPED_NAMEDTUPLE_NAMES,
Expand Down Expand Up @@ -286,6 +287,35 @@ def is_same_type(
)


def is_enum_value_pair(a: Type, b: Type) -> bool:
a = get_proper_type(a)
b = get_proper_type(b)

if not isinstance(a, LiteralType) or not isinstance(b, LiteralType):
return False
if b.fallback.type.is_enum:
a, b = b, a
if b.fallback.type.is_enum or not a.fallback.type.is_enum:
return False
# At this point we have a pair (enum literal, non-enum literal).
# Check that the non-enum fallback is compatible
if not is_subtype(a.fallback, b.fallback):
return False
assert isinstance(a.value, str)
enum_value = a.fallback.type.get(a.value)
if enum_value is None or enum_value.type is None:
return False
proper_value = get_proper_type(enum_value.type)
return isinstance(proper_value, Instance) and (
proper_value.last_known_value == b
# TODO: this is too lax and should only be applied for enums defined in stubs,
# but checking that strictly requires access to the checker. This function
# is needed in `is_overlapping_types` and operates on a lower level,
# so doing this properly would be more difficult.
or proper_value.type.fullname in ELLIPSIS_TYPE_NAMES
)


# This is a common entry point for subtyping checks (both proper and non-proper).
# Never call this private function directly, use the public versions.
def _is_subtype(
Expand Down
Loading
Loading