Skip to content

Conversation

@timmoon10
Copy link
Collaborator

@timmoon10 timmoon10 commented Dec 6, 2025

Description

All of the supported block-scaled tensor formats (MXFP8, NVFP4, DSv3 FP8) have two ways of ordering their scaling factors:

  • "Compact" ordering for quantization, dequantization, and communication
  • "Swizzled" ordering for GEMM

The core infrastructure handles this in an ad hoc way, blindly assuming that the "right" scale ordering is used for the different operations. The PyTorch infrastructure only supports MXFP8 and NVFP4 scales are in compact order, although DSv3 FP8 does have awareness of "compact" and "GEMM-ready" formats. This situation makes it hard to implement fused kernels that can bypass the swizzle kernel.

This PR adds a with_gemm_swizzled_scales field in the C++ tensor class so that the core infrastructure can distinguish between the different scale orderings. It also adds this field in the PyTorch quantized tensor classes, and exposes a optimize_for_gemm option in the quantizer so that we can create tensors that do not need communication or checkpointing. Finally, it rips out all the DSv3 FP8 infrastructure for the compact format, which is no longer necessary.

Progress

  • MXFP8
  • DSv3 FP8
  • NVFP4
  • Add option to pre-swizzle weights
  • Pre-swizzle activations
  • Fused MXFP8 quantize + swizzle

Closes #2446.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Support GEMM swizzled scales in C++ tensor class
  • Support GEMM swizzled scales in PyTorch quantized tensor classes
  • Support optimize_for_gemm option in PyTorch quantizer
  • Expose PyTorch function to swizzle scales
  • Support MXFP8 quantization with pre-swizzled scales
  • Enable fused quantize+swizzle kernels in linear module and related
  • Remove DSv3 FP8 compact data format. It was used to avoid all-gather interleaving, which we can now fix with the swap-first-dims kernel.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@timmoon10 timmoon10 force-pushed the tmoon/pre-swizzled-scales branch from d274220 to 52ce3a4 Compare December 6, 2025 02:53
@timmoon10 timmoon10 added enhancement New feature or request refactor labels Dec 6, 2025
@timmoon10 timmoon10 force-pushed the tmoon/pre-swizzled-scales branch from 4925b63 to 1de4b5e Compare December 10, 2025 07:19
@timmoon10
Copy link
Collaborator Author

/te-ci

@timmoon10
Copy link
Collaborator Author

/te-ci

@timmoon10 timmoon10 marked this pull request as ready for review December 12, 2025 08:21
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Dec 12, 2025

Greptile Overview

Greptile Summary

This PR adds explicit tracking of scale ordering formats for block-scaled tensors (MXFP8, NVFP4, DSv3 FP8), distinguishing between "compact" ordering for quantization/communication and "swizzled" ordering for GEMM operations.

Key changes:

  • Added with_gemm_swizzled_scales field to C++ Tensor class and all PyTorch quantized tensor classes (MXFP8, NVFP4, Float8Blockwise)
  • Added optimize_for_gemm option to Quantizer base class, enabling fused quantize+swizzle kernels
  • Implemented new swizzle_scales_for_gemm and multi_tensor_swizzle_scales_for_gemm functions in swizzle.cpp
  • Extended MXFP8 quantization kernel to support direct swizzled scale output via WITH_GEMM_SWIZZLED_SCALES template parameter
  • Added runtime validation in cuBLAS GEMM that scales are in expected format
  • Removed deprecated Float8BlockScaleTensorFormat::COMPACT and all_gather_usage fields
  • Refactored FP8 blockwise all-gather to use swap-first-dims kernel instead of compact format

Benefits:

  • Enables bypassing the swizzle kernel when fused kernels can write swizzled scales directly
  • Cleaner API with explicit scale format tracking vs implicit assumptions
  • Simplifies distributed communication by removing compact format complexity

Confidence Score: 5/5

  • This PR is safe to merge - it's a well-structured refactoring that adds explicit tracking for scale formats with proper validation at API boundaries.
  • The changes are comprehensive and consistent across the C++ and Python layers. The PR adds runtime validation in cuBLAS GEMM to ensure scales are in the expected format, providing a safety net. All tensor classes properly propagate the new with_gemm_swizzled_scales field through view, reshape, serialization, and copy operations. The deprecated compact format removal is clean and the new swap-first-dims approach for distributed communication is simpler.
  • No files require special attention - the implementation is consistent across all modified files.

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/common/include/transformer_engine/transformer_engine.h 5/5 Added kNVTEWithGEMMSwizzledScales tensor parameter and nvte_set_tensor_param_v2/nvte_get_tensor_param_v2 APIs for handling non-NVTEBasicTensor parameters. Deprecated Float8BlockScaleTensorFormat::COMPACT.
transformer_engine/pytorch/csrc/extensions/swizzle.cpp 5/5 New file with swizzle_scales_for_gemm, multi_tensor_swizzle_scales_for_gemm, convert_block_scaling_to_mxfp8_tensor, and inplace_swizzle_scale_for_gemm functions for scale format conversion.
transformer_engine/pytorch/quantized_tensor.py 5/5 Added optimize_for_gemm field to Quantizer base class, enabling tensors to be created with pre-swizzled scales for GEMM optimization.
transformer_engine/pytorch/tensor/mxfp8_tensor.py 5/5 Added _with_gemm_swizzled_scales tracking throughout MXFP8Tensor class including new, view, reshape, FSDP2 operations, serialization, and copy operations.
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py 5/5 Removed all_gather_usage and data_format fields from Float8BlockQuantizer and Float8BlockwiseQTensor, simplifying the interface to always use GEMM-ready format.
transformer_engine/pytorch/distributed.py 5/5 Refactored FP8 blockwise all-gather to use _AsyncHandle class with post-processing. Removed compact format handling, now uses swap-first-dims kernel for interleaving fix.
transformer_engine/pytorch/csrc/quantizer.cpp 5/5 Added optimize_for_gemm field handling. Removed all_gather_usage and compact format logic from Float8BlockQuantizer. MXFP8Quantizer now uses optimize_for_gemm to set swizzled scales.
transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh 5/5 Added WITH_GEMM_SWIZZLED_SCALES template parameter and gemm_swizzled_scale_idx function to support fused quantize+swizzle in a single kernel.
transformer_engine/common/gemm/cublaslt_gemm.cu 5/5 Added runtime checks that MXFP8 and NVFP4 scales are in GEMM-swizzled format before GEMM execution. Ensures correct format at API boundary.

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant Quantizer as Quantizer
    participant QTensor as QuantizedTensor
    participant Swizzle as swizzle_scales_for_gemm
    participant GEMM as cuBLAS GEMM

    User->>Quantizer: set optimize_for_gemm=True
    Quantizer->>QTensor: create_tensor(with_gemm_swizzled_scales=True)
    Note over QTensor: MXFP8 quantize kernel<br/>writes swizzled scales directly
    
    User->>GEMM: gemm(A, B)
    GEMM->>QTensor: check with_gemm_swizzled_scales
    alt Scales already swizzled
        GEMM->>GEMM: Use scales directly
    else Scales not swizzled
        GEMM->>Swizzle: swizzle_scales_for_gemm(tensor)
        Swizzle-->>GEMM: Return swizzled scales buffer
        Note over GEMM: Keep buffer alive during GEMM
    end
    GEMM->>GEMM: Execute cuBLAS GEMM
    GEMM-->>User: Return result
Loading

greptile-apps[bot]

This comment was marked as resolved.

greptile-apps[bot]

This comment was marked as resolved.

@timmoon10
Copy link
Collaborator Author

/te-ci L1

greptile-apps[bot]

This comment was marked as resolved.

greptile-apps[bot]

This comment was marked as resolved.

greptile-apps[bot]

This comment was marked as resolved.

@timmoon10
Copy link
Collaborator Author

/te-ci L1

greptile-apps[bot]

This comment was marked as outdated.

@timmoon10 timmoon10 added performance Performance issues MoE labels Dec 15, 2025
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

65 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request MoE performance Performance issues refactor

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support MXFP8/NVFP4 tensors with pre-swizzled scales

1 participant