diff --git a/test/test_pipeline_schedule.py b/test/test_pipeline_schedule.py
index 427cfb008..e24a05096 100644
--- a/test/test_pipeline_schedule.py
+++ b/test/test_pipeline_schedule.py
@@ -299,9 +299,6 @@ def test_1f1b(self):
 
     @skip_if_lt_x_gpu(4)
     def test_interleaved_1f1b(self):
-        # TODO: not working
-        return
-
         device = torch.device(f"cuda:{self.rank}")
         dist.init_process_group(
             init_method=self.init_method,
@@ -320,15 +317,19 @@ def test_interleaved_1f1b(self):
         microbatches = [
             (torch.randn_like(microbatch),) for _ in range(num_microbatches)
         ]
+        target_mbs = [
+            torch.randn_like(microbatch) for _ in range(num_microbatches)
+        ]
 
+        loss_fn = torch.nn.MSELoss()
         schedule = ScheduleInterleaved1F1B(
             stages,
             num_microbatches,
+            loss_fn=loss_fn,
         )
-        schedule.step_microbatches(microbatches)
+        schedule.step_microbatches(microbatches, target_mbs=target_mbs)
 
         # num local pipeline stages == world_size
-        num_microbatches = 8
         stages = self._create_virtual_pipeline_stages(
             model,
             microbatch,
@@ -336,22 +337,13 @@ def test_interleaved_1f1b(self):
             self.world_size,
             num_microbatches=num_microbatches,
         )
-        microbatches = [
-            torch.randn_like(microbatch) for _ in range(num_microbatches)
-        ]
 
         schedule = ScheduleInterleaved1F1B(
             stages,
             num_microbatches,
+            loss_fn=loss_fn,
         )
-        schedule.step_microbatches(microbatches)
-
-        # differing microbatch size
-        num_microbatches = 64
-        microbatches = [
-            torch.randn_like(microbatch) for _ in range(num_microbatches)
-        ]
-        schedule.step_microbatches(microbatches)
+        schedule.step_microbatches(microbatches, target_mbs=target_mbs)
 
     def test_interleaved_1f1b_negative(self):
         device = torch.device("cpu")