A high-performance GPU implementation of Co-ALIBI (Contextual Attention with Linear Biases), a novel attention mechanism that extends ALiBi with contextual positional encoding through sigmoid-based penalty terms.
- Contextual Position Encoding: Uses sigmoid-based cumulative penalties for position-aware attention
- Optimized Triton Kernels: Custom CUDA kernels via Triton for both forward and backward passes
- High Performance: Achieves ~180 TFLOPS/s (forward) and ~80 TFLOPS/s (backward) on H100
- FlashAttention-Compatible: Similar memory efficiency and computational complexity
- Accuracy: Passes validation with eps=1e-4 against reference implementation
pip install torch tritonfrom co_alibi_attn import co_alibi_attention
# Input tensors (B=batch, H=heads, S=sequence, D=head_dim)
q = torch.randn(B, H, S, D, device='cuda', dtype=torch.float16)
k = torch.randn(B, H, S, D, device='cuda', dtype=torch.float16)
v = torch.randn(B, H, S, D, device='cuda', dtype=torch.float16)
# Apply Co-ALIBI attention
output = co_alibi_attention(q, k, v, causal=True)Co-ALIBI modifies standard attention by introducing contextual position penalties:
- Compute raw attention scores:
p_raw = Q @ K^T * scale - Calculate sigmoid penalties:
σ(p_raw)for all valid positions - Apply cumulative penalty:
z = Σ_{j>i} σ(q_i · k_j) - Adjust scores:
p_adjusted = p_raw - slope * z - Apply softmax and compute output:
O = softmax(p_adjusted) @ V
The key innovation is the sigmoid-based cumulative penalty that provides context-aware positional biases.
On NVIDIA H100 (sequence length 4096, 16 heads, head_dim 128):
| Operation | TFLOPS/s | Latency (ms) |
|---|---|---|
| Forward | ~160 | ~0.88 |
| Backward | ~80 | ~3.4 |
co_alibi_attn/
├── co_alibi_attn.py # Main attention implementation
├── co_alibi_fwd_kernel.py # Triton forward kernel
├── co_alibi_bwd_kernel.py # Triton backward kernel
├── benchmark_flops.py # Performance benchmarking
├── benchmark_fwd_pass.py # Forward pass validation
└── benchmark_bwd_pass.py # Backward pass validation
model.py # Reference implementation for testing
- Causal Masking: Built-in support for autoregressive models
- Numerical Stability: Uses log-sum-exp trick for stable softmax computation
- Multi-Query Attention: Supports different numbers of Q and KV heads
- Configurable Slopes: ALiBi slopes computed based on number of heads with bias_max parameter
Compare performance with FlashAttention 2:
python co_alibi_attn/benchmark_flops.pyValidate accuracy:
python co_alibi_attn/benchmark_fwd_pass.py
python co_alibi_attn/benchmark_bwd_pass.py