-
Notifications
You must be signed in to change notification settings - Fork 504
Add support for the FP8 Block Scaling (ie. Deepseek) recipe on Blackwell #2157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
janekb04
wants to merge
17
commits into
NVIDIA:main
Choose a base branch
from
janekb04:deepseek-blackwell
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+546
−35
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Signed-off-by: Jan Bielak <[email protected]>
d7e794a
to
7e7bf91
Compare
Signed-off-by: Jan Bielak <[email protected]>
…ewer in GEMM Signed-off-by: Jan Bielak <[email protected]>
Signed-off-by: Jan Bielak <[email protected]>
Signed-off-by: Jan Bielak <[email protected]>
Signed-off-by: Jan Bielak <[email protected]>
7e7bf91
to
aeafe79
Compare
Signed-off-by: Jan Bielak <[email protected]>
Signed-off-by: Jan Bielak <[email protected]>
Signed-off-by: Jan Bielak <[email protected]>
for more information, see https://pre-commit.ci
/te-ci pytorch L0 |
/te-ci pytorch L0 |
/te-ci L0 |
Signed-off-by: Jan Bielak <[email protected]>
Signed-off-by: Jan Bielak <[email protected]>
for more information, see https://pre-commit.ci
/te-ci L0 L1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
This PR adds support for the FP8 block scaling (ie. DeepSeek) recipe on Blackwell. It exhibits some changes in behavior compared to Hopper.
Addresses this discussion from #1513.
Motivation
Currently, the FP8 block scaling recipe works only on Hopper. If you try to use it on Blackwell, the
check_fp8_block_scaling_support
function infp8.py
will report that is not supported and an exception will prevent further execution. Ifcheck_fp8_block_scaling_support
is changed to instead check that the architecture is Hopper or newer, the failure occurs incublas_gemm
instead. Namely, cuBLASLt does not implementcublasLtMatmul
with aCUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F
orCUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F
input type on Blackwell.A possible workaround is to simply switch from using the
Float8BlockScaling
recipe to theMXFP8BlockScaling
recipe. However, this can result in numerical discrepancies. They occur because of differences in how the two recipes quantize tensors and because the MXFP8 recipe performs more operations in low precision than the block scaling recipe.Implementation
This PR emulates only the GEMMs with MXFP8. This is done by converting input
NVTE_BLOCK_SCALING_1D
andNVTE_BLOCK_SCALING_2D
tensors toNVTE_MXFP8_1D_SCALING
just before a GEMM. The tensors' main data is not touched at all, only the format of the scaling factors is changed to be compatible with MXFP8.The FP8 block scaling tensors are created (quantized from higher precision) using
quantize_transpose_vector_blockwise
orquantize_transpose_square_blockwise
- the same as on Hopper. This means that contrary to simply switching to the MXFP8 recipe, the 1x128 and 128x128 quantization block size is preserved (if MXFP8 was used, the scaling factors could be different, as they would correspond to 1x32 blocks).To make the tensors valid inputs to the MXFP8 GEMM, their scaling factors are converted from the FP8 block scaling format to the MXFP8 format. I take advantage of the fact that when entering the GEMM, the FP8 Block Scaling tensors are guaranteed to already use the
GEMM_READY
scaling factor format. In case of 2D (128x128) block scaling, the scaling factors are simply "unsqueezed" by a factor of 512 - a single scaling factor for a 128x128 block becomes 512 scaling factors for 512 1x32 blocks constituting the 128x128 block. In case of 1D (1x128) block scaling, every 128 scaling factors corresponding to a 128x128 block are swizzled to match the cuBLASLt MXFP8 GEMM format and "unsqueezed" by a factor of 4 to correspond to the 512 1x32 blocks.Limitations
The conversion from FP8 block scaling scaling factors to MXFP8 scaling factors is lossless if and only if the original FP8 block scaling scaling factors are powers of 2. This is because the MXFP8 scaling factors are not FP32 (like the FP8 block scaling scaling factors), but rather FP8E8M0. The conversion kernels assume this requirement is met and simply extract the FP32 exponent bits and treat them as FP8E8M0. If either the sign bit or any mantissa bits are set, the results will be incorrect, as the exponent bits are not masked out when performing bit shifts. Masking them out would result in numerical discrepancies in the output anyway due to discarding them.
Despite losslessly converting the tensors, the GEMM outputs are not identical between Hopper and Blackwell because the Blackwell MXFP8 cuBLASLt GEMM is implemented differently from the Hopper FP8 block scaling cuBLASLt GEMM. However, the numerical error overall should be smaller compared to simply switching to the MXFP8 recipe.
Contrary to FP8 Block Scaling on Hopper, GEMM+GELU fusion is not currently supported as cuBLASLt doesn't support it for MXFP8 (with BF16 output, which the FP8 Block Scaling recipe uses).
Future optimizations to pursue
Take advantage of Blackwell support for FP8 non-TN GEMMs. Contrary to Hopper, Blackwell supports non-TN GEMMs for FP8. This leads to at least the following two optimizations:
_post_process_fp8_blockwise_gather
indistributed.py
transposes columnwise data after all gather. This could be avoided.Use a multi kernel for scaling factor swizzling. Currently, the support for
te_general_grouped_gemm
, which is needed for MoE, is naive. The tensor conversion functionconvert_block_scaling_to_mxfp8_tensor
is simply called in a loop for every tensor. This can result in many small kernel launches, which is inefficient. Instead, a multi kernel approach could be used, similar to how MXFP8 scaling factor swizzling is handled.Type of change
Changes
WIP
Checklist: