diff --git a/test/test_pipeline_schedule.py b/test/test_pipeline_schedule.py index 427cfb008..e24a05096 100644 --- a/test/test_pipeline_schedule.py +++ b/test/test_pipeline_schedule.py @@ -299,9 +299,6 @@ def test_1f1b(self): @skip_if_lt_x_gpu(4) def test_interleaved_1f1b(self): - # TODO: not working - return - device = torch.device(f"cuda:{self.rank}") dist.init_process_group( init_method=self.init_method, @@ -320,15 +317,19 @@ def test_interleaved_1f1b(self): microbatches = [ (torch.randn_like(microbatch),) for _ in range(num_microbatches) ] + target_mbs = [ + torch.randn_like(microbatch) for _ in range(num_microbatches) + ] + loss_fn = torch.nn.MSELoss() schedule = ScheduleInterleaved1F1B( stages, num_microbatches, + loss_fn=loss_fn, ) - schedule.step_microbatches(microbatches) + schedule.step_microbatches(microbatches, target_mbs=target_mbs) # num local pipeline stages == world_size - num_microbatches = 8 stages = self._create_virtual_pipeline_stages( model, microbatch, @@ -336,22 +337,13 @@ def test_interleaved_1f1b(self): self.world_size, num_microbatches=num_microbatches, ) - microbatches = [ - torch.randn_like(microbatch) for _ in range(num_microbatches) - ] schedule = ScheduleInterleaved1F1B( stages, num_microbatches, + loss_fn=loss_fn, ) - schedule.step_microbatches(microbatches) - - # differing microbatch size - num_microbatches = 64 - microbatches = [ - torch.randn_like(microbatch) for _ in range(num_microbatches) - ] - schedule.step_microbatches(microbatches) + schedule.step_microbatches(microbatches, target_mbs=target_mbs) def test_interleaved_1f1b_negative(self): device = torch.device("cpu")