-
Notifications
You must be signed in to change notification settings - Fork 487
[Pytorch] Add Cutlass GroupGEMM Support for fine-grained MoE Model #2045
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Pytorch] Add Cutlass GroupGEMM Support for fine-grained MoE Model #2045
Conversation
d2a9a55
to
b42385d
Compare
Hi @cassiewilliam , Thank you for a great PR - it’s good to see such a clear performance improvement! I have one suggestion - I think we should refactor the change slightly to minimize modifications in the TE framework extensions. Currently, we have two separate C APIs: Within |
I fully agree with your suggestion — keeping the code architecture clean is very important. Will you be handling the refactor on your side, or should I go ahead and make the changes directly in the current PR? |
Agree with @phu0ngng. We could unify the API and do the dispatch (based on GPU arch/data type/env variable) on the TE/common side.
Please go ahead in this PR. |
Got it — I’ll refactor the code to meet the requirements described above. |
6f01bc8
to
e832972
Compare
a023c5f
to
a76e1cd
Compare
76c1e5e
to
e4b5e60
Compare
To pass the DCO check, you have to sign your commits. See here for more details. |
Signed-off-by: Min Yang <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Min Yang <[email protected]>
for more information, see https://pre-commit.ci
… info Signed-off-by: Min Yang <[email protected]>
for more information, see https://pre-commit.ci
75d075a
to
eb3f462
Compare
Appreciate the feedback — the changes have been made. |
Please add a test here. https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/test_numerics.py#L2520. |
Signed-off-by: Min Yang <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Min Yang <[email protected]>
c3010d9
to
15d750d
Compare
Signed-off-by: Xin Yao <[email protected]>
for more information, see https://pre-commit.ci
@cassiewilliam I enabled some tests in e928062. Please make sure the following tests pass.
I tested it locally and found tests with |
struct GemmGivenSchedule { | ||
using ElementA = typename ScheduleConfig::DataType; // Element type for A matrix operand | ||
using ElementB = typename ScheduleConfig::DataType; // Element type for B matrix operand | ||
using ElementC = typename ScheduleConfig::DataType; // Element type for C and D matrix operands |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When enabling fuse_wgrad_accumulation
, C/D could have a different data type from A/B.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for taking the time to review this. I’ll make the necessary changes and push the fixes as soon as possible.
Thank you for taking the time to review this. I’ll make the necessary changes and push the fixes as soon as possible. |
Great PR! Thank you so much. CC @timmoon10 |
NVTE_CHECK(false, "Failed to run CUTLASS Grouped GEMM"); | ||
} | ||
|
||
std::free(host_workspace); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work! I would like to remind that this free
will be dangerous if kernel launch on host is much faster than kernel execution on GPU, because I don't think gemm.run
implies stream synchronization. In that case, before cudaMemcpyAsync
could copy host_workspace
to all_workspace
, host_workspace
may have been destroyed.
So I suggest adding a CUDA stream sync before free
. Otherwise, you can try a different way to maintain ptrs and shapes in device memory.
/te-ci L0 |
Description
Add Cutlass Group GEMM Support for H100(SM90), which provides greater performance advantages compared to the current Multi-Stream implementation in Fine-Grained MoE models. Currently, this PR only supports FP16 and BF16 scenarios, and FP8 support is not yet available. The implementation is limited to the standard MoE Module (Bias and other related features have not been validated yet). Please take note.
Initial performance test results are as follows, and the testing method can be found in file test_group_gemm.py.
run test script with
python tests/pytorch/test_group_gemm.py
Add the system environment variable NVTE_USE_CUTLASS_GROUPGEMM to toggle between the two GEMM implementations. Setting export NVTE_USE_CUTLASS_GROUPGEMM=0 selects the original Multi-Stream cuBLAS GEMM, while setting export NVTE_USE_CUTLASS_GROUPGEMM=1 enables the newly added CUTLASS Group GEMM. The default value is 0.
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: