Skip to content

Commit

Permalink
feat(django_getter): use function cache to increase performance
Browse files Browse the repository at this point in the history
Update schema.py
  • Loading branch information
Yorick Rommers authored and yorickr-sendcloud committed Aug 2, 2024
1 parent eecb05f commit 0bee043
Showing 1 changed file with 61 additions and 51 deletions.
112 changes: 61 additions & 51 deletions ninja/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@ def resolve_name(obj):
"""

from __future__ import annotations

import warnings
from functools import partial
from typing import (
Any,
Callable,
Dict,
Type,
ClassVar,
TypeVar,
Union,
no_type_check,
)

Expand All @@ -39,6 +40,7 @@ def resolve_name(obj):
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
from typing_extensions import dataclass_transform

from ninja.constants import NOT_SET
from ninja.signature.utils import get_args_names, has_kwargs
from ninja.types import DictStrAny

Expand All @@ -50,45 +52,68 @@ def resolve_name(obj):
S = TypeVar("S", bound="Schema")


def dict_getter(key: str, obj: DjangoGetter) -> Any:
if key not in obj._obj:
raise AttributeError(key)
return obj._obj[key]


def attr_getter(key: str, obj: DjangoGetter) -> Any:
try:
return Variable(key).resolve(obj._obj)
except VariableDoesNotExist as e:
raise AttributeError(key) from e


def resolver(resolve_func: Callable, _: str, obj: DjangoGetter) -> Any:
return resolve_func(getter=obj)


def get_attr(key: str, obj: DjangoGetter) -> Any:
return getattr(obj._obj, key)


class DjangoGetter:
__slots__ = ("_obj", "_schema_cls", "_context", "__dict__")
__slots__ = ("_obj", "_schema_cls", "_context", "__dict__", "_cache_key")
_cache: ClassVar[dict[str, Callable]] = {}

def __init__(self, obj: Any, schema_cls: Type[S], context: Any = None):
def __init__(self, obj: Any, schema_cls: type[S], context: Any = None) -> None:
self._obj = obj
self._schema_cls = schema_cls
self._context = context
self._cache_key = f"{self._schema_cls.__module__}.{self._schema_cls.__name__}.{self._obj.__class__.__name__}"

def __getattr__(self, key: str) -> Any:
# if key.startswith("__pydantic"):
# return getattr(self._obj, key)

resolver = self._schema_cls._ninja_resolvers.get(key)
if resolver:
value = resolver(getter=self)
else:
if isinstance(self._obj, dict):
if key not in self._obj:
raise AttributeError(key)
value = self._obj[key]
else:
try:
value = getattr(self._obj, key)
except AttributeError:
try:
# value = attrgetter(key)(self._obj)
value = Variable(key).resolve(self._obj)
# TODO: Variable(key) __init__ is actually slower than
# Variable.resolve - so it better be cached
except VariableDoesNotExist as e:
raise AttributeError(key) from e
cache_key = f"{self._cache_key}.{key}"
if cache_key in DjangoGetter._cache:
# Use cached function, if available.
value = DjangoGetter._cache[cache_key](key, self)
return self._convert_result(value)

stored_resolver = self._schema_cls._ninja_resolvers.get(key)
if stored_resolver:
# Use resolver when provided for this key.
value = stored_resolver(getter=self)
# bind resolver of this key to the _cache
DjangoGetter._cache[cache_key] = partial(resolver, stored_resolver)
return self._convert_result(value)

if isinstance(self._obj, dict):
# Use dict lookup, faster than getattr
value = dict_getter(key, self)
DjangoGetter._cache[cache_key] = dict_getter
return self._convert_result(value)

value = getattr(self._obj, key, NOT_SET)
if value is not NOT_SET:
# If getattr worked, use that.
DjangoGetter._cache[cache_key] = get_attr
return self._convert_result(value)
# Finally, fallback to attr_getter
value = attr_getter(key, self)
DjangoGetter._cache[cache_key] = attr_getter
return self._convert_result(value)

# def get(self, key: Any, default: Any = None) -> Any:
# try:
# return self[key]
# except KeyError:
# return default

def _convert_result(self, result: Any) -> Any:
if isinstance(result, Manager):
return list(result.all())
Expand Down Expand Up @@ -116,7 +141,7 @@ class Resolver:
_func: Any
_takes_context: bool

def __init__(self, func: Union[Callable, staticmethod]):
def __init__(self, func: Callable | staticmethod):
if isinstance(func, staticmethod):
self._static = True
self._func = func.__func__
Expand All @@ -139,25 +164,10 @@ def __call__(self, getter: DjangoGetter) -> Any:
) # pragma: no cover
# return self._func(self._fake_instance(getter), getter._obj)

# def _fake_instance(self, getter: DjangoGetter) -> "Schema":
# """
# Generate a partial schema instance that can be used as the ``self``
# attribute of resolver functions.
# """

# class PartialSchema(Schema):
# def __getattr__(self, key: str) -> Any:
# value = getattr(getter, key)
# field = getter._schema_cls.model_fields[key]
# value = field.validate(value, values={}, loc=key, cls=None)[0]
# return value

# return PartialSchema()


@dataclass_transform(kw_only_default=True, field_specifiers=(Field,))
class ResolverMetaclass(ModelMetaclass):
_ninja_resolvers: Dict[str, Resolver]
_ninja_resolvers: dict[str, Resolver]

@no_type_check
def __new__(cls, name, bases, namespace, **kwargs):
Expand Down Expand Up @@ -228,7 +238,7 @@ def _run_root_validator(
return handler(values)

@classmethod
def from_orm(cls: Type[S], obj: Any, **kw: Any) -> S:
def from_orm(cls: type[S], obj: Any, **kw: Any) -> S:
return cls.model_validate(obj, **kw)

def dict(self, *a: Any, **kw: Any) -> DictStrAny:
Expand Down

0 comments on commit 0bee043

Please sign in to comment.