-
Notifications
You must be signed in to change notification settings - Fork 46
/
Copy pathcollators.py
67 lines (61 loc) · 2.56 KB
/
collators.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import transformers
from typing import Dict, Sequence
from data.utils import IGNORE_INDEX
class DataCollatorForSupervisedDataset(object):
"""Collate examples for supervised fine-tuning."""
def __init__(
self,
tokenizer: transformers.PreTrainedTokenizer,
padding_side: str = "right",
index: str = None,
):
self.tokenizer = tokenizer
self.padding_side = padding_side
self.index = index
def get_instances_from_key(self, instances: Sequence[Dict], key: str):
ret_instances = [instance[key] for instance in instances]
return ret_instances
def _pad_tokens(self, input_ids, padding_value):
if self.padding_side == "right":
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=padding_value
)
else:
input_ids = torch.nn.utils.rnn.pad_sequence(
[torch.flip(i, dims=[0]) for i in input_ids],
batch_first=True,
padding_value=padding_value,
).flip(dims=[1])
return input_ids
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
assert isinstance(instances[0], dict)
return_dct = {}
if "input_ids" not in instances[0]:
for key in instances[0].keys():
key_instances = self.get_instances_from_key(
instances=instances, key=key
)
return_dct[key] = self(key_instances)
else:
input_ids = [instance["input_ids"] for instance in instances]
input_ids = self._pad_tokens(input_ids, self.tokenizer.pad_token_id)
attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
return_dct.update({"input_ids": input_ids})
return_dct.update({"attention_mask": attention_mask})
if "labels" in instances[0]:
labels = [instance["labels"] for instance in instances]
labels = self._pad_tokens(labels, IGNORE_INDEX)
return_dct.update({"labels": labels})
if self.index:
if self.index in instances[0]:
return_dct.update(
{
self.index: torch.tensor(
[example[self.index] for example in instances]
)
}
)
else:
raise Warning(f"{self.index} not found in dataset")
return return_dct