Skip to content

Commit fe0ea6f

Browse files
Extend type checking to tuple and type annotations.
1 parent a1ebf69 commit fe0ea6f

File tree

2 files changed

+56
-9
lines changed

2 files changed

+56
-9
lines changed

spec_classes/utils/type_checking.py

+28-8
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import inspect
22
import numbers
3+
import sys
4+
import types
35
from collections.abc import Sequence as SequenceMutator
46
from collections.abc import Set as SetMutator
57
from typing import (
@@ -37,36 +39,54 @@ def check_type(value: Any, attr_type: Type) -> bool:
3739
"""
3840
Check whether a given object `value` matches the provided `attr_type`.
3941
"""
40-
if attr_type is Any:
42+
if attr_type is Any or isinstance(attr_type, TypeVar):
4143
return True
4244

4345
if attr_type is float:
4446
attr_type = numbers.Real
4547

48+
if sys.version_info >= (3, 10) and isinstance(attr_type, types.UnionType):
49+
return any(check_type(value, type_) for type_ in attr_type.__args__)
50+
4651
if hasattr(attr_type, "__origin__"): # we are dealing with a `typing` object.
4752
if attr_type.__origin__ is Union:
4853
return any(check_type(value, type_) for type_ in attr_type.__args__)
4954

5055
if attr_type.__origin__ in (Literal, LiteralExtension):
5156
return value in attr_type.__args__
5257

53-
if isinstance(attr_type, _GenericAlias):
58+
if (
59+
isinstance(attr_type, _GenericAlias)
60+
or sys.version_info >= (3, 9)
61+
and isinstance(attr_type, types.GenericAlias)
62+
):
5463
if not isinstance(value, attr_type.__origin__):
5564
return False
56-
if attr_type._name in ("List", "Set") and not isinstance(
57-
attr_type.__args__[0], TypeVar
58-
): # pylint: disable=protected-access
65+
if attr_type.__origin__ in (list, set):
5966
for item in value:
6067
if not check_type(item, attr_type.__args__[0]):
6168
return False
62-
elif attr_type._name == "Dict" and not isinstance(
63-
attr_type.__args__[0], TypeVar
64-
): # pylint: disable=protected-access
69+
elif attr_type.__origin__ == dict:
6570
for k, v in value.items():
6671
if not check_type(k, attr_type.__args__[0]):
6772
return False
6873
if not check_type(v, attr_type.__args__[1]):
6974
return False
75+
elif attr_type.__origin__ == tuple:
76+
if len(attr_type.__args__) == 2 and attr_type.__args__[1] is Ellipsis:
77+
for item in value:
78+
if not check_type(item, attr_type.__args__[0]):
79+
return False
80+
else:
81+
if len(value) != len(attr_type.__args__):
82+
return False
83+
for i, item in enumerate(value):
84+
if not check_type(item, attr_type.__args__[i]):
85+
return False
86+
elif attr_type.__origin__ == type:
87+
if not issubclass(value, attr_type.__args__[0]):
88+
return False
89+
7090
return True
7191

7292
return isinstance(

tests/utils/test_type_checking.py

+28-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Any, Callable, Dict, List, Set, TypeVar, Union
1+
import sys
2+
from typing import Any, Callable, Dict, List, Set, Tuple, Type, TypeVar, Union
23

34
from typing_extensions import Literal
45

@@ -38,6 +39,11 @@ def test_type_checking(self):
3839
assert check_type(["a", "b"], List[str])
3940
assert not check_type([1, 2], List[str])
4041

42+
assert check_type((), Tuple)
43+
assert check_type((1, "a"), Tuple[int, str])
44+
assert not check_type((1,), Tuple[str])
45+
assert not check_type(("1", 2), Tuple[str, ...])
46+
4147
assert check_type({}, Dict)
4248
assert not check_type("a", Dict)
4349
assert check_type({"a": 1, "b": 2}, Dict[str, int])
@@ -56,9 +62,30 @@ def test_type_checking(self):
5662
assert check_type("hi", Literal["hi"])
5763
assert not check_type(1, Literal["hi"])
5864

65+
class MyType:
66+
pass
67+
68+
class SubType(MyType):
69+
pass
70+
71+
assert check_type(MyType, Type[MyType])
72+
assert check_type(SubType, Type[MyType])
73+
5974
assert check_type(1, float)
6075
assert not check_type(1.0, int)
6176

77+
if sys.version_info >= (3, 9):
78+
assert check_type(["a", "b"], list[str])
79+
assert check_type((1, 2, 3, 4), tuple[int, ...])
80+
assert check_type((1, 2, 3), tuple[int, int, float])
81+
assert not check_type((1, 2, 3), tuple[int, int])
82+
assert not check_type({1: "a", 2: "b"}, dict[str, int])
83+
assert not check_type({1, 2}, set[str])
84+
assert not check_type(str, type[MyType])
85+
86+
if sys.version_info >= (3, 10):
87+
assert check_type([1, "a"], list[str | int])
88+
6289
def test_get_collection_item_type(self):
6390
assert get_collection_item_type(list) is Any
6491
assert get_collection_item_type(List) is Any

0 commit comments

Comments
 (0)