diff --git a/benchmarks/prototype/moe_training/benchmark_moe_fsdp.py b/benchmarks/prototype/moe_training/benchmark_moe_fsdp.py index 84453fa242..1011d2609b 100644 --- a/benchmarks/prototype/moe_training/benchmark_moe_fsdp.py +++ b/benchmarks/prototype/moe_training/benchmark_moe_fsdp.py @@ -24,7 +24,7 @@ from benchmarks.prototype.moe_training.utils import ( bench_fwd_bwd_microseconds, - profile_fn, + profile_fwd_bwd, ) # this feature requires CUDA and SM89+ @@ -128,7 +128,7 @@ def warmup(model, input): print(f"BF16 time: {bf16_us} us") if enable_profile: print("Profiling bf16 training") - profile_fn(ref_model, ref_x, labels=labels, profile_name="bf16_profile") + profile_fwd_bwd(ref_model, ref_x, labels=labels, profile_name="bf16_profile") scaled_us = bench_fwd_bwd_microseconds( model, @@ -140,7 +140,7 @@ def warmup(model, input): print(f"Scaled time: {scaled_us} us") if enable_profile: print("Profiling quantized training") - profile_fn(model, x, labels=labels, profile_name=f"{recipe_name}_profile") + profile_fwd_bwd(model, x, labels=labels, profile_name=f"{recipe_name}_profile") print(f"Speedup: {bf16_us / scaled_us:.3f}x") dist.destroy_process_group() diff --git a/benchmarks/prototype/moe_training/benchmark_per_group_scaling_kernels.py b/benchmarks/prototype/moe_training/benchmark_per_group_scaling_kernels.py index f180bb15ac..7fbf48c285 100644 --- a/benchmarks/prototype/moe_training/benchmark_per_group_scaling_kernels.py +++ b/benchmarks/prototype/moe_training/benchmark_per_group_scaling_kernels.py @@ -19,6 +19,7 @@ triton_fp8_per_group_rowwise_scales, ) from torchao.prototype.moe_training.utils import ( + generate_jagged_offs, torch_to_float8_per_group_colwise, torch_to_float8_per_group_rowwise, ) @@ -40,6 +41,8 @@ class ExperimentConfig: class ExperimentResult: torch_time_us: float triton_time_us: float + torch_mem_bw_gbps: float + triton_mem_bw_gbps: float @dataclass(frozen=True) @@ -50,7 +53,7 @@ class Experiment: def get_configs() -> List[ExperimentConfig]: input_shapes = [(16640, 5120)] # (Mg, K) - n_groups_list = [16, 128] + n_groups_list = [1, 16, 128] high_precision_dtypes = [torch.bfloat16] configs = [] for input_shape, n_groups, high_precision_dtype in itertools.product( @@ -81,15 +84,9 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult: # that occurs in the backward pass of the differentiable scaled grouped mm. # - the transposed tensor in col-major format with groups along the row dimension, # which represents the right operand. - group_size = input_row_major.shape[1] // config.n_groups n_groups = config.n_groups - offs = torch.arange( - group_size, - group_size * n_groups + 1, - group_size, - device=device, - dtype=torch.int32, - ) + Mg = input_row_major.shape[0] + offs = generate_jagged_offs(n_groups, Mg, multiple_of=16) def warmup(func, *args, **kwargs): for _ in range(10): @@ -140,9 +137,21 @@ def run_triton( run_triton, input_row_major, input_col_major, offs ) + # mem bw calculations - excluding scales to simplify calculation + # but still get an accurate estimate. + bytes_per_input_el = torch.finfo(config.high_precision_dtype).bits / 8 + num_elements = input_tensor.numel() + read_bytes = num_elements * bytes_per_input_el + write_bytes = num_elements # 1 byte per element in float8_e4m3fn + read_write_bytes = read_bytes + write_bytes + torch_mem_bw_gbps = (read_write_bytes) / (torch_time_us / 1e6) / 1e9 + triton_mem_bw_gbps = (read_write_bytes) / (triton_time_us / 1e6) / 1e9 + return ExperimentResult( torch_time_us=torch_time_us, triton_time_us=triton_time_us, + torch_mem_bw_gbps=torch_mem_bw_gbps, + triton_mem_bw_gbps=triton_mem_bw_gbps, ) @@ -153,6 +162,8 @@ def print_results(experiments: List[Experiment]): "high_precision_dtype", "torch_time_us", "triton_time_us", + "torch_mem_bw_gbps", + "triton_mem_bw_gbps", "triton_speedup", ] rows = [] @@ -167,6 +178,8 @@ def print_results(experiments: List[Experiment]): experiment.config.high_precision_dtype, experiment.result.torch_time_us, experiment.result.triton_time_us, + round(experiment.result.torch_mem_bw_gbps, 3), + round(experiment.result.triton_mem_bw_gbps, 3), f"{experiment.result.torch_time_us / experiment.result.triton_time_us:.2f}x", ] ) diff --git a/benchmarks/prototype/moe_training/benchmark_rowwise_3d_quant_kernels.py b/benchmarks/prototype/moe_training/benchmark_rowwise_3d_quant_kernels.py index 53518ba491..54bfab6764 100644 --- a/benchmarks/prototype/moe_training/benchmark_rowwise_3d_quant_kernels.py +++ b/benchmarks/prototype/moe_training/benchmark_rowwise_3d_quant_kernels.py @@ -37,6 +37,8 @@ class ExperimentConfig: class ExperimentResult: torch_time_us: float triton_time_us: float + torch_mem_bw_gbps: float + triton_mem_bw_gbps: float @dataclass(frozen=True) @@ -48,8 +50,12 @@ class Experiment: def get_configs() -> List[ExperimentConfig]: # Llama4 shapes input_shapes = [ + (1, 8192, 5120), # w1, w3 + (1, 5120, 8192), # w2 (16, 8192, 5120), # w1, w3 (16, 5120, 8192), # w2 + (128, 8192, 5120), # w1, w3 + (128, 5120, 8192), # w2 ] high_precision_dtypes = [torch.bfloat16] configs = [] @@ -110,9 +116,25 @@ def run_triton(input_tensor: torch.Tensor): input_tensor, ) + # mem bw calculations - excluding scales to simplify calculation + # but still get an accurate estimate. + bytes_per_input_el = torch.finfo(config.high_precision_dtype).bits / 8 + bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8 + num_elements = input_tensor.numel() + + read_bytes = num_elements * bytes_per_input_el + write_bytes = num_elements * bytes_per_output_el + + # Both torch.compile codegen and the triton kernel read the input tensor twice + # (once for scale calculations, once for scaling + casting). + torch_mem_bw_gbps = ((read_bytes * 2 + write_bytes) / 1e9) / (torch_time_us / 1e6) + triton_mem_bw_gbps = ((read_bytes * 2 + write_bytes) / 1e9) / (triton_time_us / 1e6) + return ExperimentResult( torch_time_us=torch_time_us, triton_time_us=triton_time_us, + torch_mem_bw_gbps=torch_mem_bw_gbps, + triton_mem_bw_gbps=triton_mem_bw_gbps, ) @@ -121,6 +143,8 @@ def print_results(experiments: List[Experiment]): "input_shape", "torch_time_us", "triton_time_us", + "torch_mem_bw_gbps", + "triton_mem_bw_gbps", "triton_speedup", ] rows = [] @@ -131,6 +155,8 @@ def print_results(experiments: List[Experiment]): input_shape, experiment.result.torch_time_us, experiment.result.triton_time_us, + round(experiment.result.torch_mem_bw_gbps, 3), + round(experiment.result.triton_mem_bw_gbps, 3), f"{experiment.result.torch_time_us / experiment.result.triton_time_us:.2f}x", ] ) diff --git a/benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py b/benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py index e95f4293be..03e56d0e96 100644 --- a/benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py +++ b/benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py @@ -12,7 +12,7 @@ import torch from tabulate import tabulate from tqdm import tqdm -from utils import bench_fwd_bwd_microseconds, profile_fn +from utils import bench_fwd_bwd_microseconds, profile_fwd_bwd from torchao.prototype.moe_training import _scaled_grouped_mm from torchao.prototype.moe_training.conversion_utils import MoEScalingType @@ -46,8 +46,9 @@ class Experiment: def get_configs() -> List[ExperimentConfig]: + # Llama4 shapes A_shapes = [(16640, 5120)] - B_shapes = [(16, 8192, 5120)] + B_shapes = [(1, 8192, 5120), (16, 8192, 5120), (128, 8192, 5120)] recipes = [MoEScalingType.FP8_ROWWISE] high_precision_dtypes = [torch.bfloat16] configs = [] @@ -91,7 +92,8 @@ def run_experiment( # - the transposed tensor in col-major format with groups along the row dimension, # which represents the right operand. n_groups = config.B_shape[0] - offs = generate_jagged_offs(n_groups, A.shape[0], multiple_of=16) + Mg = A.shape[0] + offs = generate_jagged_offs(n_groups, Mg, multiple_of=16) labels = torch.ones( (A.shape[0], B_t.shape[-1]), device=device, dtype=torch.bfloat16 @@ -107,7 +109,7 @@ def run_experiment( use_compile=args.compile, ) if args.profile: - profile_fn( + profile_fwd_bwd( torch._grouped_mm, A, B_t, @@ -128,7 +130,7 @@ def run_experiment( use_compile=args.compile, ) if args.profile: - profile_fn( + profile_fwd_bwd( _scaled_grouped_mm, A, B_t, diff --git a/benchmarks/prototype/moe_training/utils.py b/benchmarks/prototype/moe_training/utils.py index 13f0dc9c6e..b880db7b32 100644 --- a/benchmarks/prototype/moe_training/utils.py +++ b/benchmarks/prototype/moe_training/utils.py @@ -23,7 +23,7 @@ def bench_fwd_bwd_microseconds( return statistics.median(times) -def profile_fn( +def profile_fwd_bwd( fn, *args, labels=None, diff --git a/torchao/prototype/moe_training/kernels/float8_rowwise.py b/torchao/prototype/moe_training/kernels/float8_rowwise.py index 5c084ca1b5..3f72aecebe 100644 --- a/torchao/prototype/moe_training/kernels/float8_rowwise.py +++ b/torchao/prototype/moe_training/kernels/float8_rowwise.py @@ -26,10 +26,10 @@ torch.float64: tl.float64, } -block_sizes_n = [32, 128, 512] # large dim (output_features) -block_sizes_k = [32, 128, 512] # small dim (input_features) -num_warps = [8] -num_stages = [2, 4] +block_sizes_n = [32, 128, 256] # large dim (output_features) +block_sizes_k = [32, 128, 256] # small dim (input_features) +num_warps = [2, 4] +num_stages = [2, 3, 4, 5, 6] kernel_configs_2D = [ triton.Config( {"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k}, @@ -176,9 +176,18 @@ def _triton_fp8_rowwise_3d_transpose_scales_rhs_kernel( input_dtype ) - # compute scales with local amax, using axis=0 because for each expert, - # we are reading the non-transposed input, and want to compute the scales - # along axis=1 for the transposed input. + # In a normal torch implementation, we should transpose the tensor then compute the amax + # along the dim1 (N), to compute colwise scales for a RHS operand of a scaled grouped gemm: + # input_data = input_data.transpose(-2,-1) # (E, K, N) -> (E, N, K) + # amaxes = input_data.abs().max(dim=1) # (E, N, K) -> (E, 1, K) + # + # Here, we are reading a (K, N) chunk for a given E, and computing the amax along the dim=1 (N) + # to compute an equivalent scale of shape (K,) for this chunk of the expert. + # We then use atomic min to compute the final scale for these logical columns of the transposed tensor. + # + # Later, we will use this scale to cast the same (K,N) input chunk to fp8 and transpose it to (N, K) before + # writing it to the output tensor. + # ((K, N) * (K, 1))^T = (N, K) amaxes = tl.max(tl.abs(input_data), axis=1).to(tl.float64) # (K,) scales = (fp8_dtype_max / tl.clamp(amaxes, min=EPS, max=float("inf"))).to( tl.float32