Skip to content

Commit 510f78e

Browse files
authored
Django documentation updates and swappable model improvements (#2422)
* Allow use of non-default database * Allow custom models to be specified * Allow swappable models * Switch to in-memory db for django tests * Test database parameter * Improve introspection * Update documentation * Update documentation
1 parent 2cf0b20 commit 510f78e

File tree

10 files changed

+880
-39
lines changed

10 files changed

+880
-39
lines changed

chatterbot/ext/django_chatterbot/abstract_models.py

Lines changed: 46 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,15 @@
33
from django.db import models
44
from django.utils import timezone
55
from django.conf import settings
6+
from django.apps import apps
67

78

89
DJANGO_APP_NAME = constants.DEFAULT_DJANGO_APP_NAME
9-
STATEMENT_MODEL = 'Statement'
10-
TAG_MODEL = 'Tag'
1110

12-
if hasattr(settings, 'CHATTERBOT'):
13-
"""
14-
Allow related models to be overridden in the project settings.
15-
Default to the original settings if one is not defined.
16-
"""
17-
DJANGO_APP_NAME = settings.CHATTERBOT.get(
18-
'django_app_name',
19-
DJANGO_APP_NAME
20-
)
21-
STATEMENT_MODEL = settings.CHATTERBOT.get(
22-
'statement_model',
23-
STATEMENT_MODEL
24-
)
11+
# Default model paths for swappable models
12+
# These can be overridden via CHATTERBOT_STATEMENT_MODEL and CHATTERBOT_TAG_MODEL settings
13+
DEFAULT_STATEMENT_MODEL = f'{DJANGO_APP_NAME}.Statement'
14+
DEFAULT_TAG_MODEL = f'{DJANGO_APP_NAME}.Tag'
2515

2616

2717
class AbstractBaseTag(models.Model):
@@ -88,7 +78,9 @@ class AbstractBaseStatement(models.Model, StatementMixin):
8878
)
8979

9080
tags = models.ManyToManyField(
91-
TAG_MODEL,
81+
settings.CHATTERBOT_TAG_MODEL if hasattr(
82+
settings, 'CHATTERBOT_TAG_MODEL'
83+
) else DEFAULT_TAG_MODEL,
9284
related_name='statements',
9385
help_text='The tags that are associated with this statement.'
9486
)
@@ -117,17 +109,51 @@ def __str__(self):
117109
return self.text
118110
return '<empty>'
119111

112+
@classmethod
113+
def get_tag_model(cls):
114+
"""
115+
Return the Tag model class, respecting the swappable setting.
116+
117+
This method checks:
118+
1. Django settings (CHATTERBOT_TAG_MODEL) - project-wide configuration
119+
2. The model referenced by the 'tags' field - handles custom models via kwargs
120+
3. Falls back to DEFAULT_TAG_MODEL if introspection fails
121+
122+
This ensures the correct Tag model is used even when custom models
123+
are specified via storage adapter kwargs rather than Django settings.
124+
"""
125+
tag_model_path = getattr(settings, 'CHATTERBOT_TAG_MODEL', None)
126+
127+
if tag_model_path:
128+
return apps.get_model(tag_model_path)
129+
130+
# If no setting, infer from the ManyToManyField relationship for
131+
# cases where custom models are specified via kwargs
132+
try:
133+
# Get the model that this class's 'tags' field points to
134+
tags_field = cls._meta.get_field('tags')
135+
related_model = tags_field.related_model
136+
137+
# Resolve strings (lazy references)
138+
if isinstance(related_model, str):
139+
return apps.get_model(related_model)
140+
return related_model
141+
except Exception:
142+
# Fallback to default if introspection fails
143+
return apps.get_model(DEFAULT_TAG_MODEL)
144+
120145
def get_tags(self) -> list[str]:
121146
"""
122147
Return the list of tags for this statement.
123-
(Overrides the method from StatementMixin)
124148
"""
125149
return list(self.tags.values_list('name', flat=True))
126150

127151
def add_tags(self, *tags):
128152
"""
129153
Add a list of strings to the statement as tags.
130-
(Overrides the method from StatementMixin)
131154
"""
132-
for _tag in tags:
133-
self.tags.get_or_create(name=_tag)
155+
TagModel = self.get_tag_model()
156+
157+
for tag_name in tags:
158+
tag_obj, _created = TagModel.objects.get_or_create(name=tag_name)
159+
self.tags.add(tag_obj)

chatterbot/ext/django_chatterbot/models.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,22 @@ class Statement(AbstractBaseStatement):
55
"""
66
A statement represents a single spoken entity, sentence or
77
phrase that someone can say.
8+
9+
This model can be swapped for a custom model by setting
10+
CHATTERBOT_STATEMENT_MODEL in your Django settings.
811
"""
9-
pass
12+
13+
class Meta:
14+
swappable = 'CHATTERBOT_STATEMENT_MODEL'
1015

1116

1217
class Tag(AbstractBaseTag):
1318
"""
1419
A label that categorizes a statement.
20+
21+
This model can be swapped for a custom model by setting
22+
CHATTERBOT_TAG_MODEL in your Django settings.
1523
"""
16-
pass
24+
25+
class Meta:
26+
swappable = 'CHATTERBOT_TAG_MODEL'

chatterbot/storage/django_storage.py

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,56 @@ class DjangoStorageAdapter(StorageAdapter):
66
"""
77
Storage adapter that allows ChatterBot to interact with
88
Django storage backends.
9+
10+
:param database: The Django database alias to use (default: 'default')
11+
:type database: str
12+
:param statement_model: The Statement model to use (default: reads from CHATTERBOT_STATEMENT_MODEL setting)
13+
:type statement_model: str
14+
:param tag_model: The Tag model to use (default: reads from CHATTERBOT_TAG_MODEL setting)
15+
:type tag_model: str
916
"""
1017

1118
def __init__(self, **kwargs):
1219
super().__init__(**kwargs)
20+
from django.conf import settings
1321

1422
self.django_app_name = kwargs.get(
1523
'django_app_name',
1624
constants.DEFAULT_DJANGO_APP_NAME
1725
)
1826

27+
self.database = kwargs.get('database', 'default')
28+
29+
# Support custom models via kwargs or Django settings
30+
self.statement_model = kwargs.get(
31+
'statement_model',
32+
getattr(
33+
settings,
34+
'CHATTERBOT_STATEMENT_MODEL',
35+
f'{self.django_app_name}.Statement'
36+
)
37+
)
38+
39+
self.tag_model = kwargs.get(
40+
'tag_model',
41+
getattr(
42+
settings,
43+
'CHATTERBOT_TAG_MODEL',
44+
f'{self.django_app_name}.Tag'
45+
)
46+
)
47+
1948
def get_statement_model(self):
2049
from django.apps import apps
21-
return apps.get_model(self.django_app_name, 'Statement')
50+
return apps.get_model(self.statement_model)
2251

2352
def get_tag_model(self):
2453
from django.apps import apps
25-
return apps.get_model(self.django_app_name, 'Tag')
54+
return apps.get_model(self.tag_model)
2655

2756
def count(self) -> int:
2857
Statement = self.get_model('statement')
29-
return Statement.objects.count()
58+
return Statement.objects.using(self.database).count()
3059

3160
def filter(self, **kwargs):
3261
"""
@@ -53,7 +82,7 @@ def filter(self, **kwargs):
5382
if tags:
5483
kwargs['tags__name__in'] = tags
5584

56-
statements = Statement.objects.filter(**kwargs)
85+
statements = Statement.objects.using(self.database).filter(**kwargs)
5786

5887
if exclude_text:
5988
statements = statements.exclude(
@@ -115,12 +144,12 @@ def create(self, **kwargs):
115144

116145
statement = Statement(**kwargs)
117146

118-
statement.save()
147+
statement.save(using=self.database)
119148

120149
tags_to_add = []
121150

122151
for _tag in tags:
123-
tag, _ = Tag.objects.get_or_create(name=_tag)
152+
tag, _ = Tag.objects.using(self.database).get_or_create(name=_tag)
124153
tags_to_add.append(tag)
125154

126155
statement.tags.add(*tags_to_add)
@@ -143,15 +172,15 @@ def create_many(self, statements):
143172

144173
statement_model_object = Statement(**statement_data)
145174

146-
statement_model_object.save()
175+
statement_model_object.save(using=self.database)
147176

148177
tags_to_add = []
149178

150179
for tag_name in tag_data:
151180
if tag_name in tag_cache:
152181
tag = tag_cache[tag_name]
153182
else:
154-
tag, _ = Tag.objects.get_or_create(name=tag_name)
183+
tag, _ = Tag.objects.using(self.database).get_or_create(name=tag_name)
155184
tag_cache[tag_name] = tag
156185
tags_to_add.append(tag)
157186

@@ -165,9 +194,9 @@ def update(self, statement):
165194
Tag = self.get_model('tag')
166195

167196
if hasattr(statement, 'id'):
168-
statement.save()
197+
statement.save(using=self.database)
169198
else:
170-
statement = Statement.objects.create(
199+
statement = Statement.objects.using(self.database).create(
171200
text=statement.text,
172201
search_text=statement.search_text,
173202
conversation=statement.conversation,
@@ -177,7 +206,7 @@ def update(self, statement):
177206
)
178207

179208
for _tag in statement.tags.all():
180-
tag, _ = Tag.objects.get_or_create(name=_tag)
209+
tag, _ = Tag.objects.using(self.database).get_or_create(name=_tag)
181210

182211
statement.tags.add(tag)
183212

@@ -189,7 +218,7 @@ def get_random(self):
189218
"""
190219
Statement = self.get_model('statement')
191220

192-
statement = Statement.objects.order_by('?').first()
221+
statement = Statement.objects.using(self.database).order_by('?').first()
193222

194223
if statement is None:
195224
raise self.EmptyDatabaseException()
@@ -204,7 +233,7 @@ def remove(self, statement_text):
204233
"""
205234
Statement = self.get_model('statement')
206235

207-
statements = Statement.objects.filter(text=statement_text)
236+
statements = Statement.objects.using(self.database).filter(text=statement_text)
208237

209238
statements.delete()
210239

@@ -215,5 +244,5 @@ def drop(self):
215244
Statement = self.get_model('statement')
216245
Tag = self.get_model('tag')
217246

218-
Statement.objects.all().delete()
219-
Tag.objects.all().delete()
247+
Statement.objects.using(self.database).all().delete()
248+
Tag.objects.using(self.database).all().delete()

0 commit comments

Comments
 (0)