diff --git a/lmdeploy/pytorch/backends/dlinfer/maca/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/maca/op_backend.py index b454b995b4..3be4ab6f24 100644 --- a/lmdeploy/pytorch/backends/dlinfer/maca/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/maca/op_backend.py @@ -52,20 +52,23 @@ def get_total_slots(): kv_start_indices, attention_mask = [], [] block_num, block_size, _, _ = step_context.kv_caches[0][1].shape - device = step_context.block_offsets.device is_unpaged_prefill = False if not step_context.is_decoding: is_unpaged_prefill = \ all((step_context.q_seqlens == step_context.kv_seqlens).tolist()) - q_start_loc = torch.cat((torch.tensor([0], device=device), step_context.q_seqlens.cumsum(0))).int() + q_start_loc = step_context.q_start_loc + cu_seqlens = torch.cat((q_start_loc, step_context.q_seqlens.sum().unsqueeze(0))).int() + q_seqlens = step_context.q_seqlens.int() kv_seqlens = step_context.kv_seqlens.int() - max_q_seq_len = torch.max(q_seqlens).item() - max_kv_seq_len = torch.max(kv_seqlens).item() if step_context.is_decoding: + # max_q_seq_len, max_kv_seq_len is not used in decoding stage + max_q_seq_len = -1 + max_kv_seq_len = -1 + # collect kv_start_indices without using a for-loop, # (fill kv-cache for just ONE token during the decoding phase) idx = (step_context.kv_seqlens - 1) % block_size @@ -73,6 +76,9 @@ def get_total_slots(): last_block = step_context.block_offsets.gather(1, b_num.view(-1, 1)).view(-1) kv_start_indices = (last_block * block_size + idx).reshape((-1, 1)) else: + max_q_seq_len = torch.max(q_seqlens).cpu().item() + max_kv_seq_len = torch.max(kv_seqlens).cpu().item() + for i in range(step_context.q_start_loc.size(0)): q_seq_len = int(step_context.q_seqlens[i]) kv_seq_len = int(step_context.kv_seqlens[i]) @@ -88,7 +94,7 @@ def get_total_slots(): attn_metadata = attn_meta_cls( step_context.is_decoding, step_context.block_offsets.int(), - q_start_loc=q_start_loc, + q_start_loc=cu_seqlens, q_seqlens=q_seqlens, kv_seqlens=kv_seqlens, kv_start_indices=kv_start_indices, diff --git a/lmdeploy/pytorch/engine/executor/ray_executor.py b/lmdeploy/pytorch/engine/executor/ray_executor.py index c128f28587..4ae018a46e 100644 --- a/lmdeploy/pytorch/engine/executor/ray_executor.py +++ b/lmdeploy/pytorch/engine/executor/ray_executor.py @@ -561,13 +561,14 @@ def _init_workers_ray(self, placement_group: PlacementGroup, worker_kwargs: dict def _init_distributed_environment_by_device(self, device_str: str): """Init distributed environment.""" driver_ip = _get_master_addr() - if device_str in ['cuda', 'maca']: + if device_str == 'cuda': self.workers = self._sort_workers(driver_ip, self.workers) elif device_str == 'ascend': self._init_ascend_distributed_environment(driver_ip) - elif device_str == 'camb': - self._init_camb_distributed_environment(driver_ip) + elif device_str in ['camb', 'maca']: + self.workers = self._sort_workers(driver_ip, self.workers) + ray.get([worker.set_device.remote(idx) for idx, worker in enumerate(self.workers)]) else: raise ValueError(f'Unsupported device type: {device_str}') @@ -590,10 +591,6 @@ def _init_ascend_distributed_environment(self, driver_ip): else: self.workers = self._sort_workers(driver_ip, self.workers) - def _init_camb_distributed_environment(self, driver_ip): - self.workers = self._sort_workers(driver_ip, self.workers) - ray.get([worker.set_device.remote(idx) for idx, worker in enumerate(self.workers)]) - """ PD Disaggregation API Begin """ def p2p_initialize(self, init_request: DistServeInitRequest):