-
Notifications
You must be signed in to change notification settings - Fork 25.1k
[muon] Introduce Muon optimizer to PyTorch #160213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wooo the approach is much simpler indeed now--thank you for the speedy turnaround on the PR. The one main API question I have is how we handle the NS config (whether it should be in the constructor) which I've commented below. Everything else looks super solid.
I know you've put in amazing work for the benchmarks and correctness compared to the original Muon, and I trust that you have verified this PR is still correct and appropriately fast locally. I will look out for the separate PR with those scripts!
test/test_optim.py
Outdated
params = [weight, bias] | ||
if optim_cls.__name__ == "Muon": | ||
params = [weight] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit to not reassign
params = [weight, bias] | |
if optim_cls.__name__ == "Muon": | |
params = [weight] | |
params = [weight, bias] if optim_cls.__name__ != "Muon" else [weight] |
test/test_optim.py
Outdated
model = torch.nn.Sequential( | ||
torch.nn.Linear(10, 4, bias=False), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can just be one Linear then, right? Or maybe it'd be more indicative to add another Linear in there?
Can you add a comment for why we branch here?
test/test_optim.py
Outdated
@@ -1577,14 +1629,26 @@ def test_can_load_from_to_named_state_dict( | |||
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( | |||
device, dtype, optim_info, skip=("differentiable",) | |||
) | |||
|
|||
def _get_model_and_input(device, dtype, optim_cls): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's only have one version of this helper, it looks the same as above
torch/optim/_muon.py
Outdated
nesterov: bool = True, | ||
*, | ||
msign_fn: MsignFn = zeropower_via_newtonschulz, | ||
msign_fn_config: BaseMsignFnConfig = NewtonSchulzConfig(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the pro of having these config live in the constructor as a struct vs separate values? Is this because these values are only used if the msign_fb is zeropower_via_newtonschulz? If so, should this config not live in the Muon constructor at all but be customizable by the user input msign_fn? What are your thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also wonder if this config can be made in regular dict, accepted in constructor as Muon(..., msign_fn_config = {'eps' : 1e-5})
, and then just passed as self.msign_fn(..., **self.msign_fn_config)
- like so, it could be easier saved into a state_dict()
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it’s better to encapsulate the configs in a dedicated class, so the function signature stays clean and manageable. Just a preference carried over from my C++ days :)
I see Vadim's point but not sure if it's feasible (or necessity) to store the callable to state_dict in the first place.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another option is to have config set as simple args to the function, and then have the user override them via calling functools.partial
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @albanD regarding API design for best practices
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ho that's interesting.
I do agree that this doesn't match how we do APIs in PyTorch in general.
For value config, I would expect they're all passed in as an argument each (see other optimizers).
If you need to override some specific methods and behavior, you can either have a set of pre-defined implementations that a flag toggles between or you can subclass the optimizer to override the particular method you care about.
Also I guess I'm missing some context on why we want to do it this way if there is only one option for each right now?
5083654
to
0f0df7b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current CI failures are cuz you (probably accidentally) committed the third_party differences--pls remove those!
0f0df7b
to
2e6bf8c
Compare
uh, they must have come from the rebase. removed |
1c82fc8
to
654f754
Compare
torch/optim/_muon.py
Outdated
assert steps < 100, ( | ||
"Number of steps must be less than 100 for computational efficiency" | ||
) | ||
assert len(grad.shape) == 2, "Input tensor gradient must be a 2D matrix" | ||
assert len(coefficients) == 3, "Coefficients must be a tuple of exactly 3 values" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No plain asserts, please raise appropriate Runtime/Value/Type errors
torch/optim/_muon.py
Outdated
nesterov: bool = True, | ||
*, | ||
msign_fn: MsignFn = zeropower_via_newtonschulz, | ||
msign_fn_config: BaseMsignFnConfig = NewtonSchulzConfig(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ho that's interesting.
I do agree that this doesn't match how we do APIs in PyTorch in general.
For value config, I would expect they're all passed in as an argument each (see other optimizers).
If you need to override some specific methods and behavior, you can either have a set of pre-defined implementations that a flag toggles between or you can subclass the optimizer to override the particular method you care about.
Also I guess I'm missing some context on why we want to do it this way if there is only one option for each right now?
torch/optim/_muon.py
Outdated
has_complex: bool, | ||
) -> None: | ||
lr = _to_scalar(lr) | ||
assert has_complex is False, "Complex parameters are not supported" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's remove plain asserts here as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done!
Given that we had agreed to land the simplest single device Muon into torch/optim as our first step, it'd be clearest to land what people accept as the original implementation as defined in Keller Jordan's blog ( https://kellerjordan.github.io/posts/muon/). As this implementation chooses newton schulz as the algo, we should take the same stance. This means we can simplify the constructor API greatly (I will get to extensibility right after):
I'm realizing that you have interest in extending the algorithm to be distributed (vs another orthogonalization algo for single-device). We are strict on keeping |
thank you team for the thoughtful suggestion and thorough review! made the following update according to inputs, so that this version is simple and clear-
|
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
79b1460
to
7fc7a60
Compare
@pytorchbot rebase |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
A single-device version of Muon. Algorithm refers Keller Jordan's Muon blogpost, and optionally incorporates Moonshot's learning rate adjustment strategy.
This implementation maintains a minimalist API and is consistent with other optimizer conventions. PyTorch team prefers to handle parameter filtering at a higher level, with the Muon optimizer performing only the msign computation for orthogonalization on all parameters it receives. Users are responsible for grouping parameters for different optimizers as needed. An example usage is shown below, and a more detailed example will be added to the PyTorch examples directory.
Usage
Additional usageUsers are also able to pass in self-definedmsign
function for orthogonalization, and learning rate adjustment function. Interface defined below:As discussed with team and in comment, we prefer to make the interface simpler and cleaner, thus we removed the callback interface, and canonicalize the original NS algorithm for Muon. The only configs available to users are
ns_steps
,coefficients
, andeps
, configurable through kwargs.By default, we use 5-step Newton-Schulz, with coefficients proposed by Keller. We use LR adjustment proposed by Moonshot, which grafts learning rate from AdamW.
Testing
1. Unit tests: the newly introduced Muon is covered intest/test_optim.py
. We updated the test cases to pass named parameters to the optimizer under test. Additionally, we introduced a new test case to verify that when the user provides an empty FQN list, Muon correctly falls back to AdamW behavior.As discussed, in order not to complicate the codebase, we prefer not to include reference implementation into PyTorch. We also updated the interface so we don't need to test the FQN based filtering. Muon is covered by the existing
test_optim.py
unit test.openwebtext-100k
dataset. We trained for one epoch and the resulting loss curve is compared against the Moonshot implementation to confirm behavioral consistency.Numerics
We evaluate our implementation with existing implementation to confirm numerical consistency.
As discussed, our implementation closely follows the algorithm described in Keller's post, while incorporating the learning rate adjustment from Moonlight. This captures a key insight that allows users to reuse hyper-parameters tuned for
adamW
, making Muon a drop-in swap.As expected, the numerics difference mainly comes from
adjust_lr
, a max of ~5% relative diff in an example unit test setup below.As explained above, including this

adjust_lr
is preferable. This is validated by an e2e training runs on training a qwen-2-like 0.5b model, where the curves show that training withadjust_lr
converges more effectively than without.Performance
Training for one epoch of openwebtext-100k on eight H100 GPUs with DDP:
Muon runs ~20s slower compared to AdamW. Assuming no other changes, Muon is 2.5% slower than AdamW.
AdamW: Optimizer.step() takes ~13.5 ms, step time ~930 ms

Muon: Optimizer.step() takes ~54 ms, step time ~960 ms

Note
We restrict the implementation to accept only 2D parameters.
An alternative approach is to allow parameters with more than two dimensions and apply orthogonalization over the last two dimensions. We opt not to go with this approach as it can be error-prone. For example, with a kernel shaped
[in_channel, height, width, out_channel]
, applying orthogonalization to the last two dimensions is not meaningful.Since Muon is designed to operate orthogonalization on 2D matrices, preserving this assumption keeps the implementation clean and sound.
Next Steps
MuP
cc: @toothacher17, @vinaysrao, @jcui2, @haocizhang
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @jerryzh168 @mcarilli @ptrblck @leslie-fang-intel @voznesenskym @penguinwu @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @Lucaskabela