Skip to content

Commit 4cf9b53

Browse files
committed
Implement straggler detection function, generating report periodically.
1 parent 0b534e5 commit 4cf9b53

File tree

8 files changed

+644
-3
lines changed

8 files changed

+644
-3
lines changed

.github/workflows/functional-tests-nvidia.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,22 +49,26 @@ jobs:
4949
5050
git config --global --add safe.directory /__w/FlagScale/FlagScale
5151
if [ "${{ inputs.type }}" = "train" ] || [ "${{ inputs.type }}" = "hetero_train" ]; then
52+
source /root/miniconda3/bin/activate flagscale-train
5253
PYTHONPATH=./:$PYTHONPATH pip install . --no-build-isolation --verbose --config-settings=device="gpu" --config-settings=backend="Megatron-LM"
5354
if [ "${{ inputs.task }}" = "llava_onevision" ]; then
5455
PYTHONPATH=./:$PYTHONPATH pip install . --no-build-isolation --verbose --config-settings=device="gpu" --config-settings=backend="Megatron-Energon"
5556
cp -r third_party/Megatron-Energon/src/megatron/energon third_party/Megatron-LM/megatron
5657
fi
58+
conda deactivate
5759
elif [ "${{ inputs.type }}" = "inference" ] || [ "${{ inputs.type }}" = "serve" ]; then
5860
source /root/miniconda3/bin/activate flagscale-inference
5961
pip install scikit-build scikit-build-core
6062
pip install git+https://github.com/FlagOpen/[email protected]
6163
PYTHONPATH=./:$PYTHONPATH pip install . --config-settings=backend="vllm" --verbose --no-build-isolation
6264
conda deactivate
6365
elif [ "${{ inputs.type }}" = "rl" ]; then
66+
source /root/miniconda3/bin/activate flagscale-RL
6467
python tools/patch/unpatch.py --backend verl
6568
cd third_party/verl
6669
pip install --no-deps -e .
6770
cd ../..
71+
conda deactivate
6872
else
6973
echo "Unknown backend type: ${{ inputs.type }}"
7074
exit 1

.github/workflows/unit-tests-nvidia.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,12 @@ jobs:
5555
git config --global --add safe.directory /__w/FlagScale/FlagScale
5656
if [ "${{ inputs.backend }}" = "megatron" ] || [ "${{ inputs.backend }}" = "flagscale" ]; then
5757
echo ""
58+
source /root/miniconda3/bin/activate flagscale-train
59+
git clone https://github.com/NVIDIA/nvidia-resiliency-ext
60+
cd nvidia-resiliency-ext
61+
pip install .
62+
cd ..
63+
conda deactivate
5864
# PYTHONPATH=./:$PYTHONPATH pip install . --config-settings=backend="Megatron-LM" --verbose --no-build-isolation
5965
elif [ "${{ inputs.backend }}" = "vllm" ]; then
6066
source /root/miniconda3/bin/activate flagscale-inference

flagscale/backends/Megatron-LM/megatron/core/pipeline_parallel/schedules.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
nvtx_range_pop,
2323
nvtx_range_push,
2424
)
25+
from flagscale.train.straggler_detection import StragglerDetectionWrapper
2526

2627
# Types
2728
Shape = Union[List[int], torch.Size]
@@ -184,6 +185,7 @@ def set_current_microbatch(model, microbatch_id):
184185
layer.current_microbatch = microbatch_id
185186

186187

188+
@StragglerDetectionWrapper(level=2, section_name="microbatch_forward")
187189
def forward_step(
188190
forward_step_func,
189191
data_iterator,
@@ -368,6 +370,7 @@ def forward_step(
368370
return [output_tensor], num_tokens
369371

370372

373+
@StragglerDetectionWrapper(level=2, section_name="microbatch_backward")
371374
def backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config):
372375
"""Backward step through passed-in output tensor.
373376
@@ -485,6 +488,9 @@ def forward_backward_no_pipelining(
485488
adjust_tensor_shapes_fn is None
486489
), "adjust_tensor_shapes_fn is not supported for non-pipeline-parallel schedule"
487490

491+
from megatron.training.global_vars import get_args
492+
args = get_args()
493+
488494
config = get_model_config(model)
489495
if config.timers is not None:
490496
config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time)
@@ -511,10 +517,16 @@ def forward_backward_no_pipelining(
511517
collect_non_loss_data,
512518
is_first_microbatch=check_first_val_step(first_val_step, forward_only, i == 0),
513519
current_microbatch=i,
520+
user_specified_level=args.straggler_detection_level,
521+
passed_warmup_stage=args.curr_iteration > args.straggler_detection_warmup_iterations,
514522
)
515523
total_num_tokens += num_tokens
516524
if not forward_only:
517-
backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config)
525+
backward_step(
526+
input_tensor, output_tensor, output_tensor_grad, model_type, config,
527+
user_specified_level=args.straggler_detection_level,
528+
passed_warmup_stage=args.curr_iteration > args.straggler_detection_warmup_iterations,
529+
)
518530

519531
# Run computation for last microbatch out of context handler (want to
520532
# synchronize gradients).
@@ -531,11 +543,19 @@ def forward_backward_no_pipelining(
531543
first_val_step, forward_only, num_microbatches == 1
532544
),
533545
current_microbatch=num_microbatches - 1,
546+
user_specified_level=args.straggler_detection_level,
547+
passed_warmup_stage=args.curr_iteration > args.straggler_detection_warmup_iterations,
548+
generate_report=forward_only and (args.curr_iteration % args.straggler_detection_interval) == 0
534549
)
535550
total_num_tokens += num_tokens
536551

537552
if not forward_only:
538-
backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config)
553+
backward_step(
554+
input_tensor, output_tensor, output_tensor_grad, model_type, config,
555+
user_specified_level=args.straggler_detection_level,
556+
passed_warmup_stage=args.curr_iteration > args.straggler_detection_warmup_iterations,
557+
generate_report=not forward_only and (args.curr_iteration % args.straggler_detection_interval) == 0
558+
)
539559

540560
if config.finalize_model_grads_func is not None and not forward_only:
541561
# Finalize model grads (perform full grad all-reduce / reduce-scatter for

flagscale/backends/Megatron-LM/megatron/training/arguments.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1666,6 +1666,16 @@ def _add_ft_package_args(parser):
16661666
group.add_argument('--calc-ft-timeouts', action='store_true',
16671667
help='If set, FT package will try to automatically compute the timeouts. '
16681668
'Note: This feature is for Nvidia internal use only.')
1669+
group.add_argument('--straggler-detection-level', type=int,
1670+
default=0, choices=range(0,3),
1671+
help='Granularity of straggler detection level.'
1672+
' 0: off.'
1673+
' 1: per train step.'
1674+
' 2: per train section.')
1675+
group.add_argument('--straggler-detection-interval', type=int, default=10,
1676+
help='Interval in iterations for generating detection report.')
1677+
group.add_argument('--straggler-detection-warmup-iterations', type=int, default=50,
1678+
help='Interval in iterations for generating detection report.')
16691679
return parser
16701680

16711681

0 commit comments

Comments
 (0)