From 1e70f1de980e47951cbaf63deea0d89701741a0b Mon Sep 17 00:00:00 2001 From: Ckk3 Date: Tue, 3 Mar 2026 20:29:39 -0300 Subject: [PATCH 1/3] feat: add hstore column filter support for PostgreSQL --- README.md | 97 +++++++++++ src/strawchemy/dto/inspectors/sqlalchemy.py | 31 ++-- src/strawchemy/schema/filters/__init__.py | 6 + src/strawchemy/schema/filters/base.py | 32 +++- src/strawchemy/schema/filters/inputs.py | 39 ++++- src/strawchemy/schema/scalars/__init__.py | 10 +- src/strawchemy/schema/scalars/base.py | 11 +- tests/integration/conftest.py | 2 + .../data_types/__snapshots__/test_hstore.ambr | 160 ++++++++++++++++++ tests/integration/data_types/test_hstore.py | 126 ++++++++++++++ tests/integration/fixtures.py | 24 ++- tests/integration/models.py | 26 ++- tests/integration/types/postgres.py | 26 ++- 13 files changed, 551 insertions(+), 39 deletions(-) create mode 100644 tests/integration/data_types/__snapshots__/test_hstore.ambr create mode 100644 tests/integration/data_types/test_hstore.py diff --git a/README.md b/README.md index e1c24392..64cad17f 100644 --- a/README.md +++ b/README.md @@ -706,6 +706,7 @@ Strawchemy supports a wide range of filter operations: | **Numeric types (Int, Float, Decimal)** | `gt`, `gte`, `lt`, `lte` | | **String** | order filter, plus `like`, `nlike`, `ilike`, `nilike`, `regexp`, `iregexp`, `nregexp`, `inregexp`, `startswith`, `endswith`, `contains`, `istartswith`, `iendswith`, `icontains` | | **JSON** | `contains`, `containedIn`, `hasKey`, `hasKeyAll`, `hasKeyAny` | +| **HStore** (PostgreSQL) | `contains`, `containedIn`, `hasKey`, `hasKeyAll`, `hasKeyAny` | | **Array** | `contains`, `containedIn`, `overlap` | | **Date** | order filters on plain dates, plus `year`, `month`, `day`, `weekDay`, `week`, `quarter`, `isoYear` and `isoWeekDay` filters | | **DateTime** | All Date filters plus `hour`, `minute`, `second` | @@ -713,6 +714,102 @@ Strawchemy supports a wide range of filter operations: | **Interval** | order filters on plain intervals, plus `days`, `hours`, `minutes` and `seconds` filters | | **Logical** | `_and`, `_or`, `_not` | +### HStore Filters + +Strawchemy supports filtering on PostgreSQL [`hstore`](https://www.postgresql.org/docs/current/hstore.html) columns. +To use HStore filters, the `hstore` extension must be enabled in your PostgreSQL database. + +
+HStore filters example + +Define models and types: + +```python +from sqlalchemy import MetaData +from sqlalchemy.dialects import postgresql +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + +class Base(DeclarativeBase): + metadata = MetaData() + +class Config(Base): + __tablename__ = "config" + + id: Mapped[int] = mapped_column(primary_key=True) + settings: Mapped[dict[str, str]] = mapped_column(postgresql.HSTORE, default=dict) + + +@strawchemy.type(Config, include="all") +class ConfigType: ... + + +@strawchemy.filter(Config, include="all") +class ConfigFilter: ... + + +@strawberry.type +class Query: + configs: list[ConfigType] = strawchemy.field(filter_input=ConfigFilter) +``` + +**Important:** When creating your Strawberry schema, add `HSTORE_SCALAR_OVERRIDES` so that `dict[str, str]` is +correctly mapped to the `HStore` GraphQL scalar: + +```python +from strawchemy.schema.scalars import HSTORE_SCALAR_OVERRIDES + +schema = strawberry.Schema( + query=Query, + scalar_overrides={**HSTORE_SCALAR_OVERRIDES}, +) +``` + +Then you can use the following HStore filter operations in your GraphQL queries: + +```graphql +{ + # Find configs where settings contain a specific key-value pair + configs(filter: { settings: { contains: { theme: "dark" } } }) { + id + settings + } + + # Find configs where settings are contained within the given dict + configs(filter: { settings: { containedIn: { theme: "dark", lang: "en", mode: "advanced" } } }) { + id + settings + } + + # Find configs that have a specific key + configs(filter: { settings: { hasKey: "theme" } }) { + id + settings + } + + # Find configs that have all specified keys + configs(filter: { settings: { hasKeyAll: ["theme", "lang"] } }) { + id + settings + } + + # Find configs that have any of the specified keys + configs(filter: { settings: { hasKeyAny: ["theme", "notifications"] } }) { + id + settings + } +} +``` + +
+ +Strawchemy supports the following HStore filter operations: + +- **contains**: Filters for HStore values that contain the given key-value pairs +- **containedIn**: Filters for HStore values that are contained within the given dict +- **hasKey**: Filters for HStore values that have the given key +- **hasKeyAll**: Filters for HStore values that have all of the given keys +- **hasKeyAny**: Filters for HStore values that have any of the given keys + ### Geo Filters Strawchemy supports spatial filtering capabilities for geometry fields diff --git a/src/strawchemy/dto/inspectors/sqlalchemy.py b/src/strawchemy/dto/inspectors/sqlalchemy.py index f5158623..3351b9c2 100644 --- a/src/strawchemy/dto/inspectors/sqlalchemy.py +++ b/src/strawchemy/dto/inspectors/sqlalchemy.py @@ -10,20 +10,6 @@ from inspect import getmodule, signature from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast, get_args, get_origin, get_type_hints -from sqlalchemy import ( - ARRAY, - Column, - ColumnElement, - PrimaryKeyConstraint, - Sequence, - SQLColumnExpression, - Table, - UniqueConstraint, - event, - inspect, - orm, - sql, -) from sqlalchemy.dialects import postgresql from sqlalchemy.orm import ( NO_VALUE, @@ -40,6 +26,20 @@ ) from typing_extensions import TypeIs, override +from sqlalchemy import ( + ARRAY, + Column, + ColumnElement, + PrimaryKeyConstraint, + Sequence, + SQLColumnExpression, + Table, + UniqueConstraint, + event, + inspect, + orm, + sql, +) from strawchemy.config.databases import DatabaseFeatures from strawchemy.constants import GEO_INSTALLED from strawchemy.dto.base import TYPING_NS, DTOFieldDefinition, Relation @@ -58,6 +58,7 @@ TimeComparison, TimeDeltaComparison, make_full_json_comparison_input, + make_hstore_comparison_input, make_sqlite_json_comparison_input, ) from strawchemy.utils.annotation import is_type_hint_optional @@ -599,6 +600,8 @@ def get_field_comparison( field_type = field_definition.model_field.type if isinstance(field_type, ARRAY) and self.db_features.dialect == "postgresql": return ArrayComparison[field_type.item_type.python_type] + if isinstance(field_type, postgresql.HSTORE) and self.db_features.dialect == "postgresql": + return make_hstore_comparison_input() return self.get_type_comparison(self.model_field_type(field_definition)) def get_type_comparison(self, type_: type[Any]) -> type[GraphQLComparison]: diff --git a/src/strawchemy/schema/filters/__init__.py b/src/strawchemy/schema/filters/__init__.py index 66345e0e..b6b5cc10 100644 --- a/src/strawchemy/schema/filters/__init__.py +++ b/src/strawchemy/schema/filters/__init__.py @@ -8,6 +8,7 @@ DateTimeFilter, EqualityFilter, FilterProtocol, + HStoreFilter, JSONFilter, OrderFilter, TextFilter, @@ -26,9 +27,11 @@ TextComparison, TimeComparison, TimeDeltaComparison, + _HStoreComparison, _JSONComparison, _SQLiteJSONComparison, make_full_json_comparison_input, + make_hstore_comparison_input, make_sqlite_json_comparison_input, ) @@ -48,6 +51,7 @@ "GraphQLComparison", "GraphQLComparisonT", "GraphQLFilter", + "HStoreFilter", "JSONFilter", "OrderComparison", "OrderFilter", @@ -57,8 +61,10 @@ "TimeDeltaComparison", "TimeDeltaFilter", "TimeFilter", + "_HStoreComparison", "_JSONComparison", "_SQLiteJSONComparison", "make_full_json_comparison_input", + "make_hstore_comparison_input", "make_sqlite_json_comparison_input", ) diff --git a/src/strawchemy/schema/filters/base.py b/src/strawchemy/schema/filters/base.py index 23bae439..f02788c2 100644 --- a/src/strawchemy/schema/filters/base.py +++ b/src/strawchemy/schema/filters/base.py @@ -3,13 +3,14 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, ClassVar, Protocol, cast -from sqlalchemy import ARRAY, JSON, ColumnElement, Dialect, Integer, Text, and_, func, not_, null, or_, type_coerce -from sqlalchemy import cast as sqla_cast from sqlalchemy.dialects import mysql from sqlalchemy.dialects import postgresql as pg -from strawberry import UNSET from typing_extensions import override +from sqlalchemy import ARRAY, JSON, ColumnElement, Dialect, Integer, Text, and_, func, not_, null, or_, type_coerce +from sqlalchemy import cast as sqla_cast +from strawberry import UNSET + if TYPE_CHECKING: from datetime import date, timedelta @@ -25,6 +26,7 @@ TextComparison, TimeComparison, TimeDeltaComparison, + _HStoreComparison, _JSONComparison, ) @@ -241,6 +243,30 @@ def to_expressions( return expressions +@dataclass(frozen=True) +class HStoreFilter(EqualityFilter): + comparison: _HStoreComparison + + @override + def to_expressions( + self, dialect: Dialect, model_attribute: QueryableAttribute[Any] | ColumnElement[Any] + ) -> list[ColumnElement[bool]]: + expressions: list[ColumnElement[bool]] = super().to_expressions(dialect, model_attribute) + as_hstore = type_coerce(model_attribute, pg.HSTORE) + + if self.comparison.contains is not UNSET: + expressions.append(as_hstore.contains(self.comparison.contains)) + if self.comparison.contained_in is not UNSET: + expressions.append(as_hstore.contained_by(self.comparison.contained_in)) + if self.comparison.has_key is not UNSET: + expressions.append(as_hstore.has_key(self.comparison.has_key)) + if self.comparison.has_key_all is not UNSET: + expressions.append(as_hstore.has_all(sqla_cast(self.comparison.has_key_all, pg.ARRAY(Text)))) + if self.comparison.has_key_any is not UNSET: + expressions.append(as_hstore.has_any(sqla_cast(self.comparison.has_key_any, pg.ARRAY(Text)))) + return expressions + + @dataclass(frozen=True) class ArrayFilter(EqualityFilter): comparison: ArrayComparison[Any] diff --git a/src/strawchemy/schema/filters/inputs.py b/src/strawchemy/schema/filters/inputs.py index 621bda20..a9039d87 100644 --- a/src/strawchemy/schema/filters/inputs.py +++ b/src/strawchemy/schema/filters/inputs.py @@ -16,13 +16,13 @@ import strawberry from strawberry import UNSET, Private - from strawchemy.schema.filters import ( ArrayFilter, DateFilter, DateTimeFilter, EqualityFilter, FilterProtocol, + HStoreFilter, JSONFilter, OrderFilter, TextFilter, @@ -32,9 +32,9 @@ from strawchemy.typing import QueryNodeType if TYPE_CHECKING: - from sqlalchemy import ColumnElement, Dialect from sqlalchemy.orm import QueryableAttribute + from sqlalchemy import ColumnElement, Dialect from strawchemy.dto.strawberry import OrderByEnum __all__ = ( @@ -46,15 +46,17 @@ "TextComparison", "TimeComparison", "TimeDeltaComparison", + "_HStoreComparison", "_JSONComparison", "make_full_json_comparison_input", + "make_hstore_comparison_input", "make_sqlite_json_comparison_input", ) T = TypeVar("T") GraphQLComparisonT = TypeVar("GraphQLComparisonT", bound="GraphQLComparison") GraphQLFilter: TypeAlias = "GraphQLComparison | OrderByEnum" -AnyGraphQLComparison: TypeAlias = "EqualityComparison[Any] | OrderComparison[Any] | TextComparison | DateComparison | TimeComparison | DateTimeComparison | TimeDeltaComparison | ArrayComparison[Any] | _JSONComparison | _SQLiteJSONComparison" +AnyGraphQLComparison: TypeAlias = "EqualityComparison[Any] | OrderComparison[Any] | TextComparison | DateComparison | TimeComparison | DateTimeComparison | TimeDeltaComparison | ArrayComparison[Any] | _JSONComparison | _SQLiteJSONComparison | _HStoreComparison" AnyOrderGraphQLComparison: TypeAlias = ( "OrderComparison[Any] | TextComparison | DateComparison | TimeComparison | DateTimeComparison | TimeDeltaComparison" ) @@ -306,6 +308,30 @@ class _JSONComparison(EqualityComparison[dict[str, Any]]): has_key_any: list[str] | None = UNSET +class _HStoreComparison(EqualityComparison[dict[str, str]]): + """HStore comparison class for GraphQL filters. + + This class provides a set of HStore comparison operators that can be + used to filter data based on containment, key existence, and other + HStore-specific properties. PostgreSQL only. + + Attributes: + contains: Filters for HStore values that contain this key-value pair. + contained_in: Filters for HStore values that are contained in this dict. + has_key: Filters for HStore values that have this key. + has_key_all: Filters for HStore values that have all of these keys. + has_key_any: Filters for HStore values that have any of these keys. + """ + + __strawchemy_filter__ = HStoreFilter + + contains: dict[str, str] | None = UNSET + contained_in: dict[str, str] | None = UNSET + has_key: str | None = UNSET + has_key_all: list[str] | None = UNSET + has_key_any: list[str] | None = UNSET + + class _SQLiteJSONComparison(EqualityComparison[dict[str, Any]]): """JSON comparison class for GraphQL filters. @@ -340,3 +366,10 @@ def make_sqlite_json_comparison_input() -> type[_SQLiteJSONComparison]: return strawberry.input(name="JSONComparison", description=_DESCRIPTION.format(field="JSON fields"))( _SQLiteJSONComparison ) + + +@cache +def make_hstore_comparison_input() -> type[_HStoreComparison]: + return strawberry.input(name="HStoreComparison", description=_DESCRIPTION.format(field="HStore fields"))( + _HStoreComparison + ) diff --git a/src/strawchemy/schema/scalars/__init__.py b/src/strawchemy/schema/scalars/__init__.py index a53503dc..aafb56b7 100644 --- a/src/strawchemy/schema/scalars/__init__.py +++ b/src/strawchemy/schema/scalars/__init__.py @@ -1,5 +1,11 @@ from __future__ import annotations -from strawchemy.schema.scalars.base import Date, DateTime, Interval, Time +from typing import Any -__all__ = ("Date", "DateTime", "Interval", "Time") +from strawchemy.schema.scalars.base import Date, DateTime, HStore, Interval, Time + +__all__ = ("HSTORE_SCALAR_OVERRIDES", "Date", "DateTime", "HStore", "Interval", "Time") + +HSTORE_SCALAR_OVERRIDES: dict[object, type[Any]] = { + dict[str, str]: HStore, +} diff --git a/src/strawchemy/schema/scalars/base.py b/src/strawchemy/schema/scalars/base.py index 3be8a015..260682f7 100644 --- a/src/strawchemy/schema/scalars/base.py +++ b/src/strawchemy/schema/scalars/base.py @@ -5,15 +5,15 @@ from typing import TYPE_CHECKING, TypeVar from msgspec import json -from strawberry import scalar from strawberry.schema.types.base_scalars import wrap_parser +from strawberry import scalar from strawchemy.utils.annotation import new_type if TYPE_CHECKING: from typing import Any -__all__ = ("Date", "DateTime", "Interval", "Time") +__all__ = ("Date", "DateTime", "HStore", "Interval", "Time") UTC = timezone.utc @@ -58,3 +58,10 @@ def _serialize(value: timedelta) -> str: serialize=_serialize_date, parse_value=wrap_parser(datetime.fromisoformat, "DateTime"), ) + +HStore = scalar( + new_type("HStore", dict), + description="The `HStore` scalar type represents a PostgreSQL hstore value, a flat mapping of string keys to string values.", + serialize=lambda val: val, + parse_value=lambda val: {str(k): str(v) for k, v in val.items()} if isinstance(val, dict) else val, +) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 5761a284..ecfd3356 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -33,6 +33,7 @@ raw_geo, raw_geo_flipped, raw_groups, + raw_hstore, raw_intervals, raw_json, raw_topics, @@ -80,6 +81,7 @@ "raw_geo", "raw_geo_flipped", "raw_groups", + "raw_hstore", "raw_intervals", "raw_json", "raw_topics", diff --git a/tests/integration/data_types/__snapshots__/test_hstore.ambr b/tests/integration/data_types/__snapshots__/test_hstore.ambr new file mode 100644 index 00000000..ee927f65 --- /dev/null +++ b/tests/integration/data_types/__snapshots__/test_hstore.ambr @@ -0,0 +1,160 @@ +# serializer version: 1 +# name: test_hstore_filters[session-tracked-async-containedIn-asyncpg_engine] + ''' + SELECT hstore_model.hstore_col, + hstore_model.id + FROM hstore_model AS hstore_model + WHERE hstore_model.hstore_col <@ $1 + ORDER BY hstore_model.id ASC + ''' +# --- +# name: test_hstore_filters[session-tracked-async-containedIn-psycopg_async_engine] + ''' + SELECT hstore_model.hstore_col, + hstore_model.id + FROM hstore_model AS hstore_model + WHERE hstore_model.hstore_col <@ %(param_1)s + ORDER BY hstore_model.id ASC + ''' +# --- +# name: test_hstore_filters[session-tracked-async-contains-asyncpg_engine] + ''' + SELECT hstore_model.hstore_col, + hstore_model.id + FROM hstore_model AS hstore_model + WHERE hstore_model.hstore_col @> $1 + ORDER BY hstore_model.id ASC + ''' +# --- +# name: test_hstore_filters[session-tracked-async-contains-psycopg_async_engine] + ''' + SELECT hstore_model.hstore_col, + hstore_model.id + FROM hstore_model AS hstore_model + WHERE hstore_model.hstore_col @> %(param_1)s + ORDER BY hstore_model.id ASC + ''' +# --- +# name: test_hstore_filters[session-tracked-async-hasKey-asyncpg_engine] + ''' + SELECT hstore_model.hstore_col, + hstore_model.id + FROM hstore_model AS hstore_model + WHERE hstore_model.hstore_col ? $1::VARCHAR + ORDER BY hstore_model.id ASC + ''' +# --- +# name: test_hstore_filters[session-tracked-async-hasKey-psycopg_async_engine] + ''' + SELECT hstore_model.hstore_col, + hstore_model.id + FROM hstore_model AS hstore_model + WHERE hstore_model.hstore_col ? %(param_1)s::VARCHAR + ORDER BY hstore_model.id ASC + ''' +# --- +# name: test_hstore_filters[session-tracked-async-hasKeyAll-asyncpg_engine] + ''' + SELECT hstore_model.hstore_col, + hstore_model.id + FROM hstore_model AS hstore_model + WHERE hstore_model.hstore_col ? & CAST($1::TEXT[] AS TEXT[]) + ORDER BY hstore_model.id ASC + ''' +# --- +# name: test_hstore_filters[session-tracked-async-hasKeyAll-psycopg_async_engine] + ''' + SELECT hstore_model.hstore_col, + hstore_model.id + FROM hstore_model AS hstore_model + WHERE hstore_model.hstore_col ? & CAST(%(param_1)s::TEXT[] AS TEXT[]) + ORDER BY hstore_model.id ASC + ''' +# --- +# name: test_hstore_filters[session-tracked-async-hasKeyAny-asyncpg_engine] + ''' + SELECT hstore_model.hstore_col, + hstore_model.id + FROM hstore_model AS hstore_model + WHERE hstore_model.hstore_col ? | CAST($1::TEXT[] AS TEXT[]) + ORDER BY hstore_model.id ASC + ''' +# --- +# name: test_hstore_filters[session-tracked-async-hasKeyAny-psycopg_async_engine] + ''' + SELECT hstore_model.hstore_col, + hstore_model.id + FROM hstore_model AS hstore_model + WHERE hstore_model.hstore_col ? | CAST(%(param_1)s::TEXT[] AS TEXT[]) + ORDER BY hstore_model.id ASC + ''' +# --- +# name: test_hstore_filters[session-tracked-sync-containedIn-psycopg_engine] + ''' + SELECT hstore_model.hstore_col, + hstore_model.id + FROM hstore_model AS hstore_model + WHERE hstore_model.hstore_col <@ %(param_1)s + ORDER BY hstore_model.id ASC + ''' +# --- +# name: test_hstore_filters[session-tracked-sync-contains-psycopg_engine] + ''' + SELECT hstore_model.hstore_col, + hstore_model.id + FROM hstore_model AS hstore_model + WHERE hstore_model.hstore_col @> %(param_1)s + ORDER BY hstore_model.id ASC + ''' +# --- +# name: test_hstore_filters[session-tracked-sync-hasKey-psycopg_engine] + ''' + SELECT hstore_model.hstore_col, + hstore_model.id + FROM hstore_model AS hstore_model + WHERE hstore_model.hstore_col ? %(param_1)s::VARCHAR + ORDER BY hstore_model.id ASC + ''' +# --- +# name: test_hstore_filters[session-tracked-sync-hasKeyAll-psycopg_engine] + ''' + SELECT hstore_model.hstore_col, + hstore_model.id + FROM hstore_model AS hstore_model + WHERE hstore_model.hstore_col ? & CAST(%(param_1)s::TEXT[] AS TEXT[]) + ORDER BY hstore_model.id ASC + ''' +# --- +# name: test_hstore_filters[session-tracked-sync-hasKeyAny-psycopg_engine] + ''' + SELECT hstore_model.hstore_col, + hstore_model.id + FROM hstore_model AS hstore_model + WHERE hstore_model.hstore_col ? | CAST(%(param_1)s::TEXT[] AS TEXT[]) + ORDER BY hstore_model.id ASC + ''' +# --- +# name: test_hstore_output[session-tracked-async-asyncpg_engine] + ''' + SELECT hstore_model.hstore_col, + hstore_model.id + FROM hstore_model AS hstore_model + ORDER BY hstore_model.id ASC + ''' +# --- +# name: test_hstore_output[session-tracked-async-psycopg_async_engine] + ''' + SELECT hstore_model.hstore_col, + hstore_model.id + FROM hstore_model AS hstore_model + ORDER BY hstore_model.id ASC + ''' +# --- +# name: test_hstore_output[session-tracked-sync-psycopg_engine] + ''' + SELECT hstore_model.hstore_col, + hstore_model.id + FROM hstore_model AS hstore_model + ORDER BY hstore_model.id ASC + ''' +# --- diff --git a/tests/integration/data_types/test_hstore.py b/tests/integration/data_types/test_hstore.py new file mode 100644 index 00000000..63fd1f01 --- /dev/null +++ b/tests/integration/data_types/test_hstore.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import pytest + +from sqlalchemy import Executable, Insert, MetaData, insert, text +from tests.integration.models import HStoreModel, hstore_metadata +from tests.integration.types import postgres as postgres_types +from tests.integration.utils import to_graphql_representation +from tests.utils import maybe_async + +if TYPE_CHECKING: + from syrupy.assertion import SnapshotAssertion + + from strawchemy.typing import SupportedDialect + from tests.integration.fixtures import QueryTracker + from tests.integration.typing import RawRecordData + from tests.typing import AnyQueryExecutor + + +@pytest.fixture +def before_create_all_statements() -> list[Executable]: + return [text("CREATE EXTENSION IF NOT EXISTS hstore")] + + +@pytest.fixture +def metadata() -> MetaData: + return hstore_metadata + + +@pytest.fixture +def seed_insert_statements(raw_hstore: RawRecordData) -> list[Insert]: + return [insert(HStoreModel).values(raw_hstore)] + + +@pytest.fixture +def async_query(dialect: SupportedDialect) -> type[Any]: + if dialect == "postgresql": + return postgres_types.HStoreAsyncQuery + pytest.skip(f"HStore tests can only be run on postgresql, not {dialect}") + + +@pytest.fixture +def sync_query(dialect: SupportedDialect) -> type[Any]: + if dialect == "postgresql": + return postgres_types.HStoreSyncQuery + pytest.skip(f"HStore tests can only be run on postgresql, not {dialect}") + + +@pytest.mark.parametrize( + ("filter_name", "value", "expected_ids"), + [ + pytest.param("contains", {"key1": "value1"}, [0], id="contains"), + pytest.param( + "containedIn", + {"key1": "value1", "key2": "value2", "key3": "value3", "extra": "value"}, + [0, 2], + id="containedIn", + ), + pytest.param("hasKey", "key1", [0], id="hasKey"), + pytest.param("hasKeyAll", ["key1", "key2"], [0], id="hasKeyAll"), + pytest.param("hasKeyAny", ["key1", "status"], [0, 1], id="hasKeyAny"), + ], +) +@pytest.mark.snapshot +async def test_hstore_filters( + filter_name: str, + value: Any, + expected_ids: list[int], + any_query: AnyQueryExecutor, + raw_hstore: RawRecordData, + query_tracker: QueryTracker, + sql_snapshot: SnapshotAssertion, +) -> None: + if isinstance(value, list): + value_str = ", ".join(to_graphql_representation(v, "input") for v in value) + value_repr = f"[{value_str}]" + else: + value_repr = to_graphql_representation(value, "input") + + query = f""" + {{ + hstore(filter: {{ hstoreCol: {{ {filter_name}: {value_repr} }} }}) {{ + id + hstoreCol + }} + }} + """ + result = await maybe_async(any_query(query)) + assert not result.errors + assert result.data + assert len(result.data["hstore"]) == len(expected_ids) + + for i, expected_id in enumerate(expected_ids): + assert result.data["hstore"][i]["id"] == raw_hstore[expected_id]["id"] + + assert query_tracker.query_count == 1 + assert query_tracker[0].statement_formatted == sql_snapshot + + +@pytest.mark.snapshot +async def test_hstore_output( + any_query: AnyQueryExecutor, + raw_hstore: RawRecordData, + query_tracker: QueryTracker, + sql_snapshot: SnapshotAssertion, +) -> None: + query = """ + { + hstore { + id + hstoreCol + } + } + """ + result = await maybe_async(any_query(query)) + assert not result.errors + assert result.data + + for hstore in result.data["hstore"]: + expected = next(f for f in raw_hstore if f["id"] == hstore["id"]) + assert hstore["hstoreCol"] == expected["hstore_col"] + + assert query_tracker.query_count == 1 + assert query_tracker[0].statement_formatted == sql_snapshot diff --git a/tests/integration/fixtures.py b/tests/integration/fixtures.py index e1153281..1ec808a6 100644 --- a/tests/integration/fixtures.py +++ b/tests/integration/fixtures.py @@ -13,6 +13,12 @@ import sqlparse from pytest_databases.docker.postgres import _provide_postgres_service from pytest_lazy_fixtures import lf +from sqlalchemy.event import listens_for +from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.orm import Session, sessionmaker +from strawberry.scalars import JSON +from typing_extensions import Self + from sqlalchemy import ( URL, ClauseElement, @@ -31,15 +37,9 @@ create_engine, insert, ) -from sqlalchemy.event import listens_for -from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine -from sqlalchemy.orm import Session, sessionmaker -from strawberry.scalars import JSON -from typing_extensions import Self - from strawchemy.config.databases import DatabaseFeatures from strawchemy.constants import GEO_INSTALLED -from strawchemy.schema.scalars import Date, DateTime, Interval, Time +from strawchemy.schema.scalars import HSTORE_SCALAR_OVERRIDES, Date, DateTime, Interval, Time from tests.fixtures import DefaultQuery from tests.integration.models import ( Color, @@ -102,6 +102,7 @@ datetime: DateTime, } engine_plugins: list[str] = [] +scalar_overrides |= HSTORE_SCALAR_OVERRIDES if GEO_INSTALLED: from strawchemy.schema.scalars.geo import GEO_SCALAR_OVERRIDES @@ -345,6 +346,15 @@ def raw_intervals() -> RawRecordData: ] +@pytest.fixture +def raw_hstore() -> RawRecordData: + return [ + {"id": 1, "hstore_col": {"key1": "value1", "key2": "value2", "key3": "value3"}}, + {"id": 2, "hstore_col": {"status": "pending", "key3": "value3"}}, + {"id": 3, "hstore_col": {}}, + ] + + @pytest.fixture def raw_json() -> RawRecordData: return [ diff --git a/tests/integration/models.py b/tests/integration/models.py index 2e4ee01f..cf3cfcb5 100644 --- a/tests/integration/models.py +++ b/tests/integration/models.py @@ -5,6 +5,12 @@ from datetime import date, datetime, time, timedelta from typing import Any +from sqlalchemy.dialects import mysql, postgresql, sqlite +from sqlalchemy.ext.declarative import declared_attr +from sqlalchemy.ext.hybrid import hybrid_property +from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, column_property, mapped_column, relationship +from sqlalchemy.orm import registry as Registry # noqa: N812 + from sqlalchemy import ( ARRAY, JSON, @@ -23,18 +29,13 @@ Time, UniqueConstraint, ) -from sqlalchemy.dialects import mysql, postgresql, sqlite -from sqlalchemy.ext.declarative import declared_attr -from sqlalchemy.ext.hybrid import hybrid_property -from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, column_property, mapped_column, relationship -from sqlalchemy.orm import registry as Registry # noqa: N812 - from strawchemy.dto.utils import PRIVATE, READ_ONLY metadata = MetaData() geo_metadata = MetaData() dc_metadata = MetaData() json_metadata = MetaData() +hstore_metadata = MetaData() array_metadata = MetaData() interval_metadata = MetaData() date_time_metadata = MetaData() @@ -88,6 +89,11 @@ class GeoUUIDBase(BaseColumns, DeclarativeBase): registry = Registry(metadata=geo_metadata) +class HStoreBase(BaseColumns, DeclarativeBase): + __abstract__ = True + registry = Registry(metadata=hstore_metadata) + + class ArrayBase(BaseColumns, DeclarativeBase): __abstract__ = True registry = Registry(metadata=array_metadata) @@ -240,6 +246,14 @@ class JSONModel(JSONBase): dict_col: Mapped[dict[str, Any]] = mapped_column(JSONType, default=dict) +class HStoreModel(HStoreBase): + __tablename__ = "hstore_model" + + registry = Registry(metadata=hstore_metadata) + + hstore_col: Mapped[dict[str, str]] = mapped_column(postgresql.HSTORE, default=dict) + + class DateTimeModel(DateTimeBase): __tablename__ = "date_time_model" diff --git a/tests/integration/types/postgres.py b/tests/integration/types/postgres.py index 12c688a7..70a5d0d2 100644 --- a/tests/integration/types/postgres.py +++ b/tests/integration/types/postgres.py @@ -3,12 +3,12 @@ from collections.abc import Awaitable, Callable from typing import TYPE_CHECKING, Annotated, Any, TypeAlias, cast -import strawberry from pydantic import AfterValidator -from sqlalchemy import Select, select from strawberry.extensions.field_extension import FieldExtension from typing_extensions import override +import strawberry +from sqlalchemy import Select, select from strawchemy import ( Input, InputValidationError, @@ -27,6 +27,7 @@ DateTimeModel, Fruit, FruitFarm, + HStoreModel, IntervalModel, JSONModel, RankedUser, @@ -298,6 +299,17 @@ class JSONFilter: ... class JSONType: ... +# HStore type + + +@strawchemy.filter(HStoreModel, include="all") +class HStoreFilter: ... + + +@strawchemy.type(HStoreModel, include="all") +class HStoreType: ... + + # Date/Time @@ -507,6 +519,16 @@ class JSONSyncQuery: json: list[JSONType] = strawchemy.field(filter_input=JSONFilter, repository_type=StrawchemySyncRepository) +@strawberry.type +class HStoreAsyncQuery: + hstore: list[HStoreType] = strawchemy.field(filter_input=HStoreFilter, repository_type=StrawchemyAsyncRepository) + + +@strawberry.type +class HStoreSyncQuery: + hstore: list[HStoreType] = strawchemy.field(filter_input=HStoreFilter, repository_type=StrawchemySyncRepository) + + @strawberry.type class DateTimeAsyncQuery: date_times: list[DateTimeType] = strawchemy.field( From ff49b531d4cf7b1b429c96d72e551679643d44f3 Mon Sep 17 00:00:00 2001 From: Ckk3 Date: Tue, 3 Mar 2026 20:49:17 -0300 Subject: [PATCH 2/3] fix: replace untyped lambdas with named functions in hstore scalar to satisfy pyright --- src/strawchemy/repository/sqlalchemy/_sync.py | 2 +- src/strawchemy/schema/scalars/base.py | 13 +++++++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/strawchemy/repository/sqlalchemy/_sync.py b/src/strawchemy/repository/sqlalchemy/_sync.py index 63845f24..450cdb08 100644 --- a/src/strawchemy/repository/sqlalchemy/_sync.py +++ b/src/strawchemy/repository/sqlalchemy/_sync.py @@ -6,9 +6,9 @@ from inspect import isclass from typing import TYPE_CHECKING, Any, TypeVar -from sqlalchemy import ColumnElement, Row, and_, delete, inspect, select, update from sqlalchemy.orm import RelationshipProperty +from sqlalchemy import ColumnElement, Row, and_, delete, inspect, select, update from strawchemy.repository.sqlalchemy._base import InsertData, MutationData, SQLAlchemyGraphQLRepository from strawchemy.repository.typing import AnySyncSession, DeclarativeT from strawchemy.schema.mutation import RelationType, UpsertData diff --git a/src/strawchemy/schema/scalars/base.py b/src/strawchemy/schema/scalars/base.py index 260682f7..c2e64a9d 100644 --- a/src/strawchemy/schema/scalars/base.py +++ b/src/strawchemy/schema/scalars/base.py @@ -59,9 +59,18 @@ def _serialize(value: timedelta) -> str: parse_value=wrap_parser(datetime.fromisoformat, "DateTime"), ) + +def _serialize_hstore(value: dict[str, str]) -> dict[str, str]: + return dict(value) + + +def _parse_hstore(value: Any) -> dict[str, str]: + return {str(k): str(v) for k, v in value.items()} if isinstance(value, dict) else value + + HStore = scalar( new_type("HStore", dict), description="The `HStore` scalar type represents a PostgreSQL hstore value, a flat mapping of string keys to string values.", - serialize=lambda val: val, - parse_value=lambda val: {str(k): str(v) for k, v in val.items()} if isinstance(val, dict) else val, + serialize=_serialize_hstore, + parse_value=_parse_hstore, ) From d3ae8978e0eea5294d69f03d191ed3de04cfc66c Mon Sep 17 00:00:00 2001 From: Ckk3 Date: Thu, 5 Mar 2026 12:07:17 -0300 Subject: [PATCH 3/3] fix: address CodeRabbit review comments for HStore PR --- README.md | 2 +- src/strawchemy/schema/scalars/base.py | 9 ++++++--- tests/unit/test_scalars.py | 27 +++++++++++++++++++++++++++ 3 files changed, 34 insertions(+), 4 deletions(-) create mode 100644 tests/unit/test_scalars.py diff --git a/README.md b/README.md index 64cad17f..38b2c1cf 100644 --- a/README.md +++ b/README.md @@ -807,7 +807,7 @@ Strawchemy supports the following HStore filter operations: - **contains**: Filters for HStore values that contain the given key-value pairs - **containedIn**: Filters for HStore values that are contained within the given dict - **hasKey**: Filters for HStore values that have the given key -- **hasKeyAll**: Filters for HStore values that have all of the given keys +- **hasKeyAll**: Filters for HStore values that have all the given keys - **hasKeyAny**: Filters for HStore values that have any of the given keys ### Geo Filters diff --git a/src/strawchemy/schema/scalars/base.py b/src/strawchemy/schema/scalars/base.py index c2e64a9d..1778e0a3 100644 --- a/src/strawchemy/schema/scalars/base.py +++ b/src/strawchemy/schema/scalars/base.py @@ -5,9 +5,9 @@ from typing import TYPE_CHECKING, TypeVar from msgspec import json +from strawberry import scalar from strawberry.schema.types.base_scalars import wrap_parser -from strawberry import scalar from strawchemy.utils.annotation import new_type if TYPE_CHECKING: @@ -64,8 +64,11 @@ def _serialize_hstore(value: dict[str, str]) -> dict[str, str]: return dict(value) -def _parse_hstore(value: Any) -> dict[str, str]: - return {str(k): str(v) for k, v in value.items()} if isinstance(value, dict) else value +def _parse_hstore(value: object) -> dict[str, str]: + if not isinstance(value, dict): + msg = f"HStore value must be a dict, got {type(value).__name__}" + raise TypeError(msg) + return {str(k): str(v) for k, v in value.items()} HStore = scalar( diff --git a/tests/unit/test_scalars.py b/tests/unit/test_scalars.py new file mode 100644 index 00000000..e359787d --- /dev/null +++ b/tests/unit/test_scalars.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +import pytest + +from strawchemy.schema.scalars.base import _parse_hstore + + +def test_parse_hstore_valid_dict() -> None: + assert _parse_hstore({"key": "value"}) == {"key": "value"} + + +def test_parse_hstore_coerces_to_strings() -> None: + assert _parse_hstore({1: 2}) == {"1": "2"} + + +@pytest.mark.parametrize( + ("value", "type_name"), + [ + ("not a dict", "str"), + ([1, 2, 3], "list"), + (42, "int"), + (None, "NoneType"), + ], +) +def test_parse_hstore_rejects_non_dict(value: object, type_name: str) -> None: + with pytest.raises(TypeError, match=f"HStore value must be a dict, got {type_name}"): + _parse_hstore(value)