We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 9881404 commit ee99408Copy full SHA for ee99408
deepspeed/moe/sharded_moe.py
@@ -329,7 +329,7 @@ def top2gating(logits: Tensor,
329
l_aux = torch.mean(me * ce) * num_experts * num_experts
330
331
# gating decisions
332
- exp_counts = torch.sum(mask1 + mask2, dim=0).detach().to(logits.device)
+ exp_counts = torch.sum(mask1 + mask2, dim=0).detach().to(logits.device)
333
334
if drop_tokens:
335
# Calculate configured capacity and remove locations outside capacity from mask
0 commit comments