66import transformer_engine .pytorch as te
77
88
9- def run_once (module : torch .nn .Module , args : List [torch .Tensor ], iters = 2000 , use_te : bool = True , is_first_microbatch : bool = True ):
9+ def run_once (
10+ module : torch .nn .Module ,
11+ args : List [torch .Tensor ],
12+ iters = 2000 ,
13+ use_te : bool = True ,
14+ is_first_microbatch : bool = True ,
15+ ):
1016 if use_te :
1117 for _ in range (iters ):
1218 module (* args , is_first_microbatch = is_first_microbatch )
@@ -45,36 +51,64 @@ def speedometer(
4551 gpu_times .append (gpu_elapsed )
4652 cpu_times .append (cpu_elapsed )
4753 print (
48- f"Round { round_idx + 1 } /{ num_rounds } : GPU { gpu_elapsed / timing_iters * 1000 :.2f} µs, CPU { cpu_elapsed / timing_iters * 1000 :.2f} µs"
54+ f"Round { round_idx + 1 } /{ num_rounds } : GPU { gpu_elapsed / timing_iters * 1000 :.2f} µs, CPU"
55+ f" { cpu_elapsed / timing_iters * 1000 :.2f} µs"
4956 )
50- print (f"Average GPU time over { num_rounds } rounds: { sum (gpu_times )/ (num_rounds * timing_iters )* 1000 :.2f} µs" )
51- print (f"Average CPU time over { num_rounds } rounds: { sum (cpu_times )/ (num_rounds * timing_iters )* 1000 :.2f} µs" )
57+ print (
58+ f"Average GPU time over { num_rounds } rounds:"
59+ f" { sum (gpu_times )/ (num_rounds * timing_iters )* 1000 :.2f} µs"
60+ )
61+ print (
62+ f"Average CPU time over { num_rounds } rounds:"
63+ f" { sum (cpu_times )/ (num_rounds * timing_iters )* 1000 :.2f} µs"
64+ )
5265
5366 return sum (gpu_times ) / num_rounds
5467
5568
5669def main ():
57- parser = argparse .ArgumentParser (description = "Benchmark torch.nn.Linear performance and CPU overhead." )
70+ parser = argparse .ArgumentParser (
71+ description = "Benchmark torch.nn.Linear performance and CPU overhead."
72+ )
5873 parser .add_argument ("--hidden_size" , type = int , default = 3072 , help = "Hidden size" )
5974 parser .add_argument ("--seq_length" , type = int , default = 2048 , help = "Sequence length" )
6075 parser .add_argument ("--warmup" , type = int , default = 500 , help = "Number of warmup iterations" )
61- parser .add_argument ("--timing_iters" , type = int , default = 2000 , help = "Number of timing iterations per round" )
76+ parser .add_argument (
77+ "--timing_iters" , type = int , default = 2000 , help = "Number of timing iterations per round"
78+ )
6279 parser .add_argument ("--num_rounds" , type = int , default = 5 , help = "Number of timing rounds" )
6380 parser .add_argument (
64- "--backend" , type = str , choices = ["torch" , "te" ], default = "te" , help = "Linear backend: torch or te"
81+ "--backend" ,
82+ type = str ,
83+ choices = ["torch" , "te" ],
84+ default = "te" ,
85+ help = "Linear backend: torch or te" ,
6586 )
6687 args = parser .parse_args ()
6788
68- x = torch .randn ((args .seq_length , args .hidden_size ), dtype = torch .bfloat16 , device = "cuda" , requires_grad = True )
89+ x = torch .randn (
90+ (args .seq_length , args .hidden_size ), dtype = torch .bfloat16 , device = "cuda" , requires_grad = True
91+ )
6992 use_te = True
7093 if args .backend == "torch" :
71- model = torch .nn .Linear (args .hidden_size , args .hidden_size , bias = False ).to (torch .bfloat16 ).cuda ()
94+ model = (
95+ torch .nn .Linear (args .hidden_size , args .hidden_size , bias = False )
96+ .to (torch .bfloat16 )
97+ .cuda ()
98+ )
7299 use_te = False
73100 else :
74- model = te .Linear (args .hidden_size , args .hidden_size , bias = False , device = "cuda" ).to (torch .bfloat16 )
101+ model = te .Linear (args .hidden_size , args .hidden_size , bias = False , device = "cuda" ).to (
102+ torch .bfloat16
103+ )
75104 with torch .no_grad ():
76105 avg_gpu_time_per_round = speedometer (
77- model , [x ], timing_iters = args .timing_iters , warmup_iters = args .warmup , num_rounds = args .num_rounds , use_te = use_te
106+ model ,
107+ [x ],
108+ timing_iters = args .timing_iters ,
109+ warmup_iters = args .warmup ,
110+ num_rounds = args .num_rounds ,
111+ use_te = use_te ,
78112 )
79113
80114 total_ops = 2 * args .hidden_size * args .hidden_size * args .seq_length * args .timing_iters
@@ -84,4 +118,4 @@ def main():
84118
85119
86120if __name__ == "__main__" :
87- main ()
121+ main ()
0 commit comments