diff --git a/train.py b/train.py index 2e743974c..2b6d631ad 100644 --- a/train.py +++ b/train.py @@ -445,6 +445,7 @@ def step(self): WARMUP_RATIO = 0.0 # fraction of time budget for LR warmup WARMDOWN_RATIO = 0.5 # fraction of time budget for LR warmdown FINAL_LR_FRAC = 0.0 # final LR as fraction of initial +GRAD_CLIP_NORM = 1.0 # max gradient norm (0.0 to disable) # Model size DEPTH = 8 # number of transformer layers @@ -561,6 +562,8 @@ def get_weight_decay(progress): if group['kind'] == 'muon': group["momentum"] = muon_momentum group["weight_decay"] = muon_weight_decay + if GRAD_CLIP_NORM > 0.0: + torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP_NORM) optimizer.step() model.zero_grad(set_to_none=True)