Skip to content

Integrate Muon optimizer (2725) #2803

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions recipes/configs/qwen2/0.5B_full_single_device_muon.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Config for single device full finetuning in full_finetune_single_device.py
# using a Qwen2 0.5B
#
# This config assumes that you've run the following command before launching
# this run:
# tune download Qwen/Qwen2-0.5B-Instruct --output-dir /tmp/Qwen2-0.5B-Instruct
#
# To launch on a single device, run the following command from root:
# tune run full_finetune_single_device --config qwen2/0.5B_full_single_device
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run full_finetune_single_device --config qwen2/0.5B_full_single_device checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
# This config works only for training on single device.

output_dir: /tmp/torchtune/qwen2_0_5B/full_single_device # /tmp may be deleted by your system. Change it to your preference.

# Tokenizer
tokenizer:
_component_: torchtune.models.qwen2.qwen2_tokenizer
path: /tmp/Qwen2-0.5B-Instruct/vocab.json
merges_file: /tmp/Qwen2-0.5B-Instruct/merges.txt
max_seq_len: null

# Dataset
dataset:
_component_: torchtune.datasets.alpaca_cleaned_dataset
packed: False # True increases speed
seed: null
shuffle: False #True

# Model Arguments
model:
_component_: torchtune.models.qwen2.qwen2_0_5b

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Qwen2-0.5B-Instruct
checkpoint_files: [
model.safetensors
]
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: QWEN2
resume_from_checkpoint: False

# Fine-tuning arguments
batch_size: 3
epochs: 1
optimizer:
_component_: torchtune.modules.optim.SingleDeviceMuonWithAuxAdam
muon_lr: 0.02
muon_momentum: 0.95
weight_decay: 0
adam_lr: 2e-5
adam_betas: [0.9, 0.95]
adam_eps: 1e-10

loss:
_component_: torchtune.modules.loss.LinearCrossEntropyLoss
optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1

max_steps_per_epoch: null
gradient_accumulation_steps: 1 # Use to increase effective batch size
clip_grad_norm: null
compile: False # torch.compile the model + loss, True increases speed + decreases memory

# Training environment
device: cuda

# Memory management
enable_activation_checkpointing: True # True reduces memory
enable_activation_offloading: False # True reduces memory

# Reduced precision
dtype: bf16

# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
log_level: INFO # DEBUG, WARN, etc.


# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
enabled: False

#Output directory of trace artifacts
output_dir: ${output_dir}/profiling_outputs

#`torch.profiler.ProfilerActivity` types to trace
cpu: True
cuda: True

#trace options passed to `torch.profiler.profile`
profile_memory: False
with_stack: False
record_shapes: True
with_flops: False

# `torch.profiler.schedule` options:
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
wait_steps: 5
warmup_steps: 3
active_steps: 2
num_cycles: 1
5 changes: 4 additions & 1 deletion recipes/full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,8 +437,10 @@ def _setup_optimizer(
**cfg_optimizer,
)
else:
optimizer_cls = cfg_optimizer["_component_"]
params = self._model.named_parameters() if 'muon' in optimizer_cls.lower() else self._model.parameters()
optimizer = config.instantiate(
cfg_optimizer, params=self._model.parameters()
cfg_optimizer, params=params
)
if opt_state_dict:
optimizer.load_state_dict(opt_state_dict)
Expand Down Expand Up @@ -604,6 +606,7 @@ def train(self) -> None:
self.lr_scheduler.step()

self.global_step += 1
print(f"running_loss: {running_loss} ; num_tokens: {num_tokens}")
loss_value = (
running_loss
/ (num_tokens if not self.optimizer_in_bwd else 1.0)
Expand Down
203 changes: 203 additions & 0 deletions torchtune/modules/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

import torch
from torch.optim import Optimizer
import torch.distributed as dist
from torch.distributed.tensor import distribute_tensor, DTensor

__all__ = ["OptimizerInBackward"]

Expand Down Expand Up @@ -82,3 +84,204 @@ def load_state_dict(self, state_dict):
)
for idx, opt in self._optimizers.items():
opt.load_state_dict(state_dict["optimizers"][str(idx)])

class Muon(torch.optim.Optimizer):
"""
Muon - MomentUm Orthogonalized by Newton-schulz

Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
the advantage that it can be stably run in bfloat16 on the GPU.

Some warnings:
- We believe this optimizer is unlikely to work well for training with small batch size.
- We believe it may not work well for finetuning pretrained models, but we haven't tested this.

Arguments:
params: The parameters to be optimized.
lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
momentum: The momentum used by the internal SGD. (0.95 is a good default)
nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
adamw_lr: The learning rate for the internal AdamW.
adamw_betas: The betas for the internal AdamW.
adamw_eps: The epsilon for the internal AdamW.
adamw_wd: The weight decay for the internal AdamW.
"""
def __init__(self, params, muon_selector=None, lr=0.02, momentum=0.95, nesterov=True, ns_steps=6,
adamw_lr=3e-4, adamw_betas=[0.95, 0.95], adamw_eps=1e-8, adamw_wd=0):

defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps,
adamw_lr_ratio=adamw_lr/lr, adamw_betas=adamw_betas,
adamw_eps=adamw_eps, adamw_wd=adamw_wd)

if muon_selector is None:
muon_selector = lambda name, param: (
param.requires_grad and
param.ndim >= 2 and # Check if scalar
"embed" not in name.lower() and # Check if embedding layer
"tok" not in name.lower() and # Check if token embeddings
"head" not in name.lower() and # Check if output head
"bias" not in name.lower() # Check if bias term
)

named_params = list(params)

muon_params = [p for n, p in named_params if muon_selector(n, p)]
adamw_params = [p for n, p in named_params if not muon_selector(n, p)]

super().__init__([*muon_params, *adamw_params], defaults)

# Sort parameters into those for which we will use Muon, and those for which we will not
# we cant pickle booleans for saving, so we will use 1=True, 0=False
def assign_muon(p):
if p.ndim >= 2 and p.size(0) < 10000:
self.state[p]['use_muon'] = 1
else:
self.state[p]['use_muon'] = 0

if isinstance(muon_params[0], dict):
for group in muon_params:
for p in group['params']:
assign_muon(p)
else:
for p in muon_params:
assign_muon(p)

def assign_adamw(p):
# Do not use Muon for parameters in adamw_params
self.state[p]['use_muon'] = 0

if len(adamw_params) and isinstance(adamw_params[0], dict):
for group in adamw_params:
for p in group['params']:
assign_adamw(p)
else:
for p in adamw_params:
assign_adamw(p)

if torch.distributed.is_initialized():
self.world_size = torch.distributed.get_world_size()
self.rank = torch.distributed.get_rank()
else:
self.world_size = 1
self.rank = 0

def to_dist(self, x, from_local=False, **meta):
if from_local:
return DTensor.from_local(
x,
device_mesh=meta["device_mesh"],
placements=meta["placements"],
shape=meta["shape"],
stride=meta["stride"],
)
else:
return distribute_tensor(x, device_mesh=meta["device_mesh"], placements=meta["placements"])


def to_local(self, x, keep_sharded=False):
if isinstance(x, DTensor):
meta = dict(
device_mesh=x.device_mesh,
placements=x.placements,
shape=x.shape,
stride=x.stride(),
)
if keep_sharded:
return x.to_local(), meta
else:
return x.full_tensor(), meta

return x, None

def zeropower_via_newtonschulz5(self, G, steps=10, eps=1e-7):
"""
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
zero even beyond the point where the iteration no longer converges all the way to one everywhere
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
performance at all relative to UV^T, where USV^T = G is the SVD.
"""
assert len(G.shape) == 2
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.bfloat16()
X /= (X.norm() + eps) # ensure top singular value <= 1
if G.size(0) > G.size(1):
X = X.T
for _ in range(steps):
A = X @ X.T
B = b * A + c * A @ A
X = a * X + B @ X
if G.size(0) > G.size(1):
X = X.T
return X

def step(self, closure=None):
"""Perform a single optimization step.

Args:
closure (Callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
lr = group["lr"]
momentum = group['momentum']
for i, p in enumerate(group['params']):
if self.state[p]['use_muon'] == 1:
g = p.grad
if g is None:
continue
if g.ndim > 2:
g = g.view(g.size(0), -1)
state = self.state[p]
if 'momentum_buffer' not in state:
state['momentum_buffer'] = torch.zeros_like(g)
buf = state['momentum_buffer']
buf.mul_(momentum).add_(g)
if group['nesterov']:
g = g.add(buf, alpha=momentum)

meta = None
if isinstance(g, DTensor):
g, meta = self.to_local(g, keep_sharded=False)
# gives NaNs when done with Dtensor, instead of throwing a typical op not supported error, quite sneaky
g = self.zeropower_via_newtonschulz5(g, steps=group['ns_steps'])
if meta is not None:
g = self.to_dist(g, **meta)
g *= max(1, g.size(0)/g.size(1))**0.5

g = g.view_as(p.data).type_as(p.data)
p.data.add_(g, alpha=-lr)
else:
# these are all pointwise so we can stay in Dtensor
g = p.grad
if g is None:
continue
state = self.state[p]
if 'step' not in state:
state['step'] = 0
state['moment1'] = torch.zeros_like(g)
state['moment2'] = torch.zeros_like(g)
state['step'] += 1
step = state['step']
buf1 = state['moment1']
buf2 = state['moment2']
buf1.lerp_(g, 1-group['adamw_betas'][0])
buf2.lerp_(g.square(), 1-group['adamw_betas'][1])

g = buf1 / (group['adamw_eps'] + buf2.sqrt())

bias_correction1 = 1 - group['adamw_betas'][0]**step
bias_correction2 = 1 - group['adamw_betas'][1]**step
scale = bias_correction1 / bias_correction2**0.5
p.data.mul_(1 - lr * group['adamw_wd'])
p.data.add_(g, alpha=-lr/scale)
3 changes: 3 additions & 0 deletions torchtune/training/lr_schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
from torch.optim.lr_scheduler import LambdaLR
from torchtune.training.memory import OptimizerInBackwardWrapper
from torchtune.modules.optim import Muon


def get_cosine_schedule_with_warmup(
Expand Down Expand Up @@ -88,7 +89,9 @@ def get_lr(
)

# LR Schedulers are the same across all param groups for full_finetune right now

lr = param_groups[0]["lr"]
if isinstance(optimizer, Muon): return lr # return Muon learning rate if Muon optimizer
for group in param_groups:
if group["lr"] != lr:
raise RuntimeError("LR Schedulers are different across all param groups ")
Expand Down