Skip to content

Commit e5aa708

Browse files
【bug fix】修复w4a8编译慢 (#3510)
* 修复w4a8编译 * code style * 修复tma copy
1 parent a5692e8 commit e5aa708

File tree

3 files changed

+10
-55
lines changed

3 files changed

+10
-55
lines changed

custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu

Lines changed: 3 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -77,16 +77,7 @@ void DisPatchW4AFp8Gemm(
7777
max_tokens,
7878
stream)
7979
} else {
80-
GEMM_SWITCH_FP16(
81-
M, K, batch_size, token_padding_size, kBlockN, TailN,
82-
weight,
83-
input,
84-
out,
85-
weight_scale,
86-
input_row_sum,
87-
tokens,
88-
max_tokens,
89-
stream)
80+
PD_THROW("Only supported dtype in ['BFLOAT16'].");
9081
}
9182
}
9283

@@ -128,22 +119,7 @@ std::vector<paddle::Tensor> W4AFp8Gemm(
128119
input.stream());
129120
return {out};
130121
} else {
131-
paddle::Tensor out = paddle::empty({all_tokens, M}, paddle::DataType::FLOAT16, input.place());
132-
phi::dtype::float16 *out_data = out.data<phi::dtype::float16>();
133-
DisPatchW4AFp8Gemm(
134-
reinterpret_cast<const cutlass::float_e4m3_t*>(input.data<phi::dtype::float8_e4m3fn>()),
135-
reinterpret_cast<const cutlass::float_e4m3_t*>(weight.data<uint8_t>()),
136-
tokens.data<int>(),
137-
input_row_sum.data<float>(),
138-
weight_scale.data<float>(),
139-
reinterpret_cast<cutlass::half_t*>(out_data),
140-
token_padding_size,
141-
max_tokens,
142-
batch_size,
143-
M,
144-
K,
145-
input.stream());
146-
return {out};
122+
PD_THROW("Only supported dtype in ['BFLOAT16'].");
147123
}
148124
} else {
149125
if (is_bflot16) {
@@ -164,23 +140,7 @@ std::vector<paddle::Tensor> W4AFp8Gemm(
164140
input.stream());
165141
return {out};
166142
} else {
167-
paddle::Tensor out = paddle::empty({batch_size, token_padding_size, M}, paddle::DataType::FLOAT16, input.place());
168-
phi::dtype::float16 * out_data = out.data<phi::dtype::float16>();
169-
170-
DisPatchW4AFp8Gemm(
171-
reinterpret_cast<const cutlass::float_e4m3_t*>(input.data<phi::dtype::float8_e4m3fn>()),
172-
reinterpret_cast<const cutlass::float_e4m3_t*>(weight.data<uint8_t>()),
173-
tokens.data<int>(),
174-
input_row_sum.data<float>(),
175-
weight_scale.data<float>(),
176-
reinterpret_cast<cutlass::half_t*>(out_data),
177-
token_padding_size,
178-
max_tokens,
179-
batch_size,
180-
M,
181-
K,
182-
input.stream());
183-
return {out};
143+
PD_THROW("Only supported dtype in ['BFLOAT16'].");
184144
}
185145
}
186146
}

custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,9 @@
8383
}}
8484
"""
8585

86-
gemm_case = [
87-
[8192, 3584, 8, 0], # eb45T ffn1
88-
[8192, 3584, 8, 2048], # eb45T ffn1
89-
[7168, 8192, 8, 0], # eb45T ffn2
90-
[7168, 8192, 8, 2048], # eb45T ffn2
91-
]
92-
93-
dtype = ["BF16", "FP16"]
86+
gemm_case = [[256, 256, 1, 0]]
87+
88+
dtype = ["BF16"]
9489

9590

9691
def get_cutlass_type(type):

tests/operators/test_w4afp8_gemm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,10 @@ def peruate_scale(weight_scale):
4444

4545

4646
paddle.seed(0)
47-
tokens_per_group = 32
48-
N = 8192
49-
K = 3584
50-
BATCH = 8
47+
tokens_per_group = 256
48+
N = 256
49+
K = 256
50+
BATCH = 1
5151
TokenPadding = 0
5252

5353
tokens = [tokens_per_group] * BATCH

0 commit comments

Comments
 (0)