Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 26 additions & 12 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@
cap = torch.cuda.get_device_capability()
# varunneal's FA3 is Hopper only, use kernels-community on non-Hopper GPUs
repo = "varunneal/flash-attention-3" if cap == (9, 0) else "kernels-community/flash-attn3"
fa3 = get_kernel(repo).flash_attn_interface
fa3 = get_kernel(repo).flash_attn_interface if cap >= (8, 0) else None

from prepare import MAX_SEQ_LEN, TIME_BUDGET, Tokenizer, make_dataloader, evaluate_bpb
from prepare import MAX_SEQ_LEN as _MAX_SEQ_LEN, TIME_BUDGET, Tokenizer, make_dataloader, evaluate_bpb
MAX_SEQ_LEN = _MAX_SEQ_LEN if cap >= (8, 0) else 512

# ---------------------------------------------------------------------------
# GPT Model
Expand Down Expand Up @@ -90,7 +91,12 @@ def forward(self, x, ve, cos_sin, window_size):
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
q, k = norm(q), norm(k)

y = fa3.flash_attn_func(q, k, v, causal=True, window_size=window_size)
if fa3 is not None:
y = fa3.flash_attn_func(q, k, v, causal=True, window_size=window_size)
else:
q2, k2, v2 = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
y = F.scaled_dot_product_attention(q2, k2, v2, is_causal=True)
y = y.transpose(1, 2)
y = y.contiguous().view(B, T, -1)
y = self.c_proj(y)
return y
Expand Down Expand Up @@ -176,9 +182,10 @@ def init_weights(self):
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.cos, self.sin = cos, sin
# Cast embeddings to bf16
self.transformer.wte.to(dtype=torch.bfloat16)
embed_dtype = torch.bfloat16 if torch.cuda.get_device_capability() >= (8, 0) else torch.float32
self.transformer.wte.to(dtype=embed_dtype)
for ve in self.value_embeds.values():
ve.to(dtype=torch.bfloat16)
ve.to(dtype=embed_dtype)

def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
if device is None:
Expand All @@ -188,7 +195,8 @@ def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=No
t = torch.arange(seq_len, dtype=torch.float32, device=device)
freqs = torch.outer(t, inv_freq)
cos, sin = freqs.cos(), freqs.sin()
cos, sin = cos.bfloat16(), sin.bfloat16()
cos_dtype = torch.bfloat16 if torch.cuda.get_device_capability() >= (8, 0) else torch.float32
cos, sin = cos.to(cos_dtype), sin.to(cos_dtype)
cos, sin = cos[None, :, None, :], sin[None, :, None, :]
return cos, sin

Expand Down Expand Up @@ -321,7 +329,7 @@ def muon_step_fused(stacked_grads, stacked_params, momentum_buffer, second_momen
momentum_buffer.lerp_(stacked_grads, 1 - momentum)
g = stacked_grads.lerp_(momentum_buffer, momentum)
# Polar express orthogonalization
X = g.bfloat16()
X = g.bfloat16() if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0) else g.half()
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
if g.size(-2) > g.size(-1):
for a, b, c in polar_express_coeffs[:ns_steps]:
Expand Down Expand Up @@ -435,7 +443,7 @@ def step(self):
WINDOW_PATTERN = "SSSL" # sliding window pattern: L=full, S=half context

# Optimization
TOTAL_BATCH_SIZE = 2**19 # ~524K tokens per optimizer step
TOTAL_BATCH_SIZE = 2**19 if cap >= (8, 0) else 2**16 # ~524K tokens, reduced for pre-Ampere GPUs
EMBEDDING_LR = 0.6 # learning rate for token embeddings (Adam)
UNEMBEDDING_LR = 0.004 # learning rate for lm_head (Adam)
MATRIX_LR = 0.04 # learning rate for matrix parameters (Muon)
Expand All @@ -448,7 +456,7 @@ def step(self):

# Model size
DEPTH = 8 # number of transformer layers
DEVICE_BATCH_SIZE = 128 # per-device batch size (reduce if OOM)
DEVICE_BATCH_SIZE = 128 if cap >= (8, 0) else 16 # reduced for pre-Ampere GPUs

# ---------------------------------------------------------------------------
# Setup: tokenizer, model, optimizer, dataloader
Expand All @@ -459,7 +467,9 @@ def step(self):
torch.cuda.manual_seed(42)
torch.set_float32_matmul_precision("high")
device = torch.device("cuda")
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
amp_dtype = torch.bfloat16 if cap >= (8, 0) else torch.float16
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=amp_dtype)
scaler = torch.amp.GradScaler("cuda") if amp_dtype == torch.float16 else None
H100_BF16_PEAK_FLOPS = 989.5e12

tokenizer = Tokenizer.from_directory()
Expand Down Expand Up @@ -548,7 +558,7 @@ def get_weight_decay(progress):
loss = model(x, y)
train_loss = loss.detach()
loss = loss / grad_accum_steps
loss.backward()
(scaler.scale(loss) if scaler else loss).backward()
x, y, epoch = next(train_loader)

# Progress and schedules
Expand All @@ -561,7 +571,11 @@ def get_weight_decay(progress):
if group['kind'] == 'muon':
group["momentum"] = muon_momentum
group["weight_decay"] = muon_weight_decay
optimizer.step()
if scaler:
scaler.step(optimizer)
scaler.update()
else:
optimizer.step()
model.zero_grad(set_to_none=True)

train_loss_f = train_loss.item()
Expand Down