Skip to content

Commit ae845e8

Browse files
authored
Merge branch 'main' into dsikka/examples
2 parents fb72a6e + a3538f6 commit ae845e8

29 files changed

+4777
-8
lines changed
File renamed without changes.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
speculators convert --eagle3 yuhuili/EAGLE3-LLaMA3.1-Instruct-8B eagle3-llama-3.1-8b-instruct-converted meta-llama/Meta-Llama-3.1-8B-Instruct --validate

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,8 @@ select = [
228228
"E722", # allow bare exception catching
229229
"C90", # ignore code complexity errors
230230
"PGH004", # allow general style ignores
231+
"ERA001", # allow commented-out code
232+
"N80", # allow upper case names
231233
]
232234

233235
[tool.ruff.lint.isort]

research/hass/README.md

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Train Time Test HASS
2+
#### This code uses the HASS method (https://github.com/HArmonizedSS/HASS) to train models that are a variation on the Eagle 1 architecture. The original Eagle code can be found here: https://github.com/SafeAILab/EAGLE
3+
4+
## To Run:
5+
6+
The training process is broken up in to 2 steps, the first where you generate data from the large model, and the second where you actually train the drafter model. It works for Llama 3.1.8B-Instruct, Llama 3.3.70B-Instruct, and Mistral Small 24B 3503.
7+
8+
9+
### Data Generation step:
10+
11+
1. Modify the directory names and arguments in `gen_data.sh`
12+
2. You can get the data for ShareGPT at: Aeala/ShareGPT_Vicuna_unfiltered. Ultrachat will be automatically downloaded.
13+
3. Make sure that the system prompts and chat template demarkation in the desired file (ultrachat.py or sharegpt.py, ultrachatMistral.py or sharegptMistral.py) are correct
14+
4. Run the script: `./gen_data.sh`
15+
5. Run for each of: sharegpt, and ultrachat sft and gen splits.
16+
Notes: For llama 3.1.8B this will generate ~4TB of data on your system. The script for training searches your data directory recursively, so the internal structure of your data directory does not matter.
17+
18+
### Run training
19+
1. Modify the directory names and arguments in `train.sh`
20+
3. Run `./train.sh`
21+
22+
### Serve the model with vllm:
23+
1. Convert your saved model with: `convert.sh`
24+
2. Serve the model with: ` VLLM_USE_V1=1 vllm serve meta-llama/Llama-3.1-8B-Instruct --seed 42 -tp 1 --speculative-config '{"model": "llama_eagle", "num_speculative_tokens": 3, "method":"eagle3", "draft_tensor_parallel_size":1}'`
25+
26+
27+
28+
### TODO:
29+
1. Throw an error if you attempt to create a model that will not be supported in vllm - with the wrong configuration of heads etc.

research/hass/__init__.py

Whitespace-only changes.

research/hass/data/mt_bench/question.jsonl

Lines changed: 80 additions & 0 deletions
Large diffs are not rendered by default.

research/hass/ge_data/__init__.py

Whitespace-only changes.
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# This file is adapted from https://github.com/HArmonizedSS/HASS (arxiv: https://arxiv.org/abs/2408.15766)
2+
# Which is a fork of the Eagle repository: https://github.com/SafeAILab/EAGLE (arxiv: https://arxiv.org/abs/2401.15077)
3+
4+
5+
import argparse
6+
import os
7+
from concurrent.futures import ThreadPoolExecutor
8+
9+
parser = argparse.ArgumentParser(description="sp")
10+
parser.add_argument("--outdir", type=str, default="0")
11+
parser.add_argument("--data_path", type=str, default="0")
12+
parser.add_argument("--model_path", type=str, default="0")
13+
parser.add_argument("--dataset", type=str, default="ultrachat")
14+
parser.add_argument("--total_gpus", type=int, default=8)
15+
parser.add_argument("--gpus_per_model", type=int, default=1)
16+
parser.add_argument("--samples", type=int, default=68000)
17+
parser.add_argument("--split", type=str, default="sft")
18+
parser.add_argument("--chat_template", type=str, default="llama")
19+
20+
21+
args = parser.parse_args()
22+
23+
24+
s = 0
25+
e = args.samples
26+
gpus = [
27+
[i + j for j in range(args.gpus_per_model)]
28+
for i in range(0, args.total_gpus, args.gpus_per_model)
29+
]
30+
31+
32+
num_p = len(gpus)
33+
outdir = args.outdir
34+
35+
36+
def split_range(start, end, n, over=False):
37+
length = end - start + 1 # Include the end
38+
base_interval = length // n
39+
additional = length % n # Get the remainder of the division
40+
intervals = []
41+
previous = start
42+
43+
for i in range(n):
44+
current_interval = base_interval + (1 if i < additional else 0)
45+
if over:
46+
intervals.append((previous, previous + current_interval))
47+
else:
48+
intervals.append(
49+
(previous, previous + current_interval - 1)
50+
) # '-1' because the end is inclusive
51+
previous += current_interval
52+
53+
return intervals
54+
55+
56+
def run_command(cmd):
57+
os.system(cmd) # noqa: S605
58+
59+
60+
if not os.path.exists(outdir):
61+
os.makedirs(outdir)
62+
63+
64+
data_a = split_range(s, e, num_p, over=True)
65+
commands = []
66+
for i in range(num_p):
67+
index = i
68+
start = data_a[i][0]
69+
end = data_a[i][1]
70+
# gpu_index_str = [str(i) for i in gpu_index]
71+
# gpu_index_str=','.join(gpu_index_str)
72+
gpu_index = gpus[i]
73+
gpu_index_str = " ".join(map(str, gpu_index))
74+
# gpu_index_str='['+gpu_index_str+']'
75+
if args.chat_template == "llama":
76+
command = (
77+
f"python ge_data/{args.dataset}.py --start={start} --end={end} "
78+
f"--index={index} --gpu_index {gpu_index_str} --outdir {outdir} "
79+
f"--data_path {args.data_path} --model_path {args.model_path} "
80+
f"--split {args.split}"
81+
)
82+
elif args.chat_template == "mistral":
83+
command = (
84+
f"python ge_data/{args.dataset}Mistral.py --start={start} --end={end} "
85+
f"--index={index} --gpu_index {gpu_index_str} --outdir {outdir} "
86+
f"--data_path {args.data_path} --model_path {args.model_path} "
87+
f"--split {args.split}"
88+
)
89+
else:
90+
raise NotImplementedError(
91+
"Only llama and mistral chat templates are supported."
92+
)
93+
94+
commands.append(command)
95+
96+
with ThreadPoolExecutor(max_workers=len(commands)) as executor:
97+
for command in commands:
98+
executor.submit(run_command, command)
99+
print(command)
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
# This file is adapted from https://github.com/HArmonizedSS/HASS (arxiv: https://arxiv.org/abs/2408.15766)
2+
# Which is a fork of the Eagle repository: https://github.com/SafeAILab/EAGLE (arxiv: https://arxiv.org/abs/2401.15077)
3+
4+
import argparse
5+
import os
6+
7+
import torch
8+
from datasets import load_dataset
9+
from transformers import AutoModelForCausalLM, AutoTokenizer
10+
11+
parser = argparse.ArgumentParser(description="sp")
12+
parser.add_argument("--start", type=int, default=0)
13+
parser.add_argument("--end", type=int, default=100)
14+
parser.add_argument("--index", type=int, default=1)
15+
parser.add_argument("--gpu_index", type=int, nargs="+", default=[0])
16+
parser.add_argument("--outdir", type=str, default="outdir0")
17+
parser.add_argument("--data_path", type=str, default="0")
18+
parser.add_argument("--model_path", type=str, default="0")
19+
args = parser.parse_args()
20+
21+
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_index)[1:-1]
22+
23+
bigname = args.model_path
24+
25+
26+
def longest_common_prefix(list1, list2):
27+
prefix_length = 0
28+
min_length = min(len(list1), len(list2))
29+
30+
for i in range(min_length):
31+
if list1[i] == list2[i]:
32+
prefix_length += 1
33+
else:
34+
break
35+
36+
common_prefix = list1[:prefix_length]
37+
return common_prefix, prefix_length
38+
39+
40+
def build_dataset_rank(
41+
tokenizer,
42+
split="train", # noqa: ARG001
43+
select=None, # noqa: ARG001
44+
):
45+
ds = load_dataset("json", data_files=args.data_path)
46+
ds = ds["train"]
47+
ds = ds.shuffle(seed=42)
48+
ds1 = ds.select(range(args.start, args.end))
49+
50+
original_columns1 = ds1.column_names
51+
# original_columns2 = ds2.column_names
52+
num_proc = 4 # noqa: F841
53+
54+
def preprocess_function(examples):
55+
new_examples = {"conversation": [], "input_ids": [], "loss_mask": []}
56+
for i in range(len(examples["id"])):
57+
messages = [
58+
{
59+
"role": "system",
60+
"content": (
61+
"Cutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024"
62+
),
63+
},
64+
]
65+
convroles = ["user", "assistant"]
66+
roles = {"human": "user", "gpt": "assistant"}
67+
source = examples["conversations"][i]
68+
if roles[source[0]["from"]] != "user":
69+
# Skip the first one if it is not from human
70+
source = source[1:]
71+
for j, sentence in enumerate(source):
72+
role = roles[sentence["from"]]
73+
assert role == convroles[j % 2], f"{i}" # noqa: S101
74+
if sentence["from"] == "gpt":
75+
sentence["value"] = " " + sentence["value"]
76+
messages.append({"role": role, "content": sentence["value"]})
77+
conversation = tokenizer.apply_chat_template(
78+
messages,
79+
tokenize=False,
80+
add_generation_prompt=False,
81+
)
82+
83+
if not tokenizer.pad_token_id:
84+
tokenizer.pad_token_id = tokenizer.unk_token_id
85+
86+
input_ids = tokenizer(
87+
conversation,
88+
return_tensors="pt",
89+
max_length=4096,
90+
add_special_tokens=False,
91+
).input_ids[0]
92+
loss_mask = torch.ones_like(input_ids)
93+
# print(i)
94+
95+
sep = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
96+
97+
total_len = len(input_ids) # noqa: F841
98+
99+
sep2 = "<|eot_id|><|start_header_id|>user<|end_header_id|>"
100+
turns = conversation.split(sep2)
101+
102+
turns[1] = turns[0] + sep2 + turns[1]
103+
turns = turns[1:]
104+
105+
cur_len = 1
106+
loss_mask[:cur_len] = 0
107+
for i, turn in enumerate(turns): # noqa: PLW2901
108+
if turn == "":
109+
break
110+
turn_len = len(tokenizer(turn).input_ids)
111+
112+
parts = turn.split(sep)
113+
if len(parts) != 2:
114+
break
115+
parts[0] += sep
116+
# "-2" is hardcoded for the Llama tokenizer to make the offset correct.
117+
instruction_len = len(tokenizer(parts[0]).input_ids) - 1
118+
119+
# Ignore the user instructions
120+
if i == 0:
121+
loss_mask[cur_len : cur_len + instruction_len - 2] = 0
122+
else:
123+
loss_mask[cur_len - 3 : cur_len + instruction_len + 1] = 0
124+
cur_len += turn_len
125+
if i != 0:
126+
cur_len += 3
127+
# cur_len+=2
128+
129+
# if i != 0 and not tokenizer.legacy:
130+
# # The legacy and non-legacy modes handle special toks differently
131+
# cur_len -= 1
132+
133+
loss_mask[cur_len:] = 0
134+
135+
new_examples["conversation"].append(conversation)
136+
new_examples["input_ids"].append(input_ids[None, :])
137+
new_examples["loss_mask"].append(loss_mask[None, :])
138+
139+
return new_examples
140+
141+
ds1 = ds1.map(
142+
preprocess_function,
143+
batched=True,
144+
# num_proc=num_proc,
145+
remove_columns=original_columns1,
146+
load_from_cache_file=False,
147+
)
148+
149+
ds1.set_format(type="torch")
150+
return ds1
151+
152+
153+
bigtokenizer = AutoTokenizer.from_pretrained(bigname, use_fast=False)
154+
ds = build_dataset_rank(bigtokenizer)
155+
print(ds)
156+
bigmodel = AutoModelForCausalLM.from_pretrained(
157+
bigname, device_map="auto", torch_dtype=torch.float16
158+
)
159+
bigmodel.eval()
160+
161+
162+
@torch.no_grad()
163+
def ge(data):
164+
input_ids = data["input_ids"]
165+
num_layers = len(bigmodel.model.layers)
166+
outs_big = bigmodel(input_ids.cuda(), output_hidden_states=True)
167+
# hidden_state_big = outs_big.hidden_states[-1]
168+
featureFusion = [
169+
outs_big.hidden_states[3],
170+
outs_big.hidden_states[num_layers // 2 + 1],
171+
outs_big.hidden_states[-3],
172+
]
173+
target = outs_big.hidden_states[-1]
174+
hidden_state_big = torch.cat(featureFusion, dim=-1)
175+
max_prob_tokens_big = torch.argmax(outs_big.logits, dim=-1) # noqa: F841
176+
probs = torch.softmax(outs_big.logits, dim=-1)
177+
maxp = probs[0].max(dim=1).values # noqa: F841
178+
return {
179+
"input_ids": input_ids.cpu()[0],
180+
"hidden_state": hidden_state_big.cpu()[0],
181+
"loss_mask": data["loss_mask"].cpu()[0],
182+
"target": target.cpu()[0],
183+
}
184+
185+
186+
outdir = f"{args.outdir}/{args.index}"
187+
if not os.path.exists(outdir):
188+
os.makedirs(outdir)
189+
190+
191+
def writedata(name, data_point):
192+
if not os.path.exists(name):
193+
os.makedirs(name)
194+
current_length = len(os.listdir(name))
195+
idx = current_length
196+
torch.save(data_point, f"{name}/data_{idx}.ckpt")
197+
198+
199+
for id_, data in enumerate(ds):
200+
if id_ % 100 == 0:
201+
print(id_, end="\t")
202+
if id_ % 1000 == 0:
203+
print("")
204+
outdata = ge(data)
205+
writedata(outdir, outdata)

0 commit comments

Comments
 (0)