diff --git a/medusa/inference/api_client.py b/medusa/inference/api_client.py new file mode 100644 index 0000000..2344600 --- /dev/null +++ b/medusa/inference/api_client.py @@ -0,0 +1,79 @@ +"""Example Python client for vllm.entrypoints.api_server""" + +import argparse +import json +from typing import Iterable, List +import pdb +import requests + + +def clear_line(n: int = 1) -> None: + LINE_UP = '\033[1A' + LINE_CLEAR = '\x1b[2K' + for _ in range(n): + print(LINE_UP, end=LINE_CLEAR, flush=True) + + +def post_http_request(prompt, + api_url: str, + n: int = 1, + stream: bool = False) -> requests.Response: + headers = {"User-Agent": "Test Client"} + pload = { + "prompt":prompt, + "max_tokens":150 + } + print(pload) + response = requests.post(api_url, headers=headers, json=pload, stream=True) + return response + + +def get_streaming_response(response: requests.Response) -> Iterable[List[str]]: + for chunk in response.iter_lines(chunk_size=8192, + decode_unicode=False, + delimiter=b"\0"): + if chunk: + data = json.loads(chunk.decode("utf-8")) + output = data["text"] + yield output + + +def get_response(response: requests.Response) -> List[str]: + print(response.content) + data = json.loads(response.content) + output = data["text"] + return output + +def add_prefix(prompt): + prompt_ = f"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {prompt} ASSISTANT:" + return prompt_ + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--n", type=int, default=1) + parser.add_argument("--prompt", type=str, default="San Francisco is a") + parser.add_argument("--stream", action="store_true") + args = parser.parse_args() + prompt = args.prompt + api_url = f"http://{args.host}:{args.port}/generate" + n = args.n + stream = True + prompt = "你叫什么名字?" + prompt = add_prefix(prompt) + print(f"Prompt: {prompt!r}\n", flush=True) + response = post_http_request(prompt, api_url, n, stream) + + if stream: + num_printed_lines = 0 + for h in get_streaming_response(response): + clear_line(num_printed_lines) + num_printed_lines = 0 + for i, line in enumerate(h): + num_printed_lines += 1 + print(f"Beam candidate {i}: {line!r}", flush=True) + else: + output = get_response(response) + for i, line in enumerate(output): + print(f"Beam candidate {i}: {line!r}", flush=True) \ No newline at end of file diff --git a/medusa/inference/api_server.py b/medusa/inference/api_server.py new file mode 100644 index 0000000..dbc7512 --- /dev/null +++ b/medusa/inference/api_server.py @@ -0,0 +1,175 @@ +import argparse +import json +from typing import AsyncGenerator +import torch +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse, Response, StreamingResponse +import uvicorn +from medusa.model.medusa_model import MedusaModel +import asyncio +from collections import deque +import uuid +from contextlib import asynccontextmanager + +TIMEOUT_KEEP_ALIVE = 5 # seconds. +engine = None +max_batch_size = 5 +request_queue = deque() +id2result = {} + +async def handle_request(request_data): + request_queue.append(request_data) + +async def get_batch_from_queue(): + prompts = [] + ids = [] + if args.origin_model: + request_dict_ = {"temperature":0.5, "max_tokens":150, "top_p": 0.85} + else: + request_dict_ = {"temperature":0.0, "max_tokens":150, "top_p": 0.85} + max_tokens = None + start_time = asyncio.get_event_loop().time() # 获取当前时间 + while len(prompts) < max_batch_size: + # 检查是否超时 + if asyncio.get_event_loop().time() - start_time >= 0.03: + break + # 如果队列为空,等待1ms再尝试 + if not request_queue: + await asyncio.sleep(0.001) + continue + request_dict = request_queue.popleft() + if request_dict.get("max_tokens", None): + if max_tokens: + max_tokens = max(max_tokens, request_dict["max_tokens"]) + else: + max_tokens = request_dict["max_tokens"] + prompts.append(request_dict.pop("prompt")) + ids.append(request_dict.pop("unique_id")) + if max_tokens: + request_dict_["max_tokens"] = max_tokens + if len(prompts) > 0 and request_dict.get("temperature", None): + request_dict_["temperature"] = request_dict["temperature"] + if len(prompts) > 0 and request_dict.get("top_p", None): + request_dict_["top_p"] = request_dict["top_p"] + return prompts, ids, request_dict_ + + +async def run_model(): + while True: + prompt, ids, request_dict = await get_batch_from_queue() + if len(prompt) >0: + print(f"batch size: {len(prompt)}") + encoded_inputs = tokenizer(prompt, padding=True, truncation=True, return_tensors="pt") + input_ids = encoded_inputs['input_ids'].to(engine.base_model.device) + attention_mask = encoded_inputs['attention_mask'].to(engine.base_model.device) + for request_output in engine.medusa_generate( + input_ids=input_ids, + attention_mask=attention_mask, + temperature=request_dict["temperature"], + max_steps=request_dict["max_tokens"], + top_p=request_dict["top_p"] + ): + await asyncio.sleep(0.001) + for index, id in enumerate(ids): + if id2result[id] is None: + id2result[id] = {'text':None, 'sign':None, 'finished':False} + if id2result[id]['text'] != request_output["text"][index]: + id2result[id]['text'] = request_output["text"][index] #full_sentences[index] + id2result[id]['sign'] = str(uuid.uuid4()) + + for index, id in enumerate(ids): + id2result[id]['finished'] = True + else: + pass + +app = FastAPI() + +@app.get("/health") +async def health() -> Response: + """Health check.""" + return Response(status_code=200) + +@app.on_event("startup") +async def startup_event(): + asyncio.create_task(run_model()) + +@app.post("/generate") +async def generate(request: Request) -> Response: + request_dict = await request.json() + unique_id = str(uuid.uuid4()) + request_dict["unique_id"] = unique_id + id2result[unique_id] = None + await handle_request(request_dict) ##接收数据放入queue + + async def stream_results(): + previous_sign = None + while True: ##循环取输出输出 + result = id2result.get(unique_id, None) + if result is not None: + if result['sign'] != previous_sign: ##是否更新 + full_sentence = result['text'] + ret = {"text":[full_sentence]} + previous_sign = result['sign'] + yield (json.dumps(ret) + "\0").encode("utf-8") + else: + if result['finished']: ##是否写完 + print(f"{unique_id} 全部输出完毕,删除") + id2result.pop(unique_id) + break + await asyncio.sleep(0.001) + else: + await asyncio.sleep(0.001) + + return StreamingResponse(stream_results()) ##返回数据 + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default=None) + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--ssl-keyfile", type=str, default=None) + parser.add_argument("--ssl-certfile", type=str, default=None) + parser.add_argument("--model", type=str, required=True, help="Model name or path.") + parser.add_argument("--origin-model", action="store_true") + parser.add_argument( + "--load-in-8bit", action="store_true", help="Use 8-bit quantization" + ) + parser.add_argument( + "--load-in-4bit", action="store_true", help="Use 4-bit quantization" + ) + parser.add_argument( + "--root-path", + type=str, + default=None, + help="FastAPI root_path when app is behind a path based routing proxy") + + args = parser.parse_args() + if args.origin_model: + import types + from medusa.model.origin_model import Model,Tokenizer, medusa_generate + from transformers_stream_generator import init_stream_support + init_stream_support() + engine = Model.from_pretrained(args.model) + tokenizer = Tokenizer.from_pretrained(args.model) + engine.medusa_generate = types.MethodType(medusa_generate, engine) + engine.tokenizer = tokenizer + print("启动原始模型") + else: + engine = MedusaModel.from_pretrained( + args.model, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + device_map="auto", + load_in_8bit=args.load_in_8bit, + load_in_4bit=args.load_in_4bit, + ) + tokenizer = engine.get_tokenizer() + print("启动medusa模型") + app.root_path = args.root_path + uvicorn.run(app, + host=args.host, + port=args.port, + log_level="debug", + timeout_keep_alive=TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile) \ No newline at end of file diff --git a/medusa/inference/inference_test.py b/medusa/inference/inference_test.py new file mode 100644 index 0000000..77342ed --- /dev/null +++ b/medusa/inference/inference_test.py @@ -0,0 +1,98 @@ +# Adapted from: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/cli.py +""" +Chat with a model with command line interface. + +Usage: +python3 -m medusa.inference.cli --model +Other commands: +- Type "!!exit" or an empty line to exit. +- Type "!!reset" to start a new conversation. +- Type "!!remove" to remove the last prompt. +- Type "!!regen" to regenerate the last message. +- Type "!!save " to save the conversation history to a json file. +- Type "!!load " to load a conversation history from a json file. +""" +import argparse +import os +import re +import sys +import torch +from fastchat.serve.cli import SimpleChatIO, RichChatIO, ProgrammaticChatIO +from fastchat.model.model_adapter import get_conversation_template +from fastchat.conversation import get_conv_template +import json +from medusa.model.medusa_model import MedusaModel +import pdb + +def main(args): + prefix = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {0} ASSISTANT:" + # prompt = ["你叫什么名字"] + # prompt = ["你叫什么名字", "中国的首都是哪里呢?"] + prompt = ["openai是家什么公司?", "2+2等于几?"] + prompt = [prefix.format(p) for p in prompt] + model = MedusaModel.from_pretrained( + args.model, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + device_map="auto", + load_in_8bit=args.load_in_8bit, + load_in_4bit=args.load_in_4bit, + ) + tokenizer = model.get_tokenizer() + # 使用tokenizer处理批量输入 + encoded_inputs = tokenizer(prompt, padding=True, truncation=True, return_tensors="pt") + # 将编码后的输入移动到模型所在的设备 + input_ids = encoded_inputs['input_ids'].to(model.base_model.device) + attention_mask = encoded_inputs['attention_mask'].to(model.base_model.device) + for output in model.medusa_generate( + input_ids, + attention_mask=attention_mask, + temperature=args.temperature, + # temperature=0, + max_steps=args.max_steps + ): + print(output['text']) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, required=True, help="Model name or path.") + parser.add_argument( + "--load-in-8bit", action="store_true", help="Use 8-bit quantization" + ) + parser.add_argument( + "--load-in-4bit", action="store_true", help="Use 4-bit quantization" + ) + parser.add_argument( + "--conv-template", type=str, default=None, help="Conversation prompt template." + ) + parser.add_argument( + "--conv-system-msg", type=str, default=None, help="Conversation system message." + ) + parser.add_argument("--temperature", type=float, default=0.7) + parser.add_argument("--max-steps", type=int, default=10) + parser.add_argument("--no-history", action="store_true") + parser.add_argument( + "--style", + type=str, + default="simple", + choices=["simple", "rich", "programmatic"], + help="Display style.", + ) + parser.add_argument( + "--multiline", + action="store_true", + help="Enable multiline input. Use ESC+Enter for newline.", + ) + parser.add_argument( + "--mouse", + action="store_true", + help="[Rich Style]: Enable mouse support for cursor positioning.", + ) + parser.add_argument( + "--debug", + action="store_true", + help="Print useful debug information (e.g., prompts)", + ) + args = parser.parse_args() + main(args) diff --git a/medusa/inference/origin_inference.py b/medusa/inference/origin_inference.py new file mode 100644 index 0000000..61daba9 --- /dev/null +++ b/medusa/inference/origin_inference.py @@ -0,0 +1,29 @@ +import torch +import pdb +import types +from medusa.model.origin_model import Model,Tokenizer, medusa_generate +from transformers_stream_generator import init_stream_support +init_stream_support() + + +prefix = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {0} ASSISTANT:" +prompt = ["openai是家什么公司?", "2+2等于几?"] +prompt = [prefix.format(p) for p in prompt] +model_dir='/mnt/wx/.cache/huggingface/hub/models--FasterDecoding--medusa-vicuna-7b-v1.3/snapshots/82ac200bf7502419cb49a9e0adcbebe3d1d293f1/' +model = Model.from_pretrained(model_dir) +tokenizer = Tokenizer.from_pretrained(model_dir) +model_inputs = tokenizer(prompt, padding=True, truncation=True, return_tensors="pt") +# 给实例对象添加方法 +model.tokenizer = tokenizer +model.medusa_generate = types.MethodType(medusa_generate, model) +input_ids = model_inputs['input_ids'].to(model.device) +attention_mask = model_inputs['attention_mask'].to(model.device) +generator = model.medusa_generate(input_ids=input_ids, + attention_mask=attention_mask, + temperature=0.1, + max_steps=20, + top_p=0.8) +for token in generator: + print(token['text']) + + diff --git a/medusa/model/kv_cache.py b/medusa/model/kv_cache.py index edb9956..dd5d908 100644 --- a/medusa/model/kv_cache.py +++ b/medusa/model/kv_cache.py @@ -1,5 +1,5 @@ import torch - +import copy class KVCache: """ @@ -41,32 +41,51 @@ def copy(self, indices: torch.Tensor, prev_length: int, dim: int = 2): Args: indices (torch.Tensor): Indices of the data tensor to be copied. - prev_length (int): Previous length before adding new data. + prev_length (int): Previous lengths before adding new data dim (int, optional): Dimension along which copying should be performed. Default is 2. """ + # 选取需要复制的数据 tgt = self.data.index_select(dim, indices) - dst = self.data.narrow(dim, prev_length, tgt.shape[dim]) - dst.copy_(tgt, non_blocking=True) - self.current_length.fill_(prev_length + tgt.shape[dim]) + prev_len = prev_length + start_index = prev_len + end_index = start_index + tgt.shape[dim] + # 根据维度选取目标区域并复制数据 + if dim == 2: + dst = self.data[:, :, :, start_index:end_index, :] + elif dim == 3: + dst = self.data[:, :, :, :, start_index:end_index] + else: + raise ValueError("Unsupported dimension for copying.") + dst.copy_(tgt[:, :], non_blocking=True) + self.current_length.fill_(prev_length + tgt.shape[dim]) def cat(self, tensor: torch.Tensor, dim: int = 2): """ - Concatenate the given tensor with the current data. + Concatenate the given tensor with the current data for batch_size > 1, and return the tensor + truncated to the maximum current length across all batches. Args: - tensor (torch.Tensor): The tensor to be concatenated. + tensor (torch.Tensor): The tensor to be concatenated, assuming the first dimension is the batch size. dim (int, optional): The dimension along which concatenation should be done. Default is 2. Returns: - torch.Tensor: The data tensor after concatenation up to the current length. + torch.Tensor: The data tensor after concatenation and truncation to the maximum current length. """ - dst = self.data.narrow(dim, self.current_length, tensor.shape[dim]) - dst.copy_(tensor) + cur_len = copy.deepcopy(self.current_length) + new_len = cur_len + tensor.size(dim) self.current_length.add_(tensor.shape[dim]) - return torch.narrow(self.data, 2, 0, self.current_length) - - -def initialize_past_key_values(model): + if dim == 2: + self.data[:, :, cur_len:new_len, :] = tensor[:,:,:,:] + truncated_data = self.data[:, :, :self.current_length, :] + elif dim == 3: + self.data[:, :, :, cur_len:new_len] = tensor[:,:,:,:] + truncated_data = self.data[:, :, :, :self.current_length] + else: + raise ValueError("Unsupported dimension for concatenation.") + return truncated_data + + +def initialize_past_key_values(model, batch_size=1): """ Initialize past key and value states for a given transformer model. @@ -84,8 +103,6 @@ def initialize_past_key_values(model): """ # Extracting configuration from the model config = model.config - # Initializing the batch size to 1, this can be modified if different batch sizes are required - batch_size = 1 # Initializing a tensor to store past keys and values for all layers past_key_values_data = torch.zeros( config.num_hidden_layers * 2, diff --git a/medusa/model/medusa_model.py b/medusa/model/medusa_model.py index 1a5f70c..6747952 100644 --- a/medusa/model/medusa_model.py +++ b/medusa/model/medusa_model.py @@ -3,11 +3,11 @@ from .modeling_llama_kv import LlamaForCausalLM as KVLlamaForCausalLM from .modeling_mistral_kv import MistralForCausalLM as KVMistralForCausalLM # import transformers - +import pdb # # monkey patch # transformers.models.llama.modeling_llama.LlamaForCausalLM = KVLlamaForCausalLM # transformers.models.mistral.modeling_mistral.MistralForCausalLM = KVMistralForCausalLM - +import copy from transformers import PreTrainedModel, PretrainedConfig from .utils import * from .kv_cache import initialize_past_key_values @@ -106,7 +106,7 @@ def __init__( self.medusa = medusa_num_heads self.medusa_num_layers = medusa_num_layers self.base_model_name_or_path = base_model_name_or_path - self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name_or_path) + self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name_or_path, use_fast=True) # Create a list of Medusa heads self.medusa_head = nn.ModuleList( [ @@ -121,6 +121,7 @@ def __init__( @property def base_model(self): return self + @classmethod def from_pretrained( cls, @@ -219,6 +220,7 @@ def forward( if output_orig: return torch.stack(medusa_logits, dim=0), outputs, orig return torch.stack(medusa_logits, dim=0) + def get_medusa_choice(self, model_name): if 'vicuna' in model_name: if '7b' in model_name: @@ -264,10 +266,11 @@ def medusa_generate( Warning: Only support batch size 1 for now!! """ - assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!" + # assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!" + batch_size = input_ids.shape[0] + valid_length = attention_mask.sum(dim=1) # Avoid modifying the input_ids in-place input_ids = input_ids.clone() - # Cache medusa buffers (the fixed patterns for tree attention) if medusa_choices is None: medusa_choices = self.get_medusa_choice(self.base_model_name_or_path) @@ -284,7 +287,7 @@ def medusa_generate( self.medusa_choices = medusa_choices # Initialize the past key and value states - if hasattr(self, "past_key_values"): + if hasattr(self, "past_key_values") and batch_size==self.past_key_values_data.shape[1]: past_key_values = self.past_key_values past_key_values_data = self.past_key_values_data current_length_data = self.current_length_data @@ -295,23 +298,31 @@ def medusa_generate( past_key_values, past_key_values_data, current_length_data, - ) = initialize_past_key_values(self.base_model) + ) = initialize_past_key_values(self.base_model, batch_size) self.past_key_values = past_key_values self.past_key_values_data = past_key_values_data self.current_length_data = current_length_data - input_len = input_ids.shape[1] + # input_len = input_ids.shape[1] + input_len = (input_ids != self.tokenizer.pad_token_id).sum(dim=1) reset_medusa_mode(self) # Initialize tree attention mask and process prefill tokens medusa_logits, logits = initialize_medusa( - input_ids, self, medusa_buffers["medusa_attn_mask"], past_key_values + input_ids, self, medusa_buffers["medusa_attn_mask"], past_key_values, attention_mask ) - new_token = 0 last_round_token = 0 + if isinstance(input_len, int): + ends = [input_len] * batch_size + else: + ends = copy.deepcopy(input_len) - for idx in range(max_steps): + target_lenght = torch.ones(batch_size, dtype=torch.int, device=input_ids.device)*max_steps + any_finished = torch.any(target_lenght<=0) + all_finished = torch.all(target_lenght<=0) + # for idx in range(max_steps): + while not all_finished: # Generate candidates with topk predictions from Medusa heads candidates, tree_candidates = generate_candidates( medusa_logits, @@ -324,8 +335,8 @@ def medusa_generate( top_p=top_p, sampling=sampling, fast=fast, + valid_length=valid_length ) - # Use tree attention to verify the candidates and get predictions medusa_logits, logits, outputs = tree_decoding( self, @@ -334,15 +345,14 @@ def medusa_generate( medusa_buffers["medusa_position_ids"], input_ids, medusa_buffers["retrieve_indices"], + attention_mask=attention_mask ) - # Evaluate the posterior of the candidates to select the accepted candidate prefix best_candidate, accept_length = evaluate_posterior( logits, candidates, temperature, posterior_threshold, posterior_alpha, top_p=top_p, sampling=sampling, fast=fast ) - # Update the input_ids and logits - input_ids, logits, medusa_logits, new_token = update_inference_inputs( + input_ids, logits, medusa_logits, new_token, valid_length, attention_mask = update_inference_inputs( input_ids, candidates, best_candidate, @@ -354,18 +364,35 @@ def medusa_generate( new_token, past_key_values_data, current_length_data, + attention_mask=attention_mask, + padding_idx=self.tokenizer.pad_token_id ) - - yield { - "text": self.tokenizer.decode( - input_ids[0, input_len:], + decoded_texts = [] + eos_encountered = [False] * batch_size + target_lenght -= valid_length + any_finished = torch.any(target_lenght<=0) + all_finished = torch.all(target_lenght<=0) + for i in range(batch_size): + if isinstance(input_len, int): + input_len_ = input_len + else: + input_len_ = input_len[i] + # 检查当前批次的文本是否包含结束符 + if self.tokenizer.eos_token_id in input_ids[i, input_len_:]: + eos_encountered[i] = True + else: + ends[i] = len(input_ids[i]) + decoded_text = self.tokenizer.decode( + input_ids[i, input_len_:ends[i]], skip_special_tokens=True, spaces_between_special_tokens=False, clean_up_tokenization_spaces=True, ) - } + decoded_texts.append(decoded_text) + yield{"text": decoded_texts} - if self.tokenizer.eos_token_id in input_ids[0, input_len:]: + # 如果所有批次都遇到了 EOS,则停止 + if all(eos_encountered): break diff --git a/medusa/model/modeling_llama_kv.py b/medusa/model/modeling_llama_kv.py index abf9382..d6d54eb 100644 --- a/medusa/model/modeling_llama_kv.py +++ b/medusa/model/modeling_llama_kv.py @@ -32,7 +32,7 @@ if is_flash_attn_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa - +import pdb logger = logging.get_logger(__name__) @@ -315,7 +315,6 @@ def forward( padding_mask: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() - if self.config.pretraining_tp > 1: key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp query_slices = self.q_proj.weight.split( @@ -815,6 +814,8 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em # [MODIFIED] add medusa mask if hasattr(self, "medusa_mask") and self.medusa_mask is not None: medusa_mask = self.medusa_mask + bs = combined_attention_mask.shape[0] + medusa_mask = medusa_mask.repeat(bs,1,1,1) medusa_len = medusa_mask.size(-1) combined_attention_mask[:, :, -medusa_len:, -medusa_len:][ medusa_mask == 0 @@ -886,7 +887,6 @@ def forward( padding_mask = attention_mask else: padding_mask = None - attention_mask = self._prepare_decoder_attention_mask( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length ) @@ -1038,7 +1038,6 @@ def forward( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, diff --git a/medusa/model/origin_model.py b/medusa/model/origin_model.py new file mode 100644 index 0000000..c069789 --- /dev/null +++ b/medusa/model/origin_model.py @@ -0,0 +1,55 @@ +import torch +import pdb +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.generation.utils import GenerationConfig +from .medusa_model import MedusaConfig + + +class Tokenizer(): + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path, + *args, + **kwargs, + ): + config = MedusaConfig.from_pretrained(pretrained_model_name_or_path) + model_dir=config.base_model_name_or_path + return AutoTokenizer.from_pretrained(model_dir, + use_fast=True, + trust_remote_code=True) + + +class Model(): + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path, + *args, + **kwargs, + ): + config = MedusaConfig.from_pretrained(pretrained_model_name_or_path) + model_dir=config.base_model_name_or_path + model = AutoModelForCausalLM.from_pretrained(model_dir, + device_map="auto", + torch_dtype=torch.float16, + trust_remote_code=True) + + model.generation_config = GenerationConfig.from_pretrained(model_dir) + return model + +def medusa_generate(self, **kwargs): + output_ids = None + kwargs['max_length'] = kwargs['max_steps']+kwargs['input_ids'].shape[-1] + generator = self.generate(**kwargs, do_stream=True, do_sample=True) + for tokens in generator: + tokens=tokens.unsqueeze(-1) + if output_ids is None: + output_ids = tokens + else: + output_ids = torch.cat((output_ids, tokens), dim=-1) + decoded_texts = self.tokenizer.batch_decode(output_ids, + skip_special_tokens=True, + spaces_between_special_tokens=False, + clean_up_tokenization_spaces=True,) + yield {"text": decoded_texts} \ No newline at end of file diff --git a/medusa/model/utils.py b/medusa/model/utils.py index 67fd4a5..dfec3ed 100644 --- a/medusa/model/utils.py +++ b/medusa/model/utils.py @@ -1,8 +1,30 @@ import torch import torch.nn.functional as F +import pdb TOPK=10 # topk for sparse tree (10 is a placeholder and it is sufficient) +def extract_last_valid_logits(logits: torch.Tensor, valid_length: torch.Tensor): + """ + Extract logits of the last valid token for each sequence in the batch. + + Args: + logits (torch.Tensor): Logits tensor of shape [batch_size, sequence_length, vocab_size], on GPU. + valid_length (torch.Tensor): valid_length tensor of shape [batch_size], on GPU. + + Returns: + torch.Tensor: Tensor containing the logits of the last valid token for each sequence, on GPU. + """ + if logits.dim() == 3: + batch_indices = torch.arange(logits.size(0), device=logits.device) # Batch indices + # Extract the logits of the last valid token for each sequence + last_valid_logits = logits[batch_indices, valid_length - 1] + elif logits.dim() == 4: + batch_indices = torch.arange(logits.size(1), device=logits.device) # Batch indices + # Extract the logits of the last valid token for each sequence + last_valid_logits = logits[:, batch_indices, valid_length - 1].transpose(1,0) + return last_valid_logits + def pad_path(path, length, pad_value=-2): """ Pad the given path list with a specific value up to a specified length. @@ -78,8 +100,32 @@ def generate_medusa_buffers(medusa_choices, device="cuda"): for i in range(len(depth_counts)): for j in range(depth_counts[i]): cur_medusa_choice = sorted_medusa_choices[start + j] - medusa_tree_indices[start + j + 1] = cur_medusa_choice[-1] + TOPK * i + 1 + medusa_tree_indices[start + j + 1] = cur_medusa_choice[-1] + TOPK * i + 1 ##根据每组最后一个节点和所在深度计算所在位置 start += depth_counts[i] + """ + 逻辑上结构: + A (原始头预测token, 没在sorted_medusa_choices中) + + B C ... K (第一个头预测token, 预测topk个) + + banana cute ... key(第二个头预测token, 预测topk个) + + 铺平之后: A B C ... K banana cute ... key (一共1+topk*深度个=1+4*10=41个) + + A:0 + ---- + B:1 + C:2 + ... + k:11 + ---- + banana:12 + cute:13 + key:22 + + 不是所有路径都选,节点有可能被多条路径选多次,事先设置选64个路径 + medusa_tree_indices: 所有路径经过的节点,根据从短到长,从小到大记录下平铺后最后一个节点序号 + """ # Generate position IDs for the Medusa structure medusa_position_ids = torch.zeros(medusa_len, dtype=torch.long) @@ -92,7 +138,7 @@ def generate_medusa_buffers(medusa_choices, device="cuda"): retrieve_indices_nest = [] retrieve_paths = [] for i in range(len(sorted_medusa_choices)): - cur_medusa_choice = sorted_medusa_choices[-i-1] + cur_medusa_choice = sorted_medusa_choices[-i-1] ##倒着循环 retrieve_indice = [] if cur_medusa_choice in retrieve_paths: continue @@ -106,7 +152,6 @@ def generate_medusa_buffers(medusa_choices, device="cuda"): retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long) retrieve_indices = retrieve_indices + 1 retrieve_indices = torch.cat([torch.zeros((retrieve_indices.shape[0], 1), dtype=torch.long), retrieve_indices], dim=1) - # Aggregate the generated buffers into a dictionary medusa_buffers = { "medusa_attn_mask": medusa_attn_mask.unsqueeze(0).unsqueeze(0), @@ -125,7 +170,7 @@ def generate_medusa_buffers(medusa_choices, device="cuda"): return medusa_buffers -def initialize_medusa(input_ids, model, medusa_attn_mask, past_key_values): +def initialize_medusa(input_ids, model, medusa_attn_mask, past_key_values, attention_mask=None): """ Initializes the Medusa structure for a given model. @@ -144,7 +189,7 @@ def initialize_medusa(input_ids, model, medusa_attn_mask, past_key_values): - logits (torch.Tensor): Original logits from the base model. """ medusa_logits, outputs, logits = model( - input_ids, past_key_values=past_key_values, output_orig=True, medusa_forward=True + input_ids, attention_mask=attention_mask, past_key_values=past_key_values, output_orig=True, medusa_forward=True ) model.base_model.model.medusa_mask = medusa_attn_mask return medusa_logits, logits @@ -173,7 +218,6 @@ def reset_medusa_mode( model.base_model.model.medusa_mask = None model.base_model.model.medusa_mode = None - def reset_past_key_values(passed_key_values): """ Resets the current lengths in the passed key-values to zero. @@ -255,7 +299,7 @@ def get_typical_one_token(logit, temperature, posterior_threshold, posterior_alp sampled_tokens = torch.multinomial(F.softmax(logit, dim=-1), 1) return sampled_tokens -def generate_candidates(medusa_logits, logits, tree_indices, retrieve_indices, temperature = 0, posterior_threshold=0.3, posterior_alpha = 0.09, top_p=0.8, sampling = 'typical', fast = False): +def generate_candidates(medusa_logits, logits, tree_indices, retrieve_indices, temperature = 0, posterior_threshold=0.3, posterior_alpha = 0.09, top_p=0.8, sampling = 'typical', fast = False, valid_length=None): """ Generate candidates based on provided logits and indices. @@ -276,34 +320,84 @@ def generate_candidates(medusa_logits, logits, tree_indices, retrieve_indices, t 1. Cartesian candidates derived from the combined original and Medusa logits. 2. Tree candidates mapped from the Cartesian candidates using tree indices. """ + + if valid_length is not None: + last_logits = extract_last_valid_logits(logits, valid_length) + medusa_last_logits = extract_last_valid_logits(medusa_logits, valid_length) + else: + last_logits = logits + medusa_last_logits = medusa_logits # Greedy decoding: Select the most probable candidate from the original logits. if temperature == 0 or fast: - candidates_logit = torch.argmax(logits[:, -1]).unsqueeze(0) + candidates_logit = torch.argmax(last_logits, dim=-1).unsqueeze(-1) ##logits: [bs,seq,vocab], candidates_logit:[bs,1] else: if sampling == 'typical': - candidates_logit = get_typical_one_token(logits[:, -1], temperature, posterior_threshold, posterior_alpha).squeeze(0) + candidates_logit = get_typical_one_token(last_logits, temperature, posterior_threshold, posterior_alpha).squeeze(0) elif sampling == 'nucleus': - candidates_logit = get_nucleus_one_token(logits[:, -1], temperature, top_p).squeeze(0) + candidates_logit = get_nucleus_one_token(last_logits, temperature, top_p).squeeze(0) else: raise NotImplementedError + # Extract the TOPK candidates from the medusa logits. - candidates_medusa_logits = torch.topk(medusa_logits[:, 0, -1], TOPK, dim = -1).indices - + candidates_medusa_logits = torch.topk(medusa_last_logits, TOPK, dim = -1).indices ##candidates_medusa_logits:[bs, medusa_head_num, TOPK] + batch_size = candidates_logit.shape[0] # Combine the selected candidate from the original logits with the topk medusa logits. - candidates = torch.cat([candidates_logit, candidates_medusa_logits.view(-1)], dim=-1) + candidates = torch.cat([candidates_logit, candidates_medusa_logits.view(batch_size, -1)], dim=-1) # Map the combined candidates to the tree indices to get tree candidates. - tree_candidates = candidates[tree_indices] + tree_candidates = candidates[:,tree_indices] # Extend the tree candidates by appending a zero. - tree_candidates_ext = torch.cat([tree_candidates, torch.zeros((1), dtype=torch.long, device=tree_candidates.device)], dim=0) + tree_candidates_ext = torch.cat([tree_candidates, torch.zeros((batch_size,1), dtype=torch.long, device=tree_candidates.device)], dim=-1) # Retrieve the cartesian candidates using the retrieve indices. - cart_candidates = tree_candidates_ext[retrieve_indices] + cart_candidates = tree_candidates_ext[:,retrieve_indices] # Unsqueeze the tree candidates for dimension consistency. - tree_candidates = tree_candidates.unsqueeze(0) + # tree_candidates = tree_candidates.unsqueeze(0) return cart_candidates, tree_candidates + """ + cart_candidates.shape + torch.Size([2, 42, 5]) + tree_candidates.shape + torch.Size([2, 64]) + """ + +def update_position_id(medusa_position_ids, attention_mask, input_ids): + bs = input_ids.shape[0] + seqlen = medusa_position_ids.shape[0] + medusa_position_ids_unsqueezed = torch.unsqueeze(medusa_position_ids, dim=0) + medusa_position_ids_repeated = medusa_position_ids_unsqueezed.repeat(bs,1) + valid_length = torch.unsqueeze(attention_mask.sum(dim=1), dim=-1) + valid_length_repeated = valid_length.repeat(1, seqlen) + position_ids = medusa_position_ids_repeated + valid_length_repeated + return position_ids + +def update_attention_mask(attention_mask, tree_candidates): + bs = tree_candidates.shape[0] + n = tree_candidates.shape[1] + # 创建一个新的张量,用于在尾部添加n个token + new_tokens = torch.ones((bs, n), dtype=attention_mask.dtype, device=attention_mask.device) + # 使用torch.cat来扩增attention_mask + extended_attention_mask = torch.cat((attention_mask, new_tokens), dim=1) + return extended_attention_mask + +def update_position_id(medusa_position_ids, attention_mask, input_ids): + bs = input_ids.shape[0] + seqlen = medusa_position_ids.shape[0] + medusa_position_ids_unsqueezed = torch.unsqueeze(medusa_position_ids, dim=0) + medusa_position_ids_repeated = medusa_position_ids_unsqueezed.repeat(bs,1) + valid_length = torch.unsqueeze(attention_mask.sum(dim=1), dim=-1) + valid_length_repeated = valid_length.repeat(1, seqlen) + position_ids = medusa_position_ids_repeated + valid_length_repeated + return position_ids + +def update_attention_mask(attention_mask, tree_candidates): + bs = tree_candidates.shape[0] + n = tree_candidates.shape[1] + new_tokens = torch.ones((bs, n), dtype=attention_mask.dtype, device=attention_mask.device) + extended_attention_mask = torch.cat((attention_mask, new_tokens), dim=1) + return extended_attention_mask def tree_decoding( @@ -313,6 +407,7 @@ def tree_decoding( medusa_position_ids, input_ids, retrieve_indices, + attention_mask ): """ Decode the tree candidates using the provided model and reorganize the logits. @@ -328,23 +423,23 @@ def tree_decoding( Returns: - tuple: Returns medusa logits, regular logits, and other outputs from the model. """ - # Compute new position IDs by adding the Medusa position IDs to the length of the input sequence. - position_ids = medusa_position_ids + input_ids.shape[1] - + # position_ids = medusa_position_ids + input_ids.shape[1] + position_ids = update_position_id(medusa_position_ids, attention_mask, input_ids) + attention_mask = update_attention_mask(attention_mask, tree_candidates) # Use the model to decode the tree candidates. # The model is expected to return logits for the Medusa structure, original logits, and possibly other outputs. tree_medusa_logits, outputs, tree_logits = model( tree_candidates, + attention_mask=attention_mask, output_orig=True, past_key_values=past_key_values, position_ids=position_ids, medusa_forward=True, ) - # Reorder the obtained logits based on the retrieve_indices to ensure consistency with some reference ordering. - logits = tree_logits[0, retrieve_indices] - medusa_logits = tree_medusa_logits[:, 0, retrieve_indices] + logits = tree_logits[:, retrieve_indices] + medusa_logits = tree_medusa_logits[:, :, retrieve_indices] return medusa_logits, logits, outputs def get_nucleus_posterior_mask(logits, candidates, temperature, top_p): @@ -430,9 +525,7 @@ def get_typical_posterior_mask(logits, candidates, temperature, posterior_thresh sampled_tokens = sampled_tokens.view(n_samples, n_tokens) posterior_mask = (candidates[:, 1:] == sampled_tokens).int() return posterior_mask - - - + def evaluate_posterior( logits, candidates, temperature, posterior_threshold=0.3, posterior_alpha = 0.09, top_p=0.8, sampling = 'typical', fast = True ): @@ -459,23 +552,19 @@ def evaluate_posterior( if temperature == 0: # Find the tokens that match the maximum logits for each position in the sequence posterior_mask = ( - candidates[:, 1:] == torch.argmax(logits[:, :-1], dim=-1) + candidates[:,:,1:] == torch.argmax(logits[:,:,:-1], dim=-1) ).int() - candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1) - accept_length = candidates_accept_length.max() - # Choose the best candidate - if accept_length == 0: - # Default to the first candidate if none are accepted - best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device) - else: - best_candidate = torch.argmax(candidates_accept_length).to(torch.long) + candidates_accept_length = (torch.cumprod(posterior_mask, dim=-1)).sum(dim=-1) + accept_length = candidates_accept_length.max(dim=1)[0] + best_candidate = torch.argmax(candidates_accept_length, dim=-1).to(torch.long) return best_candidate, accept_length if sampling == 'typical': if fast: - posterior_prob = torch.softmax(logits[:, :-1] / temperature, dim=-1) + ## logits 最后一个是新预测的,candidates第0个是原始头的输出,不用比较 + posterior_prob = torch.softmax(logits[:,:,:-1] / temperature, dim=-1) candidates_prob = torch.gather( - posterior_prob, dim=-1, index=candidates[:, 1:].unsqueeze(-1) + posterior_prob, dim=-1, index=candidates[:,:,1:].unsqueeze(-1) ).squeeze(-1) posterior_entropy = -torch.sum( posterior_prob * torch.log(posterior_prob + 1e-5), dim=-1 @@ -485,21 +574,20 @@ def evaluate_posterior( torch.exp(-posterior_entropy) * posterior_alpha, ) posterior_mask = candidates_prob > threshold - candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1) - + candidates_accept_length = (torch.cumprod(posterior_mask, dim=-1)).sum(dim=-1) + batch_size, num_path = candidates_accept_length.shape + best_candidate = torch.zeros(batch_size, dtype=torch.long, device=candidates_accept_length.device) # Choose the best candidate based on the evaluated posterior probabilities - accept_length = candidates_accept_length.max() - if accept_length == 0: - # If no candidates are accepted, just choose the first one - best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device) - else: - best_candidates = torch.where(candidates_accept_length == accept_length)[0] - # Accept the best one according to likelihood - likelihood = torch.sum( - torch.log(candidates_prob[best_candidates, :accept_length]), dim=-1 - ) - best_candidate = best_candidates[torch.argmax(likelihood)] - return best_candidate, accept_length + accept_lengths = candidates_accept_length.max(dim=1)[0] + for i in range(batch_size): + if accept_lengths[i] != 0: + best_candidates = torch.where(candidates_accept_length[i] == accept_lengths[i])[0] + # Accept the best one according to likelihood + likelihood = torch.sum( + torch.log(candidates_prob[i, best_candidates, :accept_lengths[i]]), dim=-1 + ) + best_candidate[i] = best_candidates[torch.argmax(likelihood)] + return best_candidate, accept_lengths # Calculate posterior probabilities and thresholds for candidate selection posterior_mask = get_typical_posterior_mask(logits, candidates, temperature, posterior_threshold, posterior_alpha, fast) candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1) @@ -528,6 +616,115 @@ def evaluate_posterior( return best_candidate, accept_length else: raise NotImplementedError + + +def generate_gather_mask(accept_length, max_accept_length): + batch_size = accept_length.shape[0] + range_tensor = torch.arange(max_accept_length, device='cuda:0').expand(batch_size, -1) + gather_mask = (range_tensor < accept_length.unsqueeze(1)) + return gather_mask + + +def generate_gather_indices(gather_mask, max_accept_length, candidate_ids, prev_input_len): + batch_size, _ = candidate_ids.shape + output_indices = torch.full((batch_size, max_accept_length), -1, dtype=torch.long, device=candidate_ids.device) + candidate_ids_ = candidate_ids[:, :max_accept_length] + prev_input_len + output = torch.where(gather_mask, candidate_ids_, output_indices) + return output + + +def select_new_tokens(candidates, best_candidate, gather_mask, max_accept_length, padding_id): + batch_size, _, _ = candidates.shape + candidates = candidates[:, :, :max_accept_length] + best_paths = torch.gather(candidates, 1, best_candidate.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, candidates.size(2))).squeeze(1) + default_ids = torch.full((batch_size, max_accept_length), padding_id, dtype=best_paths.dtype, device=best_paths.device) + output = torch.where(gather_mask, best_paths, default_ids) + return output + + +def gather_from_past_key_values(past_key_values_data, select_indices): + layers, batch_size, head_num, _, hidden_size = past_key_values_data.shape + seqlen = select_indices.shape[1] + + # 初始化结果张量,用于存放选择的数据或全零填充 + result_data = torch.zeros(layers, batch_size, head_num, seqlen, hidden_size, device=past_key_values_data.device, dtype=past_key_values_data.dtype) + + # 扩展 select_indices 以匹配 past_key_values_data 的操作维度 + expanded_indices = select_indices.unsqueeze(0).unsqueeze(2).expand(layers, batch_size, head_num, seqlen) + + # 创建一个掩码,用于识别 select_indices 中的有效索引(非 -1 值) + valid_indices_mask = expanded_indices != -1 + + # 修正 -1 索引值以避免 gather 时的错误,将 -1 替换为一个有效的索引(如 0),后续再通过掩码处理 + corrected_indices = torch.where(valid_indices_mask, expanded_indices, torch.zeros_like(expanded_indices)) + + # 使用 gather 选择数据 + gathered_data = torch.gather(past_key_values_data, 3, corrected_indices.unsqueeze(-1).expand(-1, -1, -1, -1, hidden_size)) + + # 利用掩码将结果中对应 -1 索引的位置替换为全零 + result_data = torch.where(valid_indices_mask.unsqueeze(-1), gathered_data, result_data) + return result_data + +## pad every step +def update_ids(input_ids, new_ids): + input_ids = torch.cat([input_ids, new_ids], dim=-1) + return input_ids + +def update_mask(attention_mask, accept_length): + # 创建一个每行都是0到max_seqlen-1的范围张量 + range_tensor = torch.arange(accept_length.max().item(), device='cuda:0').expand(accept_length.shape[0], -1) + # 根据 accept_length 生成 mask,其中有效长度标记为1,其他为0 + new_attention_mask = (range_tensor < accept_length.unsqueeze(1)).to(int) + attention_mask = torch.cat((attention_mask, new_attention_mask), dim=-1) + return attention_mask + +def update_kvcache(tgt, past_key_values_data, prev_input_len): + dst = past_key_values_data[..., prev_input_len : prev_input_len + tgt.shape[-2], :] + dst.copy_(tgt, non_blocking=True) + +def update_current_length(current_length_data, prev_input_len, new_len): + current_length_data.fill_(prev_input_len + new_len) + +## avoid too much [PAD] +def update_ids_new(previous_ids, new_ids, padding_value=0): + batch_size = previous_ids.shape[0] + previous_seqlen = previous_ids.shape[1] + new_seqlen = new_ids.shape[1] + new_id_index = torch.arange(new_seqlen, device='cuda:0').expand(batch_size, -1) + + previous_mask = previous_ids != padding_value + new_id_mask = new_ids != padding_value + + previous_valid_lengths = previous_mask.sum(dim=1) + new_id_valid_lengths = new_id_mask.sum(dim=1) + broad_max_output_len = previous_valid_lengths.max() + new_id_valid_lengths.max() + tight_max_output_len = (previous_valid_lengths + new_id_valid_lengths).max() + + new_id_index = previous_valid_lengths.view(batch_size,-1) + new_id_index + output = torch.full((batch_size, broad_max_output_len), padding_value, dtype=previous_ids.dtype, device=previous_ids.device) + output[:, :previous_seqlen] = previous_ids + output = output.scatter(1,new_id_index, new_ids) + if tight_max_output_len < broad_max_output_len: + output = output[:,:tight_max_output_len] + return output, new_id_index + +def update_mask_new(attention_mask, accept_length): + batch_size = attention_mask.shape[0] + previous_valid_lengths = attention_mask.sum(dim=1) + new_valid_lengths = (previous_valid_lengths + accept_length).view(batch_size, -1) + max_new_valid_length = new_valid_lengths.max() + output = torch.arange(max_new_valid_length, dtype=attention_mask.dtype, device=attention_mask.device).expand(batch_size, -1) + output = (output < new_valid_lengths).to(int) + return output + +def update_kvcache_new(tgt, past_key_values_data, scatter_index): + n_layers, _, num_head, _, hidden_size = tgt.shape + expand_scatter_index = scatter_index.unsqueeze(0).unsqueeze(2).unsqueeze(4).expand(n_layers,-1,num_head,-1,hidden_size) + past_key_values_data.scatter_(3, expand_scatter_index, tgt) + +def update_current_length_new(current_length_data, new_lenght): + current_length_data.fill_(new_lenght) + def update_inference_inputs( input_ids, candidates, @@ -540,6 +737,8 @@ def update_inference_inputs( new_token, past_key_values_data, current_length_data, + attention_mask, + padding_idx=0 ): """ Update the input sequences and relevant tensors based on the selected best candidate from the inference results. @@ -552,7 +751,7 @@ def update_inference_inputs( - retrieve_indices (torch.Tensor): Indices to map tree to a cartesian product. - outputs, logits, medusa_logits (torch.Tensor): Model's outputs from the previous inference step. - new_token (int): Counter for the new tokens added during inference. - - past_key_values_data (torch.Tensor): Tensor containing past hidden states for the transformer model. + - past_key_values_data (torch.Tensor): Tensor containing past hidden states for the transformer model. [layers, batch_size, head_num, max_seqlen, hidden_size] - current_length_data (torch.Tensor): Tensor containing the current length of sequences in the batch. Returns: @@ -561,33 +760,47 @@ def update_inference_inputs( - medusa_logits (torch.Tensor): Updated medusa logits. - new_token (int): Updated counter for the new tokens added. """ + accept_length += 1 ## accept_length > 0 + max_accept_length = accept_length.max().item() + batch_indices = torch.arange(best_candidate.size(0), device=logits.device) # Calculate the starting position for new tokens based on the previous input length prev_input_len = input_ids.shape[1] # Map the best candidate indices to the original indices in the sequence - select_indices = ( - retrieve_indices[best_candidate, : accept_length + 1] + prev_input_len - ) - # Append the tokens from the best candidate to the input sequence - input_ids = torch.cat( - [input_ids, candidates[None, best_candidate, : accept_length + 1]], dim=-1 - ) - # Update the past key values based on the selected tokens - # Source tensor that contains relevant past information based on the selected candidate - tgt = past_key_values_data[..., select_indices, :] - # Destination tensor where the relevant past information will be stored - dst = past_key_values_data[..., prev_input_len : prev_input_len + tgt.shape[-2], :] - # Copy relevant past information from the source to the destination - dst.copy_(tgt, non_blocking=True) - - # Update the current length tensor (currently only support batch size is 1) - current_length_data.fill_(prev_input_len + tgt.shape[-2]) - - # Extract logits and medusa logits for the accepted tokens - logits = logits[None, best_candidate, accept_length : accept_length + 1] - medusa_logits = medusa_logits[ - :, None, best_candidate, accept_length : accept_length + 1 - ] + candidate_ids = retrieve_indices[best_candidate] + gather_mask = generate_gather_mask(accept_length, max_accept_length) + select_indices = generate_gather_indices(gather_mask, max_accept_length, candidate_ids, prev_input_len) + new_ids = select_new_tokens(candidates, best_candidate, gather_mask, max_accept_length, padding_id=padding_idx) + if False: + # Append the tokens from the best candidate to the input sequence + input_ids = update_ids(input_ids, new_ids) + # Update the past key values based on the selected tokens + # Source tensor that contains relevant past information based on the selected candidate + # Destination tensor where the relevant past information will be stored + # Copy relevant past information from the source to the destination + tgt = gather_from_past_key_values(past_key_values_data, select_indices) + update_kvcache(tgt, past_key_values_data, prev_input_len) + # Update the current length tensor + update_current_length(current_length_data, prev_input_len, tgt.shape[-2]) + # Update the attention mask tensor + attention_mask = update_mask(attention_mask, accept_length) + else: + input_ids, scatter_index = update_ids_new(input_ids, new_ids, padding_value=padding_idx) + tgt = gather_from_past_key_values(past_key_values_data, select_indices) + update_kvcache_new(tgt, past_key_values_data, scatter_index) + update_current_length_new(current_length_data, input_ids.shape[-1]) + attention_mask = update_mask_new(attention_mask, accept_length) + + + if True: + # Extract logits and medusa logits for the accepted tokens + logits = logits[batch_indices, best_candidate, : max_accept_length] + medusa_logits = medusa_logits[:, batch_indices, best_candidate, : max_accept_length] + valid_length = accept_length + else: + # Extract logits and medusa logits for the last accepted tokens + logits = logits[batch_indices, best_candidate, accept_length-1] #最后一个logits + medusa_logits = medusa_logits[:, batch_indices, best_candidate, accept_length-1] #最后一个logits + valid_length = None # Update the new token counter - new_token += accept_length + 1 - - return input_ids, logits, medusa_logits, new_token + new_token += max_accept_length + return input_ids, logits, medusa_logits, new_token, valid_length, attention_mask