Skip to content

Commit 826d80b

Browse files
committed
fuse repeat_interleave and post attention residual onto other operators
1 parent b7d2834 commit 826d80b

File tree

1 file changed

+19
-22
lines changed

1 file changed

+19
-22
lines changed

applications/llama_3.2_1b/llama_npu.py

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -201,26 +201,33 @@ def __init__(self, config, prompt_len):
201201
context=elf_ctx
202202
)
203203

204+
repeat_interleave_op = AIERepeat(
205+
rows=config.n_kv_groups,
206+
cols=prompt_len * config.head_dim, # Max context length
207+
repeat=config.n_heads // config.n_kv_groups,
208+
transfer_size=config.head_dim,
209+
context=self.context
210+
)
211+
204212
self.decode.post_attn_fused_op = FusedMLIROperator(
205213
"post_attn_decode",
206214
[
207-
(rms_norm_op, "x_pre_norm", "W_norm2", "x_norm"),
215+
(residual_add_op, "x", "attn_output", "x"),
216+
(rms_norm_op, "x", "W_norm2", "x_norm"),
208217
(gemv_ffn_up_gate_op, "W_ffn_gate", "x_norm", "ffn_gate"),
209218
(gemv_ffn_up_gate_op, "W_ffn_up", "x_norm", "ffn_up"),
210219
(silu_ffn_op, "ffn_gate", "ffn_gate"),
211220
(eltwise_mul_ffn_op, "ffn_gate", "ffn_up", "ffn_hidden"),
212221
(gemv_ffn_down_op, "W_ffn_down", "ffn_hidden", "ffn_output"),
213-
(residual_add_op, "x_pre_norm", "ffn_output", "x_out"),
222+
(residual_add_op, "x", "ffn_output", "x"),
214223
],
215224
input_args=[
216-
"x_pre_norm",
217225
"W_norm2",
218226
"W_ffn_gate",
219227
"W_ffn_up",
220228
"W_ffn_down"
221229
],
222230
output_args=[
223-
"x_out"
224231
],
225232
context=elf_ctx
226233
).compile()
@@ -287,6 +294,8 @@ def __init__(self, config, prompt_len):
287294
(rope_keys_op, "keys", "rope_angles", "keys"),
288295
(strided_copy_cache_op, "keys", "keys_cache"),
289296
(strided_copy_cache_op, "values", "values_cache"),
297+
(repeat_interleave_op, "keys_cache", "attn_scores_keys"),
298+
(repeat_interleave_op, "values_cache", "attn_scores_values"),
290299
],
291300
input_args=[
292301
"W_attn_query",
@@ -304,7 +313,7 @@ def __init__(self, config, prompt_len):
304313
],
305314
buffer_sizes={
306315
"keys_cache": cache_buffer_size,
307-
"values_cache": cache_buffer_size
316+
"values_cache": cache_buffer_size,
308317
},
309318
context=elf_ctx
310319
).compile()
@@ -445,16 +454,6 @@ def get_patch_locs(elf_data, magic):
445454
context=self.context
446455
).compile().get_callable()
447456

448-
# Repeat interleave for keys: (n_kv_groups, context_len, head_dim) -> (n_heads, context_len, head_dim)
449-
# Compile with max context length, then patch at runtime for actual context_len
450-
self.decode.attn_repeat_interleave = AIERepeat(
451-
rows=config.n_kv_groups,
452-
cols=prompt_len * config.head_dim, # Max context length
453-
repeat=config.n_heads // config.n_kv_groups,
454-
transfer_size=config.head_dim,
455-
context=self.context
456-
).compile().get_callable()
457-
458457
# Attention projection operators
459458
# Query projection: (seq_len, emb_dim) -> (seq_len, n_heads * head_dim)
460459
self.prefill.attn_query = AIEGEMM(
@@ -1067,22 +1066,22 @@ def llama_forward_pass_decode(config, state):
10671066
def transformer_block_forward_decode(config, num_preceding_tokens, layer_idx):
10681067
aie_ops.decode.rms_norm(aie_buffers.decode.x, aie_buffers.W_norm1[layer_idx], aie_buffers.decode.x_norm) # Step 1: RMS normalization
10691068
grouped_query_attention_forward_decode(config, num_preceding_tokens, layer_idx) # Step 2: Attention; results stored in attn_output
1070-
aie_ops.decode.residual_add(aie_buffers.decode.x, aie_buffers.decode.attn_output, aie_buffers.decode.x) # Step 3: Residual
10711069

10721070
# Step 4-6: Fused post-norm + SwiGLU + residual
10731071
fused_op = aie_ops.decode.post_attn_fused
10741072
fused_op.input_buffer.view_as_torch().to("cpu")[:] = 0
10751073
fused_op.output_buffer.view_as_torch().to("cpu")[:] = 0
10761074
fused_op.scratch_buffer.view_as_torch().to("cpu")[:] = 0
1077-
fused_op.get_buffer("x_pre_norm").to("cpu").view_as_torch()[:] = aie_buffers.decode.x.to("cpu").view_as_torch().flatten()
1075+
fused_op.get_buffer("x").to("cpu").view_as_torch()[:] = aie_buffers.decode.x.to("cpu").view_as_torch().flatten()
1076+
fused_op.get_buffer("attn_output").to("cpu").view_as_torch()[:] = aie_buffers.decode.attn_output.to("cpu").view_as_torch().flatten()
10781077
fused_op.get_buffer("W_norm2").to("cpu").view_as_torch()[:] = aie_buffers.W_norm2[layer_idx].to("cpu").view_as_torch().flatten()
10791078
fused_op.get_buffer("W_ffn_gate").to("cpu").view_as_torch()[:] = aie_buffers.W_ffn_gate_decode[layer_idx].to("cpu").view_as_torch().flatten()
10801079
fused_op.get_buffer("W_ffn_up").to("cpu").view_as_torch()[:] = aie_buffers.W_ffn_up_decode[layer_idx].to("cpu").view_as_torch().flatten()
10811080
fused_op.get_buffer("W_ffn_down").to("cpu").view_as_torch()[:] = aie_buffers.W_ffn_down_decode[layer_idx].to("cpu").view_as_torch().flatten()
10821081

10831082
fused_op()
10841083

1085-
aie_buffers.decode.x.to("cpu").view_as_torch()[:] = fused_op.get_buffer("x_out").to("cpu").view_as_torch()[:]
1084+
aie_buffers.decode.x.to("cpu").view_as_torch()[:] = fused_op.get_buffer("x").to("cpu").view_as_torch()[:]
10861085

10871086

10881087
def grouped_query_attention_forward_decode(config, num_preceding_tokens, layer_idx):
@@ -1109,14 +1108,12 @@ def grouped_query_attention_forward_decode(config, num_preceding_tokens, layer_i
11091108
aie_buffers.decode.values.to("cpu").view_as_torch().view(-1)[:] = fused_op.get_buffer("values").to("cpu").view_as_torch().flatten()
11101109
aie_buffers.keys_cache[layer_idx].to("cpu").view_as_torch().flatten()[:] = fused_op.get_buffer("keys_cache").to("cpu").view_as_torch().flatten()
11111110
aie_buffers.values_cache[layer_idx].to("cpu").view_as_torch().flatten()[:] = fused_op.get_buffer("values_cache").to("cpu").view_as_torch().flatten()
1111+
aie_buffers.decode.attn_scores_keys.to("cpu").view_as_torch().flatten()[:] = fused_op.get_buffer("attn_scores_keys").to("cpu").view_as_torch().flatten()
1112+
aie_buffers.decode.attn_scores_values.to("cpu").view_as_torch().flatten()[:] = fused_op.get_buffer("attn_scores_values").to("cpu").view_as_torch().flatten()
11121113
aie_buffers.decode.queries.to("npu")
11131114
aie_buffers.decode.keys.to("npu")
11141115
aie_buffers.decode.values.to("npu")
11151116

1116-
# Step 4: Repeat keys and values for grouped attention using AIERepeat on NPU
1117-
aie_ops.decode.attn_repeat_interleave(aie_buffers.keys_cache[layer_idx], aie_buffers.decode.attn_scores_keys)
1118-
aie_ops.decode.attn_repeat_interleave(aie_buffers.values_cache[layer_idx], aie_buffers.decode.attn_scores_values)
1119-
11201117
# Step 5: Compute attention scores
11211118
# Copy repeated keys from keys_repeated buffer to attn_scores_keys for GEMV
11221119
aie_ops.decode.gemv_attn_scores(aie_buffers.decode.attn_scores_keys, aie_buffers.decode.queries, aie_buffers.decode.attn_scores)

0 commit comments

Comments
 (0)