Skip to content

Commit 130c5a0

Browse files
authored
Merge pull request #1540 from the-deep/feature/assisted-tagging-with-llm
LLM Assisted Tagging
2 parents 2fdf67c + f773779 commit 130c5a0

File tree

17 files changed

+438
-153
lines changed

17 files changed

+438
-153
lines changed

apps/assisted_tagging/admin.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,16 @@
22
from admin_auto_filters.filters import AutocompleteFilterFactory
33
from django.contrib import admin
44

5-
from assisted_tagging.models import AssistedTaggingModelPredictionTag, AssistedTaggingPrediction, DraftEntry
5+
from assisted_tagging.models import (
6+
AssistedTaggingModelPredictionTag,
7+
AssistedTaggingPrediction,
8+
DraftEntry,
9+
LLMAssistedTaggingPredication
10+
)
611
from deep.admin import VersionAdmin
712

13+
admin.site.register(LLMAssistedTaggingPredication)
14+
815

916
@admin.register(DraftEntry)
1017
class DraftEntryAdmin(VersionAdmin):

apps/assisted_tagging/dataloaders.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from django.utils.functional import cached_property
55

6-
from assisted_tagging.models import AssistedTaggingPrediction
6+
from assisted_tagging.models import AssistedTaggingPrediction, LLMAssistedTaggingPredication
77

88
from utils.graphene.dataloaders import DataLoaderWithContext, WithContextMixin
99

@@ -18,7 +18,21 @@ def batch_load_fn(self, keys):
1818
return Promise.resolve([_map.get(key, []) for key in keys])
1919

2020

21+
class LLMDraftEntryPredicationsLoader(DataLoaderWithContext):
22+
def batch_load_fn(self, keys):
23+
llm_assisted_tagging_qs = LLMAssistedTaggingPredication.objects.filter(draft_entry_id__in=keys)
24+
_map = {
25+
assisted_tagging.draft_entry_id: assisted_tagging
26+
for assisted_tagging in llm_assisted_tagging_qs
27+
}
28+
return Promise.resolve([_map.get(key) for key in keys])
29+
30+
2131
class DataLoaders(WithContextMixin):
2232
@cached_property
2333
def draft_entry_predications(self):
2434
return DraftEntryPredicationsLoader(context=self.context)
35+
36+
@cached_property
37+
def llm_draft_entry_predications(self):
38+
return LLMDraftEntryPredicationsLoader(context=self.context)
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Generated by Django 3.2.25 on 2024-11-28 05:00
2+
3+
from django.db import migrations, models
4+
import django.db.models.deletion
5+
6+
7+
class Migration(migrations.Migration):
8+
9+
dependencies = [
10+
('assisted_tagging', '0012_auto_20231222_0554'),
11+
]
12+
13+
operations = [
14+
migrations.CreateModel(
15+
name='LLMAssistedTaggingPredication',
16+
fields=[
17+
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
18+
('value', models.CharField(blank=True, max_length=255)),
19+
('model_tags', models.JSONField(blank=True, null=True)),
20+
('draft_entry', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='llmpredictions', to='assisted_tagging.draftentry')),
21+
('model_version', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='+', to='assisted_tagging.assistedtaggingmodelversion')),
22+
],
23+
),
24+
]

apps/assisted_tagging/models.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,18 @@ def __str__(self):
189189
return str(self.id)
190190

191191

192+
class LLMAssistedTaggingPredication(models.Model):
193+
model_version = models.ForeignKey(AssistedTaggingModelVersion, on_delete=models.CASCADE, related_name='+')
194+
draft_entry = models.ForeignKey(DraftEntry, on_delete=models.CASCADE, related_name='llmpredictions')
195+
value = models.CharField(max_length=255, blank=True)
196+
model_tags = models.JSONField(null=True, blank=True)
197+
198+
id: int
199+
200+
def __str__(self):
201+
return str(self.id)
202+
203+
192204
class WrongPredictionReview(UserResource):
193205
prediction = models.ForeignKey(
194206
AssistedTaggingPrediction,

apps/assisted_tagging/schema.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from utils.graphene.enums import EnumDescription
88
from user_resource.schema import UserResourceMixin
99
from deep.permissions import ProjectPermissions as PP
10+
from graphene.types.generic import GenericScalar
1011

1112
from geo.schema import (
1213
ProjectGeoAreaType,
@@ -20,6 +21,7 @@
2021
AssistedTaggingModelVersion,
2122
AssistedTaggingModelPredictionTag,
2223
AssistedTaggingPrediction,
24+
LLMAssistedTaggingPredication,
2325
MissingPredictionReview,
2426
WrongPredictionReview,
2527
)
@@ -145,6 +147,22 @@ class Meta:
145147
'''
146148

147149

150+
class LLMAssistedTaggingPredictionType(DjangoObjectType):
151+
model_version = graphene.ID(source='model_version_id', required=True)
152+
draft_entry = graphene.ID(source='draft_entry_id', required=True)
153+
model_tags = GenericScalar()
154+
155+
class Meta:
156+
model = LLMAssistedTaggingPredication
157+
only_fields = (
158+
'id',
159+
'model_tags'
160+
)
161+
'''
162+
NOTE: model_version_deepl_model_id and wrong_prediction_review are not included here because they are not used in client
163+
'''
164+
165+
148166
class MissingPredictionReviewType(UserResourceMixin, DjangoObjectType):
149167
category = graphene.ID(source='category_id', required=True)
150168
tag = graphene.ID(source='tag_id', required=True)
@@ -160,9 +178,7 @@ class Meta:
160178
class DraftEntryType(DjangoObjectType):
161179
prediction_status = graphene.Field(DraftEntryPredictionStatusEnum, required=True)
162180
prediction_status_display = EnumDescription(source='get_prediction_status_display', required=True)
163-
prediction_tags = graphene.List(
164-
graphene.NonNull(AssistedTaggingPredictionType)
165-
)
181+
tags = graphene.Field(LLMAssistedTaggingPredictionType)
166182
geo_areas = graphene.List(
167183
graphene.NonNull(ProjectGeoAreaType)
168184
)
@@ -187,6 +203,10 @@ def resolve_prediction_tags(root, info, **kwargs):
187203
def resolve_geo_areas(root, info, **_):
188204
return info.context.dl.geo.draft_entry_geo_area.load(root.pk)
189205

206+
@staticmethod
207+
def resolve_tags(root, info, **_):
208+
return info.context.dl.assisted_tagging.llm_draft_entry_predications.load(root.pk)
209+
190210

191211
class DraftEntryListType(CustomDjangoListObjectType):
192212
class Meta:

apps/assisted_tagging/serializers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def validate_lead(self, lead):
2323
if lead.project != self.project:
2424
raise serializers.ValidationError('Only lead from current project are allowed.')
2525
af = lead.project.analysis_framework
26-
if af is None or not af.assisted_tagging_enabled:
26+
if af is None:
2727
raise serializers.ValidationError('Assisted tagging is disabled for the Framework used by this project.')
2828
return lead
2929

apps/assisted_tagging/tasks.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from utils.common import redis_lock
88
from deep.deepl import DeeplServiceEndpoint
99
from deepl_integration.handlers import (
10-
AssistedTaggingDraftEntryHandler,
11-
AutoAssistedTaggingDraftEntryHandler,
10+
LlmAssistedTaggingDraftEntryHandler,
11+
LLMAutoAssistedTaggingDraftEntryHandler,
1212
BaseHandler as DeepHandler
1313
)
1414

@@ -95,14 +95,14 @@ def sync_models_with_deepl():
9595
@redis_lock('trigger_request_for_draft_entry_task_{0}', 60 * 60 * 0.5)
9696
def trigger_request_for_draft_entry_task(draft_entry_id):
9797
draft_entry = DraftEntry.objects.get(pk=draft_entry_id)
98-
return AssistedTaggingDraftEntryHandler.send_trigger_request_to_extractor(draft_entry)
98+
return LlmAssistedTaggingDraftEntryHandler.send_trigger_request_to_extractor(draft_entry)
9999

100100

101101
@shared_task
102102
@redis_lock('trigger_request_for_auto_draft_entry_task_{0}', 60 * 60 * 0.5)
103103
def trigger_request_for_auto_draft_entry_task(lead_id):
104104
lead = Lead.objects.get(id=lead_id)
105-
return AutoAssistedTaggingDraftEntryHandler.auto_trigger_request_to_extractor(lead)
105+
return LLMAutoAssistedTaggingDraftEntryHandler.auto_trigger_request_to_extractor(lead)
106106

107107

108108
@shared_task

apps/assisted_tagging/tests/test_query.py

Lines changed: 2 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,8 @@ class TestAssistedTaggingQuery(GraphQLTestCase):
3636
ENABLE_NOW_PATCHER = True
3737

3838
ASSISTED_TAGGING_NLP_DATA = '''
39-
query MyQuery ($taggingModelId: ID!, $predictionTag: ID!) {
39+
query MyQuery ($taggingModelId: ID! ) {
4040
assistedTagging {
41-
predictionTags {
42-
id
43-
group
44-
isCategory
45-
isDeprecated
46-
hideInAnalysisFrameworkMapping
47-
parentTag
48-
tagId
49-
}
5041
taggingModels {
5142
id
5243
modelId
@@ -65,15 +56,6 @@ class TestAssistedTaggingQuery(GraphQLTestCase):
6556
version
6657
}
6758
}
68-
predictionTag(id: $predictionTag) {
69-
id
70-
group
71-
isCategory
72-
isDeprecated
73-
hideInAnalysisFrameworkMapping
74-
parentTag
75-
tagId
76-
}
7759
}
7860
}
7961
'''
@@ -88,15 +70,6 @@ class TestAssistedTaggingQuery(GraphQLTestCase):
8870
predictionStatus
8971
predictionStatusDisplay
9072
predictionReceivedAt
91-
predictionTags {
92-
id
93-
modelVersion
94-
dataType
95-
dataTypeDisplay
96-
value
97-
category
98-
tag
99-
}
10073
geoAreas {
10174
title
10275
}
@@ -111,14 +84,12 @@ def test_unified_connector_nlp_data(self):
11184

11285
model1, *other_models = AssistedTaggingModelFactory.create_batch(2)
11386
AssistedTaggingModelVersionFactory.create_batch(2, model=model1)
114-
tag1, *other_tags = AssistedTaggingModelPredictionTagFactory.create_batch(5)
11587

11688
# -- without login
11789
content = self.query_check(
11890
self.ASSISTED_TAGGING_NLP_DATA,
11991
variables=dict(
12092
taggingModelId=model1.id,
121-
predictionTag=tag1.id,
12293
),
12394
assert_for_error=True,
12495
)
@@ -129,31 +100,8 @@ def test_unified_connector_nlp_data(self):
129100
self.ASSISTED_TAGGING_NLP_DATA,
130101
variables=dict(
131102
taggingModelId=model1.id,
132-
predictionTag=tag1.id,
133103
)
134104
)['data']['assistedTagging']
135-
self.assertEqual(content['predictionTags'], [
136-
dict(
137-
id=str(tag.id),
138-
tagId=tag.tag_id,
139-
isDeprecated=tag.is_deprecated,
140-
isCategory=tag.is_category,
141-
group=tag.group,
142-
hideInAnalysisFrameworkMapping=tag.hide_in_analysis_framework_mapping,
143-
parentTag=tag.parent_tag_id and str(tag.parent_tag_id),
144-
)
145-
for tag in [tag1, *other_tags]
146-
])
147-
self.assertEqual(content['predictionTag'], dict(
148-
id=str(tag1.id),
149-
tagId=tag1.tag_id,
150-
isDeprecated=tag1.is_deprecated,
151-
isCategory=tag1.is_category,
152-
group=tag1.group,
153-
hideInAnalysisFrameworkMapping=tag1.hide_in_analysis_framework_mapping,
154-
parentTag=tag1.parent_tag_id and str(tag1.parent_tag_id),
155-
))
156-
157105
self.assertEqual(content['taggingModels'], [
158106
dict(
159107
id=str(_model.id),
@@ -196,38 +144,8 @@ def test_unified_connector_draft_entry(self):
196144
GeoAreaFactory.create(admin_level=admin_level, title='Nepal')
197145
GeoAreaFactory.create(admin_level=admin_level, title='Bagmati')
198146
GeoAreaFactory.create(admin_level=admin_level, title='Kathmandu')
199-
model1 = AssistedTaggingModelFactory.create()
200-
geo_model = AssistedTaggingModelFactory.create(model_id=AssistedTaggingModel.ModelID.GEO)
201-
latest_model1_version = AssistedTaggingModelVersionFactory.create_batch(2, model=model1)[0]
202-
latest_geo_model_version = AssistedTaggingModelVersionFactory.create(model=geo_model)
203-
category1, tag1, *other_tags = AssistedTaggingModelPredictionTagFactory.create_batch(5)
204-
205147
draft_entry1 = DraftEntryFactory.create(project=project, lead=lead, excerpt='sample excerpt')
206148

207-
prediction1 = AssistedTaggingPredictionFactory.create(
208-
data_type=AssistedTaggingPrediction.DataType.TAG,
209-
model_version=latest_model1_version,
210-
draft_entry=draft_entry1,
211-
category=category1,
212-
tag=tag1,
213-
prediction=0.1,
214-
threshold=0.05,
215-
is_selected=True,
216-
)
217-
prediction2 = AssistedTaggingPredictionFactory.create(
218-
data_type=AssistedTaggingPrediction.DataType.RAW,
219-
model_version=latest_geo_model_version,
220-
draft_entry=draft_entry1,
221-
value='Nepal',
222-
is_selected=True,
223-
)
224-
prediction3 = AssistedTaggingPredictionFactory.create(
225-
data_type=AssistedTaggingPrediction.DataType.RAW,
226-
model_version=latest_geo_model_version,
227-
draft_entry=draft_entry1,
228-
value='Kathmandu',
229-
is_selected=True,
230-
)
231149
draft_entry1.save_geo_data()
232150

233151
def _query_check(**kwargs):
@@ -257,44 +175,7 @@ def _query_check(**kwargs):
257175
predictionReceivedAt=None,
258176
predictionStatus=self.genum(draft_entry1.prediction_status),
259177
predictionStatusDisplay=draft_entry1.get_prediction_status_display(),
260-
predictionTags=[
261-
dict(
262-
id=str(prediction1.pk),
263-
modelVersion=str(prediction1.model_version_id),
264-
dataType=self.genum(prediction1.data_type),
265-
dataTypeDisplay=prediction1.get_data_type_display(),
266-
value='',
267-
category=str(prediction1.category_id),
268-
tag=str(prediction1.tag_id),
269-
),
270-
dict(
271-
id=str(prediction2.id),
272-
modelVersion=str(prediction2.model_version.id),
273-
dataType=self.genum(prediction2.data_type),
274-
dataTypeDisplay=prediction2.get_data_type_display(),
275-
value=prediction2.value,
276-
category=None,
277-
tag=None,
278-
),
279-
dict(
280-
id=str(prediction3.id),
281-
modelVersion=str(prediction3.model_version.id),
282-
dataType=self.genum(prediction3.data_type),
283-
dataTypeDisplay=prediction3.get_data_type_display(),
284-
value=prediction3.value,
285-
category=None,
286-
tag=None,
287-
)
288-
],
289-
geoAreas=[
290-
dict(
291-
title='Nepal',
292-
),
293-
dict(
294-
title='Kathmandu',
295-
)
296-
297-
],
178+
geoAreas=[]
298179
))
299180

300181

0 commit comments

Comments
 (0)