Commit Graph

122 Commits

Author SHA1 Message Date
Jane Xu
bcf1f312a0 Migrate nontensor step and CUDA params state_dict tests to OptimizerInfo (#116509)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116509
Approved by: https://github.com/albanD
2024-01-12 22:32:37 +00:00
Jane Xu
924f1b841a [optim] Allow torch.float64 scalars for forloop + foreach implementations (#115841)
Should allow for uses cases mentioned in #110940

This would allow scalars to also be float64s in the foreach implementation. The fused implementation would still create a float32 step on Adam and AdamW. This PR also does NOT worry about performance and is mainly for enablement.

Next steps:
- Relax the constraint on fused adam(w) and allow torch.float64 scalars there
- Allow _performant_ mixed dtypes in foreach (a bigger project in itself).

This PR will conflict with my other PRs, I will figure out a landing order

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115841
Approved by: https://github.com/albanD
2023-12-27 09:13:49 +00:00
Jane Xu
44b98c09ca [BE] migrate all assertRaises tests to OptimizerInfo test_errors (#116315)
Removes a part of the sparse adam test and the following three tests: `test_fused_optimizer_raises`, `test_duplicate_params_across_param_groups`, `test_duplicate_params_in_one_param_group`

```
(pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (d2d129de)]$ python test/test_optim.py -k test_fused_optimizer_raises -k test_duplicate_params_across_param_groups -k test_duplicate_params_in_one_param_group
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0
  warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
...
----------------------------------------------------------------------
Ran 3 tests in 0.023s

OK
```

Increases coverage by testing the duplicate param tests on ALL the optims instead of just one each. Also fixes SparseAdam bug which was accidentally calling torch.unbind through list instead of putting params in a list. This bug was caught by migrating the weird warning stuff to just one easy warning context manager, which checks that nothing else gets raised.

The new test_errors does not run slower than before, overhead is still king:
```
(pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (d2d129de)]$ python test/test_optim.py -k test_errors
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0
  warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
..........................
----------------------------------------------------------------------
Ran 26 tests in 10.337s

OK
```

Compared to test_errors BEFORE my commit :p
```
(pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (b47aa696)]$ python test/test_optim.py -k test_errors
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0
  warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
.............sssssssssssss
----------------------------------------------------------------------
Ran 26 tests in 11.980s

OK (skipped=13)
(pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (b47aa696)]$
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116315
Approved by: https://github.com/mikaylagawarecki
2023-12-27 00:08:31 +00:00
Jane Xu
7c1a5012f0 [BE][SparseAdam] cleaner way to verify no sparse params (#114425)
Context:

https://github.com/pytorch/pytorch/pull/47724 fixed the problem that SparseAdam could not handle generators by using the `list(...)` construct. However, this meant that SparseAdam deviated from other optimizers in that it could _accept_ a raw Tensors/Parameter vs requiring a container of them. This is not really a big deal.

So why this PR?

I do think this PR is cleaner. It uses the fact that the Optimizer parent class already containerizes parameters into parameter groups, so we could reuse that here by calling `super().__init__` first and then filter the param_groups after. This change would also make SparseAdam consistent with the rest of our optimizers in that only containerized params are accepted, which technically is BC breaking SO I've added a deprecation warning that we should remove in May 2024.

(But is it really BC breaking when we've said in the docs that params should be an iterable this whole time? Maybe this is just a bug fix....😛)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114425
Approved by: https://github.com/drisspg
2023-11-29 19:47:03 +00:00
Axel Donath
174aef71af Clarify maximize option in optimizer.py (#112724)
While reading the documentation of the optimizers I noticed the description of the `maximize` option is misleading. It currently reads as if the parameters would we maximized, which is factually incorrect. This PR proposes a more clear description.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112724
Approved by: https://github.com/albanD
2023-11-02 16:34:37 +00:00
Jon Chuang
f74d766632 feat(optim): use has_complex shortcut flag for all applicable optimizers, use _view_as_real auxiliary function (#110706)
Follow up to: https://github.com/pytorch/pytorch/pull/110607

CC: @lezcano @janeyx99
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110706
Approved by: https://github.com/lezcano
2023-10-31 20:33:03 +00:00
isdanni
b460c30893 [BE] Enable Ruff's Flake8 PYI042 (#111114)
Enable [snake-case-type-alias (PYI042)](https://docs.astral.sh/ruff/rules/snake-case-type-alias/)

Link: #110950
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111114
Approved by: https://github.com/albanD
2023-10-13 16:33:07 +00:00
PyTorch MergeBot
3a3cf0e09d Revert "[optim] Make casting to match params a hook (#106725)"
This reverts commit 9f86d85172.

Reverted https://github.com/pytorch/pytorch/pull/106725 on behalf of https://github.com/janeyx99 due to We acknowledge this is a huge risk because people do not remember to call super().__init__ from their Optimizer subclasses and so this will break lots of load_state_dict behavior ([comment](https://github.com/pytorch/pytorch/pull/106725#issuecomment-1693386137))
2023-08-25 13:47:19 +00:00
Jane Xu
9f86d85172 [optim] Make casting to match params a hook (#106725)
Moves the logic to casting state to match parameters into a hook so that users can choose to enable their hooks before or after the casting has happened.

With this, there is a little bit of redundancy of the id_map building and the check that the param groups are still aligned in length.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106725
Approved by: https://github.com/albanD
2023-08-23 22:25:33 +00:00
Jane Xu
59d0dea90f Only make a shallow copy when loading optimizer state_dict (#106082)
The thing we do still deep copy is the param_groups, which is much lighter weight. This should also save memory when loading from a checkpoint.

The deepcopy was introduced in ecfcf39f30, but module.py had only a shallow copy at that point so it did not actually bring parity.

Incorporates an XLA fix, which is why I'm updating the pin to ca5eab87a7

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106082
Approved by: https://github.com/albanD, https://github.com/Skylion007
2023-08-01 05:33:31 +00:00
Jane Xu
ad3af0aead Change phrasing on optim state hook docs (#106209)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106209
Approved by: https://github.com/albanD
2023-07-28 18:59:21 +00:00
Jane Xu
dffa4e14b9 Add Optimizer state_dict hooks (#105953)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105953
Approved by: https://github.com/albanD
2023-07-28 11:52:41 +00:00
Jane Xu
ec0ffac33b [BE] Document optimizer state_dict better, use example (#105958)
![image](https://github.com/pytorch/pytorch/assets/31798555/50ce293c-d884-47ab-b5f5-9ba41e3b4bad)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105958
Approved by: https://github.com/albanD
2023-07-27 23:08:42 +00:00
Matthew Hoffman
0616952d13 Merge and improve torch optim optimizer type stubs (#102593)
Fixes #102428

Also improves hook registration type hints:

```python
from typing import Any, Dict, Tuple

from torch import nn
from torch.optim import Adam, Adagrad, Optimizer

linear = nn.Linear(2,2)
optimizer = Adam(linear.parameters(), lr=0.001)

def pre_hook_fn_return_none(optimizer: Adam, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None:
    return None

def pre_hook_fn_return_modified(
    optimizer: Optimizer, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
    return inputs, kwargs

def hook_fn(optimizer: Optimizer, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None:
    return None

def hook_fn_other_optimizer(optimizer: Adagrad, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None:
    return None

optimizer.register_step_post_hook(hook_fn)  # OK

optimizer.register_step_pre_hook(pre_hook_fn_return_none)  # OK
optimizer.register_step_pre_hook(pre_hook_fn_return_modified)  # OK

optimizer.register_step_post_hook(hook_fn_other_optimizer)  # Parameter 1: type "Adam" cannot be assigned to type "Adagrad"

```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102593
Approved by: https://github.com/janeyx99, https://github.com/malfet
2023-07-26 11:56:42 +00:00
Justin Chu
3721fa5612 [BE] Enable ruff's UP rules and autoformat optim/ (#105426)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105426
Approved by: https://github.com/malfet, https://github.com/albanD, https://github.com/aaronenyeshi, https://github.com/janeyx99
2023-07-18 21:07:43 +00:00
Nikita Shulga
5837e95d30 [Reland] Update mypy to 1.4.1 (#105227)
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)

That were reverted due to the conflict with internal source repo.

Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  - Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`

Unrelated, to bypass CI failures due to the gcc9 dependency update in Ubuntu-18.04:
- Add hack to squash older libstdc++ from conda environment in favor one from OS to `.ci/docker/install_conda.sh`
- Update bazel cuda builds to focal, as with libstdc++-6.0.32 bazel builds loose the ability to catch exceptions (probably because they link with cupti statically, but I could not found where it is done)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
2023-07-15 20:30:20 +00:00
PyTorch MergeBot
15fd1ea118 Revert "[Reland] Update mypy to 1.4.1 (#105227)"
This reverts commit c9c4f8efc3.

Reverted https://github.com/pytorch/pytorch/pull/105227 on behalf of https://github.com/atalman due to trying to mitigate ci sev #105248 ([comment](https://github.com/pytorch/pytorch/pull/105227#issuecomment-1636510935))
2023-07-14 22:28:35 +00:00
Nikita Shulga
c9c4f8efc3 [Reland] Update mypy to 1.4.1 (#105227)
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)

That were reverted due to the conflict with internal source repo.

Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  - Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
2023-07-14 20:45:12 +00:00
PyTorch MergeBot
1646d6f939 Revert "Merge and improve torch optim optimizer type stubs (#102593)"
This reverts commit 3279f06410.

Reverted https://github.com/pytorch/pytorch/pull/102593 on behalf of https://github.com/malfet due to There is nothing wrong with this PR, but it fails some internal builds that depend on outdated typing_extensions, will reland when update is done ([comment](https://github.com/pytorch/pytorch/pull/102593#issuecomment-1636062515))
2023-07-14 16:04:54 +00:00
PyTorch MergeBot
3c5a494d7a Revert "Update mypy to 1.4.1 (#91983)"
This reverts commit 634659e262.

Reverted https://github.com/pytorch/pytorch/pull/91983 on behalf of https://github.com/malfet due to It's dependent change was reverted, so reverting this one as well, to keep CI clean ([comment](https://github.com/pytorch/pytorch/pull/91983#issuecomment-1636059709))
2023-07-14 15:59:16 +00:00
Nikita Shulga
634659e262 Update mypy to 1.4.1 (#91983)
Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  -
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91983
Approved by: https://github.com/kit1980, https://github.com/ZainRizvi, https://github.com/huydhn, https://github.com/thiagocrepaldi, https://github.com/aaronenyeshi
2023-07-13 16:30:36 +00:00
Matthew Hoffman
3279f06410 Merge and improve torch optim optimizer type stubs (#102593)
Fixes #102428

Also improves hook registration type hints:

```python
from typing import Any, Dict, Tuple

from torch import nn
from torch.optim import Adam, Adagrad, Optimizer

linear = nn.Linear(2,2)
optimizer = Adam(linear.parameters(), lr=0.001)

def pre_hook_fn_return_none(optimizer: Adam, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None:
    return None

def pre_hook_fn_return_modified(
    optimizer: Optimizer, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
    return inputs, kwargs

def hook_fn(optimizer: Optimizer, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None:
    return None

def hook_fn_other_optimizer(optimizer: Adagrad, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None:
    return None

optimizer.register_step_post_hook(hook_fn)  # OK

optimizer.register_step_pre_hook(pre_hook_fn_return_none)  # OK
optimizer.register_step_pre_hook(pre_hook_fn_return_modified)  # OK

optimizer.register_step_post_hook(hook_fn_other_optimizer)  # Parameter 1: type "Adam" cannot be assigned to type "Adagrad"

```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102593
Approved by: https://github.com/janeyx99
2023-07-11 00:07:30 +00:00
Jane Xu
35f0e35529 [foreach][Adam] Minimize use of intermediates to decrease peak memory (#104780)
Starts addressing https://github.com/pytorch/pytorch/issues/97712 by
- Minimizing intermediates usage for foreach Adam
- Document the extra memory usage
- Add comments within the code for clarity now that we reuse intermediates
- Add tests
- Did some refactoring

Next steps involve doing this for all other foreach implementations. Note that even after this change, foreach mem usage will be higher than forloop due to the fact that we have a minimum budget of 1 intermediate (to not muddle the input values) and the intermediate will be larger. For capturable, the memory usage is higher due to moving more tensors to CUDA.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104780
Approved by: https://github.com/albanD
2023-07-10 17:38:46 +00:00
Animesh Jain
0444f9f85b [dynamo] Reland #104317 - Lazy disable_dynamo API out-of-dynamo (#104664)
Internal failed because of torch.deploy issues with disable_dynamo in fx/* and _jit/* files. Removing disable_dynamo for both. Added a comment in the code.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104664
Approved by: https://github.com/wconstab
2023-07-06 00:48:02 +00:00
Michael Lazos
a290cbf32b Enable fused foreach Adam compilation (#104121)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104121
Approved by: https://github.com/janeyx99
2023-07-05 23:40:03 +00:00
PyTorch MergeBot
54e320d4d1 Revert "[dynamo] Lazy disable_dynamo API out-of-dynamo (#104317)"
This reverts commit 5c12a810ac.

Reverted https://github.com/pytorch/pytorch/pull/104317 on behalf of https://github.com/huydhn due to This has been reverted internally by D47166892, so I need to also revert it on OSS to keep them in sync ([comment](https://github.com/pytorch/pytorch/pull/104317#issuecomment-1621099151))
2023-07-05 06:21:48 +00:00
Animesh Jain
5c12a810ac [dynamo] Lazy disable_dynamo API out-of-dynamo (#104317)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104317
Approved by: https://github.com/jansel, https://github.com/wconstab, https://github.com/mlazos
2023-06-29 13:30:17 +00:00
Michael Lazos
5a97c947c6 Fix optimizer grad mode state interaction with dynamo (#103952)
Graph break before restoring the grad mode to ensure dynamo respects `no_grad`. This isn't a bug necessarily, but this will allow us to get good perf until aot is updated.

https://github.com/pytorch/pytorch/issues/104053

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103952
Approved by: https://github.com/janeyx99
2023-06-23 02:07:08 +00:00
Nikita Shulga
6d2887cc06 Reland "Move tensor grouping to ATen" (#103912)
This is a reland of https://github.com/pytorch/pytorch/pull/100007 with a build fix for Windows debug builds.
`at::native::ParamsHash` only works on structs with standard layout, but `std::string` isn't one in Visual C++ debug builds, which one can easily verified by running something like:
```cpp
#define _DEBUG
#include <type_traits>
#include <string>
static_assert(std::is_standard_layout_v<std::string>, "Oh noes");
```
If above conditon is not met, instead of printing a static_assert output, VC++ raises a very cryptic compilation errors,  see https://github.com/pytorch/pytorch/pull/100007#discussion_r1227116292 for more detail.

Also, using `std::hash` for string should result in a faster hash function.

(cherry picked from commit 74b7a6c75e)

<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at 5914771</samp>

This pull request introduces a new function `_group_tensors_by_device_and_dtype` that can group tensors by their device and dtype, and updates the `foreach` utilities and several optimizers to use this function. The goal is to improve the performance, readability, and compatibility of the code that handles tensors with different properties. The pull request also adds a test case and type annotations for the new function, and some error checks for the `fused` argument in Adam and AdamW.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103912
Approved by: https://github.com/janeyx99
2023-06-21 09:26:33 +00:00
WEN Hao
67babf7a45 Enhance decorator _use_grad_for_differentiable (#103567)
Aim: enhance decorator _use_grad_for_differentiable so that functions (methods) decorated by it keep their docstrings and signatures unchanged.

Fixes #103566

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103567
Approved by: https://github.com/janeyx99
2023-06-16 18:33:31 +00:00
Andrew Gu
9152d0e5be Silence has_cuda deprecation in optim (#103610)
```
UserWarning: 'has_cuda' is deprecated, please use 'torch.backends.cuda.is_built()'
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103610
Approved by: https://github.com/janeyx99, https://github.com/Skylion007
2023-06-14 22:09:22 +00:00
Jane Xu
fa893f3f58 Fix optim state_dict casting to allow step to cast to CPU (#102619)
I'm guessing this should fix https://github.com/pytorch/pytorch/pull/88015#issuecomment-1569523106 but am waiting on @ychfan to supply more details so I could write a good test case.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102619
Approved by: https://github.com/albanD
2023-06-13 00:46:40 +00:00
PyTorch MergeBot
0cb5bc3b04 Revert "Move tensor grouping to ATen (#100007)"
This reverts commit 74b7a6c75e.

Reverted https://github.com/pytorch/pytorch/pull/100007 on behalf of https://github.com/izaitsevfb due to Breaks internal builds, see D46629727 ([comment](https://github.com/pytorch/pytorch/pull/100007#issuecomment-1587861598))
2023-06-12 18:30:33 +00:00
Masaki Kozuki
74b7a6c75e Move tensor grouping to ATen (#100007)
rel: #94344
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100007
Approved by: https://github.com/janeyx99
2023-06-09 15:44:46 +00:00
shibo19
e4a42bcf56 add foreach support for custom device (#102047)
Fixes #ISSUE_NUMBER
for custom device, we want to support foreach, so I add a func that we could set other device type, and the default value is cuda.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102047
Approved by: https://github.com/janeyx99
2023-06-07 13:59:20 +00:00
Michael Lazos
00f1bb0963 Fix optimizer cuda health check graph break (can be done in the compiler) (#102765)
- Ignore the health check if we are compiling
- Don't disable the function anymore

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102765
Approved by: https://github.com/albanD
2023-06-03 03:42:23 +00:00
Michael Lazos
4da88447ea Disable grouping by dtype and device if compiling (#102771)
Disable grouping if we are compiling, this happens during lowering
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102771
Approved by: https://github.com/janeyx99
2023-06-02 21:04:49 +00:00
PyTorch MergeBot
9d77949b9e Revert "add foreach support for custom device (#102047)"
This reverts commit b088ff4677.

Reverted https://github.com/pytorch/pytorch/pull/102047 on behalf of https://github.com/malfet due to Broke inductor, see b088ff4677 ([comment](https://github.com/pytorch/pytorch/pull/102047#issuecomment-1572368942))
2023-06-01 16:33:03 +00:00
shibo19
b088ff4677 add foreach support for custom device (#102047)
Fixes #ISSUE_NUMBER
for custom device, we want to support foreach, so I add a func that we could set other device type, and the default value is cuda.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102047
Approved by: https://github.com/janeyx99
2023-06-01 06:22:44 +00:00
PyTorch MergeBot
4637c5ae5b Revert "Simplify _use_grad_for_differentiable (#98706)"
This reverts commit b9da79d280.

Reverted https://github.com/pytorch/pytorch/pull/98706 on behalf of https://github.com/huydhn due to Sorry for reverting your PR but a bunch of inductor tests are failing after this commit, so reverting the PR just to be sure
2023-04-22 00:35:56 +00:00
Jason Ansel
b9da79d280 Simplify _use_grad_for_differentiable (#98706)
This makes it so dynamo can trace through it

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98706
Approved by: https://github.com/janeyx99
2023-04-21 20:47:19 +00:00
Jane Xu
aacbf091db Allow fused optimizers to call _foreach_zero_ in zero_grad (#97159)
Fixes #97032

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97159
Approved by: https://github.com/Skylion007
2023-03-20 19:03:26 +00:00
Aaron Gokaslan
5471621497 [BE] Remove unnecessary dict comprehensions (#97116)
Removes unnecessary dict comprehensions that optimize creation of dicts from iterables

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97116
Approved by: https://github.com/kit1980
2023-03-20 00:56:57 +00:00
Aaron Gokaslan
dd9ade6377 Remove unnecessary items() call in zero_grad (#97040)
Micro-optimization to zero_grad() which is performance critical
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97040
Approved by: https://github.com/ezyang, https://github.com/albanD
2023-03-17 21:34:14 +00:00
Jane Xu
75cb99e549 [optim] Widen the cases for defaulting to foreach (#95820)
Big OOP correction continued. Also added a test this time to verify the defaulting was as expected.

The key here is realizing that the grouping for foreach already assumes that the non-param tensorlists follow suit in dtype and device, so it is too narrow to check that _all_ tensors were on CUDA. The main leeway this allowed was state_steps, which are sometimes cpu tensors. Since foreach _can_ handle cpu tensors, this should not introduce breakage.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95820
Approved by: https://github.com/albanD
2023-03-02 04:15:33 +00:00
Jane Xu
2bcf863fad [optim] include nn.Parameter as foreach supported (#95811)
This PR is a result of a realization that models are NOT subscribed to the foreach defaulting as have been claimed on our documentation for months now. BIG OOPS.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95811
Approved by: https://github.com/albanD
2023-03-02 04:15:33 +00:00
Jane Xu
e5b9d98752 Rephrase zero_grad docs (#95643)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95643
Approved by: https://github.com/albanD
2023-02-28 22:04:23 +00:00
Jane Xu
097679478e [optim] Set defaults to foreach, NOT fused (#95241)
Rolling back the default change for Adam and rectifying the docs to reflect that AdamW never defaulted to fused.

Since our fused implementations are relatively newer, let's give them a longer bake-in time before flipping the switch for every user.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95241
Approved by: https://github.com/ngimel
2023-02-22 04:47:32 +00:00
Aaron Gokaslan
8fce9a09cd [BE]: pyupgrade Python to 3.8 - imports and object inheritance only (#94308)
Apply parts of pyupgrade to torch (starting with the safest changes).
This PR only does two things: removes the need to inherit from object and removes unused future imports.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94308
Approved by: https://github.com/ezyang, https://github.com/albanD
2023-02-07 21:10:56 +00:00
Masaki Kozuki
6ba041fcae Look up group["capturable"], not defaults["capturable"] in Adam(W) (#94149)
We could set different values in each `param_group` when calling dunder init of `torch.optim` optimizers as in e.g.  https://github.com/pytorch/pytorch/issues/89987.

So check whether or not `capturable` is `True` among all the `param_group`s.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94149
Approved by: https://github.com/albanD
2023-02-07 00:24:35 +00:00