diff --git a/deepspeed/moe/sharded_moe.py b/deepspeed/moe/sharded_moe.py index 340bc82de508..0beb9b872acd 100644 --- a/deepspeed/moe/sharded_moe.py +++ b/deepspeed/moe/sharded_moe.py @@ -292,7 +292,8 @@ def top2gating(logits: Tensor, min_capacity: int, drop_tokens: bool = True, ep_group: Union[torch.distributed.ProcessGroup, None] = None, - top2_2nd_expert_sampling: bool = True) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + top2_2nd_expert_sampling: bool = True, + use_tutel: bool = False) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """Implements Top2Gating on logits.""" # everything is in fp32 in this function gates = F.softmax(logits, dim=1) @@ -313,8 +314,12 @@ def top2gating(logits: Tensor, mask2 = F.one_hot(indices2_s, num_classes=num_experts) # Compute locations in capacity buffer - locations1 = torch.cumsum(mask1, dim=0) - 1 - locations2 = torch.cumsum(mask2, dim=0) - 1 + if not use_tutel: + locations1 = torch.cumsum(mask1, dim=0) - 1 + locations2 = torch.cumsum(mask2, dim=0) - 1 + else: + locations1 = tutel_moe.fast_cumsum_sub_one(mask1) + locations2 = tutel_moe.fast_cumsum_sub_one(mask2) # Update 2nd's location by accounting for locations of 1st locations2 += torch.sum(mask1, dim=0, keepdim=True) @@ -358,6 +363,19 @@ def top2gating(logits: Tensor, gates1_s /= denom_s gates2_s /= denom_s + if use_tutel: + # return critical information for tutel + return l_aux, capacity, num_experts, [ + indices1_s, + indices2_s, + ], [ + locations1_s, + locations2_s, + ], [ + gates1_s, + gates2_s, + ], exp_counts + # Calculate combine_weights and dispatch_mask gates1 = einsum("s,se->se", gates1_s, mask1_float) gates2 = einsum("s,se->se", gates2_s, mask2_float) @@ -517,7 +535,8 @@ def forward(self, elif self.k == 2: gate_output = top2gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor, - self.min_capacity, self.drop_tokens, self.ep_group, self.top2_2nd_expert_sampling) + self.min_capacity, self.drop_tokens, self.ep_group, self.top2_2nd_expert_sampling, + use_tutel) else: gate_output = topkgating(logits, self.k, self.capacity_factor if self.training else self.eval_capacity_factor, @@ -568,7 +587,7 @@ def __init__(self, self.timers = SynchronizedWallClockTimer() self.wall_clock_breakdown = False - self.use_tutel = use_tutel and TUTEL_INSTALLED and gate.k == 1 + self.use_tutel = use_tutel and TUTEL_INSTALLED and (gate.k == 1 or gate.k == 2) if self.use_tutel: logger.info('Using Tutel optimizations.')