4
4
#include " llama-batch.h"
5
5
#include " llama-cparams.h"
6
6
#include " llama-model.h"
7
+ #include " llama-context.h"
7
8
8
9
#include < algorithm>
9
10
#include < cassert>
@@ -367,10 +368,10 @@ void llama_kv_cache_unified::commit() {
367
368
pending.ranges .clear ();
368
369
}
369
370
370
- bool llama_kv_cache_unified::update (const graph_params & params ) {
371
+ bool llama_kv_cache_unified::update (llama_context & lctx ) {
371
372
bool need_reserve = false ;
372
373
373
- const auto & sched = params. sched ;
374
+ const auto & sched = lctx. get_sched () ;
374
375
375
376
if (has_shift) {
376
377
if (!get_can_shift ()) {
@@ -381,17 +382,17 @@ bool llama_kv_cache_unified::update(const graph_params & params) {
381
382
382
383
// apply K-shift if needed
383
384
if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
384
- ggml_backend_sched_reset (sched);
385
+ ggml_backend_sched_reset (sched. get () );
385
386
386
- auto * gf = params .graph_init ();
387
+ auto * gf = lctx .graph_init ();
387
388
388
- auto res = build_graph_shift (params , gf);
389
+ auto res = build_graph_shift (lctx , gf);
389
390
390
- ggml_backend_sched_alloc_graph (sched, gf);
391
+ ggml_backend_sched_alloc_graph (sched. get () , gf);
391
392
392
393
res->set_inputs (nullptr );
393
394
394
- params .graph_compute (gf);
395
+ lctx .graph_compute (gf, false );
395
396
396
397
need_reserve = true ;
397
398
}
@@ -408,18 +409,18 @@ bool llama_kv_cache_unified::update(const graph_params & params) {
408
409
if (do_defrag) {
409
410
LLAMA_LOG_DEBUG (" %s: defragmenting KV cache\n " , __func__);
410
411
411
- if (defrag_prepare (params. n_max_nodes )) {
412
- ggml_backend_sched_reset (sched);
412
+ if (defrag_prepare (lctx. graph_max_nodes () )) {
413
+ ggml_backend_sched_reset (sched. get () );
413
414
414
- auto * gf = params .graph_init ();
415
+ auto * gf = lctx .graph_init ();
415
416
416
- auto res = build_graph_defrag (params , gf);
417
+ auto res = build_graph_defrag (lctx , gf);
417
418
418
- ggml_backend_sched_alloc_graph (sched, gf);
419
+ ggml_backend_sched_alloc_graph (sched. get () , gf);
419
420
420
421
res->set_inputs (nullptr );
421
422
422
- params .graph_compute (gf);
423
+ lctx .graph_compute (gf, false );
423
424
424
425
need_reserve = true ;
425
426
}
@@ -591,17 +592,17 @@ size_t llama_kv_cache_unified::size_v_bytes() const {
591
592
}
592
593
593
594
ggml_tensor * llama_kv_cache_unified::build_rope_shift (
594
- const graph_params & params ,
595
- ggml_context * ctx,
596
- ggml_tensor * cur,
597
- ggml_tensor * shift,
598
- ggml_tensor * factors,
599
- float freq_base,
600
- float freq_scale,
601
- ggml_backend_buffer * bbuf) const {
602
- const auto & cparams = params. cparams ;
603
- const auto & backends = params. backends ;
604
- const auto & sched = params. sched ;
595
+ llama_context & lctx ,
596
+ ggml_context * ctx,
597
+ ggml_tensor * cur,
598
+ ggml_tensor * shift,
599
+ ggml_tensor * factors,
600
+ float freq_base,
601
+ float freq_scale,
602
+ ggml_backend_buffer * bbuf) const {
603
+ const auto & cparams = lctx. get_cparams () ;
604
+ const auto & backends = lctx. get_backends () ;
605
+ const auto & sched = lctx. get_sched () ;
605
606
606
607
const auto & n_ctx_orig = cparams.n_ctx_orig_yarn ;
607
608
@@ -626,7 +627,7 @@ ggml_tensor * llama_kv_cache_unified::build_rope_shift(
626
627
for (const auto & backend : backends) {
627
628
// Figure out which backend KV cache belongs to
628
629
if (ggml_backend_supports_buft (backend.get (), ggml_backend_buffer_get_type (bbuf))) {
629
- ggml_backend_sched_set_tensor_backend (sched, tmp, backend.get ());
630
+ ggml_backend_sched_set_tensor_backend (sched. get () , tmp, backend.get ());
630
631
break ;
631
632
}
632
633
}
@@ -674,13 +675,13 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
674
675
}
675
676
676
677
llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift (
677
- const graph_params & params ,
678
- ggml_cgraph * gf) const {
678
+ llama_context & lctx ,
679
+ ggml_cgraph * gf) const {
679
680
auto res = std::make_unique<llm_graph_result>();
680
681
681
- auto * ctx = params .get_ctx_compute ();
682
+ auto * ctx = lctx .get_ctx_compute (). get ();
682
683
683
- const auto & cparams = params. cparams ;
684
+ const auto & cparams = lctx. get_cparams () ;
684
685
685
686
const auto & n_layer = hparams.n_layer ;
686
687
@@ -716,7 +717,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
716
717
ggml_row_size (k_l[il]->type , n_embd_k_gqa),
717
718
0 );
718
719
719
- ggml_tensor * cur = build_rope_shift (params , ctx, k, inp->k_shift , rope_factors, freq_base_l, freq_scale_l, k_l[il]->buffer );
720
+ ggml_tensor * cur = build_rope_shift (lctx , ctx, k, inp->k_shift , rope_factors, freq_base_l, freq_scale_l, k_l[il]->buffer );
720
721
721
722
ggml_build_forward_expand (gf, cur);
722
723
}
@@ -727,15 +728,15 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
727
728
}
728
729
729
730
llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag (
730
- const graph_params & params ,
731
- ggml_cgraph * gf) const {
731
+ llama_context & lctx ,
732
+ ggml_cgraph * gf) const {
732
733
auto res = std::make_unique<llm_graph_result>();
733
734
734
- auto * ctx = params .get_ctx_compute ();
735
+ auto * ctx = lctx .get_ctx_compute (). get ();
735
736
736
737
const auto & ids = defrag_info.ids ;
737
738
738
- const auto & cparams = params. cparams ;
739
+ const auto & cparams = lctx. get_cparams () ;
739
740
740
741
#if 0
741
742
// CPU defrag
@@ -1725,8 +1726,8 @@ void llama_kv_cache_recurrent::commit() {
1725
1726
pending.ranges .clear ();
1726
1727
}
1727
1728
1728
- bool llama_kv_cache_recurrent::update (const graph_params & params ) {
1729
- GGML_UNUSED (params );
1729
+ bool llama_kv_cache_recurrent::update (llama_context & lctx ) {
1730
+ GGML_UNUSED (lctx );
1730
1731
return false ;
1731
1732
}
1732
1733
0 commit comments