@@ -155,7 +155,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
155155template <typename T, int Headdim, bool Is_causal>
156156void run_mha_fwd_splitkv_dispatch (Flash_fwd_params ¶ms, cudaStream_t stream) {
157157 constexpr static int kBlockM = 64 ; // Fixed for all head dimensions
158- constexpr static int kBlockN = Headdim <= 64 ? 64 : (Headdim <= 128 ? 64 : 32 );
158+ constexpr static int kBlockN = Headdim <= 32 ? 128 : (Headdim <= 128 ? 128 : 64 );
159159 run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM , kBlockN , 4 , false , false , T>, Is_causal>(params, stream);
160160}
161161
@@ -164,11 +164,10 @@ void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) {
164164 constexpr static int Headdim = 32 ;
165165 int device;
166166 cudaGetDevice (&device);
167- int max_smem_per_sm, max_smem_per_block;
167+ int max_smem_per_block;
168168 cudaError status_ = cudaDeviceGetAttribute (
169- &max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device);
170- status_ = cudaDeviceGetAttribute (
171- &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
169+ &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device
170+ );
172171 if (status_ != cudaSuccess) {
173172 C10_CUDA_CHECK (status_);
174173 }
@@ -184,11 +183,10 @@ void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) {
184183 constexpr static int Headdim = 64 ;
185184 int device;
186185 cudaGetDevice (&device);
187- int max_smem_per_sm, max_smem_per_block;
186+ int max_smem_per_block;
188187 cudaError status_ = cudaDeviceGetAttribute (
189- &max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device);
190- status_ = cudaDeviceGetAttribute (
191- &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
188+ &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device
189+ );
192190 if (status_ != cudaSuccess) {
193191 C10_CUDA_CHECK (status_);
194192 }
@@ -204,11 +202,10 @@ void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) {
204202 constexpr static int Headdim = 96 ;
205203 int device;
206204 cudaGetDevice (&device);
207- int max_smem_per_sm, max_smem_per_block;
205+ int max_smem_per_block;
208206 cudaError status_ = cudaDeviceGetAttribute (
209- &max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device);
210- status_ = cudaDeviceGetAttribute (
211- &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
207+ &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device
208+ );
212209 if (status_ != cudaSuccess) {
213210 C10_CUDA_CHECK (status_);
214211 }
@@ -224,11 +221,10 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) {
224221 constexpr static int Headdim = 128 ;
225222 int device;
226223 cudaGetDevice (&device);
227- int max_smem_per_sm, max_smem_per_block;
224+ int max_smem_per_block;
228225 cudaError status_ = cudaDeviceGetAttribute (
229- &max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device);
230- status_ = cudaDeviceGetAttribute (
231- &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
226+ &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device
227+ );
232228 if (status_ != cudaSuccess) {
233229 C10_CUDA_CHECK (status_);
234230 }
@@ -252,18 +248,17 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) {
252248 constexpr static int Headdim = 256 ;
253249 int device;
254250 cudaGetDevice (&device);
255- int max_smem_per_sm, max_smem_per_block;
251+ int max_smem_per_block;
256252 cudaError status_ = cudaDeviceGetAttribute (
257- &max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device);
258- status_ = cudaDeviceGetAttribute (
259- &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
253+ &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device
254+ );
260255 if (status_ != cudaSuccess) {
261256 C10_CUDA_CHECK (status_);
262257 }
263258 if (max_smem_per_block >= 224 * 1024 ) {
264259 run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64 , 64 , 4 , false , false , T>, Is_causal>(params, stream);
265260 } else {
266- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64 , 32 , 4 , false , false , T>, Is_causal>(params, stream);
261+ run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64 , 64 , 4 , true , true , T>, Is_causal>(params, stream);
267262 }
268263}
269264
0 commit comments