diff --git a/benchmarks/prototype/blockwise_fp8_training/bench_linear_fwd_bwd.py b/benchmarks/prototype/blockwise_fp8_training/bench_linear_fwd_bwd.py index e8a4785624..a46a8d3060 100644 --- a/benchmarks/prototype/blockwise_fp8_training/bench_linear_fwd_bwd.py +++ b/benchmarks/prototype/blockwise_fp8_training/bench_linear_fwd_bwd.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. # this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py +import argparse import itertools from dataclasses import dataclass from typing import List @@ -15,7 +16,7 @@ from tqdm import tqdm from triton.testing import do_bench -from benchmarks.utils import bench_fwd_bwd_microseconds +from benchmarks.utils import bench_fwd_bwd_microseconds, profile_fwd_bwd from torchao.prototype.blockwise_fp8_training.linear import Float8BlockwiseLinear device = torch.device("cuda") @@ -71,7 +72,7 @@ def get_configs() -> List[ExperimentConfig]: return configs -def run_experiment(config: ExperimentConfig) -> ExperimentResult: +def run_experiment(config: ExperimentConfig, profile=False, use_compile=False) -> ExperimentResult: M, N, K = config.m, config.n, config.k inputs = torch.randn(M, K, dtype=config.out_dtype, device="cuda") bf16_linear = torch.nn.Linear(K, N, dtype=config.out_dtype, device="cuda") @@ -83,49 +84,59 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult: ) def warmup(func, *args, **kwargs): - for _ in range(10): + for _ in range(3): func(*args, **kwargs) - def fwd_bwd(func, inputs, labels, *args, **kwargs): - out = func(inputs, *args, **kwargs) - loss = F.mse_loss(out, labels) - loss.backward() - torch.cuda.synchronize() - # Warmup then run bf16 torch.mm + # bfloat16 bench and profile labels = inputs.new_empty(M, N).fill_(1.0) - warmup(fwd_bwd, bf16_linear, inputs, labels) - - bf16_linear_us = benchmark_cuda_function_in_microseconds( - fwd_bwd, bf16_linear, inputs, labels + bf16_linear_us = bench_fwd_bwd_microseconds( + bf16_linear, + inputs, + labels=labels, + use_compile=use_compile, ) - - # Warm up then run triton bench - warmup( - fwd_bwd, - fp8_triton_linear, - inputs, - labels, + if profile: + print("Profiling bf16_linear") + profile_fwd_bwd( + bf16_linear, + inputs, + labels=labels, + profile_name="bf16_linear_profile", + use_compile=use_compile, ) + # FP8 triton bench and profile fp8_triton_linear_us = bench_fwd_bwd_microseconds( fp8_triton_linear, inputs, labels=labels, ) + if profile: + print("Profiling fp8_triton_linear") + profile_fwd_bwd( + fp8_triton_linear, + inputs, + labels=labels, + profile_name="fp8_triton_linear_profile", + ) - warmup( - fwd_bwd, - fp8_scaled_mm_linear, - inputs, - labels, - ) - + # FP8 torch._scaled_mm bench and profile fp8_scaled_mm_linear_us = bench_fwd_bwd_microseconds( fp8_scaled_mm_linear, inputs, labels=labels, + use_compile=use_compile, ) + if profile: + print("Profiling fp8_scaled_mm_linear") + profile_fwd_bwd( + fp8_scaled_mm_linear, + inputs, + labels=labels, + profile_name="fp8_scaled_mm_linear_profile", + use_compile=use_compile, + ) return ExperimentResult( bf16_linear_us=bf16_linear_us, @@ -165,12 +176,12 @@ def benchmark_cuda_function_in_microseconds(f, *args, **kwargs): return do_bench(lambda: f(*args, **kwargs), return_mode="median") * 1e3 -def main(): +def main(args: argparse.Namespace): torch.random.manual_seed(123) configs = get_configs() results = [] for config in tqdm(configs): - result = run_experiment(config) + result = run_experiment(config, profile=args.profile, use_compile=args.compile) results.append(Experiment(config=config, result=result)) # Use Tabulate to print results @@ -178,4 +189,8 @@ def main(): if __name__ == "__main__": - main() + parser = argparse.ArgumentParser() + parser.add_argument("--profile", action="store_true", help="Enable profiling") + parser.add_argument("--compile", action="store_true", help="Enable compilation") + args = parser.parse_args() + main(args)