Skip to content

Integrate Muon optimizer (2725) #2803

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open

Conversation

Saurabh750
Copy link
Contributor

@Saurabh750 Saurabh750 commented Jun 8, 2025

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.
#2725

Changelog

What are the changes made in this PR?

  • Integrating Muon optimizer as a pytorch implementation in torchtune.
  • Modify recipes accordingly.

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings

Copy link

pytorch-bot bot commented Jun 8, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2803

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 8, 2025
@krammnic
Copy link
Collaborator

krammnic commented Jun 8, 2025

Thanks for the first pass! Let's split single device and distributed versions of Muon in 2 separate files to improve readability. Speaking about plots: We will need some comparison performance plots against AdamW and general Wandb plots (loss) and results on evaluation recipe.

@krammnic
Copy link
Collaborator

krammnic commented Jun 8, 2025

I've added few comments, but it looks great! There are 2 things on which we might need to think though:

  1. Can we reduce the amount of "muon checks"? Maybe some special wrapper similarly to a fused optimizer?
  2. Maybe we need to implement it in a little bit cleaner way and support through builders?

@joecummings joecummings self-requested a review June 9, 2025 17:40
@Saurabh750
Copy link
Contributor Author

Saurabh750 commented Jun 10, 2025

Based on above comments, 2 things came to my mind:

  1. As @krammnic suggested, I can implement a fused optimizer - a wrapper around Muon where the 2nd optimizer for linear layers will be of the choice of the user. This will eliminate the muon checks and will be a cleaner way.
  2. If we want to provide more flexibility while assigning optimizer, we can do it on parameter level.
    Eg: Inside config file:
optimizer:
  muon: [param1, param2]
  AdamW: []
muon:
  _component_: torchtune.modules.SingleDeviceMuon
  momentum: 0.95
  lr: 0.02
  weight_decay: 0
AdamW:
  _component_: bitsandbytes.optim.PagedAdamW
  lr: 1e-5

Muon will be assigned to only param1, param2 while AdamW will be assigned for all remaining ones.
Please do let me know your views, I'll implement whatever is suitable!

Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awesome, thanks for porting this over!

The TL;DR of my thoughts is that we should optimize for the AuxAdam varities of the Muon optimizer so that our code stays clean and organized. I made some suggestions so that the instantiation of these classes is easy from code or config.

@@ -0,0 +1,253 @@
######################################################
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move this to our optim.py file where OptimizerInBackward lives?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see Mark maybe suggested two files, but I'd prefer to limit everything to optim.py until we see the need for new files. Our library has also gotten a little bloated.

return buf1c / (buf2c.sqrt() + eps)


class MuonWithAuxAdam(torch.optim.Optimizer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should optimize for this interface so that users don't need to specify multiple optimizers in their code and we can remove the messy branching logic.

My rough proposal is:

class MuonWithAuxAdam(torch.optim.Optimizer):
    def __init__(
        self,
        params,
        *,
        muon_selector=None,
        muon_lr: float = 0.02,
        muon_momentum: float = 0.95,
        adam_lr: float = 3e-4,
        adam_betas=(0.9, 0.95),
        adam_eps: float = 1e-10,
        weight_decay: float = 0.0,
    ):        
       if muon_selector is None:
            muon_selector = (
                lambda name, p: p.ndim >= 2
                and "tok_embeddings" not in name
                and "output"        not in name
            )

        muon_params = [p for n, p in named_params if muon_selector(n, p)]
        adam_params = [p for n, p in named_params if not muon_selector(n, p)]

        muon_params.sort(key=lambda p: p.size(), reverse=True)

        super().__init__(
            [
                dict(params=muon_params,
                     lr=muon_lr,
                     momentum=muon_momentum,
                     weight_decay=weight_decay,
                     use_muon=True),
                dict(params=adam_params,
                     lr=adam_lr,
                     betas=adam_betas,
                     eps=adam_eps,
                     weight_decay=weight_decay,
                     use_muon=False),
            ],
            defaults={}
        )

What do you think?

Also cc @ebsmothers for thoughts.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @krammnic, I think this is the easiest way to go. WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, I've seen a similar approach in other implementations, like https://github.com/ethansmith2000/fsdp_optimizers/blob/main/fsdp_optimizers/muon.py

@Saurabh750
Copy link
Contributor Author

Saurabh750 commented Jun 17, 2025

@joecummings @krammnic : I have added Muon to optim.py
I tried finetuning on alpaca dataset using qwen2-0.5B and compared AdamW and Muon.
Trained for 5, 10, 20 epochs for batchsize 5 and 10.
For all the experiments, AdamW performed better than Muon. I am attaching snippet of one of the experiments with batch-size 10 and 20 epochs:
adamwvsMuon

  • Should I try running the same experiments on a different model? According to this, the results we are getting are on similar lines. AdamW performs better than Muon as the model was not pretrained using Muon.

There is a custom implementation of Adam in Muon class. I tried using the existing pytorch implementation with the view to use any pre-existing implementations of optimizer for linear layers. But this will not be possible due to load_state_dict() which supports only single optimizer to be stored. But, I believe load_state_dict() in OptimizerInBackward supports multiple optimizers. Please correct me if I am wrong.

I have tried reducing the muon checks, but still there is a muon check in the main file. Also, I have updated the get_lr() method for returning the lr of Muon only. Please suggest if this is correct.

@krammnic
Copy link
Collaborator

Then, something is wrong, I don't like the fact that we have worse performance, because loss might be fixable with some HPO. Will review your changes tomorrow, to maybe find a bottleneck...

@Saurabh750
Copy link
Contributor Author

Hi @joecummings @krammnic , I have updated the Muon optimizer implementation.
In below image:
Green -> Muon optimizer for first 20 epochs with 5e-4 lr
Blue -> Muon optimizer for epochs 20 to 40 with 5e-5 lr
image

For below image:
Blue -> Muon optimizer for epochs 20 to 40 with 5e-5 lr
Red -> AdamW optimizer for epochs 0 to 20 with 2e-5 lr

image

As suggested earlier, switching to another implementation and playing around with HPO helped.
In the lr_scheduler, I am only returning the Muon learning rate. Please let me know if anything else needs attention.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants