Commit Graph

101 Commits

Author SHA1 Message Date
Jon Chuang
d279979102 perf(inductor): improve Adam compile times by shortcutting for loops (via has_complex) (#110607)
Adam part of: https://github.com/pytorch/pytorch/issues/110506

TODO:
- If this approach is validated as a good one, it an also be applied to all other optimizers which convert `complex` via list comprehensions

### Results:
`NUM_PARAMS=200, foreach=True`
- main: dynamo: 43s, inductor: 31s, total: 74s
- this PR: dynamo: 3.5s, inductor: 30s, total: 34s (dynamo speedup: 12.3x, overall speedup: 34s, 2.1x)

`NUM_PARAMS=1000, foreach=True, has_complex shortcut`:

```
<class 'torch.optim.adam.Adam'> {'lr': 0.01, 'foreach': True} torch.float32 TorchDynamo compilation metrics:
Function                              Runtimes (s)
------------------------------------  -------------------------------
_compile.<locals>.compile_inner       0.0329, 50.0806, 0.0041
OutputGraph.call_user_compiler        44.9924
```

`NUM_PARAMS=1000, foreach=True`:
```
<class 'torch.optim.adam.Adam'> {'lr': 0.01, 'foreach': True} torch.float32 TorchDynamo compilation metrics:
Function                              Runtimes (s)
------------------------------------  -------------------------------
_compile.<locals>.compile_inner       0.0389, 58.6069, 0.0043
OutputGraph.call_user_compiler        44.1425
```

### Discussion
- `has_complex` shortcut provides additional 2x dynamo speedup. It is not necessary to achieve a significant overall speedup.

CC: @janeyx99 @mlazos

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110607
Approved by: https://github.com/janeyx99, https://github.com/lezcano
2023-10-06 05:08:49 +00:00
Jon Chuang
df7d01aed5 perf(inductor): use for loop with shortcut in Optimizers to speedup against list comprehensions (e.g. complex conversion) (#110613)
Fully fixes: https://github.com/pytorch/pytorch/issues/110506

Depends: https://github.com/pytorch/pytorch/pull/110607
Potential merge conflicts:
- https://github.com/pytorch/pytorch/pull/110339
- https://github.com/pytorch/pytorch/pull/110345
- https://github.com/pytorch/pytorch/pull/110454

Related:
- https://github.com/pytorch/pytorch/issues/110606 (we can apply the improvements here orthogonally to the complex support)

### Results

Benchmark: 100 params.

Breakdowns (float32, dynamo):
```
Adagrad: this PR: 4.4s, main: 8.8s
Adam: this PR: 2.1s, main: 9.8s
AdamW: this PR: 2.5s, main: 8.2s
ASGD: this PR: 3.1s, main: 8.5s
RMSProp: this PR: 1.3s, main: 4.2s
RProp: this PR: 6.7s, main: 14.9s
```

Notes:
1. Adagrad is still slow due to `_get_value` list comprehension. Can be fixed in https://github.com/pytorch/pytorch/pull/110339/files by utilizing capturable path
2. Adamax is not actually compiled (it is currently disabled).
3. Inductor compile time is quite variable. We calculate dynamo by subtracting `call_user_compiler` from `compile_inner` timing.

<details>

This PR:
```
Adagrad (torch.float32): 28.47496461868286s
Adagrad (torch.complex64): 29.379547357559204s
Adam (torch.float32): 17.334211587905884s
Adam (torch.complex64): 29.637500524520874s
Adamax (torch.float32): 2.4749321937561035s
Adamax (torch.complex64): 3.1997995376586914s
AdamW (torch.float32): 18.06532859802246s
AdamW (torch.complex64): 28.25661015510559s
ASGD (torch.float32): 23.70255398750305s
ASGD (torch.complex64): 25.33756995201111s
RMSprop (torch.float32): 7.964028596878052s
RMSprop (torch.complex64): 12.909599781036377s
Rprop (torch.float32): 30.512362003326416s
Rprop (torch.complex64): 44.74405765533447s
```

Main
```
Adagrad (torch.float32): 26.919506072998047s
Adagrad (torch.complex64): 35.190622091293335s
Adam (torch.float32): 25.715000867843628s
Adam (torch.complex64): 24.17716670036316s
Adamax (torch.float32): 2.4404726028442383s
Adamax (torch.complex64): 3.3538928031921387s
AdamW (torch.float32): 25.2022807598114s
AdamW (torch.complex64): 28.915700912475586s
ASGD (torch.float32): 24.108731985092163s
ASGD (torch.complex64): 26.589075088500977s
RMSprop (torch.float32): 10.781344175338745s
RMSprop (torch.complex64): 15.136352777481079s
Rprop (torch.float32): 42.46482181549072s
Rprop (torch.complex64): 48.28277635574341s
```

Seems that it doesn't help the complex case by much (but that's not the majority case). torch.float32 is generally positive, when it does not show drastic improvement / regresses, it is due to inductor variance (by manually inspecting the logs).

</details>

### Benchmark Script
```python
import torch
import time
from torch.optim import Adagrad, Adam, Adamax, AdamW, ASGD, RMSprop, Rprop

OPTIMS = [Adagrad, Adam, Adamax, AdamW, ASGD, RMSprop, Rprop]
DTYPES = [torch.float, torch.cfloat]

NUM_PARAMS = 100
kwargs = { "lr": 0.01, "foreach": True }
summary = []

for optim_cls in OPTIMS:
    for dtype in DTYPES:
        torch._dynamo.reset()
        # torch._inductor.metrics.reset()
        input = torch.ones([10, 10], dtype=dtype, device="cuda:0")
        model = torch.nn.Sequential(
            *[torch.nn.Linear(10, 10, dtype=dtype, device="cuda:0") for _ in range(NUM_PARAMS)]
        )

        model(input).sum().abs().backward()
        opt_compiled = optim_cls(model.parameters(), **kwargs)
        compiled_step = torch.compile(opt_compiled.step)

        with torch.set_grad_enabled(False):
            start_time = time.time()
            compiled_step()
            summary.append(f"{optim_cls.__name__} ({dtype}): {time.time() - start_time}s")

        print(optim_cls, kwargs, dtype, torch._dynamo.utils.compile_times())

for s in summary:
    print(s)
```

CC: @janeyx99 @mlazos
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110613
Approved by: https://github.com/janeyx99
2023-10-05 23:10:52 +00:00
Jane Xu
1641d671e5 [optim] FusedAdam/W accepts lr: Tensor without h2ds (#106916)
Starts addressing #106802

This PR also conveniently does some BE:
- Fixes a bug in adamw where we use amsgrad instead of per group amsgrad
- Brings the impls of adamw and adam closer to correctness and to each other

I couldn't fully remove the .pyi's because mypy was going to complain about the entire files which scared me and shouldn't go in this PR anyway.

Test plan:
- Add tests to ensure that lr could be passed as a Tensor
- Did some profiling of the below code (runs 1k iterations of step for Adam)

```
import torch
from torch.testing._internal.common_utils import TestCase

param = torch.rand(2, 3, dtype=torch.float, device='cuda:0', requires_grad=True)
param.grad = torch.rand_like(param)

lr = torch.tensor(.001, device='cuda:0')
opt = torch.optim.Adam([param], lr=lr, fused=True)

with torch.profiler.profile(
    activities=[
        torch.profiler.ProfilerActivity.CPU,
        torch.profiler.ProfilerActivity.CUDA,
    ]
) as p:
    for _ in range(1000):
        opt.step()

print(p.key_averages().table(sort_by="cpu_time_total"))

```

Before my change:
<img width="1381" alt="image" src="https://github.com/pytorch/pytorch/assets/31798555/cfc5175a-0f41-4829-941f-342554f3b152">

After my change (notice there are no d2h syncs and the CPU time is lower!):
![image](https://github.com/pytorch/pytorch/assets/31798555/726d7e66-dcff-4a4f-8a75-e84329961989)

Next steps long term:
- have all capturable foreach + forloop impls in Adam(W) handle tensor LR
- have all capturable impls handle tensor LR
- have all impls handle tensor LR

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106916
Approved by: https://github.com/albanD
2023-08-21 23:00:44 +00:00
Muralidhar Andoorveedu
608afe8083 Added xla friendly codepath to single_tensor_adamw (#102858)
There are extra graph compilations on XLA when beta{1,2} ** step get too small. This PR addresses this issue by making the `capturable` interface enabled for XLA, as well as switching to `torch.float_power` which preserves the same behaviour as the non-capturable flow on XLA.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102858
Approved by: https://github.com/janeyx99, https://github.com/albanD
2023-08-18 00:16:28 +00:00
shibo19
21ede4547a remove duplicated code in optimizer (#106022)
Fixes #ISSUE_NUMBER
as the title, the check code  has duplicates
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106022
Approved by: https://github.com/janeyx99
2023-07-26 17:01:28 +00:00
Aaron Gokaslan
6d43c89f37 [BE]: Update Ruff to 0.0.280 (#105724)
Removes unusued loop values in python dictionary iteration. Automated fix from Ruff master

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105724
Approved by: https://github.com/ezyang, https://github.com/janeyx99
2023-07-22 23:03:34 +00:00
Jane Xu
e1296a7f8d [Adam] Fix complex x amsgrad support (#104989)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104989
Approved by: https://github.com/albanD
2023-07-21 23:43:26 +00:00
Jane Xu
25d80c69ce [foreach] super minor BE: remove unnecessary cast (#105601)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105601
Approved by: https://github.com/albanD
2023-07-20 17:06:52 +00:00
Justin Chu
3721fa5612 [BE] Enable ruff's UP rules and autoformat optim/ (#105426)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105426
Approved by: https://github.com/malfet, https://github.com/albanD, https://github.com/aaronenyeshi, https://github.com/janeyx99
2023-07-18 21:07:43 +00:00
albanD
ef05c5f202 Use plain power operator in Adam/Adamw when capturing (#104254)
The goal is to fix the problem from https://github.com/pytorch/pytorch/pull/102858

The full error this used to raise was :
```
2023-06-27T15:12:15.0663239Z   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/optim/adamw.py", line 409, in _single_tensor_adamw
2023-06-27T15:12:15.0663699Z     bias_correction1 = 1 - beta1 ** step
2023-06-27T15:12:15.0664200Z   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_tensor.py", line 40, in wrapped
2023-06-27T15:12:15.0664547Z     return f(*args, **kwargs)
2023-06-27T15:12:15.0665031Z   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_tensor.py", line 882, in __rpow__
2023-06-27T15:12:15.0665483Z     return torch.tensor(other, dtype=dtype, device=self.device) ** self
2023-06-27T15:12:15.0665899Z RuntimeError: CUDA error: operation not permitted when stream is capturing
2023-06-27T15:12:15.0666401Z CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
```

This pow issue was fixed in https://github.com/pytorch/pytorch/pull/104264 and so this problem should be solvable now.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104254
Approved by: https://github.com/janeyx99, https://github.com/aws-murandoo
2023-07-13 19:24:25 +00:00
Jane Xu
231364fd06 [optim] use lerp whenever possible (#104796)
This is a better copy (with fixes) of #104781.

Test plan:
CI will pass once https://github.com/pytorch/pytorch/pull/104784 is landed

Internal CI (and the newly enabled compiled optim tests) will pass after https://github.com/pytorch/pytorch/pull/104866 is landed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104796
Approved by: https://github.com/albanD
2023-07-11 14:32:59 +00:00
Jane Xu
35f0e35529 [foreach][Adam] Minimize use of intermediates to decrease peak memory (#104780)
Starts addressing https://github.com/pytorch/pytorch/issues/97712 by
- Minimizing intermediates usage for foreach Adam
- Document the extra memory usage
- Add comments within the code for clarity now that we reuse intermediates
- Add tests
- Did some refactoring

Next steps involve doing this for all other foreach implementations. Note that even after this change, foreach mem usage will be higher than forloop due to the fact that we have a minimum budget of 1 intermediate (to not muddle the input values) and the intermediate will be larger. For capturable, the memory usage is higher due to moving more tensors to CUDA.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104780
Approved by: https://github.com/albanD
2023-07-10 17:38:46 +00:00
PyTorch MergeBot
e7fe2a797c Revert "[optim] use lerp whenever possible (#104796)"
This reverts commit fbe2a7e50a.

Reverted https://github.com/pytorch/pytorch/pull/104796 on behalf of https://github.com/DanilBaibak due to Break internal build ([comment](https://github.com/pytorch/pytorch/pull/104796#issuecomment-1628591105))
2023-07-10 09:36:41 +00:00
Jane Xu
fbe2a7e50a [optim] use lerp whenever possible (#104796)
This is a better copy (with fixes) of #104781.

Test plan:
CI will pass once https://github.com/pytorch/pytorch/pull/104784 is landed

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104796
Approved by: https://github.com/albanD
2023-07-08 07:13:38 +00:00
Michael Lazos
a290cbf32b Enable fused foreach Adam compilation (#104121)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104121
Approved by: https://github.com/janeyx99
2023-07-05 23:40:03 +00:00
Nikita Shulga
6d2887cc06 Reland "Move tensor grouping to ATen" (#103912)
This is a reland of https://github.com/pytorch/pytorch/pull/100007 with a build fix for Windows debug builds.
`at::native::ParamsHash` only works on structs with standard layout, but `std::string` isn't one in Visual C++ debug builds, which one can easily verified by running something like:
```cpp
#define _DEBUG
#include <type_traits>
#include <string>
static_assert(std::is_standard_layout_v<std::string>, "Oh noes");
```
If above conditon is not met, instead of printing a static_assert output, VC++ raises a very cryptic compilation errors,  see https://github.com/pytorch/pytorch/pull/100007#discussion_r1227116292 for more detail.

Also, using `std::hash` for string should result in a faster hash function.

(cherry picked from commit 74b7a6c75e)

<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at 5914771</samp>

This pull request introduces a new function `_group_tensors_by_device_and_dtype` that can group tensors by their device and dtype, and updates the `foreach` utilities and several optimizers to use this function. The goal is to improve the performance, readability, and compatibility of the code that handles tensors with different properties. The pull request also adds a test case and type annotations for the new function, and some error checks for the `fused` argument in Adam and AdamW.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103912
Approved by: https://github.com/janeyx99
2023-06-21 09:26:33 +00:00
Jane Xu
fa893f3f58 Fix optim state_dict casting to allow step to cast to CPU (#102619)
I'm guessing this should fix https://github.com/pytorch/pytorch/pull/88015#issuecomment-1569523106 but am waiting on @ychfan to supply more details so I could write a good test case.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102619
Approved by: https://github.com/albanD
2023-06-13 00:46:40 +00:00
PyTorch MergeBot
0cb5bc3b04 Revert "Move tensor grouping to ATen (#100007)"
This reverts commit 74b7a6c75e.

Reverted https://github.com/pytorch/pytorch/pull/100007 on behalf of https://github.com/izaitsevfb due to Breaks internal builds, see D46629727 ([comment](https://github.com/pytorch/pytorch/pull/100007#issuecomment-1587861598))
2023-06-12 18:30:33 +00:00
Masaki Kozuki
74b7a6c75e Move tensor grouping to ATen (#100007)
rel: #94344
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100007
Approved by: https://github.com/janeyx99
2023-06-09 15:44:46 +00:00
shibo19
e4a42bcf56 add foreach support for custom device (#102047)
Fixes #ISSUE_NUMBER
for custom device, we want to support foreach, so I add a func that we could set other device type, and the default value is cuda.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102047
Approved by: https://github.com/janeyx99
2023-06-07 13:59:20 +00:00
Michael Lazos
4da88447ea Disable grouping by dtype and device if compiling (#102771)
Disable grouping if we are compiling, this happens during lowering
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102771
Approved by: https://github.com/janeyx99
2023-06-02 21:04:49 +00:00
PyTorch MergeBot
9d77949b9e Revert "add foreach support for custom device (#102047)"
This reverts commit b088ff4677.

Reverted https://github.com/pytorch/pytorch/pull/102047 on behalf of https://github.com/malfet due to Broke inductor, see b088ff4677 ([comment](https://github.com/pytorch/pytorch/pull/102047#issuecomment-1572368942))
2023-06-01 16:33:03 +00:00
shibo19
b088ff4677 add foreach support for custom device (#102047)
Fixes #ISSUE_NUMBER
for custom device, we want to support foreach, so I add a func that we could set other device type, and the default value is cuda.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102047
Approved by: https://github.com/janeyx99
2023-06-01 06:22:44 +00:00
Jane Xu
f558af2a55 [adam] Use the right params in weight_decay, rename for clarity, fixes #100707 (#100973)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100973
Approved by: https://github.com/Skylion007, https://github.com/albanD
2023-05-09 17:00:27 +00:00
Masaki Kozuki
22ea21da3d Change 1D Tensor of 1 element to 0D Tensor (#96994)
add 0d tensor to graph adam/adamw test

Affected:
- `torch.cuda.amp.GradScaler`'s `found_inf`, `_scale`, and `_growth_tracker`
- `step` of Adam & AdamW of `capturable`

Fixes #96776 🤞

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96994
Approved by: https://github.com/janeyx99
2023-03-21 18:24:19 +00:00
Masaki Kozuki
7d765cdc66 Fix wrong handling of grad_scale & found_inf in fused optimizers (#95847)
Fixes #95781.
The cause seems to be that the current implementation doesn't correctly pass `found_inf` when `grad_scale` is `None`. Therefore parameters can get mistakenly updated by gradients whose some elements are invalid, i.e. nan or inf.

Related #94060

I forgot about this wrong handling after #94344

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95847
Approved by: https://github.com/janeyx99
2023-03-04 01:21:21 +00:00
Jane Xu
75cb99e549 [optim] Widen the cases for defaulting to foreach (#95820)
Big OOP correction continued. Also added a test this time to verify the defaulting was as expected.

The key here is realizing that the grouping for foreach already assumes that the non-param tensorlists follow suit in dtype and device, so it is too narrow to check that _all_ tensors were on CUDA. The main leeway this allowed was state_steps, which are sometimes cpu tensors. Since foreach _can_ handle cpu tensors, this should not introduce breakage.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95820
Approved by: https://github.com/albanD
2023-03-02 04:15:33 +00:00
Jane Xu
097679478e [optim] Set defaults to foreach, NOT fused (#95241)
Rolling back the default change for Adam and rectifying the docs to reflect that AdamW never defaulted to fused.

Since our fused implementations are relatively newer, let's give them a longer bake-in time before flipping the switch for every user.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95241
Approved by: https://github.com/ngimel
2023-02-22 04:47:32 +00:00
Masaki Kozuki
3e9df622fb [mta] implement _foreach_pow (#92303)
Mainly for foreach path of `Adam` and `AdamW`

rel: https://github.com/pytorch/pytorch/issues/58833
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92303
Approved by: https://github.com/albanD
2023-02-16 02:28:26 +00:00
Xuehai Pan
5b1cedacde [BE] [2/3] Rewrite super() calls in functorch and torch (#94588)
Rewrite Python built-in class `super()` calls. Only non-semantic changes should be applied.

- #94587
- #94588
- #94592

Also, methods with only a `super()` call are removed:

```diff
class MyModule(nn.Module):
-   def __init__(self):
-       super().__init__()
-
    def forward(self, ...):
        ...
```

Some cases that change the semantics should be kept unchanged. E.g.:

f152a79be9/caffe2/python/net_printer.py (L184-L190)

f152a79be9/test/test_jit_fuser_te.py (L2628-L2635)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94588
Approved by: https://github.com/ezyang, https://github.com/albanD
2023-02-10 21:16:33 +00:00
Masaki Kozuki
6ba041fcae Look up group["capturable"], not defaults["capturable"] in Adam(W) (#94149)
We could set different values in each `param_group` when calling dunder init of `torch.optim` optimizers as in e.g.  https://github.com/pytorch/pytorch/issues/89987.

So check whether or not `capturable` is `True` among all the `param_group`s.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94149
Approved by: https://github.com/albanD
2023-02-07 00:24:35 +00:00
Masaki Kozuki
a23ed38f9a [mta][foreach] Implement fused adamw (#88015)
related: https://github.com/pytorch/pytorch/issues/68041, https://github.com/pytorch/pytorch/issues/71274, https://github.com/pytorch/pytorch/issues/80167
possibly related to https://github.com/pytorch/pytorch/issues/80595#issuecomment-1178519436

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88015
Approved by: https://github.com/albanD, https://github.com/ngimel
2023-02-01 19:32:29 +00:00
Masaki Kozuki
d7a3f2128f pass None instead of False inside Adam.__setstate__ (#93289)
with a061f139dc, `fused`'s type hint is `Optional[bool]` and its default value is `None`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93289
Approved by: https://github.com/janeyx99, https://github.com/Skylion007
2023-01-31 09:41:35 +00:00
Jane Xu
4fc19e1a71 [optim][adam] use fastest impl whenever possible, add util (#93184)
This allows it so that ONLY when the users don't set anything for foreach or fused do we switch the default and cascades adam so that we default to fused, then foreach, then single-tensor.

To clarify:
* if the user puts True in foreach _only_, it will run the foreach implementation.
* if the user puts True in fused _only_, it will run the fused implementation.
* if the user puts True in foreach AND for fused, it will run the fused implementation.

And:
* if the user puts False in foreach _only_, it will run the single tensor implementation.
* if the user puts False in fused _only_, it will still run the single tensor implementation.
* if the user puts False in foreach AND for fused, it will run the single tensor implementation.

I also didn't trust myself that much with the helper function, so I ran some local asserts on _default_to_fused_or_foreach. The only point left to really test is the type(p) -- torch.Tensor but I think the distributed tests will catch that in CI.
```
cuda_only_fp_list = [
    torch.rand((1, 2), device="cuda", dtype=torch.float32),
    torch.rand((1, 2), device="cuda", dtype=torch.float64),
    torch.rand((1, 2), device="cuda", dtype=torch.float16),
    torch.rand((1, 2), device="cuda", dtype=torch.bfloat16),
]

cuda_only_int_list = [
    torch.randint(1024, (1, 2), device="cuda", dtype=torch.int64),
]

cpu_list = [
    torch.rand((1, 2), device="cpu", dtype=torch.float32),
    torch.rand((1, 2), device="cpu", dtype=torch.float64),
    torch.rand((1, 2), device="cpu", dtype=torch.float16),
]

none_list = [None]

# differentiable should always make it return false for both
assert _default_to_fused_or_foreach([cuda_only_fp_list], True, True) == (False, False)
assert _default_to_fused_or_foreach([cuda_only_fp_list], True, False) == (False, False)

# cpu lists should always make it return false for both
assert _default_to_fused_or_foreach([cuda_only_fp_list, cpu_list], False, True) == (False, False)
assert _default_to_fused_or_foreach([cpu_list], False, True) == (False, False)
assert _default_to_fused_or_foreach([cuda_only_fp_list, cpu_list], False, False) == (False, False)
assert _default_to_fused_or_foreach([cpu_list], False, False) == (False, False)

# has fused triggers correctly
assert _default_to_fused_or_foreach([cuda_only_fp_list], False, True) == (True, False)
assert _default_to_fused_or_foreach([cuda_only_fp_list], False, False) == (False, True)

# ints always goes to foreach
assert _default_to_fused_or_foreach([cuda_only_fp_list, cuda_only_int_list], False, True) == (False, True)
assert _default_to_fused_or_foreach([cuda_only_fp_list, cuda_only_int_list], False, False) == (False, True)

# Nones don't error
assert _default_to_fused_or_foreach([cuda_only_fp_list, none_list], False, True) == (True, False)
assert _default_to_fused_or_foreach([cuda_only_fp_list, cuda_only_int_list, none_list], False, True) == (False, True)
assert _default_to_fused_or_foreach([none_list], False, True) == (True, False)
assert _default_to_fused_or_foreach([none_list], False, False) == (False, True)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93184
Approved by: https://github.com/albanD
2023-01-30 19:58:55 +00:00
Jane Xu
de0375e79d [optim][foreach] Do NOT inplace modify gradients (#92706)
SGD and ASGD already had out-of-place grads.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92706
Approved by: https://github.com/ngimel, https://github.com/albanD
2023-01-21 00:12:28 +00:00
milesial
e4d83d54a6 Foreach gradient clipping (#91846)
Faster gradient clipping using the foreach functions

```
[------------------------ (tensors, scalar) -------------------------]
                                   |  without foreach  |  with foreach |    apex
1 threads: ----------------------------------------------------------------------
      10 tensors of size 4         |         120.5     |       61.1    |     50.3
      100 tensors of size 4        |         946.2     |      239.5    |    136.3
      1000 tensors of size 4       |        9808.5     |     2151.1    |   1006.9
      10000 tensors of size 4      |       96871.2     |    22637.4    |  10119.1
      10 tensors of size 16        |         121.0     |       64.1    |     52.5
      100 tensors of size 16       |         993.4     |      252.6    |    136.7
      1000 tensors of size 16      |        9427.7     |     2151.2    |   1049.5
      10000 tensors of size 16     |       97437.1     |    22203.1    |  10340.0
      10 tensors of size 256       |         118.9     |       62.3    |     51.5
      100 tensors of size 256      |         955.2     |      243.1    |    134.2
      1000 tensors of size 256     |        9374.9     |     2140.7    |   1009.6
      10000 tensors of size 256    |       95302.5     |    21849.4    |  10215.5
      10 tensors of size 65536     |         118.5     |       62.4    |     51.1
      100 tensors of size 65536    |        1740.7     |      243.3    |    225.3
      1000 tensors of size 65536   |       17364.1     |     2228.7    |   2004.5
      10000 tensors of size 65536  |      177510.1     |    25410.4    |  20678.2
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91846
Approved by: https://github.com/janeyx99
2023-01-20 21:43:29 +00:00
Jane Xu
07800c52af [optim][adam] group tensors in foreach to maximize perf (#92349)
same idea as https://github.com/pytorch/pytorch/pull/92338
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92349
Approved by: https://github.com/albanD
2023-01-18 22:05:42 +00:00
Jane Xu
0070c546b5 [BE][optim] abstract out docstrings, add differentiable docs (#92336)
1. abstract out common doc strings --> I'm sure there are more, but let this be a first step.
2. Add differentiable docs to those who are actually differentiable
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92336
Approved by: https://github.com/albanD
2023-01-18 15:09:28 +00:00
Jane Xu
d41b5d7c14 [adam] Add not torch.jit.is_scripting() as a requirement for switching to fused (#92181)
A "fix" following https://github.com/pytorch/pytorch/pull/90865. Realized that fused is not compatible with torch.jit.is_scripting() when looking at a later line.

Took the opportunity to make the code cleaner/slightly more performant (with the extends) as well.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92181
Approved by: https://github.com/albanD
2023-01-14 19:05:27 +00:00
Nouran Ali
a60125e298 add docstring for adam differentiable parameter (#91881)
Fixes #90467

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91881
Approved by: https://github.com/janeyx99
2023-01-13 17:08:27 +00:00
Jane Xu
ed7885c254 [utils][foreach] Add group tensor by device and dtype util (#92014)
Add util that will be commonly used throughout optim
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92014
Approved by: https://github.com/albanD
2023-01-11 23:37:20 +00:00
Jane Xu
a061f139dc [optim] Adam defaults to fused when CUDA + differentiable=False (#90865)
Step 1 in faster default optimizers.

Preliminary benchmarks show gaps in improvement on CUDA for BERT_pytorch and resnet18:
![image](https://user-images.githubusercontent.com/31798555/207707118-14221802-77ce-4ee0-96e3-04638c07924c.png)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90865
Approved by: https://github.com/albanD
2022-12-27 01:28:47 +00:00
Soumith Chintala
06326a7721 [optim] skip .item calls in all optimizers when compiling with dynamo (#88173)
@mlazos: skips `item()` calls if compiling with dynamo, by defining a helper function `_get_value` which either returns the result of `.item()` or the scalar cpu tensor if compiling with dynamo. This was done because removing `item()` calls significantly regresses eager perf. Additionally, `_dispatch_sqrt` calls the appropriate sqrt function (math.sqrt, or torch.sqrt).

Fixes https://github.com/pytorch/torchdynamo/issues/1083

This PR will no longer be needed once symint support is default.

This PR closes all remaining graph breaks in the optimizers (!!)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88173
Approved by: https://github.com/albanD
2022-12-12 17:32:35 +00:00
Michael Lazos
c63afb283c Disable dynamo on optimizer lazy initialization (#89902)
Helps with https://github.com/pytorch/torchdynamo/issues/1803

Separate out the group initialization and disable dynamo on it

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89902
Approved by: https://github.com/soumith, https://github.com/albanD
2022-12-02 01:15:11 +00:00
Michael Lazos
3d47c74cfe Update code style for optimizer code (#89862)
Separating out whitespace-only changes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89862
Approved by: https://github.com/albanD, https://github.com/soumith
2022-11-30 00:53:05 +00:00
Masaki Kozuki
5f26df0345 resubmit: "resubmit: [mta] APEX style Fused Adam (#81705) (#85507)" (#85739)
Embarrassingly move the pow implementations around [ATen/native/cuda/PowKernel.cu#L21-L66](849b08f14b/aten/src/ATen/native/cuda/PowKernel.cu (L21-L66)) to a new header file and let FusedAdam use them to tame MSVC, hopefully.

cc @ngimel @ptrblck
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85739
Approved by: https://github.com/ngimel
2022-09-29 16:58:59 +00:00
PyTorch MergeBot
7167996346 Revert "resubmit: [mta] APEX style Fused Adam (#81705) (#85507)"
This reverts commit 4615d1bcfa.

Reverted https://github.com/pytorch/pytorch/pull/85507 on behalf of https://github.com/atalman due to Break internal windows builds
2022-09-27 16:59:35 +00:00
Masaki Kozuki
4615d1bcfa resubmit: [mta] APEX style Fused Adam (#81705) (#85507)
This PR implements an APEX style FusedAdam in PyTorch. This is different from the APEX one in that this is compatible with `torch.cuda.amp.GradScaler` by setting `_step_supports_amp_scaling` to `True` and unscales gradients inside its CUDA kernel.

related: https://github.com/pytorch/pytorch/issues/68041, https://github.com/pytorch/pytorch/issues/71274, https://github.com/pytorch/pytorch/issues/80167 possibly related to https://github.com/pytorch/pytorch/issues/80595#issuecomment-1178519436

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81705
Approved by: https://github.com/ngimel

cc @ptrblck @ngimel
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85507
Approved by: https://github.com/ngimel
2022-09-23 18:56:00 +00:00
PyTorch MergeBot
e505360eb8 Revert "[mta] APEX style Fused Adam (#81705)"
This reverts commit 7a6c4d0c50.

Reverted https://github.com/pytorch/pytorch/pull/81705 on behalf of https://github.com/dagitses due to broke internal builds, details to come
2022-09-22 19:37:29 +00:00
Masaki Kozuki
7a6c4d0c50 [mta] APEX style Fused Adam (#81705)
This PR implements an APEX style FusedAdam in PyTorch.
This is different from the APEX one in that this is compatible with `torch.cuda.amp.GradScaler` by setting `_step_supports_amp_scaling` to `True` and unscales gradients inside its CUDA kernel.

related: https://github.com/pytorch/pytorch/issues/68041, https://github.com/pytorch/pytorch/issues/71274, https://github.com/pytorch/pytorch/issues/80167
possibly related to https://github.com/pytorch/pytorch/issues/80595#issuecomment-1178519436

cc @ptrblck @ngimel
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81705
Approved by: https://github.com/ngimel
2022-09-20 17:18:33 +00:00