Skip to content

feat: Turing GPU (T4/RTX) compat via capability-aware FA3 fallback, fp16+GradScaler#268

Open
Nidhicodes wants to merge 1 commit intokarpathy:masterfrom
Nidhicodes:t4-baseline
Open

feat: Turing GPU (T4/RTX) compat via capability-aware FA3 fallback, fp16+GradScaler#268
Nidhicodes wants to merge 1 commit intokarpathy:masterfrom
Nidhicodes:t4-baseline

Conversation

@Nidhicodes
Copy link

@Nidhicodes Nidhicodes commented Mar 14, 2026

Summary

Adds automatic compatibility for pre-Ampere GPUs (Turing and older: T4, RTX 2000 series, GTX 1600 series) with zero behavior change for Ampere+ users (H100, A100, RTX 3000/4000 series).

This makes autoresearch runnable on free Google Colab (T4) out of the box, which is the most accessible GPU compute available to most people wanting to try this project.

Problem

Three things hard-crash on Turing GPUs:

  1. Flash Attention 3 - requires Ampere+ (SM 8.0+). T4 is SM 7.5. Crashes with RuntimeError: FlashAttention only supports Ampere GPUs or newer
  2. bfloat16 - T4 has no native bfloat16 support. PyTorch falls back to float32 emulation, causing silent OOM during torch.compile kernel generation
  3. Batch size - without FA3's memory efficiency, the default DEVICE_BATCH_SIZE=128 + seq_len=2048 allocates a (128, 2048, 2048) attention matrix = 1GB per layer, immediately OOM on 15GB T4

Solution

All changes are gated on cap = torch.cuda.get_device_capability() which is already computed at the top of the file. Ampere+ (cap >= (8, 0)) gets the original code path unchanged.

Changes

Location Ampere+ (unchanged) Pre-Ampere (new fallback)
FA3 load fa3 = get_kernel(repo).flash_attn_interface fa3 = None
Attention fa3.flash_attn_func(...) F.scaled_dot_product_attention(...)
Precision bfloat16 float16 + GradScaler
Embeddings cast to bfloat16 stay float32 (GradScaler requires fp32 grads)
Rotary emb cast to bfloat16 stay float32
Muon ortho g.bfloat16() g.half()
Seq length MAX_SEQ_LEN (2048) 512
Batch size 128 16
Total batch 2**19 (~524K tokens) 2**16 (~65K tokens)

Why SDPA instead of FA2?

torch.nn.functional.scaled_dot_product_attention is already available in the installed PyTorch version, requires no extra install, and automatically dispatches to the most efficient kernel available on the current hardware (Flash Attention 2 on supported GPUs, efficient attention otherwise).

Why fp32 embeddings?

GradScaler requires fp32 gradients, it cannot unscale fp16 gradients (ValueError: Attempting to unscale FP16 gradients). Since embeddings are looked up and their gradients flow back directly, they must stay fp32. The forward pass still runs in fp16 via autocast.

Why seq_len=512 and smaller batch?

Without FA3's fused kernel, the attention matrix is materialized explicitly in memory. At seq_len=2048 with batch=128, each layer needs 1GB just for the attention matrix. Seq_len=512 reduces this 16x. The T4 has 15GB VRAM but only ~12GB is usable after CUDA context overhead.

Results

Tested on Google Colab free tier (Tesla T4, SM 7.5, 15GB VRAM):

val_bpb:          1.333735
training_seconds: 300.5
total_seconds:    539.7
peak_vram_mb:     2496.9
mfu_percent:      1.18
total_tokens_M:   19.0
num_steps:        290
num_params_M:     50.3

For reference, the H100 baseline from the repo is val_bpb ~0.998 at seq_len=2048 with ~953 steps. The T4 result is expected to be worse due to:

  • 4x shorter context (512 vs 2048) - less signal per step
  • ~35x lower throughput (1.2% vs ~40% MFU) - fewer total steps
  • No FA3 memory efficiency - smaller effective batch

Limitations & Future Work

  • mfu_percent is computed against H100_BF16_PEAK_FLOPS - on T4 this reads as ~1%, which is misleading but not wrong (it's genuinely much slower)
  • seq_len=512 means the T4 baseline is not directly comparable to H100 results
  • A future improvement would be to also adjust H100_BF16_PEAK_FLOPS to the actual device's peak FLOPS for accurate MFU reporting
  • With Colab Pro (A100), all these fallbacks are bypassed automatically and the full H100-comparable config runs as-is

Testing

  • Runs to completion on T4 (300s budget respected)
  • No changes to H100/A100 code path (cap >= (8, 0) guard)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant