Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,13 @@ def run_torch(input_tensor: torch.Tensor):
return out

def run_triton(input_tensor: torch.Tensor):
_ = triton_fp8_rowwise_3d_transpose_rhs(
out = triton_fp8_rowwise_3d_transpose_rhs(
input_tensor,
output_dtype=torch.float8_e4m3fn,
round_scales_to_power_of_2=True,
)
torch.cuda.synchronize()
return out

# bench torch
compiled_run_torch = torch.compile(run_torch)
Expand Down
80 changes: 45 additions & 35 deletions benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,17 @@
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py
import argparse
import itertools
import time
from dataclasses import dataclass
from typing import List

import torch
from tabulate import tabulate
from tqdm import tqdm
from utils import bench_fwd_bwd_microseconds

from torchao.prototype.moe_training import _scaled_grouped_mm
from torchao.prototype.moe_training.conversion_utils import MoEScalingType
from torchao.prototype.moe_training.utils import generate_jagged_offs

device = torch.device("cuda")

Expand All @@ -27,11 +29,14 @@ class ExperimentConfig:
high_precision_dtype: torch.dtype
A_shape: tuple[int]
B_shape: tuple[int]
recipe: MoEScalingType


@dataclass(frozen=True)
class ExperimentResult:
time_us: float
bf16_us: float
fp8_us: float
fp8_speedup: float


@dataclass(frozen=True)
Expand All @@ -41,19 +46,22 @@ class Experiment:


def get_configs() -> List[ExperimentConfig]:
A_shapes = [(2**8, 8192), (2**12, 8192), (2**16, 8192)]
B_shapes = [(4, 8192, 8192), (8, 8192, 8192), (16, 8192, 8192)]
A_shapes = [(16640, 5120)]
B_shapes = [(16, 8192, 5120), (128, 8192, 5120)]
recipes = [MoEScalingType.FP8_ROWWISE]
high_precision_dtypes = [torch.bfloat16]
configs = []
for A_shape, B_shape, high_precision_dtype in itertools.product(
for A_shape, B_shape, recipe, high_precision_dtype in itertools.product(
A_shapes,
B_shapes,
recipes,
high_precision_dtypes,
):
configs.append(
ExperimentConfig(
A_shape=A_shape,
B_shape=B_shape,
recipe=recipe,
high_precision_dtype=high_precision_dtype,
)
)
Expand Down Expand Up @@ -83,47 +91,47 @@ def run_experiment(
# - the transposed tensor in col-major format with groups along the row dimension,
# which represents the right operand.
n_groups = config.B_shape[0]
group_size = A.shape[0] // n_groups
offs = torch.arange(
group_size,
group_size * n_groups + 1,
group_size,
device=device,
dtype=torch.int32,
)
offs = generate_jagged_offs(n_groups, A.shape[0], multiple_of=16)

def warmup(func, *args, **kwargs):
for _ in range(10):
func(*args, **kwargs)
labels = torch.ones(
(A.shape[0], B_t.shape[-1]), device=device, dtype=torch.bfloat16
)

def forward_backward(A, B_t, offs):
out = _scaled_grouped_mm(
A,
B_t,
offs=offs,
out_dtype=torch.bfloat16,
)
out.sum().backward()
torch.cuda.synchronize()
# benchmark bf16 grouped mm
bf16_us = bench_fwd_bwd_microseconds(
torch._grouped_mm,
A,
B_t,
offs,
labels=labels,
use_compile=args.compile,
)

# benchmark torch
torch_func = torch.compile(forward_backward) if args.compile else forward_backward
warmup(torch_func, A, B_t, offs)
start_time_ns = time.perf_counter_ns()
torch_func(A, B_t, offs)
torch_time_ns = time.perf_counter_ns() - start_time_ns
time_us = torch_time_ns / 1e3
# benchmark scaled grouped mm with dynamic fp8 rowwise quant
fp8_us = bench_fwd_bwd_microseconds(
_scaled_grouped_mm,
A,
B_t,
offs,
scaling_type=config.recipe,
labels=labels,
use_compile=args.compile,
)

return ExperimentResult(
time_us=round(time_us, 3),
bf16_us=round(bf16_us, 3),
fp8_us=round(fp8_us, 3),
fp8_speedup=round(bf16_us / fp8_us, 3),
)


def print_results(experiments: List[Experiment]):
headers = [
"A_shape",
"B_shape",
"time_us",
"bf16_time_us",
"scaled_time_us",
"fp8_speedup",
]
rows = []
for experiment in experiments:
Expand All @@ -133,7 +141,9 @@ def print_results(experiments: List[Experiment]):
[
A_shape,
B_shape,
experiment.result.time_us,
experiment.result.bf16_us,
experiment.result.fp8_us,
f"{experiment.result.fp8_speedup}x",
]
)
print(tabulate(rows, headers=headers))
Expand Down
21 changes: 21 additions & 0 deletions benchmarks/prototype/moe_training/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import statistics
from time import perf_counter_ns

import torch
from torch.nn import functional as F


def bench_fwd_bwd_microseconds(fn, *args, labels=None, use_compile=False, **kwargs):
assert labels is not None
fn = torch.compile(fn, fullgraph=False) if use_compile else fn
times = []
for _ in range(10):
start_ns = perf_counter_ns()
out = fn(*args, **kwargs)
loss = F.mse_loss(out, labels)
loss.backward()
torch.cuda.synchronize()
end_ns = perf_counter_ns()
duration_us = (end_ns - start_ns) / 1000
times.append(duration_us)
return statistics.median(times)
25 changes: 10 additions & 15 deletions torchao/prototype/moe_training/kernels/float8_rowwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def triton_fp8_rowwise_3d_transpose_rhs(
) -> Tuple[torch.Tensor, torch.Tensor]:
assert hp_tensor.ndim == 3, "input tensor must be 3D"

num_elements = hp_tensor.numel()
tl_input_dtype = FP8_DTYPE_MAP[hp_tensor.dtype]
tl_output_dtype = FP8_DTYPE_MAP[output_dtype]

Expand Down Expand Up @@ -89,7 +88,6 @@ def triton_fp8_rowwise_3d_transpose_rhs(
e,
n,
k,
num_elements,
fp8_dtype_min,
fp8_dtype_max,
tl_input_dtype,
Expand All @@ -113,7 +111,6 @@ def triton_fp8_rowwise_3d_transpose_rhs(
e,
n,
k,
num_elements,
fp8_dtype_min,
fp8_dtype_max,
tl_input_dtype,
Expand All @@ -138,20 +135,19 @@ def _fake_triton_fp8_rowwise_3d_transpose_rhs(
return output_buffer, scales_buffer


@triton.autotune(configs=kernel_configs_2D, key=["num_elements"])
@triton.autotune(configs=kernel_configs_2D, key=["K", "N"])
@triton.jit
def _triton_fp8_rowwise_3d_transpose_scales_rhs_kernel(
input_ptr,
stride_input_dim0: int,
stride_input_dim1: int,
stride_input_dim2: int,
stride_input_dim0: tl.int64,
stride_input_dim1: tl.int64,
stride_input_dim2: tl.int64,
scales_ptr,
stride_scales_dim0: int,
stride_scales_dim1: int,
E: int,
N: int,
K: int,
num_elements: int,
fp8_dtype_min: tl.constexpr,
fp8_dtype_max: tl.constexpr,
input_dtype: tl.constexpr,
Expand Down Expand Up @@ -202,20 +198,19 @@ def _triton_fp8_rowwise_3d_transpose_scales_rhs_kernel(
@triton.jit
def _triton_fp8_rowwise_3d_transpose_cast_rhs_kernel(
input_ptr,
stride_input_dim0: int,
stride_input_dim1: int,
stride_input_dim2: int,
stride_input_dim0: tl.int64,
stride_input_dim1: tl.int64,
stride_input_dim2: tl.int64,
output_ptr,
stride_output_dim0: int,
stride_output_dim1: int,
stride_output_dim2: int,
stride_output_dim0: tl.int64,
stride_output_dim1: tl.int64,
stride_output_dim2: tl.int64,
scales_ptr,
stride_scales_dim0: int,
stride_scales_dim1: int,
E: int,
N: int,
K: int,
num_elements: int,
fp8_dtype_min: tl.constexpr,
fp8_dtype_max: tl.constexpr,
input_dtype: tl.constexpr,
Expand Down
27 changes: 13 additions & 14 deletions torchao/prototype/moe_training/scaled_grouped_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def _scaled_grouped_mm(
"""
# TODO: Remove logging once prototype is more mature. This is currently very useful for development and debugging.
if scaling_type == MoEScalingType.FP8_ROWWISE:
print("Using fp8 rowwise scaled_grouped_mm")
# print("Using fp8 rowwise scaled_grouped_mm")
return _Float8GroupedMM.apply(
A,
B_t,
Expand Down Expand Up @@ -140,17 +140,8 @@ def forward(
B_t_scaled = B_t.to(torch.float32) * B_t_scales
B_t_fp8_col_major = to_fp8_saturated(B_t_scaled, torch.float8_e4m3fn)

# Precompute non-transposed B column-major for backward, to save memory by storing the
# low precision B tensor instead of the high precision B tensor.
# In the backward this is needed for grad_A: grad_output @ B.
B_fp8_col_major, B_scales = triton_fp8_rowwise_3d_transpose_rhs(
B_t._data,
output_dtype=torch.float8_e4m3fn,
round_scales_to_power_of_2=True,
)

# Store what we need for backward.
ctx.save_for_backward(A, B_fp8_col_major, B_scales, offs)
ctx.save_for_backward(A, B_t, offs)
ctx.out_dtype = out_dtype

# Perform scaled grouped GEMM and return result.
Expand Down Expand Up @@ -179,7 +170,7 @@ def forward(

@staticmethod
def backward(ctx, grad_output: torch.Tensor):
A, B_fp8_col_major, B_scales, offs = ctx.saved_tensors
A, B_t, offs = ctx.saved_tensors
out_dtype = ctx.out_dtype

# Convert grad_output to float8, row-major for left operand of grouped GEMM
Expand All @@ -199,6 +190,14 @@ def backward(ctx, grad_output: torch.Tensor):
grad_output_scaled, torch.float8_e4m3fn
)

# Compute B fp8 column-major for right operand of grouped GEMM:
# grad_A = grad_output @ B.
B_fp8_col_major, B_scales = triton_fp8_rowwise_3d_transpose_rhs(
B_t._data if hasattr(B_t, "_data") else B_t,
output_dtype=torch.float8_e4m3fn,
round_scales_to_power_of_2=True,
)

# Compute grad_A.
# grad_A = grad_output @ B
# grad_A = scaled grouped mm of (M,N) @ (B,N,K) = (M,K)
Expand All @@ -217,8 +216,8 @@ def backward(ctx, grad_output: torch.Tensor):
grad_A = torch._scaled_grouped_mm(
grad_output_fp8_row_major,
B_fp8_col_major,
grad_output_scales.squeeze().reciprocal(),
B_scales.squeeze().reciprocal(),
grad_output_scales.reciprocal(),
B_scales.reciprocal(),
offs,
out_dtype=out_dtype,
use_fast_accum=True,
Expand Down
Loading