diff --git a/colossalai/inference/batch_bucket.py b/colossalai/inference/batch_bucket.py index 581d114d2525..30455e179a6d 100644 --- a/colossalai/inference/batch_bucket.py +++ b/colossalai/inference/batch_bucket.py @@ -558,3 +558,68 @@ def fd_inter_tensor(self) -> None: def __repr__(self) -> str: return f"(sequences_dict={self._sequences_dict}, is_prompts={self.is_prompts})" + + +class RPCBatchBucket(BatchBucket): + def __init__(self, *args, **argv): + self.is_rpc = True + self.device = "cpu" + super().__init__(*args, **argv) + + # For compatibility + def get_1D_inputs(self) -> List[int]: + assert len(self._sequences_dict) > 0, "No sequence in the batch" + first_seq = next(iter(self._sequences_dict.values())) # not exactly the first sequence + if first_seq.output_len == 0: + # Assume prefill stage + assert all( + seq.output_len == 0 for seq in self._sequences_dict.values() + ), "Sequence stage (Prefill/Decoding) must be the same in the batch" + out_li = [] + seq_ids = sorted(self._sequences_indexes.keys(), key=lambda x: self._sequences_indexes[x]) + for seq_id in seq_ids: + seq: Sequence = self._sequences_dict[seq_id] + out_li.extend(seq.input_token_id) + return out_li + else: + # Assume decoding stage + if self.use_spec_dec: + # For Speculative Decoding + # the number of tokens to be verified in parallel plus the correct token in the last step + return self.get_1D_inputs_spec_dec(self.num_tokens_to_verify + 1) + assert all( + seq.output_len > 0 for seq in self._sequences_dict.values() + ), "Sequence stage (Prefill/Decoding) must be the same in the batch" + assert self.is_compact, "BatchBucket is not compact" + out = [0] * self.current_batch_size + for seq_id, index_in_b in self._sequences_indexes.items(): + seq: Sequence = self._sequences_dict[seq_id] + out[index_in_b] = seq.output_token_id[-1] + return out + + # For compatibility + def get_sequence_lengths(self) -> List[int]: + assert self.is_compact # Debug usage + sequence_lengths = self.seq_lengths[: self.current_batch_size] + return sequence_lengths + + def get_1D_inputs_spec_dec(self, n: int) -> List[int]: + # Used for main model verification in **Decoding Stage** + # `n` is the number of tokens to be verified, + # and so that prepare the last `n` tokens of each sequence as the inputs + assert len(self._sequences_dict) > 0, "No sequence in the batch" + assert all( + seq.output_len >= n for seq in self._sequences_dict.values() + ), "Sequence output tokens must be greater than or equal to the number of tokens to be verified." + out_li = [] + seq_ids = sorted(self._sequences_indexes.keys(), key=lambda x: self._sequences_indexes[x]) + for seq_id in seq_ids: + seq: Sequence = self._sequences_dict[seq_id] + out_li.extend(seq.output_token_id[-n:]) + return out_li + + # For compatibility + def get_block_table_tensor(self) -> torch.Tensor: + assert self.is_compact # Debug usage + block_table = self.block_tables[: self.current_batch_size] + return block_table diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index 072ddbcfd298..75edac2172aa 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -89,8 +89,14 @@ class InputMetaData(RPC_PARAM): def to_rpc_param(self) -> Dict[str, any]: return { - "block_tables": self.block_tables.tolist(), - "sequence_lengths": self.sequence_lengths.tolist(), + "block_tables": self.block_tables, + # "block_tables": self.block_tables.tolist() + # if isinstance(self.block_tables, torch.Tensor) + # else self.block_tables, + "sequence_lengths": self.sequence_lengths, + # "sequence_lengths": self.sequence_lengths.tolist() + # if isinstance(self.block_tables, torch.Tensor) + # else self.sequence_lengths, "batch_size": self.batch_size, "is_prompts": self.is_prompts, "use_cuda_kernel": self.use_cuda_kernel, @@ -112,12 +118,17 @@ def from_rpc_param(rpc_dict: Dict[str, any]) -> "InputMetaData": from colossalai.accelerator import get_accelerator dtype = getattr(torch, rpc_dict["dtype"]) + device = get_accelerator().get_current_device() return InputMetaData( - block_tables=torch.tensor( - rpc_dict["block_tables"], dtype=torch.int, device=get_accelerator().get_current_device() + block_tables=( + torch.tensor(rpc_dict["block_tables"], dtype=torch.int, device=device) + if isinstance(rpc_dict["block_tables"], list) + else rpc_dict["block_tables"].to(device) ), - sequence_lengths=torch.tensor( - rpc_dict["sequence_lengths"], dtype=torch.int, device=get_accelerator().get_current_device() + sequence_lengths=( + torch.tensor(rpc_dict["sequence_lengths"], dtype=torch.int, device=device) + if isinstance(rpc_dict["sequence_lengths"], list) + else rpc_dict["sequence_lengths"].to(device) ), batch_size=rpc_dict["batch_size"], is_prompts=rpc_dict["is_prompts"], diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index 5c9bdc3214e9..17a6dc57e85a 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -78,7 +78,6 @@ def generate( Args: request_ids (List[int], optional): The request ID. Defaults to None. - prompts (Union[List[str], optional): Input prompts. Defaults to None. """ assert self.engine is not None, "Please init Engine first" diff --git a/colossalai/inference/core/request_handler.py b/colossalai/inference/core/request_handler.py index 393347c31e16..bfe433250b7a 100644 --- a/colossalai/inference/core/request_handler.py +++ b/colossalai/inference/core/request_handler.py @@ -4,7 +4,7 @@ from transformers.configuration_utils import PretrainedConfig from transformers.generation import GenerationConfig -from colossalai.inference.batch_bucket import BatchBucket +from colossalai.inference.batch_bucket import BatchBucket, RPCBatchBucket from colossalai.inference.config import InferenceConfig from colossalai.inference.flash_decoding_utils import FDIntermTensors from colossalai.inference.kv_cache import KVCacheManager, RPCKVCacheManager @@ -427,7 +427,7 @@ def __init__(self, inference_config: InferenceConfig, model_config: PretrainedCo # TODO In the continuous batching scenario, the batch size may be greater than max_batch_size, # which may cause bugs and this issue should be fixed later. - self.running_bb = BatchBucket( + self.running_bb = RPCBatchBucket( num_heads=model_config.num_attention_heads // inference_config.tp_size, head_dim=head_dim, max_batch_size=self.max_batch_size, @@ -437,7 +437,7 @@ def __init__(self, inference_config: InferenceConfig, model_config: PretrainedCo fd_interm_tensor=None, dtype=self.dtype, ) - self.prefill_bb = BatchBucket( + self.prefill_bb = RPCBatchBucket( num_heads=model_config.num_attention_heads // inference_config.tp_size, head_dim=head_dim, max_batch_size=self.max_batch_size, diff --git a/colossalai/inference/core/rpc_engine.py b/colossalai/inference/core/rpc_engine.py index 7493608727ed..4677418a350e 100644 --- a/colossalai/inference/core/rpc_engine.py +++ b/colossalai/inference/core/rpc_engine.py @@ -1,4 +1,5 @@ import asyncio +import pickle from itertools import count from time import sleep from typing import List, Tuple, Union @@ -11,7 +12,7 @@ from transformers import AutoConfig, PreTrainedTokenizer, PreTrainedTokenizerFast from transformers.configuration_utils import PretrainedConfig -from colossalai.inference.batch_bucket import BatchBucket +from colossalai.inference.batch_bucket import RPCBatchBucket from colossalai.inference.config import InferenceConfig, InputMetaData from colossalai.inference.executor.rpc_worker import rpcWorkerService from colossalai.inference.utils import find_available_ports @@ -120,6 +121,9 @@ def __init__( self.counter = count() self._verify_args() + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + self.logger.info("engine init over ") def _verify_args(self) -> None: @@ -162,8 +166,16 @@ def init_workers(self): raise Exception("conn error!") self.logger.info(f"Build RPC Connection Success! Begin to load model...") asyncio.run(self.init_worker_env()) + self._init_worker_forward() self.logger.info(f"init dist env over") + def _init_worker_forward(self): + """ + Async wrappers for forward, because it will be invoked many times. + """ + assert len(self.workers) == self.tp_size, "init workers first" + self.worker_forwards = [rpyc.async_(worker.execute_model_forward) for worker in self.workers] + async def async_parallel_wrapper(self, f, *args, **kwargs): async_res = rpyc.async_(f)(*args, **kwargs) await asyncio.to_thread(async_res.wait) @@ -210,7 +222,8 @@ async def _init_device_cache(self, alloc_shape: Tuple[int, int, int, int]): def init_device_cache(self, alloc_shape: Tuple[Tuple[int, ...], Tuple[int, ...]]): asyncio.run(self._init_device_cache(alloc_shape)) - def prepare_input(self, batch: BatchBucket) -> Tuple[List[int], InputMetaData]: + def prepare_input(self, batch: RPCBatchBucket) -> Tuple[List[int], InputMetaData]: + assert batch.is_rpc, "the batch must be RPCBatchBucket" input_ids = batch.get_1D_inputs() sequence_lengths = batch.get_sequence_lengths() @@ -220,7 +233,7 @@ def prepare_input(self, batch: BatchBucket) -> Tuple[List[int], InputMetaData]: n_tokens = batch.current_batch_size if batch.use_spec_dec: n_tokens = batch.num_tokens_to_verify + 1 - assert n_tokens == input_ids.size(0) + assert n_tokens == len(input_ids) n_tokens = n_tokens * batch.current_batch_size batch_token_ids = None @@ -252,40 +265,60 @@ def prepare_input(self, batch: BatchBucket) -> Tuple[List[int], InputMetaData]: batch_token_ids=batch_token_ids, ) - return input_ids.tolist(), input_meta_data + return input_ids, input_meta_data + + async def async_parallel_forward(self, async_f, *args, **kwargs): + async_res = async_f(*args, **kwargs) + await asyncio.to_thread(async_res.wait) + assert async_res.ready + return async_res.value - async def step_(self, input_token_ids, input_meta_data: InputMetaData): + async def step_async(self, input_token_ids, input_meta_data: InputMetaData): assert len(self.workers) == self.tp_size, "init workers first" - init_tasks = [ - self.async_parallel_wrapper( - worker.execute_model_forward, - input_token_ids, - input_meta_data.to_rpc_param(), - self.generation_config_dict, - ) - for worker in self.workers - ] + init_tasks = [] + for rank, async_forward in enumerate(self.worker_forwards): + if rank == 0: + init_tasks.append( + self.async_parallel_forward( + async_forward, + pickle.dumps(input_token_ids), + pickle.dumps(input_meta_data.to_rpc_param()), + pickle.dumps(self.generation_config_dict), + ) + ) + else: + init_tasks.append( + self.async_parallel_forward( + async_forward, + None, + None, + None, + ) + ) + ret = await asyncio.gather(*init_tasks) return ret[0] def step(self) -> List[str]: - batch = self.request_handler.schedule() + with self.t_prepare: + batch = self.request_handler.schedule() + + input_token_ids, input_meta_data = self.prepare_input(batch) - input_token_ids, input_meta_data = self.prepare_input(batch) - # TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported. - next_tokens = asyncio.run(self.step_(input_token_ids, input_meta_data)) + with self.t_exe: + # TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported. + next_tokens = self.loop.run_until_complete(self.step_async(input_token_ids, input_meta_data)) # update the request_handler - next_tokens = torch.tensor(next_tokens, dtype=torch.int) self.request_handler.append_next_tokens(next_tokens) finished_sequences = self.request_handler.update() return finished_sequences def kill_workers(self): """ - I don't find a good way to implicit invoke self.kill_workers + NOTE(@lry89757) Don't find a good way to implicit invoke self.kill_workers """ assert len(self.workers) != 0 for proc in self.worker_processes: diff --git a/colossalai/inference/executor/rpc_worker.py b/colossalai/inference/executor/rpc_worker.py index a4fd20a693b2..85f5758aebb1 100644 --- a/colossalai/inference/executor/rpc_worker.py +++ b/colossalai/inference/executor/rpc_worker.py @@ -1,4 +1,6 @@ -from typing import List, Tuple, Union +import pickle +from contextlib import nullcontext +from typing import List, Optional, Tuple, Union import rpyc import torch @@ -51,6 +53,25 @@ class rpcWorkerService(rpyc.Service): def exposed_init_dist_env(self, rank, world_size, master_address, master_port): logger.info(f"init process group for rank {rank}") colossalai.launch(rank=rank, world_size=world_size, port=master_port, host=master_address) + self.rank = rank + + self.profiling = False + self.profiler = ( + torch.profiler.profile( + record_shapes=True, + with_stack=True, + with_modules=True, + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + # schedule=torch.profiler.schedule(wait=0, repeat=1, active=1), + # on_trace_ready=torch.profiler.tensorboard_trace_handler(f"./tb_log_{args.batch_size}_" + args.mode), + ) + if self.profiling + else nullcontext() + ) + logger.info(f"init process group done for rank {rank}") def exposed_init_model( @@ -98,38 +119,53 @@ def exposed_init_cache(self, alloc_shape: Tuple[Tuple[int, ...], Tuple[int, ...] logger.info("physical cache init over") def exposed_execute_model_forward( - self, input_token_ids_param: List[int], input_meta_data_param: dict, generation_config_param: dict + self, + input_token_ids_param: Optional[List[int]] = None, + input_meta_data_param: Optional[dict] = None, + generation_config_param: Optional[dict] = None, ): - # prepare the data for model forward - input_meta_data = InputMetaData.from_rpc_param(input_meta_data_param) - input_meta_data.fd_inter_tensor = self.fd_inter_tensor - if input_meta_data.is_prompts: - n_tokens = input_meta_data.sequence_lengths.sum().item() - else: - n_tokens = input_meta_data.batch_size - input_token_ids = torch.tensor(input_token_ids_param, dtype=torch.int, device=self.device) - - # execute the model - logits = self.model( - input_token_ids, - self.output_tensor[:n_tokens], - input_meta_data, - self.k_cache, - self.v_cache, - ) + with self.profiler: + # prepare the data for model forward + input_token_ids, input_meta_data, generation_config = self._broadcast_param_to_all_workers( + input_token_ids_param=input_token_ids_param, + input_meta_data_param=input_meta_data_param, + generation_config_param=generation_config_param, + ) - # sampler - if self.inference_config.pad_input: - logits = logits[:, -1, :] - next_tokens = search_tokens( - generation_config_param, - logits, - input_meta_data.is_prompts, - input_meta_data.batch_token_ids, - ) + if input_meta_data.is_prompts: + n_tokens = input_meta_data.sequence_lengths.sum().item() + else: + n_tokens = input_meta_data.batch_size + + # execute the model + logits = self.model( + input_token_ids, + self.output_tensor[:n_tokens], + input_meta_data, + self.k_cache, + self.v_cache, + ) + + if self.profiling: + self.profiler.step() + + self.record() + + if self.rank == 0: + # sampler + if self.inference_config.pad_input: + logits = logits[:, -1, :] + next_tokens = search_tokens( + generation_config, + logits, + input_meta_data.is_prompts, + input_meta_data.batch_token_ids, + ) - # return the tokens generated to scheduler - return next_tokens.tolist() + # return the tokens generated to scheduler + # only rank 0 need to pass the data back + # to reduce the overhead of rpc param passing + return next_tokens.cpu() def _init_output_tensor(self): alloc_shape = ( @@ -166,6 +202,85 @@ def _init_fd_tensor(self): self.fd_inter_tensor = fd_inter_tensor + def _broadcast_param_to_all_workers( + self, + input_token_ids_param: Optional[List[int]] = None, + input_meta_data_param: Optional[dict] = None, + generation_config_param: Optional[dict] = None, + ): + if self.rank == 0: + input_token_ids_param = pickle.loads(input_token_ids_param) + input_meta_data_param = pickle.loads(input_meta_data_param) + generation_config_param = pickle.loads(generation_config_param) + + input_meta_data = InputMetaData.from_rpc_param(input_meta_data_param) + input_meta_data.fd_inter_tensor = self.fd_inter_tensor + input_token_ids = torch.tensor(input_token_ids_param, dtype=torch.int, device=self.device) + generation_config = generation_config_param + + if dist.get_world_size() > 1: + broadcast_list = {} + for k, v in input_meta_data_param.items(): + if not isinstance(v, torch.Tensor): + broadcast_list[k] = v + + # Pass the tensor shape and type in advance for + # other workers to prepare the empty tensor and async transport tensors + broadcast_list["block_tables"] = ( + input_meta_data.block_tables.size(), + input_meta_data.block_tables.dtype, + ) + broadcast_list["sequence_lengths"] = ( + input_meta_data.sequence_lengths.size(), + input_meta_data.sequence_lengths.dtype, + ) + broadcast_list["input_token_ids"] = (input_token_ids.size(), input_token_ids.dtype) + + # Generation Config Param + broadcast_list["generation_config"] = generation_config_param + + # send some meta data and some tensor shape + torch.distributed.broadcast_object_list([broadcast_list], src=self.rank) + + # send the real tensor + torch.distributed.broadcast(input_meta_data.block_tables, src=self.rank) + torch.distributed.broadcast(input_meta_data.sequence_lengths, src=self.rank) + torch.distributed.broadcast(input_token_ids, src=self.rank) + + else: + assert input_meta_data_param is None, "Input Must Be None" + + # recv the meta data + recv_list = [None] + torch.distributed.broadcast_object_list(recv_list, src=0) + input_meta_data_param = recv_list[0] + + generation_config = input_meta_data_param["generation_config"] + + blocktable_shape, blocktable_type = input_meta_data_param["block_tables"] + blocktables = torch.empty(blocktable_shape, dtype=blocktable_type, device=self.device) + sequence_lengths_shape, sequence_lengths_type = input_meta_data_param["sequence_lengths"] + sequence_lengths = torch.empty(sequence_lengths_shape, dtype=sequence_lengths_type, device=self.device) + input_token_ids_shape, input_token_ids_type = input_meta_data_param["input_token_ids"] + input_token_ids = torch.empty(input_token_ids_shape, dtype=input_token_ids_type, device=self.device) + + # recv the real tensor + async1 = torch.distributed.broadcast(blocktables, src=0, async_op=True) + async2 = torch.distributed.broadcast(sequence_lengths, src=0, async_op=True) + async3 = torch.distributed.broadcast(input_token_ids, src=0, async_op=True) + + input_meta_data_param["sequence_lengths"] = sequence_lengths + input_meta_data_param["block_tables"] = blocktables + + input_meta_data = InputMetaData.from_rpc_param(input_meta_data_param) + input_meta_data.fd_inter_tensor = self.fd_inter_tensor + + async1.wait() + async2.wait() + async3.wait() + + return input_token_ids, input_meta_data, generation_config + def _init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None): """ Shard model or/and Load weight @@ -306,3 +421,9 @@ def exposed_compute_only_for_test(self): logger.info(f"Worker rank {dist_rank}: Sum after all_reduce: {data.item()}") return data.item() + + def record(self): + if self.profiling: + file = "/home/lurunyu/projects/ColossalAI/test_trace_rpc.json" + self.profiler.export_chrome_trace(file) + logger.info(f"trace has been saved into {file}")