diff --git a/common/arg.cpp b/common/arg.cpp index 0f01bb31454a4..87b667ebc8135 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1498,6 +1498,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.flash_attn = true; } ).set_env("LLAMA_ARG_FLASH_ATTN")); + add_opt(common_arg( + {"-ffn", "--feed-forward-network"}, + string_format("enable fused feed froward network (default: %s)", params.ffn ? "enabled" : "disabled"), + [](common_params & params) { + params.ffn = true; + } + ).set_env("LLAMA_ARG_FFN")); add_opt(common_arg( {"-p", "--prompt"}, "PROMPT", "prompt to start generation with; for system message, use -sys", diff --git a/common/common.cpp b/common/common.cpp index c6962d1d19b33..7df221245d470 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1170,6 +1170,7 @@ struct llama_context_params common_context_params_to_llama(const common_params & cparams.cb_eval_user_data = params.cb_eval_user_data; cparams.offload_kqv = !params.no_kv_offload; cparams.flash_attn = params.flash_attn; + cparams.ffn = params.ffn; cparams.no_perf = params.no_perf; cparams.op_offload = !params.no_op_offload; cparams.swa_full = params.swa_full; diff --git a/common/common.h b/common/common.h index 5eab199af559e..e9e03eb7d676e 100644 --- a/common/common.h +++ b/common/common.h @@ -347,6 +347,7 @@ struct common_params { bool simple_io = false; // improves compatibility with subprocesses and limited consoles bool cont_batching = true; // insert new sequences for decoding on-the-fly bool flash_attn = false; // flash attention + bool ffn = false; // fused feed forward network bool no_perf = false; // disable performance metrics bool ctx_shift = true; // context shift on inifinite text generation bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 2f06e1e39b225..86a3152249b25 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -543,6 +543,8 @@ extern "C" { GGML_OP_GLU, + GGML_OP_FFN, + GGML_OP_COUNT, }; @@ -2097,6 +2099,21 @@ extern "C" { struct ggml_tensor * d, bool masked); + GGML_API struct ggml_tensor * ggml_ffn_ext( + struct ggml_context * ctx, + struct ggml_tensor * cur, + struct ggml_tensor * up, + struct ggml_tensor * up_b, + struct ggml_tensor * up_s, + struct ggml_tensor * gate, + struct ggml_tensor * gate_b, + struct ggml_tensor * gate_s, + struct ggml_tensor * down, + struct ggml_tensor * down_b, + struct ggml_tensor * down_s, + struct ggml_tensor * act_scales, + int type_gate); + GGML_API struct ggml_tensor * ggml_ssm_conv( struct ggml_context * ctx, struct ggml_tensor * sx, diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 07d6b8b67d47c..0b4563ddd0fea 100755 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -3397,3 +3397,7 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ GGML_ABORT("Function is not implemented."); } } + +void ggml_cann_ffn(ggml_backend_cann_context& ctx, ggml_tensor* dst) { + +} \ No newline at end of file diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h index 5c510cc9932e8..0a9d011d5629f 100755 --- a/ggml/src/ggml-cann/aclnn_ops.h +++ b/ggml/src/ggml-cann/aclnn_ops.h @@ -740,6 +740,8 @@ void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst); */ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst); +void ggml_cann_ffn(ggml_backend_cann_context& ctx, ggml_tensor* dst); + /* * @brief A generic wrapper for ACL resources with custom deleter support. */ diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index cf575b367500a..7cfdfb787f309 100755 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -1881,6 +1881,8 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx, case GGML_OP_FLASH_ATTN_EXT: ggml_cann_flash_attn_ext(ctx, dst); break; + case GGML_OP_FFN: + ggml_cann_ffn(ctx, dst); default: return false; } @@ -2544,6 +2546,8 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, } return true; } + case GGML_OP_FFN: + return true; default: return false; } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 55a76f8248c09..7263267dbf341 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1014,9 +1014,10 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "OPT_STEP_ADAMW", "GLU", + "FFN", }; -static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87"); +static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1115,9 +1116,10 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "adamw(x)", "glu(x)", + "ffn(x)", }; -static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87"); +static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -4958,6 +4960,40 @@ struct ggml_tensor * ggml_flash_attn_back( return result; } +struct ggml_tensor * ggml_ffn_ext( + struct ggml_context * ctx, + struct ggml_tensor * cur, + struct ggml_tensor * up, + struct ggml_tensor * up_b, + struct ggml_tensor * up_s, + struct ggml_tensor * gate, + struct ggml_tensor * gate_b, + struct ggml_tensor * gate_s, + struct ggml_tensor * down, + struct ggml_tensor * down_b, + struct ggml_tensor * down_s, + struct ggml_tensor * act_scales, + int type_op) { + int64_t ne[] = {10, 10, 10, 10}; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, ne); + + ggml_set_op_params_i32(result, 0, type_op); + + result->op = GGML_OP_FFN; + result->src[0] = up; + result->src[1] = up_b; + result->src[2] = up_s; + result->src[3] = gate; + result->src[4] = gate_b; + result->src[5] = gate_s; + result->src[6] = down; + result->src[7] = down_b; + result->src[8] = down_s; + result->src[9] = act_scales; + + return result; +} + // ggml_ssm_conv struct ggml_tensor * ggml_ssm_conv( diff --git a/include/llama.h b/include/llama.h index 545e957e5f52b..5ee7e79e87893 100644 --- a/include/llama.h +++ b/include/llama.h @@ -332,6 +332,7 @@ extern "C" { bool embeddings; // if true, extract embeddings (together with logits) bool offload_kqv; // offload the KQV ops (including the KV cache) to GPU bool flash_attn; // use flash attention [EXPERIMENTAL] + bool ffn; // use fused ffn bool no_perf; // measure performance timings bool op_offload; // offload host tensor operations to device bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 26a5cf9c3f8db..c545b2e5b9f48 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -43,6 +43,7 @@ llama_context::llama_context( cparams.embeddings = params.embeddings; cparams.offload_kqv = params.offload_kqv; cparams.flash_attn = params.flash_attn; + cparams.ffn = params.ffn; cparams.no_perf = params.no_perf; cparams.pooling_type = params.pooling_type; cparams.warmup = false; @@ -2265,6 +2266,7 @@ llama_context_params llama_context_default_params() { /*.embeddings =*/ false, /*.offload_kqv =*/ true, /*.flash_attn =*/ false, + /*.ffn =*/ false, /*.no_perf =*/ true, /*.op_offload =*/ true, /*.swa_full =*/ true, diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 38750affc500b..b19810a0f6083 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -34,6 +34,7 @@ struct llama_cparams { bool warmup; bool op_offload; bool kv_unified; + bool ffn; enum llama_pooling_type pooling_type; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 053c72d6dc8d1..b8ac3543d4855 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -644,6 +644,12 @@ ggml_tensor * llm_graph_context::build_ffn( llm_ffn_op_type type_op, llm_ffn_gate_type type_gate, int il) const { + + if (cparams.ffn) { + cur = ggml_ffn_ext(ctx0, cur, up, up_b, up_s, gate, gate_b, + gate_s, down, down_b, down_s, act_scales, type_gate); + return cur; + } ggml_tensor * tmp = up ? build_lora_mm(up, cur) : cur; cb(tmp, "ffn_up", il);