Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,17 +340,20 @@ def refill_buffer():
# ---------------------------------------------------------------------------

@torch.no_grad()
def evaluate_bpb(model, tokenizer, batch_size):
def evaluate_bpb(model, tokenizer, batch_size, seq_len=None, max_steps=None):
"""
Bits per byte (BPB): vocab size-independent evaluation metric.
Sums per-token cross-entropy (in nats), sums target byte lengths,
then converts nats/byte to bits/byte. Special tokens (byte length 0)
are excluded from both sums.
Uses fixed MAX_SEQ_LEN so results are comparable across configs.
"""
seq_len = seq_len or MAX_SEQ_LEN
token_bytes = get_token_bytes(device="cuda")
val_loader = make_dataloader(tokenizer, batch_size, MAX_SEQ_LEN, "val")
steps = EVAL_TOKENS // (batch_size * MAX_SEQ_LEN)
val_loader = make_dataloader(tokenizer, batch_size, seq_len, "val")
steps = EVAL_TOKENS // (batch_size * seq_len)
if max_steps is not None:
steps = min(steps, max_steps)
total_nats = 0.0
total_bytes = 0
for _ in range(steps):
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@ description = "Autonomous pretraining research swarm"
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"kernels>=0.11.7",
"kernels>=0.12.2",
"matplotlib>=3.10.8",
"numpy>=2.2.6",
"pandas>=2.3.3",
"pyarrow>=21.0.0",
"requests>=2.32.0",
"rustbpe>=0.1.0",
"tiktoken>=0.11.0",
"torch==2.9.1",
"torch==2.8.0",
]

[tool.uv.sources]
Expand Down
Binary file added run.log
Binary file not shown.
85 changes: 70 additions & 15 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,26 @@
import torch.nn as nn
import torch.nn.functional as F

from kernels import get_kernel
cap = torch.cuda.get_device_capability()
# varunneal's FA3 is Hopper only, use kernels-community on non-Hopper GPUs
repo = "varunneal/flash-attention-3" if cap == (9, 0) else "kernels-community/flash-attn3"
fa3 = get_kernel(repo).flash_attn_interface
# Try flash-attn3 (Linux only); fall back to SDPA on Windows (no prebuilt FA3 wheels)
fa3 = None
try:
from kernels import get_kernel
cap = torch.cuda.get_device_capability()
repo = "varunneal/flash-attention-3" if cap == (9, 0) else "kernels-community/flash-attn3"
fa3 = get_kernel(repo).flash_attn_interface
except (FileNotFoundError, ModuleNotFoundError):
pass # Use SDPA fallback

from prepare import MAX_SEQ_LEN, TIME_BUDGET, Tokenizer, make_dataloader, evaluate_bpb

# Triton/torch.compile requires CUDA capability >= 7.0 (e.g. GTX 1050 Ti is 6.1)
USE_TORCH_COMPILE = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 7

def _maybe_compile(dynamic=False, fullgraph=False):
def decorator(fn):
return torch.compile(fn, dynamic=dynamic, fullgraph=fullgraph) if USE_TORCH_COMPILE else fn
return decorator

# ---------------------------------------------------------------------------
# GPT Model
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -72,6 +84,22 @@ def __init__(self, config, layer_idx):
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
self.ve_gate_channels = 32
self.ve_gate = nn.Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None
self._mask_cache = {} # for SDPA fallback

def _get_sdpa_mask(self, seq_len, window_size, device):
"""Build causal + sliding-window mask for SDPA (used when flash-attn3 unavailable, e.g. Windows)."""
window = window_size[0] if isinstance(window_size, tuple) else window_size
cache_key = (seq_len, int(window) if window is not None else -1, device.type, device.index)
mask = self._mask_cache.get(cache_key)
if mask is not None:
return mask
row = torch.arange(seq_len, device=device).unsqueeze(1)
col = torch.arange(seq_len, device=device).unsqueeze(0)
mask = col <= row # causal
if window is not None and window >= 0 and window < seq_len:
mask = mask & (col >= (row - window))
self._mask_cache[cache_key] = mask
return mask

def forward(self, x, ve, cos_sin, window_size):
B, T, C = x.size()
Expand All @@ -89,7 +117,20 @@ def forward(self, x, ve, cos_sin, window_size):
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
q, k = norm(q), norm(k)

y = fa3.flash_attn_func(q, k, v, causal=True, window_size=window_size)
if fa3 is not None:
y = fa3.flash_attn_func(q, k, v, causal=True, window_size=window_size)
else:
# SDPA fallback (Windows; flash-attn3 has no prebuilt wheels)
q = q.transpose(1, 2) # (B, H, T, D)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
attn_mask = self._get_sdpa_mask(T, window_size, q.device)
y = F.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, is_causal=False,
enable_gqa=self.n_kv_head < self.n_head,
)
y = y.transpose(1, 2)

y = y.contiguous().view(B, T, -1)
y = self.c_proj(y)
return y
Expand Down Expand Up @@ -301,7 +342,7 @@ def forward(self, idx, targets=None, reduction='mean'):
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
]

@torch.compile(dynamic=False, fullgraph=True)
@_maybe_compile(dynamic=False, fullgraph=True)
def adamw_step_fused(p, grad, exp_avg, exp_avg_sq, step_t, lr_t, beta1_t, beta2_t, eps_t, wd_t):
p.mul_(1 - lr_t * wd_t)
exp_avg.lerp_(grad, 1 - beta1_t)
Expand All @@ -312,7 +353,7 @@ def adamw_step_fused(p, grad, exp_avg, exp_avg_sq, step_t, lr_t, beta1_t, beta2_
step_size = lr_t / bias1
p.add_(exp_avg / denom, alpha=-step_size)

@torch.compile(dynamic=False, fullgraph=True)
@_maybe_compile(dynamic=False, fullgraph=True)
def muon_step_fused(stacked_grads, stacked_params, momentum_buffer, second_momentum_buffer,
momentum_t, lr_t, wd_t, beta2_t, ns_steps, red_dim):
# Nesterov momentum
Expand Down Expand Up @@ -437,11 +478,11 @@ def step(self):
TOTAL_BATCH_SIZE = 2**19 # ~524K tokens per optimizer step
EMBEDDING_LR = 0.6 # learning rate for token embeddings (Adam)
UNEMBEDDING_LR = 0.004 # learning rate for lm_head (Adam)
MATRIX_LR = 0.04 # learning rate for matrix parameters (Muon)
MATRIX_LR = 0.05 # learning rate for matrix parameters (Muon)
SCALAR_LR = 0.5 # learning rate for per-layer scalars (Adam)
WEIGHT_DECAY = 0.2 # cautious weight decay for Muon
ADAM_BETAS = (0.8, 0.95) # Adam beta1, beta2
WARMUP_RATIO = 0.0 # fraction of time budget for LR warmup
WARMUP_RATIO = 0.1 # 10% warmup
WARMDOWN_RATIO = 0.5 # fraction of time budget for LR warmdown
FINAL_LR_FRAC = 0.0 # final LR as fraction of initial

Expand All @@ -458,6 +499,18 @@ def step(self):
torch.cuda.manual_seed(42)
torch.set_float32_matmul_precision("high")
device = torch.device("cuda")

# Low-VRAM preset for GPUs < 6GB (e.g. GTX 1050 Ti 4GB) - SDPA materializes full attn matrices
VRAM_GB = torch.cuda.get_device_properties(0).total_memory / 1e9
if VRAM_GB < 6:
DEVICE_BATCH_SIZE = 32
TOTAL_BATCH_SIZE = 2**14 # 16K tokens per step
WINDOW_PATTERN = "SSSL" # banded attention (half context in S layers)
TRAIN_SEQ_LEN = min(MAX_SEQ_LEN, 256) # SDPA OOMs on long seqs
DEPTH = 4
print(f"Low-VRAM mode: {VRAM_GB:.1f}GB GPU detected, using seq_len={TRAIN_SEQ_LEN}, batch={DEVICE_BATCH_SIZE}, depth={DEPTH}")
else:
TRAIN_SEQ_LEN = MAX_SEQ_LEN
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
H100_BF16_PEAK_FLOPS = 989.5e12

Expand All @@ -470,7 +523,7 @@ def build_model_config(depth):
model_dim = ((base_dim + HEAD_DIM - 1) // HEAD_DIM) * HEAD_DIM
num_heads = model_dim // HEAD_DIM
return GPTConfig(
sequence_len=MAX_SEQ_LEN, vocab_size=vocab_size,
sequence_len=TRAIN_SEQ_LEN, vocab_size=vocab_size,
n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim,
window_pattern=WINDOW_PATTERN,
)
Expand All @@ -491,7 +544,7 @@ def build_model_config(depth):
num_flops_per_token = model.estimate_flops()
print(f"Estimated FLOPs per token: {num_flops_per_token:e}")

tokens_per_fwdbwd = DEVICE_BATCH_SIZE * MAX_SEQ_LEN
tokens_per_fwdbwd = DEVICE_BATCH_SIZE * TRAIN_SEQ_LEN
assert TOTAL_BATCH_SIZE % tokens_per_fwdbwd == 0
grad_accum_steps = TOTAL_BATCH_SIZE // tokens_per_fwdbwd

Expand All @@ -504,9 +557,10 @@ def build_model_config(depth):
weight_decay=WEIGHT_DECAY,
)

model = torch.compile(model, dynamic=False)
if USE_TORCH_COMPILE:
model = torch.compile(model, dynamic=False)

train_loader = make_dataloader(tokenizer, DEVICE_BATCH_SIZE, MAX_SEQ_LEN, "train")
train_loader = make_dataloader(tokenizer, DEVICE_BATCH_SIZE, TRAIN_SEQ_LEN, "train")
x, y, epoch = next(train_loader) # prefetch first batch

print(f"Time budget: {TIME_BUDGET}s")
Expand Down Expand Up @@ -609,7 +663,8 @@ def get_weight_decay(progress):
# Final eval
model.eval()
with autocast_ctx:
val_bpb = evaluate_bpb(model, tokenizer, DEVICE_BATCH_SIZE)
val_bpb = evaluate_bpb(model, tokenizer, DEVICE_BATCH_SIZE,
seq_len=TRAIN_SEQ_LEN, max_steps=50 if VRAM_GB < 6 else None)

# Final summary
t_end = time.time()
Expand Down
Loading