|
21 | 21 | "PipelineScheduleSingle",
|
22 | 22 | "PipelineScheduleMulti",
|
23 | 23 | "Schedule1F1B",
|
| 24 | + "ScheduleFlexibleInterleaved1F1B", |
24 | 25 | "ScheduleGPipe",
|
25 | 26 | "ScheduleInterleaved1F1B",
|
26 | 27 | "ScheduleLoopedBFS",
|
@@ -955,6 +956,82 @@ def _calculate_single_rank_operations(self, rank):
|
955 | 956 | return rank_ops
|
956 | 957 |
|
957 | 958 |
|
| 959 | +def _get_1f1b_rank_ops( |
| 960 | + n_local_stages, |
| 961 | + pp_group_size, |
| 962 | + warmup_ops, |
| 963 | + fwd_bwd_ops, |
| 964 | + cooldown_ops, |
| 965 | + rank, |
| 966 | + forward_stage_index, |
| 967 | + backward_stage_index, |
| 968 | +): |
| 969 | + # All stages start with handling microbatch 0 |
| 970 | + fwd_stage_mb_index: Dict[int, int] = defaultdict(int) |
| 971 | + bwd_stage_mb_index: Dict[int, int] = defaultdict(int) |
| 972 | + # Store the list of operations used for that rank |
| 973 | + rank_ops: List[Optional[_Action]] = [] |
| 974 | + # Pre-padding, rank starts with no-ops based on the warmup. |
| 975 | + for _ in range(rank): |
| 976 | + rank_ops.append(None) |
| 977 | + # These are used to calculate the number of slots to fill with no-ops, to account for the delay in warmup |
| 978 | + # when we want to wait for the backward to trickle back up and start 1f1b to align all ranks. |
| 979 | + # Formula: |
| 980 | + # pre-padding + warmup_ops + post_warmup_ops = earliest time step of first backward |
| 981 | + # post_warmup_ops = [earliest time step of first backward] - (warmup_ops + pre-padding) |
| 982 | + # earliest time step of first backward = [local_stages * group_size + 2 * (group_size - 1 - rank)] |
| 983 | + # warmup_ops = calculated above |
| 984 | + post_warmup_ops = ( |
| 985 | + n_local_stages * pp_group_size + 2 * (pp_group_size - 1 - rank) |
| 986 | + ) - (warmup_ops + rank) |
| 987 | + |
| 988 | + total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops |
| 989 | + |
| 990 | + for op in range(total_ops): |
| 991 | + # Warmup phase |
| 992 | + if op < warmup_ops: |
| 993 | + fwd_stage_index = forward_stage_index(op) |
| 994 | + # This will assign the current microbatch index and update it as well |
| 995 | + fwd_stage_mb_index[fwd_stage_index] = ( |
| 996 | + mb_index := fwd_stage_mb_index[fwd_stage_index] |
| 997 | + ) + 1 |
| 998 | + rank_ops.append( |
| 999 | + _Action(_ComputationType.FORWARD, mb_index, fwd_stage_index) |
| 1000 | + ) |
| 1001 | + if op == warmup_ops - 1: |
| 1002 | + # This is the last step in the warmup phase, so we need to wait for the backward to trickle back up |
| 1003 | + rank_ops.extend([None] * post_warmup_ops) |
| 1004 | + # 1F1B Phase (forward and backward) |
| 1005 | + elif warmup_ops <= op < warmup_ops + fwd_bwd_ops: |
| 1006 | + fwd_stage_index = forward_stage_index(op) |
| 1007 | + fwd_stage_mb_index[fwd_stage_index] = ( |
| 1008 | + fwd_mb_index := fwd_stage_mb_index[fwd_stage_index] |
| 1009 | + ) + 1 |
| 1010 | + rank_ops.append( |
| 1011 | + _Action(_ComputationType.FORWARD, fwd_mb_index, fwd_stage_index) |
| 1012 | + ) |
| 1013 | + bwd_stage_index = backward_stage_index(op) |
| 1014 | + bwd_stage_mb_index[bwd_stage_index] = ( |
| 1015 | + bwd_mb_index := bwd_stage_mb_index[bwd_stage_index] |
| 1016 | + ) + 1 |
| 1017 | + rank_ops.append( |
| 1018 | + _Action(_ComputationType.BACKWARD, bwd_mb_index, bwd_stage_index) |
| 1019 | + ) |
| 1020 | + # Cooldown phase |
| 1021 | + else: |
| 1022 | + # During cooldown phase, we need steps to align with 1f1b happening in other ranks |
| 1023 | + # TODO: we don't need to always append, after all 1f1b are finished we can stop appending None |
| 1024 | + rank_ops.append(None) |
| 1025 | + bwd_stage_index = backward_stage_index(op) |
| 1026 | + bwd_stage_mb_index[bwd_stage_index] = ( |
| 1027 | + bwd_mb_index := bwd_stage_mb_index[bwd_stage_index] |
| 1028 | + ) + 1 |
| 1029 | + rank_ops.append( |
| 1030 | + _Action(_ComputationType.BACKWARD, bwd_mb_index, bwd_stage_index) |
| 1031 | + ) |
| 1032 | + return rank_ops |
| 1033 | + |
| 1034 | + |
958 | 1035 | class ScheduleInterleaved1F1B(PipelineScheduleMulti):
|
959 | 1036 | """
|
960 | 1037 | The Interleaved 1F1B schedule.
|
@@ -1046,74 +1123,119 @@ def backward_stage_index(step):
|
1046 | 1123 | )
|
1047 | 1124 | return (local_index * self.pp_group_size) + rank
|
1048 | 1125 |
|
1049 |
| - # Dictionary for tracking {stage index : current microbatch index} |
1050 |
| - # All stages start with handling microbatch 0 |
1051 |
| - fwd_stage_mb_index: Dict[int, int] = defaultdict(int) |
1052 |
| - bwd_stage_mb_index: Dict[int, int] = defaultdict(int) |
| 1126 | + return _get_1f1b_rank_ops( |
| 1127 | + self.n_local_stages, |
| 1128 | + self.pp_group_size, |
| 1129 | + warmup_ops, |
| 1130 | + fwd_bwd_ops, |
| 1131 | + cooldown_ops, |
| 1132 | + rank, |
| 1133 | + forward_stage_index, |
| 1134 | + backward_stage_index, |
| 1135 | + ) |
1053 | 1136 |
|
1054 |
| - # Store the list of operations used for that rank |
1055 |
| - rank_ops: List[Optional[_Action]] = [] |
1056 |
| - # Pre-padding, rank starts with no-ops based on the warmup. |
1057 |
| - for _ in range(rank): |
1058 |
| - rank_ops.append(None) |
1059 | 1137 |
|
1060 |
| - # These are used to calculate the number of slots to fill with no-ops, to account for the delay in warmup |
1061 |
| - # when we want to wait for the backward to trickle back up and start 1f1b to align all ranks. |
1062 |
| - # Formula: |
1063 |
| - # pre-padding + warmup_ops + post_warmup_ops = earliest time step of first backward |
1064 |
| - # post_warmup_ops = [earliest time step of first backward] - (warmup_ops + pre-padding) |
1065 |
| - # earliest time step of first backward = [local_stages * group_size + 2 * (group_size - 1 - rank)] |
1066 |
| - # warmup_ops = calculated above |
1067 |
| - post_warmup_ops = ( |
1068 |
| - self.n_local_stages * self.pp_group_size |
1069 |
| - + 2 * (self.pp_group_size - 1 - rank) |
1070 |
| - ) - (warmup_ops + rank) |
1071 |
| - |
1072 |
| - for op in range(total_ops): |
1073 |
| - # Warmup phase |
1074 |
| - if op < warmup_ops: |
1075 |
| - fwd_stage_index = forward_stage_index(op) |
1076 |
| - # This will assign the current microbatch index and update it as well |
1077 |
| - fwd_stage_mb_index[fwd_stage_index] = ( |
1078 |
| - mb_index := fwd_stage_mb_index[fwd_stage_index] |
1079 |
| - ) + 1 |
1080 |
| - rank_ops.append( |
1081 |
| - _Action(_ComputationType.FORWARD, mb_index, fwd_stage_index) |
1082 |
| - ) |
1083 |
| - if op == warmup_ops - 1: |
1084 |
| - # This is the last step in the warmup phase, so we need to wait for the backward to trickle back up |
1085 |
| - rank_ops.extend([None] * post_warmup_ops) |
1086 |
| - # 1F1B Phase (forward and backward) |
1087 |
| - elif warmup_ops <= op < warmup_ops + fwd_bwd_ops: |
1088 |
| - fwd_stage_index = forward_stage_index(op) |
1089 |
| - fwd_stage_mb_index[fwd_stage_index] = ( |
1090 |
| - fwd_mb_index := fwd_stage_mb_index[fwd_stage_index] |
1091 |
| - ) + 1 |
1092 |
| - rank_ops.append( |
1093 |
| - _Action(_ComputationType.FORWARD, fwd_mb_index, fwd_stage_index) |
1094 |
| - ) |
| 1138 | +class ScheduleFlexibleInterleaved1F1B(PipelineScheduleMulti): |
| 1139 | + """ |
| 1140 | + The Flexible Interleaved 1F1B schedule. |
1095 | 1141 |
|
1096 |
| - bwd_stage_index = backward_stage_index(op) |
1097 |
| - bwd_stage_mb_index[bwd_stage_index] = ( |
1098 |
| - bwd_mb_index := bwd_stage_mb_index[bwd_stage_index] |
1099 |
| - ) + 1 |
1100 |
| - rank_ops.append( |
1101 |
| - _Action(_ComputationType.BACKWARD, bwd_mb_index, bwd_stage_index) |
1102 |
| - ) |
1103 |
| - # Cooldown phase |
1104 |
| - else: |
1105 |
| - # During cooldown phase, we need steps to align with 1f1b happening in other ranks |
1106 |
| - # TODO: we don't need to always append, after all 1f1b are finished we can stop appending None |
1107 |
| - rank_ops.append(None) |
1108 |
| - bwd_stage_index = backward_stage_index(op) |
1109 |
| - bwd_stage_mb_index[bwd_stage_index] = ( |
1110 |
| - bwd_mb_index := bwd_stage_mb_index[bwd_stage_index] |
1111 |
| - ) + 1 |
1112 |
| - rank_ops.append( |
1113 |
| - _Action(_ComputationType.BACKWARD, bwd_mb_index, bwd_stage_index) |
1114 |
| - ) |
| 1142 | + This schedule is mostly similar to the interleaved 1F1B schedule. |
| 1143 | + It differs by being relaxing the requirement of num_microbatch % pp_size == 0. |
| 1144 | + Using the flex_pp schedule, we will have num_rounds = max(1, n_microbatches // pp_group_size) and |
| 1145 | + it works as long as n_microbatches % num_rounds is 0. As a few examples, support |
1115 | 1146 |
|
1116 |
| - # Post padding |
1117 |
| - for _ in range(self.pp_group_size - rank - 1): |
1118 |
| - rank_ops.append(None) |
1119 |
| - return rank_ops |
| 1147 | + 1. pp_group_size = 4, n_microbatches = 10. We will have num_rounds = 2 and n_microbatches % 2 is 0. |
| 1148 | + 2. pp_group_size = 4, n_microbatches = 3. We will have num_rounds = 1 and n_microbatches % 1 is 0. |
| 1149 | + """ |
| 1150 | + |
| 1151 | + def __init__( |
| 1152 | + self, |
| 1153 | + stages: List[_PipelineStageBase], |
| 1154 | + n_microbatches: int, |
| 1155 | + loss_fn: Optional[Callable] = None, |
| 1156 | + args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None, |
| 1157 | + kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None, |
| 1158 | + output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, |
| 1159 | + ): |
| 1160 | + self.pp_group_size = stages[0].group_size |
| 1161 | + super().__init__( |
| 1162 | + stages=stages, |
| 1163 | + n_microbatches=n_microbatches, |
| 1164 | + loss_fn=loss_fn, |
| 1165 | + args_chunk_spec=args_chunk_spec, |
| 1166 | + kwargs_chunk_spec=kwargs_chunk_spec, |
| 1167 | + output_merge_spec=output_merge_spec, |
| 1168 | + ) |
| 1169 | + self.n_local_stages = len(stages) |
| 1170 | + self.rank = stages[0].group_rank |
| 1171 | + self.number_of_rounds = max(1, n_microbatches // self.pp_group_size) |
| 1172 | + self.microbatches_per_round = n_microbatches // self.number_of_rounds |
| 1173 | + if n_microbatches % self.number_of_rounds != 0: |
| 1174 | + raise ValueError( |
| 1175 | + "Flexible Interleaved 1F1B requires the number of microbatches to be a " |
| 1176 | + f"multiple of the number of rounds ({self.number_of_rounds}), " |
| 1177 | + f"but got {n_microbatches}." |
| 1178 | + ) |
| 1179 | + # 1. Create the pipeline_order (all ranks do this calculation) |
| 1180 | + # This will be used to keep track of the current state of the entire pipeline |
| 1181 | + # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...] |
| 1182 | + self.pipeline_order: Dict[int, List[Optional[_Action]]] = {} |
| 1183 | + for rank in range(self.pp_group_size): |
| 1184 | + rank_ops = self._calculate_single_rank_operations(rank) |
| 1185 | + self.pipeline_order[rank] = rank_ops |
| 1186 | + |
| 1187 | + def _calculate_single_rank_operations(self, rank) -> List[Optional[_Action]]: |
| 1188 | + def get_rank_warmup_ops(rank): |
| 1189 | + # Warms up operations for last stage |
| 1190 | + warmups_ops_last_stage = ( |
| 1191 | + self.n_local_stages - 1 |
| 1192 | + ) * self.microbatches_per_round |
| 1193 | + # Increment warmup operations by 2 for each hop away from the last stage |
| 1194 | + warmup_ops = warmups_ops_last_stage + 2 * ((self.pp_group_size - 1) - rank) |
| 1195 | + # We cannot have more warmup operations than there are number of microbatches, so cap it there |
| 1196 | + return min(warmup_ops, self._n_microbatches * self.n_local_stages) |
| 1197 | + |
| 1198 | + warmup_ops = get_rank_warmup_ops(rank) |
| 1199 | + microbatch_ops = self.n_local_stages * self._n_microbatches |
| 1200 | + # fwd_bwd_ops should encompass the remaining forwards |
| 1201 | + fwd_bwd_ops = microbatch_ops - warmup_ops |
| 1202 | + # cooldown_ops should encompass the remaining backwards |
| 1203 | + cooldown_ops = microbatch_ops - fwd_bwd_ops |
| 1204 | + # total ops encompass both forward and backward ops |
| 1205 | + total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops |
| 1206 | + # warmup_ops + fwd_bwd_ops * 2 + cooldown_ops == microbatch_ops * 2 |
| 1207 | + logger.debug( |
| 1208 | + "rank %s, warmup_ops %s, 1f1b %s, cooldown_ops %s total_ops %s", |
| 1209 | + rank, |
| 1210 | + warmup_ops, |
| 1211 | + fwd_bwd_ops, |
| 1212 | + cooldown_ops, |
| 1213 | + total_ops, |
| 1214 | + ) |
| 1215 | + |
| 1216 | + # Calculates the stage index based on step and pp_group_size |
| 1217 | + |
| 1218 | + def forward_stage_index(step): |
| 1219 | + # Get the local index from 0 to n_local_stages-1 |
| 1220 | + local_index = (step // self.microbatches_per_round) % self.n_local_stages |
| 1221 | + return (local_index * self.pp_group_size) + rank |
| 1222 | + |
| 1223 | + def backward_stage_index(step): |
| 1224 | + local_index = ( |
| 1225 | + self.n_local_stages |
| 1226 | + - 1 |
| 1227 | + - ((step - warmup_ops) // self.microbatches_per_round) |
| 1228 | + % self.n_local_stages |
| 1229 | + ) |
| 1230 | + return (local_index * self.pp_group_size) + rank |
| 1231 | + |
| 1232 | + return _get_1f1b_rank_ops( |
| 1233 | + self.n_local_stages, |
| 1234 | + self.pp_group_size, |
| 1235 | + warmup_ops, |
| 1236 | + fwd_bwd_ops, |
| 1237 | + cooldown_ops, |
| 1238 | + rank, |
| 1239 | + forward_stage_index, |
| 1240 | + backward_stage_index, |
| 1241 | + ) |
0 commit comments