Commit Graph

45 Commits

Author SHA1 Message Date
Michael Lazos
0f02e0aa39 Disable dynamo on functional optims if capturable=False (#123619)
This resolves a bug in eager where if an old state dict is loaded (without the capturable flag) but the original dict had the capturable flag, then state_steps would be on cuda but we would take the non-capturable path. We now fallback to eager if capturable=False.

Current design doc and discussion: https://docs.google.com/document/d/1DmmbiaSp16CDZtGw1qzXKHFTY_0gqc0xpnBdviXq0vk/edit#heading=h.871u7bvwz7ze

Note on the actual fallback logic - there was an issue with torchscript originally not handling *args, **kwargs properly, after rectifying that by using `functools.wraps`, there was an additional bug with scoping which required the single tensor implementation to be in the global scope at the time of the fallback closure being created. I pass in the single tensor function to the `_disable_dynamo_if_unsupported` decorator to workaround this bug.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123619
Approved by: https://github.com/janeyx99
2024-05-07 22:17:01 +00:00
Michael Lazos
787afc5180 Add LR as tensor tests (#123750)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123750
Approved by: https://github.com/janeyx99
2024-05-01 04:46:49 +00:00
FFFrog
ac74a6783b Part 2: UFMT fix 2 files in torch/optim due to the pr-sanity-checks (#124054)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124054
Approved by: https://github.com/ezyang
ghstack dependencies: #124048, #124053
2024-04-16 03:20:21 +00:00
Jane Xu
b5ba80828f [optim] Rectify capturable testing and fix bugs! (#118326)
This PR fixes several bugs, listed in priority:
1. `load_state_dict` with a nontensor step was incorrect for capturable and fused implementations since we don't create the tensors on the right device in `__setstate__`. This has been fixed.
2. The most recently added capturable implementations forgot the check that all tensors should be on CUDA for eager. We've now added those checks
3. The most recent change in Adamax only adds capturable for foreach but will silently be incorrect for forloop/single-tensor. I've added erroring and modified testing with many many many skips for that. Honestly my preference after this PR has only been further cemented  that we should just do the single tensor and multi tensor capturable implementations together in the future. @mlazos
4. The conditional for adding cuda-supported configs for the optimizer infos was incorrect! So we hadn't been testing capturable! This also stands rectified and was the trigger for this PR in the first place.
5. In a similar way, the conditional for `_get_optim_inputs_including_global_cliquey_kwargs` was incorrect sometimes as well. This has also been corrected.

The following is not a bug, but is just something to make life simpler by not needing to handle Nones: `optim_input_funcs` must now mandatorily take in a `device`, which could be a string or a torch.device.

Details for posterity:
4. Running the test_foreach_matches_forloop test and printing the configs that get printed yields capturable getting included, which is correct.
```
(pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (5d50138f)]$ python test/test_optim.py -k test_foreach_matches_forloop_AdamW_cuda
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0
  warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
params=None, kwargs={}, desc=default
params=None, kwargs={'lr': 0.01}, desc=non-default lr
params=None, kwargs={'weight_decay': 0.1}, desc=nonzero weight_decay
params=None, kwargs={'weight_decay': 0.1, 'maximize': True}, desc=maximize
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True}, desc=amsgrad
params=None, kwargs={'capturable': True}, desc=capturable
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True}, desc=capturable, amsgrad
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True}, desc=Tensor lr with capturable and amsgrad
.
----------------------------------------------------------------------
Ran 1 test in 19.229s

OK
```
5. Running the test_optimizer_can_be_printed test (which calls `_get_optim_inputs_including_global_cliquey_kwargs`) and printing what gets run is also now correct.
```
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0
  warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
params=None, kwargs={'differentiable': False}, desc=default
params=None, kwargs={'differentiable': True}, desc=default & differentiable
params=None, kwargs={'lr': 0.01, 'differentiable': False}, desc=non-default lr
params=None, kwargs={'lr': 0.01, 'differentiable': True}, desc=non-default lr & differentiable
params=None, kwargs={'weight_decay': 0.1, 'differentiable': False}, desc=nonzero weight_decay
params=None, kwargs={'weight_decay': 0.1, 'differentiable': True}, desc=nonzero weight_decay & differentiable
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'differentiable': False}, desc=maximize
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'differentiable': True}, desc=maximize & differentiable
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'differentiable': False}, desc=amsgrad
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'differentiable': True}, desc=amsgrad & differentiable
.params=None, kwargs={'foreach': False, 'differentiable': False, 'fused': False}, desc=default
params=None, kwargs={'foreach': True, 'differentiable': False, 'fused': False}, desc=default & foreach
params=None, kwargs={'foreach': False, 'differentiable': True, 'fused': False}, desc=default & differentiable
params=None, kwargs={'foreach': False, 'differentiable': False, 'fused': True}, desc=default & fused
params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': False, 'fused': False}, desc=non-default lr
params=None, kwargs={'lr': 0.01, 'foreach': True, 'differentiable': False, 'fused': False}, desc=non-default lr & foreach
params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': True, 'fused': False}, desc=non-default lr & differentiable
params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': False, 'fused': True}, desc=non-default lr & fused
params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': False, 'fused': False}, desc=nonzero weight_decay
params=None, kwargs={'weight_decay': 0.1, 'foreach': True, 'differentiable': False, 'fused': False}, desc=nonzero weight_decay & foreach
params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': True, 'fused': False}, desc=nonzero weight_decay & differentiable
params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': False, 'fused': True}, desc=nonzero weight_decay & fused
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=maximize
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=maximize & foreach
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=maximize & differentiable
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=maximize & fused
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=amsgrad
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=amsgrad & foreach
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=amsgrad & differentiable
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=amsgrad & fused
params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=capturable
params=None, kwargs={'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=capturable & foreach
params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=capturable & differentiable
params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=capturable & fused
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=capturable, amsgrad
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=capturable, amsgrad & foreach
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=capturable, amsgrad & differentiable
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=capturable, amsgrad & fused
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=Tensor lr with capturable and amsgrad
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=Tensor lr with capturable and amsgrad & foreach
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=Tensor lr with capturable and amsgrad & differentiable
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=Tensor lr with capturable and amsgrad & fused
.
----------------------------------------------------------------------
Ran 2 tests in 11.112s

OK
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118326
Approved by: https://github.com/mlazos
2024-02-02 19:13:00 +00:00
PyTorch MergeBot
2964170f3a Revert "[optim] Rectify capturable testing and fix bugs! (#118326)"
This reverts commit d947b9d500.

Reverted https://github.com/pytorch/pytorch/pull/118326 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it looks like there are some relevant failures in trunk d947b9d500, may be a land race ([comment](https://github.com/pytorch/pytorch/pull/118326#issuecomment-1923125676))
2024-02-02 07:08:14 +00:00
Jane Xu
d947b9d500 [optim] Rectify capturable testing and fix bugs! (#118326)
This PR fixes several bugs, listed in priority:
1. `load_state_dict` with a nontensor step was incorrect for capturable and fused implementations since we don't create the tensors on the right device in `__setstate__`. This has been fixed.
2. The most recently added capturable implementations forgot the check that all tensors should be on CUDA for eager. We've now added those checks
3. The most recent change in Adamax only adds capturable for foreach but will silently be incorrect for forloop/single-tensor. I've added erroring and modified testing with many many many skips for that. Honestly my preference after this PR has only been further cemented  that we should just do the single tensor and multi tensor capturable implementations together in the future. @mlazos
4. The conditional for adding cuda-supported configs for the optimizer infos was incorrect! So we hadn't been testing capturable! This also stands rectified and was the trigger for this PR in the first place.
5. In a similar way, the conditional for `_get_optim_inputs_including_global_cliquey_kwargs` was incorrect sometimes as well. This has also been corrected.

The following is not a bug, but is just something to make life simpler by not needing to handle Nones: `optim_input_funcs` must now mandatorily take in a `device`, which could be a string or a torch.device.

Details for posterity:
4. Running the test_foreach_matches_forloop test and printing the configs that get printed yields capturable getting included, which is correct.
```
(pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (5d50138f)]$ python test/test_optim.py -k test_foreach_matches_forloop_AdamW_cuda
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0
  warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
params=None, kwargs={}, desc=default
params=None, kwargs={'lr': 0.01}, desc=non-default lr
params=None, kwargs={'weight_decay': 0.1}, desc=nonzero weight_decay
params=None, kwargs={'weight_decay': 0.1, 'maximize': True}, desc=maximize
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True}, desc=amsgrad
params=None, kwargs={'capturable': True}, desc=capturable
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True}, desc=capturable, amsgrad
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True}, desc=Tensor lr with capturable and amsgrad
.
----------------------------------------------------------------------
Ran 1 test in 19.229s

OK
```
5. Running the test_optimizer_can_be_printed test (which calls `_get_optim_inputs_including_global_cliquey_kwargs`) and printing what gets run is also now correct.
```
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0
  warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
params=None, kwargs={'differentiable': False}, desc=default
params=None, kwargs={'differentiable': True}, desc=default & differentiable
params=None, kwargs={'lr': 0.01, 'differentiable': False}, desc=non-default lr
params=None, kwargs={'lr': 0.01, 'differentiable': True}, desc=non-default lr & differentiable
params=None, kwargs={'weight_decay': 0.1, 'differentiable': False}, desc=nonzero weight_decay
params=None, kwargs={'weight_decay': 0.1, 'differentiable': True}, desc=nonzero weight_decay & differentiable
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'differentiable': False}, desc=maximize
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'differentiable': True}, desc=maximize & differentiable
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'differentiable': False}, desc=amsgrad
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'differentiable': True}, desc=amsgrad & differentiable
.params=None, kwargs={'foreach': False, 'differentiable': False, 'fused': False}, desc=default
params=None, kwargs={'foreach': True, 'differentiable': False, 'fused': False}, desc=default & foreach
params=None, kwargs={'foreach': False, 'differentiable': True, 'fused': False}, desc=default & differentiable
params=None, kwargs={'foreach': False, 'differentiable': False, 'fused': True}, desc=default & fused
params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': False, 'fused': False}, desc=non-default lr
params=None, kwargs={'lr': 0.01, 'foreach': True, 'differentiable': False, 'fused': False}, desc=non-default lr & foreach
params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': True, 'fused': False}, desc=non-default lr & differentiable
params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': False, 'fused': True}, desc=non-default lr & fused
params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': False, 'fused': False}, desc=nonzero weight_decay
params=None, kwargs={'weight_decay': 0.1, 'foreach': True, 'differentiable': False, 'fused': False}, desc=nonzero weight_decay & foreach
params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': True, 'fused': False}, desc=nonzero weight_decay & differentiable
params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': False, 'fused': True}, desc=nonzero weight_decay & fused
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=maximize
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=maximize & foreach
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=maximize & differentiable
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=maximize & fused
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=amsgrad
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=amsgrad & foreach
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=amsgrad & differentiable
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=amsgrad & fused
params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=capturable
params=None, kwargs={'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=capturable & foreach
params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=capturable & differentiable
params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=capturable & fused
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=capturable, amsgrad
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=capturable, amsgrad & foreach
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=capturable, amsgrad & differentiable
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=capturable, amsgrad & fused
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=Tensor lr with capturable and amsgrad
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=Tensor lr with capturable and amsgrad & foreach
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=Tensor lr with capturable and amsgrad & differentiable
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=Tensor lr with capturable and amsgrad & fused
.
----------------------------------------------------------------------
Ran 2 tests in 11.112s

OK
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118326
Approved by: https://github.com/mlazos
2024-02-02 02:02:58 +00:00
Jane Xu
924f1b841a [optim] Allow torch.float64 scalars for forloop + foreach implementations (#115841)
Should allow for uses cases mentioned in #110940

This would allow scalars to also be float64s in the foreach implementation. The fused implementation would still create a float32 step on Adam and AdamW. This PR also does NOT worry about performance and is mainly for enablement.

Next steps:
- Relax the constraint on fused adam(w) and allow torch.float64 scalars there
- Allow _performant_ mixed dtypes in foreach (a bigger project in itself).

This PR will conflict with my other PRs, I will figure out a landing order

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115841
Approved by: https://github.com/albanD
2023-12-27 09:13:49 +00:00
Jon Chuang
62de29d06f [optim] be explicit about CPU scalar tensor dtypes (#111008)
Fixes https://github.com/pytorch/pytorch/issues/110940

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111008
Approved by: https://github.com/janeyx99
2023-11-21 22:44:50 +00:00
Jon Chuang
f74d766632 feat(optim): use has_complex shortcut flag for all applicable optimizers, use _view_as_real auxiliary function (#110706)
Follow up to: https://github.com/pytorch/pytorch/pull/110607

CC: @lezcano @janeyx99
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110706
Approved by: https://github.com/lezcano
2023-10-31 20:33:03 +00:00
Jane Xu
93a9b1314b Make step() faster by passing in a tensor vs scalar 1 (#111084)
This is the culminated result of https://github.com/pytorch/pytorch/pull/110954#issuecomment-1758520411.

We are making the code slightly more complicated to gain some perf in minimizing calls to `.copy_()` and `.to()`.

### Code
```
import torch
with torch.cuda.device(0):
    steps = [torch.zeros((), device="cpu", dtype=torch.float32) for i in range(1000)]

    with torch.profiler.profile(
        activities=[
            torch.profiler.ProfilerActivity.CPU,
            torch.profiler.ProfilerActivity.CUDA,
        ]
    ) as p:
        # New code:
        # step_device = steps[0].device
        # one = torch.tensor(1.0, device=step_device) if str(step_device) == "cpu" else 1
        # torch._foreach_add_(steps, one, 1.0)

        # Old code:
        torch._foreach_add_(steps, 1)

    print(p.key_averages().table(sort_by="cpu_time_total"))
```

### Profiles
**with old code**
```
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------
      aten::_foreach_add_        35.31%      52.089ms        99.99%     147.495ms     147.495ms             1
               aten::add_        25.05%      36.949ms        64.68%      95.406ms      95.406us          1000
                 aten::to         3.97%       5.852ms        39.63%      58.457ms      58.457us          1000
           aten::_to_copy        10.11%      14.917ms        35.66%      52.605ms      52.605us          1000
              aten::copy_        21.65%      31.939ms        21.65%      31.939ms      31.939us          1000
      aten::empty_strided         3.90%       5.749ms         3.90%       5.749ms       5.749us          1000
    cudaDeviceSynchronize         0.01%      18.000us         0.01%      18.000us      18.000us             1
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 147.513ms
```

**with new code**
```
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------
      aten::_foreach_add_        55.06%      49.963ms        99.86%      90.625ms      90.625ms             1
               aten::add_        44.81%      40.662ms        44.81%      40.662ms      40.662us          1000
            aten::detach_         0.01%       8.000us         0.05%      45.000us      45.000us             1
                  detach_         0.04%      37.000us         0.04%      37.000us      37.000us             1
              aten::empty         0.03%      30.000us         0.03%      30.000us      30.000us             1
                 aten::to         0.03%      23.000us         0.03%      23.000us      23.000us             1
    cudaDeviceSynchronize         0.02%      22.000us         0.02%      22.000us      22.000us             1
         aten::lift_fresh         0.01%       6.000us         0.01%       6.000us       6.000us             1
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 90.751ms
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111084
Approved by: https://github.com/albanD
ghstack dependencies: #111079
2023-10-20 01:34:08 +00:00
Jon Chuang
11047be10e feat(optim): Add NAdamsupport for complex, with has_complex shortcut (#110634)
Partial fix: https://github.com/pytorch/pytorch/issues/110606

More on `has_complex` shortcut: https://github.com/pytorch/pytorch/pull/110613#issuecomment-1749314805

CC: @janeyx99 @mlazos @lezcano
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110634
Approved by: https://github.com/lezcano
2023-10-06 03:31:48 +00:00
Jane Xu
c0ba9a7840 Fix docs, missed a // in LaTeX for nadam (#107736)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107736
Approved by: https://github.com/mikaylagawarecki
2023-08-23 21:36:27 +00:00
Jane Xu
bcee3d6fa4 [BE] Make nadam decoupled_weight_decay clearer, add missing setstate (#107706)
Inspired by @blizard's changes in #107507

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107706
Approved by: https://github.com/Skylion007
2023-08-22 20:48:28 +00:00
Muralidhar Andoorveedu
608afe8083 Added xla friendly codepath to single_tensor_adamw (#102858)
There are extra graph compilations on XLA when beta{1,2} ** step get too small. This PR addresses this issue by making the `capturable` interface enabled for XLA, as well as switching to `torch.float_power` which preserves the same behaviour as the non-capturable flow on XLA.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102858
Approved by: https://github.com/janeyx99, https://github.com/albanD
2023-08-18 00:16:28 +00:00
Jane Xu
0208574db9 [NAdam] Add capturable API and tests + fix differentiable (#106615)
This PR:
- adds a capturable API for NAdam similar to Adam(W)
- adds tests accordingly
- discovered and fixed bugs in the differentiable implementation (now tested through the capturable codepath).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106615
Approved by: https://github.com/albanD
2023-08-07 19:49:11 +00:00
janEbert
b0708654c0 Implement NAdamW optimizer (#103881)
NAdamW, which is simply NAdam with the AdamW weight decay term, has shown strong performance in optimizer comparisons such as
1. https://arxiv.org/abs/2211.09760
1. https://arxiv.org/abs/2306.07179

[The VeLO paper](https://arxiv.org/abs/2211.09760) argues its power lies in its ability to act as a superset of other popular optimizers.

This PR adds NAdamW by ~~copying and making very small adaptations to the NAdam implementation (just like AdamW and Adam). To see the small changes in better detail, you can `diff torch/optim/nadam.py torch/optim/nadamw.py`.~~ adding a boolean flag `decoupled_weight_decay` that activates NAdamW behavior (`False` by default) to NAdam.

Interest in the optimizer has also been shown in the PyTorch forums:
https://discuss.pytorch.org/t/nadamw-and-demon-optimizers/179778

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103881
Approved by: https://github.com/janeyx99
2023-07-24 19:29:26 +00:00
Aaron Gokaslan
6d43c89f37 [BE]: Update Ruff to 0.0.280 (#105724)
Removes unusued loop values in python dictionary iteration. Automated fix from Ruff master

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105724
Approved by: https://github.com/ezyang, https://github.com/janeyx99
2023-07-22 23:03:34 +00:00
Justin Chu
3721fa5612 [BE] Enable ruff's UP rules and autoformat optim/ (#105426)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105426
Approved by: https://github.com/malfet, https://github.com/albanD, https://github.com/aaronenyeshi, https://github.com/janeyx99
2023-07-18 21:07:43 +00:00
Jane Xu
15aa401baa [foreach][NAdam] Minimize use of intermediates to decrease peak memory (#104910)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104910
Approved by: https://github.com/Skylion007, https://github.com/albanD
2023-07-11 17:08:07 +00:00
Jane Xu
231364fd06 [optim] use lerp whenever possible (#104796)
This is a better copy (with fixes) of #104781.

Test plan:
CI will pass once https://github.com/pytorch/pytorch/pull/104784 is landed

Internal CI (and the newly enabled compiled optim tests) will pass after https://github.com/pytorch/pytorch/pull/104866 is landed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104796
Approved by: https://github.com/albanD
2023-07-11 14:32:59 +00:00
PyTorch MergeBot
e7fe2a797c Revert "[optim] use lerp whenever possible (#104796)"
This reverts commit fbe2a7e50a.

Reverted https://github.com/pytorch/pytorch/pull/104796 on behalf of https://github.com/DanilBaibak due to Break internal build ([comment](https://github.com/pytorch/pytorch/pull/104796#issuecomment-1628591105))
2023-07-10 09:36:41 +00:00
Jane Xu
fbe2a7e50a [optim] use lerp whenever possible (#104796)
This is a better copy (with fixes) of #104781.

Test plan:
CI will pass once https://github.com/pytorch/pytorch/pull/104784 is landed

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104796
Approved by: https://github.com/albanD
2023-07-08 07:13:38 +00:00
Nikita Shulga
6d2887cc06 Reland "Move tensor grouping to ATen" (#103912)
This is a reland of https://github.com/pytorch/pytorch/pull/100007 with a build fix for Windows debug builds.
`at::native::ParamsHash` only works on structs with standard layout, but `std::string` isn't one in Visual C++ debug builds, which one can easily verified by running something like:
```cpp
#define _DEBUG
#include <type_traits>
#include <string>
static_assert(std::is_standard_layout_v<std::string>, "Oh noes");
```
If above conditon is not met, instead of printing a static_assert output, VC++ raises a very cryptic compilation errors,  see https://github.com/pytorch/pytorch/pull/100007#discussion_r1227116292 for more detail.

Also, using `std::hash` for string should result in a faster hash function.

(cherry picked from commit 74b7a6c75e)

<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at 5914771</samp>

This pull request introduces a new function `_group_tensors_by_device_and_dtype` that can group tensors by their device and dtype, and updates the `foreach` utilities and several optimizers to use this function. The goal is to improve the performance, readability, and compatibility of the code that handles tensors with different properties. The pull request also adds a test case and type annotations for the new function, and some error checks for the `fused` argument in Adam and AdamW.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103912
Approved by: https://github.com/janeyx99
2023-06-21 09:26:33 +00:00
PyTorch MergeBot
0cb5bc3b04 Revert "Move tensor grouping to ATen (#100007)"
This reverts commit 74b7a6c75e.

Reverted https://github.com/pytorch/pytorch/pull/100007 on behalf of https://github.com/izaitsevfb due to Breaks internal builds, see D46629727 ([comment](https://github.com/pytorch/pytorch/pull/100007#issuecomment-1587861598))
2023-06-12 18:30:33 +00:00
Masaki Kozuki
74b7a6c75e Move tensor grouping to ATen (#100007)
rel: #94344
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100007
Approved by: https://github.com/janeyx99
2023-06-09 15:44:46 +00:00
Michael Lazos
4da88447ea Disable grouping by dtype and device if compiling (#102771)
Disable grouping if we are compiling, this happens during lowering
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102771
Approved by: https://github.com/janeyx99
2023-06-02 21:04:49 +00:00
Jane Xu
75cb99e549 [optim] Widen the cases for defaulting to foreach (#95820)
Big OOP correction continued. Also added a test this time to verify the defaulting was as expected.

The key here is realizing that the grouping for foreach already assumes that the non-param tensorlists follow suit in dtype and device, so it is too narrow to check that _all_ tensors were on CUDA. The main leeway this allowed was state_steps, which are sometimes cpu tensors. Since foreach _can_ handle cpu tensors, this should not introduce breakage.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95820
Approved by: https://github.com/albanD
2023-03-02 04:15:33 +00:00
Jane Xu
097679478e [optim] Set defaults to foreach, NOT fused (#95241)
Rolling back the default change for Adam and rectifying the docs to reflect that AdamW never defaulted to fused.

Since our fused implementations are relatively newer, let's give them a longer bake-in time before flipping the switch for every user.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95241
Approved by: https://github.com/ngimel
2023-02-22 04:47:32 +00:00
Xuehai Pan
5b1cedacde [BE] [2/3] Rewrite super() calls in functorch and torch (#94588)
Rewrite Python built-in class `super()` calls. Only non-semantic changes should be applied.

- #94587
- #94588
- #94592

Also, methods with only a `super()` call are removed:

```diff
class MyModule(nn.Module):
-   def __init__(self):
-       super().__init__()
-
    def forward(self, ...):
        ...
```

Some cases that change the semantics should be kept unchanged. E.g.:

f152a79be9/caffe2/python/net_printer.py (L184-L190)

f152a79be9/test/test_jit_fuser_te.py (L2628-L2635)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94588
Approved by: https://github.com/ezyang, https://github.com/albanD
2023-02-10 21:16:33 +00:00
Jane Xu
4fc19e1a71 [optim][adam] use fastest impl whenever possible, add util (#93184)
This allows it so that ONLY when the users don't set anything for foreach or fused do we switch the default and cascades adam so that we default to fused, then foreach, then single-tensor.

To clarify:
* if the user puts True in foreach _only_, it will run the foreach implementation.
* if the user puts True in fused _only_, it will run the fused implementation.
* if the user puts True in foreach AND for fused, it will run the fused implementation.

And:
* if the user puts False in foreach _only_, it will run the single tensor implementation.
* if the user puts False in fused _only_, it will still run the single tensor implementation.
* if the user puts False in foreach AND for fused, it will run the single tensor implementation.

I also didn't trust myself that much with the helper function, so I ran some local asserts on _default_to_fused_or_foreach. The only point left to really test is the type(p) -- torch.Tensor but I think the distributed tests will catch that in CI.
```
cuda_only_fp_list = [
    torch.rand((1, 2), device="cuda", dtype=torch.float32),
    torch.rand((1, 2), device="cuda", dtype=torch.float64),
    torch.rand((1, 2), device="cuda", dtype=torch.float16),
    torch.rand((1, 2), device="cuda", dtype=torch.bfloat16),
]

cuda_only_int_list = [
    torch.randint(1024, (1, 2), device="cuda", dtype=torch.int64),
]

cpu_list = [
    torch.rand((1, 2), device="cpu", dtype=torch.float32),
    torch.rand((1, 2), device="cpu", dtype=torch.float64),
    torch.rand((1, 2), device="cpu", dtype=torch.float16),
]

none_list = [None]

# differentiable should always make it return false for both
assert _default_to_fused_or_foreach([cuda_only_fp_list], True, True) == (False, False)
assert _default_to_fused_or_foreach([cuda_only_fp_list], True, False) == (False, False)

# cpu lists should always make it return false for both
assert _default_to_fused_or_foreach([cuda_only_fp_list, cpu_list], False, True) == (False, False)
assert _default_to_fused_or_foreach([cpu_list], False, True) == (False, False)
assert _default_to_fused_or_foreach([cuda_only_fp_list, cpu_list], False, False) == (False, False)
assert _default_to_fused_or_foreach([cpu_list], False, False) == (False, False)

# has fused triggers correctly
assert _default_to_fused_or_foreach([cuda_only_fp_list], False, True) == (True, False)
assert _default_to_fused_or_foreach([cuda_only_fp_list], False, False) == (False, True)

# ints always goes to foreach
assert _default_to_fused_or_foreach([cuda_only_fp_list, cuda_only_int_list], False, True) == (False, True)
assert _default_to_fused_or_foreach([cuda_only_fp_list, cuda_only_int_list], False, False) == (False, True)

# Nones don't error
assert _default_to_fused_or_foreach([cuda_only_fp_list, none_list], False, True) == (True, False)
assert _default_to_fused_or_foreach([cuda_only_fp_list, cuda_only_int_list, none_list], False, True) == (False, True)
assert _default_to_fused_or_foreach([none_list], False, True) == (True, False)
assert _default_to_fused_or_foreach([none_list], False, False) == (False, True)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93184
Approved by: https://github.com/albanD
2023-01-30 19:58:55 +00:00
Jane Xu
0d870b50d3 [optim][nadam] group tensors in foreach, make it default (#92715)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92715
Approved by: https://github.com/albanD
2023-01-21 05:43:37 +00:00
Jane Xu
de0375e79d [optim][foreach] Do NOT inplace modify gradients (#92706)
SGD and ASGD already had out-of-place grads.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92706
Approved by: https://github.com/ngimel, https://github.com/albanD
2023-01-21 00:12:28 +00:00
Jane Xu
2b885e1f6c [optim][NAdam] Fix discrepancy between mt vs st impl (#92699)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92699
Approved by: https://github.com/albanD
2023-01-21 00:12:28 +00:00
Jane Xu
0070c546b5 [BE][optim] abstract out docstrings, add differentiable docs (#92336)
1. abstract out common doc strings --> I'm sure there are more, but let this be a first step.
2. Add differentiable docs to those who are actually differentiable
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92336
Approved by: https://github.com/albanD
2023-01-18 15:09:28 +00:00
Soumith Chintala
06326a7721 [optim] skip .item calls in all optimizers when compiling with dynamo (#88173)
@mlazos: skips `item()` calls if compiling with dynamo, by defining a helper function `_get_value` which either returns the result of `.item()` or the scalar cpu tensor if compiling with dynamo. This was done because removing `item()` calls significantly regresses eager perf. Additionally, `_dispatch_sqrt` calls the appropriate sqrt function (math.sqrt, or torch.sqrt).

Fixes https://github.com/pytorch/torchdynamo/issues/1083

This PR will no longer be needed once symint support is default.

This PR closes all remaining graph breaks in the optimizers (!!)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88173
Approved by: https://github.com/albanD
2022-12-12 17:32:35 +00:00
Michael Lazos
c63afb283c Disable dynamo on optimizer lazy initialization (#89902)
Helps with https://github.com/pytorch/torchdynamo/issues/1803

Separate out the group initialization and disable dynamo on it

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89902
Approved by: https://github.com/soumith, https://github.com/albanD
2022-12-02 01:15:11 +00:00
Michael Lazos
3d47c74cfe Update code style for optimizer code (#89862)
Separating out whitespace-only changes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89862
Approved by: https://github.com/albanD, https://github.com/soumith
2022-11-30 00:53:05 +00:00
Emilio Castillo
1b43883fd6 Make AdamW, NAdam & RAdam differentiable (#86183)
Blocked by #86096
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86183
Approved by: https://github.com/albanD
2022-10-17 04:32:08 +00:00
ProGamerGov
71d50f4f89 Change docstring type callable to Callable for consistency (#82487)
### Description

Across PyTorch's docstrings, both `callable` and `Callable` for variable types. The Callable should be capitalized as we are referring to the `Callable` type, and not the Python `callable()` function.

### Testing

There shouldn't be any testing required.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82487
Approved by: https://github.com/albanD
2022-08-01 17:26:09 +00:00
anjali411
bda04e9f5e Add __all__ for torch.optim and torch.nn.modules modules (#80237)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80237
Approved by: https://github.com/albanD
2022-06-24 21:34:10 +00:00
Sergii Dymchenko
de7219e8a7 Use generators with all/any in torch/optim (#78142)
Generator comprehensions with any/all are less verbose and potentially help to save memory/CPU : https://eklitzke.org/generator-comprehensions-and-using-any-and-all-in-python

To make JIT work with this change, I added code to convert GeneratorExp to ListComp. So the whole PR is basically NoOp for JIT, but potentially memory and speed improvement for eager mode.

Also I removed a test from test/jit/test_parametrization.py. The test was bad and had a TODO to actually implement and just tested that UnsupportedNodeError is thrown, and with GeneratorExp support a different error would be thrown.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78142
Approved by: https://github.com/malfet, https://github.com/albanD
2022-06-24 17:23:45 +00:00
Mikayla Gawarecki
3653f07c7c Optim foreach cleanup for NAdam (#70229)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/70229

Test Plan: Imported from OSS

Reviewed By: anjali411

Differential Revision: D33767873

Pulled By: mikaylagawarecki

fbshipit-source-id: 833ead14c1d1659351ebfbeb41045a3c7eb96dad
(cherry picked from commit 9415df6b5c)
2022-02-09 16:52:13 +00:00
Mikayla Gawarecki
ccc1a01dcb [optim] NAdam fold state updates into functional (#71334)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/71334

Test Plan: Imported from OSS

Reviewed By: anjali411

Differential Revision: D33767864

Pulled By: mikaylagawarecki

fbshipit-source-id: 4d985e9e346f40110bd4231e0f16e5643fbc448d
(cherry picked from commit 58aa77e367)
2022-02-08 23:58:41 +00:00
Ilqar Ramazanli
9ccb9299e0 To add Nesterov Adam algorithm description to documentation (#63793)
Summary:
It has been discussed before that adding description of Optimization algorithms to PyTorch Core documentation may result in a nice Optimization research tutorial. In the following tracking issue we mentioned about all the necessary algorithms and links to the originally published paper  https://github.com/pytorch/pytorch/issues/63236.

In this PR we are adding description of Nesterov Adam Algorithm to the documentation.  For more details, we refer to the paper  https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ

<img width="439" alt="NAdam" src="https://user-images.githubusercontent.com/73658284/131185124-e81b2edf-33d9-4a9d-a7bf-f7e5eea47d7c.png">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/63793

Reviewed By: NivekT

Differential Revision: D30617057

Pulled By: iramazanli

fbshipit-source-id: cd2054b0d9b6883878be74576e86e307f32f1435
2021-08-27 19:29:34 -07:00
Ilqar Ramazanli
e8690dacb2 To add Nesterov Adam Algorithm to Optimizers (#59009)
Summary:
Fixes : https://github.com/pytorch/pytorch/issues/5804

In the paper : https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ  Timothy Dozat suggested a new optimization algorithm with an essence of combination of NAG and Adam algorithms.

It is known that the idea of momentum can be improved with the Nesterov acceleration in optimization algorithms, and Dozat is investigating to apply this idea to momentum component of Adam algorithm. Author provided experiment evidence in their work to show excellence of the idea.

In this PR we are implementing the proposed algorithm NAdam in the mentioned paper. Author has a preliminary work http://cs229.stanford.edu/proj2015/054_report.pdf  where he shows the decay base constant should be taken as 0.96 which we also followed the same phenomenon here in this implementation similar to Keras. Moreover, implementation / coding practice have been followed similar to Keras in some other places as well:

f9d3868495/tensorflow/python/keras/optimizer_v2/nadam.py

Pull Request resolved: https://github.com/pytorch/pytorch/pull/59009

Reviewed By: gchanan, vincentqb

Differential Revision: D29220375

Pulled By: iramazanli

fbshipit-source-id: 4b4bb4b15f7e16f7527f368bbf4207ed345751aa
2021-06-23 08:21:43 -07:00