@@ -229,9 +229,22 @@ void set_params_dgrad(
229229}
230230
231231void run_mha_fwd (Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split_kernel=false ) {
232+ int device;
233+ cudaGetDevice (&device);
234+ int max_smem_per_block;
235+ cudaError status_ = cudaDeviceGetAttribute (
236+ &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device
237+ );
238+ if (status_ != cudaSuccess) {
239+ C10_CUDA_CHECK (status_);
240+ }
241+
232242 FP16_SWITCH (!params.is_bf16 , [&] {
233243 HEADDIM_SWITCH (params.d , [&] {
234244 BOOL_SWITCH (params.is_causal , Is_causal, [&] {
245+ // splitkv kernel is not supported for head_dim >= 128 in sm89 due to smem limits
246+ bool splitkv_forbidden = (kHeadDim >= 128 ) && (max_smem_per_block < 112 * 1024 );
247+ params.num_splits = splitkv_forbidden ? 1 : params.num_splits ;
235248 if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
236249 run_mha_fwd_<elem_type, kHeadDim , Is_causal>(params, stream);
237250 } else {
@@ -298,7 +311,7 @@ std::tuple<at::Tensor, at::Tensor> set_params_splitkv(
298311) {
299312
300313 // This needs to match with run_mha_fwd_splitkv_dispatch
301- const int block_n = head_size <= 64 ? 64 : (head_size <= 128 ? 64 : 32 );
314+ const int block_n = head_size <= 32 ? 128 : (head_size <= 128 ? 128 : 64 );
302315 const int num_n_blocks = (max_seqlen_k + block_n - 1 ) / block_n;
303316 // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
304317 // In any case we don't expect seqlen_q to be larger than 64 for inference.
0 commit comments