@@ -292,7 +292,8 @@ def top2gating(logits: Tensor,
292292 min_capacity : int ,
293293 drop_tokens : bool = True ,
294294 ep_group : Union [torch .distributed .ProcessGroup , None ] = None ,
295- top2_2nd_expert_sampling : bool = True ) -> Tuple [Tensor , Tensor , Tensor , Tensor ]:
295+ top2_2nd_expert_sampling : bool = True ,
296+ use_tutel : bool = False ) -> Tuple [Tensor , Tensor , Tensor , Tensor ]:
296297 """Implements Top2Gating on logits."""
297298 # everything is in fp32 in this function
298299 gates = F .softmax (logits , dim = 1 )
@@ -313,8 +314,12 @@ def top2gating(logits: Tensor,
313314 mask2 = F .one_hot (indices2_s , num_classes = num_experts )
314315
315316 # Compute locations in capacity buffer
316- locations1 = torch .cumsum (mask1 , dim = 0 ) - 1
317- locations2 = torch .cumsum (mask2 , dim = 0 ) - 1
317+ if not use_tutel :
318+ locations1 = torch .cumsum (mask1 , dim = 0 ) - 1
319+ locations2 = torch .cumsum (mask2 , dim = 0 ) - 1
320+ else :
321+ locations1 = tutel_moe .fast_cumsum_sub_one (mask1 )
322+ locations2 = tutel_moe .fast_cumsum_sub_one (mask2 )
318323 # Update 2nd's location by accounting for locations of 1st
319324 locations2 += torch .sum (mask1 , dim = 0 , keepdim = True )
320325
@@ -324,7 +329,7 @@ def top2gating(logits: Tensor,
324329 l_aux = torch .mean (me * ce ) * num_experts * num_experts
325330
326331 # gating decisions
327- exp_counts = torch .sum (mask1 + mask2 , dim = 0 ).detach ().to (logits .device )
332+ exp_counts = torch .sum (mask1 + mask2 , dim = 0 ).detach ().to (logits .device )
328333
329334 if drop_tokens :
330335 # Calculate configured capacity and remove locations outside capacity from mask
@@ -358,6 +363,19 @@ def top2gating(logits: Tensor,
358363 gates1_s /= denom_s
359364 gates2_s /= denom_s
360365
366+ if use_tutel :
367+ # return critical information for tutel
368+ return l_aux , capacity , num_experts , [
369+ indices1_s ,
370+ indices2_s ,
371+ ], [
372+ locations1_s ,
373+ locations2_s ,
374+ ], [
375+ gates1_s ,
376+ gates2_s ,
377+ ], exp_counts
378+
361379 # Calculate combine_weights and dispatch_mask
362380 gates1 = einsum ("s,se->se" , gates1_s , mask1_float )
363381 gates2 = einsum ("s,se->se" , gates2_s , mask2_float )
@@ -517,7 +535,7 @@ def forward(self,
517535
518536 elif self .k == 2 :
519537 gate_output = top2gating (logits , self .capacity_factor if self .training else self .eval_capacity_factor ,
520- self .min_capacity , self .drop_tokens , self .ep_group , self .top2_2nd_expert_sampling )
538+ self .min_capacity , self .drop_tokens , self .ep_group , self .top2_2nd_expert_sampling , use_tutel )
521539 else :
522540 gate_output = topkgating (logits , self .k ,
523541 self .capacity_factor if self .training else self .eval_capacity_factor ,
@@ -568,7 +586,7 @@ def __init__(self,
568586 self .timers = SynchronizedWallClockTimer ()
569587 self .wall_clock_breakdown = False
570588
571- self .use_tutel = use_tutel and TUTEL_INSTALLED and gate .k == 1
589+ self .use_tutel = use_tutel and TUTEL_INSTALLED and ( gate .k == 1 or gate . k == 2 )
572590
573591 if self .use_tutel :
574592 logger .info ('Using Tutel optimizations.' )
0 commit comments