2222 nvtx_range_pop ,
2323 nvtx_range_push ,
2424)
25+ from flagscale .train .straggler_detection import StragglerDetectionWrapper
2526
2627# Types
2728Shape = Union [List [int ], torch .Size ]
@@ -184,6 +185,7 @@ def set_current_microbatch(model, microbatch_id):
184185 layer .current_microbatch = microbatch_id
185186
186187
188+ @StragglerDetectionWrapper (level = 2 , section_name = "microbatch_forward" )
187189def forward_step (
188190 forward_step_func ,
189191 data_iterator ,
@@ -368,6 +370,7 @@ def forward_step(
368370 return [output_tensor ], num_tokens
369371
370372
373+ @StragglerDetectionWrapper (level = 2 , section_name = "microbatch_backward" )
371374def backward_step (input_tensor , output_tensor , output_tensor_grad , model_type , config ):
372375 """Backward step through passed-in output tensor.
373376
@@ -485,6 +488,9 @@ def forward_backward_no_pipelining(
485488 adjust_tensor_shapes_fn is None
486489 ), "adjust_tensor_shapes_fn is not supported for non-pipeline-parallel schedule"
487490
491+ from megatron .training .global_vars import get_args
492+ args = get_args ()
493+
488494 config = get_model_config (model )
489495 if config .timers is not None :
490496 config .timers ('forward-backward' , log_level = 1 ).start (barrier = config .barrier_with_L1_time )
@@ -511,10 +517,16 @@ def forward_backward_no_pipelining(
511517 collect_non_loss_data ,
512518 is_first_microbatch = check_first_val_step (first_val_step , forward_only , i == 0 ),
513519 current_microbatch = i ,
520+ user_specified_level = args .straggler_detection_level ,
521+ passed_warmup_stage = args .curr_iteration > args .straggler_detection_warmup_iterations ,
514522 )
515523 total_num_tokens += num_tokens
516524 if not forward_only :
517- backward_step (input_tensor , output_tensor , output_tensor_grad , model_type , config )
525+ backward_step (
526+ input_tensor , output_tensor , output_tensor_grad , model_type , config ,
527+ user_specified_level = args .straggler_detection_level ,
528+ passed_warmup_stage = args .curr_iteration > args .straggler_detection_warmup_iterations ,
529+ )
518530
519531 # Run computation for last microbatch out of context handler (want to
520532 # synchronize gradients).
@@ -531,11 +543,19 @@ def forward_backward_no_pipelining(
531543 first_val_step , forward_only , num_microbatches == 1
532544 ),
533545 current_microbatch = num_microbatches - 1 ,
546+ user_specified_level = args .straggler_detection_level ,
547+ passed_warmup_stage = args .curr_iteration > args .straggler_detection_warmup_iterations ,
548+ generate_report = forward_only and (args .curr_iteration % args .straggler_detection_interval ) == 0
534549 )
535550 total_num_tokens += num_tokens
536551
537552 if not forward_only :
538- backward_step (input_tensor , output_tensor , output_tensor_grad , model_type , config )
553+ backward_step (
554+ input_tensor , output_tensor , output_tensor_grad , model_type , config ,
555+ user_specified_level = args .straggler_detection_level ,
556+ passed_warmup_stage = args .curr_iteration > args .straggler_detection_warmup_iterations ,
557+ generate_report = not forward_only and (args .curr_iteration % args .straggler_detection_interval ) == 0
558+ )
539559
540560 if config .finalize_model_grads_func is not None and not forward_only :
541561 # Finalize model grads (perform full grad all-reduce / reduce-scatter for
0 commit comments