Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 27 additions & 11 deletions aiter/dist/device_communicators/custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,10 @@ def __init__(
self.meta = ops.allocate_meta_buffer(ops.meta_size() + max_size)
# This is a pre-registered IPC buffer. In eager mode, input tensors
# are first copied into this buffer before allreduce is performed
self.buffer = torch.empty(max_size, dtype=torch.uint8, device=self.device)
self.input_buffer = torch.empty(max_size, dtype=torch.uint8, device=self.device)
# This is a pre-registered IPC buffer for output. In eager mode, kernel
# writes results to this buffer, then it's copied to the actual output
self.output_buffer = torch.empty(max_size, dtype=torch.uint8, device=self.device)
# This is a buffer for storing the tuples of pointers pointing to
# IPC buffers from all ranks. Each registered tuple has size of
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
Expand All @@ -177,7 +180,9 @@ def __init__(
self._ptr = ops.init_custom_ar(
self.meta, self.rank_data, handles, offsets, rank, self.fully_connected
)
self.register_buffer(self.buffer)
# Register both input and output buffers
self.register_input_buffer(self.input_buffer)
self.register_output_buffer(self.output_buffer)

@contextmanager
def capture(self):
Expand Down Expand Up @@ -236,9 +241,13 @@ def _gather_ipc_meta(self, shard_data):
offsets.append(all_data[i][0][1]) # type: ignore
return handles, offsets

def register_buffer(self, inp: torch.Tensor):
def register_input_buffer(self, inp: torch.Tensor):
handles, offsets = self._get_ipc_meta(inp)
ops.register_buffer(self._ptr, inp, handles, offsets)
ops.register_input_buffer(self._ptr, inp, handles, offsets)

def register_output_buffer(self, out: torch.Tensor):
handles, offsets = self._get_ipc_meta(out)
ops.register_output_buffer(self._ptr, out, handles, offsets)

def register_graph_buffers(self):
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
Expand Down Expand Up @@ -268,7 +277,8 @@ def all_reduce(
out: Optional[torch.Tensor] = None,
use_new: bool = True,
open_fp8_quant: bool = False,
registered: bool = False,
registered_input: bool = False,
registered_output: bool = False,
):
"""Performs an out-of-place all reduce.

Expand All @@ -284,7 +294,8 @@ def all_reduce(
out,
use_new,
open_fp8_quant,
None if registered else self.buffer,
None if registered_input else self.input_buffer,
None if registered_output else self.output_buffer,
)
return out

Expand All @@ -300,7 +311,8 @@ def custom_all_reduce(
input,
use_new=use_new,
open_fp8_quant=open_fp8_quant,
registered=True,
registered_input=True,
registered_output=True
)
else:
# if warm up, mimic the allocation pattern
Expand All @@ -312,7 +324,11 @@ def custom_all_reduce(
# be small(<=1% of overall latency) compared to the performance
# gains of using custom kernels
return self.all_reduce(
input, use_new=use_new, open_fp8_quant=open_fp8_quant, registered=False
input,
use_new=use_new,
open_fp8_quant=open_fp8_quant,
registered_input=False,
registered_output=False
)

def reduce_scatter(
Expand All @@ -326,7 +342,7 @@ def reduce_scatter(
self._ptr,
inp,
out,
None if registered else self.buffer,
None if registered else self.input_buffer,
)

def custom_reduce_scatter(
Expand Down Expand Up @@ -354,7 +370,7 @@ def all_gather_unreg(self, inp: torch.Tensor, out: torch.Tensor = None):
out = torch.empty(
inp.numel() * self.world_size, dtype=inp.dtype, device=inp.device
)
ops.all_gather_unreg(self._ptr, inp, self.buffer, out)
ops.all_gather_unreg(self._ptr, inp, self.input_buffer, out)
return out

def custom_all_gather(self, inp: torch.Tensor) -> Optional[torch.Tensor]:
Expand Down Expand Up @@ -390,7 +406,7 @@ def fused_ar_rms(
out,
w,
eps,
None if registered else self.buffer,
None if registered else self.input_buffer,
)
return out, res_out

Expand Down
11 changes: 9 additions & 2 deletions aiter/ops/custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def all_reduce(
out: torch.Tensor,
use_new: bool,
open_fp8_quant: bool,
reg_buffer: Optional[torch.Tensor] = None,
reg_input_buffer: Optional[torch.Tensor] = None,
reg_output_buffer: Optional[torch.Tensor] = None,
) -> None: ...


Expand Down Expand Up @@ -179,7 +180,13 @@ def meta_size() -> int: ...


@compile_ops("module_custom_all_reduce")
def register_buffer(
def register_input_buffer(
_fa: int, t: torch.Tensor, handles: List[torch.Tensor], offsets: List[int]
) -> None: ...


@compile_ops("module_custom_all_reduce")
def register_output_buffer(
_fa: int, t: torch.Tensor, handles: List[torch.Tensor], offsets: List[int]
) -> None: ...

Expand Down
Loading
Loading