Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions cpp/tensorrt_llm/thop/allgatherOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class AllgatherOp
return 0;
}

torch::Tensor run(torch::Tensor input, torch::optional<torch::List<int64_t>> sizes)
torch::Tensor run(torch::Tensor input, torch::optional<torch::List<int64_t>> sizes, bool called_by_run_list = false)
{
TLLM_CHECK_WITH_INFO(mNcclComm.get() != nullptr, "mNcclComm should be initialized before used");
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
Expand All @@ -74,7 +74,10 @@ class AllgatherOp
{
size_t numel_base = std::accumulate(outputShape.cbegin() + 1, outputShape.cend(), 1, std::multiplies<>{});
int64_t split_offset = 0;
ncclGroupStart();
if (!called_by_run_list)
{
ncclGroupStart();
}
for (int root = 0; root < static_cast<int>(mGroup.size()); ++root)
{
auto split_size = sizes.value()[root];
Expand All @@ -83,7 +86,10 @@ class AllgatherOp
numel_base * split_size, (*getDtypeMap())[type], root, *mNcclComm, stream));
split_offset += split_size;
}
ncclGroupEnd();
if (!called_by_run_list)
{
ncclGroupEnd();
}
}
else
{
Expand All @@ -100,7 +106,7 @@ class AllgatherOp
ncclGroupStart();
for (auto const& input : input_list)
{
auto output = run(input, sizes);
auto output = run(input, sizes, true);
output_list.push_back(output);
}
ncclGroupEnd();
Expand Down
61 changes: 27 additions & 34 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1197,16 +1197,16 @@ def _prepare_tp_inputs(
new_tokens_lens_device = new_tensors_device.new_tokens_lens # [batch]
next_draft_tokens_device = new_tensors_device.next_draft_tokens # [batch, draft_len]

# Requests with draft tokens are treated like extend requests. CUDA graph dummy extend
# requests should be at the end of extend_requests.
# Requests with draft tokens are treated like extend requests. Dummy extend requests should be
# at the end of extend_requests.
extend_requests = []
extend_cuda_graph_dummy_requests = []
extend_dummy_requests = []
generation_requests = []
for request in scheduled_requests.generation_requests:
if len(request.py_draft_tokens
) > 0 or next_draft_tokens_device is not None:
if request.is_cuda_graph_dummy:
extend_cuda_graph_dummy_requests.append(request)
if request.is_dummy:
extend_dummy_requests.append(request)
else:
extend_requests.append(request)
else:
Expand All @@ -1219,8 +1219,8 @@ def _prepare_tp_inputs(
pin_memory=True)
mrope_config['mrope_position_deltas'].append(
mrope_position_deltas.to('cuda', non_blocking=True))
extend_requests += extend_dummy_requests

extend_requests = extend_cuda_graph_dummy_requests + extend_requests
if not self._disable_overlap_scheduler and self.is_spec_decode:
spec_dec_mode = self.spec_config.spec_dec_mode
assert spec_dec_mode.support_overlap_scheduler(
Expand All @@ -1229,18 +1229,18 @@ def _prepare_tp_inputs(
# will contain previous batch incices of generation requests
previous_batch_indices = []
previous_pos_indices = []
request_ids_with_previous_batch = []
num_extend_reqs_wo_previous_batch = 0
for request in extend_requests:
# the request has no previous tensor:
# (1) next_draft_tokens_device is None, which means overlap scheduler is disabled; or
# (2) a dummy request; or
# (3) the first step in the generation server of disaggregated serving
if next_draft_tokens_device is None or request.is_dummy or request.py_batch_idx is None:
# get token ids, including input token ids and draft token ids
input_ids.append(request.get_last_tokens(0))
input_ids.extend(request.py_draft_tokens)
draft_tokens.extend(request.py_draft_tokens)
# get token ids, including input token ids and draft token ids. For these dummy requests,
# no need to copy the token ids.
if not request.is_dummy:
input_ids.append(request.get_last_tokens(0))
input_ids.extend(request.py_draft_tokens)
draft_tokens.extend(request.py_draft_tokens)
# get other ids and lengths
num_draft_tokens = len(request.py_draft_tokens)
past_seen_token_num = request.max_beam_num_tokens - 1
Expand Down Expand Up @@ -1268,7 +1268,6 @@ def _prepare_tp_inputs(
# update batch index
request.py_batch_idx = batch_idx
batch_idx += 1
num_extend_reqs_wo_previous_batch += 1
else:
# update batch index
previous_batch_idx = request.py_batch_idx
Expand All @@ -1295,10 +1294,7 @@ def _prepare_tp_inputs(
num_cached_tokens_per_seq.append(past_seen_token_num +
self.max_draft_len + 1)
prompt_lengths.append(request.py_prompt_len)
request_ids_with_previous_batch.append(request.py_request_id)

# move requests with previous batch to the end of the list
request_ids.extend(request_ids_with_previous_batch)
request_ids.append(request.py_request_id)

sequence_lengths.extend([1] * len(generation_requests))
gather_ids.extend(
Expand Down Expand Up @@ -1333,6 +1329,7 @@ def _prepare_tp_inputs(
num_tokens = len(input_ids)
num_draft_tokens = len(draft_tokens)
previous_batchs = len(previous_batch_indices)
num_requests = len(request_ids)
total_num_tokens = len(position_ids)
assert total_num_tokens <= self.max_num_tokens, (
"total_num_tokens should be less than or equal to max_num_tokens")
Expand Down Expand Up @@ -1374,31 +1371,27 @@ def _prepare_tp_inputs(
non_blocking=True)
# prepare data for the preprocess inputs
kv_len_offsets_device = new_tokens_lens_device - self.max_draft_len - 1
pre_tokens_start_idx = num_extend_reqs_wo_previous_batch * (
1 + self.max_draft_len)
pre_tokens_end_idx = pre_tokens_start_idx + previous_batch_tokens
pre_batch_start_idx = num_extend_reqs_wo_previous_batch
pre_batch_end_idx = pre_batch_start_idx + previous_batchs
previous_pos_indices = torch.tensor(previous_pos_indices,
dtype=torch.int,
pin_memory=True)
self.previous_pos_indices_cuda[
pre_tokens_start_idx:pre_tokens_end_idx].copy_(
previous_pos_indices, non_blocking=True)
self.previous_pos_indices_cuda[0:previous_batch_tokens].copy_(
previous_pos_indices, non_blocking=True)
self.previous_pos_id_offsets_cuda[
pre_tokens_start_idx:pre_tokens_end_idx].copy_(
0:previous_batch_tokens].copy_(
new_tokens_lens_device[self.previous_pos_indices_cuda[
pre_tokens_start_idx:pre_tokens_end_idx]],
non_blocking=True)
self.previous_kv_lens_offsets_cuda[
pre_batch_start_idx:pre_batch_end_idx].copy_(
kv_len_offsets_device[
self.previous_batch_indices_cuda[:previous_batchs]],
0:previous_batch_tokens]],
non_blocking=True)
self.previous_kv_lens_offsets_cuda[0:previous_batchs].copy_(
kv_len_offsets_device[
self.previous_batch_indices_cuda[:previous_batchs]],
non_blocking=True)
# for the requests that do not have previous batch, set the previous_pos_id_offsets and
# previous_kv_lens_offsets to zeros to skip the value changes in _preprocess_inputs
self.previous_pos_id_offsets_cuda[:pre_tokens_start_idx] *= 0
self.previous_kv_lens_offsets_cuda[:pre_batch_start_idx] *= 0
self.previous_pos_id_offsets_cuda[
previous_batch_tokens:num_requests *
(1 + self.max_draft_len)] *= 0
self.previous_kv_lens_offsets_cuda[
previous_batchs:num_requests] *= 0
else:
# change the data to zeros to skip the value changes in _preprocess_inputs
self.previous_pos_id_offsets_cuda *= 0
Expand Down
33 changes: 16 additions & 17 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1492,7 +1492,7 @@ def _check_disagg_gen_transfer_status(self):
@nvtx_range("_pad_attention_dp_dummy_request")
def _pad_attention_dp_dummy_request(self):
"""
Pad with dummy requests, if required, to avoid empty attention_dp rank.
Pad with a dummy request, if required, to ensure every attention_dp rank has at least one active request.
"""
if not self.enable_attention_dp:
return
Expand All @@ -1506,22 +1506,20 @@ def _pad_attention_dp_dummy_request(self):
or req.is_disagg_generation_transmission_in_progress else 1
for req in self.active_requests
])
num_dummy_request = self.expected_num_active_requests - num_active_request
if num_dummy_request > 0:
llm_request_list = self.kv_cache_manager.add_dummy_requests(
request_ids=list(range(num_dummy_request)),

if self.expected_num_active_requests - num_active_request > 0 and num_active_request == 0:
llm_request = self.kv_cache_manager.add_dummy_requests(
request_ids=[0],
is_gen=not self.has_context_request,
prepare_resource=not self.has_context_request,
max_num_draft_tokens=self.max_draft_tokens,
)
for llm_request in llm_request_list:
llm_request.is_attention_dp_dummy = True
)[0]
llm_request.is_attention_dp_dummy = True
spec_resource_manager = self.resource_manager.get_resource_manager(
ResourceManagerType.SPEC_RESOURCE_MANAGER)
if spec_resource_manager is not None:
spec_resource_manager.add_dummy_requests(
list(range(num_dummy_request)))
self.active_requests += llm_request_list
spec_resource_manager.add_dummy_requests([0])
self.active_requests.append(llm_request)

@nvtx_range("_prepare_disagg_gen_init")
def _prepare_disagg_gen_init(self, fitting_disagg_gen_init_requests):
Expand Down Expand Up @@ -1645,12 +1643,13 @@ def forward(scheduled_requests, resource_manager, new_tensors_device,

def _update_request_states_tp(self, scheduled_requests: ScheduledRequests):
# handle potential attention dp dummy request
for request in self.active_requests[:]:
if request.is_attention_dp_dummy:
request.state = LlmRequestState.GENERATION_COMPLETE
self.inflight_req_ids.erase(request.py_request_id)
self._terminate_request(request)
self.active_requests.remove(request)
if self.active_requests and self.active_requests[
-1].is_attention_dp_dummy:
request = self.active_requests[-1]
request.state = LlmRequestState.GENERATION_COMPLETE
self.inflight_req_ids.erase(request.py_request_id)
self._terminate_request(request)
self.active_requests.remove(request)

for request in scheduled_requests.context_requests:
if request.state != LlmRequestState.GENERATION_COMPLETE: # skip failed requests
Expand Down
6 changes: 0 additions & 6 deletions tensorrt_llm/_torch/speculative/mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,12 +267,6 @@ def update_requests(self, state: SampleStateMTP) -> None:
request.py_decoding_iter += 1
idx += 1

# skip the results of cuda graph dummy requests
if idx == 0:
num_cuda_graph_dummy_requests = len(new_tokens_list) - len(
state.scheduled_requests.generation_requests)
idx += num_cuda_graph_dummy_requests

for request in state.scheduled_requests.generation_requests:
assert not request.py_return_context_logits, "return_context_logits not implemented for MTPSampler"
assert not request.py_return_generation_logits, "return_generation_logits not implemented for MTPSampler"
Expand Down
3 changes: 1 addition & 2 deletions tests/integration/defs/accuracy/test_llm_api_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,8 +646,7 @@ def test_fp8_block_scales(self, mtp_nextn, fp8kv, attention_dp, cuda_graph,
@pytest.mark.skip_device_not_contain(["H100"])
@parametrize_with_ids("mtp_nextn", [0, 2])
def test_fp8_block_scales_cuda_graph_padding(self, mtp_nextn):
# OOM on H100 with default free_gpu_memory_fraction=0.9
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8)
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
mtp_config = None
if mtp_nextn > 0:
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn)
Expand Down