feat: Turing GPU (T4/RTX) compat via capability-aware FA3 fallback, fp16+GradScaler#268
Open
Nidhicodes wants to merge 1 commit intokarpathy:masterfrom
Open
feat: Turing GPU (T4/RTX) compat via capability-aware FA3 fallback, fp16+GradScaler#268Nidhicodes wants to merge 1 commit intokarpathy:masterfrom
Nidhicodes wants to merge 1 commit intokarpathy:masterfrom
Conversation
…uced batch/seq for pre-Ampere GPUs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds automatic compatibility for pre-Ampere GPUs (Turing and older: T4, RTX 2000 series, GTX 1600 series) with zero behavior change for Ampere+ users (H100, A100, RTX 3000/4000 series).
This makes autoresearch runnable on free Google Colab (T4) out of the box, which is the most accessible GPU compute available to most people wanting to try this project.
Problem
Three things hard-crash on Turing GPUs:
RuntimeError: FlashAttention only supports Ampere GPUs or newertorch.compilekernel generationDEVICE_BATCH_SIZE=128+seq_len=2048allocates a(128, 2048, 2048)attention matrix = 1GB per layer, immediately OOM on 15GB T4Solution
All changes are gated on
cap = torch.cuda.get_device_capability()which is already computed at the top of the file. Ampere+ (cap >= (8, 0)) gets the original code path unchanged.Changes
fa3 = get_kernel(repo).flash_attn_interfacefa3 = Nonefa3.flash_attn_func(...)F.scaled_dot_product_attention(...)bfloat16float16+GradScalerbfloat16float32(GradScaler requires fp32 grads)bfloat16float32g.bfloat16()g.half()MAX_SEQ_LEN(2048)2**19(~524K tokens)2**16(~65K tokens)Why SDPA instead of FA2?
torch.nn.functional.scaled_dot_product_attentionis already available in the installed PyTorch version, requires no extra install, and automatically dispatches to the most efficient kernel available on the current hardware (Flash Attention 2 on supported GPUs, efficient attention otherwise).Why fp32 embeddings?
GradScalerrequires fp32 gradients, it cannot unscale fp16 gradients (ValueError: Attempting to unscale FP16 gradients). Since embeddings are looked up and their gradients flow back directly, they must stay fp32. The forward pass still runs in fp16 viaautocast.Why seq_len=512 and smaller batch?
Without FA3's fused kernel, the attention matrix is materialized explicitly in memory. At seq_len=2048 with batch=128, each layer needs 1GB just for the attention matrix. Seq_len=512 reduces this 16x. The T4 has 15GB VRAM but only ~12GB is usable after CUDA context overhead.
Results
Tested on Google Colab free tier (Tesla T4, SM 7.5, 15GB VRAM):
For reference, the H100 baseline from the repo is
val_bpb ~0.998at seq_len=2048 with ~953 steps. The T4 result is expected to be worse due to:Limitations & Future Work
mfu_percentis computed againstH100_BF16_PEAK_FLOPS- on T4 this reads as ~1%, which is misleading but not wrong (it's genuinely much slower)H100_BF16_PEAK_FLOPSto the actual device's peak FLOPS for accurate MFU reportingTesting
cap >= (8, 0)guard)