Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions tests/pytorch/debug/test_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,15 @@ def _run_cpu_overhead(debug_tools_initialized, layer, configs_dir, feature_dirs)
model = torch.nn.Sequential(
te.Linear(1, 1, name="linear1"), te.Linear(1, 1, name="linear2")
).cuda()
NUM_ITERS = 18000
NUM_ITERS = 1800
elif layer == "transformer":
model = torch.nn.Sequential(
te.TransformerLayer(1, 1, 1, name="transformer1"),
te.TransformerLayer(1, 1, 1, name="transformer2"),
).cuda()
NUM_ITERS = 2000
NUM_ITERS = 200

NUM_INVOCATIONS_PER_ITER = 10

x = torch.randn(1, 1, 1).cuda()

Expand All @@ -45,8 +47,9 @@ def _run_cpu_overhead(debug_tools_initialized, layer, configs_dir, feature_dirs)

time_start = time.time()
for i in range(NUM_ITERS):
y = model(x)
y.sum().backward()
for _ in range(NUM_INVOCATIONS_PER_ITER):
y = model(x)
y.sum().backward()
if debug_tools_initialized:
debug_api.step()
torch.cuda.synchronize()
Expand Down
8 changes: 7 additions & 1 deletion transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1514,7 +1514,13 @@ def is_debug_iter(self) -> bool:
debug = False
else:
debug = TEDebugState.get_iteration() >= self.next_iter_when_debug_should_be_run
self.debug_last_iteration = TEDebugState.get_iteration()
self.debug_last_iteration = TEDebugState.get_iteration()
self.debug_enabled_in_this_iteration = debug
else:
# If this is the same iteration as previous invocation of the module,
# we use the debug value from the first invocation in the iteration.
debug = self.debug_enabled_in_this_iteration

return debug

def no_debug_features_active(self, quantizers):
Expand Down