From 6afefbc0ff644c5072a8b78b8429220f80e707b8 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Wed, 13 Aug 2025 16:04:53 -0700 Subject: [PATCH] [moe training] update bench script to compare fp8 dynamic quant scaled_grouped_mm fwd+bwd against bf16 stack-info: PR: https://github.com/pytorch/ao/pull/2765, branch: danielvegamyhre/stack/40 --- .../benchmark_rowwise_3d_quant_kernels.py | 3 +- .../benchmark_scaled_grouped_mm.py | 80 +++++++++++-------- benchmarks/prototype/moe_training/utils.py | 21 +++++ .../moe_training/kernels/float8_rowwise.py | 25 +++--- .../moe_training/scaled_grouped_mm.py | 27 +++---- 5 files changed, 91 insertions(+), 65 deletions(-) create mode 100644 benchmarks/prototype/moe_training/utils.py diff --git a/benchmarks/prototype/moe_training/benchmark_rowwise_3d_quant_kernels.py b/benchmarks/prototype/moe_training/benchmark_rowwise_3d_quant_kernels.py index 66a7c91f53..53518ba491 100644 --- a/benchmarks/prototype/moe_training/benchmark_rowwise_3d_quant_kernels.py +++ b/benchmarks/prototype/moe_training/benchmark_rowwise_3d_quant_kernels.py @@ -87,12 +87,13 @@ def run_torch(input_tensor: torch.Tensor): return out def run_triton(input_tensor: torch.Tensor): - _ = triton_fp8_rowwise_3d_transpose_rhs( + out = triton_fp8_rowwise_3d_transpose_rhs( input_tensor, output_dtype=torch.float8_e4m3fn, round_scales_to_power_of_2=True, ) torch.cuda.synchronize() + return out # bench torch compiled_run_torch = torch.compile(run_torch) diff --git a/benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm.py b/benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm.py index c229eaeb71..9b615e5b8d 100644 --- a/benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm.py +++ b/benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm.py @@ -6,15 +6,17 @@ # this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py import argparse import itertools -import time from dataclasses import dataclass from typing import List import torch from tabulate import tabulate from tqdm import tqdm +from utils import bench_fwd_bwd_microseconds from torchao.prototype.moe_training import _scaled_grouped_mm +from torchao.prototype.moe_training.conversion_utils import MoEScalingType +from torchao.prototype.moe_training.utils import generate_jagged_offs device = torch.device("cuda") @@ -27,11 +29,14 @@ class ExperimentConfig: high_precision_dtype: torch.dtype A_shape: tuple[int] B_shape: tuple[int] + recipe: MoEScalingType @dataclass(frozen=True) class ExperimentResult: - time_us: float + bf16_us: float + fp8_us: float + fp8_speedup: float @dataclass(frozen=True) @@ -41,19 +46,22 @@ class Experiment: def get_configs() -> List[ExperimentConfig]: - A_shapes = [(2**8, 8192), (2**12, 8192), (2**16, 8192)] - B_shapes = [(4, 8192, 8192), (8, 8192, 8192), (16, 8192, 8192)] + A_shapes = [(16640, 5120)] + B_shapes = [(16, 8192, 5120), (128, 8192, 5120)] + recipes = [MoEScalingType.FP8_ROWWISE] high_precision_dtypes = [torch.bfloat16] configs = [] - for A_shape, B_shape, high_precision_dtype in itertools.product( + for A_shape, B_shape, recipe, high_precision_dtype in itertools.product( A_shapes, B_shapes, + recipes, high_precision_dtypes, ): configs.append( ExperimentConfig( A_shape=A_shape, B_shape=B_shape, + recipe=recipe, high_precision_dtype=high_precision_dtype, ) ) @@ -83,39 +91,37 @@ def run_experiment( # - the transposed tensor in col-major format with groups along the row dimension, # which represents the right operand. n_groups = config.B_shape[0] - group_size = A.shape[0] // n_groups - offs = torch.arange( - group_size, - group_size * n_groups + 1, - group_size, - device=device, - dtype=torch.int32, - ) + offs = generate_jagged_offs(n_groups, A.shape[0], multiple_of=16) - def warmup(func, *args, **kwargs): - for _ in range(10): - func(*args, **kwargs) + labels = torch.ones( + (A.shape[0], B_t.shape[-1]), device=device, dtype=torch.bfloat16 + ) - def forward_backward(A, B_t, offs): - out = _scaled_grouped_mm( - A, - B_t, - offs=offs, - out_dtype=torch.bfloat16, - ) - out.sum().backward() - torch.cuda.synchronize() + # benchmark bf16 grouped mm + bf16_us = bench_fwd_bwd_microseconds( + torch._grouped_mm, + A, + B_t, + offs, + labels=labels, + use_compile=args.compile, + ) - # benchmark torch - torch_func = torch.compile(forward_backward) if args.compile else forward_backward - warmup(torch_func, A, B_t, offs) - start_time_ns = time.perf_counter_ns() - torch_func(A, B_t, offs) - torch_time_ns = time.perf_counter_ns() - start_time_ns - time_us = torch_time_ns / 1e3 + # benchmark scaled grouped mm with dynamic fp8 rowwise quant + fp8_us = bench_fwd_bwd_microseconds( + _scaled_grouped_mm, + A, + B_t, + offs, + scaling_type=config.recipe, + labels=labels, + use_compile=args.compile, + ) return ExperimentResult( - time_us=round(time_us, 3), + bf16_us=round(bf16_us, 3), + fp8_us=round(fp8_us, 3), + fp8_speedup=round(bf16_us / fp8_us, 3), ) @@ -123,7 +129,9 @@ def print_results(experiments: List[Experiment]): headers = [ "A_shape", "B_shape", - "time_us", + "bf16_time_us", + "scaled_time_us", + "fp8_speedup", ] rows = [] for experiment in experiments: @@ -133,7 +141,9 @@ def print_results(experiments: List[Experiment]): [ A_shape, B_shape, - experiment.result.time_us, + experiment.result.bf16_us, + experiment.result.fp8_us, + f"{experiment.result.fp8_speedup}x", ] ) print(tabulate(rows, headers=headers)) diff --git a/benchmarks/prototype/moe_training/utils.py b/benchmarks/prototype/moe_training/utils.py new file mode 100644 index 0000000000..d6c5e7e82f --- /dev/null +++ b/benchmarks/prototype/moe_training/utils.py @@ -0,0 +1,21 @@ +import statistics +from time import perf_counter_ns + +import torch +from torch.nn import functional as F + + +def bench_fwd_bwd_microseconds(fn, *args, labels=None, use_compile=False, **kwargs): + assert labels is not None + fn = torch.compile(fn, fullgraph=False) if use_compile else fn + times = [] + for _ in range(10): + start_ns = perf_counter_ns() + out = fn(*args, **kwargs) + loss = F.mse_loss(out, labels) + loss.backward() + torch.cuda.synchronize() + end_ns = perf_counter_ns() + duration_us = (end_ns - start_ns) / 1000 + times.append(duration_us) + return statistics.median(times) diff --git a/torchao/prototype/moe_training/kernels/float8_rowwise.py b/torchao/prototype/moe_training/kernels/float8_rowwise.py index 3449b89336..5c084ca1b5 100644 --- a/torchao/prototype/moe_training/kernels/float8_rowwise.py +++ b/torchao/prototype/moe_training/kernels/float8_rowwise.py @@ -51,7 +51,6 @@ def triton_fp8_rowwise_3d_transpose_rhs( ) -> Tuple[torch.Tensor, torch.Tensor]: assert hp_tensor.ndim == 3, "input tensor must be 3D" - num_elements = hp_tensor.numel() tl_input_dtype = FP8_DTYPE_MAP[hp_tensor.dtype] tl_output_dtype = FP8_DTYPE_MAP[output_dtype] @@ -89,7 +88,6 @@ def triton_fp8_rowwise_3d_transpose_rhs( e, n, k, - num_elements, fp8_dtype_min, fp8_dtype_max, tl_input_dtype, @@ -113,7 +111,6 @@ def triton_fp8_rowwise_3d_transpose_rhs( e, n, k, - num_elements, fp8_dtype_min, fp8_dtype_max, tl_input_dtype, @@ -138,20 +135,19 @@ def _fake_triton_fp8_rowwise_3d_transpose_rhs( return output_buffer, scales_buffer -@triton.autotune(configs=kernel_configs_2D, key=["num_elements"]) +@triton.autotune(configs=kernel_configs_2D, key=["K", "N"]) @triton.jit def _triton_fp8_rowwise_3d_transpose_scales_rhs_kernel( input_ptr, - stride_input_dim0: int, - stride_input_dim1: int, - stride_input_dim2: int, + stride_input_dim0: tl.int64, + stride_input_dim1: tl.int64, + stride_input_dim2: tl.int64, scales_ptr, stride_scales_dim0: int, stride_scales_dim1: int, E: int, N: int, K: int, - num_elements: int, fp8_dtype_min: tl.constexpr, fp8_dtype_max: tl.constexpr, input_dtype: tl.constexpr, @@ -202,20 +198,19 @@ def _triton_fp8_rowwise_3d_transpose_scales_rhs_kernel( @triton.jit def _triton_fp8_rowwise_3d_transpose_cast_rhs_kernel( input_ptr, - stride_input_dim0: int, - stride_input_dim1: int, - stride_input_dim2: int, + stride_input_dim0: tl.int64, + stride_input_dim1: tl.int64, + stride_input_dim2: tl.int64, output_ptr, - stride_output_dim0: int, - stride_output_dim1: int, - stride_output_dim2: int, + stride_output_dim0: tl.int64, + stride_output_dim1: tl.int64, + stride_output_dim2: tl.int64, scales_ptr, stride_scales_dim0: int, stride_scales_dim1: int, E: int, N: int, K: int, - num_elements: int, fp8_dtype_min: tl.constexpr, fp8_dtype_max: tl.constexpr, input_dtype: tl.constexpr, diff --git a/torchao/prototype/moe_training/scaled_grouped_mm.py b/torchao/prototype/moe_training/scaled_grouped_mm.py index 58d7aa71d8..0ee72ea35b 100644 --- a/torchao/prototype/moe_training/scaled_grouped_mm.py +++ b/torchao/prototype/moe_training/scaled_grouped_mm.py @@ -48,7 +48,7 @@ def _scaled_grouped_mm( """ # TODO: Remove logging once prototype is more mature. This is currently very useful for development and debugging. if scaling_type == MoEScalingType.FP8_ROWWISE: - print("Using fp8 rowwise scaled_grouped_mm") + # print("Using fp8 rowwise scaled_grouped_mm") return _Float8GroupedMM.apply( A, B_t, @@ -140,17 +140,8 @@ def forward( B_t_scaled = B_t.to(torch.float32) * B_t_scales B_t_fp8_col_major = to_fp8_saturated(B_t_scaled, torch.float8_e4m3fn) - # Precompute non-transposed B column-major for backward, to save memory by storing the - # low precision B tensor instead of the high precision B tensor. - # In the backward this is needed for grad_A: grad_output @ B. - B_fp8_col_major, B_scales = triton_fp8_rowwise_3d_transpose_rhs( - B_t._data, - output_dtype=torch.float8_e4m3fn, - round_scales_to_power_of_2=True, - ) - # Store what we need for backward. - ctx.save_for_backward(A, B_fp8_col_major, B_scales, offs) + ctx.save_for_backward(A, B_t, offs) ctx.out_dtype = out_dtype # Perform scaled grouped GEMM and return result. @@ -179,7 +170,7 @@ def forward( @staticmethod def backward(ctx, grad_output: torch.Tensor): - A, B_fp8_col_major, B_scales, offs = ctx.saved_tensors + A, B_t, offs = ctx.saved_tensors out_dtype = ctx.out_dtype # Convert grad_output to float8, row-major for left operand of grouped GEMM @@ -199,6 +190,14 @@ def backward(ctx, grad_output: torch.Tensor): grad_output_scaled, torch.float8_e4m3fn ) + # Compute B fp8 column-major for right operand of grouped GEMM: + # grad_A = grad_output @ B. + B_fp8_col_major, B_scales = triton_fp8_rowwise_3d_transpose_rhs( + B_t._data if hasattr(B_t, "_data") else B_t, + output_dtype=torch.float8_e4m3fn, + round_scales_to_power_of_2=True, + ) + # Compute grad_A. # grad_A = grad_output @ B # grad_A = scaled grouped mm of (M,N) @ (B,N,K) = (M,K) @@ -217,8 +216,8 @@ def backward(ctx, grad_output: torch.Tensor): grad_A = torch._scaled_grouped_mm( grad_output_fp8_row_major, B_fp8_col_major, - grad_output_scales.squeeze().reciprocal(), - B_scales.squeeze().reciprocal(), + grad_output_scales.reciprocal(), + B_scales.reciprocal(), offs, out_dtype=out_dtype, use_fast_accum=True,