Skip to content

[Pytorch] Add Cutlass GroupGEMM Support for fine-grained MoE Model #2045

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from

Conversation

cassiewilliam
Copy link

@cassiewilliam cassiewilliam commented Aug 8, 2025

Description

Add Cutlass Group GEMM Support for H100(SM90), which provides greater performance advantages compared to the current Multi-Stream implementation in Fine-Grained MoE models. Currently, this PR only supports FP16 and BF16 scenarios, and FP8 support is not yet available. The implementation is limited to the standard MoE Module (Bias and other related features have not been validated yet). Please take note.

Initial performance test results are as follows, and the testing method can be found in file test_group_gemm.py.

run test script with

python tests/pytorch/test_group_gemm.py

Shape(g,m,n,k) TE V2.2 (TFLOPs) Cutlass-Opt-V1(TFLOPs) Speed-Up
(8, 4096, 768, 2048) 508.77 568.63 11.77%
(16, 2048, 768, 2048) 398.81 534.75 34.08%

Add the system environment variable NVTE_USE_CUTLASS_GROUPGEMM to toggle between the two GEMM implementations. Setting export NVTE_USE_CUTLASS_GROUPGEMM=0 selects the original Multi-Stream cuBLAS GEMM, while setting export NVTE_USE_CUTLASS_GROUPGEMM=1 enables the newly added CUTLASS Group GEMM. The default value is 0.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@cassiewilliam cassiewilliam force-pushed the feature/cutlass_group_gemm_support branch 2 times, most recently from d2a9a55 to b42385d Compare August 8, 2025 09:14
@phu0ngng
Copy link
Collaborator

phu0ngng commented Aug 11, 2025

Hi @cassiewilliam ,

Thank you for a great PR - it’s good to see such a clear performance improvement!

I have one suggestion - I think we should refactor the change slightly to minimize modifications in the TE framework extensions.

Currently, we have two separate C APIs: nvte_multi_stream_cublas_gemm and nvte_cutlass_grouped_gemm. The PyTorch extensions call these individually, and we would need to do the same on the JAX side. Since they share the same function signature, we could unify them into a single API - nvte_multi_tensor_gemm - and deprecate nvte_multi_stream_cublas_gemm.

Within nvte_multi_tensor_gemm, we can determine the GPU architecture and enable CUTLASS GroupedGEMM for FP16/BF16 on Hopper. This way, future changes to the GroupedGEMM implementation or backend would not require modifications to the PyTorch/JAX extensions.

@cassiewilliam
Copy link
Author

Hi @cassiewilliam ,

Thank you for a great PR - it’s good to see such a clear performance improvement!

I have one suggestion - I think we should refactor the change slightly to minimize modifications in the TE framework extensions.

Currently, we have two separate C APIs: nvte_multi_stream_cublas_gemm and nvte_cutlass_grouped_gemm. The PyTorch extensions call these individually, and we would need to do the same on the JAX side. Since they share the same function signature, we could unify them into a single API - nvte_multi_tensor_gemm - and deprecate nvte_multi_stream_cublas_gemm.

Within nvte_multi_tensor_gemm, we can determine the GPU architecture and enable CUTLASS GroupedGEMM for FP16/BF16 on Hopper. This way, future changes to the GroupedGEMM implementation or backend would not require modifications to the PyTorch/JAX extensions.

I fully agree with your suggestion — keeping the code architecture clean is very important. Will you be handling the refactor on your side, or should I go ahead and make the changes directly in the current PR?

@yaox12
Copy link
Member

yaox12 commented Aug 12, 2025

Agree with @phu0ngng. We could unify the API and do the dispatch (based on GPU arch/data type/env variable) on the TE/common side.

Will you be handling the refactor on your side, or should I go ahead and make the changes directly in the current PR?

Please go ahead in this PR.

@cassiewilliam
Copy link
Author

Agree with @phu0ngng. We could unify the API and do the dispatch (based on GPU arch/data type/env variable) on the TE/common side.

Will you be handling the refactor on your side, or should I go ahead and make the changes directly in the current PR?

Please go ahead in this PR.

Got it — I’ll refactor the code to meet the requirements described above.

@cassiewilliam cassiewilliam force-pushed the feature/cutlass_group_gemm_support branch 12 times, most recently from 6f01bc8 to e832972 Compare August 13, 2025 04:24
@cassiewilliam
Copy link
Author

hello @phu0ngng @yaox12 The nvte_multi_tensor_gemm interface has been fully refactored. Please review the implementation for correctness and compliance with the updated design.

@cassiewilliam cassiewilliam force-pushed the feature/cutlass_group_gemm_support branch 7 times, most recently from a023c5f to a76e1cd Compare August 18, 2025 03:58
@cassiewilliam cassiewilliam force-pushed the feature/cutlass_group_gemm_support branch 3 times, most recently from 76c1e5e to e4b5e60 Compare August 18, 2025 04:08
@yaox12
Copy link
Member

yaox12 commented Aug 18, 2025

To pass the DCO check, you have to sign your commits. See here for more details.

@cassiewilliam cassiewilliam force-pushed the feature/cutlass_group_gemm_support branch from 75d075a to eb3f462 Compare August 18, 2025 06:43
@cassiewilliam
Copy link
Author

To pass the DCO check, you have to sign your commits. See here for more details.

Appreciate the feedback — the changes have been made.

@yaox12
Copy link
Member

yaox12 commented Aug 18, 2025

Please add a test here. https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/test_numerics.py#L2520.
I think you can add an argument use_cutlass_group_gemm, and set the environment variable if it's true (remember to unset it at the end of test). I won't expect CUTLASS implementation to be bit-wise match with the baseline, so you need to set tolerances.

@cassiewilliam cassiewilliam force-pushed the feature/cutlass_group_gemm_support branch from c3010d9 to 15d750d Compare August 18, 2025 11:44
@yaox12
Copy link
Member

yaox12 commented Aug 19, 2025

@cassiewilliam I enabled some tests in e928062. Please make sure the following tests pass.

pytest -v -s tests/pytorch/test_numerics.py::test_grouped_linear_accuracy_cutlass
pytest -v -s tests/pytorch/test_numerics.py::test_grouped_gemm

I tested it locally and found tests with fuse_wgrad_accumulation = True would fail. When enabling fuse_wgrad_accumulation, the output is usually in FP32, but you're using the same data type for C as A/B. This may explain the failures in test_grouped_linear_accuracy_cutlass. But in test_grouped_gemm, the output is in the same dtype as inputs, so I'm not sure why they failed. Maybe the tolerance is not correctly set.

struct GemmGivenSchedule {
using ElementA = typename ScheduleConfig::DataType; // Element type for A matrix operand
using ElementB = typename ScheduleConfig::DataType; // Element type for B matrix operand
using ElementC = typename ScheduleConfig::DataType; // Element type for C and D matrix operands
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When enabling fuse_wgrad_accumulation, C/D could have a different data type from A/B.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for taking the time to review this. I’ll make the necessary changes and push the fixes as soon as possible.

@cassiewilliam
Copy link
Author

@cassiewilliam I enabled some tests in e928062. Please make sure the following tests pass.

pytest -v -s tests/pytorch/test_numerics.py::test_grouped_linear_accuracy_cutlass
pytest -v -s tests/pytorch/test_numerics.py::test_grouped_gemm

I tested it locally and found tests with fuse_wgrad_accumulation = True would fail. When enabling fuse_wgrad_accumulation, the output is usually in FP32, but you're using the same data type for C as A/B. This may explain the failures in test_grouped_linear_accuracy_cutlass. But in test_grouped_gemm, the output is in the same dtype as inputs, so I'm not sure why they failed. Maybe the tolerance is not correctly set.

Thank you for taking the time to review this. I’ll make the necessary changes and push the fixes as soon as possible.

@zhongbozhu
Copy link
Collaborator

Great PR! Thank you so much. CC @timmoon10

NVTE_CHECK(false, "Failed to run CUTLASS Grouped GEMM");
}

std::free(host_workspace);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work! I would like to remind that this free will be dangerous if kernel launch on host is much faster than kernel execution on GPU, because I don't think gemm.run implies stream synchronization. In that case, before cudaMemcpyAsync could copy host_workspace to all_workspace, host_workspace may have been destroyed.

So I suggest adding a CUDA stream sync before free. Otherwise, you can try a different way to maintain ptrs and shapes in device memory.

@phu0ngng
Copy link
Collaborator

/te-ci L0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants