Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,183 changes: 630 additions & 553 deletions README.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion examples/testapp/testapp/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from datetime import datetime, timezone
from uuid import UUID, uuid4

from sqlalchemy import Column, DateTime, ForeignKey, MetaData, Table
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship

from sqlalchemy import Column, DateTime, ForeignKey, MetaData, Table
from strawchemy.dto.utils import READ_ONLY

UTC = timezone.utc
Expand Down
1 change: 1 addition & 0 deletions examples/testapp/testapp/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import TYPE_CHECKING

import strawberry

from strawchemy.validation.pydantic import PydanticValidation
from testapp.types import (
CustomerCreate,
Expand Down
21 changes: 11 additions & 10 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ footer = """
trim = true
# postprocessors
postprocessors = [
# { pattern = '<REPO>', replace = "https://github.com/orhun/git-cliff" }, # replace repository URL
# { pattern = '<REPO>', replace = "https://github.com/orhun/git-cliff" }, # replace strawberry URL
]

# render body even when there are no releases to process
Expand Down Expand Up @@ -403,18 +403,19 @@ update_docstrings = true
cache = true

[tool.unasyncd.files]
"src/strawchemy/sqlalchemy/repository/_async.py" = "src/strawchemy/sqlalchemy/repository/_sync.py"
"src/strawchemy/strawberry/repository/_async.py" = "src/strawchemy/strawberry/repository/_sync.py"
"src/strawchemy/repository/sqlalchemy/_async.py" = "src/strawchemy/repository/sqlalchemy/_sync.py"
"src/strawchemy/repository/strawberry/_async.py" = "src/strawchemy/repository/strawberry/_sync.py"

[tool.unasyncd.per_file_add_replacements."src/strawchemy/sqlalchemy/repository/_async.py"]
"strawchemy.sqlalchemy._executor.AsyncQueryExecutor" = "strawchemy.sqlalchemy._executor.SyncQueryExecutor"
[tool.unasyncd.per_file_add_replacements."src/strawchemy/repository/sqlalchemy/_async.py"]
"strawchemy.transpiler.AsyncQueryExecutor" = "strawchemy.transpiler.SyncQueryExecutor"
SQLAlchemyGraphQLAsyncRepository = "SQLAlchemyGraphQLSyncRepository"
"strawchemy.sqlalchemy.typing.AnyAsyncSession" = "strawchemy.sqlalchemy.typing.AnySyncSession"
"strawchemy.repository.typing.AnyAsyncSession" = "strawchemy.repository.typing.AnySyncSession"

[tool.unasyncd.per_file_add_replacements."src/strawchemy/strawberry/repository/_async.py"]
"strawchemy.sqlalchemy.repository.SQLAlchemyGraphQLAsyncRepository" = "strawchemy.sqlalchemy.repository.SQLAlchemyGraphQLSyncRepository"
"strawchemy.sqlalchemy.typing.AnyAsyncSession" = "strawchemy.sqlalchemy.typing.AnySyncSession"
"strawchemy.strawberry.typing.AsyncSessionGetter" = "strawchemy.strawberry.typing.SyncSessionGetter"
[tool.unasyncd.per_file_add_replacements."src/strawchemy/repository/strawberry/_async.py"]
"strawchemy.repository.strawberry.base.IS_ASYNC_REPOSITORY" = "strawchemy.repository.strawberry.base.IS_SYNC_REPOSITORY"
"strawchemy.repository.sqlalchemy.SQLAlchemyGraphQLAsyncRepository" = "strawchemy.repository.sqlalchemy.SQLAlchemyGraphQLSyncRepository"
"strawchemy.repository.typing.AnyAsyncSession" = "strawchemy.repository.typing.AnySyncSession"
"strawchemy.repository.typing.AsyncSessionGetter" = "strawchemy.repository.typing.SyncSessionGetter"
StrawchemyAsyncRepository = "StrawchemySyncRepository"

[tool.uv]
Expand Down
Binary file added src/.DS_Store
Binary file not shown.
Binary file added src/strawchemy/.DS_Store
Binary file not shown.
14 changes: 7 additions & 7 deletions src/strawchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,21 @@
from __future__ import annotations

from strawchemy.config.base import StrawchemyConfig
from strawchemy.instance import ModelInstance
from strawchemy.mapper import Strawchemy
from strawchemy.sqlalchemy.hook import QueryHook
from strawchemy.strawberry import ModelInstance
from strawchemy.strawberry.mutation.input import Input
from strawchemy.strawberry.mutation.types import (
ErrorType,
from strawchemy.repository.strawberry import StrawchemyAsyncRepository, StrawchemySyncRepository
from strawchemy.schema.interfaces import ErrorType
from strawchemy.schema.mutation import (
Input,
RequiredToManyUpdateInput,
RequiredToOneInput,
ToManyCreateInput,
ToManyUpdateInput,
ToOneInput,
ValidationErrorType,
)
from strawchemy.strawberry.repository import StrawchemyAsyncRepository, StrawchemySyncRepository
from strawchemy.validation.base import InputValidationError
from strawchemy.transpiler.hook import QueryHook
from strawchemy.validation import InputValidationError

__all__ = (
"ErrorType",
Expand Down
16 changes: 16 additions & 0 deletions src/strawchemy/__metadata__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""Metadata for the Project."""

from importlib.metadata import PackageNotFoundError, metadata, version # pragma: no cover

__all__ = ("__project__", "__version__") # pragma: no cover

try: # pragma: no cover
__version__ = version("strawchemy")
"""Version of the project."""
__project__ = metadata("strawchemy")["Name"]
"""Name of the project."""
except PackageNotFoundError: # pragma: no cover
__version__ = "0.0.1"
__project__ = "strawchemy"
finally: # pragma: no cover
del version, PackageNotFoundError, metadata
21 changes: 9 additions & 12 deletions src/strawchemy/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,15 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from strawchemy.sqlalchemy.inspector import SQLAlchemyGraphQLInspector
from strawchemy.strawberry import default_session_getter
from strawchemy.strawberry.repository import StrawchemySyncRepository
from strawchemy.dto.inspectors import SQLAlchemyGraphQLInspector
from strawchemy.repository.strawberry import StrawchemySyncRepository
from strawchemy.utils.strawberry import default_session_getter

if TYPE_CHECKING:
from typing import Any

from strawchemy.sqlalchemy.typing import FilterMap
from strawchemy.strawberry.typing import AnySessionGetter
from strawchemy.typing import AnyRepository, SupportedDialect
from strawchemy.repository.typing import AnySessionGetter, FilterMap
from strawchemy.typing import AnyRepositoryType, SupportedDialect


@dataclass
Expand All @@ -27,7 +24,7 @@ class StrawchemyConfig:
auto_snake_case: Automatically convert snake cased names to camel case.
repository_type: Repository class to use for auto resolvers.
filter_overrides: Override default filters with custom filters.
execution_options: SQLAlchemy execution options for repository operations.
execution_options: SQLAlchemy execution options for strawberry operations.
pagination_default_limit: Default pagination limit when `pagination=True`.
pagination: Enable/disable pagination on list resolvers.
default_id_field_name: Name for primary key fields arguments on primary key resolvers.
Expand All @@ -40,12 +37,12 @@ class StrawchemyConfig:
"""Function to retrieve SQLAlchemy session from strawberry `Info` object."""
auto_snake_case: bool = True
"""Automatically convert snake cased names to camel case"""
repository_type: AnyRepository = StrawchemySyncRepository
repository_type: AnyRepositoryType = StrawchemySyncRepository
"""Repository class to use for auto resolvers."""
filter_overrides: FilterMap | None = None
"""Override default filters with custom filters."""
execution_options: dict[str, Any] | None = None
"""SQLAlchemy execution options for repository operations."""
"""SQLAlchemy execution options for strawberry operations."""
pagination_default_limit: int = 100
"""Default pagination limit when `pagination=True`."""
pagination: bool = False
Expand Down
3 changes: 1 addition & 2 deletions src/strawchemy/config/databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from strawchemy.exceptions import StrawchemyError

if TYPE_CHECKING:
from strawchemy.strawberry.typing import AggregationFunction
from strawchemy.typing import SupportedDialect
from strawchemy.typing import AggregationFunction, SupportedDialect


@dataclass(frozen=True)
Expand Down
22 changes: 11 additions & 11 deletions src/strawchemy/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,17 @@

GEO_INSTALLED: bool = all(find_spec(package) is not None for package in ("geoalchemy2", "shapely"))

LIMIT_KEY = "limit"
OFFSET_KEY = "offset"
ORDER_BY_KEY = "order_by"
FILTER_KEY = "filter"
DISTINCT_ON_KEY = "distinct_on"
LIMIT_KEY: str = "limit"
OFFSET_KEY: str = "offset"
ORDER_BY_KEY: str = "order_by"
FILTER_KEY: str = "filter"
DISTINCT_ON_KEY: str = "distinct_on"

AGGREGATIONS_KEY = "aggregations"
NODES_KEY = "nodes"
AGGREGATIONS_KEY: str = "aggregations"
NODES_KEY: str = "nodes"

DATA_KEY = "data"
JSON_PATH_KEY = "path"
DATA_KEY: str = "data"
JSON_PATH_KEY: str = "path"

UPSERT_UPDATE_FIELDS = "update_fields"
UPSERT_CONFLICT_FIELDS = "conflict_fields"
UPSERT_UPDATE_FIELDS: str = "update_fields"
UPSERT_CONFLICT_FIELDS: str = "conflict_fields"
Binary file added src/strawchemy/dto/.DS_Store
Binary file not shown.
6 changes: 4 additions & 2 deletions src/strawchemy/dto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,20 @@

from __future__ import annotations

from strawchemy.dto.base import DTOFieldDefinition, ModelFieldT, ModelInspector, ModelT
from strawchemy.dto.base import DTOFieldDefinition, MappedDTO, ModelFieldT, ModelT, ToMappedProtocol, VisitorProtocol
from strawchemy.dto.types import DTOConfig, Purpose, PurposeConfig
from strawchemy.dto.utils import config, field

__all__ = (
"DTOConfig",
"DTOFieldDefinition",
"MappedDTO",
"ModelFieldT",
"ModelInspector",
"ModelT",
"Purpose",
"PurposeConfig",
"ToMappedProtocol",
"VisitorProtocol",
"config",
"field",
)
2 changes: 1 addition & 1 deletion src/strawchemy/dto/backend/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from strawchemy.dto.base import DTOBackend, DTOBase, DTOFieldDefinition, MappedDTO, ModelFieldT, ModelT
from strawchemy.dto.types import DTOMissing
from strawchemy.utils import get_annotations
from strawchemy.utils.annotation import get_annotations

if TYPE_CHECKING:
from collections.abc import Iterable
Expand Down
6 changes: 3 additions & 3 deletions src/strawchemy/dto/backend/strawberry.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,20 @@
from types import new_class
from typing import TYPE_CHECKING, Any, TypeVar, get_origin

import strawberry
from strawberry.types.field import StrawberryField
from typing_extensions import override

import strawberry
from strawchemy.dto.base import DTOBackend, DTOBase, MappedDTO, ModelFieldT, ModelT
from strawchemy.dto.types import DTOMissing
from strawchemy.utils import get_annotations
from strawchemy.utils.annotation import get_annotations

if TYPE_CHECKING:
from collections.abc import Iterable

from strawchemy.dto.base import DTOFieldDefinition

__all__ = ("AnnotatedDTOT", "StrawberrryDTOBackend", "StrawberryDTO", "StrawberryDTO")
__all__ = ("AnnotatedDTOT", "MappedStrawberryDTO", "StrawberrryDTOBackend", "StrawberryDTO", "StrawberryDTO")

AnnotatedDTOT = TypeVar("AnnotatedDTOT", bound="StrawberryDTO[Any] | MappedStrawberryDTO[Any]")

Expand Down
69 changes: 20 additions & 49 deletions src/strawchemy/dto/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from types import new_class
from typing import (
TYPE_CHECKING,
Annotated,
ClassVar,
ForwardRef,
Generic,
Expand All @@ -20,15 +19,12 @@
TypeAlias,
TypeVar,
cast,
get_args,
get_origin,
get_type_hints,
runtime_checkable,
)

from typing_extensions import Self, override

from strawchemy.dto.exceptions import DTOError, EmptyDTOError
from strawchemy.dto.types import (
DTOAuto,
DTOConfig,
Expand All @@ -42,17 +38,32 @@
PurposeConfig,
)
from strawchemy.dto.utils import config
from strawchemy.graph import Node
from strawchemy.utils import is_type_hint_optional, non_optional_type_hint
from strawchemy.exceptions import DTOError, EmptyDTOError
from strawchemy.utils.annotation import is_type_hint_optional, non_optional_type_hint
from strawchemy.utils.graph import Node

if TYPE_CHECKING:
from collections.abc import Callable, Generator, Hashable, Iterable, Mapping
from typing import Any

from strawchemy.dto.inspectors import ModelInspector


__all__ = (
"DTOBackend",
"DTOBase",
"DTOFactory",
"DTOFieldDefinition",
"MappedDTO",
"ModelFieldT",
"ModelT",
"Relation",
"ToMappedProtocol",
"ToMappedProtocolT",
"VisitorProtocol",
)

__all__ = ("DTOFactory", "DTOFieldDefinition", "MappedDTO", "ModelInspector")

T = TypeVar("T")
T = TypeVar("T", bound="Any")
DTOBaseT = TypeVar("DTOBaseT", bound="DTOBase[Any]")
ModelT = TypeVar("ModelT")
ToMappedProtocolT = TypeVar("ToMappedProtocolT", bound="ToMappedProtocol[Any]")
Expand Down Expand Up @@ -226,46 +237,6 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.model.__name__})"


class ModelInspector(Protocol, Generic[ModelT, ModelFieldT]):
def field_definitions(
self, model: type[Any], dto_config: DTOConfig
) -> Iterable[tuple[str, DTOFieldDefinition[ModelT, ModelFieldT]]]: ...

def id_field_definitions(
self, model: type[Any], dto_config: DTOConfig
) -> list[tuple[str, DTOFieldDefinition[ModelT, ModelFieldT]]]: ...

def field_definition(
self, model_field: ModelFieldT, dto_config: DTOConfig
) -> DTOFieldDefinition[ModelT, ModelFieldT]: ...

def get_type_hints(self, type_: type[Any], include_extras: bool = True) -> dict[str, Any]: ...

def relation_model(self, model_field: ModelFieldT) -> type[Any]: ...

def model_field_type(self, field_definition: DTOFieldDefinition[ModelT, ModelFieldT]) -> Any:
type_hint = (
field_definition.type_hint_override if field_definition.has_type_override else field_definition.type_hint
)
if get_origin(type_hint) is Annotated:
return get_args(type_hint)[0]
return non_optional_type_hint(type_hint)

def relation_cycle(
self, field: DTOFieldDefinition[Any, ModelFieldT], node: Node[Relation[ModelT, Any], None]
) -> bool: ...

def has_default(self, model_field: ModelFieldT) -> bool: ...

def required(self, model_field: ModelFieldT) -> bool: ...

def is_foreign_key(self, model_field: ModelFieldT) -> bool: ...

def is_primary_key(self, model_field: ModelFieldT) -> bool: ...

def reverse_relation_required(self, model_field: ModelFieldT) -> bool: ...


@dataclass(slots=True)
class DTOFieldDefinition(Generic[ModelT, ModelFieldT]):
dto_config: DTOConfig
Expand Down
12 changes: 0 additions & 12 deletions src/strawchemy/dto/exceptions.py

This file was deleted.

6 changes: 6 additions & 0 deletions src/strawchemy/dto/inspectors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from __future__ import annotations

from strawchemy.dto.inspectors.base import ModelInspector
from strawchemy.dto.inspectors.sqlalchemy import SQLAlchemyGraphQLInspector, SQLAlchemyInspector

__all__ = ("ModelInspector", "SQLAlchemyGraphQLInspector", "SQLAlchemyInspector")
Loading