-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathdataloaders.py
More file actions
38 lines (28 loc) · 1.43 KB
/
dataloaders.py
File metadata and controls
38 lines (28 loc) · 1.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from collections import defaultdict
from promise import Promise
from django.utils.functional import cached_property
from assisted_tagging.models import AssistedTaggingPrediction, LLMAssistedTaggingPredication
from utils.graphene.dataloaders import DataLoaderWithContext, WithContextMixin
class DraftEntryPredicationsLoader(DataLoaderWithContext):
def batch_load_fn(self, keys):
assisted_tagging_qs = AssistedTaggingPrediction.objects\
.filter(draft_entry_id__in=keys, is_selected=True)
_map = defaultdict(list)
for assisted_tagging in assisted_tagging_qs:
_map[assisted_tagging.draft_entry_id].append(assisted_tagging)
return Promise.resolve([_map.get(key, []) for key in keys])
class LLMDraftEntryPredicationsLoader(DataLoaderWithContext):
def batch_load_fn(self, keys):
llm_assisted_tagging_qs = LLMAssistedTaggingPredication.objects.filter(draft_entry_id__in=keys)
_map = {
assisted_tagging.draft_entry_id: assisted_tagging
for assisted_tagging in llm_assisted_tagging_qs
}
return Promise.resolve([_map.get(key) for key in keys])
class DataLoaders(WithContextMixin):
@cached_property
def draft_entry_predications(self):
return DraftEntryPredicationsLoader(context=self.context)
@cached_property
def llm_draft_entry_predications(self):
return LLMDraftEntryPredicationsLoader(context=self.context)