Commit Graph

764 Commits

Author SHA1 Message Date
Yuanyuan Chen
f91899ca6c [2/N] Add strict parameter to Python zip calls (#166257)
This PR adds `strict=True/False` to zip calls in test utils. strict=True is passed when possible.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166257
Approved by: https://github.com/janeyx99
2025-11-01 00:35:41 +00:00
PyTorch MergeBot
f60751024e Revert "[2/N] Add strict parameter to Python zip calls (#166257)"
This reverts commit 39e5cdddf7.

Reverted https://github.com/pytorch/pytorch/pull/166257 on behalf of https://github.com/atalman due to Failing: test/distributed/fsdp/test_fsdp_mixed_precision.py::TestFSDPTrainEval::test_train_ema_eval_flow [GH job link](https://github.com/pytorch/pytorch/actions/runs/18934047991/job/54057218160) [HUD commit link](39e5cdddf7) ([comment](https://github.com/pytorch/pytorch/pull/166257#issuecomment-3467955332))
2025-10-30 13:20:00 +00:00
Yuanyuan Chen
39e5cdddf7 [2/N] Add strict parameter to Python zip calls (#166257)
This PR adds `strict=True/False` to zip calls in test utils. strict=True is passed when possible.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166257
Approved by: https://github.com/janeyx99
2025-10-30 08:10:10 +00:00
Maggie Moss
d1a6e006e0 Fix syntax for pyrefly errors (#166496)
Last one! This ensures all existing suppressions match the syntax expected and will silence only one error code

pyrefly check
lintrunner

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166496
Approved by: https://github.com/Skylion007, https://github.com/mlazos
2025-10-29 20:00:25 +00:00
Yuanyuan Chen
a60d9e1f6d Fix flake8 B028 warnings (#166224)
This PR fixes flake8 B028 warning by specifying stacklevel=2 in `warnings.warn`. The advantage is that users can know more contextual information about PyTorch warnings.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166224
Approved by: https://github.com/ezyang
2025-10-26 06:18:55 +00:00
Maggie Moss
c7eee49525 Fix pyrefly ignores 1/n (#166239)
First diff adjusting the syntax for pyrefly: ignore suppressions so they only hide one class of type error.

Test:
lintrunner
pyrefly check

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166239
Approved by: https://github.com/oulgen
2025-10-26 00:44:10 +00:00
Maggie Moss
eb83c3ca23 Clean up unused Pyrefly suppressions (#166178)
Cleaning up ignores that are no longer needed in the repo and adding select suppressions so the main branch is clean.

test plan:
`lintrunner -a`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166178
Approved by: https://github.com/oulgen
2025-10-25 05:32:21 +00:00
mansiag05
f8fccb1e48 [Code Clean] Clean asserts in torch/optim. (#165629)
Replaces 50 assert statements across 15 files in torch.optim with explicit  if-checks raising AssertionError to prevent assertions from being disabled with Python -O flag.

fix partially #164878

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165629
Approved by: https://github.com/albanD
2025-10-23 15:56:29 +00:00
Maggie Moss
d795fb225a [RFC] Add pyrefly to lintrunner (#165179)
This will add pyrefly to lint runner as a warning only - and allow us to collect feedback about the tool before switching to pyrefly as the main type checker.

References the steps outlined here: : https://github.com/pytorch/pytorch/issues/163283:

test plan:
`lintrunner init`
`lintrunner`
confirm when pyrefly errors are present results look like: https://gist.github.com/maggiemoss/e6cb2d015dd1ded560ae1329098cf33f

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165179
Approved by: https://github.com/ezyang
2025-10-16 20:07:09 +00:00
Yuanyuan Chen
fbe0d20a17 [2/N] More ruff SIM fixes (#165031)
This is follow-up of #164695 to apply ruff SIM rules to more files. Most changes are about simplifying dict.get because None is already the default value.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165031
Approved by: https://github.com/mlazos
2025-10-14 14:22:54 +00:00
PyTorch MergeBot
b8be796a57 Revert "[2/N] More ruff SIM fixes (#165031)"
This reverts commit 38095fbd13.

Reverted https://github.com/pytorch/pytorch/pull/165031 on behalf of https://github.com/albanD due to One of the changed line started to fail on trunk ([comment](https://github.com/pytorch/pytorch/pull/165031#issuecomment-3390190870))
2025-10-10 13:42:14 +00:00
Yuanyuan Chen
38095fbd13 [2/N] More ruff SIM fixes (#165031)
This is follow-up of #164695 to apply ruff SIM rules to more files. Most changes are about simplifying dict.get because None is already the default value.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165031
Approved by: https://github.com/mlazos
2025-10-10 05:37:46 +00:00
Maggie Moss
086dec3235 Pyrefly suppressions 6/n (#164877)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Almost there!

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: delete lines in the pyrefly.toml file from the project-excludes field
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199

after:

INFO 0 errors (5,064 ignored)

Only four directories left to enable

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164877
Approved by: https://github.com/oulgen
2025-10-08 02:30:57 +00:00
Maggie Moss
b13cd141b3 Add pyrefly suppressions (#164748)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: delete lines in the pyrefly.toml file from the `project-excludes` field
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199

after:

0 errors (4,263 ignored)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164748
Approved by: https://github.com/oulgen
2025-10-07 17:31:18 +00:00
zeshengzong
fdc8ccc5bc Make Adam, AdamW work with nonzero-dim Tensor betas (#149939)
Fixes #147921

## Changes

- Convert tensor `betas` using `_to_scalar`
- Change annotation of `betas` param
- Change param type in docs

## Test Result

```bash
pytest -s test/test_optim.py -k test_tensor_lr -vv
```

![image](https://github.com/user-attachments/assets/312ee045-1e8b-4789-aa6e-ba63e6df7e81)

![image](https://github.com/user-attachments/assets/7e6ec274-645b-46b9-b1a6-2b340a685203)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149939
Approved by: https://github.com/janeyx99

Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
2025-10-06 22:03:25 +00:00
Maggie Moss
4ab847bbc7 Pyrefly suppressions 4/n (#164615)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: uncomment lines in the pyrefly.toml file
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/356645cf8cfe33123d9a27f23b30f7b1

after:

0 errors (2,753 ignored)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164615
Approved by: https://github.com/oulgen
2025-10-06 16:14:36 +00:00
PyTorch MergeBot
5d7360bb03 Revert "Enable all SIM rules except disabled ones (#164645)"
This reverts commit 321e602692.

Reverted https://github.com/pytorch/pytorch/pull/164645 on behalf of https://github.com/izaitsevfb due to causes lint failures ([comment](https://github.com/pytorch/pytorch/pull/164645#issuecomment-3369274351))
2025-10-05 19:32:21 +00:00
Yuanyuan Chen
321e602692 Enable all SIM rules except disabled ones (#164645)
`SIM` rules are useful for simplifying boolean expressions and enhances code readability.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164645
Approved by: https://github.com/ezyang
2025-10-05 07:38:25 +00:00
Yuanyuan Chen
a43c4c3972 [5/N] Apply ruff UP035 rule (#164423)
Continued code migration to enable ruff `UP035`. Most changes are about moving `Callable` from `typing` to `from collections.abc`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164423
Approved by: https://github.com/ezyang
2025-10-02 07:31:11 +00:00
Filip
2b6a74abf1 [optim] prevent unintended aliasing in lr_scheduler; update type annotations/docs (#163120)
1. Prevents unintended aliasing of `self._last_lr`/`get_last_lr(...)` with `group["lr"]` when `group["lr"]` is a tensor.
2. Prevents unintended aliasing of `LRScheduler.base_lrs` with the `group["initial_lr"]`s.
3. Updates `test/optim/test_lrscheduler.py` to test tensor LRs.
4. Changes type annotations for `_last_lr`, `get_last_lr()`, `base_lrs`, `get_lr()`, and `_get_closed_form_lr()` from `list[float]` to `list[float | Tensor]`; adds documentation.

Fixes #163103

LR schedulers can behave in unexpected ways when using a tensor LR due to patterns like this:
```python
self._last_lr: list[float] = [group["lr"] for group in self.optimizer.param_groups]
```

This PR adds a helper to address this:
```python
def _param_groups_val_list(optimizer: Optimizer, key: str) -> list[Any]:
    """Create a list containing group[key] for each optimizer param_group.
    Prevents aliasing when group[key] could be a Tensor.
    Raises a KeyError when group[key] does not exist.
    """
    return [
        group[key].clone() if isinstance(group[key], Tensor) else group[key]
        for group in optimizer.param_groups
    ]
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163120
Approved by: https://github.com/janeyx99
2025-09-25 06:58:58 +00:00
Filip
167ad09be5 [optim] override SWALR.state_dict and load_state_dict (#163122)
Fixes #163105

Note that the new `SWALR.load_state_dict` is **not backwards compatible**:
```python
@override
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
  """Load the scheduler's state.

  Args:
      state_dict (dict): scheduler state. Should be an object returned
          from a call to :meth:`state_dict`.
  """
  self.__dict__.update(state_dict)
  self._set_anneal_func(self._anneal_strategy)
```

If we'd like to maintain compatibility with old state_dicts (loaded with `weights_only=False`), we could use something along these lines:
```python
@override
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
    """Load the scheduler's state.

    Args:
        state_dict (dict): scheduler state. Should be an object returned
            from a call to :meth:`state_dict`.
    """
    anneal_func = state_dict.pop("anneal_func", None)
    strategy = state_dict.get("_anneal_strategy")
    self.__dict__.update(state_dict)

    if anneal_func is not None:
        state_dict["anneal_func"] = anneal_func
        if strategy is None:
            if anneal_func == self._linear_anneal:
                strategy = "linear"
            elif anneal_func == self._cosine_anneal:
                strategy = "cos"

    if strategy is None:
        strategy = getattr(self, "_anneal_strategy", "cos")

    self._set_anneal_func(strategy)
```

But given the fact that loading an `SWALR` state_dict before this PR would have caused an error, this seems okay. A GitHub/Google search for `SWALR.load_state_dict` had no results. Happy to change if not, or add a warning just in case.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163122
Approved by: https://github.com/janeyx99
2025-09-17 18:17:26 +00:00
Filip
bc38c5baa1 [optim] prevent problematic tensor aliasing in lr_scheduler (#163098)
Prevents edge cases in SequentialLR and ReduceLROnPlateau which could corrupt learning rates or trigger recompilation.

Supersedes #162360
Fixes #162359
Fixes #163093

While putting #162360 together, I noticed the class of issue I was fixing (i.e. unintended aliasing in lr_schedulers when using Tensor lrs) appeared in several other places. @janeyx99 suggested I put together a follow-up PR.

There are several bugs resembling the one fixed in #162360. I added a helper to fix these:
```python
def _update_param_group_val(param_group: dict[str, Any], key: str, val: float | Tensor):
    """Set param_group[key] to val without aliasing or assignment when they're both tensors.
    Raises a KeyError if param_group[key] does not exist.
    """
    if isinstance(param_group[key], Tensor):
        param_group[key].fill_(_to_scalar(val))
    else:
        param_group[key] = val
```

And applied it to fix bugs in `SequentialLR.__init__` and `LRScheduler._update_lr`. I also added it to `CyclicLR.__init__` which was using an equivalent pattern, and `CosineAnnealingWarmRestarts.step` which *should* have had a similar issue:
```python
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
    param_group["lr"] = lr
```

But did not, because `get_lr()` actually returns tensors when using a tensor lr (despite its `list[float]` return type annotation). Relying on this propagation seems fragile, so I conservatively added the method here as well. I'll be fixing the type annotations and several related issues in followup PRs built off of this one.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163098
Approved by: https://github.com/janeyx99
2025-09-17 13:40:23 +00:00
PyTorch MergeBot
a5419743c6 Revert "remove unnecessary sync point in AveragedModel update (#158017)"
This reverts commit cb7f45fd34.

Reverted https://github.com/pytorch/pytorch/pull/158017 on behalf of https://github.com/wdvr due to discussed with author - expecting this to break checkpointing ([comment](https://github.com/pytorch/pytorch/pull/158017#issuecomment-3301790645))
2025-09-17 08:02:02 +00:00
Gael Le Lan
cb7f45fd34 remove unnecessary sync point in AveragedModel update (#158017)
Summary:
The test `bool(self.n_averaged == 0)` is a CPU/GPU synchronization point that is called for each update.
This test is only meant to know whether the AveragedModel copy has been initialized or not.
This diff introduces a CPU-based variable for that purpose.
When loading from checkpoint we also make sure the parameter is refreshed.

After this fix, each `update_parameter` call is reduced to 6ms from 333ms (98% reduction).

Test Plan:
contbuild & OSS CI
Test plan from GitHub:
CI

Rollback Plan:

Differential Revision: D78074709

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158017
Approved by: https://github.com/janeyx99
2025-09-16 18:57:55 +00:00
zeshengzong
fa127d9b20 Fix LBFGS wolfe max iteration (#161488)
Fixes #91581 , based on #135026

## Test Result

```bash
pytest test/test_optim.py

.........
========================== 1473 passed, 242 skipped in 2412.49s (0:40:12) ===========================
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161488
Approved by: https://github.com/albanD
2025-09-16 12:07:50 +00:00
Kushagra Rastogi
cfc539fe15 Improved error lr last epoch (#162368)
Fixes #160626

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162368
Approved by: https://github.com/janeyx99
2025-09-15 23:33:14 +00:00
Parshant Sharma
4ad9fbc83a Unify TypeAlias definitions in optimizer.py (#161493)
Fixes #160834

This issue unifies TypeAlias definitions in [optimizer.py](https://github.com/pytorch/pytorch/blob/main/torch/optim/optimizer.py)

This ensures the following:

- Consistency and Standardization
- Enhanced IDE support
- Prevents runtime confusion

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161493
Approved by: https://github.com/Skylion007
2025-08-30 00:35:02 +00:00
zeshengzong
448a7e7e31 Fix SequentialLR deprecate warning about invoke step(epoch) (#149392)
Fixes #116776 #76113 #113222 #67958
## Changes

- Refactor `LRScheduler.step` method, leave `epoch` check logic in public method `step`
- Move update `lr` logic to `_update_lr` method
- Make `SequentialLR` use `_update_lr` to avoid unnecessary warning message

## Test Result

```bash
pytest test/optim/test_lrscheduler.py -vv
```

![image](https://github.com/user-attachments/assets/e1c5527e-193e-4328-bf95-023139ea0416)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149392
Approved by: https://github.com/janeyx99
2025-08-29 11:45:11 +00:00
Chuanhao Zhuge
e9d42b3880 [small][muon] Use addmm for Newton–Schulz orthogonalization (#161379)
A performance optimization. Using `torch.addmm`, which fuses `matrix multiply + scale + add` into one op.

**Benchmark**
In a QWEN-like 0.5B model training we observed average `optimizer.step()` latency speedup: matmul ~44.5 ms -> addmm ~27.4 ms: a **1.62×** speedup.

matmul
<img width="1403" height="600" alt="Screenshot 2025-08-24 at 3 15 37 PM" src="https://github.com/user-attachments/assets/a77a68d4-da3c-473a-97f0-e6ef0a3b46d9" />

addmm
<img width="1426" height="602" alt="Screenshot 2025-08-24 at 3 13 42 PM" src="https://github.com/user-attachments/assets/e493af36-44d3-4026-9f7c-fd0f9cdbc7e5" />

**Testing**
End-to-end training:
We used a training script that pre-trains a QWEN-like model on `openwebtext-100k` dataset. We trained for one epoch and the resulting loss curves show consistency between normal matmul and addmm.
<img width="1035" height="434" alt="Screenshot 2025-08-24 at 2 56 21 PM" src="https://github.com/user-attachments/assets/b96b13e3-0a01-4908-853c-d917b41f3d75" />

Unit test:

```python
    # dummy model and data
    model0 = Linear(10, 10, bias=False)
    model1 = copy.deepcopy(model0)
    inputs = torch.randn(8, 10)
    targets = torch.randn(8, 10)
    loss = MSELoss()

    lr = 1e-3
    wd = 0.1
    momentum = 0.95

    opt_ref_muon = Muon(
        params=model0.parameters(),
        lr=lr,
        weight_decay=wd,
        momentum=momentum,
        nesterov=nesterov,
        adjust_lr_fn="original",
    )

    opt_exp_muon = Muon(
        params=model1.parameters(),
        lr=lr,
        weight_decay=wd,
        momentum=momentum,
        nesterov=nesterov,
        adjust_lr_fn="original",
        use_addmm=True,
    )

    out_ref = model0(inputs)
    loss_ref = loss(out_ref, targets)
    opt_ref_muon.zero_grad()
    loss_ref.backward()
    opt_ref_muon.step()

    out_exp = model1(inputs)
    loss_exp = loss(out_exp, targets)
    opt_exp_muon.zero_grad()
    loss_exp.backward()
    opt_exp_muon.step()

    for p_ref, p_exp in zip(model0.parameters(), model1.parameters()):
        torch.testing.assert_close(p_ref, p_exp)
```

shows numeric difference, but this is expected on bf16 precision:
```
Mismatched elements: 96 / 100 (96.0%)
Greatest absolute difference: 8.985400199890137e-05 at index (1, 9) (up to 1e-06 allowed)
Greatest relative difference: 0.007370449136942625 at index (0, 6) (up to 1e-05 allowed)
```

~~Introduced a flag that allows users to opt in, as there are numerical differences relative to the original implementation.~~
Update: since `addmm` fuses the math ops, there are fewer intermediate roundings and is therefore more numerically accurate compared to the original form. Based on this, we opt to make `addmm` the default and only option.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161379
Approved by: https://github.com/janeyx99
2025-08-26 09:17:28 +00:00
Zesheng Zong
8c442e4fd3 Fix LBFGS warning convert a tensor with requires_grad=True to a scalar (#160389)
Fixes #160197

## Test Result

```python
In [1]: import warnings
   ...: warnings.simplefilter('error')
   ...: import torch
   ...: print(torch.__version__)
   ...: a, b = torch.rand((2, 32, 32))
   ...: a.requires_grad_()
   ...: optimizer = torch.optim.LBFGS([a])
   ...: loss_fn = lambda x, y: (x-y).pow(2).mean()
   ...:
   ...: def closure():
   ...:     optimizer.zero_grad()
   ...:     loss = loss_fn(a, b)
   ...:     loss.backward()
   ...:     return loss
   ...:
   ...: for i in range(100):
   ...:     optimizer.step(closure)
   ...:     print(i, loss_fn(a, b))
   ...:
2.9.0a0+gitf33f3f8
0 tensor(5.8066e-11, grad_fn=<MeanBackward0>)
1 tensor(5.8066e-11, grad_fn=<MeanBackward0>)
2 tensor(5.8066e-11, grad_fn=<MeanBackward0>)
3 tensor(5.8066e-11, grad_fn=<MeanBackward0>)
4 tensor(5.8066e-11, grad_fn=<MeanBackward0>)
5 tensor(5.8066e-11, grad_fn=<MeanBackward0>)
6 tensor(5.8066e-11, grad_fn=<MeanBackward0>)
7 tensor(5.8066e-11, grad_fn=<MeanBackward0>)
8 tensor(5.8066e-11, grad_fn=<MeanBackward0>)
9 tensor(5.8066e-11, grad_fn=<MeanBackward0>)
10 tensor(5.8066e-11, grad_fn=<MeanBackward0>)

...

```

```bash
pytest test/test_optim.py -vv

...
test/test_optim.py::TestOptimRenewedCUDA::test_tensor_lr_num_dim_2_NAdam_cuda_float32 PASSED [2.7192s]                                                                                                                                           [ 99%]
test/test_optim.py::TestOptimRenewedCUDA::test_tensor_lr_num_dim_2_RAdam_cuda_float32 PASSED [2.5370s]                                                                                                                                           [ 99%]
test/test_optim.py::TestOptimRenewedCUDA::test_tensor_lr_num_dim_2_RMSprop_cuda_float32 PASSED [2.0190s]                                                                                                                                         [ 99%]
test/test_optim.py::TestOptimRenewedCUDA::test_tensor_lr_num_dim_2_Rprop_cuda_float32 PASSED [1.8554s]                                                                                                                                           [ 99%]
test/test_optim.py::TestOptimRenewedCUDA::test_tensor_lr_num_dim_2_SGD_cuda_float32 PASSED [2.0433s]                                                                                                                                             [ 99%]
test/test_optim.py::TestOptimRenewedCUDA::test_tensor_lr_num_dim_2_SparseAdam_cuda_float32 PASSED [1.1788s]                                                                                                                                      [100%]

================== 1471 passed, 242 skipped in 2440.52s (0:40:40) ============
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160389
Approved by: https://github.com/janeyx99

Co-authored-by: albanD <desmaison.alban@gmail.com>
2025-08-26 03:07:47 +00:00
Chuanhao Zhuge
74280d0913 [muon] Introduce Muon optimizer to PyTorch (#160213)
A single-device version of Muon. Algorithm refers Keller Jordan's [Muon blogpost](https://kellerjordan.github.io/posts/muon/), and optionally incorporates [Moonshot's](https://github.com/MoonshotAI/Moonlight/blob/master/Moonlight.pdf) learning rate adjustment strategy.

This implementation maintains a minimalist API and is consistent with other optimizer conventions. PyTorch team prefers to handle parameter filtering at a higher level, with the Muon optimizer performing only the msign computation for orthogonalization on all parameters it receives. Users are responsible for grouping parameters for different optimizers as needed. An example usage is shown below, and a more detailed example will be added to the [PyTorch examples](https://github.com/pytorch/examples) directory.

**Usage**

```python
    model = MyModelForCausalLM
    # filter out your params manually
    muon_params = [...]
    adamw_params = [...]
    muon = Muon(
        params = muon_params
        lr=lr,
        wd=wd,
    )
    adamw = AdamW(
        params = adamw_params
        lr=lr,
        wd=wd,
    )

    # in training loop
    loss = model(input)
    loss.backward()
    muon.step()
    adamw.step()
    muon.zero_grad()
    adamw.zero_grad()
```

~~**Additional usage**~~
~~Users are also able to pass in self-defined `msign` function for orthogonalization, and learning rate adjustment function. Interface defined below:~~

```python
~~AdjustLrFn: TypeAlias = Callable[[float, torch.Size], float]~~
~~MsignFn: TypeAlias = Callable[[Tensor, BaseMsignFnConfig], Tensor]~~
```

As discussed with team and in comment, we prefer to make the interface simpler and cleaner, thus we removed the callback interface, and canonicalize the original NS algorithm for Muon. The only configs available to users are `ns_steps`, `coefficients`, and `eps`, configurable through kwargs.

By default, we use 5-step Newton-Schulz, with coefficients proposed by [Keller](https://kellerjordan.github.io/posts/muon/). We use LR adjustment proposed by [Moonshot](https://github.com/MoonshotAI/Moonlight/blob/master/Moonlight.pdf), which grafts learning rate from AdamW.

**Testing**

~~1. Unit tests: the newly introduced Muon is covered in `test/test_optim.py`. We updated the test cases to pass named parameters to the optimizer under test. Additionally, we introduced a new test case to verify that when the user provides an empty FQN list, Muon correctly falls back to AdamW behavior.~~

As discussed, in order not to complicate the codebase, we prefer not to include reference implementation into PyTorch. We also updated the interface so we don't need to test the FQN based filtering. Muon is covered by the existing `test_optim.py` unit test.

2. End-to-end test: we added a training script that pre-trains a QWEN-like model on `openwebtext-100k` dataset. We trained for one epoch and the resulting loss curve is compared against the Moonshot implementation to confirm behavioral consistency.
<img width="1102" height="472" alt="Screenshot 2025-07-29 at 1 04 12 AM" src="https://github.com/user-attachments/assets/ceab0733-497d-4070-8032-02ae7995c64c" />

**Numerics**
We evaluate our implementation with existing implementation to confirm numerical consistency.

As discussed, our implementation closely follows the algorithm described in [Keller's post](https://kellerjordan.github.io/posts/muon/), while incorporating the learning rate adjustment from [Moonlight](https://github.com/MoonshotAI/Moonlight/blob/master/Moonlight.pdf). This captures a key insight that allows users to reuse hyper-parameters tuned for `adamW`, making Muon a drop-in swap.

As expected, the numerics difference mainly comes from `adjust_lr`, a max of ~5% relative diff in an example unit test setup below.

```python
    # dummy model and data
    model0 = Linear(10, 10, bias=False)
    model1 = copy.deepcopy(model0)
    inputs = torch.randn(8, 10)
    targets = torch.randn(8, 10)
    loss = MSELoss()

    lr = 1e-3
    wd = 0.1
    momentum = 0.95

    opt_ref_muon = KellySingleDeviceMuon(
        params=model0.parameters(),
        lr=lr,
        weight_decay=wd,
        momentum=momentum,
    )

    opt_exp_muon = Muon(
        params=model1.parameters(),
        lr=lr,
        weight_decay=wd,
        momentum=momentum,
    )

    out_ref = model0(inputs)
    loss_ref = loss(out_ref, targets)
    opt_ref_muon.zero_grad()
    loss_ref.backward()
    opt_ref_muon.step()

    out_exp = model1(inputs)
    loss_exp = loss(out_exp, targets)
    opt_exp_muon.zero_grad()
    loss_exp.backward()
    opt_exp_muon.step()

    for p_ref, p_exp in zip(model0.parameters(), model1.parameters()):
        torch.testing.assert_close(p_ref, p_exp)
```

As explained above, including this `adjust_lr` is preferable. This is validated by an e2e training runs on training a qwen-2-like 0.5b model, where the curves show that training with `adjust_lr` converges more effectively than without.
<img width="1179" height="464" alt="Screenshot 2025-08-18 at 10 12 33 AM" src="https://github.com/user-attachments/assets/e797d3da-c2f0-4187-b99e-5d48b7437c3c" />

**Performance**
Training for one epoch of openwebtext-100k on eight H100 GPUs with DDP:

- adamw_ddp finishes in 13.12 min
- pytorch_muon_ddp finishes in 13.45 min

Muon runs ~20s slower compared to AdamW. Assuming no other changes, Muon is *2.5%* slower than AdamW.

AdamW: Optimizer.step() takes ~13.5 ms, step time ~930 ms
<img width="726" height="590" alt="Screenshot 2025-07-29 at 1 56 14 AM" src="https://github.com/user-attachments/assets/ebcd7e1c-d129-4b20-9396-39f568edf03d" />

Muon: Optimizer.step() takes ~54 ms, step time ~960 ms
<img width="751" height="597" alt="Screenshot 2025-07-29 at 2 02 20 AM" src="https://github.com/user-attachments/assets/72f5b904-ebd5-4502-a6ff-d3e9e5a6da81" />

**Note**
We restrict the implementation to accept only 2D parameters.

An alternative approach is to allow parameters with more than two dimensions and apply orthogonalization over the last two dimensions. We opt not to go with this approach as it can be error-prone. For example, with a kernel shaped `[in_channel, height, width, out_channel]`, applying orthogonalization to the last two dimensions is not meaningful.

Since Muon is designed to operate orthogonalization on 2D matrices, preserving this assumption keeps the implementation clean and sound.

**Next Steps**

1. Add `MuP`
2. Open-source optimized triton kernel for symmetric matmul. A preliminary benchmark found 1.23x - 1.48x speedup on small - large (n = 256 -> 16384) matrices.
3. Open-source unsharded Muon co-designed with FSDP2.

****

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160213
Approved by: https://github.com/janeyx99
2025-08-24 08:03:04 +00:00
zeshengzong
f8bd85827d Optimzie zero_grad description (#161239)
Optimize [zero_grad doc](https://docs.pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html) format and description.

## Test Result

### Before

<img width="996" height="534" alt="image" src="https://github.com/user-attachments/assets/e1db973c-57e8-4525-90e7-0500cde2263d" />

### After

<img width="890" height="496" alt="image" src="https://github.com/user-attachments/assets/5579c4fb-a857-4030-9303-34770083d1a5" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161239
Approved by: https://github.com/janeyx99
2025-08-22 06:18:25 +00:00
Jane Xu
9b803cdbe2 [BE] Remove more optim entries from docs coverage ignore list (#160194)
This PR does privatize ReduceLRSchedulerOnPlateau.is_better -> ReduceLRSchedulerOnPlateau._is_better because that API was never meant to be public. A GitHub search for it also reveals that the API is not commonly used much. https://github.com/search?q=.is_better%28&type=code&p=2

If you do use this API and you rely on it for some reason, please file an issue. In the meantime, you can access it through `_is_better(...)`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160194
Approved by: https://github.com/albanD, https://github.com/Skylion007
2025-08-09 00:09:45 +00:00
cyy
f6c89c1ef3 Detach tensor before clone in SGD optimiser and other code (#159204)
Reverse the pattern of tensor clone followed by detach in SGD and other code.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159204
Approved by: https://github.com/Skylion007
2025-07-27 03:31:12 +00:00
Jane Xu
7cc5d03dfc Document the rest of the specific optimizer module APIs (#158669)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158669
Approved by: https://github.com/albanD
ghstack dependencies: #158483
2025-07-19 07:27:15 +00:00
Jane Xu
f73594164a [BE] document Adadelta and Adagrad APIs properly (#158483)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158483
Approved by: https://github.com/albanD
2025-07-19 07:27:15 +00:00
Zeina Migeed
4f5be56612 [Pyrefly][Refactor] Replace dict() calls with literal dict syntax for improved readability (#157735)
There are 31 places that I spotted which construct literal dictionaries.

This PR refactors dictionary construction by replacing` dict(...) `calls with `literal {...}` syntax where applicable.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157735
Approved by: https://github.com/ezyang, https://github.com/Skylion007
2025-07-08 18:10:33 +00:00
Olga Gerasimova
2efa5eaa65 swa avoid stream sync (#157705)
Summary:
When AveragedModel updates_parameters it calls self.n_averaged == 0 for each parameter, where n_averated is a buffer on GPU. Moving check before the cycle to call sync once

It improves update_parameter from 74ms to 57ms ~22% improvement
{F1980011097}
{F1980011111}

Test Plan:
CI

Rollback Plan:

Differential Revision: D77723025

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157705
Approved by: https://github.com/albanD, https://github.com/Skylion007, https://github.com/janeyx99
2025-07-07 20:47:35 +00:00
Xuehai Pan
db259bd6b8 [BE][12/16] fix typos in torch/ (#156602)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156602
Approved by: https://github.com/justinchuby, https://github.com/albanD
ghstack dependencies: #156318, #156320
2025-07-02 22:55:29 +00:00
morrison-turnansky
2815ade9a8 updated adafactor doc #154862 (#155248)
updated adafactor doc to reflect difference in implementation vs original paper

Fixes #154862

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155248
Approved by: https://github.com/janeyx99

Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
2025-06-27 23:23:19 +00:00
Tom Ritchford
e2c9d8d641 Fix non-bitwise type annotations for Tensor operators (see #145838) (#146845)
Fix https://github.com/pytorch/pytorch/issues/145838

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146845
Approved by: https://github.com/Skylion007
2025-06-24 15:41:34 +00:00
Aaron Orenstein
e95e8eed0a mypy 1.16.0 (#155821)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155821
Approved by: https://github.com/ezyang, https://github.com/zou3519
2025-06-14 18:18:43 +00:00
Xuehai Pan
596b418391 [BE][PYFMT] migrate PYFMT for {torch,test}/{nn,optim}/** to ruff format (#144548)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144548
Approved by: https://github.com/ezyang
2025-06-14 11:27:04 +00:00
zeshengzong
d7a83ab67b Fix lr_scheduler unexpectedly calls step() when init argument last_epoch is larger than -1 (#149312)
Fixes #102261

## Changes

- Use flag `_is_initial` to replace `self.last_epoch == 0` condition to judge whether `lr` should be initial value
- Add test for `ExponentialLR` checkpoint usecase

## Test Result

```python
pytest -s test/optim/test_lrscheduler.py  -vv
```

![image](https://github.com/user-attachments/assets/6fd32bcc-b4fb-4421-b891-620bd4900dc1)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149312
Approved by: https://github.com/janeyx99

Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
2025-05-22 08:42:37 +00:00
zeshengzong
82dc3457e0 Add load_state_dict hint doc about invoke order work with lr_scheduler (#149942)
Fixes #119168

## Test Result

![image](https://github.com/user-attachments/assets/edb8124c-f103-475a-b903-20fbc71fdea6)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149942
Approved by: https://github.com/janeyx99

Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
2025-05-15 01:07:36 +00:00
Aaron Gokaslan
f05b38aa26 [BE]: Improve decorator typing for Optimizer subclasses (#153374)
Improves typing so that all the optimizer subclasses (which all of them that subtype step) do not erase their type signature when this decorator is used. Now *kwarg values and returns will propogate

This complements @tsunghsienlee PR #153367  as the type signature of step() was being erased on all the optimizer subclasses by this untyped decorator

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153374
Approved by: https://github.com/janeyx99, https://github.com/tsunghsienlee
2025-05-12 22:55:25 +00:00
Tsung-Hsien Lee
ea4b65ab60 Fix the type hint of step() with default value (#153367)
Summary: Because the default value of `closure` is `None`, this fixes the situation when `step()`. The previous typing (https://github.com/pytorch/pytorch/pull/102593) could only be used as `step(closure=None)` and `step(None)`.

Test Plan: contbuild & OSS CI

Differential Revision: D74560785

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153367
Approved by: https://github.com/cyyever, https://github.com/Skylion007, https://github.com/janeyx99
2025-05-12 15:52:59 +00:00
Jacobgoss30
6a8006472e Fix doc cosineannealinglr 152081 (#152936)
## Summary

This PR updates the docstring for `CosineAnnealingLR` to accurately reflect its recursive learning rate schedule. The previous docstring displayed only the SGDR closed-form expression, which doesn't match the actual recursive implementation in code.

Changes:

- Added the recursive update formula used in `get_lr()`
- Retained the original closed-form SGDR expression for reference
- Clarified that warm restarts are not implemented in this scheduler

This addresses confusion raised in issue #152081.

## Related issue

[#152081](https://github.com/pytorch/pytorch/issues/152081)

## Testing

Doc-only change. Ran pre-commit to verify formatting.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152936
Approved by: https://github.com/janeyx99
2025-05-08 17:25:30 +00:00
Jane Xu
3bc69cc08d Document that dampening is skipped in SGD momentum first step (#152833)
Pointed out by https://x.com/hi_tysam/status/1917318692276174977/photo/2.

It would be BC breaking to change this behavior 7 years after it has been decided, so we are documenting it first at the very least.

<img width="642" alt="image" src="https://github.com/user-attachments/assets/3febcb07-e0ed-44a1-bd3b-a8e685711cb4" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152833
Approved by: https://github.com/albanD
2025-05-05 20:07:23 +00:00
kyo
a21090a38c Fix incorrect citation of authors in documentation (#145209)
This PR corrects the citation of Adafactor authors "Noam Shazeer" and "Mitchell Stern" in the documentation.
The current text incorrectly lists them as "Shazeer, Noam, and Mitchell Stern," which seems to be a result of a data parsing issue of some reference manager(s) [as you can find many papers with the same issue](https://www.google.com/search?q=%22Shazeer%2C+Noam%2C+and+Mitchell+Stern%22).
The updated citation follows standard conventions for author names.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145209
Approved by: https://github.com/janeyx99
2025-05-05 17:45:05 +00:00