From 98814048f2037e7a25a24e8e1d89f0dce9a252e7 Mon Sep 17 00:00:00 2001 From: Xueshen Liu Date: Tue, 14 Jan 2025 15:09:16 -0500 Subject: [PATCH 1/3] Update sharded_moe.py to support top2 gate with Tutel Given the fact that multiple experts per token is very common, and the gather and scatter operation without Tutel is so inefficient, I added support of tutel to top2 gate and tested on pipeline engine. This can be done for any k actually, I'll push that later when I have time to test, --- deepspeed/moe/sharded_moe.py | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/deepspeed/moe/sharded_moe.py b/deepspeed/moe/sharded_moe.py index 340bc82de508..191807012270 100644 --- a/deepspeed/moe/sharded_moe.py +++ b/deepspeed/moe/sharded_moe.py @@ -292,7 +292,8 @@ def top2gating(logits: Tensor, min_capacity: int, drop_tokens: bool = True, ep_group: Union[torch.distributed.ProcessGroup, None] = None, - top2_2nd_expert_sampling: bool = True) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + top2_2nd_expert_sampling: bool = True, + use_tutel: bool = False) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """Implements Top2Gating on logits.""" # everything is in fp32 in this function gates = F.softmax(logits, dim=1) @@ -313,8 +314,12 @@ def top2gating(logits: Tensor, mask2 = F.one_hot(indices2_s, num_classes=num_experts) # Compute locations in capacity buffer - locations1 = torch.cumsum(mask1, dim=0) - 1 - locations2 = torch.cumsum(mask2, dim=0) - 1 + if not use_tutel: + locations1 = torch.cumsum(mask1, dim=0) - 1 + locations2 = torch.cumsum(mask2, dim=0) - 1 + else: + locations1 = tutel_moe.fast_cumsum_sub_one(mask1) + locations2 = tutel_moe.fast_cumsum_sub_one(mask2) # Update 2nd's location by accounting for locations of 1st locations2 += torch.sum(mask1, dim=0, keepdim=True) @@ -324,7 +329,7 @@ def top2gating(logits: Tensor, l_aux = torch.mean(me * ce) * num_experts * num_experts # gating decisions - exp_counts = torch.sum(mask1 + mask2, dim=0).detach().to(logits.device) + exp_counts = torch.sum(mask1 + mask2, dim=0).detach().to(logits.device) if drop_tokens: # Calculate configured capacity and remove locations outside capacity from mask @@ -358,6 +363,19 @@ def top2gating(logits: Tensor, gates1_s /= denom_s gates2_s /= denom_s + if use_tutel: + # return critical information for tutel + return l_aux, capacity, num_experts, [ + indices1_s, + indices2_s, + ], [ + locations1_s, + locations2_s, + ], [ + gates1_s, + gates2_s, + ], exp_counts + # Calculate combine_weights and dispatch_mask gates1 = einsum("s,se->se", gates1_s, mask1_float) gates2 = einsum("s,se->se", gates2_s, mask2_float) @@ -517,7 +535,7 @@ def forward(self, elif self.k == 2: gate_output = top2gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor, - self.min_capacity, self.drop_tokens, self.ep_group, self.top2_2nd_expert_sampling) + self.min_capacity, self.drop_tokens, self.ep_group, self.top2_2nd_expert_sampling, use_tutel) else: gate_output = topkgating(logits, self.k, self.capacity_factor if self.training else self.eval_capacity_factor, @@ -568,7 +586,7 @@ def __init__(self, self.timers = SynchronizedWallClockTimer() self.wall_clock_breakdown = False - self.use_tutel = use_tutel and TUTEL_INSTALLED and gate.k == 1 + self.use_tutel = use_tutel and TUTEL_INSTALLED and (gate.k == 1 or gate.k == 2) if self.use_tutel: logger.info('Using Tutel optimizations.') From ee994086a42976c4ec76d445b3689ffd15c08056 Mon Sep 17 00:00:00 2001 From: Xueshen Liu Date: Tue, 14 Jan 2025 15:11:46 -0500 Subject: [PATCH 2/3] Update sharded_moe.py --- deepspeed/moe/sharded_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/moe/sharded_moe.py b/deepspeed/moe/sharded_moe.py index 191807012270..122193116417 100644 --- a/deepspeed/moe/sharded_moe.py +++ b/deepspeed/moe/sharded_moe.py @@ -329,7 +329,7 @@ def top2gating(logits: Tensor, l_aux = torch.mean(me * ce) * num_experts * num_experts # gating decisions - exp_counts = torch.sum(mask1 + mask2, dim=0).detach().to(logits.device) + exp_counts = torch.sum(mask1 + mask2, dim=0).detach().to(logits.device) if drop_tokens: # Calculate configured capacity and remove locations outside capacity from mask From 07cc866e68360379e4049bb5414f3b5e4703a138 Mon Sep 17 00:00:00 2001 From: Xueshen Liu Date: Wed, 15 Jan 2025 14:40:43 -0500 Subject: [PATCH 3/3] Update sharded_moe.py for formatting --- deepspeed/moe/sharded_moe.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/deepspeed/moe/sharded_moe.py b/deepspeed/moe/sharded_moe.py index 122193116417..0beb9b872acd 100644 --- a/deepspeed/moe/sharded_moe.py +++ b/deepspeed/moe/sharded_moe.py @@ -535,7 +535,8 @@ def forward(self, elif self.k == 2: gate_output = top2gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor, - self.min_capacity, self.drop_tokens, self.ep_group, self.top2_2nd_expert_sampling, use_tutel) + self.min_capacity, self.drop_tokens, self.ep_group, self.top2_2nd_expert_sampling, + use_tutel) else: gate_output = topkgating(logits, self.k, self.capacity_factor if self.training else self.eval_capacity_factor,