Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions tests/cpp/operator/test_cast_float8blockwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,12 @@ TEST_P(FusedCastFloat8BlockwiseTestSuite, TestFusedCastFloat8Blockwise) {
q_opts.amax_epsilon = eps;
q_opts.block_scaling_dim = 2u;

// On Blackwell and newer, the FP8 block scaling recipe is emulated with MXFP8,
// which requires using power of two scaling factors. Skip unsupported tests.
if (getDeviceComputeCapability() > hopperComputeCapability && !force_pow_2) {
GTEST_SKIP();
}

if (colwise && matrix_size.size() < 2) {
// test_common Tensor initialization code does not
// handle this case.
Expand Down Expand Up @@ -552,6 +558,12 @@ TEST_P(FusedCastFloat8VectorwiseTestSuite, TestFusedCastFloat8Vectorwise) {
q_opts.amax_epsilon = eps;
q_opts.block_scaling_dim = 1u;

// On Blackwell and newer, the FP8 block scaling recipe is emulated with MXFP8,
// which requires using power of two scaling factors. Skip unsupported tests.
if (getDeviceComputeCapability() > hopperComputeCapability && !force_pow_2) {
GTEST_SKIP();
}

if (colwise && matrix_size.size() < 2) {
// test_common Tensor initialization code does not
// handle this case.
Expand Down
4 changes: 3 additions & 1 deletion tests/pytorch/test_float8_blockwise_gemm_exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import transformer_engine_torch as tex

from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.utils import get_device_compute_capability
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
Float8BlockQuantizer,
Expand All @@ -19,7 +20,8 @@

def fp8_blockwise_gemm_supported() -> bool:
supported, _ = FP8GlobalStateManager.is_fp8_block_scaling_available()
return supported
emulated = get_device_compute_capability() > (9, 0)
return supported and not emulated


def cublas_gemm_fp8_blockwise_case(
Expand Down
14 changes: 14 additions & 0 deletions tests/pytorch/test_float8_blockwise_scaling_exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.common.recipe import Float8BlockScaling
from transformer_engine.pytorch.utils import get_device_compute_capability
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
Float8BlockQuantizer,
Expand All @@ -32,6 +33,7 @@
if tensor_dump_dir_env is not None:
TENSOR_DUMP_DIR = pathlib.Path(tensor_dump_dir_env)
recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_fp8_block_scaling_available()
recipe_emulated = get_device_compute_capability() > (9, 0)


class GetRecipes:
Expand Down Expand Up @@ -218,6 +220,12 @@ def check_quantization_block_tiling_versus_reference(
pow_2_scales: bool,
tile_size: Tuple[int, int],
) -> None:
if recipe_emulated and not pow_2_scales:
pytest.skip(
"On Blackwell and newer, the FP8 block scaling recipe is emulated "
"with MXFP8, which requires using power of two scaling factors."
)

te_dtype = TE_DType[quant_dtype]
if tile_size == (1, 128):
block_scaling_dim = 1
Expand Down Expand Up @@ -409,6 +417,12 @@ def test_quantization_block_tiling_extrema_versus_reference(
tile_size: Tuple[int, int],
extrema_high: bool,
) -> None:
if recipe_emulated and not pow_2_scales:
pytest.skip(
"On Blackwell and newer, the FP8 block scaling recipe is emulated "
"with MXFP8, which requires using power of two scaling factors."
)

# This test runs a single tile through a quantizer as a way to test
# branch coverage of scale computation.
te_dtype = TE_DType[quant_dtype]
Expand Down
1 change: 1 addition & 0 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ list(APPEND transformer_engine_SOURCES
util/multi_stream.cpp
util/rtc.cpp
swizzle/swizzle.cu
swizzle/swizzle_block_scaling.cu
fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu
Expand Down
20 changes: 20 additions & 0 deletions transformer_engine/common/include/transformer_engine/swizzle.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,26 @@ void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cud
void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETensor* outputs,
const size_t num_tensors, cudaStream_t stream);

/*! \brief Swizzling FP8 block scaling scaling factors into mxfp8 interleaved layout for GEMM
*
* \param[in] input Input FP8 block scaling tensor with GEMM_READY scale_inv.
* \param[in,out] output Output mxfp8 tensor which hosts swizzled scale_inv.
* \param[in] stream CUDA stream used for the operation.
*
* This function is used for emulating the FP8 block scaling recipe on Blackwell and newer as it
* not natively supported by cublasLt on architectures other than Hopper.

* Requirements:
* - input is an FP8 block scaling tensor
* - input has rowwise usage
* - input.scale_inv is in GEMM_READY format
* - output is an MXFP8 tensor
* - output has rowwise usage
* - output.scale_inv has appropriate shape
* */
void nvte_swizzle_block_scaling_to_mxfp8_scaling_factors(const NVTETensor input, NVTETensor output,
cudaStream_t stream);

#ifdef __cplusplus
} // extern "C"
#endif
Expand Down
Loading