Skip to content

Commit 02ec723

Browse files
committed
Optimizes block sizes and removes unused variables
Adjusts block size parameters for better performance across different head dimensions and removes unused shared memory per multiprocessor variables. Updates block size logic to use larger blocks for smaller head dimensions and enables kernel optimizations for the 256 head dimension case when shared memory is limited.
1 parent 7362c09 commit 02ec723

File tree

1 file changed

+17
-22
lines changed

1 file changed

+17
-22
lines changed

csrc/src/flash_fwd_launch_template.h

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
155155
template<typename T, int Headdim, bool Is_causal>
156156
void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream) {
157157
constexpr static int kBlockM = 64; // Fixed for all head dimensions
158-
constexpr static int kBlockN = Headdim <= 64 ? 64 : (Headdim <= 128 ? 64 : 32);
158+
constexpr static int kBlockN = Headdim <= 32 ? 128 : (Headdim <= 128 ? 128 : 64);
159159
run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>, Is_causal>(params, stream);
160160
}
161161

@@ -164,11 +164,10 @@ void run_mha_fwd_hdim32(Flash_fwd_params &params, cudaStream_t stream) {
164164
constexpr static int Headdim = 32;
165165
int device;
166166
cudaGetDevice(&device);
167-
int max_smem_per_sm, max_smem_per_block;
167+
int max_smem_per_block;
168168
cudaError status_ = cudaDeviceGetAttribute(
169-
&max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device);
170-
status_ = cudaDeviceGetAttribute(
171-
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
169+
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device
170+
);
172171
if (status_ != cudaSuccess) {
173172
C10_CUDA_CHECK(status_);
174173
}
@@ -184,11 +183,10 @@ void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
184183
constexpr static int Headdim = 64;
185184
int device;
186185
cudaGetDevice(&device);
187-
int max_smem_per_sm, max_smem_per_block;
186+
int max_smem_per_block;
188187
cudaError status_ = cudaDeviceGetAttribute(
189-
&max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device);
190-
status_ = cudaDeviceGetAttribute(
191-
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
188+
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device
189+
);
192190
if (status_ != cudaSuccess) {
193191
C10_CUDA_CHECK(status_);
194192
}
@@ -204,11 +202,10 @@ void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
204202
constexpr static int Headdim = 96;
205203
int device;
206204
cudaGetDevice(&device);
207-
int max_smem_per_sm, max_smem_per_block;
205+
int max_smem_per_block;
208206
cudaError status_ = cudaDeviceGetAttribute(
209-
&max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device);
210-
status_ = cudaDeviceGetAttribute(
211-
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
207+
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device
208+
);
212209
if (status_ != cudaSuccess) {
213210
C10_CUDA_CHECK(status_);
214211
}
@@ -224,11 +221,10 @@ void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
224221
constexpr static int Headdim = 128;
225222
int device;
226223
cudaGetDevice(&device);
227-
int max_smem_per_sm, max_smem_per_block;
224+
int max_smem_per_block;
228225
cudaError status_ = cudaDeviceGetAttribute(
229-
&max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device);
230-
status_ = cudaDeviceGetAttribute(
231-
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
226+
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device
227+
);
232228
if (status_ != cudaSuccess) {
233229
C10_CUDA_CHECK(status_);
234230
}
@@ -252,18 +248,17 @@ void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
252248
constexpr static int Headdim = 256;
253249
int device;
254250
cudaGetDevice(&device);
255-
int max_smem_per_sm, max_smem_per_block;
251+
int max_smem_per_block;
256252
cudaError status_ = cudaDeviceGetAttribute(
257-
&max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device);
258-
status_ = cudaDeviceGetAttribute(
259-
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
253+
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device
254+
);
260255
if (status_ != cudaSuccess) {
261256
C10_CUDA_CHECK(status_);
262257
}
263258
if (max_smem_per_block >= 224 * 1024) {
264259
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_causal>(params, stream);
265260
} else {
266-
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_causal>(params, stream);
261+
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, true, true, T>, Is_causal>(params, stream);
267262
}
268263
}
269264

0 commit comments

Comments
 (0)