Skip to content

[muon] Introduce Muon optimizer to PyTorch #160213

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open

[muon] Introduce Muon optimizer to PyTorch #160213

wants to merge 23 commits into from

Conversation

chuanhaozhuge
Copy link
Contributor

@chuanhaozhuge chuanhaozhuge commented Aug 8, 2025

A single-device version of Muon. Algorithm refers Keller Jordan's Muon blogpost, and optionally incorporates Moonshot's learning rate adjustment strategy.

This implementation maintains a minimalist API and is consistent with other optimizer conventions. PyTorch team prefers to handle parameter filtering at a higher level, with the Muon optimizer performing only the msign computation for orthogonalization on all parameters it receives. Users are responsible for grouping parameters for different optimizers as needed. An example usage is shown below, and a more detailed example will be added to the PyTorch examples directory.

Usage

    model = MyModelForCausalLM
    # filter out your params manually
    muon_params = [...]
    adamw_params = [...]
    muon = Muon(
        params = muon_params
        lr=lr,
        wd=wd,
    )
    adamw = AdamW(
        params = adamw_params
        lr=lr,
        wd=wd,
    )

    # in training loop
    loss = model(input)
    loss.backward()
    muon.step()
    adamw.step()
    muon.zero_grad()
    adamw.zero_grad()

Additional usage
Users are also able to pass in self-defined msign function for orthogonalization, and learning rate adjustment function. Interface defined below:

~~AdjustLrFn: TypeAlias = Callable[[float, torch.Size], float]~~
~~MsignFn: TypeAlias = Callable[[Tensor, BaseMsignFnConfig], Tensor]~~

As discussed with team and in comment, we prefer to make the interface simpler and cleaner, thus we removed the callback interface, and canonicalize the original NS algorithm for Muon. The only configs available to users are ns_steps, coefficients, and eps, configurable through kwargs.

By default, we use 5-step Newton-Schulz, with coefficients proposed by Keller. We use LR adjustment proposed by Moonshot, which grafts learning rate from AdamW.

Testing

1. Unit tests: the newly introduced Muon is covered in test/test_optim.py. We updated the test cases to pass named parameters to the optimizer under test. Additionally, we introduced a new test case to verify that when the user provides an empty FQN list, Muon correctly falls back to AdamW behavior.

As discussed, in order not to complicate the codebase, we prefer not to include reference implementation into PyTorch. We also updated the interface so we don't need to test the FQN based filtering. Muon is covered by the existing test_optim.py unit test.

  1. End-to-end test: we added a training script that pre-trains a QWEN-like model on openwebtext-100k dataset. We trained for one epoch and the resulting loss curve is compared against the Moonshot implementation to confirm behavioral consistency.
Screenshot 2025-07-29 at 1 04 12 AM

Numerics
We evaluate our implementation with existing implementation to confirm numerical consistency.

As discussed, our implementation closely follows the algorithm described in Keller's post, while incorporating the learning rate adjustment from Moonlight. This captures a key insight that allows users to reuse hyper-parameters tuned for adamW, making Muon a drop-in swap.

As expected, the numerics difference mainly comes from adjust_lr, a max of ~5% relative diff in an example unit test setup below.

    # dummy model and data
    model0 = Linear(10, 10, bias=False)
    model1 = copy.deepcopy(model0)
    inputs = torch.randn(8, 10)
    targets = torch.randn(8, 10)
    loss = MSELoss()

    lr = 1e-3
    wd = 0.1
    momentum = 0.95

    opt_ref_muon = KellySingleDeviceMuon(
        params=model0.parameters(), 
        lr=lr,
        weight_decay=wd,
        momentum=momentum,
    )

    opt_exp_muon = Muon(
        params=model1.parameters(),
        lr=lr,
        weight_decay=wd,
        momentum=momentum,
    )

    out_ref = model0(inputs)
    loss_ref = loss(out_ref, targets)
    opt_ref_muon.zero_grad()
    loss_ref.backward()
    opt_ref_muon.step()

    out_exp = model1(inputs)
    loss_exp = loss(out_exp, targets)
    opt_exp_muon.zero_grad()
    loss_exp.backward()
    opt_exp_muon.step()

    for p_ref, p_exp in zip(model0.parameters(), model1.parameters()):
        torch.testing.assert_close(p_ref, p_exp)

As explained above, including this adjust_lr is preferable. This is validated by an e2e training runs on training a qwen-2-like 0.5b model, where the curves show that training with adjust_lr converges more effectively than without.
Screenshot 2025-08-18 at 10 12 33 AM

Performance
Training for one epoch of openwebtext-100k on eight H100 GPUs with DDP:

  • adamw_ddp finishes in 13.12 min
  • pytorch_muon_ddp finishes in 13.45 min

Muon runs ~20s slower compared to AdamW. Assuming no other changes, Muon is 2.5% slower than AdamW.

AdamW: Optimizer.step() takes ~13.5 ms, step time ~930 ms
Screenshot 2025-07-29 at 1 56 14 AM

Muon: Optimizer.step() takes ~54 ms, step time ~960 ms
Screenshot 2025-07-29 at 2 02 20 AM

Note
We restrict the implementation to accept only 2D parameters.

An alternative approach is to allow parameters with more than two dimensions and apply orthogonalization over the last two dimensions. We opt not to go with this approach as it can be error-prone. For example, with a kernel shaped [in_channel, height, width, out_channel], applying orthogonalization to the last two dimensions is not meaningful.

Since Muon is designed to operate orthogonalization on 2D matrices, preserving this assumption keeps the implementation clean and sound.

Next Steps

  1. Add MuP
  2. Open-source optimized triton kernel for symmetric matmul. A preliminary benchmark found 1.23x - 1.48x speedup on small - large (n = 256 -> 16384) matrices.
  3. Open-source unsharded Muon co-designed with FSDP2.

cc: @toothacher17, @vinaysrao, @jcui2, @haocizhang

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @jerryzh168 @mcarilli @ptrblck @leslie-fang-intel @voznesenskym @penguinwu @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @Lucaskabela

Copy link

pytorch-bot bot commented Aug 8, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/160213

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (7 Unrelated Failures)

As of commit 7fc7a60 with merge base 3e5b021 (image):

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wooo the approach is much simpler indeed now--thank you for the speedy turnaround on the PR. The one main API question I have is how we handle the NS config (whether it should be in the constructor) which I've commented below. Everything else looks super solid.

I know you've put in amazing work for the benchmarks and correctness compared to the original Muon, and I trust that you have verified this PR is still correct and appropriately fast locally. I will look out for the separate PR with those scripts!

Comment on lines 356 to 358
params = [weight, bias]
if optim_cls.__name__ == "Muon":
params = [weight]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit to not reassign

Suggested change
params = [weight, bias]
if optim_cls.__name__ == "Muon":
params = [weight]
params = [weight, bias] if optim_cls.__name__ != "Muon" else [weight]

Comment on lines 1568 to 1569
model = torch.nn.Sequential(
torch.nn.Linear(10, 4, bias=False),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can just be one Linear then, right? Or maybe it'd be more indicative to add another Linear in there?

Can you add a comment for why we branch here?

@@ -1577,14 +1629,26 @@ def test_can_load_from_to_named_state_dict(
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
device, dtype, optim_info, skip=("differentiable",)
)

def _get_model_and_input(device, dtype, optim_cls):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's only have one version of this helper, it looks the same as above

nesterov: bool = True,
*,
msign_fn: MsignFn = zeropower_via_newtonschulz,
msign_fn_config: BaseMsignFnConfig = NewtonSchulzConfig(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the pro of having these config live in the constructor as a struct vs separate values? Is this because these values are only used if the msign_fb is zeropower_via_newtonschulz? If so, should this config not live in the Muon constructor at all but be customizable by the user input msign_fn? What are your thoughts?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also wonder if this config can be made in regular dict, accepted in constructor as Muon(..., msign_fn_config = {'eps' : 1e-5}), and then just passed as self.msign_fn(..., **self.msign_fn_config) - like so, it could be easier saved into a state_dict()...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it’s better to encapsulate the configs in a dedicated class, so the function signature stays clean and manageable. Just a preference carried over from my C++ days :)

I see Vadim's point but not sure if it's feasible (or necessity) to store the callable to state_dict in the first place.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another option is to have config set as simple args to the function, and then have the user override them via calling functools.partial

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @albanD regarding API design for best practices

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ho that's interesting.
I do agree that this doesn't match how we do APIs in PyTorch in general.
For value config, I would expect they're all passed in as an argument each (see other optimizers).
If you need to override some specific methods and behavior, you can either have a set of pre-defined implementations that a flag toggles between or you can subclass the optimizer to override the particular method you care about.

Also I guess I'm missing some context on why we want to do it this way if there is only one option for each right now?

Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current CI failures are cuz you (probably accidentally) committed the third_party differences--pls remove those!

@chuanhaozhuge
Copy link
Contributor Author

The current CI failures are cuz you (probably accidentally) committed the third_party differences--pls remove those!

uh, they must have come from the rebase. removed

@chuanhaozhuge chuanhaozhuge force-pushed the muon_dev branch 2 times, most recently from 1c82fc8 to 654f754 Compare August 12, 2025 05:07
@chuanhaozhuge chuanhaozhuge marked this pull request as ready for review August 12, 2025 05:08
@chuanhaozhuge chuanhaozhuge requested a review from albanD as a code owner August 12, 2025 05:08
Comment on lines 61 to 65
assert steps < 100, (
"Number of steps must be less than 100 for computational efficiency"
)
assert len(grad.shape) == 2, "Input tensor gradient must be a 2D matrix"
assert len(coefficients) == 3, "Coefficients must be a tuple of exactly 3 values"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No plain asserts, please raise appropriate Runtime/Value/Type errors

nesterov: bool = True,
*,
msign_fn: MsignFn = zeropower_via_newtonschulz,
msign_fn_config: BaseMsignFnConfig = NewtonSchulzConfig(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ho that's interesting.
I do agree that this doesn't match how we do APIs in PyTorch in general.
For value config, I would expect they're all passed in as an argument each (see other optimizers).
If you need to override some specific methods and behavior, you can either have a set of pre-defined implementations that a flag toggles between or you can subclass the optimizer to override the particular method you care about.

Also I guess I'm missing some context on why we want to do it this way if there is only one option for each right now?

has_complex: bool,
) -> None:
lr = _to_scalar(lr)
assert has_complex is False, "Complex parameters are not supported"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove plain asserts here as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

@janeyx99
Copy link
Contributor

Given that we had agreed to land the simplest single device Muon into torch/optim as our first step, it'd be clearest to land what people accept as the original implementation as defined in Keller Jordan's blog ( https://kellerjordan.github.io/posts/muon/). As this implementation chooses newton schulz as the algo, we should take the same stance. This means we can simplify the constructor API greatly (I will get to extensibility right after):

  • Remove the msign_fn callable argument (the algorithm will just call NS by default today). OSS folks will not expect to pass in a callable to the constructor, so we will not accept the PR with a kwarg that takes in a callable. Instead, for customization, we can intake a string enum (more on this later).
  • Remove the struct definitions and flatten the kwargs as top-level keyword arguments to the constructor. Having layers of configs is confusing and unnecessarily abstracted. Since we are going in on NS being the algorithm for single device Muon, we can explicitly list out these kwargs in the constructor.
  • Move the algorithm description of NS into the Muon doc, so people can see immediately what algo they are calling.
  • I am remembering that the original test scripts comparing this implementation to Keller Jordan's had high atol and rtols, which is surprising as the two algorithms both use PyTorch ops in python, and so I'd expect the results to be the same. Could you link a standalone script that can be run to ascertain correctness and explain why the high atol/rtols are necessary (if they still are)? As we want to land a trustworthy impl for folks to try out, we need to ensure the accuracy results are expected.

I'm realizing that you have interest in extending the algorithm to be distributed (vs another orthogonalization algo for single-device). We are strict on keeping torch/optim code single-device runnable and maximally composable, so we cannot land anything distributed in torch/optim and I'd propose landing the distributed optimizer solution in https://github.com/pytorch/pytorch/tree/main/torch/distributed/optim. With that, I see the possible extension options as below:
a) If we are interested in other orthogonalization techniques for single device, we'd recommend using string enum kwargs similar to line_search_fn in https://docs.pytorch.org/docs/stable/generated/torch.optim.LBFGS.html#lbfgs, where the default None is NS, and other strings can represent other algorithms.
b) If we are attempting to extend in a distributed manner, and the code (state_dict, etc) is easily shareable, we'd recommend subclassing Muon into a new distributed optimizer in torch/distributed/optim. If the code ends up not being so shareable, it is perfectly acceptable to have Dion or a different optim class entirely living in torch/distributed/optim.

@chuanhaozhuge
Copy link
Contributor Author

thank you team for the thoughtful suggestion and thorough review! made the following update according to inputs, so that this version is simple and clear-

  1. removed the callback interface as suggested, and canonicalize the original NS algorithm for Muon. the only configs users can play with are ns_steps, coefficients, and eps, configurable through kwargs.
  2. updated the docstring to follow other optimizers.
  3. used RMS matching (adjust_lr) from Moonshot tech report to graft lr and weight decay tuned for adamw. Keller's impl is equivalent to Moonshot with some constraint.

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@chuanhaozhuge
Copy link
Contributor Author

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Tried to rebase and push PR #160213, but it was already up to date. Try rebasing against main by issuing:
@pytorchbot rebase -b main

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants