Bug: attention_naive rejects ONNX-compliant Group Query Attention (mismatched K/V heads)
On burn-flex @ 7640aa7d, attention_naive in crates/burn-flex/src/ops/attention.rs asserts strict equality between the query-head count and the key/value-head count. This rejects Group Query Attention (GQA) and Multi-Query Attention (MQA), both of which the ONNX Attention-23 operator explicitly supports.
The ONNX Attention spec allows kv_num_heads < q_num_heads as long as q_num_heads % kv_num_heads == 0; each K/V head is then shared by q_num_heads / kv_num_heads query heads. burn-ndarray handles this; burn-flex panics.
Failing assertions
// crates/burn-flex/src/ops/attention.rs
let batch = q_shape[0];
let heads = q_shape[1];
let seq_q = q_shape[2];
let head_dim = q_shape[3];
// ...
assert_eq!(k_shape[0], batch, "attention_naive: key batch mismatch");
assert_eq!(k_shape[1], heads, "attention_naive: key heads mismatch"); // ← rejects GQA
// ...
assert_eq!(v_shape[1], heads, "attention_naive: value heads mismatch"); // ← rejects GQA
The heads here is taken from Q. For GQA, k_shape[1] and v_shape[1] are smaller than q_shape[1], so these assertions fire.
Minimal repro
use burn::backend::Flex;
use burn::tensor::{Tensor, TensorData};
use burn::tensor::module::attention;
type B = Flex;
fn main() {
let device = Default::default();
// GQA: 9 query heads, 3 KV heads. Each KV head shared by 3 query heads.
// q: [batch=2, q_heads=9, seq_q=4, head_dim=8]
let q = Tensor::<B, 4>::from_floats(
TensorData::from([[[[0.0f32; 8]; 4]; 9]; 2]),
&device,
);
// k, v: [batch=2, kv_heads=3, seq_kv=6, head_dim=8]
let k = Tensor::<B, 4>::from_floats(
TensorData::from([[[[0.0f32; 8]; 6]; 3]; 2]),
&device,
);
let v = k.clone();
// Expected: attention output with shape [2, 9, 4, 8].
// Actual: thread panics at attention.rs:658 with
// "attention_naive: key heads mismatch"
// left: 3, right: 9
let _out = attention::scaled_dot_product_attention(
q, k, v, None, None, Default::default(),
);
}
Expected behavior
When kv_heads < q_heads and q_heads % kv_heads == 0, share each K/V head across q_heads / kv_heads consecutive query heads. The standard implementation maps query head h to KV head h * kv_heads / q_heads (or equivalently h / (q_heads / kv_heads)). The ONNX reference and burn-ndarray both implement this.
This same pattern unifies three cases:
- MHA (
q_heads == kv_heads): each query head maps to its matching KV head.
- GQA (
q_heads > kv_heads, divisible): each KV head shared by q_heads / kv_heads query heads.
- MQA (
kv_heads == 1): one shared KV head for all queries.
How it was surfaced
Found while wiring f16 and bf16 decoding through the onnx-official-tests harness in tracel-ai/burn-onnx#393. With f16 plumbed through, test_attention_4d_gqa_with_past_and_present_fp16 started actually running (previously it only verified codegen + compile) and panics at attention.rs:658:
thread 'test_attention_4d_gqa_with_past_and_present_fp16' panicked at
burn-flex/src/ops/attention.rs:658:5:
assertion `left == right` failed: attention_naive: key heads mismatch
left: 3
right: 9
The model under test has Q [2, 9, 4, 8], K/V [2, 3, 6, 8] — i.e. 9 query heads, 3 KV heads, each KV head shared by 3 query heads (a textbook GQA configuration).
Environment
burn rev 7640aa7d0704a6a548587d67e77ef489a9a587b7
- Host: Darwin aarch64 (Apple Silicon), macOS 25.4
- Feature set: default
- Rust: stable
Suggested fix
Replace the strict head-count assert_eq! checks with a divisibility check:
let kv_heads = k_shape[1];
assert_eq!(v_shape[1], kv_heads, "attention_naive: K and V must agree on heads");
assert!(
heads % kv_heads == 0,
"attention_naive: q_heads ({heads}) must be divisible by kv_heads ({kv_heads})"
);
let q_per_kv = heads / kv_heads;
Then in the inner loop, index K/V with h / q_per_kv (or equivalently h * kv_heads / heads) instead of h. Strides and offsets need to use kv_heads rather than heads for K/V.
Bug:
attention_naiverejects ONNX-compliant Group Query Attention (mismatched K/V heads)On
burn-flex@7640aa7d,attention_naiveincrates/burn-flex/src/ops/attention.rsasserts strict equality between the query-head count and the key/value-head count. This rejects Group Query Attention (GQA) and Multi-Query Attention (MQA), both of which the ONNXAttention-23operator explicitly supports.The ONNX Attention spec allows
kv_num_heads < q_num_headsas long asq_num_heads % kv_num_heads == 0; each K/V head is then shared byq_num_heads / kv_num_headsquery heads.burn-ndarrayhandles this;burn-flexpanics.Failing assertions
The
headshere is taken from Q. For GQA,k_shape[1]andv_shape[1]are smaller thanq_shape[1], so these assertions fire.Minimal repro
Expected behavior
When
kv_heads < q_headsandq_heads % kv_heads == 0, share each K/V head acrossq_heads / kv_headsconsecutive query heads. The standard implementation maps query headhto KV headh * kv_heads / q_heads(or equivalentlyh / (q_heads / kv_heads)). The ONNX reference and burn-ndarray both implement this.This same pattern unifies three cases:
q_heads == kv_heads): each query head maps to its matching KV head.q_heads > kv_heads, divisible): each KV head shared byq_heads / kv_headsquery heads.kv_heads == 1): one shared KV head for all queries.How it was surfaced
Found while wiring
f16andbf16decoding through theonnx-official-testsharness in tracel-ai/burn-onnx#393. With f16 plumbed through,test_attention_4d_gqa_with_past_and_present_fp16started actually running (previously it only verified codegen + compile) and panics atattention.rs:658:The model under test has Q
[2, 9, 4, 8], K/V[2, 3, 6, 8]— i.e. 9 query heads, 3 KV heads, each KV head shared by 3 query heads (a textbook GQA configuration).Environment
burnrev7640aa7d0704a6a548587d67e77ef489a9a587b7Suggested fix
Replace the strict head-count
assert_eq!checks with a divisibility check:Then in the inner loop, index K/V with
h / q_per_kv(or equivalentlyh * kv_heads / heads) instead ofh. Strides and offsets need to usekv_headsrather thanheadsfor K/V.