diff --git a/docs/filters.rst b/docs/filters.rst index ac36803..47cb872 100644 --- a/docs/filters.rst +++ b/docs/filters.rst @@ -152,7 +152,7 @@ this query will return all pets which have a person named "Ben" in their ``peopl } -and this one will return all pets which hvae a person list that contains exactly the people "Ada" and "Ben" and no fewer or people with other names. +and this one will return all pets which have a person list that contains exactly the people "Ada" and "Ben" and no fewer or people with other names. .. code:: diff --git a/examples/filters/database.py b/examples/filters/database.py index 8f6522f..c842c41 100644 --- a/examples/filters/database.py +++ b/examples/filters/database.py @@ -26,8 +26,8 @@ def init_db(): person1 = Person(name="A") person2 = Person(name="B") - pet1 = Pet(name="Spot") - pet2 = Pet(name="Milo") + pet1 = Pet(name="Spot", kind="dog") + pet2 = Pet(name="Milo", kind="cat") toy1 = Toy(name="disc") toy2 = Toy(name="ball") diff --git a/examples/filters/models.py b/examples/filters/models.py index 1b22956..f5d4bf9 100644 --- a/examples/filters/models.py +++ b/examples/filters/models.py @@ -1,14 +1,17 @@ import sqlalchemy from database import Base -from sqlalchemy import Column, ForeignKey, Integer, String +from sqlalchemy import Column, Enum, ForeignKey, Integer, String from sqlalchemy.orm import relationship +PetKind = Enum("cat", "dog", name="pet_kind") + class Pet(Base): __tablename__ = "pets" id = Column(Integer(), primary_key=True) name = Column(String(30)) age = Column(Integer()) + kind = Column(PetKind) person_id = Column(Integer(), ForeignKey("people.id")) diff --git a/graphene_sqlalchemy/filters.py b/graphene_sqlalchemy/filters.py index cbe3d09..335fe80 100644 --- a/graphene_sqlalchemy/filters.py +++ b/graphene_sqlalchemy/filters.py @@ -79,11 +79,7 @@ def __init_subclass_with_meta__( new_filter_fields = {} # Generate Graphene Fields from the filter functions based on type hints for field_name, _annotations in logic_functions: - assert ( - "val" in _annotations - ), "Each filter method must have a value field with valid type annotations" # If type is generic, replace with actual type of filter class - replace_type_vars = {BaseTypeFilterSelf: cls} field_type = convert_sqlalchemy_type( _annotations.get("val", str), replace_type_vars=replace_type_vars @@ -170,7 +166,7 @@ def execute_filters( field_filter_type = input_field.type else: field_filter_type = cls._meta.fields[field].type - # raise Exception + # TODO we need to save the relationship props in the meta fields array # to conduct joins and alias the joins (in case there are duplicate joins: A->B A->C B->C) if field == "and": @@ -252,9 +248,11 @@ def __init_subclass_with_meta__(cls, graphene_type=None, _meta=None, **options): new_filter_fields = {} # Generate Graphene Fields from the filter functions based on type hints for field_name, _annotations in filter_functions: - assert ( - "val" in _annotations - ), "Each filter method must have a value field with valid type annotations" + if "val" not in _annotations: + raise TypeError( + "Each filter method must have a 'val' field with valid type annotations." + ) + # If type is generic, replace with actual type of filter class replace_type_vars = {ScalarFilterInputType: _meta.graphene_type} field_type = convert_sqlalchemy_type( @@ -309,10 +307,7 @@ def execute_filters( class SQLEnumFilter(FieldFilter): - """Basic Filter for Scalars in Graphene. - We want this filter to use Dynamic fields so it provides the base - filtering methods ("eq, nEq") for different types of scalars. - The Dynamic fields will resolve to Meta.filtered_type""" + """Basic Filter for SQL Enums in Graphene.""" class Meta: graphene_type = graphene.Enum @@ -332,10 +327,7 @@ def n_eq_filter( class PyEnumFilter(FieldFilter): - """Basic Filter for Scalars in Graphene. - We want this filter to use Dynamic fields so it provides the base - filtering methods ("eq, nEq") for different types of scalars. - The Dynamic fields will resolve to Meta.filtered_type""" + """Basic Filter for Python Enums in Graphene.""" class Meta: graphene_type = graphene.Enum @@ -441,7 +433,8 @@ def __init_subclass_with_meta__( cls, base_type_filter=None, model=None, _meta=None, **options ): if not base_type_filter: - raise Exception("Relationship Filters must be specific to an object type") + raise TypeError("Relationship Filters must be specific to an object type.") + # Init meta options class if it doesn't exist already if not _meta: _meta = InputObjectTypeOptions(cls) @@ -453,9 +446,6 @@ def __init_subclass_with_meta__( # Generate Graphene Fields from the filter functions based on type hints for field_name, _annotations in filter_functions: - assert ( - "val" in _annotations - ), "Each filter method must have a value field with valid type annotations" # If type is generic, replace with actual type of filter class if is_list(_annotations["val"]): relationship_filters.update( diff --git a/graphene_sqlalchemy/registry.py b/graphene_sqlalchemy/registry.py index b959d22..09e23a4 100644 --- a/graphene_sqlalchemy/registry.py +++ b/graphene_sqlalchemy/registry.py @@ -181,19 +181,6 @@ def get_filter_for_scalar_type( return filter_type - # TODO register enums automatically - def register_filter_for_enum_type( - self, enum_type: Type[graphene.Enum], filter_obj: Type["FieldFilter"] - ): - from .filters import FieldFilter - - if not issubclass(enum_type, graphene.Enum): - raise TypeError("Expected Enum, but got: {!r}".format(enum_type)) - - if not issubclass(filter_obj, FieldFilter): - raise TypeError("Expected FieldFilter, but got: {!r}".format(filter_obj)) - self._registry_scalar_filters[enum_type] = filter_obj - # Filter Base Types def register_filter_for_base_type( self, diff --git a/graphene_sqlalchemy/tests/test_filters.py b/graphene_sqlalchemy/tests/test_filters.py index 4acf89a..dd5ad8a 100644 --- a/graphene_sqlalchemy/tests/test_filters.py +++ b/graphene_sqlalchemy/tests/test_filters.py @@ -5,7 +5,7 @@ from graphene import Connection, relay from ..fields import SQLAlchemyConnectionField -from ..filters import FloatFilter +from ..filters import FloatFilter, RelationshipFilter from ..types import ORMField, SQLAlchemyObjectType from .models import ( Article, @@ -1199,3 +1199,33 @@ async def test_additional_filters(session): schema = graphene.Schema(query=Query) result = await schema.execute_async(query, context_value={"session": session}) assert_and_raise_result(result, expected) + + +# Test that exceptions are called correctly +@pytest.mark.asyncio +async def test_filter_relationship_no_base_type(session): + with pytest.raises( + TypeError, + match=r"(.*)Relationship Filters must be specific to an object type.(.*)", + ): + RelationshipFilter.create_type( + "InvalidRelationshipFilter", base_type_filter=None, model=Article + ) + + +@pytest.mark.asyncio +async def test_filter_invalid_filter_method(session): + + # Field filter + with pytest.raises( + TypeError, + match=r"(.*)Each filter method must have a 'val' field with valid type annotations.(.*)", + ): + + class InvalidFieldFilter(FloatFilter): + class Meta: + graphene_type = graphene.Float + + @classmethod + def invalid_filter(cls, query, field) -> bool: + return False diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 7053988..533e348 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -135,19 +135,15 @@ def get_or_create_relationship_filter( relationship_filter = registry.get_relationship_filter_for_base_type(base_type) if not relationship_filter: - try: - base_type_filter = registry.get_filter_for_base_type(base_type) - relationship_filter = RelationshipFilter.create_type( - f"{base_type.__name__}RelationshipFilter", - base_type_filter=base_type_filter, - model=base_type._meta.model, - ) - registry.register_relationship_filter_for_base_type( - base_type, relationship_filter - ) - except Exception as e: - print("e") - raise e + base_type_filter = registry.get_filter_for_base_type(base_type) + relationship_filter = RelationshipFilter.create_type( + f"{base_type.__name__}RelationshipFilter", + base_type_filter=base_type_filter, + model=base_type._meta.model, + ) + registry.register_relationship_filter_for_base_type( + base_type, relationship_filter + ) return relationship_filter