Skip to content

Commit 73df685

Browse files
committed
kv-cache : replace struct graph_params with llama_context &
ggml-ci
1 parent 7e4b545 commit 73df685

File tree

4 files changed

+85
-85
lines changed

4 files changed

+85
-85
lines changed

src/llama-context.cpp

+17-9
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,22 @@ const llama_model & llama_context::get_model() const {
396396
return model;
397397
}
398398

399+
const llama_cparams & llama_context::get_cparams() const {
400+
return cparams;
401+
}
402+
403+
const ggml_backend_sched_ptr & llama_context::get_sched() const {
404+
return sched;
405+
}
406+
407+
const ggml_context_ptr & llama_context::get_ctx_compute() const {
408+
return ctx_compute;
409+
}
410+
411+
const std::vector<ggml_backend_ptr> & llama_context::get_backends() const {
412+
return backends;
413+
}
414+
399415
uint32_t llama_context::n_ctx() const {
400416
return cparams.n_ctx;
401417
}
@@ -439,15 +455,7 @@ void llama_context::kv_self_update() {
439455

440456
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
441457

442-
need_reserve = kv_self->update({
443-
/*.cparams =*/ cparams,
444-
/*.sched =*/ sched.get(),
445-
/*.backends =*/ backends,
446-
/*.n_max_nodes =*/ graph_max_nodes(),
447-
/*.get_ctx_compute =*/ [this]() { return ctx_compute.get(); },
448-
/*.graph_init =*/ [this]() { return graph_init(); },
449-
/*.graph_compute =*/ [this](ggml_cgraph * gf) { graph_compute(gf, false); },
450-
});
458+
need_reserve = kv_self->update(*this);
451459

452460
// reserve a worst case graph if needed
453461
if (need_reserve) {

src/llama-context.h

+15-6
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,14 @@ struct llama_context {
2727

2828
void synchronize();
2929

30-
const llama_model & get_model() const;
30+
const llama_model & get_model() const;
31+
const llama_cparams & get_cparams() const;
32+
33+
const ggml_backend_sched_ptr & get_sched() const;
34+
35+
const ggml_context_ptr & get_ctx_compute() const;
36+
37+
const std::vector<ggml_backend_ptr> & get_backends() const;
3138

3239
uint32_t n_ctx() const;
3340
uint32_t n_ctx_per_seq() const;
@@ -141,22 +148,24 @@ struct llama_context {
141148
// graph
142149
//
143150

151+
public:
144152
int32_t graph_max_nodes() const;
145153

146154
// zero-out inputs and create the ctx_compute for the compute graph
147155
ggml_cgraph * graph_init();
148156

157+
// returns the result of ggml_backend_sched_graph_compute_async execution
158+
ggml_status graph_compute(
159+
ggml_cgraph * gf,
160+
bool batched);
161+
162+
private:
149163
llm_graph_result_ptr graph_build(
150164
ggml_context * ctx,
151165
ggml_cgraph * gf,
152166
const llama_ubatch & ubatch,
153167
llm_graph_type gtype);
154168

155-
// returns the result of ggml_backend_sched_graph_compute_async execution
156-
ggml_status graph_compute(
157-
ggml_cgraph * gf,
158-
bool batched);
159-
160169
llm_graph_cb graph_get_cb() const;
161170

162171
// TODO: read/write lora adapters and cvec

src/llama-kv-cache.cpp

+37-36
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "llama-batch.h"
55
#include "llama-cparams.h"
66
#include "llama-model.h"
7+
#include "llama-context.h"
78

89
#include <algorithm>
910
#include <cassert>
@@ -367,10 +368,10 @@ void llama_kv_cache_unified::commit() {
367368
pending.ranges.clear();
368369
}
369370

370-
bool llama_kv_cache_unified::update(const graph_params & params) {
371+
bool llama_kv_cache_unified::update(llama_context & lctx) {
371372
bool need_reserve = false;
372373

373-
const auto & sched = params.sched;
374+
const auto & sched = lctx.get_sched();
374375

375376
if (has_shift) {
376377
if (!get_can_shift()) {
@@ -381,17 +382,17 @@ bool llama_kv_cache_unified::update(const graph_params & params) {
381382

382383
// apply K-shift if needed
383384
if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
384-
ggml_backend_sched_reset(sched);
385+
ggml_backend_sched_reset(sched.get());
385386

386-
auto * gf = params.graph_init();
387+
auto * gf = lctx.graph_init();
387388

388-
auto res = build_graph_shift(params, gf);
389+
auto res = build_graph_shift(lctx, gf);
389390

390-
ggml_backend_sched_alloc_graph(sched, gf);
391+
ggml_backend_sched_alloc_graph(sched.get(), gf);
391392

392393
res->set_inputs(nullptr);
393394

394-
params.graph_compute(gf);
395+
lctx.graph_compute(gf, false);
395396

396397
need_reserve = true;
397398
}
@@ -408,18 +409,18 @@ bool llama_kv_cache_unified::update(const graph_params & params) {
408409
if (do_defrag) {
409410
LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
410411

411-
if (defrag_prepare(params.n_max_nodes)) {
412-
ggml_backend_sched_reset(sched);
412+
if (defrag_prepare(lctx.graph_max_nodes())) {
413+
ggml_backend_sched_reset(sched.get());
413414

414-
auto * gf = params.graph_init();
415+
auto * gf = lctx.graph_init();
415416

416-
auto res = build_graph_defrag(params, gf);
417+
auto res = build_graph_defrag(lctx, gf);
417418

418-
ggml_backend_sched_alloc_graph(sched, gf);
419+
ggml_backend_sched_alloc_graph(sched.get(), gf);
419420

420421
res->set_inputs(nullptr);
421422

422-
params.graph_compute(gf);
423+
lctx.graph_compute(gf, false);
423424

424425
need_reserve = true;
425426
}
@@ -591,17 +592,17 @@ size_t llama_kv_cache_unified::size_v_bytes() const {
591592
}
592593

593594
ggml_tensor * llama_kv_cache_unified::build_rope_shift(
594-
const graph_params & params,
595-
ggml_context * ctx,
596-
ggml_tensor * cur,
597-
ggml_tensor * shift,
598-
ggml_tensor * factors,
599-
float freq_base,
600-
float freq_scale,
601-
ggml_backend_buffer * bbuf) const {
602-
const auto & cparams = params.cparams;
603-
const auto & backends = params.backends;
604-
const auto & sched = params.sched;
595+
llama_context & lctx,
596+
ggml_context * ctx,
597+
ggml_tensor * cur,
598+
ggml_tensor * shift,
599+
ggml_tensor * factors,
600+
float freq_base,
601+
float freq_scale,
602+
ggml_backend_buffer * bbuf) const {
603+
const auto & cparams = lctx.get_cparams();
604+
const auto & backends = lctx.get_backends();
605+
const auto & sched = lctx.get_sched();
605606

606607
const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
607608

@@ -626,7 +627,7 @@ ggml_tensor * llama_kv_cache_unified::build_rope_shift(
626627
for (const auto & backend : backends) {
627628
// Figure out which backend KV cache belongs to
628629
if (ggml_backend_supports_buft(backend.get(), ggml_backend_buffer_get_type(bbuf))) {
629-
ggml_backend_sched_set_tensor_backend(sched, tmp, backend.get());
630+
ggml_backend_sched_set_tensor_backend(sched.get(), tmp, backend.get());
630631
break;
631632
}
632633
}
@@ -674,13 +675,13 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
674675
}
675676

676677
llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
677-
const graph_params & params,
678-
ggml_cgraph * gf) const {
678+
llama_context & lctx,
679+
ggml_cgraph * gf) const {
679680
auto res = std::make_unique<llm_graph_result>();
680681

681-
auto * ctx = params.get_ctx_compute();
682+
auto * ctx = lctx.get_ctx_compute().get();
682683

683-
const auto & cparams = params.cparams;
684+
const auto & cparams = lctx.get_cparams();
684685

685686
const auto & n_layer = hparams.n_layer;
686687

@@ -716,7 +717,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
716717
ggml_row_size(k_l[il]->type, n_embd_k_gqa),
717718
0);
718719

719-
ggml_tensor * cur = build_rope_shift(params, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, k_l[il]->buffer);
720+
ggml_tensor * cur = build_rope_shift(lctx, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, k_l[il]->buffer);
720721

721722
ggml_build_forward_expand(gf, cur);
722723
}
@@ -727,15 +728,15 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
727728
}
728729

729730
llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
730-
const graph_params & params,
731-
ggml_cgraph * gf) const {
731+
llama_context & lctx,
732+
ggml_cgraph * gf) const {
732733
auto res = std::make_unique<llm_graph_result>();
733734

734-
auto * ctx = params.get_ctx_compute();
735+
auto * ctx = lctx.get_ctx_compute().get();
735736

736737
const auto & ids = defrag_info.ids;
737738

738-
const auto & cparams = params.cparams;
739+
const auto & cparams = lctx.get_cparams();
739740

740741
#if 0
741742
// CPU defrag
@@ -1725,8 +1726,8 @@ void llama_kv_cache_recurrent::commit() {
17251726
pending.ranges.clear();
17261727
}
17271728

1728-
bool llama_kv_cache_recurrent::update(const graph_params & params) {
1729-
GGML_UNUSED(params);
1729+
bool llama_kv_cache_recurrent::update(llama_context & lctx) {
1730+
GGML_UNUSED(lctx);
17301731
return false;
17311732
}
17321733

src/llama-kv-cache.h

+16-34
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
#include "ggml-cpp.h"
99

10-
#include <functional>
1110
#include <set>
1211
#include <vector>
1312

@@ -16,26 +15,9 @@ struct llama_hparams;
1615
struct llama_ubatch;
1716
struct llama_sbatch;
1817
struct llama_model;
18+
struct llama_context;
1919

2020
struct llama_kv_cache : public llama_memory_i {
21-
struct graph_params {
22-
const llama_cparams & cparams;
23-
24-
const ggml_backend_sched_t & sched;
25-
26-
const std::vector<ggml_backend_ptr> & backends;
27-
28-
int32_t n_max_nodes;
29-
30-
std::function<ggml_context * ()> get_ctx_compute;
31-
32-
// function for creating ggml graphs
33-
std::function<ggml_cgraph * ()> graph_init;
34-
35-
// function for computing ggml graphs
36-
std::function<void(ggml_cgraph * gf)> graph_compute;
37-
};
38-
3921
virtual ~llama_kv_cache() = default;
4022

4123
// call if batch processing fails - restores the cache state
@@ -46,7 +28,7 @@ struct llama_kv_cache : public llama_memory_i {
4628

4729
// process any pending defrag/shift/etc. operations
4830
// optionally call once before processing a new batch
49-
virtual bool update(const graph_params & params) = 0;
31+
virtual bool update(llama_context & lctx) = 0;
5032

5133
// schedule a defrag if the fragmentation threshold is exceeded. otherwise, do nothing
5234
virtual void defrag_sched(float thold) = 0;
@@ -161,7 +143,7 @@ class llama_kv_cache_unified : public llama_kv_cache {
161143
void restore() override;
162144
void commit() override;
163145

164-
bool update(const graph_params & params) override;
146+
bool update(llama_context & ctx) override;
165147

166148
void defrag_sched(float thold) override;
167149

@@ -251,22 +233,22 @@ class llama_kv_cache_unified : public llama_kv_cache {
251233
size_t size_v_bytes() const;
252234

253235
ggml_tensor * build_rope_shift(
254-
const graph_params & params,
255-
ggml_context * ctx,
256-
ggml_tensor * cur,
257-
ggml_tensor * shift,
258-
ggml_tensor * factors,
259-
float freq_base,
260-
float freq_scale,
261-
ggml_backend_buffer * bbuf) const;
236+
llama_context & lctx,
237+
ggml_context * ctx,
238+
ggml_tensor * cur,
239+
ggml_tensor * shift,
240+
ggml_tensor * factors,
241+
float freq_base,
242+
float freq_scale,
243+
ggml_backend_buffer * bbuf) const;
262244

263245
llm_graph_result_ptr build_graph_shift(
264-
const graph_params & params,
265-
ggml_cgraph * gf) const;
246+
llama_context & lctx,
247+
ggml_cgraph * gf) const;
266248

267249
llm_graph_result_ptr build_graph_defrag(
268-
const graph_params & params,
269-
ggml_cgraph * gf) const;
250+
llama_context & lctx,
251+
ggml_cgraph * gf) const;
270252

271253
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
272254
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
@@ -331,7 +313,7 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
331313
void restore() override;
332314
void commit() override;
333315

334-
bool update(const graph_params & params) override;
316+
bool update(llama_context & lctx) override;
335317

336318
void defrag_sched(float thold) override;
337319

0 commit comments

Comments
 (0)