diff --git a/.gitmodules b/.gitmodules index 21492db5ef..4b188d6bb1 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,3 +4,6 @@ [submodule "3rdparty/cudnn-frontend"] path = 3rdparty/cudnn-frontend url = https://github.com/NVIDIA/cudnn-frontend.git +[submodule "3rdparty/cutlass"] + path = 3rdparty/cutlass + url = https://github.com/NVIDIA/cutlass.git diff --git a/3rdparty/cutlass b/3rdparty/cutlass new file mode 160000 index 0000000000..e51efbfe18 --- /dev/null +++ b/3rdparty/cutlass @@ -0,0 +1 @@ +Subproject commit e51efbfe18fe4f4cbb66ab814c55bf4aa0185491 diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index b76f3d2b21..17f26b4390 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -120,6 +120,11 @@ fp8_recipes.append(recipe.Float8CurrentScaling()) fp8_recipes.append(recipe.DelayedScaling()) +use_cutlass_grouped_gemm = [False] +# Only enable cutlass grouped gemm on Hopper +if torch.cuda.get_device_capability() == (9, 0): + use_cutlass_grouped_gemm.append(True) + def is_fused_attn_available( config: ModelConfig, dtype: torch.dtype, qkv_layout="bshd_bshd_bshd", is_training=True @@ -1791,6 +1796,8 @@ def test_grouped_linear_accuracy( bias, delay_wgrad_compute, parallel_mode=None, + bitwise_match=True, + use_cutlass=False, ): fp8 = recipe is not None if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: @@ -1862,9 +1869,50 @@ def test_grouped_linear_accuracy( delay_wgrad_compute, ) - # Shoule be bit-wise match - for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)): - torch.testing.assert_close(o, o_ref, rtol=0, atol=0) + for o, o_ref in zip(outputs, outputs_ref): + if bitwise_match: + # cuBLAS implementation should be bit-wise match + torch.testing.assert_close(o, o_ref, rtol=0, atol=0) + elif use_cutlass: + torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3) + else: + torch.testing.assert_close(o, o_ref) + + +@pytest.mark.skipif( + torch.cuda.get_device_capability() != (9, 0), + reason="Only enable CUTLASS grouped gemm on Hopper", +) +@pytest.mark.parametrize("dtype", param_types, ids=str) +@pytest.mark.parametrize("num_gemms", [3, 6]) +@pytest.mark.parametrize("bs", batch_sizes) +@pytest.mark.parametrize("model", ["126m"]) +@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) +@pytest.mark.parametrize("delay_wgrad_compute", all_boolean) +def test_grouped_linear_accuracy_cutlass( + dtype, + num_gemms, + bs, + model, + fuse_wgrad_accumulation, + delay_wgrad_compute, +): + os.environ["NVTE_USE_CUTLASS_GROUPGEMM"] = "1" + test_grouped_linear_accuracy( + dtype, + num_gemms, + bs, + model, + None, + False, + fuse_wgrad_accumulation, + False, + delay_wgrad_compute, + None, + bitwise_match=False, + use_cutlass=True, + ) + os.environ.pop("NVTE_USE_CUTLASS_GROUPGEMM", None) @pytest.mark.parametrize("dtype", param_types, ids=str) @@ -2528,10 +2576,11 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): (16, 10027, 128, 512), ], ) -@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("dtype", param_types, ids=str) @pytest.mark.parametrize("layout", ["TN", "NN", "NT"]) @pytest.mark.parametrize("accumulate", [False, True]) -def test_grouped_gemm(shape, dtype, layout, accumulate): +@pytest.mark.parametrize("use_cutlass", use_cutlass_grouped_gemm) +def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass): torch.manual_seed(0) z, m, k, n = shape @@ -2566,6 +2615,9 @@ def test_grouped_gemm(shape, dtype, layout, accumulate): grad = True single_output = False + if use_cutlass: + os.environ["NVTE_USE_CUTLASS_GROUPGEMM"] = "1" + for i in range(z): general_gemm( A[i], @@ -2593,9 +2645,15 @@ def test_grouped_gemm(shape, dtype, layout, accumulate): single_output=single_output, ) - # should be bit-wise match for o, o_ref in zip(out, out_ref): - torch.testing.assert_close(o, o_ref, rtol=0, atol=0) + if not use_cutlass: + # cublas implementation should be bit-wise match + torch.testing.assert_close(o, o_ref, rtol=0, atol=0) + else: + torch.testing.assert_close(o, o_ref) + + if use_cutlass: + os.environ.pop("NVTE_USE_CUTLASS_GROUPGEMM", None) @pytest.mark.parametrize( diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 183a7a72ec..b05862ae25 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -45,6 +45,11 @@ if(NOT EXISTS "${CUDNN_FRONTEND_INCLUDE_DIR}") endif() include(${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake) +set(CUTLASS_INCLUDE_DIR + "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cutlass/include") +set(CUTLASS_TOOLS_INCLUDE_DIR + "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cutlass/tools/util/include") + # Python find_package(Python COMPONENTS Interpreter Development.Module REQUIRED) @@ -80,6 +85,7 @@ list(APPEND transformer_engine_SOURCES fused_attn/fused_attn.cpp fused_attn/utils.cu gemm/cublaslt_gemm.cu + gemm/cutlass_groupgemm.cu normalization/common.cpp normalization/layernorm/ln_api.cpp normalization/layernorm/ln_bwd_semi_cuda_kernel.cu @@ -120,18 +126,30 @@ add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include") - +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0) + set_source_files_properties( + "gemm/cutlass_groupgemm.cu" + PROPERTIES + COMPILE_FLAGS + "-gencode arch=compute_90a,code=sm_90a") +else() + message(FATAL_ERROR "cutlass gemm/cutlass_groupgemm.cu kernel required sm 90a") +endif() # Configure dependencies target_link_libraries(transformer_engine PUBLIC CUDA::cublas CUDA::cudart CUDNN::cudnn_all) + target_include_directories(transformer_engine PRIVATE - ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) + ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) target_include_directories(transformer_engine SYSTEM PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}/cccl) target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}") +target_include_directories(transformer_engine PRIVATE + ${CUTLASS_INCLUDE_DIR} + ${CUTLASS_TOOLS_INCLUDE_DIR}) # Compiling Userbuffers with native MPI bootstrapping requires linking against MPI option(NVTE_UB_WITH_MPI "Bootstrap Userbuffers with MPI" OFF) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 9e6c5417bc..cc0018b1d4 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -19,6 +19,7 @@ #include "../util/logging.h" #include "../util/multi_stream.h" #include "common/util/cuda_runtime.h" +#include "cutlass_groupgemm.cuh" namespace { @@ -650,9 +651,10 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor CUBLAS_VERSION); #endif NVTE_CHECK( - cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000, + transformer_engine::cuda::cudart_version() >= 12020 && + transformer_engine::cuda::cudart_version() < 13000, "Atomic GEMM requires CUDA version >=12.2.0 and <13.0.0, but run-time CUDA version is ", - cuda::cudart_version()); + transformer_engine::cuda::cudart_version()); NVTE_CHECK( cublas_version() >= 120205 && cublas_version() < 130000, "Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS version is ", @@ -675,13 +677,11 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor n_split, gemm_producer, inputCounter, stream); } -void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D, - const NVTETensor *bias, NVTETensor *pre_gelu_out, - const int num_gemms, bool transa, bool transb, bool grad, - NVTETensor *workspace, bool accumulate, - bool use_split_accumulator, int math_sm_count, - cudaStream_t stream) { - NVTE_API_CALL(nvte_multi_stream_cublas_gemm); +void multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D, + const NVTETensor *bias, NVTETensor *pre_gelu_out, const int num_gemms, + bool transa, bool transb, bool grad, NVTETensor *workspace, + bool accumulate, bool use_split_accumulator, int math_sm_count, + cudaStream_t stream) { using namespace transformer_engine; int num_streams = nvte_get_num_compute_streams(); @@ -711,6 +711,25 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT } } +void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D, + const NVTETensor *bias, NVTETensor *pre_gelu_out, + const int num_gemms, bool transa, bool transb, bool grad, + NVTETensor *workspace, bool accumulate, + bool use_split_accumulator, int math_sm_count, + cudaStream_t stream) { + NVTE_API_CALL(nvte_multi_stream_cublas_gemm); + using namespace transformer_engine; + + // Deprecation warning + NVTE_WARN( + "nvte_multi_stream_cublas_gemm is deprecated and will be removed in a future release. " + "Please migrate to nvte_multi_tensor_gemm (with CUTLASS Grouped GEMM support where " + "applicable)."); + + multi_stream_cublas_gemm(A, B, D, bias, pre_gelu_out, num_gemms, transa, transb, grad, workspace, + accumulate, use_split_accumulator, math_sm_count, stream); +} + namespace transformer_engine { using cublasHandleManager = detail::HandleManager; @@ -718,3 +737,81 @@ using cublasHandleManager = detail::HandleManager("NVTE_USE_CUTLASS_GROUPGEMM", false); + + auto cublas_path = [&]() { + multi_stream_cublas_gemm(A, B, D, bias, pre_gelu_out, num_gemms, transa, transb, grad, + workspace, accumulate, use_split_accumulator, math_sm_count, stream); + }; + + // Currently only support cutlass group gemm on Hopper Arch + if (!(is_hopper && use_cutlass)) { + cublas_path(); + return; + } + + auto is_empty_arr = [&](const NVTETensor *p) -> bool { + if (p == nullptr) return true; + for (int i = 0; i < num_gemms; ++i) { + if (transformer_engine::convertNVTETensor(p[i])->has_data()) return false; + } + return true; + }; + + auto all_groups_uniform_k128 = [&](const NVTETensor *p, bool trans) -> bool { + int64_t ref_k = -1; + for (size_t i = 0; i < num_gemms; i++) { + const auto tensor = transformer_engine::convertNVTETensorCheck(p[i]); + const int k = trans ? tensor->data.shape[0] : tensor->data.shape[1]; + + if ((k & 127) != 0) return false; + + if (ref_k < 0) + ref_k = k; + else if (k != ref_k) + return false; + } + + return true; + }; + + auto is_supported_dtype = [&]() -> bool { + auto *inputA = transformer_engine::convertNVTETensorCheck(A[0]); + auto *inputB = transformer_engine::convertNVTETensorCheck(B[0]); + auto *OutputB = transformer_engine::convertNVTETensorCheck(D[0]); + auto A_type = get_cuda_dtype(inputA->data.dtype); + auto B_type = get_cuda_dtype(inputB->data.dtype); + auto D_type = get_cuda_dtype(OutputB->data.dtype); + + return (A_type == B_type) && (A_type == D_type) && + ((A_type == CUDA_R_16BF) || (A_type == CUDA_R_16F)); + }; + + // CUTLASS Grouped GEMM fast path (SM90/TMA) + // Conditions: + // - No fused epilogue: both bias and pre_gelu_out are empty. + // - Supported dtypes only: FP16/BF16 (FP32 accumulate). + // - Uniform K across groups and K % 128 == 0. + // - use_split_accumulator is ignored for FP16/BF16. + // - grad is irrelevant when bias/pre_gelu_out are empty. + // + // Otherwise, fall back to cuBLAS. + if (is_empty_arr(bias) && is_empty_arr(pre_gelu_out) && is_supported_dtype() && + all_groups_uniform_k128(B, transb)) { + cutlass_grouped_gemm(A, B, D, num_gemms, transa, transb, grad, workspace, accumulate, + current_device, math_sm_count, stream); + } else { + NVTE_WARN("cuBLAS fallback."); + cublas_path(); + } +} diff --git a/transformer_engine/common/gemm/cutlass_groupgemm.cu b/transformer_engine/common/gemm/cutlass_groupgemm.cu new file mode 100644 index 0000000000..d8556311f0 --- /dev/null +++ b/transformer_engine/common/gemm/cutlass_groupgemm.cu @@ -0,0 +1,95 @@ +/*************************************************************************************************** + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + **************************************************************************************************/ + +#include "cutlass/bfloat16.h" +#include "cutlass/cutlass.h" +#include "cutlass_groupgemm.cuh" + +namespace grouped_gemm { + +cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) { + using namespace transformer_engine; + switch (t) { + case DType::kFloat16: + return CUDA_R_16F; + case DType::kFloat32: + return CUDA_R_32F; + case DType::kBFloat16: + return CUDA_R_16BF; + case DType::kFloat8E4M3: + return CUDA_R_8F_E4M3; + case DType::kFloat8E5M2: + return CUDA_R_8F_E5M2; + default: + NVTE_ERROR("Invalid type"); + } +} + +// Explicit template instantiation to match the template declarations in the .cuh +template void CutlassGroupedGemm(const NVTETensor*, + const NVTETensor*, NVTETensor*, + NVTETensor*, float, float, int, + cudaStream_t, int, int); +template void CutlassGroupedGemm(const NVTETensor*, const NVTETensor*, + NVTETensor*, NVTETensor*, float, + float, int, cudaStream_t, int, int); +template void CutlassGroupedGemm(const NVTETensor*, const NVTETensor*, + NVTETensor*, NVTETensor*, float, + float, int, cudaStream_t, int, int); + +template void CutlassGroupedGemm(const NVTETensor*, + const NVTETensor*, NVTETensor*, + NVTETensor*, float, float, int, + cudaStream_t, int, int); +template void CutlassGroupedGemm(const NVTETensor*, + const NVTETensor*, NVTETensor*, + NVTETensor*, float, float, int, + cudaStream_t, int, int); +template void CutlassGroupedGemm(const NVTETensor*, + const NVTETensor*, NVTETensor*, + NVTETensor*, float, float, int, + cudaStream_t, int, int); + +} // namespace grouped_gemm + +void cutlass_grouped_gemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, int num_gemms, + bool transa, bool transb, bool grad, NVTETensor* workspace, + bool accumulate, int device, int math_sm_count, cudaStream_t stream) { + auto* inputA = transformer_engine::convertNVTETensorCheck(A[0]); + auto* inputB = transformer_engine::convertNVTETensorCheck(B[0]); + + auto A_type = grouped_gemm::get_cuda_dtype(inputA->data.dtype); + auto B_type = grouped_gemm::get_cuda_dtype(inputB->data.dtype); + + float one = 1.0; + float zero = 0.0; + float alpha = one; + float beta = (accumulate) ? one : zero; + + auto dispatch = [&](auto tag) { + using T = decltype(tag); + if (!transa && !transb) { + grouped_gemm::CutlassGroupedGemm(B, A, D, workspace, alpha, beta, num_gemms, + stream, device, math_sm_count); + } else if (!transb && transa) { + grouped_gemm::CutlassGroupedGemm(B, A, D, workspace, alpha, beta, num_gemms, + stream, device, math_sm_count); + } else if (transb && !transa) { + grouped_gemm::CutlassGroupedGemm(B, A, D, workspace, alpha, beta, num_gemms, + stream, device, math_sm_count); + } else { + NVTE_ERROR("Layout 'TT' is not supported by cutlass_grouped_gemm."); + } + }; + + if (A_type == CUDA_R_16BF) { + dispatch(cutlass::bfloat16_t{}); + } else if (A_type == CUDA_R_16F) { + dispatch(cutlass::half_t{}); + } else { + NVTE_ERROR("Unsupported dtype: only BF16(FP16) are supported."); + } +} diff --git a/transformer_engine/common/gemm/cutlass_groupgemm.cuh b/transformer_engine/common/gemm/cutlass_groupgemm.cuh new file mode 100644 index 0000000000..233a28d666 --- /dev/null +++ b/transformer_engine/common/gemm/cutlass_groupgemm.cuh @@ -0,0 +1,346 @@ +/*************************************************************************************************** + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + **************************************************************************************************/ + +// +// Copyright (c) 2025 Shopee Inc. All Rights Reserved. +// + +/** + * @file: cutlass_groupgemm.cuh + * @author: min.yang@shopee.com, yangfan.bai@shopee.com, finch.li@shopee.com + * @date: 2025-08-08 16:20:00 + * @brief: cutlass group gemm kernel. + **/ + +#pragma once + +#include + +#include +#include + +#include "../common.h" +#include "../util/logging.h" +#include "common/util/system.h" +#include "cute/tensor.hpp" +#include "cutlass/bfloat16.h" +#include "cutlass/complex.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" + +namespace grouped_gemm { + +template +using GroupedGemmInputALayout = + std::conditional_t; + +template +using GroupedGemmInputBLayout = + std::conditional_t; + +using ProblemShapeType = cute::Shape; +using ProblemShape = cutlass::gemm::GroupProblemShape; // per group +template +struct GemmGivenSchedule { + using ElementA = typename ScheduleConfig::DataType; // Element type for A matrix operand + using ElementB = typename ScheduleConfig::DataType; // Element type for B matrix operand + using ElementC = typename ScheduleConfig::DataType; // Element type for C and D matrix operands + + // A matrix configuration + using LayoutA = typename ScheduleConfig::LayoutA; // Layout type for A matrix operand + static constexpr int AlignmentA = + 128 / cutlass::sizeof_bits< + ElementA>::value; // Alignment of A matrix in units of elements (up to 16 bytes) + + // B matrix configuration + using LayoutB = typename ScheduleConfig::LayoutB; // Layout type for B matrix operand + static constexpr int AlignmentB = + 128 / cutlass::sizeof_bits< + ElementB>::value; // Alignment of B matrix in units of elements (up to 16 bytes) + + // C/D matrix configuration + using LayoutC = typename ScheduleConfig::LayoutC; // Layout type for C and D matrix operands + static constexpr int AlignmentC = + 128 / cutlass::sizeof_bits< + ElementC>::value; // Alignment of C matrix in units of elements (up to 16 bytes) + + // Core kernel configurations + using ElementAccumulator = float; // Element type for internal accumulation + using ArchTag = + cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature + using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag + using StageCountType = + cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size + + using TileShape = typename ScheduleConfig::TileShape; // Threadblock-level tile size + using ClusterShape = + typename ScheduleConfig::ClusterShape; // Shape of the threadblocks in a cluster + using KernelSchedule = typename ScheduleConfig::KernelSchedule; // Kernel to launch + using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule; // Epilogue to launch + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator, + ElementC, LayoutC*, AlignmentC, ElementC, LayoutC*, AlignmentC, EpilogueSchedule, + cutlass::epilogue::fusion::LinearCombination>::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementA, LayoutA*, AlignmentA, ElementB, LayoutB*, AlignmentB, + ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernel = + cutlass::gemm::kernel::GemmUniversal; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +}; + +template +struct ScheduleConfig { + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + + // TODO(Alan): Add tuning for different scenarios to select the optimal configuration, + // as the current configuration may not be the best. + + // using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; + // using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; + // using TileShape = Shape; + // using ClusterShape = Shape; + + using LayoutA = GroupedGemmInputALayout; + using LayoutB = GroupedGemmInputBLayout; + using LayoutC = cutlass::layout::RowMajor; + using DataType = DataType_; +}; + +template +using GemmGrouped = typename GemmGivenSchedule>::Gemm; + +template +typename GemmT::Arguments MakeArguments(int num_experts, void* problem_sizes_host, + void* problem_sizes, const ElementA** ptr_A, + StrideA* stride_A, const ElementB** ptr_B, + StrideB* stride_B, ElementC** ptr_C, StrideC* stride_C, + float alpha, float beta, int device, int math_sm_count) { + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + + cutlass::KernelHardwareInfo kernel_hw_info = + cutlass::KernelHardwareInfo::make_kernel_hardware_info( + device, math_sm_count); + + typename GemmT::Arguments arguments; + decltype(arguments.epilogue.thread) fusion_args; + + fusion_args.alpha = alpha; + fusion_args.beta = beta; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + fusion_args.alpha_ptr_array = nullptr; + fusion_args.beta_ptr_array = nullptr; + // Single alpha and beta for all groups + fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0}; + fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0}; + + arguments = + typename GemmT::Arguments{cutlass::gemm::GemmUniversalMode::kGrouped, + {num_experts, reinterpret_cast(problem_sizes), + reinterpret_cast(problem_sizes_host)}, + {ptr_A, stride_A, ptr_B, stride_B}, + { + fusion_args, + (beta > 0.0) ? (const ElementC**)ptr_C : nullptr, // NOLINT(*) + stride_C, + ptr_C, + stride_C, + }, + kernel_hw_info}; + + return arguments; +} + +template +inline __device__ __host__ T ROUND_UP(T m, T n) { + return (m + n - 1) / n * n; +} + +template +void debug_type() { + std::cout << typeid(T).name() << std::endl; +} + +int64_t inline getGemmCoordSize(int64_t num_gemms) { + return (int64_t)(ROUND_UP(num_gemms * sizeof(ProblemShapeType), 128UL)); +} + +int64_t inline getPtrSize(int64_t num_gemms) { + return (int64_t)(ROUND_UP(num_gemms * sizeof(half*), 128UL)); +} + +int64_t inline getLddSize(int64_t num_gemms) { + return (int64_t)(ROUND_UP(num_gemms * sizeof(int64_t), 128UL)); +} + +// cpu workspace size is 4MB +static constexpr size_t kCPUWorkSpaceSize = 4 * 1024 * 1024; + +static char* getHostWorkspace() { + static std::once_flag flag; + static std::shared_ptr workspace; + + std::call_once(flag, [&]() { + workspace = + std::shared_ptr(reinterpret_cast(std::malloc(kCPUWorkSpaceSize)), [](char* p) { + if (p) std::free(p); + }); + + if (!workspace) { + throw std::bad_alloc(); + } + }); + + return workspace.get(); +} + +template +void CutlassGroupedGemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, + NVTETensor* workspace, float alpha, float beta, int num_gemms, + cudaStream_t stream, int device, int math_sm_count) { + using Gemm = GemmGrouped; + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + + using StrideA = typename Gemm::GemmKernel::InternalStrideA; + using StrideB = typename Gemm::GemmKernel::InternalStrideB; + using StrideC = typename Gemm::GemmKernel::InternalStrideC; + + typename Gemm::Arguments arguments; + size_t kernel_workspace_size = Gemm::get_workspace_size(arguments); + auto gemm_coord_size = getGemmCoordSize(num_gemms); + auto ptr_size = getPtrSize(num_gemms); + auto ldd_size = getLddSize(num_gemms); + auto param_workspace_size = 3 * ptr_size + 3 * ldd_size + gemm_coord_size; + + NVTE_CHECK( + param_workspace_size < kCPUWorkSpaceSize, + "Insufficient kCPUWorkSpaceSize size: required=", static_cast(param_workspace_size), + ", available=", static_cast(kCPUWorkSpaceSize), " for CUTLASS grouped GEMM."); + + auto total_workspace_size = param_workspace_size + kernel_workspace_size; + transformer_engine::Tensor* wspace = transformer_engine::convertNVTETensor(workspace[0]); + + NVTE_CHECK(total_workspace_size < wspace->numel(), "Insufficient workspace[0] size: required=", + static_cast(total_workspace_size), + ", available=", static_cast(wspace->numel()), " for CUTLASS grouped GEMM."); + + char* workspace_ptr = reinterpret_cast(wspace->data.dptr); + + char* kernel_workspace_ptr = nullptr; + + char* host_workspace = getHostWorkspace(); + + ProblemShapeType* problem_sizes_host = reinterpret_cast(host_workspace); + + ElementA** ptr_A_host = reinterpret_cast(host_workspace + gemm_coord_size); + ElementB** ptr_B_host = reinterpret_cast(host_workspace + gemm_coord_size + ptr_size); + ElementC** ptr_C_host = + reinterpret_cast(host_workspace + gemm_coord_size + 2 * ptr_size); + int64_t* lda_host = + reinterpret_cast(host_workspace + gemm_coord_size + 3 * ptr_size + 0 * ldd_size); + int64_t* ldb_host = + reinterpret_cast(host_workspace + gemm_coord_size + 3 * ptr_size + 1 * ldd_size); + int64_t* ldc_host = + reinterpret_cast(host_workspace + gemm_coord_size + 3 * ptr_size + 2 * ldd_size); + + for (size_t i = 0; i < num_gemms; i++) { + const transformer_engine::Tensor* inputA = transformer_engine::convertNVTETensorCheck(A[i]); + const transformer_engine::Tensor* inputB = transformer_engine::convertNVTETensorCheck(B[i]); + transformer_engine::Tensor* outputD = transformer_engine::convertNVTETensor(D[i]); + + const int m = trans_a ? inputA->data.shape[1] : inputA->data.shape[0]; + const int k = trans_a ? inputA->data.shape[0] : inputA->data.shape[1]; + const int n = trans_b ? inputB->data.shape[0] : inputB->data.shape[1]; + + auto problem = ProblemShapeType(m, n, k); + problem_sizes_host[i] = problem; + + ptr_A_host[i] = reinterpret_cast(inputA->data.dptr); + ptr_B_host[i] = reinterpret_cast(inputB->data.dptr); + ptr_C_host[i] = reinterpret_cast(outputD->data.dptr); + + lda_host[i] = LayoutA::packed({m, k}).stride(0); + ldb_host[i] = LayoutB::packed({k, n}).stride(0); + ldc_host[i] = LayoutC::packed({m, n}).stride(0); + } + + cudaMemcpyAsync(workspace_ptr, host_workspace, param_workspace_size, cudaMemcpyHostToDevice, + stream); + + char* param_workspace_ptr = workspace_ptr; + ProblemShapeType* problem_sizes_device = reinterpret_cast(param_workspace_ptr); + const ElementA** ptr_A = reinterpret_cast( + reinterpret_cast(param_workspace_ptr) + gemm_coord_size); + const ElementB** ptr_B = reinterpret_cast( + reinterpret_cast(param_workspace_ptr) + gemm_coord_size + 1 * ptr_size); + ElementC** ptr_C = reinterpret_cast(reinterpret_cast(param_workspace_ptr) + + gemm_coord_size + 2 * ptr_size); + + StrideA* lda = reinterpret_cast(reinterpret_cast(param_workspace_ptr) + + gemm_coord_size + 3 * ptr_size + 0 * ldd_size); + StrideB* ldb = reinterpret_cast(reinterpret_cast(param_workspace_ptr) + + gemm_coord_size + 3 * ptr_size + 1 * ldd_size); + StrideC* ldc = reinterpret_cast(reinterpret_cast(param_workspace_ptr) + + gemm_coord_size + 3 * ptr_size + 2 * ldd_size); + + kernel_workspace_ptr = workspace_ptr + param_workspace_size; + + arguments = MakeArguments( + num_gemms, problem_sizes_host, problem_sizes_device, ptr_A, lda, ptr_B, ldb, ptr_C, ldc, + alpha, beta, device, math_sm_count); + + Gemm gemm; + + // Check can implement the kernel. + if (gemm.can_implement(arguments) != cutlass::Status::kSuccess) { + NVTE_CHECK(false, "Failed to implement CUTLASS Grouped GEMM"); + } + + // Initialize the kernel. + if (gemm.initialize(arguments, kernel_workspace_ptr) != cutlass::Status::kSuccess) { + NVTE_CHECK(false, "Failed to initialize CUTLASS Grouped GEMM"); + } + + // Execute the kernel in the current stream. + if (gemm.run(stream) != cutlass::Status::kSuccess) { + NVTE_CHECK(false, "Failed to run CUTLASS Grouped GEMM"); + } +} + +} // namespace grouped_gemm + +void cutlass_grouped_gemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, int num_gemms, + bool transa, bool transb, bool grad, NVTETensor* workspace, + bool accumulate, int device, int math_sm_count, cudaStream_t stream); diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 50b33909fb..0c358328b6 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -133,12 +133,11 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor * \param[in] math_sm_count Number of GPU SMs to use (default=0: use cuBLAS heuristics) * \param[in] stream CUDA stream to wait on. */ -void nvte_multi_stream_cublas_gemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, - const NVTETensor* bias, NVTETensor* pre_gelu_out, - const int num_gemms, bool transa, bool transb, bool grad, - NVTETensor* workspace, bool accumulate, - bool use_split_accumulator, int math_sm_count, - cudaStream_t stream); +void nvte_multi_tensor_gemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, + const NVTETensor* bias, NVTETensor* pre_gelu_out, const int num_gemms, + bool transa, bool transb, bool grad, NVTETensor* workspace, + bool accumulate, bool use_split_accumulator, int math_sm_count, + cudaStream_t stream); #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 032ac9eb70..330c9019d2 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -564,10 +564,10 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type NVTE_CHECK_CUDA(cudaMemsetAsync(dptr, 0, count, stream_i)); } - nvte_multi_stream_cublas_gemm(rhs_list.data(), lhs_list.data(), out_list.data(), bias_list.data(), - pre_gelu_list.data(), num_non_empty_gemms, rhs_is_trans, - lhs_is_trans, grad, workspace_list.data(), accumulate, - use_split_accumulator, num_math_sm, stream); + nvte_multi_tensor_gemm(rhs_list.data(), lhs_list.data(), out_list.data(), bias_list.data(), + pre_gelu_list.data(), num_non_empty_gemms, rhs_is_trans, lhs_is_trans, + grad, workspace_list.data(), accumulate, use_split_accumulator, + num_math_sm, stream); return ffi_with_cuda_error_check(); } diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index f4768bb9ba..e95ce455d1 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -448,11 +448,10 @@ std::optional> te_general_grouped_gemm( // For now, we only have multi-stream cublas backend. NVTE_SCOPED_GIL_RELEASE({ - nvte_multi_stream_cublas_gemm(te_A_vector.data(), te_B_vector.data(), te_D_vector.data(), - te_bias_vector.data(), te_pre_gelu_out_vector.data(), - te_A_vector.size(), transa, transb, grad, - te_workspace_vector.data(), accumulate, use_split_accumulator, - math_sm_count, at::cuda::getCurrentCUDAStream()); + nvte_multi_tensor_gemm(te_A_vector.data(), te_B_vector.data(), te_D_vector.data(), + te_bias_vector.data(), te_pre_gelu_out_vector.data(), te_A_vector.size(), + transa, transb, grad, te_workspace_vector.data(), accumulate, + use_split_accumulator, math_sm_count, at::cuda::getCurrentCUDAStream()); }); return bias; }