diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index d31bf9b..061b9c9 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -6,11 +6,11 @@ env: on: push: - branches: [main, ci/*, dependabot/*, renovate/*] + branches: [ main, ci/*, dependabot/*, renovate/* ] tags: - "v*.*.*" pull_request: - branches: [main] + branches: [ main ] concurrency: group: ${{ github.head_ref || github.run_id }} @@ -57,7 +57,7 @@ jobs: tests: name: 🔬 ${{ matrix.session.job_name }} - needs: [pre, generate-jobs-tests] + needs: [ pre, generate-jobs-tests ] if: github.ref_type == 'tag' || needs.pre.outputs.should_skip != 'true' runs-on: ubuntu-latest strategy: @@ -121,7 +121,7 @@ jobs: upload-coverage: name: 🆙 Upload Coverage - needs: [tests, generate-jobs-tests] + needs: [ tests, generate-jobs-tests ] runs-on: ubuntu-latest steps: @@ -155,7 +155,7 @@ jobs: upload-test-results: name: 📊 Upload test results - needs: [tests, generate-jobs-tests] + needs: [ tests, generate-jobs-tests ] runs-on: ubuntu-latest strategy: matrix: @@ -196,6 +196,9 @@ jobs: cache: true log_level: debug + - name: Install dependencies + run: mise run uv:install + - name: Run linting run: mise run lint diff --git a/.gitignore b/.gitignore index 4b27c3b..5a127ea 100644 --- a/.gitignore +++ b/.gitignore @@ -192,3 +192,5 @@ mise.local.toml .windsurf CLAUDE.md + +.serena diff --git a/mise.toml b/mise.toml index cfab27b..cc0d28d 100644 --- a/mise.toml +++ b/mise.toml @@ -41,20 +41,15 @@ cleanable_paths = '''{{ # Dependencies # ############### -[tasks._install] +[tasks."uv:install"] description = "Install dependencies" -hide = true run = "uv sync --all-extras --dev" -[tasks."_install:geo"] -description = "Install dependencies with geo extras" -hide = true -run = "uv sync --extra=geo --dev" [tasks.install] description = "Install dependencies and pre-commit hooks" alias = "i" -depends = ["install:pre-commit", "_install"] +depends = ["install:pre-commit", "uv:install"] [tasks."install:pre-commit"] description = "Install pre-commit hooks" @@ -66,86 +61,86 @@ run = "pre-commit install --install-hooks 2>&1" [tasks.test] description = "Run tests" -depends = "_install" +depends = "uv:install" alias = "t" usage = 'arg "" default=""' run = 'uv run pytest {{vars.local_pytest_options}} {{arg(name="test", var=true, default="")}}' [tasks."test:coverage"] description = "Run tests with coverage" -depends = "_install" +depends = "uv:install" alias = "tc" usage = 'arg "" default=""' run = 'uv run pytest {{vars.local_pytest_options}} {{vars.pytest_coverage_options}} {{arg(name="test", var=true, default="")}}' [tasks."test:unit"] description = "Run unit tests" -depends = "_install" +depends = "uv:install" alias = "tu" usage = 'arg "" default=""' run = 'uv run nox -r -P {{option(name="python", default="3.13")}} -t unit -- tests/unit {{vars.local_pytest_options}} {{arg(name="test", var=true, default="")}}' [tasks."test:unit:no-extras"] description = "Run unit tests without extras dependencies" -depends = "_install" +depends = "uv:install" alias = "tug" run = 'uv run nox -r -P {{option(name="python", default="3.13")}} -s unit-no-extras -- tests/unit {{vars.local_pytest_options}} {{arg(name="test", var=true, default="")}}' [tasks."test:unit:coverage"] description = "Run unit tests with coverage" -depends = "_install" +depends = "uv:install" alias = "tuc" usage = 'arg "" default=""' run = 'uv run nox -r -P {{option(name="python", default="3.13")}} -t unit -- tests/unit {{vars.local_pytest_options}} {{vars.pytest_coverage_options}} {{arg(name="test", var=true, default="")}}' [tasks."test:integration"] description = "Run integration tests" -depends = "_install" +depends = "uv:install" alias = "ti" run = 'uv run nox -r -P {{option(name="python", default="3.13")}} -s integration -- {{vars.local_pytest_options}} {{arg(name="test", var=true, default="")}}' [tasks."test:integration-postgres"] description = "Run integration tests" -depends = "_install" +depends = "uv:install" alias = "ti-postgres" run = 'uv run nox -r -P {{option(name="python", default="3.13")}} -t postgres -- {{vars.local_pytest_options}} {{arg(name="test", var=true, default="")}}' [tasks."test:integration-mysql"] description = "Run integration tests" -depends = "_install" +depends = "uv:install" alias = "ti-mysql" run = 'uv run nox -r -P {{option(name="python", default="3.13")}} -t mysql -- {{vars.local_pytest_options}} {{arg(name="test", var=true, default="")}}' [tasks."test:integration-sqlite"] description = "Run integration tests" -depends = "_install" +depends = "uv:install" alias = "ti-sqlite" run = 'uv run nox -r -P {{option(name="python", default="3.13")}} -t sqlite -- {{vars.local_pytest_options}} {{arg(name="test", var=true, default="")}}' [tasks."test:integration:coverage"] description = "Run integration tests with coverage" -depends = "_install" +depends = "uv:install" alias = "tic" usage = 'arg "" default=""' run = 'uv run nox -r -P {{option(name="python", default="3.13")}} -s integration -- {{vars.local_pytest_options}} {{vars.pytest_coverage_options}} {{arg(name="test", var=true, default="")}}' [tasks."test:unit-all"] description = "Run unit tests on all supported python versions" -depends = "_install" +depends = "uv:install" alias = "tua" usage = 'arg "" default=""' run = 'uv run nox -r -P {{option(name="python", default="3.13")}} -t unit -- {{vars.local_pytest_options}} {{arg(name="test", var=true, default="")}}' [tasks."test:integration-all"] description = "Run integration tests on all supported python versions" -depends = "_install" +depends = "uv:install" alias = "tia" usage = 'arg "" default=""' run = 'uv run nox -r -P {{option(name="python", default="3.13")}} -s integration -- {{vars.local_pytest_options}} {{arg(name="test", var=true, default="")}}' [tasks."test:update-snapshots"] description = "Run snapshot-based tests and update snapshots" -depends = "_install" +depends = "uv:install" run = "uv run pytest {{vars.local_pytest_options}} -m snapshot --snapshot-update" # ############### @@ -194,7 +189,6 @@ run = "uv run nox --json -t tests -l | jq 'map(.name) | unique'" [tasks."ruff:check"] description = "Check ruff formatting" -depends = "_install" run = "uv run ruff check" [tasks."ruff:fix"] @@ -211,7 +205,6 @@ run = "uv run ruff format --check" [tasks.pyright] description = "Run basedpyright" -depends = "_install" run = "uv run basedpyright" [tasks.vulture] @@ -221,7 +214,7 @@ run = "uv run --only-group lint vulture" [tasks.lint] description = "Lint the code" alias = "l" -depends = ["vulture", "pyright", "ruff:check", "ruff:format:check"] +depends = ["vulture", "pyright", "ruff:check", "ruff:format:check", "slotscheck"] [tasks."lint:pre-commit"] description = "Lint the code in pre-commit hook" @@ -232,6 +225,10 @@ description = "Run pre-commit checks" depends = "install:pre-commit" run = "pre-commit run --color=always --all-files" +[tasks.slotscheck] +description = "Run slotscheck" +run = "uv run slotscheck src/strawchemy" + # ############### # Tools # ############### @@ -239,7 +236,7 @@ run = "pre-commit run --color=always --all-files" [tasks.auto-bump] description = "Auto bump the version" confirm = "Are you sure you want to auto bump the version?" -depends = "_install" +depends = "uv:install" run = "uv run bump-my-version bump --new-version $(uv run git cliff --unreleased --bumped-version)" [tasks.clean] @@ -247,14 +244,14 @@ description = "Clean working directory" alias = "c" confirm = "Are you sure you want to clean the working directory? This will remove test caches, build artifacts, and other temporary files." run = [ - "rm -rf {{vars.cleanable_paths}} >/dev/null 2>&1", - "find . -name '*.egg-info' -exec rm -rf {} + >/dev/null 2>&1", - "find . -type f -name '*.egg' -exec rm -f {} + >/dev/null 2>&1", - "find . -name '*.pyc' -exec rm -f {} + >/dev/null 2>&1", - "find . -name '*.pyo' -exec rm -f {} + >/dev/null 2>&1", - "find . -name '*~' -exec rm -f {} + >/dev/null 2>&1", - "find . -name '__pycache__' -exec rm -rf {} + >/dev/null 2>&1", - "find . -name '.ipynb_checkpoints' -exec rm -rf {} + >/dev/null 2>&1", + "rm -rf {{vars.cleanable_paths}} >/dev/null 2>&1", + "find . -name '*.egg-info' -exec rm -rf {} + >/dev/null 2>&1", + "find . -type f -name '*.egg' -exec rm -f {} + >/dev/null 2>&1", + "find . -name '*.pyc' -exec rm -f {} + >/dev/null 2>&1", + "find . -name '*.pyo' -exec rm -f {} + >/dev/null 2>&1", + "find . -name '*~' -exec rm -f {} + >/dev/null 2>&1", + "find . -name '__pycache__' -exec rm -rf {} + >/dev/null 2>&1", + "find . -name '.ipynb_checkpoints' -exec rm -rf {} + >/dev/null 2>&1", ] [tasks."render:usage"] diff --git a/pyproject.toml b/pyproject.toml index 3be715e..8a1773f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,7 +63,7 @@ dev = [ ] codeflash = ["codeflash"] doc = ["git-cliff>=2.6.1"] -lint = ["basedpyright", "ruff", "vulture"] +lint = ["basedpyright", "ruff", "vulture", "slotscheck>=0.16.5"] mysql = ["asyncmy", "cryptography"] postgres = ["asyncpg>=0.29.0", "psycopg[binary,pool]>=3.2.3"] test = [ @@ -309,7 +309,6 @@ asyncio_default_fixture_loop_scope = "function" [tool.ruff] line-length = 120 -fix = true target-version = "py310" exclude = [ ".bzr", @@ -378,6 +377,24 @@ known-first-party = ["strawchemy", "tests"] [tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "all" +[tool.slotscheck] +exclude-modules = ''' +( + (^|\.)test_ + |^tests\.* + |^tools\.* + |^docs\.* + |^examples\.* + |^sqlalchemy\.( + testing + |ext\.mypy # see slotscheck/issues/178 + ) +) +''' +include-modules = "strawchemy.*" +require-superclass = false +strict-imports = true + [tool.unasyncd] add_editors_note = true ruff_fix = true diff --git a/src/strawchemy/factories.py b/src/strawchemy/factories.py new file mode 100644 index 0000000..1d23fa3 --- /dev/null +++ b/src/strawchemy/factories.py @@ -0,0 +1,135 @@ +"""Factory container for organizing Strawchemy DTO factories.""" + +from __future__ import annotations + +from dataclasses import dataclass +from functools import partial +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from strawchemy.mapper import Strawchemy + from strawchemy.strawberry.factories.aggregations import EnumDTOFactory + from strawchemy.strawberry.factories.inputs import AggregateFilterDTOFactory, BooleanFilterDTOFactory + from strawchemy.strawberry.factories.types import ( + DistinctOnFieldsDTOFactory, + InputFactory, + OrderByDTOFactory, + RootAggregateTypeDTOFactory, + TypeDTOFactory, + UpsertConflictFieldsDTOFactory, + ) + + +@dataclass +class StrawchemyFactories: + """Container for all Strawchemy DTO factories. + + This class encapsulates the initialization and management of all factory + instances used by Strawchemy, providing a cleaner separation of concerns + and easier testing. + + Attributes: + aggregate_filter: Factory for aggregate filter DTOs. + order_by: Factory for order by DTOs. + distinct_on_enum: Factory for distinct on enum DTOs. + type_factory: Factory for output type DTOs. + input_factory: Factory for input type DTOs. + aggregation: Factory for root aggregate type DTOs. + enum_factory: Factory for enum DTOs. + filter_factory: Factory for boolean filter DTOs. + upsert_conflict: Factory for upsert conflict fields DTOs. + """ + + aggregate_filter: AggregateFilterDTOFactory + order_by: OrderByDTOFactory + distinct_on_enum: DistinctOnFieldsDTOFactory + type_factory: TypeDTOFactory # type: ignore[type-arg] + input_factory: InputFactory # type: ignore[type-arg] + aggregation: RootAggregateTypeDTOFactory # type: ignore[type-arg] + enum_factory: EnumDTOFactory + filter_factory: BooleanFilterDTOFactory + upsert_conflict: UpsertConflictFieldsDTOFactory + + @classmethod + def create(cls, mapper: Strawchemy) -> StrawchemyFactories: + """Create all factories with proper dependencies. + + Args: + mapper: The Strawchemy instance that will own these factories. + + Returns: + A StrawchemyFactories instance with all factories initialized. + """ + # Imports inside method to avoid circular dependencies at module load time + from strawchemy.dto.backend.strawberry import StrawberrryDTOBackend # noqa: PLC0415 + from strawchemy.strawberry.dto import MappedStrawberryGraphQLDTO # noqa: PLC0415 + from strawchemy.strawberry.factories.aggregations import EnumDTOFactory # noqa: PLC0415 + from strawchemy.strawberry.factories.enum import ( # noqa: PLC0415 + EnumDTOBackend, + UpsertConflictFieldsEnumDTOBackend, + ) + from strawchemy.strawberry.factories.inputs import ( # noqa: PLC0415 + AggregateFilterDTOFactory, + BooleanFilterDTOFactory, + ) + from strawchemy.strawberry.factories.types import ( # noqa: PLC0415 + DistinctOnFieldsDTOFactory, + InputFactory, + OrderByDTOFactory, + RootAggregateTypeDTOFactory, + TypeDTOFactory, + UpsertConflictFieldsDTOFactory, + ) + + config = mapper.config + + # Create backend instances + strawberry_backend = StrawberrryDTOBackend(MappedStrawberryGraphQLDTO) + enum_backend = EnumDTOBackend(config.auto_snake_case) + upsert_conflict_fields_enum_backend = UpsertConflictFieldsEnumDTOBackend( + config.inspector, config.auto_snake_case + ) + + # Create factory instances + aggregate_filter = AggregateFilterDTOFactory(mapper) + order_by = OrderByDTOFactory(mapper) + distinct_on_enum = DistinctOnFieldsDTOFactory(config.inspector) + type_factory = TypeDTOFactory(mapper, strawberry_backend, order_by_factory=order_by) + input_factory = InputFactory(mapper, strawberry_backend) + aggregation = RootAggregateTypeDTOFactory(mapper, strawberry_backend, type_factory=type_factory) + enum_factory = EnumDTOFactory(config.inspector, enum_backend) + filter_factory = BooleanFilterDTOFactory(mapper, aggregate_filter_factory=aggregate_filter) + upsert_conflict = UpsertConflictFieldsDTOFactory(config.inspector, upsert_conflict_fields_enum_backend) + + return cls( + aggregate_filter=aggregate_filter, + order_by=order_by, + distinct_on_enum=distinct_on_enum, + type_factory=type_factory, + input_factory=input_factory, + aggregation=aggregation, + enum_factory=enum_factory, + filter_factory=filter_factory, + upsert_conflict=upsert_conflict, + ) + + def create_public_api(self) -> dict[str, Any]: + """Create the public API mappings for factory methods. + + Returns: + A dictionary mapping public API names to factory methods. + """ + return { + "filter": self.filter_factory.input, + "aggregate_filter": partial(self.aggregate_filter.input, mode="aggregate_filter"), + "distinct_on": self.distinct_on_enum.decorator, + "input": self.input_factory.input, + "create_input": partial(self.input_factory.input, mode="create_input"), + "pk_update_input": partial(self.input_factory.input, mode="update_by_pk_input"), + "filter_update_input": partial(self.input_factory.input, mode="update_by_filter_input"), + "order": partial(self.order_by.input, mode="order_by"), + "type": self.type_factory.type, + "aggregate": partial(self.aggregation.type, mode="aggregate_type"), + "upsert_update_fields": self.enum_factory.input, + "upsert_conflict_fields": self.upsert_conflict.input, + } diff --git a/src/strawchemy/mapper.py b/src/strawchemy/mapper.py index c75a348..b824bd6 100644 --- a/src/strawchemy/mapper.py +++ b/src/strawchemy/mapper.py @@ -1,15 +1,15 @@ from __future__ import annotations import dataclasses -from functools import cached_property, partial +from functools import cached_property from typing import TYPE_CHECKING, Any, TypeVar, overload from strawberry.annotation import StrawberryAnnotation from strawberry.schema.config import StrawberryConfig from strawchemy.config.base import StrawchemyConfig -from strawchemy.dto.backend.strawberry import StrawberrryDTOBackend from strawchemy.dto.base import TYPING_NS +from strawchemy.factories import StrawchemyFactories from strawchemy.strawberry._field import ( StrawchemyCreateMutationField, StrawchemyDeleteMutationField, @@ -18,19 +18,9 @@ StrawchemyUpsertMutationField, ) from strawchemy.strawberry._registry import StrawberryRegistry -from strawchemy.strawberry.dto import BooleanFilterDTO, EnumDTO, MappedStrawberryGraphQLDTO, OrderByDTO, OrderByEnum -from strawchemy.strawberry.factories.aggregations import EnumDTOFactory -from strawchemy.strawberry.factories.enum import EnumDTOBackend, UpsertConflictFieldsEnumDTOBackend -from strawchemy.strawberry.factories.inputs import AggregateFilterDTOFactory, BooleanFilterDTOFactory -from strawchemy.strawberry.factories.types import ( - DistinctOnFieldsDTOFactory, - InputFactory, - OrderByDTOFactory, - RootAggregateTypeDTOFactory, - TypeDTOFactory, - UpsertConflictFieldsDTOFactory, -) +from strawchemy.strawberry.dto import BooleanFilterDTO, EnumDTO, OrderByDTO, OrderByEnum from strawchemy.strawberry.mutation import types +from strawchemy.strawberry.mutation.builder import MutationFieldBuilder from strawchemy.types import DefaultOffsetPagination if TYPE_CHECKING: @@ -100,38 +90,37 @@ def __init__( self.config = StrawchemyConfig(config) if isinstance(config, str) else config self.registry = StrawberryRegistry(strawberry_config or StrawberryConfig()) - strawberry_backend = StrawberrryDTOBackend(MappedStrawberryGraphQLDTO) - enum_backend = EnumDTOBackend(self.config.auto_snake_case) - upsert_conflict_fields_enum_backend = UpsertConflictFieldsEnumDTOBackend( - self.config.inspector, self.config.auto_snake_case - ) - - self._aggregate_filter_factory = AggregateFilterDTOFactory(self) - self._order_by_factory = OrderByDTOFactory(self) - self._distinct_on_enum_factory = DistinctOnFieldsDTOFactory(self.config.inspector) - self._type_factory = TypeDTOFactory(self, strawberry_backend, order_by_factory=self._order_by_factory) - self._input_factory = InputFactory(self, strawberry_backend) - self._aggregation_factory = RootAggregateTypeDTOFactory( - self, strawberry_backend, type_factory=self._type_factory - ) - self._enum_factory = EnumDTOFactory(self.config.inspector, enum_backend) - self._filter_factory = BooleanFilterDTOFactory(self, aggregate_filter_factory=self._aggregate_filter_factory) - self._upsert_conflict_factory = UpsertConflictFieldsDTOFactory( - self.config.inspector, upsert_conflict_fields_enum_backend - ) - - self.filter = self._filter_factory.input - self.aggregate_filter = partial(self._aggregate_filter_factory.input, mode="aggregate_filter") - self.distinct_on = self._distinct_on_enum_factory.decorator - self.input = self._input_factory.input - self.create_input = partial(self._input_factory.input, mode="create_input") - self.pk_update_input = partial(self._input_factory.input, mode="update_by_pk_input") - self.filter_update_input = partial(self._input_factory.input, mode="update_by_filter_input") - self.order = partial(self._order_by_factory.input, mode="order_by") - self.type = self._type_factory.type - self.aggregate = partial(self._aggregation_factory.type, mode="aggregate_type") - self.upsert_update_fields = self._enum_factory.input - self.upsert_conflict_fields = self._upsert_conflict_factory.input + # Initialize all factories through the container + factories = StrawchemyFactories.create(self) + + # Store factory references for internal use + self._aggregate_filter_factory = factories.aggregate_filter + self._order_by_factory = factories.order_by + self._distinct_on_enum_factory = factories.distinct_on_enum + self._type_factory = factories.type_factory + self._input_factory = factories.input_factory + self._aggregation_factory = factories.aggregation + self._enum_factory = factories.enum_factory + self._filter_factory = factories.filter_factory + self._upsert_conflict_factory = factories.upsert_conflict + + # Expose public factory API + public_api = factories.create_public_api() + self.filter = public_api["filter"] + self.aggregate_filter = public_api["aggregate_filter"] + self.distinct_on = public_api["distinct_on"] + self.input = public_api["input"] + self.create_input = public_api["create_input"] + self.pk_update_input = public_api["pk_update_input"] + self.filter_update_input = public_api["filter_update_input"] + self.order = public_api["order"] + self.type = public_api["type"] + self.aggregate = public_api["aggregate"] + self.upsert_update_fields = public_api["upsert_update_fields"] + self.upsert_conflict_fields = public_api["upsert_conflict_fields"] + + # Initialize mutation field builder + self._mutation_builder = MutationFieldBuilder(self.config, self._annotation_namespace) # Register common types self.registry.register_enum(OrderByEnum, "OrderByEnum") @@ -372,30 +361,23 @@ def create( A `StrawchemyCreateMutationField` instance, which is a specialized StrawberryField configured for create mutations. """ - namespace = self._annotation_namespace() - type_annotation = StrawberryAnnotation.from_annotation(graphql_type, namespace) if graphql_type else None - repository_type_ = repository_type if repository_type is not None else self.config.repository_type - - field = StrawchemyCreateMutationField( - input_type, - config=self.config, - repository_type=repository_type_, - python_name=None, - graphql_name=name, - type_annotation=type_annotation, - is_subscription=False, - permission_classes=permission_classes or [], + return self._mutation_builder.build( + StrawchemyCreateMutationField, + resolver, + repository_type=repository_type, + graphql_type=graphql_type, + name=name, + description=description, + permission_classes=permission_classes, deprecation_reason=deprecation_reason, default=default, default_factory=default_factory, metadata=metadata, directives=directives, - extensions=extensions or [], - registry_namespace=namespace, - description=description, + extensions=extensions, + input_type=input_type, validation=validation, ) - return field(resolver) if resolver else field def upsert( self, @@ -455,32 +437,25 @@ def upsert( A `StrawchemyUpsertMutationField` instance, which is a specialized StrawberryField configured for upsert mutations. """ - namespace = self._annotation_namespace() - type_annotation = StrawberryAnnotation.from_annotation(graphql_type, namespace) if graphql_type else None - repository_type_ = repository_type if repository_type is not None else self.config.repository_type - - field = StrawchemyUpsertMutationField( - input_type, - update_fields_enum=update_fields, - conflict_fields_enum=conflict_fields, - config=self.config, - repository_type=repository_type_, - python_name=None, - graphql_name=name, - type_annotation=type_annotation, - is_subscription=False, - permission_classes=permission_classes or [], + return self._mutation_builder.build( + StrawchemyUpsertMutationField, + resolver, + repository_type=repository_type, + graphql_type=graphql_type, + name=name, + description=description, + permission_classes=permission_classes, deprecation_reason=deprecation_reason, default=default, default_factory=default_factory, metadata=metadata, directives=directives, - extensions=extensions or [], - registry_namespace=namespace, - description=description, + extensions=extensions, + input_type=input_type, + update_fields_enum=update_fields, + conflict_fields_enum=conflict_fields, validation=validation, ) - return field(resolver) if resolver else field def update( self, @@ -538,31 +513,24 @@ def update( A `StrawchemyUpdateMutationField` instance, which is a specialized StrawberryField configured for update mutations. """ - namespace = self._annotation_namespace() - type_annotation = StrawberryAnnotation.from_annotation(graphql_type, namespace) if graphql_type else None - repository_type_ = repository_type if repository_type is not None else self.config.repository_type - - field = StrawchemyUpdateMutationField( - config=self.config, - input_type=input_type, - filter_type=filter_input, - repository_type=repository_type_, - python_name=None, - graphql_name=name, - type_annotation=type_annotation, - is_subscription=False, - permission_classes=permission_classes or [], + return self._mutation_builder.build( + StrawchemyUpdateMutationField, + resolver, + repository_type=repository_type, + graphql_type=graphql_type, + name=name, + description=description, + permission_classes=permission_classes, deprecation_reason=deprecation_reason, default=default, default_factory=default_factory, metadata=metadata, directives=directives, - extensions=extensions or [], - registry_namespace=namespace, - description=description, + extensions=extensions, + input_type=input_type, + filter_type=filter_input, validation=validation, ) - return field(resolver) if resolver else field def update_by_ids( self, @@ -618,30 +586,23 @@ def update_by_ids( A `StrawchemyUpdateMutationField` instance, specialized for updates by ID. """ - namespace = self._annotation_namespace() - type_annotation = StrawberryAnnotation.from_annotation(graphql_type, namespace) if graphql_type else None - repository_type_ = repository_type if repository_type is not None else self.config.repository_type - - field = StrawchemyUpdateMutationField( - config=self.config, - input_type=input_type, - repository_type=repository_type_, - python_name=None, - graphql_name=name, - type_annotation=type_annotation, - is_subscription=False, - permission_classes=permission_classes or [], + return self._mutation_builder.build( + StrawchemyUpdateMutationField, + resolver, + repository_type=repository_type, + graphql_type=graphql_type, + name=name, + description=description, + permission_classes=permission_classes, deprecation_reason=deprecation_reason, default=default, default_factory=default_factory, metadata=metadata, directives=directives, - extensions=extensions or [], - registry_namespace=namespace, - description=description, + extensions=extensions, + input_type=input_type, validation=validation, ) - return field(resolver) if resolver else field def delete( self, @@ -696,26 +657,19 @@ def delete( A `StrawchemyDeleteMutationField` instance, which is a specialized StrawberryField configured for delete mutations. """ - namespace = self._annotation_namespace() - type_annotation = StrawberryAnnotation.from_annotation(graphql_type, namespace) if graphql_type else None - repository_type_ = repository_type if repository_type is not None else self.config.repository_type - - field = StrawchemyDeleteMutationField( - filter_input, - config=self.config, - repository_type=repository_type_, - python_name=None, - graphql_name=name, - type_annotation=type_annotation, - is_subscription=False, - permission_classes=permission_classes or [], + return self._mutation_builder.build( + StrawchemyDeleteMutationField, + resolver, + repository_type=repository_type, + graphql_type=graphql_type, + name=name, + description=description, + permission_classes=permission_classes, deprecation_reason=deprecation_reason, default=default, default_factory=default_factory, metadata=metadata, directives=directives, - extensions=extensions or [], - registry_namespace=namespace, - description=description, + extensions=extensions, + input_type=filter_input, ) - return field(resolver) if resolver else field diff --git a/src/strawchemy/strawberry/factories/aggregations.py b/src/strawchemy/strawberry/factories/aggregations.py index 45777fe..eff61e7 100644 --- a/src/strawchemy/strawberry/factories/aggregations.py +++ b/src/strawchemy/strawberry/factories/aggregations.py @@ -1,5 +1,6 @@ from __future__ import annotations +from dataclasses import dataclass, field from datetime import date, datetime, time, timedelta from decimal import Decimal from functools import cached_property @@ -37,6 +38,19 @@ T = TypeVar("T") +@dataclass(frozen=True, slots=True) +class _TypeFilterConfig: + """Configuration for type-filtered DTO factories. + + Attributes: + types: Set of Python types to filter fields by. + suffix: Suffix to append to the base name for DTO naming. + """ + + suffix: str + types: frozenset[type[Any]] = field(default_factory=frozenset) + + class _CountFieldsDTOFactory(EnumDTOFactory): @override def dto_name( @@ -88,10 +102,14 @@ def iter_field_definitions( function: FunctionInfo | None = None, **kwargs: Any, ) -> Generator[DTOFieldDefinition[DeclarativeBase, QueryableAttribute[Any]]]: - for field in super().iter_field_definitions( + for field_def in super().iter_field_definitions( name, model, dto_config, base, node, raise_if_no_fields, field_map=field_map, **kwargs ): - yield (FunctionArgFieldDefinition.from_field(field, function=function) if function is not None else field) + yield ( + FunctionArgFieldDefinition.from_field(field_def, function=function) + if function is not None + else field_def + ) @override def factory( @@ -146,99 +164,36 @@ def enum_factory( return self._enum_backend.build(name, model, list(field_defs), base) -class _NumericFieldsDTOFactory(_FunctionArgDTOFactory): - types: ClassVar[set[type[Any]]] = {int, float, Decimal} - - @override - def dto_name( - self, - base_name: str, - dto_config: DTOConfig, - node: Node[Relation[Any, UnmappedStrawberryGraphQLDTO[ModelT]], None] | None = None, - ) -> str: - return f"{base_name}NumericFields" - - -class _MinMaxFieldsDTOFactory(_FunctionArgDTOFactory): - types: ClassVar[set[type[Any]]] = {int, float, str, Decimal, date, datetime, time} - - @override - def dto_name( - self, - base_name: str, - dto_config: DTOConfig, - node: Node[Relation[Any, UnmappedStrawberryGraphQLDTO[ModelT]], None] | None = None, - ) -> str: - return f"{base_name}MinMaxFields" - - -class _MinMaxDateFieldsDTOFactory(_FunctionArgDTOFactory): - types: ClassVar[set[type[Any]]] = {date} - - @override - def dto_name( - self, - base_name: str, - dto_config: DTOConfig, - node: Node[Relation[Any, UnmappedStrawberryGraphQLDTO[ModelT]], None] | None = None, - ) -> str: - return f"{base_name}MinMaxDateFields" - - -class _MinMaxDateTimeFieldsDTOFactory(_FunctionArgDTOFactory): - types: ClassVar[set[type[Any]]] = {datetime} - - @override - def dto_name( - self, - base_name: str, - dto_config: DTOConfig, - node: Node[Relation[Any, UnmappedStrawberryGraphQLDTO[ModelT]], None] | None = None, - ) -> str: - return f"{base_name}MinMaxDateTimeFields" - - -class _MinMaxNumericFieldsDTOFactory(_FunctionArgDTOFactory): - types: ClassVar[set[type[Any]]] = {int, float, Decimal} +class _TypeFilteredFunctionArgDTOFactory(_FunctionArgDTOFactory): + """Generic factory for type-filtered aggregation field DTOs. - @override - def dto_name( - self, - base_name: str, - dto_config: DTOConfig, - node: Node[Relation[Any, UnmappedStrawberryGraphQLDTO[ModelT]], None] | None = None, - ) -> str: - return f"{base_name}MinMaxNumericFields" + This factory replaces multiple nearly-identical factory classes by using + a configuration object to specify the types and naming suffix. + """ - -class _MinMaxStringFieldsDTOFactory(_FunctionArgDTOFactory): - types: ClassVar[set[type[Any]]] = {str} - - @override - def dto_name( + def __init__( self, - base_name: str, - dto_config: DTOConfig, - node: Node[Relation[Any, UnmappedStrawberryGraphQLDTO[ModelT]], None] | None = None, - ) -> str: - return f"{base_name}MinMaxStringFields" - - -class _MinMaxTimeFieldsDTOFactory(_FunctionArgDTOFactory): - types: ClassVar[set[type[Any]]] = {time} + mapper: Strawchemy, + filter_config: _TypeFilterConfig, + backend: DTOBackend[UnmappedStrawberryGraphQLDTO[DeclarativeBase]] | None = None, + ) -> None: + super().__init__(mapper, backend) + self._filter_types = set(filter_config.types) + self._suffix = filter_config.suffix @override - def dto_name( + def should_exclude_field( self, - base_name: str, + field: DTOFieldDefinition[Any, QueryableAttribute[Any]], dto_config: DTOConfig, - node: Node[Relation[Any, UnmappedStrawberryGraphQLDTO[ModelT]], None] | None = None, - ) -> str: - return f"{base_name}MinMaxTimeFields" - - -class _SumFieldsDTOFactory(_FunctionArgDTOFactory): - types: ClassVar[set[type[Any]]] = {int, float, str, Decimal, timedelta} + node: Node[Relation[Any, UnmappedStrawberryGraphQLDTO[DeclarativeBase]], None], + has_override: bool = False, + ) -> bool: + return ( + super(_FunctionArgDTOFactory, self).should_exclude_field(field, dto_config, node, has_override) + or field.is_relation + or self.inspector.model_field_type(field) not in self._filter_types + ) @override def dto_name( @@ -247,21 +202,30 @@ def dto_name( dto_config: DTOConfig, node: Node[Relation[Any, UnmappedStrawberryGraphQLDTO[ModelT]], None] | None = None, ) -> str: - return f"{base_name}SumFields" + return f"{base_name}{self._suffix}" class AggregationInspector: + _aggregation_type_filters: ClassVar[dict[str, _TypeFilterConfig]] = { + "numeric": _TypeFilterConfig("NumericFields", frozenset({int, float, Decimal})), + "sum": _TypeFilterConfig("SumFields", frozenset({int, float, str, Decimal, timedelta})), + "min_max": _TypeFilterConfig("MinMaxFields", frozenset({int, float, str, Decimal, date, datetime, time})), + "min_max_numeric": _TypeFilterConfig("MinMaxNumericFields", frozenset({int, float, Decimal})), + "min_max_datetime": _TypeFilterConfig("MinMaxDateTimeFields", frozenset({datetime})), + "min_max_date": _TypeFilterConfig("MinMaxDateFields", frozenset({date})), + "min_max_string": _TypeFilterConfig("MinMaxStringFields", frozenset({str})), + "min_max_time": _TypeFilterConfig("MinMaxTimeFields", frozenset({time})), + } + def __init__(self, mapper: Strawchemy) -> None: self._inspector = mapper.config.inspector self._count_fields_factory = _CountFieldsDTOFactory(self._inspector) - self._numeric_fields_factory = _NumericFieldsDTOFactory(mapper) - self._sum_fields_factory = _SumFieldsDTOFactory(mapper) - self._min_max_numeric_fields_factory = _MinMaxNumericFieldsDTOFactory(mapper) - self._min_max_datetime_fields_factory = _MinMaxDateTimeFieldsDTOFactory(mapper) - self._min_max_date_fields_factory = _MinMaxDateFieldsDTOFactory(mapper) - self._min_max_string_fields_factory = _MinMaxStringFieldsDTOFactory(mapper) - self._min_max_time_fields_factory = _MinMaxTimeFieldsDTOFactory(mapper) - self._min_max_fields_factory = _MinMaxFieldsDTOFactory(mapper) + + # Create type-filtered factories from configuration + self._type_filtered_factories: dict[str, _TypeFilteredFunctionArgDTOFactory] = { + key: _TypeFilteredFunctionArgDTOFactory(mapper, config) + for key, config in self._aggregation_type_filters.items() + } def _supports_aggregations(self, *function: AggregationFunction) -> bool: return set(function).issubset(self._inspector.db_features.aggregation_functions) @@ -273,7 +237,7 @@ def _statistical_aggregations(self) -> list[AggregationFunction]: - cast("set[AggregationFunction]", {"min", "max", "sum", "count"}) ) - def _min_max_filters(self, model: type[Any], dto_config: DTOConfig) -> list[FilterFunctionInfo]: + def _min_max_filters(self, model: type[DeclarativeBase], dto_config: DTOConfig) -> list[FilterFunctionInfo]: aggregations: list[FilterFunctionInfo] = [] if min_max_numeric_fields := self.arguments_type(model, dto_config, "min_max_numeric"): @@ -375,20 +339,10 @@ def arguments_type( self, model: type[DeclarativeBase], dto_config: DTOConfig, aggregation: AggregationType ) -> type[EnumDTO] | None: try: - if aggregation == "numeric": - dto = self._numeric_fields_factory.enum_factory(model, dto_config, raise_if_no_fields=True) - elif aggregation == "sum": - dto = self._sum_fields_factory.enum_factory(model, dto_config, raise_if_no_fields=True) - elif aggregation == "min_max_date": - dto = self._min_max_date_fields_factory.enum_factory(model, dto_config, raise_if_no_fields=True) - elif aggregation == "min_max_datetime": - dto = self._min_max_datetime_fields_factory.enum_factory(model, dto_config, raise_if_no_fields=True) - elif aggregation == "min_max_string": - dto = self._min_max_string_fields_factory.enum_factory(model, dto_config, raise_if_no_fields=True) - elif aggregation == "min_max_numeric": - dto = self._min_max_numeric_fields_factory.enum_factory(model, dto_config, raise_if_no_fields=True) - elif aggregation == "min_max_time": - dto = self._min_max_time_fields_factory.enum_factory(model, dto_config, raise_if_no_fields=True) + factory = self._type_filtered_factories.get(aggregation) + if factory is None: + return None + dto = factory.enum_factory(model, dto_config, raise_if_no_fields=True) except DTOError: return None return dto @@ -397,7 +351,8 @@ def numeric_field_type( self, model: type[DeclarativeBase], dto_config: DTOConfig ) -> type[UnmappedStrawberryGraphQLDTO[DeclarativeBase]] | None: try: - dto = self._numeric_fields_factory.factory(model=model, dto_config=dto_config, raise_if_no_fields=True) + factory = self._type_filtered_factories["numeric"] + dto = factory.factory(model=model, dto_config=dto_config, raise_if_no_fields=True) except DTOError: return None return dto @@ -406,7 +361,8 @@ def min_max_field_type( self, model: type[DeclarativeBase], dto_config: DTOConfig ) -> type[UnmappedStrawberryGraphQLDTO[DeclarativeBase]] | None: try: - dto = self._min_max_fields_factory.factory(model=model, dto_config=dto_config, raise_if_no_fields=True) + factory = self._type_filtered_factories["min_max"] + dto = factory.factory(model=model, dto_config=dto_config, raise_if_no_fields=True) except DTOError: return None return dto @@ -415,12 +371,13 @@ def sum_field_type( self, model: type[DeclarativeBase], dto_config: DTOConfig ) -> type[UnmappedStrawberryGraphQLDTO[DeclarativeBase]] | None: try: - dto = self._sum_fields_factory.factory(model=model, dto_config=dto_config, raise_if_no_fields=True) + factory = self._type_filtered_factories["sum"] + dto = factory.factory(model=model, dto_config=dto_config, raise_if_no_fields=True) except DTOError: return None return dto - def output_functions(self, model: type[Any], dto_config: DTOConfig) -> list[OutputFunctionInfo]: + def output_functions(self, model: type[DeclarativeBase], dto_config: DTOConfig) -> list[OutputFunctionInfo]: int_as_float_config = dto_config.copy_with( type_overrides={int: Optional[float], Optional[int]: Optional[float]} ) @@ -454,7 +411,7 @@ def output_functions(self, model: type[Any], dto_config: DTOConfig) -> list[Outp ) return sorted(aggregations, key=lambda aggregation: aggregation.function) - def filter_functions(self, model: type[Any], dto_config: DTOConfig) -> list[FilterFunctionInfo]: + def filter_functions(self, model: type[DeclarativeBase], dto_config: DTOConfig) -> list[FilterFunctionInfo]: count_fields = self._count_fields_factory.factory(model=model, dto_config=dto_config) numeric_arg_fields = self.arguments_type(model, dto_config, "numeric") sum_arg_fields = self.arguments_type(model, dto_config, "sum") diff --git a/src/strawchemy/strawberry/factories/inputs.py b/src/strawchemy/strawberry/factories/inputs.py index 6656ef7..26396cc 100644 --- a/src/strawchemy/strawberry/factories/inputs.py +++ b/src/strawchemy/strawberry/factories/inputs.py @@ -394,7 +394,7 @@ def _order_by_aggregation_fields( dto, RegistryTypeInfo(dto.__name__, "input", default_name=self.root_dto_name(model, dto_config)) ) - def _order_by_aggregation(self, model: type[Any], dto_config: DTOConfig) -> type[OrderByDTO]: + def _order_by_aggregation(self, model: type[DeclarativeBase], dto_config: DTOConfig) -> type[OrderByDTO]: field_definitions: list[GraphQLFieldDefinition] = [] for aggregation in self._aggregation_filter_factory.aggregation_builder.filter_functions(model, dto_config): if aggregation.require_arguments: diff --git a/src/strawchemy/strawberry/mutation/builder.py b/src/strawchemy/strawberry/mutation/builder.py new file mode 100644 index 0000000..4acc595 --- /dev/null +++ b/src/strawchemy/strawberry/mutation/builder.py @@ -0,0 +1,109 @@ +"""Builder for Strawchemy mutation fields with common configuration.""" + +from __future__ import annotations + +import dataclasses +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +from strawberry.annotation import StrawberryAnnotation + +if TYPE_CHECKING: + from collections.abc import Callable, Mapping, Sequence + + from strawberry.extensions.field_extension import FieldExtension + + from strawberry import BasePermission + from strawchemy.config.base import StrawchemyConfig + from strawchemy.strawberry._field import ( + StrawchemyCreateMutationField, + StrawchemyDeleteMutationField, + StrawchemyUpdateMutationField, + StrawchemyUpsertMutationField, + ) + from strawchemy.typing import AnyRepository + + +@dataclass +class MutationFieldBuilder: + """Builder for Strawchemy mutation fields with common configuration. + + This builder encapsulates the common logic for creating mutation fields + (create, update, upsert, delete) to eliminate code duplication and provide + a consistent interface for mutation field creation. + """ + + config: StrawchemyConfig + registry_namespace_getter: Callable[[], dict[str, Any]] + + def build( + self, + field_class: type[ + StrawchemyCreateMutationField + | StrawchemyUpdateMutationField + | StrawchemyUpsertMutationField + | StrawchemyDeleteMutationField + ], + resolver: Any | None = None, + *, + repository_type: AnyRepository | None = None, + graphql_type: Any | None = None, + name: str | None = None, + description: str | None = None, + permission_classes: list[type[BasePermission]] | None = None, + deprecation_reason: str | None = None, + default: Any = dataclasses.MISSING, + default_factory: Callable[..., object] | object = dataclasses.MISSING, + metadata: Mapping[Any, Any] | None = None, + directives: Sequence[object] = (), + extensions: list[FieldExtension] | None = None, + **field_specific_kwargs: Any, + ) -> Any: + """Build a mutation field with common configuration. + + Args: + field_class: The specific mutation field class to instantiate + (e.g., StrawchemyCreateMutationField). + resolver: An optional custom resolver function for the mutation. + repository_type: An optional custom repository class. Defaults to + the repository configured in StrawchemyConfig. + graphql_type: The GraphQL return type of the mutation. + name: The name of the GraphQL mutation field. + description: The description of the GraphQL mutation field. + permission_classes: A list of permission classes for the field. + deprecation_reason: The reason for deprecating the field. + default: The default value for the field. + default_factory: A factory function to generate the default value. + metadata: Additional metadata for the field. + directives: A sequence of directives for the field. + extensions: A list of Strawberry FieldExtensions. + **field_specific_kwargs: Additional keyword arguments specific to + the field type (e.g., input_type, filter_input, update_fields, etc.). + + Returns: + A configured mutation field instance, either wrapped with the resolver + or as a standalone field. + """ + namespace = self.registry_namespace_getter() + type_annotation = StrawberryAnnotation.from_annotation(graphql_type, namespace) if graphql_type else None + repository_type_ = repository_type if repository_type is not None else self.config.repository_type + + field = field_class( + config=self.config, + repository_type=repository_type_, + python_name=None, + graphql_name=name, + type_annotation=type_annotation, + is_subscription=False, + permission_classes=permission_classes or [], + deprecation_reason=deprecation_reason, + default=default, + default_factory=default_factory, + metadata=metadata, + directives=directives, + extensions=extensions or [], + registry_namespace=namespace, + description=description, + **field_specific_kwargs, + ) + return field(resolver) if resolver else field diff --git a/tasks.md b/tasks.md index 8e07cd3..a0f38fe 100644 --- a/tasks.md +++ b/tasks.md @@ -1,6 +1,6 @@ ## `auto-bump` -- Depends: _install +- Depends: uv:install - **Usage**: `auto-bump` @@ -55,7 +55,7 @@ Clean working directory ## `install` -- Depends: install:pre-commit, _install +- Depends: install:pre-commit, uv:install - **Usage**: `install` - **Aliases**: `i` @@ -70,7 +70,7 @@ Install pre-commit hooks ## `lint` -- Depends: vulture, pyright, ruff:check, ruff:format:check +- Depends: vulture, pyright, ruff:check, ruff:format:check, slotscheck - **Usage**: `lint` - **Aliases**: `l` @@ -95,8 +95,6 @@ Run pre-commit checks ## `pyright` -- Depends: _install - - **Usage**: `pyright` Run basedpyright @@ -109,8 +107,6 @@ Generate tasks documentation ## `ruff:check` -- Depends: _install - - **Usage**: `ruff:check` Check ruff formatting @@ -133,9 +129,15 @@ Format code Format code +## `slotscheck` + +- **Usage**: `slotscheck` + +Run slotscheck + ## `test` -- Depends: _install +- Depends: uv:install - **Usage**: `test [test]` - **Aliases**: `t` @@ -146,9 +148,11 @@ Run tests #### `[test]` +**Default:** `` + ## `test:coverage` -- Depends: _install +- Depends: uv:install - **Usage**: `test:coverage [test]` - **Aliases**: `tc` @@ -159,9 +163,11 @@ Run tests with coverage #### `[test]` +**Default:** `` + ## `test:integration` -- Depends: _install +- Depends: uv:install - **Usage**: `test:integration [--python [python]] …` - **Aliases**: `ti` @@ -172,6 +178,8 @@ Run integration tests #### `…` +**Default:** `` + ### Flags #### `--python [python]` @@ -180,7 +188,7 @@ Run integration tests ## `test:integration-all` -- Depends: _install +- Depends: uv:install - **Usage**: `test:integration-all [--python [python]] [test]` - **Aliases**: `tia` @@ -191,6 +199,8 @@ Run integration tests on all supported python versions #### `[test]` +**Default:** `` + ### Flags #### `--python [python]` @@ -199,7 +209,7 @@ Run integration tests on all supported python versions ## `test:integration-mysql` -- Depends: _install +- Depends: uv:install - **Usage**: `test:integration-mysql [--python [python]] …` - **Aliases**: `ti-mysql` @@ -210,6 +220,8 @@ Run integration tests #### `…` +**Default:** `` + ### Flags #### `--python [python]` @@ -218,7 +230,7 @@ Run integration tests ## `test:integration-postgres` -- Depends: _install +- Depends: uv:install - **Usage**: `test:integration-postgres [--python [python]] …` - **Aliases**: `ti-postgres` @@ -229,6 +241,8 @@ Run integration tests #### `…` +**Default:** `` + ### Flags #### `--python [python]` @@ -237,7 +251,7 @@ Run integration tests ## `test:integration-sqlite` -- Depends: _install +- Depends: uv:install - **Usage**: `test:integration-sqlite [--python [python]] …` - **Aliases**: `ti-sqlite` @@ -248,6 +262,8 @@ Run integration tests #### `…` +**Default:** `` + ### Flags #### `--python [python]` @@ -256,7 +272,7 @@ Run integration tests ## `test:integration:coverage` -- Depends: _install +- Depends: uv:install - **Usage**: `test:integration:coverage [--python [python]] [test]` - **Aliases**: `tic` @@ -267,6 +283,8 @@ Run integration tests with coverage #### `[test]` +**Default:** `` + ### Flags #### `--python [python]` @@ -275,7 +293,7 @@ Run integration tests with coverage ## `test:unit` -- Depends: _install +- Depends: uv:install - **Usage**: `test:unit [--python [python]] [test]` - **Aliases**: `tu` @@ -286,6 +304,8 @@ Run unit tests #### `[test]` +**Default:** `` + ### Flags #### `--python [python]` @@ -294,7 +314,7 @@ Run unit tests ## `test:unit-all` -- Depends: _install +- Depends: uv:install - **Usage**: `test:unit-all [--python [python]] [test]` - **Aliases**: `tua` @@ -305,6 +325,8 @@ Run unit tests on all supported python versions #### `[test]` +**Default:** `` + ### Flags #### `--python [python]` @@ -313,7 +335,7 @@ Run unit tests on all supported python versions ## `test:unit:coverage` -- Depends: _install +- Depends: uv:install - **Usage**: `test:unit:coverage [--python [python]] [test]` - **Aliases**: `tuc` @@ -324,6 +346,8 @@ Run unit tests with coverage #### `[test]` +**Default:** `` + ### Flags #### `--python [python]` @@ -332,7 +356,7 @@ Run unit tests with coverage ## `test:unit:no-extras` -- Depends: _install +- Depends: uv:install - **Usage**: `test:unit:no-extras [--python [python]] …` - **Aliases**: `tug` @@ -343,6 +367,8 @@ Run unit tests without extras dependencies #### `…` +**Default:** `` + ### Flags #### `--python [python]` @@ -351,12 +377,18 @@ Run unit tests without extras dependencies ## `test:update-snapshots` -- Depends: _install +- Depends: uv:install - **Usage**: `test:update-snapshots` Run snapshot-based tests and update snapshots +## `uv:install` + +- **Usage**: `uv:install` + +Install dependencies + ## `vulture` - **Usage**: `vulture` diff --git a/tests/unit/mapping/test_schemas.py b/tests/unit/mapping/test_schemas.py index cf9f8d2..7ddd515 100644 --- a/tests/unit/mapping/test_schemas.py +++ b/tests/unit/mapping/test_schemas.py @@ -62,9 +62,9 @@ class InputType: id: auto name: auto - user = InputType(id=1, name="user") # pyright: ignore[reportCallIssue] - assert user.id == 1 # pyright: ignore[reportAttributeAccessIssue] - assert user.name == "user" # pyright: ignore[reportAttributeAccessIssue] + user = InputType(id=1, name="user") + assert user.id == 1 + assert user.name == "user" def test_field_metadata_default(strawchemy: Strawchemy) -> None: diff --git a/tests/unit/test_mutation_input.py b/tests/unit/test_mutation_input.py index d1f3c2d..b59eff5 100644 --- a/tests/unit/test_mutation_input.py +++ b/tests/unit/test_mutation_input.py @@ -19,9 +19,9 @@ def test_add_non_input_relationships( @strawchemy.create_input(color_model, include="all") class ColorInput: ... - color = ColorInput(name="Blue") # pyright: ignore[reportCallIssue] + color = ColorInput(name="Blue") color_input = Input(color) assert len(color_input.relations) == 0 - color_input.instances[0].fruits.append(fruit_model(name="Apple", color_id=uuid4(), sweetness=1, color=None)) # pyright: ignore[reportArgumentType] + color_input.instances[0].fruits.append(fruit_model(name="Apple", color_id=uuid4(), sweetness=1, color=None)) color_input.add_non_input_relations() assert len(color_input.relations) == 1 diff --git a/uv.lock b/uv.lock index 77c2328..52a8b65 100644 --- a/uv.lock +++ b/uv.lock @@ -3124,6 +3124,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050, upload-time = "2024-12-04T17:35:26.475Z" }, ] +[[package]] +name = "slotscheck" +version = "0.19.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click", version = "8.1.8", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'group-10-strawchemy-build' or extra == 'group-10-strawchemy-dev'" }, + { name = "click", version = "8.3.1", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'group-10-strawchemy-codeflash' or (extra != 'group-10-strawchemy-build' and extra != 'group-10-strawchemy-dev')" }, + { name = "tomli", version = "2.0.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and extra == 'group-10-strawchemy-build') or (python_full_version < '3.11' and extra == 'group-10-strawchemy-dev') or (extra == 'group-10-strawchemy-build' and extra == 'group-10-strawchemy-codeflash') or (extra == 'group-10-strawchemy-codeflash' and extra == 'group-10-strawchemy-dev')" }, + { name = "tomli", version = "2.3.0", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and extra == 'group-10-strawchemy-codeflash') or (python_full_version < '3.11' and extra != 'group-10-strawchemy-build' and extra != 'group-10-strawchemy-dev') or (extra == 'group-10-strawchemy-build' and extra == 'group-10-strawchemy-codeflash') or (extra == 'group-10-strawchemy-codeflash' and extra == 'group-10-strawchemy-dev')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4b/57/6fcb8df11e7c76eb87b23bfa931408e47f051c6161749c531b4060a45516/slotscheck-0.19.1.tar.gz", hash = "sha256:6146b7747f8db335a00a66b782f86011b74b995f61746dc5b36a9e77d5326013", size = 16050, upload-time = "2024-10-19T13:30:53.369Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/da/32/bd569256267f80b76b87d21a09795741a175778b954bee1d7b1a89852b6f/slotscheck-0.19.1-py3-none-any.whl", hash = "sha256:bff9926f8d6408ea21b6c6bbaa4389cea1682962e73ee4f30084b6d2b89260ee", size = 16995, upload-time = "2024-10-19T13:30:51.23Z" }, +] + [[package]] name = "smmap" version = "5.0.2" @@ -3266,6 +3281,7 @@ dev = [ { name = "pytest-pretty" }, { name = "pytest-xdist" }, { name = "ruff" }, + { name = "slotscheck" }, { name = "sqlparse" }, { name = "syrupy" }, { name = "testapp" }, @@ -3278,6 +3294,7 @@ doc = [ lint = [ { name = "basedpyright" }, { name = "ruff" }, + { name = "slotscheck" }, { name = "vulture" }, ] mysql = [ @@ -3352,6 +3369,7 @@ dev = [ { name = "pytest-pretty" }, { name = "pytest-xdist" }, { name = "ruff" }, + { name = "slotscheck", specifier = ">=0.16.5" }, { name = "sqlparse" }, { name = "syrupy" }, { name = "testapp", editable = "examples/testapp" }, @@ -3362,6 +3380,7 @@ doc = [{ name = "git-cliff", specifier = ">=2.6.1" }] lint = [ { name = "basedpyright" }, { name = "ruff" }, + { name = "slotscheck", specifier = ">=0.16.5" }, { name = "vulture" }, ] mysql = [