Skip to content

Conversation

janekb04
Copy link
Collaborator

@janekb04 janekb04 commented Sep 5, 2025

Description

This PR adds support for the FP8 block scaling (ie. DeepSeek) recipe on Blackwell. It exhibits some changes in behavior compared to Hopper.

Addresses this discussion from #1513.

Motivation

Currently, the FP8 block scaling recipe works only on Hopper. If you try to use it on Blackwell, the check_fp8_block_scaling_support function in fp8.py will report that is not supported and an exception will prevent further execution. If check_fp8_block_scaling_support is changed to instead check that the architecture is Hopper or newer, the failure occurs in cublas_gemm instead. Namely, cuBLASLt does not implement cublasLtMatmul with a CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F or CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F input type on Blackwell.

A possible workaround is to simply switch from using the Float8BlockScaling recipe to the MXFP8BlockScaling recipe. However, this can result in numerical discrepancies. They occur because of differences in how the two recipes quantize tensors and because the MXFP8 recipe performs more operations in low precision than the block scaling recipe.

Implementation

This PR emulates only the GEMMs with MXFP8. This is done by converting input NVTE_BLOCK_SCALING_1D and NVTE_BLOCK_SCALING_2D tensors to NVTE_MXFP8_1D_SCALING just before a GEMM. The tensors' main data is not touched at all, only the format of the scaling factors is changed to be compatible with MXFP8.

The FP8 block scaling tensors are created (quantized from higher precision) using quantize_transpose_vector_blockwise or quantize_transpose_square_blockwise - the same as on Hopper. This means that contrary to simply switching to the MXFP8 recipe, the 1x128 and 128x128 quantization block size is preserved (if MXFP8 was used, the scaling factors could be different, as they would correspond to 1x32 blocks).

To make the tensors valid inputs to the MXFP8 GEMM, their scaling factors are converted from the FP8 block scaling format to the MXFP8 format. I take advantage of the fact that when entering the GEMM, the FP8 Block Scaling tensors are guaranteed to already use the GEMM_READY scaling factor format. In case of 2D (128x128) block scaling, the scaling factors are simply "unsqueezed" by a factor of 512 - a single scaling factor for a 128x128 block becomes 512 scaling factors for 512 1x32 blocks constituting the 128x128 block. In case of 1D (1x128) block scaling, every 128 scaling factors corresponding to a 128x128 block are swizzled to match the cuBLASLt MXFP8 GEMM format and "unsqueezed" by a factor of 4 to correspond to the 512 1x32 blocks.

Limitations

  1. The conversion from FP8 block scaling scaling factors to MXFP8 scaling factors is lossless if and only if the original FP8 block scaling scaling factors are powers of 2. This is because the MXFP8 scaling factors are not FP32 (like the FP8 block scaling scaling factors), but rather FP8E8M0. The conversion kernels assume this requirement is met and simply extract the FP32 exponent bits and treat them as FP8E8M0. If either the sign bit or any mantissa bits are set, the results will be incorrect, as the exponent bits are not masked out when performing bit shifts. Masking them out would result in numerical discrepancies in the output anyway due to discarding them.

  2. Despite losslessly converting the tensors, the GEMM outputs are not identical between Hopper and Blackwell because the Blackwell MXFP8 cuBLASLt GEMM is implemented differently from the Hopper FP8 block scaling cuBLASLt GEMM. However, the numerical error overall should be smaller compared to simply switching to the MXFP8 recipe.

  3. Contrary to FP8 Block Scaling on Hopper, GEMM+GELU fusion is not currently supported as cuBLASLt doesn't support it for MXFP8 (with BF16 output, which the FP8 Block Scaling recipe uses).

Future optimizations to pursue

  • Take advantage of Blackwell support for FP8 non-TN GEMMs. Contrary to Hopper, Blackwell supports non-TN GEMMs for FP8. This leads to at least the following two optimizations:

    1. Don't create columnwise data for weights. Because weights are 2D-block-scaled, their quantized columnwise data is simply the transpose of the rowwise data. As such, columnwise data doesn't have to be created in the forward pass, as the rowwise data can be used in the dgad GEMM.
    2. Don't transpose data after all gather. Currently, _post_process_fp8_blockwise_gather in distributed.py transposes columnwise data after all gather. This could be avoided.
  • Use a multi kernel for scaling factor swizzling. Currently, the support for te_general_grouped_gemm, which is needed for MoE, is naive. The tensor conversion function convert_block_scaling_to_mxfp8_tensor is simply called in a loop for every tensor. This can result in many small kernel launches, which is inefficient. Instead, a multi kernel approach could be used, similar to how MXFP8 scaling factor swizzling is handled.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

WIP

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@janekb04 janekb04 changed the title Deepseek blackwell Add support for the FP8 Block Scaling (ie. Deepseek) recipe on Blackwell Sep 5, 2025
@janekb04 janekb04 force-pushed the deepseek-blackwell branch 4 times, most recently from d7e794a to 7e7bf91 Compare September 9, 2025 21:59
@janekb04
Copy link
Collaborator Author

/te-ci pytorch L0

@janekb04
Copy link
Collaborator Author

/te-ci pytorch L0

@janekb04
Copy link
Collaborator Author

/te-ci L0

@janekb04
Copy link
Collaborator Author

/te-ci L0 L1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants