diff --git a/README.md b/README.md index 8249b3d9..e1c24392 100644 --- a/README.md +++ b/README.md @@ -23,13 +23,17 @@ Generates GraphQL types, inputs, queries and resolvers directly from SQLAlchemy - ⚡ **Sync/Async**: Works with both sync and async SQLAlchemy sessions - 🛢 **Supported databases**: - - PostgreSQL (using [asyncpg](https://github.com/MagicStack/asyncpg) or [psycopg3 sync/async](https://www.psycopg.org/psycopg3/)) - - MySQL (using [asyncmy](https://github.com/long2ice/asyncmy)) - - SQLite (using [aiosqlite](https://aiosqlite.omnilib.dev/en/stable/) or [sqlite](https://docs.python.org/3/library/sqlite3.html)) + - PostgreSQL (using [asyncpg](https://github.com/MagicStack/asyncpg) + or [psycopg3 sync/async](https://www.psycopg.org/psycopg3/)) + - MySQL (using [asyncmy](https://github.com/long2ice/asyncmy)) + - SQLite (using [aiosqlite](https://aiosqlite.omnilib.dev/en/stable/) + or [sqlite](https://docs.python.org/3/library/sqlite3.html)) > [!Warning] > -> Please note that strawchemy is currently in a pre-release stage of development. This means that the library is still under active development and the initial API is subject to change. We encourage you to experiment with strawchemy and provide feedback, but be sure to pin and update carefully until a stable release is available. +> Please note that strawchemy is currently in a pre-release stage of development. This means that the library is still +> under active development and the initial API is subject to change. We encourage you to experiment with strawchemy and +> provide feedback, but be sure to pin and update carefully until a stable release is available. ## Table of Contents @@ -142,44 +146,46 @@ class Query: users: list[UserType] = strawchemy.field(filter_input=UserFilter, order_by=UserOrderBy, pagination=True) posts: list[PostType] = strawchemy.field(filter_input=PostFilter, order_by=PostOrderBy, pagination=True) + # Create schema schema = strawberry.Schema(query=Query) ``` ```graphql { - # Users with pagination, filtering, and ordering - users( - offset: 0 - limit: 10 - filter: { name: { contains: "John" } } - orderBy: { name: ASC } - ) { - id - name - posts { - id - title - content + # Users with pagination, filtering, and ordering + users( + offset: 0 + limit: 10 + filter: { name: { contains: "John" } } + orderBy: { name: ASC } + ) { + id + name + posts { + id + title + content + } } - } - # Posts with exact title match - posts(filter: { title: { eq: "Introduction to GraphQL" } }) { - id - title - content - author { - id - name + # Posts with exact title match + posts(filter: { title: { eq: "Introduction to GraphQL" } }) { + id + title + content + author { + id + name + } } - } } ``` ## Mapping SQLAlchemy Models -Strawchemy provides an easy way to map SQLAlchemy models to GraphQL types using the `@strawchemy.type` decorator. You can include/exclude specific fields or have strawchemy map all columns/relationships of the model and it's children. +Strawchemy provides an easy way to map SQLAlchemy models to GraphQL types using the `@strawchemy.type` decorator. You +can include/exclude specific fields or have strawchemy map all columns/relationships of the model and it's children.
Mapping example @@ -200,6 +206,7 @@ strawchemy = Strawchemy("postgresql") class Base(DeclarativeBase): pass + class User(Base): __tablename__ = "user" id: Mapped[int] = mapped_column(primary_key=True) @@ -246,6 +253,7 @@ Add a custom fields ```python from strawchemy import ModelInstance + class User(Base): __tablename__ = "user" @@ -269,7 +277,8 @@ See the [custom resolvers](#custom-resolvers) for more details ### Type Override -When generating types for relationships, Strawchemy creates default names (e.g., `Type`). If you have already defined a Python class with that same name, it will cause a name collision. +When generating types for relationships, Strawchemy creates default names (e.g., `Type`). If you have already +defined a Python class with that same name, it will cause a name collision. The `override=True` parameter tells Strawchemy that your definition should be used, resolving the conflict. @@ -284,6 +293,7 @@ class Author(Base): id: Mapped[int] = mapped_column(primary_key=True) name: Mapped[str] + class Book(Base): __tablename__ = "book" id: Mapped[int] = mapped_column(primary_key=True) @@ -292,7 +302,9 @@ class Book(Base): author: Mapped[Author] = relationship() ``` -If you define a type for `Book`, Strawchemy will inspect the `author` relationship and attempt to auto-generate a type for the `Author` model, naming it `AuthorType` by default. If you have already defined a class with that name, it will cause a name collision. +If you define a type for `Book`, Strawchemy will inspect the `author` relationship and attempt to auto-generate a type +for the `Author` model, naming it `AuthorType` by default. If you have already defined a class with that name, it will +cause a name collision. ```python # Let's say you've already defined this class @@ -300,6 +312,7 @@ If you define a type for `Book`, Strawchemy will inspect the `author` relationsh class BookType: pass + # This will cause an error because Strawchemy has already created `AuthorType` when generating `BookType` @strawchemy.type(Book, include="all") class AuthorType: @@ -308,13 +321,15 @@ class AuthorType: You would see an error like: `Type 'AuthorType' cannot be auto generated because it's already declared.` -To solve this, you can create a single, definitive `AuthorType` and mark it with `override=True`. This tells Strawchemy to use your version instead of generating a new one. +To solve this, you can create a single, definitive `AuthorType` and mark it with `override=True`. This tells Strawchemy +to use your version instead of generating a new one. ```python @strawchemy.type(Author, include="all", override=True) class AuthorType: pass + # Now this works, because Strawchemy knows to use your `AuthorType` @strawchemy.type(Book, include="all") class BookType: @@ -327,12 +342,15 @@ class BookType: While `override=True` solves name collisions, `scope="global"` is used to promote consistency and reuse. -By defining a type with `scope="global"`, you register it as the canonical type for a given SQLAlchemy model and purpose (e.g. a strawberry `type`, `filter`, or `input`). Strawchemy will then automatically use this globally-scoped type everywhere it's needed in your schema, rather than generating new ones. +By defining a type with `scope="global"`, you register it as the canonical type for a given SQLAlchemy model and +purpose (e.g. a strawberry `type`, `filter`, or `input`). Strawchemy will then automatically use this globally-scoped +type everywhere it's needed in your schema, rather than generating new ones.
Using `scope="global"` -Let's define a global type for the `Color` model. This type will now be the default for the `Color` model across the entire schema. +Let's define a global type for the `Color` model. This type will now be the default for the `Color` model across the +entire schema. ```python # This becomes the canonical type for the `Color` model @@ -340,6 +358,7 @@ Let's define a global type for the `Color` model. This type will now be the defa class ColorType: pass + # Another type that references the Color model @strawchemy.type(Fruit, include="all") class FruitType: @@ -348,13 +367,15 @@ class FruitType: # without needing an explicit annotation. ``` -This ensures that the `Color` model is represented consistently as `ColorType` in all parts of your GraphQL schema, such as in the `FruitType`'s `color` field, without needing to manually specify it every time. +This ensures that the `Color` model is represented consistently as `ColorType` in all parts of your GraphQL schema, such +as in the `FruitType`'s `color` field, without needing to manually specify it every time.
## Resolver Generation -Strawchemy automatically generates resolvers for your GraphQL fields. You can use the `strawchemy.field()` function to generate fields that query your database +Strawchemy automatically generates resolvers for your GraphQL fields. You can use the `strawchemy.field()` function to +generate fields that query your database
Resolvers example @@ -372,12 +393,15 @@ class Query:
-While Strawchemy automatically generates resolvers for most use cases, you can also create custom resolvers for more complex scenarios. There are two main approaches to creating custom resolvers: +While Strawchemy automatically generates resolvers for most use cases, you can also create custom resolvers for more +complex scenarios. There are two main approaches to creating custom resolvers: ### Using Repository Directly -When using `strawchemy.field()` as a function, strawchemy creates a resolver that delegates data fetching to the `StrawchemySyncRepository` or `StrawchemyAsyncRepository` classes depending on the SQLAlchemy session type. -You can create custom resolvers by using the `@strawchemy.field` as a decorator and working directly with the repository: +When using `strawchemy.field()` as a function, strawchemy creates a resolver that delegates data fetching to the +`StrawchemySyncRepository` or `StrawchemyAsyncRepository` classes depending on the SQLAlchemy session type. +You can create custom resolvers by using the `@strawchemy.field` as a decorator and working directly with the +repository:
Custom resolvers using repository @@ -386,18 +410,19 @@ You can create custom resolvers by using the `@strawchemy.field` as a decorator from sqlalchemy import select, true from strawchemy import StrawchemySyncRepository + @strawberry.type class Query: @strawchemy.field def red_color(self, info: strawberry.Info) -> ColorType: - # Create a repository with a predefined filter + # Create a strawberry with a predefined filter repo = StrawchemySyncRepository(ColorType, info, filter_statement=select(Color).where(Color.name == "Red")) # Return a single result (will raise an exception if not found) return repo.get_one().graphql_type() @strawchemy.field def get_color_by_name(self, info: strawberry.Info, color: str) -> ColorType | None: - # Create a repository with a custom filter statement + # Create a strawberry with a custom filter statement repo = StrawchemySyncRepository(ColorType, info, filter_statement=select(Color).where(Color.name == color)) # Return a single result or None if not found return repo.get_one_or_none().graphql_type_or_none() @@ -420,6 +445,7 @@ For async resolvers, use `StrawchemyAsyncRepository` which is the async variant ```python from strawchemy import StrawchemyAsyncRepository + @strawberry.type class Query: @strawchemy.field @@ -439,7 +465,8 @@ The repository provides several methods for fetching data: ### Query Hooks -Strawchemy provides query hooks that allow you to customize query behavior. Query hooks give you fine-grained control over how SQL queries are constructed and executed. +Strawchemy provides query hooks that allow you to customize query behavior. Query hooks give you fine-grained control +over how SQL queries are constructed and executed.
Using query hooks @@ -448,13 +475,16 @@ The `QueryHook` base class provides several methods that you can override to cus #### Modifying the statement -You can subclass `QueryHook` and override the `apply_hook` method apply changes to the statement. By default, it returns it unchanged. This method is only for filtering or ordering customizations, if you want to explicitly load columns or relationships, use the `load` parameter instead. +You can subclass `QueryHook` and override the `apply_hook` method apply changes to the statement. By default, it returns +it unchanged. This method is only for filtering or ordering customizations, if you want to explicitly load columns or +relationships, use the `load` parameter instead. ```python from strawchemy import ModelInstance, QueryHook from sqlalchemy import Select, select from sqlalchemy.orm.util import AliasedClass + # Define a model and type class Fruit(Base): __tablename__ = "fruit" @@ -462,6 +492,7 @@ class Fruit(Base): name: Mapped[str] adjectives: Mapped[list[str]] = mapped_column(ARRAY(String)) + # Apply the hook at the field level @strawchemy.type(Fruit, exclude={"color"}) class FruitTypeWithDescription: @@ -472,12 +503,14 @@ class FruitTypeWithDescription: def description(self) -> str: return f"The {self.instance.name} is {', '.join(self.instance.adjectives)}" + # Create a custom query hook for filtering class FilterFruitHook(QueryHook[Fruit]): def apply_hook(self, statement: Select[tuple[Fruit]], alias: AliasedClass[Fruit]) -> Select[tuple[Fruit]]: # Add a custom WHERE clause return statement.where(alias.name == "Apple") + # Apply the hook at the type level @strawchemy.type(Fruit, exclude={"color"}, query_hook=FilterFruitHook()) class FilteredFruitType: @@ -486,14 +519,16 @@ class FilteredFruitType: Important notes when implementing `apply_hooks`: -- You must use the provided `alias` parameter to refer to columns of the model on which the hook is applied. Otherwise, the statement may fail. +- You must use the provided `alias` parameter to refer to columns of the model on which the hook is applied. Otherwise, + the statement may fail. - The GraphQL context is available through `self.info` within hook methods. - You must set a `ModelInstance` typed attribute if you want to access the model instance values. The `instance` attribute is matched by the `ModelInstance[Fruit]` type hint, so you can give it any name you want. #### Load specific columns/relationships -The `load` parameter specify columns and relationships that should always be loaded, even if not directly requested in the GraphQL query. This is useful for: +The `load` parameter specify columns and relationships that should always be loaded, even if not directly requested in +the GraphQL query. This is useful for: - Ensuring data needed for computed properties is available - Loading columns or relationships required for custom resolvers @@ -506,16 +541,19 @@ Examples of using the `load` parameter: def description(self) -> str: return f"The {self.instance.name} is {', '.join(self.instance.adjectives)}" + # Load a relationship without specifying columns @strawchemy.field(query_hook=QueryHook(load=[Fruit.farms])) def pretty_farms(self) -> str: return f"Farms are: {', '.join(farm.name for farm in self.instance.farms)}" + # Load a relationship with specific columns @strawchemy.field(query_hook=QueryHook(load=[(Fruit.color, [Color.name, Color.created_at])])) def pretty_color(self) -> str: return f"Color is {self.instance.color.name}" if self.instance.color else "No color!" + # Load nested relationships @strawchemy.field(query_hook=QueryHook(load=[(Color.fruits, [(Fruit.farms, [FruitFarm.name])])])) def farms(self) -> str: @@ -534,7 +572,8 @@ Strawchemy supports offset-based pagination out of the box. Enable pagination on fields: ```python -from strawchemy.types import DefaultOffsetPagination +from strawchemy.schema.pagination import DefaultOffsetPagination + @strawberry.type class Query: @@ -548,10 +587,10 @@ In your GraphQL queries, you can use the `offset` and `limit` parameters: ```graphql { - users(offset: 0, limit: 10) { - id - name - } + users(offset: 0, limit: 10) { + id + name + } } ``` @@ -567,14 +606,14 @@ Then in your GraphQL queries: ```graphql { - users { - id - name - posts(offset: 0, limit: 5) { - id - title + users { + id + name + posts(offset: 0, limit: 5) { + id + title + } } - } } ``` @@ -607,53 +646,53 @@ Now you can use various filter operations in your GraphQL queries: ```graphql { - # Equality filter - users(filter: { name: { eq: "John" } }) { - id - name - } + # Equality filter + users(filter: { name: { eq: "John" } }) { + id + name + } - # Comparison filters - users(filter: { age: { gt: 18, lte: 30 } }) { - id - name - age - } + # Comparison filters + users(filter: { age: { gt: 18, lte: 30 } }) { + id + name + age + } - # String filters - users(filter: { name: { contains: "oh", ilike: "%OHN%" } }) { - id - name - } + # String filters + users(filter: { name: { contains: "oh", ilike: "%OHN%" } }) { + id + name + } - # Logical operators - users(filter: { _or: [{ name: { eq: "John" } }, { name: { eq: "Jane" } }] }) { - id - name - } - # Nested filters - users(filter: { posts: { title: { contains: "GraphQL" } } }) { - id - name - posts { - id - title + # Logical operators + users(filter: { _or: [{ name: { eq: "John" } }, { name: { eq: "Jane" } }] }) { + id + name + } + # Nested filters + users(filter: { posts: { title: { contains: "GraphQL" } } }) { + id + name + posts { + id + title + } } - } - # Compare interval component - tasks(filter: { duration: { days: { gt: 2 } } }) { - id - name - duration - } + # Compare interval component + tasks(filter: { duration: { days: { gt: 2 } } }) { + id + name + duration + } - # Direct interval comparison - tasks(filter: { duration: { gt: "P2DT5H" } }) { - id - name - duration - } + # Direct interval comparison + tasks(filter: { duration: { gt: "P2DT5H" } }) { + id + name + duration + } } ``` @@ -662,7 +701,7 @@ Now you can use various filter operations in your GraphQL queries: Strawchemy supports a wide range of filter operations: | Data Type/Category | Filter Operations | -| --------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +|-----------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | **Common to most types** | `eq`, `neq`, `isNull`, `in`, `nin` | | **Numeric types (Int, Float, Decimal)** | `gt`, `gte`, `lt`, `lte` | | **String** | order filter, plus `like`, `nlike`, `ilike`, `nilike`, `regexp`, `iregexp`, `nregexp`, `inregexp`, `startswith`, `endswith`, `contains`, `istartswith`, `iendswith`, `icontains` | @@ -676,7 +715,9 @@ Strawchemy supports a wide range of filter operations: ### Geo Filters -Strawchemy supports spatial filtering capabilities for geometry fields using [GeoJSON](https://datatracker.ietf.org/doc/html/rfc7946). To use geo filters, you need to have PostGIS installed and enabled in your PostgreSQL database. +Strawchemy supports spatial filtering capabilities for geometry fields +using [GeoJSON](https://datatracker.ietf.org/doc/html/rfc7946). To use geo filters, you need to have PostGIS installed +and enabled in your PostgreSQL database.
Geo filters example @@ -692,15 +733,18 @@ class GeoModel(Base): point: Mapped[WKBElement | None] = mapped_column(Geometry("POINT", srid=4326), nullable=True) polygon: Mapped[WKBElement | None] = mapped_column(Geometry("POLYGON", srid=4326), nullable=True) + @strawchemy.type(GeoModel, include="all") class GeoType: ... + @strawchemy.filter(GeoModel, include="all") class GeoFieldsFilter: ... + @strawberry.type class Query: -geo: list[GeoType] = strawchemy.field(filter_input=GeoFieldsFilter) + geo: list[GeoType] = strawchemy.field(filter_input=GeoFieldsFilter) ``` @@ -708,35 +752,35 @@ Then you can use the following geo filter operations in your GraphQL queries: ```graphql { - # Find geometries that contain a point - geo( - filter: { - polygon: { containsGeometry: { type: "Point", coordinates: [0.5, 0.5] } } + # Find geometries that contain a point + geo( + filter: { + polygon: { containsGeometry: { type: "Point", coordinates: [0.5, 0.5] } } + } + ) { + id + polygon } - ) { - id - polygon - } - # Find geometries that are within a polygon - geo( - filter: { - point: { - withinGeometry: { - type: "Polygon" - coordinates: [[[0, 0], [0, 2], [2, 2], [2, 0], [0, 0]]] + # Find geometries that are within a polygon + geo( + filter: { + point: { + withinGeometry: { + type: "Polygon" + coordinates: [[[0, 0], [0, 2], [2, 2], [2, 0], [0, 0]]] + } + } } - } + ) { + id + point } - ) { - id - point - } - # Find records with null geometry - geo(filter: { point: { isNull: true } }) { - id - } + # Find records with null geometry + geo(filter: { point: { isNull: true } }) { + id + } } ``` @@ -762,7 +806,8 @@ These filters work with all geometry types supported by PostGIS, including: Strawchemy automatically exposes aggregation fields for list relationships. -When you define a model with a list relationship, the corresponding GraphQL type will include an aggregation field for that relationship, named `Aggregate`. +When you define a model with a list relationship, the corresponding GraphQL type will include an aggregation field for +that relationship, named `Aggregate`.
Basic aggregation example: @@ -803,20 +848,20 @@ You can query aggregations on the `posts` relationship: ```graphql { - users { - id - name - postsAggregate { - count - min { - title - } - max { - title - } - # Other aggregation functions are also available + users { + id + name + postsAggregate { + count + min { + title + } + max { + title + } + # Other aggregation functions are also available + } } - } } ``` @@ -846,17 +891,17 @@ For example, to find users who have more than 5 posts: ```graphql { - users( - filter: { - postsAggregate: { count: { arguments: [id], predicate: { gt: 5 } } } - } - ) { - id - name - postsAggregate { - count + users( + filter: { + postsAggregate: { count: { arguments: [id], predicate: { gt: 5 } } } + } + ) { + id + name + postsAggregate { + count + } } - } } ``` @@ -865,32 +910,32 @@ You can use various predicates for filtering: ```graphql # Users with exactly 3 posts users(filter: { - postsAggregate: { - count: { - arguments: [id] - predicate: { eq: 3 } - } - } +postsAggregate: { +count: { +arguments: [id] +predicate: { eq: 3 } +} +} }) # Users with posts containing "GraphQL" in the title users(filter: { - postsAggregate: { - maxString: { - arguments: [title] - predicate: { contains: "GraphQL" } - } - } +postsAggregate: { +maxString: { +arguments: [title] +predicate: { contains: "GraphQL" } +} +} }) # Users with an average post length greater than 1000 characters users(filter: { - postsAggregate: { - avg: { - arguments: [contentLength] - predicate: { gt: 1000 } - } - } +postsAggregate: { +avg: { +arguments: [contentLength] +predicate: { gt: 1000 } +} +} }) ``` @@ -905,16 +950,16 @@ You can also use the `distinct` parameter to count only distinct values: ```graphql { - users( - filter: { - postsAggregate: { - count: { arguments: [category], predicate: { gt: 2 }, distinct: true } - } + users( + filter: { + postsAggregate: { + count: { arguments: [category], predicate: { gt: 2 }, distinct: true } + } + } + ) { + id + name } - ) { - id - name - } } ``` @@ -949,43 +994,43 @@ Now you can use aggregation functions on the result of your query: ```graphql { - usersAggregations { - aggregations { - # Basic aggregations - count - - sum { - age - } - - avg { - age - } - - min { - age - createdAt - } - max { - age - createdAt - } - - # Statistical aggregations - stddev { - age - } - variance { - age - } - } - # Access the actual data - nodes { - id - name - age + usersAggregations { + aggregations { + # Basic aggregations + count + + sum { + age + } + + avg { + age + } + + min { + age + createdAt + } + max { + age + createdAt + } + + # Statistical aggregations + stddev { + age + } + variance { + age + } + } + # Access the actual data + nodes { + id + name + age + } } - } } ``` @@ -993,7 +1038,8 @@ Now you can use aggregation functions on the result of your query: ## Mutations -Strawchemy provides a powerful way to create GraphQL mutations for your SQLAlchemy models. These mutations allow you to create, update, and delete data through your GraphQL API. +Strawchemy provides a powerful way to create GraphQL mutations for your SQLAlchemy models. These mutations allow you to +create, update, and delete data through your GraphQL API.
Mutations example @@ -1005,19 +1051,23 @@ from strawchemy import Strawchemy, StrawchemySyncRepository, StrawchemyAsyncRepo # Initialize the strawchemy mapper strawchemy = Strawchemy("postgresql") + # Define input types for mutations @strawchemy.input(User, include=["name", "email"]) class UserCreateInput: pass + @strawchemy.input(User, include=["id", "name", "email"]) class UserUpdateInput: pass + @strawchemy.filter(User, include="all") class UserFilter: pass + # Define GraphQL mutation fields @strawberry.type class Mutation: @@ -1034,6 +1084,7 @@ class Mutation: delete_users: list[UserType] = strawchemy.delete() # Delete all delete_users_filter: list[UserType] = strawchemy.delete(UserFilter) # Delete with filter + # Create schema with mutations schema = strawberry.Schema(query=Query, mutation=Mutation) ``` @@ -1058,6 +1109,7 @@ Create mutations allow you to insert new records into your database. Strawchemy class ColorCreateInput: pass + @strawberry.type class Mutation: # Single entity creation @@ -1072,18 +1124,18 @@ GraphQL usage: ```graphql # Create a single color mutation { - createColor(data: { name: "Purple" }) { - id - name - } + createColor(data: { name: "Purple" }) { + id + name + } } # Create multiple colors in one operation mutation { - createColors(data: [{ name: "Teal" }, { name: "Magenta" }]) { - id - name - } + createColors(data: [{ name: "Teal" }, { name: "Magenta" }]) { + id + name + } } ``` @@ -1114,55 +1166,55 @@ GraphQL usage: ```graphql # Set an existing relationship mutation { - createFruit( - data: { - name: "Apple" - adjectives: ["sweet", "crunchy"] - color: { set: { id: "123e4567-e89b-12d3-a456-426614174000" } } - } - ) { - id - name - color { - id - name + createFruit( + data: { + name: "Apple" + adjectives: ["sweet", "crunchy"] + color: { set: { id: "123e4567-e89b-12d3-a456-426614174000" } } + } + ) { + id + name + color { + id + name + } } - } } # Create a new related entity mutation { - createFruit( - data: { - name: "Banana" - adjectives: ["yellow", "soft"] - color: { create: { name: "Yellow" } } - } - ) { - id - name - color { - id - name + createFruit( + data: { + name: "Banana" + adjectives: ["yellow", "soft"] + color: { create: { name: "Yellow" } } + } + ) { + id + name + color { + id + name + } } - } } # Set relationship to null mutation { - createFruit( - data: { - name: "Strawberry" - adjectives: ["red", "sweet"] - color: { set: null } - } - ) { - id - name - color { - id + createFruit( + data: { + name: "Strawberry" + adjectives: ["red", "sweet"] + color: { set: null } + } + ) { + id + name + color { + id + } } - } } ``` @@ -1180,58 +1232,58 @@ GraphQL usage: ```graphql # Set existing to-many relationships mutation { - createColor( - data: { - name: "Red" - fruits: { set: [{ id: "123e4567-e89b-12d3-a456-426614174000" }] } - } - ) { - id - name - fruits { - id - name + createColor( + data: { + name: "Red" + fruits: { set: [{ id: "123e4567-e89b-12d3-a456-426614174000" }] } + } + ) { + id + name + fruits { + id + name + } } - } } # Add to existing to-many relationships mutation { - createColor( - data: { - name: "Green" - fruits: { add: [{ id: "123e4567-e89b-12d3-a456-426614174000" }] } - } - ) { - id - name - fruits { - id - name + createColor( + data: { + name: "Green" + fruits: { add: [{ id: "123e4567-e89b-12d3-a456-426614174000" }] } + } + ) { + id + name + fruits { + id + name + } } - } } # Create new related entities mutation { - createColor( - data: { - name: "Blue" - fruits: { - create: [ - { name: "Blueberry", adjectives: ["small", "blue"] } - { name: "Plum", adjectives: ["juicy", "purple"] } - ] - } - } - ) { - id - name - fruits { - id - name + createColor( + data: { + name: "Blue" + fruits: { + create: [ + { name: "Blueberry", adjectives: ["small", "blue"] } + { name: "Plum", adjectives: ["juicy", "purple"] } + ] + } + } + ) { + id + name + fruits { + id + name + } } - } } ``` @@ -1241,28 +1293,28 @@ You can create deeply nested relationships: ```graphql mutation { - createColor( - data: { - name: "White" - fruits: { - create: [ - { - name: "Grape" - adjectives: ["tangy", "juicy"] - farms: { create: [{ name: "Bio farm" }] } - } - ] - } - } - ) { - name - fruits { - name - farms { + createColor( + data: { + name: "White" + fruits: { + create: [ + { + name: "Grape" + adjectives: ["tangy", "juicy"] + farms: { create: [{ name: "Bio farm" }] } + } + ] + } + } + ) { name - } + fruits { + name + farms { + name + } + } } - } } ``` @@ -1287,10 +1339,12 @@ Update mutations allow you to modify existing records. Strawchemy provides sever class ColorUpdateInput: pass + @strawchemy.filter(Color, include="all") class ColorFilter: pass + @strawberry.type class Mutation: # Update by ID @@ -1308,36 +1362,36 @@ GraphQL usage: ```graphql # Update by ID mutation { - updateColor( - data: { id: "123e4567-e89b-12d3-a456-426614174000", name: "Crimson" } - ) { - id - name - } + updateColor( + data: { id: "123e4567-e89b-12d3-a456-426614174000", name: "Crimson" } + ) { + id + name + } } # Batch update by IDs mutation { - updateColors( - data: [ - { id: "123e4567-e89b-12d3-a456-426614174000", name: "Crimson" } - { id: "223e4567-e89b-12d3-a456-426614174000", name: "Navy" } - ] - ) { - id - name - } + updateColors( + data: [ + { id: "123e4567-e89b-12d3-a456-426614174000", name: "Crimson" } + { id: "223e4567-e89b-12d3-a456-426614174000", name: "Navy" } + ] + ) { + id + name + } } # Update with filter mutation { - updateColorsFilter( - data: { name: "Bright Red" } - filter: { name: { eq: "Red" } } - ) { - id - name - } + updateColorsFilter( + data: { name: "Bright Red" } + filter: { name: { eq: "Red" } } + ) { + id + name + } } ``` @@ -1364,55 +1418,55 @@ GraphQL usage: ```graphql # Set an existing relationship mutation { - updateFruit( - data: { - id: "123e4567-e89b-12d3-a456-426614174000" - name: "Red Apple" - color: { set: { id: "223e4567-e89b-12d3-a456-426614174000" } } - } - ) { - id - name - color { - id - name + updateFruit( + data: { + id: "123e4567-e89b-12d3-a456-426614174000" + name: "Red Apple" + color: { set: { id: "223e4567-e89b-12d3-a456-426614174000" } } + } + ) { + id + name + color { + id + name + } } - } } # Create a new related entity mutation { - updateFruit( - data: { - id: "123e4567-e89b-12d3-a456-426614174000" - name: "Green Apple" - color: { create: { name: "Green" } } - } - ) { - id - name - color { - id - name + updateFruit( + data: { + id: "123e4567-e89b-12d3-a456-426614174000" + name: "Green Apple" + color: { create: { name: "Green" } } + } + ) { + id + name + color { + id + name + } } - } } # Set relationship to null mutation { - updateFruit( - data: { - id: "123e4567-e89b-12d3-a456-426614174000" - name: "Plain Apple" - color: { set: null } - } - ) { - id - name - color { - id + updateFruit( + data: { + id: "123e4567-e89b-12d3-a456-426614174000" + name: "Plain Apple" + color: { set: null } + } + ) { + id + name + color { + id + } } - } } ``` @@ -1430,79 +1484,79 @@ GraphQL usage: ```graphql # Set (replace) to-many relationships mutation { - updateColor( - data: { - id: "123e4567-e89b-12d3-a456-426614174000" - name: "Red" - fruits: { set: [{ id: "223e4567-e89b-12d3-a456-426614174000" }] } - } - ) { - id - name - fruits { - id - name + updateColor( + data: { + id: "123e4567-e89b-12d3-a456-426614174000" + name: "Red" + fruits: { set: [{ id: "223e4567-e89b-12d3-a456-426614174000" }] } + } + ) { + id + name + fruits { + id + name + } } - } } # Add to existing to-many relationships mutation { - updateColor( - data: { - id: "123e4567-e89b-12d3-a456-426614174000" - name: "Red" - fruits: { add: [{ id: "223e4567-e89b-12d3-a456-426614174000" }] } - } - ) { - id - name - fruits { - id - name + updateColor( + data: { + id: "123e4567-e89b-12d3-a456-426614174000" + name: "Red" + fruits: { add: [{ id: "223e4567-e89b-12d3-a456-426614174000" }] } + } + ) { + id + name + fruits { + id + name + } } - } } # Remove from to-many relationships mutation { - updateColor( - data: { - id: "123e4567-e89b-12d3-a456-426614174000" - name: "Red" - fruits: { remove: [{ id: "223e4567-e89b-12d3-a456-426614174000" }] } - } - ) { - id - name - fruits { - id - name + updateColor( + data: { + id: "123e4567-e89b-12d3-a456-426614174000" + name: "Red" + fruits: { remove: [{ id: "223e4567-e89b-12d3-a456-426614174000" }] } + } + ) { + id + name + fruits { + id + name + } } - } } # Create new related entities mutation { - updateColor( - data: { - id: "123e4567-e89b-12d3-a456-426614174000" - name: "Red" - fruits: { - create: [ - { name: "Cherry", adjectives: ["small", "red"] } - { name: "Strawberry", adjectives: ["sweet", "red"] } - ] - } - } - ) { - id - name - fruits { - id - name + updateColor( + data: { + id: "123e4567-e89b-12d3-a456-426614174000" + name: "Red" + fruits: { + create: [ + { name: "Cherry", adjectives: ["small", "red"] } + { name: "Strawberry", adjectives: ["sweet", "red"] } + ] + } + } + ) { + id + name + fruits { + id + name + } } - } } ``` @@ -1512,23 +1566,23 @@ You can combine `add` and `create` operations in a single update: ```graphql mutation { - updateColor( - data: { - id: "123e4567-e89b-12d3-a456-426614174000" - name: "Red" - fruits: { - add: [{ id: "223e4567-e89b-12d3-a456-426614174000" }] - create: [{ name: "Raspberry", adjectives: ["tart", "red"] }] - } - } - ) { - id - name - fruits { - id - name + updateColor( + data: { + id: "123e4567-e89b-12d3-a456-426614174000" + name: "Red" + fruits: { + add: [{ id: "223e4567-e89b-12d3-a456-426614174000" }] + create: [{ name: "Raspberry", adjectives: ["tart", "red"] }] + } + } + ) { + id + name + fruits { + id + name + } } - } } ``` @@ -1551,6 +1605,7 @@ Delete mutations allow you to remove records from your database. Strawchemy prov class UserFilter: pass + @strawberry.type class Mutation: # Delete all users @@ -1565,18 +1620,18 @@ GraphQL usage: ```graphql # Delete all users mutation { - deleteUsers { - id - name - } + deleteUsers { + id + name + } } # Delete users that match a filter mutation { - deleteUsersFilter(filter: { name: { eq: "Alice" } }) { - id - name - } + deleteUsersFilter(filter: { name: { eq: "Alice" } }) { + id + name + } } ``` @@ -1586,7 +1641,9 @@ The returned data contains the records that were deleted. ### Upsert Mutations -Upsert mutations provide "insert or update" functionality, allowing you to create new records or update existing ones based on conflict resolution. This is particularly useful when you want to ensure data exists without worrying about whether it's already in the database. +Upsert mutations provide "insert or update" functionality, allowing you to create new records or update existing ones +based on conflict resolution. This is particularly useful when you want to ensure data exists without worrying about +whether it's already in the database. Strawchemy supports upsert operations for: @@ -1606,16 +1663,19 @@ First, define the necessary input types and enums: class FruitCreateInput: pass + # Define which fields can be updated during upsert @strawchemy.upsert_update_fields(Fruit, include=["sweetness", "waterPercent"]) class FruitUpsertFields: pass + # Define which fields are used for conflict detection @strawchemy.upsert_conflict_fields(Fruit) class FruitUpsertConflictFields: pass + @strawberry.type class Mutation: # Single entity upsert @@ -1638,42 +1698,43 @@ class Mutation: ```graphql # Upsert a single fruit (will create if name doesn't exist, update if it does) mutation { - upsertFruit( - data: { name: "Apple", sweetness: 8, waterPercent: 0.85 } - conflictFields: name - ) { - id - name - sweetness - waterPercent - } + upsertFruit( + data: { name: "Apple", sweetness: 8, waterPercent: 0.85 } + conflictFields: name + ) { + id + name + sweetness + waterPercent + } } # Batch upsert multiple fruits mutation { - upsertFruits( - data: [ - { name: "Apple", sweetness: 8, waterPercent: 0.85 } - { name: "Orange", sweetness: 6, waterPercent: 0.87 } - ] - conflictFields: name - ) { - id - name - sweetness - waterPercent - } + upsertFruits( + data: [ + { name: "Apple", sweetness: 8, waterPercent: 0.85 } + { name: "Orange", sweetness: 6, waterPercent: 0.87 } + ] + conflictFields: name + ) { + id + name + sweetness + waterPercent + } } ``` #### How Upsert Works 1. **Conflict Detection**: The `conflictFields` parameter specifies which field(s) to check for existing records -2. **Update Fields**: The `updateFields` parameter (optional) specifies which fields should be updated if a conflict is found +2. **Update Fields**: The `updateFields` parameter (optional) specifies which fields should be updated if a conflict is + found 3. **Database Support**: - - **PostgreSQL**: Uses `ON CONFLICT DO UPDATE` - - **MySQL**: Uses `ON DUPLICATE KEY UPDATE` - - **SQLite**: Uses `ON CONFLICT DO UPDATE` + - **PostgreSQL**: Uses `ON CONFLICT DO UPDATE` + - **MySQL**: Uses `ON DUPLICATE KEY UPDATE` + - **SQLite**: Uses `ON CONFLICT DO UPDATE` #### Upsert in Relationships @@ -1688,29 +1749,29 @@ class ColorUpdateInput: ```graphql # Update a color and upsert related fruits mutation { - updateColor( - data: { - id: 1 - name: "Bright Red" - fruits: { - upsert: { - create: [ - { name: "Cherry", sweetness: 7, waterPercent: 0.87 } - { name: "Strawberry", sweetness: 8, waterPercent: 0.91 } - ] - conflictFields: name + updateColor( + data: { + id: 1 + name: "Bright Red" + fruits: { + upsert: { + create: [ + { name: "Cherry", sweetness: 7, waterPercent: 0.87 } + { name: "Strawberry", sweetness: 8, waterPercent: 0.91 } + ] + conflictFields: name + } + } + } + ) { + id + name + fruits { + id + name + sweetness } - } - } - ) { - id - name - fruits { - id - name - sweetness } - } } ``` @@ -1725,9 +1786,11 @@ mutation { ### Input Validation -Strawchemy supports input validation using Pydantic models. You can define validation schemas and apply them to mutations to ensure data meets specific requirements before being processed. +Strawchemy supports input validation using Pydantic models. You can define validation schemas and apply them to +mutations to ensure data meets specific requirements before being processed. -Create Pydantic models for the input type where you want the validation, and set the `validation` parameter on `strawchemy.field`: +Create Pydantic models for the input type where you want the validation, and set the `validation` parameter on +`strawchemy.field`:
Validation example @@ -1739,6 +1802,7 @@ from pydantic import AfterValidator from strawchemy import InputValidationError, ValidationErrorType from strawchemy.validation.pydantic import PydanticValidation + def _check_lower_case(value: str) -> str: if not value.islower(): raise ValueError("Name must be lower cased") @@ -1758,30 +1822,32 @@ class UserCreateValidation: @strawberry.type class Mutation: - create_user: UserType | ValidationErrorType = strawchemy.create(UserCreate, validation=PydanticValidation(UserCreateValidation)) + create_user: UserType | ValidationErrorType = strawchemy.create(UserCreate, + validation=PydanticValidation(UserCreateValidation)) ``` > To get the validation errors exposed in the schema, you need to add `ValidationErrorType` in the field union type -When validation fails, the query will returns a `ValidationErrorType` with detailed error information from pydantic validation: +When validation fails, the query will returns a `ValidationErrorType` with detailed error information from pydantic +validation: ```graphql mutation { - createUser(data: { name: "Bob" }) { - __typename - ... on UserType { - name - } - ... on ValidationErrorType { - id - errors { - id - loc - message - type - } + createUser(data: { name: "Bob" }) { + __typename + ... on UserType { + name + } + ... on ValidationErrorType { + id + errors { + id + loc + message + type + } + } } - } } ``` @@ -1794,7 +1860,9 @@ mutation { "errors": [ { "id": "ERROR", - "loc": ["name"], + "loc": [ + "name" + ], "message": "Value error, Name must be lower cased", "type": "value_error" } @@ -1808,25 +1876,25 @@ Validation also works with nested relationships: ```graphql mutation { - createUser( - data: { - name: "bob" - group: { - create: { - name: "Group" # This will be validated - tag: { set: { id: "..." } } + createUser( + data: { + name: "bob" + group: { + create: { + name: "Group" # This will be validated + tag: { set: { id: "..." } } + } + } + } + ) { + __typename + ... on ValidationErrorType { + errors { + loc + message + } } - } - } - ) { - __typename - ... on ValidationErrorType { - errors { - loc - message - } } - } } ``` @@ -1834,23 +1902,27 @@ mutation { ## Async Support -Strawchemy supports both synchronous and asynchronous operations. You can use either `StrawchemySyncRepository` or `StrawchemyAsyncRepository` depending on your needs: +Strawchemy supports both synchronous and asynchronous operations. You can use either `StrawchemySyncRepository` or +`StrawchemyAsyncRepository` depending on your needs: ```python from strawchemy import StrawchemySyncRepository, StrawchemyAsyncRepository + # Synchronous resolver @strawchemy.field def get_color(self, info: strawberry.Info, color: str) -> ColorType | None: repo = StrawchemySyncRepository(ColorType, info, filter_statement=select(Color).where(Color.name == color)) return repo.get_one_or_none().graphql_type_or_none() + # Asynchronous resolver @strawchemy.field async def get_color(self, info: strawberry.Info, color: str) -> ColorType | None: repo = StrawchemyAsyncRepository(ColorType, info, filter_statement=select(Color).where(Color.name == color)) return await repo.get_one_or_none().graphql_type_or_none() + # Synchronous mutation @strawberry.type class Mutation: @@ -1859,6 +1931,7 @@ class Mutation: repository_type=StrawchemySyncRepository ) + # Asynchronous mutation @strawberry.type class AsyncMutation: @@ -1868,7 +1941,8 @@ class AsyncMutation: ) ``` -By default, Strawchemy uses the StrawchemySyncRepository as its repository type. You can override this behavior by specifying a different repository using the `repository_type` configuration option. +By default, Strawchemy uses the StrawchemySyncRepository as its repository type. You can override this behavior by +specifying a different repository using the `repository_type` configuration option. ## Configuration @@ -1877,7 +1951,7 @@ Configuration is made by passing a `StrawchemyConfig` to the `Strawchemy` instan ### Configuration Options | Option | Type | Default | Description | -| -------------------------- | ----------------------------------------------------------- | -------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------- | +|----------------------------|-------------------------------------------------------------|----------------------------|------------------------------------------------------------------------------------------------------------------------------------------| | `dialect` | `SupportedDialect` | | Database dialect to use. Supported dialects are "postgresql", "mysql", "sqlite". | | `session_getter` | `Callable[[Info], Session]` | `default_session_getter` | Function to retrieve SQLAlchemy session from strawberry `Info` object. By default, it retrieves the session from `info.context.session`. | | `auto_snake_case` | `bool` | `True` | Automatically convert snake cased names to camel case in GraphQL schema. | @@ -1894,26 +1968,29 @@ Configuration is made by passing a `StrawchemyConfig` to the `Strawchemy` instan ```python from strawchemy import Strawchemy, StrawchemyConfig + # Custom session getter function def get_session_from_context(info): return info.context.db_session + # Initialize with custom configuration strawchemy = Strawchemy( StrawchemyConfig( - "postgresql", - session_getter=get_session_from_context, - auto_snake_case=True, - pagination=True, - pagination_default_limit=50, - default_id_field_name="pk", + "postgresql", + session_getter=get_session_from_context, + auto_snake_case=True, + pagination=True, + pagination_default_limit=50, + default_id_field_name="pk", ) ) ``` ## Contributing -Contributions are welcome! Please see [CONTRIBUTING.md](CONTRIBUTING.md) for details on how to contribute to this project. +Contributions are welcome! Please see [CONTRIBUTING.md](CONTRIBUTING.md) for details on how to contribute to this +project. ## License diff --git a/examples/testapp/testapp/models.py b/examples/testapp/testapp/models.py index 659eccdb..8bdca0b3 100644 --- a/examples/testapp/testapp/models.py +++ b/examples/testapp/testapp/models.py @@ -3,9 +3,9 @@ from datetime import datetime, timezone from uuid import UUID, uuid4 +from sqlalchemy import Column, DateTime, ForeignKey, MetaData, Table from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship -from sqlalchemy import Column, DateTime, ForeignKey, MetaData, Table from strawchemy.dto.utils import READ_ONLY UTC = timezone.utc diff --git a/examples/testapp/testapp/schema.py b/examples/testapp/testapp/schema.py index 38d9da4e..aff7d22b 100644 --- a/examples/testapp/testapp/schema.py +++ b/examples/testapp/testapp/schema.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING import strawberry + from strawchemy.validation.pydantic import PydanticValidation from testapp.types import ( CustomerCreate, diff --git a/pyproject.toml b/pyproject.toml index 28c24cf5..ae9489fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -233,7 +233,7 @@ footer = """ trim = true # postprocessors postprocessors = [ - # { pattern = '', replace = "https://github.com/orhun/git-cliff" }, # replace repository URL + # { pattern = '', replace = "https://github.com/orhun/git-cliff" }, # replace strawberry URL ] # render body even when there are no releases to process @@ -403,18 +403,19 @@ update_docstrings = true cache = true [tool.unasyncd.files] -"src/strawchemy/sqlalchemy/repository/_async.py" = "src/strawchemy/sqlalchemy/repository/_sync.py" -"src/strawchemy/strawberry/repository/_async.py" = "src/strawchemy/strawberry/repository/_sync.py" +"src/strawchemy/repository/sqlalchemy/_async.py" = "src/strawchemy/repository/sqlalchemy/_sync.py" +"src/strawchemy/repository/strawberry/_async.py" = "src/strawchemy/repository/strawberry/_sync.py" -[tool.unasyncd.per_file_add_replacements."src/strawchemy/sqlalchemy/repository/_async.py"] -"strawchemy.sqlalchemy._executor.AsyncQueryExecutor" = "strawchemy.sqlalchemy._executor.SyncQueryExecutor" +[tool.unasyncd.per_file_add_replacements."src/strawchemy/repository/sqlalchemy/_async.py"] +"strawchemy.transpiler.AsyncQueryExecutor" = "strawchemy.transpiler.SyncQueryExecutor" SQLAlchemyGraphQLAsyncRepository = "SQLAlchemyGraphQLSyncRepository" -"strawchemy.sqlalchemy.typing.AnyAsyncSession" = "strawchemy.sqlalchemy.typing.AnySyncSession" +"strawchemy.repository.typing.AnyAsyncSession" = "strawchemy.repository.typing.AnySyncSession" -[tool.unasyncd.per_file_add_replacements."src/strawchemy/strawberry/repository/_async.py"] -"strawchemy.sqlalchemy.repository.SQLAlchemyGraphQLAsyncRepository" = "strawchemy.sqlalchemy.repository.SQLAlchemyGraphQLSyncRepository" -"strawchemy.sqlalchemy.typing.AnyAsyncSession" = "strawchemy.sqlalchemy.typing.AnySyncSession" -"strawchemy.strawberry.typing.AsyncSessionGetter" = "strawchemy.strawberry.typing.SyncSessionGetter" +[tool.unasyncd.per_file_add_replacements."src/strawchemy/repository/strawberry/_async.py"] +"strawchemy.repository.strawberry.base.IS_ASYNC_REPOSITORY" = "strawchemy.repository.strawberry.base.IS_SYNC_REPOSITORY" +"strawchemy.repository.sqlalchemy.SQLAlchemyGraphQLAsyncRepository" = "strawchemy.repository.sqlalchemy.SQLAlchemyGraphQLSyncRepository" +"strawchemy.repository.typing.AnyAsyncSession" = "strawchemy.repository.typing.AnySyncSession" +"strawchemy.repository.typing.AsyncSessionGetter" = "strawchemy.repository.typing.SyncSessionGetter" StrawchemyAsyncRepository = "StrawchemySyncRepository" [tool.uv] diff --git a/src/.DS_Store b/src/.DS_Store new file mode 100644 index 00000000..b8805c1f Binary files /dev/null and b/src/.DS_Store differ diff --git a/src/strawchemy/.DS_Store b/src/strawchemy/.DS_Store new file mode 100644 index 00000000..d9e9269b Binary files /dev/null and b/src/strawchemy/.DS_Store differ diff --git a/src/strawchemy/__init__.py b/src/strawchemy/__init__.py index 204383a0..c93724b4 100644 --- a/src/strawchemy/__init__.py +++ b/src/strawchemy/__init__.py @@ -3,12 +3,12 @@ from __future__ import annotations from strawchemy.config.base import StrawchemyConfig +from strawchemy.instance import ModelInstance from strawchemy.mapper import Strawchemy -from strawchemy.sqlalchemy.hook import QueryHook -from strawchemy.strawberry import ModelInstance -from strawchemy.strawberry.mutation.input import Input -from strawchemy.strawberry.mutation.types import ( - ErrorType, +from strawchemy.repository.strawberry import StrawchemyAsyncRepository, StrawchemySyncRepository +from strawchemy.schema.interfaces import ErrorType +from strawchemy.schema.mutation import ( + Input, RequiredToManyUpdateInput, RequiredToOneInput, ToManyCreateInput, @@ -16,8 +16,8 @@ ToOneInput, ValidationErrorType, ) -from strawchemy.strawberry.repository import StrawchemyAsyncRepository, StrawchemySyncRepository -from strawchemy.validation.base import InputValidationError +from strawchemy.transpiler.hook import QueryHook +from strawchemy.validation import InputValidationError __all__ = ( "ErrorType", diff --git a/src/strawchemy/__metadata__.py b/src/strawchemy/__metadata__.py new file mode 100644 index 00000000..4b99d387 --- /dev/null +++ b/src/strawchemy/__metadata__.py @@ -0,0 +1,16 @@ +"""Metadata for the Project.""" + +from importlib.metadata import PackageNotFoundError, metadata, version # pragma: no cover + +__all__ = ("__project__", "__version__") # pragma: no cover + +try: # pragma: no cover + __version__ = version("strawchemy") + """Version of the project.""" + __project__ = metadata("strawchemy")["Name"] + """Name of the project.""" +except PackageNotFoundError: # pragma: no cover + __version__ = "0.0.1" + __project__ = "strawchemy" +finally: # pragma: no cover + del version, PackageNotFoundError, metadata diff --git a/src/strawchemy/config/base.py b/src/strawchemy/config/base.py index d02bce25..f8e7b9cd 100644 --- a/src/strawchemy/config/base.py +++ b/src/strawchemy/config/base.py @@ -3,18 +3,15 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any -from strawchemy.sqlalchemy.inspector import SQLAlchemyGraphQLInspector -from strawchemy.strawberry import default_session_getter -from strawchemy.strawberry.repository import StrawchemySyncRepository +from strawchemy.dto.inspectors import SQLAlchemyGraphQLInspector +from strawchemy.repository.strawberry import StrawchemySyncRepository +from strawchemy.utils.strawberry import default_session_getter if TYPE_CHECKING: - from typing import Any - - from strawchemy.sqlalchemy.typing import FilterMap - from strawchemy.strawberry.typing import AnySessionGetter - from strawchemy.typing import AnyRepository, SupportedDialect + from strawchemy.repository.typing import AnySessionGetter, FilterMap + from strawchemy.typing import AnyRepositoryType, SupportedDialect @dataclass @@ -27,7 +24,7 @@ class StrawchemyConfig: auto_snake_case: Automatically convert snake cased names to camel case. repository_type: Repository class to use for auto resolvers. filter_overrides: Override default filters with custom filters. - execution_options: SQLAlchemy execution options for repository operations. + execution_options: SQLAlchemy execution options for strawberry operations. pagination_default_limit: Default pagination limit when `pagination=True`. pagination: Enable/disable pagination on list resolvers. default_id_field_name: Name for primary key fields arguments on primary key resolvers. @@ -40,12 +37,12 @@ class StrawchemyConfig: """Function to retrieve SQLAlchemy session from strawberry `Info` object.""" auto_snake_case: bool = True """Automatically convert snake cased names to camel case""" - repository_type: AnyRepository = StrawchemySyncRepository + repository_type: AnyRepositoryType = StrawchemySyncRepository """Repository class to use for auto resolvers.""" filter_overrides: FilterMap | None = None """Override default filters with custom filters.""" execution_options: dict[str, Any] | None = None - """SQLAlchemy execution options for repository operations.""" + """SQLAlchemy execution options for strawberry operations.""" pagination_default_limit: int = 100 """Default pagination limit when `pagination=True`.""" pagination: bool = False diff --git a/src/strawchemy/config/databases.py b/src/strawchemy/config/databases.py index 0290f0e2..e9420c11 100644 --- a/src/strawchemy/config/databases.py +++ b/src/strawchemy/config/databases.py @@ -8,8 +8,7 @@ from strawchemy.exceptions import StrawchemyError if TYPE_CHECKING: - from strawchemy.strawberry.typing import AggregationFunction - from strawchemy.typing import SupportedDialect + from strawchemy.typing import AggregationFunction, SupportedDialect @dataclass(frozen=True) diff --git a/src/strawchemy/constants.py b/src/strawchemy/constants.py index f0e3de44..0a611f3f 100644 --- a/src/strawchemy/constants.py +++ b/src/strawchemy/constants.py @@ -19,17 +19,17 @@ GEO_INSTALLED: bool = all(find_spec(package) is not None for package in ("geoalchemy2", "shapely")) -LIMIT_KEY = "limit" -OFFSET_KEY = "offset" -ORDER_BY_KEY = "order_by" -FILTER_KEY = "filter" -DISTINCT_ON_KEY = "distinct_on" +LIMIT_KEY: str = "limit" +OFFSET_KEY: str = "offset" +ORDER_BY_KEY: str = "order_by" +FILTER_KEY: str = "filter" +DISTINCT_ON_KEY: str = "distinct_on" -AGGREGATIONS_KEY = "aggregations" -NODES_KEY = "nodes" +AGGREGATIONS_KEY: str = "aggregations" +NODES_KEY: str = "nodes" -DATA_KEY = "data" -JSON_PATH_KEY = "path" +DATA_KEY: str = "data" +JSON_PATH_KEY: str = "path" -UPSERT_UPDATE_FIELDS = "update_fields" -UPSERT_CONFLICT_FIELDS = "conflict_fields" +UPSERT_UPDATE_FIELDS: str = "update_fields" +UPSERT_CONFLICT_FIELDS: str = "conflict_fields" diff --git a/src/strawchemy/dto/.DS_Store b/src/strawchemy/dto/.DS_Store new file mode 100644 index 00000000..5e818281 Binary files /dev/null and b/src/strawchemy/dto/.DS_Store differ diff --git a/src/strawchemy/dto/__init__.py b/src/strawchemy/dto/__init__.py index f111c82c..162ab24c 100644 --- a/src/strawchemy/dto/__init__.py +++ b/src/strawchemy/dto/__init__.py @@ -2,18 +2,20 @@ from __future__ import annotations -from strawchemy.dto.base import DTOFieldDefinition, ModelFieldT, ModelInspector, ModelT +from strawchemy.dto.base import DTOFieldDefinition, MappedDTO, ModelFieldT, ModelT, ToMappedProtocol, VisitorProtocol from strawchemy.dto.types import DTOConfig, Purpose, PurposeConfig from strawchemy.dto.utils import config, field __all__ = ( "DTOConfig", "DTOFieldDefinition", + "MappedDTO", "ModelFieldT", - "ModelInspector", "ModelT", "Purpose", "PurposeConfig", + "ToMappedProtocol", + "VisitorProtocol", "config", "field", ) diff --git a/src/strawchemy/dto/backend/pydantic.py b/src/strawchemy/dto/backend/pydantic.py index 2e61b7e3..052a0163 100644 --- a/src/strawchemy/dto/backend/pydantic.py +++ b/src/strawchemy/dto/backend/pydantic.py @@ -16,7 +16,7 @@ from strawchemy.dto.base import DTOBackend, DTOBase, DTOFieldDefinition, MappedDTO, ModelFieldT, ModelT from strawchemy.dto.types import DTOMissing -from strawchemy.utils import get_annotations +from strawchemy.utils.annotation import get_annotations if TYPE_CHECKING: from collections.abc import Iterable diff --git a/src/strawchemy/dto/backend/strawberry.py b/src/strawchemy/dto/backend/strawberry.py index 0964891f..fad4060a 100644 --- a/src/strawchemy/dto/backend/strawberry.py +++ b/src/strawchemy/dto/backend/strawberry.py @@ -5,20 +5,20 @@ from types import new_class from typing import TYPE_CHECKING, Any, TypeVar, get_origin +import strawberry from strawberry.types.field import StrawberryField from typing_extensions import override -import strawberry from strawchemy.dto.base import DTOBackend, DTOBase, MappedDTO, ModelFieldT, ModelT from strawchemy.dto.types import DTOMissing -from strawchemy.utils import get_annotations +from strawchemy.utils.annotation import get_annotations if TYPE_CHECKING: from collections.abc import Iterable from strawchemy.dto.base import DTOFieldDefinition -__all__ = ("AnnotatedDTOT", "StrawberrryDTOBackend", "StrawberryDTO", "StrawberryDTO") +__all__ = ("AnnotatedDTOT", "MappedStrawberryDTO", "StrawberrryDTOBackend", "StrawberryDTO", "StrawberryDTO") AnnotatedDTOT = TypeVar("AnnotatedDTOT", bound="StrawberryDTO[Any] | MappedStrawberryDTO[Any]") diff --git a/src/strawchemy/dto/base.py b/src/strawchemy/dto/base.py index ab3321f8..b78f55bb 100644 --- a/src/strawchemy/dto/base.py +++ b/src/strawchemy/dto/base.py @@ -11,7 +11,6 @@ from types import new_class from typing import ( TYPE_CHECKING, - Annotated, ClassVar, ForwardRef, Generic, @@ -20,15 +19,12 @@ TypeAlias, TypeVar, cast, - get_args, - get_origin, get_type_hints, runtime_checkable, ) from typing_extensions import Self, override -from strawchemy.dto.exceptions import DTOError, EmptyDTOError from strawchemy.dto.types import ( DTOAuto, DTOConfig, @@ -42,17 +38,32 @@ PurposeConfig, ) from strawchemy.dto.utils import config -from strawchemy.graph import Node -from strawchemy.utils import is_type_hint_optional, non_optional_type_hint +from strawchemy.exceptions import DTOError, EmptyDTOError +from strawchemy.utils.annotation import is_type_hint_optional, non_optional_type_hint +from strawchemy.utils.graph import Node if TYPE_CHECKING: from collections.abc import Callable, Generator, Hashable, Iterable, Mapping from typing import Any + from strawchemy.dto.inspectors import ModelInspector + + +__all__ = ( + "DTOBackend", + "DTOBase", + "DTOFactory", + "DTOFieldDefinition", + "MappedDTO", + "ModelFieldT", + "ModelT", + "Relation", + "ToMappedProtocol", + "ToMappedProtocolT", + "VisitorProtocol", +) -__all__ = ("DTOFactory", "DTOFieldDefinition", "MappedDTO", "ModelInspector") - -T = TypeVar("T") +T = TypeVar("T", bound="Any") DTOBaseT = TypeVar("DTOBaseT", bound="DTOBase[Any]") ModelT = TypeVar("ModelT") ToMappedProtocolT = TypeVar("ToMappedProtocolT", bound="ToMappedProtocol[Any]") @@ -226,46 +237,6 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}({self.model.__name__})" -class ModelInspector(Protocol, Generic[ModelT, ModelFieldT]): - def field_definitions( - self, model: type[Any], dto_config: DTOConfig - ) -> Iterable[tuple[str, DTOFieldDefinition[ModelT, ModelFieldT]]]: ... - - def id_field_definitions( - self, model: type[Any], dto_config: DTOConfig - ) -> list[tuple[str, DTOFieldDefinition[ModelT, ModelFieldT]]]: ... - - def field_definition( - self, model_field: ModelFieldT, dto_config: DTOConfig - ) -> DTOFieldDefinition[ModelT, ModelFieldT]: ... - - def get_type_hints(self, type_: type[Any], include_extras: bool = True) -> dict[str, Any]: ... - - def relation_model(self, model_field: ModelFieldT) -> type[Any]: ... - - def model_field_type(self, field_definition: DTOFieldDefinition[ModelT, ModelFieldT]) -> Any: - type_hint = ( - field_definition.type_hint_override if field_definition.has_type_override else field_definition.type_hint - ) - if get_origin(type_hint) is Annotated: - return get_args(type_hint)[0] - return non_optional_type_hint(type_hint) - - def relation_cycle( - self, field: DTOFieldDefinition[Any, ModelFieldT], node: Node[Relation[ModelT, Any], None] - ) -> bool: ... - - def has_default(self, model_field: ModelFieldT) -> bool: ... - - def required(self, model_field: ModelFieldT) -> bool: ... - - def is_foreign_key(self, model_field: ModelFieldT) -> bool: ... - - def is_primary_key(self, model_field: ModelFieldT) -> bool: ... - - def reverse_relation_required(self, model_field: ModelFieldT) -> bool: ... - - @dataclass(slots=True) class DTOFieldDefinition(Generic[ModelT, ModelFieldT]): dto_config: DTOConfig diff --git a/src/strawchemy/dto/exceptions.py b/src/strawchemy/dto/exceptions.py deleted file mode 100644 index 0891559d..00000000 --- a/src/strawchemy/dto/exceptions.py +++ /dev/null @@ -1,12 +0,0 @@ -from __future__ import annotations - -__all__ = ("DTOError", "EmptyDTOError", "ModelInspectorError") - - -class DTOError(Exception): ... - - -class EmptyDTOError(DTOError): ... - - -class ModelInspectorError(DTOError): ... diff --git a/src/strawchemy/dto/inspectors/__init__.py b/src/strawchemy/dto/inspectors/__init__.py index e69de29b..74eb2aaf 100644 --- a/src/strawchemy/dto/inspectors/__init__.py +++ b/src/strawchemy/dto/inspectors/__init__.py @@ -0,0 +1,6 @@ +from __future__ import annotations + +from strawchemy.dto.inspectors.base import ModelInspector +from strawchemy.dto.inspectors.sqlalchemy import SQLAlchemyGraphQLInspector, SQLAlchemyInspector + +__all__ = ("ModelInspector", "SQLAlchemyGraphQLInspector", "SQLAlchemyInspector") diff --git a/src/strawchemy/dto/inspectors/base.py b/src/strawchemy/dto/inspectors/base.py new file mode 100644 index 00000000..4309c17a --- /dev/null +++ b/src/strawchemy/dto/inspectors/base.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Annotated, Any, Generic, Protocol, get_args, get_origin + +from strawchemy.dto import DTOConfig, DTOFieldDefinition, ModelFieldT, ModelT +from strawchemy.utils.annotation import non_optional_type_hint + +if TYPE_CHECKING: + from collections.abc import Iterable + + from strawchemy.dto.base import Relation + from strawchemy.utils.graph import Node + + +class ModelInspector(Protocol, Generic[ModelT, ModelFieldT]): + def field_definitions( + self, model: type[Any], dto_config: DTOConfig + ) -> Iterable[tuple[str, DTOFieldDefinition[ModelT, ModelFieldT]]]: ... + + def id_field_definitions( + self, model: type[Any], dto_config: DTOConfig + ) -> list[tuple[str, DTOFieldDefinition[ModelT, ModelFieldT]]]: ... + + def field_definition( + self, model_field: ModelFieldT, dto_config: DTOConfig + ) -> DTOFieldDefinition[ModelT, ModelFieldT]: ... + + def get_type_hints(self, type_: type[Any], include_extras: bool = True) -> dict[str, Any]: ... + + def relation_model(self, model_field: ModelFieldT) -> type[Any]: ... + + def model_field_type(self, field_definition: DTOFieldDefinition[ModelT, ModelFieldT]) -> Any: + type_hint = ( + field_definition.type_hint_override if field_definition.has_type_override else field_definition.type_hint + ) + if get_origin(type_hint) is Annotated: + return get_args(type_hint)[0] + return non_optional_type_hint(type_hint) + + def relation_cycle( + self, field: DTOFieldDefinition[Any, ModelFieldT], node: Node[Relation[ModelT, Any], None] + ) -> bool: ... + + def has_default(self, model_field: ModelFieldT) -> bool: ... + + def required(self, model_field: ModelFieldT) -> bool: ... + + def is_foreign_key(self, model_field: ModelFieldT) -> bool: ... + + def is_primary_key(self, model_field: ModelFieldT) -> bool: ... + + def reverse_relation_required(self, model_field: ModelFieldT) -> bool: ... diff --git a/src/strawchemy/dto/inspectors/sqlalchemy.py b/src/strawchemy/dto/inspectors/sqlalchemy.py index 22770e13..f5158623 100644 --- a/src/strawchemy/dto/inspectors/sqlalchemy.py +++ b/src/strawchemy/dto/inspectors/sqlalchemy.py @@ -2,11 +2,28 @@ import builtins import contextlib +from collections import OrderedDict from dataclasses import MISSING as DATACLASS_MISSING from dataclasses import Field, fields +from datetime import date, datetime, time, timedelta +from decimal import Decimal from inspect import getmodule, signature from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast, get_args, get_origin, get_type_hints +from sqlalchemy import ( + ARRAY, + Column, + ColumnElement, + PrimaryKeyConstraint, + Sequence, + SQLColumnExpression, + Table, + UniqueConstraint, + event, + inspect, + orm, + sql, +) from sqlalchemy.dialects import postgresql from sqlalchemy.orm import ( NO_VALUE, @@ -23,25 +40,27 @@ ) from typing_extensions import TypeIs, override -from sqlalchemy import ( - Column, - ColumnElement, - PrimaryKeyConstraint, - Sequence, - SQLColumnExpression, - Table, - UniqueConstraint, - event, - inspect, - orm, - sql, -) +from strawchemy.config.databases import DatabaseFeatures from strawchemy.constants import GEO_INSTALLED -from strawchemy.dto.base import TYPING_NS, DTOFieldDefinition, ModelInspector, Relation +from strawchemy.dto.base import TYPING_NS, DTOFieldDefinition, Relation from strawchemy.dto.constants import DTO_INFO_KEY -from strawchemy.dto.exceptions import ModelInspectorError +from strawchemy.dto.inspectors import ModelInspector from strawchemy.dto.types import DTOConfig, DTOFieldConfig, DTOMissing, DTOUnset, Purpose -from strawchemy.utils import is_type_hint_optional +from strawchemy.exceptions import ModelInspectorError +from strawchemy.schema.filters import ( + ArrayComparison, + DateComparison, + DateTimeComparison, + EqualityComparison, + GraphQLComparison, + OrderComparison, + TextComparison, + TimeComparison, + TimeDeltaComparison, + make_full_json_comparison_input, + make_sqlite_json_comparison_input, +) +from strawchemy.utils.annotation import is_type_hint_optional if TYPE_CHECKING: from collections.abc import Callable, Generator, Iterable @@ -51,17 +70,33 @@ from sqlalchemy.orm import MapperProperty from sqlalchemy.sql.schema import ColumnCollectionConstraint - from strawchemy.graph import Node + from strawchemy.repository.typing import FilterMap + from strawchemy.typing import SupportedDialect + from strawchemy.utils.graph import Node -__all__ = ("SQLAlchemyInspector",) +__all__ = ("SQLAlchemyGraphQLInspector", "SQLAlchemyInspector") T = TypeVar("T", bound=Any) +_SQLA_NS = {**vars(orm), **vars(sql)} + +_DEFAULT_FILTERS_MAP: FilterMap = OrderedDict( + { + (timedelta,): TimeDeltaComparison, + (datetime,): DateTimeComparison, + (time,): TimeComparison, + (date,): DateComparison, + (bool,): EqualityComparison, + (int, float, Decimal): OrderComparison, + (str,): TextComparison, + } +) _shapely_geometry_map: dict[str, type[Geometry]] = {} + if GEO_INSTALLED: from shapely import ( Geometry, @@ -88,7 +123,20 @@ } -_SQLA_NS = {**vars(orm), **vars(sql)} +def loaded_attributes(model: DeclarativeBase) -> set[str]: + """Identifies attributes of a SQLAlchemy model instance that have been loaded. + + This function inspects the given SQLAlchemy model instance and returns a set + of attribute names for which the value has been loaded from the database + (i.e., the value is not `strawberry.orm.NO_VALUE`). + + Args: + model: The SQLAlchemy `DeclarativeBase` instance to inspect. + + Returns: + A set of strings, where each string is the name of a loaded attribute. + """ + return {name for name, attr in inspect(model).attrs.items() if attr.loaded_value is not NO_VALUE} class SQLAlchemyInspector(ModelInspector[DeclarativeBase, QueryableAttribute[Any]]): @@ -431,3 +479,146 @@ def unique_constraints(cls, model: type[DeclarativeBase]) -> list[ColumnCollecti if isinstance(constraint, (PrimaryKeyConstraint, UniqueConstraint, postgresql.ExcludeConstraint)) ] return sorted(constraints, key=lambda cons: "_".join(col.key for col in cons.columns)) + + +class SQLAlchemyGraphQLInspector(SQLAlchemyInspector): + """Inspects SQLAlchemy models to determine appropriate GraphQL filter types. + + This inspector extends `SQLAlchemyInspector` to provide mappings from + SQLAlchemy model attributes and Python types to specific GraphQL comparison + filter input types (e.g., `TextComparison`, `OrderComparison`). + + It takes into account the database dialect's features (via `DatabaseFeatures`) + to select suitable filters, for example, for JSON or geospatial types. + Custom filter mappings can also be provided through `filter_overrides`. + + Key methods `get_field_comparison` and `get_type_comparison` are used to + retrieve the corresponding filter types. + """ + + def __init__( + self, + dialect: SupportedDialect, + registries: list[registry] | None = None, + filter_overrides: FilterMap | None = None, + ) -> None: + """Initializes the SQLAlchemyGraphQLInspector. + + Args: + dialect: The SQL dialect of the target database (e.g., "postgresql", "sqlite"). + registries: An optional list of SQLAlchemy registries to inspect. + If None, the default registry is used. + filter_overrides: An optional mapping to override or extend the default + Python type to GraphQL filter type mappings. + """ + super().__init__(registries) + self.db_features = DatabaseFeatures.new(dialect) + self.filters_map = self._filter_map() + self.filters_map |= filter_overrides or {} + + def _filter_map(self) -> FilterMap: + """Constructs the map of Python types to GraphQL filter comparison types. + + Starts with a default set of filters (`_DEFAULT_FILTERS_MAP`). + If GeoAlchemy is installed (`GEO_INSTALLED`), it adds mappings for + geospatial types to `GeoComparison`. + It then adds mappings for `dict` to appropriate JSON comparison + types based on whether the dialect is SQLite or another database + that supports more advanced JSON operations. + + Returns: + The constructed `FilterMap`. + """ + filters_map = _DEFAULT_FILTERS_MAP + + if GEO_INSTALLED: + from geoalchemy2 import WKBElement, WKTElement # noqa: PLC0415 + from shapely import Geometry # noqa: PLC0415 + + from strawchemy.schema.filters.geo import GeoComparison # noqa: PLC0415 + + filters_map |= {(Geometry, WKBElement, WKTElement): GeoComparison} + if self.db_features.dialect == "sqlite": + filters_map[(dict, dict)] = make_sqlite_json_comparison_input() + else: + filters_map[(dict, dict)] = make_full_json_comparison_input() + return filters_map + + @classmethod + def _is_specialized(cls, type_: type[Any]) -> bool: + """Checks if a generic type is fully specialized. + + A type is considered specialized if it has no type parameters (`__parameters__`) + or if all its type parameters are concrete types (not `TypeVar`). + + Args: + type_: The type to check. + + Returns: + True if the type is specialized, False otherwise. + """ + return not hasattr(type_, "__parameters__") or all( + not isinstance(param, TypeVar) for param in type_.__parameters__ + ) + + @classmethod + def _filter_type(cls, type_: type[Any], sqlalchemy_filter: type[GraphQLComparison]) -> type[GraphQLComparison]: + """Potentially specializes a generic GraphQL filter type with a Python type. + + If the provided `sqlalchemy_filter` is a generic type (e.g., `OrderComparison[T]`) + and is not yet specialized, this method specializes it using `type_` + (e.g., `OrderComparison[int]`). If `sqlalchemy_filter` is already specialized + or not generic, it's returned as is. + + Args: + type_: The Python type to use for specialization if needed. + sqlalchemy_filter: The GraphQL filter type, which might be generic. + + Returns: + The (potentially specialized) GraphQL filter type. + """ + return sqlalchemy_filter if cls._is_specialized(sqlalchemy_filter) else sqlalchemy_filter[type_] # pyright: ignore[reportInvalidTypeArguments] + + def get_field_comparison( + self, field_definition: DTOFieldDefinition[DeclarativeBase, QueryableAttribute[Any]] + ) -> type[GraphQLComparison]: + """Determines the GraphQL comparison filter type for a DTO field. + + This method inspects the type of the given DTO field. + For `ARRAY` types on PostgreSQL, it returns a specialized `ArrayComparison`. + Otherwise, it delegates to `get_type_comparison` using the Python type + of the model field. + + Args: + field_definition: The DTO field definition, which contains information + about the model attribute and its type. + + Returns: + The GraphQL comparison filter type suitable for the field. + """ + field_type = field_definition.model_field.type + if isinstance(field_type, ARRAY) and self.db_features.dialect == "postgresql": + return ArrayComparison[field_type.item_type.python_type] + return self.get_type_comparison(self.model_field_type(field_definition)) + + def get_type_comparison(self, type_: type[Any]) -> type[GraphQLComparison]: + """Determines the GraphQL comparison filter type for a Python type. + + It iterates through the `self.filters_map` (which includes default + and dialect-specific filters) to find a filter type that matches + the provided Python `type_`. + If a direct match or a superclass match is found, the corresponding + filter type is returned, potentially specialized using `_filter_type`. + If no specific filter is found in the map, it defaults to + `EqualityComparison` specialized with the given `type_`. + + Args: + type_: The Python type for which to find a GraphQL filter. + + Returns: + The GraphQL comparison filter type suitable for the Python type. + """ + for types, sqlalchemy_filter in self.filters_map.items(): + if issubclass(type_, types): + return self._filter_type(type_, sqlalchemy_filter) + return EqualityComparison[type_] diff --git a/src/strawchemy/dto/pydantic.py b/src/strawchemy/dto/pydantic.py index 9502c57f..5b95e8ff 100644 --- a/src/strawchemy/dto/pydantic.py +++ b/src/strawchemy/dto/pydantic.py @@ -6,7 +6,7 @@ from strawchemy.dto.backend.pydantic import MappedPydanticDTO, PydanticDTOBackend from strawchemy.dto.base import DTOFactory -from strawchemy.dto.inspectors.sqlalchemy import SQLAlchemyInspector +from strawchemy.dto.inspectors import SQLAlchemyInspector __all__ = ("factory", "pydantic_dto") diff --git a/src/strawchemy/strawberry/dto.py b/src/strawchemy/dto/strawberry.py similarity index 96% rename from src/strawchemy/strawberry/dto.py rename to src/strawchemy/dto/strawberry.py index 3c6e6fed..a4c545b5 100644 --- a/src/strawchemy/strawberry/dto.py +++ b/src/strawchemy/dto/strawberry.py @@ -30,34 +30,34 @@ from dataclasses import dataclass from enum import Enum from functools import cached_property -from typing import ( - TYPE_CHECKING, - Any, - ClassVar, - Generic, - Literal, - TypeVar, - overload, -) +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, TypeVar, overload +import strawberry from msgspec import Struct, field, json from sqlalchemy.orm import DeclarativeBase, QueryableAttribute from typing_extensions import Self, override -import strawberry from strawchemy.dto.backend.strawberry import MappedStrawberryDTO, StrawberryDTO from strawchemy.dto.base import DTOBase, DTOFieldDefinition, ModelFieldT, ModelT from strawchemy.dto.types import DTOConfig, DTOFieldConfig, DTOMissing, Purpose -from strawchemy.graph import AnyNode, GraphMetadata, MatchOn, Node, NodeMetadata, NodeT -from strawchemy.sqlalchemy.hook import QueryHook # noqa: TC001 -from strawchemy.strawberry.typing import GraphQLPurpose, OrderByDTOT, QueryNodeType -from strawchemy.utils import camel_to_snake +from strawchemy.transpiler.hook import ( + QueryHook, # noqa: TC001 msgspec does not support resolving references dynamically +) +from strawchemy.typing import ( + AggregationFunction, + AggregationType, + FunctionInfo, + GraphQLPurpose, + OrderByDTOT, + QueryNodeType, +) +from strawchemy.utils.graph import AnyNode, GraphMetadata, MatchOn, Node, NodeMetadata, NodeT +from strawchemy.utils.text import camel_to_snake if TYPE_CHECKING: from collections.abc import Callable, Hashable, Sequence - from strawchemy.strawberry.filters import EqualityComparison, GraphQLComparison - from strawchemy.strawberry.typing import AggregationFunction, AggregationType, FunctionInfo + from strawchemy.schema.filters import EqualityComparison, GraphQLComparison T = TypeVar("T") @@ -343,7 +343,7 @@ class QueryNode(Node[GraphQLFieldDefinition, QueryNodeMetadata]): @classmethod @override def _node_hash_identity(cls, node: Node[GraphQLFieldDefinition, QueryNodeMetadata]) -> Hashable: - return (super()._node_hash_identity(node), node.metadata.data.relation_filter) + return super()._node_hash_identity(node), node.metadata.data.relation_filter @override def _update_new_child(self, child: NodeT) -> NodeT: diff --git a/src/strawchemy/dto/types.py b/src/strawchemy/dto/types.py index 7b3f9722..9722d3c6 100644 --- a/src/strawchemy/dto/types.py +++ b/src/strawchemy/dto/types.py @@ -9,17 +9,27 @@ from typing_extensions import override -from strawchemy.utils import get_annotations +from strawchemy.utils.annotation import get_annotations if TYPE_CHECKING: from collections.abc import Callable, Mapping -__all__ = ("DTOAuto", "DTOConfig", "DTOFieldConfig", "DTOMissing", "ExcludeFields", "IncludeFields", "Purpose") +__all__ = ( + "DTOAuto", + "DTOConfig", + "DTOFieldConfig", + "DTOMissing", + "DTOScope", + "DTOSkip", + "DTOUnset", + "ExcludeFields", + "IncludeFields", + "Purpose", + "PurposeConfig", +) DTOScope: TypeAlias = Literal["global", "dto"] - - IncludeFields: TypeAlias = "list[str] | set[str] | Literal['all']" ExcludeFields: TypeAlias = "list[str] | set[str]" diff --git a/src/strawchemy/exceptions.py b/src/strawchemy/exceptions.py index 1b3dc3ab..e78ed26f 100644 --- a/src/strawchemy/exceptions.py +++ b/src/strawchemy/exceptions.py @@ -7,7 +7,18 @@ if TYPE_CHECKING: from typing import Any -__all__ = ("SessionNotFoundError", "StrawchemyError") +__all__ = ( + "DTOError", + "EmptyDTOError", + "GraphError", + "ModelInspectorError", + "QueryHookError", + "QueryResultError", + "SessionNotFoundError", + "StrawchemyError", + "StrawchemyFieldError", + "TranspilingError", +) class StrawchemyError(Exception): @@ -41,3 +52,30 @@ def __str__(self) -> str: class SessionNotFoundError(StrawchemyError): ... + + +class StrawchemyFieldError(StrawchemyError): ... + + +class DTOError(StrawchemyError): ... + + +class EmptyDTOError(DTOError): ... + + +class ModelInspectorError(DTOError): ... + + +class TranspilingError(StrawchemyError): + """Raised when an error occurs during transpiling.""" + + +class QueryResultError(StrawchemyError): + """Raised when an error occurs during query result processing or mapping.""" + + +class QueryHookError(StrawchemyError): + """Raised when an error occurs within a query hook's execution.""" + + +class GraphError(StrawchemyError): ... diff --git a/src/strawchemy/factories.py b/src/strawchemy/factories.py deleted file mode 100644 index 1d23fa3b..00000000 --- a/src/strawchemy/factories.py +++ /dev/null @@ -1,135 +0,0 @@ -"""Factory container for organizing Strawchemy DTO factories.""" - -from __future__ import annotations - -from dataclasses import dataclass -from functools import partial -from typing import TYPE_CHECKING, Any - -if TYPE_CHECKING: - from strawchemy.mapper import Strawchemy - from strawchemy.strawberry.factories.aggregations import EnumDTOFactory - from strawchemy.strawberry.factories.inputs import AggregateFilterDTOFactory, BooleanFilterDTOFactory - from strawchemy.strawberry.factories.types import ( - DistinctOnFieldsDTOFactory, - InputFactory, - OrderByDTOFactory, - RootAggregateTypeDTOFactory, - TypeDTOFactory, - UpsertConflictFieldsDTOFactory, - ) - - -@dataclass -class StrawchemyFactories: - """Container for all Strawchemy DTO factories. - - This class encapsulates the initialization and management of all factory - instances used by Strawchemy, providing a cleaner separation of concerns - and easier testing. - - Attributes: - aggregate_filter: Factory for aggregate filter DTOs. - order_by: Factory for order by DTOs. - distinct_on_enum: Factory for distinct on enum DTOs. - type_factory: Factory for output type DTOs. - input_factory: Factory for input type DTOs. - aggregation: Factory for root aggregate type DTOs. - enum_factory: Factory for enum DTOs. - filter_factory: Factory for boolean filter DTOs. - upsert_conflict: Factory for upsert conflict fields DTOs. - """ - - aggregate_filter: AggregateFilterDTOFactory - order_by: OrderByDTOFactory - distinct_on_enum: DistinctOnFieldsDTOFactory - type_factory: TypeDTOFactory # type: ignore[type-arg] - input_factory: InputFactory # type: ignore[type-arg] - aggregation: RootAggregateTypeDTOFactory # type: ignore[type-arg] - enum_factory: EnumDTOFactory - filter_factory: BooleanFilterDTOFactory - upsert_conflict: UpsertConflictFieldsDTOFactory - - @classmethod - def create(cls, mapper: Strawchemy) -> StrawchemyFactories: - """Create all factories with proper dependencies. - - Args: - mapper: The Strawchemy instance that will own these factories. - - Returns: - A StrawchemyFactories instance with all factories initialized. - """ - # Imports inside method to avoid circular dependencies at module load time - from strawchemy.dto.backend.strawberry import StrawberrryDTOBackend # noqa: PLC0415 - from strawchemy.strawberry.dto import MappedStrawberryGraphQLDTO # noqa: PLC0415 - from strawchemy.strawberry.factories.aggregations import EnumDTOFactory # noqa: PLC0415 - from strawchemy.strawberry.factories.enum import ( # noqa: PLC0415 - EnumDTOBackend, - UpsertConflictFieldsEnumDTOBackend, - ) - from strawchemy.strawberry.factories.inputs import ( # noqa: PLC0415 - AggregateFilterDTOFactory, - BooleanFilterDTOFactory, - ) - from strawchemy.strawberry.factories.types import ( # noqa: PLC0415 - DistinctOnFieldsDTOFactory, - InputFactory, - OrderByDTOFactory, - RootAggregateTypeDTOFactory, - TypeDTOFactory, - UpsertConflictFieldsDTOFactory, - ) - - config = mapper.config - - # Create backend instances - strawberry_backend = StrawberrryDTOBackend(MappedStrawberryGraphQLDTO) - enum_backend = EnumDTOBackend(config.auto_snake_case) - upsert_conflict_fields_enum_backend = UpsertConflictFieldsEnumDTOBackend( - config.inspector, config.auto_snake_case - ) - - # Create factory instances - aggregate_filter = AggregateFilterDTOFactory(mapper) - order_by = OrderByDTOFactory(mapper) - distinct_on_enum = DistinctOnFieldsDTOFactory(config.inspector) - type_factory = TypeDTOFactory(mapper, strawberry_backend, order_by_factory=order_by) - input_factory = InputFactory(mapper, strawberry_backend) - aggregation = RootAggregateTypeDTOFactory(mapper, strawberry_backend, type_factory=type_factory) - enum_factory = EnumDTOFactory(config.inspector, enum_backend) - filter_factory = BooleanFilterDTOFactory(mapper, aggregate_filter_factory=aggregate_filter) - upsert_conflict = UpsertConflictFieldsDTOFactory(config.inspector, upsert_conflict_fields_enum_backend) - - return cls( - aggregate_filter=aggregate_filter, - order_by=order_by, - distinct_on_enum=distinct_on_enum, - type_factory=type_factory, - input_factory=input_factory, - aggregation=aggregation, - enum_factory=enum_factory, - filter_factory=filter_factory, - upsert_conflict=upsert_conflict, - ) - - def create_public_api(self) -> dict[str, Any]: - """Create the public API mappings for factory methods. - - Returns: - A dictionary mapping public API names to factory methods. - """ - return { - "filter": self.filter_factory.input, - "aggregate_filter": partial(self.aggregate_filter.input, mode="aggregate_filter"), - "distinct_on": self.distinct_on_enum.decorator, - "input": self.input_factory.input, - "create_input": partial(self.input_factory.input, mode="create_input"), - "pk_update_input": partial(self.input_factory.input, mode="update_by_pk_input"), - "filter_update_input": partial(self.input_factory.input, mode="update_by_filter_input"), - "order": partial(self.order_by.input, mode="order_by"), - "type": self.type_factory.type, - "aggregate": partial(self.aggregation.type, mode="aggregate_type"), - "upsert_update_fields": self.enum_factory.input, - "upsert_conflict_fields": self.upsert_conflict.input, - } diff --git a/src/strawchemy/strawberry/_instance.py b/src/strawchemy/instance.py similarity index 100% rename from src/strawchemy/strawberry/_instance.py rename to src/strawchemy/instance.py diff --git a/src/strawchemy/mapper.py b/src/strawchemy/mapper.py index b824bd66..5769cded 100644 --- a/src/strawchemy/mapper.py +++ b/src/strawchemy/mapper.py @@ -1,40 +1,52 @@ from __future__ import annotations import dataclasses -from functools import cached_property +from functools import cached_property, partial from typing import TYPE_CHECKING, Any, TypeVar, overload from strawberry.annotation import StrawberryAnnotation from strawberry.schema.config import StrawberryConfig from strawchemy.config.base import StrawchemyConfig +from strawchemy.dto.backend.strawberry import StrawberrryDTOBackend from strawchemy.dto.base import TYPING_NS -from strawchemy.factories import StrawchemyFactories -from strawchemy.strawberry._field import ( +from strawchemy.dto.strawberry import BooleanFilterDTO, EnumDTO, MappedStrawberryGraphQLDTO, OrderByDTO, OrderByEnum +from strawchemy.schema.factories import ( + AggregateFilterDTOFactory, + BooleanFilterDTOFactory, + DistinctOnFieldsDTOFactory, + EnumDTOBackend, + EnumDTOFactory, + InputFactory, + OrderByDTOFactory, + RootAggregateTypeDTOFactory, + TypeDTOFactory, + UpsertConflictFieldsDTOFactory, + UpsertConflictFieldsEnumDTOBackend, +) +from strawchemy.schema.field import StrawchemyField +from strawchemy.schema.mutation import types +from strawchemy.schema.mutation.field_builder import MutationFieldBuilder +from strawchemy.schema.mutation.fields import ( StrawchemyCreateMutationField, StrawchemyDeleteMutationField, - StrawchemyField, StrawchemyUpdateMutationField, StrawchemyUpsertMutationField, ) -from strawchemy.strawberry._registry import StrawberryRegistry -from strawchemy.strawberry.dto import BooleanFilterDTO, EnumDTO, OrderByDTO, OrderByEnum -from strawchemy.strawberry.mutation import types -from strawchemy.strawberry.mutation.builder import MutationFieldBuilder -from strawchemy.types import DefaultOffsetPagination +from strawchemy.schema.pagination import DefaultOffsetPagination +from strawchemy.utils.registry import StrawberryRegistry if TYPE_CHECKING: from collections.abc import Callable, Mapping, Sequence from sqlalchemy.orm import DeclarativeBase + from strawberry import BasePermission from strawberry.extensions.field_extension import FieldExtension from strawberry.types.arguments import StrawberryArgument - from strawberry import BasePermission - from strawchemy.sqlalchemy.hook import QueryHook - from strawchemy.sqlalchemy.typing import QueryHookCallable - from strawchemy.strawberry.typing import FilterStatementCallable, MappedGraphQLDTO - from strawchemy.typing import AnyRepository, SupportedDialect + from strawchemy.repository.typing import QueryHookCallable + from strawchemy.transpiler.hook import QueryHook + from strawchemy.typing import AnyRepositoryType, FilterStatementCallable, MappedGraphQLDTO, SupportedDialect from strawchemy.validation.base import ValidationProtocol from strawchemy.validation.pydantic import PydanticMapper @@ -90,38 +102,40 @@ def __init__( self.config = StrawchemyConfig(config) if isinstance(config, str) else config self.registry = StrawberryRegistry(strawberry_config or StrawberryConfig()) - # Initialize all factories through the container - factories = StrawchemyFactories.create(self) - - # Store factory references for internal use - self._aggregate_filter_factory = factories.aggregate_filter - self._order_by_factory = factories.order_by - self._distinct_on_enum_factory = factories.distinct_on_enum - self._type_factory = factories.type_factory - self._input_factory = factories.input_factory - self._aggregation_factory = factories.aggregation - self._enum_factory = factories.enum_factory - self._filter_factory = factories.filter_factory - self._upsert_conflict_factory = factories.upsert_conflict - - # Expose public factory API - public_api = factories.create_public_api() - self.filter = public_api["filter"] - self.aggregate_filter = public_api["aggregate_filter"] - self.distinct_on = public_api["distinct_on"] - self.input = public_api["input"] - self.create_input = public_api["create_input"] - self.pk_update_input = public_api["pk_update_input"] - self.filter_update_input = public_api["filter_update_input"] - self.order = public_api["order"] - self.type = public_api["type"] - self.aggregate = public_api["aggregate"] - self.upsert_update_fields = public_api["upsert_update_fields"] - self.upsert_conflict_fields = public_api["upsert_conflict_fields"] + strawberry_backend = StrawberrryDTOBackend(MappedStrawberryGraphQLDTO) + enum_backend = EnumDTOBackend(self.config.auto_snake_case) + upsert_conflict_fields_enum_backend = UpsertConflictFieldsEnumDTOBackend( + self.config.inspector, self.config.auto_snake_case + ) + + self._aggregate_filter_factory = AggregateFilterDTOFactory(self) + self._order_by_factory = OrderByDTOFactory(self) + self._distinct_on_enum_factory = DistinctOnFieldsDTOFactory(self.config.inspector) + self._type_factory = TypeDTOFactory(self, strawberry_backend, order_by_factory=self._order_by_factory) + self._input_factory = InputFactory(self, strawberry_backend) + self._aggregation_factory = RootAggregateTypeDTOFactory( + self, strawberry_backend, type_factory=self._type_factory + ) + self._enum_factory = EnumDTOFactory(self.config.inspector, enum_backend) + self._filter_factory = BooleanFilterDTOFactory(self, aggregate_filter_factory=self._aggregate_filter_factory) + self._upsert_conflict_factory = UpsertConflictFieldsDTOFactory( + self.config.inspector, upsert_conflict_fields_enum_backend + ) + self.filter = self._filter_factory.input + self.aggregate_filter = partial(self._aggregate_filter_factory.input, mode="aggregate_filter") + self.distinct_on = self._distinct_on_enum_factory.decorator + self.input = self._input_factory.input + self.create_input = partial(self._input_factory.input, mode="create_input") + self.pk_update_input = partial(self._input_factory.input, mode="update_by_pk_input") + self.filter_update_input = partial(self._input_factory.input, mode="update_by_filter_input") + self.order = partial(self._order_by_factory.input, mode="order_by") + self.type = self._type_factory.type + self.aggregate = partial(self._aggregation_factory.type, mode="aggregate_type") + self.upsert_update_fields = self._enum_factory.input + self.upsert_conflict_fields = self._upsert_conflict_factory.input # Initialize mutation field builder self._mutation_builder = MutationFieldBuilder(self.config, self._annotation_namespace) - # Register common types self.registry.register_enum(OrderByEnum, "OrderByEnum") @@ -164,7 +178,7 @@ def field( filter_statement: FilterStatementCallable | None = None, execution_options: dict[str, Any] | None = None, query_hook: QueryHook[Any] | Sequence[QueryHook[Any]] | None = None, - repository_type: AnyRepository | None = None, + repository_type: AnyRepositoryType | None = None, name: str | None = None, description: str | None = None, permission_classes: list[type[BasePermission]] | None = None, @@ -192,7 +206,7 @@ def field( filter_statement: FilterStatementCallable | None = None, execution_options: dict[str, Any] | None = None, query_hook: QueryHookCallable[Any] | Sequence[QueryHookCallable[Any]] | None = None, - repository_type: AnyRepository | None = None, + repository_type: AnyRepositoryType | None = None, name: str | None = None, description: str | None = None, permission_classes: list[type[BasePermission]] | None = None, @@ -220,7 +234,7 @@ def field( filter_statement: FilterStatementCallable | None = None, execution_options: dict[str, Any] | None = None, query_hook: QueryHookCallable[Any] | Sequence[QueryHookCallable[Any]] | None = None, - repository_type: AnyRepository | None = None, + repository_type: AnyRepositoryType | None = None, name: str | None = None, description: str | None = None, permission_classes: list[type[BasePermission]] | None = None, @@ -253,7 +267,7 @@ def field( filter_statement: A callable to generate a filter statement for the query. execution_options: SQLAlchemy execution options for the query. query_hook: A callable or sequence of callables to modify the SQLAlchemy query. - repository_type: A custom repository class for data fetching logic. + repository_type: A custom strawberry class for data fetching logic. name: The name of the GraphQL field. description: The description of the GraphQL field. permission_classes: A list of permission classes for the field. @@ -315,7 +329,7 @@ def create( input_type: type[MappedGraphQLDTO[T]], resolver: Any | None = None, *, - repository_type: AnyRepository | None = None, + repository_type: AnyRepositoryType | None = None, name: str | None = None, description: str | None = None, permission_classes: list[type[BasePermission]] | None = None, @@ -332,7 +346,7 @@ def create( This method generates a mutation field that handles the creation of SQLAlchemy model instances based on the provided input type. It integrates - with Strawchemy's repository system for data persistence and allows for + with Strawchemy's strawberry system for data persistence and allows for custom validation. Args: @@ -340,8 +354,8 @@ def create( a new model instance. This should be a `MappedGraphQLDTO`. resolver: An optional custom resolver function for the mutation. If not provided, Strawchemy will use a default resolver. - repository_type: An optional custom repository class for data fetching - and persistence logic. Defaults to the repository configured in + repository_type: An optional custom strawberry class for data fetching + and persistence logic. Defaults to the strawberry configured in `StrawchemyConfig`. name: The name of the GraphQL mutation field. description: The description of the GraphQL mutation field. @@ -386,7 +400,7 @@ def upsert( conflict_fields: type[EnumDTO], resolver: Any | None = None, *, - repository_type: AnyRepository | None = None, + repository_type: AnyRepositoryType | None = None, name: str | None = None, description: str | None = None, permission_classes: list[type[BasePermission]] | None = None, @@ -404,7 +418,7 @@ def upsert( This method generates a mutation field that handles the "upsert" (update or insert) of SQLAlchemy model instances. It uses the provided input type, update fields enum, and conflict fields enum to determine - the behavior on conflict. It integrates with Strawchemy's repository + the behavior on conflict. It integrates with Strawchemy's strawberry system and allows for custom validation. Args: @@ -416,8 +430,8 @@ def upsert( conflict detection (e.g., primary key or unique constraints). resolver: An optional custom resolver function for the mutation. If not provided, Strawchemy will use a default resolver. - repository_type: An optional custom repository class for data fetching - and persistence logic. Defaults to the repository configured in + repository_type: An optional custom strawberry class for data fetching + and persistence logic. Defaults to the strawberry configured in `StrawchemyConfig`. name: The name of the GraphQL mutation field. description: The description of the GraphQL mutation field. @@ -463,7 +477,7 @@ def update( filter_input: type[BooleanFilterDTO], resolver: Any | None = None, *, - repository_type: AnyRepository | None = None, + repository_type: AnyRepositoryType | None = None, name: str | None = None, description: str | None = None, permission_classes: list[type[BasePermission]] | None = None, @@ -481,7 +495,7 @@ def update( This method generates a mutation field that handles updating existing SQLAlchemy model instances based on filter criteria. It uses the provided input type for the update data and a filter input type to specify which - records to update. It integrates with Strawchemy's repository system and + records to update. It integrates with Strawchemy's strawberry system and allows for custom validation. Args: @@ -491,8 +505,8 @@ def update( instances should be updated. This should be a `BooleanFilterDTO`. resolver: An optional custom resolver function for the mutation. If not provided, Strawchemy will use a default resolver. - repository_type: An optional custom repository class for data fetching - and persistence logic. Defaults to the repository configured in + repository_type: An optional custom strawberry class for data fetching + and persistence logic. Defaults to the strawberry configured in `StrawchemyConfig`. name: The name of the GraphQL mutation field. description: The description of the GraphQL mutation field. @@ -537,7 +551,7 @@ def update_by_ids( input_type: type[MappedGraphQLDTO[T]], resolver: Any | None = None, *, - repository_type: AnyRepository | None = None, + repository_type: AnyRepositoryType | None = None, name: str | None = None, description: str | None = None, permission_classes: list[type[BasePermission]] | None = None, @@ -555,7 +569,7 @@ def update_by_ids( This method generates a mutation field that handles updating existing SQLAlchemy model instances based on their primary key(s). The input type should typically include the ID(s) of the record(s) to update and the - data to apply. It integrates with Strawchemy's repository system and + data to apply. It integrates with Strawchemy's strawberry system and allows for custom validation. Args: @@ -564,8 +578,8 @@ def update_by_ids( generated by `pk_update_input`, which includes primary key fields. resolver: An optional custom resolver function for the mutation. If not provided, Strawchemy will use a default resolver. - repository_type: An optional custom repository class for data fetching - and persistence logic. Defaults to the repository configured in + repository_type: An optional custom strawberry class for data fetching + and persistence logic. Defaults to the strawberry configured in `StrawchemyConfig`. name: The name of the GraphQL mutation field. description: The description of the GraphQL mutation field. @@ -609,7 +623,7 @@ def delete( filter_input: type[BooleanFilterDTO] | None = None, resolver: Any | None = None, *, - repository_type: AnyRepository | None = None, + repository_type: AnyRepositoryType | None = None, name: str | None = None, description: str | None = None, permission_classes: list[type[BasePermission]] | None = None, @@ -626,7 +640,7 @@ def delete( This method generates a mutation field that handles the deletion of SQLAlchemy model instances. Deletion can be based on filter criteria provided via `filter_input` or by ID if the `filter_input` is structured - to accept primary key(s). It integrates with Strawchemy's repository + to accept primary key(s). It integrates with Strawchemy's strawberry system for data persistence. Args: @@ -637,8 +651,8 @@ def delete( record based on an ID passed directly (implementation dependent). resolver: An optional custom resolver function for the mutation. If not provided, Strawchemy will use a default resolver. - repository_type: An optional custom repository class for data fetching - and persistence logic. Defaults to the repository configured in + repository_type: An optional custom strawberry class for data fetching + and persistence logic. Defaults to the strawberry configured in `StrawchemyConfig`. name: The name of the GraphQL mutation field. description: The description of the GraphQL mutation field. diff --git a/src/py.typed b/src/strawchemy/py.typed similarity index 100% rename from src/py.typed rename to src/strawchemy/py.typed diff --git a/src/strawchemy/repository/__init__.py b/src/strawchemy/repository/__init__.py new file mode 100644 index 00000000..804a2917 --- /dev/null +++ b/src/strawchemy/repository/__init__.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +from strawchemy.repository.strawberry import StrawchemyAsyncRepository, StrawchemySyncRepository + +__all__ = ("StrawchemyAsyncRepository", "StrawchemySyncRepository") diff --git a/src/strawchemy/sqlalchemy/repository/__init__.py b/src/strawchemy/repository/sqlalchemy/__init__.py similarity index 51% rename from src/strawchemy/sqlalchemy/repository/__init__.py rename to src/strawchemy/repository/sqlalchemy/__init__.py index 20166143..a2d7d543 100644 --- a/src/strawchemy/sqlalchemy/repository/__init__.py +++ b/src/strawchemy/repository/sqlalchemy/__init__.py @@ -1,7 +1,7 @@ from __future__ import annotations -from strawchemy.sqlalchemy.repository._async import SQLAlchemyGraphQLAsyncRepository -from strawchemy.sqlalchemy.repository._base import SQLAlchemyGraphQLRepository -from strawchemy.sqlalchemy.repository._sync import SQLAlchemyGraphQLSyncRepository +from strawchemy.repository.sqlalchemy._async import SQLAlchemyGraphQLAsyncRepository +from strawchemy.repository.sqlalchemy._base import SQLAlchemyGraphQLRepository +from strawchemy.repository.sqlalchemy._sync import SQLAlchemyGraphQLSyncRepository __all__ = ("SQLAlchemyGraphQLAsyncRepository", "SQLAlchemyGraphQLRepository", "SQLAlchemyGraphQLSyncRepository") diff --git a/src/strawchemy/sqlalchemy/repository/_async.py b/src/strawchemy/repository/sqlalchemy/_async.py similarity index 96% rename from src/strawchemy/sqlalchemy/repository/_async.py rename to src/strawchemy/repository/sqlalchemy/_async.py index e23afc6a..ad3ac829 100644 --- a/src/strawchemy/sqlalchemy/repository/_async.py +++ b/src/strawchemy/repository/sqlalchemy/_async.py @@ -4,15 +4,13 @@ from inspect import isclass from typing import TYPE_CHECKING, Any, TypeVar +from sqlalchemy import ColumnElement, Row, and_, delete, inspect, select, update from sqlalchemy.orm import RelationshipProperty -from sqlalchemy import ColumnElement, Row, and_, delete, inspect, select, update -from strawchemy.sqlalchemy._executor import AsyncQueryExecutor, QueryResult -from strawchemy.sqlalchemy._transpiler import QueryTranspiler -from strawchemy.sqlalchemy.repository._base import InsertData, MutationData, SQLAlchemyGraphQLRepository -from strawchemy.sqlalchemy.typing import AnyAsyncSession, DeclarativeT -from strawchemy.strawberry.mutation.input import UpsertData -from strawchemy.strawberry.mutation.types import RelationType +from strawchemy.repository.sqlalchemy._base import InsertData, MutationData, SQLAlchemyGraphQLRepository +from strawchemy.repository.typing import AnyAsyncSession, DeclarativeT +from strawchemy.schema.mutation import RelationType, UpsertData +from strawchemy.transpiler import AsyncQueryExecutor, QueryResult, QueryTranspiler if TYPE_CHECKING: from collections.abc import Sequence @@ -20,11 +18,11 @@ from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm.util import AliasedClass - from strawchemy.sqlalchemy.hook import QueryHook - from strawchemy.sqlalchemy.repository._base import InsertOrUpdate, RowLike - from strawchemy.strawberry.dto import BooleanFilterDTO, EnumDTO, OrderByDTO - from strawchemy.strawberry.mutation.input import Input, LevelInput - from strawchemy.strawberry.typing import QueryNodeType + from strawchemy.dto.strawberry import BooleanFilterDTO, EnumDTO, OrderByDTO + from strawchemy.repository.sqlalchemy._base import InsertOrUpdate, RowLike + from strawchemy.schema.mutation import Input, LevelInput + from strawchemy.transpiler.hook import QueryHook + from strawchemy.typing import QueryNodeType __all__ = ("SQLAlchemyGraphQLAsyncRepository",) @@ -266,7 +264,7 @@ async def _list_by_ids( ) -> QueryResult[DeclarativeT]: """Retrieves multiple records by their primary keys with optional selection. - Fetches records from the repository's main model that match the provided + Fetches records from the strawberry's main model that match the provided primary key combinations. Allows specifying a GraphQL selection Args: @@ -301,7 +299,7 @@ async def list( ) -> QueryResult[DeclarativeT]: """Retrieves a list of records based on filtering, ordering, and pagination. - Fetches records from the repository's main model, applying optional + Fetches records from the strawberry's main model, applying optional filtering, ordering, pagination (limit/offset), and distinct constraints. Supports GraphQL selection sets for optimized data retrieval and query hooks for customization. diff --git a/src/strawchemy/sqlalchemy/repository/_base.py b/src/strawchemy/repository/sqlalchemy/_base.py similarity index 95% rename from src/strawchemy/sqlalchemy/repository/_base.py rename to src/strawchemy/repository/sqlalchemy/_base.py index c2755e84..f44eeed2 100644 --- a/src/strawchemy/sqlalchemy/repository/_base.py +++ b/src/strawchemy/repository/sqlalchemy/_base.py @@ -5,32 +5,31 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Generic, Literal, NamedTuple, TypeAlias, TypeVar, cast +from sqlalchemy import Column, Function, Insert, Row, Table, func, insert, inspect from sqlalchemy.dialects import mysql, postgresql, sqlite from sqlalchemy.orm import RelationshipProperty -from sqlalchemy import Column, Function, Insert, Row, Table, func, insert, inspect -from strawchemy.dto.inspectors.sqlalchemy import SQLAlchemyInspector +from strawchemy.dto.inspectors import SQLAlchemyInspector from strawchemy.exceptions import StrawchemyError -from strawchemy.sqlalchemy._transpiler import QueryTranspiler -from strawchemy.sqlalchemy.typing import DeclarativeT, QueryExecutorT, SessionT -from strawchemy.strawberry.mutation.types import RelationType +from strawchemy.repository.typing import DeclarativeT, QueryExecutorT, SessionT +from strawchemy.schema.mutation.types import RelationType +from strawchemy.transpiler import QueryTranspiler if TYPE_CHECKING: from collections.abc import Mapping, Sequence + from sqlalchemy import Select from sqlalchemy.orm import DeclarativeBase from sqlalchemy.sql.base import ReadOnlyColumnCollection from sqlalchemy.sql.elements import KeyedColumnElement - from sqlalchemy import Select - from strawchemy.sqlalchemy.hook import QueryHook - from strawchemy.strawberry.dto import BooleanFilterDTO, EnumDTO, OrderByDTO - from strawchemy.strawberry.mutation.input import Input, LevelInput, UpsertData - from strawchemy.strawberry.typing import QueryNodeType - from strawchemy.typing import SupportedDialect + from strawchemy.dto.strawberry import BooleanFilterDTO, EnumDTO, OrderByDTO + from strawchemy.schema.mutation import Input, LevelInput, UpsertData + from strawchemy.transpiler.hook import QueryHook + from strawchemy.typing import QueryNodeType, SupportedDialect -__all__ = ("InsertOrUpdate", "RowLike", "SQLAlchemyGraphQLRepository") +__all__ = ("InsertData", "InsertOrUpdate", "MutationData", "RowLike", "SQLAlchemyGraphQLRepository") T = TypeVar("T", bound=Any) diff --git a/src/strawchemy/sqlalchemy/repository/_sync.py b/src/strawchemy/repository/sqlalchemy/_sync.py similarity index 96% rename from src/strawchemy/sqlalchemy/repository/_sync.py rename to src/strawchemy/repository/sqlalchemy/_sync.py index 83bc2d3d..63845f24 100644 --- a/src/strawchemy/sqlalchemy/repository/_sync.py +++ b/src/strawchemy/repository/sqlalchemy/_sync.py @@ -1,20 +1,18 @@ # Do not edit this file directly. It has been autogenerated from -# src/strawchemy/sqlalchemy/repository/_async.py +# src/strawchemy/repository/sqlalchemy/_async.py from __future__ import annotations from collections import defaultdict, namedtuple from inspect import isclass from typing import TYPE_CHECKING, Any, TypeVar +from sqlalchemy import ColumnElement, Row, and_, delete, inspect, select, update from sqlalchemy.orm import RelationshipProperty -from sqlalchemy import ColumnElement, Row, and_, delete, inspect, select, update -from strawchemy.sqlalchemy._executor import QueryResult, SyncQueryExecutor -from strawchemy.sqlalchemy._transpiler import QueryTranspiler -from strawchemy.sqlalchemy.repository._base import InsertData, MutationData, SQLAlchemyGraphQLRepository -from strawchemy.sqlalchemy.typing import AnySyncSession, DeclarativeT -from strawchemy.strawberry.mutation.input import UpsertData -from strawchemy.strawberry.mutation.types import RelationType +from strawchemy.repository.sqlalchemy._base import InsertData, MutationData, SQLAlchemyGraphQLRepository +from strawchemy.repository.typing import AnySyncSession, DeclarativeT +from strawchemy.schema.mutation import RelationType, UpsertData +from strawchemy.transpiler import QueryResult, QueryTranspiler, SyncQueryExecutor if TYPE_CHECKING: from collections.abc import Sequence @@ -22,11 +20,11 @@ from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm.util import AliasedClass - from strawchemy.sqlalchemy.hook import QueryHook - from strawchemy.sqlalchemy.repository._base import InsertOrUpdate, RowLike - from strawchemy.strawberry.dto import BooleanFilterDTO, EnumDTO, OrderByDTO - from strawchemy.strawberry.mutation.input import Input, LevelInput - from strawchemy.strawberry.typing import QueryNodeType + from strawchemy.dto.strawberry import BooleanFilterDTO, EnumDTO, OrderByDTO + from strawchemy.repository.sqlalchemy._base import InsertOrUpdate, RowLike + from strawchemy.schema.mutation import Input, LevelInput + from strawchemy.transpiler.hook import QueryHook + from strawchemy.typing import QueryNodeType __all__ = () @@ -268,7 +266,7 @@ def _list_by_ids( ) -> QueryResult[DeclarativeT]: """Retrieves multiple records by their primary keys with optional selection. - Fetches records from the repository's main model that match the provided + Fetches records from the strawberry's main model that match the provided primary key combinations. Allows specifying a GraphQL selection Args: @@ -303,7 +301,7 @@ def list( ) -> QueryResult[DeclarativeT]: """Retrieves a list of records based on filtering, ordering, and pagination. - Fetches records from the repository's main model, applying optional + Fetches records from the strawberry's main model, applying optional filtering, ordering, pagination (limit/offset), and distinct constraints. Supports GraphQL selection sets for optimized data retrieval and query hooks for customization. diff --git a/src/strawchemy/strawberry/repository/__init__.py b/src/strawchemy/repository/strawberry/__init__.py similarity index 50% rename from src/strawchemy/strawberry/repository/__init__.py rename to src/strawchemy/repository/strawberry/__init__.py index 637b27b3..7ccecb18 100644 --- a/src/strawchemy/strawberry/repository/__init__.py +++ b/src/strawchemy/repository/strawberry/__init__.py @@ -1,6 +1,6 @@ from __future__ import annotations -from strawchemy.strawberry.repository._async import StrawchemyAsyncRepository -from strawchemy.strawberry.repository._sync import StrawchemySyncRepository +from strawchemy.repository.strawberry._async import StrawchemyAsyncRepository +from strawchemy.repository.strawberry._sync import StrawchemySyncRepository __all__ = ("StrawchemyAsyncRepository", "StrawchemySyncRepository") diff --git a/src/strawchemy/strawberry/repository/_async.py b/src/strawchemy/repository/strawberry/_async.py similarity index 92% rename from src/strawchemy/strawberry/repository/_async.py rename to src/strawchemy/repository/strawberry/_async.py index b806ebbc..1908b680 100644 --- a/src/strawchemy/strawberry/repository/_async.py +++ b/src/strawchemy/repository/strawberry/_async.py @@ -1,6 +1,6 @@ -"""Asynchronous repository implementation for Strawchemy. +"""Asynchronous strawberry implementation for Strawchemy. -This module provides an asynchronous implementation of the Strawchemy repository +This module provides an asynchronous implementation of the Strawchemy strawberry pattern, built on top of SQLAlchemy's asynchronous API. """ @@ -9,17 +9,17 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, TypeVar -from strawchemy.sqlalchemy.repository import SQLAlchemyGraphQLAsyncRepository -from strawchemy.strawberry._utils import default_session_getter, dto_model_from_type, strawberry_contained_user_type -from strawchemy.strawberry.repository._base import GraphQLResult, StrawchemyRepository +from strawchemy.repository.sqlalchemy import SQLAlchemyGraphQLAsyncRepository +from strawchemy.repository.strawberry.base import IS_ASYNC_REPOSITORY, GraphQLResult, StrawchemyRepository +from strawchemy.utils.strawberry import default_session_getter, dto_model_from_type, strawberry_contained_user_type if TYPE_CHECKING: from sqlalchemy import Select from strawberry import Info - from strawchemy.sqlalchemy.typing import AnyAsyncSession - from strawchemy.strawberry.dto import BooleanFilterDTO, EnumDTO, OrderByDTO - from strawchemy.strawberry.mutation.input import Input, InputModel - from strawchemy.strawberry.typing import AsyncSessionGetter + + from strawchemy.dto.strawberry import BooleanFilterDTO, EnumDTO, OrderByDTO + from strawchemy.repository.typing import AnyAsyncSession, AsyncSessionGetter + from strawchemy.schema.mutation import Input, InputModel __all__ = ("StrawchemyAsyncRepository",) @@ -28,13 +28,13 @@ @dataclass class StrawchemyAsyncRepository(StrawchemyRepository[T]): - """Asynchronous repository implementation for GraphQL data access. + """Asynchronous strawberry implementation for GraphQL data access. This class provides asynchronous methods for querying and mutating data through GraphQL, using SQLAlchemy's asynchronous API under the hood. Args: - type: The Strawberry GraphQL type this repository works with + type: The Strawberry GraphQL type this strawberry works with info: The GraphQL resolver info object session_getter: Callable to get an async database session session: Optional explicit async database session to use @@ -43,10 +43,12 @@ class StrawchemyAsyncRepository(StrawchemyRepository[T]): deterministic_ordering: Whether to ensure deterministic ordering of results """ + is_async = IS_ASYNC_REPOSITORY + type: type[T] info: Info[Any, Any] - # sqlalchemy related settings + # strawberry related settings session_getter: AsyncSessionGetter = default_session_getter session: AnyAsyncSession | None = None filter_statement: Select[tuple[Any]] | None = None @@ -54,7 +56,7 @@ class StrawchemyAsyncRepository(StrawchemyRepository[T]): deterministic_ordering: bool = False def graphql_repository(self) -> SQLAlchemyGraphQLAsyncRepository[Any]: - """Create and configure the underlying async SQLAlchemy GraphQL repository. + """Create and configure the underlying async SQLAlchemy GraphQL strawberry. Returns: A configured SQLAlchemyGraphQLAsyncRepository instance diff --git a/src/strawchemy/strawberry/repository/_node.py b/src/strawchemy/repository/strawberry/_node.py similarity index 92% rename from src/strawchemy/strawberry/repository/_node.py rename to src/strawchemy/repository/strawberry/_node.py index 3d8c4862..44423f72 100644 --- a/src/strawchemy/strawberry/repository/_node.py +++ b/src/strawchemy/repository/strawberry/_node.py @@ -8,19 +8,17 @@ from strawberry.utils.typing import type_has_annotation from strawchemy.constants import AGGREGATIONS_KEY, NODES_KEY +from strawchemy.dto.strawberry import QueryNode from strawchemy.dto.types import DTOMissing -from strawchemy.graph import GraphError -from strawchemy.sqlalchemy import SQLAlchemyGraphQLRepository -from strawchemy.strawberry._instance import MapperModelInstance -from strawchemy.strawberry.dto import QueryNode +from strawchemy.exceptions import GraphError +from strawchemy.instance import MapperModelInstance +from strawchemy.repository.sqlalchemy import SQLAlchemyGraphQLRepository if TYPE_CHECKING: from collections.abc import Sequence - from strawchemy.sqlalchemy._executor import NodeResult, QueryResult - from strawchemy.strawberry.typing import QueryNodeType - from strawchemy.typing import DataclassProtocol - + from strawchemy.transpiler import NodeResult, QueryResult + from strawchemy.typing import DataclassProtocol, QueryNodeType __all__ = ("SQLAlchemyGraphQLRepository", "StrawberryQueryNode") diff --git a/src/strawchemy/strawberry/repository/_sync.py b/src/strawchemy/repository/strawberry/_sync.py similarity index 91% rename from src/strawchemy/strawberry/repository/_sync.py rename to src/strawchemy/repository/strawberry/_sync.py index 09416265..b6141a36 100644 --- a/src/strawchemy/strawberry/repository/_sync.py +++ b/src/strawchemy/repository/strawberry/_sync.py @@ -1,8 +1,8 @@ # Do not edit this file directly. It has been autogenerated from -# src/strawchemy/strawberry/repository/_async.py -"""Asynchronous repository implementation for Strawchemy. +# src/strawchemy/repository/strawberry/_async.py +"""Asynchronous strawberry implementation for Strawchemy. -This module provides an asynchronous implementation of the Strawchemy repository +This module provides an asynchronous implementation of the Strawchemy strawberry pattern, built on top of SQLAlchemy's asynchronous API. """ @@ -11,17 +11,21 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, TypeVar -from strawchemy.sqlalchemy.repository import SQLAlchemyGraphQLSyncRepository -from strawchemy.strawberry._utils import default_session_getter, dto_model_from_type, strawberry_contained_user_type -from strawchemy.strawberry.repository._base import GraphQLResult, StrawchemyRepository +from strawchemy.repository.sqlalchemy import SQLAlchemyGraphQLSyncRepository +from strawchemy.repository.strawberry.base import ( + IS_SYNC_REPOSITORY, + GraphQLResult, + StrawchemyRepository, +) +from strawchemy.utils.strawberry import default_session_getter, dto_model_from_type, strawberry_contained_user_type if TYPE_CHECKING: from sqlalchemy import Select from strawberry import Info - from strawchemy.sqlalchemy.typing import AnySyncSession - from strawchemy.strawberry.dto import BooleanFilterDTO, EnumDTO, OrderByDTO - from strawchemy.strawberry.mutation.input import Input, InputModel - from strawchemy.strawberry.typing import SyncSessionGetter + + from strawchemy.dto.strawberry import BooleanFilterDTO, EnumDTO, OrderByDTO + from strawchemy.repository.typing import AnySyncSession, SyncSessionGetter + from strawchemy.schema.mutation import Input, InputModel __all__ = () @@ -30,13 +34,13 @@ @dataclass class StrawchemySyncRepository(StrawchemyRepository[T]): - """Asynchronous repository implementation for GraphQL data access. + """Asynchronous strawberry implementation for GraphQL data access. This class provides asynchronous methods for querying and mutating data through GraphQL, using SQLAlchemy's asynchronous API under the hood. Args: - type: The Strawberry GraphQL type this repository works with + type: The Strawberry GraphQL type this strawberry works with info: The GraphQL resolver info object session_getter: Callable to get an async database session session: Optional explicit async database session to use @@ -45,10 +49,12 @@ class StrawchemySyncRepository(StrawchemyRepository[T]): deterministic_ordering: Whether to ensure deterministic ordering of results """ + is_async = IS_SYNC_REPOSITORY + type: type[T] info: Info[Any, Any] - # sqlalchemy related settings + # strawberry related settings session_getter: SyncSessionGetter = default_session_getter session: AnySyncSession | None = None filter_statement: Select[tuple[Any]] | None = None @@ -56,7 +62,7 @@ class StrawchemySyncRepository(StrawchemyRepository[T]): deterministic_ordering: bool = False def graphql_repository(self) -> SQLAlchemyGraphQLSyncRepository[Any]: - """Create and configure the underlying async SQLAlchemy GraphQL repository. + """Create and configure the underlying async SQLAlchemy GraphQL strawberry. Returns: A configured SQLAlchemyGraphQLAsyncRepository instance diff --git a/src/strawchemy/strawberry/repository/_base.py b/src/strawchemy/repository/strawberry/base.py similarity index 90% rename from src/strawchemy/strawberry/repository/_base.py rename to src/strawchemy/repository/strawberry/base.py index 0e688199..98b016fb 100644 --- a/src/strawchemy/strawberry/repository/_base.py +++ b/src/strawchemy/repository/strawberry/base.py @@ -1,6 +1,6 @@ -"""Base repository module for Strawchemy framework. +"""Base strawberry module for Strawchemy framework. -This module provides the core repository implementation for GraphQL data access +This module provides the core strawberry implementation for GraphQL data access in Strawchemy, including base classes for query building and result handling. """ @@ -19,10 +19,7 @@ from strawchemy.constants import JSON_PATH_KEY, ORDER_BY_KEY from strawchemy.dto.base import ModelT -from strawchemy.exceptions import StrawchemyError -from strawchemy.graph import NodeMetadata -from strawchemy.strawberry._utils import dto_model_from_type, strawberry_contained_user_type -from strawchemy.strawberry.dto import ( +from strawchemy.dto.strawberry import ( DTOKey, OrderByRelationFilterDTO, QueryNode, @@ -30,22 +27,27 @@ RelationFilterDTO, StrawchemyDTOAttributes, ) -from strawchemy.strawberry.mutation.types import error_type_names -from strawchemy.strawberry.repository._node import StrawberryQueryNode -from strawchemy.utils import camel_to_snake, snake_keys +from strawchemy.exceptions import StrawchemyError +from strawchemy.repository.strawberry._node import StrawberryQueryNode +from strawchemy.schema.mutation import error_type_names +from strawchemy.utils.graph import NodeMetadata +from strawchemy.utils.strawberry import dto_model_from_type, strawberry_contained_user_type +from strawchemy.utils.text import camel_to_snake, snake_keys if TYPE_CHECKING: + from strawberry import Info from strawberry.types.field import StrawberryField - from strawberry import Info - from strawchemy.sqlalchemy._executor import QueryResult - from strawchemy.sqlalchemy.hook import QueryHook - from strawchemy.strawberry.typing import QueryNodeType, StrawchemyTypeWithStrawberryObjectDefinition + from strawchemy.transpiler import QueryHook, QueryResult + from strawchemy.typing import QueryNodeType, StrawchemyTypeWithStrawberryObjectDefinition -__all__ = ("GraphQLResult", "StrawchemyRepository") +__all__ = ("IS_ASYNC_REPOSITORY", "IS_SYNC_REPOSITORY", "GraphQLResult", "StrawchemyRepository") T = TypeVar("T") +IS_ASYNC_REPOSITORY: bool = True +IS_SYNC_REPOSITORY: bool = not IS_ASYNC_REPOSITORY + @dataclass class GraphQLResult(Generic[ModelT, T]): @@ -132,13 +134,13 @@ def instance(self) -> ModelT: @dataclass class StrawchemyRepository(Generic[T]): - """Base repository for GraphQL data access in Strawchemy. + """Base strawberry for GraphQL data access in Strawchemy. This class provides the core functionality for building and executing GraphQL queries against a database, with support for filtering, ordering, and field selection. Args: - type: The Strawberry GraphQL type this repository works with + type: The Strawberry GraphQL type this strawberry works with info: The GraphQL resolver info object root_aggregations: Whether to enable root-level aggregations auto_snake_case: Whether to automatically convert field names to snake_case @@ -151,6 +153,8 @@ class StrawchemyRepository(Generic[T]): _ignored_field_names: ClassVar[frozenset[str]] = frozenset({"__typename"}) + is_async: ClassVar[bool] + type: type[T] info: Info[Any, Any] root_aggregations: bool = False @@ -190,7 +194,7 @@ def _relation_filter( @classmethod def _get_field_hooks(cls, field: StrawberryField) -> QueryHook[Any] | Sequence[QueryHook[Any]] | None: - from strawchemy.strawberry._field import StrawchemyField # noqa: PLC0415 + from strawchemy.schema.field import StrawchemyField # noqa: PLC0415 return field.query_hook if isinstance(field, StrawchemyField) else None diff --git a/src/strawchemy/sqlalchemy/typing.py b/src/strawchemy/repository/typing.py similarity index 78% rename from src/strawchemy/sqlalchemy/typing.py rename to src/strawchemy/repository/typing.py index 84859424..a0d07bb5 100644 --- a/src/strawchemy/sqlalchemy/typing.py +++ b/src/strawchemy/repository/typing.py @@ -6,15 +6,15 @@ from collections import OrderedDict from collections.abc import Callable + from sqlalchemy import Function from sqlalchemy.ext.asyncio import AsyncSession, async_scoped_session from sqlalchemy.orm import DeclarativeBase, Session, scoped_session from sqlalchemy.sql import SQLColumnExpression + from strawberry import Info - from sqlalchemy import Function - from strawchemy.sqlalchemy._executor import QueryExecutor - from strawchemy.sqlalchemy.hook import QueryHook - from strawchemy.strawberry.dto import OrderByEnum - from strawchemy.strawberry.filters.base import GraphQLComparison + from strawchemy.dto.strawberry import OrderByEnum + from strawchemy.schema.filters import GraphQLComparison + from strawchemy.transpiler import QueryExecutor, QueryHook __all__ = ( @@ -46,3 +46,6 @@ AnyAsyncSession: TypeAlias = "AsyncSession | async_scoped_session[AsyncSession]" AnySession: TypeAlias = "AnySyncSession | AnyAsyncSession" OrderBySpec: TypeAlias = "tuple[SQLColumnExpression[Any], OrderByEnum]" +AsyncSessionGetter: TypeAlias = "Callable[[Info[Any, Any]], AnyAsyncSession]" +SyncSessionGetter: TypeAlias = "Callable[[Info[Any, Any]], AnySyncSession]" +AnySessionGetter: TypeAlias = "AsyncSessionGetter | SyncSessionGetter" diff --git a/src/strawchemy/schema/__init__.py b/src/strawchemy/schema/__init__.py new file mode 100644 index 00000000..9d48db4f --- /dev/null +++ b/src/strawchemy/schema/__init__.py @@ -0,0 +1 @@ +from __future__ import annotations diff --git a/src/strawchemy/schema/factories/__init__.py b/src/strawchemy/schema/factories/__init__.py new file mode 100644 index 00000000..d20c7434 --- /dev/null +++ b/src/strawchemy/schema/factories/__init__.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from strawchemy.schema.factories.aggregations import AggregationInspector +from strawchemy.schema.factories.base import ( + ChildOptions, + GraphQLDTOFactory, + MappedGraphQLDTOT, + StrawchemyMappedFactory, + StrawchemyUnMappedDTOFactory, + UnmappedGraphQLDTOT, +) +from strawchemy.schema.factories.enum import EnumDTOBackend, EnumDTOFactory, UpsertConflictFieldsEnumDTOBackend +from strawchemy.schema.factories.inputs import AggregateFilterDTOFactory, BooleanFilterDTOFactory, OrderByDTOFactory +from strawchemy.schema.factories.types import ( + AggregateDTOFactory, + DistinctOnFieldsDTOFactory, + InputFactory, + RootAggregateTypeDTOFactory, + TypeDTOFactory, + UpsertConflictFieldsDTOFactory, +) + +__all__ = ( + "AggregateDTOFactory", + "AggregateFilterDTOFactory", + "AggregationInspector", + "BooleanFilterDTOFactory", + "ChildOptions", + "DistinctOnFieldsDTOFactory", + "EnumDTOBackend", + "EnumDTOFactory", + "GraphQLDTOFactory", + "InputFactory", + "MappedGraphQLDTOT", + "OrderByDTOFactory", + "RootAggregateTypeDTOFactory", + "StrawchemyMappedFactory", + "StrawchemyUnMappedDTOFactory", + "TypeDTOFactory", + "UnmappedGraphQLDTOT", + "UpsertConflictFieldsDTOFactory", + "UpsertConflictFieldsEnumDTOBackend", +) diff --git a/src/strawchemy/strawberry/factories/aggregations.py b/src/strawchemy/schema/factories/aggregations.py similarity index 97% rename from src/strawchemy/strawberry/factories/aggregations.py rename to src/strawchemy/schema/factories/aggregations.py index eff61e7a..a53103a0 100644 --- a/src/strawchemy/strawberry/factories/aggregations.py +++ b/src/strawchemy/schema/factories/aggregations.py @@ -10,8 +10,7 @@ from typing_extensions import override from strawchemy.dto.backend.strawberry import StrawberrryDTOBackend -from strawchemy.dto.exceptions import DTOError -from strawchemy.strawberry.dto import ( +from strawchemy.dto.strawberry import ( DTOKey, EnumDTO, FilterFunctionInfo, @@ -20,8 +19,9 @@ OutputFunctionInfo, UnmappedStrawberryGraphQLDTO, ) -from strawchemy.strawberry.factories.base import GraphQLDTOFactory -from strawchemy.strawberry.factories.enum import EnumDTOBackend, EnumDTOFactory +from strawchemy.exceptions import DTOError +from strawchemy.schema.factories.base import GraphQLDTOFactory +from strawchemy.schema.factories.enum import EnumDTOBackend, EnumDTOFactory if TYPE_CHECKING: from collections.abc import Generator @@ -30,10 +30,10 @@ from strawchemy.dto.base import DTOBackend, DTOBase, DTOFieldDefinition, ModelT, Relation from strawchemy.dto.types import DTOConfig - from strawchemy.graph import Node from strawchemy.mapper import Strawchemy - from strawchemy.sqlalchemy.typing import DeclarativeT - from strawchemy.strawberry.typing import AggregationFunction, AggregationType, FunctionInfo + from strawchemy.repository.typing import DeclarativeT + from strawchemy.typing import AggregationFunction, AggregationType, FunctionInfo + from strawchemy.utils.graph import Node T = TypeVar("T") diff --git a/src/strawchemy/strawberry/factories/base.py b/src/strawchemy/schema/factories/base.py similarity index 95% rename from src/strawchemy/strawberry/factories/base.py rename to src/strawchemy/schema/factories/base.py index e4e75c3f..8fb7bd99 100644 --- a/src/strawchemy/strawberry/factories/base.py +++ b/src/strawchemy/schema/factories/base.py @@ -15,25 +15,18 @@ from __future__ import annotations import dataclasses -from collections.abc import Generator, Sequence from functools import cached_property from typing import TYPE_CHECKING, Any, Literal, Optional, TypeAlias, TypeVar, get_type_hints from sqlalchemy.orm import DeclarativeBase, QueryableAttribute +from strawberry import UNSET from strawberry.types.auto import StrawberryAuto from strawberry.utils.typing import type_has_annotation from typing_extensions import dataclass_transform, override -from strawberry import UNSET +from strawchemy import typing as strawchemy_typing from strawchemy.dto.base import DTOBackend, DTOBase, DTOFactory, DTOFieldDefinition, Relation -from strawchemy.dto.types import DTOAuto, DTOConfig, DTOScope, Purpose -from strawchemy.dto.utils import config -from strawchemy.exceptions import StrawchemyError -from strawchemy.graph import Node -from strawchemy.strawberry import typing as strawchemy_typing -from strawchemy.strawberry._instance import MapperModelInstance -from strawchemy.strawberry._registry import RegistryTypeInfo -from strawchemy.strawberry.dto import ( +from strawchemy.dto.strawberry import ( BooleanFilterDTO, DTOKey, GraphQLFieldDefinition, @@ -42,22 +35,27 @@ StrawchemyDTOAttributes, UnmappedStrawberryGraphQLDTO, ) -from strawchemy.strawberry.typing import GraphQLDTOT, GraphQLPurpose, MappedGraphQLDTO -from strawchemy.types import DefaultOffsetPagination -from strawchemy.utils import get_annotations +from strawchemy.dto.types import DTOAuto, DTOScope, Purpose +from strawchemy.dto.utils import config +from strawchemy.exceptions import StrawchemyError +from strawchemy.instance import MapperModelInstance +from strawchemy.schema.pagination import DefaultOffsetPagination +from strawchemy.transpiler import hook +from strawchemy.typing import GraphQLDTOT, GraphQLPurpose, GraphQLType, MappedGraphQLDTO +from strawchemy.utils.annotation import get_annotations +from strawchemy.utils.registry import RegistryTypeInfo if TYPE_CHECKING: from collections.abc import Callable, Generator, Mapping, Sequence from strawchemy import Strawchemy + from strawchemy.dto.inspectors import SQLAlchemyGraphQLInspector from strawchemy.dto.types import DTOConfig, ExcludeFields, IncludeFields - from strawchemy.graph import Node - from strawchemy.sqlalchemy.hook import QueryHook - from strawchemy.sqlalchemy.inspector import SQLAlchemyGraphQLInspector - from strawchemy.strawberry.typing import GraphQLType + from strawchemy.transpiler.hook import QueryHook + from strawchemy.utils.graph import Node from strawchemy.validation.pydantic import MappedPydanticGraphQLDTO -__all__ = ("GraphQLDTOFactory",) +__all__ = ("GraphQLDTOFactory", "StrawchemyMappedFactory", "StrawchemyUnMappedDTOFactory") T = TypeVar("T", bound="DeclarativeBase") PydanticGraphQLDTOT = TypeVar("PydanticGraphQLDTOT", bound="MappedPydanticGraphQLDTO[Any]") @@ -73,7 +71,7 @@ def type_scope_to_dto_scope(scope: TypeScope) -> DTOScope: @dataclasses.dataclass(eq=True, frozen=True) -class _ChildOptions: +class ChildOptions: pagination: DefaultOffsetPagination | bool = False order_by: bool = False @@ -99,9 +97,9 @@ def _type_info( current_node: Node[Relation[Any, GraphQLDTOT], None] | None, override: bool = False, user_defined: bool = False, - child_options: _ChildOptions | None = None, + child_options: ChildOptions | None = None, ) -> RegistryTypeInfo: - child_options = child_options or _ChildOptions() + child_options = child_options or ChildOptions() graphql_type = self.graphql_type(dto_config) model: type[DeclarativeBase] | None = dto.__dto_model__ if issubclass(dto, MappedStrawberryGraphQLDTO) else None # type: ignore[reportGeneralTypeIssues] default_name = self.root_dto_name(model, dto_config, current_node) if model else dto.__name__ @@ -132,7 +130,7 @@ def _register_type( directives: Sequence[object] | None = (), override: bool = False, user_defined: bool = False, - child_options: _ChildOptions | None = None, + child_options: ChildOptions | None = None, ) -> type[StrawchemyDTOT]: type_info = self._type_info( dto, @@ -249,7 +247,7 @@ def wrapper(class_: type[Any]) -> type[GraphQLDTOT]: override=override, user_defined=True, mode=mode, - child_options=_ChildOptions(pagination=child_pagination, order_by=child_order_by), + child_options=ChildOptions(pagination=child_pagination, order_by=child_order_by), ) dto.__strawchemy_query_hook__ = query_hook if issubclass(dto, MappedStrawberryGraphQLDTO): @@ -310,8 +308,6 @@ def wrapper(class_: type[Any]) -> type[GraphQLDTOT]: @cached_property def _namespace(self) -> dict[str, Any]: - from strawchemy.sqlalchemy import hook # noqa: PLC0415 - return vars(strawchemy_typing) | vars(hook) @classmethod diff --git a/src/strawchemy/strawberry/factories/enum.py b/src/strawchemy/schema/factories/enum.py similarity index 96% rename from src/strawchemy/strawberry/factories/enum.py rename to src/strawchemy/schema/factories/enum.py index 84a49a19..05d445ab 100644 --- a/src/strawchemy/strawberry/factories/enum.py +++ b/src/strawchemy/schema/factories/enum.py @@ -9,15 +9,15 @@ from typing_extensions import override from strawchemy.dto.base import DTOBackend, DTOBase, DTOFactory, DTOFieldDefinition, Relation +from strawchemy.dto.strawberry import EnumDTO, GraphQLFieldDefinition from strawchemy.dto.types import DTOConfig, ExcludeFields, IncludeFields, Purpose -from strawchemy.strawberry.dto import EnumDTO, GraphQLFieldDefinition -from strawchemy.utils import snake_to_lower_camel_case +from strawchemy.utils.text import snake_to_lower_camel_case if TYPE_CHECKING: from collections.abc import Callable, Generator, Iterable, Mapping - from strawchemy.graph import Node - from strawchemy.sqlalchemy.inspector import SQLAlchemyGraphQLInspector + from strawchemy.dto.inspectors import SQLAlchemyGraphQLInspector + from strawchemy.utils.graph import Node T = TypeVar("T") diff --git a/src/strawchemy/strawberry/factories/inputs.py b/src/strawchemy/schema/factories/inputs.py similarity index 95% rename from src/strawchemy/strawberry/factories/inputs.py rename to src/strawchemy/schema/factories/inputs.py index 26396cc8..02869d39 100644 --- a/src/strawchemy/strawberry/factories/inputs.py +++ b/src/strawchemy/schema/factories/inputs.py @@ -1,18 +1,12 @@ from __future__ import annotations -from collections.abc import Generator, Sequence from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union -from sqlalchemy.orm import DeclarativeBase, QueryableAttribute +from strawberry import UNSET from typing_extensions import override -from strawberry import UNSET from strawchemy.dto.backend.strawberry import StrawberrryDTOBackend -from strawchemy.dto.base import DTOBackend, DTOBase, DTOFieldDefinition, Relation -from strawchemy.dto.types import DTOConfig, DTOMissing, Purpose -from strawchemy.graph import Node -from strawchemy.strawberry._registry import RegistryTypeInfo -from strawchemy.strawberry.dto import ( +from strawchemy.dto.strawberry import ( AggregateFieldDefinition, AggregateFilterDTO, AggregationFunctionFilterDTO, @@ -25,10 +19,11 @@ OrderByDTO, OrderByEnum, ) -from strawchemy.strawberry.factories.aggregations import AggregationInspector -from strawchemy.strawberry.factories.base import StrawchemyUnMappedDTOFactory, UnmappedGraphQLDTOT -from strawchemy.strawberry.typing import AggregationFunction, GraphQLFilterDTOT, GraphQLPurpose -from strawchemy.utils import snake_to_camel +from strawchemy.dto.types import DTOConfig, DTOMissing, Purpose +from strawchemy.schema.factories import AggregationInspector, StrawchemyUnMappedDTOFactory, UnmappedGraphQLDTOT +from strawchemy.typing import AggregationFunction, GraphQLFilterDTOT, GraphQLPurpose, GraphQLType +from strawchemy.utils.registry import RegistryTypeInfo +from strawchemy.utils.text import snake_to_camel if TYPE_CHECKING: from collections.abc import Callable, Generator, Mapping, Sequence @@ -38,11 +33,9 @@ from strawchemy import Strawchemy from strawchemy.dto.base import DTOBackend, DTOBase, DTOFieldDefinition, ModelFieldT, Relation from strawchemy.dto.types import ExcludeFields, IncludeFields - from strawchemy.graph import Node - from strawchemy.sqlalchemy.typing import DeclarativeT - from strawchemy.strawberry.filters import GraphQLFilter - from strawchemy.strawberry.typing import GraphQLType - + from strawchemy.repository.typing import DeclarativeT + from strawchemy.schema.filters import GraphQLFilter + from strawchemy.utils.graph import Node T = TypeVar("T") diff --git a/src/strawchemy/strawberry/factories/types.py b/src/strawchemy/schema/factories/types.py similarity index 95% rename from src/strawchemy/strawberry/factories/types.py rename to src/strawchemy/schema/factories/types.py index b6fce185..861325bd 100644 --- a/src/strawchemy/strawberry/factories/types.py +++ b/src/strawchemy/schema/factories/types.py @@ -1,21 +1,17 @@ from __future__ import annotations -from collections.abc import Generator from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union +from sqlalchemy import JSON from sqlalchemy.orm import DeclarativeBase, QueryableAttribute from strawberry.annotation import StrawberryAnnotation from strawberry.types.arguments import StrawberryArgument from typing_extensions import Self, override -from sqlalchemy import JSON from strawchemy.constants import AGGREGATIONS_KEY, JSON_PATH_KEY, NODES_KEY from strawchemy.dto.backend.strawberry import StrawberrryDTOBackend from strawchemy.dto.base import DTOFactory, DTOFieldDefinition, MappedDTO -from strawchemy.dto.exceptions import EmptyDTOError -from strawchemy.dto.types import DTOConfig, DTOMissing, Purpose -from strawchemy.dto.utils import read_all_partial_config, read_partial, write_all_config -from strawchemy.strawberry.dto import ( +from strawchemy.dto.strawberry import ( AggregateDTO, AggregateFieldDefinition, DTOKey, @@ -24,24 +20,29 @@ GraphQLFieldDefinition, MappedStrawberryGraphQLDTO, ) -from strawchemy.strawberry.factories.aggregations import AggregationInspector -from strawchemy.strawberry.factories.base import ( +from strawchemy.dto.types import DTOConfig, DTOMissing, Purpose +from strawchemy.dto.utils import read_all_partial_config, read_partial, write_all_config +from strawchemy.exceptions import EmptyDTOError +from strawchemy.schema.factories import ( + AggregationInspector, + ChildOptions, + EnumDTOFactory, GraphQLDTOFactory, MappedGraphQLDTOT, + OrderByDTOFactory, StrawchemyMappedFactory, - _ChildOptions, + UpsertConflictFieldsEnumDTOBackend, ) -from strawchemy.strawberry.factories.enum import EnumDTOFactory, UpsertConflictFieldsEnumDTOBackend -from strawchemy.strawberry.factories.inputs import OrderByDTOFactory -from strawchemy.strawberry.mutation.types import ( +from strawchemy.schema.mutation import ( RequiredToManyUpdateInput, RequiredToOneInput, ToManyCreateInput, ToManyUpdateInput, ToOneInput, ) -from strawchemy.strawberry.typing import AggregateDTOT, GraphQLDTOT, GraphQLPurpose -from strawchemy.utils import get_annotations, non_optional_type_hint, snake_to_camel +from strawchemy.typing import AggregateDTOT, GraphQLDTOT, GraphQLPurpose +from strawchemy.utils.annotation import get_annotations, non_optional_type_hint +from strawchemy.utils.text import snake_to_camel if TYPE_CHECKING: from collections.abc import Generator, Hashable, Sequence @@ -49,13 +50,20 @@ from strawchemy import Strawchemy from strawchemy.dto.base import DTOBackend, DTOBase, Relation - from strawchemy.graph import Node - from strawchemy.sqlalchemy.inspector import SQLAlchemyGraphQLInspector - from strawchemy.sqlalchemy.typing import DeclarativeT - from strawchemy.types import DefaultOffsetPagination - - -__all__ = ("AggregateDTOFactory", "DistinctOnFieldsDTOFactory", "RootAggregateTypeDTOFactory", "TypeDTOFactory") + from strawchemy.dto.inspectors import SQLAlchemyGraphQLInspector + from strawchemy.repository.typing import DeclarativeT + from strawchemy.schema.pagination import DefaultOffsetPagination + from strawchemy.utils.graph import Node + + +__all__ = ( + "AggregateDTOFactory", + "DistinctOnFieldsDTOFactory", + "InputFactory", + "RootAggregateTypeDTOFactory", + "TypeDTOFactory", + "UpsertConflictFieldsDTOFactory", +) T = TypeVar("T") @@ -151,7 +159,7 @@ def _cache_key( dto_config: DTOConfig, node: Node[Relation[Any, MappedGraphQLDTOT], None], *, - child_options: _ChildOptions, + child_options: ChildOptions, **factory_kwargs: Any, ) -> Hashable: return (super()._cache_key(model, dto_config, node, **factory_kwargs), child_options) @@ -200,7 +208,7 @@ def factory( tags: set[str] | None = None, backend_kwargs: dict[str, Any] | None = None, *, - child_options: _ChildOptions | None = None, + child_options: ChildOptions | None = None, aggregations: bool = True, description: str | None = None, directives: Sequence[object] | None = (), @@ -225,7 +233,7 @@ def factory( child_options=child_options, **kwargs, ) - child_options = child_options or _ChildOptions() + child_options = child_options or ChildOptions() if self.graphql_type(dto_config) == "object": dto = self._update_fields(dto, base, pagination=child_options.pagination, order_by=child_options.order_by) if register_type: @@ -550,7 +558,7 @@ def _cache_key( dto_config: DTOConfig, node: Node[Relation[Any, MappedGraphQLDTOT], None], *, - child_options: _ChildOptions, + child_options: ChildOptions, mode: GraphQLPurpose, **factory_kwargs: Any, ) -> Hashable: diff --git a/src/strawchemy/strawberry/_field.py b/src/strawchemy/schema/field.py similarity index 56% rename from src/strawchemy/strawberry/_field.py rename to src/strawchemy/schema/field.py index 791e734d..32a21f04 100644 --- a/src/strawchemy/strawberry/_field.py +++ b/src/strawchemy/schema/field.py @@ -1,102 +1,60 @@ from __future__ import annotations import dataclasses -from collections.abc import Sequence from functools import cached_property from inspect import isclass -from typing import TYPE_CHECKING, Any, Literal, Optional, TypeAlias, TypeVar, cast, get_args, get_origin +from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar, cast from strawberry.annotation import StrawberryAnnotation from strawberry.types import get_object_definition from strawberry.types.arguments import StrawberryArgument -from strawberry.types.base import StrawberryList, StrawberryOptional, StrawberryType, WithStrawberryObjectDefinition +from strawberry.types.base import StrawberryOptional from strawberry.types.field import UNRESOLVED, StrawberryField from typing_extensions import Self, TypeIs, override -from strawchemy.constants import ( - DATA_KEY, - DISTINCT_ON_KEY, - FILTER_KEY, - LIMIT_KEY, - NODES_KEY, - OFFSET_KEY, - ORDER_BY_KEY, - UPSERT_CONFLICT_FIELDS, - UPSERT_UPDATE_FIELDS, -) +from strawchemy.constants import DISTINCT_ON_KEY, FILTER_KEY, LIMIT_KEY, NODES_KEY, OFFSET_KEY, ORDER_BY_KEY from strawchemy.dto.base import MappedDTO +from strawchemy.dto.strawberry import MappedStrawberryGraphQLDTO, StrawchemyDTOAttributes from strawchemy.dto.types import DTOConfig, Purpose -from strawchemy.strawberry._utils import dto_model_from_type, strawberry_contained_types, strawberry_contained_user_type -from strawchemy.strawberry.dto import ( - BooleanFilterDTO, - EnumDTO, - MappedStrawberryGraphQLDTO, - OrderByDTO, - StrawchemyDTOAttributes, +from strawchemy.exceptions import StrawchemyFieldError +from strawchemy.schema.pagination import DefaultOffsetPagination +from strawchemy.utils.annotation import is_type_hint_optional +from strawchemy.utils.strawberry import ( + dto_model_from_type, + is_list, + strawberry_contained_types, + strawberry_contained_user_type, ) -from strawchemy.strawberry.exceptions import StrawchemyFieldError -from strawchemy.strawberry.mutation.input import Input -from strawchemy.strawberry.repository import StrawchemyAsyncRepository -from strawchemy.types import DefaultOffsetPagination -from strawchemy.typing import UNION_TYPES -from strawchemy.utils import is_type_hint_optional -from strawchemy.validation.base import InputValidationError if TYPE_CHECKING: - from collections.abc import Awaitable, Callable, Coroutine, Mapping + from collections.abc import Awaitable, Callable, Coroutine, Mapping, Sequence + from sqlalchemy import Select from sqlalchemy.orm import DeclarativeBase + from strawberry import BasePermission, Info from strawberry.extensions.field_extension import FieldExtension from strawberry.types.base import StrawberryObjectDefinition, StrawberryType, WithStrawberryObjectDefinition from strawberry.types.fields.resolver import StrawberryResolver - from sqlalchemy import Select - from strawberry import BasePermission, Info from strawchemy import StrawchemyConfig - from strawchemy.sqlalchemy.typing import QueryHookCallable - from strawchemy.strawberry.dto import BooleanFilterDTO, EnumDTO, OrderByDTO - from strawchemy.strawberry.mutation.types import ValidationErrorType - from strawchemy.strawberry.repository import StrawchemySyncRepository - from strawchemy.strawberry.repository._base import GraphQLResult - from strawchemy.strawberry.typing import ( - AnyMappedDTO, + from strawchemy.dto.strawberry import BooleanFilterDTO, EnumDTO, OrderByDTO + from strawchemy.repository.strawberry import StrawchemyAsyncRepository, StrawchemySyncRepository + from strawchemy.repository.strawberry.base import GraphQLResult + from strawchemy.repository.typing import QueryHookCallable + from strawchemy.typing import ( + AnyRepository, + AnyRepositoryType, + CreateOrUpdateResolverResult, FilterStatementCallable, - MappedGraphQLDTO, + GetByIdResolverResult, + ListResolverResult, StrawchemyTypeWithStrawberryObjectDefinition, ) - from strawchemy.typing import AnyRepository - from strawchemy.validation.base import ValidationProtocol - -__all__ = ("StrawchemyCreateMutationField", "StrawchemyDeleteMutationField", "StrawchemyField") +__all__ = ("StrawchemyField",) T = TypeVar("T", bound="DeclarativeBase") -_OneOrManyResult: TypeAlias = ( - "Sequence[StrawchemyTypeWithStrawberryObjectDefinition] | StrawchemyTypeWithStrawberryObjectDefinition" -) -_ListResolverResult: TypeAlias = _OneOrManyResult -_GetByIdResolverResult: TypeAlias = "StrawchemyTypeWithStrawberryObjectDefinition | None" -_CreateOrUpdateResolverResult: TypeAlias = "_OneOrManyResult | ValidationErrorType | Sequence[ValidationErrorType]" - - -_OPTIONAL_UNION_ARG_SIZE: int = 2 - - -def _is_list( - type_: StrawberryType | type[WithStrawberryObjectDefinition] | object | str, -) -> TypeIs[type[list[Any]] | StrawberryList]: - if isinstance(type_, StrawberryOptional): - type_ = type_.of_type - if origin := get_origin(type_): - type_ = origin - if origin is Optional: - type_ = get_args(type_)[0] - if origin in UNION_TYPES and len(args := get_args(type_)) == _OPTIONAL_UNION_ARG_SIZE: - type_ = args[0] if args[0] is not type(None) else args[1] - - return isinstance(type_, StrawberryList) or type_ is list - class StrawchemyField(StrawberryField): """A custom field class for Strawberry GraphQL that allows explicit handling of resolver arguments. @@ -116,7 +74,7 @@ class StrawchemyField(StrawberryField): def __init__( self, config: StrawchemyConfig, - repository_type: AnyRepository, + repository_type: AnyRepositoryType, filter_type: type[BooleanFilterDTO] | None = None, order_by: type[OrderByDTO] | None = None, distinct_on: type[EnumDTO] | None = None, @@ -207,26 +165,29 @@ def _get_repository(self, info: Info[Any, Any]) -> StrawchemySyncRepository[Any] deterministic_ordering=self._config.deterministic_ordering, ) - async def _list_result_async(self, repository_call: Awaitable[GraphQLResult[Any, Any]]) -> _ListResolverResult: + def _is_repo_async(self, repository: AnyRepository | type[AnyRepository]) -> TypeIs[StrawchemyAsyncRepository[Any]]: + return repository.is_async + + async def _list_result_async(self, repository_call: Awaitable[GraphQLResult[Any, Any]]) -> ListResolverResult: return (await repository_call).graphql_list(root_aggregations=self.root_aggregations) - def _list_result_sync(self, repository_call: GraphQLResult[Any, Any]) -> _ListResolverResult: + def _list_result_sync(self, repository_call: GraphQLResult[Any, Any]) -> ListResolverResult: return repository_call.graphql_list(root_aggregations=self.root_aggregations) async def _get_by_id_result_async( self, repository_call: Awaitable[GraphQLResult[Any, Any]] - ) -> _GetByIdResolverResult: + ) -> GetByIdResolverResult: result = await repository_call return result.graphql_type_or_none() if self.is_optional else result.graphql_type() - def _get_by_id_result_sync(self, repository_call: GraphQLResult[Any, Any]) -> _GetByIdResolverResult: + def _get_by_id_result_sync(self, repository_call: GraphQLResult[Any, Any]) -> GetByIdResolverResult: return repository_call.graphql_type_or_none() if self.is_optional else repository_call.graphql_type() def _get_by_id_resolver( self, info: Info, **kwargs: Any - ) -> _GetByIdResolverResult | Coroutine[_GetByIdResolverResult, Any, Any]: + ) -> GetByIdResolverResult | Coroutine[GetByIdResolverResult, Any, Any]: repository = self._get_repository(info) - if isinstance(repository, StrawchemyAsyncRepository): + if self._is_repo_async(repository): return self._get_by_id_result_async(repository.get_by_id(**kwargs)) return self._get_by_id_result_sync(repository.get_by_id(**kwargs)) @@ -238,9 +199,9 @@ def _list_resolver( distinct_on: list[EnumDTO] | None = None, limit: int | None = None, offset: int | None = None, - ) -> _ListResolverResult | Coroutine[_ListResolverResult, Any, Any]: + ) -> ListResolverResult | Coroutine[ListResolverResult, Any, Any]: repository = self._get_repository(info) - if isinstance(repository, StrawchemyAsyncRepository): + if self._is_repo_async(repository): return self._list_result_async(repository.list(filter_input, order_by, distinct_on, limit, offset)) return self._list_result_sync(repository.list(filter_input, order_by, distinct_on, limit, offset)) @@ -347,7 +308,7 @@ def filter_statement(self, info: Info[Any, Any]) -> Select[tuple[DeclarativeBase @cached_property def is_list(self) -> bool: - return True if self.root_aggregations else _is_list(self._type_or_annotation()) + return True if self.root_aggregations else is_list(self._type_or_annotation()) @cached_property def is_optional(self) -> bool: @@ -362,7 +323,7 @@ def is_basic_field(self) -> bool: @cached_property @override def is_async(self) -> bool: - return issubclass(self._repository_type, StrawchemyAsyncRepository) + return self._is_repo_async(self._repository_type) @override def __copy__(self) -> Self: @@ -456,13 +417,13 @@ def resolve_type( def resolver(self, info: Info[Any, Any], *args: Any, **kwargs: Any) -> ( ( - _ListResolverResult - | Coroutine[_ListResolverResult, Any, Any] - | _GetByIdResolverResult - | Coroutine[_GetByIdResolverResult, Any, Any] + ListResolverResult + | Coroutine[ListResolverResult, Any, Any] + | GetByIdResolverResult + | Coroutine[GetByIdResolverResult, Any, Any] ) - | _CreateOrUpdateResolverResult - ) | Coroutine[_CreateOrUpdateResolverResult, Any, Any]: + | CreateOrUpdateResolverResult + ) | Coroutine[CreateOrUpdateResolverResult, Any, Any]: if self.is_list: return self._list_resolver(info, *args, **kwargs) return self._get_by_id_resolver(info, *args, **kwargs) @@ -475,227 +436,3 @@ def get_result( assert info return self.resolver(info, *args, **kwargs) return super().get_result(source, info, args, kwargs) - - -class _StrawchemyInputMutationField(StrawchemyField): - def __init__( - self, - input_type: type[MappedGraphQLDTO[T]], - *args: Any, - validation: ValidationProtocol[T] | None = None, - **kwargs: Any, - ) -> None: - super().__init__(*args, **kwargs) - self.is_root_field = True - self._input_type = input_type - self._validation = validation - - -class _StrawchemyMutationField: - async def _input_result_async( - self, repository_call: Awaitable[GraphQLResult[Any, Any]], input_data: Input[Any] - ) -> _ListResolverResult: - result = await repository_call - return result.graphql_list() if input_data.list_input else result.graphql_type() - - def _input_result_sync( - self, repository_call: GraphQLResult[Any, Any], input_data: Input[Any] - ) -> _ListResolverResult: - return repository_call.graphql_list() if input_data.list_input else repository_call.graphql_type() - - -class StrawchemyCreateMutationField(_StrawchemyInputMutationField, _StrawchemyMutationField): - def _create_resolver( - self, info: Info, data: AnyMappedDTO | Sequence[AnyMappedDTO] - ) -> _CreateOrUpdateResolverResult | Coroutine[_CreateOrUpdateResolverResult, Any, Any]: - repository = self._get_repository(info) - try: - input_data = Input(data, self._validation) - except InputValidationError as error: - return error.graphql_type() - if isinstance(repository, StrawchemyAsyncRepository): - return self._input_result_async(repository.create(input_data), input_data) - return self._input_result_sync(repository.create(input_data), input_data) - - @override - def auto_arguments(self) -> list[StrawberryArgument]: - if self.is_list: - return [StrawberryArgument(DATA_KEY, None, type_annotation=StrawberryAnnotation(list[self._input_type]))] - return [StrawberryArgument(DATA_KEY, None, type_annotation=StrawberryAnnotation(self._input_type))] - - @override - def resolver( - self, info: Info[Any, Any], *args: Any, **kwargs: Any - ) -> _CreateOrUpdateResolverResult | Coroutine[_CreateOrUpdateResolverResult, Any, Any]: - return self._create_resolver(info, *args, **kwargs) - - -class StrawchemyUpsertMutationField(_StrawchemyInputMutationField, _StrawchemyMutationField): - def __init__( - self, - input_type: type[MappedGraphQLDTO[T]], - update_fields_enum: type[EnumDTO], - conflict_fields_enum: type[EnumDTO], - *args: Any, - **kwargs: Any, - ) -> None: - super().__init__(input_type, *args, **kwargs) - self._update_fields_enum = update_fields_enum - self._conflict_fields_enum = conflict_fields_enum - - def _upsert_resolver( - self, - info: Info, - data: AnyMappedDTO | Sequence[AnyMappedDTO], - filter_input: BooleanFilterDTO | None = None, - update_fields: list[EnumDTO] | None = None, - conflict_fields: EnumDTO | None = None, - ) -> _CreateOrUpdateResolverResult | Coroutine[_CreateOrUpdateResolverResult, Any, Any]: - repository = self._get_repository(info) - try: - input_data = Input(data, self._validation) - except InputValidationError as error: - return error.graphql_type() - if isinstance(repository, StrawchemyAsyncRepository): - return self._input_result_async( - repository.upsert(input_data, filter_input, update_fields, conflict_fields), input_data - ) - return self._input_result_sync( - repository.upsert(input_data, filter_input, update_fields, conflict_fields), input_data - ) - - @override - def auto_arguments(self) -> list[StrawberryArgument]: - arguments = [ - StrawberryArgument( - UPSERT_UPDATE_FIELDS, - None, - type_annotation=StrawberryAnnotation(Optional[list[self._update_fields_enum]]), - default=None, - ), - StrawberryArgument( - UPSERT_CONFLICT_FIELDS, - None, - type_annotation=StrawberryAnnotation(Optional[self._conflict_fields_enum]), - default=None, - ), - ] - if self.is_list: - arguments.append( - StrawberryArgument(DATA_KEY, None, type_annotation=StrawberryAnnotation(list[self._input_type])) - ) - else: - arguments.append(StrawberryArgument(DATA_KEY, None, type_annotation=StrawberryAnnotation(self._input_type))) - return arguments - - @override - def resolver( - self, info: Info[Any, Any], *args: Any, **kwargs: Any - ) -> _CreateOrUpdateResolverResult | Coroutine[_CreateOrUpdateResolverResult, Any, Any]: - return self._upsert_resolver(info, *args, **kwargs) - - -class StrawchemyUpdateMutationField(_StrawchemyInputMutationField, _StrawchemyMutationField): - @override - def _validate_type(self, type_: StrawberryType | type[WithStrawberryObjectDefinition] | Any) -> None: - if self._filter is not None and not _is_list(type_): - msg = f"Type of update mutation by filter must be a list: {self.name}" - raise StrawchemyFieldError(msg) - - def _update_by_ids_resolver( - self, info: Info, data: AnyMappedDTO | Sequence[AnyMappedDTO], **_: Any - ) -> _CreateOrUpdateResolverResult | Coroutine[_CreateOrUpdateResolverResult, Any, Any]: - repository = self._get_repository(info) - try: - input_data = Input(data, self._validation) - except InputValidationError as error: - error_result = error.graphql_type() - return [error_result] if isinstance(data, Sequence) else error_result - - if isinstance(repository, StrawchemyAsyncRepository): - return self._input_result_async(repository.update_by_id(input_data), input_data) - return self._input_result_sync(repository.update_by_id(input_data), input_data) - - def _update_by_filter_resolver( - self, info: Info, data: AnyMappedDTO, filter_input: BooleanFilterDTO - ) -> _CreateOrUpdateResolverResult | Coroutine[_CreateOrUpdateResolverResult, Any, Any]: - repository = self._get_repository(info) - try: - input_data = Input(data, self._validation) - except InputValidationError as error: - return [error.graphql_type()] - if isinstance(repository, StrawchemyAsyncRepository): - return self._list_result_async(repository.update_by_filter(input_data, filter_input)) - return self._list_result_sync(repository.update_by_filter(input_data, filter_input)) - - @override - def auto_arguments(self) -> list[StrawberryArgument]: - if self.filter: - return [ - StrawberryArgument(DATA_KEY, None, type_annotation=StrawberryAnnotation(self._input_type)), - StrawberryArgument( - python_name="filter_input", - graphql_name=FILTER_KEY, - type_annotation=StrawberryAnnotation(Optional[self.filter]), - default=None, - ), - ] - if self.is_list: - return [StrawberryArgument(DATA_KEY, None, type_annotation=StrawberryAnnotation(list[self._input_type]))] - return [StrawberryArgument(DATA_KEY, None, type_annotation=StrawberryAnnotation(self._input_type))] - - @override - def resolver( - self, info: Info[Any, Any], *args: Any, **kwargs: Any - ) -> _CreateOrUpdateResolverResult | Coroutine[_CreateOrUpdateResolverResult, Any, Any]: - if self._filter is None: - return self._update_by_ids_resolver(info, *args, **kwargs) - return self._update_by_filter_resolver(info, *args, **kwargs) - - -class StrawchemyDeleteMutationField(StrawchemyField, _StrawchemyMutationField): - def __init__( - self, - input_type: type[BooleanFilterDTO] | None = None, - *args: Any, - **kwargs: Any, - ) -> None: - super().__init__(*args, **kwargs) - self.is_root_field = True - self._input_type = input_type - - def _delete_resolver( - self, - info: Info, - filter_input: BooleanFilterDTO | None = None, - ) -> _CreateOrUpdateResolverResult | Coroutine[_CreateOrUpdateResolverResult, Any, Any]: - repository = self._get_repository(info) - if isinstance(repository, StrawchemyAsyncRepository): - return self._list_result_async(repository.delete(filter_input)) - return self._list_result_sync(repository.delete(filter_input)) - - @override - def _validate_type(self, type_: StrawberryType | type[WithStrawberryObjectDefinition] | Any) -> None: - # Calling self.is_list cause a recursion loop - if not _is_list(type_): - msg = f"Type of delete mutation must be a list: {self.name}" - raise StrawchemyFieldError(msg) - - @override - def auto_arguments(self) -> list[StrawberryArgument]: - if self._input_type: - return [ - StrawberryArgument( - python_name="filter_input", - graphql_name=FILTER_KEY, - default=None, - type_annotation=StrawberryAnnotation(self._input_type), - ) - ] - return [] - - @override - def resolver( - self, info: Info[Any, Any], *args: Any, **kwargs: Any - ) -> _CreateOrUpdateResolverResult | Coroutine[_CreateOrUpdateResolverResult, Any, Any]: - return self._delete_resolver(info, *args, **kwargs) diff --git a/src/strawchemy/schema/filters/__init__.py b/src/strawchemy/schema/filters/__init__.py new file mode 100644 index 00000000..66345e0e --- /dev/null +++ b/src/strawchemy/schema/filters/__init__.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +from strawchemy.schema.filters.base import ( + ArrayFilter, + BaseDateFilter, + BaseTimeFilter, + DateFilter, + DateTimeFilter, + EqualityFilter, + FilterProtocol, + JSONFilter, + OrderFilter, + TextFilter, + TimeDeltaFilter, + TimeFilter, +) +from strawchemy.schema.filters.inputs import ( + ArrayComparison, + DateComparison, + DateTimeComparison, + EqualityComparison, + GraphQLComparison, + GraphQLComparisonT, + GraphQLFilter, + OrderComparison, + TextComparison, + TimeComparison, + TimeDeltaComparison, + _JSONComparison, + _SQLiteJSONComparison, + make_full_json_comparison_input, + make_sqlite_json_comparison_input, +) + +__all__ = ( + "ArrayComparison", + "ArrayFilter", + "BaseDateFilter", + "BaseTimeFilter", + "DateComparison", + "DateFilter", + "DateTimeComparison", + "DateTimeFilter", + "EqualityComparison", + "EqualityFilter", + "FilterProtocol", + "FilterProtocol", + "GraphQLComparison", + "GraphQLComparisonT", + "GraphQLFilter", + "JSONFilter", + "OrderComparison", + "OrderFilter", + "TextComparison", + "TextFilter", + "TimeComparison", + "TimeDeltaComparison", + "TimeDeltaFilter", + "TimeFilter", + "_JSONComparison", + "_SQLiteJSONComparison", + "make_full_json_comparison_input", + "make_sqlite_json_comparison_input", +) diff --git a/src/strawchemy/strawberry/filters/base.py b/src/strawchemy/schema/filters/base.py similarity index 99% rename from src/strawchemy/strawberry/filters/base.py rename to src/strawchemy/schema/filters/base.py index 0798befa..23bae439 100644 --- a/src/strawchemy/strawberry/filters/base.py +++ b/src/strawchemy/schema/filters/base.py @@ -3,20 +3,19 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, ClassVar, Protocol, cast -from sqlalchemy.dialects import mysql -from sqlalchemy.dialects import postgresql as pg -from typing_extensions import override - from sqlalchemy import ARRAY, JSON, ColumnElement, Dialect, Integer, Text, and_, func, not_, null, or_, type_coerce from sqlalchemy import cast as sqla_cast +from sqlalchemy.dialects import mysql +from sqlalchemy.dialects import postgresql as pg from strawberry import UNSET +from typing_extensions import override if TYPE_CHECKING: from datetime import date, timedelta from sqlalchemy.orm import QueryableAttribute - from strawchemy.strawberry.filters.inputs import ( + from strawchemy.schema.filters import ( ArrayComparison, DateComparison, DateTimeComparison, diff --git a/src/strawchemy/strawberry/filters/geo.py b/src/strawchemy/schema/filters/geo.py similarity index 92% rename from src/strawchemy/strawberry/filters/geo.py rename to src/strawchemy/schema/filters/geo.py index f8ca87d9..f1049f16 100644 --- a/src/strawchemy/strawberry/filters/geo.py +++ b/src/strawchemy/schema/filters/geo.py @@ -4,16 +4,15 @@ from dataclasses import dataclass from typing import Any, TypeVar +import strawberry from geoalchemy2 import functions as geo_func +from sqlalchemy import ColumnElement, Dialect, null from sqlalchemy.orm import QueryableAttribute +from strawberry import UNSET from typing_extensions import override -import strawberry -from sqlalchemy import ColumnElement, Dialect, null -from strawberry import UNSET -from strawchemy.strawberry.filters.base import FilterProtocol -from strawchemy.strawberry.filters.inputs import GraphQLComparison -from strawchemy.strawberry.geo import GeoJSON +from strawchemy.schema.filters import FilterProtocol, GraphQLComparison +from strawchemy.schema.scalars.geo import GeoJSON __all__ = ("GeoComparison",) diff --git a/src/strawchemy/strawberry/filters/inputs.py b/src/strawchemy/schema/filters/inputs.py similarity index 97% rename from src/strawchemy/strawberry/filters/inputs.py rename to src/strawchemy/schema/filters/inputs.py index fec2706e..621bda20 100644 --- a/src/strawchemy/strawberry/filters/inputs.py +++ b/src/strawchemy/schema/filters/inputs.py @@ -15,9 +15,9 @@ from typing import TYPE_CHECKING, Any, Generic, TypeAlias, TypeVar import strawberry -from sqlalchemy import Dialect from strawberry import UNSET, Private -from strawchemy.strawberry.filters.base import ( + +from strawchemy.schema.filters import ( ArrayFilter, DateFilter, DateTimeFilter, @@ -29,13 +29,13 @@ TimeDeltaFilter, TimeFilter, ) -from strawchemy.strawberry.typing import QueryNodeType +from strawchemy.typing import QueryNodeType if TYPE_CHECKING: + from sqlalchemy import ColumnElement, Dialect from sqlalchemy.orm import QueryableAttribute - from sqlalchemy import ColumnElement - from strawchemy.strawberry.dto import OrderByEnum + from strawchemy.dto.strawberry import OrderByEnum __all__ = ( "ArrayComparison", @@ -47,6 +47,8 @@ "TimeComparison", "TimeDeltaComparison", "_JSONComparison", + "make_full_json_comparison_input", + "make_sqlite_json_comparison_input", ) T = TypeVar("T") diff --git a/src/strawchemy/schema/interfaces.py b/src/strawchemy/schema/interfaces.py new file mode 100644 index 00000000..46a35f92 --- /dev/null +++ b/src/strawchemy/schema/interfaces.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from enum import Enum +from typing import Any, ClassVar + +import strawberry + + +class ErrorId(Enum): + ERROR = "ERROR" + VALIDATION_ERROR = "VALIDATION_ERROR" + LOCALIZED_VALIDATION_ERROR = "LOCALIZED_VALIDATION_ERROR" + + +@strawberry.interface(description="Base interface for expected errors", name="ErrorType") +class ErrorType: + """Base class for GraphQL errors.""" + + __error_types__: ClassVar[set[type[Any]]] = set() + + id: str = ErrorId.ERROR.value + + def __init_subclass__(cls) -> None: + if not cls.__error_types__: + cls.__error_types__.add(ErrorType) + cls.__error_types__.add(cls) diff --git a/src/strawchemy/schema/mutation/__init__.py b/src/strawchemy/schema/mutation/__init__.py new file mode 100644 index 00000000..fde81cda --- /dev/null +++ b/src/strawchemy/schema/mutation/__init__.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from strawchemy.schema.mutation.input import Input, InputModel, LevelInput, UpsertData +from strawchemy.schema.mutation.types import ( + LocalizedErrorType, + RelationType, + RequiredToManyUpdateInput, + RequiredToOneInput, + ToManyCreateInput, + ToManyUpdateInput, + ToManyUpsertInput, + ToOneInput, + ToOneUpsertInput, + ValidationErrorType, + error_type_names, +) + +__all__ = ( + "Input", + "InputModel", + "LevelInput", + "LocalizedErrorType", + "RelationType", + "RequiredToManyUpdateInput", + "RequiredToOneInput", + "ToManyCreateInput", + "ToManyUpdateInput", + "ToManyUpsertInput", + "ToOneInput", + "ToOneUpsertInput", + "UpsertData", + "ValidationErrorType", + "error_type_names", +) diff --git a/src/strawchemy/strawberry/mutation/builder.py b/src/strawchemy/schema/mutation/field_builder.py similarity index 93% rename from src/strawchemy/strawberry/mutation/builder.py rename to src/strawchemy/schema/mutation/field_builder.py index 4acc5952..72903bb2 100644 --- a/src/strawchemy/strawberry/mutation/builder.py +++ b/src/strawchemy/schema/mutation/field_builder.py @@ -11,17 +11,17 @@ if TYPE_CHECKING: from collections.abc import Callable, Mapping, Sequence + from strawberry import BasePermission from strawberry.extensions.field_extension import FieldExtension - from strawberry import BasePermission from strawchemy.config.base import StrawchemyConfig - from strawchemy.strawberry._field import ( + from strawchemy.schema.mutation.fields import ( StrawchemyCreateMutationField, StrawchemyDeleteMutationField, StrawchemyUpdateMutationField, StrawchemyUpsertMutationField, ) - from strawchemy.typing import AnyRepository + from strawchemy.typing import AnyRepositoryType @dataclass @@ -46,7 +46,7 @@ def build( ], resolver: Any | None = None, *, - repository_type: AnyRepository | None = None, + repository_type: AnyRepositoryType | None = None, graphql_type: Any | None = None, name: str | None = None, description: str | None = None, @@ -65,8 +65,8 @@ def build( field_class: The specific mutation field class to instantiate (e.g., StrawchemyCreateMutationField). resolver: An optional custom resolver function for the mutation. - repository_type: An optional custom repository class. Defaults to - the repository configured in StrawchemyConfig. + repository_type: An optional custom strawberry class. Defaults to + the strawberry configured in StrawchemyConfig. graphql_type: The GraphQL return type of the mutation. name: The name of the GraphQL mutation field. description: The description of the GraphQL mutation field. diff --git a/src/strawchemy/schema/mutation/fields.py b/src/strawchemy/schema/mutation/fields.py new file mode 100644 index 00000000..1fcc66e5 --- /dev/null +++ b/src/strawchemy/schema/mutation/fields.py @@ -0,0 +1,259 @@ +from __future__ import annotations + +from collections.abc import Awaitable, Coroutine, Sequence +from typing import TYPE_CHECKING, Any, Optional, TypeVar + +from strawberry.annotation import StrawberryAnnotation +from strawberry.types.arguments import StrawberryArgument +from typing_extensions import override + +from strawchemy.constants import DATA_KEY, FILTER_KEY, UPSERT_CONFLICT_FIELDS, UPSERT_UPDATE_FIELDS +from strawchemy.exceptions import StrawchemyFieldError +from strawchemy.schema.field import StrawchemyField +from strawchemy.schema.mutation.input import Input +from strawchemy.utils.strawberry import is_list +from strawchemy.validation import InputValidationError + +if TYPE_CHECKING: + from sqlalchemy.orm import DeclarativeBase + from strawberry import Info + from strawberry.types.base import StrawberryType, WithStrawberryObjectDefinition + + from strawchemy.dto.strawberry import BooleanFilterDTO, EnumDTO + from strawchemy.repository.strawberry.base import GraphQLResult + from strawchemy.typing import AnyMappedDTO, CreateOrUpdateResolverResult, ListResolverResult, MappedGraphQLDTO + from strawchemy.validation import ValidationProtocol + + +__all__ = ( + "StrawchemyCreateMutationField", + "StrawchemyDeleteMutationField", + "StrawchemyUpdateMutationField", + "StrawchemyUpsertMutationField", +) + +T = TypeVar("T", bound="DeclarativeBase") + + +class _StrawchemyInputMutationField(StrawchemyField): + def __init__( + self, + input_type: type[MappedGraphQLDTO[T]], + *args: Any, + validation: ValidationProtocol[T] | None = None, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + self.is_root_field = True + self._input_type = input_type + self._validation = validation + + +class _StrawchemyMutationField: + async def _input_result_async( + self, repository_call: Awaitable[GraphQLResult[Any, Any]], input_data: Input[Any] + ) -> ListResolverResult: + result = await repository_call + return result.graphql_list() if input_data.list_input else result.graphql_type() + + def _input_result_sync( + self, repository_call: GraphQLResult[Any, Any], input_data: Input[Any] + ) -> ListResolverResult: + return repository_call.graphql_list() if input_data.list_input else repository_call.graphql_type() + + +class StrawchemyCreateMutationField(_StrawchemyInputMutationField, _StrawchemyMutationField): + def _create_resolver( + self, info: Info, data: AnyMappedDTO | Sequence[AnyMappedDTO] + ) -> CreateOrUpdateResolverResult | Coroutine[CreateOrUpdateResolverResult, Any, Any]: + repository = self._get_repository(info) + try: + input_data = Input(data, self._validation) + except InputValidationError as error: + return error.graphql_type() + if self._is_repo_async(repository): + return self._input_result_async(repository.create(input_data), input_data) + return self._input_result_sync(repository.create(input_data), input_data) + + @override + def auto_arguments(self) -> list[StrawberryArgument]: + if self.is_list: + return [StrawberryArgument(DATA_KEY, None, type_annotation=StrawberryAnnotation(list[self._input_type]))] + return [StrawberryArgument(DATA_KEY, None, type_annotation=StrawberryAnnotation(self._input_type))] + + @override + def resolver( + self, info: Info[Any, Any], *args: Any, **kwargs: Any + ) -> CreateOrUpdateResolverResult | Coroutine[CreateOrUpdateResolverResult, Any, Any]: + return self._create_resolver(info, *args, **kwargs) + + +class StrawchemyUpsertMutationField(_StrawchemyInputMutationField, _StrawchemyMutationField): + def __init__( + self, + input_type: type[MappedGraphQLDTO[T]], + update_fields_enum: type[EnumDTO], + conflict_fields_enum: type[EnumDTO], + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__(input_type, *args, **kwargs) + self._update_fields_enum = update_fields_enum + self._conflict_fields_enum = conflict_fields_enum + + def _upsert_resolver( + self, + info: Info, + data: AnyMappedDTO | Sequence[AnyMappedDTO], + filter_input: BooleanFilterDTO | None = None, + update_fields: list[EnumDTO] | None = None, + conflict_fields: EnumDTO | None = None, + ) -> CreateOrUpdateResolverResult | Coroutine[CreateOrUpdateResolverResult, Any, Any]: + repository = self._get_repository(info) + try: + input_data = Input(data, self._validation) + except InputValidationError as error: + return error.graphql_type() + if self._is_repo_async(repository): + return self._input_result_async( + repository.upsert(input_data, filter_input, update_fields, conflict_fields), input_data + ) + return self._input_result_sync( + repository.upsert(input_data, filter_input, update_fields, conflict_fields), input_data + ) + + @override + def auto_arguments(self) -> list[StrawberryArgument]: + arguments = [ + StrawberryArgument( + UPSERT_UPDATE_FIELDS, + None, + type_annotation=StrawberryAnnotation(Optional[list[self._update_fields_enum]]), + default=None, + ), + StrawberryArgument( + UPSERT_CONFLICT_FIELDS, + None, + type_annotation=StrawberryAnnotation(Optional[self._conflict_fields_enum]), + default=None, + ), + ] + if self.is_list: + arguments.append( + StrawberryArgument(DATA_KEY, None, type_annotation=StrawberryAnnotation(list[self._input_type])) + ) + else: + arguments.append(StrawberryArgument(DATA_KEY, None, type_annotation=StrawberryAnnotation(self._input_type))) + return arguments + + @override + def resolver( + self, info: Info[Any, Any], *args: Any, **kwargs: Any + ) -> CreateOrUpdateResolverResult | Coroutine[CreateOrUpdateResolverResult, Any, Any]: + return self._upsert_resolver(info, *args, **kwargs) + + +class StrawchemyUpdateMutationField(_StrawchemyInputMutationField, _StrawchemyMutationField): + @override + def _validate_type(self, type_: StrawberryType | type[WithStrawberryObjectDefinition] | Any) -> None: + if self._filter is not None and not is_list(type_): + msg = f"Type of update mutation by filter must be a list: {self.name}" + raise StrawchemyFieldError(msg) + + def _update_by_ids_resolver( + self, info: Info, data: AnyMappedDTO | Sequence[AnyMappedDTO], **_: Any + ) -> CreateOrUpdateResolverResult | Coroutine[CreateOrUpdateResolverResult, Any, Any]: + repository = self._get_repository(info) + try: + input_data = Input(data, self._validation) + except InputValidationError as error: + error_result = error.graphql_type() + return [error_result] if isinstance(data, Sequence) else error_result + + if self._is_repo_async(repository): + return self._input_result_async(repository.update_by_id(input_data), input_data) + return self._input_result_sync(repository.update_by_id(input_data), input_data) + + def _update_by_filter_resolver( + self, info: Info, data: AnyMappedDTO, filter_input: BooleanFilterDTO + ) -> CreateOrUpdateResolverResult | Coroutine[CreateOrUpdateResolverResult, Any, Any]: + repository = self._get_repository(info) + try: + input_data = Input(data, self._validation) + except InputValidationError as error: + return [error.graphql_type()] + if self._is_repo_async(repository): + return self._list_result_async(repository.update_by_filter(input_data, filter_input)) + return self._list_result_sync(repository.update_by_filter(input_data, filter_input)) + + @override + def auto_arguments(self) -> list[StrawberryArgument]: + if self.filter: + return [ + StrawberryArgument(DATA_KEY, None, type_annotation=StrawberryAnnotation(self._input_type)), + StrawberryArgument( + python_name="filter_input", + graphql_name=FILTER_KEY, + type_annotation=StrawberryAnnotation(Optional[self.filter]), + default=None, + ), + ] + if self.is_list: + return [StrawberryArgument(DATA_KEY, None, type_annotation=StrawberryAnnotation(list[self._input_type]))] + return [StrawberryArgument(DATA_KEY, None, type_annotation=StrawberryAnnotation(self._input_type))] + + @override + def resolver( + self, info: Info[Any, Any], *args: Any, **kwargs: Any + ) -> CreateOrUpdateResolverResult | Coroutine[CreateOrUpdateResolverResult, Any, Any]: + if self._filter is None: + return self._update_by_ids_resolver(info, *args, **kwargs) + return self._update_by_filter_resolver(info, *args, **kwargs) + + +class StrawchemyDeleteMutationField(StrawchemyField, _StrawchemyMutationField): + def __init__( + self, + input_type: type[BooleanFilterDTO] | None = None, + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + self.is_root_field = True + self._input_type = input_type + + def _delete_resolver( + self, + info: Info, + filter_input: BooleanFilterDTO | None = None, + ) -> CreateOrUpdateResolverResult | Coroutine[CreateOrUpdateResolverResult, Any, Any]: + repository = self._get_repository(info) + if self._is_repo_async(repository): + return self._list_result_async(repository.delete(filter_input)) + return self._list_result_sync(repository.delete(filter_input)) + + @override + def _validate_type(self, type_: StrawberryType | type[WithStrawberryObjectDefinition] | Any) -> None: + # Calling self.is_list cause a recursion loop + if not is_list(type_): + msg = f"Type of delete mutation must be a list: {self.name}" + raise StrawchemyFieldError(msg) + + @override + def auto_arguments(self) -> list[StrawberryArgument]: + if self._input_type: + return [ + StrawberryArgument( + python_name="filter_input", + graphql_name=FILTER_KEY, + default=None, + type_annotation=StrawberryAnnotation(self._input_type), + ) + ] + return [] + + @override + def resolver( + self, info: Info[Any, Any], *args: Any, **kwargs: Any + ) -> CreateOrUpdateResolverResult | Coroutine[CreateOrUpdateResolverResult, Any, Any]: + return self._delete_resolver(info, *args, **kwargs) diff --git a/src/strawchemy/strawberry/mutation/input.py b/src/strawchemy/schema/mutation/input.py similarity index 85% rename from src/strawchemy/strawberry/mutation/input.py rename to src/strawchemy/schema/mutation/input.py index e9963d25..2c82653a 100644 --- a/src/strawchemy/strawberry/mutation/input.py +++ b/src/strawchemy/schema/mutation/input.py @@ -5,13 +5,13 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, TypeVar, cast, final +from sqlalchemy import event, inspect from sqlalchemy.orm import MapperProperty, RelationshipDirection, object_mapper from typing_extensions import Self, override -from sqlalchemy import event, inspect from strawchemy.dto.base import DTOFieldDefinition, MappedDTO, ToMappedProtocol, VisitorProtocol -from strawchemy.dto.inspectors.sqlalchemy import SQLAlchemyInspector -from strawchemy.strawberry.mutation.types import ( +from strawchemy.dto.inspectors import SQLAlchemyInspector +from strawchemy.schema.mutation.types import ( RelationType, ToManyCreateInput, ToManyUpdateInput, @@ -26,12 +26,12 @@ from sqlalchemy.orm import DeclarativeBase, QueryableAttribute - from strawchemy.strawberry.dto import EnumDTO - from strawchemy.strawberry.typing import MappedGraphQLDTO + from strawchemy.dto.strawberry import EnumDTO + from strawchemy.typing import MappedGraphQLDTO from strawchemy.validation.base import ValidationProtocol -__all__ = ("Input", "LevelInput", "RelationType") +__all__ = ("Input", "InputModel", "LevelInput", "RelationType") T = TypeVar("T", bound=MappedDTO[Any]) DeclarativeBaseT = TypeVar("DeclarativeBaseT", bound="DeclarativeBase") @@ -68,30 +68,22 @@ def __iter__(self) -> Iterator[DeclarativeBase]: return iter(self.instances) +@dataclass class _UnboundRelationInput: - def __init__( - self, - attribute: MapperProperty[Any], - related: type[DeclarativeBase], - relation_type: RelationType, - set_: list[DeclarativeBase] | None | type[_Unset] = _Unset, - add: list[DeclarativeBase] | None = None, - remove: list[DeclarativeBase] | None = None, - create: list[DeclarativeBase] | None = None, - upsert: UpsertData | None = None, - input_index: int = -1, - level: int = 0, - ) -> None: - self.attribute = attribute - self.related = related - self.relation_type = relation_type - self.set: list[DeclarativeBase] | None = set_ if set_ is not _Unset else [] - self.add = add if add is not None else [] - self.remove = remove if remove is not None else [] - self.create = create if create is not None else [] - self.upsert = upsert - self.input_index = input_index - self.level = level + attribute: MapperProperty[Any] + related: type[DeclarativeBase] + relation_type: RelationType + set_: list[DeclarativeBase] | None | type[_Unset] = _Unset + add: list[DeclarativeBase] = field(default_factory=list) + remove: list[DeclarativeBase] = field(default_factory=list) + create: list[DeclarativeBase] = field(default_factory=list) + upsert: UpsertData | None = None + input_index: int = -1 + level: int = 0 + set: list[DeclarativeBase] | None = field(init=False) + + def __post_init__(self) -> None: + self.set = self.set_ if self.set_ is not _Unset else [] def add_instance(self, model: DeclarativeBase) -> None: if not _has_record(model): @@ -108,34 +100,12 @@ def __bool__(self) -> bool: return bool(self.set or self.add or self.remove or self.create or self.upsert) or self.set is None +@dataclass class RelationInput(_UnboundRelationInput): - def __init__( - self, - attribute: MapperProperty[Any], - related: type[DeclarativeBase], - parent: DeclarativeBase, - relation_type: RelationType, - set_: list[DeclarativeBase] | None | type[_Unset] = _Unset, - add: list[DeclarativeBase] | None = None, - remove: list[DeclarativeBase] | None = None, - create: list[DeclarativeBase] | None = None, - upsert: UpsertData | None = None, - input_index: int = -1, - level: int = 0, - ) -> None: - super().__init__( - attribute=attribute, - related=related, - relation_type=relation_type, - set_=set_, - add=add, - remove=remove, - create=create, - upsert=upsert, - input_index=input_index, - level=level, - ) - self.parent = parent + parent: DeclarativeBase = field(kw_only=True) + + def __post_init__(self) -> None: + super().__post_init__() if self.relation_type is RelationType.TO_ONE: event.listens_for(self.attribute, "set")(self._set_event) diff --git a/src/strawchemy/strawberry/mutation/types.py b/src/strawchemy/schema/mutation/types.py similarity index 90% rename from src/strawchemy/strawberry/mutation/types.py rename to src/strawchemy/schema/mutation/types.py index a47e59d2..a0d64c54 100644 --- a/src/strawchemy/strawberry/mutation/types.py +++ b/src/strawchemy/schema/mutation/types.py @@ -1,20 +1,33 @@ from __future__ import annotations from enum import Enum, auto -from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar +from typing import TYPE_CHECKING, Any, Generic, TypeVar +import strawberry +from strawberry import UNSET from strawberry.types import get_object_definition from typing_extensions import override -import strawberry -from strawberry import UNSET -from strawchemy.dto.base import MappedDTO, ToMappedProtocol, VisitorProtocol +from strawchemy.dto import MappedDTO, ToMappedProtocol, VisitorProtocol from strawchemy.dto.types import DTOUnset +from strawchemy.schema.interfaces import ErrorId, ErrorType if TYPE_CHECKING: - from strawchemy.strawberry.dto import EnumDTO - -__all__ = ("RelationType",) + from strawchemy.dto.strawberry import EnumDTO + +__all__ = ( + "LocalizedErrorType", + "RelationType", + "RequiredToManyUpdateInput", + "RequiredToOneInput", + "ToManyCreateInput", + "ToManyUpdateInput", + "ToManyUpsertInput", + "ToOneInput", + "ToOneUpsertInput", + "ValidationErrorType", + "error_type_names", +) T = TypeVar("T", bound="MappedDTO[Any]") UpdateFieldsT = TypeVar("UpdateFieldsT", bound="EnumDTO") @@ -30,12 +43,6 @@ def error_type_names() -> set[str]: return {get_object_definition(type_, strict=True).name for type_ in ErrorType.__error_types__} -class ErrorId(Enum): - ERROR = "ERROR" - VALIDATION_ERROR = "VALIDATION_ERROR" - LOCALIZED_VALIDATION_ERROR = "LOCALIZED_VALIDATION_ERROR" - - class RelationType(Enum): TO_ONE = auto() TO_MANY = auto() @@ -181,20 +188,6 @@ def to_mapped( return super().to_mapped(visitor, level=level, override=override) -@strawberry.interface(description="Base interface for expected errors", name="ErrorType") -class ErrorType: - """Base class for GraphQL errors.""" - - __error_types__: ClassVar[set[type[Any]]] = set() - - id: str = ErrorId.ERROR.value - - def __init_subclass__(cls) -> None: - if not cls.__error_types__: - cls.__error_types__.add(ErrorType) - cls.__error_types__.add(cls) - - @strawberry.type(description="Indicate validation error type and location.", name="LocalizedErrorType") class LocalizedErrorType(ErrorType): """Match inner shape of pydantic ValidationError.""" diff --git a/src/strawchemy/types.py b/src/strawchemy/schema/pagination.py similarity index 100% rename from src/strawchemy/types.py rename to src/strawchemy/schema/pagination.py diff --git a/src/strawchemy/schema/scalars/__init__.py b/src/strawchemy/schema/scalars/__init__.py new file mode 100644 index 00000000..a53503dc --- /dev/null +++ b/src/strawchemy/schema/scalars/__init__.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +from strawchemy.schema.scalars.base import Date, DateTime, Interval, Time + +__all__ = ("Date", "DateTime", "Interval", "Time") diff --git a/src/strawchemy/strawberry/scalars.py b/src/strawchemy/schema/scalars/base.py similarity index 88% rename from src/strawchemy/strawberry/scalars.py rename to src/strawchemy/schema/scalars/base.py index 7cf80400..3be8a015 100644 --- a/src/strawchemy/strawberry/scalars.py +++ b/src/strawchemy/schema/scalars/base.py @@ -2,18 +2,23 @@ from datetime import date, datetime, time, timedelta, timezone from functools import partial -from typing import NewType, TypeVar +from typing import TYPE_CHECKING, TypeVar from msgspec import json +from strawberry import scalar from strawberry.schema.types.base_scalars import wrap_parser -from strawberry import scalar +from strawchemy.utils.annotation import new_type + +if TYPE_CHECKING: + from typing import Any __all__ = ("Date", "DateTime", "Interval", "Time") + UTC = timezone.utc -T = TypeVar("T") +T = TypeVar("T", bound="Any") def _serialize_time(value: time | timedelta | str) -> str: @@ -30,11 +35,6 @@ def _serialize(value: timedelta) -> str: return json.encode(value).decode() -def new_type(name: str, type_: type[T]) -> type[T]: - # Needed for pyright - return NewType(name, type_) # pyright: ignore[reportArgumentType] - - Interval = scalar( new_type("Interval", timedelta), description=( diff --git a/src/strawchemy/strawberry/geo.py b/src/strawchemy/schema/scalars/geo.py similarity index 99% rename from src/strawchemy/strawberry/geo.py rename to src/strawchemy/schema/scalars/geo.py index f3dc085d..a2653f23 100644 --- a/src/strawchemy/strawberry/geo.py +++ b/src/strawchemy/schema/scalars/geo.py @@ -6,6 +6,7 @@ from typing import Any import shapely +import strawberry from geoalchemy2 import WKBElement, WKTElement from geoalchemy2.shape import to_shape from geojson_pydantic.geometries import Geometry as PydanticGeometry @@ -21,8 +22,7 @@ from pydantic import TypeAdapter from shapely import Geometry, to_geojson -import strawberry -from strawchemy.strawberry.scalars import new_type +from strawchemy.utils.annotation import new_type __all__ = ( "GEO_SCALAR_OVERRIDES", diff --git a/src/strawchemy/sqlalchemy/__init__.py b/src/strawchemy/sqlalchemy/__init__.py deleted file mode 100644 index 591a708e..00000000 --- a/src/strawchemy/sqlalchemy/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -"""SQLAlchemy integration for Strawchemy. - -This package provides SQLAlchemy-based repository implementations for use with -Strawberry GraphQL. -""" - -from __future__ import annotations - -from strawchemy.sqlalchemy.repository import ( - SQLAlchemyGraphQLAsyncRepository, - SQLAlchemyGraphQLRepository, - SQLAlchemyGraphQLSyncRepository, -) - -__all__ = ("SQLAlchemyGraphQLAsyncRepository", "SQLAlchemyGraphQLRepository", "SQLAlchemyGraphQLSyncRepository") diff --git a/src/strawchemy/sqlalchemy/exceptions.py b/src/strawchemy/sqlalchemy/exceptions.py deleted file mode 100644 index 044fa3f5..00000000 --- a/src/strawchemy/sqlalchemy/exceptions.py +++ /dev/null @@ -1,17 +0,0 @@ -"""SQLAlchemy DTO exceptions.""" - -from __future__ import annotations - -__all__ = ("QueryHookError", "QueryResultError", "TranspilingError") - - -class TranspilingError(Exception): - """Raised when an error occurs during transpiling.""" - - -class QueryResultError(Exception): - """Raised when an error occurs during query result processing or mapping.""" - - -class QueryHookError(Exception): - """Raised when an error occurs within a query hook's execution.""" diff --git a/src/strawchemy/sqlalchemy/inspector.py b/src/strawchemy/sqlalchemy/inspector.py deleted file mode 100644 index 5bd5ac53..00000000 --- a/src/strawchemy/sqlalchemy/inspector.py +++ /dev/null @@ -1,219 +0,0 @@ -"""Provides an inspector for SQLAlchemy models to determine GraphQL filter types. - -This module defines `SQLAlchemyGraphQLInspector`, which extends the base -`SQLAlchemyInspector` from `strawchemy.dto.inspectors.sqlalchemy`. It maps -SQLAlchemy column types and model attributes to appropriate GraphQL comparison -filter types (e.g., `TextComparison`, `OrderComparison`). This process considers -database-specific features (via `DatabaseFeatures`) and allows for custom -filter overrides. The module also includes utility functions like `loaded_attributes`. -""" - -from __future__ import annotations - -from collections import OrderedDict -from datetime import date, datetime, time, timedelta -from decimal import Decimal -from typing import TYPE_CHECKING, Any, TypeVar - -from sqlalchemy.orm import NO_VALUE, DeclarativeBase, QueryableAttribute, registry -from sqlalchemy.types import ARRAY - -from sqlalchemy import inspect -from strawchemy.config.databases import DatabaseFeatures -from strawchemy.constants import GEO_INSTALLED -from strawchemy.dto.inspectors.sqlalchemy import SQLAlchemyInspector -from strawchemy.strawberry.filters import ( - ArrayComparison, - DateComparison, - DateTimeComparison, - EqualityComparison, - GraphQLComparison, - OrderComparison, - TextComparison, - TimeComparison, - TimeDeltaComparison, -) -from strawchemy.strawberry.filters.inputs import make_full_json_comparison_input, make_sqlite_json_comparison_input - -if TYPE_CHECKING: - from strawchemy.dto.base import DTOFieldDefinition - from strawchemy.sqlalchemy.typing import FilterMap - from strawchemy.typing import SupportedDialect - - -__all__ = ("SQLAlchemyGraphQLInspector", "loaded_attributes") - - -T = TypeVar("T", bound=Any) - - -_DEFAULT_FILTERS_MAP: FilterMap = OrderedDict( - { - (timedelta,): TimeDeltaComparison, - (datetime,): DateTimeComparison, - (time,): TimeComparison, - (date,): DateComparison, - (bool,): EqualityComparison, - (int, float, Decimal): OrderComparison, - (str,): TextComparison, - } -) - - -def loaded_attributes(model: DeclarativeBase) -> set[str]: - """Identifies attributes of a SQLAlchemy model instance that have been loaded. - - This function inspects the given SQLAlchemy model instance and returns a set - of attribute names for which the value has been loaded from the database - (i.e., the value is not `sqlalchemy.orm.NO_VALUE`). - - Args: - model: The SQLAlchemy `DeclarativeBase` instance to inspect. - - Returns: - A set of strings, where each string is the name of a loaded attribute. - """ - return {name for name, attr in inspect(model).attrs.items() if attr.loaded_value is not NO_VALUE} - - -class SQLAlchemyGraphQLInspector(SQLAlchemyInspector): - """Inspects SQLAlchemy models to determine appropriate GraphQL filter types. - - This inspector extends `SQLAlchemyInspector` to provide mappings from - SQLAlchemy model attributes and Python types to specific GraphQL comparison - filter input types (e.g., `TextComparison`, `OrderComparison`). - - It takes into account the database dialect's features (via `DatabaseFeatures`) - to select suitable filters, for example, for JSON or geospatial types. - Custom filter mappings can also be provided through `filter_overrides`. - - Key methods `get_field_comparison` and `get_type_comparison` are used to - retrieve the corresponding filter types. - """ - - def __init__( - self, - dialect: SupportedDialect, - registries: list[registry] | None = None, - filter_overrides: FilterMap | None = None, - ) -> None: - """Initializes the SQLAlchemyGraphQLInspector. - - Args: - dialect: The SQL dialect of the target database (e.g., "postgresql", "sqlite"). - registries: An optional list of SQLAlchemy registries to inspect. - If None, the default registry is used. - filter_overrides: An optional mapping to override or extend the default - Python type to GraphQL filter type mappings. - """ - super().__init__(registries) - self.db_features = DatabaseFeatures.new(dialect) - self.filters_map = self._filter_map() - self.filters_map |= filter_overrides or {} - - def _filter_map(self) -> FilterMap: - """Constructs the map of Python types to GraphQL filter comparison types. - - Starts with a default set of filters (`_DEFAULT_FILTERS_MAP`). - If GeoAlchemy is installed (`GEO_INSTALLED`), it adds mappings for - geospatial types to `GeoComparison`. - It then adds mappings for `dict` to appropriate JSON comparison - types based on whether the dialect is SQLite or another database - that supports more advanced JSON operations. - - Returns: - The constructed `FilterMap`. - """ - filters_map = _DEFAULT_FILTERS_MAP - - if GEO_INSTALLED: - from geoalchemy2 import WKBElement, WKTElement # noqa: PLC0415 - from shapely import Geometry # noqa: PLC0415 - - from strawchemy.strawberry.filters.geo import GeoComparison # noqa: PLC0415 - - filters_map |= {(Geometry, WKBElement, WKTElement): GeoComparison} - if self.db_features.dialect == "sqlite": - filters_map[(dict, dict)] = make_sqlite_json_comparison_input() - else: - filters_map[(dict, dict)] = make_full_json_comparison_input() - return filters_map - - @classmethod - def _is_specialized(cls, type_: type[Any]) -> bool: - """Checks if a generic type is fully specialized. - - A type is considered specialized if it has no type parameters (`__parameters__`) - or if all its type parameters are concrete types (not `TypeVar`). - - Args: - type_: The type to check. - - Returns: - True if the type is specialized, False otherwise. - """ - return not hasattr(type_, "__parameters__") or all( - not isinstance(param, TypeVar) for param in type_.__parameters__ - ) - - @classmethod - def _filter_type(cls, type_: type[Any], sqlalchemy_filter: type[GraphQLComparison]) -> type[GraphQLComparison]: - """Potentially specializes a generic GraphQL filter type with a Python type. - - If the provided `sqlalchemy_filter` is a generic type (e.g., `OrderComparison[T]`) - and is not yet specialized, this method specializes it using `type_` - (e.g., `OrderComparison[int]`). If `sqlalchemy_filter` is already specialized - or not generic, it's returned as is. - - Args: - type_: The Python type to use for specialization if needed. - sqlalchemy_filter: The GraphQL filter type, which might be generic. - - Returns: - The (potentially specialized) GraphQL filter type. - """ - return sqlalchemy_filter if cls._is_specialized(sqlalchemy_filter) else sqlalchemy_filter[type_] # pyright: ignore[reportInvalidTypeArguments] - - def get_field_comparison( - self, field_definition: DTOFieldDefinition[DeclarativeBase, QueryableAttribute[Any]] - ) -> type[GraphQLComparison]: - """Determines the GraphQL comparison filter type for a DTO field. - - This method inspects the type of the given DTO field. - For `ARRAY` types on PostgreSQL, it returns a specialized `ArrayComparison`. - Otherwise, it delegates to `get_type_comparison` using the Python type - of the model field. - - Args: - field_definition: The DTO field definition, which contains information - about the model attribute and its type. - - Returns: - The GraphQL comparison filter type suitable for the field. - """ - field_type = field_definition.model_field.type - if isinstance(field_type, ARRAY) and self.db_features.dialect == "postgresql": - return ArrayComparison[field_type.item_type.python_type] - return self.get_type_comparison(self.model_field_type(field_definition)) - - def get_type_comparison(self, type_: type[Any]) -> type[GraphQLComparison]: - """Determines the GraphQL comparison filter type for a Python type. - - It iterates through the `self.filters_map` (which includes default - and dialect-specific filters) to find a filter type that matches - the provided Python `type_`. - If a direct match or a superclass match is found, the corresponding - filter type is returned, potentially specialized using `_filter_type`. - If no specific filter is found in the map, it defaults to - `EqualityComparison` specialized with the given `type_`. - - Args: - type_: The Python type for which to find a GraphQL filter. - - Returns: - The GraphQL comparison filter type suitable for the Python type. - """ - for types, sqlalchemy_filter in self.filters_map.items(): - if issubclass(type_, types): - return self._filter_type(type_, sqlalchemy_filter) - return EqualityComparison[type_] diff --git a/src/strawchemy/strawberry/__init__.py b/src/strawchemy/strawberry/__init__.py deleted file mode 100644 index 038db7ff..00000000 --- a/src/strawchemy/strawberry/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from __future__ import annotations - -from strawchemy.strawberry._instance import ModelInstance -from strawchemy.strawberry._utils import default_session_getter -from strawchemy.types import DefaultOffsetPagination - -__all__ = ("DefaultOffsetPagination", "ModelInstance", "default_session_getter") diff --git a/src/strawchemy/strawberry/exceptions.py b/src/strawchemy/strawberry/exceptions.py deleted file mode 100644 index f4024a76..00000000 --- a/src/strawchemy/strawberry/exceptions.py +++ /dev/null @@ -1,6 +0,0 @@ -from __future__ import annotations - -__all__ = ("StrawchemyFieldError",) - - -class StrawchemyFieldError(Exception): ... diff --git a/src/strawchemy/strawberry/filters/__init__.py b/src/strawchemy/strawberry/filters/__init__.py deleted file mode 100644 index 26b4388c..00000000 --- a/src/strawchemy/strawberry/filters/__init__.py +++ /dev/null @@ -1,33 +0,0 @@ -from __future__ import annotations - -from strawchemy.strawberry.filters.inputs import ( - ArrayComparison, - DateComparison, - DateTimeComparison, - EqualityComparison, - GraphQLComparison, - GraphQLComparisonT, - GraphQLFilter, - OrderComparison, - TextComparison, - TimeComparison, - TimeDeltaComparison, - _JSONComparison, - _SQLiteJSONComparison, -) - -__all__ = ( - "ArrayComparison", - "DateComparison", - "DateTimeComparison", - "EqualityComparison", - "GraphQLComparison", - "GraphQLComparisonT", - "GraphQLFilter", - "OrderComparison", - "TextComparison", - "TimeComparison", - "TimeDeltaComparison", - "_JSONComparison", - "_SQLiteJSONComparison", -) diff --git a/src/strawchemy/strawberry/mutation/__init__.py b/src/strawchemy/strawberry/mutation/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/strawchemy/strawberry/typing.py b/src/strawchemy/strawberry/typing.py deleted file mode 100644 index 3b56b4df..00000000 --- a/src/strawchemy/strawberry/typing.py +++ /dev/null @@ -1,75 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Literal, TypeAlias, TypeVar - -if TYPE_CHECKING: - from collections.abc import Callable - - from strawberry.types.base import WithStrawberryObjectDefinition - - from sqlalchemy import Select - from strawberry import Info - from strawchemy.graph import Node - from strawchemy.sqlalchemy.typing import AnyAsyncSession, AnySyncSession - from strawchemy.strawberry.dto import ( - AggregateDTO, - FilterFunctionInfo, - GraphQLFieldDefinition, - GraphQLFilterDTO, - MappedStrawberryGraphQLDTO, - OrderByDTO, - OutputFunctionInfo, - QueryNodeMetadata, - StrawchemyDTOAttributes, - UnmappedStrawberryGraphQLDTO, - ) - from strawchemy.validation.pydantic import MappedPydanticGraphQLDTO - -__all__ = ( - "AnySessionGetter", - "AsyncSessionGetter", - "FilterStatementCallable", - "StrawchemyTypeWithStrawberryObjectDefinition", - "SyncSessionGetter", -) - - -_T = TypeVar("_T") -QueryObject = TypeVar("QueryObject", bound="Any") -GraphQLFilterDTOT = TypeVar("GraphQLFilterDTOT", bound="GraphQLFilterDTO") -AggregateDTOT = TypeVar("AggregateDTOT", bound="AggregateDTO") -GraphQLDTOT = TypeVar("GraphQLDTOT", bound="GraphQLDTO[Any]") -OrderByDTOT = TypeVar("OrderByDTOT", bound="OrderByDTO") - -AggregationFunction = Literal["min", "max", "sum", "avg", "count", "stddev_samp", "stddev_pop", "var_samp", "var_pop"] -AggregationType = Literal[ - "sum", "numeric", "min_max_datetime", "min_max_date", "min_max_time", "min_max_string", "min_max_numeric" -] - -GraphQLType = Literal["input", "object", "interface", "enum"] -AsyncSessionGetter: TypeAlias = "Callable[[Info[Any, Any]], AnyAsyncSession]" -SyncSessionGetter: TypeAlias = "Callable[[Info[Any, Any]], AnySyncSession]" -AnySessionGetter: TypeAlias = "AsyncSessionGetter | SyncSessionGetter" -FilterStatementCallable: TypeAlias = "Callable[[Info[Any, Any]], Select[tuple[Any]]]" -GraphQLPurpose: TypeAlias = Literal[ - "type", - "aggregate_type", - "create_input", - "update_by_pk_input", - "update_by_filter_input", - "filter", - "aggregate_filter", - "order_by", - "upsert_update_fields", - "upsert_conflict_fields", -] -FunctionInfo: TypeAlias = "FilterFunctionInfo | OutputFunctionInfo" -StrawberryGraphQLDTO: TypeAlias = "MappedStrawberryGraphQLDTO[_T] | UnmappedStrawberryGraphQLDTO[_T]" -GraphQLDTO: TypeAlias = "StrawberryGraphQLDTO[_T] | MappedPydanticGraphQLDTO[_T]" -MappedGraphQLDTO: TypeAlias = "MappedStrawberryGraphQLDTO[_T] | MappedPydanticGraphQLDTO[_T]" -AnyMappedDTO: TypeAlias = "MappedStrawberryGraphQLDTO[Any] | MappedPydanticGraphQLDTO[Any]" -QueryNodeType: TypeAlias = "Node[GraphQLFieldDefinition, QueryNodeMetadata]" - -if TYPE_CHECKING: - - class StrawchemyTypeWithStrawberryObjectDefinition(StrawchemyDTOAttributes, WithStrawberryObjectDefinition): ... diff --git a/src/strawchemy/testing/pytest_plugin.py b/src/strawchemy/testing/pytest_plugin.py index c9107f84..c0f6c44e 100644 --- a/src/strawchemy/testing/pytest_plugin.py +++ b/src/strawchemy/testing/pytest_plugin.py @@ -5,17 +5,17 @@ from unittest.mock import MagicMock import pytest - from sqlalchemy import Result -from strawchemy.sqlalchemy import _executor as executor -from strawchemy.sqlalchemy._scope import AggregationFunctionInfo + +from strawchemy.transpiler import _executor as executor +from strawchemy.transpiler._scope import AggregationFunctionInfo if TYPE_CHECKING: from collections.abc import Awaitable, Callable from strawchemy.dto import ModelT - from strawchemy.sqlalchemy.typing import AnySession, DeclarativeT - from strawchemy.strawberry.dto import QueryNode + from strawchemy.dto.strawberry import QueryNode + from strawchemy.repository.typing import AnySession, DeclarativeT from strawchemy.typing import SupportedDialect diff --git a/src/strawchemy/transpiler/__init__.py b/src/strawchemy/transpiler/__init__.py new file mode 100644 index 00000000..118e4b75 --- /dev/null +++ b/src/strawchemy/transpiler/__init__.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +from strawchemy.transpiler._executor import ( + AsyncQueryExecutor, + NodeResult, + QueryExecutor, + QueryResult, + SyncQueryExecutor, +) +from strawchemy.transpiler._transpiler import QueryTranspiler +from strawchemy.transpiler.hook import ColumnLoadingMode, QueryHook + +__all__ = ( + "AsyncQueryExecutor", + "ColumnLoadingMode", + "NodeResult", + "QueryExecutor", + "QueryHook", + "QueryResult", + "QueryTranspiler", + "SyncQueryExecutor", +) diff --git a/src/strawchemy/sqlalchemy/_executor.py b/src/strawchemy/transpiler/_executor.py similarity index 97% rename from src/strawchemy/sqlalchemy/_executor.py rename to src/strawchemy/transpiler/_executor.py index 5326dc01..9e9c7dae 100644 --- a/src/strawchemy/sqlalchemy/_executor.py +++ b/src/strawchemy/transpiler/_executor.py @@ -15,18 +15,19 @@ from typing_extensions import Self from strawchemy.dto import ModelT -from strawchemy.sqlalchemy.exceptions import QueryResultError -from strawchemy.sqlalchemy.typing import AnyAsyncSession, AnySyncSession, DeclarativeT +from strawchemy.exceptions import QueryResultError +from strawchemy.repository.typing import AnyAsyncSession, AnySyncSession, DeclarativeT if TYPE_CHECKING: from collections.abc import Callable, Generator, Sequence from sqlalchemy import Label, Result, Select, StatementLambdaElement - from strawchemy.sqlalchemy._scope import QueryScope - from strawchemy.strawberry.typing import QueryNodeType + from strawchemy.transpiler._scope import QueryScope + from strawchemy.typing import QueryNodeType -__all__ = ("AsyncQueryExecutor", "NodeResult", "QueryExecutor", "SyncQueryExecutor") + +__all__ = ("AsyncQueryExecutor", "NodeResult", "QueryExecutor", "QueryResult", "SyncQueryExecutor") @dataclass diff --git a/src/strawchemy/sqlalchemy/_query.py b/src/strawchemy/transpiler/_query.py similarity index 96% rename from src/strawchemy/sqlalchemy/_query.py rename to src/strawchemy/transpiler/_query.py index af633614..f675ab96 100644 --- a/src/strawchemy/sqlalchemy/_query.py +++ b/src/strawchemy/transpiler/_query.py @@ -2,22 +2,10 @@ import dataclasses from collections import defaultdict -from dataclasses import dataclass +from dataclasses import dataclass, field from functools import cached_property from typing import TYPE_CHECKING, Any, Generic, cast -from sqlalchemy.orm import ( - QueryableAttribute, - RelationshipDirection, - RelationshipProperty, - aliased, - class_mapper, - raiseload, -) -from sqlalchemy.orm.util import AliasedClass -from sqlalchemy.sql.elements import NamedColumn -from typing_extensions import Self - from sqlalchemy import ( CTE, AliasedReturnsRows, @@ -32,11 +20,20 @@ null, select, ) +from sqlalchemy.orm import ( + QueryableAttribute, + RelationshipDirection, + RelationshipProperty, + aliased, + class_mapper, + raiseload, +) +from sqlalchemy.orm.util import AliasedClass +from sqlalchemy.sql.elements import NamedColumn +from typing_extensions import Self + from strawchemy.constants import AGGREGATIONS_KEY, NODES_KEY -from strawchemy.graph import merge_trees -from strawchemy.sqlalchemy.exceptions import TranspilingError -from strawchemy.sqlalchemy.typing import DeclarativeT, OrderBySpec -from strawchemy.strawberry.dto import ( +from strawchemy.dto.strawberry import ( BooleanFilterDTO, EnumDTO, Filter, @@ -45,6 +42,9 @@ OrderByEnum, QueryNode, ) +from strawchemy.exceptions import TranspilingError +from strawchemy.repository.typing import DeclarativeT, OrderBySpec +from strawchemy.utils.graph import merge_trees if TYPE_CHECKING: from collections.abc import Sequence @@ -55,13 +55,14 @@ from sqlalchemy.sql.selectable import NamedFromClause from strawchemy.config.databases import DatabaseFeatures - from strawchemy.sqlalchemy._scope import QueryScope - from strawchemy.sqlalchemy.hook import ColumnLoadingMode, QueryHook - from strawchemy.strawberry.typing import QueryNodeType + from strawchemy.transpiler import ColumnLoadingMode, QueryHook + from strawchemy.transpiler._scope import QueryScope + from strawchemy.typing import QueryNodeType __all__ = ("AggregationJoin", "Conjunction", "DistinctOn", "Join", "OrderBy", "QueryGraph", "Where") +@dataclass class Join: """Represents a join to be applied to a SQLAlchemy query. @@ -77,19 +78,11 @@ class Join: particularly relevant for ordered relationships. """ - def __init__( - self, - target: QueryableAttribute[Any] | NamedFromClause | AliasedClass[Any], - node: QueryNodeType, - onclause: _OnClauseArgument | None = None, - is_outer: bool = False, - order_nodes: list[QueryNodeType] | None = None, - ) -> None: - self.target = target - self.node = node - self.onclause = onclause - self.is_outer = is_outer - self.order_nodes = order_nodes if order_nodes is not None else [] + target: QueryableAttribute[Any] | NamedFromClause | AliasedClass[Any] + node: QueryNodeType + onclause: _OnClauseArgument | None = None + is_outer: bool = False + order_nodes: list[QueryNodeType] = dataclasses.field(default_factory=list) @property def _relationship(self) -> RelationshipProperty[Any]: @@ -138,6 +131,7 @@ def __ge__(self, other: Self) -> bool: return self.order >= other.order +@dataclass(kw_only=True) class AggregationJoin(Join): """Represents a join specifically for aggregation purposes, often involving a subquery. @@ -150,19 +144,10 @@ class AggregationJoin(Join): _column_names: Internal tracking of column names within the subquery to ensure uniqueness. """ - def __init__( - self, - target: QueryableAttribute[Any] | NamedFromClause | AliasedClass[Any], - node: QueryNodeType, - subquery_alias: AliasedClass[Any], - onclause: _OnClauseArgument | None = None, - is_outer: bool = False, - order_nodes: list[QueryNodeType] | None = None, - ) -> None: - super().__init__(target, node, onclause, is_outer, order_nodes) - self.subquery_alias = subquery_alias - self._column_names: defaultdict[str, int] = defaultdict(int) + subquery_alias: AliasedClass[Any] + _column_names: defaultdict[str, int] = field(default_factory=lambda: defaultdict(int)) + def __post_init__(self) -> None: # Initialize the _column_names mapping from the subquery's selected columns for column in self._inner_select.selected_columns: if isinstance(column, NamedColumn): diff --git a/src/strawchemy/sqlalchemy/_scope.py b/src/strawchemy/transpiler/_scope.py similarity index 98% rename from src/strawchemy/sqlalchemy/_scope.py rename to src/strawchemy/transpiler/_scope.py index 6d7f0272..e74d8b05 100644 --- a/src/strawchemy/sqlalchemy/_scope.py +++ b/src/strawchemy/transpiler/_scope.py @@ -24,21 +24,20 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeAlias +from sqlalchemy import ColumnElement, FromClause, Function, Label, Select, func, inspect +from sqlalchemy import cast as sqla_cast +from sqlalchemy import distinct as sqla_distinct from sqlalchemy.dialects import postgresql from sqlalchemy.orm import DeclarativeBase, Mapper, MapperProperty, QueryableAttribute, RelationshipProperty, aliased -from sqlalchemy.orm.util import AliasedClass from typing_extensions import Self, override -from sqlalchemy import ColumnElement, FromClause, Function, Label, Select, func, inspect -from sqlalchemy import cast as sqla_cast -from sqlalchemy import distinct as sqla_distinct from strawchemy.constants import NODES_KEY +from strawchemy.dto.inspectors import SQLAlchemyInspector +from strawchemy.dto.strawberry import GraphQLFieldDefinition, QueryNode from strawchemy.dto.types import DTOConfig, Purpose -from strawchemy.graph import Node -from strawchemy.sqlalchemy.exceptions import TranspilingError -from strawchemy.sqlalchemy.inspector import SQLAlchemyInspector -from strawchemy.sqlalchemy.typing import DeclarativeT -from strawchemy.strawberry.dto import GraphQLFieldDefinition, QueryNode +from strawchemy.exceptions import TranspilingError +from strawchemy.repository.typing import DeclarativeT +from strawchemy.utils.graph import Node if TYPE_CHECKING: from collections.abc import Callable @@ -46,9 +45,8 @@ from sqlalchemy.orm.util import AliasedClass from sqlalchemy.sql.elements import NamedColumn - from strawchemy.sqlalchemy.typing import DeclarativeSubT, FunctionGenerator, RelationshipSide - from strawchemy.strawberry.typing import QueryNodeType - from strawchemy.typing import SupportedDialect + from strawchemy.repository.typing import DeclarativeSubT, FunctionGenerator, RelationshipSide + from strawchemy.typing import QueryNodeType, SupportedDialect __all__ = ("NodeInspect", "QueryScope") @@ -460,7 +458,7 @@ def filter_function( It retrieves the function using `AggregationFunctionInfo`. Arguments for the function are derived from the children of the current node, adapted to the given `alias`. - If `distinct` is True, `sqlalchemy.distinct()` is applied to the arguments. + If `distinct` is True, `strawberry.distinct()` is applied to the arguments. The label for the function is determined by the scope key of either the first child (if there's only one, implying the function applies to that @@ -468,7 +466,7 @@ def filter_function( Args: alias: The `AliasedClass` to adapt function arguments to. - distinct: If True, applies `sqlalchemy.distinct()` to the function + distinct: If True, applies `strawberry.distinct()` to the function arguments. Defaults to None (no distinct). Returns: @@ -615,8 +613,8 @@ class QueryScope(Generic[DeclarativeT]): naming conflicts and ensuring the query is valid. Example: - >>> from sqlalchemy.orm import declarative_base - >>> from sqlalchemy import Column, Integer, String + >>> from strawberry.orm import declarative_base + >>> from strawberry import Column, Integer, String >>> Base = declarative_base() >>> class User(Base): ... __tablename__ = 'users' diff --git a/src/strawchemy/sqlalchemy/_transpiler.py b/src/strawchemy/transpiler/_transpiler.py similarity index 98% rename from src/strawchemy/sqlalchemy/_transpiler.py rename to src/strawchemy/transpiler/_transpiler.py index 8f1daaaf..160c81e6 100644 --- a/src/strawchemy/sqlalchemy/_transpiler.py +++ b/src/strawchemy/transpiler/_transpiler.py @@ -14,10 +14,6 @@ from contextlib import contextmanager from typing import TYPE_CHECKING, Any, Generic, cast -from sqlalchemy.orm import Mapper, RelationshipProperty, aliased, class_mapper, contains_eager, load_only, raiseload -from sqlalchemy.sql.elements import ColumnElement -from typing_extensions import Self, override - from sqlalchemy import ( Dialect, Label, @@ -34,9 +30,26 @@ text, true, ) +from sqlalchemy.orm import Mapper, RelationshipProperty, aliased, class_mapper, contains_eager, load_only, raiseload +from typing_extensions import Self, override + from strawchemy.constants import AGGREGATIONS_KEY -from strawchemy.sqlalchemy._executor import SyncQueryExecutor -from strawchemy.sqlalchemy._query import ( +from strawchemy.dto.inspectors import SQLAlchemyGraphQLInspector +from strawchemy.dto.strawberry import ( + AggregationFilter, + BooleanFilterDTO, + EnumDTO, + Filter, + OrderByDTO, + OrderByEnum, + OrderByRelationFilterDTO, + QueryNode, +) +from strawchemy.exceptions import TranspilingError +from strawchemy.repository.typing import DeclarativeT, OrderBySpec, QueryExecutorT +from strawchemy.schema.filters import GraphQLComparison +from strawchemy.transpiler._executor import SyncQueryExecutor +from strawchemy.transpiler._query import ( AggregationJoin, Conjunction, DistinctOn, @@ -48,21 +61,7 @@ SubqueryBuilder, Where, ) -from strawchemy.sqlalchemy._scope import QueryScope -from strawchemy.sqlalchemy.exceptions import TranspilingError -from strawchemy.sqlalchemy.inspector import SQLAlchemyGraphQLInspector -from strawchemy.sqlalchemy.typing import DeclarativeT, OrderBySpec, QueryExecutorT -from strawchemy.strawberry.dto import ( - AggregationFilter, - BooleanFilterDTO, - EnumDTO, - Filter, - OrderByDTO, - OrderByEnum, - OrderByRelationFilterDTO, - QueryNode, -) -from strawchemy.strawberry.filters import GraphQLComparison +from strawchemy.transpiler._scope import QueryScope if TYPE_CHECKING: from collections.abc import Iterable, Iterator, Sequence @@ -72,9 +71,8 @@ from sqlalchemy.sql import ColumnElement, SQLColumnExpression from sqlalchemy.sql.elements import NamedColumn - from strawchemy.sqlalchemy.hook import QueryHook - from strawchemy.strawberry.typing import QueryNodeType - from strawchemy.typing import SupportedDialect + from strawchemy.transpiler.hook import QueryHook + from strawchemy.typing import QueryNodeType, SupportedDialect __all__ = ("QueryTranspiler",) diff --git a/src/strawchemy/sqlalchemy/hook.py b/src/strawchemy/transpiler/hook.py similarity index 97% rename from src/strawchemy/sqlalchemy/hook.py rename to src/strawchemy/transpiler/hook.py index 84b99dd5..55da0f8f 100644 --- a/src/strawchemy/sqlalchemy/hook.py +++ b/src/strawchemy/transpiler/hook.py @@ -13,20 +13,17 @@ from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, TypeAlias from sqlalchemy.orm import ColumnProperty, RelationshipProperty, joinedload, selectinload, undefer -from sqlalchemy.orm.strategy_options import _AbstractLoad -from sqlalchemy.orm.util import AliasedClass -from strawchemy.sqlalchemy.exceptions import QueryHookError -from strawchemy.sqlalchemy.typing import DeclarativeT +from strawchemy.exceptions import QueryHookError +from strawchemy.repository.typing import DeclarativeT if TYPE_CHECKING: from collections.abc import Sequence + from sqlalchemy import Select from sqlalchemy.orm import InstrumentedAttribute from sqlalchemy.orm.strategy_options import _AbstractLoad from sqlalchemy.orm.util import AliasedClass - - from sqlalchemy import Select from strawberry import Info diff --git a/src/strawchemy/typing.py b/src/strawchemy/typing.py index 8e8653b9..c0984dc9 100644 --- a/src/strawchemy/typing.py +++ b/src/strawchemy/typing.py @@ -1,21 +1,112 @@ from __future__ import annotations from types import UnionType -from typing import TYPE_CHECKING, Any, ClassVar, Literal, Protocol, TypeAlias, Union +from typing import TYPE_CHECKING, Any, ClassVar, Literal, Protocol, TypeAlias, TypeVar, Union -UNION_TYPES = (Union, UnionType) +if TYPE_CHECKING: + from collections.abc import Callable, Sequence + from sqlalchemy import Select + from strawberry import Info + from strawberry.types.base import WithStrawberryObjectDefinition -if TYPE_CHECKING: - from strawchemy import StrawchemyAsyncRepository, StrawchemySyncRepository + from strawchemy import StrawchemyAsyncRepository, StrawchemySyncRepository, ValidationErrorType + from strawchemy.dto.strawberry import ( + AggregateDTO, + FilterFunctionInfo, + GraphQLFieldDefinition, + GraphQLFilterDTO, + MappedStrawberryGraphQLDTO, + OrderByDTO, + OutputFunctionInfo, + QueryNodeMetadata, + StrawchemyDTOAttributes, + UnmappedStrawberryGraphQLDTO, + ) + from strawchemy.utils.graph import Node + from strawchemy.validation.pydantic import MappedPydanticGraphQLDTO -__all__ = ("UNION_TYPES", "AnyRepository", "DataclassProtocol", "SupportedDialect") +__all__ = ( + "UNION_TYPES", + "AggregateDTOT", + "AggregationFunction", + "AggregationType", + "AnyMappedDTO", + "AnyRepository", + "AnyRepositoryType", + "CreateOrUpdateResolverResult", + "DataclassProtocol", + "FilterStatementCallable", + "FunctionInfo", + "GetByIdResolverResult", + "GraphQLDTO", + "GraphQLDTOT", + "GraphQLFilterDTOT", + "GraphQLPurpose", + "GraphQLType", + "ListResolverResult", + "MappedGraphQLDTO", + "OneOrManyResult", + "OrderByDTOT", + "QueryNodeType", + "QueryObject", + "StrawberryGraphQLDTO", + "StrawchemyTypeWithStrawberryObjectDefinition", + "SupportedDialect", +) + +UNION_TYPES = (Union, UnionType) -class DataclassProtocol(Protocol): - __dataclass_fields__: ClassVar[dict[str, Any]] +T = TypeVar("T", bound="Any") +QueryObject = TypeVar("QueryObject", bound="Any") +GraphQLFilterDTOT = TypeVar("GraphQLFilterDTOT", bound="GraphQLFilterDTO") +AggregateDTOT = TypeVar("AggregateDTOT", bound="AggregateDTO") +GraphQLDTOT = TypeVar("GraphQLDTOT", bound="GraphQLDTO[Any]") +OrderByDTOT = TypeVar("OrderByDTOT", bound="OrderByDTO") -AnyRepository: TypeAlias = "type[StrawchemySyncRepository[Any] | StrawchemyAsyncRepository[Any]]" SupportedDialect: TypeAlias = Literal["postgresql", "mysql", "sqlite"] """Must match SQLAlchemy dialect.""" + +AggregationFunction = Literal["min", "max", "sum", "avg", "count", "stddev_samp", "stddev_pop", "var_samp", "var_pop"] +AggregationType = Literal[ + "sum", "numeric", "min_max_datetime", "min_max_date", "min_max_time", "min_max_string", "min_max_numeric" +] +GraphQLType = Literal["input", "object", "interface", "enum"] + +AnyRepository: TypeAlias = "StrawchemySyncRepository[Any] | StrawchemyAsyncRepository[Any]" +AnyRepositoryType: TypeAlias = "type[AnyRepository]" +FilterStatementCallable: TypeAlias = "Callable[[Info[Any, Any]], Select[tuple[Any]]]" +GraphQLPurpose: TypeAlias = Literal[ + "type", + "aggregate_type", + "create_input", + "update_by_pk_input", + "update_by_filter_input", + "filter", + "aggregate_filter", + "order_by", + "upsert_update_fields", + "upsert_conflict_fields", +] +FunctionInfo: TypeAlias = "FilterFunctionInfo | OutputFunctionInfo" +StrawberryGraphQLDTO: TypeAlias = "MappedStrawberryGraphQLDTO[T] | UnmappedStrawberryGraphQLDTO[T]" +GraphQLDTO: TypeAlias = "StrawberryGraphQLDTO[T] | MappedPydanticGraphQLDTO[T]" +MappedGraphQLDTO: TypeAlias = "MappedStrawberryGraphQLDTO[T] | MappedPydanticGraphQLDTO[T]" +AnyMappedDTO: TypeAlias = "MappedStrawberryGraphQLDTO[Any] | MappedPydanticGraphQLDTO[Any]" +QueryNodeType: TypeAlias = "Node[GraphQLFieldDefinition, QueryNodeMetadata]" +OneOrManyResult: TypeAlias = ( + "Sequence[StrawchemyTypeWithStrawberryObjectDefinition] | StrawchemyTypeWithStrawberryObjectDefinition" +) +ListResolverResult: TypeAlias = OneOrManyResult +GetByIdResolverResult: TypeAlias = "StrawchemyTypeWithStrawberryObjectDefinition | None" +CreateOrUpdateResolverResult: TypeAlias = "OneOrManyResult | ValidationErrorType | Sequence[ValidationErrorType]" + + +if TYPE_CHECKING: + + class DataclassProtocol(Protocol): + __dataclass_fields__: ClassVar[dict[str, Any]] + + class StrawchemyTypeWithStrawberryObjectDefinition(StrawchemyDTOAttributes, WithStrawberryObjectDefinition): ... diff --git a/src/strawchemy/strawberry/factories/__init__.py b/src/strawchemy/utils/__init__.py similarity index 100% rename from src/strawchemy/strawberry/factories/__init__.py rename to src/strawchemy/utils/__init__.py diff --git a/src/strawchemy/utils/annotation.py b/src/strawchemy/utils/annotation.py new file mode 100644 index 00000000..9aaafd8e --- /dev/null +++ b/src/strawchemy/utils/annotation.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +import inspect +from typing import Any, NewType, Optional, TypeVar, Union, get_args, get_origin + +from strawchemy.typing import UNION_TYPES + +T = TypeVar("T", bound="Any") + + +def non_optional_type_hint(type_hint: Any) -> Any: + origin, args = get_origin(type_hint), get_args(type_hint) + if origin is Optional: + return args + if origin in UNION_TYPES: + union_args = tuple([arg for arg in args if arg not in (None, type(None))]) + if len(union_args) == 1: + return union_args[0] + return Union[union_args] + return type_hint + + +def is_type_hint_optional(type_hint: Any) -> bool: + """Whether the given type hint is considered as optional or not. + + Returns: + `True` if arguments of the given type hint are optional + + Three cases are considered: + ``` + Optional[str] + Union[str, None] + str | None + ``` + In any other form, the type hint will not be considered as optional + """ + origin = get_origin(type_hint) + if origin is None: + return False + if origin is Optional: + return True + if origin in UNION_TYPES: + args = get_args(type_hint) + return any(arg is type(None) for arg in args) + return False + + +def get_annotations(obj: Any) -> dict[str, Any]: + """Get the annotations of the given object.""" + return inspect.get_annotations(obj) + + +def new_type(name: str, type_: type[T]) -> type[T]: + # Needed for pyright + return NewType(name, type_) # pyright: ignore[reportArgumentType] diff --git a/src/strawchemy/graph.py b/src/strawchemy/utils/graph.py similarity index 98% rename from src/strawchemy/graph.py rename to src/strawchemy/utils/graph.py index 354c719d..7d9a1787 100644 --- a/src/strawchemy/graph.py +++ b/src/strawchemy/utils/graph.py @@ -8,10 +8,23 @@ from typing_extensions import Self, override +from strawchemy.exceptions import GraphError + if TYPE_CHECKING: from collections.abc import Callable, Generator, Hashable -__all__ = ("GraphMetadata", "IterationMode", "MatchOn", "Node", "NodeMetadataT", "NodeValueT", "merge_trees") +__all__ = ( + "AnyNode", + "GraphMetadata", + "IterationMode", + "MatchOn", + "Node", + "NodeMetadata", + "NodeMetadataT", + "NodeT", + "NodeValueT", + "merge_trees", +) T = TypeVar("T") NodeValueT = TypeVar("NodeValueT", bound="Any") @@ -22,9 +35,6 @@ AnyNode: TypeAlias = "Node[Any, Any]" -class GraphError(Exception): ... - - @dataclass class GraphMetadata(Generic[T]): metadata: T diff --git a/src/strawchemy/strawberry/_registry.py b/src/strawchemy/utils/registry.py similarity index 96% rename from src/strawchemy/strawberry/_registry.py rename to src/strawchemy/utils/registry.py index ae2be4c0..ad060ac8 100644 --- a/src/strawchemy/strawberry/_registry.py +++ b/src/strawchemy/utils/registry.py @@ -4,29 +4,18 @@ from collections import defaultdict from copy import copy from enum import Enum -from typing import ( - TYPE_CHECKING, - Any, - ForwardRef, - Literal, - NewType, - TypeVar, - cast, - get_args, - get_origin, - overload, -) +from typing import TYPE_CHECKING, Any, ForwardRef, Literal, NewType, TypeVar, cast, get_args, get_origin, overload +import strawberry from strawberry.annotation import StrawberryAnnotation from strawberry.types import get_object_definition, has_object_definition from strawberry.types.base import StrawberryContainer from strawberry.types.field import StrawberryField -import strawberry -from strawchemy.strawberry._utils import strawberry_contained_types +from strawchemy.utils.strawberry import strawberry_contained_types try: - from strawchemy.strawberry.filters.geo import GeoComparison + from strawchemy.schema.filters.geo import GeoComparison geo_comparison = GeoComparison except ModuleNotFoundError: # pragma: no cover @@ -42,9 +31,8 @@ from strawberry.types.base import WithStrawberryObjectDefinition from strawchemy.dto.types import DTOScope - from strawchemy.strawberry.typing import GraphQLType, StrawchemyTypeWithStrawberryObjectDefinition - from strawchemy.types import DefaultOffsetPagination - + from strawchemy.schema.pagination import DefaultOffsetPagination + from strawchemy.typing import GraphQLType, StrawchemyTypeWithStrawberryObjectDefinition __all__ = ("RegistryTypeInfo", "StrawberryRegistry") diff --git a/src/strawchemy/strawberry/_utils.py b/src/strawchemy/utils/strawberry.py similarity index 61% rename from src/strawchemy/strawberry/_utils.py rename to src/strawchemy/utils/strawberry.py index fb34d2c8..75d52bae 100644 --- a/src/strawchemy/strawberry/_utils.py +++ b/src/strawchemy/utils/strawberry.py @@ -1,19 +1,25 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any - -from strawberry.types.base import StrawberryContainer, StrawberryType -from strawberry.types.lazy_type import LazyType +from typing import TYPE_CHECKING, Any, Optional, get_args, get_origin + +from strawberry import Info, LazyType +from strawberry.types.base import ( + StrawberryContainer, + StrawberryList, + StrawberryOptional, + StrawberryType, + WithStrawberryObjectDefinition, +) from strawberry.types.union import StrawberryUnion from strawchemy.exceptions import SessionNotFoundError -from strawchemy.strawberry.mutation.types import ErrorType +from strawchemy.schema.interfaces import ErrorType +from strawchemy.typing import UNION_TYPES if TYPE_CHECKING: - from strawberry import Info - + from typing_extensions import TypeIs -__all__ = ("default_session_getter", "dto_model_from_type") +_OPTIONAL_UNION_ARG_SIZE: int = 2 def _get_or_subscribe(obj: Any, key: Any) -> Any: @@ -56,3 +62,18 @@ def strawberry_contained_user_type(type_: StrawberryType | Any) -> Any: inner_type for inner_type in strawberry_contained_types(type_) if inner_type not in ErrorType.__error_types__ ] return inner_types[0] + + +def is_list( + type_: StrawberryType | type[WithStrawberryObjectDefinition] | object | str, +) -> TypeIs[type[list[Any]] | StrawberryList]: + if isinstance(type_, StrawberryOptional): + type_ = type_.of_type + if origin := get_origin(type_): + type_ = origin + if origin is Optional: + type_ = get_args(type_)[0] + if origin in UNION_TYPES and len(args := get_args(type_)) == _OPTIONAL_UNION_ARG_SIZE: + type_ = args[0] if args[0] is not type(None) else args[1] + + return isinstance(type_, StrawberryList) or type_ is list diff --git a/src/strawchemy/utils.py b/src/strawchemy/utils/text.py similarity index 51% rename from src/strawchemy/utils.py rename to src/strawchemy/utils/text.py index a8c097ef..70b46398 100644 --- a/src/strawchemy/utils.py +++ b/src/strawchemy/utils/text.py @@ -1,23 +1,20 @@ from __future__ import annotations -import inspect import re -from typing import TYPE_CHECKING, Any, Optional, Union, get_args, get_origin - -from strawchemy.typing import UNION_TYPES +from typing import TYPE_CHECKING, Any, TypeVar if TYPE_CHECKING: from re import Pattern __all__ = ( "camel_to_snake", - "is_type_hint_optional", - "non_optional_type_hint", "snake_keys", "snake_to_camel", "snake_to_lower_camel_case", ) +T = TypeVar("T", bound="Any") + _camel_to_snake_pattern: Pattern[str] = re.compile(r"((?<=[a-z0-9])[A-Z]|(?!^)(? dict[str, Any]: else: res[to_snake] = v return res - - -def non_optional_type_hint(type_hint: Any) -> Any: - origin, args = get_origin(type_hint), get_args(type_hint) - if origin is Optional: - return args - if origin in UNION_TYPES: - union_args = tuple([arg for arg in args if arg not in (None, type(None))]) - if len(union_args) == 1: - return union_args[0] - return Union[union_args] - return type_hint - - -def is_type_hint_optional(type_hint: Any) -> bool: - """Whether the given type hint is considered as optional or not. - - Returns: - `True` if arguments of the given type hint are optional - - Three cases are considered: - ``` - Optional[str] - Union[str, None] - str | None - ``` - In any other form, the type hint will not be considered as optional - """ - origin = get_origin(type_hint) - if origin is None: - return False - if origin is Optional: - return True - if origin in UNION_TYPES: - args = get_args(type_hint) - return any(arg is type(None) for arg in args) - return False - - -def get_annotations(obj: Any) -> dict[str, Any]: - """Get the annotations of the given object.""" - return inspect.get_annotations(obj) diff --git a/src/strawchemy/validation/__init__.py b/src/strawchemy/validation/__init__.py index e69de29b..d9a11f82 100644 --- a/src/strawchemy/validation/__init__.py +++ b/src/strawchemy/validation/__init__.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +from strawchemy.validation.base import InputValidationError, ValidationProtocol + +__all__ = ("InputValidationError", "ValidationProtocol") diff --git a/src/strawchemy/validation/base.py b/src/strawchemy/validation/base.py index 93e41725..818177b5 100644 --- a/src/strawchemy/validation/base.py +++ b/src/strawchemy/validation/base.py @@ -8,14 +8,16 @@ from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar +from strawchemy.exceptions import StrawchemyError + if TYPE_CHECKING: from strawchemy.dto.base import MappedDTO - from strawchemy.strawberry.mutation.types import ValidationErrorType + from strawchemy.schema.mutation import ValidationErrorType T = TypeVar("T") -class InputValidationError(Exception): +class InputValidationError(StrawchemyError): """Exception raised when input validation fails. This exception wraps the original validation error and provides a method to convert diff --git a/src/strawchemy/validation/pydantic.py b/src/strawchemy/validation/pydantic.py index 371aad6d..93426ebd 100644 --- a/src/strawchemy/validation/pydantic.py +++ b/src/strawchemy/validation/pydantic.py @@ -1,21 +1,19 @@ from __future__ import annotations -from collections.abc import Callable from dataclasses import dataclass from functools import partial from typing import TYPE_CHECKING, Any, ClassVar from pydantic import ValidationError -from sqlalchemy.orm import DeclarativeBase from typing_extensions import override from strawchemy.dto.backend.pydantic import MappedPydanticDTO, PydanticDTOBackend from strawchemy.dto.base import ModelT +from strawchemy.dto.strawberry import StrawchemyDTOAttributes from strawchemy.dto.utils import read_partial -from strawchemy.strawberry.dto import StrawchemyDTOAttributes -from strawchemy.strawberry.factories.types import InputFactory -from strawchemy.strawberry.mutation.types import LocalizedErrorType, ValidationErrorType -from strawchemy.utils import snake_to_lower_camel_case +from strawchemy.schema.factories import InputFactory +from strawchemy.schema.mutation import LocalizedErrorType, ValidationErrorType +from strawchemy.utils.text import snake_to_lower_camel_case from strawchemy.validation.base import InputValidationError, T, ValidationProtocol if TYPE_CHECKING: @@ -27,9 +25,9 @@ from strawchemy import Strawchemy from strawchemy.dto.base import DTOFieldDefinition, MappedDTO, Relation from strawchemy.dto.types import DTOConfig, ExcludeFields, IncludeFields, Purpose - from strawchemy.graph import Node - from strawchemy.sqlalchemy.typing import DeclarativeT - from strawchemy.strawberry.typing import GraphQLPurpose + from strawchemy.repository.typing import DeclarativeT + from strawchemy.typing import GraphQLPurpose + from strawchemy.utils.graph import Node @dataclass diff --git a/tests/fixtures.py b/tests/fixtures.py index ee1d4038..e819ba61 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -4,10 +4,10 @@ from typing import TYPE_CHECKING import pytest +import strawberry from syrupy.assertion import SnapshotAssertion from syrupy.extensions.amber import AmberSnapshotExtension -import strawberry from strawchemy import Strawchemy, StrawchemyConfig from tests.syrupy import GraphQLFileExtension from tests.utils import sqlalchemy_pydantic_factory diff --git a/tests/integration/data_types/test_array.py b/tests/integration/data_types/test_array.py index 9a47e216..6b41c016 100644 --- a/tests/integration/data_types/test_array.py +++ b/tests/integration/data_types/test_array.py @@ -3,8 +3,8 @@ from typing import TYPE_CHECKING, Any import pytest - from sqlalchemy import Insert, MetaData, insert + from tests.integration.fixtures import QueryTracker from tests.integration.models import ArrayModel, array_metadata from tests.integration.types import postgres as postgres_types diff --git a/tests/integration/data_types/test_date_time.py b/tests/integration/data_types/test_date_time.py index 115490b6..2696eaaa 100644 --- a/tests/integration/data_types/test_date_time.py +++ b/tests/integration/data_types/test_date_time.py @@ -3,8 +3,8 @@ from typing import TYPE_CHECKING, Any import pytest - from sqlalchemy import Insert, MetaData, insert + from tests.integration.fixtures import QueryTracker from tests.integration.models import DateTimeModel, date_time_metadata from tests.integration.types import mysql as mysql_types diff --git a/tests/integration/data_types/test_interval.py b/tests/integration/data_types/test_interval.py index 5c149453..80aadbce 100644 --- a/tests/integration/data_types/test_interval.py +++ b/tests/integration/data_types/test_interval.py @@ -5,8 +5,8 @@ import msgspec import pytest - from sqlalchemy import Insert, MetaData, insert + from tests.integration.fixtures import QueryTracker from tests.integration.models import IntervalModel, interval_metadata from tests.integration.types import mysql as mysql_types diff --git a/tests/integration/data_types/test_json.py b/tests/integration/data_types/test_json.py index f8cdb3e7..83d683f2 100644 --- a/tests/integration/data_types/test_json.py +++ b/tests/integration/data_types/test_json.py @@ -3,8 +3,8 @@ from typing import TYPE_CHECKING, Any import pytest - from sqlalchemy import Insert, MetaData, insert + from tests.integration.models import JSONModel, json_metadata from tests.integration.types import mysql as mysql_types from tests.integration.types import postgres as postgres_types diff --git a/tests/integration/fixtures.py b/tests/integration/fixtures.py index c82d9e0c..e1153281 100644 --- a/tests/integration/fixtures.py +++ b/tests/integration/fixtures.py @@ -13,11 +13,6 @@ import sqlparse from pytest_databases.docker.postgres import _provide_postgres_service from pytest_lazy_fixtures import lf -from sqlalchemy.event import listens_for -from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine -from sqlalchemy.orm import Session, sessionmaker -from typing_extensions import Self - from sqlalchemy import ( URL, ClauseElement, @@ -36,10 +31,15 @@ create_engine, insert, ) +from sqlalchemy.event import listens_for +from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.orm import Session, sessionmaker from strawberry.scalars import JSON +from typing_extensions import Self + from strawchemy.config.databases import DatabaseFeatures from strawchemy.constants import GEO_INSTALLED -from strawchemy.strawberry.scalars import Date, DateTime, Interval, Time +from strawchemy.schema.scalars import Date, DateTime, Interval, Time from tests.fixtures import DefaultQuery from tests.integration.models import ( Color, @@ -71,7 +71,7 @@ from syrupy.assertion import SnapshotAssertion from strawchemy import Strawchemy, StrawchemyConfig - from strawchemy.sqlalchemy.typing import AnySession + from strawchemy.repository.typing import AnySession from strawchemy.typing import SupportedDialect from tests.integration.typing import RawRecordData @@ -104,7 +104,7 @@ engine_plugins: list[str] = [] if GEO_INSTALLED: - from strawchemy.strawberry.geo import GEO_SCALAR_OVERRIDES + from strawchemy.schema.scalars.geo import GEO_SCALAR_OVERRIDES engine_plugins = ["geoalchemy2"] scalar_overrides |= GEO_SCALAR_OVERRIDES diff --git a/tests/integration/geo/models.py b/tests/integration/geo/models.py index 79b871cc..25907db1 100644 --- a/tests/integration/geo/models.py +++ b/tests/integration/geo/models.py @@ -1,10 +1,10 @@ from __future__ import annotations from geoalchemy2 import Geometry, WKBElement +from sqlalchemy import MetaData from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column from sqlalchemy.orm import registry as Registry # noqa: N812 -from sqlalchemy import MetaData from tests.integration.models import BaseColumns metadata, geo_metadata = MetaData(), MetaData() diff --git a/tests/integration/geo/test_geo_filters.py b/tests/integration/geo/test_geo_filters.py index cf5724aa..a88e91c6 100644 --- a/tests/integration/geo/test_geo_filters.py +++ b/tests/integration/geo/test_geo_filters.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any from sqlalchemy import Executable, Insert, MetaData, insert, text + from tests.integration.fixtures import QueryTracker from tests.integration.geo.models import GeoModel, geo_metadata from tests.integration.geo.types import mysql as mysql_types diff --git a/tests/integration/geo/types/mysql.py b/tests/integration/geo/types/mysql.py index 2629ca75..b5b8cabd 100644 --- a/tests/integration/geo/types/mysql.py +++ b/tests/integration/geo/types/mysql.py @@ -1,6 +1,7 @@ from __future__ import annotations import strawberry + from strawchemy import Strawchemy, StrawchemyAsyncRepository, StrawchemySyncRepository from tests.integration.geo.models import GeoModel diff --git a/tests/integration/geo/types/postgres.py b/tests/integration/geo/types/postgres.py index 164cc061..7faf4e8e 100644 --- a/tests/integration/geo/types/postgres.py +++ b/tests/integration/geo/types/postgres.py @@ -1,6 +1,7 @@ from __future__ import annotations import strawberry + from strawchemy import Strawchemy, StrawchemyAsyncRepository, StrawchemySyncRepository from tests.integration.geo.models import GeoModel diff --git a/tests/integration/models.py b/tests/integration/models.py index b83bc73f..2e4ee01f 100644 --- a/tests/integration/models.py +++ b/tests/integration/models.py @@ -5,12 +5,6 @@ from datetime import date, datetime, time, timedelta from typing import Any -from sqlalchemy.dialects import mysql, postgresql, sqlite -from sqlalchemy.ext.declarative import declared_attr -from sqlalchemy.ext.hybrid import hybrid_property -from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, column_property, mapped_column, relationship -from sqlalchemy.orm import registry as Registry # noqa: N812 - from sqlalchemy import ( ARRAY, JSON, @@ -29,6 +23,12 @@ Time, UniqueConstraint, ) +from sqlalchemy.dialects import mysql, postgresql, sqlite +from sqlalchemy.ext.declarative import declared_attr +from sqlalchemy.ext.hybrid import hybrid_property +from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, column_property, mapped_column, relationship +from sqlalchemy.orm import registry as Registry # noqa: N812 + from strawchemy.dto.utils import PRIVATE, READ_ONLY metadata = MetaData() diff --git a/tests/integration/test_aggregations.py b/tests/integration/test_aggregations.py index 7ecc6937..e34e9259 100644 --- a/tests/integration/test_aggregations.py +++ b/tests/integration/test_aggregations.py @@ -4,7 +4,7 @@ import pytest -from strawchemy.types import DefaultOffsetPagination +from strawchemy.schema.pagination import DefaultOffsetPagination from tests.integration.fixtures import QueryTracker from tests.integration.models import Fruit from tests.integration.typing import RawRecordData diff --git a/tests/integration/test_mutations.py b/tests/integration/test_mutations.py index c5564648..11f70ecc 100644 --- a/tests/integration/test_mutations.py +++ b/tests/integration/test_mutations.py @@ -1776,7 +1776,6 @@ async def test_read_only_column_override(query_name: str, query: str, any_query: ), ], ) -@pytest.mark.filterwarnings("ignore::sqlalchemy.exc.SAWarning") async def test_relationship_to_many_override(query_name: str, query: str, any_query: AnyQueryExecutor) -> None: result = await maybe_async(any_query(query)) assert not result.errors @@ -1820,7 +1819,6 @@ async def test_relationship_to_many_override(query_name: str, query: str, any_qu ), ], ) -@pytest.mark.filterwarnings("ignore::sqlalchemy.exc.SAWarning") async def test_relationship_to_one_override(query_name: str, query: str, any_query: AnyQueryExecutor) -> None: result = await maybe_async(any_query(query)) assert not result.errors diff --git a/tests/integration/test_schema.py b/tests/integration/test_schema.py index 420ac30d..93cff09f 100644 --- a/tests/integration/test_schema.py +++ b/tests/integration/test_schema.py @@ -1,6 +1,7 @@ from __future__ import annotations from strawberry import Schema + from tests.integration.fixtures import scalar_overrides from tests.integration.types import mysql, postgres diff --git a/tests/integration/types/mysql.py b/tests/integration/types/mysql.py index c65dd80d..910e6b98 100644 --- a/tests/integration/types/mysql.py +++ b/tests/integration/types/mysql.py @@ -3,12 +3,12 @@ from collections.abc import Awaitable, Callable from typing import TYPE_CHECKING, Annotated, Any, TypeAlias, cast +import strawberry from pydantic import AfterValidator +from sqlalchemy import Select, select from strawberry.extensions.field_extension import FieldExtension from typing_extensions import override -import strawberry -from sqlalchemy import Select, select from strawchemy import ( Input, InputValidationError, @@ -20,7 +20,7 @@ StrawchemySyncRepository, ValidationErrorType, ) -from strawchemy.types import DefaultOffsetPagination +from strawchemy.schema.pagination import DefaultOffsetPagination from strawchemy.validation.pydantic import PydanticValidation from tests.integration.models import Color, DateTimeModel, Fruit, FruitFarm, IntervalModel, JSONModel, RankedUser, User @@ -31,7 +31,7 @@ from sqlalchemy.orm import Session from sqlalchemy.orm.util import AliasedClass - from strawchemy.sqlalchemy.hook import LoadType + from strawchemy.transpiler.hook import LoadType SyncExtensionResolver: TypeAlias = Callable[..., Any] AsyncExtensionResolver: TypeAlias = Callable[..., Awaitable[Any]] diff --git a/tests/integration/types/postgres.py b/tests/integration/types/postgres.py index 98fdefd5..12c688a7 100644 --- a/tests/integration/types/postgres.py +++ b/tests/integration/types/postgres.py @@ -3,12 +3,12 @@ from collections.abc import Awaitable, Callable from typing import TYPE_CHECKING, Annotated, Any, TypeAlias, cast +import strawberry from pydantic import AfterValidator +from sqlalchemy import Select, select from strawberry.extensions.field_extension import FieldExtension from typing_extensions import override -import strawberry -from sqlalchemy import Select, select from strawchemy import ( Input, InputValidationError, @@ -19,7 +19,7 @@ StrawchemySyncRepository, ValidationErrorType, ) -from strawchemy.types import DefaultOffsetPagination +from strawchemy.schema.pagination import DefaultOffsetPagination from strawchemy.validation.pydantic import PydanticValidation from tests.integration.models import ( ArrayModel, @@ -40,7 +40,7 @@ from sqlalchemy.orm import Session from sqlalchemy.orm.util import AliasedClass - from strawchemy.sqlalchemy.hook import LoadType + from strawchemy.transpiler.hook import LoadType SyncExtensionResolver: TypeAlias = Callable[..., Any] AsyncExtensionResolver: TypeAlias = Callable[..., Awaitable[Any]] diff --git a/tests/integration/types/sqlite.py b/tests/integration/types/sqlite.py index 290faa44..d090a8e5 100644 --- a/tests/integration/types/sqlite.py +++ b/tests/integration/types/sqlite.py @@ -3,12 +3,12 @@ from collections.abc import Awaitable, Callable from typing import TYPE_CHECKING, Annotated, Any, TypeAlias, cast +import strawberry from pydantic import AfterValidator +from sqlalchemy import Select, select from strawberry.extensions.field_extension import FieldExtension from typing_extensions import override -import strawberry -from sqlalchemy import Select, select from strawchemy import ( Input, InputValidationError, @@ -19,7 +19,7 @@ StrawchemySyncRepository, ValidationErrorType, ) -from strawchemy.types import DefaultOffsetPagination +from strawchemy.schema.pagination import DefaultOffsetPagination from strawchemy.validation.pydantic import PydanticValidation from tests.integration.models import Color, DateTimeModel, Fruit, FruitFarm, IntervalModel, JSONModel, RankedUser, User @@ -30,7 +30,7 @@ from sqlalchemy.orm import Session from sqlalchemy.orm.util import AliasedClass - from strawchemy.sqlalchemy.hook import LoadType + from strawchemy.transpiler.hook import LoadType SyncExtensionResolver: TypeAlias = Callable[..., Any] AsyncExtensionResolver: TypeAlias = Callable[..., Awaitable[Any]] diff --git a/tests/integration/utils.py b/tests/integration/utils.py index aa7e4610..d6536439 100644 --- a/tests/integration/utils.py +++ b/tests/integration/utils.py @@ -9,8 +9,8 @@ from uuid import UUID from pydantic import TypeAdapter - from sqlalchemy import inspect + from tests.integration.types import postgres as pg_types if TYPE_CHECKING: diff --git a/tests/unit/dc_models.py b/tests/unit/dc_models.py index 4879fa33..fe66deed 100644 --- a/tests/unit/dc_models.py +++ b/tests/unit/dc_models.py @@ -2,10 +2,10 @@ from uuid import UUID, uuid4 +from sqlalchemy import ForeignKey from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, column_property, mapped_column, relationship -from sqlalchemy import ForeignKey from strawchemy.dto import Purpose, PurposeConfig, field from strawchemy.dto.utils import WRITE_ONLY from tests.unit.models import validate_tomato_type diff --git a/tests/unit/mapping/test_schemas.py b/tests/unit/mapping/test_schemas.py index 7ddd5159..ae1d08f6 100644 --- a/tests/unit/mapping/test_schemas.py +++ b/tests/unit/mapping/test_schemas.py @@ -8,18 +8,15 @@ from typing import TYPE_CHECKING, Any import pytest +import strawberry +from strawberry import auto +from strawberry.scalars import JSON from strawberry.types import get_object_definition from strawberry.types.object_type import StrawberryObjectDefinition from syrupy.assertion import SnapshotAssertion -import strawberry -from strawberry import auto -from strawberry.scalars import JSON -from strawchemy.dto.exceptions import EmptyDTOError -from strawchemy.exceptions import StrawchemyError -from strawchemy.sqlalchemy.exceptions import QueryHookError -from strawchemy.strawberry.exceptions import StrawchemyFieldError -from strawchemy.strawberry.scalars import Interval +from strawchemy.exceptions import EmptyDTOError, QueryHookError, StrawchemyError, StrawchemyFieldError +from strawchemy.schema.scalars import Interval from strawchemy.testing.pytest_plugin import MockContext from tests.fixtures import DefaultQuery from tests.unit.models import Book as BookModel @@ -62,9 +59,9 @@ class InputType: id: auto name: auto - user = InputType(id=1, name="user") - assert user.id == 1 - assert user.name == "user" + user = InputType(id=1, name="user") # pyright: ignore[reportCallIssue] + assert user.id == 1 # pyright: ignore[reportAttributeAccessIssue] + assert user.name == "user" # pyright: ignore[reportAttributeAccessIssue] def test_field_metadata_default(strawchemy: Strawchemy) -> None: @@ -220,7 +217,7 @@ def test_query_schemas(path: str, graphql_snapshot: SnapshotAssertion) -> None: @pytest.mark.snapshot @pytest.mark.skipif(not find_spec("geoalchemy2"), reason="geoalchemy2 is not installed") def test_geo_schemas(path: str, graphql_snapshot: SnapshotAssertion) -> None: - from strawchemy.strawberry.geo import GEO_SCALAR_OVERRIDES + from strawchemy.schema.scalars.geo import GEO_SCALAR_OVERRIDES module, query_name = f"tests.unit.schemas.{path}".rsplit(".", maxsplit=1) query_class = getattr(import_module(module), query_name) diff --git a/tests/unit/models.py b/tests/unit/models.py index 61deab78..c2832440 100644 --- a/tests/unit/models.py +++ b/tests/unit/models.py @@ -8,11 +8,11 @@ from typing import Any from uuid import UUID, uuid4 +from sqlalchemy import VARCHAR, Column, DateTime, Enum, ForeignKey, Table, Text, UniqueConstraint from sqlalchemy.dialects import postgresql from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import DeclarativeBase, Mapped, column_property, mapped_column, relationship -from sqlalchemy import VARCHAR, Column, DateTime, Enum, ForeignKey, Table, Text, UniqueConstraint from strawchemy.constants import GEO_INSTALLED from strawchemy.dto.types import Purpose, PurposeConfig from strawchemy.dto.utils import PRIVATE, READ_ONLY, WRITE_ONLY, field diff --git a/tests/unit/schemas/aggregations/root_aggregations.py b/tests/unit/schemas/aggregations/root_aggregations.py index 88f94773..bcbb94b4 100644 --- a/tests/unit/schemas/aggregations/root_aggregations.py +++ b/tests/unit/schemas/aggregations/root_aggregations.py @@ -1,6 +1,7 @@ from __future__ import annotations import strawberry + from strawchemy import Strawchemy from tests.unit.models import SQLDataTypes diff --git a/tests/unit/schemas/aggregations/type_mismatch.py b/tests/unit/schemas/aggregations/type_mismatch.py index 7719f452..14e5efb8 100644 --- a/tests/unit/schemas/aggregations/type_mismatch.py +++ b/tests/unit/schemas/aggregations/type_mismatch.py @@ -1,6 +1,7 @@ from __future__ import annotations import strawberry + from strawchemy import Strawchemy from tests.unit.models import Color diff --git a/tests/unit/schemas/custom_id_field_name.py b/tests/unit/schemas/custom_id_field_name.py index a9795f48..e25e314d 100644 --- a/tests/unit/schemas/custom_id_field_name.py +++ b/tests/unit/schemas/custom_id_field_name.py @@ -4,6 +4,7 @@ import strawberry from strawberry import Info, auto + from strawchemy import Strawchemy, StrawchemyConfig from tests.unit.models import Color, Fruit diff --git a/tests/unit/schemas/distinct.py b/tests/unit/schemas/distinct.py index f9a0876f..03d91171 100644 --- a/tests/unit/schemas/distinct.py +++ b/tests/unit/schemas/distinct.py @@ -1,6 +1,7 @@ from __future__ import annotations import strawberry + from strawchemy import Strawchemy from tests.unit.models import Color diff --git a/tests/unit/schemas/enums.py b/tests/unit/schemas/enums.py index 1f200861..dad5db5e 100644 --- a/tests/unit/schemas/enums.py +++ b/tests/unit/schemas/enums.py @@ -2,6 +2,7 @@ import strawberry from strawberry import auto + from strawchemy import Strawchemy from tests.unit.models import Vegetable diff --git a/tests/unit/schemas/exclude/exclude_and_override_field.py b/tests/unit/schemas/exclude/exclude_and_override_field.py index e9ce66bf..07b729ad 100644 --- a/tests/unit/schemas/exclude/exclude_and_override_field.py +++ b/tests/unit/schemas/exclude/exclude_and_override_field.py @@ -2,6 +2,7 @@ import strawberry from strawberry import auto + from strawchemy import Strawchemy from tests.unit.models import Fruit diff --git a/tests/unit/schemas/exclude/exclude_and_override_type.py b/tests/unit/schemas/exclude/exclude_and_override_type.py index c4f22e06..3b59c0a0 100644 --- a/tests/unit/schemas/exclude/exclude_and_override_type.py +++ b/tests/unit/schemas/exclude/exclude_and_override_type.py @@ -1,6 +1,7 @@ from __future__ import annotations import strawberry + from strawchemy import Strawchemy from tests.unit.models import Fruit diff --git a/tests/unit/schemas/exclude/exclude_explicit.py b/tests/unit/schemas/exclude/exclude_explicit.py index c726fce3..8976c91a 100644 --- a/tests/unit/schemas/exclude/exclude_explicit.py +++ b/tests/unit/schemas/exclude/exclude_explicit.py @@ -1,6 +1,7 @@ from __future__ import annotations import strawberry + from strawchemy import Strawchemy from tests.unit.models import Fruit diff --git a/tests/unit/schemas/exclude/exclude_non_existent.py b/tests/unit/schemas/exclude/exclude_non_existent.py index 64b8fc6e..2780c395 100644 --- a/tests/unit/schemas/exclude/exclude_non_existent.py +++ b/tests/unit/schemas/exclude/exclude_non_existent.py @@ -1,6 +1,7 @@ from __future__ import annotations import strawberry + from strawchemy import Strawchemy from tests.unit.models import Fruit diff --git a/tests/unit/schemas/filters/filters.py b/tests/unit/schemas/filters/filters.py index f386126c..e3876c07 100644 --- a/tests/unit/schemas/filters/filters.py +++ b/tests/unit/schemas/filters/filters.py @@ -1,6 +1,7 @@ from __future__ import annotations import strawberry + from strawchemy import Strawchemy from tests.unit.models import SQLDataTypes diff --git a/tests/unit/schemas/filters/filters_aggregation.py b/tests/unit/schemas/filters/filters_aggregation.py index 681447fa..22d3e6de 100644 --- a/tests/unit/schemas/filters/filters_aggregation.py +++ b/tests/unit/schemas/filters/filters_aggregation.py @@ -1,6 +1,7 @@ from __future__ import annotations import strawberry + from strawchemy import Strawchemy from tests.unit.models import Group diff --git a/tests/unit/schemas/filters/filters_base_array.py b/tests/unit/schemas/filters/filters_base_array.py index c3c3bc86..b5e6fb79 100644 --- a/tests/unit/schemas/filters/filters_base_array.py +++ b/tests/unit/schemas/filters/filters_base_array.py @@ -1,9 +1,9 @@ from __future__ import annotations -from sqlalchemy.orm import Mapped, mapped_column - import strawberry from sqlalchemy import ARRAY, Text +from sqlalchemy.orm import Mapped, mapped_column + from strawchemy import Strawchemy from tests.unit.models import UUIDBase diff --git a/tests/unit/schemas/filters/filters_base_json.py b/tests/unit/schemas/filters/filters_base_json.py index a8c8a863..ea6f5c6a 100644 --- a/tests/unit/schemas/filters/filters_base_json.py +++ b/tests/unit/schemas/filters/filters_base_json.py @@ -2,10 +2,10 @@ from typing import Any -from sqlalchemy.orm import Mapped, mapped_column - import strawberry from sqlalchemy import JSON +from sqlalchemy.orm import Mapped, mapped_column + from strawchemy import Strawchemy from tests.unit.models import UUIDBase diff --git a/tests/unit/schemas/filters/type_filter.py b/tests/unit/schemas/filters/type_filter.py index 3f6d1a2a..a34e8592 100644 --- a/tests/unit/schemas/filters/type_filter.py +++ b/tests/unit/schemas/filters/type_filter.py @@ -1,6 +1,7 @@ from __future__ import annotations import strawberry + from strawchemy import Strawchemy from tests.unit.models import SQLDataTypes diff --git a/tests/unit/schemas/geo/geo.py b/tests/unit/schemas/geo/geo.py index b376a715..5ff25bc9 100644 --- a/tests/unit/schemas/geo/geo.py +++ b/tests/unit/schemas/geo/geo.py @@ -1,6 +1,7 @@ from __future__ import annotations import strawberry + from strawchemy import Strawchemy from tests.unit.models import GeoModel diff --git a/tests/unit/schemas/geo/geo_filters.py b/tests/unit/schemas/geo/geo_filters.py index 1d572215..e898336f 100644 --- a/tests/unit/schemas/geo/geo_filters.py +++ b/tests/unit/schemas/geo/geo_filters.py @@ -1,6 +1,7 @@ from __future__ import annotations import strawberry + from strawchemy import Strawchemy from tests.unit.models import GeoModel diff --git a/tests/unit/schemas/include/all_fields.py b/tests/unit/schemas/include/all_fields.py index e3c6b9a5..b46323bf 100644 --- a/tests/unit/schemas/include/all_fields.py +++ b/tests/unit/schemas/include/all_fields.py @@ -1,6 +1,7 @@ from __future__ import annotations import strawberry + from strawchemy import Strawchemy from tests.unit.models import Fruit diff --git a/tests/unit/schemas/include/all_fields_filter.py b/tests/unit/schemas/include/all_fields_filter.py index 992e4336..1d97ed80 100644 --- a/tests/unit/schemas/include/all_fields_filter.py +++ b/tests/unit/schemas/include/all_fields_filter.py @@ -1,6 +1,7 @@ from __future__ import annotations import strawberry + from strawchemy import Strawchemy from tests.unit.models import Fruit diff --git a/tests/unit/schemas/include/all_fields_override.py b/tests/unit/schemas/include/all_fields_override.py index 85da8a59..89cc64b7 100644 --- a/tests/unit/schemas/include/all_fields_override.py +++ b/tests/unit/schemas/include/all_fields_override.py @@ -1,6 +1,7 @@ from __future__ import annotations import strawberry + from strawchemy import Strawchemy from tests.unit.models import Color, Fruit diff --git a/tests/unit/schemas/include/all_order_by.py b/tests/unit/schemas/include/all_order_by.py index db8ebefa..85fa06b0 100644 --- a/tests/unit/schemas/include/all_order_by.py +++ b/tests/unit/schemas/include/all_order_by.py @@ -1,6 +1,7 @@ from __future__ import annotations import strawberry + from strawchemy import Strawchemy from tests.unit.models import Fruit diff --git a/tests/unit/schemas/include/include_explicit.py b/tests/unit/schemas/include/include_explicit.py index 8df228cc..fb588cc9 100644 --- a/tests/unit/schemas/include/include_explicit.py +++ b/tests/unit/schemas/include/include_explicit.py @@ -1,6 +1,7 @@ from __future__ import annotations import strawberry + from strawchemy import Strawchemy from tests.unit.models import Fruit diff --git a/tests/unit/schemas/include/include_non_existent.py b/tests/unit/schemas/include/include_non_existent.py index 425eccbc..bae5b6e3 100644 --- a/tests/unit/schemas/include/include_non_existent.py +++ b/tests/unit/schemas/include/include_non_existent.py @@ -1,6 +1,7 @@ from __future__ import annotations import strawberry + from strawchemy import Strawchemy from tests.unit.models import Fruit diff --git a/tests/unit/schemas/mutation_and_query.py b/tests/unit/schemas/mutation_and_query.py index 911f4669..157ab28e 100644 --- a/tests/unit/schemas/mutation_and_query.py +++ b/tests/unit/schemas/mutation_and_query.py @@ -1,6 +1,7 @@ from __future__ import annotations import strawberry + from strawchemy import Strawchemy from tests.unit.models import Color, Fruit diff --git a/tests/unit/schemas/mutations/create.py b/tests/unit/schemas/mutations/create.py index f65a3e5f..000058b9 100644 --- a/tests/unit/schemas/mutations/create.py +++ b/tests/unit/schemas/mutations/create.py @@ -1,6 +1,7 @@ from __future__ import annotations import strawberry + from strawchemy import Strawchemy from tests.unit.models import Group, SQLDataTypes diff --git a/tests/unit/schemas/mutations/create_no_id.py b/tests/unit/schemas/mutations/create_no_id.py index d44535b5..3ce54821 100644 --- a/tests/unit/schemas/mutations/create_no_id.py +++ b/tests/unit/schemas/mutations/create_no_id.py @@ -1,6 +1,7 @@ from __future__ import annotations import strawberry + from strawchemy import Strawchemy from tests.unit.models import Group diff --git a/tests/unit/schemas/mutations/delete.py b/tests/unit/schemas/mutations/delete.py index 37fe6488..f4b0e930 100644 --- a/tests/unit/schemas/mutations/delete.py +++ b/tests/unit/schemas/mutations/delete.py @@ -1,6 +1,7 @@ from __future__ import annotations import strawberry + from strawchemy import Strawchemy from tests.unit.models import Group diff --git a/tests/unit/schemas/mutations/delete_mutation_type_not_list.py b/tests/unit/schemas/mutations/delete_mutation_type_not_list.py index d7e43282..a5c8f3bf 100644 --- a/tests/unit/schemas/mutations/delete_mutation_type_not_list.py +++ b/tests/unit/schemas/mutations/delete_mutation_type_not_list.py @@ -1,6 +1,7 @@ from __future__ import annotations import strawberry + from strawchemy import Strawchemy from tests.unit.models import Group diff --git a/tests/unit/schemas/mutations/invalid_filter_update_field.py b/tests/unit/schemas/mutations/invalid_filter_update_field.py index 9c97247a..729e214e 100644 --- a/tests/unit/schemas/mutations/invalid_filter_update_field.py +++ b/tests/unit/schemas/mutations/invalid_filter_update_field.py @@ -1,6 +1,7 @@ from __future__ import annotations import strawberry + from strawchemy import Strawchemy from tests.unit.models import Group diff --git a/tests/unit/schemas/mutations/read_only_pk_with_update_input.py b/tests/unit/schemas/mutations/read_only_pk_with_update_input.py index 740ea1c0..ee782294 100644 --- a/tests/unit/schemas/mutations/read_only_pk_with_update_input.py +++ b/tests/unit/schemas/mutations/read_only_pk_with_update_input.py @@ -2,9 +2,9 @@ from uuid import UUID, uuid4 +from sqlalchemy import ForeignKey from sqlalchemy.orm import Mapped, mapped_column, relationship -from sqlalchemy import ForeignKey from strawchemy import Strawchemy from strawchemy.dto.utils import READ_ONLY from tests.unit.models import UUIDBase diff --git a/tests/unit/schemas/mutations/update.py b/tests/unit/schemas/mutations/update.py index fbef1ff3..972c5991 100644 --- a/tests/unit/schemas/mutations/update.py +++ b/tests/unit/schemas/mutations/update.py @@ -1,6 +1,7 @@ from __future__ import annotations import strawberry + from strawchemy import Strawchemy from tests.unit.models import Group, SQLDataTypes, Tag diff --git a/tests/unit/schemas/mutations/upsert.py b/tests/unit/schemas/mutations/upsert.py index 9281a574..b0b3e6e9 100644 --- a/tests/unit/schemas/mutations/upsert.py +++ b/tests/unit/schemas/mutations/upsert.py @@ -1,6 +1,7 @@ from __future__ import annotations import strawberry + from strawchemy import Strawchemy from tests.unit.models import Fruit diff --git a/tests/unit/schemas/order/auto_order_by.py b/tests/unit/schemas/order/auto_order_by.py index a178a232..a8a96e21 100644 --- a/tests/unit/schemas/order/auto_order_by.py +++ b/tests/unit/schemas/order/auto_order_by.py @@ -1,6 +1,7 @@ from __future__ import annotations import strawberry + from strawchemy import Strawchemy from tests.unit.models import Group diff --git a/tests/unit/schemas/order/field_order_by.py b/tests/unit/schemas/order/field_order_by.py index 300c0ab0..4cb8bb1f 100644 --- a/tests/unit/schemas/order/field_order_by.py +++ b/tests/unit/schemas/order/field_order_by.py @@ -1,6 +1,7 @@ from __future__ import annotations import strawberry + from strawchemy import Strawchemy from tests.unit.models import SQLDataTypes diff --git a/tests/unit/schemas/order/type_order_by.py b/tests/unit/schemas/order/type_order_by.py index 02ffacdc..022c8828 100644 --- a/tests/unit/schemas/order/type_order_by.py +++ b/tests/unit/schemas/order/type_order_by.py @@ -1,6 +1,7 @@ from __future__ import annotations import strawberry + from strawchemy import Strawchemy from tests.unit.models import SQLDataTypes diff --git a/tests/unit/schemas/override/auto_type_existing.py b/tests/unit/schemas/override/auto_type_existing.py index ec935b74..f9e3d0e3 100644 --- a/tests/unit/schemas/override/auto_type_existing.py +++ b/tests/unit/schemas/override/auto_type_existing.py @@ -4,6 +4,7 @@ import strawberry from strawberry import Info, auto + from strawchemy import Strawchemy from tests.unit.models import Color, Fruit diff --git a/tests/unit/schemas/override/nested_overrides.py b/tests/unit/schemas/override/nested_overrides.py index 771b40d4..63901f3a 100644 --- a/tests/unit/schemas/override/nested_overrides.py +++ b/tests/unit/schemas/override/nested_overrides.py @@ -1,6 +1,7 @@ from __future__ import annotations import strawberry + from strawchemy import Strawchemy from tests.unit.models import Group, Tag, User diff --git a/tests/unit/schemas/override/override_argument.py b/tests/unit/schemas/override/override_argument.py index 836c9496..a41b6520 100644 --- a/tests/unit/schemas/override/override_argument.py +++ b/tests/unit/schemas/override/override_argument.py @@ -1,6 +1,7 @@ from __future__ import annotations import strawberry + from strawchemy import Strawchemy from tests.unit.models import Fruit diff --git a/tests/unit/schemas/override/override_auto_type.py b/tests/unit/schemas/override/override_auto_type.py index 68928b36..8cf80710 100644 --- a/tests/unit/schemas/override/override_auto_type.py +++ b/tests/unit/schemas/override/override_auto_type.py @@ -2,6 +2,7 @@ import strawberry from strawberry import auto + from strawchemy import Strawchemy from tests.unit.models import Color, Fruit diff --git a/tests/unit/schemas/override/override_with_custom_name.py b/tests/unit/schemas/override/override_with_custom_name.py index cd1ac5ea..acc8f4aa 100644 --- a/tests/unit/schemas/override/override_with_custom_name.py +++ b/tests/unit/schemas/override/override_with_custom_name.py @@ -2,6 +2,7 @@ import strawberry from strawberry import auto + from strawchemy import Strawchemy from tests.unit.models import Color, Fruit diff --git a/tests/unit/schemas/pagination/children_pagination.py b/tests/unit/schemas/pagination/children_pagination.py index 46b2d9b0..2bcd129d 100644 --- a/tests/unit/schemas/pagination/children_pagination.py +++ b/tests/unit/schemas/pagination/children_pagination.py @@ -1,6 +1,7 @@ from __future__ import annotations import strawberry + from strawchemy import Strawchemy from tests.unit.models import Fruit diff --git a/tests/unit/schemas/pagination/children_pagination_defaults.py b/tests/unit/schemas/pagination/children_pagination_defaults.py index 04fe63b1..9c81a1c5 100644 --- a/tests/unit/schemas/pagination/children_pagination_defaults.py +++ b/tests/unit/schemas/pagination/children_pagination_defaults.py @@ -1,8 +1,9 @@ from __future__ import annotations import strawberry + from strawchemy import Strawchemy -from strawchemy.types import DefaultOffsetPagination +from strawchemy.schema.pagination import DefaultOffsetPagination from tests.unit.models import Fruit strawchemy = Strawchemy("postgresql") diff --git a/tests/unit/schemas/pagination/pagination.py b/tests/unit/schemas/pagination/pagination.py index f07069ec..ff385a3a 100644 --- a/tests/unit/schemas/pagination/pagination.py +++ b/tests/unit/schemas/pagination/pagination.py @@ -1,6 +1,7 @@ from __future__ import annotations import strawberry + from strawchemy import Strawchemy from tests.unit.models import Fruit diff --git a/tests/unit/schemas/pagination/pagination_config_default.py b/tests/unit/schemas/pagination/pagination_config_default.py index 55b72fee..46134b59 100644 --- a/tests/unit/schemas/pagination/pagination_config_default.py +++ b/tests/unit/schemas/pagination/pagination_config_default.py @@ -1,6 +1,7 @@ from __future__ import annotations import strawberry + from strawchemy import Strawchemy, StrawchemyConfig from tests.unit.models import Fruit diff --git a/tests/unit/schemas/pagination/pagination_default_limit.py b/tests/unit/schemas/pagination/pagination_default_limit.py index 27ddad1f..d815418c 100644 --- a/tests/unit/schemas/pagination/pagination_default_limit.py +++ b/tests/unit/schemas/pagination/pagination_default_limit.py @@ -1,6 +1,7 @@ from __future__ import annotations import strawberry + from strawchemy import Strawchemy, StrawchemyConfig from tests.unit.models import Fruit diff --git a/tests/unit/schemas/pagination/pagination_defaults.py b/tests/unit/schemas/pagination/pagination_defaults.py index 0deac33f..a4f0cef4 100644 --- a/tests/unit/schemas/pagination/pagination_defaults.py +++ b/tests/unit/schemas/pagination/pagination_defaults.py @@ -1,8 +1,9 @@ from __future__ import annotations import strawberry + from strawchemy import Strawchemy -from strawchemy.types import DefaultOffsetPagination +from strawchemy.schema.pagination import DefaultOffsetPagination from tests.unit.models import Fruit strawchemy = Strawchemy("postgresql") diff --git a/tests/unit/schemas/pydantic/validation.py b/tests/unit/schemas/pydantic/validation.py index c8b39e18..c44ea719 100644 --- a/tests/unit/schemas/pydantic/validation.py +++ b/tests/unit/schemas/pydantic/validation.py @@ -2,9 +2,9 @@ from typing import Annotated +import strawberry from pydantic import AfterValidator -import strawberry from strawchemy import Input, InputValidationError, Strawchemy, StrawchemySyncRepository, ValidationErrorType from strawchemy.validation.pydantic import PydanticValidation from tests.unit.models import Group, User diff --git a/tests/unit/schemas/query_hooks.py b/tests/unit/schemas/query_hooks.py index c3a60734..6df10466 100644 --- a/tests/unit/schemas/query_hooks.py +++ b/tests/unit/schemas/query_hooks.py @@ -1,6 +1,7 @@ from __future__ import annotations import strawberry + from strawchemy import QueryHook, Strawchemy from tests.unit.models import Fruit diff --git a/tests/unit/schemas/resolver/custom_resolver.py b/tests/unit/schemas/resolver/custom_resolver.py index 6a3a023a..c26329d2 100644 --- a/tests/unit/schemas/resolver/custom_resolver.py +++ b/tests/unit/schemas/resolver/custom_resolver.py @@ -4,6 +4,7 @@ import strawberry from strawberry import Info, auto + from strawchemy import Strawchemy from tests.unit.models import Color, Fruit diff --git a/tests/unit/schemas/resolver/list_resolver.py b/tests/unit/schemas/resolver/list_resolver.py index 40850427..512e3ab9 100644 --- a/tests/unit/schemas/resolver/list_resolver.py +++ b/tests/unit/schemas/resolver/list_resolver.py @@ -4,6 +4,7 @@ import strawberry from strawberry import Info, auto + from strawchemy import Strawchemy from tests.unit.models import Color, Fruit diff --git a/tests/unit/schemas/resolver/primary_key_resolver.py b/tests/unit/schemas/resolver/primary_key_resolver.py index b5def013..861c6d60 100644 --- a/tests/unit/schemas/resolver/primary_key_resolver.py +++ b/tests/unit/schemas/resolver/primary_key_resolver.py @@ -4,6 +4,7 @@ import strawberry from strawberry import Info, auto + from strawchemy import Strawchemy from tests.unit.models import Color, Fruit diff --git a/tests/unit/schemas/scope/schema_after.py b/tests/unit/schemas/scope/schema_after.py index c4c7ca46..af043d18 100644 --- a/tests/unit/schemas/scope/schema_after.py +++ b/tests/unit/schemas/scope/schema_after.py @@ -1,6 +1,7 @@ from __future__ import annotations import strawberry + from strawchemy import Strawchemy from tests.unit.models import Group, Tag, User diff --git a/tests/unit/schemas/scope/schema_before.py b/tests/unit/schemas/scope/schema_before.py index db22a9f8..229ac263 100644 --- a/tests/unit/schemas/scope/schema_before.py +++ b/tests/unit/schemas/scope/schema_before.py @@ -1,6 +1,7 @@ from __future__ import annotations import strawberry + from strawchemy import Strawchemy from tests.unit.models import Group, Tag, User diff --git a/tests/unit/schemas/scope/schema_in_the_middle.py b/tests/unit/schemas/scope/schema_in_the_middle.py index ba8b17c0..c4dde334 100644 --- a/tests/unit/schemas/scope/schema_in_the_middle.py +++ b/tests/unit/schemas/scope/schema_in_the_middle.py @@ -1,6 +1,7 @@ from __future__ import annotations import strawberry + from strawchemy import Strawchemy from tests.unit.models import Group, Tag, User diff --git a/tests/unit/test_metadata.py b/tests/unit/test_metadata.py new file mode 100644 index 00000000..4879d157 --- /dev/null +++ b/tests/unit/test_metadata.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +from importlib.metadata import metadata, version + +from strawchemy.__metadata__ import __project__, __version__ + + +def test_version() -> None: + assert version("strawchemy") == __version__ + + +def test_project() -> None: + assert metadata("strawchemy")["Name"] == __project__ diff --git a/tests/unit/test_mutation_input.py b/tests/unit/test_mutation_input.py index b59eff52..5fe2ebf9 100644 --- a/tests/unit/test_mutation_input.py +++ b/tests/unit/test_mutation_input.py @@ -5,7 +5,7 @@ import pytest from strawchemy import Strawchemy -from strawchemy.strawberry.mutation.input import Input +from strawchemy.schema.mutation import Input from tests.unit.dc_models import ColorDataclass, FruitDataclass from tests.unit.models import Color, Fruit @@ -19,9 +19,9 @@ def test_add_non_input_relationships( @strawchemy.create_input(color_model, include="all") class ColorInput: ... - color = ColorInput(name="Blue") + color = ColorInput(name="Blue") # pyright: ignore[reportCallIssue] color_input = Input(color) assert len(color_input.relations) == 0 - color_input.instances[0].fruits.append(fruit_model(name="Apple", color_id=uuid4(), sweetness=1, color=None)) + color_input.instances[0].fruits.append(fruit_model(name="Apple", color_id=uuid4(), sweetness=1, color=None)) # pyright: ignore[reportArgumentType] color_input.add_non_input_relations() assert len(color_input.relations) == 1 diff --git a/tests/unit/test_pytest_plugin.py b/tests/unit/test_pytest_plugin.py index 6061cbf9..6e0ff0d4 100644 --- a/tests/unit/test_pytest_plugin.py +++ b/tests/unit/test_pytest_plugin.py @@ -36,7 +36,7 @@ def test_patch_query_fixture(query: str, pytester: pytest.Pytester) -> None: from strawberry.scalars import JSON from typing import Any from datetime import timedelta - from strawchemy.strawberry.scalars import Interval + from strawchemy.schema.scalars import Interval SCALAR_OVERRIDES: dict[object, Any] = {{dict[str, Any]: JSON, timedelta: Interval}} pytest_plugins = ["strawchemy.testing.pytest_plugin", "pytest_asyncio"] diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 0e7452ac..0833ff12 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -5,7 +5,7 @@ import pytest from strawchemy.exceptions import SessionNotFoundError -from strawchemy.strawberry import default_session_getter +from strawchemy.utils.strawberry import default_session_getter @pytest.mark.parametrize( diff --git a/tests/utils.py b/tests/utils.py index 004f6d98..ccb69798 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -7,13 +7,13 @@ from importlib.util import find_spec from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast, overload +import strawberry from sqlalchemy.ext.asyncio import AsyncSession from typing_extensions import TypeIs, override -import strawberry from strawchemy.dto.base import DTOFactory -from strawchemy.sqlalchemy.inspector import SQLAlchemyInspector -from strawchemy.utils import get_annotations +from strawchemy.dto.inspectors import SQLAlchemyInspector +from strawchemy.utils.annotation import get_annotations from tests.typing import AnyFactory, MappedPydanticFactory if TYPE_CHECKING: @@ -23,7 +23,7 @@ from sqlalchemy.orm import Session from strawberry.types.execution import ExecutionResult - from strawchemy.sqlalchemy.typing import AnySession + from strawchemy.repository.typing import AnySession from strawchemy.typing import DataclassProtocol from tests.typing import AnyQueryExecutor, AsyncQueryExecutor, SyncQueryExecutor