-
Notifications
You must be signed in to change notification settings - Fork 579
Add logic for block-scaled tensors with GEMM swizzled scales #2486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
d274220 to
52ce3a4
Compare
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
4925b63 to
1de4b5e
Compare
for more information, see https://pre-commit.ci
|
/te-ci |
Signed-off-by: Tim Moon <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <[email protected]>
for more information, see https://pre-commit.ci
|
/te-ci |
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Greptile OverviewGreptile SummaryThis PR adds explicit tracking of scale ordering formats for block-scaled tensors (MXFP8, NVFP4, DSv3 FP8), distinguishing between "compact" ordering for quantization/communication and "swizzled" ordering for GEMM operations. Key changes:
Benefits:
Confidence Score: 5/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant User as User Code
participant Quantizer as Quantizer
participant QTensor as QuantizedTensor
participant Swizzle as swizzle_scales_for_gemm
participant GEMM as cuBLAS GEMM
User->>Quantizer: set optimize_for_gemm=True
Quantizer->>QTensor: create_tensor(with_gemm_swizzled_scales=True)
Note over QTensor: MXFP8 quantize kernel<br/>writes swizzled scales directly
User->>GEMM: gemm(A, B)
GEMM->>QTensor: check with_gemm_swizzled_scales
alt Scales already swizzled
GEMM->>GEMM: Use scales directly
else Scales not swizzled
GEMM->>Swizzle: swizzle_scales_for_gemm(tensor)
Swizzle-->>GEMM: Return swizzled scales buffer
Note over GEMM: Keep buffer alive during GEMM
end
GEMM->>GEMM: Execute cuBLAS GEMM
GEMM-->>User: Return result
|
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
for more information, see https://pre-commit.ci
|
/te-ci L1 |
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <[email protected]>
|
/te-ci L1 |
Signed-off-by: Tim Moon <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
65 files reviewed, no comments
Description
All of the supported block-scaled tensor formats (MXFP8, NVFP4, DSv3 FP8) have two ways of ordering their scaling factors:
The core infrastructure handles this in an ad hoc way, blindly assuming that the "right" scale ordering is used for the different operations. The PyTorch infrastructure only supports MXFP8 and NVFP4 scales are in compact order, although DSv3 FP8 does have awareness of "compact" and "GEMM-ready" formats. This situation makes it hard to implement fused kernels that can bypass the swizzle kernel.
This PR adds a
with_gemm_swizzled_scalesfield in the C++ tensor class so that the core infrastructure can distinguish between the different scale orderings. It also adds this field in the PyTorch quantized tensor classes, and exposes aoptimize_for_gemmoption in the quantizer so that we can create tensors that do not need communication or checkpointing. Finally, it rips out all the DSv3 FP8 infrastructure for the compact format, which is no longer necessary.Progress
Add option to pre-swizzle weightsCloses #2446.
Type of change
Changes
Please list the changes introduced in this PR:
optimize_for_gemmoption in PyTorch quantizerChecklist: