-
Notifications
You must be signed in to change notification settings - Fork 653
Integrate Muon optimizer (2725) #2803
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2803
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Thanks for the first pass! Let's split single device and distributed versions of Muon in 2 separate files to improve readability. Speaking about plots: We will need some comparison performance plots against AdamW and general Wandb plots (loss) and results on evaluation recipe. |
I've added few comments, but it looks great! There are 2 things on which we might need to think though:
|
Based on above comments, 2 things came to my mind:
Muon will be assigned to only param1, param2 while AdamW will be assigned for all remaining ones. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is awesome, thanks for porting this over!
The TL;DR of my thoughts is that we should optimize for the AuxAdam varities of the Muon optimizer so that our code stays clean and organized. I made some suggestions so that the instantiation of these classes is easy from code or config.
torchtune/modules/muon.py
Outdated
@@ -0,0 +1,253 @@ | |||
###################################################### |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you move this to our optim.py
file where OptimizerInBackward
lives?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I see Mark maybe suggested two files, but I'd prefer to limit everything to optim.py
until we see the need for new files. Our library has also gotten a little bloated.
torchtune/modules/muon.py
Outdated
return buf1c / (buf2c.sqrt() + eps) | ||
|
||
|
||
class MuonWithAuxAdam(torch.optim.Optimizer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should optimize for this interface so that users don't need to specify multiple optimizers in their code and we can remove the messy branching logic.
My rough proposal is:
class MuonWithAuxAdam(torch.optim.Optimizer):
def __init__(
self,
params,
*,
muon_selector=None,
muon_lr: float = 0.02,
muon_momentum: float = 0.95,
adam_lr: float = 3e-4,
adam_betas=(0.9, 0.95),
adam_eps: float = 1e-10,
weight_decay: float = 0.0,
):
if muon_selector is None:
muon_selector = (
lambda name, p: p.ndim >= 2
and "tok_embeddings" not in name
and "output" not in name
)
muon_params = [p for n, p in named_params if muon_selector(n, p)]
adam_params = [p for n, p in named_params if not muon_selector(n, p)]
muon_params.sort(key=lambda p: p.size(), reverse=True)
super().__init__(
[
dict(params=muon_params,
lr=muon_lr,
momentum=muon_momentum,
weight_decay=weight_decay,
use_muon=True),
dict(params=adam_params,
lr=adam_lr,
betas=adam_betas,
eps=adam_eps,
weight_decay=weight_decay,
use_muon=False),
],
defaults={}
)
What do you think?
Also cc @ebsmothers for thoughts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @krammnic, I think this is the easiest way to go. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, I've seen a similar approach in other implementations, like https://github.com/ethansmith2000/fsdp_optimizers/blob/main/fsdp_optimizers/muon.py
@joecummings @krammnic : I have added Muon to optim.py
There is a custom implementation of Adam in Muon class. I tried using the existing pytorch implementation with the view to use any pre-existing implementations of optimizer for linear layers. But this will not be possible due to load_state_dict() which supports only single optimizer to be stored. But, I believe load_state_dict() in OptimizerInBackward supports multiple optimizers. Please correct me if I am wrong. I have tried reducing the muon checks, but still there is a muon check in the main file. Also, I have updated the get_lr() method for returning the lr of Muon only. Please suggest if this is correct. |
Then, something is wrong, I don't like the fact that we have worse performance, because loss might be fixable with some HPO. Will review your changes tomorrow, to maybe find a bottleneck... |
Hi @joecummings @krammnic , I have updated the Muon optimizer implementation. For below image: As suggested earlier, switching to another implementation and playing around with HPO helped. |
Context
What is the purpose of this PR? Is it to
Please link to any issues this PR addresses.
#2725
Changelog
What are the changes made in this PR?
Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example