Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ In your GraphQL queries, you can use the `offset` and `limit` parameters:
You can also enable pagination for nested relationships:

```python
@strawchemy.type(User, include="all", child_pagination=True)
@strawchemy.type(User, include="all", paginate="all")
class UserType:
pass
```
Expand Down
11 changes: 10 additions & 1 deletion src/strawchemy/dto/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from enum import Enum
from typing import TYPE_CHECKING, Any, Literal, TypeAlias, final, get_type_hints

from typing_extensions import override
from typing_extensions import Self, override

from strawchemy.utils.annotation import get_annotations

Expand Down Expand Up @@ -185,6 +185,10 @@ def __post_init__(self) -> None:
if self.exclude:
self.include = "all"

@classmethod
def from_include(cls, include: IncludeFields | None = None, purpose: Purpose = Purpose.READ) -> Self:
return cls(purpose, include=set() if include is None else include)
Comment on lines +199 to +220
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick | 🔵 Trivial

Consider the past suggestion to make include keyword-only.

A previous review recommended making the include parameter keyword-only to avoid potential misuse with boolean positional arguments. While the current implementation is functional, the suggestion remains valid for improved API safety.

🔎 Reference to past review

The past review suggested:

@classmethod
def from_include(cls, *, include: IncludeFields | None = None, purpose: Purpose = Purpose.READ) -> Self:

This would require callers to explicitly use include=... when calling the method.

🤖 Prompt for AI Agents
In src/strawchemy/dto/types.py around lines 188 to 209, change the classmethod
signature to make include a keyword-only parameter to prevent accidental
positional misuse: update the signature to accept a leading "*" (e.g., def
from_include(cls, *, include: IncludeFields | None = None, purpose: Purpose =
Purpose.READ) -> Self), keep the existing logic that converts include None to an
empty set, and ensure any call sites are updated to pass include=... (and
purpose=... if needed).


def copy_with(
self,
purpose: Purpose | type[DTOUnset] = DTOUnset,
Expand Down Expand Up @@ -265,3 +269,8 @@ def alias(self, name: str) -> str | None:
if self.alias_generator is not None:
return self.alias_generator(name)
return None

def is_field_included(self, name: str) -> bool:
if self.include == "all":
return True
return name in self.include and name not in self.exclude
49 changes: 31 additions & 18 deletions src/strawchemy/schema/factories/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
from strawchemy.dto.utils import config
from strawchemy.exceptions import StrawchemyError
from strawchemy.instance import MapperModelInstance
from strawchemy.schema.pagination import DefaultOffsetPagination
from strawchemy.transpiler import hook
from strawchemy.typing import GraphQLDTOT, GraphQLPurpose, GraphQLType, MappedGraphQLDTO
from strawchemy.utils.annotation import get_annotations
Expand All @@ -51,6 +50,7 @@
from strawchemy import Strawchemy
from strawchemy.dto.inspectors import SQLAlchemyGraphQLInspector
from strawchemy.dto.types import DTOConfig, ExcludeFields, IncludeFields
from strawchemy.schema.pagination import DefaultOffsetPagination
from strawchemy.transpiler.hook import QueryHook
from strawchemy.utils.graph import Node
from strawchemy.validation.pydantic import MappedPydanticGraphQLDTO
Expand Down Expand Up @@ -97,9 +97,10 @@ def _type_info(
current_node: Node[Relation[Any, GraphQLDTOT], None] | None,
override: bool = False,
user_defined: bool = False,
child_options: ChildOptions | None = None,
order: IncludeFields | None = None,
paginate: IncludeFields | None = None,
default_pagination: None | DefaultOffsetPagination = None,
) -> RegistryTypeInfo:
child_options = child_options or ChildOptions()
graphql_type = self.graphql_type(dto_config)
model: type[DeclarativeBase] | None = dto.__dto_model__ if issubclass(dto, MappedStrawberryGraphQLDTO) else None # type: ignore[reportGeneralTypeIssues]
default_name = self.root_dto_name(model, dto_config, current_node) if model else dto.__name__
Expand All @@ -109,8 +110,9 @@ def _type_info(
graphql_type=graphql_type,
override=override,
user_defined=user_defined,
pagination=DefaultOffsetPagination() if child_options.pagination is True else child_options.pagination,
order_by=child_options.order_by,
pagination=default_pagination,
order=frozenset() if order is None else frozenset(order),
paginate=frozenset() if paginate is None else frozenset(paginate),
scope=dto_config.scope,
model=model,
exclude_from_scope=dto_config.exclude_from_scope,
Expand All @@ -130,14 +132,18 @@ def _register_type(
directives: Sequence[object] | None = (),
override: bool = False,
user_defined: bool = False,
child_options: ChildOptions | None = None,
order: IncludeFields | None = None,
paginate: IncludeFields | None = None,
default_pagination: None | DefaultOffsetPagination = None,
) -> type[StrawchemyDTOT]:
type_info = self._type_info(
dto,
dto_config,
override=override,
user_defined=user_defined,
child_options=child_options,
order=order,
paginate=paginate,
default_pagination=default_pagination,
current_node=current_node,
)
self._raise_if_type_conflicts(type_info)
Expand Down Expand Up @@ -212,8 +218,9 @@ def _type_wrapper(
type_map: Mapping[Any, Any] | None = None,
aliases: Mapping[str, str] | None = None,
alias_generator: Callable[[str], str] | None = None,
child_pagination: bool | DefaultOffsetPagination = False,
child_order_by: bool = False,
paginate: IncludeFields | None = None,
order: IncludeFields | None = None,
default_pagination: None | DefaultOffsetPagination = None,
filter_input: type[BooleanFilterDTO] | None = None,
order_by: type[OrderByDTO] | None = None,
name: str | None = None,
Expand Down Expand Up @@ -247,7 +254,9 @@ def wrapper(class_: type[Any]) -> type[GraphQLDTOT]:
override=override,
user_defined=True,
mode=mode,
child_options=ChildOptions(pagination=child_pagination, order_by=child_order_by),
paginate=paginate,
order=order,
default_pagination=default_pagination,
)
dto.__strawchemy_query_hook__ = query_hook
if issubclass(dto, MappedStrawberryGraphQLDTO):
Expand Down Expand Up @@ -436,8 +445,9 @@ def type(
type_map: Mapping[Any, Any] | None = None,
aliases: Mapping[str, str] | None = None,
alias_generator: Callable[[str], str] | None = None,
child_pagination: bool | DefaultOffsetPagination = False,
child_order_by: bool = False,
paginate: IncludeFields | None = None,
order: IncludeFields | None = None,
default_pagination: None | DefaultOffsetPagination = None,
filter_input: type[BooleanFilterDTO] | None = None,
order_by: type[OrderByDTO] | None = None,
name: str | None = None,
Expand All @@ -457,8 +467,9 @@ def type(
type_map=type_map,
aliases=aliases,
alias_generator=alias_generator,
child_pagination=child_pagination,
child_order_by=child_order_by,
paginate=paginate,
order=order,
default_pagination=default_pagination,
filter_input=filter_input,
order_by=order_by,
name=name,
Expand Down Expand Up @@ -587,8 +598,9 @@ def type(
type_map: Mapping[Any, Any] | None = None,
aliases: Mapping[str, str] | None = None,
alias_generator: Callable[[str], str] | None = None,
child_pagination: bool | DefaultOffsetPagination = False,
child_order_by: bool = False,
order: IncludeFields | None = None,
paginate: IncludeFields | None = None,
default_pagination: None | DefaultOffsetPagination = None,
filter_input: type[BooleanFilterDTO] | None = None,
order_by: type[OrderByDTO] | None = None,
name: str | None = None,
Expand All @@ -607,8 +619,9 @@ def type(
type_map=type_map,
aliases=aliases,
alias_generator=alias_generator,
child_pagination=child_pagination,
child_order_by=child_order_by,
paginate=paginate,
order=order,
default_pagination=default_pagination,
filter_input=filter_input,
order_by=order_by,
name=name,
Expand Down
4 changes: 3 additions & 1 deletion src/strawchemy/schema/factories/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,9 @@ def iter_field_definitions(
if field.uselist and field.related_dto:
field.type_ = Union[field.related_dto, None]
if aggregate_filters:
aggregation_field = self._aggregation_field(field, dto_config.copy_with(partial_default=UNSET))
aggregation_field = self._aggregation_field(
field, dto_config.copy_with(partial_default=UNSET, partial=True)
)
field_map[key + aggregation_field.name] = aggregation_field
yield aggregation_field
else:
Expand Down
52 changes: 26 additions & 26 deletions src/strawchemy/schema/factories/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@
GraphQLFieldDefinition,
MappedStrawberryGraphQLDTO,
)
from strawchemy.dto.types import DTOConfig, DTOMissing, Purpose
from strawchemy.dto.types import DTOConfig, DTOMissing, IncludeFields, Purpose
from strawchemy.dto.utils import read_all_partial_config, read_partial, write_all_config
from strawchemy.exceptions import EmptyDTOError
from strawchemy.schema.factories import (
AggregationInspector,
ChildOptions,
EnumDTOFactory,
GraphQLDTOFactory,
MappedGraphQLDTOT,
Expand Down Expand Up @@ -104,27 +103,34 @@ def _aggregation_field(
related_dto=dto,
)

def _update_fields(
def _add_fields_arguments(
self,
dto: type[GraphQLDTOT],
base: type[Any] | None,
pagination: bool | DefaultOffsetPagination = False,
order_by: bool = False,
order: IncludeFields | None = None,
paginate: IncludeFields | None = None,
default_pagination: None | DefaultOffsetPagination = None,
) -> type[GraphQLDTOT]:
attributes: dict[str, Any] = {}
annotations: dict[str, Any] = {}
order_config = DTOConfig.from_include(order)
paginate_config = DTOConfig.from_include(paginate)

for field in dto.__strawchemy_field_map__.values():
# Add pagination and ordering arguments for relations
if field.is_relation and field.uselist:
related = Self if field.related_dto is dto else field.related_dto
type_annotation = list[related] if related is not None else field.type_
assert field.related_model
order_by_input = None
if order_by:
order_by_input, pagination = None, False
if order_config.is_field_included(field.model_field_name):
order_by_input = self._order_by_factory.factory(field.related_model, read_all_partial_config)
if paginate_config.is_field_included(field.model_field_name):
pagination = default_pagination or True
strawberry_field = self._mapper.field(pagination=pagination, order_by=order_by_input, root_field=False)
attributes[field.name] = strawberry_field
annotations[field.name] = type_annotation
# Add path filtering argument for JSON fields
elif (
not field.is_relation
and field.has_model_field
Expand Down Expand Up @@ -152,18 +158,6 @@ def _update_fields(
setattr(dto, name, value)
return dto

@override
def _cache_key(
self,
model: type[Any],
dto_config: DTOConfig,
node: Node[Relation[Any, MappedGraphQLDTOT], None],
*,
child_options: ChildOptions,
**factory_kwargs: Any,
) -> Hashable:
return (super()._cache_key(model, dto_config, node, **factory_kwargs), child_options)

@override
def dto_name(
self, base_name: str, dto_config: DTOConfig, node: Node[Relation[Any, MappedGraphQLDTOT], None] | None = None
Expand Down Expand Up @@ -208,7 +202,9 @@ def factory(
tags: set[str] | None = None,
backend_kwargs: dict[str, Any] | None = None,
*,
child_options: ChildOptions | None = None,
default_pagination: None | DefaultOffsetPagination = None,
order: IncludeFields | None = None,
paginate: IncludeFields | None = None,
aggregations: bool = True,
description: str | None = None,
directives: Sequence[object] | None = (),
Expand All @@ -230,12 +226,15 @@ def factory(
aggregations=aggregations if dto_config.purpose is Purpose.READ else False,
register_type=False,
override=override,
child_options=child_options,
paginate=paginate if paginate == "all" else None,
order=order if order == "all" else None,
default_pagination=default_pagination,
**kwargs,
)
child_options = child_options or ChildOptions()
if self.graphql_type(dto_config) == "object":
dto = self._update_fields(dto, base, pagination=child_options.pagination, order_by=child_options.order_by)
dto = self._add_fields_arguments(
dto, base, default_pagination=default_pagination, order=order, paginate=paginate
)
if register_type:
return self._register_type(
dto,
Expand All @@ -244,7 +243,9 @@ def factory(
directives=directives,
override=override,
user_defined=user_defined,
child_options=child_options,
default_pagination=default_pagination,
order=order,
paginate=paginate,
current_node=current_node,
)
return dto
Expand Down Expand Up @@ -558,12 +559,11 @@ def _cache_key(
dto_config: DTOConfig,
node: Node[Relation[Any, MappedGraphQLDTOT], None],
*,
child_options: ChildOptions,
mode: GraphQLPurpose,
**factory_kwargs: Any,
) -> Hashable:
return (
super()._cache_key(model, dto_config, node, child_options=child_options, **factory_kwargs),
super()._cache_key(model, dto_config, node, **factory_kwargs),
node.root.value.model,
mode,
)
Expand Down
1 change: 1 addition & 0 deletions src/strawchemy/schema/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def __init__(
self.pagination: DefaultOffsetPagination | Literal[False] = (
DefaultOffsetPagination() if pagination is True else pagination
)

self.id_field_name = id_field_name

self._filter = filter_type
Expand Down
7 changes: 4 additions & 3 deletions src/strawchemy/utils/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,17 @@ class RegistryTypeInfo:
default_name: str | None = None
user_defined: bool = False
override: bool = False
pagination: DefaultOffsetPagination | Literal[False] = False
order_by: bool = False
pagination: DefaultOffsetPagination | None = None
order: frozenset[str] | Literal["all"] = dataclasses.field(default_factory=frozenset)
paginate: frozenset[str] | Literal["all"] = dataclasses.field(default_factory=frozenset)
scope: DTOScope | None = None
model: type[DeclarativeBase] | None = None
tags: frozenset[str] = dataclasses.field(default_factory=frozenset)
exclude_from_scope: bool = False

@property
def scoped_id(self) -> Hashable:
return (self.model, self.graphql_type, self.tags)
return self.model, self.graphql_type, self.tags


class StrawberryRegistry:
Expand Down
6 changes: 3 additions & 3 deletions tests/integration/types/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ class OrderedFruitType: ...
class FruitAggregationType: ...


@strawchemy.type(Fruit, include="all", child_pagination=True, child_order_by=True)
@strawchemy.type(Fruit, include="all", paginate="all", order="all")
class FruitTypeWithPaginationAndOrderBy: ...


Expand Down Expand Up @@ -182,7 +182,7 @@ class FruitUpsertConflictFields: ...
# Color


@strawchemy.type(Color, include="all", override=True, child_order_by=True)
@strawchemy.type(Color, include="all", override=True, order="all")
class ColorType: ...


Expand All @@ -194,7 +194,7 @@ class ColorOrder: ...
class ColorDistinctOn: ...


@strawchemy.type(Color, include="all", child_pagination=True)
@strawchemy.type(Color, include="all", paginate="all")
class ColorTypeWithPagination: ...


Expand Down
6 changes: 3 additions & 3 deletions tests/integration/types/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ class OrderedFruitType: ...
class FruitAggregationType: ...


@strawchemy.type(Fruit, include="all", child_pagination=True, child_order_by=True)
@strawchemy.type(Fruit, include="all", paginate="all", order="all")
class FruitTypeWithPaginationAndOrderBy: ...


Expand Down Expand Up @@ -190,7 +190,7 @@ class FruitUpsertConflictFields: ...
# Color


@strawchemy.type(Color, include="all", override=True, child_order_by=True)
@strawchemy.type(Color, include="all", override=True, order="all")
class ColorType: ...


Expand All @@ -202,7 +202,7 @@ class ColorOrder: ...
class ColorDistinctOn: ...


@strawchemy.type(Color, include="all", child_pagination=True)
@strawchemy.type(Color, include="all", paginate="all")
class ColorTypeWithPagination: ...


Expand Down
Loading
Loading