Skip to content

Commit 3edac10

Browse files
committed
[Add] Target Confidence Keep Generation Mechanism
1 parent e771c60 commit 3edac10

4 files changed

Lines changed: 355 additions & 8 deletions

File tree

experiments/run_confidence.sh

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#!/bin/bash
2+
3+
4+
python3 speculative_decoding/run_speculative_decoding_target_confidence.py \
5+
--target_model_path ./models/HuggingFaceTB--SmolLM2-1.7B-Instruct \
6+
--draft_model_path ./models/HuggingFaceTB--SmolLM2-135M-Instruct \
7+
--device cuda:0 \
8+
--question 'What is the capital of Taiwan. And why?' \
9+
--gamma 5 \
10+
--test_token_num 100
Lines changed: 327 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,327 @@
1+
from typing import Dict, List, Optional, Tuple
2+
3+
import os
4+
import sys
5+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
6+
7+
import argparse
8+
import copy
9+
import time
10+
11+
import torch
12+
from torch.nn.utils.rnn import pad_sequence
13+
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase
14+
15+
from sampling.sampling import sample_next_token
16+
17+
18+
"""
19+
python speculative_decoding/run_speculative_decoding.py \
20+
--target_model_path HuggingFaceTB/SmolLM2-1.7B-Instruct \
21+
--draft_model_path HuggingFaceTB/SmolLM2-135M-Instruct \
22+
--device cuda:0 \
23+
--question 'What is the capital of Taiwan. And why?' \
24+
--gamma 5 \
25+
--test_token_num 100
26+
"""
27+
28+
29+
def calculate_continuous_acceptance(acceptance_mask: torch.BoolTensor) -> int:
30+
continuous_acceptance = 0
31+
for accepted in acceptance_mask.long().squeeze(0):
32+
if accepted == 1:
33+
continuous_acceptance += 1
34+
else:
35+
break
36+
return continuous_acceptance
37+
38+
39+
def drafter_speculative_decode(
40+
draft_model: torch.nn.Module,
41+
draft_tokenizer: PreTrainedTokenizerBase,
42+
inputs: Dict[str, torch.Tensor],
43+
gamma: int = 10,
44+
temperature: float = 1.0,
45+
top_k: Optional[int] = 0, # Default is 0, it means do not select top-k tokens
46+
top_p: Optional[float] = 1.0,
47+
repetition_penalty: Optional[float] = 1.0,
48+
) -> Tuple[Dict[str, torch.Tensor], torch.FloatTensor]:
49+
draft_probs = []
50+
51+
for idx in range(gamma):
52+
with torch.no_grad():
53+
outputs = draft_model(**inputs)
54+
55+
next_tokens, probs = sample_next_token(
56+
logits=outputs.logits,
57+
prefix_token_ids=inputs["input_ids"],
58+
temperature=temperature,
59+
top_k=top_k,
60+
top_p=top_p,
61+
repetition_penalty=repetition_penalty,
62+
)
63+
64+
draft_probs.append(probs)
65+
input_ids = torch.cat([inputs["input_ids"], next_tokens[:, -1:]], dim=-1)
66+
attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device)], dim=-1)
67+
68+
inputs["input_ids"] = input_ids
69+
inputs["attention_mask"] = attention_mask
70+
71+
return inputs, torch.cat(draft_probs, dim=1)
72+
73+
74+
def target_speculative_decode(
75+
target_model: torch.nn.Module,
76+
target_tokenizer: PreTrainedTokenizerBase,
77+
inputs: Dict[str, torch.Tensor],
78+
draft_probs: torch.FloatTensor,
79+
temperature: float = 1.0,
80+
top_k: Optional[int] = 0, # Default is 0, it means do not select top-k tokens
81+
top_p: Optional[float] = 1.0,
82+
repetition_penalty: Optional[float] = 1.0,
83+
) -> Tuple[Dict[str, torch.Tensor], bool, int]:
84+
with torch.no_grad():
85+
outputs = target_model(**inputs)
86+
87+
next_tokens, target_probs = sample_next_token(
88+
logits=outputs.logits,
89+
prefix_token_ids=inputs["input_ids"],
90+
temperature=temperature,
91+
top_k=top_k,
92+
top_p=top_p,
93+
repetition_penalty=repetition_penalty,
94+
probs_num=draft_probs.shape[1] + 1,
95+
)
96+
97+
next_token = next_tokens[:, -1:]
98+
99+
# Evaluation
100+
indices = inputs["input_ids"][:, -draft_probs.shape[1]:]
101+
102+
eval_probs = target_probs[:, :-1, :]
103+
104+
expanded_indices = indices.unsqueeze(-1)
105+
selected_draft_probs = torch.gather(draft_probs, dim=-1, index=expanded_indices)
106+
selected_draft_probs = selected_draft_probs.squeeze(-1)
107+
108+
selected_eval_probs = torch.gather(eval_probs, dim=-1, index=expanded_indices)
109+
selected_eval_probs = selected_eval_probs.squeeze(-1)
110+
111+
# Compare draft_prob and eval_prob, and check the reject_mask
112+
mask_to_reject = selected_draft_probs > selected_eval_probs
113+
114+
# Calculate reject probabilty 1 - (eval_prob / draft_prob)
115+
rejection_probs = 1 - (selected_eval_probs / selected_draft_probs)
116+
117+
# Generate random values to determined accept or reject
118+
random_values = torch.rand_like(rejection_probs)
119+
rejection_decisions = random_values < rejection_probs
120+
121+
# Get the final reject masks
122+
rejection_masks = mask_to_reject & rejection_decisions
123+
acceptance_mask = torch.ones_like(selected_draft_probs, dtype=torch.bool)
124+
acceptance_mask[rejection_masks] = False
125+
126+
is_end = False
127+
128+
# Concat `input_ids`
129+
confidence_score = 0
130+
131+
if torch.all(acceptance_mask):
132+
input_ids = torch.cat([inputs["input_ids"], next_token], dim=-1)
133+
attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device)], dim=-1)
134+
confidence_score = target_probs[:, -1, next_token[0][0]].item()
135+
print(f"Confidence for next token: {confidence_score:.4f}")
136+
else:
137+
new_input_ids = []
138+
new_attention_mask = []
139+
140+
for batch_idx in range(next_tokens.shape[0]):
141+
gamma = next_tokens.shape[1] - 1
142+
start_idx = inputs["input_ids"].shape[1] - gamma
143+
144+
for pos_idx in range(acceptance_mask[batch_idx].shape[0]):
145+
if (acceptance_mask[batch_idx][pos_idx] and inputs["input_ids"][batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id) or not acceptance_mask[batch_idx][pos_idx]:
146+
inputs["input_ids"][batch_idx][start_idx+pos_idx] = next_tokens[batch_idx][pos_idx]
147+
confidence_score = target_probs[batch_idx, pos_idx, next_tokens[batch_idx, pos_idx]].max().item()
148+
print(f"Replacement Confidence for next token: {confidence_score:.4f}")
149+
150+
new_input_ids.append(inputs["input_ids"][batch_idx][:start_idx+pos_idx+1])
151+
new_attention_mask.append(inputs["attention_mask"][batch_idx][:start_idx+pos_idx+1])
152+
153+
is_end = inputs["input_ids"][batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id
154+
break
155+
156+
input_ids = pad_sequence(new_input_ids, batch_first=True, padding_value=target_tokenizer.pad_token_id)
157+
attention_mask = pad_sequence(new_attention_mask, batch_first=True, padding_value=0)
158+
159+
inputs["input_ids"] = input_ids
160+
inputs["attention_mask"] = attention_mask
161+
162+
# Keep generating if confidence_score is less than confidence threshold
163+
while confidence_score < 0.5:
164+
with torch.no_grad():
165+
outputs = target_model(**inputs)
166+
167+
next_tokens, target_probs = sample_next_token(
168+
logits=outputs.logits,
169+
prefix_token_ids=inputs["input_ids"],
170+
temperature=temperature,
171+
top_k=top_k,
172+
top_p=top_p,
173+
repetition_penalty=repetition_penalty,
174+
probs_num=1,
175+
)
176+
177+
# Update `confidence_score`
178+
next_token = next_tokens[:, -1:]
179+
confidence_score = target_probs[0, -1, next_token[0][0]].item()
180+
print(f"keep generate confidence_score: {confidence_score:.4f}")
181+
182+
input_ids = torch.cat([inputs["input_ids"], next_token], dim=-1)
183+
attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device)], dim=-1)
184+
185+
inputs["input_ids"] = input_ids
186+
inputs["attention_mask"] = attention_mask
187+
188+
is_end = inputs["input_ids"][0][-1].item() == target_tokenizer.eos_token_id
189+
if is_end:
190+
break
191+
192+
return inputs, is_end, calculate_continuous_acceptance(acceptance_mask)
193+
194+
195+
def run_test(args) -> None:
196+
# Device
197+
device = torch.device(args.device if args.device != "cpu" and torch.cuda.is_available() else "cpu")
198+
print(device)
199+
200+
# Model path
201+
target_model_path = args.target_model_path
202+
draft_model_path = args.draft_model_path
203+
204+
# Load Tokenizer
205+
draft_tokenizer = AutoTokenizer.from_pretrained(draft_model_path)
206+
target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
207+
208+
# Load Model
209+
draft_model = AutoModelForCausalLM.from_pretrained(draft_model_path, torch_dtype=torch.bfloat16).to(device)
210+
target_model = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.bfloat16).to(device)
211+
212+
# Tokenize
213+
messages = [
214+
[
215+
{
216+
"role": "user",
217+
"content": args.question,
218+
},
219+
],
220+
]
221+
222+
input_text=draft_tokenizer.apply_chat_template(messages, tokenize=False)
223+
inputs = draft_tokenizer(
224+
input_text,
225+
return_tensors="pt",
226+
max_length=512,
227+
truncation=True,
228+
padding=True,
229+
).to(device)
230+
231+
# Warm up the model (CUDA)
232+
inputs_dummy = {k: v.clone() for k, v in inputs.items()}
233+
with torch.no_grad():
234+
draft_model(**inputs_dummy)
235+
target_model(**inputs_dummy)
236+
torch.cuda.synchronize()
237+
238+
is_end = False
239+
240+
# Record
241+
raw_inputs = copy.deepcopy(inputs)
242+
raw_token_num = raw_inputs["input_ids"].shape[1]
243+
start_time = time.time()
244+
245+
total_draft_tokens = 0
246+
total_accept_tokens = 0
247+
gamma = args.gamma
248+
max_new_tokens = args.test_token_num
249+
250+
while not is_end:
251+
# Draft model
252+
target_inputs, draft_probs = drafter_speculative_decode(
253+
draft_model=draft_model,
254+
draft_tokenizer=draft_tokenizer,
255+
inputs=inputs,
256+
gamma=gamma,
257+
)
258+
259+
total_draft_tokens += gamma
260+
261+
# Target model
262+
outputs, is_end, accept_tokens = target_speculative_decode(
263+
target_model=target_model,
264+
target_tokenizer=target_tokenizer,
265+
inputs=target_inputs,
266+
draft_probs=draft_probs,
267+
)
268+
269+
total_accept_tokens += accept_tokens
270+
271+
inputs = outputs
272+
273+
if inputs["input_ids"].shape[1] - raw_token_num >= max_new_tokens:
274+
break
275+
276+
generate_token_num = outputs["input_ids"].shape[1] - raw_token_num
277+
spent_time = time.time() - start_time
278+
279+
print(f"Generate token number: {generate_token_num}")
280+
print(f"Generate speed: {generate_token_num / spent_time} tokens/sec")
281+
print(f"Speculative Decoding Spent Time: {spent_time} seconds.")
282+
print(f"Accept Rate: {total_accept_tokens / total_draft_tokens}\n")
283+
284+
# Normal Target Model Speed
285+
raw_inputs = copy.deepcopy(inputs)
286+
start_time = time.time()
287+
target_inputs, draft_probs = drafter_speculative_decode(
288+
draft_model=target_model,
289+
draft_tokenizer=draft_tokenizer,
290+
inputs=raw_inputs,
291+
gamma=args.test_token_num,
292+
)
293+
294+
spent_time = time.time() - start_time
295+
296+
print(f"Generate token number: {max_new_tokens}")
297+
print(f"Generate speed: {max_new_tokens / spent_time} tokens/sec")
298+
print(f"Normal Target Model Decoding Spent Time: {spent_time} seconds.\n")
299+
300+
# Normal Draft Model Speed
301+
raw_inputs = copy.deepcopy(inputs)
302+
start_time = time.time()
303+
target_inputs, draft_probs = drafter_speculative_decode(
304+
draft_model=draft_model,
305+
draft_tokenizer=draft_tokenizer,
306+
inputs=raw_inputs,
307+
gamma=args.test_token_num,
308+
)
309+
310+
spent_time = time.time() - start_time
311+
312+
print(f"Generate token number: {max_new_tokens}")
313+
print(f"Generate speed: {max_new_tokens / spent_time} tokens/sec")
314+
print(f"Normal Draft Model Decoding Spent Time: {spent_time} seconds.\n")
315+
316+
317+
if __name__ == "__main__":
318+
parser = argparse.ArgumentParser()
319+
parser.add_argument("--target_model_path", type=str, default="HuggingFaceTB/SmolLM2-1.7B-Instruct")
320+
parser.add_argument("--draft_model_path", type=str, default="HuggingFaceTB/SmolLM2-135M-Instruct")
321+
parser.add_argument("--device", type=str, default="cpu")
322+
parser.add_argument("--question", type=str, default="What is the capital of Taiwan. And why?")
323+
parser.add_argument("--gamma", type=int, default=5)
324+
parser.add_argument("--test_token_num", type=int, default=100)
325+
args = parser.parse_args()
326+
327+
run_test(args)

self_speculative_decoding/run_self_sepculative_decoding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def target_speculative_decode(
247247
spent_time = time.time() - start_time
248248

249249
print(f"Generate token number: {generate_token_num}")
250-
print(f"Generate speed: {generate_token_num / spent_time} token/sec")
250+
print(f"Generate speed: {generate_token_num / spent_time} tokens/sec")
251251
print(f"Speculative Decoding Spent Time: {spent_time} seconds.")
252252
print(f"Accept Rate: {total_accept_tokens / total_draft_tokens}\n")
253253

@@ -265,7 +265,7 @@ def target_speculative_decode(
265265
spent_time = time.time() - start_time
266266

267267
print(f"Generate token number: {max_new_tokens}")
268-
print(f"Generate speed: {max_new_tokens / spent_time} token/sec")
268+
print(f"Generate speed: {max_new_tokens / spent_time} tokens/sec")
269269
print(f"Normal Target Model Decoding Spent Time: {spent_time} seconds.\n")
270270

271271
# Normal Draft Model Speed

0 commit comments

Comments
 (0)