diff --git a/.gitignore b/.gitignore index e4070f31..f20e3a02 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ __pycache__/ # Distribution / packaging .Python +.venv/ env/ build/ develop-eggs/ diff --git a/examples/flask_sqlalchemy/app.py b/examples/flask_sqlalchemy/app.py index a4d3f29e..ee933382 100755 --- a/examples/flask_sqlalchemy/app.py +++ b/examples/flask_sqlalchemy/app.py @@ -1,12 +1,11 @@ #!/usr/bin/env python +from database import db_session, init_db from flask import Flask +from schema import schema from flask_graphql import GraphQLView -from .database import db_session, init_db -from .schema import schema - app = Flask(__name__) app.debug = True diff --git a/examples/flask_sqlalchemy/database.py b/examples/flask_sqlalchemy/database.py index 01e76ca6..ca4d4122 100644 --- a/examples/flask_sqlalchemy/database.py +++ b/examples/flask_sqlalchemy/database.py @@ -14,7 +14,7 @@ def init_db(): # import all modules here that might define models so that # they will be registered properly on the metadata. Otherwise # you will have to import them first before calling init_db() - from .models import Department, Employee, Role + from models import Department, Employee, Role Base.metadata.drop_all(bind=engine) Base.metadata.create_all(bind=engine) diff --git a/examples/flask_sqlalchemy/models.py b/examples/flask_sqlalchemy/models.py index e164c015..efbbe690 100644 --- a/examples/flask_sqlalchemy/models.py +++ b/examples/flask_sqlalchemy/models.py @@ -1,8 +1,7 @@ +from database import Base from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, func from sqlalchemy.orm import backref, relationship -from .database import Base - class Department(Base): __tablename__ = 'department' diff --git a/examples/flask_sqlalchemy/schema.py b/examples/flask_sqlalchemy/schema.py index cbee081c..6c403002 100644 --- a/examples/flask_sqlalchemy/schema.py +++ b/examples/flask_sqlalchemy/schema.py @@ -1,12 +1,12 @@ +from models import Department as DepartmentModel +from models import Employee as EmployeeModel +from models import Role as RoleModel + import graphene from graphene import relay from graphene_sqlalchemy import (SQLAlchemyConnectionField, SQLAlchemyObjectType, utils) -from .models import Department as DepartmentModel -from .models import Employee as EmployeeModel -from .models import Role as RoleModel - class Department(SQLAlchemyObjectType): class Meta: @@ -26,8 +26,7 @@ class Meta: interfaces = (relay.Node, ) -SortEnumEmployee = utils.sort_enum_for_model(EmployeeModel, 'SortEnumEmployee', - lambda c, d: c.upper() + ('_ASC' if d else '_DESC')) +SortEnumEmployee = utils.get_sort_enum_for_model(EmployeeModel) class Query(graphene.ObjectType): diff --git a/graphene_sqlalchemy/__init__.py b/graphene_sqlalchemy/__init__.py index eee98090..b83cecf7 100644 --- a/graphene_sqlalchemy/__init__.py +++ b/graphene_sqlalchemy/__init__.py @@ -1,5 +1,5 @@ -from .types import SQLAlchemyObjectType from .fields import SQLAlchemyConnectionField +from .types import SQLAlchemyObjectType from .utils import get_query, get_session __version__ = "2.1.1" diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index 7cc259e0..9968432f 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -7,6 +7,8 @@ String) from graphene.types.json import JSONString +from .registry import get_global_registry + try: from sqlalchemy_utils import ChoiceType, JSONType, ScalarListType, TSVectorType except ImportError: @@ -62,8 +64,6 @@ def convert_sqlalchemy_composite(composite, registry): def _register_composite_class(cls, registry=None): if registry is None: - from .registry import get_global_registry - registry = get_global_registry() def inner(fn): @@ -145,21 +145,16 @@ def convert_column_to_float(type, column, registry=None): @convert_sqlalchemy_type.register(types.Enum) def convert_enum_to_enum(type, column, registry=None): - enum_class = getattr(type, 'enum_class', None) - if enum_class: # Check if an enum.Enum type is used - graphene_type = Enum.from_enum(enum_class) - else: # Nope, just a list of string options - items = zip(type.enums, type.enums) - graphene_type = Enum(type.name, items) - return Field( - graphene_type, - description=get_column_doc(column), - required=not (is_column_nullable(column)), - ) + if registry is None: + registry = get_global_registry() + graphene_type = registry.get_type_for_enum(type) + return Field(graphene_type, + description=get_column_doc(column), + required=not(is_column_nullable(column))) @convert_sqlalchemy_type.register(ChoiceType) -def convert_column_to_enum(type, column, registry=None): +def convert_choice_to_enum(type, column, registry=None): name = "{}_{}".format(column.table.name, column.name).upper() return Enum(name, type.choices, description=get_column_doc(column)) diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index 4a46b749..c05fa370 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -8,7 +8,7 @@ from graphene.relay.connection import PageInfo from graphql_relay.connection.arrayconnection import connection_from_list_slice -from .utils import get_query, sort_argument_for_model +from .utils import get_query, get_sort_argument_for_model log = logging.getLogger() @@ -85,7 +85,7 @@ def __init__(self, type, *args, **kwargs): # Let super class raise if type is not a Connection try: model = type.Edge.node._type._meta.model - kwargs.setdefault("sort", sort_argument_for_model(model)) + kwargs.setdefault("sort", get_sort_argument_for_model(model)) except Exception: raise Exception( 'Cannot create sort argument for {}. A model is required. Set the "sort" argument' diff --git a/graphene_sqlalchemy/registry.py b/graphene_sqlalchemy/registry.py index 460053f2..d68581ba 100644 --- a/graphene_sqlalchemy/registry.py +++ b/graphene_sqlalchemy/registry.py @@ -1,22 +1,83 @@ +from collections import OrderedDict + +from sqlalchemy.types import Enum as SQLAlchemyEnumType + +from graphene import Enum + +from .utils import to_enum_value_name, to_type_name + + class Registry(object): - def __init__(self): + def __init__(self, check_duplicate_registration=False): + self.check_duplicate_registration = check_duplicate_registration self._registry = {} self._registry_models = {} self._registry_composites = {} + self._registry_enums = {} + self._registry_sort_params = {} def register(self, cls): from .types import SQLAlchemyObjectType - assert issubclass(cls, SQLAlchemyObjectType), ( - "Only classes of type SQLAlchemyObjectType can be registered, " - 'received "{}"' - ).format(cls.__name__) - assert cls._meta.registry == self, "Registry for a Model have to match." - # assert self.get_type_for_model(cls._meta.model) in [None, cls], ( - # 'SQLAlchemy model "{}" already associated with ' - # 'another type "{}".' - # ).format(cls._meta.model, self._registry[cls._meta.model]) - self._registry[cls._meta.model] = cls + if not issubclass(cls, SQLAlchemyObjectType): + raise TypeError( + "Only classes of type SQLAlchemyObjectType can be registered, " + 'received "{}"'.format(cls.__name__) + ) + if cls._meta.registry != self: + raise TypeError("Registry for a Model have to match.") + + registered_cls = ( + self._registry.get(cls._meta.model) + if self.check_duplicate_registration + else None + ) + if registered_cls: + if cls != registered_cls: + raise TypeError( + "Different object types registered for the same model {}:" + " tried to register {}, but {} existed already.".format( + cls._meta.model, cls, registered_cls + ) + ) + else: + self._registry[cls._meta.model] = cls + + def register_enum(self, name, members): + graphene_enum = self._registry_enums.get(name) + if graphene_enum: + registered_members = { + key: value.value + for key, value in graphene_enum._meta.enum.__members__.items() + } + if members != registered_members: + raise TypeError( + 'Different enums with the same name "{}":' + " tried to register {}, but {} existed already.".format( + name, members, registered_members + ) + ) + else: + graphene_enum = Enum(name, members) + self._registry_enums[name] = graphene_enum + return graphene_enum + + def register_sort_params(self, cls, sort_params): + registered_sort_params = ( + self._registry_sort_params.get(cls) + if self.check_duplicate_registration + else None + ) + if registered_sort_params: + if registered_sort_params != sort_params: + raise TypeError( + "Different sort args for the same model {}:" + " tried to register {}, but {} existed already.".format( + cls, sort_params, registered_sort_params + ) + ) + else: + self._registry_sort_params[cls] = sort_params def get_type_for_model(self, model): return self._registry.get(model) @@ -27,6 +88,34 @@ def register_composite_converter(self, composite, converter): def get_converter_for_composite(self, composite): return self._registry_composites.get(composite) + def get_type_for_enum(self, sql_type): + if not isinstance(sql_type, SQLAlchemyEnumType): + raise TypeError( + "Only sqlalchemy.Enum objects can be registered as enum, " + 'received "{}"'.format(sql_type) + ) + enum_class = sql_type.enum_class + if enum_class: + name = enum_class.__name__ + members = OrderedDict( + (to_enum_value_name(key), value.value) + for key, value in enum_class.__members__.items() + ) + else: + name = sql_type.name + name = ( + to_type_name(name) + if name + else "Enum{}".format(len(self._registry_enums) + 1) + ) + members = OrderedDict( + (to_enum_value_name(key), key) for key in sql_type.enums + ) + return self.register_enum(name, members) + + def get_sort_params_for_model(self, model): + return self._registry_sort_params.get(model) + registry = None diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index 3ba23a8a..12781cc5 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -6,8 +6,10 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import mapper, relationship +PetKind = Enum("cat", "dog", name="pet_kind") -class Hairkind(enum.Enum): + +class HairKind(enum.Enum): LONG = 'long' SHORT = 'short' @@ -32,8 +34,8 @@ class Pet(Base): __tablename__ = "pets" id = Column(Integer(), primary_key=True) name = Column(String(30)) - pet_kind = Column(Enum("cat", "dog", name="pet_kind"), nullable=False) - hair_kind = Column(Enum(Hairkind, name="hair_kind"), nullable=False) + pet_kind = Column(PetKind, nullable=False) + hair_kind = Column(Enum(HairKind, name="hair_kind"), nullable=False) reporter_id = Column(Integer(), ForeignKey("reporters.id")) @@ -43,6 +45,7 @@ class Reporter(Base): first_name = Column(String(30)) last_name = Column(String(30)) email = Column(String()) + favorite_pet_kind = Column(PetKind) pets = relationship("Pet", secondary=association_table, backref="reporters") articles = relationship("Article", backref="reporter") favorite_article = relationship("Article", uselist=False) diff --git a/graphene_sqlalchemy/tests/test_converter.py b/graphene_sqlalchemy/tests/test_converter.py index 5cc16e79..0492218a 100644 --- a/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene_sqlalchemy/tests/test_converter.py @@ -86,19 +86,37 @@ def test_should_unicodetext_convert_string(): def test_should_enum_convert_enum(): - field = assert_column_conversion( - types.Enum(enum.Enum("one", "two")), graphene.Field - ) + field = assert_column_conversion(types.Enum("one", "two"), graphene.Field) field_type = field.type() + assert field_type.__class__.__name__.startswith("Enum") assert isinstance(field_type, graphene.Enum) - assert hasattr(field_type, "two") + assert hasattr(field_type, "ONE") + assert not hasattr(field_type, "one") + assert hasattr(field_type, "TWO") + field = assert_column_conversion( types.Enum("one", "two", name="two_numbers"), graphene.Field ) field_type = field.type() - assert field_type.__class__.__name__ == "two_numbers" + assert field_type.__class__.__name__ == "TwoNumbers" + assert isinstance(field_type, graphene.Enum) + assert hasattr(field_type, "ONE") + assert not hasattr(field_type, "one") + assert hasattr(field_type, "TWO") + + +def test_conflicting_enum_should_raise_error(): + some_type = types.Enum(enum.Enum("ConflictingEnum", "cat cow")) + field = assert_column_conversion(some_type, graphene.Field) + field_type = field.type() assert isinstance(field_type, graphene.Enum) - assert hasattr(field_type, "two") + assert hasattr(field_type, "COW") + same_type = types.Enum(enum.Enum("ConflictingEnum", "cat cow")) + field = assert_column_conversion(same_type, graphene.Field) + assert field_type == field.type() + conflicting_type = types.Enum(enum.Enum("ConflictingEnum", "cat horse")) + with raises(TypeError): + assert_column_conversion(conflicting_type, graphene.Field) def test_should_small_integer_convert_int(): @@ -277,19 +295,20 @@ def test_should_postgresql_enum_convert(): postgresql.ENUM("one", "two", name="two_numbers"), graphene.Field ) field_type = field.type() - assert field_type.__class__.__name__ == "two_numbers" + assert field_type.__class__.__name__ == "TwoNumbers" assert isinstance(field_type, graphene.Enum) - assert hasattr(field_type, "two") + assert hasattr(field_type, "TWO") def test_should_postgresql_py_enum_convert(): field = assert_column_conversion( - postgresql.ENUM(enum.Enum("TwoNumbers", "one two"), name="two_numbers"), graphene.Field + postgresql.ENUM(enum.Enum("TwoNumbersEnum", "one two"), name="two_numbers"), + graphene.Field, ) field_type = field.type() - assert field_type.__class__.__name__ == "TwoNumbers" + assert field_type.__class__.__name__ == "TwoNumbersEnum" assert isinstance(field_type, graphene.Enum) - assert hasattr(field_type, "two") + assert hasattr(field_type, "TWO") def test_should_postgresql_array_convert(): @@ -309,7 +328,7 @@ def test_should_postgresql_hstore_convert(): def test_should_composite_convert(): - class CompositeClass(object): + class CompositeClass: def __init__(self, col1, col2): self.col1 = col1 self.col2 = col2 diff --git a/graphene_sqlalchemy/tests/test_fields.py b/graphene_sqlalchemy/tests/test_fields.py index ff616b30..a94bfa7f 100644 --- a/graphene_sqlalchemy/tests/test_fields.py +++ b/graphene_sqlalchemy/tests/test_fields.py @@ -4,7 +4,7 @@ from ..fields import SQLAlchemyConnectionField from ..types import SQLAlchemyObjectType -from ..utils import sort_argument_for_model +from ..utils import get_sort_argument_for_model from .models import Editor from .models import Pet as PetModel @@ -22,7 +22,7 @@ class Meta: def test_sort_added_by_default(): arg = SQLAlchemyConnectionField(PetConn) assert "sort" in arg.args - assert arg.args["sort"] == sort_argument_for_model(PetModel) + assert arg.args["sort"] == get_sort_argument_for_model(PetModel) def test_sort_can_be_removed(): @@ -31,8 +31,8 @@ def test_sort_can_be_removed(): def test_custom_sort(): - arg = SQLAlchemyConnectionField(PetConn, sort=sort_argument_for_model(Editor)) - assert arg.args["sort"] == sort_argument_for_model(Editor) + arg = SQLAlchemyConnectionField(PetConn, sort=get_sort_argument_for_model(Editor)) + assert arg.args["sort"] == get_sort_argument_for_model(Editor) def test_init_raises(): diff --git a/graphene_sqlalchemy/tests/test_query.py b/graphene_sqlalchemy/tests/test_query.py index 146c54e6..aacfafb5 100644 --- a/graphene_sqlalchemy/tests/test_query.py +++ b/graphene_sqlalchemy/tests/test_query.py @@ -8,10 +8,10 @@ from ..fields import SQLAlchemyConnectionField from ..registry import reset_global_registry from ..types import SQLAlchemyObjectType -from ..utils import sort_argument_for_model, sort_enum_for_model -from .models import Article, Base, Editor, Hairkind, Pet, Reporter +from ..utils import get_sort_argument_for_model, get_sort_enum_for_model +from .models import Article, Base, Editor, HairKind, Pet, Reporter -db = create_engine("sqlite:///test_sqlalchemy.sqlite3") +db = create_engine("sqlite://") # use in-memory database @pytest.yield_fixture(scope="function") @@ -34,16 +34,22 @@ def session(): def setup_fixtures(session): - pet = Pet(name="Lassie", pet_kind="dog", hair_kind=Hairkind.LONG) - session.add(pet) - reporter = Reporter(first_name="ABA", last_name="X") + reporter = Reporter( + first_name='John', last_name='Doe', favorite_pet_kind='cat') session.add(reporter) - reporter2 = Reporter(first_name="ABO", last_name="Y") - session.add(reporter2) - article = Article(headline="Hi!") + pet = Pet(name='Garfield', pet_kind='cat', hair_kind=HairKind.SHORT) + session.add(pet) + pet.reporters.append(reporter) + article = Article(headline='Hi!') article.reporter = reporter session.add(article) - editor = Editor(name="John") + reporter = Reporter( + first_name='Jane', last_name='Roe', favorite_pet_kind='dog') + session.add(reporter) + pet = Pet(name='Lassie', pet_kind='dog', hair_kind=HairKind.LONG) + pet.reporters.append(reporter) + session.add(pet) + editor = Editor(name="Jack") session.add(editor) session.commit() @@ -51,6 +57,10 @@ def setup_fixtures(session): def test_should_query_well(session): setup_fixtures(session) + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + class ReporterType(SQLAlchemyObjectType): class Meta: model = Reporter @@ -58,33 +68,68 @@ class Meta: class Query(graphene.ObjectType): reporter = graphene.Field(ReporterType) reporters = graphene.List(ReporterType) + pets = graphene.List(PetType, kind=graphene.Argument( + PetType._meta.fields['pet_kind'].type)) - def resolve_reporter(self, *args, **kwargs): + def resolve_reporter(self, _info): return session.query(Reporter).first() - def resolve_reporters(self, *args, **kwargs): + def resolve_reporters(self, _info): return session.query(Reporter) + def resolve_pets(self, _info, kind): + query = session.query(Pet) + if kind: + query = query.filter_by(pet_kind=kind) + return query + query = """ query ReporterQuery { reporter { firstName, lastName, - email + email, + favoritePetKind, + pets { + name + petKind + } } reporters { firstName } + pets(kind: DOG) { + name + petKind + } } """ expected = { - "reporter": {"firstName": "ABA", "lastName": "X", "email": None}, - "reporters": [{"firstName": "ABA"}, {"firstName": "ABO"}], + 'reporter': { + 'firstName': 'John', + 'lastName': 'Doe', + 'email': None, + 'favoritePetKind': 'CAT', + 'pets': [{ + 'name': 'Garfield', + 'petKind': 'CAT' + }] + }, + 'reporters': [{ + 'firstName': 'John', + }, { + 'firstName': 'Jane', + }], + 'pets': [{ + 'name': 'Lassie', + 'petKind': 'DOG' + }] } schema = graphene.Schema(query=Query) result = schema.execute(query) assert not result.errors - assert result.data == expected + result = to_std_dicts(result.data) + assert result == expected def test_should_query_enums(session): @@ -97,7 +142,7 @@ class Meta: class Query(graphene.ObjectType): pet = graphene.Field(PetType) - def resolve_pet(self, *args, **kwargs): + def resolve_pet(self, _info): return session.query(Pet).first() query = """ @@ -109,11 +154,12 @@ def resolve_pet(self, *args, **kwargs): } } """ - expected = {"pet": {"name": "Lassie", "petKind": "dog", "hairKind": "LONG"}} + expected = {"pet": {"name": "Garfield", "petKind": "CAT", "hairKind": "SHORT"}} schema = graphene.Schema(query=Query) result = schema.execute(query) assert not result.errors - assert result.data == expected, result.data + result = to_std_dicts(result.data) + assert result == expected def test_enum_parameter(session): @@ -124,16 +170,18 @@ class Meta: model = Pet class Query(graphene.ObjectType): - pet = graphene.Field(PetType, kind=graphene.Argument(PetType._meta.fields['pet_kind'].type.of_type)) + pet = graphene.Field( + PetType, + kind=graphene.Argument(PetType._meta.fields['pet_kind'].type.of_type)) - def resolve_pet(self, info, kind=None, *args, **kwargs): + def resolve_pet(self, info, kind=None): query = session.query(Pet) if kind: query = query.filter(Pet.pet_kind == kind) return query.first() query = """ - query PetQuery($kind: pet_kind) { + query PetQuery($kind: PetKind) { pet(kind: $kind) { name, petKind @@ -141,14 +189,15 @@ def resolve_pet(self, info, kind=None, *args, **kwargs): } } """ - expected = {"pet": {"name": "Lassie", "petKind": "dog", "hairKind": "LONG"}} schema = graphene.Schema(query=Query) - result = schema.execute(query, variables={"kind": "cat"}) + result = schema.execute(query, variables={"kind": "CAT"}) assert not result.errors - assert result.data == {"pet": None} - result = schema.execute(query, variables={"kind": "dog"}) + expected = {"pet": {"name": "Garfield", "petKind": "CAT", "hairKind": "SHORT"}} + assert result.data == expected + result = schema.execute(query, variables={"kind": "DOG"}) assert not result.errors - assert result.data == expected, result.data + expected = {"pet": {"name": "Lassie", "petKind": "DOG", "hairKind": "LONG"}} + assert result.data == expected def test_py_enum_parameter(session): @@ -161,15 +210,15 @@ class Meta: class Query(graphene.ObjectType): pet = graphene.Field(PetType, kind=graphene.Argument(PetType._meta.fields['hair_kind'].type.of_type)) - def resolve_pet(self, info, kind=None, *args, **kwargs): + def resolve_pet(self, _info, kind=None): query = session.query(Pet) if kind: # XXX Why kind passed in as a str instead of a Hairkind instance? - query = query.filter(Pet.hair_kind == Hairkind(kind)) + query = query.filter(Pet.hair_kind == HairKind(kind)) return query.first() query = """ - query PetQuery($kind: Hairkind) { + query PetQuery($kind: HairKind) { pet(kind: $kind) { name, petKind @@ -177,14 +226,15 @@ def resolve_pet(self, info, kind=None, *args, **kwargs): } } """ - expected = {"pet": {"name": "Lassie", "petKind": "dog", "hairKind": "LONG"}} schema = graphene.Schema(query=Query) result = schema.execute(query, variables={"kind": "SHORT"}) assert not result.errors - assert result.data == {"pet": None} + expected = {"pet": {"name": "Garfield", "petKind": "CAT", "hairKind": "SHORT"}} + assert result.data == expected result = schema.execute(query, variables={"kind": "LONG"}) assert not result.errors - assert result.data == expected, result.data + expected = {"pet": {"name": "Lassie", "petKind": "DOG", "hairKind": "LONG"}} + assert result.data == expected def test_should_node(session): @@ -218,10 +268,10 @@ class Query(graphene.ObjectType): article = graphene.Field(ArticleNode) all_articles = SQLAlchemyConnectionField(ArticleConnection) - def resolve_reporter(self, *args, **kwargs): + def resolve_reporter(self, _info): return session.query(Reporter).first() - def resolve_article(self, *args, **kwargs): + def resolve_article(self, _info): return session.query(Article).first() query = """ @@ -260,8 +310,8 @@ def resolve_article(self, *args, **kwargs): expected = { "reporter": { "id": "UmVwb3J0ZXJOb2RlOjE=", - "firstName": "ABA", - "lastName": "X", + "firstName": "John", + "lastName": "Doe", "email": None, "articles": {"edges": [{"node": {"headline": "Hi!"}}]}, }, @@ -271,7 +321,8 @@ def resolve_article(self, *args, **kwargs): schema = graphene.Schema(query=Query) result = schema.execute(query, context_value={"session": session}) assert not result.errors - assert result.data == expected + result = to_std_dicts(result.data) + assert result == expected def test_should_custom_identifier(session): @@ -308,14 +359,15 @@ class Query(graphene.ObjectType): } """ expected = { - "allEditors": {"edges": [{"node": {"id": "RWRpdG9yTm9kZTox", "name": "John"}}]}, - "node": {"name": "John"}, + "allEditors": {"edges": [{"node": {"id": "RWRpdG9yTm9kZTox", "name": "Jack"}}]}, + "node": {"name": "Jack"}, } schema = graphene.Schema(query=Query) result = schema.execute(query, context_value={"session": session}) assert not result.errors - assert result.data == expected + result = to_std_dicts(result.data) + assert result == expected def test_should_mutate_well(session): @@ -385,7 +437,7 @@ class Mutation(graphene.ObjectType): "ok": True, "article": { "headline": "My Article", - "reporter": {"id": "UmVwb3J0ZXJOb2RlOjE=", "firstName": "ABA"}, + "reporter": {"id": "UmVwb3J0ZXJOb2RlOjE=", "firstName": "John"}, }, } } @@ -393,14 +445,15 @@ class Mutation(graphene.ObjectType): schema = graphene.Schema(query=Query, mutation=Mutation) result = schema.execute(query, context_value={"session": session}) assert not result.errors - assert result.data == expected + result = to_std_dicts(result.data) + assert result == expected def sort_setup(session): pets = [ - Pet(id=2, name="Lassie", pet_kind="dog", hair_kind=Hairkind.LONG), - Pet(id=22, name="Alf", pet_kind="cat", hair_kind=Hairkind.LONG), - Pet(id=3, name="Barf", pet_kind="dog", hair_kind=Hairkind.LONG), + Pet(id=2, name="Lassie", pet_kind="dog", hair_kind=HairKind.LONG), + Pet(id=22, name="Alf", pet_kind="cat", hair_kind=HairKind.LONG), + Pet(id=3, name="Barf", pet_kind="dog", hair_kind=HairKind.LONG), ] session.add_all(pets) session.commit() @@ -424,10 +477,10 @@ class Query(graphene.ObjectType): multipleSort = SQLAlchemyConnectionField(PetConnection) descSort = SQLAlchemyConnectionField(PetConnection) singleColumnSort = SQLAlchemyConnectionField( - PetConnection, sort=graphene.Argument(sort_enum_for_model(Pet)) + PetConnection, sort=graphene.Argument(get_sort_enum_for_model(Pet)) ) noDefaultSort = SQLAlchemyConnectionField( - PetConnection, sort=sort_argument_for_model(Pet, False) + PetConnection, sort=get_sort_argument_for_model(Pet, False) ) noSort = SQLAlchemyConnectionField(PetConnection, sort=None) @@ -440,14 +493,14 @@ class Query(graphene.ObjectType): } } } - nameSort(sort: name_asc){ + nameSort(sort: NAME_ASC){ edges{ node{ name } } } - multipleSort(sort: [pet_kind_asc, name_desc]){ + multipleSort(sort: [PET_KIND_ASC, NAME_DESC]){ edges{ node{ name @@ -455,21 +508,21 @@ class Query(graphene.ObjectType): } } } - descSort(sort: [name_desc]){ + descSort(sort: [NAME_DESC]){ edges{ node{ name } } } - singleColumnSort(sort: name_desc){ + singleColumnSort(sort: NAME_DESC){ edges{ node{ name } } } - noDefaultSort(sort: name_asc){ + noDefaultSort(sort: NAME_ASC){ edges{ node{ name @@ -493,9 +546,9 @@ def makeNodes(nodeList): ), "multipleSort": makeNodes( [ - {"name": "Alf", "petKind": "cat"}, - {"name": "Lassie", "petKind": "dog"}, - {"name": "Barf", "petKind": "dog"}, + {"name": "Alf", "petKind": "CAT"}, + {"name": "Lassie", "petKind": "DOG"}, + {"name": "Barf", "petKind": "DOG"}, ] ), "descSort": makeNodes([{"name": "Lassie"}, {"name": "Barf"}, {"name": "Alf"}]), @@ -507,7 +560,8 @@ def makeNodes(nodeList): schema = graphene.Schema(query=Query) result = schema.execute(query, context_value={"session": session}) assert not result.errors - assert result.data == expected + result = to_std_dicts(result.data) + assert result == expected queryError = """ query sortTest { @@ -555,3 +609,13 @@ def makeNodes(nodeList): assert set(node["node"]["name"] for node in value["edges"]) == set( node["node"]["name"] for node in expectedNoSort[key]["edges"] ) + + +def to_std_dicts(value): + """Convert nested ordered dicts to normal dicts for better comparison.""" + if isinstance(value, dict): + return {k: to_std_dicts(v) for k, v in value.items()} + elif isinstance(value, list): + return [to_std_dicts(v) for v in value] + else: + return value diff --git a/graphene_sqlalchemy/tests/test_registry.py b/graphene_sqlalchemy/tests/test_registry.py index 1945af6d..e94cc468 100644 --- a/graphene_sqlalchemy/tests/test_registry.py +++ b/graphene_sqlalchemy/tests/test_registry.py @@ -1,4 +1,7 @@ +from enum import Enum as PyEnum + import pytest +from sqlalchemy import Enum as SQLAlchemyEnum from ..registry import Registry from ..types import SQLAlchemyObjectType @@ -11,15 +14,15 @@ def test_register_incorrect_objecttype(): class Spam: pass - with pytest.raises(AssertionError) as excinfo: + with pytest.raises(TypeError) as exc_info: reg.register(Spam) assert "Only classes of type SQLAlchemyObjectType can be registered" in str( - excinfo.value + exc_info.value ) -def test_register_objecttype(): +def test_register_objecttype_twice(): reg = Registry() class PetType(SQLAlchemyObjectType): @@ -29,5 +32,87 @@ class Meta: try: reg.register(PetType) - except AssertionError: - pytest.fail("expected no AssertionError") + + class PetType2(SQLAlchemyObjectType): + class Meta: + model = Pet + registry = reg + + reg.register(PetType2) + except TypeError: + pytest.fail("check not enabled, expected no TypeError") + + assert reg.get_type_for_model(Pet) is PetType2 + + +def test_register_objecttype_twice_with_check(): + reg = Registry(check_duplicate_registration=True) + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + registry = reg + + try: + reg.register(PetType) + except TypeError: + pytest.fail("same object type, expected no TypeError") + + assert reg.get_type_for_model(Pet) is PetType + + with pytest.raises(TypeError) as exc_info: + + # noinspection PyUnusedLocal + class PetType2(SQLAlchemyObjectType): + class Meta: + model = Pet + registry = reg + + assert "Different object types registered for the same model" in str(exc_info.value) + + +def test_register_composite_converter(): + reg = Registry() + composite = object() + converter = len + reg.register_composite_converter(composite, converter) + reg.get_converter_for_composite(composite) is converter + + +def test_get_type_for_enum_from_list(): + reg = Registry() + sa_enum = SQLAlchemyEnum('red', 'blue', name='color_enum') + graphene_enum = reg.get_type_for_enum(sa_enum) + assert graphene_enum._meta.name == 'ColorEnum' + assert graphene_enum._meta.enum.__members__['RED'].value == 'red' + assert graphene_enum._meta.enum.__members__['BLUE'].value == 'blue' + try: + assert reg.get_type_for_enum(sa_enum) == graphene_enum + except TypeError: + pytest.fail("same enum, expected no TypeError") + sa_enum = SQLAlchemyEnum('red', 'green', name='color_enum') + with pytest.raises(TypeError) as exc_info: # different keys + reg.get_type_for_enum(sa_enum) + assert 'Different enums with the same name "ColorEnum"' in str(exc_info.value) + + +def test_get_type_for_enum_from_py_enum(): + reg = Registry() + py_enum = PyEnum('ColorEnum', 'red blue') + sa_enum = SQLAlchemyEnum(py_enum) + graphene_enum = reg.get_type_for_enum(sa_enum) + assert graphene_enum._meta.name == 'ColorEnum' + assert graphene_enum._meta.enum.__members__['RED'].value == 1 + assert graphene_enum._meta.enum.__members__['BLUE'].value == 2 + sa_enum = SQLAlchemyEnum('red', 'blue', name='color_enum') + with pytest.raises(TypeError) as exc_info: # different values + reg.get_type_for_enum(sa_enum) + assert 'Different enums with the same name "ColorEnum"' in str(exc_info.value) + + +def test_sort_params_for_model(): + reg = Registry() + model = object + sort_params = object() + reg.register_sort_params(model, sort_params) + assert reg.get_sort_params_for_model(model) is sort_params diff --git a/graphene_sqlalchemy/tests/test_schema.py b/graphene_sqlalchemy/tests/test_schema.py index 628da185..87739bdb 100644 --- a/graphene_sqlalchemy/tests/test_schema.py +++ b/graphene_sqlalchemy/tests/test_schema.py @@ -35,6 +35,7 @@ class Meta: "first_name", "last_name", "email", + "favorite_pet_kind", "pets", "articles", "favorite_article", diff --git a/graphene_sqlalchemy/tests/test_types.py b/graphene_sqlalchemy/tests/test_types.py index 0360a644..b76136fb 100644 --- a/graphene_sqlalchemy/tests/test_types.py +++ b/graphene_sqlalchemy/tests/test_types.py @@ -57,6 +57,7 @@ def test_objecttype_registered(): "first_name", "last_name", "email", + "favorite_pet_kind", "pets", "articles", "favorite_article", @@ -124,6 +125,7 @@ def test_custom_objecttype_registered(): "first_name", "last_name", "email", + "favorite_pet_kind", "pets", "articles", "favorite_article", @@ -168,6 +170,7 @@ def test_objecttype_with_custom_options(): "first_name", "last_name", "email", + "favorite_pet_kind", "pets", "articles", "favorite_article", @@ -181,7 +184,7 @@ class TestConnection(Connection): class Meta: node = ReporterWithCustomOptions - def resolver(*args, **kwargs): + def resolver(_obj, _info): return Promise.resolve([]) result = SQLAlchemyConnectionField.connection_resolver( diff --git a/graphene_sqlalchemy/tests/test_utils.py b/graphene_sqlalchemy/tests/test_utils.py index a7b902fe..53e1f4d0 100644 --- a/graphene_sqlalchemy/tests/test_utils.py +++ b/graphene_sqlalchemy/tests/test_utils.py @@ -2,7 +2,9 @@ from graphene import Enum, List, ObjectType, Schema, String -from ..utils import get_session, sort_argument_for_model, sort_enum_for_model +from ..utils import (create_sort_enum_for_model, get_session, + get_sort_argument_for_model, get_sort_enum_for_model, + to_enum_value_name, to_type_name) from .models import Editor, Pet @@ -27,37 +29,65 @@ def resolve_x(self, info): assert result.data["x"] == session -def test_sort_enum_for_model(): - enum = sort_enum_for_model(Pet) +def test_to_type_name(): + assert to_type_name("make_camel_case") == "MakeCamelCase" + assert to_type_name("AlreadyCamelCase") == "AlreadyCamelCase" + assert to_type_name("A_Snake_and_a_Camel") == "ASnakeAndACamel" + + +def test_to_enum_value_name(): + assert to_enum_value_name("make_enum_value_name") == "MAKE_ENUM_VALUE_NAME" + assert to_enum_value_name("makeEnumValueName") == "MAKE_ENUM_VALUE_NAME" + assert to_enum_value_name("HTTPStatus400Message") == "HTTP_STATUS400_MESSAGE" + assert to_enum_value_name("ALREADY_ENUM_VALUE_NAME") == "ALREADY_ENUM_VALUE_NAME" + + +def test_get_sort_enum_for_model(): + enum = get_sort_enum_for_model(Pet) assert isinstance(enum, type(Enum)) assert str(enum) == "PetSortEnum" - for col in sa.inspect(Pet).columns: - assert hasattr(enum, col.name + "_asc") - assert hasattr(enum, col.name + "_desc") + expect_symbols = [] + for name in sa.inspect(Pet).columns.keys(): + name_asc = name.upper() + "_ASC" + name_desc = name.upper() + "_DESC" + expect_symbols.extend([name_asc, name_desc]) + # the order of enums is not preserved for Python < 3.6 + assert sorted(enum._meta.enum.__members__) == sorted(expect_symbols) def test_sort_enum_for_model_custom_naming(): - enum = sort_enum_for_model(Pet, "Foo", lambda n, d: n.upper() + ("A" if d else "D")) + enum, default = create_sort_enum_for_model( + Pet, "Foo", lambda n, d: ("a_" if d else "d_") + n + ) assert str(enum) == "Foo" - for col in sa.inspect(Pet).columns: - assert hasattr(enum, col.name.upper() + "A") - assert hasattr(enum, col.name.upper() + "D") + expect_symbols = [] + expect_default = [] + for col in sa.inspect(Pet).columns.values(): + name = col.name + name_asc = "a_" + name + name_desc = "d_" + name + expect_symbols.extend([name_asc, name_desc]) + if col.primary_key: + expect_default.append(name_asc) + # the order of enums is not preserved for Python < 3.6 + assert sorted(enum._meta.enum.__members__) == sorted(expect_symbols) + assert default == expect_default def test_enum_cache(): - assert sort_enum_for_model(Editor) is sort_enum_for_model(Editor) + assert get_sort_enum_for_model(Editor) is get_sort_enum_for_model(Editor) def test_sort_argument_for_model(): - arg = sort_argument_for_model(Pet) + arg = get_sort_argument_for_model(Pet) assert isinstance(arg.type, List) - assert arg.default_value == [Pet.id.name + "_asc"] - assert arg.type.of_type == sort_enum_for_model(Pet) + assert arg.default_value == [Pet.id.name.upper() + "_ASC"] + assert arg.type.of_type == get_sort_enum_for_model(Pet) def test_sort_argument_for_model_no_default(): - arg = sort_argument_for_model(Pet, False) + arg = get_sort_argument_for_model(Pet, False) assert arg.default_value is None @@ -70,7 +100,7 @@ class MultiplePK(Base): bar = sa.Column(sa.Integer, primary_key=True) __tablename__ = "MultiplePK" - arg = sort_argument_for_model(MultiplePK) + arg = get_sort_argument_for_model(MultiplePK) assert set(arg.default_value) == set( - (MultiplePK.foo.name + "_asc", MultiplePK.bar.name + "_asc") + (MultiplePK.foo.name.upper() + "_ASC", MultiplePK.bar.name.upper() + "_ASC") ) diff --git a/graphene_sqlalchemy/utils.py b/graphene_sqlalchemy/utils.py index 276a8075..df96f043 100644 --- a/graphene_sqlalchemy/utils.py +++ b/graphene_sqlalchemy/utils.py @@ -1,9 +1,13 @@ +import re +import warnings +from collections import OrderedDict + from sqlalchemy.exc import ArgumentError from sqlalchemy.inspection import inspect from sqlalchemy.orm import class_mapper, object_mapper from sqlalchemy.orm.exc import UnmappedClassError, UnmappedInstanceError -from graphene import Argument, Enum, List +from graphene import Argument, List def get_session(context): @@ -41,7 +45,27 @@ def is_mapped_instance(cls): return True -def _symbol_name(column_name, is_asc): +def to_type_name(name): + """Convert the given name to a GraphQL type name.""" + return "".join(part[:1].upper() + part[1:] for part in name.split("_")) + + +_re_enum_value_name_1 = re.compile("(.)([A-Z][a-z]+)") +_re_enum_value_name_2 = re.compile("([a-z0-9])([A-Z])") + + +def to_enum_value_name(name): + """Convert the given name to a GraphQL enum value name.""" + return _re_enum_value_name_2.sub( + r"\1_\2", _re_enum_value_name_1.sub(r"\1_\2", name) + ).upper() + + +def default_symbol_name(column_name, is_asc): + return to_enum_value_name(column_name) + ("_ASC" if is_asc else "_DESC") + + +def plain_symbol_name(column_name, is_asc): # pragma: no cover return column_name + ("_asc" if is_asc else "_desc") @@ -56,55 +80,107 @@ def __init__(self, str_value, value): self.value = value -# Cache for the generated enums, to avoid name clash -_ENUM_CACHE = {} +def create_sort_enum_for_model( + cls, name=None, symbol_name=default_symbol_name, registry=None +): + """Create a Graphene Enum type for defining a sort order for the given model class. + The created Enum type and sort order will then be registered for that class. -def _sort_enum_for_model(cls, name=None, symbol_name=_symbol_name): - name = name or cls.__name__ + "SortEnum" - if name in _ENUM_CACHE: - return _ENUM_CACHE[name] - items = [] - default = [] + Parameters + - cls : SQLAlchemy model class + Model used to create the sort enumerator type + - name : str, optional, default None + Name to use for the enumerator. If not provided it will be set to the name + of the class with a 'SortEnum' postfix + - symbol_name : function, optional, default `default_symbol_name` + Function which takes the column name and a boolean indicating if the sort + direction is ascending, and returns the enum symbol name for the current column + and sort direction. The default function will create, for a column named 'foo', + the symbols 'FOO_ASC' and 'FOO_DESC'. + - registry: if not specified, the global registry will be used + Returns + - tuple with the Graphene Enum type and the default sort argument for the model + """ + if not name: + name = cls.__name__ + "SortEnum" + if registry is None: + from .registry import get_global_registry + registry = get_global_registry() + members = OrderedDict() + default_sort = [] for column in inspect(cls).columns.values(): asc_name = symbol_name(column.name, True) asc_value = EnumValue(asc_name, column.asc()) + members[asc_name] = asc_value + if column.primary_key: + default_sort.append(asc_value) desc_name = symbol_name(column.name, False) desc_value = EnumValue(desc_name, column.desc()) - if column.primary_key: - default.append(asc_value) - items.extend(((asc_name, asc_value), (desc_name, desc_value))) - enum = Enum(name, items) - _ENUM_CACHE[name] = (enum, default) - return enum, default + members[desc_name] = desc_value + graphene_enum = registry.register_enum(name, members) + registry.register_sort_params(graphene_enum, default_sort) + return graphene_enum, default_sort -def sort_enum_for_model(cls, name=None, symbol_name=_symbol_name): - """Create Graphene Enum for sorting a SQLAlchemy class query +def get_sort_enum_for_model(cls, registry=None): + """Get the Graphene Enum type for defining a sort order for the given model class. - Parameters - - cls : Sqlalchemy model class - Model used to create the sort enumerator - - name : str, optional, default None - Name to use for the enumerator. If not provided it will be set to `cls.__name__ + 'SortEnum'` - - symbol_name : function, optional, default `_symbol_name` - Function which takes the column name and a boolean indicating if the sort direction is ascending, - and returns the symbol name for the current column and sort direction. - The default function will create, for a column named 'foo', the symbols 'foo_asc' and 'foo_desc' + If no Enum type has been registered, create a default one and register it. + Parameters + - cls : SQLAlchemy model class + - registry: if not specified, the global registry will be used Returns - - Enum - The Graphene enumerator + - The Graphene Enum type """ - enum, _ = _sort_enum_for_model(cls, name, symbol_name) - return enum - - -def sort_argument_for_model(cls, has_default=True): - """Returns a Graphene argument for the sort field that accepts a list of sorting directions for a model. - If `has_default` is True (the default) it will sort the result by the primary key(s) + if registry is None: + from .registry import get_global_registry + registry = get_global_registry() + sort_params = registry.get_sort_params_for_model(cls) + if not sort_params: + sort_params = create_sort_enum_for_model(cls, registry=registry) + return sort_params[0] + + +def sort_enum_for_model( + cls, name=None, symbol_name=plain_symbol_name +): # pragma: no cover + warnings.warn( + "sort_argument_for_model() is deprecated;" + " use get_sort_argument_for_model() and create_sort_argument_for_model()", + DeprecationWarning, + stacklevel=2, + ) + if not name and not symbol_name: + return get_sort_enum_for_model(cls) + sort_params = create_sort_enum_for_model(cls, name, symbol_name) + return sort_params[0] + + +def get_sort_argument_for_model(cls, has_default=True, registry=None): + """Returns a Graphene Argument for defining a sort order for the given model class. + + The Argument that is returned accepts a list of sorting directions for the model. + If `has_default` is set to False, no sorting will happen when this argument is not + passed. Otherwise results will be sortied by the primary key(s) of the model. """ - enum, default = _sort_enum_for_model(cls) + if registry is None: + from .registry import get_global_registry + registry = get_global_registry() + sort_params = registry.get_sort_params_for_model(cls) + if not sort_params: + sort_params = create_sort_enum_for_model(cls, registry=registry) + enum, default = sort_params if not has_default: default = None return Argument(List(enum), default_value=default) + + +def sort_argument_for_model(cls, has_default=True): # pragma: no cover + warnings.warn( + "sort_argument_for_model() is deprecated; use get_sort_argument_for_model().", + DeprecationWarning, + stacklevel=2, + ) + return get_sort_argument_for_model(cls, has_default=has_default) diff --git a/setup.cfg b/setup.cfg index 7fd23df6..39a48fd2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,7 +8,7 @@ max-line-length = 120 [isort] known_graphene=graphene,graphql_relay,flask_graphql,graphql_server,sphinx_graphene_theme known_first_party=graphene_sqlalchemy -known_third_party=flask,nameko,promise,py,pytest,setuptools,singledispatch,six,sqlalchemy,sqlalchemy_utils +known_third_party=database,flask,models,nameko,promise,py,pytest,schema,setuptools,singledispatch,six,sqlalchemy,sqlalchemy_utils sections=FUTURE,STDLIB,THIRDPARTY,GRAPHENE,FIRSTPARTY,LOCALFOLDER no_lines_before=FIRSTPARTY