Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -706,13 +706,110 @@ Strawchemy supports a wide range of filter operations:
| **Numeric types (Int, Float, Decimal)** | `gt`, `gte`, `lt`, `lte` |
| **String** | order filter, plus `like`, `nlike`, `ilike`, `nilike`, `regexp`, `iregexp`, `nregexp`, `inregexp`, `startswith`, `endswith`, `contains`, `istartswith`, `iendswith`, `icontains` |
| **JSON** | `contains`, `containedIn`, `hasKey`, `hasKeyAll`, `hasKeyAny` |
| **HStore** (PostgreSQL) | `contains`, `containedIn`, `hasKey`, `hasKeyAll`, `hasKeyAny` |
| **Array** | `contains`, `containedIn`, `overlap` |
| **Date** | order filters on plain dates, plus `year`, `month`, `day`, `weekDay`, `week`, `quarter`, `isoYear` and `isoWeekDay` filters |
| **DateTime** | All Date filters plus `hour`, `minute`, `second` |
| **Time** | order filters on plain times, plus `hour`, `minute` and `second` filters |
| **Interval** | order filters on plain intervals, plus `days`, `hours`, `minutes` and `seconds` filters |
| **Logical** | `_and`, `_or`, `_not` |

### HStore Filters

Strawchemy supports filtering on PostgreSQL [`hstore`](https://www.postgresql.org/docs/current/hstore.html) columns.
To use HStore filters, the `hstore` extension must be enabled in your PostgreSQL database.

<details>
<summary>HStore filters example</summary>

Define models and types:

```python
from sqlalchemy import MetaData
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column

class Base(DeclarativeBase):
metadata = MetaData()

class Config(Base):
__tablename__ = "config"

id: Mapped[int] = mapped_column(primary_key=True)
settings: Mapped[dict[str, str]] = mapped_column(postgresql.HSTORE, default=dict)


@strawchemy.type(Config, include="all")
class ConfigType: ...


@strawchemy.filter(Config, include="all")
class ConfigFilter: ...


@strawberry.type
class Query:
configs: list[ConfigType] = strawchemy.field(filter_input=ConfigFilter)
```

**Important:** When creating your Strawberry schema, add `HSTORE_SCALAR_OVERRIDES` so that `dict[str, str]` is
correctly mapped to the `HStore` GraphQL scalar:

```python
from strawchemy.schema.scalars import HSTORE_SCALAR_OVERRIDES

schema = strawberry.Schema(
query=Query,
scalar_overrides={**HSTORE_SCALAR_OVERRIDES},
)
```

Then you can use the following HStore filter operations in your GraphQL queries:

```graphql
{
# Find configs where settings contain a specific key-value pair
configs(filter: { settings: { contains: { theme: "dark" } } }) {
id
settings
}

# Find configs where settings are contained within the given dict
configs(filter: { settings: { containedIn: { theme: "dark", lang: "en", mode: "advanced" } } }) {
id
settings
}

# Find configs that have a specific key
configs(filter: { settings: { hasKey: "theme" } }) {
id
settings
}

# Find configs that have all specified keys
configs(filter: { settings: { hasKeyAll: ["theme", "lang"] } }) {
id
settings
}

# Find configs that have any of the specified keys
configs(filter: { settings: { hasKeyAny: ["theme", "notifications"] } }) {
id
settings
}
}
```

</details>

Strawchemy supports the following HStore filter operations:

- **contains**: Filters for HStore values that contain the given key-value pairs
- **containedIn**: Filters for HStore values that are contained within the given dict
- **hasKey**: Filters for HStore values that have the given key
- **hasKeyAll**: Filters for HStore values that have all the given keys
- **hasKeyAny**: Filters for HStore values that have any of the given keys

### Geo Filters

Strawchemy supports spatial filtering capabilities for geometry fields
Expand Down
31 changes: 17 additions & 14 deletions src/strawchemy/dto/inspectors/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,6 @@
from inspect import getmodule, signature
from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast, get_args, get_origin, get_type_hints

from sqlalchemy import (
ARRAY,
Column,
ColumnElement,
PrimaryKeyConstraint,
Sequence,
SQLColumnExpression,
Table,
UniqueConstraint,
event,
inspect,
orm,
sql,
)
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import (
NO_VALUE,
Expand All @@ -40,6 +26,20 @@
)
from typing_extensions import TypeIs, override

from sqlalchemy import (
ARRAY,
Column,
ColumnElement,
PrimaryKeyConstraint,
Sequence,
SQLColumnExpression,
Table,
UniqueConstraint,
event,
inspect,
orm,
sql,
)
from strawchemy.config.databases import DatabaseFeatures
from strawchemy.constants import GEO_INSTALLED
from strawchemy.dto.base import TYPING_NS, DTOFieldDefinition, Relation
Expand All @@ -58,6 +58,7 @@
TimeComparison,
TimeDeltaComparison,
make_full_json_comparison_input,
make_hstore_comparison_input,
make_sqlite_json_comparison_input,
)
from strawchemy.utils.annotation import is_type_hint_optional
Expand Down Expand Up @@ -599,6 +600,8 @@ def get_field_comparison(
field_type = field_definition.model_field.type
if isinstance(field_type, ARRAY) and self.db_features.dialect == "postgresql":
return ArrayComparison[field_type.item_type.python_type]
if isinstance(field_type, postgresql.HSTORE) and self.db_features.dialect == "postgresql":
return make_hstore_comparison_input()
return self.get_type_comparison(self.model_field_type(field_definition))

def get_type_comparison(self, type_: type[Any]) -> type[GraphQLComparison]:
Expand Down
2 changes: 1 addition & 1 deletion src/strawchemy/repository/sqlalchemy/_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from inspect import isclass
from typing import TYPE_CHECKING, Any, TypeVar

from sqlalchemy import ColumnElement, Row, and_, delete, inspect, select, update
from sqlalchemy.orm import RelationshipProperty

from sqlalchemy import ColumnElement, Row, and_, delete, inspect, select, update
from strawchemy.repository.sqlalchemy._base import InsertData, MutationData, SQLAlchemyGraphQLRepository
from strawchemy.repository.typing import AnySyncSession, DeclarativeT
from strawchemy.schema.mutation import RelationType, UpsertData
Expand Down
6 changes: 6 additions & 0 deletions src/strawchemy/schema/filters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
DateTimeFilter,
EqualityFilter,
FilterProtocol,
HStoreFilter,
JSONFilter,
OrderFilter,
TextFilter,
Expand All @@ -26,9 +27,11 @@
TextComparison,
TimeComparison,
TimeDeltaComparison,
_HStoreComparison,
_JSONComparison,
_SQLiteJSONComparison,
make_full_json_comparison_input,
make_hstore_comparison_input,
make_sqlite_json_comparison_input,
)

Expand All @@ -48,6 +51,7 @@
"GraphQLComparison",
"GraphQLComparisonT",
"GraphQLFilter",
"HStoreFilter",
"JSONFilter",
"OrderComparison",
"OrderFilter",
Expand All @@ -57,8 +61,10 @@
"TimeDeltaComparison",
"TimeDeltaFilter",
"TimeFilter",
"_HStoreComparison",
"_JSONComparison",
"_SQLiteJSONComparison",
"make_full_json_comparison_input",
"make_hstore_comparison_input",
"make_sqlite_json_comparison_input",
)
32 changes: 29 additions & 3 deletions src/strawchemy/schema/filters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, ClassVar, Protocol, cast

from sqlalchemy import ARRAY, JSON, ColumnElement, Dialect, Integer, Text, and_, func, not_, null, or_, type_coerce
from sqlalchemy import cast as sqla_cast
from sqlalchemy.dialects import mysql
from sqlalchemy.dialects import postgresql as pg
from strawberry import UNSET
from typing_extensions import override

from sqlalchemy import ARRAY, JSON, ColumnElement, Dialect, Integer, Text, and_, func, not_, null, or_, type_coerce
from sqlalchemy import cast as sqla_cast
from strawberry import UNSET

if TYPE_CHECKING:
from datetime import date, timedelta

Expand All @@ -25,6 +26,7 @@
TextComparison,
TimeComparison,
TimeDeltaComparison,
_HStoreComparison,
_JSONComparison,
)

Expand Down Expand Up @@ -241,6 +243,30 @@ def to_expressions(
return expressions


@dataclass(frozen=True)
class HStoreFilter(EqualityFilter):
comparison: _HStoreComparison

@override
def to_expressions(
self, dialect: Dialect, model_attribute: QueryableAttribute[Any] | ColumnElement[Any]
) -> list[ColumnElement[bool]]:
expressions: list[ColumnElement[bool]] = super().to_expressions(dialect, model_attribute)
as_hstore = type_coerce(model_attribute, pg.HSTORE)

if self.comparison.contains is not UNSET:
expressions.append(as_hstore.contains(self.comparison.contains))
if self.comparison.contained_in is not UNSET:
expressions.append(as_hstore.contained_by(self.comparison.contained_in))
if self.comparison.has_key is not UNSET:
expressions.append(as_hstore.has_key(self.comparison.has_key))
if self.comparison.has_key_all is not UNSET:
expressions.append(as_hstore.has_all(sqla_cast(self.comparison.has_key_all, pg.ARRAY(Text))))
if self.comparison.has_key_any is not UNSET:
expressions.append(as_hstore.has_any(sqla_cast(self.comparison.has_key_any, pg.ARRAY(Text))))
return expressions


@dataclass(frozen=True)
class ArrayFilter(EqualityFilter):
comparison: ArrayComparison[Any]
Expand Down
39 changes: 36 additions & 3 deletions src/strawchemy/schema/filters/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@

import strawberry
from strawberry import UNSET, Private

from strawchemy.schema.filters import (
ArrayFilter,
DateFilter,
DateTimeFilter,
EqualityFilter,
FilterProtocol,
HStoreFilter,
JSONFilter,
OrderFilter,
TextFilter,
Expand All @@ -32,9 +32,9 @@
from strawchemy.typing import QueryNodeType

if TYPE_CHECKING:
from sqlalchemy import ColumnElement, Dialect
from sqlalchemy.orm import QueryableAttribute

from sqlalchemy import ColumnElement, Dialect
from strawchemy.dto.strawberry import OrderByEnum

__all__ = (
Expand All @@ -46,15 +46,17 @@
"TextComparison",
"TimeComparison",
"TimeDeltaComparison",
"_HStoreComparison",
"_JSONComparison",
"make_full_json_comparison_input",
"make_hstore_comparison_input",
"make_sqlite_json_comparison_input",
)

T = TypeVar("T")
GraphQLComparisonT = TypeVar("GraphQLComparisonT", bound="GraphQLComparison")
GraphQLFilter: TypeAlias = "GraphQLComparison | OrderByEnum"
AnyGraphQLComparison: TypeAlias = "EqualityComparison[Any] | OrderComparison[Any] | TextComparison | DateComparison | TimeComparison | DateTimeComparison | TimeDeltaComparison | ArrayComparison[Any] | _JSONComparison | _SQLiteJSONComparison"
AnyGraphQLComparison: TypeAlias = "EqualityComparison[Any] | OrderComparison[Any] | TextComparison | DateComparison | TimeComparison | DateTimeComparison | TimeDeltaComparison | ArrayComparison[Any] | _JSONComparison | _SQLiteJSONComparison | _HStoreComparison"
AnyOrderGraphQLComparison: TypeAlias = (
"OrderComparison[Any] | TextComparison | DateComparison | TimeComparison | DateTimeComparison | TimeDeltaComparison"
)
Expand Down Expand Up @@ -306,6 +308,30 @@ class _JSONComparison(EqualityComparison[dict[str, Any]]):
has_key_any: list[str] | None = UNSET


class _HStoreComparison(EqualityComparison[dict[str, str]]):
"""HStore comparison class for GraphQL filters.

This class provides a set of HStore comparison operators that can be
used to filter data based on containment, key existence, and other
HStore-specific properties. PostgreSQL only.

Attributes:
contains: Filters for HStore values that contain this key-value pair.
contained_in: Filters for HStore values that are contained in this dict.
has_key: Filters for HStore values that have this key.
has_key_all: Filters for HStore values that have all of these keys.
has_key_any: Filters for HStore values that have any of these keys.
"""

__strawchemy_filter__ = HStoreFilter

contains: dict[str, str] | None = UNSET
contained_in: dict[str, str] | None = UNSET
has_key: str | None = UNSET
has_key_all: list[str] | None = UNSET
has_key_any: list[str] | None = UNSET


class _SQLiteJSONComparison(EqualityComparison[dict[str, Any]]):
"""JSON comparison class for GraphQL filters.

Expand Down Expand Up @@ -340,3 +366,10 @@ def make_sqlite_json_comparison_input() -> type[_SQLiteJSONComparison]:
return strawberry.input(name="JSONComparison", description=_DESCRIPTION.format(field="JSON fields"))(
_SQLiteJSONComparison
)


@cache
def make_hstore_comparison_input() -> type[_HStoreComparison]:
return strawberry.input(name="HStoreComparison", description=_DESCRIPTION.format(field="HStore fields"))(
_HStoreComparison
)
10 changes: 8 additions & 2 deletions src/strawchemy/schema/scalars/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from __future__ import annotations

from strawchemy.schema.scalars.base import Date, DateTime, Interval, Time
from typing import Any

__all__ = ("Date", "DateTime", "Interval", "Time")
from strawchemy.schema.scalars.base import Date, DateTime, HStore, Interval, Time

__all__ = ("HSTORE_SCALAR_OVERRIDES", "Date", "DateTime", "HStore", "Interval", "Time")

HSTORE_SCALAR_OVERRIDES: dict[object, type[Any]] = {
dict[str, str]: HStore,
}
Loading
Loading