Skip to content

CANN: Add fused FFN op #15209

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1498,6 +1498,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.flash_attn = true;
}
).set_env("LLAMA_ARG_FLASH_ATTN"));
add_opt(common_arg(
{"-ffn", "--feed-forward-network"},
string_format("enable fused feed froward network (default: %s)", params.ffn ? "enabled" : "disabled"),
[](common_params & params) {
params.ffn = true;
}
).set_env("LLAMA_ARG_FFN"));
add_opt(common_arg(
{"-p", "--prompt"}, "PROMPT",
"prompt to start generation with; for system message, use -sys",
Expand Down
1 change: 1 addition & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1170,6 +1170,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
cparams.cb_eval_user_data = params.cb_eval_user_data;
cparams.offload_kqv = !params.no_kv_offload;
cparams.flash_attn = params.flash_attn;
cparams.ffn = params.ffn;
cparams.no_perf = params.no_perf;
cparams.op_offload = !params.no_op_offload;
cparams.swa_full = params.swa_full;
Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ struct common_params {
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
bool cont_batching = true; // insert new sequences for decoding on-the-fly
bool flash_attn = false; // flash attention
bool ffn = false; // fused feed forward network
bool no_perf = false; // disable performance metrics
bool ctx_shift = true; // context shift on inifinite text generation
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
Expand Down
17 changes: 17 additions & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,8 @@ extern "C" {

GGML_OP_GLU,

GGML_OP_FFN,

GGML_OP_COUNT,
};

Expand Down Expand Up @@ -2097,6 +2099,21 @@ extern "C" {
struct ggml_tensor * d,
bool masked);

GGML_API struct ggml_tensor * ggml_ffn_ext(
struct ggml_context * ctx,
struct ggml_tensor * cur,
struct ggml_tensor * up,
struct ggml_tensor * up_b,
struct ggml_tensor * up_s,
struct ggml_tensor * gate,
struct ggml_tensor * gate_b,
struct ggml_tensor * gate_s,
struct ggml_tensor * down,
struct ggml_tensor * down_b,
struct ggml_tensor * down_s,
struct ggml_tensor * act_scales,
int type_gate);

GGML_API struct ggml_tensor * ggml_ssm_conv(
struct ggml_context * ctx,
struct ggml_tensor * sx,
Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-cann/aclnn_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3397,3 +3397,7 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
GGML_ABORT("Function is not implemented.");
}
}

void ggml_cann_ffn(ggml_backend_cann_context& ctx, ggml_tensor* dst) {

}
2 changes: 2 additions & 0 deletions ggml/src/ggml-cann/aclnn_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,8 @@ void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst);
*/
void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst);

void ggml_cann_ffn(ggml_backend_cann_context& ctx, ggml_tensor* dst);

/*
* @brief A generic wrapper for ACL resources with custom deleter support.
*/
Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-cann/ggml-cann.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1881,6 +1881,8 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
case GGML_OP_FLASH_ATTN_EXT:
ggml_cann_flash_attn_ext(ctx, dst);
break;
case GGML_OP_FFN:
ggml_cann_ffn(ctx, dst);
default:
return false;
}
Expand Down Expand Up @@ -2544,6 +2546,8 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
}
return true;
}
case GGML_OP_FFN:
return true;
default:
return false;
}
Expand Down
40 changes: 38 additions & 2 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -1014,9 +1014,10 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"OPT_STEP_ADAMW",

"GLU",
"FFN",
};

static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87");
static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88");

static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
Expand Down Expand Up @@ -1115,9 +1116,10 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"adamw(x)",

"glu(x)",
"ffn(x)",
};

static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87");
static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88");

static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");

Expand Down Expand Up @@ -4958,6 +4960,40 @@ struct ggml_tensor * ggml_flash_attn_back(
return result;
}

struct ggml_tensor * ggml_ffn_ext(
struct ggml_context * ctx,
struct ggml_tensor * cur,
struct ggml_tensor * up,
struct ggml_tensor * up_b,
struct ggml_tensor * up_s,
struct ggml_tensor * gate,
struct ggml_tensor * gate_b,
struct ggml_tensor * gate_s,
struct ggml_tensor * down,
struct ggml_tensor * down_b,
struct ggml_tensor * down_s,
struct ggml_tensor * act_scales,
int type_op) {
int64_t ne[] = {10, 10, 10, 10};
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, ne);

ggml_set_op_params_i32(result, 0, type_op);

result->op = GGML_OP_FFN;
result->src[0] = up;
result->src[1] = up_b;
result->src[2] = up_s;
result->src[3] = gate;
result->src[4] = gate_b;
result->src[5] = gate_s;
result->src[6] = down;
result->src[7] = down_b;
result->src[8] = down_s;
result->src[9] = act_scales;

return result;
}

// ggml_ssm_conv

struct ggml_tensor * ggml_ssm_conv(
Expand Down
1 change: 1 addition & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ extern "C" {
bool embeddings; // if true, extract embeddings (together with logits)
bool offload_kqv; // offload the KQV ops (including the KV cache) to GPU
bool flash_attn; // use flash attention [EXPERIMENTAL]
bool ffn; // use fused ffn
bool no_perf; // measure performance timings
bool op_offload; // offload host tensor operations to device
bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
Expand Down
2 changes: 2 additions & 0 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ llama_context::llama_context(
cparams.embeddings = params.embeddings;
cparams.offload_kqv = params.offload_kqv;
cparams.flash_attn = params.flash_attn;
cparams.ffn = params.ffn;
cparams.no_perf = params.no_perf;
cparams.pooling_type = params.pooling_type;
cparams.warmup = false;
Expand Down Expand Up @@ -2265,6 +2266,7 @@ llama_context_params llama_context_default_params() {
/*.embeddings =*/ false,
/*.offload_kqv =*/ true,
/*.flash_attn =*/ false,
/*.ffn =*/ false,
/*.no_perf =*/ true,
/*.op_offload =*/ true,
/*.swa_full =*/ true,
Expand Down
1 change: 1 addition & 0 deletions src/llama-cparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ struct llama_cparams {
bool warmup;
bool op_offload;
bool kv_unified;
bool ffn;

enum llama_pooling_type pooling_type;

Expand Down
6 changes: 6 additions & 0 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,12 @@ ggml_tensor * llm_graph_context::build_ffn(
llm_ffn_op_type type_op,
llm_ffn_gate_type type_gate,
int il) const {

if (cparams.ffn) {
cur = ggml_ffn_ext(ctx0, cur, up, up_b, up_s, gate, gate_b,
gate_s, down, down_b, down_s, act_scales, type_gate);
return cur;
}
ggml_tensor * tmp = up ? build_lora_mm(up, cur) : cur;
cb(tmp, "ffn_up", il);

Expand Down
Loading