From ff99248c8bf831277ca24bcc7e992e6308287fdd Mon Sep 17 00:00:00 2001 From: Sergey Fursov Date: Thu, 26 Dec 2024 11:47:49 +0300 Subject: [PATCH] fix typed choices, make working with different Django 5x choices options --- graphene_django/compat.py | 24 ++++++- graphene_django/converter.py | 33 +++++---- graphene_django/forms/types.py | 7 +- graphene_django/tests/models.py | 44 ++++++++++++ graphene_django/tests/test_converter.py | 95 +++++++++++++++++++++---- graphene_django/tests/test_schema.py | 3 + graphene_django/tests/test_types.py | 33 +++++++++ 7 files changed, 208 insertions(+), 31 deletions(-) diff --git a/graphene_django/compat.py b/graphene_django/compat.py index b3d160a13..a2fc1f022 100644 --- a/graphene_django/compat.py +++ b/graphene_django/compat.py @@ -1,10 +1,12 @@ import sys +from collections.abc import Callable from pathlib import PurePath # For backwards compatibility, we import JSONField to have it available for import via # this compat module (https://github.com/graphql-python/graphene-django/issues/1428). # Django's JSONField is available in Django 3.2+ (the minimum version we support) -from django.db.models import JSONField +import django +from django.db.models import Choices, JSONField class MissingType: @@ -42,3 +44,23 @@ def __init__(self, *args, **kwargs): else: ArrayField = MissingType + + +try: + from django.utils.choices import normalize_choices +except ImportError: + + def normalize_choices(choices): + if isinstance(choices, type) and issubclass(choices, Choices): + choices = choices.choices + + if isinstance(choices, Callable): + choices = choices() + + # In restframework==3.15.0, choices are not passed + # as OrderedDict anymore, so it's safer to check + # for a dict + if isinstance(choices, dict): + choices = choices.items() + + return choices diff --git a/graphene_django/converter.py b/graphene_django/converter.py index 7eba22a1d..4e458f185 100644 --- a/graphene_django/converter.py +++ b/graphene_django/converter.py @@ -1,5 +1,4 @@ import inspect -from collections.abc import Callable from functools import partial, singledispatch, wraps from django.db import models @@ -37,7 +36,7 @@ from graphql import assert_valid_name as assert_name from graphql.pyutils import register_description -from .compat import ArrayField, HStoreField, RangeField +from .compat import ArrayField, HStoreField, RangeField, normalize_choices from .fields import DjangoConnectionField, DjangoListField from .settings import graphene_settings from .utils.str_converters import to_const @@ -61,6 +60,24 @@ def wrapped_resolver(*args, **kwargs): return blank_field_wrapper(resolver) +class EnumValueField(BlankValueField): + def wrap_resolve(self, parent_resolver): + resolver = super().wrap_resolve(parent_resolver) + + # create custom resolver + def enum_field_wrapper(func): + @wraps(func) + def wrapped_resolver(*args, **kwargs): + return_value = func(*args, **kwargs) + if isinstance(return_value, models.Choices): + return_value = return_value.value + return return_value + + return wrapped_resolver + + return enum_field_wrapper(resolver) + + def convert_choice_name(name): name = to_const(force_str(name)) try: @@ -72,15 +89,7 @@ def convert_choice_name(name): def get_choices(choices): converted_names = [] - if isinstance(choices, Callable): - choices = choices() - - # In restframework==3.15.0, choices are not passed - # as OrderedDict anymore, so it's safer to check - # for a dict - if isinstance(choices, dict): - choices = choices.items() - + choices = normalize_choices(choices) for value, help_text in choices: if isinstance(help_text, (tuple, list)): yield from get_choices(help_text) @@ -157,7 +166,7 @@ def convert_django_field_with_choices( converted = EnumCls( description=get_django_field_description(field), required=required - ).mount_as(BlankValueField) + ).mount_as(EnumValueField) else: converted = convert_django_field(field, registry) if registry is not None: diff --git a/graphene_django/forms/types.py b/graphene_django/forms/types.py index 0e311e5d6..68ffa6635 100644 --- a/graphene_django/forms/types.py +++ b/graphene_django/forms/types.py @@ -3,7 +3,7 @@ from graphene.types.inputobjecttype import InputObjectType from graphene.utils.str_converters import to_camel_case -from ..converter import BlankValueField +from ..converter import EnumValueField from ..types import ErrorType # noqa Import ErrorType for backwards compatibility from .mutation import fields_for_form @@ -57,11 +57,10 @@ def mutate(_root, _info, data): if ( object_type and name in object_type._meta.fields - and isinstance(object_type._meta.fields[name], BlankValueField) + and isinstance(object_type._meta.fields[name], EnumValueField) ): - # Field type BlankValueField here means that field + # Field type EnumValueField here means that field # with choices have been converted to enum - # (BlankValueField is using only for that task ?) setattr(cls, name, cls.get_enum_cnv_cls_instance(name, object_type)) elif ( object_type diff --git a/graphene_django/tests/models.py b/graphene_django/tests/models.py index ece1bb6db..d190a9b94 100644 --- a/graphene_django/tests/models.py +++ b/graphene_django/tests/models.py @@ -1,9 +1,38 @@ +import django from django.db import models from django.utils.translation import gettext_lazy as _ CHOICES = ((1, "this"), (2, _("that"))) +def get_choices_as_class(choices_class): + if django.VERSION >= (5, 0): + return choices_class + else: + return choices_class.choices + + +def get_choices_as_callable(choices_class): + if django.VERSION >= (5, 0): + + def choices(): + return choices_class.choices + + return choices + else: + return choices_class.choices + + +class TypedIntChoice(models.IntegerChoices): + CHOICE_THIS = 1 + CHOICE_THAT = 2 + + +class TypedStrChoice(models.TextChoices): + CHOICE_THIS = "this" + CHOICE_THAT = "that" + + class Person(models.Model): name = models.CharField(max_length=30) parent = models.ForeignKey( @@ -51,6 +80,21 @@ class Reporter(models.Model): email = models.EmailField() pets = models.ManyToManyField("self") a_choice = models.IntegerField(choices=CHOICES, null=True, blank=True) + typed_choice = models.IntegerField( + choices=TypedIntChoice.choices, + null=True, + blank=True, + ) + class_choice = models.IntegerField( + choices=get_choices_as_class(TypedIntChoice), + null=True, + blank=True, + ) + callable_choice = models.IntegerField( + choices=get_choices_as_callable(TypedStrChoice), + null=True, + blank=True, + ) objects = models.Manager() doe_objects = DoeReporterManager() fans = models.ManyToManyField(Person) diff --git a/graphene_django/tests/test_converter.py b/graphene_django/tests/test_converter.py index 2f8b1d515..1499348fe 100644 --- a/graphene_django/tests/test_converter.py +++ b/graphene_django/tests/test_converter.py @@ -25,7 +25,7 @@ ) from ..registry import Registry from ..types import DjangoObjectType -from .models import Article, Film, FilmDetails, Reporter +from .models import Article, Film, FilmDetails, Reporter, TypedIntChoice, TypedStrChoice # from graphene.core.types.custom_scalars import DateTime, Time, JSONString @@ -443,35 +443,102 @@ def test_choice_enum_blank_value(): class ReporterType(DjangoObjectType): class Meta: model = Reporter - fields = ( - "first_name", - "a_choice", - ) + fields = ("callable_choice",) class Query(graphene.ObjectType): reporter = graphene.Field(ReporterType) def resolve_reporter(root, info): - return Reporter.objects.first() + # return a model instance with blank choice field value + return Reporter(callable_choice="") schema = graphene.Schema(query=Query) - # Create model with empty choice option - Reporter.objects.create( - first_name="Bridget", last_name="Jones", email="bridget@example.com" - ) - result = schema.execute( """ query { reporter { - firstName - aChoice + callableChoice } } """ ) assert not result.errors assert result.data == { - "reporter": {"firstName": "Bridget", "aChoice": None}, + "reporter": {"callableChoice": None}, + } + + +def test_typed_choice_value(): + """Test that typed choices fields are resolved correctly to the enum values""" + + class ReporterType(DjangoObjectType): + class Meta: + model = Reporter + fields = ("typed_choice", "class_choice", "callable_choice") + + class Query(graphene.ObjectType): + reporter = graphene.Field(ReporterType) + + def resolve_reporter(root, info): + # assign choice values to the fields instead of their str or int values + return Reporter( + typed_choice=TypedIntChoice.CHOICE_THIS, + class_choice=TypedIntChoice.CHOICE_THAT, + callable_choice=TypedStrChoice.CHOICE_THIS, + ) + + class CreateReporter(graphene.Mutation): + reporter = graphene.Field(ReporterType) + + def mutate(root, info, **kwargs): + return CreateReporter( + reporter=Reporter( + typed_choice=TypedIntChoice.CHOICE_THIS, + class_choice=TypedIntChoice.CHOICE_THAT, + callable_choice=TypedStrChoice.CHOICE_THIS, + ), + ) + + class Mutation(graphene.ObjectType): + create_reporter = CreateReporter.Field() + + schema = graphene.Schema(query=Query, mutation=Mutation) + + reporter_fragment = """ + fragment reporter on ReporterType { + typedChoice + classChoice + callableChoice + } + """ + + expected_reporter = { + "typedChoice": "A_1", + "classChoice": "A_2", + "callableChoice": "THIS", } + + result = schema.execute( + reporter_fragment + + """ + query { + reporter { ...reporter } + } + """ + ) + assert not result.errors + assert result.data["reporter"] == expected_reporter + + result = schema.execute( + reporter_fragment + + """ + mutation { + createReporter { + reporter { ...reporter } + } + } + """ + ) + assert not result.errors + assert result.data["createReporter"]["reporter"] == expected_reporter diff --git a/graphene_django/tests/test_schema.py b/graphene_django/tests/test_schema.py index 93cbd9f05..009211294 100644 --- a/graphene_django/tests/test_schema.py +++ b/graphene_django/tests/test_schema.py @@ -40,6 +40,9 @@ class Meta: "email", "pets", "a_choice", + "typed_choice", + "class_choice", + "callable_choice", "fans", "reporter_type", ] diff --git a/graphene_django/tests/test_types.py b/graphene_django/tests/test_types.py index 72514d23b..0491bcd30 100644 --- a/graphene_django/tests/test_types.py +++ b/graphene_django/tests/test_types.py @@ -77,6 +77,9 @@ def test_django_objecttype_map_correct_fields(): "email", "pets", "a_choice", + "typed_choice", + "class_choice", + "callable_choice", "fans", "reporter_type", ] @@ -186,6 +189,9 @@ def test_schema_representation(): email: String! pets: [Reporter!]! aChoice: TestsReporterAChoiceChoices + typedChoice: TestsReporterTypedChoiceChoices + classChoice: TestsReporterClassChoiceChoices + callableChoice: TestsReporterCallableChoiceChoices reporterType: TestsReporterReporterTypeChoices articles(offset: Int, before: String, after: String, first: Int, last: Int): ArticleConnection! } @@ -199,6 +205,33 @@ def test_schema_representation(): A_2 } + \"""An enumeration.\""" + enum TestsReporterTypedChoiceChoices { + \"""Choice This\""" + A_1 + + \"""Choice That\""" + A_2 + } + + \"""An enumeration.\""" + enum TestsReporterClassChoiceChoices { + \"""Choice This\""" + A_1 + + \"""Choice That\""" + A_2 + } + + \"""An enumeration.\""" + enum TestsReporterCallableChoiceChoices { + \"""Choice This\""" + THIS + + \"""Choice That\""" + THAT + } + \"""An enumeration.\""" enum TestsReporterReporterTypeChoices { \"""Regular\"""