Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 31 additions & 35 deletions lmdeploy/pytorch/engine/model_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,7 @@ async def _async_model_forward(
):
"""Model forward."""
max_prefill_token_num = self.cache_config.max_prefill_token_num
strategy = self.agent_strategy

class _OutputGather:
"""Output gather."""
Expand Down Expand Up @@ -469,7 +470,11 @@ def gather(self, output):
def get_output(self):
"""Get tmp_output."""
if not return_logits:
return self._output[:, -1:]
seqlen = torch.full((1, ),
self._output.numel() // self._output.size(-1),
device=self._output.device,
dtype=self._output.dtype)
return strategy.slice_outputs(self._output, seqlen)
torch.cuda.synchronize()
return self._output.to(self._device)

Expand Down Expand Up @@ -562,17 +567,14 @@ def _push_output(self, output: BatchedOutputs):
self._out_que.put_nowait((output, event))

@contextmanager
def _broadcast_next_token(self, next_token_ids: torch.Tensor, dist_ctx: DistContext = None, enable: bool = True):
def _broadcast_next_token(self, next_token_ids: torch.Tensor, extra_inputs: ExtraInputs, enable: bool = True):
if not enable:
yield
return

if dist_ctx is None:
dist_ctx = get_dist_manager().current_context()
tp_gpu_group = dist_ctx.tp_gpu_group
handle = dist.broadcast(next_token_ids, src=0, group=tp_gpu_group, async_op=True)
yield
handle.wait()
dist_ctx = self.dist_ctx
with self.agent_strategy.broadcast_next_token(next_token_ids, extra_inputs, dist_ctx) as handle:
yield handle

async def _async_step_background(
self,
Expand Down Expand Up @@ -698,6 +700,7 @@ async def __prepare_dp():
seq_length = output.get('seq_length', inputs.seq_length)
last_logits = self._slice_outs(logits, seq_length) # [bs, 1, prob] -> [bs, prob]
extra_inputs = self.agent_strategy.slice_extra_inputs(extra_inputs, seq_length)
model_metas = output.get('model_metas')

# output empty for dummy inputs
if is_dummy:
Expand All @@ -711,47 +714,40 @@ async def __prepare_dp():
# sampling
next_token_ids, logprobs = await self.async_sampling_logits(last_logits, sampling_inputs, inputs)

with self._broadcast_next_token(next_token_ids, dist_ctx, enable=need_broadcast_next):
logger.debug(f'<ForwardTask> rank[{rank}]: synchronize token ids [{idx}]')
# post sampling
next_token_ids, extra_inputs = self.agent_strategy.post_sampling(inputs, last_logits, next_token_ids,
extra_inputs)

# post sampling
next_token_ids, extra_inputs = self.agent_strategy.post_sampling(
inputs, last_logits, next_token_ids, extra_inputs)
with self._broadcast_next_token(next_token_ids, extra_inputs, enable=need_broadcast_next):
logger.debug(f'<ForwardTask> rank[{rank}]: synchronize token ids [{idx}]')

# stopping criteria
stopped, stop_pos, stopping_criteria = stopping_criteria.step(next_token_ids,
sampling_inputs.stop_words,
inputs=inputs,
extra_inputs=extra_inputs)

# send output
logger.debug(f'<ForwardTask> rank[{rank}]: Output [{idx}]')
extra_outputs = self.agent_strategy.make_extra_outputs(extra_inputs)
self._push_output(
BatchedOutputs(next_token_ids=next_token_ids,
logits=logits if return_logits else None,
stopped=stopped,
stop_pos=stop_pos,
model_metas=model_metas,
logprobs=logprobs,
extra_outputs=extra_outputs))
else:
# Avoid adding the ADInplaceOrView dispatch key to `next_token_ids`,
# as it can trigger recompilation on different ranks when using torch.compile.
with torch.inference_mode():
next_token_ids = inputs.input_ids.new_zeros(last_logits.size(0))
logprobs = None
next_token_ids, extra_inputs = self.agent_strategy.make_dummy_next_token(
inputs, last_logits, extra_inputs)

# broadcast next token for TP > 1
with self._broadcast_next_token(next_token_ids, dist_ctx, enable=need_broadcast_next):
with self._broadcast_next_token(next_token_ids, extra_inputs, enable=need_broadcast_next):
logger.debug(f'<ForwardTask> rank[{rank}]: synchronize token ids [{idx}]')

# post sampling
next_token_ids, extra_inputs = self.agent_strategy.post_sampling(inputs, last_logits, next_token_ids,
extra_inputs)

# send output
model_metas = output.get('model_metas')
if need_output:
logger.debug(f'<ForwardTask> rank[{rank}]: Output [{idx}]')
extra_outputs = self.agent_strategy.make_extra_outputs(extra_inputs)
self._push_output(
BatchedOutputs(next_token_ids=next_token_ids,
logits=logits if return_logits else None,
stopped=stopped,
stop_pos=stop_pos,
model_metas=model_metas,
logprobs=logprobs,
extra_outputs=extra_outputs))

# update for next loop
if is_decoding and idx < loop_count - 1:
inputs, extra_inputs = __update_inputs(next_token_ids, model_metas, extra_inputs)
Expand Down
11 changes: 11 additions & 0 deletions lmdeploy/pytorch/strategies/ar/model_agent.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, List, Optional

import torch
from torch.profiler import record_function

import lmdeploy.pytorch.distributed as dist
from lmdeploy.pytorch.distributed import DistContext
from lmdeploy.pytorch.engine.logits_process import SamplingInputs
from lmdeploy.pytorch.messages import SchedulerSequence
from lmdeploy.pytorch.model_inputs import ModelInputs
Expand Down Expand Up @@ -106,3 +109,11 @@ def post_sampling(self, inputs: 'ModelInputs', logits: torch.Tensor, next_token_
extra_inputs: ARExtraInputs):
"""Post sampling."""
return next_token_ids, extra_inputs

@contextmanager
def broadcast_next_token(self, next_token_ids: torch.Tensor, extra_inputs: ExtraInputs, dist_ctx: DistContext):
"""Broadcast next token ids and extra inputs."""
tp_gpu_group = dist_ctx.tp_gpu_group
handle = dist.broadcast(next_token_ids, src=0, group=tp_gpu_group, async_op=True)
yield
handle.wait()
17 changes: 17 additions & 0 deletions lmdeploy/pytorch/strategies/base/model_agent.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass, fields
from typing import TYPE_CHECKING, Any, List, Optional

import numpy as np
import torch

if TYPE_CHECKING:
from lmdeploy.pytorch.distributed import DistContext
from lmdeploy.pytorch.engine.logits_process import SamplingInputs
from lmdeploy.pytorch.messages import SchedulerSequence
from lmdeploy.pytorch.model_inputs import ModelInputs
Expand All @@ -33,6 +35,10 @@ def to_device(self, device: str, non_blocking: bool = False):
"""To device."""
return to_device(self, device, non_blocking)

def broadcast(self, src: int, group, async_op=False):
"""Broadcast extra inputs."""
pass


@dataclass
class ExtraOutputs(ABC):
Expand Down Expand Up @@ -130,3 +136,14 @@ def post_sampling(self, inputs: 'ModelInputs', logits: torch.Tensor, next_token_
extra_inputs: ExtraInputs):
"""Post sampling."""
pass

def make_dummy_next_token(self, inputs: 'ModelInputs', logits: torch.Tensor, extra_inputs: ExtraInputs):
"""Make dummy next token for broadcast."""
with torch.inference_mode():
next_token_ids = inputs.input_ids.new_zeros(logits.size(0))
return next_token_ids, extra_inputs

@abstractmethod
@contextmanager
def broadcast_next_token(self, next_token_ids: torch.Tensor, extra_inputs: ExtraInputs, dist_ctx: 'DistContext'):
"""Broadcast next token ids and extra inputs."""
2 changes: 1 addition & 1 deletion lmdeploy/pytorch/strategies/base/model_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def make_dummy_inputs(batch_size: int,
num_ignored_history=num_ignored_history,
max_q_seqlen=max_q_seqlen,
max_kv_seqlen=max_kv_seqlen,
sum_kv_seqlen=batch_size,
sum_kv_seqlen=num_tokens,
local_adapter_ids=local_adapter_ids,
)

Expand Down
21 changes: 21 additions & 0 deletions lmdeploy/pytorch/strategies/dllm/model_agent.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
# Copyright (c) OpenMMLab. All rights reserved.
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, List, Optional

import numpy as np
import torch
from torch.profiler import record_function

import lmdeploy.pytorch.distributed as dist
from lmdeploy.pytorch import consts
from lmdeploy.pytorch.config import DLLMConfig
from lmdeploy.pytorch.distributed import DistContext
from lmdeploy.pytorch.engine.logits_process import SamplingInputs
from lmdeploy.pytorch.messages import SchedulerSequence
from lmdeploy.pytorch.model_inputs import ModelInputs
Expand All @@ -23,6 +26,9 @@ class DLLMExtraInputs(ExtraInputs):
"""DLLM extra inputs."""
dllm_mask: torch.Tensor

def broadcast(self, src: int, group, async_op=False):
return dist.broadcast(self.dllm_mask, src=src, group=group, async_op=async_op)


@dataclass
class DLLMExtraOutputs(ExtraOutputs):
Expand Down Expand Up @@ -216,3 +222,18 @@ def post_sampling(self, inputs: 'ModelInputs', logits: torch.Tensor, next_token_

extra_inputs.dllm_mask = dllm_mask
return next_token_ids, extra_inputs

def make_dummy_next_token(self, inputs: 'ModelInputs', logits: torch.Tensor, extra_inputs: DLLMExtraInputs):
"""Make dummy next token for broadcast."""
with torch.inference_mode():
next_token_ids = inputs.input_ids.new_zeros(logits.size(0))
return next_token_ids, extra_inputs

@contextmanager
def broadcast_next_token(self, next_token_ids: torch.Tensor, extra_inputs: DLLMExtraInputs, dist_ctx: DistContext):
"""Broadcast next token ids and extra inputs."""
tp_gpu_group = dist_ctx.tp_gpu_group
dist.broadcast(next_token_ids, src=0, group=tp_gpu_group, async_op=True)
handle = extra_inputs.broadcast(src=0, group=tp_gpu_group, async_op=True)
yield
handle.wait()