diff --git a/tests/pytorch/debug/test_perf.py b/tests/pytorch/debug/test_perf.py index 2d4b62b23f..ad40c31c02 100644 --- a/tests/pytorch/debug/test_perf.py +++ b/tests/pytorch/debug/test_perf.py @@ -28,13 +28,15 @@ def _run_cpu_overhead(debug_tools_initialized, layer, configs_dir, feature_dirs) model = torch.nn.Sequential( te.Linear(1, 1, name="linear1"), te.Linear(1, 1, name="linear2") ).cuda() - NUM_ITERS = 18000 + NUM_ITERS = 1800 elif layer == "transformer": model = torch.nn.Sequential( te.TransformerLayer(1, 1, 1, name="transformer1"), te.TransformerLayer(1, 1, 1, name="transformer2"), ).cuda() - NUM_ITERS = 2000 + NUM_ITERS = 200 + + NUM_INVOCATIONS_PER_ITER = 10 x = torch.randn(1, 1, 1).cuda() @@ -45,8 +47,9 @@ def _run_cpu_overhead(debug_tools_initialized, layer, configs_dir, feature_dirs) time_start = time.time() for i in range(NUM_ITERS): - y = model(x) - y.sum().backward() + for _ in range(NUM_INVOCATIONS_PER_ITER): + y = model(x) + y.sum().backward() if debug_tools_initialized: debug_api.step() torch.cuda.synchronize() diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 3bbfaacdf5..6a0766562f 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1514,7 +1514,13 @@ def is_debug_iter(self) -> bool: debug = False else: debug = TEDebugState.get_iteration() >= self.next_iter_when_debug_should_be_run - self.debug_last_iteration = TEDebugState.get_iteration() + self.debug_last_iteration = TEDebugState.get_iteration() + self.debug_enabled_in_this_iteration = debug + else: + # If this is the same iteration as previous invocation of the module, + # we use the debug value from the first invocation in the iteration. + debug = self.debug_enabled_in_this_iteration + return debug def no_debug_features_active(self, quantizers):