diff --git a/train.py b/train.py index 2e743974c..fd556c99a 100644 --- a/train.py +++ b/train.py @@ -21,9 +21,10 @@ cap = torch.cuda.get_device_capability() # varunneal's FA3 is Hopper only, use kernels-community on non-Hopper GPUs repo = "varunneal/flash-attention-3" if cap == (9, 0) else "kernels-community/flash-attn3" -fa3 = get_kernel(repo).flash_attn_interface +fa3 = get_kernel(repo).flash_attn_interface if cap >= (8, 0) else None -from prepare import MAX_SEQ_LEN, TIME_BUDGET, Tokenizer, make_dataloader, evaluate_bpb +from prepare import MAX_SEQ_LEN as _MAX_SEQ_LEN, TIME_BUDGET, Tokenizer, make_dataloader, evaluate_bpb +MAX_SEQ_LEN = _MAX_SEQ_LEN if cap >= (8, 0) else 512 # --------------------------------------------------------------------------- # GPT Model @@ -90,7 +91,12 @@ def forward(self, x, ve, cos_sin, window_size): q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) q, k = norm(q), norm(k) - y = fa3.flash_attn_func(q, k, v, causal=True, window_size=window_size) + if fa3 is not None: + y = fa3.flash_attn_func(q, k, v, causal=True, window_size=window_size) + else: + q2, k2, v2 = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + y = F.scaled_dot_product_attention(q2, k2, v2, is_causal=True) + y = y.transpose(1, 2) y = y.contiguous().view(B, T, -1) y = self.c_proj(y) return y @@ -176,9 +182,10 @@ def init_weights(self): cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) self.cos, self.sin = cos, sin # Cast embeddings to bf16 - self.transformer.wte.to(dtype=torch.bfloat16) + embed_dtype = torch.bfloat16 if torch.cuda.get_device_capability() >= (8, 0) else torch.float32 + self.transformer.wte.to(dtype=embed_dtype) for ve in self.value_embeds.values(): - ve.to(dtype=torch.bfloat16) + ve.to(dtype=embed_dtype) def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None): if device is None: @@ -188,7 +195,8 @@ def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=No t = torch.arange(seq_len, dtype=torch.float32, device=device) freqs = torch.outer(t, inv_freq) cos, sin = freqs.cos(), freqs.sin() - cos, sin = cos.bfloat16(), sin.bfloat16() + cos_dtype = torch.bfloat16 if torch.cuda.get_device_capability() >= (8, 0) else torch.float32 + cos, sin = cos.to(cos_dtype), sin.to(cos_dtype) cos, sin = cos[None, :, None, :], sin[None, :, None, :] return cos, sin @@ -321,7 +329,7 @@ def muon_step_fused(stacked_grads, stacked_params, momentum_buffer, second_momen momentum_buffer.lerp_(stacked_grads, 1 - momentum) g = stacked_grads.lerp_(momentum_buffer, momentum) # Polar express orthogonalization - X = g.bfloat16() + X = g.bfloat16() if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0) else g.half() X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6) if g.size(-2) > g.size(-1): for a, b, c in polar_express_coeffs[:ns_steps]: @@ -435,7 +443,7 @@ def step(self): WINDOW_PATTERN = "SSSL" # sliding window pattern: L=full, S=half context # Optimization -TOTAL_BATCH_SIZE = 2**19 # ~524K tokens per optimizer step +TOTAL_BATCH_SIZE = 2**19 if cap >= (8, 0) else 2**16 # ~524K tokens, reduced for pre-Ampere GPUs EMBEDDING_LR = 0.6 # learning rate for token embeddings (Adam) UNEMBEDDING_LR = 0.004 # learning rate for lm_head (Adam) MATRIX_LR = 0.04 # learning rate for matrix parameters (Muon) @@ -448,7 +456,7 @@ def step(self): # Model size DEPTH = 8 # number of transformer layers -DEVICE_BATCH_SIZE = 128 # per-device batch size (reduce if OOM) +DEVICE_BATCH_SIZE = 128 if cap >= (8, 0) else 16 # reduced for pre-Ampere GPUs # --------------------------------------------------------------------------- # Setup: tokenizer, model, optimizer, dataloader @@ -459,7 +467,9 @@ def step(self): torch.cuda.manual_seed(42) torch.set_float32_matmul_precision("high") device = torch.device("cuda") -autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) +amp_dtype = torch.bfloat16 if cap >= (8, 0) else torch.float16 +autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=amp_dtype) +scaler = torch.amp.GradScaler("cuda") if amp_dtype == torch.float16 else None H100_BF16_PEAK_FLOPS = 989.5e12 tokenizer = Tokenizer.from_directory() @@ -548,7 +558,7 @@ def get_weight_decay(progress): loss = model(x, y) train_loss = loss.detach() loss = loss / grad_accum_steps - loss.backward() + (scaler.scale(loss) if scaler else loss).backward() x, y, epoch = next(train_loader) # Progress and schedules @@ -561,7 +571,11 @@ def get_weight_decay(progress): if group['kind'] == 'muon': group["momentum"] = muon_momentum group["weight_decay"] = muon_weight_decay - optimizer.step() + if scaler: + scaler.step(optimizer) + scaler.update() + else: + optimizer.step() model.zero_grad(set_to_none=True) train_loss_f = train_loss.item()