diff --git a/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu b/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu index 4ff79ad854..349e5e7d4c 100644 --- a/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu +++ b/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu @@ -77,16 +77,7 @@ void DisPatchW4AFp8Gemm( max_tokens, stream) } else { - GEMM_SWITCH_FP16( - M, K, batch_size, token_padding_size, kBlockN, TailN, - weight, - input, - out, - weight_scale, - input_row_sum, - tokens, - max_tokens, - stream) + PD_THROW("Only supported dtype in ['BFLOAT16']."); } } @@ -128,22 +119,7 @@ std::vector W4AFp8Gemm( input.stream()); return {out}; } else { - paddle::Tensor out = paddle::empty({all_tokens, M}, paddle::DataType::FLOAT16, input.place()); - phi::dtype::float16 *out_data = out.data(); - DisPatchW4AFp8Gemm( - reinterpret_cast(input.data()), - reinterpret_cast(weight.data()), - tokens.data(), - input_row_sum.data(), - weight_scale.data(), - reinterpret_cast(out_data), - token_padding_size, - max_tokens, - batch_size, - M, - K, - input.stream()); - return {out}; + PD_THROW("Only supported dtype in ['BFLOAT16']."); } } else { if (is_bflot16) { @@ -164,23 +140,7 @@ std::vector W4AFp8Gemm( input.stream()); return {out}; } else { - paddle::Tensor out = paddle::empty({batch_size, token_padding_size, M}, paddle::DataType::FLOAT16, input.place()); - phi::dtype::float16 * out_data = out.data(); - - DisPatchW4AFp8Gemm( - reinterpret_cast(input.data()), - reinterpret_cast(weight.data()), - tokens.data(), - input_row_sum.data(), - weight_scale.data(), - reinterpret_cast(out_data), - token_padding_size, - max_tokens, - batch_size, - M, - K, - input.stream()); - return {out}; + PD_THROW("Only supported dtype in ['BFLOAT16']."); } } } diff --git a/custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py b/custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py index 87b06fa747..d7e8ad6b6e 100644 --- a/custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py +++ b/custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py @@ -83,14 +83,9 @@ }} """ -gemm_case = [ - [8192, 3584, 8, 0], # eb45T ffn1 - [8192, 3584, 8, 2048], # eb45T ffn1 - [7168, 8192, 8, 0], # eb45T ffn2 - [7168, 8192, 8, 2048], # eb45T ffn2 -] - -dtype = ["BF16", "FP16"] +gemm_case = [[256, 256, 1, 0]] + +dtype = ["BF16"] def get_cutlass_type(type): diff --git a/tests/operators/test_w4afp8_gemm.py b/tests/operators/test_w4afp8_gemm.py index 1cd1bd6ea8..f6e38d4883 100644 --- a/tests/operators/test_w4afp8_gemm.py +++ b/tests/operators/test_w4afp8_gemm.py @@ -44,10 +44,10 @@ def peruate_scale(weight_scale): paddle.seed(0) -tokens_per_group = 32 -N = 8192 -K = 3584 -BATCH = 8 +tokens_per_group = 256 +N = 256 +K = 256 +BATCH = 1 TokenPadding = 0 tokens = [tokens_per_group] * BATCH