From bc1f2a2240c2a7f6b39bbcff003f6230e7a82361 Mon Sep 17 00:00:00 2001 From: Yorick Rommers Date: Fri, 3 May 2024 15:48:49 +0200 Subject: [PATCH] feat(django_getter): use function cache to increase performance Update schema.py --- ninja/schema.py | 112 ++++++++++++++++++++++++++---------------------- 1 file changed, 61 insertions(+), 51 deletions(-) diff --git a/ninja/schema.py b/ninja/schema.py index e51dd43f..9d952266 100644 --- a/ninja/schema.py +++ b/ninja/schema.py @@ -18,14 +18,15 @@ def resolve_name(obj): """ +from __future__ import annotations + import warnings +from functools import partial from typing import ( Any, Callable, - Dict, - Type, + ClassVar, TypeVar, - Union, no_type_check, ) @@ -39,6 +40,7 @@ def resolve_name(obj): from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue from typing_extensions import dataclass_transform +from ninja.constants import NOT_SET from ninja.signature.utils import get_args_names, has_kwargs from ninja.types import DictStrAny @@ -50,45 +52,68 @@ def resolve_name(obj): S = TypeVar("S", bound="Schema") +def dict_getter(key: str, obj: DjangoGetter) -> Any: + if key not in obj._obj: + raise AttributeError(key) + return obj._obj[key] + + +def attr_getter(key: str, obj: DjangoGetter) -> Any: + try: + return Variable(key).resolve(obj._obj) + except VariableDoesNotExist as e: + raise AttributeError(key) from e + + +def resolver(resolve_func: Callable, _: str, obj: DjangoGetter) -> Any: + return resolve_func(getter=obj) + + +def get_attr(key: str, obj: DjangoGetter) -> Any: + return getattr(obj._obj, key) + + class DjangoGetter: - __slots__ = ("_obj", "_schema_cls", "_context", "__dict__") + __slots__ = ("_obj", "_schema_cls", "_context", "__dict__", "_cache_key") + _cache: ClassVar[dict[str, Callable]] = {} - def __init__(self, obj: Any, schema_cls: Type[S], context: Any = None): + def __init__(self, obj: Any, schema_cls: type[S], context: Any = None) -> None: self._obj = obj self._schema_cls = schema_cls self._context = context + self._cache_key = f"{self._schema_cls.__module__}.{self._schema_cls.__name__}.{self._obj.__class__.__name__}" def __getattr__(self, key: str) -> Any: - # if key.startswith("__pydantic"): - # return getattr(self._obj, key) - - resolver = self._schema_cls._ninja_resolvers.get(key) - if resolver: - value = resolver(getter=self) - else: - if isinstance(self._obj, dict): - if key not in self._obj: - raise AttributeError(key) - value = self._obj[key] - else: - try: - value = getattr(self._obj, key) - except AttributeError: - try: - # value = attrgetter(key)(self._obj) - value = Variable(key).resolve(self._obj) - # TODO: Variable(key) __init__ is actually slower than - # Variable.resolve - so it better be cached - except VariableDoesNotExist as e: - raise AttributeError(key) from e + cache_key = f"{self._cache_key}.{key}" + if cache_key in DjangoGetter._cache: + # Use cached function, if available. + value = DjangoGetter._cache[cache_key](key, self) + return self._convert_result(value) + + stored_resolver = self._schema_cls._ninja_resolvers.get(key) + if stored_resolver: + # Use resolver when provided for this key. + value = stored_resolver(getter=self) + # bind resolver of this key to the _cache + DjangoGetter._cache[cache_key] = partial(resolver, stored_resolver) + return self._convert_result(value) + + if isinstance(self._obj, dict): + # Use dict lookup, faster than getattr + value = dict_getter(key, self) + DjangoGetter._cache[cache_key] = dict_getter + return self._convert_result(value) + + value = getattr(self._obj, key, NOT_SET) + if value is not NOT_SET: + # If getattr worked, use that. + DjangoGetter._cache[cache_key] = get_attr + return self._convert_result(value) + # Finally, fallback to attr_getter + value = attr_getter(key, self) + DjangoGetter._cache[cache_key] = attr_getter return self._convert_result(value) - # def get(self, key: Any, default: Any = None) -> Any: - # try: - # return self[key] - # except KeyError: - # return default - def _convert_result(self, result: Any) -> Any: if isinstance(result, Manager): return list(result.all()) @@ -116,7 +141,7 @@ class Resolver: _func: Any _takes_context: bool - def __init__(self, func: Union[Callable, staticmethod]): + def __init__(self, func: Callable | staticmethod): if isinstance(func, staticmethod): self._static = True self._func = func.__func__ @@ -139,25 +164,10 @@ def __call__(self, getter: DjangoGetter) -> Any: ) # pragma: no cover # return self._func(self._fake_instance(getter), getter._obj) - # def _fake_instance(self, getter: DjangoGetter) -> "Schema": - # """ - # Generate a partial schema instance that can be used as the ``self`` - # attribute of resolver functions. - # """ - - # class PartialSchema(Schema): - # def __getattr__(self, key: str) -> Any: - # value = getattr(getter, key) - # field = getter._schema_cls.model_fields[key] - # value = field.validate(value, values={}, loc=key, cls=None)[0] - # return value - - # return PartialSchema() - @dataclass_transform(kw_only_default=True, field_specifiers=(Field,)) class ResolverMetaclass(ModelMetaclass): - _ninja_resolvers: Dict[str, Resolver] + _ninja_resolvers: dict[str, Resolver] @no_type_check def __new__(cls, name, bases, namespace, **kwargs): @@ -228,7 +238,7 @@ def _run_root_validator( return handler(values) @classmethod - def from_orm(cls: Type[S], obj: Any, **kw: Any) -> S: + def from_orm(cls: type[S], obj: Any, **kw: Any) -> S: return cls.model_validate(obj, **kw) def dict(self, *a: Any, **kw: Any) -> DictStrAny: