diff --git a/benchmarks/benchmark_blockwise_scaled_linear_triton.py b/benchmarks/benchmark_blockwise_scaled_linear_triton.py index ffdd63ec8d..1550a29b2e 100644 --- a/benchmarks/benchmark_blockwise_scaled_linear_triton.py +++ b/benchmarks/benchmark_blockwise_scaled_linear_triton.py @@ -18,6 +18,18 @@ fp8_blockwise_act_quant, fp8_blockwise_weight_quant, ) + # Import training kernels for comparison + from torchao.prototype.blockwise_fp8_training.kernels import ( + blockwise_fp8_gemm_1x128_128x128, + fp8_blockwise_act_quant_lhs, + fp8_blockwise_weight_quant_transposed_rhs, + ) + from torchao.prototype.blockwise_fp8_training.scaled_mm_kernels import ( + blockwise_fp8_gemm_scaled_mm_1x128_128x128, + ) + from torchao.prototype.blockwise_fp8_training.linear import ( + Float8BlockwiseLinear, + ) from torchao.utils import is_sm_at_least_89 else: raise RuntimeError("This benchmark is only avaible on CUDA hardware") @@ -74,6 +86,54 @@ def benchmark_latency( } +def benchmark_training_kernels_latency( + m: int, k: int, n: int, block_size: int, dtype: torch.dtype, device +): + """Benchmark training kernels: Triton vs torch._scaled_mm implementations.""" + # Create reference tensors + A_ref = torch.randn((m, k), dtype=torch.bfloat16, device=device) + B_ref = torch.randn((k, n), dtype=torch.bfloat16, device=device) + fp16_time = benchmark_microseconds(torch.nn.functional.linear, A_ref, B_ref) + + # Create quantized inputs for training kernels + A_fp8, A_scale = fp8_blockwise_act_quant_lhs(A_ref, block_size) + B_fp8, B_scale = fp8_blockwise_weight_quant_transposed_rhs(B_ref, block_size) + + # Benchmark Triton training kernel + try: + triton_time = benchmark_microseconds( + blockwise_fp8_gemm_1x128_128x128, + A_fp8, 1.0 / A_scale, B_fp8, 1.0 / B_scale + ) + except Exception as e: + print(f"Triton kernel failed: {e}") + triton_time = float('inf') + + # Benchmark torch._scaled_mm training kernel + try: + scaled_mm_time = benchmark_microseconds( + blockwise_fp8_gemm_scaled_mm_1x128_128x128, + A_fp8, 1.0 / A_scale, B_fp8, 1.0 / B_scale, block_size + ) + except Exception as e: + print(f"Scaled MM kernel failed: {e}") + scaled_mm_time = float('inf') + + return { + "m": m, + "k": k, + "n": n, + "block_size": block_size, + "dtype": dtype, + "fp16_latency (ms)": fp16_time, + "triton_training_latency (ms)": triton_time, + "scaled_mm_training_latency (ms)": scaled_mm_time, + "triton_training_speedup": fp16_time / triton_time if triton_time != float('inf') else 0, + "scaled_mm_training_speedup": fp16_time / scaled_mm_time if scaled_mm_time != float('inf') else 0, + "scaled_mm_vs_triton_speedup": triton_time / scaled_mm_time if triton_time != float('inf') and scaled_mm_time != float('inf') else 0, + } + + def benchmark_precision( m: int, k: int, n: int, block_size: int, dtype: torch.dtype, device ): @@ -96,20 +156,90 @@ def benchmark_precision( } +def benchmark_training_kernels_precision( + m: int, k: int, n: int, block_size: int, dtype: torch.dtype, device +): + """Benchmark precision of training kernels: Triton vs torch._scaled_mm.""" + # Create high precision reference + A_ref = torch.randn((m, k), dtype=torch.bfloat16, device=device) + B_ref = torch.randn((k, n), dtype=torch.bfloat16, device=device) + ref_output = torch.nn.functional.linear(A_ref, B_ref) + + # Create quantized inputs + A_fp8, A_scale = fp8_blockwise_act_quant_lhs(A_ref, block_size) + B_fp8, B_scale = fp8_blockwise_weight_quant_transposed_rhs(B_ref, block_size) + + results = { + "m": m, "k": k, "n": n, "block_size": block_size, "dtype": dtype + } + + # Test Triton kernel + try: + triton_output = blockwise_fp8_gemm_1x128_128x128( + A_fp8, 1.0 / A_scale, B_fp8, 1.0 / B_scale + ) + results["triton_error_db"] = compute_error(ref_output, triton_output) + except Exception as e: + print(f"Triton precision test failed: {e}") + results["triton_error_db"] = float('inf') + + # Test torch._scaled_mm kernel + try: + scaled_mm_output = blockwise_fp8_gemm_scaled_mm_1x128_128x128( + A_fp8, 1.0 / A_scale, B_fp8, 1.0 / B_scale, block_size + ) + results["scaled_mm_error_db"] = compute_error(ref_output, scaled_mm_output) + except Exception as e: + print(f"Scaled MM precision test failed: {e}") + results["scaled_mm_error_db"] = float('inf') + + # Compare the two implementations + if results["triton_error_db"] != float('inf') and results["scaled_mm_error_db"] != float('inf'): + try: + triton_output = blockwise_fp8_gemm_1x128_128x128( + A_fp8, 1.0 / A_scale, B_fp8, 1.0 / B_scale + ) + scaled_mm_output = blockwise_fp8_gemm_scaled_mm_1x128_128x128( + A_fp8, 1.0 / A_scale, B_fp8, 1.0 / B_scale, block_size + ) + results["triton_vs_scaled_mm_error_db"] = compute_error(triton_output, scaled_mm_output) + except Exception: + results["triton_vs_scaled_mm_error_db"] = float('inf') + else: + results["triton_vs_scaled_mm_error_db"] = float('inf') + + return results + + if __name__ == "__main__" and torch.cuda.is_available(): device = torch.device("cuda") + + # Original inference benchmark configurations k_vals = (8192, 8192, 8192, 28672) n_vals = (8192, 10240, 57344, 8192) block_size_vals = (128, 128, 128, 128) + + # Training kernel benchmark configurations (smaller set for faster testing) + training_configs = [ + (1, 4096, 4096), # Single token + (32, 4096, 4096), # Small batch + (8, 4096, 11008), # MLP up projection + (8, 11008, 4096), # MLP down projection + (1, 4096, 128256), # Vocab projection (if memory allows) + ] latency_results = [] precision_results = [] + training_latency_results = [] + training_precision_results = [] available_dtypes = ( [torch.float8_e4m3fn, torch.float8_e5m2] if is_sm_at_least_89() else [torch.float8_e5m2] ) + + print("Running original inference benchmarks...") for m in tqdm([1 << i for i in range(14)]): for dtype in available_dtypes: for n, k, block_size in zip(n_vals, k_vals, block_size_vals): @@ -119,12 +249,42 @@ def benchmark_precision( precision_results.append( benchmark_precision(m, k, n, block_size, dtype, device) ) + + print("Running training kernel benchmarks...") + for m, k, n in tqdm(training_configs): + # Only test on fp8_e4m3fn for training (most common) + if k % 128 == 0 and n % 128 == 0: # Ensure divisibility + try: + training_latency_results.append( + benchmark_training_kernels_latency(m, k, n, 128, torch.float8_e4m3fn, device) + ) + training_precision_results.append( + benchmark_training_kernels_precision(m, k, n, 128, torch.float8_e4m3fn, device) + ) + except Exception as e: + print(f"Skipping training config ({m}, {k}, {n}): {e}") - df_latency = pd.DataFrame(latency_results) - df_precision = pd.DataFrame(precision_results) - - df_latency.to_csv("blockwise_triton_latency_results.csv", index=False) - df_precision.to_csv("blockwise_triton_precision_results.csv", index=False) + # Save results + if latency_results: + df_latency = pd.DataFrame(latency_results) + df_latency.to_csv("blockwise_triton_inference_latency_results.csv", index=False) + print("\nInference Latency Results:") + print(df_latency.to_markdown(index=False)) - print(df_latency.to_markdown(index=False)) - print(df_precision.to_markdown(index=False)) + if precision_results: + df_precision = pd.DataFrame(precision_results) + df_precision.to_csv("blockwise_triton_inference_precision_results.csv", index=False) + print("\nInference Precision Results:") + print(df_precision.to_markdown(index=False)) + + if training_latency_results: + df_training_latency = pd.DataFrame(training_latency_results) + df_training_latency.to_csv("blockwise_training_kernels_latency_results.csv", index=False) + print("\nTraining Kernels Latency Results:") + print(df_training_latency.to_markdown(index=False)) + + if training_precision_results: + df_training_precision = pd.DataFrame(training_precision_results) + df_training_precision.to_csv("blockwise_training_kernels_precision_results.csv", index=False) + print("\nTraining Kernels Precision Results:") + print(df_training_precision.to_markdown(index=False)) diff --git a/benchmarks/prototype/blockwise_fp8_training/benchmark_triton_vs_scaled_mm.py b/benchmarks/prototype/blockwise_fp8_training/benchmark_triton_vs_scaled_mm.py new file mode 100644 index 0000000000..333c2ae084 --- /dev/null +++ b/benchmarks/prototype/blockwise_fp8_training/benchmark_triton_vs_scaled_mm.py @@ -0,0 +1,349 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Microbenchmark script to compare Triton kernels vs torch._scaled_mm native blockwise scaling +for blockwise fp8 GEMM operations. + +This provides a proper 1:1 comparison between the Triton blockwise implementation +and the torch._scaled_mm native blockwise scaling with CUDA 12.9+, as recommended +by danielvegamyhre to avoid uncoalesced memory access issues in Triton kernels. +""" + +import torch +import pandas as pd +from typing import Tuple, Optional +from tqdm import tqdm + +if torch.cuda.is_available(): + from triton.testing import do_bench + from torchao.float8.float8_utils import compute_error + from torchao.prototype.blockwise_fp8_training.kernels import ( + blockwise_fp8_gemm_1x128_128x128, + blockwise_fp8_gemm_1x128_128x1, + fp8_blockwise_act_quant_lhs, + fp8_blockwise_act_quant_rhs, + fp8_blockwise_act_quant_transposed_lhs, + fp8_blockwise_weight_quant_rhs, + fp8_blockwise_weight_quant_transposed_rhs, + ) + from torchao.utils import is_sm_at_least_90 +else: + raise RuntimeError("This benchmark is only available on CUDA hardware") + + +def benchmark_microseconds(f, *args, warmup=25, rep=100): + """Benchmark function in microseconds""" + return ( + do_bench(lambda: f(*args), warmup=warmup, rep=rep, return_mode="median") * 1e3 + ) + + + + +def blockwise_fp8_scaled_mm_1x128_128x128_reference( + a_fp8: torch.Tensor, + a_scale: torch.Tensor, + b_fp8: torch.Tensor, + b_scale: torch.Tensor, +) -> torch.Tensor: + """ + Reference implementation using native torch._scaled_mm with blockwise scaling. + This uses the CUDA 12.9+ native blockwise scaling support to provide optimal + performance through direct CUTLASS kernel usage, avoiding the uncoalesced + memory access issues present in Triton kernels. + """ + from torchao.prototype.blockwise_fp8_training.scaled_mm_kernels import ( + blockwise_fp8_gemm_scaled_mm_1x128_128x128 + ) + + return blockwise_fp8_gemm_scaled_mm_1x128_128x128( + a_fp8, + 1.0 / a_scale, + b_fp8, + 1.0 / b_scale, + block_size=128 + ) + + +def create_test_tensors( + m: int, k: int, n: int, block_size: int = 128, device="cuda" +) -> Tuple: + """Create test tensors for benchmarking""" + # Create high precision reference tensors + a_ref = torch.randn(m, k, device=device, dtype=torch.bfloat16) + b_ref = torch.randn(k, n, device=device, dtype=torch.bfloat16) + + # Quantize activation (A) with 1x128 blockwise scaling + a_fp8, a_scale = fp8_blockwise_act_quant_lhs(a_ref, block_size) + + # Quantize weight (B) with 128x128 blockwise scaling, transposed dims in column major + b_fp8, b_scale = fp8_blockwise_weight_quant_transposed_rhs(b_ref, block_size) + + return a_ref, b_ref, a_fp8, a_scale, b_fp8, b_scale + + +def create_test_tensors_128x1( + m: int, k: int, n: int, block_size: int = 128, device="cuda" +): + """Create test tensors for 1x128 (LHS) x 128x1 (RHS) blockwise GEMM.""" + # High-precision reference tensors + a_ref = torch.randn(m, k, device=device, dtype=torch.bfloat16) + b_ref = torch.randn(k, n, device=device, dtype=torch.bfloat16) + + # LHS: use transposed-lhs quantization. Input to that kernel should be KxM + a_t = a_ref.t().contiguous() + a_fp8, a_scale = fp8_blockwise_act_quant_transposed_lhs(a_t, block_size) + + # RHS: 128x1 scaling along K + b_fp8, b_scale = fp8_blockwise_act_quant_rhs(b_ref, block_size) + + return a_ref, b_ref, a_fp8, a_scale, b_fp8, b_scale + + +def benchmark_gemm_variants( + m: int, k: int, n: int, block_size: int = 128, device="cuda" +) -> dict: + """Benchmark different GEMM implementations""" + + # Create test tensors + a_ref, b_ref, a_fp8, a_scale, b_fp8, b_scale = create_test_tensors( + m, k, n, block_size, device + ) + + results = { + "m": m, "k": k, "n": n, "block_size": block_size + } + + # Benchmark reference bf16 GEMM + bf16_time = benchmark_microseconds(torch.nn.functional.linear, a_ref, b_ref) + results["bf16_time_us"] = bf16_time + + # Benchmark Triton blockwise fp8 GEMM + triton_time = benchmark_microseconds( + blockwise_fp8_gemm_1x128_128x128, + a_fp8, 1.0 / a_scale, b_fp8, 1.0 / b_scale, block_size + ) + results["triton_time_us"] = triton_time + + # Benchmark torch._scaled_mm (native blockwise scaling with CUDA 12.9+) + try: + scaled_mm_time = benchmark_microseconds( + blockwise_fp8_scaled_mm_1x128_128x128_reference, + a_fp8, a_scale, b_fp8, b_scale + ) + results["scaled_mm_time_us"] = scaled_mm_time + except Exception as e: + print(f"Warning: torch._scaled_mm native blockwise benchmark failed: {e}") + print(f"Note: Requires CUDA 12.9+ for native blockwise scaling support") + results["scaled_mm_time_us"] = float('inf') + + # Calculate speedups + results["triton_speedup"] = bf16_time / triton_time if triton_time > 0 else 0 + results["scaled_mm_speedup"] = ( + bf16_time / results["scaled_mm_time_us"] + if results["scaled_mm_time_us"] > 0 and results["scaled_mm_time_us"] != float('inf') + else 0 + ) + + return results + + +def blockwise_fp8_scaled_mm_1x128_128x1_reference( + a_fp8: torch.Tensor, + a_scale: torch.Tensor, + b_fp8: torch.Tensor, + b_scale: torch.Tensor, + block_size: int = 128, +) -> torch.Tensor: + from torchao.prototype.blockwise_fp8_training.scaled_mm_kernels import ( + blockwise_fp8_gemm_scaled_mm_1x128_128x1, + ) + return blockwise_fp8_gemm_scaled_mm_1x128_128x1( + a_fp8, + 1.0 / a_scale, + b_fp8, + 1.0 / b_scale, + block_size, + ) + + +def benchmark_gemm_variants_128x1( + m: int, k: int, n: int, block_size: int = 128, device="cuda" +) -> dict: + """Benchmark 1x128 (LHS) x 128x1 (RHS) blockwise GEMM variants.""" + a_ref, b_ref, a_fp8, a_scale, b_fp8, b_scale = create_test_tensors_128x1( + m, k, n, block_size, device + ) + + results = {"m": m, "k": k, "n": n, "block_size": block_size, "case": "1x128_128x1"} + + # Reference bf16 GEMM + bf16_time = benchmark_microseconds(torch.nn.functional.linear, a_ref, b_ref) + results["bf16_time_us"] = bf16_time + + # Triton + triton_time = benchmark_microseconds( + blockwise_fp8_gemm_1x128_128x1, + a_fp8, 1.0 / a_scale, b_fp8, 1.0 / b_scale, block_size + ) + results["triton_time_us"] = triton_time + + # Native torch._scaled_mm + try: + scaled_mm_time = benchmark_microseconds( + blockwise_fp8_scaled_mm_1x128_128x1_reference, + a_fp8, a_scale, b_fp8, b_scale, block_size + ) + results["scaled_mm_time_us"] = scaled_mm_time + except Exception as e: + print(f"Warning: torch._scaled_mm native blockwise 128x1 benchmark failed: {e}") + results["scaled_mm_time_us"] = float('inf') + + results["triton_speedup"] = bf16_time / triton_time if triton_time > 0 else 0 + sm_time = results["scaled_mm_time_us"] + results["scaled_mm_speedup"] = bf16_time / sm_time if sm_time not in (0, float('inf')) else 0 + return results + + +def benchmark_precision( + m: int, k: int, n: int, block_size: int = 128, device="cuda" +) -> dict: + """Benchmark numerical precision of different implementations""" + + # Create test tensors + a_ref, b_ref, a_fp8, a_scale, b_fp8, b_scale = create_test_tensors( + m, k, n, block_size, device + ) + + # Reference computation + ref_output = torch.nn.functional.linear(a_ref, b_ref) + + # Triton blockwise fp8 computation + triton_output = blockwise_fp8_gemm_1x128_128x128( + a_fp8, 1.0 / a_scale, b_fp8, 1.0 / b_scale, block_size + ) + + results = { + "m": m, "k": k, "n": n, "block_size": block_size, + "triton_error_db": compute_error(ref_output, triton_output), + } + + # torch._scaled_mm precision (native blockwise scaling) + try: + scaled_mm_output = blockwise_fp8_scaled_mm_1x128_128x128_reference( + a_fp8, a_scale, b_fp8, b_scale + ) + results["scaled_mm_error_db"] = compute_error(ref_output, scaled_mm_output) + except Exception as e: + print(f"Warning: torch._scaled_mm native blockwise precision test failed: {e}") + print(f"Note: Requires CUDA 12.9+ for native blockwise scaling support") + results["scaled_mm_error_db"] = float('inf') + + return results + + +def benchmark_precision_128x1( + m: int, k: int, n: int, block_size: int = 128, device="cuda" +) -> dict: + """Precision benchmark for 1x128 x 128x1.""" + a_ref, b_ref, a_fp8, a_scale, b_fp8, b_scale = create_test_tensors_128x1( + m, k, n, block_size, device + ) + + ref_output = torch.nn.functional.linear(a_ref, b_ref) + + results = {"m": m, "k": k, "n": n, "block_size": block_size, "case": "1x128_128x1"} + + # Triton + triton_output = blockwise_fp8_gemm_1x128_128x1( + a_fp8, 1.0 / a_scale, b_fp8, 1.0 / b_scale, block_size + ) + + from torchao.float8.float8_utils import compute_error + results["triton_error_db"] = compute_error(ref_output, triton_output) + + # Native torch._scaled_mm + try: + scaled_mm_output = blockwise_fp8_scaled_mm_1x128_128x1_reference( + a_fp8, a_scale, b_fp8, b_scale, block_size + ) + results["scaled_mm_error_db"] = compute_error(ref_output, scaled_mm_output) + except Exception as e: + print(f"Warning: torch._scaled_mm native blockwise 128x1 precision failed: {e}") + results["scaled_mm_error_db"] = float('inf') + + return results + + +def run_benchmarks(): + """Run comprehensive benchmarks""" + if not is_sm_at_least_90(): + print("Warning: This benchmark requires SM90 or higher for optimal performance") + + # Test configurations - various matrix sizes commonly used in LLMs + test_configs = [ + # (M, K, N) - batch_size x hidden_dim x output_dim + (1, 4096, 4096), # Single token + (32, 4096, 4096), # Small batch + (128, 4096, 4096), # Medium batch + (1, 4096, 11008), # MLP up projection + (32, 4096, 11008), # MLP up projection, batched + (1, 11008, 4096), # MLP down projection + (32, 11008, 4096), # MLP down projection, batched + (1, 4096, 128256), # Vocab projection + (32, 4096, 128256), # Vocab projection, batched + ] + + print("Running performance benchmarks...") + perf_results = [] + for m, k, n in tqdm(test_configs): + if k % 128 == 0 and n % 128 == 0: # Ensure divisibility by block size + try: + result = benchmark_gemm_variants(m, k, n) + perf_results.append(result) + except Exception as e: + print(f"Error benchmarking {m}x{k}x{n}: {e}") + try: + result_128x1 = benchmark_gemm_variants_128x1(m, k, n) + perf_results.append(result_128x1) + except Exception as e: + print(f"Error benchmarking 128x1 {m}x{k}x{n}: {e}") + + print("Running precision benchmarks...") + precision_results = [] + for m, k, n in tqdm(test_configs): + if k % 128 == 0 and n % 128 == 0: + try: + result = benchmark_precision(m, k, n) + precision_results.append(result) + except Exception as e: + print(f"Error in precision test {m}x{k}x{n}: {e}") + try: + result_128x1 = benchmark_precision_128x1(m, k, n) + precision_results.append(result_128x1) + except Exception as e: + print(f"Error in 128x1 precision test {m}x{k}x{n}: {e}") + + # Save and display results + if perf_results: + perf_df = pd.DataFrame(perf_results) + perf_df.to_csv("triton_vs_scaled_mm_performance.csv", index=False) + print("\nPerformance Results:") + print(perf_df.to_markdown(index=False)) + + if precision_results: + precision_df = pd.DataFrame(precision_results) + precision_df.to_csv("triton_vs_scaled_mm_precision.csv", index=False) + print("\nPrecision Results:") + print(precision_df.to_markdown(index=False)) + + +if __name__ == "__main__": + if torch.cuda.is_available(): + run_benchmarks() + else: + print("CUDA not available. Skipping benchmarks.") diff --git a/test/prototype/blockwise_fp8_training/test_scaled_mm_kernels.py b/test/prototype/blockwise_fp8_training/test_scaled_mm_kernels.py new file mode 100644 index 0000000000..1690fc5764 --- /dev/null +++ b/test/prototype/blockwise_fp8_training/test_scaled_mm_kernels.py @@ -0,0 +1,280 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +from packaging import version + +from torchao.float8.float8_utils import compute_error +from torchao.utils import is_sm_at_least_90 + +triton = pytest.importorskip("triton", reason="Triton required to run this test") +if not is_sm_at_least_90(): + pytest.skip("This test requires SM90 or higher", allow_module_level=True) + +from torchao.prototype.blockwise_fp8_training.kernels import ( + blockwise_fp8_gemm_1x128_128x128, + blockwise_fp8_gemm_1x128_128x1, + fp8_blockwise_act_quant_lhs, + fp8_blockwise_act_quant_rhs, + fp8_blockwise_act_quant_transposed_lhs, + fp8_blockwise_weight_quant_rhs, + fp8_blockwise_weight_quant_transposed_rhs, +) +from torchao.prototype.blockwise_fp8_training.scaled_mm_kernels import ( + blockwise_fp8_gemm_scaled_mm_1x128_128x128, + blockwise_fp8_gemm_scaled_mm_1x128_128x1, + blockwise_fp8_scaled_mm_1x128_128x128, + blockwise_fp8_scaled_mm_1x128_128x1, +) +from torchao.prototype.blockwise_fp8_training.linear import ( + Float8BlockwiseLinear, + Float8BlockwiseLinearConfig, +) + +# Test matrix sizes covering various common LLM dimensions +SCALED_MM_TEST_SIZES = [ + (128, 128, 128), + (2, 512, 128), + (4, 4096, 4096), + (8, 4096, 11008), + (16, 11008, 4096), + (1, 4096, 128256), +] + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + version.parse(triton.__version__) < version.parse("3.3.0"), + reason="Triton version < 3.3.0, test skipped", +) +@pytest.mark.parametrize("M, N, K", SCALED_MM_TEST_SIZES) +@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) +def test_blockwise_fp8_scaled_mm_1x128_128x128_correctness(M, N, K, dtype): + """Test correctness of native torch._scaled_mm blockwise scaling vs Triton kernel.""" + if K % 128 != 0 or N % 128 != 0: + pytest.skip(f"Dimensions K={K}, N={N} must be divisible by 128") + + device = torch.device("cuda") + block_size = 128 + + # Create high-precision reference tensors + a_ref = torch.randn(M, K, device=device, dtype=torch.bfloat16) + b_ref = torch.randn(K, N, device=device, dtype=torch.bfloat16) + + # Quantize inputs using the same quantization functions + a_fp8, a_scale = fp8_blockwise_act_quant_lhs(a_ref, block_size) + b_fp8, b_scale = fp8_blockwise_weight_quant_transposed_rhs(b_ref, block_size) + + # Compute using Triton kernel + triton_output = blockwise_fp8_gemm_1x128_128x128( + a_fp8, + 1.0 / a_scale, + b_fp8, + 1.0 / b_scale, + ) + + # Compute using native torch._scaled_mm with blockwise scaling + scaled_mm_output = blockwise_fp8_gemm_scaled_mm_1x128_128x128( + a_fp8, + 1.0 / a_scale, + b_fp8, + 1.0 / b_scale, + block_size, + ) + + # Compare results - native blockwise scaling should be close to Triton + error_db = compute_error(triton_output, scaled_mm_output) + print(f"Error between Triton and native torch._scaled_mm (dB): {error_db}") + + # With native blockwise scaling, should have similar accuracy to Triton + assert error_db > -60, f"Error too large: {error_db} dB (expected reasonable accuracy with native blockwise scaling)" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + version.parse(triton.__version__) < version.parse("3.3.0"), + reason="Triton version < 3.3.0, test skipped", +) +@pytest.mark.parametrize("M, N, K", SCALED_MM_TEST_SIZES) +@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) +def test_blockwise_fp8_scaled_mm_1x128_128x1_correctness(M, N, K, dtype): + """Test correctness of native torch._scaled_mm blockwise scaling vs Triton kernel for 128x1 scaling.""" + if K % 128 != 0: + pytest.skip(f"Dimension K={K} must be divisible by 128") + + device = torch.device("cuda") + block_size = 128 + + # Create high-precision reference tensors + a_ref = torch.randn(M, K, device=device, dtype=torch.bfloat16) + b_ref = torch.randn(K, N, device=device, dtype=torch.bfloat16) + + # Quantize inputs - note different scaling pattern for this variant + a_fp8, a_scale = fp8_blockwise_act_quant_transposed_lhs(a_ref, block_size) + b_fp8, b_scale = fp8_blockwise_act_quant_rhs(b_ref, block_size) + + # Compute using Triton kernel + triton_output = blockwise_fp8_gemm_1x128_128x1( + a_fp8, + 1.0 / a_scale, + b_fp8, + 1.0 / b_scale, + block_size, + ) + + # Compute using native torch._scaled_mm with blockwise scaling + scaled_mm_output = blockwise_fp8_gemm_scaled_mm_1x128_128x1( + a_fp8, + 1.0 / a_scale, + b_fp8, + 1.0 / b_scale, + block_size, + ) + + # Compare results - native blockwise scaling should be close to Triton + error_db = compute_error(triton_output, scaled_mm_output) + print(f"Error between Triton and native torch._scaled_mm 128x1 (dB): {error_db}") + + # With native blockwise scaling, should have similar accuracy to Triton + assert error_db > -60, f"Error too large: {error_db} dB (expected reasonable accuracy with native blockwise scaling)" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("use_scaled_mm", [False, True]) +@pytest.mark.parametrize("M, N, K", [(4, 4096, 4096), (8, 4096, 11008)]) +def test_float8_blockwise_linear_forward_backward(use_scaled_mm, M, N, K): + """Test forward and backward passes with both Triton and scaled_mm backends.""" + if K % 128 != 0 or N % 128 != 0: + pytest.skip(f"Dimensions K={K}, N={N} must be divisible by 128") + + device = torch.device("cuda") + + # Create reference linear layer + ref_layer = torch.nn.Linear(K, N, bias=False, device=device, dtype=torch.bfloat16) + + # Create blockwise fp8 layer + test_layer = Float8BlockwiseLinear.from_float(ref_layer, use_scaled_mm=use_scaled_mm) + + # Create input + x = torch.randn(M, K, device=device, dtype=torch.bfloat16, requires_grad=True) + x_ref = x.clone().detach().requires_grad_(True) + + # Forward pass + y_ref = ref_layer(x_ref) + y_test = test_layer(x) + + # Check forward pass shapes + assert y_test.shape == y_ref.shape + + # Backward pass + grad_output = torch.randn_like(y_test) + + y_ref.backward(grad_output) + y_test.backward(grad_output.clone()) + + # Check gradient shapes + assert x.grad.shape == x_ref.grad.shape + assert test_layer.weight.grad.shape == ref_layer.weight.grad.shape + + print(f"Forward error (dB): {compute_error(y_ref, y_test)}") + print(f"Input gradient error (dB): {compute_error(x_ref.grad, x.grad)}") + print(f"Weight gradient error (dB): {compute_error(ref_layer.weight.grad, test_layer.weight.grad)}") + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_native_scaled_mm_vs_triton_accuracy(): + """Test that native torch._scaled_mm blockwise scaling matches Triton kernel accuracy.""" + device = torch.device("cuda") + M, K, N = 256, 1024, 512 # Divisible by 128 + block_size = 128 + + # Create test tensors + a_ref = torch.randn(M, K, device=device, dtype=torch.bfloat16) + b_ref = torch.randn(K, N, device=device, dtype=torch.bfloat16) + + # Quantize + a_fp8, a_scale = fp8_blockwise_act_quant_lhs(a_ref, block_size) + b_fp8, b_scale = fp8_blockwise_weight_quant_transposed_rhs(b_ref, block_size) + + # Native torch._scaled_mm implementation with blockwise scaling + scaled_mm_output = blockwise_fp8_scaled_mm_1x128_128x128( + a_fp8, 1.0 / a_scale, b_fp8, 1.0 / b_scale, block_size + ) + + # Triton reference + triton_output = blockwise_fp8_gemm_1x128_128x128( + a_fp8, 1.0 / a_scale, b_fp8, 1.0 / b_scale + ) + + # Check shapes + assert scaled_mm_output.shape == triton_output.shape + + # Compare accuracy - native blockwise scaling should be very close to Triton + # The main difference will be due to different computation order, not algorithmic differences + triton_error = compute_error(triton_output, scaled_mm_output) + print(f"Triton vs native torch._scaled_mm blockwise error (dB): {triton_error}") + + # With native blockwise scaling, should have similar accuracy to Triton + # Allow some difference due to different kernel implementations but should be close + assert triton_error > -60, f"Error too large: {triton_error} dB (expected reasonable accuracy with native blockwise scaling)" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_config_integration(): + """Test integration with configuration system.""" + device = torch.device("cuda") + + # Test both backends via config + ref_layer = torch.nn.Linear(512, 1024, bias=False, device=device, dtype=torch.bfloat16) + + # Test Triton backend + triton_config = Float8BlockwiseLinearConfig(use_scaled_mm=False) + triton_layer = Float8BlockwiseLinear.from_float(ref_layer, use_scaled_mm=triton_config.use_scaled_mm) + + # Test scaled_mm backend + scaled_mm_config = Float8BlockwiseLinearConfig(use_scaled_mm=True) + scaled_mm_layer = Float8BlockwiseLinear.from_float(ref_layer, use_scaled_mm=scaled_mm_config.use_scaled_mm) + + # Test forward passes + x = torch.randn(4, 512, device=device, dtype=torch.bfloat16) + + y_triton = triton_layer(x) + y_scaled_mm = scaled_mm_layer(x) + + assert y_triton.shape == y_scaled_mm.shape + assert not triton_layer.use_scaled_mm + assert scaled_mm_layer.use_scaled_mm + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_error_conditions(): + """Test various error conditions and edge cases.""" + device = torch.device("cuda") + + # Test unsupported block sizes + with pytest.raises(AssertionError, match="Only block_size=128 is supported"): + blockwise_fp8_scaled_mm_1x128_128x128( + torch.randn(128, 256, device=device, dtype=torch.float8_e4m3fn), + torch.randn(128, 2, device=device, dtype=torch.float32), + torch.randn(256, 128, device=device, dtype=torch.float8_e4m3fn), + torch.randn(2, 1, device=device, dtype=torch.float32), + block_size=64, # Unsupported + ) + + # Test tensor shape mismatches + with pytest.raises((RuntimeError, AssertionError)): + blockwise_fp8_scaled_mm_1x128_128x128( + torch.randn(128, 256, device=device, dtype=torch.float8_e4m3fn), + torch.randn(128, 2, device=device, dtype=torch.float32), + torch.randn(512, 128, device=device, dtype=torch.float8_e4m3fn), # Wrong K dim + torch.randn(4, 1, device=device, dtype=torch.float32), + block_size=128, + ) + + +if __name__ == "__main__": + pytest.main([__file__]) \ No newline at end of file diff --git a/torchao/prototype/blockwise_fp8_training/linear.py b/torchao/prototype/blockwise_fp8_training/linear.py index b32f3c0073..16d7e932d1 100644 --- a/torchao/prototype/blockwise_fp8_training/linear.py +++ b/torchao/prototype/blockwise_fp8_training/linear.py @@ -17,6 +17,10 @@ fp8_blockwise_weight_quant_rhs, fp8_blockwise_weight_quant_transposed_rhs, ) +from torchao.prototype.blockwise_fp8_training.scaled_mm_kernels import ( + blockwise_fp8_gemm_scaled_mm_1x128_128x1, + blockwise_fp8_gemm_scaled_mm_1x128_128x128, +) from torchao.quantization.transform_module import ( register_quantize_module_handler, ) @@ -25,7 +29,7 @@ class fp8_blockwise_mm(torch.autograd.Function): @staticmethod - def forward(ctx, x, weight, block_size): + def forward(ctx, x, weight, block_size, use_scaled_mm=False): assert block_size == 128, "Only support block_size=128" # Temporarily reshape x to 2D tensor @@ -42,21 +46,32 @@ def forward(ctx, x, weight, block_size): ) # out = input @ weight.T - out = blockwise_fp8_gemm_1x128_128x128( - x_fp8, - 1.0 / x_scale, - weight_t_fp8, - 1.0 / weight_t_scale, - ) + if use_scaled_mm: + out = blockwise_fp8_gemm_scaled_mm_1x128_128x128( + x_fp8, + 1.0 / x_scale, + weight_t_fp8, + 1.0 / weight_t_scale, + block_size, + ) + else: + out = blockwise_fp8_gemm_1x128_128x128( + x_fp8, + 1.0 / x_scale, + weight_t_fp8, + 1.0 / weight_t_scale, + ) out = out.reshape(*x_orig_shape[:-1], out.shape[-1]) ctx.save_for_backward(x, weight) ctx.block_size = block_size + ctx.use_scaled_mm = use_scaled_mm return out @staticmethod def backward(ctx, grad_output): x, weight = ctx.saved_tensors block_size = ctx.block_size + use_scaled_mm = ctx.use_scaled_mm # Reshape input to 2D x_orig_shape = x.shape @@ -80,12 +95,21 @@ def backward(ctx, grad_output): ) # grad_x = grad_output @ weight - grad_x = blockwise_fp8_gemm_1x128_128x128( - grad_output_fp8, - 1.0 / grad_output_scale, - weight_fp8, - 1.0 / weight_scale, - ) + if use_scaled_mm: + grad_x = blockwise_fp8_gemm_scaled_mm_1x128_128x128( + grad_output_fp8, + 1.0 / grad_output_scale, + weight_fp8, + 1.0 / weight_scale, + block_size, + ) + else: + grad_x = blockwise_fp8_gemm_1x128_128x128( + grad_output_fp8, + 1.0 / grad_output_scale, + weight_fp8, + 1.0 / weight_scale, + ) # Cast grad_output_t to fp8 blockwise with (1 x block_size) scaling groups, since it is # the grad of the output activation. @@ -101,12 +125,21 @@ def backward(ctx, grad_output): x_fp8, x_scale = fp8_blockwise_act_quant_rhs(x, block_size) # grad_weight = grad_output.T @ x - grad_weight = blockwise_fp8_gemm_1x128_128x1( - grad_output_t_fp8, - 1.0 / grad_output_t_scale, - x_fp8, - 1.0 / x_scale, - ) + if use_scaled_mm: + grad_weight = blockwise_fp8_gemm_scaled_mm_1x128_128x1( + grad_output_t_fp8, + 1.0 / grad_output_t_scale, + x_fp8, + 1.0 / x_scale, + block_size, + ) + else: + grad_weight = blockwise_fp8_gemm_1x128_128x1( + grad_output_t_fp8, + 1.0 / grad_output_t_scale, + x_fp8, + 1.0 / x_scale, + ) # Reshape grad_x to expected potentially 3D+ shape grad_x = grad_x.reshape(*grad_output_orig_shape[:-1], grad_x.shape[-1]) @@ -122,7 +155,9 @@ class Float8BlockwiseLinear(nn.Linear): out_features (int): Number of output features. bias (bool): Whether to include a bias term. Defaults to False. block_size (int): Block size for quantization. Defaults to 128. - dtype (torch.dtype): Data type for the weights. Defaults to torch.float8_e4m3fn. + dtype (torch.dtype): Data type for the weights. Defaults to torch.bfloat16. + use_scaled_mm (bool): Whether to use torch._scaled_mm instead of Triton kernels. + Defaults to False. """ supported_dtypes = [ @@ -134,6 +169,7 @@ def __init__( *args, block_size: int = 128, dtype=torch.bfloat16, + use_scaled_mm: bool = False, **kwargs, ): super().__init__(*args, **kwargs) @@ -144,6 +180,7 @@ def __init__( assert is_sm_at_least_90(), "Only support SM90" self.block_size = block_size self.dtype = dtype + self.use_scaled_mm = use_scaled_mm def forward(self, x: torch.Tensor) -> torch.Tensor: """ @@ -155,12 +192,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: Transformed tensor after linear computation. """ - return fp8_blockwise_mm.apply(x, self.weight, self.block_size) + return fp8_blockwise_mm.apply(x, self.weight, self.block_size, self.use_scaled_mm) @classmethod def from_float( cls, mod, + use_scaled_mm: bool = False, ): assert mod.bias is None, "unsupported" assert mod.in_features % 128 == 0, "unsupported" @@ -170,6 +208,7 @@ def from_float( mod.in_features, mod.out_features, bias=False, + use_scaled_mm=use_scaled_mm, ) new_mod.weight = mod.weight new_mod.bias = mod.bias @@ -177,9 +216,14 @@ def from_float( class Float8BlockwiseLinearConfig(AOBaseConfig): - pass + """Configuration for Float8BlockwiseLinear quantization.""" + + def __init__(self, use_scaled_mm: bool = False): + self.use_scaled_mm = use_scaled_mm @register_quantize_module_handler(Float8BlockwiseLinearConfig) def _float8_blockwise_transform(module, config): - return Float8BlockwiseLinear.from_float(module) + return Float8BlockwiseLinear.from_float( + module, use_scaled_mm=config.use_scaled_mm + ) diff --git a/torchao/prototype/blockwise_fp8_training/scaled_mm_kernels.py b/torchao/prototype/blockwise_fp8_training/scaled_mm_kernels.py new file mode 100644 index 0000000000..2a828744c2 --- /dev/null +++ b/torchao/prototype/blockwise_fp8_training/scaled_mm_kernels.py @@ -0,0 +1,300 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Implementation of blockwise fp8 GEMM operations using torch._scaled_mm native blockwise scaling. + +This implementation leverages the native blockwise scaling support in torch._scaled_mm +available with CUDA 12.9+, providing optimal performance through direct CUTLASS kernel usage. + +Based on PyTorch's native support for ScalingType.BlockWise128x128 and other blockwise modes, +this avoids the uncoalesced memory access issues present in custom Triton kernels. +""" + +from typing import Tuple + +import torch +import warnings + + + +def _check_cuda_version_for_native_blockwise(): + """Check if CUDA version supports native blockwise scaling in torch._scaled_mm.""" + try: + # Check if we're running with CUDA 12.9+ + cuda_version = torch.version.cuda + if cuda_version is None: + return False + + major, minor = map(int, cuda_version.split(".")[:2]) + return major > 12 or (major == 12 and minor >= 9) + except: + return False + + +def _outer_dim_major(t: torch.Tensor) -> torch.Tensor: + """Ensure a 2D scale tensor is outer-dim-major (stride(0) == 1). + + PyTorch's native blockwise scaled GEMM expects 1x128 scales to be + outer-dim-major. The idiom `t.t().contiguous().t()` preserves shape + while flipping strides to make the outer dimension contiguous. + """ + if t.ndim != 2: + return t + # Already outer-dim-major if stride(0) == 1 + if t.stride(0) == 1: + return t + return t.t().contiguous().t() + + +def blockwise_fp8_scaled_mm_1x128_128x128( + a: torch.Tensor, # (M, K) in fp8 + a_s: torch.Tensor, # (M, K // block_size) reciprocals of scales + b: torch.Tensor, # (K, N) in fp8, column-major + b_s: torch.Tensor, # (K // block_size, N // block_size) reciprocals of scales + block_size: int = 128, +) -> torch.Tensor: + """ + Blockwise fp8 GEMM using torch._scaled_mm with native blockwise scaling when available. + + This implementation attempts to use native blockwise scaling support in torch._scaled_mm + with CUDA 12.9+. Falls back to block-by-block processing if native support is unavailable. + + Args: + a: Input tensor (M, K) in fp8, row-major + a_s: Input scales (M, K // block_size), reciprocals (will be inverted) + b: Weight tensor (K, N) in fp8, column-major layout + b_s: Weight scales (K // block_size, N // block_size), reciprocals (will be inverted) + block_size: Block size for quantization (must be 128) + + Returns: + Output tensor (M, N) in bfloat16 + """ + assert block_size == 128, "Only block_size=128 is supported" + assert a.dtype == torch.float8_e4m3fn, f"Input a must be fp8_e4m3fn, got {a.dtype}" + assert b.dtype == torch.float8_e4m3fn, f"Input b must be fp8_e4m3fn, got {b.dtype}" + + # Convert reciprocal scales back to regular scales for torch._scaled_mm + scale_a = 1.0 / a_s + scale_b = 1.0 / b_s + + # For 1x128 on LHS, scales must be outer-dim-major (see PyTorch test_matmul_cuda.py) + scale_a = _outer_dim_major(scale_a) + + # Try native blockwise scaling first (requires CUDA 12.9+) + if _check_cuda_version_for_native_blockwise(): + try: + # Use native blockwise scaling with torch._scaled_mm + # This should dispatch to the CUTLASS kernel with native blockwise support + return torch._scaled_mm( + a, # (M, K) fp8, row-major + b, # (K, N) fp8, column-major - torch._scaled_mm should handle layout + scale_a=scale_a, # (M, K // 128) blockwise scales for input + scale_b=scale_b, # (K // 128, N // 128) blockwise scales for weight + out_dtype=torch.bfloat16, + use_fast_accum=True, + ) + except Exception as e: + warnings.warn( + f"Native blockwise scaling failed: {e}. Falling back to block-by-block processing. " + f"For optimal performance, ensure CUDA 12.9+ and compatible PyTorch version.", + RuntimeWarning + ) + + # Fallback: block-by-block processing to emulate blockwise scaling + # This preserves the blockwise precision but may be slower than native implementation + return _blockwise_fp8_scaled_mm_fallback_1x128_128x128(a, scale_a, b, scale_b, block_size) + + +def _blockwise_fp8_scaled_mm_fallback_1x128_128x128( + a: torch.Tensor, + scale_a: torch.Tensor, + b: torch.Tensor, + scale_b: torch.Tensor, + block_size: int = 128, +) -> torch.Tensor: + """ + Fallback implementation using block-by-block torch._scaled_mm calls. + + This emulates blockwise scaling by processing the computation in blocks, + preserving the precision benefits while remaining compatible with older CUDA versions. + """ + M, K = a.size() + N = b.size(1) + + k_blocks = K // block_size + n_blocks = N // block_size + + # Initialize output + output = torch.zeros(M, N, dtype=torch.bfloat16, device=a.device) + + # Process each (K_block, N_block) tile separately to preserve blockwise scaling + for k_idx in range(k_blocks): + k_start = k_idx * block_size + k_end = k_start + block_size + + # Extract K-block from inputs + a_block = a[:, k_start:k_end].contiguous() # (M, block_size) + a_scale_block = scale_a[:, k_idx : k_idx + 1] # (M, 1) + + for n_idx in range(n_blocks): + n_start = n_idx * block_size + n_end = n_start + block_size + + # Extract (K_block, N_block) from b + b_block = b[k_start:k_end, n_start:n_end].contiguous() # (block_size, block_size) + b_scale_block = scale_b[k_idx : k_idx + 1, n_idx : n_idx + 1] # (1, 1) + + # Compute this block's contribution using torch._scaled_mm + block_output = torch._scaled_mm( + a_block, # (M, block_size) + b_block, # (block_size, block_size) + scale_a=a_scale_block, # (M, 1) + scale_b=b_scale_block, # (1, 1) + out_dtype=torch.bfloat16, + use_fast_accum=True, + ) + + # Accumulate into output + output[:, n_start:n_end] += block_output + + return output + + +def blockwise_fp8_scaled_mm_1x128_128x1( + a: torch.Tensor, # (M, K) in fp8 + a_s: torch.Tensor, # (M, K // block_size) reciprocals of scales + b: torch.Tensor, # (K, N) in fp8, column-major + b_s: torch.Tensor, # (K // block_size, N) reciprocals of scales + block_size: int = 128, +) -> torch.Tensor: + """ + Blockwise fp8 GEMM for backward pass using torch._scaled_mm with native scaling when available. + + This variant is used when B has (128 x 1) scaling granularity, corresponding + to PyTorch's native ScalingType.BlockWise1x128 support. + + Args: + a: Input tensor (M, K) in fp8, row-major + a_s: Input scales (M, K // block_size), reciprocals (will be inverted) + b: Weight tensor (K, N) in fp8, column-major layout + b_s: Weight scales (K // block_size, N), reciprocals (will be inverted) + block_size: Block size for quantization (must be 128) + + Returns: + Output tensor (M, N) in bfloat16 + """ + assert block_size == 128, "Only block_size=128 is supported" + assert a.dtype == torch.float8_e4m3fn, f"Input a must be fp8_e4m3fn, got {a.dtype}" + assert b.dtype == torch.float8_e4m3fn, f"Input b must be fp8_e4m3fn, got {b.dtype}" + + # Convert reciprocal scales back to regular scales for torch._scaled_mm + scale_a = 1.0 / a_s + scale_b = 1.0 / b_s + + # For 1x128 on LHS and 128x1 on RHS, scales must be outer-dim-major + # Ref: PyTorch test_matmul_cuda.py::test_scaled_mm_vs_emulated_block_wise + scale_a = _outer_dim_major(scale_a) + scale_b = _outer_dim_major(scale_b) + + # Try native blockwise scaling first (requires CUDA 12.9+) + if _check_cuda_version_for_native_blockwise(): + try: + # Use native blockwise scaling with torch._scaled_mm + # This uses BlockWise1x128 scaling for the weight tensor + return torch._scaled_mm( + a, # (M, K) fp8, row-major + b, # (K, N) fp8, column-major - torch._scaled_mm should handle layout + scale_a=scale_a, # (M, K // 128) blockwise scales for input + scale_b=scale_b, # (K // 128, N) blockwise scales for weight (128x1 scaling) + out_dtype=torch.bfloat16, + use_fast_accum=True, + ) + except Exception as e: + warnings.warn( + f"Native blockwise scaling failed: {e}. Falling back to block-by-block processing. " + f"For optimal performance, ensure CUDA 12.9+ and compatible PyTorch version.", + RuntimeWarning + ) + + # Fallback: block-by-block processing to emulate blockwise scaling + return _blockwise_fp8_scaled_mm_fallback_1x128_128x1(a, scale_a, b, scale_b, block_size) + + +def _blockwise_fp8_scaled_mm_fallback_1x128_128x1( + a: torch.Tensor, + scale_a: torch.Tensor, + b: torch.Tensor, + scale_b: torch.Tensor, + block_size: int = 128, +) -> torch.Tensor: + """ + Fallback implementation for 128x1 scaling using block-by-block processing. + """ + M, K = a.size() + N = b.size(1) + k_blocks = K // block_size + + # Initialize output + output = torch.zeros(M, N, dtype=torch.bfloat16, device=a.device) + + # Process each K-block separately to preserve blockwise scaling + for k_idx in range(k_blocks): + k_start = k_idx * block_size + k_end = k_start + block_size + + # Extract K-block from inputs + a_block = a[:, k_start:k_end].contiguous() # (M, block_size) + a_scale_block = scale_a[:, k_idx : k_idx + 1] # (M, 1) + + b_block = b[k_start:k_end, :].contiguous() # (block_size, N) + b_scale_block = scale_b[k_idx : k_idx + 1, :] # (1, N) + + # Compute this block's contribution using torch._scaled_mm + block_output = torch._scaled_mm( + a_block, # (M, block_size) + b_block, # (block_size, N) + scale_a=a_scale_block, # (M, 1) + scale_b=b_scale_block, # (1, N) + out_dtype=torch.bfloat16, + use_fast_accum=True, + ) + + # Accumulate into output + output += block_output + + return output + + +# Convenience wrapper functions to match the Triton kernel interface +def blockwise_fp8_gemm_scaled_mm_1x128_128x128( + a: torch.Tensor, + a_s: torch.Tensor, + b: torch.Tensor, + b_s: torch.Tensor, + block_size: int = 128, +) -> torch.Tensor: + """ + Wrapper function that matches the Triton kernel interface. + + Uses native torch._scaled_mm with blockwise scaling for optimal performance. + """ + return blockwise_fp8_scaled_mm_1x128_128x128(a, a_s, b, b_s, block_size) + + +def blockwise_fp8_gemm_scaled_mm_1x128_128x1( + a: torch.Tensor, + a_s: torch.Tensor, + b: torch.Tensor, + b_s: torch.Tensor, + block_size: int = 128, +) -> torch.Tensor: + """ + Wrapper function that matches the Triton kernel interface. + + Uses native torch._scaled_mm with blockwise scaling for optimal performance. + """ + return blockwise_fp8_scaled_mm_1x128_128x1(a, a_s, b, b_s, block_size)