Commit Graph

675 Commits

Author SHA1 Message Date
PyTorch MergeBot
1d28b8b6d5 Revert "Deprecate torch._utils.is_compiling() and torch._dynamo.external_utils.is_compiling() (#127690)"
This reverts commit e84d1121ad.

Reverted https://github.com/pytorch/pytorch/pull/127690 on behalf of https://github.com/ZainRizvi due to Sorry but this is breaking internally. More details in D65483292 ([comment](https://github.com/pytorch/pytorch/pull/127690#issuecomment-2458381056))
2024-11-05 23:10:38 +00:00
Xuehai Pan
e84d1121ad Deprecate torch._utils.is_compiling() and torch._dynamo.external_utils.is_compiling() (#127690)
This PR is split from PR #126898.

- #126898

------

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127690
Approved by: https://github.com/Skylion007, https://github.com/malfet
2024-11-05 10:44:56 +00:00
bskrlj
8e27833e30 Ensure SWA boundary conditions w.r.t. definition (#133773)
According to the documentation, decay is a number in [0,1] range,[ i.e.](https://pytorch.org/docs/stable/optim.html)
```
Decay is a parameter between 0 and 1 that controls how fast the averaged parameters are decayed. If not provided to get_ema_multi_avg_fn, the default is 0.999.
```
An inspection of `swa_utils.py`  indicates there are no checks for invalid values of `decay`. Adding asserts as suggested in this PR ensures valid compute range (one way to enforce correct behavior, there are perhaps more suitable ones). Papers `torch` cites for reference idea/implementation also consider exclusively this range (e.g., https://arxiv.org/pdf/2310.04415).

Fixes https://github.com/pytorch/pytorch/issues/133772

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133773
Approved by: https://github.com/janeyx99
2024-10-31 18:24:08 +00:00
Tom Ritchford
c0582fd0f8 Remove unused Python variables in torch/[b-z]* (#136963)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136963
Approved by: https://github.com/ezyang
2024-10-19 16:45:22 +00:00
Matt Pitkin
8a5dd7f59b Allow SequentialLR to include ChainedScheduler (#133450)
This fixes #132745 and allows a `SequentialLR` to include schedulers that are compound scheduler types (i.e., a `ChainedScheduler`), which contain a list of schedulers in a `_schedulers` attribute.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133450
Approved by: https://github.com/janeyx99
2024-10-18 02:29:38 +00:00
Daniel Velkov
4abe38bc94 RMSprop docs: add missing input "epsilon" (#137854)
Adding a missing input argument in the docs for RMSprop. Like in the doc for AdamW https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137854
Approved by: https://github.com/janeyx99
2024-10-15 16:40:42 +00:00
ErezYosef
197601eeea Add Support for Tracking Parameter Names (named_parameters) in Optimizer State Dict (#134107)
A proposal addressing Issue #1489: **Optimizer should track parameter names and not id.**

(also mentioned in here: [[RFC] Introducing FQNs/clarity eyeglasses to optim state_dict](https://dev-discuss.pytorch.org/t/rfc-introducing-fqns-clarity-to-optim-state-dict/1552)

## Summary
This PR introduces a backward-compatible enhancement where optimizers track parameter names instead of just their id.
Optimizers can be initialized with `named_parameters()` as:
```python
optimizer = optim.SGD(model.named_parameters(), lr=0.01, momentum=0.9)
```
This allows for greater clarity and ease when handling optimizers, as the parameters' names are preserved within the optimizer’s `state_dict` as:
```
state_dict =
{
    'state': {
    0: {'momentum_buffer': tensor(...), ...},
    1: {'momentum_buffer': tensor(...), ...},
    },
    'param_groups': [
        {
        'lr': 0.01,
        'weight_decay': 0,
        ...
        'params': [0,1]
        'param_names' ['layer.weight', 'layer.bias']  (optional)
        }
    ]
}
```
Loading `state_dict` is not changed (backward-compatible) and the `param_names` key will be ignored.

## Key Features
#### Named Parameters in Optimizer Initialization:
Optimizers can accept the output of `model.named_parameters()` during initialization, allowing them to store parameter names directly.
#### Parameter Names in `state_dict`:
The parameter names are saved as a list in the optimizer’s `state_dict` with key `param_names`, alongside the `params` indices, ensuring seamless tracking of both names and parameters.

## Backward Compatibility
#### No Breaking Changes:
This change is fully backward-compatible. The added `param_names` key in the optimizer's `state_dict` is ignored when loading a state to the optimizer.

#### Customization with Hooks:
For more control, the loaded state_dict can be modified using a custom `register_load_state_dict_pre_hook`, providing flexibility for different design needs.

## Documentation Updates
Please refer to the documentation changes for more details on how this feature is implemented and how it can be used effectively.

## Solution Example:

A suggested solution to the problem mentioned in #1489, for the same parameters but in a different order.
The following `register_load_state_dict_pre_hook` should be added to the optimizer before loading to enable loading the state dict :
```python
def adapt_state_dict_ids(optimizer, state_dict):
    # assuming a single param group.
    current_state_group = optimizer.state_dict()['param_groups'][0]
    loaded_state_group = state_dict['param_groups'][0]

    # same number of params, same names, only different ordering
    current_state_name_to_id_mapping = {}  # mapping --  param_name: id
    for i, name in enumerate(current_state_group['param_names']):
        current_state_name_to_id_mapping[name] = current_state_group['params'][i]

    # changing the ids of the loaded state dict to match the order of the given state dict.
    for i, name in enumerate(current_state_group['param_names']):
        loaded_state_group['params'][i] = current_state_name_to_id_mapping[name]

    return state_dict
```
In this code, the loaded `state_dict` ids are adapted to match the order of the current optimizer `state_dict`.
Both the previous and the current optimizers are required to be initiated with `named_parameters()` to have the 'param_names' key in the dict.

### Note
This is my first contribution to PyTorch, and I wish to receive feedback or suggestions for improvement.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134107
Approved by: https://github.com/janeyx99

Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
2024-10-14 19:24:44 +00:00
Jane Xu
f9ed39c989 Autoupdate min_lrs for ReduceLROnPlateau if possible, fixes #104361 (#137637)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137637
Approved by: https://github.com/albanD
2024-10-10 01:23:30 +00:00
Jane Xu
972822dea1 Minorly reorder optim kwargs in docs, fixes #137391 (#137531)
Closes #137391

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137531
Approved by: https://github.com/albanD
2024-10-09 04:14:45 +00:00
Jane Xu
ddc7b6d0b4 Removes confusing note, addresses #38006 (#137535)
Fixes #38006

The note was originally added in https://github.com/pytorch/pytorch/pull/30257, which tried to ensure that the gradient wasn't modified in the optimizer. This note creates more confusion than is helpful, so removing it is better than leaving it in, especially because most uses of closure that I know _does_ modify the grads.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137535
Approved by: https://github.com/albanD
2024-10-09 04:00:38 +00:00
Jane Xu
b16167874d Minor SGD docs clarification fixing #137356, #137352 (#137528)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137528
Approved by: https://github.com/albanD
2024-10-08 23:05:08 +00:00
Sunishchal Dev
a8ed873ba2 Add missing input "eps" to adam docs (#135191)
Minor fix for missing input argument in the Adam optimizer docs page.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135191
Approved by: https://github.com/janeyx99
2024-09-25 20:17:23 +00:00
Mauricio Villegas
ece8267d2c Add back optim type hints that were lost when *.pyi files were removed (#136185)
When stub files (`*.pyi`) were removed from `optim` (#125556, #125452), some types that existed are no longer available. This pull request adds them back.

Just for reference, these types are used in `pytorch-lightning`'s `LightningCLI`. Command line interfaces are created automatically, and having type hints make them nicer.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136185
Approved by: https://github.com/janeyx99
2024-09-17 15:45:15 +00:00
Aaron Gokaslan
31715be72a [BE]: Update mypy to 1.11.2 (#133816)
Updates mypy to 1.11.1 to improve type inference

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133816
Approved by: https://github.com/ezyang
2024-09-16 19:44:11 +00:00
PyTorch MergeBot
3117f2cf67 Revert "[BE]: Update mypy to 1.11.2 (#133816)"
This reverts commit 55299cfc22.

Reverted https://github.com/pytorch/pytorch/pull/133816 on behalf of https://github.com/jeanschmidt due to seems to have broken https://github.com/pytorch/pytorch/actions/runs/10865710499/job/30155699792 on main ([comment](https://github.com/pytorch/pytorch/pull/133816#issuecomment-2352377684))
2024-09-16 09:11:16 +00:00
Aaron Gokaslan
55299cfc22 [BE]: Update mypy to 1.11.2 (#133816)
Updates mypy to 1.11.1 to improve type inference

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133816
Approved by: https://github.com/ezyang
2024-09-14 21:40:36 +00:00
Jane Xu
b1612569f6 [BE] Clarify defaulting behavior in optimizer (#135384)
Fixes #135340

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135384
Approved by: https://github.com/drisspg, https://github.com/jainapurva
2024-09-06 21:52:55 +00:00
Nikita Shulga
fb1c580892 [BE][optim] Make pyright recognize exported symbols (#135043)
Follows pattern introduced by https://github.com/pytorch/pytorch/pull/80955 which [pyright](https://github.com/microsoft/pyright) prefers over `__all__` symbol, see https://github.com/microsoft/pylance-release/issues/2953#issuecomment-1168956296
Fixes https://github.com/pytorch/pytorch/issues/134985

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135043
Approved by: https://github.com/janeyx99
2024-09-04 21:53:46 +00:00
Wil Kong
de06345e9b Avoid Host & Device Sync In LR Scheduler (#133663)
Fixes #133662.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133663
Approved by: https://github.com/janeyx99, https://github.com/eqy

Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
2024-08-22 03:52:43 +00:00
Sahdev Zala
06cc2e83f0 Make optim.swa.util content accessible from the torch.optim doc (#133393)
Link various classes and functions of the `optim.swa.util` to make doc content accessible from the `torch.optim` doc.

Currently, if you click the link,
https://pytorch.org/docs/stable/optim.html#module-torch.optim.swa_utils it goes to a blank, bottom of the page section of `torch.optim`.
Also,
`torch.optim.swa_utils.AveragedModel` and `torch.optim.swa_utils.SWALR` classes as well as `torch.optim.swa_utils.update_bn()` and `optim.swa_utils.get_ema_multi_avg_fn` are not linked to doc.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133393
Approved by: https://github.com/janeyx99
2024-08-21 00:43:46 +00:00
Masaki Kozuki
702c810780 move param's device check to _init_group for fused (#131153)
There could be some cases where the params have the meta device when calling optimizer's dunder init and those params are materialized in the first computation. This change would allow such situation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131153
Approved by: https://github.com/mlazos, https://github.com/janeyx99

Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
2024-08-17 04:49:47 +00:00
Jane Xu
c23dceb8f1 Add Adafactor foreach impl (#132336)
This PR adds the foreach impl for Adafactor knowing that there are many ways to improve its runtime perf today (by adding more foreach support). After this PR:
- we have a foreach flag for Adafactor
- It is NOT the default. Why not? It is only slightly faster + uses O(n) more memory where n is the number of params in your max param group. People tend to use Adafactor for memory efficiency.

Next steps:
- make torch.compile possible on it
- make it faster (by adding more foreach apis)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132336
Approved by: https://github.com/albanD
ghstack dependencies: #133360
2024-08-15 17:00:33 +00:00
Xuehai Pan
758a0a88a2 [BE][Easy] enable ruff rule PIE790: unnecessary pass statement (#133200)
This PR removes unnecessary `pass` statement. This is semanticly safe because the bytecode for the Python code does not change.

Note that if there is a docstring in the function, a empty function does not need a `pass` statement as placeholder.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133200
Approved by: https://github.com/malfet, https://github.com/eqy, https://github.com/kit1980
2024-08-15 15:50:19 +00:00
Jane Xu
14750dd737 Correct return type of grouping helper function in Optimizer (#133360)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133360
Approved by: https://github.com/albanD
2024-08-14 01:56:02 +00:00
Pierre Chapuis
0e4c0ef29f fix type of eta_min parameter in CosineAnnealing (int -> float) (#132482)
This fixes errors with type checkers such as `pyright`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132482
Approved by: https://github.com/janeyx99
2024-08-12 18:22:26 +00:00
PyTorch MergeBot
cbee9c1fd2 Revert "Deprecate torch._utils.is_compiling() and torch._dynamo.external_utils.is_compiling() (#127690)"
This reverts commit 0e7e61f7ce.

Reverted https://github.com/pytorch/pytorch/pull/127690 on behalf of https://github.com/kit1980 due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/127690#issuecomment-2272370386))
2024-08-07 00:05:20 +00:00
Xuehai Pan
0e7e61f7ce Deprecate torch._utils.is_compiling() and torch._dynamo.external_utils.is_compiling() (#127690)
This PR is split from PR #126898.

- #126898

------

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127690
Approved by: https://github.com/Skylion007, https://github.com/malfet
2024-08-03 09:43:38 +00:00
xinyu-intel
2ee9895304 Support optimizer capturable on hpu and xpu (#132119)
as title
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132119
Approved by: https://github.com/jgong5, https://github.com/janeyx99
2024-08-02 08:19:52 +00:00
Xuehai Pan
30293319a8 [BE][Easy][19/19] enforce style for empty lines in import segments in torch/[o-z]*/ (#129771)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129771
Approved by: https://github.com/justinchuby, https://github.com/janeyx99
2024-08-01 17:07:14 +00:00
Joel Schlosser
e6cddc9271 Fix public API tests (#131386)
This PR fixes a bug in `test_correct_module_names` introduced in #130497. It also addresses post-fix test failures in:
* `torch/ao/quantization/__init__.py` - set the correct `__module__` for several public API helpers
* `torch/library.py` - add `register_vmap` to `__all__`
* `torch/nn/attention/flex_attention.py` - make `round_up_to_multiple` private by prepending an underscore
* `torch/storage.py` - introduce `__all__` to avoid `Self` being re-exported as a public API
* `torch/distributed/pipelining/schedules.py` - add `ZeroBubbleAlgorithm` to `__all__`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131386
Approved by: https://github.com/albanD
2024-07-30 18:42:54 +00:00
Jane Xu
3816f6420a [BE] remove unnecessary _dispatch_sqrt by using ** 0.5 (#131358)
Based on the discussion here where ** 0.5 is not slower than math.sqrt. https://github.com/pytorch/pytorch/pull/129905#discussion_r1675605075

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131358
Approved by: https://github.com/albanD
2024-07-30 18:08:17 +00:00
PyTorch MergeBot
8f5cf46405 Revert "Fix public API tests (#131386)"
This reverts commit 91fcfd8760.

Reverted https://github.com/pytorch/pytorch/pull/131386 on behalf of https://github.com/clee2000 due to reverting this to revert something else, only action you should need to do is to rebase and merge again, sorry for the churn ([comment](https://github.com/pytorch/pytorch/pull/131386#issuecomment-2254327487))
2024-07-28 03:23:04 +00:00
Matthew Hoffman
fdf1451bfa Add __all__ to torch.optim to define public interface (#131959)
There was a regression in the public interface for `torch.optim` introduced in #125452 when `torch/optim/__init__.pyi` was merged into `torch/optim/__init__.py`. [The import aliases were not preserved and so now `pyright` thinks that these classes are not publicly exported from `torch/optim/__init__.py`.](https://github.com/pytorch/pytorch/pull/125452/files#diff-941595c1e1aa06bec94578499dd3654532a5183d0bc1bcd94d1f33b47e0d0adfL1-L15)

```
error: "SGD" is not exported from module "torch.optim"
```

Adding these classes/modules to `__all__` fixes this.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131959
Approved by: https://github.com/ezyang
2024-07-27 01:03:25 +00:00
Joel Schlosser
91fcfd8760 Fix public API tests (#131386)
This PR fixes a bug in `test_correct_module_names` introduced in #130497. It also addresses post-fix test failures in:
* `torch/ao/quantization/__init__.py` - set the correct `__module__` for several public API helpers
* `torch/library.py` - add `register_vmap` to `__all__`
* `torch/nn/attention/flex_attention.py` - make `round_up_to_multiple` private by prepending an underscore
* `torch/storage.py` - introduce `__all__` to avoid `Self` being re-exported as a public API
* `torch/distributed/pipelining/schedules.py` - add `ZeroBubbleAlgorithm` to `__all__`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131386
Approved by: https://github.com/albanD
2024-07-26 23:38:43 +00:00
PyTorch MergeBot
e4db5dc1c4 Revert "[BE] remove unnecessary _dispatch_sqrt by using ** 0.5 (#131358)"
This reverts commit 4c7f22dee2.

Reverted https://github.com/pytorch/pytorch/pull/131358 on behalf of https://github.com/janeyx99 due to Internal uses this private API and landing that has been a pain so we're reverting this first ([comment](https://github.com/pytorch/pytorch/pull/131358#issuecomment-2253190654))
2024-07-26 17:35:27 +00:00
PyTorch MergeBot
c9888c2739 Revert "[BE] typing for decorators - optim/optimizer (#131583)"
This reverts commit a1dad77dfa.

Reverted https://github.com/pytorch/pytorch/pull/131583 on behalf of https://github.com/atalman due to Breaks CI: [GH job link](https://github.com/pytorch/pytorch/actions/runs/10105959146/job/27947741162) [HUD commit link](a1dad77dfa) ([comment](https://github.com/pytorch/pytorch/pull/131583#issuecomment-2252784280))
2024-07-26 13:41:22 +00:00
Aaron Orenstein
a1dad77dfa [BE] typing for decorators - optim/optimizer (#131583)
See #131429
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131583
Approved by: https://github.com/janeyx99
ghstack dependencies: #131568, #131569, #131570, #131571, #131572, #131573, #131574, #131575, #131576, #131577, #131578, #131579, #131580, #131581, #131582
2024-07-26 05:00:07 +00:00
Jane Xu
9c4cf866c2 Adafactor forloop basic impl (#129905)
#109581

At this point, the vanilla implementation (the default) is good.
Docs: https://docs-preview.pytorch.org/pytorch/pytorch/129905/generated/torch.optim.Adafactor.html#torch.optim.Adafactor

Specifically, the impl in this PR, which attempts to replicate the paper,
```
optim = torch.optim.Adafactor([weight])
```
is close enough to https://pytorch-optimizers.readthedocs.io/en/latest/optimizer/#pytorch_optimizer.AdaFactor
```
optim_c = AdaFactor([weight], betas=(0, 0.999), scale_parameter=False)
```
is close enough to https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Adafactor
```
optim = keras.optimizers.Adafactor(learning_rate=0.01)
```

The three results respectively for the same randomly generated weights:
```
# ours
tensor([[ 0.3807594, -0.3912092],
        [ 0.0762539,  0.5377805],
        [ 0.2459473,  0.4662207]])

# pytorch-optimizer
tensor([[ 0.3807592, -0.3912172],
        [ 0.0762507,  0.5377818],
        [ 0.2459457,  0.4662213]])

# keras
array([[ 0.38076326, -0.39121315],
        [ 0.0762547 ,  0.5377859 ],
        [ 0.24594972,  0.46622536]], dtype=float32)
```

This gives me confidence to move forward in speeding up the implementation now that a baseline has been established. If you're curious about differences:
* keras assigns step_size (rho_t in their code) to `min(lr, 1 / sqrt(step)` whereas the OG impl uses a hardcoded 0.01 instead of lr. We do the same thing as keras, but our lr default is 0.01.
* We differ from the pytorch-optimizers default in that our default will not track momentum (thus `beta1=0`) and we do not apply parameter scaling.

<details>

Keras collab: https://colab.research.google.com/drive/1i3xF8ChL7TWKJGV_5v_5nMhXKnYmQQ06?usp=sharing

My script repro:

```
import torch
from pytorch_optimizer import AdaFactor
torch.set_printoptions(precision=7)

weight = torch.tensor([[ 0.37697506, -0.39500135],
        [ 0.07246649,  0.53399765],
        [ 0.24216151,  0.46243715]], dtype=torch.float32)
# bias = torch.tensor([0, 0], dtype=torch.float32)

weight.grad = torch.tensor([[-0.5940447, -0.7743838],
        [-0.5940447, -0.7743838],
        [-0.5940447, -0.7743838]], dtype=torch.float32)
# bias.grad = torch.tensor([-2.5027974,  1.5422692], dtype=torch.float32)

weight_c = weight.clone()
weight_c.grad = weight.grad.clone()

optim = torch.optim.Adafactor([weight])
optim.step()
print(weight)

optim_c = AdaFactor([weight_c], betas=(0, 0.999), scale_parameter=False)
optim_c.step()
print(weight_c)
```

<details>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129905
Approved by: https://github.com/albanD
2024-07-25 13:17:19 +00:00
Jane Xu
4c7f22dee2 [BE] remove unnecessary _dispatch_sqrt by using ** 0.5 (#131358)
Based on the discussion here where ** 0.5 is not slower than math.sqrt. https://github.com/pytorch/pytorch/pull/129905#discussion_r1675605075

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131358
Approved by: https://github.com/albanD
2024-07-24 14:58:57 +00:00
hxwang
276b5238ef [bug] Add is_compiling check for optimizers to avoid untracked tensor during graph tracing (#130909)
Hey folks, I was using the `stateless_func` [here](7c45476d38/torch/distributed/_spmd/api.py (L435)), which worked well before [this commit](https://github.com/pytorch/pytorch/pull/111084) but then introduced a `_tensor_constant0` and made this func non-stateless. Since there is no way to retrieve this constant tensor before compilation and performance is not an issue when tracing a graph, I think it might be good to fall back to the other branch.
![image](https://github.com/user-attachments/assets/6ee4487d-456b-47e0-8c1d-66cb5a641d47)

![image](https://github.com/user-attachments/assets/1ed46502-e50e-45c4-9751-49aa5a4590ae)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130909
Approved by: https://github.com/mlazos
2024-07-24 08:29:27 +00:00
Aaron Orenstein
5a0068cc69 [BE] mypy: disallow untyped decorators (#131428)
Untyped decorators strip the types from their decorated function so even if the underlying function is fully typed then callers to it don't get any benefit from type annotations.

Step 1 - Enable the error and override in all the offending files.

#131429

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131428
Approved by: https://github.com/justinchuby, https://github.com/oulgen
2024-07-23 21:50:55 +00:00
Li-Huai (Allan) Lin
99d9b369f4 [Optim] Support tensor lr for all optimizers and check it is 1-element (#131065)
Fixes: #130980
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131065
Approved by: https://github.com/janeyx99
2024-07-23 04:27:05 +00:00
Tianyi Tao
3477ee38e4 fix the use of initial learning rate in the OneCycleLR example (#130306)
Fixes #127649

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130306
Approved by: https://github.com/janeyx99
2024-07-09 18:58:07 +00:00
Li-Huai (Allan) Lin
8ec5ba960f [MPS] Add tensor_lr overloads to fused adam & adamw (#129451)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129451
Approved by: https://github.com/janeyx99
2024-07-02 19:46:30 +00:00
Michael Lazos
aa7ea6b45c Add wraps back (#129933)
Fixes https://github.com/pytorch/pytorch/issues/129922

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129933
Approved by: https://github.com/eqy, https://github.com/janeyx99
2024-07-02 18:24:02 +00:00
Sahdev Zala
9795dba1e0 Optim package docstring fix (#129086)
Fix docstrings in various files in optim package. This is a last remaining fix for the issue #112593

The fix can be verified by running pydocstyle path-to-file --count

Fixes #112593

Related #128248

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129086
Approved by: https://github.com/janeyx99
2024-06-21 14:30:53 +00:00
Li-Huai (Allan) Lin
9a7e2519d3 [MPS] Fused Adam & AdamW (#127242)
Summary:

This PR adds fused Adam and AdamW implementations.

Benchmark on Macbook Pro with M1 Max chip and 64GB unified memory:
**Fast math enabled:**
```
[---------------------------------------------- Fused Adam ----------------------------------------------]
                                                                           |  Fused: True  |  Fused: False
1 threads: -----------------------------------------------------------------------------------------------
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 100        |       10      |       100
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 100       |        9      |        89
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 100       |        9      |        90
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 100      |        9      |        83
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 100       |       12      |        94
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 100      |       11      |        88
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 100      |       12      |        90
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 100     |       11      |       100
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 100     |       27      |       100
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 100    |       23      |       100
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 100    |       27      |       100
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 100   |       23      |        98
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 500        |       82      |       480
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 500       |       72      |       450
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 500       |       82      |       450
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 500      |       73      |       420
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 500       |       91      |       500
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 500      |       83      |       400
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 500      |       94      |       500
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 500     |       78      |       400
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 500     |      170      |       500
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 500    |      140      |       600
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 500    |      170      |       600
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 500   |      140      |       500
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 1000       |      250      |       890
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 1000      |      220      |       850
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 1000      |      250      |       830
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 1000     |      220      |       770
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 1000      |      270      |       870
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 1000     |      230      |       840
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 1000     |      270      |       810
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 1000    |      240      |       800
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 1000    |      400      |      1000
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 1000   |      360      |      2000
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 1000   |      430      |      2000
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 1000  |      360      |      1300

Times are in milliseconds (ms).
```

**Fast math disabled:**
```
[---------------------------------------------- Fused Adam ----------------------------------------------]
                                                                           |  Fused: True  |  Fused: False
1 threads: -----------------------------------------------------------------------------------------------
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 100        |       10      |       100
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 100       |        9      |        84
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 100       |        9      |        84
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 100      |        9      |        79
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 100       |       11      |        93
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 100      |       10      |        90
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 100      |       11      |        91
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 100     |       11      |        81
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 100     |       34      |       100
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 100    |       31      |       100
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 100    |       34      |        95
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 100   |       31      |       100
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 500        |       94      |       500
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 500       |       82      |       430
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 500       |       92      |       430
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 500      |       81      |       390
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 500       |       98      |       500
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 500      |       88      |       430
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 500      |      100      |       500
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 500     |       88      |       400
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 500     |      210      |       500
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 500    |      190      |       610
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 500    |      210      |       510
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 500   |      190      |       500
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 1000       |      300      |       900
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 1000      |      260      |       850
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 1000      |      295      |       900
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 1000     |      260      |       800
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 1000      |      320      |       910
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 1000     |      280      |       900
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 1000     |      320      |       900
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 1000    |      300      |       900
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 1000    |      500      |      2000
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 1000   |      480      |      2000
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 1000   |      540      |      1500
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 1000  |      480      |      1200

Times are in milliseconds (ms).
```

```python
def profile_fused_adam():
    from torch.optim import adam, adamw
    import torch.utils.benchmark as benchmark

    import itertools

    def profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused):
        fn(
            params,
            grads,
            exp_avgs,
            exp_avg_sqs,
            max_exp_avg_sqs,
            state_steps,
            foreach=False,
            capturable=False,
            fused=fused,
            amsgrad=amsgrad,
            beta1=0.9,
            beta2=0.99,
            lr=1e-3,
            weight_decay=.0,
            eps=1e-5,
            maximize=False,
            grad_scale=None,
            found_inf=None,
        )
        torch.mps.synchronize()

    device = "mps"

    results = []

    for num_tensors, numel, adamWflag, amsgrad in itertools.product([100, 500, 1000], [1024, 65536, 1048576], [True, False], [True, False]):
        print(f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}")
        params, grads, exp_avgs, exp_avg_sqs = [[torch.arange(numel, dtype=torch.float32, device=device) + (numel * i) for i in range(num_tensors)] for _ in range(4)]
        max_exp_avg_sqs = [torch.arange(numel, dtype=torch.float32, device=device) for _ in range(num_tensors)] if amsgrad else []
        state_steps = [torch.tensor([5], dtype=torch.float32, device=device) for _ in range(num_tensors)]
        if adamWflag:
            fn = adamw.adamw
        else:
            fn = adam.adam

        for fused in [True, False]:

            t = benchmark.Timer(
                    stmt='profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused)',
                    label='Fused Adam',
                    sub_label=f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}",
                    globals=locals(),
                    description= f"Fused: {fused}",
                ).blocked_autorange(min_run_time=5)
            results.append(t)

    compare = benchmark.Compare(results)
    compare.trim_significant_figures()
    compare.colorize(rowwise=True)
    compare.print()
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127242
Approved by: https://github.com/kulinseth, https://github.com/janeyx99
2024-06-18 19:59:50 +00:00
Ahmed Gheith
3dd5f0ecbb Remove circular import (#128875)
Summary: A spurious import is causing circular dependency errors

Test Plan: phabricator signals

Differential Revision: D58685676

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128875
Approved by: https://github.com/kit1980
2024-06-18 12:30:13 +00:00
Sahdev Zala
4ccbf711e2 Learning Rate Scheduler docstring fix (#128679)
Fix docstrings in Learning Rate Scheduler.

The fix can be verified by running pydocstyle path-to-file --count

Related #112593

**BEFORE the PR:**
pydocstyle torch/optim/lr_scheduler.py --count

92


**AFTER the PR:**
pydocstyle torch/optim/lr_scheduler.py --count

0

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128679
Approved by: https://github.com/janeyx99
2024-06-15 05:30:35 +00:00
Xuehai Pan
dcc0093dba [BE][Easy] export explicitly imported public submodules (#127703)
Add top-level submodules `torch.{storage,serialization,functional,amp,overrides,types}`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127703
Approved by: https://github.com/ezyang
2024-06-12 05:52:18 +00:00