Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 61 additions & 2 deletions src/strawchemy/dto/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,14 @@ class DTOConfig:
tags: set[str] = field(default_factory=set)

def __post_init__(self) -> None:
"""
Validate and normalize DTOConfig after initialization.

Performs consistency checks and normalizes dependent fields:
- Raises ValueError if both `aliases` and `alias_generator` are provided.
- Raises ValueError if `exclude` is non-empty while `include` is a specific collection (not `"all"` or empty).
- If `exclude` is non-empty, sets `include` to `"all"`.
"""
if self.aliases and self.alias_generator is not None:
msg = "You must set `aliases` or `alias_generator`, not both"
raise ValueError(msg)
Expand All @@ -187,6 +195,16 @@ def __post_init__(self) -> None:

@classmethod
def from_include(cls, include: IncludeFields | None = None, purpose: Purpose = Purpose.READ) -> Self:
"""
Create a DTOConfig initialized with the given inclusion set and purpose.

Parameters:
include (IncludeFields | None): Fields to include for the DTO. If None, the include set will be an empty set.
purpose (Purpose): The purpose that the resulting DTOConfig will use.

Returns:
A DTOConfig instance with `purpose` set to the provided value and `include` normalized to the provided collection or an empty set when `include` is None.
"""
return cls(purpose, include=set() if include is None else include)

def copy_with(
Expand All @@ -206,7 +224,30 @@ def copy_with(
exclude_from_scope: bool | type[DTOUnset] = DTOUnset,
tags: set[str] | type[DTOUnset] = DTOUnset,
) -> DTOConfig:
"""Create a copy of the DTOConfig with the specified changes."""
"""
Return a new DTOConfig with the provided fields replaced.

If a parameter is passed as the DTOUnset sentinel (the default for most parameters), the corresponding value from the original DTOConfig is retained. If both `include` and `exclude` are None, their values are copied from the original config; otherwise `include`/`exclude` are normalized to empty collections when falsy.

Parameters:
purpose: The purpose to set on the new config or DTOUnset to keep the original.
include: Field names to include, the literal "all", or None to use original include behavior.
exclude: Field names to exclude, or None to use original exclude behavior.
partial: Whether the DTO is partial, or DTOUnset to keep original.
unset_sentinel: Sentinel value used to represent unset fields, or DTOUnset to keep original.
type_overrides: Mapping of type overrides or DTOUnset to keep original.
annotation_overrides: Mapping of annotation overrides or DTOUnset to keep original.
aliases: Mapping of field name -> alias or DTOUnset to keep original.
exclude_defaults: Whether to exclude default-valued fields or DTOUnset to keep original.
alias_generator: Callable to generate aliases or DTOUnset to keep original.
partial_default: Default value for partial fields or DTOUnset to keep original.
scope: DTO scope ("global" or "dto") or DTOUnset to keep original.
exclude_from_scope: Whether to exclude fields from scope or DTOUnset to keep original.
tags: Set of tags to apply or DTOUnset to keep original.

Returns:
DTOConfig: A new configuration object reflecting the specified overrides.
"""
if include is None and exclude is None:
include, exclude = self.include, self.exclude
else:
Expand Down Expand Up @@ -264,11 +305,29 @@ def with_base_annotations(self, base: type[Any]) -> DTOConfig:
)

def alias(self, name: str) -> str | None:
"""
Resolve the configured external name (alias) for a DTO field.

Parameters:
name (str): The DTO field name to resolve.

Returns:
str | None: The alias for `name` if configured; prefers an explicit entry in `aliases` over `alias_generator`. Returns `None` if no alias is configured.
"""
if self.aliases:
return self.aliases.get(name)
if self.alias_generator is not None:
return self.alias_generator(name)
return None

def is_field_included(self, name: str) -> bool:
return (name in self.include or self.include == "all") and name not in self.exclude
"""
Determine whether a given field name is selected by the config's include/exclude rules.

Parameters:
name (str): The field name to check.

Returns:
`true` if the field is included (explicitly or because `include` is `"all"`) and not excluded, `false` otherwise.
"""
return (name in self.include or self.include == "all") and name not in self.exclude
111 changes: 110 additions & 1 deletion src/strawchemy/schema/factories/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,25 @@ def _type_info(
paginate: IncludeFields | None = None,
default_pagination: None | DefaultOffsetPagination = None,
) -> RegistryTypeInfo:
"""
Builds a RegistryTypeInfo describing how a DTO should be registered with the mapper.

Parameters:
dto: The DTO class to describe.
dto_config: Configuration for the DTO (scope, exclude_from_scope, purpose, etc.).
current_node: Optional registry node representing the DTO's position in a relationship graph; used to disambiguate names on conflicts.
override: If True, marks the resulting type as allowed to override existing registrations.
user_defined: If True, marks the type as provided by the user (affects conflict handling).
order: Optional set of child field names that should be treated as orderable for this type.
paginate: Optional set of child field names that should be treated as paginable for this type.
default_pagination: Optional pagination configuration to apply as the type's default pagination behavior.

Returns:
A RegistryTypeInfo populated with the DTO's registration metadata (name, default_name, graphql_type, scope,
model if mapped, pagination/order/paginate sets, override/user_defined flags, and exclude_from_scope).
If the mapper registry reports a name clash and `current_node` is provided, the returned `name` is replaced
with a path-derived unique name.
"""
graphql_type = self.graphql_type(dto_config)
model: type[DeclarativeBase] | None = dto.__dto_model__ if issubclass(dto, MappedStrawberryGraphQLDTO) else None # type: ignore[reportGeneralTypeIssues]
default_name = self.root_dto_name(model, dto_config, current_node) if model else dto.__name__
Expand Down Expand Up @@ -136,6 +155,24 @@ def _register_type(
paginate: IncludeFields | None = None,
default_pagination: None | DefaultOffsetPagination = None,
) -> type[StrawchemyDTOT]:
"""
Register a DTO type in the mapper registry with GraphQL-related metadata.

Parameters:
dto (type[StrawchemyDTOT]): The DTO class to register.
dto_config (DTOConfig): Configuration used to build the registry type info.
current_node (Node[Relation[Any, GraphQLDTOT], None] | None): Optional registry node representing the DTO's position in relation graphs; used when constructing type metadata.
description (str | None): Optional schema description to store for the type; if omitted the DTO's internal description is used.
directives (Sequence[object] | None): Optional GraphQL directives to attach to the registered type.
override (bool): If True, allow replacing an existing overridable type in the registry.
user_defined (bool): If True, mark the registered type as provided by user code.
order (IncludeFields | None): Optional set/list of child field names to expose for ordering; when None no order-by fields are added.
paginate (IncludeFields | None): Optional set/list of child field names to expose for pagination; when None pagination is not enabled for children.
default_pagination (DefaultOffsetPagination | None): Optional default pagination strategy to apply when pagination is enabled.

Returns:
type[StrawchemyDTOT]: The DTO class that was registered.
"""
type_info = self._type_info(
dto,
dto_config,
Expand Down Expand Up @@ -231,6 +268,36 @@ def _type_wrapper(
purpose: Purpose = Purpose.READ,
scope: DTOScope | None = None,
) -> Callable[[type[Any]], type[GraphQLDTOT]]:
"""
Create a decorator that builds and registers a GraphQL DTO class for the given model and GraphQL purpose.

The returned decorator accepts a user-defined base class and produces a Strawberry-compatible DTO configured with the provided include/exclude rules, mapping aliases, pagination and ordering settings, hooks, and GraphQL metadata.

Parameters:
model: The mapped domain model type the DTO represents.
mode: GraphQL purpose tag for the DTO (e.g., "type", "input", "create", "update"); used as the DTO's purpose and tag.
include: Fields to explicitly include from the model.
exclude: Fields to explicitly exclude from the model.
partial: If True, make fields optional to represent partial inputs.
type_map: Custom type mappings for field types.
aliases: Explicit field name aliases.
alias_generator: Callable to generate field aliases from names.
paginate: Fields eligible for pagination on child relations.
order: Fields eligible for ordering on child relations.
default_pagination: Default pagination strategy to attach to the type when pagination is enabled.
filter_input: DTO class used for filtering when the created DTO is mapped.
order_by: DTO class used for ordering when the created DTO is mapped.
name: Explicit GraphQL name for the generated type.
description: GraphQL description text for the generated type.
directives: Sequence of GraphQL directives to attach to the type.
query_hook: Single or list of query hooks to be attached to the created DTO to modify query behavior.
override: If True, allow registering a type that overrides an existing non-user-defined registered type.
purpose: Higher-level purpose enum guiding DTO creation (defaults to read).
scope: DTO scope to register the type under (e.g., "global" or "dto").

Returns:
A decorator that accepts a user base class and returns the constructed GraphQL DTO type.
"""
def wrapper(class_: type[Any]) -> type[GraphQLDTOT]:
dto_config = config(
purpose,
Expand Down Expand Up @@ -459,6 +526,34 @@ def type(
scope: TypeScope | None = None,
mode: GraphQLPurpose = "type",
) -> Callable[[type[Any]], type[MappedGraphQLDTO[T]]]:
"""
Create a decorator that builds a mapped GraphQL DTO for the given SQLAlchemy model with GraphQL-specific options.

Parameters:
model: The mapped SQLAlchemy model type to base the DTO on.
include: Fields to include in the DTO; defaults to None (use model defaults).
exclude: Fields to exclude from the DTO; defaults to None.
partial: If True, make fields optional for partial updates; defaults to None.
type_map: Custom type mapping for field types.
aliases: Explicit field name aliases.
alias_generator: Callable to generate field aliases from attribute names.
paginate: Fields that should support pagination when exposed as child relations.
order: Fields that should expose ordering controls when exposed as child relations.
default_pagination: Default pagination strategy to apply for paginated child fields.
filter_input: DTO class used for boolean/filter input generation.
order_by: DTO class used for order-by input generation.
name: Explicit GraphQL type name override.
description: GraphQL type description.
directives: GraphQL directives to attach to the type.
query_hook: Hook or list of hooks to modify queries for this DTO.
override: If True, allow overriding an existing registered type with the same name.
purpose: Purpose of the DTO (read/write); defaults to read.
scope: TypeScope to convert into a DTOScope for registration.
mode: GraphQL purpose mode ("type", "input", etc.) controlling DTO shape.

Returns:
A decorator that, when applied to a class, returns the constructed mapped GraphQL DTO type.
"""
return self._type_wrapper(
model=model,
include=include,
Expand Down Expand Up @@ -611,6 +706,20 @@ def type(
purpose: Purpose = Purpose.READ,
mode: GraphQLPurpose = "type",
) -> Callable[[type[Any]], type[UnmappedStrawberryGraphQLDTO[T]]]:
"""
Create a decorator that builds an unmapped Strawberry GraphQL DTO for the given model.

Parameters:
paginate (IncludeFields | None): Fields eligible for per-child pagination; if `None`, no per-field pagination is configured.
order (IncludeFields | None): Fields eligible for per-child ordering; if `None`, no per-field ordering is configured.
default_pagination (DefaultOffsetPagination | None): Default pagination strategy to attach to the created DTO when pagination is enabled.
mode (GraphQLPurpose): GraphQL role for the generated type (e.g., `"type"` for object types or input modes for write operations).
purpose (Purpose): Overall DTO purpose (read vs write) which influences GraphQL kind and available behaviors.
query_hook (QueryHook[T] | list[QueryHook[T]] | None): Optional hook(s) applied to queries produced for the DTO.

Returns:
decorator: A class decorator that, when applied, produces an unmapped Strawberry-based GraphQL DTO configured with the provided options.
"""
return self._type_wrapper(
model=model,
include=include,
Expand All @@ -631,4 +740,4 @@ def type(
override=override,
purpose=purpose,
mode=mode,
)
)
21 changes: 20 additions & 1 deletion src/strawchemy/schema/factories/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,25 @@ def iter_field_definitions(
field_map: dict[DTOKey, GraphQLFieldDefinition] | None = None,
**kwargs: Any,
) -> Generator[DTOFieldDefinition[DeclarativeBase, QueryableAttribute[Any]]]:
"""
Yield field definitions for the DTO being built, adjusting types for relational fields, adding aggregation sub-fields when requested, and marking fields as unset/missing by default.

If a field is a relation, its type is widened to allow None; if it is a scalar field, its type is set to Optional of the appropriate comparison/order type. When `aggregate_filters` is True, an additional aggregation field is created for relation fields, registered into `field_map` under a key composed from the current DTO node plus the aggregation field name, and yielded before the original relation field. Each yielded field has its default set to `UNSET` and its `default_factory` set to `DTOMissing`.

Parameters:
name: The DTO name being constructed.
model: The ORM model class the DTO is derived from.
dto_config: Configuration used to build nested DTOs; for aggregation fields this function copies the config with `partial=True` and `partial_default=UNSET`.
base: Optional base DTO class for inheritance.
node: The current DTO construction node, used to derive DTOKey entries.
raise_if_no_fields: If True, raises when no fields are produced (propagated to super).
aggregate_filters: If True, produce and yield aggregation sub-fields for relation fields.
field_map: Mutable mapping where newly created aggregation fields will be registered; if None, an internal map is used.
**kwargs: Forwarded to the superclass implementation.

@returns:
Generator that yields DTOFieldDefinition objects for each field in the DTO, including any generated aggregation fields when `aggregate_filters` is True.
"""
field_map = field_map if field_map is not None else {}
for field in super().iter_field_definitions(
name, model, dto_config, base, node, raise_if_no_fields, field_map=field_map, **kwargs
Expand Down Expand Up @@ -472,4 +491,4 @@ def factory(
**kwargs,
)
dto.__strawchemy_description__ = "Ordering options"
return dto
return dto
Loading