diff --git a/django_enumfield/enum.py b/django_enumfield/enum.py index 48b0218..d63b51a 100644 --- a/django_enumfield/enum.py +++ b/django_enumfield/enum.py @@ -15,7 +15,7 @@ class EnumType(type): def __new__(mcs, *args): """ Create enum values from all uppercase class attributes and store them in a dict on the Enum class.""" enum = super(EnumType, mcs).__new__(mcs, *args) - attributes = [k_v for k_v in list(enum.__dict__.items()) if k_v[0].isupper()] + attributes = [(k, getattr(enum, k)) for k in dir(enum) if k.isupper()] labels = enum.__dict__.get('labels', {}) enum.values = {} diff --git a/django_enumfield/tests/models.py b/django_enumfield/tests/models.py index 9cbded3..7d1ef90 100644 --- a/django_enumfield/tests/models.py +++ b/django_enumfield/tests/models.py @@ -9,6 +9,20 @@ class LampState(Enum): OFF = 0 ON = 1 + labels = { + OFF: 'Off', + ON: 'On' + } + + +class DimmableLampState(LampState): + DIMMED = 2 + + labels = LampState.labels.copy() + labels.update({ + DIMMED: 'Dimmed' + }) + class Lamp(models.Model): state = EnumField(LampState) diff --git a/django_enumfield/tests/test_enum.py b/django_enumfield/tests/test_enum.py index 1ffb56f..0aaf0b4 100644 --- a/django_enumfield/tests/test_enum.py +++ b/django_enumfield/tests/test_enum.py @@ -7,7 +7,7 @@ from django_enumfield.db.fields import EnumField from django_enumfield.enum import Enum from django_enumfield.exceptions import InvalidStatusOperationError -from django_enumfield.tests.models import Person, PersonStatus, Lamp, LampState, Beer, BeerStyle, BeerState, LabelBeer +from django_enumfield.tests.models import Person, PersonStatus, Lamp, LampState, DimmableLampState, Beer, BeerStyle, BeerState, LabelBeer class EnumFieldTest(TestCase): @@ -137,6 +137,14 @@ def test_choices(self): self.assertEqual(len(PersonStatus.choices()), len(list(PersonStatus.items()))) self.assertTrue(all(key in PersonStatus.__dict__ for key in dict(list(PersonStatus.items())))) + def test_subclass_choices(self): + self.assertEqual(len(DimmableLampState.choices()), len(list(DimmableLampState.items()))) + self.assertTrue(all(key in dir(DimmableLampState) for key in dict(list(DimmableLampState.items())))) + self.assertTrue(all(key in dir(DimmableLampState) for key in dict(list(LampState.items())))) + for key in dict(list(DimmableLampState.items())): + lamp = DimmableLampState.get(key) + self.assertTrue(lamp.value in DimmableLampState.labels, 'key %s not found' % lamp) + def test_default(self): self.assertEqual(PersonStatus.default(), PersonStatus.UNBORN)