@@ -44,9 +44,9 @@ def _maybe_compute_loss(self, stage, output, target_mbs, mb_index):
44
44
f"[{ stage .stage_index } ] Loss of microbatch { mb_index } : { loss } "
45
45
)
46
46
47
- def _maybe_get_loss (self , mb_index ):
47
+ def _maybe_get_loss (self , stage , mb_index ):
48
48
valid_index = 0 <= mb_index < len (self ._internal_losses )
49
- if self ._has_backward and valid_index :
49
+ if stage . is_last and self ._has_backward and valid_index :
50
50
return self ._internal_losses [mb_index ]
51
51
elif len (self ._internal_losses ) != 0 and not valid_index :
52
52
raise RuntimeError (
@@ -56,12 +56,17 @@ def _maybe_get_loss(self, mb_index):
56
56
else :
57
57
return None
58
58
59
- def _update_losses (self , losses ):
59
+ def _update_losses (self , stages , losses ):
60
60
"""
61
61
Update the losses to those in the internal state
62
62
"""
63
+ # if stages not a list turn into a list
64
+ if not isinstance (stages , list ):
65
+ stages = [stages ]
66
+ contains_last_stage = any ([stage .is_last for stage in stages ])
67
+
63
68
# Return losses if there is a container passed in
64
- if losses is not None :
69
+ if contains_last_stage and losses is not None :
65
70
if len (self ._internal_losses ) != self ._n_microbatches :
66
71
raise RuntimeError (
67
72
f"Expecting { self ._n_microbatches } losses but got { len (self ._internal_losses )} "
@@ -330,7 +335,7 @@ def step_microbatches(
330
335
for work in works .values ():
331
336
work .wait ()
332
337
333
- loss = self ._maybe_get_loss (i )
338
+ loss = self ._maybe_get_loss (self . _stage , i )
334
339
self ._stage .backward_one_chunk (loss = loss )
335
340
336
341
ops = self ._stage .get_bwd_send_ops ()
@@ -342,7 +347,7 @@ def step_microbatches(
342
347
)
343
348
344
349
# Return losses if there is a container passed in
345
- self ._update_losses (losses )
350
+ self ._update_losses (self . _stage , losses )
346
351
347
352
# Wait for all backward sends to finish
348
353
for work in bwd_sends_to_wait :
@@ -423,7 +428,7 @@ def step_microbatches(
423
428
for work in works .values ():
424
429
work .wait ()
425
430
426
- loss = self ._maybe_get_loss (bwd_mb_index )
431
+ loss = self ._maybe_get_loss (self . _stage , bwd_mb_index )
427
432
self ._stage .backward_one_chunk (loss = loss )
428
433
429
434
ops = self ._stage .get_bwd_send_ops ()
@@ -440,7 +445,7 @@ def step_microbatches(
440
445
work .wait ()
441
446
442
447
# Return losses if there is a container passed in
443
- self ._update_losses (losses )
448
+ self ._update_losses (self . _stage , losses )
444
449
445
450
446
451
class PipelineScheduleMulti (PipelineSchedule ):
@@ -553,14 +558,14 @@ def step_microbatches(
553
558
if ops :
554
559
dist .batch_isend_irecv (ops ).pop ().wait ()
555
560
556
- loss = self ._maybe_get_loss (i )
561
+ loss = self ._maybe_get_loss (stage , i )
557
562
stage .backward_one_chunk (loss = loss )
558
563
559
564
ops = stage .get_bwd_send_ops ()
560
565
if ops :
561
566
dist .batch_isend_irecv (ops )
562
567
563
- self ._update_losses (losses )
568
+ self ._update_losses (self . _stages , losses )
564
569
565
570
566
571
class ScheduleInterleaved1F1B (PipelineScheduleMulti ):
@@ -739,7 +744,7 @@ def backward_stage_local_index(step):
739
744
)
740
745
741
746
# bwd
742
- loss = self ._maybe_get_loss (bwd_mb_index )
747
+ loss = self ._maybe_get_loss (bwd_stage , bwd_mb_index )
743
748
bwd_stage .backward_one_chunk (loss = loss )
744
749
ops .extend (bwd_stage .get_bwd_send_ops ())
745
750
@@ -764,7 +769,7 @@ def backward_stage_local_index(step):
764
769
for work in works .values ():
765
770
work .wait ()
766
771
767
- loss = self ._maybe_get_loss (bwd_mb_index )
772
+ loss = self ._maybe_get_loss (bwd_stage , bwd_mb_index )
768
773
bwd_stage .backward_one_chunk (loss = loss )
769
774
770
775
ops = bwd_stage .get_bwd_send_ops ()
@@ -776,4 +781,4 @@ def backward_stage_local_index(step):
776
781
work .wait ()
777
782
778
783
# Return losses if there is a container passed in
779
- self ._update_losses (losses )
784
+ self ._update_losses (self . _stages , losses )
0 commit comments