Skip to content

Commit 9a57828

Browse files
haocizhangpytorchmergebot
authored andcommitted
Implemented flexible PP schedule (pytorch#129597)
Enabled some cases to work where num_microbatches % pp_size != 0. Using the flex_pp schedule, we will have num_rounds = max(1, n_microbatches // pp_group_size) and it works as long as n_microbatches % num_rounds is 0. As a few examples, support pp_group_size = 4, n_microbatches = 10. We will have num_rounds = 2 and n_microbatches % 2 is 0. pp_group_size = 4, n_microbatches = 3. We will have num_rounds = 1 and n_microbatches % 1 is 0. Moved over from PiPPy (pytorch/PiPPy#1129) Tested using the config in (1), schedule looks like the following graph: ``` =========== ALL_RANK_ACTIONS =========== Rank 0 Rank 1 Rank 2 Rank 3 Step 00: F0_s0 None None None Step 01: F1_s0 F0_s1 None None Step 02: F2_s0 F1_s1 F0_s2 None Step 03: F3_s0 F2_s1 F1_s2 F0_s3 Step 04: F4_s0 F3_s1 F2_s2 F1_s3 Step 05: F0_s4 F4_s1 F3_s2 F2_s3 Step 06: F1_s4 F0_s5 F4_s2 F3_s3 Step 07: F2_s4 F1_s5 F0_s6 F4_s3 Step 08: F3_s4 F2_s5 F1_s6 F0_s7 Step 09: F4_s4 F3_s5 None B0_s7 Step 10: F5_s0 None F2_s6 F1_s7 Step 11: None None B0_s6 B1_s7 Step 12: None F4_s5 F3_s6 F2_s7 Step 13: None B0_s5 B1_s6 B2_s7 Step 14: F6_s0 F5_s1 F4_s6 F3_s7 Step 15: B0_s4 B1_s5 B2_s6 B3_s7 Step 16: F7_s0 F6_s1 F5_s2 F4_s7 Step 17: B1_s4 B2_s5 B3_s6 B4_s7 Step 18: F8_s0 F7_s1 F6_s2 F5_s3 Step 19: B2_s4 B3_s5 B4_s6 B0_s3 Step 20: F9_s0 F8_s1 F7_s2 F6_s3 Step 21: B3_s4 B4_s5 B0_s2 B1_s3 Step 22: F5_s4 F9_s1 F8_s2 F7_s3 Step 23: B4_s4 B0_s1 B1_s2 B2_s3 Step 24: F6_s4 F5_s5 F9_s2 F8_s3 Step 25: B0_s0 B1_s1 B2_s2 B3_s3 Step 26: F7_s4 F6_s5 F5_s6 F9_s3 Step 27: B1_s0 B2_s1 B3_s2 B4_s3 Step 28: F8_s4 F7_s5 F6_s6 F5_s7 Step 29: B2_s0 B3_s1 B4_s2 B5_s7 Step 30: F9_s4 F8_s5 F7_s6 F6_s7 Step 31: B3_s0 B4_s1 B5_s6 B6_s7 Step 32: None F9_s5 F8_s6 F7_s7 Step 33: B4_s0 B5_s5 B6_s6 B7_s7 Step 34: None None F9_s6 F8_s7 Step 35: B5_s4 B6_s5 B7_s6 B8_s7 Step 36: None None None F9_s7 Step 37: B6_s4 B7_s5 B8_s6 B9_s7 Step 38: None None None None Step 39: B7_s4 B8_s5 B9_s6 B5_s3 Step 40: None None None None Step 41: B8_s4 B9_s5 B5_s2 B6_s3 Step 42: None None None None Step 43: B9_s4 B5_s1 B6_s2 B7_s3 Step 44: None None None None Step 45: B5_s0 B6_s1 B7_s2 B8_s3 Step 46: None None None None Step 47: B6_s0 B7_s1 B8_s2 B9_s3 Step 48: None None None Step 49: B7_s0 B8_s1 B9_s2 Step 50: None None Step 51: B8_s0 B9_s1 Step 52: None ``` Pull Request resolved: pytorch#129597 Approved by: https://github.com/H-Huang
1 parent ac5f655 commit 9a57828

File tree

4 files changed

+211
-69
lines changed

4 files changed

+211
-69
lines changed

docs/source/distributed.pipelining.rst

+4-1
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,8 @@ You can implement your own pipeline schedule by extending one of the following t
414414
``PipelineScheduleMulti`` is for schedules that assigns multiple stages per rank.
415415

416416
For example, ``ScheduleGPipe`` and ``Schedule1F1B`` are subclasses of ``PipelineScheduleSingle``.
417-
Whereas, ``ScheduleInterleaved1F1B`` and ``ScheduleLoopedBFS`` are subclasses of ``PipelineScheduleMulti``.
417+
Whereas, ``ScheduleFlexibleInterleaved1F1B``, ``ScheduleInterleaved1F1B`` and ``ScheduleLoopedBFS``
418+
are subclasses of ``PipelineScheduleMulti``.
418419

419420

420421
API Reference
@@ -472,6 +473,8 @@ Pipeline Schedules
472473

473474
.. autoclass:: Schedule1F1B
474475

476+
.. autoclass:: ScheduleFlexibleInterleaved1F1B
477+
475478
.. autoclass:: ScheduleInterleaved1F1B
476479

477480
.. autoclass:: ScheduleLoopedBFS

test/distributed/pipelining/test_schedule.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
pipeline,
1818
PipelineStage,
1919
Schedule1F1B,
20+
ScheduleFlexibleInterleaved1F1B,
2021
ScheduleGPipe,
2122
ScheduleInterleaved1F1B,
2223
ScheduleLoopedBFS,
@@ -754,7 +755,10 @@ def _validate_pipeline_order(
754755
if len(error_msg) != 0:
755756
self.fail(f"Error at timestep {timestep}: " + ",".join(error_msg))
756757

757-
@parametrize("ScheduleClass", [ScheduleInterleaved1F1B, ScheduleLoopedBFS])
758+
@parametrize(
759+
"ScheduleClass",
760+
[ScheduleFlexibleInterleaved1F1B, ScheduleInterleaved1F1B, ScheduleLoopedBFS],
761+
)
758762
def test_pipeline_order(self, ScheduleClass):
759763
# Define a list of test cases with varying num_local_stages, num_microbatches, and group_size
760764
# These should succeed since num_microbatches % group_size == 0
@@ -783,13 +787,24 @@ def test_pipeline_order(self, ScheduleClass):
783787
# odd group_sizes
784788
(4, 6, 3),
785789
(4, 10, 5),
790+
# n_mb non divisible by group_size
791+
(2, 3, 4),
792+
(2, 4, 4),
793+
(2, 10, 4),
794+
(2, 15, 4),
786795
]
787796
for num_local_stages, num_microbatches, group_size in test_cases:
788797
with self.subTest(
789798
num_local_stages=num_local_stages,
790799
num_microbatches=num_microbatches,
791800
group_size=group_size,
792801
):
802+
only_run_in_flex_pp = num_microbatches % group_size != 0
803+
if only_run_in_flex_pp and not isinstance(
804+
ScheduleClass, ScheduleFlexibleInterleaved1F1B
805+
):
806+
continue
807+
793808
print(f"{num_local_stages=} {num_microbatches=} {group_size=}")
794809
num_stages = num_local_stages * group_size
795810
stages = [

torch/distributed/pipelining/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from ._IR import Pipe, pipe_split, pipeline, SplitPoint
33
from .schedules import (
44
Schedule1F1B,
5+
ScheduleFlexibleInterleaved1F1B,
56
ScheduleGPipe,
67
ScheduleInterleaved1F1B,
78
ScheduleLoopedBFS,
@@ -17,6 +18,7 @@
1718
"PipelineStage",
1819
"build_stage",
1920
"Schedule1F1B",
21+
"ScheduleFlexibleInterleaved1F1B",
2022
"ScheduleGPipe",
2123
"ScheduleInterleaved1F1B",
2224
"ScheduleLoopedBFS",

torch/distributed/pipelining/schedules.py

+189-67
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
"PipelineScheduleSingle",
2222
"PipelineScheduleMulti",
2323
"Schedule1F1B",
24+
"ScheduleFlexibleInterleaved1F1B",
2425
"ScheduleGPipe",
2526
"ScheduleInterleaved1F1B",
2627
"ScheduleLoopedBFS",
@@ -955,6 +956,82 @@ def _calculate_single_rank_operations(self, rank):
955956
return rank_ops
956957

957958

959+
def _get_1f1b_rank_ops(
960+
n_local_stages,
961+
pp_group_size,
962+
warmup_ops,
963+
fwd_bwd_ops,
964+
cooldown_ops,
965+
rank,
966+
forward_stage_index,
967+
backward_stage_index,
968+
):
969+
# All stages start with handling microbatch 0
970+
fwd_stage_mb_index: Dict[int, int] = defaultdict(int)
971+
bwd_stage_mb_index: Dict[int, int] = defaultdict(int)
972+
# Store the list of operations used for that rank
973+
rank_ops: List[Optional[_Action]] = []
974+
# Pre-padding, rank starts with no-ops based on the warmup.
975+
for _ in range(rank):
976+
rank_ops.append(None)
977+
# These are used to calculate the number of slots to fill with no-ops, to account for the delay in warmup
978+
# when we want to wait for the backward to trickle back up and start 1f1b to align all ranks.
979+
# Formula:
980+
# pre-padding + warmup_ops + post_warmup_ops = earliest time step of first backward
981+
# post_warmup_ops = [earliest time step of first backward] - (warmup_ops + pre-padding)
982+
# earliest time step of first backward = [local_stages * group_size + 2 * (group_size - 1 - rank)]
983+
# warmup_ops = calculated above
984+
post_warmup_ops = (
985+
n_local_stages * pp_group_size + 2 * (pp_group_size - 1 - rank)
986+
) - (warmup_ops + rank)
987+
988+
total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops
989+
990+
for op in range(total_ops):
991+
# Warmup phase
992+
if op < warmup_ops:
993+
fwd_stage_index = forward_stage_index(op)
994+
# This will assign the current microbatch index and update it as well
995+
fwd_stage_mb_index[fwd_stage_index] = (
996+
mb_index := fwd_stage_mb_index[fwd_stage_index]
997+
) + 1
998+
rank_ops.append(
999+
_Action(_ComputationType.FORWARD, mb_index, fwd_stage_index)
1000+
)
1001+
if op == warmup_ops - 1:
1002+
# This is the last step in the warmup phase, so we need to wait for the backward to trickle back up
1003+
rank_ops.extend([None] * post_warmup_ops)
1004+
# 1F1B Phase (forward and backward)
1005+
elif warmup_ops <= op < warmup_ops + fwd_bwd_ops:
1006+
fwd_stage_index = forward_stage_index(op)
1007+
fwd_stage_mb_index[fwd_stage_index] = (
1008+
fwd_mb_index := fwd_stage_mb_index[fwd_stage_index]
1009+
) + 1
1010+
rank_ops.append(
1011+
_Action(_ComputationType.FORWARD, fwd_mb_index, fwd_stage_index)
1012+
)
1013+
bwd_stage_index = backward_stage_index(op)
1014+
bwd_stage_mb_index[bwd_stage_index] = (
1015+
bwd_mb_index := bwd_stage_mb_index[bwd_stage_index]
1016+
) + 1
1017+
rank_ops.append(
1018+
_Action(_ComputationType.BACKWARD, bwd_mb_index, bwd_stage_index)
1019+
)
1020+
# Cooldown phase
1021+
else:
1022+
# During cooldown phase, we need steps to align with 1f1b happening in other ranks
1023+
# TODO: we don't need to always append, after all 1f1b are finished we can stop appending None
1024+
rank_ops.append(None)
1025+
bwd_stage_index = backward_stage_index(op)
1026+
bwd_stage_mb_index[bwd_stage_index] = (
1027+
bwd_mb_index := bwd_stage_mb_index[bwd_stage_index]
1028+
) + 1
1029+
rank_ops.append(
1030+
_Action(_ComputationType.BACKWARD, bwd_mb_index, bwd_stage_index)
1031+
)
1032+
return rank_ops
1033+
1034+
9581035
class ScheduleInterleaved1F1B(PipelineScheduleMulti):
9591036
"""
9601037
The Interleaved 1F1B schedule.
@@ -1046,74 +1123,119 @@ def backward_stage_index(step):
10461123
)
10471124
return (local_index * self.pp_group_size) + rank
10481125

1049-
# Dictionary for tracking {stage index : current microbatch index}
1050-
# All stages start with handling microbatch 0
1051-
fwd_stage_mb_index: Dict[int, int] = defaultdict(int)
1052-
bwd_stage_mb_index: Dict[int, int] = defaultdict(int)
1126+
return _get_1f1b_rank_ops(
1127+
self.n_local_stages,
1128+
self.pp_group_size,
1129+
warmup_ops,
1130+
fwd_bwd_ops,
1131+
cooldown_ops,
1132+
rank,
1133+
forward_stage_index,
1134+
backward_stage_index,
1135+
)
10531136

1054-
# Store the list of operations used for that rank
1055-
rank_ops: List[Optional[_Action]] = []
1056-
# Pre-padding, rank starts with no-ops based on the warmup.
1057-
for _ in range(rank):
1058-
rank_ops.append(None)
10591137

1060-
# These are used to calculate the number of slots to fill with no-ops, to account for the delay in warmup
1061-
# when we want to wait for the backward to trickle back up and start 1f1b to align all ranks.
1062-
# Formula:
1063-
# pre-padding + warmup_ops + post_warmup_ops = earliest time step of first backward
1064-
# post_warmup_ops = [earliest time step of first backward] - (warmup_ops + pre-padding)
1065-
# earliest time step of first backward = [local_stages * group_size + 2 * (group_size - 1 - rank)]
1066-
# warmup_ops = calculated above
1067-
post_warmup_ops = (
1068-
self.n_local_stages * self.pp_group_size
1069-
+ 2 * (self.pp_group_size - 1 - rank)
1070-
) - (warmup_ops + rank)
1071-
1072-
for op in range(total_ops):
1073-
# Warmup phase
1074-
if op < warmup_ops:
1075-
fwd_stage_index = forward_stage_index(op)
1076-
# This will assign the current microbatch index and update it as well
1077-
fwd_stage_mb_index[fwd_stage_index] = (
1078-
mb_index := fwd_stage_mb_index[fwd_stage_index]
1079-
) + 1
1080-
rank_ops.append(
1081-
_Action(_ComputationType.FORWARD, mb_index, fwd_stage_index)
1082-
)
1083-
if op == warmup_ops - 1:
1084-
# This is the last step in the warmup phase, so we need to wait for the backward to trickle back up
1085-
rank_ops.extend([None] * post_warmup_ops)
1086-
# 1F1B Phase (forward and backward)
1087-
elif warmup_ops <= op < warmup_ops + fwd_bwd_ops:
1088-
fwd_stage_index = forward_stage_index(op)
1089-
fwd_stage_mb_index[fwd_stage_index] = (
1090-
fwd_mb_index := fwd_stage_mb_index[fwd_stage_index]
1091-
) + 1
1092-
rank_ops.append(
1093-
_Action(_ComputationType.FORWARD, fwd_mb_index, fwd_stage_index)
1094-
)
1138+
class ScheduleFlexibleInterleaved1F1B(PipelineScheduleMulti):
1139+
"""
1140+
The Flexible Interleaved 1F1B schedule.
10951141
1096-
bwd_stage_index = backward_stage_index(op)
1097-
bwd_stage_mb_index[bwd_stage_index] = (
1098-
bwd_mb_index := bwd_stage_mb_index[bwd_stage_index]
1099-
) + 1
1100-
rank_ops.append(
1101-
_Action(_ComputationType.BACKWARD, bwd_mb_index, bwd_stage_index)
1102-
)
1103-
# Cooldown phase
1104-
else:
1105-
# During cooldown phase, we need steps to align with 1f1b happening in other ranks
1106-
# TODO: we don't need to always append, after all 1f1b are finished we can stop appending None
1107-
rank_ops.append(None)
1108-
bwd_stage_index = backward_stage_index(op)
1109-
bwd_stage_mb_index[bwd_stage_index] = (
1110-
bwd_mb_index := bwd_stage_mb_index[bwd_stage_index]
1111-
) + 1
1112-
rank_ops.append(
1113-
_Action(_ComputationType.BACKWARD, bwd_mb_index, bwd_stage_index)
1114-
)
1142+
This schedule is mostly similar to the interleaved 1F1B schedule.
1143+
It differs by being relaxing the requirement of num_microbatch % pp_size == 0.
1144+
Using the flex_pp schedule, we will have num_rounds = max(1, n_microbatches // pp_group_size) and
1145+
it works as long as n_microbatches % num_rounds is 0. As a few examples, support
11151146
1116-
# Post padding
1117-
for _ in range(self.pp_group_size - rank - 1):
1118-
rank_ops.append(None)
1119-
return rank_ops
1147+
1. pp_group_size = 4, n_microbatches = 10. We will have num_rounds = 2 and n_microbatches % 2 is 0.
1148+
2. pp_group_size = 4, n_microbatches = 3. We will have num_rounds = 1 and n_microbatches % 1 is 0.
1149+
"""
1150+
1151+
def __init__(
1152+
self,
1153+
stages: List[_PipelineStageBase],
1154+
n_microbatches: int,
1155+
loss_fn: Optional[Callable] = None,
1156+
args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
1157+
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
1158+
output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
1159+
):
1160+
self.pp_group_size = stages[0].group_size
1161+
super().__init__(
1162+
stages=stages,
1163+
n_microbatches=n_microbatches,
1164+
loss_fn=loss_fn,
1165+
args_chunk_spec=args_chunk_spec,
1166+
kwargs_chunk_spec=kwargs_chunk_spec,
1167+
output_merge_spec=output_merge_spec,
1168+
)
1169+
self.n_local_stages = len(stages)
1170+
self.rank = stages[0].group_rank
1171+
self.number_of_rounds = max(1, n_microbatches // self.pp_group_size)
1172+
self.microbatches_per_round = n_microbatches // self.number_of_rounds
1173+
if n_microbatches % self.number_of_rounds != 0:
1174+
raise ValueError(
1175+
"Flexible Interleaved 1F1B requires the number of microbatches to be a "
1176+
f"multiple of the number of rounds ({self.number_of_rounds}), "
1177+
f"but got {n_microbatches}."
1178+
)
1179+
# 1. Create the pipeline_order (all ranks do this calculation)
1180+
# This will be used to keep track of the current state of the entire pipeline
1181+
# pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
1182+
self.pipeline_order: Dict[int, List[Optional[_Action]]] = {}
1183+
for rank in range(self.pp_group_size):
1184+
rank_ops = self._calculate_single_rank_operations(rank)
1185+
self.pipeline_order[rank] = rank_ops
1186+
1187+
def _calculate_single_rank_operations(self, rank) -> List[Optional[_Action]]:
1188+
def get_rank_warmup_ops(rank):
1189+
# Warms up operations for last stage
1190+
warmups_ops_last_stage = (
1191+
self.n_local_stages - 1
1192+
) * self.microbatches_per_round
1193+
# Increment warmup operations by 2 for each hop away from the last stage
1194+
warmup_ops = warmups_ops_last_stage + 2 * ((self.pp_group_size - 1) - rank)
1195+
# We cannot have more warmup operations than there are number of microbatches, so cap it there
1196+
return min(warmup_ops, self._n_microbatches * self.n_local_stages)
1197+
1198+
warmup_ops = get_rank_warmup_ops(rank)
1199+
microbatch_ops = self.n_local_stages * self._n_microbatches
1200+
# fwd_bwd_ops should encompass the remaining forwards
1201+
fwd_bwd_ops = microbatch_ops - warmup_ops
1202+
# cooldown_ops should encompass the remaining backwards
1203+
cooldown_ops = microbatch_ops - fwd_bwd_ops
1204+
# total ops encompass both forward and backward ops
1205+
total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops
1206+
# warmup_ops + fwd_bwd_ops * 2 + cooldown_ops == microbatch_ops * 2
1207+
logger.debug(
1208+
"rank %s, warmup_ops %s, 1f1b %s, cooldown_ops %s total_ops %s",
1209+
rank,
1210+
warmup_ops,
1211+
fwd_bwd_ops,
1212+
cooldown_ops,
1213+
total_ops,
1214+
)
1215+
1216+
# Calculates the stage index based on step and pp_group_size
1217+
1218+
def forward_stage_index(step):
1219+
# Get the local index from 0 to n_local_stages-1
1220+
local_index = (step // self.microbatches_per_round) % self.n_local_stages
1221+
return (local_index * self.pp_group_size) + rank
1222+
1223+
def backward_stage_index(step):
1224+
local_index = (
1225+
self.n_local_stages
1226+
- 1
1227+
- ((step - warmup_ops) // self.microbatches_per_round)
1228+
% self.n_local_stages
1229+
)
1230+
return (local_index * self.pp_group_size) + rank
1231+
1232+
return _get_1f1b_rank_ops(
1233+
self.n_local_stages,
1234+
self.pp_group_size,
1235+
warmup_ops,
1236+
fwd_bwd_ops,
1237+
cooldown_ops,
1238+
rank,
1239+
forward_stage_index,
1240+
backward_stage_index,
1241+
)

0 commit comments

Comments
 (0)