Skip to content

Commit e1da1b2

Browse files
authored
Fix dynamic block-sparse attention kernel performance
2 parents f8db33a + 02ec723 commit e1da1b2

File tree

3 files changed

+404
-512
lines changed

3 files changed

+404
-512
lines changed

csrc/flash_api.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,9 +229,22 @@ void set_params_dgrad(
229229
}
230230

231231
void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) {
232+
int device;
233+
cudaGetDevice(&device);
234+
int max_smem_per_block;
235+
cudaError status_ = cudaDeviceGetAttribute(
236+
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device
237+
);
238+
if (status_ != cudaSuccess) {
239+
C10_CUDA_CHECK(status_);
240+
}
241+
232242
FP16_SWITCH(!params.is_bf16, [&] {
233243
HEADDIM_SWITCH(params.d, [&] {
234244
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
245+
// splitkv kernel is not supported for head_dim >= 128 in sm89 due to smem limits
246+
bool splitkv_forbidden = (kHeadDim >= 128) && (max_smem_per_block < 112 * 1024);
247+
params.num_splits = splitkv_forbidden ? 1 : params.num_splits;
235248
if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
236249
run_mha_fwd_<elem_type, kHeadDim, Is_causal>(params, stream);
237250
} else {
@@ -298,7 +311,7 @@ std::tuple<at::Tensor, at::Tensor> set_params_splitkv(
298311
) {
299312

300313
// This needs to match with run_mha_fwd_splitkv_dispatch
301-
const int block_n = head_size <= 64 ? 64 : (head_size <= 128 ? 64 : 32);
314+
const int block_n = head_size <= 32 ? 128 : (head_size <= 128 ? 128 : 64);
302315
const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n;
303316
// Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
304317
// In any case we don't expect seqlen_q to be larger than 64 for inference.

0 commit comments

Comments
 (0)