-
Notifications
You must be signed in to change notification settings - Fork 190
Add gfx950 mla a8w8 qh32 kernel #1912
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds support for the gfx950 MLA (Multi-Head Latent Attention) a8w8 (8-bit activation, 8-bit weight) qh32 (32 query heads) kernel. The changes enable a new kernel configuration for fp8 data types with 32 heads in persistent mode with decode sequence length of 4.
Changes:
- Added new kernel entry to the assembly CSV configuration for fp8/fp8 with 32 heads
- Updated metadata generation logic to support 32-head configurations alongside existing 16 and 128-head support
- Modified test harness to disable causal masking and reduce iteration count for testing
- Added debug logging to trace tensor information and kernel arguments
Reviewed changes
Copilot reviewed 8 out of 9 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
| hsa/gfx950/mla/mla_asm.csv | Registers new fp8/fp8 kernel with 32-head configuration |
| csrc/py_itfs_cu/asm_mla.cu | Adds configuration branch for fp8/fp8 32-head decode with qlen=4 and enables debug logging |
| csrc/kernels/mla/metadata/v1_comm.cuh | Introduces NUM_HEADS dispatcher macro for compile-time head count specialization |
| csrc/kernels/mla/metadata/v1_2_device.cuh | Extends metadata generation to support 32-head fp8 configurations |
| aiter/ops/attention.py | Updates metadata calculation to include 32-head case for fp8 data types |
| aiter/mla.py | Extends native support check and adds tensor debugging utilities |
| op_tests/test_mla_persistent.py | Disables causal masking for testing new kernel configuration |
| aiter/test_common.py | Reduces performance test iterations for faster debugging |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
f98d992 to
2d8cb5b
Compare
2d8cb5b to
92790a5
Compare
Motivation
Technical Details
Test Plan
Test Result
Submission Checklist