Skip to content

Commit c4a3ff2

Browse files
TennyWang1223root
andauthored
[fix]: support dim(-1) allgather (#2162)
* [fix]: support dim(-1) allgather Signed-off-by: root <root@hjbog-srdc-24.amd.com> * [fix] test script format Signed-off-by: root <root@hjbog-srdc-24.amd.com> --------- Signed-off-by: root <root@hjbog-srdc-24.amd.com> Co-authored-by: root <root@hjbog-srdc-24.amd.com>
1 parent 87006ab commit c4a3ff2

File tree

8 files changed

+239
-88
lines changed

8 files changed

+239
-88
lines changed

aiter/dist/device_communicators/custom_all_reduce.py

Lines changed: 71 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,10 @@ def __init__(
5454
self,
5555
group: ProcessGroup,
5656
device: Union[int, str, torch.device],
57-
max_size=8192 * 1024 * 8 * 2, # In allreduce 2stage writemode, use 2x tmp buffer
57+
max_size=8192
58+
* 1024
59+
* 8
60+
* 2, # In allreduce 2stage writemode, use 2x tmp buffer
5861
enable_register_for_capturing: bool = True,
5962
) -> None:
6063
"""
@@ -160,7 +163,9 @@ def __init__(
160163
self.input_buffer = torch.empty(max_size, dtype=torch.uint8, device=self.device)
161164
# This is a pre-registered IPC buffer for output. In eager mode, kernel
162165
# writes results to this buffer, then it's copied to the actual output
163-
self.output_buffer = torch.empty(max_size, dtype=torch.uint8, device=self.device)
166+
self.output_buffer = torch.empty(
167+
max_size, dtype=torch.uint8, device=self.device
168+
)
164169
# This is a buffer for storing the tuples of pointers pointing to
165170
# IPC buffers from all ranks. Each registered tuple has size of
166171
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
@@ -247,7 +252,7 @@ def _gather_ipc_meta(self, shard_data):
247252
def register_input_buffer(self, inp: torch.Tensor):
248253
handles, offsets = self._get_ipc_meta(inp)
249254
ops.register_input_buffer(self._ptr, inp, handles, offsets)
250-
255+
251256
def register_output_buffer(self, out: torch.Tensor):
252257
handles, offsets = self._get_ipc_meta(out)
253258
ops.register_output_buffer(self._ptr, out, handles, offsets)
@@ -316,7 +321,7 @@ def custom_all_reduce(
316321
use_new=use_new,
317322
open_fp8_quant=open_fp8_quant,
318323
registered_input=self.enable_register_for_capturing,
319-
registered_output=self.enable_register_for_capturing
324+
registered_output=self.enable_register_for_capturing,
320325
)
321326
else:
322327
# if warm up, mimic the allocation pattern
@@ -332,7 +337,7 @@ def custom_all_reduce(
332337
use_new=use_new,
333338
open_fp8_quant=open_fp8_quant,
334339
registered_input=False,
335-
registered_output=False
340+
registered_output=False,
336341
)
337342

338343
def reduce_scatter(
@@ -361,31 +366,50 @@ def custom_reduce_scatter(
361366
else:
362367
return self.reduce_scatter(input, output, registered=False)
363368

364-
def all_gather_reg(self, inp: torch.Tensor, out: torch.Tensor = None):
369+
def _allgather_out_shape(self, inp: torch.Tensor, dim: int):
370+
ndim = inp.dim()
371+
if dim == 0:
372+
return (inp.shape[0] * self.world_size,) + inp.shape[1:]
373+
if dim == -1 or dim == ndim - 1:
374+
return inp.shape[:-1] + (inp.shape[-1] * self.world_size,)
375+
print(
376+
f"[aiter] allgather does not support dim={dim}, falling back to 1-D output"
377+
)
378+
return (inp.numel() * self.world_size,)
379+
380+
def all_gather_reg(self, inp: torch.Tensor, out: torch.Tensor = None, dim: int = 0):
365381
if out is None:
366382
out = torch.empty(
367-
inp.numel() * self.world_size, dtype=inp.dtype, device=inp.device
383+
self._allgather_out_shape(inp, dim),
384+
dtype=inp.dtype,
385+
device=inp.device,
368386
)
369-
ops.all_gather_reg(self._ptr, inp, out)
387+
ops.all_gather_reg(self._ptr, inp, out, inp.shape[-1], dim)
370388
return out
371389

372-
def all_gather_unreg(self, inp: torch.Tensor, out: torch.Tensor = None):
390+
def all_gather_unreg(
391+
self, inp: torch.Tensor, out: torch.Tensor = None, dim: int = 0
392+
):
373393
if out is None:
374394
out = torch.empty(
375-
inp.numel() * self.world_size, dtype=inp.dtype, device=inp.device
395+
self._allgather_out_shape(inp, dim),
396+
dtype=inp.dtype,
397+
device=inp.device,
376398
)
377-
ops.all_gather_unreg(self._ptr, inp, self.input_buffer, out)
399+
ops.all_gather_unreg(self._ptr, inp, self.input_buffer, out, inp.shape[-1], dim)
378400
return out
379401

380-
def custom_all_gather(self, inp: torch.Tensor) -> Optional[torch.Tensor]:
402+
def custom_all_gather(
403+
self, inp: torch.Tensor, dim: int = 0
404+
) -> Optional[torch.Tensor]:
381405
if self._IS_CAPTURING:
382406
if torch.cuda.is_current_stream_capturing():
383-
return self.all_gather_reg(inp)
407+
return self.all_gather_reg(inp, dim=dim)
384408
else:
385409
print("allgather capture hipgraph error")
386410
return torch.zeros_like(inp)
387411
else:
388-
return self.all_gather_unreg(inp)
412+
return self.all_gather_unreg(inp, dim=dim)
389413

390414
def fused_ar_rms(
391415
self,
@@ -422,7 +446,9 @@ def fused_ar_rms(
422446
if out is None:
423447
out = torch.empty(inp.shape, dtype=fp8, device=inp.device)
424448
if scale_out is None:
425-
scale_out = torch.empty(inp.shape[:-1] + (1,), dtype=torch.float32, device=inp.device)
449+
scale_out = torch.empty(
450+
inp.shape[:-1] + (1,), dtype=torch.float32, device=inp.device
451+
)
426452
ops.fused_allreduce_rmsnorm_quant(
427453
self._ptr,
428454
inp,
@@ -451,15 +477,25 @@ def custom_fused_ar_rms(
451477
if self._IS_CAPTURING:
452478
if torch.cuda.is_current_stream_capturing():
453479
return self.fused_ar_rms(
454-
input, residual_inp, w=weight, eps=eps, registered=True, use_1stage=use_1stage,
480+
input,
481+
residual_inp,
482+
w=weight,
483+
eps=eps,
484+
registered=True,
485+
use_1stage=use_1stage,
455486
)
456487
else:
457488
return torch.zeros_like(input), torch.zeros_like(input)
458489
else:
459490
return self.fused_ar_rms(
460-
input, residual_inp, w=weight, eps=eps, registered=False, use_1stage=use_1stage,
491+
input,
492+
residual_inp,
493+
w=weight,
494+
eps=eps,
495+
registered=False,
496+
use_1stage=use_1stage,
461497
)
462-
498+
463499
def custom_fused_ar_rms_quant(
464500
self,
465501
input: torch.Tensor,
@@ -474,15 +510,29 @@ def custom_fused_ar_rms_quant(
474510
if self._IS_CAPTURING:
475511
if torch.cuda.is_current_stream_capturing():
476512
return self.fused_ar_rms(
477-
input, residual_inp, w=weight, eps=eps, registered=True, use_1stage=use_1stage, post_per_token_quant=True,
513+
input,
514+
residual_inp,
515+
w=weight,
516+
eps=eps,
517+
registered=True,
518+
use_1stage=use_1stage,
519+
post_per_token_quant=True,
478520
)
479521
else:
480522
dummy_out = torch.zeros(input.shape, dtype=fp8, device=input.device)
481-
dummy_scale_out = torch.zeros(input.shape[:-1] + (1,), dtype=torch.float32, device=input.device)
523+
dummy_scale_out = torch.zeros(
524+
input.shape[:-1] + (1,), dtype=torch.float32, device=input.device
525+
)
482526
return dummy_out, torch.zeros_like(input), dummy_scale_out
483527
else:
484528
return self.fused_ar_rms(
485-
input, residual_inp, w=weight, eps=eps, registered=False, use_1stage=use_1stage, post_per_token_quant=True,
529+
input,
530+
residual_inp,
531+
w=weight,
532+
eps=eps,
533+
registered=False,
534+
use_1stage=use_1stage,
535+
post_per_token_quant=True,
486536
)
487537

488538
def close(self):

aiter/dist/parallel_state.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -175,12 +175,14 @@ def fused_allreduce_rmsnorm_quant_(
175175
if supports_custom_op():
176176

177177
# @torch.library.custom_op("aiter::outplace_all_gather", mutates_args=[])
178-
def outplace_all_gather(input: torch.Tensor, group_name: str) -> torch.Tensor:
178+
def outplace_all_gather(
179+
input: torch.Tensor, group_name: str, dim: int = 0
180+
) -> torch.Tensor:
179181
assert group_name in _groups, f"Group {group_name} is not found."
180182
group = _groups[group_name]()
181183
if group is None:
182184
raise ValueError(f"Group {group_name} is destroyed.")
183-
return group._all_gather_out_place(input)
185+
return group._all_gather_out_place(input, dim)
184186

185187
def outplace_reduce_scatter(
186188
input: torch.Tensor, output: torch.Tensor, group_name: str, dim: int
@@ -442,11 +444,11 @@ def _fused_allreduce_rmsnorm_quant_out_place(
442444
input_, residual_inp_, weight_, eps
443445
)
444446

445-
def _all_gather_out_place(self, input_: torch.Tensor) -> torch.Tensor:
447+
def _all_gather_out_place(self, input_: torch.Tensor, dim: int = 0) -> torch.Tensor:
446448
ca_comm = self.device_communicator.ca_comm
447449
assert ca_comm is not None
448450
assert not ca_comm.disabled
449-
out = ca_comm.custom_all_gather(input_)
451+
out = ca_comm.custom_all_gather(input_, dim)
450452
assert out is not None
451453
return out
452454

@@ -491,30 +493,32 @@ def all_gather(
491493
self, input_: torch.Tensor, use_custom: bool = False, dim: int = -1
492494
) -> torch.Tensor:
493495
world_size = self.world_size
494-
# Bypass the function if we are using only 1 GPU.
495496
if world_size == 1:
496497
return input_
497498
assert (
498499
-input_.dim() <= dim < input_.dim()
499500
), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
500501

501502
if dim < 0:
502-
# Convert negative dim to positive.
503503
dim += input_.dim()
504504
input_size = input_.size()
505-
if use_custom:
506-
output_tensor = outplace_all_gather(input_, group_name=self.unique_name)
507-
output_tensor = output_tensor.reshape((world_size,) + input_size)
508-
else:
509-
# Allocate output tensor.
510-
output_tensor = torch.empty(
511-
(world_size,) + input_size, dtype=input_.dtype, device=input_.device
512-
)
513-
# All-gather.
514-
torch.distributed.all_gather_into_tensor(
515-
output_tensor, input_, group=self.device_group
516-
)
517-
# Reshape
505+
506+
is_last_dim = dim == input_.dim() - 1
507+
can_use_custom = use_custom and (
508+
dim == 0
509+
or (is_last_dim and input_size[-1] * input_.element_size() % 16 == 0)
510+
)
511+
512+
if can_use_custom:
513+
return outplace_all_gather(input_, group_name=self.unique_name, dim=dim)
514+
515+
# NCCL path
516+
output_tensor = torch.empty(
517+
(world_size,) + input_size, dtype=input_.dtype, device=input_.device
518+
)
519+
torch.distributed.all_gather_into_tensor(
520+
output_tensor, input_, group=self.device_group
521+
)
518522
output_tensor = output_tensor.movedim(0, dim)
519523
output_tensor = output_tensor.reshape(
520524
input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :]

aiter/ops/custom_all_reduce.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,19 @@ def reduce_scatter(
4343

4444

4545
@compile_ops("module_custom_all_reduce")
46-
def all_gather_reg(_fa: int, inp: torch.Tensor, out: torch.Tensor) -> None: ...
46+
def all_gather_reg(
47+
_fa: int, inp: torch.Tensor, out: torch.Tensor, last_dim_size: int, dim: int
48+
) -> None: ...
4749

4850

4951
@compile_ops("module_custom_all_reduce")
5052
def all_gather_unreg(
51-
_fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor
53+
_fa: int,
54+
inp: torch.Tensor,
55+
reg_buffer: torch.Tensor,
56+
out: torch.Tensor,
57+
last_dim_size: int,
58+
dim: int,
5259
) -> None: ...
5360

5461

0 commit comments

Comments
 (0)