Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions benchmarks/prototype/moe_training/benchmark_moe_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from benchmarks.prototype.moe_training.utils import (
bench_fwd_bwd_microseconds,
profile_fn,
profile_fwd_bwd,
)

# this feature requires CUDA and SM89+
Expand Down Expand Up @@ -128,7 +128,7 @@ def warmup(model, input):
print(f"BF16 time: {bf16_us} us")
if enable_profile:
print("Profiling bf16 training")
profile_fn(ref_model, ref_x, labels=labels, profile_name="bf16_profile")
profile_fwd_bwd(ref_model, ref_x, labels=labels, profile_name="bf16_profile")

scaled_us = bench_fwd_bwd_microseconds(
model,
Expand All @@ -140,7 +140,7 @@ def warmup(model, input):
print(f"Scaled time: {scaled_us} us")
if enable_profile:
print("Profiling quantized training")
profile_fn(model, x, labels=labels, profile_name=f"{recipe_name}_profile")
profile_fwd_bwd(model, x, labels=labels, profile_name=f"{recipe_name}_profile")

print(f"Speedup: {bf16_us / scaled_us:.3f}x")
dist.destroy_process_group()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
triton_fp8_per_group_rowwise_scales,
)
from torchao.prototype.moe_training.utils import (
generate_jagged_offs,
torch_to_float8_per_group_colwise,
torch_to_float8_per_group_rowwise,
)
Expand All @@ -40,6 +41,8 @@ class ExperimentConfig:
class ExperimentResult:
torch_time_us: float
triton_time_us: float
torch_mem_bw_gbps: float
triton_mem_bw_gbps: float


@dataclass(frozen=True)
Expand All @@ -50,7 +53,7 @@ class Experiment:

def get_configs() -> List[ExperimentConfig]:
input_shapes = [(16640, 5120)] # (Mg, K)
n_groups_list = [16, 128]
n_groups_list = [1, 16, 128]
high_precision_dtypes = [torch.bfloat16]
configs = []
for input_shape, n_groups, high_precision_dtype in itertools.product(
Expand Down Expand Up @@ -81,15 +84,9 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
# that occurs in the backward pass of the differentiable scaled grouped mm.
# - the transposed tensor in col-major format with groups along the row dimension,
# which represents the right operand.
group_size = input_row_major.shape[1] // config.n_groups
n_groups = config.n_groups
offs = torch.arange(
group_size,
group_size * n_groups + 1,
group_size,
device=device,
dtype=torch.int32,
)
Mg = input_row_major.shape[0]
offs = generate_jagged_offs(n_groups, Mg, multiple_of=16)

def warmup(func, *args, **kwargs):
for _ in range(10):
Expand Down Expand Up @@ -140,9 +137,21 @@ def run_triton(
run_triton, input_row_major, input_col_major, offs
)

# mem bw calculations - excluding scales to simplify calculation
# but still get an accurate estimate.
bytes_per_input_el = torch.finfo(config.high_precision_dtype).bits / 8
num_elements = input_tensor.numel()
read_bytes = num_elements * bytes_per_input_el
write_bytes = num_elements # 1 byte per element in float8_e4m3fn
read_write_bytes = read_bytes + write_bytes
torch_mem_bw_gbps = (read_write_bytes) / (torch_time_us / 1e6) / 1e9
triton_mem_bw_gbps = (read_write_bytes) / (triton_time_us / 1e6) / 1e9

return ExperimentResult(
torch_time_us=torch_time_us,
triton_time_us=triton_time_us,
torch_mem_bw_gbps=torch_mem_bw_gbps,
triton_mem_bw_gbps=triton_mem_bw_gbps,
)


Expand All @@ -153,6 +162,8 @@ def print_results(experiments: List[Experiment]):
"high_precision_dtype",
"torch_time_us",
"triton_time_us",
"torch_mem_bw_gbps",
"triton_mem_bw_gbps",
"triton_speedup",
]
rows = []
Expand All @@ -167,6 +178,8 @@ def print_results(experiments: List[Experiment]):
experiment.config.high_precision_dtype,
experiment.result.torch_time_us,
experiment.result.triton_time_us,
round(experiment.result.torch_mem_bw_gbps, 3),
round(experiment.result.triton_mem_bw_gbps, 3),
f"{experiment.result.torch_time_us / experiment.result.triton_time_us:.2f}x",
]
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class ExperimentConfig:
class ExperimentResult:
torch_time_us: float
triton_time_us: float
torch_mem_bw_gbps: float
triton_mem_bw_gbps: float


@dataclass(frozen=True)
Expand All @@ -48,8 +50,12 @@ class Experiment:
def get_configs() -> List[ExperimentConfig]:
# Llama4 shapes
input_shapes = [
(1, 8192, 5120), # w1, w3
(1, 5120, 8192), # w2
(16, 8192, 5120), # w1, w3
(16, 5120, 8192), # w2
(128, 8192, 5120), # w1, w3
(128, 5120, 8192), # w2
]
high_precision_dtypes = [torch.bfloat16]
configs = []
Expand Down Expand Up @@ -110,9 +116,25 @@ def run_triton(input_tensor: torch.Tensor):
input_tensor,
)

# mem bw calculations - excluding scales to simplify calculation
# but still get an accurate estimate.
bytes_per_input_el = torch.finfo(config.high_precision_dtype).bits / 8
bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8
num_elements = input_tensor.numel()

read_bytes = num_elements * bytes_per_input_el
write_bytes = num_elements * bytes_per_output_el

# Both torch.compile codegen and the triton kernel read the input tensor twice
# (once for scale calculations, once for scaling + casting).
torch_mem_bw_gbps = ((read_bytes * 2 + write_bytes) / 1e9) / (torch_time_us / 1e6)
triton_mem_bw_gbps = ((read_bytes * 2 + write_bytes) / 1e9) / (triton_time_us / 1e6)

return ExperimentResult(
torch_time_us=torch_time_us,
triton_time_us=triton_time_us,
torch_mem_bw_gbps=torch_mem_bw_gbps,
triton_mem_bw_gbps=triton_mem_bw_gbps,
)


Expand All @@ -121,6 +143,8 @@ def print_results(experiments: List[Experiment]):
"input_shape",
"torch_time_us",
"triton_time_us",
"torch_mem_bw_gbps",
"triton_mem_bw_gbps",
"triton_speedup",
]
rows = []
Expand All @@ -131,6 +155,8 @@ def print_results(experiments: List[Experiment]):
input_shape,
experiment.result.torch_time_us,
experiment.result.triton_time_us,
round(experiment.result.torch_mem_bw_gbps, 3),
round(experiment.result.triton_mem_bw_gbps, 3),
f"{experiment.result.torch_time_us / experiment.result.triton_time_us:.2f}x",
]
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch
from tabulate import tabulate
from tqdm import tqdm
from utils import bench_fwd_bwd_microseconds, profile_fn
from utils import bench_fwd_bwd_microseconds, profile_fwd_bwd

from torchao.prototype.moe_training import _scaled_grouped_mm
from torchao.prototype.moe_training.conversion_utils import MoEScalingType
Expand Down Expand Up @@ -46,8 +46,9 @@ class Experiment:


def get_configs() -> List[ExperimentConfig]:
# Llama4 shapes
A_shapes = [(16640, 5120)]
B_shapes = [(16, 8192, 5120)]
B_shapes = [(1, 8192, 5120), (16, 8192, 5120), (128, 8192, 5120)]
recipes = [MoEScalingType.FP8_ROWWISE]
high_precision_dtypes = [torch.bfloat16]
configs = []
Expand Down Expand Up @@ -91,7 +92,8 @@ def run_experiment(
# - the transposed tensor in col-major format with groups along the row dimension,
# which represents the right operand.
n_groups = config.B_shape[0]
offs = generate_jagged_offs(n_groups, A.shape[0], multiple_of=16)
Mg = A.shape[0]
offs = generate_jagged_offs(n_groups, Mg, multiple_of=16)

labels = torch.ones(
(A.shape[0], B_t.shape[-1]), device=device, dtype=torch.bfloat16
Expand All @@ -107,7 +109,7 @@ def run_experiment(
use_compile=args.compile,
)
if args.profile:
profile_fn(
profile_fwd_bwd(
torch._grouped_mm,
A,
B_t,
Expand All @@ -128,7 +130,7 @@ def run_experiment(
use_compile=args.compile,
)
if args.profile:
profile_fn(
profile_fwd_bwd(
_scaled_grouped_mm,
A,
B_t,
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/prototype/moe_training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def bench_fwd_bwd_microseconds(
return statistics.median(times)


def profile_fn(
def profile_fwd_bwd(
fn,
*args,
labels=None,
Expand Down
23 changes: 16 additions & 7 deletions torchao/prototype/moe_training/kernels/float8_rowwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@
torch.float64: tl.float64,
}

block_sizes_n = [32, 128, 512] # large dim (output_features)
block_sizes_k = [32, 128, 512] # small dim (input_features)
num_warps = [8]
num_stages = [2, 4]
block_sizes_n = [32, 128, 256] # large dim (output_features)
block_sizes_k = [32, 128, 256] # small dim (input_features)
num_warps = [2, 4]
num_stages = [2, 3, 4, 5, 6]
kernel_configs_2D = [
triton.Config(
{"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k},
Expand Down Expand Up @@ -176,9 +176,18 @@ def _triton_fp8_rowwise_3d_transpose_scales_rhs_kernel(
input_dtype
)

# compute scales with local amax, using axis=0 because for each expert,
# we are reading the non-transposed input, and want to compute the scales
# along axis=1 for the transposed input.
# In a normal torch implementation, we should transpose the tensor then compute the amax
# along the dim1 (N), to compute colwise scales for a RHS operand of a scaled grouped gemm:
# input_data = input_data.transpose(-2,-1) # (E, K, N) -> (E, N, K)
# amaxes = input_data.abs().max(dim=1) # (E, N, K) -> (E, 1, K)
#
# Here, we are reading a (K, N) chunk for a given E, and computing the amax along the dim=1 (N)
# to compute an equivalent scale of shape (K,) for this chunk of the expert.
# We then use atomic min to compute the final scale for these logical columns of the transposed tensor.
#
# Later, we will use this scale to cast the same (K,N) input chunk to fp8 and transpose it to (N, K) before
# writing it to the output tensor.
# ((K, N) * (K, 1))^T = (N, K)
amaxes = tl.max(tl.abs(input_data), axis=1).to(tl.float64) # (K,)
scales = (fp8_dtype_max / tl.clamp(amaxes, min=EPS, max=float("inf"))).to(
tl.float32
Expand Down
Loading