-
Notifications
You must be signed in to change notification settings - Fork 46
/
Copy pathutils.py
176 lines (156 loc) · 7.23 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import torch
import datasets
import numpy as np
from typing import List, Dict, Any, Union
IGNORE_INDEX = -100 # TODO put in common constants
def load_hf_dataset(path, **kwargs):
dataset = datasets.load_dataset(path, **kwargs)
return dataset
def preprocess_chat_instance(
tokenizer,
template_config: Dict[str, Any],
prompt_msgs: Union[List[str], str],
response_msgs: Union[List[str], str],
max_length: int,
predict_with_generate: bool = False,
) -> Dict[str, torch.Tensor]:
"""Preprocesses a chat instance for training or generation.
When in training, both the returned `input_ids` and `labels` cover the entire conversation.
`input_ids` has no padding, and `labels` assign `IGNORE_INDEX` to tokens where loss is not computed (i.e. all tokens except the final response message).
When in generation, `input_ids` are returned only up to the last user prompt, excluding the assistant's response. The `labels` returned are the same as during training.
`attention_mask` is always 1 over the full `input_ids` token sequence.
`prompt_msgs` and `response_msgs` are lists where, except for the last pair, all
corresponding pairs are in-context examples. When they are a string and not
a list, there are no in-context examples.
Args:
tokenizer: Tokenizer to apply on text
template_config (Dict[str, Any]): Configuration for the chat template (comes from model-specific config).
prompt_msgs (Union[List[str], str]): List of prompt messages or a single prompt message string.
response_msgs (Union[List[str], str]): List of response messages or a single response message string.
max_length (int): Maximum sequence length after tokenization.
predict_with_generate (bool, optional): Whether to prepare inputs for generation.
Returns:
Dict[str, torch.Tensor]: A dictionary containing 'input_ids', 'labels', and 'attention_mask' tensors for model input.
"""
assert len(prompt_msgs) == len(response_msgs)
if isinstance(prompt_msgs, str):
assert isinstance(response_msgs, str)
prompt_msgs, response_msgs = [prompt_msgs], [response_msgs]
if template_config["apply_chat_template"]:
chat = []
system_prompt = template_config.get("system_prompt", None)
if system_prompt:
chat += [{"role": "system", "content": system_prompt}]
for prompt, response in zip(prompt_msgs, response_msgs):
chat += [{"role": "user", "content": prompt}]
chat += [{"role": "assistant", "content": response}]
chat_ids = tokenizer.apply_chat_template(
chat, tokenize=True, add_generation_prompt=False
)
# all except last response are in-context examples
wrapped_prompt = tokenizer.apply_chat_template(
chat[:-1], tokenize=False, add_generation_prompt=True
)
prompt_ids = tokenizer.apply_chat_template(
chat[:-1], tokenize=True, add_generation_prompt=True
)
else:
wrapped_prompt = ""
system_prompt_with_special_tokens = template_config.get(
"system_prompt_with_special_tokens", None
)
if system_prompt_with_special_tokens:
wrapped_prompt += system_prompt_with_special_tokens
# add in-context examples
n_few_shot = len(prompt_msgs) - 1
for i in range(n_few_shot):
fs_prompt, fs_response = prompt_msgs[i], response_msgs[i]
wrapped_prompt += (
template_config["user_start_tag"]
+ fs_prompt
+ template_config["user_end_tag"]
+ template_config["asst_start_tag"]
+ fs_response
+ template_config["asst_end_tag"]
)
# add actual example
final_prompt, final_response = prompt_msgs[-1], response_msgs[-1]
wrapped_prompt += (
template_config["user_start_tag"]
+ final_prompt
+ template_config["user_end_tag"]
+ template_config["asst_start_tag"]
)
chat_ids = tokenizer(
wrapped_prompt + final_response,
add_special_tokens=True,
max_length=max_length,
truncation=True,
)["input_ids"]
prompt_ids = tokenizer(
wrapped_prompt,
add_special_tokens=True,
max_length=max_length,
truncation=True,
)["input_ids"]
if chat_ids[-1] != tokenizer.eos_token_id:
chat_ids += [tokenizer.eos_token_id]
len_matched = len(prompt_ids)
item = {}
if predict_with_generate:
item["input_ids"] = prompt_ids
labels = chat_ids # contains the entire conversation
else:
item["input_ids"] = chat_ids
labels = [IGNORE_INDEX] * len_matched + chat_ids[len_matched:]
item["labels"] = labels
item["attention_mask"] = [1] * len(item["input_ids"])
for attr in item:
item[attr] = torch.tensor(item[attr])
return item
def preprocess_pretraining_instance(
tokenizer,
prefix: str,
text_content: str,
max_length: int,
predict_with_generate: bool = False,
insert_space: bool = False,
) -> Dict[str, torch.Tensor]:
"""Preprocesses a pretraining instance for training or generation.
When in training, both the returned `input_ids` and `labels` are over the entire token sequence. `input_ids` has no padding, `labels` assigns `IGNORE_INDEX` to ignore all tokens that we don't compute loss over (i.e. the the 0th index token, all prefix tokens)
When in generation, `input_ids` are returned only until the prefix portion. The `labels` returned are the same as during training.
`attention_mask` is always 1 over the full input token sequence.
Args:
tokenizer: Tokenizer to apply on text
prefix (str): The prefix string to prepend to the content.
text_content (str): The main text content (following the prefix) to be tokenized.
max_length (int): Maximum text content length after tokenization.
predict_with_generate (bool, optional): Whether to prepare inputs for generation.
insert_space (bool, optional): Whether to insert a space between prefix and content.
Returns:
Dict[str, torch.Tensor]: A dictionary containing 'input_ids', 'labels', and 'attention_mask' tensors for model input.
"""
full_seq_ids = tokenizer(
prefix + (" " if insert_space else "") + text_content, add_special_tokens=True
)["input_ids"]
prefix_ids = tokenizer(prefix, add_special_tokens=True)["input_ids"]
prefix_len = len(prefix_ids)
full_seq_ids = full_seq_ids[: prefix_len + max_length] # manual truncation
len_matched = prefix_len
if len_matched == 0: # never give loss on index 0, when prefix is empty
len_matched = 1
labels = [IGNORE_INDEX] * len_matched + full_seq_ids[len_matched:]
item = {}
if predict_with_generate:
item["input_ids"] = prefix_ids
else:
item["input_ids"] = full_seq_ids
item["labels"] = labels
item["attention_mask"] = [1] * len(item["input_ids"])
for attr in item:
item[attr] = torch.tensor(item[attr])
return item
def add_dataset_index(dataset):
indexing = np.arange(len(dataset))
dataset = dataset.add_column("index", indexing)
return dataset