Skip to content

Commit

Permalink
Add support for collections.abc.Mapping
Browse files Browse the repository at this point in the history
Work towards #29135

stacked-commit: true
  • Loading branch information
msuozzo committed Feb 15, 2025
1 parent 62b5214 commit 7130947
Show file tree
Hide file tree
Showing 4 changed files with 277 additions and 2 deletions.
12 changes: 12 additions & 0 deletions sdks/python/apache_beam/typehints/native_type_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
collections.abc.MutableSet,
collections.abc.Collection,
collections.abc.Sequence,
collections.abc.Mapping,
]


Expand Down Expand Up @@ -149,6 +150,10 @@ def _match_is_exactly_sequence(user_type):
return getattr(user_type, '__origin__', None) is collections.abc.Sequence


def _match_is_exactly_mapping(user_type):
return getattr(user_type, '__origin__', None) is collections.abc.Mapping


def match_is_named_tuple(user_type):
return (
_safe_issubclass(user_type, typing.Tuple) and
Expand Down Expand Up @@ -414,6 +419,10 @@ def convert_to_beam_type(typ):
match=_match_is_exactly_sequence,
arity=1,
beam_type=typehints.Sequence),
_TypeMapEntry(
match=_match_is_exactly_mapping,
arity=2,
beam_type=typehints.Mapping),
]

# Find the first matching entry.
Expand Down Expand Up @@ -534,6 +543,9 @@ def convert_to_python_type(typ):
return collections.abc.Sequence[convert_to_python_type(typ.inner_type)]
if isinstance(typ, typehints.IteratorTypeConstraint):
return collections.abc.Iterator[convert_to_python_type(typ.yielded_type)]
if isinstance(typ, typehints.MappingTypeConstraint):
return collections.abc.Mapping[convert_to_python_type(typ.key_type),
convert_to_python_type(typ.value_type)]

raise ValueError('Failed to convert Beam type: %s' % typ)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,9 @@ def test_convert_to_beam_type_with_collections_types(self):
collections.abc.Iterable[tuple[str, int]],
typehints.Iterable[typehints.Tuple[str, int]]),
(
'mapping not caught',
'mapping',
collections.abc.Mapping[str, int],
collections.abc.Mapping[str, int]),
typehints.Mapping[str, int]),
('set', collections.abc.Set[int], typehints.Set[int]),
('mutable set', collections.abc.MutableSet[int], typehints.Set[int]),
(
Expand Down
141 changes: 141 additions & 0 deletions sdks/python/apache_beam/typehints/typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -1111,6 +1111,146 @@ def __getitem__(self, type_param):
ABCSequenceTypeConstraint = SequenceHint.ABCSequenceTypeConstraint


class MappingHint(CompositeTypeHint):
"""A Mapping type-hint.
Mapping[K, V] represents any mapping (dict-like object) where all keys are
of type K and all values are of type V. This is more general than Dict as it
supports any object implementing the Mapping ABC.
Examples of valid mappings include:
- dict
- collections.defaultdict
- collections.OrderedDict
- types.MappingProxyType
"""
class MappingTypeConstraint(TypeConstraint):
def __init__(self, key_type, value_type):
self.key_type = normalize(key_type)
self.value_type = normalize(value_type)

def __repr__(self):
return 'Mapping[%s, %s]' % (repr(self.key_type), repr(self.value_type))

def __eq__(self, other):
return (
type(self) == type(other)
and self.key_type == other.key_type
and self.value_type == other.value_type
)

def __hash__(self):
return hash((type(self), self.key_type, self.value_type))

def _inner_types(self):
yield self.key_type
yield self.value_type

def _consistent_with_check_(self, sub):
# A Dict is consistent with a Mapping of the same types
# Other Mapping subtypes are also consistent
return (
(isinstance(sub, (self.__class__, DictConstraint)) and
is_consistent_with(sub.key_type, self.key_type) and
is_consistent_with(sub.value_type, self.value_type))
)

def _raise_type_error(self, is_key, instance, inner_error_message=''):
type_desc = 'key' if is_key else 'value'
expected_type = self.key_type if is_key else self.value_type

if inner_error_message:
raise CompositeTypeHintError(
'%s hint %s-type constraint violated. All %ss should be of type '
'%s. Instead: %s'
% (
repr(self),
type_desc,
type_desc,
repr(expected_type),
inner_error_message,
)
)
else:
raise CompositeTypeHintError(
'%s hint %s-type constraint violated. All %ss should be of '
'type %s. Instead, %s is of type %s.'
% (
repr(self),
type_desc,
type_desc,
repr(expected_type),
instance,
instance.__class__.__name__,
)
)

def type_check(self, instance):
if not isinstance(instance, abc.Mapping):
raise CompositeTypeHintError(
'Mapping type-constraint violated. All passed instances must be of '
'type Mapping. %s is of type %s.'
% (instance, instance.__class__.__name__)
)

for key, value in instance.items():
try:
check_constraint(self.key_type, key)
except CompositeTypeHintError as e:
self._raise_type_error(True, key, str(e))
except SimpleTypeHintError:
self._raise_type_error(True, key)

try:
check_constraint(self.value_type, value)
except CompositeTypeHintError as e:
self._raise_type_error(False, value, str(e))
except SimpleTypeHintError:
self._raise_type_error(False, value)

def match_type_variables(self, concrete_type):
if isinstance(concrete_type, (MappingTypeConstraint, DictConstraint)):
bindings = {}
bindings.update(
match_type_variables(self.key_type, concrete_type.key_type))
bindings.update(
match_type_variables(self.value_type, concrete_type.value_type))
return bindings
return {}

def bind_type_variables(self, bindings):
bound_key_type = bind_type_variables(self.key_type, bindings)
bound_value_type = bind_type_variables(self.value_type, bindings)
if (bound_key_type, bound_value_type) == (self.key_type, self.value_type):
return self
return Mapping[bound_key_type, bound_value_type]

def __getitem__(self, type_params):
if not isinstance(type_params, tuple):
raise TypeError(
'Parameter to Mapping type-hint must be a tuple of types: '
'Mapping[.., ..].')

if len(type_params) != 2:
raise TypeError(
'Length of parameters to a Mapping type-hint must be exactly 2. '
'Passed parameters: %s, have a length of %s.'
% (type_params, len(type_params))
)

key_type, value_type = type_params

validate_composite_type_param(
key_type, error_msg_prefix='Key-type parameter to a Mapping hint')
validate_composite_type_param(
value_type, error_msg_prefix='Value-type parameter to a Mapping hint')

return self.MappingTypeConstraint(key_type, value_type)


MappingTypeConstraint = MappingHint.MappingTypeConstraint


class IterableHint(CompositeTypeHint):
"""An Iterable type-hint.
Expand Down Expand Up @@ -1292,6 +1432,7 @@ def __getitem__(self, type_params):
FrozenSet = FrozenSetHint()
Collection = CollectionHint()
Sequence = SequenceHint()
Mapping = MappingHint()
Iterable = IterableHint()
Iterator = IteratorHint()
Generator = GeneratorHint()
Expand Down
122 changes: 122 additions & 0 deletions sdks/python/apache_beam/typehints/typehints_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,128 @@ def test_builtin_and_type_compatibility(self):
dict[str, list[int]], typing.Dict[str, typing.List[int]])


class MappingHintTestCase(TypeHintTestCase):
def test_getitem_param_must_be_tuple(self):
with self.assertRaises(TypeError) as e:
typehints.Mapping[4]

self.assertEqual(
'Parameter to Mapping type-hint must be a tuple of '
'types: Mapping[.., ..].',
e.exception.args[0])

def test_getitem_param_must_have_length_2(self):
with self.assertRaises(TypeError) as e:
typehints.Mapping[float, int, bool]

self.assertEqual(
"Length of parameters to a Mapping type-hint must be "
"exactly 2. Passed parameters: ({}, {}, {}), have a "
"length of 3.".format(float, int, bool),
e.exception.args[0])

def test_key_type_must_be_valid_composite_param(self):
try:
typehints.Mapping[list, int]
except TypeError:
self.fail("built-in composite raised TypeError unexpectedly")

def test_value_type_must_be_valid_composite_param(self):
with self.assertRaises(TypeError):
typehints.Mapping[str, 5]

def test_compatibility(self):
hint1 = typehints.Mapping[int, str]
hint2 = typehints.Mapping[bool, int]
hint3 = typehints.Mapping[int, typehints.List[typehints.Tuple[str, str, str]]]
hint4 = typehints.Mapping[int, int]

self.assertCompatible(hint1, hint1)
self.assertCompatible(hint3, hint3)
self.assertNotCompatible(hint3, 4)
self.assertNotCompatible(hint2, hint1) # Key incompatibility.
self.assertNotCompatible(hint1, hint4) # Value incompatibility.

def test_repr(self):
hint3 = typehints.Mapping[int, typehints.List[typehints.Tuple[str, str, str]]]
self.assertEqual(
'Mapping[<class \'int\'>, List[Tuple[<class \'str\'>, ' \
'<class \'str\'>, <class \'str\'>]]]',
repr(hint3))

def test_type_checks_not_dict(self):
hint = typehints.Mapping[int, str]
l = [1, 2]
with self.assertRaises(TypeError) as e:
hint.type_check(l)
self.assertEqual(
'Mapping type-constraint violated. All passed instances '
'must be of type Mapping. [1, 2] is of type list.',
e.exception.args[0])

def test_type_check_invalid_key_type(self):
hint = typehints.Mapping[typehints.Tuple[int, int, int], typehints.List[str]]
d = {(1, 2): ['m', '1', '2', '3']}
with self.assertRaises((TypeError, TypeError)) as e:
hint.type_check(d)
self.assertEqual(
'Mapping[Tuple[<class \'int\'>, <class \'int\'>, <class \'int\'>], '
'List[<class \'str\'>]] hint key-type '
'constraint violated. All keys should be of type '
'Tuple[<class \'int\'>, <class \'int\'>, <class \'int\'>]. Instead: '
'Passed object instance is of the proper type, but differs in '
'length from the hinted type. Expected a tuple of '
'length 3, received a tuple of length 2.',
e.exception.args[0])

def test_type_check_invalid_value_type(self):
hint = typehints.Mapping[str, typehints.Mapping[int, str]]
d = {'f': [1, 2, 3]}
with self.assertRaises(TypeError) as e:
hint.type_check(d)
self.assertEqual(
"Mapping[<class 'str'>, Mapping[<class 'int'>, <class 'str'>]] hint"
' value-type constraint violated. All values should be of type'
" Mapping[<class 'int'>, <class 'str'>]. Instead: Mapping"
' type-constraint violated. All passed instances must be of type'
' Mapping. [1, 2, 3] is of type list.',
e.exception.args[0],
)

def test_type_check_valid_simple_type(self):
hint = typehints.Mapping[int, str]
d = {4: 'f', 9: 'k'}
self.assertIsNone(hint.type_check(d))

def test_type_check_valid_composite_type(self):
hint = typehints.Mapping[typehints.Tuple[str, str], typehints.List[int]]
d = {('f', 'k'): [1, 2, 3], ('m', 'r'): [4, 6, 9]}
self.assertIsNone(hint.type_check(d))

def test_match_type_variables(self):
S = typehints.TypeVariable('S') # pylint: disable=invalid-name
T = typehints.TypeVariable('T') # pylint: disable=invalid-name
hint = typehints.Mapping[S, T]
self.assertEqual({
S: int, T: str
},
hint.match_type_variables(typehints.Mapping[int, str]))

def test_builtin_and_type_compatibility(self):
self.assertCompatible(typing.Mapping, dict)
self.assertCompatible(typing.Mapping[str, int], dict[str, int])
self.assertCompatible(
typing.Mapping[str, typing.List[int]], dict[str, list[int]])
self.assertCompatible(typing.Iterable[str], typing.Mapping[str, int])

def test_collections_compatibility(self):
self.assertCompatible(typing.Mapping, collections.defaultdict)
self.assertCompatible(
typing.Mapping[str, int], collections.defaultdict[str, int])
self.assertCompatible(
typing.Mapping[str, int], collections.OrderedDict[str, int])


class BaseSetHintTest:
class CommonTests(TypeHintTestCase):
def test_getitem_invalid_composite_type_param(self):
Expand Down

0 comments on commit 7130947

Please sign in to comment.