From f2fc15c80394aa61231f0bacc15ee9b9173dae06 Mon Sep 17 00:00:00 2001 From: Lequn Chen Date: Wed, 22 Nov 2023 21:34:54 +0000 Subject: [PATCH] update flashinfer to 25f7c03dcf577e0824382c47fed9d6d308dbbd69 --- csrc/flashinfer_adapter/flashinfer_all.cu | 26 ++-- csrc/sgmv_flashinfer/sgmv_flashinfer.cuh | 173 +++++++++++++--------- setup.py | 4 + third_party/flashinfer | 2 +- 4 files changed, 121 insertions(+), 84 deletions(-) diff --git a/csrc/flashinfer_adapter/flashinfer_all.cu b/csrc/flashinfer_adapter/flashinfer_all.cu index 5cbcf72..e127268 100644 --- a/csrc/flashinfer_adapter/flashinfer_all.cu +++ b/csrc/flashinfer_adapter/flashinfer_all.cu @@ -6,6 +6,10 @@ #include "flashinfer/page.cuh" #include "flashinfer_config.h" +using flashinfer::paged_kv_t; +using flashinfer::PageStorage; +using flashinfer::RotaryMode; + template void FlashInferBatchDecodeKernel(T* o, T* q, T* kv_data, int32_t* kv_indptr, int32_t* kv_indicies, @@ -13,11 +17,11 @@ void FlashInferBatchDecodeKernel(T* o, T* q, T* kv_data, int32_t* kv_indptr, int num_layers, int layer_idx, int num_qo_heads, int num_kv_heads, int page_size, int batch_size) { - flashinfer::paged_kv_t paged_kv( + paged_kv_t paged_kv( num_layers, layer_idx, num_kv_heads, page_size, head_dim, batch_size, - kv_data, kv_indptr, kv_indicies, last_page_offset); + kv_data, kv_indicies, kv_indptr, last_page_offset); flashinfer::BatchDecodeWithPagedKVCache(q, paged_kv, o, nullptr, num_qo_heads, - flashinfer::RotaryMode::kLlama); + RotaryMode::kLlama); } template @@ -26,9 +30,9 @@ void FlashInferInitKvKernel(T* kv_data, int32_t* kv_indptr, T* key, T* value, int32_t* seqlen_indptr, int num_layers, int layer_idx, int num_kv_heads, int page_size, int batch_size) { - flashinfer::paged_kv_t paged_kv( + paged_kv_t paged_kv( num_layers, layer_idx, num_kv_heads, page_size, head_dim, batch_size, - kv_data, kv_indptr, kv_indicies, last_page_offset); + kv_data, kv_indicies, kv_indptr, last_page_offset); constexpr size_t vec_size = std::max(16 / sizeof(T), static_cast(head_dim / 32)); @@ -36,8 +40,8 @@ void FlashInferInitKvKernel(T* kv_data, int32_t* kv_indptr, constexpr size_t bdy = 128 / bdx; dim3 nblks(paged_kv.batch_size * paged_kv.num_heads / bdy); dim3 nthrs(bdx, bdy); - flashinfer::AppendPagedKVCachePrefillKernel + flashinfer::AppendPagedKVCachePrefillKernel <<>>(paged_kv, key, value, seqlen_indptr); } @@ -46,9 +50,9 @@ void FlashInferAppendKvKernel(T* kv_data, int32_t* kv_indptr, int32_t* kv_indicies, int32_t* last_page_offset, T* key, T* value, int num_layers, int layer_idx, int num_kv_heads, int page_size, int batch_size) { - flashinfer::paged_kv_t paged_kv( + paged_kv_t paged_kv( num_layers, layer_idx, num_kv_heads, page_size, head_dim, batch_size, - kv_data, kv_indptr, kv_indicies, last_page_offset); + kv_data, kv_indicies, kv_indptr, last_page_offset); constexpr size_t vec_size = std::max(16 / sizeof(T), static_cast(head_dim / 32)); @@ -56,8 +60,8 @@ void FlashInferAppendKvKernel(T* kv_data, int32_t* kv_indptr, constexpr size_t bdy = 128 / bdx; dim3 nblks(paged_kv.batch_size * paged_kv.num_heads / bdy); dim3 nthrs(bdx, bdy); - flashinfer::AppendPagedKVCacheDecodeKernel + flashinfer::AppendPagedKVCacheDecodeKernel <<>>(paged_kv, key, value); } diff --git a/csrc/sgmv_flashinfer/sgmv_flashinfer.cuh b/csrc/sgmv_flashinfer/sgmv_flashinfer.cuh index 25c93a7..45afffa 100644 --- a/csrc/sgmv_flashinfer/sgmv_flashinfer.cuh +++ b/csrc/sgmv_flashinfer/sgmv_flashinfer.cuh @@ -10,11 +10,14 @@ namespace flashinfer { namespace sgmv { -template -__global__ void sgmv_shrink(T* y, T* x, T** w, IdType* s, float* tmp, uint32_t num_problems, - uint32_t d_in, uint32_t layer_idx, uint32_t chunk_size) { +template +__global__ void sgmv_shrink(T* y, T* x, T** w, IdType* s, float* tmp, + uint32_t num_problems, uint32_t d_in, + uint32_t layer_idx, uint32_t chunk_size) { auto block = cooperative_groups::this_thread_block(); auto grid = cooperative_groups::this_grid(); + constexpr auto fill_mode = cp_async::SharedMemFillMode::kFillZero; const uint32_t problem_id = blockIdx.y; const uint32_t bx = blockIdx.x; const uint32_t s_start = s[problem_id], s_end = s[problem_id + 1]; @@ -24,57 +27,64 @@ __global__ void sgmv_shrink(T* y, T* x, T** w, IdType* s, float* tmp, uint32_t n constexpr uint32_t num_blocks_n = d_out / 16; const uint32_t num_chunks = gridDim.x; const uint32_t chunk_start = chunk_size * bx; - const uint32_t num_iterations = (chunk_size + (num_k_frags * 16 - 1)) / (num_k_frags * 16); - constexpr uint32_t num_cells_n = (d_out < 32 ? 32 : d_out) / cell_capacity(); + const uint32_t num_iterations = + (chunk_size + (num_k_frags * 16 - 1)) / (num_k_frags * 16); + constexpr uint32_t num_cells_n = + (d_out < 32 ? 32 : d_out) / cell_capacity(); const uint32_t tx = threadIdx.x, ty = threadIdx.y; extern __shared__ uint8_t smem[]; smem_t x_smem[2]{smem, smem + sizeof(T) * num_warps * 16 * 16 * num_k_frags}; smem_t w_smem[2]{smem + sizeof(T) * 2 * num_warps * 16 * 16 * num_k_frags, - smem + sizeof(T) * 16 * 16 * num_k_frags * (2 * num_warps + num_blocks_n)}; + smem + sizeof(T) * 16 * 16 * num_k_frags * + (2 * num_warps + num_blocks_n)}; smem_t y_smem(smem); uint32_t x_frag[num_k_frags][4]; uint32_t w_frag[num_k_frags][num_blocks_n][4]; float y_frag[num_blocks_n][8]; - for (uint32_t i = 0; i < (s_end - s_start + (num_warps * 16 - 1)) / (num_warps * 16); ++i) { + for (uint32_t i = 0; + i < (s_end - s_start + (num_warps * 16 - 1)) / (num_warps * 16); ++i) { // init y_frag if (bx == 0) { if constexpr (num_blocks_n == 1) { uint32_t row_idx = s_start + (i * num_warps + ty) * 16 + tx / 2; T* y_ptr = y + row_idx * d_out + (tx % 2) * cell_capacity(); - y_smem.offset = smem_t::get_permuted_offset(ty * 16 + tx / 2, tx % 2); - y_smem.load_128b_async(y_ptr, row_idx < s_end); + auto offset = + smem_t::get_permuted_offset(ty * 16 + tx / 2, tx % 2); + y_smem.load_128b_async(offset, y_ptr, row_idx < s_end); } else { uint32_t row_idx = s_start + (i * num_warps + ty) * 16 + tx / 4; T* y_ptr = y + row_idx * d_out + (tx % 4) * cell_capacity(); - y_smem.offset = smem_t::get_permuted_offset(ty * 16 + tx / 4, tx % 4); + auto offset = + smem_t::get_permuted_offset(ty * 16 + tx / 4, tx % 4); #pragma unroll for (uint32_t j = 0; j < 2; ++j) { #pragma unroll for (uint32_t fno = 0; fno < num_blocks_n / 2; ++fno) { - y_smem.load_128b_async(y_ptr, row_idx < s_end); + y_smem.load_128b_async(offset, y_ptr, row_idx < s_end); y_ptr += 4 * cell_capacity(); - y_smem.offset += 8; + offset += 8; } row_idx += 8; y_ptr += 8 * d_out; - y_smem.offset += 8 * num_cells_n - 4 * num_blocks_n; + offset += 8 * num_cells_n - 4 * num_blocks_n; } } cp_async::commit_group(); cp_async::wait_group<0>(); block.sync(); - y_smem.offset = smem_t::get_permuted_offset(ty * 16 + tx % 16, tx / 16); + auto offset = + smem_t::get_permuted_offset(ty * 16 + tx % 16, tx / 16); #pragma unroll for (uint32_t fn = 0; fn < num_blocks_n; ++fn) { uint32_t tmp[4]; - y_smem.ldmatrix_m8n8x4(tmp); + y_smem.ldmatrix_m8n8x4(offset, tmp); vec_cast(y_frag[fn], (T*)tmp); - y_smem.offset = (y_smem.offset ^ 0x2) + (fn & 0x1) * 8; + offset = (offset ^ 0x2) + (fn & 0x1) * 8; } } else { #pragma unroll @@ -90,23 +100,25 @@ __global__ void sgmv_shrink(T* y, T* x, T** w, IdType* s, float* tmp, uint32_t n #pragma unroll for (uint32_t iter = 0; iter < num_stages; ++iter) { uint32_t row_idx = s_start + (i * num_warps + ty) * 16 + tx / 4; - T* x_ptr = - x + row_idx * d_in + chunk_start + (2 * num_k_frags * iter + tx % 4) * cell_capacity(); + T* x_ptr = x + row_idx * d_in + chunk_start + + (2 * num_k_frags * iter + tx % 4) * cell_capacity(); T* x_ptr_max = x + row_idx * d_in + min(d_in, chunk_start + chunk_size); - x_smem[iter].offset = smem_t::get_permuted_offset(ty * 16 + tx / 4, tx % 4); + auto offset = + smem_t::get_permuted_offset(ty * 16 + tx / 4, tx % 4); // pre-load x_smem, w_smem #pragma unroll for (uint32_t j = 0; j < 2; ++j) { #pragma unroll for (uint32_t fko = 0; fko < num_k_frags / 2; ++fko) { - x_smem[iter].load_128b_async(x_ptr, row_idx < s_end && x_ptr < x_ptr_max); + x_smem[iter].load_128b_async( + offset, x_ptr, row_idx < s_end && x_ptr < x_ptr_max); x_ptr += 4 * cell_capacity(); - x_smem[iter].offset += 8; + offset += 8; } row_idx += 8; x_ptr += 8 * d_in - 2 * cell_capacity() * num_k_frags; x_ptr_max += 8 * d_in; - x_smem[iter].offset += 8 * num_cells_k - 4 * num_k_frags; + offset += 8 * num_cells_k - 4 * num_k_frags; } row_idx -= 8; @@ -114,26 +126,29 @@ __global__ void sgmv_shrink(T* y, T* x, T** w, IdType* s, float* tmp, uint32_t n constexpr uint32_t num_fko_iters_per_warp = num_k_frags / (num_warps * 2); #pragma unroll for (uint32_t fn = 0; fn < num_blocks_n; ++fn) { - T* w_ptr = w[problem_id] + layer_idx * d_in * d_out + (fn * 16 + tx / 4) * d_in + - chunk_start + - (2 * num_k_frags * iter + ty * num_fko_iters_per_warp * 4 + tx % 4) * + T* w_ptr = w[problem_id] + layer_idx * d_in * d_out + + (fn * 16 + tx / 4) * d_in + chunk_start + + (2 * num_k_frags * iter + ty * num_fko_iters_per_warp * 4 + + tx % 4) * cell_capacity(); - T* w_ptr_max = w[problem_id] + layer_idx * d_in * d_out + - min((fn * 16 + tx / 4 + 1) * d_in, - (fn * 16 + tx / 4) * d_in + chunk_start + chunk_size); - w_smem[iter].offset = smem_t::get_permuted_offset( + T* w_ptr_max = + w[problem_id] + layer_idx * d_in * d_out + + min((fn * 16 + tx / 4 + 1) * d_in, + (fn * 16 + tx / 4) * d_in + chunk_start + chunk_size); + auto offset = smem_t::get_permuted_offset( fn * 16 + tx / 4, ty * num_fko_iters_per_warp * 4 + tx % 4); #pragma unroll for (uint32_t j = 0; j < 2; ++j) { #pragma unroll for (uint32_t fko = 0; fko < num_fko_iters_per_warp; ++fko) { - w_smem[iter].load_128b_async(w_ptr, w_ptr < w_ptr_max); + w_smem[iter].load_128b_async(offset, w_ptr, + w_ptr < w_ptr_max); w_ptr += 4 * cell_capacity(); - w_smem[iter].offset += 8; + offset += 8; } w_ptr += 8 * d_in - 4 * cell_capacity() * num_fko_iters_per_warp; w_ptr_max += 8 * d_in; - w_smem[iter].offset += 8 * num_cells_k - 8 * num_fko_iters_per_warp; + offset += 8 * num_cells_k - 8 * num_fko_iters_per_warp; } } cp_async::commit_group(); @@ -145,24 +160,24 @@ __global__ void sgmv_shrink(T* y, T* x, T** w, IdType* s, float* tmp, uint32_t n cp_async::wait_group<1>(); block.sync(); - x_smem[stage_idx].offset = + auto offset = smem_t::get_permuted_offset(ty * 16 + tx % 16, tx / 16); #pragma unroll for (uint32_t fk = 0; fk < num_k_frags; ++fk) { - x_smem[stage_idx].ldmatrix_m8n8x4(x_frag[fk]); - x_smem[stage_idx].offset = (x_smem[stage_idx].offset ^ 0x2) + (fk & 0x1) * 8; + x_smem[stage_idx].ldmatrix_m8n8x4(offset, x_frag[fk]); + offset = (offset ^ 0x2) + (fk & 0x1) * 8; } #pragma unroll for (uint32_t fn = 0; fn < num_blocks_n; ++fn) { - w_smem[stage_idx].offset = smem_t::get_permuted_offset( + auto offset = smem_t::get_permuted_offset( fn * 16 + 8 * (tx / 16) + tx % 8, (tx % 16) / 8); #pragma unroll for (uint32_t fk = 0; fk < num_k_frags; ++fk) { - w_smem[stage_idx].ldmatrix_m8n8x4(w_frag[fk][fn]); - w_smem[stage_idx].offset = (w_smem[stage_idx].offset ^ 0x2) + (fk & 0x1) * 8; + w_smem[stage_idx].ldmatrix_m8n8x4(offset, w_frag[fk][fn]); + offset = (offset ^ 0x2) + (fk & 0x1) * 8; } - w_smem[stage_idx].offset += 16 * num_cells_k - 4 * num_k_frags; + offset += 16 * num_cells_k - 4 * num_k_frags; } // compute y_frag @@ -170,7 +185,8 @@ __global__ void sgmv_shrink(T* y, T* x, T** w, IdType* s, float* tmp, uint32_t n for (uint32_t fk = 0; fk < num_k_frags; ++fk) { #pragma unroll for (uint32_t fn = 0; fn < num_blocks_n; ++fn) { - mma::mma_sync_m16n16k16_row_col_f16f16f32(y_frag[fn], x_frag[fk], w_frag[fk][fn]); + mma::mma_sync_m16n16k16_row_col_f16f16f32(y_frag[fn], x_frag[fk], + w_frag[fk][fn]); } } block.sync(); @@ -179,49 +195,55 @@ __global__ void sgmv_shrink(T* y, T* x, T** w, IdType* s, float* tmp, uint32_t n if (iter + num_stages < num_iterations) { uint32_t row_idx = s_start + (i * num_warps + ty) * 16 + tx / 4; T* x_ptr = x + row_idx * d_in + chunk_start + - (2 * num_k_frags * (iter + num_stages) + tx % 4) * cell_capacity(); + (2 * num_k_frags * (iter + num_stages) + tx % 4) * + cell_capacity(); T* x_ptr_max = x + row_idx * d_in + min(d_in, chunk_start + chunk_size); - x_smem[stage_idx].offset = + auto offset = smem_t::get_permuted_offset(ty * 16 + tx / 4, tx % 4); // pre-load x_smem, w_smem #pragma unroll for (uint32_t j = 0; j < 2; ++j) { #pragma unroll for (uint32_t fko = 0; fko < num_k_frags / 2; ++fko) { - x_smem[stage_idx].load_128b_async(x_ptr, row_idx < s_end && x_ptr < x_ptr_max); + x_smem[stage_idx].load_128b_async( + offset, x_ptr, row_idx < s_end && x_ptr < x_ptr_max); x_ptr += 4 * cell_capacity(); - x_smem[stage_idx].offset += 8; + offset += 8; } row_idx += 8; x_ptr += 8 * d_in - 2 * cell_capacity() * num_k_frags; x_ptr_max += 8 * d_in; - x_smem[stage_idx].offset += 8 * num_cells_k - 4 * num_k_frags; + offset += 8 * num_cells_k - 4 * num_k_frags; } row_idx -= 8; - constexpr uint32_t num_fko_iters_per_warp = num_k_frags / (num_warps * 2); + constexpr uint32_t num_fko_iters_per_warp = + num_k_frags / (num_warps * 2); #pragma unroll for (uint32_t fn = 0; fn < num_blocks_n; ++fn) { - T* w_ptr = - w[problem_id] + layer_idx * d_in * d_out + (fn * 16 + tx / 4) * d_in + chunk_start + - (2 * num_k_frags * (iter + num_stages) + ty * num_fko_iters_per_warp * 4 + tx % 4) * - cell_capacity(); - T* w_ptr_max = w[problem_id] + layer_idx * d_in * d_out + - min((fn * 16 + tx / 4 + 1) * d_in, - (fn * 16 + tx / 4) * d_in + chunk_start + chunk_size); - w_smem[stage_idx].offset = smem_t::get_permuted_offset( + T* w_ptr = w[problem_id] + layer_idx * d_in * d_out + + (fn * 16 + tx / 4) * d_in + chunk_start + + (2 * num_k_frags * (iter + num_stages) + + ty * num_fko_iters_per_warp * 4 + tx % 4) * + cell_capacity(); + T* w_ptr_max = + w[problem_id] + layer_idx * d_in * d_out + + min((fn * 16 + tx / 4 + 1) * d_in, + (fn * 16 + tx / 4) * d_in + chunk_start + chunk_size); + auto offset = smem_t::get_permuted_offset( fn * 16 + tx / 4, ty * num_fko_iters_per_warp * 4 + tx % 4); #pragma unroll for (uint32_t j = 0; j < 2; ++j) { #pragma unroll for (uint32_t fko = 0; fko < num_fko_iters_per_warp; ++fko) { - w_smem[stage_idx].load_128b_async(w_ptr, w_ptr < w_ptr_max); + w_smem[stage_idx].load_128b_async(offset, w_ptr, + w_ptr < w_ptr_max); w_ptr += 4 * cell_capacity(); - w_smem[stage_idx].offset += 8; + offset += 8; } w_ptr += 8 * d_in - 4 * cell_capacity() * num_fko_iters_per_warp; w_ptr_max += 8 * d_in; - w_smem[stage_idx].offset += 8 * num_cells_k - 8 * num_fko_iters_per_warp; + offset += 8 * num_cells_k - 8 * num_fko_iters_per_warp; } } } @@ -234,7 +256,8 @@ __global__ void sgmv_shrink(T* y, T* x, T** w, IdType* s, float* tmp, uint32_t n #pragma unroll for (uint32_t fn = 0; fn < num_blocks_n; ++fn) { vec_t::memcpy( - tmp + (fn * grid.size() + (problem_id * num_chunks + bx) * block.num_threads() + + tmp + (fn * grid.size() + + (problem_id * num_chunks + bx) * block.num_threads() + block.thread_rank()) * 8, y_frag[fn]); @@ -250,7 +273,8 @@ __global__ void sgmv_shrink(T* y, T* x, T** w, IdType* s, float* tmp, uint32_t n for (uint32_t chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) { vec_t y_other; y_other.load(tmp + (fn * grid.size() + - (problem_id * num_chunks + chunk_idx) * block.num_threads() + + (problem_id * num_chunks + chunk_idx) * + block.num_threads() + block.thread_rank()) * 8); #pragma unroll @@ -263,45 +287,50 @@ __global__ void sgmv_shrink(T* y, T* x, T** w, IdType* s, float* tmp, uint32_t n if (bx == 0) { // store y_frag - y_smem.offset = smem_t::get_permuted_offset(ty * 16 + tx / 4, 0); + auto offset = + smem_t::get_permuted_offset(ty * 16 + tx / 4, 0); #pragma unroll for (uint32_t fn = 0; fn < num_blocks_n; ++fn) { - vec_cast((T*)(y_smem.base + y_smem.offset) + (tx % 4) * 2, &y_frag[fn][0]); - vec_cast((T*)(y_smem.base + y_smem.offset + 8 * num_cells_n) + (tx % 4) * 2, - &y_frag[fn][2]); - vec_cast((T*)(y_smem.base + (y_smem.offset ^ 0x1)) + (tx % 4) * 2, + vec_cast((T*)(y_smem.base + offset) + (tx % 4) * 2, + &y_frag[fn][0]); + vec_cast( + (T*)(y_smem.base + offset + 8 * num_cells_n) + (tx % 4) * 2, + &y_frag[fn][2]); + vec_cast((T*)(y_smem.base + (offset ^ 0x1)) + (tx % 4) * 2, &y_frag[fn][4]); vec_cast( - (T*)(y_smem.base + (y_smem.offset ^ 0x1) + 8 * num_cells_n) + (tx % 4) * 2, + (T*)(y_smem.base + (offset ^ 0x1) + 8 * num_cells_n) + (tx % 4) * 2, &y_frag[fn][6]); - y_smem.offset = (y_smem.offset ^ 0x2) + (fn & 0x1) * 8; + offset = (offset ^ 0x2) + (fn & 0x1) * 8; } // store y if constexpr (num_blocks_n == 1) { uint32_t row_idx = s_start + (i * num_warps + ty) * 16 + tx / 2; T* y_ptr = y + row_idx * d_out + (tx % 2) * cell_capacity(); - y_smem.offset = smem_t::get_permuted_offset(ty * 16 + tx / 2, tx % 2); + auto offset = + smem_t::get_permuted_offset(ty * 16 + tx / 2, tx % 2); if (row_idx < s_end) { - y_smem.store_128b(y_ptr); + y_smem.store_128b(offset, y_ptr); } } else { uint32_t row_idx = s_start + (i * num_warps + ty) * 16 + tx / 4; T* y_ptr = y + row_idx * d_out + (tx % 4) * cell_capacity(); - y_smem.offset = smem_t::get_permuted_offset(ty * 16 + tx / 4, tx % 4); + auto offset = + smem_t::get_permuted_offset(ty * 16 + tx / 4, tx % 4); #pragma unroll for (uint32_t j = 0; j < 2; ++j) { #pragma unroll for (uint32_t fno = 0; fno < num_blocks_n / 2; ++fno) { if (row_idx < s_end) { - y_smem.store_128b(y_ptr); + y_smem.store_128b(offset, y_ptr); } y_ptr += 4 * cell_capacity(); - y_smem.offset += 8; + offset += 8; } row_idx += 8; y_ptr += 8 * d_out; - y_smem.offset += 8 * num_cells_n - 4 * num_blocks_n; + offset += 8 * num_cells_n - 4 * num_blocks_n; } } } diff --git a/setup.py b/setup.py index e3c029b..4102bb5 100644 --- a/setup.py +++ b/setup.py @@ -48,6 +48,10 @@ def remove_unwanted_pytorch_nvcc_flags(): str(root.resolve() / "third_party/cutlass/include"), str(root.resolve() / "third_party/flashinfer/include"), ], + extra_compile_args={ + "cxx": ["-O3"], + "nvcc": ["-O3"], + }, ) ) diff --git a/third_party/flashinfer b/third_party/flashinfer index 5834b34..25f7c03 160000 --- a/third_party/flashinfer +++ b/third_party/flashinfer @@ -1 +1 @@ -Subproject commit 5834b34e6b1f4c835abad33bd961471c12eae272 +Subproject commit 25f7c03dcf577e0824382c47fed9d6d308dbbd69