From 5007c5944bdfb7368907d9b84998d13856ffaeda Mon Sep 17 00:00:00 2001 From: Sanshan Gao Date: Mon, 31 Mar 2025 16:39:17 -0700 Subject: [PATCH] Sanshang/fix all2all (#202) Summary: Fixed support to replay all2all. Depend on PyTorch https://github.com/pytorch/pytorch/pull/149485. Test Plan: constructed 4 rank case to invoke torch.distributed.all_to_all() and torch.distributed.all_to_all_single(), then dump trace and replay. Differential Revision: D72196033 Pulled By: shengfukevin --- .../comm/backend/pytorch_dist_backend.py | 24 ++------ et_replay/comm/commsTraceParser.py | 23 ++++---- et_replay/comm/comms_utils.py | 56 +++++++------------ et_replay/comm/profiler_trace_analysis.py | 15 +++-- et_replay/tools/comm_replay.py | 2 +- 5 files changed, 48 insertions(+), 72 deletions(-) diff --git a/et_replay/comm/backend/pytorch_dist_backend.py b/et_replay/comm/backend/pytorch_dist_backend.py index 6847005a..a8259daa 100644 --- a/et_replay/comm/backend/pytorch_dist_backend.py +++ b/et_replay/comm/backend/pytorch_dist_backend.py @@ -227,24 +227,12 @@ def all_to_all( "Not using batched embedding tables because extend distributed package not in use" ) - if isinstance(collectiveArgs.opTensor, list): - work = dist.all_to_all( - collectiveArgs.opTensor, - collectiveArgs.ipTensor, - group=self.get_collective_group(collectiveArgs), - async_op=collectiveArgs.asyncOp, - ) - else: - work = dist.all_to_all_single( - collectiveArgs.opTensor - if not pair - else collectiveArgs.opTensor_pair, - collectiveArgs.ipTensor - if not pair - else collectiveArgs.ipTensor_pair, - group=self.get_collective_group(collectiveArgs), - async_op=collectiveArgs.asyncOp, - ) + work = dist.all_to_all( + collectiveArgs.opTensor, + collectiveArgs.ipTensor, + group=self.get_collective_group(collectiveArgs), + async_op=collectiveArgs.asyncOp, + ) if collectiveArgs.asyncOp: collectiveArgs.waitObj.append(work) diff --git a/et_replay/comm/commsTraceParser.py b/et_replay/comm/commsTraceParser.py index e1f65e24..d0cba6f2 100644 --- a/et_replay/comm/commsTraceParser.py +++ b/et_replay/comm/commsTraceParser.py @@ -5,6 +5,8 @@ import logging +import math + from et_replay import ExecutionTrace from et_replay.comm import comms_utils from et_replay.comm.backend.base_backend import supportedP2pOps @@ -189,22 +191,17 @@ def _parse_comms_op_node( # noqa: C901 comm_args.root = comm_args.groupRanks[recorded_rank] comm_args.groupRanks = comm_args.groupRanks - if comm_args.comms == "all_to_allv": + if comm_args.comms == "all_to_all": + # flatten each tensor and store the # of elements into split field + comm_args.inSplit = [math.prod(i) for i in node.input_shapes[0]] + comm_args.outSplit = [math.prod(i) for i in node.output_shapes[0]] + elif comm_args.comms == "all_to_allv": if not comm_args.worldSize: # if no pg info provided, use total ranks as world size comm_args.worldSize = total_ranks - comm_args.inSplit = ( - json.loads(node.commArgs.in_split_size) - if json.loads(node.commArgs.in_split_size) - else [int(comm_args.inMsgSize / comm_args.worldSize)] - * comm_args.worldSize - ) - comm_args.outSplit = ( - json.loads(node.commArgs.out_split_size) - if json.loads(node.commArgs.out_split_size) - else [int(comm_args.outMsgSize / comm_args.worldSize)] - * comm_args.worldSize - ) + comm_args.inSplit = json.loads(node.commArgs.in_split_size) + comm_args.outSplit = json.loads(node.commArgs.out_split_size) + comms_op_list.append(comm_args) return comms_op_list diff --git a/et_replay/comm/comms_utils.py b/et_replay/comm/comms_utils.py index 55653f6b..27cc00d1 100644 --- a/et_replay/comm/comms_utils.py +++ b/et_replay/comm/comms_utils.py @@ -876,16 +876,17 @@ def _prep_all_to_allv( opTensor = self.backendFuncs.alloc_random( [numElementsOut], curDevice, dtype, scaleFactor ) - # all_to_allv requires tensors to specify split + # recorded splits in trace is only for dim 0, but tensor in replay has been flattened. + # need to recalculate the splits for flattened 1D tensor self.collectiveArgs.opTensor_split = ( - curComm.outSplit - if (curComm.outSplit is not None) - else [(numElementsOut // world_size) for _ in range(world_size)] + [numElementsOut // sum(curComm.outSplit) * i for i in curComm.outSplit] + if curComm.outSplit + else None ) self.collectiveArgs.ipTensor_split = ( - curComm.inSplit - if (curComm.inSplit is not None) - else [(numElementsIn // world_size) for _ in range(world_size)] + [numElementsIn // sum(curComm.inSplit) * i for i in curComm.inSplit] + if curComm.inSplit + else None ) return (ipTensor, opTensor) @@ -937,37 +938,22 @@ def _prep_all_to_all( scaleFactor: float, allocate: bool = True, ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: - # all_to_all requires two tensor lists, e.g., List[torch.Tensor] - ipTensor = [] opTensor = [] if allocate: - if commsParams.dcheck == 1: - for _ in range(world_size): - ipTensor.append( - self.backendFuncs.alloc_ones( - [(numElementsIn // world_size)], - curDevice, - commsParams.dtype, - self.initVal, - ) - ) - else: - for _ in range(world_size): - ipTensor.append( - self.backendFuncs.alloc_random( - [(numElementsIn // world_size)], - curDevice, - commsParams.dtype, - scaleFactor, - ) - ) - for _ in range(world_size): - opTensor.append( - self.backendFuncs.alloc_random( - [(numElementsOut // world_size)], curDevice, dtype, scaleFactor - ) - ) + alloc_func = ( + self.backendFuncs.alloc_ones + if commsParams.dcheck == 1 + else self.backendFuncs.alloc_random + ) + ipTensor = [ + alloc_func(i, curDevice, commsParams.dtype, self.initVal) + for i in curComm.inSplit + ] + opTensor = [ + alloc_func(i, curDevice, commsParams.dtype, self.initVal) + for i in curComm.outSplit + ] return (ipTensor, opTensor) def _prep_all_gather( diff --git a/et_replay/comm/profiler_trace_analysis.py b/et_replay/comm/profiler_trace_analysis.py index fcf94813..dd5170d2 100644 --- a/et_replay/comm/profiler_trace_analysis.py +++ b/et_replay/comm/profiler_trace_analysis.py @@ -241,6 +241,7 @@ def pick_comm_bw_(trace_data, comm_bw_data): ] for evt in nccl_events: knl_name = evt["name"][: evt["name"].index("(")] + coll_name = evt["args"]["Collective name"] data_size = _calculate_event_data_size(evt) ranks_count = evt["args"]["Group size"] @@ -248,7 +249,9 @@ def pick_comm_bw_(trace_data, comm_bw_data): pg_id = int(evt["args"]["Process Group Name"]) pg = (*ranks, pg_id) if ranks and rank == min(ranks) else None - comm_bw_data[(knl_name, data_size, ranks_count)].append( + # TODO: calculation of unbalanced all2all bw needs to be improved + # all2all is implemented by single ncclDevKernel_SendRecv() in NCCL + comm_bw_data[(knl_name, coll_name, data_size, ranks_count)].append( [ evt["dur"], evt["args"]["algbw (GB/sec)"], @@ -331,25 +334,27 @@ def analyze_profiler_trace(trace_dir: str, report_dir: str): ) f.write( - f'\n{" ":>70s}|{" ":>5s}|{"AVG.":^19s}|{"p01":^8s}|{"p50":^8s}|{"p90":^8s}|{"p99":^8s}|\n' + f'\n{" ":>86s}|{" ":>5s}|{"AVG.":^19s}|{"p01":^8s}|{"p50":^8s}|{"p90":^8s}|{"p99":^8s}|\n' ) f.write( - f'{"kernel":>50s} {"size":>12s} {"#rks":>6s}|{"#pgs":>5s}|{" dur":>10s} ' + f'{"kernel":>50s} {"coll":>15s} {"size":>12s} {"#rks":>6s}|{"#pgs":>5s}|{" dur":>10s} ' ) for _ in range(5): # average, p01, p50, p90, p99 f.write(f'{" busbw":>8s}|') f.write("\n") f.write( - f'{" ":>50s} {" (B)":>12s} {" ":>6s}|{" ":>5s}|{" (ms)":>10s} ' + f'{" ":>66s} {" (B)":>12s} {" ":>6s}|{" ":>5s}|{" (ms)":>10s} ' ) for _ in range(5): # average, p50, p90, p99 f.write(f'{"(GB/s)":>8s}|') f.write("\n") for k, v in comm_bw_summary.items(): - f.write(f"{k[0]:>50s} {k[1]:>12d} {k[2]:>6d}|{v[0]:>5d}|{v[1]/1e3:>10.3f} ") + f.write( + f"{k[0]:>50s} {k[1]:>15s} {k[2]:>12d} {k[3]:>6d}|{v[0]:>5d}|{v[1]/1e3:>10.3f} " + ) for i in range(2, len(v)): f.write(f"{v[i]:>8.2f}|") f.write("\n") diff --git a/et_replay/tools/comm_replay.py b/et_replay/tools/comm_replay.py index 76ea2974..30d0e78e 100644 --- a/et_replay/tools/comm_replay.py +++ b/et_replay/tools/comm_replay.py @@ -1099,7 +1099,7 @@ def replaySingle( if groupRank >= 0: commDesc = f"{str(curComm.comms)}: NumElemsIn={curComm.inMsgSize}, NumElemsOut={curComm.outMsgSize}, Dtype={curComm.dtype}" - if curComm.comms == "all_to_allv": + if curComm.comms in ("all_to_all", "all_to_allv"): commDesc += ( f", InSplit={curComm.inSplit}, OutSplit={curComm.outSplit}" )