Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update sharded_moe.py to support top2 gate with Tutel #6948

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions deepspeed/moe/sharded_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,8 @@ def top2gating(logits: Tensor,
min_capacity: int,
drop_tokens: bool = True,
ep_group: Union[torch.distributed.ProcessGroup, None] = None,
top2_2nd_expert_sampling: bool = True) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
top2_2nd_expert_sampling: bool = True,
use_tutel: bool = False) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Implements Top2Gating on logits."""
# everything is in fp32 in this function
gates = F.softmax(logits, dim=1)
Expand All @@ -313,8 +314,12 @@ def top2gating(logits: Tensor,
mask2 = F.one_hot(indices2_s, num_classes=num_experts)

# Compute locations in capacity buffer
locations1 = torch.cumsum(mask1, dim=0) - 1
locations2 = torch.cumsum(mask2, dim=0) - 1
if not use_tutel:
locations1 = torch.cumsum(mask1, dim=0) - 1
locations2 = torch.cumsum(mask2, dim=0) - 1
else:
locations1 = tutel_moe.fast_cumsum_sub_one(mask1)
locations2 = tutel_moe.fast_cumsum_sub_one(mask2)
# Update 2nd's location by accounting for locations of 1st
locations2 += torch.sum(mask1, dim=0, keepdim=True)

Expand Down Expand Up @@ -358,6 +363,19 @@ def top2gating(logits: Tensor,
gates1_s /= denom_s
gates2_s /= denom_s

if use_tutel:
# return critical information for tutel
return l_aux, capacity, num_experts, [
indices1_s,
indices2_s,
], [
locations1_s,
locations2_s,
], [
gates1_s,
gates2_s,
], exp_counts

# Calculate combine_weights and dispatch_mask
gates1 = einsum("s,se->se", gates1_s, mask1_float)
gates2 = einsum("s,se->se", gates2_s, mask2_float)
Expand Down Expand Up @@ -517,7 +535,8 @@ def forward(self,

elif self.k == 2:
gate_output = top2gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor,
self.min_capacity, self.drop_tokens, self.ep_group, self.top2_2nd_expert_sampling)
self.min_capacity, self.drop_tokens, self.ep_group, self.top2_2nd_expert_sampling,
use_tutel)
else:
gate_output = topkgating(logits, self.k,
self.capacity_factor if self.training else self.eval_capacity_factor,
Expand Down Expand Up @@ -568,7 +587,7 @@ def __init__(self,
self.timers = SynchronizedWallClockTimer()
self.wall_clock_breakdown = False

self.use_tutel = use_tutel and TUTEL_INSTALLED and gate.k == 1
self.use_tutel = use_tutel and TUTEL_INSTALLED and (gate.k == 1 or gate.k == 2)

if self.use_tutel:
logger.info('Using Tutel optimizations.')
Expand Down
Loading