@@ -201,26 +201,33 @@ def __init__(self, config, prompt_len):
201201 context = elf_ctx
202202 )
203203
204+ repeat_interleave_op = AIERepeat (
205+ rows = config .n_kv_groups ,
206+ cols = prompt_len * config .head_dim , # Max context length
207+ repeat = config .n_heads // config .n_kv_groups ,
208+ transfer_size = config .head_dim ,
209+ context = self .context
210+ )
211+
204212 self .decode .post_attn_fused_op = FusedMLIROperator (
205213 "post_attn_decode" ,
206214 [
207- (rms_norm_op , "x_pre_norm" , "W_norm2" , "x_norm" ),
215+ (residual_add_op , "x" , "attn_output" , "x" ),
216+ (rms_norm_op , "x" , "W_norm2" , "x_norm" ),
208217 (gemv_ffn_up_gate_op , "W_ffn_gate" , "x_norm" , "ffn_gate" ),
209218 (gemv_ffn_up_gate_op , "W_ffn_up" , "x_norm" , "ffn_up" ),
210219 (silu_ffn_op , "ffn_gate" , "ffn_gate" ),
211220 (eltwise_mul_ffn_op , "ffn_gate" , "ffn_up" , "ffn_hidden" ),
212221 (gemv_ffn_down_op , "W_ffn_down" , "ffn_hidden" , "ffn_output" ),
213- (residual_add_op , "x_pre_norm " , "ffn_output" , "x_out " ),
222+ (residual_add_op , "x " , "ffn_output" , "x " ),
214223 ],
215224 input_args = [
216- "x_pre_norm" ,
217225 "W_norm2" ,
218226 "W_ffn_gate" ,
219227 "W_ffn_up" ,
220228 "W_ffn_down"
221229 ],
222230 output_args = [
223- "x_out"
224231 ],
225232 context = elf_ctx
226233 ).compile ()
@@ -287,6 +294,8 @@ def __init__(self, config, prompt_len):
287294 (rope_keys_op , "keys" , "rope_angles" , "keys" ),
288295 (strided_copy_cache_op , "keys" , "keys_cache" ),
289296 (strided_copy_cache_op , "values" , "values_cache" ),
297+ (repeat_interleave_op , "keys_cache" , "attn_scores_keys" ),
298+ (repeat_interleave_op , "values_cache" , "attn_scores_values" ),
290299 ],
291300 input_args = [
292301 "W_attn_query" ,
@@ -304,7 +313,7 @@ def __init__(self, config, prompt_len):
304313 ],
305314 buffer_sizes = {
306315 "keys_cache" : cache_buffer_size ,
307- "values_cache" : cache_buffer_size
316+ "values_cache" : cache_buffer_size ,
308317 },
309318 context = elf_ctx
310319 ).compile ()
@@ -445,16 +454,6 @@ def get_patch_locs(elf_data, magic):
445454 context = self .context
446455 ).compile ().get_callable ()
447456
448- # Repeat interleave for keys: (n_kv_groups, context_len, head_dim) -> (n_heads, context_len, head_dim)
449- # Compile with max context length, then patch at runtime for actual context_len
450- self .decode .attn_repeat_interleave = AIERepeat (
451- rows = config .n_kv_groups ,
452- cols = prompt_len * config .head_dim , # Max context length
453- repeat = config .n_heads // config .n_kv_groups ,
454- transfer_size = config .head_dim ,
455- context = self .context
456- ).compile ().get_callable ()
457-
458457 # Attention projection operators
459458 # Query projection: (seq_len, emb_dim) -> (seq_len, n_heads * head_dim)
460459 self .prefill .attn_query = AIEGEMM (
@@ -1067,22 +1066,22 @@ def llama_forward_pass_decode(config, state):
10671066def transformer_block_forward_decode (config , num_preceding_tokens , layer_idx ):
10681067 aie_ops .decode .rms_norm (aie_buffers .decode .x , aie_buffers .W_norm1 [layer_idx ], aie_buffers .decode .x_norm ) # Step 1: RMS normalization
10691068 grouped_query_attention_forward_decode (config , num_preceding_tokens , layer_idx ) # Step 2: Attention; results stored in attn_output
1070- aie_ops .decode .residual_add (aie_buffers .decode .x , aie_buffers .decode .attn_output , aie_buffers .decode .x ) # Step 3: Residual
10711069
10721070 # Step 4-6: Fused post-norm + SwiGLU + residual
10731071 fused_op = aie_ops .decode .post_attn_fused
10741072 fused_op .input_buffer .view_as_torch ().to ("cpu" )[:] = 0
10751073 fused_op .output_buffer .view_as_torch ().to ("cpu" )[:] = 0
10761074 fused_op .scratch_buffer .view_as_torch ().to ("cpu" )[:] = 0
1077- fused_op .get_buffer ("x_pre_norm" ).to ("cpu" ).view_as_torch ()[:] = aie_buffers .decode .x .to ("cpu" ).view_as_torch ().flatten ()
1075+ fused_op .get_buffer ("x" ).to ("cpu" ).view_as_torch ()[:] = aie_buffers .decode .x .to ("cpu" ).view_as_torch ().flatten ()
1076+ fused_op .get_buffer ("attn_output" ).to ("cpu" ).view_as_torch ()[:] = aie_buffers .decode .attn_output .to ("cpu" ).view_as_torch ().flatten ()
10781077 fused_op .get_buffer ("W_norm2" ).to ("cpu" ).view_as_torch ()[:] = aie_buffers .W_norm2 [layer_idx ].to ("cpu" ).view_as_torch ().flatten ()
10791078 fused_op .get_buffer ("W_ffn_gate" ).to ("cpu" ).view_as_torch ()[:] = aie_buffers .W_ffn_gate_decode [layer_idx ].to ("cpu" ).view_as_torch ().flatten ()
10801079 fused_op .get_buffer ("W_ffn_up" ).to ("cpu" ).view_as_torch ()[:] = aie_buffers .W_ffn_up_decode [layer_idx ].to ("cpu" ).view_as_torch ().flatten ()
10811080 fused_op .get_buffer ("W_ffn_down" ).to ("cpu" ).view_as_torch ()[:] = aie_buffers .W_ffn_down_decode [layer_idx ].to ("cpu" ).view_as_torch ().flatten ()
10821081
10831082 fused_op ()
10841083
1085- aie_buffers .decode .x .to ("cpu" ).view_as_torch ()[:] = fused_op .get_buffer ("x_out " ).to ("cpu" ).view_as_torch ()[:]
1084+ aie_buffers .decode .x .to ("cpu" ).view_as_torch ()[:] = fused_op .get_buffer ("x " ).to ("cpu" ).view_as_torch ()[:]
10861085
10871086
10881087def grouped_query_attention_forward_decode (config , num_preceding_tokens , layer_idx ):
@@ -1109,14 +1108,12 @@ def grouped_query_attention_forward_decode(config, num_preceding_tokens, layer_i
11091108 aie_buffers .decode .values .to ("cpu" ).view_as_torch ().view (- 1 )[:] = fused_op .get_buffer ("values" ).to ("cpu" ).view_as_torch ().flatten ()
11101109 aie_buffers .keys_cache [layer_idx ].to ("cpu" ).view_as_torch ().flatten ()[:] = fused_op .get_buffer ("keys_cache" ).to ("cpu" ).view_as_torch ().flatten ()
11111110 aie_buffers .values_cache [layer_idx ].to ("cpu" ).view_as_torch ().flatten ()[:] = fused_op .get_buffer ("values_cache" ).to ("cpu" ).view_as_torch ().flatten ()
1111+ aie_buffers .decode .attn_scores_keys .to ("cpu" ).view_as_torch ().flatten ()[:] = fused_op .get_buffer ("attn_scores_keys" ).to ("cpu" ).view_as_torch ().flatten ()
1112+ aie_buffers .decode .attn_scores_values .to ("cpu" ).view_as_torch ().flatten ()[:] = fused_op .get_buffer ("attn_scores_values" ).to ("cpu" ).view_as_torch ().flatten ()
11121113 aie_buffers .decode .queries .to ("npu" )
11131114 aie_buffers .decode .keys .to ("npu" )
11141115 aie_buffers .decode .values .to ("npu" )
11151116
1116- # Step 4: Repeat keys and values for grouped attention using AIERepeat on NPU
1117- aie_ops .decode .attn_repeat_interleave (aie_buffers .keys_cache [layer_idx ], aie_buffers .decode .attn_scores_keys )
1118- aie_ops .decode .attn_repeat_interleave (aie_buffers .values_cache [layer_idx ], aie_buffers .decode .attn_scores_values )
1119-
11201117 # Step 5: Compute attention scores
11211118 # Copy repeated keys from keys_repeated buffer to attn_scores_keys for GEMV
11221119 aie_ops .decode .gemv_attn_scores (aie_buffers .decode .attn_scores_keys , aie_buffers .decode .queries , aie_buffers .decode .attn_scores )
0 commit comments