Skip to content

Commit

Permalink
update flashinfer to 25f7c03dcf577e0824382c47fed9d6d308dbbd69
Browse files Browse the repository at this point in the history
  • Loading branch information
abcdabcd987 committed Nov 22, 2023
1 parent 03be351 commit f2fc15c
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 84 deletions.
26 changes: 15 additions & 11 deletions csrc/flashinfer_adapter/flashinfer_all.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,22 @@
#include "flashinfer/page.cuh"
#include "flashinfer_config.h"

using flashinfer::paged_kv_t;
using flashinfer::PageStorage;
using flashinfer::RotaryMode;

template <typename T>
void FlashInferBatchDecodeKernel(T* o, T* q, T* kv_data, int32_t* kv_indptr,
int32_t* kv_indicies,
int32_t* last_page_offset, int head_dim,
int num_layers, int layer_idx,
int num_qo_heads, int num_kv_heads,
int page_size, int batch_size) {
flashinfer::paged_kv_t<T, int32_t> paged_kv(
paged_kv_t<PageStorage::kIndices, T, int32_t> paged_kv(
num_layers, layer_idx, num_kv_heads, page_size, head_dim, batch_size,
kv_data, kv_indptr, kv_indicies, last_page_offset);
kv_data, kv_indicies, kv_indptr, last_page_offset);
flashinfer::BatchDecodeWithPagedKVCache(q, paged_kv, o, nullptr, num_qo_heads,
flashinfer::RotaryMode::kLlama);
RotaryMode::kLlama);
}

template <int head_dim, typename T>
Expand All @@ -26,18 +30,18 @@ void FlashInferInitKvKernel(T* kv_data, int32_t* kv_indptr,
T* key, T* value, int32_t* seqlen_indptr,
int num_layers, int layer_idx, int num_kv_heads,
int page_size, int batch_size) {
flashinfer::paged_kv_t<T, int32_t> paged_kv(
paged_kv_t<PageStorage::kIndices, T, int32_t> paged_kv(
num_layers, layer_idx, num_kv_heads, page_size, head_dim, batch_size,
kv_data, kv_indptr, kv_indicies, last_page_offset);
kv_data, kv_indicies, kv_indptr, last_page_offset);

constexpr size_t vec_size =
std::max(16 / sizeof(T), static_cast<size_t>(head_dim / 32));
constexpr size_t bdx = head_dim / vec_size;
constexpr size_t bdy = 128 / bdx;
dim3 nblks(paged_kv.batch_size * paged_kv.num_heads / bdy);
dim3 nthrs(bdx, bdy);
flashinfer::AppendPagedKVCachePrefillKernel<head_dim, vec_size, bdx, bdy, T,
int32_t>
flashinfer::AppendPagedKVCachePrefillKernel<head_dim, vec_size, bdx, bdy,
PageStorage::kIndices, T, int32_t>
<<<nblks, nthrs>>>(paged_kv, key, value, seqlen_indptr);
}

Expand All @@ -46,18 +50,18 @@ void FlashInferAppendKvKernel(T* kv_data, int32_t* kv_indptr,
int32_t* kv_indicies, int32_t* last_page_offset,
T* key, T* value, int num_layers, int layer_idx,
int num_kv_heads, int page_size, int batch_size) {
flashinfer::paged_kv_t<T, int32_t> paged_kv(
paged_kv_t<PageStorage::kIndices, T, int32_t> paged_kv(
num_layers, layer_idx, num_kv_heads, page_size, head_dim, batch_size,
kv_data, kv_indptr, kv_indicies, last_page_offset);
kv_data, kv_indicies, kv_indptr, last_page_offset);

constexpr size_t vec_size =
std::max(16 / sizeof(T), static_cast<size_t>(head_dim / 32));
constexpr size_t bdx = head_dim / vec_size;
constexpr size_t bdy = 128 / bdx;
dim3 nblks(paged_kv.batch_size * paged_kv.num_heads / bdy);
dim3 nthrs(bdx, bdy);
flashinfer::AppendPagedKVCacheDecodeKernel<head_dim, vec_size, bdx, bdy, T,
int32_t>
flashinfer::AppendPagedKVCacheDecodeKernel<head_dim, vec_size, bdx, bdy,
PageStorage::kIndices, T, int32_t>
<<<nblks, nthrs>>>(paged_kv, key, value);
}

Expand Down
Loading

0 comments on commit f2fc15c

Please sign in to comment.