Skip to content

Commit

Permalink
feat: Add field equivalent to origin or annotation. (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
DanCardin authored Aug 16, 2024
1 parent 0181ad8 commit 715a598
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 45 deletions.
23 changes: 23 additions & 0 deletions tests/test_type_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
from typing import (
TYPE_CHECKING,
Any,
Dict,
ForwardRef,
List,
Literal,
Optional,
Sequence,
Tuple,
TypedDict,
TypeVar,
Expand Down Expand Up @@ -362,3 +364,24 @@ def test_repr_type() -> None:

if sys.version_info >= (3, 9):
assert TypeView(set[bool]).repr_type == "set[bool]"


def test_instantiatable_origin() -> None:
assert TypeView(int).instantiable_origin == int
assert TypeView(list).instantiable_origin == list
assert TypeView(List[int]).instantiable_origin == list
assert TypeView(Dict[int, int]).instantiable_origin == dict
assert TypeView(Sequence[int]).instantiable_origin == list
assert TypeView(TypeView).instantiable_origin == TypeView


def test_fallback_origin() -> None:
assert TypeView(int).fallback_origin == int
assert TypeView(list).fallback_origin == list
assert TypeView(List).fallback_origin == list
assert TypeView(List[str]).fallback_origin == list
assert TypeView(Literal[1]).fallback_origin == Literal
assert TypeView(Literal).fallback_origin == Literal

if sys.version_info >= (3, 9):
assert TypeView(set[bool]).fallback_origin == set
10 changes: 6 additions & 4 deletions type_lens/type_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing_extensions import Annotated, NotRequired, Required, get_args, get_origin

from type_lens.types.builtins import UNION_TYPES, NoneType
from type_lens.utils import get_instantiable_origin, get_safe_generic_origin, unwrap_annotation
from type_lens.utils import INSTANTIABLE_TYPE_MAPPING, SAFE_GENERIC_ORIGIN_MAP, unwrap_annotation

__all__ = ("TypeView",)

Expand All @@ -24,6 +24,7 @@ class TypeView(Generic[T]):
"inner_types": "The type's generic args parsed as ParsedType, if applicable.",
"metadata": "Any metadata associated with the annotation via Annotated.",
"origin": "The result of calling get_origin(annotation) after unwrapping Annotated, e.g. list.",
"fallback_origin": "The unsubscripted version of a type, distinct from 'origin' in that for non-generics, this is the original type.",
"raw": "The annotation exactly as received.",
"_wrappers": "A set of wrapper types that were removed from the annotation.",
}
Expand All @@ -47,6 +48,7 @@ def __init__(self, annotation: T) -> None:
self.raw: Final[T] = annotation
self.annotation: Final = unwrapped
self.origin: Final = origin
self.fallback_origin: Final = origin or unwrapped
self.args: Final = args
self.metadata: Final = metadata
self._wrappers: Final = wrappers
Expand Down Expand Up @@ -101,7 +103,7 @@ def instantiable_origin(self) -> Any:
Returns:
An instantiable type that is consistent with the origin type of the annotation.
"""
return get_instantiable_origin(self)
return INSTANTIABLE_TYPE_MAPPING.get(self.fallback_origin, self.fallback_origin)

@property
def is_annotated(self) -> bool:
Expand Down Expand Up @@ -182,14 +184,14 @@ def is_variadic_tuple(self) -> bool:

@property
def safe_generic_origin(self) -> Any:
"""An object safe to be used as a generic type across all supported Python versions.
"""A type, safe to be used as a generic type across all supported Python versions.
Examples:
>>> from type_lens import TypeView
>>> TypeView(dict[str, int]).safe_generic_origin
typing.Dict
"""
return get_safe_generic_origin(self)
return SAFE_GENERIC_ORIGIN_MAP.get(self.fallback_origin)

def has_inner_subtype_of(self, typ: type[Any] | tuple[type[Any], ...]) -> bool:
"""Whether any generic args are a subclass of the given type.
Expand Down
44 changes: 3 additions & 41 deletions type_lens/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,9 @@

from type_lens.types.builtins import UNION_TYPES

if t.TYPE_CHECKING:
from type_lens import TypeView
__all__ = ("unwrap_annotation", "SAFE_GENERIC_ORIGIN_MAP", "INSTANTIABLE_TYPE_MAPPING")

__all__ = (
"get_instantiable_origin",
"get_safe_generic_origin",
"unwrap_annotation",
)

_SAFE_GENERIC_ORIGIN_MAP: te.Final[dict[object, object]] = {
SAFE_GENERIC_ORIGIN_MAP: te.Final[dict[object, object]] = {
set: t.AbstractSet,
defaultdict: t.DefaultDict,
deque: t.Deque,
Expand Down Expand Up @@ -56,7 +49,7 @@
_WRAPPER_TYPES: te.Final = {te.Annotated, te.Required, te.NotRequired}
"""Types that always contain a wrapped type annotation as their first arg."""

_INSTANTIABLE_TYPE_MAPPING: te.Final = {
INSTANTIABLE_TYPE_MAPPING: te.Final = {
t.AbstractSet: set,
t.DefaultDict: defaultdict,
t.Deque: deque,
Expand Down Expand Up @@ -87,37 +80,6 @@
"""A mapping of types to equivalent types that are safe to instantiate."""


def get_instantiable_origin(type_view: TypeView) -> t.Any:
"""Get a type that is safe to instantiate for the given origin type.
If a builtin collection type is annotated without generic args, e.g, ``a: dict``, then the origin type will be
``None``. In this case, we can use the annotation to determine the correct instantiable type, if one exists.
Args:
type_view: A :class:`TypeView` instance.
Returns:
A builtin type that is safe to instantiate for the given origin type.
"""
if type_view.origin is None:
return _INSTANTIABLE_TYPE_MAPPING.get(type_view.annotation)
return _INSTANTIABLE_TYPE_MAPPING.get(type_view.origin, type_view.origin)


def get_safe_generic_origin(type_view: TypeView) -> t.Any | None:
"""Get a type that is safe to use as a generic type across all supported Python versions.
Args:
type_view: A :class:`TypeView` instance.
Returns:
A type that is safe to use as a generic type across all supported Python versions.
"""
if type_view.origin is None:
return _SAFE_GENERIC_ORIGIN_MAP.get(type_view.annotation)
return _SAFE_GENERIC_ORIGIN_MAP.get(type_view.origin)


def unwrap_annotation(annotation: t.Any) -> tuple[t.Any, tuple[t.Any, ...], set[t.Any]]:
"""Remove "wrapper" annotation types, such as ``Annotated``, ``Required``, and ``NotRequired``.
Expand Down

0 comments on commit 715a598

Please sign in to comment.