Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions lmdeploy/pytorch/backends/dlinfer/maca/op_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,27 +52,33 @@ def get_total_slots():

kv_start_indices, attention_mask = [], []
block_num, block_size, _, _ = step_context.kv_caches[0][1].shape
device = step_context.block_offsets.device

is_unpaged_prefill = False
if not step_context.is_decoding:
is_unpaged_prefill = \
all((step_context.q_seqlens ==
step_context.kv_seqlens).tolist())
q_start_loc = torch.cat((torch.tensor([0], device=device), step_context.q_seqlens.cumsum(0))).int()
q_start_loc = step_context.q_start_loc
cu_seqlens = torch.cat((q_start_loc, step_context.q_seqlens.sum().unsqueeze(0))).int()

q_seqlens = step_context.q_seqlens.int()
kv_seqlens = step_context.kv_seqlens.int()
max_q_seq_len = torch.max(q_seqlens).item()
max_kv_seq_len = torch.max(kv_seqlens).item()

if step_context.is_decoding:
# max_q_seq_len, max_kv_seq_len is not used in decoding stage
max_q_seq_len = -1
max_kv_seq_len = -1

# collect kv_start_indices without using a for-loop,
# (fill kv-cache for just ONE token during the decoding phase)
idx = (step_context.kv_seqlens - 1) % block_size
b_num = (step_context.kv_seqlens - 1) // block_size
last_block = step_context.block_offsets.gather(1, b_num.view(-1, 1)).view(-1)
kv_start_indices = (last_block * block_size + idx).reshape((-1, 1))
else:
max_q_seq_len = torch.max(q_seqlens).cpu().item()
max_kv_seq_len = torch.max(kv_seqlens).cpu().item()

for i in range(step_context.q_start_loc.size(0)):
q_seq_len = int(step_context.q_seqlens[i])
kv_seq_len = int(step_context.kv_seqlens[i])
Expand All @@ -88,7 +94,7 @@ def get_total_slots():
attn_metadata = attn_meta_cls(
step_context.is_decoding,
step_context.block_offsets.int(),
q_start_loc=q_start_loc,
q_start_loc=cu_seqlens,
q_seqlens=q_seqlens,
kv_seqlens=kv_seqlens,
kv_start_indices=kv_start_indices,
Expand Down
9 changes: 8 additions & 1 deletion lmdeploy/pytorch/engine/executor/ray_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,13 +561,15 @@ def _init_workers_ray(self, placement_group: PlacementGroup, worker_kwargs: dict
def _init_distributed_environment_by_device(self, device_str: str):
"""Init distributed environment."""
driver_ip = _get_master_addr()
if device_str in ['cuda', 'maca']:
if device_str == 'cuda':
self.workers = self._sort_workers(driver_ip, self.workers)

elif device_str == 'ascend':
self._init_ascend_distributed_environment(driver_ip)
elif device_str == 'camb':
self._init_camb_distributed_environment(driver_ip)
elif device_str == 'maca':
self._init_maca_distributed_environment(driver_ip)
else:
raise ValueError(f'Unsupported device type: {device_str}')

Expand All @@ -594,6 +596,11 @@ def _init_camb_distributed_environment(self, driver_ip):
self.workers = self._sort_workers(driver_ip, self.workers)
ray.get([worker.set_device.remote(idx) for idx, worker in enumerate(self.workers)])

def _init_maca_distributed_environment(self, driver_ip):
"""Init maca distributed environment."""
self.workers = self._sort_workers(driver_ip, self.workers)
ray.get([worker.set_device.remote(idx) for idx, worker in enumerate(self.workers)])
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _init_maca_distributed_environment method is identical to _init_camb_distributed_environment. Consider extracting this common logic into a shared helper method to reduce code duplication. For example:

def _init_generic_distributed_environment(self, driver_ip):
    """Init generic distributed environment for device types."""
    self.workers = self._sort_workers(driver_ip, self.workers)
    ray.get([worker.set_device.remote(idx) for idx, worker in enumerate(self.workers)])

def _init_camb_distributed_environment(self, driver_ip):
    self._init_generic_distributed_environment(driver_ip)

def _init_maca_distributed_environment(self, driver_ip):
    """Init maca distributed environment."""
    self._init_generic_distributed_environment(driver_ip)

Copilot uses AI. Check for mistakes.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch


""" PD Disaggregation API Begin """

def p2p_initialize(self, init_request: DistServeInitRequest):
Expand Down