Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 45 additions & 30 deletions benchmarks/prototype/blockwise_fp8_training/bench_linear_fwd_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py

import argparse
import itertools
from dataclasses import dataclass
from typing import List
Expand All @@ -15,7 +16,7 @@
from tqdm import tqdm
from triton.testing import do_bench

from benchmarks.utils import bench_fwd_bwd_microseconds
from benchmarks.utils import bench_fwd_bwd_microseconds, profile_fwd_bwd
from torchao.prototype.blockwise_fp8_training.linear import Float8BlockwiseLinear

device = torch.device("cuda")
Expand Down Expand Up @@ -71,7 +72,7 @@ def get_configs() -> List[ExperimentConfig]:
return configs


def run_experiment(config: ExperimentConfig) -> ExperimentResult:
def run_experiment(config: ExperimentConfig, profile=False, use_compile=False) -> ExperimentResult:
M, N, K = config.m, config.n, config.k
inputs = torch.randn(M, K, dtype=config.out_dtype, device="cuda")
bf16_linear = torch.nn.Linear(K, N, dtype=config.out_dtype, device="cuda")
Expand All @@ -83,49 +84,59 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
)

def warmup(func, *args, **kwargs):
for _ in range(10):
for _ in range(3):
func(*args, **kwargs)

def fwd_bwd(func, inputs, labels, *args, **kwargs):
out = func(inputs, *args, **kwargs)
loss = F.mse_loss(out, labels)
loss.backward()
torch.cuda.synchronize()

# Warmup then run bf16 torch.mm
# bfloat16 bench and profile
labels = inputs.new_empty(M, N).fill_(1.0)
warmup(fwd_bwd, bf16_linear, inputs, labels)

bf16_linear_us = benchmark_cuda_function_in_microseconds(
fwd_bwd, bf16_linear, inputs, labels
bf16_linear_us = bench_fwd_bwd_microseconds(
bf16_linear,
inputs,
labels=labels,
use_compile=use_compile,
)

# Warm up then run triton bench
warmup(
fwd_bwd,
fp8_triton_linear,
inputs,
labels,
if profile:
print("Profiling bf16_linear")
profile_fwd_bwd(
bf16_linear,
inputs,
labels=labels,
profile_name="bf16_linear_profile",
use_compile=use_compile,
)

# FP8 triton bench and profile
fp8_triton_linear_us = bench_fwd_bwd_microseconds(
fp8_triton_linear,
inputs,
labels=labels,
)
if profile:
print("Profiling fp8_triton_linear")
profile_fwd_bwd(
fp8_triton_linear,
inputs,
labels=labels,
profile_name="fp8_triton_linear_profile",
)

warmup(
fwd_bwd,
fp8_scaled_mm_linear,
inputs,
labels,
)

# FP8 torch._scaled_mm bench and profile
fp8_scaled_mm_linear_us = bench_fwd_bwd_microseconds(
fp8_scaled_mm_linear,
inputs,
labels=labels,
use_compile=use_compile,
)
if profile:
print("Profiling fp8_scaled_mm_linear")
profile_fwd_bwd(
fp8_scaled_mm_linear,
inputs,
labels=labels,
profile_name="fp8_scaled_mm_linear_profile",
use_compile=use_compile,
)

return ExperimentResult(
bf16_linear_us=bf16_linear_us,
Expand Down Expand Up @@ -165,17 +176,21 @@ def benchmark_cuda_function_in_microseconds(f, *args, **kwargs):
return do_bench(lambda: f(*args, **kwargs), return_mode="median") * 1e3


def main():
def main(args: argparse.Namespace):
torch.random.manual_seed(123)
configs = get_configs()
results = []
for config in tqdm(configs):
result = run_experiment(config)
result = run_experiment(config, profile=args.profile, use_compile=args.compile)
results.append(Experiment(config=config, result=result))

# Use Tabulate to print results
print_results(results)


if __name__ == "__main__":
main()
parser = argparse.ArgumentParser()
parser.add_argument("--profile", action="store_true", help="Enable profiling")
parser.add_argument("--compile", action="store_true", help="Enable compilation")
args = parser.parse_args()
main(args)
Loading