Skip to content

Commit 359c5a0

Browse files
authored
[Maca] fix ray and memory sync (#4164)
* [Maca] fix memory sync * [Maca] fix ray blocking * fix: resolve lint issues * remove comma * fix code * fix code
1 parent 322b133 commit 359c5a0

File tree

2 files changed

+15
-12
lines changed

2 files changed

+15
-12
lines changed

lmdeploy/pytorch/backends/dlinfer/maca/op_backend.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,27 +52,33 @@ def get_total_slots():
5252

5353
kv_start_indices, attention_mask = [], []
5454
block_num, block_size, _, _ = step_context.kv_caches[0][1].shape
55-
device = step_context.block_offsets.device
5655

5756
is_unpaged_prefill = False
5857
if not step_context.is_decoding:
5958
is_unpaged_prefill = \
6059
all((step_context.q_seqlens ==
6160
step_context.kv_seqlens).tolist())
62-
q_start_loc = torch.cat((torch.tensor([0], device=device), step_context.q_seqlens.cumsum(0))).int()
61+
q_start_loc = step_context.q_start_loc
62+
cu_seqlens = torch.cat((q_start_loc, step_context.q_seqlens.sum().unsqueeze(0))).int()
63+
6364
q_seqlens = step_context.q_seqlens.int()
6465
kv_seqlens = step_context.kv_seqlens.int()
65-
max_q_seq_len = torch.max(q_seqlens).item()
66-
max_kv_seq_len = torch.max(kv_seqlens).item()
6766

6867
if step_context.is_decoding:
68+
# max_q_seq_len, max_kv_seq_len is not used in decoding stage
69+
max_q_seq_len = -1
70+
max_kv_seq_len = -1
71+
6972
# collect kv_start_indices without using a for-loop,
7073
# (fill kv-cache for just ONE token during the decoding phase)
7174
idx = (step_context.kv_seqlens - 1) % block_size
7275
b_num = (step_context.kv_seqlens - 1) // block_size
7376
last_block = step_context.block_offsets.gather(1, b_num.view(-1, 1)).view(-1)
7477
kv_start_indices = (last_block * block_size + idx).reshape((-1, 1))
7578
else:
79+
max_q_seq_len = torch.max(q_seqlens).cpu().item()
80+
max_kv_seq_len = torch.max(kv_seqlens).cpu().item()
81+
7682
for i in range(step_context.q_start_loc.size(0)):
7783
q_seq_len = int(step_context.q_seqlens[i])
7884
kv_seq_len = int(step_context.kv_seqlens[i])
@@ -88,7 +94,7 @@ def get_total_slots():
8894
attn_metadata = attn_meta_cls(
8995
step_context.is_decoding,
9096
step_context.block_offsets.int(),
91-
q_start_loc=q_start_loc,
97+
q_start_loc=cu_seqlens,
9298
q_seqlens=q_seqlens,
9399
kv_seqlens=kv_seqlens,
94100
kv_start_indices=kv_start_indices,

lmdeploy/pytorch/engine/executor/ray_executor.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -561,13 +561,14 @@ def _init_workers_ray(self, placement_group: PlacementGroup, worker_kwargs: dict
561561
def _init_distributed_environment_by_device(self, device_str: str):
562562
"""Init distributed environment."""
563563
driver_ip = _get_master_addr()
564-
if device_str in ['cuda', 'maca']:
564+
if device_str == 'cuda':
565565
self.workers = self._sort_workers(driver_ip, self.workers)
566566

567567
elif device_str == 'ascend':
568568
self._init_ascend_distributed_environment(driver_ip)
569-
elif device_str == 'camb':
570-
self._init_camb_distributed_environment(driver_ip)
569+
elif device_str in ['camb', 'maca']:
570+
self.workers = self._sort_workers(driver_ip, self.workers)
571+
ray.get([worker.set_device.remote(idx) for idx, worker in enumerate(self.workers)])
571572
else:
572573
raise ValueError(f'Unsupported device type: {device_str}')
573574

@@ -590,10 +591,6 @@ def _init_ascend_distributed_environment(self, driver_ip):
590591
else:
591592
self.workers = self._sort_workers(driver_ip, self.workers)
592593

593-
def _init_camb_distributed_environment(self, driver_ip):
594-
self.workers = self._sort_workers(driver_ip, self.workers)
595-
ray.get([worker.set_device.remote(idx) for idx, worker in enumerate(self.workers)])
596-
597594
""" PD Disaggregation API Begin """
598595

599596
def p2p_initialize(self, init_request: DistServeInitRequest):

0 commit comments

Comments
 (0)