Skip to content

Commit 9881404

Browse files
authored
Update sharded_moe.py to support top2 gate with Tutel
Given the fact that multiple experts per token is very common, and the gather and scatter operation without Tutel is so inefficient, I added support of tutel to top2 gate and tested on pipeline engine. This can be done for any k actually, I'll push that later when I have time to test,
1 parent 66d3d3e commit 9881404

File tree

1 file changed

+24
-6
lines changed

1 file changed

+24
-6
lines changed

deepspeed/moe/sharded_moe.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,8 @@ def top2gating(logits: Tensor,
292292
min_capacity: int,
293293
drop_tokens: bool = True,
294294
ep_group: Union[torch.distributed.ProcessGroup, None] = None,
295-
top2_2nd_expert_sampling: bool = True) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
295+
top2_2nd_expert_sampling: bool = True,
296+
use_tutel: bool = False) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
296297
"""Implements Top2Gating on logits."""
297298
# everything is in fp32 in this function
298299
gates = F.softmax(logits, dim=1)
@@ -313,8 +314,12 @@ def top2gating(logits: Tensor,
313314
mask2 = F.one_hot(indices2_s, num_classes=num_experts)
314315

315316
# Compute locations in capacity buffer
316-
locations1 = torch.cumsum(mask1, dim=0) - 1
317-
locations2 = torch.cumsum(mask2, dim=0) - 1
317+
if not use_tutel:
318+
locations1 = torch.cumsum(mask1, dim=0) - 1
319+
locations2 = torch.cumsum(mask2, dim=0) - 1
320+
else:
321+
locations1 = tutel_moe.fast_cumsum_sub_one(mask1)
322+
locations2 = tutel_moe.fast_cumsum_sub_one(mask2)
318323
# Update 2nd's location by accounting for locations of 1st
319324
locations2 += torch.sum(mask1, dim=0, keepdim=True)
320325

@@ -324,7 +329,7 @@ def top2gating(logits: Tensor,
324329
l_aux = torch.mean(me * ce) * num_experts * num_experts
325330

326331
# gating decisions
327-
exp_counts = torch.sum(mask1 + mask2, dim=0).detach().to(logits.device)
332+
exp_counts = torch.sum(mask1 + mask2, dim=0).detach().to(logits.device)
328333

329334
if drop_tokens:
330335
# Calculate configured capacity and remove locations outside capacity from mask
@@ -358,6 +363,19 @@ def top2gating(logits: Tensor,
358363
gates1_s /= denom_s
359364
gates2_s /= denom_s
360365

366+
if use_tutel:
367+
# return critical information for tutel
368+
return l_aux, capacity, num_experts, [
369+
indices1_s,
370+
indices2_s,
371+
], [
372+
locations1_s,
373+
locations2_s,
374+
], [
375+
gates1_s,
376+
gates2_s,
377+
], exp_counts
378+
361379
# Calculate combine_weights and dispatch_mask
362380
gates1 = einsum("s,se->se", gates1_s, mask1_float)
363381
gates2 = einsum("s,se->se", gates2_s, mask2_float)
@@ -517,7 +535,7 @@ def forward(self,
517535

518536
elif self.k == 2:
519537
gate_output = top2gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor,
520-
self.min_capacity, self.drop_tokens, self.ep_group, self.top2_2nd_expert_sampling)
538+
self.min_capacity, self.drop_tokens, self.ep_group, self.top2_2nd_expert_sampling, use_tutel)
521539
else:
522540
gate_output = topkgating(logits, self.k,
523541
self.capacity_factor if self.training else self.eval_capacity_factor,
@@ -568,7 +586,7 @@ def __init__(self,
568586
self.timers = SynchronizedWallClockTimer()
569587
self.wall_clock_breakdown = False
570588

571-
self.use_tutel = use_tutel and TUTEL_INSTALLED and gate.k == 1
589+
self.use_tutel = use_tutel and TUTEL_INSTALLED and (gate.k == 1 or gate.k == 2)
572590

573591
if self.use_tutel:
574592
logger.info('Using Tutel optimizations.')

0 commit comments

Comments
 (0)