Skip to content

Commit 20010c4

Browse files
committed
graph : refactor context to not pass gf explicitly
ggml-ci
1 parent 01612b7 commit 20010c4

File tree

5 files changed

+292
-338
lines changed

5 files changed

+292
-338
lines changed

src/llama-context.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -694,7 +694,7 @@ bool llama_context::apply_adapter_cvec(
694694
return cvec.apply(model, data, len, n_embd, il_start, il_end);
695695
}
696696

697-
llm_graph_result_i * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
697+
llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
698698
if (mctx && !mctx->apply()) {
699699
LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
700700
ret = GGML_STATUS_FAILED;
@@ -1363,7 +1363,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
13631363
}
13641364

13651365
llm_graph_params llama_context::graph_params(
1366-
llm_graph_result_i * res,
1366+
llm_graph_result * res,
13671367
const llama_ubatch & ubatch,
13681368
const llama_memory_context_i * mctx,
13691369
llm_graph_type gtype) const {

src/llama-context.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ struct llama_context {
9494
// if memory_context is provided, it will be applied first to the context's memory
9595
// ret contains the status of the graph computation
9696
// returns nullptr only if ret != GGML_STATUS_SUCCESS
97-
llm_graph_result_i * process_ubatch(
97+
llm_graph_result * process_ubatch(
9898
const llama_ubatch & ubatch,
9999
llm_graph_type gtype,
100100
llama_memory_context_i * mctx,
@@ -199,7 +199,7 @@ struct llama_context {
199199

200200
private:
201201
llm_graph_params graph_params(
202-
llm_graph_result_i * res,
202+
llm_graph_result * res,
203203
const llama_ubatch & ubatch,
204204
const llama_memory_context_i * mctx,
205205
llm_graph_type gtype) const;

src/llama-graph.cpp

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,10 @@ llm_graph_input_i * llm_graph_result::add_input(llm_graph_input_ptr input) {
484484
return inputs.back().get();
485485
}
486486

487+
void llm_graph_result::set_params(const llm_graph_params & params) {
488+
this->params = params;
489+
}
490+
487491
//
488492
// llm_graph_context
489493
//
@@ -525,9 +529,10 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
525529
mctx (params.mctx),
526530
cross (params.cross),
527531
cb_func (params.cb),
528-
res (static_cast<llm_graph_result *>(params.res)),
529-
ctx0 (res->get_ctx()) {
530-
res->params = params;
532+
res (params.res),
533+
ctx0 (res->get_ctx()),
534+
gf (res->get_gf()) {
535+
res->set_params(params);
531536
}
532537

533538
void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
@@ -1117,7 +1122,6 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
11171122
}
11181123

11191124
ggml_tensor * llm_graph_context::build_attn_mha(
1120-
ggml_cgraph * gf,
11211125
ggml_tensor * q,
11221126
ggml_tensor * k,
11231127
ggml_tensor * v,
@@ -1251,7 +1255,6 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
12511255

12521256
ggml_tensor * llm_graph_context::build_attn(
12531257
llm_graph_input_attn_no_cache * inp,
1254-
ggml_cgraph * gf,
12551258
ggml_tensor * wo,
12561259
ggml_tensor * wo_b,
12571260
ggml_tensor * q_cur,
@@ -1279,7 +1282,7 @@ ggml_tensor * llm_graph_context::build_attn(
12791282
ggml_tensor * k = k_cur;
12801283
ggml_tensor * v = v_cur;
12811284

1282-
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1285+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
12831286
cb(cur, "kqv_out", il);
12841287

12851288
if (wo) {
@@ -1335,7 +1338,6 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
13351338

13361339
ggml_tensor * llm_graph_context::build_attn(
13371340
llm_graph_input_attn_kv_unified * inp,
1338-
ggml_cgraph * gf,
13391341
ggml_tensor * wo,
13401342
ggml_tensor * wo_b,
13411343
ggml_tensor * q_cur,
@@ -1368,7 +1370,7 @@ ggml_tensor * llm_graph_context::build_attn(
13681370
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
13691371
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
13701372

1371-
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1373+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
13721374
cb(cur, "kqv_out", il);
13731375

13741376
if (wo) {
@@ -1388,7 +1390,6 @@ ggml_tensor * llm_graph_context::build_attn(
13881390

13891391
ggml_tensor * llm_graph_context::build_attn(
13901392
llm_graph_input_attn_kv_unified_iswa * inp,
1391-
ggml_cgraph * gf,
13921393
ggml_tensor * wo,
13931394
ggml_tensor * wo_b,
13941395
ggml_tensor * q_cur,
@@ -1435,7 +1436,7 @@ ggml_tensor * llm_graph_context::build_attn(
14351436
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
14361437
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
14371438

1438-
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1439+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
14391440
cb(cur, "kqv_out", il);
14401441

14411442
if (wo) {
@@ -1468,7 +1469,6 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
14681469

14691470
ggml_tensor * llm_graph_context::build_attn(
14701471
llm_graph_input_attn_cross * inp,
1471-
ggml_cgraph * gf,
14721472
ggml_tensor * wo,
14731473
ggml_tensor * wo_b,
14741474
ggml_tensor * q_cur,
@@ -1490,7 +1490,7 @@ ggml_tensor * llm_graph_context::build_attn(
14901490
ggml_tensor * k = k_cur;
14911491
ggml_tensor * v = v_cur;
14921492

1493-
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1493+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
14941494
cb(cur, "kqv_out", il);
14951495

14961496
if (wo) {
@@ -1548,7 +1548,6 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
15481548
}
15491549

15501550
ggml_tensor * llm_graph_context::build_rs(
1551-
ggml_cgraph * gf,
15521551
ggml_tensor * s,
15531552
ggml_tensor * state_copy,
15541553
int32_t state_size,
@@ -1606,21 +1605,19 @@ llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
16061605

16071606
ggml_tensor * llm_graph_context::build_rs(
16081607
llm_graph_input_rs * inp,
1609-
ggml_cgraph * gf,
16101608
ggml_tensor * s,
16111609
int32_t state_size,
16121610
int32_t n_seqs,
16131611
const llm_graph_get_rows_fn & get_state_rows) const {
16141612
const auto * kv_state = inp->mctx;
16151613

1616-
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
1614+
return build_rs(s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
16171615
}
16181616

16191617
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
16201618
llm_graph_input_rs * inp,
1621-
ggml_cgraph * gf,
16221619
const llama_ubatch & ubatch,
1623-
int il) const {
1620+
int il) const {
16241621
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
16251622

16261623
const auto token_shift_count = hparams.token_shift_count;
@@ -1630,7 +1627,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
16301627
ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
16311628

16321629
ggml_tensor * token_shift = build_rs(
1633-
inp, gf, token_shift_all,
1630+
inp, token_shift_all,
16341631
hparams.n_embd_r(), n_seqs);
16351632

16361633
token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
@@ -1670,7 +1667,6 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
16701667
}
16711668

16721669
void llm_graph_context::build_pooling(
1673-
ggml_cgraph * gf,
16741670
ggml_tensor * cls,
16751671
ggml_tensor * cls_b,
16761672
ggml_tensor * cls_out,

src/llama-graph.h

Lines changed: 20 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -371,31 +371,11 @@ class llm_graph_input_mem_hybrid : public llm_graph_input_i {
371371
// along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
372372
// these are used by the llama_context to extact the relevant data, based on the compute parameters
373373

374-
// TODO: this interface seems redundant - remove it
375-
class llm_graph_result_i {
376-
public:
377-
virtual ~llm_graph_result_i() = default;
378-
379-
virtual ggml_tensor * get_tokens() const = 0;
380-
virtual ggml_tensor * get_logits() const = 0;
381-
virtual ggml_tensor * get_embd() const = 0;
382-
virtual ggml_tensor * get_embd_pooled() const = 0;
383-
384-
virtual ggml_cgraph * get_gf() = 0;
385-
virtual ggml_context * get_ctx() = 0;
386-
387-
virtual void reset() = 0;
388-
389-
virtual void set_inputs(const llama_ubatch * ubatch) = 0;
390-
391-
virtual bool can_reuse(const llm_graph_params & params) = 0;
392-
};
393-
394-
using llm_graph_result_ptr = std::unique_ptr<llm_graph_result_i>;
395-
396374
// callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
397375
using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
398376

377+
class llm_graph_result;
378+
399379
struct llm_graph_params {
400380
llm_arch arch = LLM_ARCH_UNKNOWN;
401381

@@ -418,8 +398,7 @@ struct llm_graph_params {
418398

419399
llm_graph_cb cb;
420400

421-
// TODO: temporary
422-
llm_graph_result_i * res;
401+
llm_graph_result * res;
423402

424403
// return true if the "other" params would result in a graph with the same topology as with the current params
425404
// having the same topology allows us to reuse the graph in some cases
@@ -464,35 +443,37 @@ struct llm_graph_params {
464443
}
465444
};
466445

467-
class llm_graph_result : public llm_graph_result_i {
446+
class llm_graph_result {
468447
public:
469448
llm_graph_result(int64_t max_nodes);
470449

471450
virtual ~llm_graph_result() = default;
472451

473-
ggml_tensor * get_tokens() const override { return t_tokens; }
474-
ggml_tensor * get_logits() const override { return t_logits; }
475-
ggml_tensor * get_embd() const override { return t_embd; }
476-
ggml_tensor * get_embd_pooled() const override { return t_embd_pooled; }
452+
ggml_tensor * get_tokens() const { return t_tokens; }
453+
ggml_tensor * get_logits() const { return t_logits; }
454+
ggml_tensor * get_embd() const { return t_embd; }
455+
ggml_tensor * get_embd_pooled() const { return t_embd_pooled; }
477456

478-
ggml_cgraph * get_gf() override { return gf; }
479-
ggml_context * get_ctx() override { return ctx_compute.get(); }
457+
ggml_cgraph * get_gf() const { return gf; }
458+
ggml_context * get_ctx() const { return ctx_compute.get(); }
480459

481460
int64_t get_max_nodes() const;
482461

483-
void reset() override;
462+
void reset();
484463

485-
void set_inputs(const llama_ubatch * ubatch) override;
464+
void set_inputs(const llama_ubatch * ubatch);
486465

487466
// try to update the existing graph result using the new graph parameters in order to reuse it
488467
// this can only be done if we determine that the resulting graph using the new graph parameters
489468
// would be identical to the existing graph. in that case, we simply have to update the memory
490469
// contexts of the input tensors of the graph and we can reuse it for another computation
491470
// return true if the graph was updated and can be reused
492-
bool can_reuse(const llm_graph_params & params) override;
471+
bool can_reuse(const llm_graph_params & params);
493472

494473
llm_graph_input_i * add_input(llm_graph_input_ptr input);
495474

475+
void set_params(const llm_graph_params & params);
476+
496477
// important graph nodes
497478
ggml_tensor * t_tokens = nullptr;
498479
ggml_tensor * t_logits = nullptr;
@@ -510,6 +491,7 @@ class llm_graph_result : public llm_graph_result_i {
510491

511492
int64_t max_nodes;
512493

494+
private:
513495
// keep a copy of the previous graph parameters
514496
// we will use this to determine whether the graph can be reused by comparing them with the new parameters
515497
// note: these are updated after constructing the new graph
@@ -519,6 +501,8 @@ class llm_graph_result : public llm_graph_result_i {
519501
int debug = 0;
520502
};
521503

504+
using llm_graph_result_ptr = std::unique_ptr<llm_graph_result>;
505+
522506
//
523507
// llm_graph_context
524508
//
@@ -576,6 +560,7 @@ struct llm_graph_context {
576560
llm_graph_result * res;
577561

578562
ggml_context * ctx0 = nullptr;
563+
ggml_cgraph * gf = nullptr;
579564

580565
llm_graph_context(const llm_graph_params & params);
581566
virtual ~llm_graph_context() = default;
@@ -661,7 +646,6 @@ struct llm_graph_context {
661646
//
662647

663648
ggml_tensor * build_attn_mha(
664-
ggml_cgraph * gf,
665649
ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
666650
ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
667651
ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
@@ -674,7 +658,6 @@ struct llm_graph_context {
674658

675659
ggml_tensor * build_attn(
676660
llm_graph_input_attn_no_cache * inp,
677-
ggml_cgraph * gf,
678661
ggml_tensor * wo,
679662
ggml_tensor * wo_b,
680663
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -689,7 +672,6 @@ struct llm_graph_context {
689672

690673
ggml_tensor * build_attn(
691674
llm_graph_input_attn_kv_unified * inp,
692-
ggml_cgraph * gf,
693675
ggml_tensor * wo,
694676
ggml_tensor * wo_b,
695677
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -705,7 +687,6 @@ struct llm_graph_context {
705687
// note: if k_cur or v_cur are not provided, they will not be stored in the memory
706688
ggml_tensor * build_attn(
707689
llm_graph_input_attn_kv_unified_iswa * inp,
708-
ggml_cgraph * gf,
709690
ggml_tensor * wo,
710691
ggml_tensor * wo_b,
711692
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -720,7 +701,6 @@ struct llm_graph_context {
720701

721702
ggml_tensor * build_attn(
722703
llm_graph_input_attn_cross * inp,
723-
ggml_cgraph * gf,
724704
ggml_tensor * wo,
725705
ggml_tensor * wo_b,
726706
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -742,7 +722,6 @@ struct llm_graph_context {
742722
// implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
743723
// `llama_memory_recurrent`
744724
ggml_tensor * build_rs(
745-
ggml_cgraph * gf,
746725
ggml_tensor * s,
747726
ggml_tensor * state_copy,
748727
int32_t state_size,
@@ -757,17 +736,15 @@ struct llm_graph_context {
757736

758737
ggml_tensor * build_rs(
759738
llm_graph_input_rs * inp,
760-
ggml_cgraph * gf,
761739
ggml_tensor * s,
762740
int32_t state_size,
763741
int32_t n_seqs,
764742
const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
765743

766744
ggml_tensor * build_rwkv_token_shift_load(
767745
llm_graph_input_rs * inp,
768-
ggml_cgraph * gf,
769746
const llama_ubatch & ubatch,
770-
int il) const;
747+
int il) const;
771748

772749
ggml_tensor * build_rwkv_token_shift_store(
773750
ggml_tensor * token_shift,
@@ -784,7 +761,6 @@ struct llm_graph_context {
784761
//
785762

786763
void build_pooling(
787-
ggml_cgraph * gf,
788764
ggml_tensor * cls,
789765
ggml_tensor * cls_b,
790766
ggml_tensor * cls_out,

0 commit comments

Comments
 (0)