Skip to content

feat: First pass at llama_kv_cache_hybrid #13276

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 225 additions & 0 deletions src/llama-kv-cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2384,6 +2384,231 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
return true;
}

//
// llama_kv_cache_hybrid
//
llama_kv_cache_hybrid::llama_kv_cache_hybrid(
const llama_hparams & hparams,
const std::vector<child_cache> & children) :
m_hparams(hparams),
m_layer_cache_map(
[](const std::vector<child_cache>& caches) -> std::unordered_map<size_t, llama_kv_cache*> {
std::unordered_map<size_t, llama_kv_cache*> map;
for (const auto & cache : caches) {
for (size_t layer_id : cache.layer_ids) {
map[layer_id] = cache.child;
}
}

return map;
}(children)
),
m_children(
[](std::vector<child_cache> caches) -> std::set<llama_kv_cache*> {
// Sort the caches by the lowest layer ID so the order is repeatable
for (auto & cache : caches) {
GGML_ASSERT(cache.layer_ids.size() > 0);
std::sort(cache.layer_ids.begin(), cache.layer_ids.end());
}
std::sort(caches.begin(), caches.end(), [](const child_cache & a, const child_cache & b) {
return a.layer_ids[0] < b.layer_ids[0];
});
std::set<llama_kv_cache*> unique_caches;
for (const auto & cache : caches) {
unique_caches.insert(cache.child);
}
return unique_caches;
}(children)
),
m_has_recurrent(
[](const std::vector<child_cache>& caches) -> bool {
for (const auto & cache : caches) {
if (dynamic_cast<llama_kv_cache_recurrent *>(cache.child)) {
return true;
}
}
return false;
}(children)
)
{
// Ensure at least one child
GGML_ASSERT(m_children.size() > 0);

// Ensure layers are not overlapping and are concurrent
std::set<size_t> seen_layers;
size_t max_layer = 0;
for (const auto & cache : children) {
for (const auto & layer_id : cache.layer_ids) {
GGML_ASSERT(seen_layers.find(layer_id) == seen_layers.end());
seen_layers.insert(layer_id);
if (layer_id > max_layer) {
max_layer = layer_id;
}
}
}
GGML_ASSERT(max_layer == seen_layers.size());
}

void llama_kv_cache_hybrid::clear() {
for (const auto & cache : m_children) {
cache->clear();
}
}

bool llama_kv_cache_hybrid::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
// TODO: Will it cause problems if some caches are able to remove the seq
// but others aren't?
Comment on lines +2459 to +2460
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it will cause problems if this breaks the coherency between caches. (e.g. part of a sequence is removed in one cache but not the other).

This is what I was referring to in #12799 (comment) when I wrote:

The hardest part will be handling errors and properly keeping coherency between the different types of caches (because they don't necessarily roll-back states in the same way).

I think the seq_rm API might fundamentally be too specific to self-attention KV cache. Recurrent models can't rollback their state, because intermediate states are not kept since keeping them for all tokens would take too much space. (when seq_rm returns false, it means the states have to be re-calculated from scratch for the affected sequence (at least that was the intention in #5328))

Ideally, if there was some API to create snapshots and rollback to them, the implementation would be simpler for recurrent models (and for hybrid models by extension). (technically, sequences (with seq_id) already kind of do this (and are copy-on-write), but snapshots within sequences might be more convenient to manage in user code, since managing which state is the latest per sequence could be done transparently)

But that would also mean having to manage the lifetime of explicit state snapshots (in examples/server/server.cpp among others) instead of directly dealing with ranges of token positions (and might make things like largest-common-prefix context caching harder to handle). I've previously shared some ideas about state snapshots/checkpoints in #7531 (comment) (although the first half of the comment is about session restore as in state_read).

bool removed = true;
for (const auto & cache : m_children) {
removed = cache->seq_rm(seq_id, p0, p1) && removed;
}
return removed;
}

void llama_kv_cache_hybrid::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
for (const auto & cache : m_children) {
cache->seq_cp(seq_id_src, seq_id_dst, p0, p1);
}
}

void llama_kv_cache_hybrid::seq_keep(llama_seq_id seq_id) {
for (const auto & cache : m_children) {
cache->seq_keep(seq_id);
}
}

void llama_kv_cache_hybrid::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
for (const auto & cache : m_children) {
cache->seq_add(seq_id, p0, p1, delta);
}
}

void llama_kv_cache_hybrid::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
for (const auto & cache : m_children) {
cache->seq_div(seq_id, p0, p1, d);
}
}

llama_pos llama_kv_cache_hybrid::seq_pos_max(llama_seq_id seq_id) const {
llama_pos max_pos = 0;
for (const auto & cache : m_children) {
max_pos = std::max(max_pos, cache->seq_pos_max(seq_id));
}
return max_pos;
}

void llama_kv_cache_hybrid::restore() {
for (const auto & cache : m_children) {
cache->restore();
}
}

void llama_kv_cache_hybrid::commit() {
for (const auto & cache : m_children) {
cache->commit();
}
}

bool llama_kv_cache_hybrid::update(llama_context & ctx) {
bool updated = false;
for (const auto & cache : m_children) {
updated = cache->update(ctx) || updated;
}
return updated;
}

void llama_kv_cache_hybrid::defrag_sched(float thold) {
for (const auto & cache : m_children) {
cache->defrag_sched(thold);
}
}

void llama_kv_cache_hybrid::set_full() {
for (const auto & cache : m_children) {
cache->set_full();
}
}

llama_sbatch llama_kv_cache_hybrid::sbatch_init(const llama_batch & batch, bool logits_all) {
// If any of the caches are recurrent, require simple split
return llama_sbatch(batch, m_hparams.n_embd, m_has_recurrent, logits_all);
Comment on lines +2533 to +2534
Copy link
Collaborator

@compilade compilade May 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simple split should not be used with recurrent models, they expect equal split.

See #7531 (comment) which illustrates the splits

Suggested change
// If any of the caches are recurrent, require simple split
return llama_sbatch(batch, m_hparams.n_embd, m_has_recurrent, logits_all);
// If any of the caches are recurrent, require non-simple split
return llama_sbatch(batch, m_hparams.n_embd, !m_has_recurrent, logits_all);

}

llama_ubatch llama_kv_cache_hybrid::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
if (m_has_recurrent) {
return sbatch.split_simple(n_ubatch);
}
Comment on lines +2538 to +2540
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will not work, recurrent models expect split_equal to be used.

if (embd_pooled) {
// Pooled embeddings cannot be split across ubatches (yet)
return sbatch.split_seq(n_ubatch);
}
return sbatch.split_equal(n_ubatch);
}

bool llama_kv_cache_hybrid::find_slot(const llama_ubatch & batch) {
bool found = true;
for (const auto & cache : m_children) {
found = cache->find_slot(batch) && found;
}
return found;
}

int32_t llama_kv_cache_hybrid::get_n_tokens() const {
// The number of tokens should be the same across all child caches
int32_t n_tokens = -1;
for (const auto & cache : m_children) {
const auto cache_n_tokens = cache->get_n_tokens();
GGML_ASSERT(n_tokens == -1 || cache_n_tokens == n_tokens);
n_tokens = cache_n_tokens;
}
return n_tokens;
}

int32_t llama_kv_cache_hybrid::get_used_cells() const {
// TODO: Is this correct?
// Return the largetst number of used cells
int32_t used_cells = -1;
for (const auto & cache : m_children) {
used_cells = std::max(used_cells, cache->get_used_cells());
}
return used_cells;
}

llama_pos llama_kv_cache_hybrid::get_pos_max() const {
llama_pos pos_max = -1;
for (const auto & cache : m_children) {
pos_max = std::max(pos_max, cache->get_pos_max());
}
return pos_max;
}

bool llama_kv_cache_hybrid::get_can_shift() const {
// TODO: Is this correct?
// If any children can shift, return true
for (const auto & cache : m_children) {
if (cache->get_can_shift()) {
return true;
}
}
Comment on lines +2586 to +2592
Copy link
Collaborator

@compilade compilade May 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this should be if all children can shift, then return true.

But as you've noticed elsewhere, can_shift should technically always be true for all currently-implemented cache types, so I don't know if that part of the API will stay anyway.

return false;
}

void llama_kv_cache_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
// Write each cache state in order. Note that order is guaranteed at
// initialization by using an ordered set sorted by lowest layer ID
for (const auto & cache : m_children) {
cache->state_write(io, seq_id);
}
}

void llama_kv_cache_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
// Read each cache state in order. Note that order is guaranteed at
// initialization by using an ordered set sorted by lowest layer ID
for (const auto & cache : m_children) {
cache->state_read(io, seq_id);
}
}

//
// kv cache view
//
Expand Down
74 changes: 74 additions & 0 deletions src/llama-kv-cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include <set>
#include <vector>
#include <unordered_map>

struct llama_cparams;
struct llama_hparams;
Expand Down Expand Up @@ -395,6 +396,79 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
};

//
// llama_kv_cache_hybrid
//

class llama_kv_cache_hybrid : public llama_kv_cache {
public:

struct child_cache {
llama_kv_cache * child;
std::vector<size_t> layer_ids;
};

llama_kv_cache_hybrid(
const llama_hparams & hparams,
const std::vector<child_cache> & children);

//
// llama_memory_i
//

void clear() override;

bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id) override;
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;

llama_pos seq_pos_max(llama_seq_id seq_id) const override;

//
// llama_kv_cache
//

void restore() override;
void commit() override;

bool update(llama_context & ctx) override;

void defrag_sched(float thold) override;

void set_full() override;

llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;

llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;

// updates the cache head
// Note: On success, it's important that cache.head points
// to the first cell of the slot.
bool find_slot(const llama_ubatch & batch) override;

int32_t get_n_tokens() const override;
int32_t get_used_cells() const override;

// TODO: better data structures to reduce the cost of this operation
llama_pos get_pos_max() const override;

bool get_can_shift() const override;

// state write/load

void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;

private:

const llama_hparams & m_hparams;
const std::unordered_map<size_t, llama_kv_cache *> m_layer_cache_map;
const std::set<llama_kv_cache *> m_children; // Ordered for state IO
const bool m_has_recurrent;
};


//
// kv cache view
Expand Down
Loading