@@ -77,16 +77,7 @@ void DisPatchW4AFp8Gemm(
7777 max_tokens,
7878 stream)
7979 } else {
80- GEMM_SWITCH_FP16 (
81- M, K, batch_size, token_padding_size, kBlockN , TailN,
82- weight,
83- input,
84- out,
85- weight_scale,
86- input_row_sum,
87- tokens,
88- max_tokens,
89- stream)
80+ PD_THROW (" Only supported dtype in ['BFLOAT16']." );
9081 }
9182}
9283
@@ -128,22 +119,7 @@ std::vector<paddle::Tensor> W4AFp8Gemm(
128119 input.stream ());
129120 return {out};
130121 } else {
131- paddle::Tensor out = paddle::empty ({all_tokens, M}, paddle::DataType::FLOAT16, input.place ());
132- phi::dtype::float16 *out_data = out.data <phi::dtype::float16>();
133- DisPatchW4AFp8Gemm (
134- reinterpret_cast <const cutlass::float_e4m3_t *>(input.data <phi::dtype::float8_e4m3fn>()),
135- reinterpret_cast <const cutlass::float_e4m3_t *>(weight.data <uint8_t >()),
136- tokens.data <int >(),
137- input_row_sum.data <float >(),
138- weight_scale.data <float >(),
139- reinterpret_cast <cutlass::half_t *>(out_data),
140- token_padding_size,
141- max_tokens,
142- batch_size,
143- M,
144- K,
145- input.stream ());
146- return {out};
122+ PD_THROW (" Only supported dtype in ['BFLOAT16']." );
147123 }
148124 } else {
149125 if (is_bflot16) {
@@ -164,23 +140,7 @@ std::vector<paddle::Tensor> W4AFp8Gemm(
164140 input.stream ());
165141 return {out};
166142 } else {
167- paddle::Tensor out = paddle::empty ({batch_size, token_padding_size, M}, paddle::DataType::FLOAT16, input.place ());
168- phi::dtype::float16 * out_data = out.data <phi::dtype::float16>();
169-
170- DisPatchW4AFp8Gemm (
171- reinterpret_cast <const cutlass::float_e4m3_t *>(input.data <phi::dtype::float8_e4m3fn>()),
172- reinterpret_cast <const cutlass::float_e4m3_t *>(weight.data <uint8_t >()),
173- tokens.data <int >(),
174- input_row_sum.data <float >(),
175- weight_scale.data <float >(),
176- reinterpret_cast <cutlass::half_t *>(out_data),
177- token_padding_size,
178- max_tokens,
179- batch_size,
180- M,
181- K,
182- input.stream ());
183- return {out};
143+ PD_THROW (" Only supported dtype in ['BFLOAT16']." );
184144 }
185145 }
186146}
0 commit comments