Skip to content

Commit ee99408

Browse files
authored
Update sharded_moe.py
1 parent 9881404 commit ee99408

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

deepspeed/moe/sharded_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ def top2gating(logits: Tensor,
329329
l_aux = torch.mean(me * ce) * num_experts * num_experts
330330

331331
# gating decisions
332-
exp_counts = torch.sum(mask1 + mask2, dim=0).detach().to(logits.device)
332+
exp_counts = torch.sum(mask1 + mask2, dim=0).detach().to(logits.device)
333333

334334
if drop_tokens:
335335
# Calculate configured capacity and remove locations outside capacity from mask

0 commit comments

Comments
 (0)