Skip to content

Commit dc2532a

Browse files
committed
apply to more modules
Signed-off-by: zhongboz <[email protected]>
1 parent a831adb commit dc2532a

File tree

5 files changed

+75
-23
lines changed

5 files changed

+75
-23
lines changed

benchmarks/linear/benchmark_linear_cpu_overhead.py

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,13 @@
66
import transformer_engine.pytorch as te
77

88

9-
def run_once(module: torch.nn.Module, args: List[torch.Tensor], iters = 2000, use_te: bool = True, is_first_microbatch: bool = True):
9+
def run_once(
10+
module: torch.nn.Module,
11+
args: List[torch.Tensor],
12+
iters=2000,
13+
use_te: bool = True,
14+
is_first_microbatch: bool = True,
15+
):
1016
if use_te:
1117
for _ in range(iters):
1218
module(*args, is_first_microbatch=is_first_microbatch)
@@ -45,36 +51,64 @@ def speedometer(
4551
gpu_times.append(gpu_elapsed)
4652
cpu_times.append(cpu_elapsed)
4753
print(
48-
f"Round {round_idx+1}/{num_rounds}: GPU {gpu_elapsed/timing_iters*1000:.2f} µs, CPU {cpu_elapsed/timing_iters*1000:.2f} µs"
54+
f"Round {round_idx+1}/{num_rounds}: GPU {gpu_elapsed/timing_iters*1000:.2f} µs, CPU"
55+
f" {cpu_elapsed/timing_iters*1000:.2f} µs"
4956
)
50-
print(f"Average GPU time over {num_rounds} rounds: {sum(gpu_times)/(num_rounds*timing_iters)*1000:.2f} µs")
51-
print(f"Average CPU time over {num_rounds} rounds: {sum(cpu_times)/(num_rounds*timing_iters)*1000:.2f} µs")
57+
print(
58+
f"Average GPU time over {num_rounds} rounds:"
59+
f" {sum(gpu_times)/(num_rounds*timing_iters)*1000:.2f} µs"
60+
)
61+
print(
62+
f"Average CPU time over {num_rounds} rounds:"
63+
f" {sum(cpu_times)/(num_rounds*timing_iters)*1000:.2f} µs"
64+
)
5265

5366
return sum(gpu_times) / num_rounds
5467

5568

5669
def main():
57-
parser = argparse.ArgumentParser(description="Benchmark torch.nn.Linear performance and CPU overhead.")
70+
parser = argparse.ArgumentParser(
71+
description="Benchmark torch.nn.Linear performance and CPU overhead."
72+
)
5873
parser.add_argument("--hidden_size", type=int, default=3072, help="Hidden size")
5974
parser.add_argument("--seq_length", type=int, default=2048, help="Sequence length")
6075
parser.add_argument("--warmup", type=int, default=500, help="Number of warmup iterations")
61-
parser.add_argument("--timing_iters", type=int, default=2000, help="Number of timing iterations per round")
76+
parser.add_argument(
77+
"--timing_iters", type=int, default=2000, help="Number of timing iterations per round"
78+
)
6279
parser.add_argument("--num_rounds", type=int, default=5, help="Number of timing rounds")
6380
parser.add_argument(
64-
"--backend", type=str, choices=["torch", "te"], default="te", help="Linear backend: torch or te"
81+
"--backend",
82+
type=str,
83+
choices=["torch", "te"],
84+
default="te",
85+
help="Linear backend: torch or te",
6586
)
6687
args = parser.parse_args()
6788

68-
x = torch.randn((args.seq_length, args.hidden_size), dtype=torch.bfloat16, device="cuda", requires_grad=True)
89+
x = torch.randn(
90+
(args.seq_length, args.hidden_size), dtype=torch.bfloat16, device="cuda", requires_grad=True
91+
)
6992
use_te = True
7093
if args.backend == "torch":
71-
model = torch.nn.Linear(args.hidden_size, args.hidden_size, bias=False).to(torch.bfloat16).cuda()
94+
model = (
95+
torch.nn.Linear(args.hidden_size, args.hidden_size, bias=False)
96+
.to(torch.bfloat16)
97+
.cuda()
98+
)
7299
use_te = False
73100
else:
74-
model = te.Linear(args.hidden_size, args.hidden_size, bias=False, device="cuda").to(torch.bfloat16)
101+
model = te.Linear(args.hidden_size, args.hidden_size, bias=False, device="cuda").to(
102+
torch.bfloat16
103+
)
75104
with torch.no_grad():
76105
avg_gpu_time_per_round = speedometer(
77-
model, [x], timing_iters=args.timing_iters, warmup_iters=args.warmup, num_rounds=args.num_rounds, use_te=use_te
106+
model,
107+
[x],
108+
timing_iters=args.timing_iters,
109+
warmup_iters=args.warmup,
110+
num_rounds=args.num_rounds,
111+
use_te=use_te,
78112
)
79113

80114
total_ops = 2 * args.hidden_size * args.hidden_size * args.seq_length * args.timing_iters
@@ -84,4 +118,4 @@ def main():
84118

85119

86120
if __name__ == "__main__":
87-
main()
121+
main()

transformer_engine/pytorch/module/grouped_linear.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -748,9 +748,14 @@ def forward(
748748
if skip_fp8_weight_update is not None:
749749
is_first_microbatch = False
750750

751-
with torch.cuda.device(
752-
getattr(self, list(self.named_parameters())[0][0]).device
753-
), self.prepare_forward(inp, num_gemms=self.num_gemms) as inp:
751+
if is_first_microbatch is None or is_first_microbatch:
752+
device_ctx = torch.cuda.device(
753+
getattr(self, list(self.named_parameters())[0][0]).device
754+
)
755+
else:
756+
device_ctx = contextlib.nullcontext()
757+
758+
with device_ctx, self.prepare_forward(inp, num_gemms=self.num_gemms) as inp:
754759
weight_tensors = self._get_weight_tensors()
755760
bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]
756761

transformer_engine/pytorch/module/layernorm_linear.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1502,10 +1502,16 @@ def forward(
15021502
).is_fp8_ubuf():
15031503
fp8_grad = True
15041504

1505-
with torch.cuda.device(
1506-
getattr(self, list(self.named_parameters())[0][0]).device
1507-
), self.prepare_forward(
1508-
inp, allow_non_contiguous=False # removed .contiguous from inside the layer
1505+
if is_first_microbatch is None or is_first_microbatch:
1506+
device_ctx = torch.cuda.device(
1507+
getattr(self, list(self.named_parameters())[0][0]).device
1508+
)
1509+
else:
1510+
device_ctx = contextlib.nullcontext()
1511+
1512+
with device_ctx, self.prepare_forward(
1513+
inp,
1514+
allow_non_contiguous=False, # removed .contiguous from inside the layer
15091515
) as inp:
15101516

15111517
# Get concatenated weight and bias tensors

transformer_engine/pytorch/module/layernorm_mlp.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1764,9 +1764,14 @@ def forward(
17641764
if get_ub("fc2_fprop", FP8GlobalStateManager.is_fp8_enabled()).is_fp8_ubuf():
17651765
fp8_output = True
17661766

1767-
with torch.cuda.device(
1768-
getattr(self, list(self.named_parameters())[0][0]).device
1769-
), self.prepare_forward(inp, num_gemms=2) as inp:
1767+
if is_first_microbatch is None or is_first_microbatch:
1768+
device_ctx = torch.cuda.device(
1769+
getattr(self, list(self.named_parameters())[0][0]).device
1770+
)
1771+
else:
1772+
device_ctx = contextlib.nullcontext()
1773+
1774+
with device_ctx, self.prepare_forward(inp, num_gemms=2) as inp:
17701775

17711776
quantizers = (
17721777
self._get_quantizers(fp8_output)

transformer_engine/pytorch/module/linear.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1389,7 +1389,9 @@ def forward(
13891389
fp8_grad = True
13901390

13911391
if is_first_microbatch is None or is_first_microbatch:
1392-
device_ctx = torch.cuda.device(getattr(self, list(self.named_parameters())[0][0]).device)
1392+
device_ctx = torch.cuda.device(
1393+
getattr(self, list(self.named_parameters())[0][0]).device
1394+
)
13931395
else:
13941396
device_ctx = contextlib.nullcontext()
13951397

0 commit comments

Comments
 (0)