-
Notifications
You must be signed in to change notification settings - Fork 3
support mtp stage 2: support deepseek mtp=1 #35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR implements stage 1 support for multiple token processing (MTP) by enabling the model engine to handle multiple decode tokens per step. The changes modify the core engine to track and process variable numbers of tokens per sequence rather than assuming exactly one token per decode step.
Key changes:
- Modified the postprocessing pipeline to return multiple tokens per sequence (
dict[int, list[int]]instead ofdict[int, int]) - Enhanced block allocation to reserve space for lookahead slots when in deferred output mode
- Updated position and context length calculations to account for previously accepted tokens in deferred output mode
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| atom/model_engine/scheduler.py | Updated postprocess to handle multiple tokens per sequence, modified block allocation to reserve lookahead slots, and removed placeholder token logic |
| atom/model_engine/model_runner.py | Changed return types from single token to token lists, added tracking of accepted token counts per sequence |
| atom/model_engine/block_manager.py | Added num_lookahead_slots parameter to allocation methods to reserve space for multiple tokens |
| atom/model_ops/attentions/backends.py | Updated block table preparation to use actual context lengths when in deferred output mode |
| atom/model_ops/attentions/aiter_mla.py | Added logic to calculate actual context lengths and positions based on deferred tokens from previous batch |
| atom/model_ops/attentions/aiter_attention.py | Similar context length calculation updates for standard attention mechanism |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 8 out of 9 changed files in this pull request and generated 7 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| ) | ||
| ] | ||
| if max_q_len > 1: | ||
| context_lens = np.array(batch.context_lens) - var["mtp_k"] |
Copilot
AI
Dec 12, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The computation 'batch.context_lens - var["mtp_k"]' is performed but the result overwrites context_lens which is also used in the else block at line 195. This could be confusing. Consider using a more descriptive variable name like 'adjusted_context_lens' for clarity.
atom/model_engine/scheduler.py
Outdated
| num_placeholder = ( | ||
| 2 * self.mtp_k if is_deferred_out and self.use_spec else | ||
| 1 if is_deferred_out else | ||
| self.mtp_k if self.use_spec else | ||
| 0 | ||
| ) |
Copilot
AI
Dec 12, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The nested ternary expression for num_placeholder is difficult to read. Consider using an if-elif-else structure or extracting this logic into a separate method with a descriptive name for better maintainability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 8 out of 9 changed files in this pull request and generated 8 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| runner=None, | ||
| ): | ||
| self.config = atom_config | ||
| self.speculative_config = self.config.speculative_config |
Copilot
AI
Dec 22, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable mtp_k is renamed but its meaning remains ambiguous. Based on the context, this represents the number of speculative tokens, not k in the traditional MTP sense. Consider renaming to num_speculative_tokens for clarity, or add a comment explaining the naming convention.
| self.speculative_config = self.config.speculative_config | |
| self.speculative_config = self.config.speculative_config | |
| # Number of speculative tokens used in MTP; kept as `mtp_k` for backward compatibility. |
| # Generate the remaining draft tokens. | ||
| draft_token_ids_list = [draft_token_ids] | ||
|
|
||
| for _ in range(self.mtp_k - 1): | ||
| pass # TODO: support multiple num_speculative_tokens | ||
|
|
||
| # [batch_size, num_speculative_tokens] | ||
| draft_token_ids = torch.stack(draft_token_ids_list, dim=1) | ||
| return draft_token_ids | ||
|
|
Copilot
AI
Dec 22, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The TODO comment is misleading - the function already handles mtp_k speculative tokens via the loop. Either implement the remaining logic or update the comment to clarify what specific functionality is missing.
| # Generate the remaining draft tokens. | |
| draft_token_ids_list = [draft_token_ids] | |
| for _ in range(self.mtp_k - 1): | |
| pass # TODO: support multiple num_speculative_tokens | |
| # [batch_size, num_speculative_tokens] | |
| draft_token_ids = torch.stack(draft_token_ids_list, dim=1) | |
| return draft_token_ids | |
| # Multiple draft tokens (mtp_k > 1) are not yet supported. Fail fast instead of | |
| # returning an incorrectly shaped tensor or silently misbehaving. | |
| raise NotImplementedError( | |
| f"Multiple speculative tokens (mtp_k={self.mtp_k}) are not supported in EagleSpecDecoder yet." | |
| ) |
| metadata.num_draft_tokens, | ||
| metadata.max_spec_len, | ||
| metadata.cu_num_draft_tokens, | ||
| None, |
Copilot
AI
Dec 22, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The draft_probs parameter is always passed as None but still has validation logic and is used in assertions. Either implement probability-based rejection sampling or remove this unused parameter and its associated validation code.
| num_placeholder = ( | ||
| 2 * self.mtp_k if is_deferred_out and self.use_spec else | ||
| 1 if is_deferred_out else | ||
| self.mtp_k if self.use_spec else | ||
| 0 | ||
| ) |
Copilot
AI
Dec 22, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The multi-line conditional expression for num_placeholder is difficult to read and error-prone. Consider using a traditional if-elif-else structure or extracting this logic into a helper method with clear comments explaining each case.
| num_placeholder = ( | |
| 2 * self.mtp_k if is_deferred_out and self.use_spec else | |
| 1 if is_deferred_out else | |
| self.mtp_k if self.use_spec else | |
| 0 | |
| ) | |
| # Determine how many placeholder tokens to use based on deferred output | |
| # and speculative decoding configuration. | |
| if is_deferred_out and self.use_spec: | |
| num_placeholder = 2 * self.mtp_k | |
| elif is_deferred_out: | |
| num_placeholder = 1 | |
| elif self.use_spec: | |
| num_placeholder = self.mtp_k | |
| else: | |
| num_placeholder = 0 |
| else: | ||
| if not seq.ignore_eos and token_id == self.eos_token_id: | ||
| # Check the last token in the list for EOS | ||
| if token_ids and not seq.ignore_eos and self.eos_token_id in token_ids: |
Copilot
AI
Dec 22, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The EOS check now looks for eos_token_id anywhere in the token_ids list, but the comment on line 236 says 'Check the last token in the list for EOS'. This could cause premature sequence termination if EOS appears in draft tokens. The check should verify only the last accepted token: token_ids[-1] == self.eos_token_id.
| if token_ids and not seq.ignore_eos and self.eos_token_id in token_ids: | |
| if token_ids and not seq.ignore_eos and token_ids[-1] == self.eos_token_id: |
| """for decode: input ids are from prev_sampled_token_ids""" | ||
| locations, is_all_alive = self.get_prev_alive_locations(batch) | ||
| num_deferred_tokens = len(locations) | ||
| alive_seq_indices, num_deferred_tokens, is_all_alive = self.get_prev_alive_locations(batch) |
Copilot
AI
Dec 22, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable alive_seq_indices is misleading - it contains indices into the previous batch, not sequence IDs. Consider renaming to prev_batch_indices or deferred_seq_indices to clarify what these indices reference.
| # self.input_ids_loc.np[:num_deferred_tokens] = alive_seq_indices | ||
| # self.input_ids_loc.copy_to_gpu(num_deferred_tokens) | ||
| # torch.gather( | ||
| # self.prev_token_ids, | ||
| # 0, | ||
| # self.input_ids_loc.gpu[:num_deferred_tokens], | ||
| # out=self.input_ids.gpu[:num_deferred_tokens], | ||
| # ) |
Copilot
AI
Dec 22, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the commented-out code block or replace it with a concise comment explaining why the alternative implementation was chosen. Large blocks of commented code reduce maintainability.
| self.forward_vars["draft_tokens"].gpu[:bs, :self.drafter.mtp_k] = draft_token | ||
| self.forward_vars["draft_tokens"].copy_to_cpu() | ||
| self.tokenID_processor.draft_token_ids = draft_token | ||
| self.tokenID_processor.pre_num_decode_token_per_seq = 2 |
Copilot
AI
Dec 22, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The hardcoded value 2 should be derived from self.drafter.mtp_k + 1 to maintain consistency with the speculative decoding configuration. Magic numbers make the code fragile to configuration changes.
| self.tokenID_processor.pre_num_decode_token_per_seq = 2 | |
| self.tokenID_processor.pre_num_decode_token_per_seq = self.drafter.mtp_k + 1 |
Motivation
Technical Details
Test Plan
Test Result
Submission Checklist