-
Notifications
You must be signed in to change notification settings - Fork 185
mla ps support paged 64 and 3buffer layout for ds3.2 #1917
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This pull request adds support for paged 64 and 3-buffer layout for MLA (Multi-Head Latent Attention) operations in DeepSpeed 3.2. The PR introduces new parameters (page_size and nhead_kv) across the entire call stack to enable flexible paging strategies and support for a specialized 3-buffer KV cache layout with FP8 quantization.
Changes:
- Added
page_sizeandnhead_kvparameters throughout the MLA API stack (Python, C++, and CUDA) - Introduced support for "byte" datatype to handle 3-buffer layout with separate nope/scale/rope buffers
- Added new kernel assembly files (
mla.co,mla_page64.co) for optimized page size 64 support - Updated metadata generation logic to properly calculate KV offsets for paged layouts
- Added test infrastructure for 3-buffer layout with helper functions for initialization and validation
Reviewed changes
Copilot reviewed 13 out of 15 changed files in this pull request and generated 11 comments.
Show a summary per file
| File | Description |
|---|---|
| op_tests/test_mla_sparse.py | Added kv_last_page_lens and page_size/nhead_kv parameters to function calls |
| op_tests/test_mla_persistent.py | Major additions: 3-buffer KV cache helper functions, new test cases for 3-buffer layout |
| op_tests/test_mla.py | Updated function signatures to include new page_size and nhead_kv parameters |
| hsa/gfx942/mla/mla_page64.co | New assembly kernel binary for page size 64 support |
| hsa/gfx942/mla/mla_asm.csv | Added kernel configuration entry for "byte" datatype |
| hsa/gfx942/mla/mla.co | New assembly kernel binary for 3-buffer layout support |
| csrc/py_itfs_cu/asm_mla.cu | Extended to handle "byte" datatype and removed page_size extraction from KV tensor |
| csrc/kernels/mla/metadata/v1_comm.cuh | Code formatting improvements and copyright update |
| csrc/kernels/mla/metadata/v1_2_device.cuh | Added paged layout support with proper kv_offset calculation |
| csrc/kernels/mla/metadata.cu | Added kv_last_page_lens and page_size parameters to metadata generation |
| csrc/include/rocm_ops.hpp | Updated Python bindings with new parameters |
| csrc/include/mla.h | Updated function signatures with new parameters |
| csrc/include/attention_asm_mla.h | Updated function signatures with new parameters |
| aiter/ops/attention.py | Added default values for backward compatibility |
| aiter/mla.py | Added page_size and nhead_kv parameters with default values |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
2dc00f6 to
1c7550d
Compare
Motivation
support for paged 64 and 3-buffer layout for MLA (Multi-Head Latent Attention) operations in DeepSpeed 3.2
Technical Details
Added page_size and nhead_kv parameters throughout the MLA API stack (Python, C++, and CUDA)
Introduced support for "byte" datatype to handle 3-buffer layout with separate nope/scale/rope buffers
Added new kernel assembly files (mla_a16w8_qh16_m16x4_n16x1_coex0_mask1_ps_page64_ds32.co) for optimized page size 64 support
Updated metadata generation logic to properly calculate KV offsets for paged layouts
Added test infrastructure for 3-buffer layout with helper functions for initialization and validation
Test Plan
python3 op_tests/test_mla_persistent.py -blk=64 -d=bf16 -kvd=fp8 -pl=3BUFFER
Test Result
Submission Checklist