Skip to content

burn-flex: attention_naive rejects ONNX-compliant Group Query Attention (mismatched K/V heads) #4930

@antimora

Description

@antimora

Bug: attention_naive rejects ONNX-compliant Group Query Attention (mismatched K/V heads)

On burn-flex @ 7640aa7d, attention_naive in crates/burn-flex/src/ops/attention.rs asserts strict equality between the query-head count and the key/value-head count. This rejects Group Query Attention (GQA) and Multi-Query Attention (MQA), both of which the ONNX Attention-23 operator explicitly supports.

The ONNX Attention spec allows kv_num_heads < q_num_heads as long as q_num_heads % kv_num_heads == 0; each K/V head is then shared by q_num_heads / kv_num_heads query heads. burn-ndarray handles this; burn-flex panics.

Failing assertions

// crates/burn-flex/src/ops/attention.rs
let batch = q_shape[0];
let heads = q_shape[1];
let seq_q = q_shape[2];
let head_dim = q_shape[3];
// ...
assert_eq!(k_shape[0], batch, "attention_naive: key batch mismatch");
assert_eq!(k_shape[1], heads, "attention_naive: key heads mismatch");      // ← rejects GQA
// ...
assert_eq!(v_shape[1], heads, "attention_naive: value heads mismatch");    // ← rejects GQA

The heads here is taken from Q. For GQA, k_shape[1] and v_shape[1] are smaller than q_shape[1], so these assertions fire.

Minimal repro

use burn::backend::Flex;
use burn::tensor::{Tensor, TensorData};
use burn::tensor::module::attention;

type B = Flex;

fn main() {
    let device = Default::default();

    // GQA: 9 query heads, 3 KV heads. Each KV head shared by 3 query heads.
    // q: [batch=2, q_heads=9, seq_q=4, head_dim=8]
    let q = Tensor::<B, 4>::from_floats(
        TensorData::from([[[[0.0f32; 8]; 4]; 9]; 2]),
        &device,
    );
    // k, v: [batch=2, kv_heads=3, seq_kv=6, head_dim=8]
    let k = Tensor::<B, 4>::from_floats(
        TensorData::from([[[[0.0f32; 8]; 6]; 3]; 2]),
        &device,
    );
    let v = k.clone();

    // Expected: attention output with shape [2, 9, 4, 8].
    // Actual:   thread panics at attention.rs:658 with
    //   "attention_naive: key heads mismatch"
    //   left: 3, right: 9
    let _out = attention::scaled_dot_product_attention(
        q, k, v, None, None, Default::default(),
    );
}

Expected behavior

When kv_heads < q_heads and q_heads % kv_heads == 0, share each K/V head across q_heads / kv_heads consecutive query heads. The standard implementation maps query head h to KV head h * kv_heads / q_heads (or equivalently h / (q_heads / kv_heads)). The ONNX reference and burn-ndarray both implement this.

This same pattern unifies three cases:

  • MHA (q_heads == kv_heads): each query head maps to its matching KV head.
  • GQA (q_heads > kv_heads, divisible): each KV head shared by q_heads / kv_heads query heads.
  • MQA (kv_heads == 1): one shared KV head for all queries.

How it was surfaced

Found while wiring f16 and bf16 decoding through the onnx-official-tests harness in tracel-ai/burn-onnx#393. With f16 plumbed through, test_attention_4d_gqa_with_past_and_present_fp16 started actually running (previously it only verified codegen + compile) and panics at attention.rs:658:

thread 'test_attention_4d_gqa_with_past_and_present_fp16' panicked at
  burn-flex/src/ops/attention.rs:658:5:
assertion `left == right` failed: attention_naive: key heads mismatch
  left: 3
 right: 9

The model under test has Q [2, 9, 4, 8], K/V [2, 3, 6, 8] — i.e. 9 query heads, 3 KV heads, each KV head shared by 3 query heads (a textbook GQA configuration).

Environment

  • burn rev 7640aa7d0704a6a548587d67e77ef489a9a587b7
  • Host: Darwin aarch64 (Apple Silicon), macOS 25.4
  • Feature set: default
  • Rust: stable

Suggested fix

Replace the strict head-count assert_eq! checks with a divisibility check:

let kv_heads = k_shape[1];
assert_eq!(v_shape[1], kv_heads, "attention_naive: K and V must agree on heads");
assert!(
    heads % kv_heads == 0,
    "attention_naive: q_heads ({heads}) must be divisible by kv_heads ({kv_heads})"
);
let q_per_kv = heads / kv_heads;

Then in the inner loop, index K/V with h / q_per_kv (or equivalently h * kv_heads / heads) instead of h. Strides and offsets need to use kv_heads rather than heads for K/V.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions