Commit Graph

87 Commits

Author SHA1 Message Date
Michael Lazos
a290cbf32b Enable fused foreach Adam compilation (#104121)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104121
Approved by: https://github.com/janeyx99
2023-07-05 23:40:03 +00:00
Nikita Shulga
6d2887cc06 Reland "Move tensor grouping to ATen" (#103912)
This is a reland of https://github.com/pytorch/pytorch/pull/100007 with a build fix for Windows debug builds.
`at::native::ParamsHash` only works on structs with standard layout, but `std::string` isn't one in Visual C++ debug builds, which one can easily verified by running something like:
```cpp
#define _DEBUG
#include <type_traits>
#include <string>
static_assert(std::is_standard_layout_v<std::string>, "Oh noes");
```
If above conditon is not met, instead of printing a static_assert output, VC++ raises a very cryptic compilation errors,  see https://github.com/pytorch/pytorch/pull/100007#discussion_r1227116292 for more detail.

Also, using `std::hash` for string should result in a faster hash function.

(cherry picked from commit 74b7a6c75e)

<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at 5914771</samp>

This pull request introduces a new function `_group_tensors_by_device_and_dtype` that can group tensors by their device and dtype, and updates the `foreach` utilities and several optimizers to use this function. The goal is to improve the performance, readability, and compatibility of the code that handles tensors with different properties. The pull request also adds a test case and type annotations for the new function, and some error checks for the `fused` argument in Adam and AdamW.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103912
Approved by: https://github.com/janeyx99
2023-06-21 09:26:33 +00:00
Jane Xu
fa893f3f58 Fix optim state_dict casting to allow step to cast to CPU (#102619)
I'm guessing this should fix https://github.com/pytorch/pytorch/pull/88015#issuecomment-1569523106 but am waiting on @ychfan to supply more details so I could write a good test case.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102619
Approved by: https://github.com/albanD
2023-06-13 00:46:40 +00:00
PyTorch MergeBot
0cb5bc3b04 Revert "Move tensor grouping to ATen (#100007)"
This reverts commit 74b7a6c75e.

Reverted https://github.com/pytorch/pytorch/pull/100007 on behalf of https://github.com/izaitsevfb due to Breaks internal builds, see D46629727 ([comment](https://github.com/pytorch/pytorch/pull/100007#issuecomment-1587861598))
2023-06-12 18:30:33 +00:00
Masaki Kozuki
74b7a6c75e Move tensor grouping to ATen (#100007)
rel: #94344
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100007
Approved by: https://github.com/janeyx99
2023-06-09 15:44:46 +00:00
shibo19
e4a42bcf56 add foreach support for custom device (#102047)
Fixes #ISSUE_NUMBER
for custom device, we want to support foreach, so I add a func that we could set other device type, and the default value is cuda.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102047
Approved by: https://github.com/janeyx99
2023-06-07 13:59:20 +00:00
Michael Lazos
4da88447ea Disable grouping by dtype and device if compiling (#102771)
Disable grouping if we are compiling, this happens during lowering
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102771
Approved by: https://github.com/janeyx99
2023-06-02 21:04:49 +00:00
PyTorch MergeBot
9d77949b9e Revert "add foreach support for custom device (#102047)"
This reverts commit b088ff4677.

Reverted https://github.com/pytorch/pytorch/pull/102047 on behalf of https://github.com/malfet due to Broke inductor, see b088ff4677 ([comment](https://github.com/pytorch/pytorch/pull/102047#issuecomment-1572368942))
2023-06-01 16:33:03 +00:00
shibo19
b088ff4677 add foreach support for custom device (#102047)
Fixes #ISSUE_NUMBER
for custom device, we want to support foreach, so I add a func that we could set other device type, and the default value is cuda.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102047
Approved by: https://github.com/janeyx99
2023-06-01 06:22:44 +00:00
Jane Xu
f558af2a55 [adam] Use the right params in weight_decay, rename for clarity, fixes #100707 (#100973)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100973
Approved by: https://github.com/Skylion007, https://github.com/albanD
2023-05-09 17:00:27 +00:00
Masaki Kozuki
22ea21da3d Change 1D Tensor of 1 element to 0D Tensor (#96994)
add 0d tensor to graph adam/adamw test

Affected:
- `torch.cuda.amp.GradScaler`'s `found_inf`, `_scale`, and `_growth_tracker`
- `step` of Adam & AdamW of `capturable`

Fixes #96776 🤞

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96994
Approved by: https://github.com/janeyx99
2023-03-21 18:24:19 +00:00
Masaki Kozuki
7d765cdc66 Fix wrong handling of grad_scale & found_inf in fused optimizers (#95847)
Fixes #95781.
The cause seems to be that the current implementation doesn't correctly pass `found_inf` when `grad_scale` is `None`. Therefore parameters can get mistakenly updated by gradients whose some elements are invalid, i.e. nan or inf.

Related #94060

I forgot about this wrong handling after #94344

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95847
Approved by: https://github.com/janeyx99
2023-03-04 01:21:21 +00:00
Jane Xu
75cb99e549 [optim] Widen the cases for defaulting to foreach (#95820)
Big OOP correction continued. Also added a test this time to verify the defaulting was as expected.

The key here is realizing that the grouping for foreach already assumes that the non-param tensorlists follow suit in dtype and device, so it is too narrow to check that _all_ tensors were on CUDA. The main leeway this allowed was state_steps, which are sometimes cpu tensors. Since foreach _can_ handle cpu tensors, this should not introduce breakage.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95820
Approved by: https://github.com/albanD
2023-03-02 04:15:33 +00:00
Jane Xu
097679478e [optim] Set defaults to foreach, NOT fused (#95241)
Rolling back the default change for Adam and rectifying the docs to reflect that AdamW never defaulted to fused.

Since our fused implementations are relatively newer, let's give them a longer bake-in time before flipping the switch for every user.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95241
Approved by: https://github.com/ngimel
2023-02-22 04:47:32 +00:00
Masaki Kozuki
3e9df622fb [mta] implement _foreach_pow (#92303)
Mainly for foreach path of `Adam` and `AdamW`

rel: https://github.com/pytorch/pytorch/issues/58833
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92303
Approved by: https://github.com/albanD
2023-02-16 02:28:26 +00:00
Xuehai Pan
5b1cedacde [BE] [2/3] Rewrite super() calls in functorch and torch (#94588)
Rewrite Python built-in class `super()` calls. Only non-semantic changes should be applied.

- #94587
- #94588
- #94592

Also, methods with only a `super()` call are removed:

```diff
class MyModule(nn.Module):
-   def __init__(self):
-       super().__init__()
-
    def forward(self, ...):
        ...
```

Some cases that change the semantics should be kept unchanged. E.g.:

f152a79be9/caffe2/python/net_printer.py (L184-L190)

f152a79be9/test/test_jit_fuser_te.py (L2628-L2635)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94588
Approved by: https://github.com/ezyang, https://github.com/albanD
2023-02-10 21:16:33 +00:00
Masaki Kozuki
6ba041fcae Look up group["capturable"], not defaults["capturable"] in Adam(W) (#94149)
We could set different values in each `param_group` when calling dunder init of `torch.optim` optimizers as in e.g.  https://github.com/pytorch/pytorch/issues/89987.

So check whether or not `capturable` is `True` among all the `param_group`s.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94149
Approved by: https://github.com/albanD
2023-02-07 00:24:35 +00:00
Masaki Kozuki
a23ed38f9a [mta][foreach] Implement fused adamw (#88015)
related: https://github.com/pytorch/pytorch/issues/68041, https://github.com/pytorch/pytorch/issues/71274, https://github.com/pytorch/pytorch/issues/80167
possibly related to https://github.com/pytorch/pytorch/issues/80595#issuecomment-1178519436

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88015
Approved by: https://github.com/albanD, https://github.com/ngimel
2023-02-01 19:32:29 +00:00
Masaki Kozuki
d7a3f2128f pass None instead of False inside Adam.__setstate__ (#93289)
with a061f139dc, `fused`'s type hint is `Optional[bool]` and its default value is `None`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93289
Approved by: https://github.com/janeyx99, https://github.com/Skylion007
2023-01-31 09:41:35 +00:00
Jane Xu
4fc19e1a71 [optim][adam] use fastest impl whenever possible, add util (#93184)
This allows it so that ONLY when the users don't set anything for foreach or fused do we switch the default and cascades adam so that we default to fused, then foreach, then single-tensor.

To clarify:
* if the user puts True in foreach _only_, it will run the foreach implementation.
* if the user puts True in fused _only_, it will run the fused implementation.
* if the user puts True in foreach AND for fused, it will run the fused implementation.

And:
* if the user puts False in foreach _only_, it will run the single tensor implementation.
* if the user puts False in fused _only_, it will still run the single tensor implementation.
* if the user puts False in foreach AND for fused, it will run the single tensor implementation.

I also didn't trust myself that much with the helper function, so I ran some local asserts on _default_to_fused_or_foreach. The only point left to really test is the type(p) -- torch.Tensor but I think the distributed tests will catch that in CI.
```
cuda_only_fp_list = [
    torch.rand((1, 2), device="cuda", dtype=torch.float32),
    torch.rand((1, 2), device="cuda", dtype=torch.float64),
    torch.rand((1, 2), device="cuda", dtype=torch.float16),
    torch.rand((1, 2), device="cuda", dtype=torch.bfloat16),
]

cuda_only_int_list = [
    torch.randint(1024, (1, 2), device="cuda", dtype=torch.int64),
]

cpu_list = [
    torch.rand((1, 2), device="cpu", dtype=torch.float32),
    torch.rand((1, 2), device="cpu", dtype=torch.float64),
    torch.rand((1, 2), device="cpu", dtype=torch.float16),
]

none_list = [None]

# differentiable should always make it return false for both
assert _default_to_fused_or_foreach([cuda_only_fp_list], True, True) == (False, False)
assert _default_to_fused_or_foreach([cuda_only_fp_list], True, False) == (False, False)

# cpu lists should always make it return false for both
assert _default_to_fused_or_foreach([cuda_only_fp_list, cpu_list], False, True) == (False, False)
assert _default_to_fused_or_foreach([cpu_list], False, True) == (False, False)
assert _default_to_fused_or_foreach([cuda_only_fp_list, cpu_list], False, False) == (False, False)
assert _default_to_fused_or_foreach([cpu_list], False, False) == (False, False)

# has fused triggers correctly
assert _default_to_fused_or_foreach([cuda_only_fp_list], False, True) == (True, False)
assert _default_to_fused_or_foreach([cuda_only_fp_list], False, False) == (False, True)

# ints always goes to foreach
assert _default_to_fused_or_foreach([cuda_only_fp_list, cuda_only_int_list], False, True) == (False, True)
assert _default_to_fused_or_foreach([cuda_only_fp_list, cuda_only_int_list], False, False) == (False, True)

# Nones don't error
assert _default_to_fused_or_foreach([cuda_only_fp_list, none_list], False, True) == (True, False)
assert _default_to_fused_or_foreach([cuda_only_fp_list, cuda_only_int_list, none_list], False, True) == (False, True)
assert _default_to_fused_or_foreach([none_list], False, True) == (True, False)
assert _default_to_fused_or_foreach([none_list], False, False) == (False, True)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93184
Approved by: https://github.com/albanD
2023-01-30 19:58:55 +00:00
Jane Xu
de0375e79d [optim][foreach] Do NOT inplace modify gradients (#92706)
SGD and ASGD already had out-of-place grads.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92706
Approved by: https://github.com/ngimel, https://github.com/albanD
2023-01-21 00:12:28 +00:00
milesial
e4d83d54a6 Foreach gradient clipping (#91846)
Faster gradient clipping using the foreach functions

```
[------------------------ (tensors, scalar) -------------------------]
                                   |  without foreach  |  with foreach |    apex
1 threads: ----------------------------------------------------------------------
      10 tensors of size 4         |         120.5     |       61.1    |     50.3
      100 tensors of size 4        |         946.2     |      239.5    |    136.3
      1000 tensors of size 4       |        9808.5     |     2151.1    |   1006.9
      10000 tensors of size 4      |       96871.2     |    22637.4    |  10119.1
      10 tensors of size 16        |         121.0     |       64.1    |     52.5
      100 tensors of size 16       |         993.4     |      252.6    |    136.7
      1000 tensors of size 16      |        9427.7     |     2151.2    |   1049.5
      10000 tensors of size 16     |       97437.1     |    22203.1    |  10340.0
      10 tensors of size 256       |         118.9     |       62.3    |     51.5
      100 tensors of size 256      |         955.2     |      243.1    |    134.2
      1000 tensors of size 256     |        9374.9     |     2140.7    |   1009.6
      10000 tensors of size 256    |       95302.5     |    21849.4    |  10215.5
      10 tensors of size 65536     |         118.5     |       62.4    |     51.1
      100 tensors of size 65536    |        1740.7     |      243.3    |    225.3
      1000 tensors of size 65536   |       17364.1     |     2228.7    |   2004.5
      10000 tensors of size 65536  |      177510.1     |    25410.4    |  20678.2
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91846
Approved by: https://github.com/janeyx99
2023-01-20 21:43:29 +00:00
Jane Xu
07800c52af [optim][adam] group tensors in foreach to maximize perf (#92349)
same idea as https://github.com/pytorch/pytorch/pull/92338
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92349
Approved by: https://github.com/albanD
2023-01-18 22:05:42 +00:00
Jane Xu
0070c546b5 [BE][optim] abstract out docstrings, add differentiable docs (#92336)
1. abstract out common doc strings --> I'm sure there are more, but let this be a first step.
2. Add differentiable docs to those who are actually differentiable
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92336
Approved by: https://github.com/albanD
2023-01-18 15:09:28 +00:00
Jane Xu
d41b5d7c14 [adam] Add not torch.jit.is_scripting() as a requirement for switching to fused (#92181)
A "fix" following https://github.com/pytorch/pytorch/pull/90865. Realized that fused is not compatible with torch.jit.is_scripting() when looking at a later line.

Took the opportunity to make the code cleaner/slightly more performant (with the extends) as well.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92181
Approved by: https://github.com/albanD
2023-01-14 19:05:27 +00:00
Nouran Ali
a60125e298 add docstring for adam differentiable parameter (#91881)
Fixes #90467

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91881
Approved by: https://github.com/janeyx99
2023-01-13 17:08:27 +00:00
Jane Xu
ed7885c254 [utils][foreach] Add group tensor by device and dtype util (#92014)
Add util that will be commonly used throughout optim
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92014
Approved by: https://github.com/albanD
2023-01-11 23:37:20 +00:00
Jane Xu
a061f139dc [optim] Adam defaults to fused when CUDA + differentiable=False (#90865)
Step 1 in faster default optimizers.

Preliminary benchmarks show gaps in improvement on CUDA for BERT_pytorch and resnet18:
![image](https://user-images.githubusercontent.com/31798555/207707118-14221802-77ce-4ee0-96e3-04638c07924c.png)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90865
Approved by: https://github.com/albanD
2022-12-27 01:28:47 +00:00
Soumith Chintala
06326a7721 [optim] skip .item calls in all optimizers when compiling with dynamo (#88173)
@mlazos: skips `item()` calls if compiling with dynamo, by defining a helper function `_get_value` which either returns the result of `.item()` or the scalar cpu tensor if compiling with dynamo. This was done because removing `item()` calls significantly regresses eager perf. Additionally, `_dispatch_sqrt` calls the appropriate sqrt function (math.sqrt, or torch.sqrt).

Fixes https://github.com/pytorch/torchdynamo/issues/1083

This PR will no longer be needed once symint support is default.

This PR closes all remaining graph breaks in the optimizers (!!)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88173
Approved by: https://github.com/albanD
2022-12-12 17:32:35 +00:00
Michael Lazos
c63afb283c Disable dynamo on optimizer lazy initialization (#89902)
Helps with https://github.com/pytorch/torchdynamo/issues/1803

Separate out the group initialization and disable dynamo on it

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89902
Approved by: https://github.com/soumith, https://github.com/albanD
2022-12-02 01:15:11 +00:00
Michael Lazos
3d47c74cfe Update code style for optimizer code (#89862)
Separating out whitespace-only changes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89862
Approved by: https://github.com/albanD, https://github.com/soumith
2022-11-30 00:53:05 +00:00
Masaki Kozuki
5f26df0345 resubmit: "resubmit: [mta] APEX style Fused Adam (#81705) (#85507)" (#85739)
Embarrassingly move the pow implementations around [ATen/native/cuda/PowKernel.cu#L21-L66](849b08f14b/aten/src/ATen/native/cuda/PowKernel.cu (L21-L66)) to a new header file and let FusedAdam use them to tame MSVC, hopefully.

cc @ngimel @ptrblck
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85739
Approved by: https://github.com/ngimel
2022-09-29 16:58:59 +00:00
PyTorch MergeBot
7167996346 Revert "resubmit: [mta] APEX style Fused Adam (#81705) (#85507)"
This reverts commit 4615d1bcfa.

Reverted https://github.com/pytorch/pytorch/pull/85507 on behalf of https://github.com/atalman due to Break internal windows builds
2022-09-27 16:59:35 +00:00
Masaki Kozuki
4615d1bcfa resubmit: [mta] APEX style Fused Adam (#81705) (#85507)
This PR implements an APEX style FusedAdam in PyTorch. This is different from the APEX one in that this is compatible with `torch.cuda.amp.GradScaler` by setting `_step_supports_amp_scaling` to `True` and unscales gradients inside its CUDA kernel.

related: https://github.com/pytorch/pytorch/issues/68041, https://github.com/pytorch/pytorch/issues/71274, https://github.com/pytorch/pytorch/issues/80167 possibly related to https://github.com/pytorch/pytorch/issues/80595#issuecomment-1178519436

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81705
Approved by: https://github.com/ngimel

cc @ptrblck @ngimel
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85507
Approved by: https://github.com/ngimel
2022-09-23 18:56:00 +00:00
PyTorch MergeBot
e505360eb8 Revert "[mta] APEX style Fused Adam (#81705)"
This reverts commit 7a6c4d0c50.

Reverted https://github.com/pytorch/pytorch/pull/81705 on behalf of https://github.com/dagitses due to broke internal builds, details to come
2022-09-22 19:37:29 +00:00
Masaki Kozuki
7a6c4d0c50 [mta] APEX style Fused Adam (#81705)
This PR implements an APEX style FusedAdam in PyTorch.
This is different from the APEX one in that this is compatible with `torch.cuda.amp.GradScaler` by setting `_step_supports_amp_scaling` to `True` and unscales gradients inside its CUDA kernel.

related: https://github.com/pytorch/pytorch/issues/68041, https://github.com/pytorch/pytorch/issues/71274, https://github.com/pytorch/pytorch/issues/80167
possibly related to https://github.com/pytorch/pytorch/issues/80595#issuecomment-1178519436

cc @ptrblck @ngimel
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81705
Approved by: https://github.com/ngimel
2022-09-20 17:18:33 +00:00
albanD
84c4b07932 Make sure that we can load old optimizer checkpoint (#83588)
We want to make sure that we can load checkpoints that were saved with older version of the code (which doesn't contain the differentiable attribute).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83588
Approved by: https://github.com/mikaylagawarecki
2022-08-17 15:08:05 +00:00
Emilio Castillo
5aab57e112 Make Adam optimizer differentiable (#82205)
Continues [80938](https://github.com/pytorch/pytorch/pull/80938)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82205
Approved by: https://github.com/albanD
2022-08-17 07:20:37 +00:00
Masaki Kozuki
3139722679 [foreach][mta] Inplace maximum and minimum (#82523)
### Description
<!-- What did you change and why was it needed? -->
Implement `torch._foreach_maximum_` and `torch._foreach_minimum_` mainly for `_multi_tensor_adam` and `_multi_tensor_adamw` with `amsgrad=True` to correctly update their `max_exp_avg_sqs`.

### Issue
<!-- Link to Issue ticket or RFP -->
- https://github.com/pytorch/pytorch/issues/78807
- https://github.com/pytorch/pytorch/pull/81894
- https://github.com/pytorch/pytorch/pull/81348
- https://github.com/pytorch/pytorch/pull/81705
- https://github.com/pytorch/pytorch/issues/58833
- https://github.com/pytorch/pytorch/issues/68041

### Testing
<!-- How did you test your change? -->
Updated `test_foreach.py::TestForeach::_minmax_test` to compare the outputs of `_foreach_maximum_` (and `_foreach_minimum_`) against those of `[torch.maximum(a, b) for a, b in zip(tensors1, tensors2)]`

cc @ngimel @albanD @mikaylagawarecki
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82523
Approved by: https://github.com/albanD
2022-08-03 03:40:42 +00:00
ProGamerGov
71d50f4f89 Change docstring type callable to Callable for consistency (#82487)
### Description

Across PyTorch's docstrings, both `callable` and `Callable` for variable types. The Callable should be capitalized as we are referring to the `Callable` type, and not the Python `callable()` function.

### Testing

There shouldn't be any testing required.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82487
Approved by: https://github.com/albanD
2022-08-01 17:26:09 +00:00
ProGamerGov
357b7d589c Fix docstring inconsistencies: string -> str, boolean -> bool (#82410)
### Description

Throughout the PyTorch docs and codebase, the `string` type in docstrings is referred to by two separate names. This leads to inconsistent docs, like you can see here: https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html#torch.nn.Conv3d

This PR fixes this issue by ensuring that all mentions of the string type in docstrings, are using the same format that Sphinx generates hyperlinks for.

### Testing
No testing should be required for this change

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82410
Approved by: https://github.com/jbschlosser
2022-07-28 21:29:57 +00:00
Rob Zinkov
f9ef363982 Modifying Adam to support complex numbers as 2d real numbers (#80279)
This commit addresses issues in #65711

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80279
Approved by: https://github.com/albanD
2022-07-27 18:39:40 +00:00
anjali411
bda04e9f5e Add __all__ for torch.optim and torch.nn.modules modules (#80237)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80237
Approved by: https://github.com/albanD
2022-06-24 21:34:10 +00:00
Sergii Dymchenko
de7219e8a7 Use generators with all/any in torch/optim (#78142)
Generator comprehensions with any/all are less verbose and potentially help to save memory/CPU : https://eklitzke.org/generator-comprehensions-and-using-any-and-all-in-python

To make JIT work with this change, I added code to convert GeneratorExp to ListComp. So the whole PR is basically NoOp for JIT, but potentially memory and speed improvement for eager mode.

Also I removed a test from test/jit/test_parametrization.py. The test was bad and had a TODO to actually implement and just tested that UnsupportedNodeError is thrown, and with GeneratorExp support a different error would be thrown.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78142
Approved by: https://github.com/malfet, https://github.com/albanD
2022-06-24 17:23:45 +00:00
albanD
375668cd96 Remove overly restrictive assert in adam (#80222)
This is causing issues if the user has the step on cuda for a good reason.

These assert prevents code that used to run just fine to fail.
Note that this is a pretty bad thing to do for performance though so it is ok to try and push users away from doing it.

For the 1.12.1 milestone: this is not asking for a dot release to fix this (as this is bad practice anyways). But it would be a great thing to add if we do one: it is very low risk and will prevent breakage for users.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80222
Approved by: https://github.com/jbschlosser, https://github.com/ngimel
2022-06-24 17:08:34 +00:00
Michael Carilli
ba27ee9e8f [CUDA graphs] Allows Adam and AdamW to be capture-safe (#77862)
Near term fix for https://github.com/pytorch/pytorch/issues/76368.

Q. Why does the user need to request `capturable=True` in the optimizer constructor? Why can't capture safety be completely automatic?
A. We need to set up capture-safe (device-side) state variables before capture. If we don't, and step() internally detects capture is underway, it's too late: the best we could do is create a device state variable and copy the current CPU value into it, which is not something we want baked into the graph.

Q. Ok, why not just do the capture-safe approach with device-side state variables all the time?
A. It incurs several more kernel launches per parameter, which could really add up and regress cpu overhead for ungraphed step()s. If the optimizer won't be captured, we should allow step() to stick with its current cpu-side state handling.

Q. But cuda RNG is a stateful thing that maintains its state on the cpu outside of capture and replay, and we capture it automatically. Why can't we do the same thing here?
A. The graph object can handle RNG generator increments because its capture_begin, capture_end, and replay() methods can see and access generator object. But the graph object has no explicit knowledge of or access to optimizer steps in its capture scope. We could let the user tell the graph object what optimizers will be stepped in its scope, ie something like
```python
graph.will_use_optimizer(opt)
graph.capture_begin()
...
```
but that seems clunkier than an optimizer constructor arg.

I'm open to other ideas, but right now I think constructor arg is necessary and the least bad approach.

Long term, https://github.com/pytorch/pytorch/issues/71274 is a better fix.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77862
Approved by: https://github.com/ezyang
2022-06-13 01:56:47 +00:00
Mikayla Gawarecki
5f9590681d Optim foreach cleanup for Adam (#70295)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/70295

Test Plan: Imported from OSS

Reviewed By: anjali411

Differential Revision: D33767870

Pulled By: mikaylagawarecki

fbshipit-source-id: f922f15ecb0307458c8ecee737325c42c4f3ce8b
(cherry picked from commit 66233a8a3e)
2022-02-15 18:02:08 +00:00
Mikayla Gawarecki
7176c92687 [optim] update step in functional and pass state_steps instead of state (#71333)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71333

Updated
- Adagrad
- Adamax
- Adam
- AdamW
- RAdam
make multi_tensor functionals take `state_steps: List[Tensor]` instead of taking `states: List[Dict]`
make `state_steps: List[int]s -> state_steps:List[Tensor]` where each is a Singleton tensor so step can be updated within the functional

(NAdam and ASGD) were updated in separate diffs to fold their handling of state into the functionals

Test Plan: Imported from OSS

Reviewed By: anjali411

Differential Revision: D33767872

Pulled By: mikaylagawarecki

fbshipit-source-id: 9baa7cafb6375eab839917df9287c65a437891f2
(cherry picked from commit 831c02b3d0)
2022-02-08 16:51:19 +00:00
Alban Desmaison
e1b84e1b6b fix loading of older models that don't have maximize (#71023)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/71023

Reviewed By: jbschlosser

Differential Revision: D33483687

Pulled By: albanD

fbshipit-source-id: 2f3c6e97a9579be9ba15eca0756fc1f2c466fbb6
2022-01-10 06:01:24 -08:00
Adnios
15f14ce0dc fix typo in adam docs (#70387)
Summary:
Fix the typo in [adam docs in master branch](https://pytorch.org/docs/master/generated/torch.optim.Adam.html#torch.optim.Adam)

![image](https://user-images.githubusercontent.com/41060790/147345284-37e180d1-fd06-4a62-9c79-2d17b8aa5cd3.png)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/70387

Reviewed By: H-Huang

Differential Revision: D33309283

Pulled By: albanD

fbshipit-source-id: d20c5d8f2498ac64013f71e202a6b50dcc069f2b
2021-12-28 07:35:39 -08:00