-
Notifications
You must be signed in to change notification settings - Fork 46
/
Copy pathpretraining.py
92 lines (82 loc) · 2.84 KB
/
pretraining.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# import torch
from torch.utils.data import Dataset
from data.utils import (
load_hf_dataset,
add_dataset_index,
preprocess_pretraining_instance,
)
class CompletionDataset(Dataset):
def __init__(
self,
hf_args,
template_args,
tokenizer,
prefix_key="prompt",
text_key="text",
max_length=2048,
predict_with_generate=False,
insert_space=False,
):
super(CompletionDataset, self).__init__()
self.tokenizer = tokenizer
self.max_length = max_length
self.data = load_hf_dataset(**hf_args)
self.data = add_dataset_index(self.data)
# if either key does not exist in dataset, it is taken as ""
self.prefix_key = prefix_key
self.text_key = text_key
self.predict_with_generate = predict_with_generate
self.insert_space = insert_space
def __len__(self):
return len(self.data)
def _process_sample(self, prefix, text_content, index=-1):
tokenized_data = preprocess_pretraining_instance(
self.tokenizer,
prefix,
text_content,
self.max_length,
self.predict_with_generate,
self.insert_space,
)
item_dct = {
"input_ids": tokenized_data["input_ids"],
"labels": tokenized_data["labels"],
"attention_mask": tokenized_data["attention_mask"],
}
if index != -1:
item_dct["index"] = index
return item_dct
def __getitem__(self, idx):
pref = self.data[idx].get(self.prefix_key, "")
text_content = self.data[idx].get(self.text_key, "")
index = self.data[idx]["index"]
item = self._process_sample(pref, text_content, index)
return item
class PretrainingDataset(Dataset):
def __init__(
self, hf_args, template_args, tokenizer, text_key="text", max_length=2048
):
super(PretrainingDataset, self).__init__()
self.tokenizer = tokenizer
self.max_length = max_length
self.chunks = self._chunk_raw_text(load_hf_dataset(**hf_args)[text_key])
def _chunk_raw_text(self, raw_text):
raw_text = "\n\n".join(raw_text)
full_token_sequence = self.tokenizer(raw_text, add_special_tokens=False)[
"input_ids"
]
num_chunks = len(full_token_sequence) // self.max_length + 1
chunks = []
for i in range(num_chunks):
chunks.append(
self.tokenizer.decode(
full_token_sequence[i * self.max_length : (i + 1) * self.max_length]
)
)
return chunks
def __len__(self):
return len(self.chunks)
def __getitem__(self, idx):
return preprocess_pretraining_instance(
self.tokenizer, "", self.chunks[idx], self.max_length
)