Skip to content

Commit 200af1a

Browse files
committed
Allow subclassing of Enums
1 parent 661eb63 commit 200af1a

File tree

3 files changed

+24
-2
lines changed

3 files changed

+24
-2
lines changed

django_enumfield/enum.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class EnumType(type):
1515
def __new__(mcs, *args):
1616
""" Create enum values from all uppercase class attributes and store them in a dict on the Enum class."""
1717
enum = super(EnumType, mcs).__new__(mcs, *args)
18-
attributes = [k_v for k_v in list(enum.__dict__.items()) if k_v[0].isupper()]
18+
attributes = [(k, getattr(enum, k)) for k in dir(enum) if k.isupper()]
1919
labels = enum.__dict__.get('labels', {})
2020

2121
enum.values = {}

django_enumfield/tests/models.py

+14
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,20 @@ class LampState(Enum):
99
OFF = 0
1010
ON = 1
1111

12+
labels = {
13+
OFF: 'Off',
14+
ON: 'On'
15+
}
16+
17+
18+
class DimmableLampState(LampState):
19+
DIMMED = 2
20+
21+
labels = LampState.labels.copy()
22+
labels.update({
23+
DIMMED: 'Dimmed'
24+
})
25+
1226

1327
class Lamp(models.Model):
1428
state = EnumField(LampState)

django_enumfield/tests/test_enum.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from django_enumfield.db.fields import EnumField
88
from django_enumfield.enum import Enum
99
from django_enumfield.exceptions import InvalidStatusOperationError
10-
from django_enumfield.tests.models import Person, PersonStatus, Lamp, LampState, Beer, BeerStyle, BeerState, LabelBeer
10+
from django_enumfield.tests.models import Person, PersonStatus, Lamp, LampState, DimmableLampState, Beer, BeerStyle, BeerState, LabelBeer
1111

1212

1313
class EnumFieldTest(TestCase):
@@ -137,6 +137,14 @@ def test_choices(self):
137137
self.assertEqual(len(PersonStatus.choices()), len(list(PersonStatus.items())))
138138
self.assertTrue(all(key in PersonStatus.__dict__ for key in dict(list(PersonStatus.items()))))
139139

140+
def test_subclass_choices(self):
141+
self.assertEqual(len(DimmableLampState.choices()), len(list(DimmableLampState.items())))
142+
self.assertTrue(all(key in dir(DimmableLampState) for key in dict(list(DimmableLampState.items()))))
143+
self.assertTrue(all(key in dir(DimmableLampState) for key in dict(list(LampState.items()))))
144+
for key in dict(list(DimmableLampState.items())):
145+
lamp = DimmableLampState.get(key)
146+
self.assertTrue(lamp.value in DimmableLampState.labels, 'key %s not found' % lamp)
147+
140148
def test_default(self):
141149
self.assertEqual(PersonStatus.default(), PersonStatus.UNBORN)
142150

0 commit comments

Comments
 (0)