Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[KVCache] PagedKVCache Quantization #17159

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 56 additions & 51 deletions src/runtime/relax_vm/paged_kv_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -853,7 +853,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
/*!
* \brief The KV data managed by the KV cache.
* The array has `num_layers` NDArrays, each of them
* has layout (num_pages, 2, num_heads, page_size, head_dim).
* has layout (num_pages, 2, num_heads, page_size, num_storage).
* Along on the "2" dimension, index 0 stands for K and 1 stands for V.
*/
Array<NDArray> pages_;
Expand Down Expand Up @@ -985,10 +985,11 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
int64_t num_layers, int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim,
int64_t reserved_num_seqs, int64_t num_total_pages, int64_t prefill_chunk_size,
bool support_sliding_window, RoPEMode rope_mode, double rotary_scale, double rotary_theta,
DLDataType dtype, Device device, PackedFunc f_transpose_append, PackedFunc f_compact_copy,
PackedFunc f_attention_prefill, PackedFunc f_attention_decode,
PackedFunc f_attention_prefill_sliding_window, PackedFunc f_attention_decode_sliding_window,
PackedFunc f_attention_prefill_ragged, PackedFunc f_attention_prefill_with_tree_mask,
int64_t num_storage, DLDataType dtype, DLDataType kv_storage_dtype, Device device,
PackedFunc f_transpose_append, PackedFunc f_compact_copy, PackedFunc f_attention_prefill,
PackedFunc f_attention_decode, PackedFunc f_attention_prefill_sliding_window,
PackedFunc f_attention_decode_sliding_window, PackedFunc f_attention_prefill_ragged,
PackedFunc f_attention_prefill_with_tree_mask,
Optional<PackedFunc> f_attention_prefill_ragged_begin_forward,
Optional<PackedFunc> f_attention_prefill_ragged_end_forward,
Optional<PackedFunc> f_attention_prefill_begin_forward,
Expand Down Expand Up @@ -1030,8 +1031,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
device_(device) {
pages_.reserve(num_layers);
for (int i = 0; i < num_layers; ++i) {
pages_.push_back(
NDArray::Empty({num_total_pages, 2, num_kv_heads, page_size, head_dim}, dtype, device));
pages_.push_back(NDArray::Empty({num_total_pages, 2, num_kv_heads, page_size, num_storage},
kv_storage_dtype, device));
}
// Allocate the host memory.
Device preferred_host_device = GetPreferredHostDevice(device);
Expand Down Expand Up @@ -1673,8 +1674,6 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
NDArray o_data, double attn_score_scaling_factor) final {
// Part 1. Shape and dtype check.
NDArray pages = pages_[layer_id];
CHECK(qkv_data.DataType() == pages.DataType());
CHECK(o_data.DataType() == pages.DataType());

// qkv_data: (num_total_length, num_qo_heads + 2 * num_kv_heads, head_dim)
// o_data: (num_total_length, num_qo_heads, head_dim)
Expand Down Expand Up @@ -2433,7 +2432,7 @@ TVM_REGISTER_OBJECT_TYPE(PagedAttentionKVCacheObj);

TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
.set_body([](TVMArgs args, TVMRetValue* rv) {
CHECK(args.size() == 25 || args.size() == 26 || args.size() == 27)
CHECK(args.size() == 27 || args.size() == 28 || args.size() == 29)
<< "Invalid number of KV cache constructor args.";
ShapeTuple cache_config = args[0];
int64_t num_layers = args[1];
Expand All @@ -2443,31 +2442,33 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
int rope_mode = args[5];
double rotary_scale = args[6];
double rotary_theta = args[7];
NDArray init = args[8];
PackedFunc f_transpose_append = args[9];
PackedFunc f_attention_prefill = args[10];
PackedFunc f_attention_decode = args[11];
PackedFunc f_attention_prefill_sliding_window = args[12];
PackedFunc f_attention_decode_sliding_window = args[13];
PackedFunc f_attention_prefill_ragged = args[14];
PackedFunc f_attention_prefill_ragged_begin_forward = args[15];
PackedFunc f_attention_prefill_ragged_end_forward = args[16];
PackedFunc f_attention_prefill_begin_forward = args[17];
PackedFunc f_attention_prefill_end_forward = args[18];
PackedFunc f_attention_decode_begin_forward = args[19];
PackedFunc f_attention_decode_end_forward = args[20];
PackedFunc f_merge_inplace = args[21];
PackedFunc f_split_rotary = args[22];
PackedFunc f_copy_single_page = args[23];
Optional<PackedFunc> f_debug_get_kv = args[24];
int64_t num_storage = args[8];
NDArray init = args[9];
NDArray kv_storage_init = args[10];
PackedFunc f_transpose_append = args[11];
PackedFunc f_attention_prefill = args[12];
PackedFunc f_attention_decode = args[13];
PackedFunc f_attention_prefill_sliding_window = args[14];
PackedFunc f_attention_decode_sliding_window = args[15];
PackedFunc f_attention_prefill_ragged = args[16];
PackedFunc f_attention_prefill_ragged_begin_forward = args[17];
PackedFunc f_attention_prefill_ragged_end_forward = args[18];
PackedFunc f_attention_prefill_begin_forward = args[19];
PackedFunc f_attention_prefill_end_forward = args[20];
PackedFunc f_attention_decode_begin_forward = args[21];
PackedFunc f_attention_decode_end_forward = args[22];
PackedFunc f_merge_inplace = args[23];
PackedFunc f_split_rotary = args[24];
PackedFunc f_copy_single_page = args[25];
Optional<PackedFunc> f_debug_get_kv = args[26];
PackedFunc f_compact_copy{nullptr};
PackedFunc f_attention_prefill_with_tree_mask{nullptr};

if (args.size() >= 26) {
f_compact_copy = args[25].AsObjectRef<PackedFunc>();
if (args.size() >= 28) {
f_compact_copy = args[27].AsObjectRef<PackedFunc>();
}
if (args.size() >= 27) {
f_attention_prefill_with_tree_mask = args[26].AsObjectRef<PackedFunc>();
if (args.size() >= 29) {
f_attention_prefill_with_tree_mask = args[28].AsObjectRef<PackedFunc>();
}

CHECK_EQ(cache_config.size(), 5);
Expand All @@ -2484,8 +2485,9 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
ObjectPtr<PagedAttentionKVCacheObj> n = make_object<PagedAttentionKVCacheObj>(
page_size, num_layers, num_qo_heads, num_kv_heads, head_dim, reserved_num_seqs,
num_total_pages, prefill_chunk_size, support_sliding_window, RoPEMode(rope_mode),
rotary_scale, rotary_theta, init->dtype, init->device, std::move(f_transpose_append),
std::move(f_compact_copy), std::move(f_attention_prefill), std::move(f_attention_decode),
rotary_scale, rotary_theta, num_storage, init->dtype, kv_storage_init->dtype,
init->device, std::move(f_transpose_append), std::move(f_compact_copy),
std::move(f_attention_prefill), std::move(f_attention_decode),
std::move(f_attention_prefill_sliding_window),
std::move(f_attention_decode_sliding_window), std::move(f_attention_prefill_ragged),
std::move(f_attention_prefill_with_tree_mask),
Expand All @@ -2500,7 +2502,7 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")

TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
.set_body([](TVMArgs args, TVMRetValue* rv) {
CHECK(args.size() == 19 || args.size() == 20 || args.size() == 21)
CHECK(args.size() == 21 || args.size() == 22 || args.size() == 23)
<< "Invalid number of KV cache constructor args.";
ShapeTuple cache_config = args[0];
int64_t num_layers = args[1];
Expand All @@ -2510,25 +2512,27 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
int rope_mode = args[5];
double rotary_scale = args[6];
double rotary_theta = args[7];
NDArray init = args[8];
PackedFunc f_transpose_append = args[9];
PackedFunc f_attention_prefill = args[10];
PackedFunc f_attention_decode = args[11];
PackedFunc f_attention_prefill_sliding_window = args[12];
PackedFunc f_attention_decode_sliding_window = args[13];
PackedFunc f_attention_prefill_ragged = args[14];
PackedFunc f_merge_inplace = args[15];
PackedFunc f_split_rotary = args[16];
PackedFunc f_copy_single_page = args[17];
Optional<PackedFunc> f_debug_get_kv = args[18];
int64_t num_storage = args[8];
NDArray init = args[9];
NDArray kv_storage_init = args[10];
PackedFunc f_transpose_append = args[11];
PackedFunc f_attention_prefill = args[12];
PackedFunc f_attention_decode = args[13];
PackedFunc f_attention_prefill_sliding_window = args[14];
PackedFunc f_attention_decode_sliding_window = args[15];
PackedFunc f_attention_prefill_ragged = args[16];
PackedFunc f_merge_inplace = args[17];
PackedFunc f_split_rotary = args[18];
PackedFunc f_copy_single_page = args[19];
Optional<PackedFunc> f_debug_get_kv = args[20];
PackedFunc f_compact_copy{nullptr};
PackedFunc f_attention_prefill_with_tree_mask{nullptr};

if (args.size() >= 20) {
f_compact_copy = args[19].AsObjectRef<PackedFunc>();
if (args.size() >= 22) {
f_compact_copy = args[21].AsObjectRef<PackedFunc>();
}
if (args.size() >= 21) {
f_attention_prefill_with_tree_mask = args[20].AsObjectRef<PackedFunc>();
if (args.size() >= 23) {
f_attention_prefill_with_tree_mask = args[22].AsObjectRef<PackedFunc>();
}

CHECK_EQ(cache_config.size(), 5);
Expand All @@ -2545,8 +2549,9 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
ObjectPtr<PagedAttentionKVCacheObj> n = make_object<PagedAttentionKVCacheObj>(
page_size, num_layers, num_qo_heads, num_kv_heads, head_dim, reserved_num_seqs,
num_total_pages, prefill_chunk_size, support_sliding_window, RoPEMode(rope_mode),
rotary_scale, rotary_theta, init->dtype, init->device, std::move(f_transpose_append),
std::move(f_compact_copy), std::move(f_attention_prefill), std::move(f_attention_decode),
rotary_scale, rotary_theta, num_storage, init->dtype, kv_storage_init->dtype,
init->device, std::move(f_transpose_append), std::move(f_compact_copy),
std::move(f_attention_prefill), std::move(f_attention_decode),
std::move(f_attention_prefill_sliding_window),
std::move(f_attention_decode_sliding_window), std::move(f_attention_prefill_ragged),
std::move(f_attention_prefill_with_tree_mask), //
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,9 @@ def set_global_func():

def create_kv_cache(rope_mode):
support_sliding_window = 0
num_storage = head_dim
kv_storage_dtype = dtype

cache = fcreate(
tvm.runtime.ShapeTuple(
[
Expand All @@ -361,7 +364,9 @@ def create_kv_cache(rope_mode):
rope_mode,
rope_scale,
rope_theta,
num_storage,
tvm.nd.empty((), dtype, device=device),
tvm.nd.empty((), kv_storage_dtype, device=device),
ftranspose_append,
fattention_prefill,
fattention_decode,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ def set_global_func(head_dim, dtype):


def create_kv_cache(head_dim, dtype, rope_mode, support_sliding_window):
num_storage = head_dim
kv_storage_dtype = dtype

fcreate = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_create_reduced")
cache = fcreate(
tvm.runtime.ShapeTuple(
Expand All @@ -160,7 +163,9 @@ def create_kv_cache(head_dim, dtype, rope_mode, support_sliding_window):
rope_mode,
rope_scale,
rope_theta,
num_storage,
tvm.nd.empty((), dtype, device=device),
tvm.nd.empty((), kv_storage_dtype, device=device),
ftranspose_append,
fattn_prefill,
fattn_decode,
Expand Down
Loading