-
Notifications
You must be signed in to change notification settings - Fork 11.6k
feat: First pass at llama_kv_cache_hybrid #13276
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -2384,6 +2384,231 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce | |||||||||
return true; | ||||||||||
} | ||||||||||
|
||||||||||
// | ||||||||||
// llama_kv_cache_hybrid | ||||||||||
// | ||||||||||
llama_kv_cache_hybrid::llama_kv_cache_hybrid( | ||||||||||
const llama_hparams & hparams, | ||||||||||
const std::vector<child_cache> & children) : | ||||||||||
m_hparams(hparams), | ||||||||||
m_layer_cache_map( | ||||||||||
[](const std::vector<child_cache>& caches) -> std::unordered_map<size_t, llama_kv_cache*> { | ||||||||||
std::unordered_map<size_t, llama_kv_cache*> map; | ||||||||||
for (const auto & cache : caches) { | ||||||||||
for (size_t layer_id : cache.layer_ids) { | ||||||||||
map[layer_id] = cache.child; | ||||||||||
} | ||||||||||
} | ||||||||||
|
||||||||||
return map; | ||||||||||
}(children) | ||||||||||
), | ||||||||||
m_children( | ||||||||||
[](std::vector<child_cache> caches) -> std::set<llama_kv_cache*> { | ||||||||||
// Sort the caches by the lowest layer ID so the order is repeatable | ||||||||||
for (auto & cache : caches) { | ||||||||||
GGML_ASSERT(cache.layer_ids.size() > 0); | ||||||||||
std::sort(cache.layer_ids.begin(), cache.layer_ids.end()); | ||||||||||
} | ||||||||||
std::sort(caches.begin(), caches.end(), [](const child_cache & a, const child_cache & b) { | ||||||||||
return a.layer_ids[0] < b.layer_ids[0]; | ||||||||||
}); | ||||||||||
std::set<llama_kv_cache*> unique_caches; | ||||||||||
for (const auto & cache : caches) { | ||||||||||
unique_caches.insert(cache.child); | ||||||||||
} | ||||||||||
return unique_caches; | ||||||||||
}(children) | ||||||||||
), | ||||||||||
m_has_recurrent( | ||||||||||
[](const std::vector<child_cache>& caches) -> bool { | ||||||||||
for (const auto & cache : caches) { | ||||||||||
if (dynamic_cast<llama_kv_cache_recurrent *>(cache.child)) { | ||||||||||
return true; | ||||||||||
} | ||||||||||
} | ||||||||||
return false; | ||||||||||
}(children) | ||||||||||
) | ||||||||||
{ | ||||||||||
// Ensure at least one child | ||||||||||
GGML_ASSERT(m_children.size() > 0); | ||||||||||
|
||||||||||
// Ensure layers are not overlapping and are concurrent | ||||||||||
std::set<size_t> seen_layers; | ||||||||||
size_t max_layer = 0; | ||||||||||
for (const auto & cache : children) { | ||||||||||
for (const auto & layer_id : cache.layer_ids) { | ||||||||||
GGML_ASSERT(seen_layers.find(layer_id) == seen_layers.end()); | ||||||||||
seen_layers.insert(layer_id); | ||||||||||
if (layer_id > max_layer) { | ||||||||||
max_layer = layer_id; | ||||||||||
} | ||||||||||
} | ||||||||||
} | ||||||||||
GGML_ASSERT(max_layer == seen_layers.size()); | ||||||||||
} | ||||||||||
|
||||||||||
void llama_kv_cache_hybrid::clear() { | ||||||||||
for (const auto & cache : m_children) { | ||||||||||
cache->clear(); | ||||||||||
} | ||||||||||
} | ||||||||||
|
||||||||||
bool llama_kv_cache_hybrid::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { | ||||||||||
// TODO: Will it cause problems if some caches are able to remove the seq | ||||||||||
// but others aren't? | ||||||||||
bool removed = true; | ||||||||||
for (const auto & cache : m_children) { | ||||||||||
removed = cache->seq_rm(seq_id, p0, p1) && removed; | ||||||||||
} | ||||||||||
return removed; | ||||||||||
} | ||||||||||
|
||||||||||
void llama_kv_cache_hybrid::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { | ||||||||||
for (const auto & cache : m_children) { | ||||||||||
cache->seq_cp(seq_id_src, seq_id_dst, p0, p1); | ||||||||||
} | ||||||||||
} | ||||||||||
|
||||||||||
void llama_kv_cache_hybrid::seq_keep(llama_seq_id seq_id) { | ||||||||||
for (const auto & cache : m_children) { | ||||||||||
cache->seq_keep(seq_id); | ||||||||||
} | ||||||||||
} | ||||||||||
|
||||||||||
void llama_kv_cache_hybrid::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { | ||||||||||
for (const auto & cache : m_children) { | ||||||||||
cache->seq_add(seq_id, p0, p1, delta); | ||||||||||
} | ||||||||||
} | ||||||||||
|
||||||||||
void llama_kv_cache_hybrid::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { | ||||||||||
for (const auto & cache : m_children) { | ||||||||||
cache->seq_div(seq_id, p0, p1, d); | ||||||||||
} | ||||||||||
} | ||||||||||
|
||||||||||
llama_pos llama_kv_cache_hybrid::seq_pos_max(llama_seq_id seq_id) const { | ||||||||||
llama_pos max_pos = 0; | ||||||||||
for (const auto & cache : m_children) { | ||||||||||
max_pos = std::max(max_pos, cache->seq_pos_max(seq_id)); | ||||||||||
} | ||||||||||
return max_pos; | ||||||||||
} | ||||||||||
|
||||||||||
void llama_kv_cache_hybrid::restore() { | ||||||||||
for (const auto & cache : m_children) { | ||||||||||
cache->restore(); | ||||||||||
} | ||||||||||
} | ||||||||||
|
||||||||||
void llama_kv_cache_hybrid::commit() { | ||||||||||
for (const auto & cache : m_children) { | ||||||||||
cache->commit(); | ||||||||||
} | ||||||||||
} | ||||||||||
|
||||||||||
bool llama_kv_cache_hybrid::update(llama_context & ctx) { | ||||||||||
bool updated = false; | ||||||||||
for (const auto & cache : m_children) { | ||||||||||
updated = cache->update(ctx) || updated; | ||||||||||
} | ||||||||||
return updated; | ||||||||||
} | ||||||||||
|
||||||||||
void llama_kv_cache_hybrid::defrag_sched(float thold) { | ||||||||||
for (const auto & cache : m_children) { | ||||||||||
cache->defrag_sched(thold); | ||||||||||
} | ||||||||||
} | ||||||||||
|
||||||||||
void llama_kv_cache_hybrid::set_full() { | ||||||||||
for (const auto & cache : m_children) { | ||||||||||
cache->set_full(); | ||||||||||
} | ||||||||||
} | ||||||||||
|
||||||||||
llama_sbatch llama_kv_cache_hybrid::sbatch_init(const llama_batch & batch, bool logits_all) { | ||||||||||
// If any of the caches are recurrent, require simple split | ||||||||||
return llama_sbatch(batch, m_hparams.n_embd, m_has_recurrent, logits_all); | ||||||||||
Comment on lines
+2533
to
+2534
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Simple split should not be used with recurrent models, they expect equal split. See #7531 (comment) which illustrates the splits
Suggested change
|
||||||||||
} | ||||||||||
|
||||||||||
llama_ubatch llama_kv_cache_hybrid::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const { | ||||||||||
if (m_has_recurrent) { | ||||||||||
return sbatch.split_simple(n_ubatch); | ||||||||||
} | ||||||||||
Comment on lines
+2538
to
+2540
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will not work, recurrent models expect |
||||||||||
if (embd_pooled) { | ||||||||||
// Pooled embeddings cannot be split across ubatches (yet) | ||||||||||
return sbatch.split_seq(n_ubatch); | ||||||||||
} | ||||||||||
return sbatch.split_equal(n_ubatch); | ||||||||||
} | ||||||||||
|
||||||||||
bool llama_kv_cache_hybrid::find_slot(const llama_ubatch & batch) { | ||||||||||
bool found = true; | ||||||||||
for (const auto & cache : m_children) { | ||||||||||
found = cache->find_slot(batch) && found; | ||||||||||
} | ||||||||||
return found; | ||||||||||
} | ||||||||||
|
||||||||||
int32_t llama_kv_cache_hybrid::get_n_tokens() const { | ||||||||||
// The number of tokens should be the same across all child caches | ||||||||||
int32_t n_tokens = -1; | ||||||||||
for (const auto & cache : m_children) { | ||||||||||
const auto cache_n_tokens = cache->get_n_tokens(); | ||||||||||
GGML_ASSERT(n_tokens == -1 || cache_n_tokens == n_tokens); | ||||||||||
n_tokens = cache_n_tokens; | ||||||||||
} | ||||||||||
return n_tokens; | ||||||||||
} | ||||||||||
|
||||||||||
int32_t llama_kv_cache_hybrid::get_used_cells() const { | ||||||||||
// TODO: Is this correct? | ||||||||||
// Return the largetst number of used cells | ||||||||||
int32_t used_cells = -1; | ||||||||||
for (const auto & cache : m_children) { | ||||||||||
used_cells = std::max(used_cells, cache->get_used_cells()); | ||||||||||
} | ||||||||||
return used_cells; | ||||||||||
} | ||||||||||
|
||||||||||
llama_pos llama_kv_cache_hybrid::get_pos_max() const { | ||||||||||
llama_pos pos_max = -1; | ||||||||||
for (const auto & cache : m_children) { | ||||||||||
pos_max = std::max(pos_max, cache->get_pos_max()); | ||||||||||
} | ||||||||||
return pos_max; | ||||||||||
} | ||||||||||
|
||||||||||
bool llama_kv_cache_hybrid::get_can_shift() const { | ||||||||||
// TODO: Is this correct? | ||||||||||
// If any children can shift, return true | ||||||||||
for (const auto & cache : m_children) { | ||||||||||
if (cache->get_can_shift()) { | ||||||||||
return true; | ||||||||||
} | ||||||||||
} | ||||||||||
Comment on lines
+2586
to
+2592
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe this should be if all children can shift, then return true. But as you've noticed elsewhere, |
||||||||||
return false; | ||||||||||
} | ||||||||||
|
||||||||||
void llama_kv_cache_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { | ||||||||||
// Write each cache state in order. Note that order is guaranteed at | ||||||||||
// initialization by using an ordered set sorted by lowest layer ID | ||||||||||
for (const auto & cache : m_children) { | ||||||||||
cache->state_write(io, seq_id); | ||||||||||
} | ||||||||||
} | ||||||||||
|
||||||||||
void llama_kv_cache_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id) { | ||||||||||
// Read each cache state in order. Note that order is guaranteed at | ||||||||||
// initialization by using an ordered set sorted by lowest layer ID | ||||||||||
for (const auto & cache : m_children) { | ||||||||||
cache->state_read(io, seq_id); | ||||||||||
} | ||||||||||
} | ||||||||||
|
||||||||||
// | ||||||||||
// kv cache view | ||||||||||
// | ||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it will cause problems if this breaks the coherency between caches. (e.g. part of a sequence is removed in one cache but not the other).
This is what I was referring to in #12799 (comment) when I wrote:
I think the
seq_rm
API might fundamentally be too specific to self-attention KV cache. Recurrent models can't rollback their state, because intermediate states are not kept since keeping them for all tokens would take too much space. (whenseq_rm
returns false, it means the states have to be re-calculated from scratch for the affected sequence (at least that was the intention in #5328))Ideally, if there was some API to create snapshots and rollback to them, the implementation would be simpler for recurrent models (and for hybrid models by extension). (technically, sequences (with
seq_id
) already kind of do this (and are copy-on-write), but snapshots within sequences might be more convenient to manage in user code, since managing which state is the latest per sequence could be done transparently)But that would also mean having to manage the lifetime of explicit state snapshots (in
examples/server/server.cpp
among others) instead of directly dealing with ranges of token positions (and might make things like largest-common-prefix context caching harder to handle). I've previously shared some ideas about state snapshots/checkpoints in #7531 (comment) (although the first half of the comment is about session restore as instate_read
).