Skip to content

Commit 10956e4

Browse files
committed
update loss utilities to take stage
ghstack-source-id: 3b39366 Pull Request resolved: #1077
1 parent 73d7f48 commit 10956e4

File tree

1 file changed

+18
-13
lines changed

1 file changed

+18
-13
lines changed

pippy/PipelineSchedule.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ def _maybe_compute_loss(self, stage, output, target_mbs, mb_index):
4444
f"[{stage.stage_index}] Loss of microbatch {mb_index}: {loss}"
4545
)
4646

47-
def _maybe_get_loss(self, mb_index):
47+
def _maybe_get_loss(self, stage, mb_index):
4848
valid_index = 0 <= mb_index < len(self._internal_losses)
49-
if self._has_backward and valid_index:
49+
if stage.is_last and self._has_backward and valid_index:
5050
return self._internal_losses[mb_index]
5151
elif len(self._internal_losses) != 0 and not valid_index:
5252
raise RuntimeError(
@@ -56,12 +56,17 @@ def _maybe_get_loss(self, mb_index):
5656
else:
5757
return None
5858

59-
def _update_losses(self, losses):
59+
def _update_losses(self, stages, losses):
6060
"""
6161
Update the losses to those in the internal state
6262
"""
63+
# if stages not a list turn into a list
64+
if not isinstance(stages, list):
65+
stages = [stages]
66+
contains_last_stage = any([stage.is_last for stage in stages])
67+
6368
# Return losses if there is a container passed in
64-
if losses is not None:
69+
if contains_last_stage and losses is not None:
6570
if len(self._internal_losses) != self._n_microbatches:
6671
raise RuntimeError(
6772
f"Expecting {self._n_microbatches} losses but got {len(self._internal_losses)}"
@@ -330,7 +335,7 @@ def step_microbatches(
330335
for work in works.values():
331336
work.wait()
332337

333-
loss = self._maybe_get_loss(i)
338+
loss = self._maybe_get_loss(self._stage, i)
334339
self._stage.backward_one_chunk(loss=loss)
335340

336341
ops = self._stage.get_bwd_send_ops()
@@ -342,7 +347,7 @@ def step_microbatches(
342347
)
343348

344349
# Return losses if there is a container passed in
345-
self._update_losses(losses)
350+
self._update_losses(self._stage, losses)
346351

347352
# Wait for all backward sends to finish
348353
for work in bwd_sends_to_wait:
@@ -423,7 +428,7 @@ def step_microbatches(
423428
for work in works.values():
424429
work.wait()
425430

426-
loss = self._maybe_get_loss(bwd_mb_index)
431+
loss = self._maybe_get_loss(self._stage, bwd_mb_index)
427432
self._stage.backward_one_chunk(loss=loss)
428433

429434
ops = self._stage.get_bwd_send_ops()
@@ -440,7 +445,7 @@ def step_microbatches(
440445
work.wait()
441446

442447
# Return losses if there is a container passed in
443-
self._update_losses(losses)
448+
self._update_losses(self._stage, losses)
444449

445450

446451
class PipelineScheduleMulti(PipelineSchedule):
@@ -553,14 +558,14 @@ def step_microbatches(
553558
if ops:
554559
dist.batch_isend_irecv(ops).pop().wait()
555560

556-
loss = self._maybe_get_loss(i)
561+
loss = self._maybe_get_loss(stage, i)
557562
stage.backward_one_chunk(loss=loss)
558563

559564
ops = stage.get_bwd_send_ops()
560565
if ops:
561566
dist.batch_isend_irecv(ops)
562567

563-
self._update_losses(losses)
568+
self._update_losses(self._stages, losses)
564569

565570

566571
class ScheduleInterleaved1F1B(PipelineScheduleMulti):
@@ -739,7 +744,7 @@ def backward_stage_local_index(step):
739744
)
740745

741746
# bwd
742-
loss = self._maybe_get_loss(bwd_mb_index)
747+
loss = self._maybe_get_loss(bwd_stage, bwd_mb_index)
743748
bwd_stage.backward_one_chunk(loss=loss)
744749
ops.extend(bwd_stage.get_bwd_send_ops())
745750

@@ -764,7 +769,7 @@ def backward_stage_local_index(step):
764769
for work in works.values():
765770
work.wait()
766771

767-
loss = self._maybe_get_loss(bwd_mb_index)
772+
loss = self._maybe_get_loss(bwd_stage, bwd_mb_index)
768773
bwd_stage.backward_one_chunk(loss=loss)
769774

770775
ops = bwd_stage.get_bwd_send_ops()
@@ -776,4 +781,4 @@ def backward_stage_local_index(step):
776781
work.wait()
777782

778783
# Return losses if there is a container passed in
779-
self._update_losses(losses)
784+
self._update_losses(self._stages, losses)

0 commit comments

Comments
 (0)