diff --git a/benchmarks/prototype/blockwise_fp8_training/bench_1x128_128x128_gemms.py b/benchmarks/prototype/blockwise_fp8_training/bench_1x128_128x128_gemms.py index 000b6d3326..9343cf2d5c 100644 --- a/benchmarks/prototype/blockwise_fp8_training/bench_1x128_128x128_gemms.py +++ b/benchmarks/prototype/blockwise_fp8_training/bench_1x128_128x128_gemms.py @@ -15,9 +15,9 @@ from triton.testing import do_bench from torchao.prototype.blockwise_fp8_training.kernels import ( - blockwise_fp8_gemm_1x128_128x128, fp8_blockwise_act_quant_lhs, fp8_blockwise_weight_quant_transposed_rhs, + triton_fp8_gemm_1x128_128x128, ) device = torch.device("cuda") @@ -58,7 +58,7 @@ def get_configs() -> List[ExperimentConfig]: (16640, 5120, 8192), (16640, 8192, 5120), ] - out_dtypes = [torch.float32, torch.bfloat16] + out_dtypes = [torch.bfloat16] configs = [] for mnk, out_dtype in itertools.product(mnk_list, out_dtypes): m, n, k = mnk @@ -94,19 +94,21 @@ def warmup(func, *args, **kwargs): # Warm up then run triton bench warmup( - blockwise_fp8_gemm_1x128_128x128, + triton_fp8_gemm_1x128_128x128, A_q, - 1.0 / A_s, B_t_q, + 1.0 / A_s, 1.0 / B_t_s, + out_dtype=config.out_dtype, ) fp8_triton_us = benchmark_cuda_function_in_microseconds( - blockwise_fp8_gemm_1x128_128x128, + triton_fp8_gemm_1x128_128x128, A_q, - 1.0 / A_s, B_t_q, + 1.0 / A_s, 1.0 / B_t_s, + out_dtype=config.out_dtype, ) # Warm up then run torch bench diff --git a/benchmarks/prototype/blockwise_fp8_training/bench_1x128_128x1_gemms.py b/benchmarks/prototype/blockwise_fp8_training/bench_1x128_128x1_gemms.py index 6873ee2eae..d708c58856 100644 --- a/benchmarks/prototype/blockwise_fp8_training/bench_1x128_128x1_gemms.py +++ b/benchmarks/prototype/blockwise_fp8_training/bench_1x128_128x1_gemms.py @@ -15,9 +15,9 @@ from triton.testing import do_bench from torchao.prototype.blockwise_fp8_training.kernels import ( - blockwise_fp8_gemm_1x128_128x1, fp8_blockwise_act_quant_rhs, fp8_blockwise_act_quant_transposed_lhs, + triton_fp8_gemm_1x128_128x1, ) device = torch.device("cuda") @@ -58,7 +58,7 @@ def get_configs() -> List[ExperimentConfig]: (16640, 5120, 8192), (16640, 8192, 5120), ] - out_dtypes = [torch.float32, torch.bfloat16] + out_dtypes = [torch.bfloat16] configs = [] for mnk, out_dtype in itertools.product(mnk_list, out_dtypes): m, n, k = mnk @@ -92,24 +92,23 @@ def warmup(func, *args, **kwargs): # Warm up then run triton bench warmup( - blockwise_fp8_gemm_1x128_128x1, + triton_fp8_gemm_1x128_128x1, A_t_q, - 1.0 / A_t_s, B_q, + 1.0 / A_t_s, 1.0 / B_s, + out_dtype=config.out_dtype, ) fp8_triton_us = benchmark_cuda_function_in_microseconds( - blockwise_fp8_gemm_1x128_128x1, + triton_fp8_gemm_1x128_128x1, A_t_q, - 1.0 / A_t_s, B_q, + 1.0 / A_t_s, 1.0 / B_s, + out_dtype=config.out_dtype, ) - # torch._scaled_mm requires A_s and B_t_s be in column-major format - A_t_s = A_t_s.t().contiguous().t() - # Warm up then run torch bench warmup( torch._scaled_mm, diff --git a/benchmarks/prototype/blockwise_fp8_training/bench_linear_fwd_bwd.py b/benchmarks/prototype/blockwise_fp8_training/bench_linear_fwd_bwd.py new file mode 100644 index 0000000000..e8a4785624 --- /dev/null +++ b/benchmarks/prototype/blockwise_fp8_training/bench_linear_fwd_bwd.py @@ -0,0 +1,181 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py + +import itertools +from dataclasses import dataclass +from typing import List + +import torch +from tabulate import tabulate +from torch.nn import functional as F +from tqdm import tqdm +from triton.testing import do_bench + +from benchmarks.utils import bench_fwd_bwd_microseconds +from torchao.prototype.blockwise_fp8_training.linear import Float8BlockwiseLinear + +device = torch.device("cuda") + +# This benchmark requires CUDA 12.9+ +assert torch.version.cuda is not None, "CUDA is not available" +cuda_major, cuda_minor = map(int, torch.version.cuda.split(".")) +assert cuda_major >= 12 and cuda_minor >= 9, "CUDA 12.9+ is required" + +# Needed since changing args to function causes recompiles +torch._dynamo.config.cache_size_limit = 1000 + + +@dataclass(frozen=True) +class ExperimentConfig: + out_dtype: torch.dtype + m: int + n: int + k: int + + +@dataclass(frozen=True) +class ExperimentResult: + bf16_linear_us: float + fp8_triton_linear_us: float + fp8_scaled_mm_linear_us: float + + +@dataclass(frozen=True) +class Experiment: + config: ExperimentConfig + result: ExperimentResult + + +def get_configs() -> List[ExperimentConfig]: + mnk_list = [ + # Llama4 shapes + (16640, 5120, 8192), + (16640, 8192, 5120), + ] + out_dtypes = [torch.bfloat16] + configs = [] + for mnk, out_dtype in itertools.product(mnk_list, out_dtypes): + m, n, k = mnk + configs.append( + ExperimentConfig( + out_dtype=out_dtype, + m=m, + n=n, + k=k, + ) + ) + return configs + + +def run_experiment(config: ExperimentConfig) -> ExperimentResult: + M, N, K = config.m, config.n, config.k + inputs = torch.randn(M, K, dtype=config.out_dtype, device="cuda") + bf16_linear = torch.nn.Linear(K, N, dtype=config.out_dtype, device="cuda") + fp8_triton_linear = Float8BlockwiseLinear( + K, N, dtype=config.out_dtype, device="cuda", use_triton=True + ) + fp8_scaled_mm_linear = Float8BlockwiseLinear( + K, N, dtype=config.out_dtype, device="cuda", use_triton=False + ) + + def warmup(func, *args, **kwargs): + for _ in range(10): + func(*args, **kwargs) + + def fwd_bwd(func, inputs, labels, *args, **kwargs): + out = func(inputs, *args, **kwargs) + loss = F.mse_loss(out, labels) + loss.backward() + torch.cuda.synchronize() + + # Warmup then run bf16 torch.mm + labels = inputs.new_empty(M, N).fill_(1.0) + warmup(fwd_bwd, bf16_linear, inputs, labels) + + bf16_linear_us = benchmark_cuda_function_in_microseconds( + fwd_bwd, bf16_linear, inputs, labels + ) + + # Warm up then run triton bench + warmup( + fwd_bwd, + fp8_triton_linear, + inputs, + labels, + ) + + fp8_triton_linear_us = bench_fwd_bwd_microseconds( + fp8_triton_linear, + inputs, + labels=labels, + ) + + warmup( + fwd_bwd, + fp8_scaled_mm_linear, + inputs, + labels, + ) + + fp8_scaled_mm_linear_us = bench_fwd_bwd_microseconds( + fp8_scaled_mm_linear, + inputs, + labels=labels, + ) + + return ExperimentResult( + bf16_linear_us=bf16_linear_us, + fp8_triton_linear_us=fp8_triton_linear_us, + fp8_scaled_mm_linear_us=fp8_scaled_mm_linear_us, + ) + + +def print_results(experiments: List[Experiment]): + headers = [ + "M", + "N", + "K", + "out_dtype", + "bf16_mm_linear_us", + "fp8_triton_linear_us", + "fp8_scaled_mm_linear_us", + ] + rows = [] + for experiment in experiments: + m, n, k = experiment.config.m, experiment.config.n, experiment.config.k + rows.append( + [ + m, + n, + k, + experiment.config.out_dtype, + experiment.result.bf16_linear_us, + experiment.result.fp8_triton_linear_us, + experiment.result.fp8_scaled_mm_linear_us, + ] + ) + print(tabulate(rows, headers=headers)) + + +def benchmark_cuda_function_in_microseconds(f, *args, **kwargs): + return do_bench(lambda: f(*args, **kwargs), return_mode="median") * 1e3 + + +def main(): + torch.random.manual_seed(123) + configs = get_configs() + results = [] + for config in tqdm(configs): + result = run_experiment(config) + results.append(Experiment(config=config, result=result)) + + # Use Tabulate to print results + print_results(results) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/prototype/moe_training/benchmark_moe_fsdp.py b/benchmarks/prototype/moe_training/benchmark_moe_fsdp.py index 1011d2609b..e9fbbdcd86 100644 --- a/benchmarks/prototype/moe_training/benchmark_moe_fsdp.py +++ b/benchmarks/prototype/moe_training/benchmark_moe_fsdp.py @@ -22,10 +22,7 @@ from torch.distributed._composable.fsdp import fully_shard from torch.nn import functional as F -from benchmarks.prototype.moe_training.utils import ( - bench_fwd_bwd_microseconds, - profile_fwd_bwd, -) +from benchmarks.utils import bench_fwd_bwd_microseconds, profile_fwd_bwd # this feature requires CUDA and SM89+ if not torch.cuda.is_available() or torch.cuda.get_device_capability() < (8, 9): diff --git a/benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py b/benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py index c2d2b998f6..d365330bf2 100644 --- a/benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py +++ b/benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py @@ -12,8 +12,8 @@ import torch from tabulate import tabulate from tqdm import tqdm -from utils import bench_fwd_bwd_microseconds, profile_fwd_bwd +from benchmarks.utils import bench_fwd_bwd_microseconds, profile_fwd_bwd from torchao.prototype.moe_training import _scaled_grouped_mm from torchao.prototype.moe_training.conversion_utils import MoEScalingType from torchao.prototype.moe_training.utils import generate_jagged_offs diff --git a/benchmarks/prototype/moe_training/utils.py b/benchmarks/utils.py similarity index 100% rename from benchmarks/prototype/moe_training/utils.py rename to benchmarks/utils.py diff --git a/test/prototype/blockwise_fp8_training/test_blockwise_kernels.py b/test/prototype/blockwise_fp8_training/test_blockwise_kernels.py index e8e855232c..63799aaaf7 100644 --- a/test/prototype/blockwise_fp8_training/test_blockwise_kernels.py +++ b/test/prototype/blockwise_fp8_training/test_blockwise_kernels.py @@ -12,8 +12,6 @@ from packaging import version from torchao.float8.float8_utils import compute_error from torchao.prototype.blockwise_fp8_training.kernels import ( - blockwise_fp8_gemm_1x128_128x1, - blockwise_fp8_gemm_1x128_128x128, fp8_blockwise_act_quant_lhs, fp8_blockwise_act_quant_rhs, fp8_blockwise_act_quant_transposed_lhs, @@ -22,12 +20,14 @@ torch_blockwise_scale_act_quant_lhs, torch_blockwise_scale_act_quant_rhs, torch_blockwise_scale_weight_quant, + triton_fp8_gemm_1x128_128x1, + triton_fp8_gemm_1x128_128x128, ) from torchao.testing.utils import skip_if_rocm from torchao.utils import is_sm_at_least_90 BLOCKWISE_SIZE_MNK = [ - (128, 128, 128), + # (128, 128, 128), (2, 512, 128), (2, 5120, 1280), (3, 2048, 2048), @@ -46,14 +46,16 @@ ) @pytest.mark.parametrize("M, N, K", BLOCKWISE_SIZE_MNK) @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) -def test_blockwise_fp8_gemm_1x128_128x128(M, N, K, dtype): +def test_triton_fp8_gemm_1x128_128x128(M, N, K, dtype): # Simulate output = input @ weight.T A = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") B = torch.randn(N, K, dtype=torch.bfloat16, device="cuda") C = A @ B.T A_q, A_s = fp8_blockwise_act_quant_lhs(A, dtype=dtype) B_t_q, B_t_s = fp8_blockwise_weight_quant_transposed_rhs(B, dtype=dtype) - C_q = blockwise_fp8_gemm_1x128_128x128(A_q, 1.0 / A_s, B_t_q, 1.0 / B_t_s) + C_q = triton_fp8_gemm_1x128_128x128( + A_q, B_t_q, A_s, B_t_s, out_dtype=torch.bfloat16 + ) assert not C_q.isnan().any(), "C_q must not contain NaNs" sqnr = compute_error(C, C_q) @@ -69,14 +71,14 @@ def test_blockwise_fp8_gemm_1x128_128x128(M, N, K, dtype): ) @pytest.mark.parametrize("M, N, K", BLOCKWISE_SIZE_MNK) @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) -def test_blockwise_fp8_gemm_1x128_128x1(M, N, K, dtype): +def test_triton_fp8_gemm_1x128_128x1(M, N, K, dtype): # Simulate grad_weight = grad_output_t @ input A = torch.randn(K, M, dtype=torch.bfloat16, device="cuda") B = torch.randn(K, N, dtype=torch.bfloat16, device="cuda") C = A.T @ B A_t_q, A_t_s = fp8_blockwise_act_quant_transposed_lhs(A, dtype=dtype) B_q, B_s = fp8_blockwise_act_quant_rhs(B, dtype=dtype) - C_q = blockwise_fp8_gemm_1x128_128x1(A_t_q, 1.0 / A_t_s, B_q, 1.0 / B_s) + C_q = triton_fp8_gemm_1x128_128x1(A_t_q, B_q, A_t_s, B_s, out_dtype=torch.bfloat16) assert not C_q.isnan().any(), "C_q must not contain NaNs" assert C.dtype == torch.bfloat16 @@ -99,13 +101,13 @@ def test_triton_quantize_fp8_act_quant_lhs(block_size): # quantized tensor will have NaNs due to division by 0 x[0, :block_size] = 0.0 - # Get the quantized tensor and scales using triton implementation + # Get the quantized tensor and reciprocal scales using triton implementation triton_fp8, triton_scale = fp8_blockwise_act_quant_lhs( x, block_size=block_size, ) - # Get the quantized tensor and scales using reference implementation + # Get the quantized tensor and reciprocal scales using reference implementation ref_fp8, ref_scale = torch_blockwise_scale_act_quant_lhs(x, tile_size=block_size) assert not triton_fp8.isnan().any(), "fp8 output must not contain NaNs" @@ -124,7 +126,7 @@ def test_triton_quantize_fp8_act_quant_lhs(block_size): msg=f"Quantized tensors differ: max diff = {(triton_fp32 - ref_fp32).abs().max().item()}", ) - # Compare scales + # Compare reciprocal scales torch.testing.assert_close( triton_scale, ref_scale, @@ -146,13 +148,13 @@ def test_triton_quantize_fp8_act_quant_rhs(block_size: int): # quantized tensor will have NaNs due to division by 0 x[:block_size, :block_size] = 0.0 - # Get the quantized tensor and scales using triton implementation + # Get the quantized tensor and reciprocal scales using triton implementation triton_fp8, triton_scale = fp8_blockwise_act_quant_rhs( x, block_size=block_size, ) - # Get the quantized tensor and scales using reference implementation + # Get the quantized tensor and reciprocal scales using reference implementation ref_fp8, ref_scale = torch_blockwise_scale_act_quant_rhs(x, block_size=block_size) assert not triton_fp8.isnan().any(), "fp8 output must not contain NaNs" @@ -171,7 +173,7 @@ def test_triton_quantize_fp8_act_quant_rhs(block_size: int): msg=f"Quantized tensors differ: max diff = {(triton_fp32 - ref_fp32).abs().max().item()}", ) - # Compare scales + # Compare reciprocal scales torch.testing.assert_close( triton_scale, ref_scale, @@ -193,13 +195,13 @@ def test_triton_quantize_fp8_act_quant_transposed_lhs(M, K, block_size: int): # quantized tensor will have NaNs due to division by 0 x[0, :block_size] = 0.0 - # Get the quantized tensor and scales using triton implementation + # Get the quantized tensor and reciprocal scales using triton implementation triton_fp8, triton_scale = fp8_blockwise_act_quant_transposed_lhs( x, block_size=block_size, ) - # Get the quantized tensor and scales using reference implementation + # Get the quantized tensor and reciprocal scales using reference implementation ref_fp8, ref_scale = torch_blockwise_scale_act_quant_lhs( x.t().contiguous(), tile_size=block_size ) @@ -220,7 +222,7 @@ def test_triton_quantize_fp8_act_quant_transposed_lhs(M, K, block_size: int): msg=f"Quantized tensors differ: max diff = {(triton_fp32 - ref_fp32).abs().max().item()}", ) - # Compare scales + # Compare reciprocal scales torch.testing.assert_close( triton_scale, ref_scale, @@ -242,12 +244,12 @@ def test_triton_quantize_fp8_weight_quant_rhs(M, K, block_size: int): # quantized tensor will have NaNs due to division by 0 x[:block_size, :block_size] = 0.0 - # Get the quantized tensor and scales using triton implementation + # Get the quantized tensor and reciprocal scales using triton implementation triton_fp8, triton_scale = fp8_blockwise_weight_quant_rhs( x, block_size=block_size, ) - # Get the quantized tensor and scales using reference implementation + # Get the quantized tensor and reciprocal scales using reference implementation ref_fp8, ref_scale = torch_blockwise_scale_weight_quant(x, tile_size=block_size) assert not ref_fp8.isnan().any(), "fp8 output must not contain NaNs" @@ -266,7 +268,7 @@ def test_triton_quantize_fp8_weight_quant_rhs(M, K, block_size: int): msg=f"Quantized tensors differ: max diff = {(triton_fp32 - ref_fp32).abs().max().item()}", ) - # Compare scales + # Compare reciprocal scales torch.testing.assert_close( triton_scale, ref_scale, @@ -289,12 +291,12 @@ def test_triton_quantize_fp8_weight_quant_transposed_rhs(block_size: int): # quantized tensor will have NaNs due to division by 0 x[:block_size, :block_size] = 0.0 - # Get the quantized tensor and scales using triton implementation + # Get the quantized tensor and reciprocal scales using triton implementation triton_fp8, triton_scale = fp8_blockwise_weight_quant_transposed_rhs( x, block_size=block_size, ) - # Get the quantized tensor and scales using reference implementation + # Get the quantized tensor and reciprocal scales using reference implementation ref_fp8, ref_scale = torch_blockwise_scale_weight_quant( x.t().contiguous(), tile_size=block_size ) @@ -315,7 +317,7 @@ def test_triton_quantize_fp8_weight_quant_transposed_rhs(block_size: int): msg=f"Quantized tensors differ: max diff = {(triton_fp32 - ref_fp32).abs().max().item()}", ) - # Compare scales + # Compare reciprocal scales torch.testing.assert_close( triton_scale, ref_scale, diff --git a/torchao/prototype/blockwise_fp8_training/kernels.py b/torchao/prototype/blockwise_fp8_training/kernels.py index 515886ec1d..0ff0ace146 100644 --- a/torchao/prototype/blockwise_fp8_training/kernels.py +++ b/torchao/prototype/blockwise_fp8_training/kernels.py @@ -23,14 +23,7 @@ ) for block_size in [64, 128, 256] for num_warps in [4, 8] - for num_stages in [2, 4] -] - -# For fast compile times during development. -dev_fp8_gemm_configs = [ - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128}, num_warps=4, num_stages=3 - ), + for num_stages in [2] ] EPS = 1e-12 @@ -38,7 +31,7 @@ @triton.autotune(configs=fp8_gemm_configs_max_autotune, key=["N", "K", "BLOCK_SIZE_K"]) @triton.jit -def blockwise_fp8_gemm_1x128_128x128_kernel( +def triton_fp8_gemm_1x128_128x128_kernel( a_ptr, # (M, K) a_stride_dim_0, a_stride_dim_1, @@ -102,23 +95,21 @@ def blockwise_fp8_gemm_1x128_128x128_kernel( tl.store(c_ptrs, c, mask=c_mask) -def blockwise_fp8_gemm_1x128_128x128( +def triton_fp8_gemm_1x128_128x128( a: torch.Tensor, # (M, K) - a_s: torch.Tensor, # (M, K // block_size) b: torch.Tensor, # (K, N) + a_s: torch.Tensor, # (M, K // block_size) b_s: torch.Tensor, # (K // block_size, N // block_size) block_size: int = 128, out_dtype: torch.dtype = torch.float32, ): # 'a' must be in row-major layout, 'b' must be in column-major layout - assert _is_row_major(a) and _is_column_major(b), ( - "a must be row-major, b must be column-major" - ) + assert _is_row_major(a), "a must be row-major" + assert _is_column_major(b), "b must be column-major" - # a_scales must be row-major, b_scales must be column-major - assert _is_row_major(a_s) and _is_column_major(b_s), ( - "a_s must be row-major, b_s must be column-major" - ) + # a_scales must be col-major, b_scales must be row-major + assert _is_column_major(a_s), "a_s must be column-major" + assert _is_column_major(b_s), "b_s must be column-major" M = a.size(0) K = a.size(1) @@ -128,7 +119,7 @@ def blockwise_fp8_gemm_1x128_128x128( triton.cdiv(M, META["BLOCK_SIZE_M"]), triton.cdiv(N, META["BLOCK_SIZE_N"]), ) - blockwise_fp8_gemm_1x128_128x128_kernel[grid]( + triton_fp8_gemm_1x128_128x128_kernel[grid]( a, a.stride(0), a.stride(1), @@ -157,7 +148,7 @@ def blockwise_fp8_gemm_1x128_128x128( configs=fp8_gemm_configs_max_autotune, key=["M", "N", "K", "BLOCK_SIZE_K"] ) @triton.jit -def blockwise_fp8_gemm_1x128_128x1_kernel( +def triton_fp8_gemm_1x128_128x1_kernel( a_ptr, # (M, K) a_stride_dim_0, a_stride_dim_1, @@ -219,17 +210,22 @@ def blockwise_fp8_gemm_1x128_128x1_kernel( tl.store(c_ptrs, c, mask=c_mask) -def blockwise_fp8_gemm_1x128_128x1( +def triton_fp8_gemm_1x128_128x1( a: torch.Tensor, # (M, K) - a_s: torch.Tensor, # (M, K // block_size) reciprocals of scales b: torch.Tensor, # (K, N) + a_s: torch.Tensor, # (M, K // block_size) reciprocals of scales b_s: torch.Tensor, # (K // block_size, N) reciprocals of scales block_size: int = 128, out_dtype: torch.dtype = torch.float32, ): # 'a' must be in row-major layout, 'b' must be in column-major layout - assert a.is_contiguous() and not b.is_contiguous() - assert a_s.is_contiguous() and b_s.is_contiguous() + assert _is_row_major(a), "a must be row-major" + assert _is_column_major(b), "b must be column-major" + + # a_scales must be col-major, b_scales must be row-major + assert _is_column_major(a_s), "a_s must be column-major" + assert _is_row_major(b_s), "b_s must be row-major" + M = a.size(0) K = a.size(1) N = b.size(1) @@ -238,7 +234,7 @@ def blockwise_fp8_gemm_1x128_128x1( triton.cdiv(M, META["BLOCK_SIZE_M"]), triton.cdiv(N, META["BLOCK_SIZE_N"]), ) - blockwise_fp8_gemm_1x128_128x1_kernel[grid]( + triton_fp8_gemm_1x128_128x1_kernel[grid]( a, a.stride(0), a.stride(1), @@ -260,6 +256,19 @@ def blockwise_fp8_gemm_1x128_128x1( return c +# Quantization kernels autotuner configs +quant_kernel_configs = [ + triton.Config( + {}, + num_warps=warps, + num_stages=stages, + ) + for warps in [4, 8] + for stages in [2, 4, 6] +] + + +@triton.autotune(configs=quant_kernel_configs, key=["K"]) @triton.jit def fp8_blockwise_act_quant_lhs_kernel( x_ptr, @@ -299,9 +308,9 @@ def fp8_blockwise_act_quant_lhs_kernel( y_mask = (m_offs[:, None] < M) & (k_offs[None, :] < K) tl.store(y_ptr + y_offs, y, mask=y_mask) - # Write scales + # Write reciprocal scales scale_offs = pid_m * s_stride_dim_0 + pid_k * s_stride_dim_1 - tl.store(s_ptr + scale_offs, scale) + tl.store(s_ptr + scale_offs, tl.div_rn(1.0, scale)) def fp8_blockwise_act_quant_lhs( @@ -309,7 +318,7 @@ def fp8_blockwise_act_quant_lhs( ) -> Tuple[torch.Tensor, torch.Tensor]: """ Input: row-major high-precision tensor - Output: row-major, with scales for (1 x block_size) groups stored in row-major. + Output: row-major, with reciprocal scales for (1 x block_size) groups stored in col-major. """ assert x.is_contiguous(), "Input tensor must be contiguous" assert x.size(-1) % block_size == 0, ( @@ -320,7 +329,11 @@ def fp8_blockwise_act_quant_lhs( ], "dtype must be torch.float8_e4m3fn" M, K = x.size() y = torch.empty_like(x, dtype=dtype) - s = x.new_empty(M, K // block_size, dtype=torch.float32) + # Write scales to column-major format to align with torch._scaled_mm requirements. + s = x.new_empty(M, K // block_size, dtype=torch.float32).as_strided( + (M, K // block_size), + (1, M), + ) grid = lambda meta: (M, triton.cdiv(K, meta["BLOCK_SIZE"])) fp8_blockwise_act_quant_lhs_kernel[grid]( x, @@ -340,6 +353,7 @@ def fp8_blockwise_act_quant_lhs( return y, s +@triton.autotune(configs=quant_kernel_configs, key=["K"]) @triton.jit def fp8_blockwise_act_quant_rhs_kernel( x_ptr, @@ -381,7 +395,7 @@ def fp8_blockwise_act_quant_rhs_kernel( # Write scales scale_offs = pid_m * s_stride_dim_0 + pid_k * s_stride_dim_1 - tl.store(s_ptr + scale_offs, scale) + tl.store(s_ptr + scale_offs, tl.div_rn(1.0, scale)) def fp8_blockwise_act_quant_rhs( @@ -399,9 +413,11 @@ def fp8_blockwise_act_quant_rhs( torch.float8_e4m3fn, ], "dtype must be torch.float8_e4m3fn" M, K = x.size() + M_blocks = triton.cdiv(M, block_size) y = torch.empty_like(x, dtype=dtype) y = y.as_strided(y.size(), (1, y.size(0))) - s = x.new_empty(triton.cdiv(M, block_size), K, dtype=torch.float32) + s = x.new_empty(M_blocks, K, dtype=torch.float32) + grid = lambda meta: ( triton.cdiv(M, meta["BLOCK_SIZE"]), K, @@ -424,6 +440,7 @@ def fp8_blockwise_act_quant_rhs( return y, s +@triton.autotune(configs=quant_kernel_configs, key=["K"]) @triton.jit def fp8_blockwise_act_quant_transposed_lhs_kernel( x_ptr, @@ -480,7 +497,9 @@ def fp8_blockwise_act_quant_transposed_lhs_kernel( # Scale tensor size is (K, M // SCALE_BLOCK_SIZE) scale_offs = scale_k_offs * s_stride_dim_0 + scale_m_off * s_stride_dim_1 scale_mask = (scale_k_offs < K) & (scale_m_off < M // SCALE_BLOCK_SIZE) - tl.store(s_ptr + scale_offs, scale, mask=scale_mask) + + # Write out reciprocal scales + tl.store(s_ptr + scale_offs, tl.div_rn(1.0, scale), mask=scale_mask) def fp8_blockwise_act_quant_transposed_lhs( @@ -497,7 +516,13 @@ def fp8_blockwise_act_quant_transposed_lhs( # Output should have transposed dims and be in row major format M, K = x.shape y = torch.empty(K, M, dtype=dtype, device=x.device) - s = x.new_empty(K, triton.cdiv(M, block_size), dtype=torch.float32) + M_blocks = triton.cdiv(M, block_size) + + # Column major scales required for torch._scaled_mm + s = x.new_empty(K, M_blocks, dtype=torch.float32).as_strided( + (K, M_blocks), # shape + (1, K), # stride + ) grid = lambda meta: ( triton.cdiv(M, meta["SCALE_BLOCK_SIZE"]), triton.cdiv(K, meta["BLOCK_SIZE_K"]), @@ -522,6 +547,7 @@ def fp8_blockwise_act_quant_transposed_lhs( return y, s +@triton.autotune(configs=quant_kernel_configs, key=["M", "N"]) @triton.jit def fp8_blockwise_weight_quant_rhs_kernel( x_ptr, @@ -562,10 +588,10 @@ def fp8_blockwise_weight_quant_rhs_kernel( y_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) tl.store(y_ptr + y_offs, y, mask=y_mask) - # Write scale (scalar value) + # Write reciprocal scale (scalar value) scale_m_off = pid_m * s_stride_dim_0 scale_n_off = pid_n * s_stride_dim_1 - tl.store(s_ptr + scale_m_off + scale_n_off, scale) + tl.store(s_ptr + scale_m_off + scale_n_off, tl.div_rn(1.0, scale)) def fp8_blockwise_weight_quant_rhs( @@ -582,8 +608,10 @@ def fp8_blockwise_weight_quant_rhs( M, N = x.size() y = torch.empty_like(x, dtype=dtype) y = y.as_strided(y.size(), (1, y.size(0))) # Column major - s = x.new_empty( - triton.cdiv(M, block_size), triton.cdiv(N, block_size), dtype=torch.float32 + M_blocks, N_blocks = triton.cdiv(M, block_size), triton.cdiv(N, block_size) + s = x.new_empty(M_blocks, N_blocks, dtype=torch.float32).as_strided( + (M_blocks, N_blocks), # shape + (1, M_blocks), # stride ) grid = lambda meta: ( triton.cdiv(M, meta["BLOCK_SIZE"]), @@ -607,6 +635,7 @@ def fp8_blockwise_weight_quant_rhs( return y, s +@triton.autotune(configs=quant_kernel_configs, key=["M", "N"]) @triton.jit def fp8_blockwise_weight_quant_transposed_rhs_kernel( x_ptr, @@ -659,14 +688,14 @@ def fp8_blockwise_weight_quant_transposed_rhs_kernel( y_mask = (n_offs[:, None] < N) & (m_offs[None, :] < M) tl.store(y_ptr + y_offs, y.trans(1, 0), mask=y_mask) - # Write scales + # Write reciprocal scales scale_m = pid_m scale_k = pid_n scale_offs = scale_k[:, None] * s_stride_dim_0 + scale_m[None, :] * s_stride_dim_1 scale_mask = (scale_k[:, None] < N // BLOCK_SIZE) & ( scale_m[None, :] < M // BLOCK_SIZE ) - tl.store(s_ptr + scale_offs, scale, mask=scale_mask) + tl.store(s_ptr + scale_offs, tl.div_rn(1.0, scale), mask=scale_mask) def fp8_blockwise_weight_quant_transposed_rhs( @@ -738,7 +767,9 @@ def torch_blockwise_scale_act_quant_lhs(x, tile_size=128): # Reshape quantized output back to original shape and reshape scales accordingly x = x.reshape(*orig_shape) s = s.reshape(orig_shape[0], -1).to(torch.float) - return x, s + + # Return output tensor and reciprocal scale + return x, 1.0 / s def torch_blockwise_scale_act_quant_rhs( @@ -797,7 +828,8 @@ def torch_blockwise_scale_act_quant_rhs( # Convert to column-major format y = y.t().contiguous().t() - return y, scales + # Return output tensor and reciprocal scales + return y, 1.0 / scales def torch_blockwise_scale_weight_quant(x, tile_size=128): @@ -837,4 +869,6 @@ def torch_blockwise_scale_weight_quant(x, tile_size=128): x = x.permute(0, 2, 1, 3) x = x.reshape(height, width) s = s.reshape(t_h, t_w).to(torch.float) - return x, s + + # Return output tensor and reciprocal scale + return x, 1.0 / s diff --git a/torchao/prototype/blockwise_fp8_training/linear.py b/torchao/prototype/blockwise_fp8_training/linear.py index b32f3c0073..69bb2c3a9a 100644 --- a/torchao/prototype/blockwise_fp8_training/linear.py +++ b/torchao/prototype/blockwise_fp8_training/linear.py @@ -9,13 +9,13 @@ from torchao.core.config import AOBaseConfig from torchao.prototype.blockwise_fp8_training.kernels import ( - blockwise_fp8_gemm_1x128_128x1, - blockwise_fp8_gemm_1x128_128x128, fp8_blockwise_act_quant_lhs, fp8_blockwise_act_quant_rhs, fp8_blockwise_act_quant_transposed_lhs, fp8_blockwise_weight_quant_rhs, fp8_blockwise_weight_quant_transposed_rhs, + triton_fp8_gemm_1x128_128x1, + triton_fp8_gemm_1x128_128x128, ) from torchao.quantization.transform_module import ( register_quantize_module_handler, @@ -25,7 +25,7 @@ class fp8_blockwise_mm(torch.autograd.Function): @staticmethod - def forward(ctx, x, weight, block_size): + def forward(ctx, x, weight, block_size, out_dtype=torch.bfloat16, use_triton=False): assert block_size == 128, "Only support block_size=128" # Temporarily reshape x to 2D tensor @@ -42,21 +42,27 @@ def forward(ctx, x, weight, block_size): ) # out = input @ weight.T - out = blockwise_fp8_gemm_1x128_128x128( + fp8_gemm = triton_fp8_gemm_1x128_128x128 if use_triton else torch._scaled_mm + out = fp8_gemm( x_fp8, - 1.0 / x_scale, weight_t_fp8, - 1.0 / weight_t_scale, + x_scale, + weight_t_scale, + out_dtype=out_dtype, ) out = out.reshape(*x_orig_shape[:-1], out.shape[-1]) ctx.save_for_backward(x, weight) ctx.block_size = block_size + ctx.out_dtype = out_dtype + ctx.use_triton = use_triton return out @staticmethod def backward(ctx, grad_output): x, weight = ctx.saved_tensors block_size = ctx.block_size + out_dtype = ctx.out_dtype + use_triton = ctx.use_triton # Reshape input to 2D x_orig_shape = x.shape @@ -80,11 +86,15 @@ def backward(ctx, grad_output): ) # grad_x = grad_output @ weight - grad_x = blockwise_fp8_gemm_1x128_128x128( + fp8_gemm_1x128_128x128 = ( + triton_fp8_gemm_1x128_128x128 if use_triton else torch._scaled_mm + ) + grad_x = fp8_gemm_1x128_128x128( grad_output_fp8, - 1.0 / grad_output_scale, weight_fp8, - 1.0 / weight_scale, + grad_output_scale, + weight_scale, + out_dtype=out_dtype, ) # Cast grad_output_t to fp8 blockwise with (1 x block_size) scaling groups, since it is @@ -101,16 +111,20 @@ def backward(ctx, grad_output): x_fp8, x_scale = fp8_blockwise_act_quant_rhs(x, block_size) # grad_weight = grad_output.T @ x - grad_weight = blockwise_fp8_gemm_1x128_128x1( + fp8_gemm_1x128_128x1 = ( + triton_fp8_gemm_1x128_128x1 if use_triton else torch._scaled_mm + ) + grad_weight = fp8_gemm_1x128_128x1( grad_output_t_fp8, - 1.0 / grad_output_t_scale, x_fp8, - 1.0 / x_scale, + grad_output_t_scale, + x_scale, + out_dtype=out_dtype, ) # Reshape grad_x to expected potentially 3D+ shape grad_x = grad_x.reshape(*grad_output_orig_shape[:-1], grad_x.shape[-1]) - return grad_x, grad_weight, None, None + return grad_x, grad_weight, None, None, None class Float8BlockwiseLinear(nn.Linear): @@ -134,6 +148,7 @@ def __init__( *args, block_size: int = 128, dtype=torch.bfloat16, + use_triton=False, **kwargs, ): super().__init__(*args, **kwargs) @@ -144,6 +159,7 @@ def __init__( assert is_sm_at_least_90(), "Only support SM90" self.block_size = block_size self.dtype = dtype + self.use_triton = use_triton def forward(self, x: torch.Tensor) -> torch.Tensor: """ @@ -155,7 +171,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: Transformed tensor after linear computation. """ - return fp8_blockwise_mm.apply(x, self.weight, self.block_size) + return fp8_blockwise_mm.apply( + x, self.weight, self.block_size, self.dtype, self.use_triton + ) @classmethod def from_float(