Skip to content

Commit e2ddd78

Browse files
committed
feat: add tests for union/optional types
1 parent b317c0f commit e2ddd78

File tree

2 files changed

+70
-0
lines changed

2 files changed

+70
-0
lines changed
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from types import UnionType
2+
from typing import Union
3+
from typing import get_args
4+
from typing import get_origin
5+
6+
NoneType = type(None)
7+
"""Helpful alias for `type(None)`."""
8+
9+
10+
def is_union(type_: type) -> bool:
11+
"""
12+
True if `type_` is a union type.
13+
14+
When declared with `Union[T, ...]` or `Optional[T]`, `get_origin()` returns `typing.Union`.
15+
When declared with PEP604 syntax `T | ...`, `get_origin()` returns `types.UnionType`.
16+
17+
Args:
18+
type_: The type to check.
19+
20+
Returns:
21+
True if `type_` is a union type.
22+
"""
23+
return get_origin(type_) is Union or get_origin(type_) is UnionType
24+
25+
26+
def is_optional(type_: type) -> bool:
27+
"""
28+
True if `_type` is `Optional`, or the union of a single type and `None`.
29+
30+
Args:
31+
type_: The type to check.
32+
33+
Returns:
34+
True if `_type` is `Optional[T]`, `Union[T, None]` or `T | None`.
35+
"""
36+
type_args: tuple[type] = get_args(type_)
37+
38+
return is_union(type_) and (NoneType in type_args) and (len(type_args) == 2)

tests/_lib/test_typing_extensions

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from typing import Union, Optional, Any
2+
from dataclass_io._lib.typing_extensions import is_optional
3+
from dataclass_io._lib.typing_extensions import is_union
4+
5+
6+
def test_is_union() -> None:
7+
"""Test that we can identify a union type."""
8+
9+
assert is_union(Union[int, None])
10+
assert is_union(Union[int, str])
11+
assert is_union(Union[int, str, float])
12+
assert is_union(Optional[int])
13+
14+
assert not is_union(int)
15+
assert not is_union(str)
16+
assert not is_union(dict[int, str])
17+
18+
19+
def test_is_optional() -> None:
20+
"""Test that we can identify an Optional type."""
21+
22+
assert is_optional(Union[int, None])
23+
assert is_optional(int | None)
24+
assert is_optional(Optional[int])
25+
26+
assert is_optional(Union[dict[str, Any], None])
27+
assert is_optional(dict[str, Any] | None)
28+
assert is_optional(Optional[dict[str, Any]])
29+
30+
assert not is_optional(int)
31+
assert not is_optional(Union[int, str])
32+
assert not is_optional(dict[str, Any])

0 commit comments

Comments
 (0)