Skip to content

Commit c3010d9

Browse files
author
Min Yang
committed
feat: add cutlass support type protect
Signed-off-by: Min Yang <[email protected]>
1 parent 22f5c47 commit c3010d9

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

transformer_engine/common/gemm/cublaslt_gemm.cu

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -784,8 +784,17 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor
784784
return true;
785785
};
786786

787+
auto *inputA = transformer_engine::convertNVTETensorCheck(A[0]);
788+
auto *inputB = transformer_engine::convertNVTETensorCheck(B[0]);
789+
auto A_type = grouped_gemm::get_cuda_dtype(inputA->data.dtype);
790+
auto B_type = grouped_gemm::get_cuda_dtype(inputB->data.dtype);
791+
792+
NVTE_CHECK(A_type == B_type, "A/B dtype mismatch in cutlass_grouped_gemm.");
793+
bool supported_data_type_flag = (A_type == CUDA_R_16BF) || (A_type == CUDA_R_16F);
794+
787795
// Currently only supports the case when bias is null, in this case the grad flag can be ignored.
788-
if (is_empty_arr(bias) && is_empty_arr(pre_gelu_out) && !use_split_accumulator) {
796+
if (is_empty_arr(bias) && is_empty_arr(pre_gelu_out) && !use_split_accumulator &&
797+
supported_data_type_flag) {
789798
cutlass_grouped_gemm(A, B, D, num_gemms, transa, transb, grad, workspace, accumulate,
790799
current_device, math_sm_count, stream);
791800
} else {

transformer_engine/common/gemm/cutlass_groupgemm.cu

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ void cutlass_grouped_gemm(const NVTETensor* A, const NVTETensor* B, NVTETensor*
6565

6666
auto A_type = grouped_gemm::get_cuda_dtype(inputA->data.dtype);
6767
auto B_type = grouped_gemm::get_cuda_dtype(inputB->data.dtype);
68-
NVTE_CHECK(A_type == B_type, "A/B dtype mismatch in cutlass_grouped_gemm.");
6968

7069
float one = 1.0;
7170
float zero = 0.0;

0 commit comments

Comments
 (0)