Skip to content

Conversation

@jiayyu
Copy link
Contributor

@jiayyu jiayyu commented Dec 9, 2025

Motivation

Technical Details

  1. block manager may append multiple blocks
  2. postprocess return a list of tokens
  3. for both deferred output and mtp append multiple placeholders
  4. set_mla_persistent_worker_buffers pass max_q_len as an argument
  5. aiter mla prepare decode metadata support if max_q_len > 1
  6. forward context add optional SpecDecodeMetadata
  7. ops add rejection sampler

Test Plan

Test Result

Submission Checklist

Copilot AI review requested due to automatic review settings December 9, 2025 08:34
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR implements stage 1 support for multiple token processing (MTP) by enabling the model engine to handle multiple decode tokens per step. The changes modify the core engine to track and process variable numbers of tokens per sequence rather than assuming exactly one token per decode step.

Key changes:

  • Modified the postprocessing pipeline to return multiple tokens per sequence (dict[int, list[int]] instead of dict[int, int])
  • Enhanced block allocation to reserve space for lookahead slots when in deferred output mode
  • Updated position and context length calculations to account for previously accepted tokens in deferred output mode

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
atom/model_engine/scheduler.py Updated postprocess to handle multiple tokens per sequence, modified block allocation to reserve lookahead slots, and removed placeholder token logic
atom/model_engine/model_runner.py Changed return types from single token to token lists, added tracking of accepted token counts per sequence
atom/model_engine/block_manager.py Added num_lookahead_slots parameter to allocation methods to reserve space for multiple tokens
atom/model_ops/attentions/backends.py Updated block table preparation to use actual context lengths when in deferred output mode
atom/model_ops/attentions/aiter_mla.py Added logic to calculate actual context lengths and positions based on deferred tokens from previous batch
atom/model_ops/attentions/aiter_attention.py Similar context length calculation updates for standard attention mechanism

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copilot AI review requested due to automatic review settings December 10, 2025 06:27
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@jiayyu jiayyu changed the title Support mtp stage 1: model engine supports multiple decode tokens Support mtp stage 1: model runner postprocess returns multiple tokens Dec 10, 2025
@jiayyu jiayyu changed the title Support mtp stage 1: model runner postprocess returns multiple tokens support mtp stage 2: model runner postprocess returns multiple tokens Dec 10, 2025
@jiayyu jiayyu marked this pull request as draft December 10, 2025 07:40
@jiayyu jiayyu changed the title support mtp stage 2: model runner postprocess returns multiple tokens support mtp stage 2: support deepseek mtp=1 Dec 12, 2025
@jiayyu jiayyu marked this pull request as ready for review December 12, 2025 05:13
Copilot AI review requested due to automatic review settings December 12, 2025 05:13
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 8 out of 9 changed files in this pull request and generated 7 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

)
]
if max_q_len > 1:
context_lens = np.array(batch.context_lens) - var["mtp_k"]
Copy link

Copilot AI Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The computation 'batch.context_lens - var["mtp_k"]' is performed but the result overwrites context_lens which is also used in the else block at line 195. This could be confusing. Consider using a more descriptive variable name like 'adjusted_context_lens' for clarity.

Copilot uses AI. Check for mistakes.
Comment on lines 259 to 264
num_placeholder = (
2 * self.mtp_k if is_deferred_out and self.use_spec else
1 if is_deferred_out else
self.mtp_k if self.use_spec else
0
)
Copy link

Copilot AI Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The nested ternary expression for num_placeholder is difficult to read. Consider using an if-elif-else structure or extracting this logic into a separate method with a descriptive name for better maintainability.

Copilot uses AI. Check for mistakes.
Copilot AI review requested due to automatic review settings December 22, 2025 02:11
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 8 out of 9 changed files in this pull request and generated 8 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

runner=None,
):
self.config = atom_config
self.speculative_config = self.config.speculative_config
Copy link

Copilot AI Dec 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable mtp_k is renamed but its meaning remains ambiguous. Based on the context, this represents the number of speculative tokens, not k in the traditional MTP sense. Consider renaming to num_speculative_tokens for clarity, or add a comment explaining the naming convention.

Suggested change
self.speculative_config = self.config.speculative_config
self.speculative_config = self.config.speculative_config
# Number of speculative tokens used in MTP; kept as `mtp_k` for backward compatibility.

Copilot uses AI. Check for mistakes.
Comment on lines +170 to +179
# Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids]

for _ in range(self.mtp_k - 1):
pass # TODO: support multiple num_speculative_tokens

# [batch_size, num_speculative_tokens]
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
return draft_token_ids

Copy link

Copilot AI Dec 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TODO comment is misleading - the function already handles mtp_k speculative tokens via the loop. Either implement the remaining logic or update the comment to clarify what specific functionality is missing.

Suggested change
# Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids]
for _ in range(self.mtp_k - 1):
pass # TODO: support multiple num_speculative_tokens
# [batch_size, num_speculative_tokens]
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
return draft_token_ids
# Multiple draft tokens (mtp_k > 1) are not yet supported. Fail fast instead of
# returning an incorrectly shaped tensor or silently misbehaving.
raise NotImplementedError(
f"Multiple speculative tokens (mtp_k={self.mtp_k}) are not supported in EagleSpecDecoder yet."
)

Copilot uses AI. Check for mistakes.
metadata.num_draft_tokens,
metadata.max_spec_len,
metadata.cu_num_draft_tokens,
None,
Copy link

Copilot AI Dec 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The draft_probs parameter is always passed as None but still has validation logic and is used in assertions. Either implement probability-based rejection sampling or remove this unused parameter and its associated validation code.

Copilot uses AI. Check for mistakes.
Comment on lines +197 to +202
num_placeholder = (
2 * self.mtp_k if is_deferred_out and self.use_spec else
1 if is_deferred_out else
self.mtp_k if self.use_spec else
0
)
Copy link

Copilot AI Dec 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The multi-line conditional expression for num_placeholder is difficult to read and error-prone. Consider using a traditional if-elif-else structure or extracting this logic into a helper method with clear comments explaining each case.

Suggested change
num_placeholder = (
2 * self.mtp_k if is_deferred_out and self.use_spec else
1 if is_deferred_out else
self.mtp_k if self.use_spec else
0
)
# Determine how many placeholder tokens to use based on deferred output
# and speculative decoding configuration.
if is_deferred_out and self.use_spec:
num_placeholder = 2 * self.mtp_k
elif is_deferred_out:
num_placeholder = 1
elif self.use_spec:
num_placeholder = self.mtp_k
else:
num_placeholder = 0

Copilot uses AI. Check for mistakes.
else:
if not seq.ignore_eos and token_id == self.eos_token_id:
# Check the last token in the list for EOS
if token_ids and not seq.ignore_eos and self.eos_token_id in token_ids:
Copy link

Copilot AI Dec 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The EOS check now looks for eos_token_id anywhere in the token_ids list, but the comment on line 236 says 'Check the last token in the list for EOS'. This could cause premature sequence termination if EOS appears in draft tokens. The check should verify only the last accepted token: token_ids[-1] == self.eos_token_id.

Suggested change
if token_ids and not seq.ignore_eos and self.eos_token_id in token_ids:
if token_ids and not seq.ignore_eos and token_ids[-1] == self.eos_token_id:

Copilot uses AI. Check for mistakes.
"""for decode: input ids are from prev_sampled_token_ids"""
locations, is_all_alive = self.get_prev_alive_locations(batch)
num_deferred_tokens = len(locations)
alive_seq_indices, num_deferred_tokens, is_all_alive = self.get_prev_alive_locations(batch)
Copy link

Copilot AI Dec 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable alive_seq_indices is misleading - it contains indices into the previous batch, not sequence IDs. Consider renaming to prev_batch_indices or deferred_seq_indices to clarify what these indices reference.

Copilot uses AI. Check for mistakes.
Comment on lines +236 to +243
# self.input_ids_loc.np[:num_deferred_tokens] = alive_seq_indices
# self.input_ids_loc.copy_to_gpu(num_deferred_tokens)
# torch.gather(
# self.prev_token_ids,
# 0,
# self.input_ids_loc.gpu[:num_deferred_tokens],
# out=self.input_ids.gpu[:num_deferred_tokens],
# )
Copy link

Copilot AI Dec 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the commented-out code block or replace it with a concise comment explaining why the alternative implementation was chosen. Large blocks of commented code reduce maintainability.

Copilot uses AI. Check for mistakes.
self.forward_vars["draft_tokens"].gpu[:bs, :self.drafter.mtp_k] = draft_token
self.forward_vars["draft_tokens"].copy_to_cpu()
self.tokenID_processor.draft_token_ids = draft_token
self.tokenID_processor.pre_num_decode_token_per_seq = 2
Copy link

Copilot AI Dec 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The hardcoded value 2 should be derived from self.drafter.mtp_k + 1 to maintain consistency with the speculative decoding configuration. Magic numbers make the code fragile to configuration changes.

Suggested change
self.tokenID_processor.pre_num_decode_token_per_seq = 2
self.tokenID_processor.pre_num_decode_token_per_seq = self.drafter.mtp_k + 1

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants