diff --git a/tests/pytorch/test_fused_router.py b/tests/pytorch/test_fused_router.py index fa134ba4bd..5522911e4b 100644 --- a/tests/pytorch/test_fused_router.py +++ b/tests/pytorch/test_fused_router.py @@ -402,7 +402,7 @@ def profile_topk_softmax( test_topk_softmax( dtype=torch.float32, num_tokens=1024, - num_experts=128, + num_experts=3000, topk=4, use_pre_softmax=False, group_topk=None, diff --git a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu index 03d22942b5..197c662d7b 100644 --- a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu @@ -147,6 +147,8 @@ void fused_score_for_moe_aux_loss_forward_kernel_launcher( size_t shared_memory_size = num_experts * num_token_per_block * sizeof(DataType) // logits + topk * num_token_per_block * sizeof(DataType) // topk_logits + topk * num_token_per_block * sizeof(int); // topk_indices + cudaFuncSetAttribute(fused_score_for_moe_aux_loss_forward_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size); fused_score_for_moe_aux_loss_forward_kernel <<>>( logits, num_tokens, num_experts, topk, score_function, scores, routing_map, @@ -283,6 +285,8 @@ void fused_score_for_moe_aux_loss_backward_kernel_launcher( + num_experts * num_token_per_block * sizeof(DataType) // act_from_fwd + num_experts * num_token_per_block * sizeof(DataType); // comp_buf + cudaFuncSetAttribute(fused_score_for_moe_aux_loss_backward_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size); fused_score_for_moe_aux_loss_backward_kernel <<>>( intermediate_output, grad_scores, num_tokens, num_experts, topk, score_function, diff --git a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu index 03e972332a..2ec497188a 100644 --- a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu +++ b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu @@ -253,6 +253,8 @@ void fused_topk_with_score_function_forward_kernel_launcher( shared_memory_size += num_groups * num_token_per_block * sizeof(DataType); // group_scores shared_memory_size += num_experts * num_token_per_block * sizeof(DataType); // maksed_scores } + cudaFuncSetAttribute(fused_topk_with_score_function_forward_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size); fused_topk_with_score_function_forward_kernel <<>>( logits, num_tokens, num_experts, topk, use_pre_softmax, num_groups, group_topk, @@ -444,6 +446,8 @@ void fused_topk_with_score_function_backward_kernel_launcher( num_experts * num_token_per_block * sizeof(DataType) // act_from_fwd + num_experts * num_token_per_block * sizeof(DataType) // comp_buf + num_experts * num_token_per_block * sizeof(bool); // routing_map + cudaFuncSetAttribute(fused_topk_with_score_function_backward_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, shared_memory_size); fused_topk_with_score_function_backward_kernel <<>>( routing_map, intermediate_output, grad_probs, num_tokens, num_experts, topk,