Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
4ac7be8
feat: add cutlass group gemm support
Aug 8, 2025
a5562d1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 8, 2025
68948fb
refactor: refactor multi tensor gemm interface
Aug 13, 2025
b151755
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 18, 2025
5126889
refactor: refactor nvte_multi_stream_cublas_gemm func and add license…
Aug 18, 2025
eb3f462
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 18, 2025
a362a01
Merge branch 'main' into feature/cutlass_group_gemm_support
yaox12 Aug 18, 2025
ec44a6f
feat: add unit test for cutlass group gemm
Aug 18, 2025
22f5c47
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 18, 2025
15d750d
feat: add cutlass support type protect
Aug 18, 2025
e928062
add tests and fix lint
yaox12 Aug 19, 2025
d8697ea
Merge branch 'main' into feature/cutlass_group_gemm_support
yaox12 Aug 19, 2025
00e96c4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 19, 2025
c568733
Merge branch 'main' into feature/cutlass_group_gemm_support
phu0ngng Aug 20, 2025
896d4b9
feat: fix unit tests error
Aug 22, 2025
305c7b4
Merge branch 'main' into feature/cutlass_group_gemm_support
cassiewilliam Aug 22, 2025
af70227
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 22, 2025
b3ef3c5
feat: refactor host workspace malloc
Aug 26, 2025
579c539
Merge branch 'main' into feature/cutlass_group_gemm_support
cassiewilliam Aug 27, 2025
8b9ffe7
Merge branch 'main' into feature/cutlass_group_gemm_support
yaox12 Aug 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@
[submodule "3rdparty/cudnn-frontend"]
path = 3rdparty/cudnn-frontend
url = https://github.com/NVIDIA/cudnn-frontend.git
[submodule "3rdparty/cutlass"]
path = 3rdparty/cutlass
url = https://github.com/NVIDIA/cutlass.git
1 change: 1 addition & 0 deletions 3rdparty/cutlass
Submodule cutlass added at e51efb
72 changes: 65 additions & 7 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,11 @@
fp8_recipes.append(recipe.Float8CurrentScaling())
fp8_recipes.append(recipe.DelayedScaling())

use_cutlass_grouped_gemm = [False]
# Only enable cutlass grouped gemm on Hopper
if torch.cuda.get_device_capability() == (9, 0):
use_cutlass_grouped_gemm.append(True)


def is_fused_attn_available(
config: ModelConfig, dtype: torch.dtype, qkv_layout="bshd_bshd_bshd", is_training=True
Expand Down Expand Up @@ -1791,6 +1796,8 @@ def test_grouped_linear_accuracy(
bias,
delay_wgrad_compute,
parallel_mode=None,
bitwise_match=True,
use_cutlass=False,
):
fp8 = recipe is not None
if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
Expand Down Expand Up @@ -1862,9 +1869,50 @@ def test_grouped_linear_accuracy(
delay_wgrad_compute,
)

# Shoule be bit-wise match
for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
for o, o_ref in zip(outputs, outputs_ref):
if bitwise_match:
# cuBLAS implementation should be bit-wise match
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
elif use_cutlass:
torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3)
else:
torch.testing.assert_close(o, o_ref)


@pytest.mark.skipif(
torch.cuda.get_device_capability() != (9, 0),
reason="Only enable CUTLASS grouped gemm on Hopper",
)
@pytest.mark.parametrize("dtype", param_types, ids=str)
@pytest.mark.parametrize("num_gemms", [3, 6])
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
@pytest.mark.parametrize("delay_wgrad_compute", all_boolean)
def test_grouped_linear_accuracy_cutlass(
dtype,
num_gemms,
bs,
model,
fuse_wgrad_accumulation,
delay_wgrad_compute,
):
os.environ["NVTE_USE_CUTLASS_GROUPGEMM"] = "1"
test_grouped_linear_accuracy(
dtype,
num_gemms,
bs,
model,
None,
False,
fuse_wgrad_accumulation,
False,
delay_wgrad_compute,
None,
bitwise_match=False,
use_cutlass=True,
)
os.environ.pop("NVTE_USE_CUTLASS_GROUPGEMM", None)


@pytest.mark.parametrize("dtype", param_types, ids=str)
Expand Down Expand Up @@ -2528,10 +2576,11 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
(16, 10027, 128, 512),
],
)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("dtype", param_types, ids=str)
@pytest.mark.parametrize("layout", ["TN", "NN", "NT"])
@pytest.mark.parametrize("accumulate", [False, True])
def test_grouped_gemm(shape, dtype, layout, accumulate):
@pytest.mark.parametrize("use_cutlass", use_cutlass_grouped_gemm)
def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
torch.manual_seed(0)
z, m, k, n = shape

Expand Down Expand Up @@ -2566,6 +2615,9 @@ def test_grouped_gemm(shape, dtype, layout, accumulate):
grad = True
single_output = False

if use_cutlass:
os.environ["NVTE_USE_CUTLASS_GROUPGEMM"] = "1"

for i in range(z):
general_gemm(
A[i],
Expand Down Expand Up @@ -2593,9 +2645,15 @@ def test_grouped_gemm(shape, dtype, layout, accumulate):
single_output=single_output,
)

# should be bit-wise match
for o, o_ref in zip(out, out_ref):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
if not use_cutlass:
# cublas implementation should be bit-wise match
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
else:
torch.testing.assert_close(o, o_ref)

if use_cutlass:
os.environ.pop("NVTE_USE_CUTLASS_GROUPGEMM", None)


@pytest.mark.parametrize(
Expand Down
22 changes: 20 additions & 2 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ if(NOT EXISTS "${CUDNN_FRONTEND_INCLUDE_DIR}")
endif()
include(${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake)

set(CUTLASS_INCLUDE_DIR
"${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cutlass/include")
set(CUTLASS_TOOLS_INCLUDE_DIR
"${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cutlass/tools/util/include")

# Python
find_package(Python COMPONENTS Interpreter Development.Module REQUIRED)

Expand Down Expand Up @@ -80,6 +85,7 @@ list(APPEND transformer_engine_SOURCES
fused_attn/fused_attn.cpp
fused_attn/utils.cu
gemm/cublaslt_gemm.cu
gemm/cutlass_groupgemm.cu
normalization/common.cpp
normalization/layernorm/ln_api.cpp
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
Expand Down Expand Up @@ -120,18 +126,30 @@ add_library(transformer_engine SHARED ${transformer_engine_SOURCES})
target_include_directories(transformer_engine PUBLIC
"${CMAKE_CURRENT_SOURCE_DIR}/include")


if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0)
set_source_files_properties(
"gemm/cutlass_groupgemm.cu"
PROPERTIES
COMPILE_FLAGS
"-gencode arch=compute_90a,code=sm_90a")
else()
message(FATAL_ERROR "cutlass gemm/cutlass_groupgemm.cu kernel required sm 90a")
endif()

# Configure dependencies
target_link_libraries(transformer_engine PUBLIC
CUDA::cublas
CUDA::cudart
CUDNN::cudnn_all)

target_include_directories(transformer_engine PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_include_directories(transformer_engine SYSTEM PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}/cccl)
target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}")
target_include_directories(transformer_engine PRIVATE
${CUTLASS_INCLUDE_DIR}
${CUTLASS_TOOLS_INCLUDE_DIR})

# Compiling Userbuffers with native MPI bootstrapping requires linking against MPI
option(NVTE_UB_WITH_MPI "Bootstrap Userbuffers with MPI" OFF)
Expand Down
115 changes: 106 additions & 9 deletions transformer_engine/common/gemm/cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "../util/logging.h"
#include "../util/multi_stream.h"
#include "common/util/cuda_runtime.h"
#include "cutlass_groupgemm.cuh"

namespace {

Expand Down Expand Up @@ -650,9 +651,10 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
CUBLAS_VERSION);
#endif
NVTE_CHECK(
cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000,
transformer_engine::cuda::cudart_version() >= 12020 &&
transformer_engine::cuda::cudart_version() < 13000,
"Atomic GEMM requires CUDA version >=12.2.0 and <13.0.0, but run-time CUDA version is ",
cuda::cudart_version());
transformer_engine::cuda::cudart_version());
NVTE_CHECK(
cublas_version() >= 120205 && cublas_version() < 130000,
"Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS version is ",
Expand All @@ -675,13 +677,11 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
n_split, gemm_producer, inputCounter, stream);
}

void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
const NVTETensor *bias, NVTETensor *pre_gelu_out,
const int num_gemms, bool transa, bool transb, bool grad,
NVTETensor *workspace, bool accumulate,
bool use_split_accumulator, int math_sm_count,
cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_stream_cublas_gemm);
void multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
const NVTETensor *bias, NVTETensor *pre_gelu_out, const int num_gemms,
bool transa, bool transb, bool grad, NVTETensor *workspace,
bool accumulate, bool use_split_accumulator, int math_sm_count,
cudaStream_t stream) {
using namespace transformer_engine;

int num_streams = nvte_get_num_compute_streams();
Expand Down Expand Up @@ -711,10 +711,107 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT
}
}

void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
const NVTETensor *bias, NVTETensor *pre_gelu_out,
const int num_gemms, bool transa, bool transb, bool grad,
NVTETensor *workspace, bool accumulate,
bool use_split_accumulator, int math_sm_count,
cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_stream_cublas_gemm);
using namespace transformer_engine;

// Deprecation warning
NVTE_WARN(
"nvte_multi_stream_cublas_gemm is deprecated and will be removed in a future release. "
"Please migrate to nvte_multi_tensor_gemm (with CUTLASS Grouped GEMM support where "
"applicable).");

multi_stream_cublas_gemm(A, B, D, bias, pre_gelu_out, num_gemms, transa, transb, grad, workspace,
accumulate, use_split_accumulator, math_sm_count, stream);
}

namespace transformer_engine {

using cublasHandleManager = detail::HandleManager<cublasLtHandle_t, CreateCublasHandle>;

void nvte_cublas_handle_init() { auto _ = cublasHandleManager::Instance().GetHandle(); }

} // namespace transformer_engine

void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
const NVTETensor *bias, NVTETensor *pre_gelu_out, const int num_gemms,
bool transa, bool transb, bool grad, NVTETensor *workspace,
bool accumulate, bool use_split_accumulator, int math_sm_count,
cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_gemm);

const int current_device = transformer_engine::cuda::current_device();
const bool is_hopper = (transformer_engine::cuda::sm_arch(current_device) == 90);
const bool use_cutlass = transformer_engine::getenv<bool>("NVTE_USE_CUTLASS_GROUPGEMM", false);

auto cublas_path = [&]() {
multi_stream_cublas_gemm(A, B, D, bias, pre_gelu_out, num_gemms, transa, transb, grad,
workspace, accumulate, use_split_accumulator, math_sm_count, stream);
};

// Currently only support cutlass group gemm on Hopper Arch
if (!(is_hopper && use_cutlass)) {
cublas_path();
return;
}

auto is_empty_arr = [&](const NVTETensor *p) -> bool {
if (p == nullptr) return true;
for (int i = 0; i < num_gemms; ++i) {
if (transformer_engine::convertNVTETensor(p[i])->has_data()) return false;
}
return true;
};

auto all_groups_uniform_k128 = [&](const NVTETensor *p, bool trans) -> bool {
int64_t ref_k = -1;
for (size_t i = 0; i < num_gemms; i++) {
const auto tensor = transformer_engine::convertNVTETensorCheck(p[i]);
const int k = trans ? tensor->data.shape[0] : tensor->data.shape[1];

if ((k & 127) != 0) return false;

if (ref_k < 0)
ref_k = k;
else if (k != ref_k)
return false;
}

return true;
};

auto is_supported_dtype = [&]() -> bool {
auto *inputA = transformer_engine::convertNVTETensorCheck(A[0]);
auto *inputB = transformer_engine::convertNVTETensorCheck(B[0]);
auto *OutputB = transformer_engine::convertNVTETensorCheck(D[0]);
auto A_type = get_cuda_dtype(inputA->data.dtype);
auto B_type = get_cuda_dtype(inputB->data.dtype);
auto D_type = get_cuda_dtype(OutputB->data.dtype);

return (A_type == B_type) && (A_type == D_type) &&
((A_type == CUDA_R_16BF) || (A_type == CUDA_R_16F));
};

// CUTLASS Grouped GEMM fast path (SM90/TMA)
// Conditions:
// - No fused epilogue: both bias and pre_gelu_out are empty.
// - Supported dtypes only: FP16/BF16 (FP32 accumulate).
// - Uniform K across groups and K % 128 == 0.
// - use_split_accumulator is ignored for FP16/BF16.
// - grad is irrelevant when bias/pre_gelu_out are empty.
//
// Otherwise, fall back to cuBLAS.
if (is_empty_arr(bias) && is_empty_arr(pre_gelu_out) && is_supported_dtype() &&
all_groups_uniform_k128(B, transb)) {
cutlass_grouped_gemm(A, B, D, num_gemms, transa, transb, grad, workspace, accumulate,
current_device, math_sm_count, stream);
} else {
NVTE_WARN("cuBLAS fallback.");
cublas_path();
}
}
Loading