@@ -54,7 +54,10 @@ def __init__(
5454 self ,
5555 group : ProcessGroup ,
5656 device : Union [int , str , torch .device ],
57- max_size = 8192 * 1024 * 8 * 2 , # In allreduce 2stage writemode, use 2x tmp buffer
57+ max_size = 8192
58+ * 1024
59+ * 8
60+ * 2 , # In allreduce 2stage writemode, use 2x tmp buffer
5861 enable_register_for_capturing : bool = True ,
5962 ) -> None :
6063 """
@@ -160,7 +163,9 @@ def __init__(
160163 self .input_buffer = torch .empty (max_size , dtype = torch .uint8 , device = self .device )
161164 # This is a pre-registered IPC buffer for output. In eager mode, kernel
162165 # writes results to this buffer, then it's copied to the actual output
163- self .output_buffer = torch .empty (max_size , dtype = torch .uint8 , device = self .device )
166+ self .output_buffer = torch .empty (
167+ max_size , dtype = torch .uint8 , device = self .device
168+ )
164169 # This is a buffer for storing the tuples of pointers pointing to
165170 # IPC buffers from all ranks. Each registered tuple has size of
166171 # 8*world_size bytes where world_size is at most 8. Allocating 8MB
@@ -247,7 +252,7 @@ def _gather_ipc_meta(self, shard_data):
247252 def register_input_buffer (self , inp : torch .Tensor ):
248253 handles , offsets = self ._get_ipc_meta (inp )
249254 ops .register_input_buffer (self ._ptr , inp , handles , offsets )
250-
255+
251256 def register_output_buffer (self , out : torch .Tensor ):
252257 handles , offsets = self ._get_ipc_meta (out )
253258 ops .register_output_buffer (self ._ptr , out , handles , offsets )
@@ -316,7 +321,7 @@ def custom_all_reduce(
316321 use_new = use_new ,
317322 open_fp8_quant = open_fp8_quant ,
318323 registered_input = self .enable_register_for_capturing ,
319- registered_output = self .enable_register_for_capturing
324+ registered_output = self .enable_register_for_capturing ,
320325 )
321326 else :
322327 # if warm up, mimic the allocation pattern
@@ -332,7 +337,7 @@ def custom_all_reduce(
332337 use_new = use_new ,
333338 open_fp8_quant = open_fp8_quant ,
334339 registered_input = False ,
335- registered_output = False
340+ registered_output = False ,
336341 )
337342
338343 def reduce_scatter (
@@ -361,31 +366,50 @@ def custom_reduce_scatter(
361366 else :
362367 return self .reduce_scatter (input , output , registered = False )
363368
364- def all_gather_reg (self , inp : torch .Tensor , out : torch .Tensor = None ):
369+ def _allgather_out_shape (self , inp : torch .Tensor , dim : int ):
370+ ndim = inp .dim ()
371+ if dim == 0 :
372+ return (inp .shape [0 ] * self .world_size ,) + inp .shape [1 :]
373+ if dim == - 1 or dim == ndim - 1 :
374+ return inp .shape [:- 1 ] + (inp .shape [- 1 ] * self .world_size ,)
375+ print (
376+ f"[aiter] allgather does not support dim={ dim } , falling back to 1-D output"
377+ )
378+ return (inp .numel () * self .world_size ,)
379+
380+ def all_gather_reg (self , inp : torch .Tensor , out : torch .Tensor = None , dim : int = 0 ):
365381 if out is None :
366382 out = torch .empty (
367- inp .numel () * self .world_size , dtype = inp .dtype , device = inp .device
383+ self ._allgather_out_shape (inp , dim ),
384+ dtype = inp .dtype ,
385+ device = inp .device ,
368386 )
369- ops .all_gather_reg (self ._ptr , inp , out )
387+ ops .all_gather_reg (self ._ptr , inp , out , inp . shape [ - 1 ], dim )
370388 return out
371389
372- def all_gather_unreg (self , inp : torch .Tensor , out : torch .Tensor = None ):
390+ def all_gather_unreg (
391+ self , inp : torch .Tensor , out : torch .Tensor = None , dim : int = 0
392+ ):
373393 if out is None :
374394 out = torch .empty (
375- inp .numel () * self .world_size , dtype = inp .dtype , device = inp .device
395+ self ._allgather_out_shape (inp , dim ),
396+ dtype = inp .dtype ,
397+ device = inp .device ,
376398 )
377- ops .all_gather_unreg (self ._ptr , inp , self .input_buffer , out )
399+ ops .all_gather_unreg (self ._ptr , inp , self .input_buffer , out , inp . shape [ - 1 ], dim )
378400 return out
379401
380- def custom_all_gather (self , inp : torch .Tensor ) -> Optional [torch .Tensor ]:
402+ def custom_all_gather (
403+ self , inp : torch .Tensor , dim : int = 0
404+ ) -> Optional [torch .Tensor ]:
381405 if self ._IS_CAPTURING :
382406 if torch .cuda .is_current_stream_capturing ():
383- return self .all_gather_reg (inp )
407+ return self .all_gather_reg (inp , dim = dim )
384408 else :
385409 print ("allgather capture hipgraph error" )
386410 return torch .zeros_like (inp )
387411 else :
388- return self .all_gather_unreg (inp )
412+ return self .all_gather_unreg (inp , dim = dim )
389413
390414 def fused_ar_rms (
391415 self ,
@@ -422,7 +446,9 @@ def fused_ar_rms(
422446 if out is None :
423447 out = torch .empty (inp .shape , dtype = fp8 , device = inp .device )
424448 if scale_out is None :
425- scale_out = torch .empty (inp .shape [:- 1 ] + (1 ,), dtype = torch .float32 , device = inp .device )
449+ scale_out = torch .empty (
450+ inp .shape [:- 1 ] + (1 ,), dtype = torch .float32 , device = inp .device
451+ )
426452 ops .fused_allreduce_rmsnorm_quant (
427453 self ._ptr ,
428454 inp ,
@@ -451,15 +477,25 @@ def custom_fused_ar_rms(
451477 if self ._IS_CAPTURING :
452478 if torch .cuda .is_current_stream_capturing ():
453479 return self .fused_ar_rms (
454- input , residual_inp , w = weight , eps = eps , registered = True , use_1stage = use_1stage ,
480+ input ,
481+ residual_inp ,
482+ w = weight ,
483+ eps = eps ,
484+ registered = True ,
485+ use_1stage = use_1stage ,
455486 )
456487 else :
457488 return torch .zeros_like (input ), torch .zeros_like (input )
458489 else :
459490 return self .fused_ar_rms (
460- input , residual_inp , w = weight , eps = eps , registered = False , use_1stage = use_1stage ,
491+ input ,
492+ residual_inp ,
493+ w = weight ,
494+ eps = eps ,
495+ registered = False ,
496+ use_1stage = use_1stage ,
461497 )
462-
498+
463499 def custom_fused_ar_rms_quant (
464500 self ,
465501 input : torch .Tensor ,
@@ -474,15 +510,29 @@ def custom_fused_ar_rms_quant(
474510 if self ._IS_CAPTURING :
475511 if torch .cuda .is_current_stream_capturing ():
476512 return self .fused_ar_rms (
477- input , residual_inp , w = weight , eps = eps , registered = True , use_1stage = use_1stage , post_per_token_quant = True ,
513+ input ,
514+ residual_inp ,
515+ w = weight ,
516+ eps = eps ,
517+ registered = True ,
518+ use_1stage = use_1stage ,
519+ post_per_token_quant = True ,
478520 )
479521 else :
480522 dummy_out = torch .zeros (input .shape , dtype = fp8 , device = input .device )
481- dummy_scale_out = torch .zeros (input .shape [:- 1 ] + (1 ,), dtype = torch .float32 , device = input .device )
523+ dummy_scale_out = torch .zeros (
524+ input .shape [:- 1 ] + (1 ,), dtype = torch .float32 , device = input .device
525+ )
482526 return dummy_out , torch .zeros_like (input ), dummy_scale_out
483527 else :
484528 return self .fused_ar_rms (
485- input , residual_inp , w = weight , eps = eps , registered = False , use_1stage = use_1stage , post_per_token_quant = True ,
529+ input ,
530+ residual_inp ,
531+ w = weight ,
532+ eps = eps ,
533+ registered = False ,
534+ use_1stage = use_1stage ,
535+ post_per_token_quant = True ,
486536 )
487537
488538 def close (self ):
0 commit comments