mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
e4db5dc1c4
133 Commits
| Author | SHA1 | Message | Date | |
|---|---|---|---|---|
|
|
e4db5dc1c4 |
Revert "[BE] remove unnecessary _dispatch_sqrt by using ** 0.5 (#131358)"
This reverts commit
|
||
|
|
c9888c2739 |
Revert "[BE] typing for decorators - optim/optimizer (#131583)"
This reverts commit |
||
|
|
a1dad77dfa |
[BE] typing for decorators - optim/optimizer (#131583)
See #131429 Pull Request resolved: https://github.com/pytorch/pytorch/pull/131583 Approved by: https://github.com/janeyx99 ghstack dependencies: #131568, #131569, #131570, #131571, #131572, #131573, #131574, #131575, #131576, #131577, #131578, #131579, #131580, #131581, #131582 |
||
|
|
4c7f22dee2 |
[BE] remove unnecessary _dispatch_sqrt by using ** 0.5 (#131358)
Based on the discussion here where ** 0.5 is not slower than math.sqrt. https://github.com/pytorch/pytorch/pull/129905#discussion_r1675605075 Pull Request resolved: https://github.com/pytorch/pytorch/pull/131358 Approved by: https://github.com/albanD |
||
|
|
276b5238ef |
[bug] Add is_compiling check for optimizers to avoid untracked tensor during graph tracing (#130909)
Hey folks, I was using the `stateless_func` [here](
|
||
|
|
5a0068cc69 |
[BE] mypy: disallow untyped decorators (#131428)
Untyped decorators strip the types from their decorated function so even if the underlying function is fully typed then callers to it don't get any benefit from type annotations. Step 1 - Enable the error and override in all the offending files. #131429 Pull Request resolved: https://github.com/pytorch/pytorch/pull/131428 Approved by: https://github.com/justinchuby, https://github.com/oulgen |
||
|
|
99d9b369f4 |
[Optim] Support tensor lr for all optimizers and check it is 1-element (#131065)
Fixes: #130980 Pull Request resolved: https://github.com/pytorch/pytorch/pull/131065 Approved by: https://github.com/janeyx99 |
||
|
|
8ec5ba960f |
[MPS] Add tensor_lr overloads to fused adam & adamw (#129451)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129451 Approved by: https://github.com/janeyx99 |
||
|
|
9a7e2519d3 |
[MPS] Fused Adam & AdamW (#127242)
Summary:
This PR adds fused Adam and AdamW implementations.
Benchmark on Macbook Pro with M1 Max chip and 64GB unified memory:
**Fast math enabled:**
```
[---------------------------------------------- Fused Adam ----------------------------------------------]
| Fused: True | Fused: False
1 threads: -----------------------------------------------------------------------------------------------
amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 100 | 10 | 100
amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 100 | 9 | 89
amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 100 | 9 | 90
amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 100 | 9 | 83
amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 100 | 12 | 94
amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 100 | 11 | 88
amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 100 | 12 | 90
amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 100 | 11 | 100
amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 100 | 27 | 100
amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 100 | 23 | 100
amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 100 | 27 | 100
amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 100 | 23 | 98
amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 500 | 82 | 480
amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 500 | 72 | 450
amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 500 | 82 | 450
amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 500 | 73 | 420
amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 500 | 91 | 500
amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 500 | 83 | 400
amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 500 | 94 | 500
amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 500 | 78 | 400
amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 500 | 170 | 500
amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 500 | 140 | 600
amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 500 | 170 | 600
amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 500 | 140 | 500
amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 1000 | 250 | 890
amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 1000 | 220 | 850
amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 1000 | 250 | 830
amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 1000 | 220 | 770
amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 1000 | 270 | 870
amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 1000 | 230 | 840
amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 1000 | 270 | 810
amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 1000 | 240 | 800
amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 1000 | 400 | 1000
amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 1000 | 360 | 2000
amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 1000 | 430 | 2000
amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 1000 | 360 | 1300
Times are in milliseconds (ms).
```
**Fast math disabled:**
```
[---------------------------------------------- Fused Adam ----------------------------------------------]
| Fused: True | Fused: False
1 threads: -----------------------------------------------------------------------------------------------
amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 100 | 10 | 100
amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 100 | 9 | 84
amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 100 | 9 | 84
amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 100 | 9 | 79
amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 100 | 11 | 93
amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 100 | 10 | 90
amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 100 | 11 | 91
amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 100 | 11 | 81
amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 100 | 34 | 100
amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 100 | 31 | 100
amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 100 | 34 | 95
amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 100 | 31 | 100
amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 500 | 94 | 500
amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 500 | 82 | 430
amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 500 | 92 | 430
amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 500 | 81 | 390
amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 500 | 98 | 500
amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 500 | 88 | 430
amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 500 | 100 | 500
amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 500 | 88 | 400
amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 500 | 210 | 500
amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 500 | 190 | 610
amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 500 | 210 | 510
amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 500 | 190 | 500
amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 1000 | 300 | 900
amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 1000 | 260 | 850
amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 1000 | 295 | 900
amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 1000 | 260 | 800
amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 1000 | 320 | 910
amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 1000 | 280 | 900
amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 1000 | 320 | 900
amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 1000 | 300 | 900
amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 1000 | 500 | 2000
amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 1000 | 480 | 2000
amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 1000 | 540 | 1500
amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 1000 | 480 | 1200
Times are in milliseconds (ms).
```
```python
def profile_fused_adam():
from torch.optim import adam, adamw
import torch.utils.benchmark as benchmark
import itertools
def profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused):
fn(
params,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
foreach=False,
capturable=False,
fused=fused,
amsgrad=amsgrad,
beta1=0.9,
beta2=0.99,
lr=1e-3,
weight_decay=.0,
eps=1e-5,
maximize=False,
grad_scale=None,
found_inf=None,
)
torch.mps.synchronize()
device = "mps"
results = []
for num_tensors, numel, adamWflag, amsgrad in itertools.product([100, 500, 1000], [1024, 65536, 1048576], [True, False], [True, False]):
print(f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}")
params, grads, exp_avgs, exp_avg_sqs = [[torch.arange(numel, dtype=torch.float32, device=device) + (numel * i) for i in range(num_tensors)] for _ in range(4)]
max_exp_avg_sqs = [torch.arange(numel, dtype=torch.float32, device=device) for _ in range(num_tensors)] if amsgrad else []
state_steps = [torch.tensor([5], dtype=torch.float32, device=device) for _ in range(num_tensors)]
if adamWflag:
fn = adamw.adamw
else:
fn = adam.adam
for fused in [True, False]:
t = benchmark.Timer(
stmt='profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused)',
label='Fused Adam',
sub_label=f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}",
globals=locals(),
description= f"Fused: {fused}",
).blocked_autorange(min_run_time=5)
results.append(t)
compare = benchmark.Compare(results)
compare.trim_significant_figures()
compare.colorize(rowwise=True)
compare.print()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127242
Approved by: https://github.com/kulinseth, https://github.com/janeyx99
|
||
|
|
90bb510ece |
Revert "Deprecate torch._utils.is_compiling() and torch._dynamo.external_utils.is_compiling() (#127690)"
This reverts commit
|
||
|
|
27f9d3b0a1 |
Flip default value for mypy disallow_untyped_defs [8/11] (#127845)
See #127836 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127845 Approved by: https://github.com/oulgen ghstack dependencies: #127842, #127843, #127844 |
||
|
|
348b181a97 |
Deprecate torch._utils.is_compiling() and torch._dynamo.external_utils.is_compiling() (#127690)
This PR is split from PR #126898. - #126898 ------ Pull Request resolved: https://github.com/pytorch/pytorch/pull/127690 Approved by: https://github.com/Skylion007 |
||
|
|
033e733021 |
Revert "[BE] wrap deprecated function/class with typing_extensions.deprecated (#126898)"
This reverts commit
|
||
|
|
749a132fb0 |
[BE] wrap deprecated function/class with typing_extensions.deprecated (#126898)
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.
Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.
UPDATE: Use `FutureWarning` instead of `DeprecationWarning`.
Resolves #126888
- #126888
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126898
Approved by: https://github.com/albanD
|
||
|
|
1a28f731dc |
[optim] Merge the pyi files into py files of optimizer (#125452)
Continue the work of pytorch/pytorch#125153 Pull Request resolved: https://github.com/pytorch/pytorch/pull/125452 Approved by: https://github.com/janeyx99 |
||
|
|
b805d3cbcb |
Modify device check in capturable optimizer to support more devices (#124919)
Fixes #124830 Modify device check in capturable optimizer to support more device Pull Request resolved: https://github.com/pytorch/pytorch/pull/124919 Approved by: https://github.com/janeyx99 |
||
|
|
0f02e0aa39 |
Disable dynamo on functional optims if capturable=False (#123619)
This resolves a bug in eager where if an old state dict is loaded (without the capturable flag) but the original dict had the capturable flag, then state_steps would be on cuda but we would take the non-capturable path. We now fallback to eager if capturable=False. Current design doc and discussion: https://docs.google.com/document/d/1DmmbiaSp16CDZtGw1qzXKHFTY_0gqc0xpnBdviXq0vk/edit#heading=h.871u7bvwz7ze Note on the actual fallback logic - there was an issue with torchscript originally not handling *args, **kwargs properly, after rectifying that by using `functools.wraps`, there was an additional bug with scoping which required the single tensor implementation to be in the global scope at the time of the fallback closure being created. I pass in the single tensor function to the `_disable_dynamo_if_unsupported` decorator to workaround this bug. Pull Request resolved: https://github.com/pytorch/pytorch/pull/123619 Approved by: https://github.com/janeyx99 |
||
|
|
3c964ad1ca |
add fused_sgd_kernel support for CPU device (#123629)
Support fused_sgd_kernel support for CPU. ## Bench result: 32 core/sockets ICX Test Scripts: https://gist.github.com/zhuhaozhe/688763e17e93e4c5e12f25f676ec90d9 https://gist.github.com/zhuhaozhe/ad9938694bc7fae8b66d376f4dffc6c9 ``` Tensor Size: 262144, Num Tensor 4, Num Threads: 1 _single_tensor_sgd time: 0.2301 seconds _fused_sgd time: 0.0925 seconds Tensor Size: 4194304, Num Tensor 32, Num Threads: 32 _single_tensor_sgd time: 2.6195 seconds _fused_sgd time: 1.7543 seconds ``` ## Test Plan: ``` python test_optim.py -k test_fused_matches_forloop python test_optim.py -k test_fused_large_tensor python test_optim.py -k test_can_load_older_state_dict python test_optim.py -k test_grad_scaling_autocast_fused_optimizers python test_torch.py -k test_grad_scaling_autocast_fused python test_torch.py -k test_params_invalidated_with_grads_invalidated_between_unscale_and_step ``` Looks like we already have some PRs under this issue https://github.com/pytorch/pytorch/issues/123451 to unified the UTs, I did not modified UT in this PR. Co-authored-by: Jane Xu <janeyx@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/123629 Approved by: https://github.com/jgong5, https://github.com/janeyx99 |
||
|
|
b412b75b42 |
[optim] add fused_adam/adamw_kernel support for CPU device (#123074)
On par with `CUDA` implementation.
For `autocast` logic, same with `CUDA` + `Fused Adam`:
- check inf in `gradscalar.step`
- In fused kernel, if there is `inf`, do nothing. If not, unscale the grad ( also write back) and update the param.
**TestPlan**:
```
# extend CUDA only test for CPU fused adagrad
python test_optim.py -k test_fused_matches_forloop
python test_optim.py -k test_fused_large_tensor
python test_torch.py -k test_grad_scaling_autocast_fused
# extend fused test
python test_torch.py -k test_params_invalidated_with_grads_invalidated_between_unscale_and_step
python test_optim.py -k test_can_load_older_state_dict
# newly added test (follow
|
||
|
|
560efaa471 |
Part 1: UFMT partial files in torch/optim due to the pr-sanity-checks (#124053)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124053 Approved by: https://github.com/ezyang ghstack dependencies: #124048 |
||
|
|
b5ba80828f |
[optim] Rectify capturable testing and fix bugs! (#118326)
This PR fixes several bugs, listed in priority: 1. `load_state_dict` with a nontensor step was incorrect for capturable and fused implementations since we don't create the tensors on the right device in `__setstate__`. This has been fixed. 2. The most recently added capturable implementations forgot the check that all tensors should be on CUDA for eager. We've now added those checks 3. The most recent change in Adamax only adds capturable for foreach but will silently be incorrect for forloop/single-tensor. I've added erroring and modified testing with many many many skips for that. Honestly my preference after this PR has only been further cemented that we should just do the single tensor and multi tensor capturable implementations together in the future. @mlazos 4. The conditional for adding cuda-supported configs for the optimizer infos was incorrect! So we hadn't been testing capturable! This also stands rectified and was the trigger for this PR in the first place. 5. In a similar way, the conditional for `_get_optim_inputs_including_global_cliquey_kwargs` was incorrect sometimes as well. This has also been corrected. The following is not a bug, but is just something to make life simpler by not needing to handle Nones: `optim_input_funcs` must now mandatorily take in a `device`, which could be a string or a torch.device. Details for posterity: 4. Running the test_foreach_matches_forloop test and printing the configs that get printed yields capturable getting included, which is correct. ``` (pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (5d50138f)]$ python test/test_optim.py -k test_foreach_matches_forloop_AdamW_cuda /home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. _torch_pytree._register_pytree_node( /home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0 warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}" params=None, kwargs={}, desc=default params=None, kwargs={'lr': 0.01}, desc=non-default lr params=None, kwargs={'weight_decay': 0.1}, desc=nonzero weight_decay params=None, kwargs={'weight_decay': 0.1, 'maximize': True}, desc=maximize params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True}, desc=amsgrad params=None, kwargs={'capturable': True}, desc=capturable params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True}, desc=capturable, amsgrad params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True}, desc=Tensor lr with capturable and amsgrad . ---------------------------------------------------------------------- Ran 1 test in 19.229s OK ``` 5. Running the test_optimizer_can_be_printed test (which calls `_get_optim_inputs_including_global_cliquey_kwargs`) and printing what gets run is also now correct. ``` /home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0 warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}" params=None, kwargs={'differentiable': False}, desc=default params=None, kwargs={'differentiable': True}, desc=default & differentiable params=None, kwargs={'lr': 0.01, 'differentiable': False}, desc=non-default lr params=None, kwargs={'lr': 0.01, 'differentiable': True}, desc=non-default lr & differentiable params=None, kwargs={'weight_decay': 0.1, 'differentiable': False}, desc=nonzero weight_decay params=None, kwargs={'weight_decay': 0.1, 'differentiable': True}, desc=nonzero weight_decay & differentiable params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'differentiable': False}, desc=maximize params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'differentiable': True}, desc=maximize & differentiable params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'differentiable': False}, desc=amsgrad params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'differentiable': True}, desc=amsgrad & differentiable .params=None, kwargs={'foreach': False, 'differentiable': False, 'fused': False}, desc=default params=None, kwargs={'foreach': True, 'differentiable': False, 'fused': False}, desc=default & foreach params=None, kwargs={'foreach': False, 'differentiable': True, 'fused': False}, desc=default & differentiable params=None, kwargs={'foreach': False, 'differentiable': False, 'fused': True}, desc=default & fused params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': False, 'fused': False}, desc=non-default lr params=None, kwargs={'lr': 0.01, 'foreach': True, 'differentiable': False, 'fused': False}, desc=non-default lr & foreach params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': True, 'fused': False}, desc=non-default lr & differentiable params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': False, 'fused': True}, desc=non-default lr & fused params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': False, 'fused': False}, desc=nonzero weight_decay params=None, kwargs={'weight_decay': 0.1, 'foreach': True, 'differentiable': False, 'fused': False}, desc=nonzero weight_decay & foreach params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': True, 'fused': False}, desc=nonzero weight_decay & differentiable params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': False, 'fused': True}, desc=nonzero weight_decay & fused params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=maximize params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=maximize & foreach params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=maximize & differentiable params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=maximize & fused params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=amsgrad params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=amsgrad & foreach params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=amsgrad & differentiable params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=amsgrad & fused params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=capturable params=None, kwargs={'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=capturable & foreach params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=capturable & differentiable params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=capturable & fused params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=capturable, amsgrad params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=capturable, amsgrad & foreach params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=capturable, amsgrad & differentiable params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=capturable, amsgrad & fused params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=Tensor lr with capturable and amsgrad params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=Tensor lr with capturable and amsgrad & foreach params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=Tensor lr with capturable and amsgrad & differentiable params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=Tensor lr with capturable and amsgrad & fused . ---------------------------------------------------------------------- Ran 2 tests in 11.112s OK ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/118326 Approved by: https://github.com/mlazos |
||
|
|
2964170f3a |
Revert "[optim] Rectify capturable testing and fix bugs! (#118326)"
This reverts commit |
||
|
|
d947b9d500 |
[optim] Rectify capturable testing and fix bugs! (#118326)
This PR fixes several bugs, listed in priority: 1. `load_state_dict` with a nontensor step was incorrect for capturable and fused implementations since we don't create the tensors on the right device in `__setstate__`. This has been fixed. 2. The most recently added capturable implementations forgot the check that all tensors should be on CUDA for eager. We've now added those checks 3. The most recent change in Adamax only adds capturable for foreach but will silently be incorrect for forloop/single-tensor. I've added erroring and modified testing with many many many skips for that. Honestly my preference after this PR has only been further cemented that we should just do the single tensor and multi tensor capturable implementations together in the future. @mlazos 4. The conditional for adding cuda-supported configs for the optimizer infos was incorrect! So we hadn't been testing capturable! This also stands rectified and was the trigger for this PR in the first place. 5. In a similar way, the conditional for `_get_optim_inputs_including_global_cliquey_kwargs` was incorrect sometimes as well. This has also been corrected. The following is not a bug, but is just something to make life simpler by not needing to handle Nones: `optim_input_funcs` must now mandatorily take in a `device`, which could be a string or a torch.device. Details for posterity: 4. Running the test_foreach_matches_forloop test and printing the configs that get printed yields capturable getting included, which is correct. ``` (pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (5d50138f)]$ python test/test_optim.py -k test_foreach_matches_forloop_AdamW_cuda /home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. _torch_pytree._register_pytree_node( /home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0 warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}" params=None, kwargs={}, desc=default params=None, kwargs={'lr': 0.01}, desc=non-default lr params=None, kwargs={'weight_decay': 0.1}, desc=nonzero weight_decay params=None, kwargs={'weight_decay': 0.1, 'maximize': True}, desc=maximize params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True}, desc=amsgrad params=None, kwargs={'capturable': True}, desc=capturable params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True}, desc=capturable, amsgrad params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True}, desc=Tensor lr with capturable and amsgrad . ---------------------------------------------------------------------- Ran 1 test in 19.229s OK ``` 5. Running the test_optimizer_can_be_printed test (which calls `_get_optim_inputs_including_global_cliquey_kwargs`) and printing what gets run is also now correct. ``` /home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0 warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}" params=None, kwargs={'differentiable': False}, desc=default params=None, kwargs={'differentiable': True}, desc=default & differentiable params=None, kwargs={'lr': 0.01, 'differentiable': False}, desc=non-default lr params=None, kwargs={'lr': 0.01, 'differentiable': True}, desc=non-default lr & differentiable params=None, kwargs={'weight_decay': 0.1, 'differentiable': False}, desc=nonzero weight_decay params=None, kwargs={'weight_decay': 0.1, 'differentiable': True}, desc=nonzero weight_decay & differentiable params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'differentiable': False}, desc=maximize params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'differentiable': True}, desc=maximize & differentiable params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'differentiable': False}, desc=amsgrad params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'differentiable': True}, desc=amsgrad & differentiable .params=None, kwargs={'foreach': False, 'differentiable': False, 'fused': False}, desc=default params=None, kwargs={'foreach': True, 'differentiable': False, 'fused': False}, desc=default & foreach params=None, kwargs={'foreach': False, 'differentiable': True, 'fused': False}, desc=default & differentiable params=None, kwargs={'foreach': False, 'differentiable': False, 'fused': True}, desc=default & fused params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': False, 'fused': False}, desc=non-default lr params=None, kwargs={'lr': 0.01, 'foreach': True, 'differentiable': False, 'fused': False}, desc=non-default lr & foreach params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': True, 'fused': False}, desc=non-default lr & differentiable params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': False, 'fused': True}, desc=non-default lr & fused params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': False, 'fused': False}, desc=nonzero weight_decay params=None, kwargs={'weight_decay': 0.1, 'foreach': True, 'differentiable': False, 'fused': False}, desc=nonzero weight_decay & foreach params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': True, 'fused': False}, desc=nonzero weight_decay & differentiable params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': False, 'fused': True}, desc=nonzero weight_decay & fused params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=maximize params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=maximize & foreach params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=maximize & differentiable params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=maximize & fused params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=amsgrad params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=amsgrad & foreach params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=amsgrad & differentiable params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=amsgrad & fused params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=capturable params=None, kwargs={'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=capturable & foreach params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=capturable & differentiable params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=capturable & fused params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=capturable, amsgrad params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=capturable, amsgrad & foreach params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=capturable, amsgrad & differentiable params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=capturable, amsgrad & fused params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=Tensor lr with capturable and amsgrad params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=Tensor lr with capturable and amsgrad & foreach params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=Tensor lr with capturable and amsgrad & differentiable params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=Tensor lr with capturable and amsgrad & fused . ---------------------------------------------------------------------- Ran 2 tests in 11.112s OK ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/118326 Approved by: https://github.com/mlazos |
||
|
|
17ecd1e9cd |
Migrate test_complex_optimizer to OptimizerInfo (#118160)
This PR does what it says and more. 1. We increase coverage by a LOT! Previously, complex was not tested for many many configs, including foreach + maximize at the same time. Or the fused impls. Or just random configs people forgot about. 2. I rearranged the maximize conditional and the _view_as_real to preserve list-ness. This is needed for _view_as_real to function properly, I did add a comment in the Files Changed. This new order also just...makes more aesthetic sense. 3. Note that LBFGS and SparseAdam are skipped--they don't support complex and now we know. Pull Request resolved: https://github.com/pytorch/pytorch/pull/118160 Approved by: https://github.com/mikaylagawarecki |
||
|
|
924f1b841a |
[optim] Allow torch.float64 scalars for forloop + foreach implementations (#115841)
Should allow for uses cases mentioned in #110940 This would allow scalars to also be float64s in the foreach implementation. The fused implementation would still create a float32 step on Adam and AdamW. This PR also does NOT worry about performance and is mainly for enablement. Next steps: - Relax the constraint on fused adam(w) and allow torch.float64 scalars there - Allow _performant_ mixed dtypes in foreach (a bigger project in itself). This PR will conflict with my other PRs, I will figure out a landing order Pull Request resolved: https://github.com/pytorch/pytorch/pull/115841 Approved by: https://github.com/albanD |
||
|
|
5b6b680517 |
Revert "Adamw refactor (#115983)"
This reverts commit
|
||
|
|
eafeba71c1 |
Adamw refactor (#115983)
Fixes #104899, refactors adamw by abstracting out common code in adam. Pull Request resolved: https://github.com/pytorch/pytorch/pull/115983 Approved by: https://github.com/janeyx99 |
||
|
|
62de29d06f |
[optim] be explicit about CPU scalar tensor dtypes (#111008)
Fixes https://github.com/pytorch/pytorch/issues/110940 Pull Request resolved: https://github.com/pytorch/pytorch/pull/111008 Approved by: https://github.com/janeyx99 |
||
|
|
a2552d5521 |
Fixed docstring errors inside torch/cuda/ and torch/optim/ (Docathon H2) (#112964)
Fixes #112592 1) **File: torch/cuda/random.py** ``` Before: /content/pytorch/torch/cuda/random.py:1 at module level: D100: Missing docstring in public module /content/pytorch/torch/cuda/random.py:21 in public function `get_rng_state`: D401: First line should be in imperative mood (perhaps 'Return', not 'Returns') /content/pytorch/torch/cuda/random.py:43 in public function `get_rng_state_all`: D202: No blank lines allowed after function docstring (found 1) /content/pytorch/torch/cuda/random.py:43 in public function `get_rng_state_all`: D401: First line should be in imperative mood (perhaps 'Return', not 'Returns') /content/pytorch/torch/cuda/random.py:54 in public function `set_rng_state`: D401: First line should be in imperative mood (perhaps 'Set', not 'Sets') /content/pytorch/torch/cuda/random.py:79 in public function `set_rng_state_all`: D208: Docstring is over-indented /content/pytorch/torch/cuda/random.py:79 in public function `set_rng_state_all`: D209: Multi-line docstring closing quotes should be on a separate line /content/pytorch/torch/cuda/random.py:79 in public function `set_rng_state_all`: D401: First line should be in imperative mood (perhaps 'Set', not 'Sets') /content/pytorch/torch/cuda/random.py:79 in public function `set_rng_state_all`: D414: Section has no content ('Args') /content/pytorch/torch/cuda/random.py:88 in public function `manual_seed`: D205: 1 blank line required between summary line and description (found 0) /content/pytorch/torch/cuda/random.py:88 in public function `manual_seed`: D401: First line should be in imperative mood (perhaps 'Set', not 'Sets') /content/pytorch/torch/cuda/random.py:110 in public function `manual_seed_all`: D205: 1 blank line required between summary line and description (found 0) /content/pytorch/torch/cuda/random.py:110 in public function `manual_seed_all`: D401: First line should be in imperative mood (perhaps 'Set', not 'Sets') /content/pytorch/torch/cuda/random.py:128 in public function `seed`: D205: 1 blank line required between summary line and description (found 0) /content/pytorch/torch/cuda/random.py:128 in public function `seed`: D401: First line should be in imperative mood (perhaps 'Set', not 'Sets') /content/pytorch/torch/cuda/random.py:146 in public function `seed_all`: D205: 1 blank line required between summary line and description (found 0) /content/pytorch/torch/cuda/random.py:146 in public function `seed_all`: D401: First line should be in imperative mood (perhaps 'Set', not 'Sets') /content/pytorch/torch/cuda/random.py:167 in public function `initial_seed`: D401: First line should be in imperative mood (perhaps 'Return', not 'Returns') 18 ``` ``` After: /content/pytorch/torch/cuda/random.py:1 at module level: D100: Missing docstring in public module 1 ``` 2) **File: torch/cuda/amp/autocast_mode.py** ``` Before: /content/pytorch/torch/cuda/amp/autocast_mode.py:1 at module level: D100: Missing docstring in public module /content/pytorch/torch/cuda/amp/autocast_mode.py:18 in public class `autocast`: D205: 1 blank line required between summary line and description (found 0) /content/pytorch/torch/cuda/amp/autocast_mode.py:23 in public method `__init__`: D107: Missing docstring in __init__ /content/pytorch/torch/cuda/amp/autocast_mode.py:38 in public method `__enter__`: D105: Missing docstring in magic method /content/pytorch/torch/cuda/amp/autocast_mode.py:44 in public method `__exit__`: D105: Missing docstring in magic method /content/pytorch/torch/cuda/amp/autocast_mode.py:49 in public method `__call__`: D102: Missing docstring in public method /content/pytorch/torch/cuda/amp/autocast_mode.py:90 in public function `custom_fwd`: D205: 1 blank line required between summary line and description (found 0) /content/pytorch/torch/cuda/amp/autocast_mode.py:90 in public function `custom_fwd`: D400: First line should end with a period (not 'f') /content/pytorch/torch/cuda/amp/autocast_mode.py:90 in public function `custom_fwd`: D401: First line should be in imperative mood; try rephrasing (found 'Helper') /content/pytorch/torch/cuda/amp/autocast_mode.py:130 in public function `custom_bwd`: D205: 1 blank line required between summary line and description (found 0) /content/pytorch/torch/cuda/amp/autocast_mode.py:130 in public function `custom_bwd`: D400: First line should end with a period (not 'f') /content/pytorch/torch/cuda/amp/autocast_mode.py:130 in public function `custom_bwd`: D401: First line should be in imperative mood; try rephrasing (found 'Helper') 12 ``` ``` After: /content/pytorch/torch/cuda/amp/autocast_mode.py:1 at module level: D100: Missing docstring in public module /content/pytorch/torch/cuda/amp/autocast_mode.py:23 in public method `__init__`: D107: Missing docstring in __init__ /content/pytorch/torch/cuda/amp/autocast_mode.py:38 in public method `__enter__`: D105: Missing docstring in magic method /content/pytorch/torch/cuda/amp/autocast_mode.py:44 in public method `__exit__`: D105: Missing docstring in magic method /content/pytorch/torch/cuda/amp/autocast_mode.py:49 in public method `__call__`: D102: Missing docstring in public method 5 ``` 3) **File: torch/cuda/amp/grad_scaler.py** ``` Before: /content/pytorch/torch/cuda/amp/grad_scaler.py:1 at module level: D100: Missing docstring in public module /content/pytorch/torch/cuda/amp/grad_scaler.py:17 in private class `_MultiDeviceReplicator`: D200: One-line docstring should fit on one line with quotes (found 3) /content/pytorch/torch/cuda/amp/grad_scaler.py:39 in public class `OptState`: D101: Missing docstring in public class /content/pytorch/torch/cuda/amp/grad_scaler.py:50 in public class `GradScaler`: D205: 1 blank line required between summary line and description (found 0) /content/pytorch/torch/cuda/amp/grad_scaler.py:50 in public class `GradScaler`: D400: First line should end with a period (not 'g') /content/pytorch/torch/cuda/amp/grad_scaler.py:115 in public method `__init__`: D107: Missing docstring in __init__ /content/pytorch/torch/cuda/amp/grad_scaler.py:354 in public method `step`: D400: First line should end with a period (not ':') /content/pytorch/torch/cuda/amp/grad_scaler.py:456 in public method `update`: D401: First line should be in imperative mood (perhaps 'Update', not 'Updates') /content/pytorch/torch/cuda/amp/grad_scaler.py:529 in public method `get_scale`: D401: First line should be in imperative mood (perhaps 'Return', not 'Returns') /content/pytorch/torch/cuda/amp/grad_scaler.py:544 in public method `get_growth_factor`: D200: One-line docstring should fit on one line with quotes (found 3) /content/pytorch/torch/cuda/amp/grad_scaler.py:544 in public method `get_growth_factor`: D401: First line should be in imperative mood (perhaps 'Return', not 'Returns') /content/pytorch/torch/cuda/amp/grad_scaler.py:550 in public method `set_growth_factor`: D205: 1 blank line required between summary line and description (found 0) /content/pytorch/torch/cuda/amp/grad_scaler.py:550 in public method `set_growth_factor`: D400: First line should end with a period (not ':') /content/pytorch/torch/cuda/amp/grad_scaler.py:557 in public method `get_backoff_factor`: D200: One-line docstring should fit on one line with quotes (found 3) /content/pytorch/torch/cuda/amp/grad_scaler.py:557 in public method `get_backoff_factor`: D401: First line should be in imperative mood (perhaps 'Return', not 'Returns') /content/pytorch/torch/cuda/amp/grad_scaler.py:563 in public method `set_backoff_factor`: D205: 1 blank line required between summary line and description (found 0) /content/pytorch/torch/cuda/amp/grad_scaler.py:563 in public method `set_backoff_factor`: D400: First line should end with a period (not ':') /content/pytorch/torch/cuda/amp/grad_scaler.py:570 in public method `get_growth_interval`: D200: One-line docstring should fit on one line with quotes (found 3) /content/pytorch/torch/cuda/amp/grad_scaler.py:570 in public method `get_growth_interval`: D401: First line should be in imperative mood (perhaps 'Return', not 'Returns') /content/pytorch/torch/cuda/amp/grad_scaler.py:576 in public method `set_growth_interval`: D205: 1 blank line required between summary line and description (found 0) /content/pytorch/torch/cuda/amp/grad_scaler.py:576 in public method `set_growth_interval`: D400: First line should end with a period (not ':') /content/pytorch/torch/cuda/amp/grad_scaler.py:592 in public method `is_enabled`: D200: One-line docstring should fit on one line with quotes (found 3) /content/pytorch/torch/cuda/amp/grad_scaler.py:592 in public method `is_enabled`: D401: First line should be in imperative mood (perhaps 'Return', not 'Returns') /content/pytorch/torch/cuda/amp/grad_scaler.py:598 in public method `state_dict`: D400: First line should end with a period (not ':') /content/pytorch/torch/cuda/amp/grad_scaler.py:598 in public method `state_dict`: D401: First line should be in imperative mood (perhaps 'Return', not 'Returns') /content/pytorch/torch/cuda/amp/grad_scaler.py:624 in public method `load_state_dict`: D401: First line should be in imperative mood (perhaps 'Load', not 'Loads') /content/pytorch/torch/cuda/amp/grad_scaler.py:649 in public method `__getstate__`: D105: Missing docstring in magic method /content/pytorch/torch/cuda/amp/grad_scaler.py:665 in public method `__setstate__`: D105: Missing docstring in magic method 28 ``` ``` After: /content/pytorch/torch/cuda/amp/grad_scaler.py:1 at module level: D100: Missing docstring in public module /content/pytorch/torch/cuda/amp/grad_scaler.py:40 in public class `OptState`: D101: Missing docstring in public class /content/pytorch/torch/cuda/amp/grad_scaler.py:117 in public method `__init__`: D107: Missing docstring in __init__ /content/pytorch/torch/cuda/amp/grad_scaler.py:647 in public method `__getstate__`: D105: Missing docstring in magic method /content/pytorch/torch/cuda/amp/grad_scaler.py:663 in public method `__setstate__`: D105: Missing docstring in magic method 5 ``` 4) **File: torch/optim/_functional.py** ``` Before: /content/pytorch/torch/optim/_functional.py:1 at module level: D400: First line should end with a period (not 'e') 1 ``` ``` After: 0 ``` 5) **File: torch/optim/__init__.py** ``` Before: /content/pytorch/torch/optim/__init__.py:1 at module level: D205: 1 blank line required between summary line and description (found 0) 1 ``` ``` After: 0 ``` 6) **File: torch/optim/lbfgs.py** ``` Before: /content/pytorch/torch/optim/lbfgs.py:1 at module level: D100: Missing docstring in public module /content/pytorch/torch/optim/lbfgs.py:185 in public class `LBFGS`: D205: 1 blank line required between summary line and description (found 0) /content/pytorch/torch/optim/lbfgs.py:185 in public class `LBFGS`: D400: First line should end with a period (not 'c') /content/pytorch/torch/optim/lbfgs.py:215 in public method `__init__`: D107: Missing docstring in __init__ /content/pytorch/torch/optim/lbfgs.py:285 in public method `step`: D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs') 5 ``` ``` After: /content/pytorch/torch/optim/lbfgs.py:1 at module level: D100: Missing docstring in public module /content/pytorch/torch/optim/lbfgs.py:217 in public method `__init__`: D107: Missing docstring in __init__ 2 ``` 7)**File: torch/optim/sparse_adam.py** ``` Before: /content/pytorch/torch/optim/sparse_adam.py:1 at module level: D100: Missing docstring in public module /content/pytorch/torch/optim/sparse_adam.py:7 in public class `SparseAdam`: D101: Missing docstring in public class /content/pytorch/torch/optim/sparse_adam.py:8 in public method `__init__`: D107: Missing docstring in __init__ /content/pytorch/torch/optim/sparse_adam.py:40 in public method `step`: D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs') 4 ``` ``` After: /content/pytorch/torch/optim/sparse_adam.py:1 at module level: D100: Missing docstring in public module /content/pytorch/torch/optim/sparse_adam.py:7 in public class `SparseAdam`: D101: Missing docstring in public class /content/pytorch/torch/optim/sparse_adam.py:8 in public method `__init__`: D107: Missing docstring in __init__ 3 ``` 8) **File:torch/optim/adadelta.py** ``` Before: /content/pytorch/torch/optim/adadelta.py:1 at module level: D100: Missing docstring in public module /content/pytorch/torch/optim/adadelta.py:11 in public class `Adadelta`: D101: Missing docstring in public class /content/pytorch/torch/optim/adadelta.py:12 in public method `__init__`: D107: Missing docstring in __init__ /content/pytorch/torch/optim/adadelta.py:44 in public method `__setstate__`: D105: Missing docstring in magic method /content/pytorch/torch/optim/adadelta.py:82 in public method `step`: D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs') /content/pytorch/torch/optim/adadelta.py:193 in public function `adadelta`: D202: No blank lines allowed after function docstring (found 1) 6 ``` ``` After: /content/pytorch/torch/optim/adadelta.py:1 at module level: D100: Missing docstring in public module /content/pytorch/torch/optim/adadelta.py:11 in public class `Adadelta`: D101: Missing docstring in public class /content/pytorch/torch/optim/adadelta.py:12 in public method `__init__`: D107: Missing docstring in __init__ /content/pytorch/torch/optim/adadelta.py:44 in public method `__setstate__`: D105: Missing docstring in magic method 4 ``` 9) **File: torch/optim/adagrad.py** ``` Before: /content/pytorch/torch/optim/adagrad.py:1 at module level: D100: Missing docstring in public module /content/pytorch/torch/optim/adagrad.py:11 in public class `Adagrad`: D101: Missing docstring in public class /content/pytorch/torch/optim/adagrad.py:12 in public method `__init__`: D107: Missing docstring in __init__ /content/pytorch/torch/optim/adagrad.py:63 in public method `__setstate__`: D105: Missing docstring in magic method /content/pytorch/torch/optim/adagrad.py:78 in public method `share_memory`: D102: Missing docstring in public method /content/pytorch/torch/optim/adagrad.py:100 in public method `step`: D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs') /content/pytorch/torch/optim/adagrad.py:201 in public function `adagrad`: D202: No blank lines allowed after function docstring (found 1) 7 ``` ``` After: /content/pytorch/torch/optim/adagrad.py:1 at module level: D100: Missing docstring in public module /content/pytorch/torch/optim/adagrad.py:11 in public class `Adagrad`: D101: Missing docstring in public class /content/pytorch/torch/optim/adagrad.py:12 in public method `__init__`: D107: Missing docstring in __init__ /content/pytorch/torch/optim/adagrad.py:63 in public method `__setstate__`: D105: Missing docstring in magic method /content/pytorch/torch/optim/adagrad.py:78 in public method `share_memory`: D102: Missing docstring in public method 5 ``` 10) **File: torch/optim/adam.py** ``` Before: /content/pytorch/torch/optim/adam.py:1 at module level: D100: Missing docstring in public module /content/pytorch/torch/optim/adam.py:14 in public class `Adam`: D101: Missing docstring in public class /content/pytorch/torch/optim/adam.py:15 in public method `__init__`: D107: Missing docstring in __init__ /content/pytorch/torch/optim/adam.py:65 in public method `__setstate__`: D105: Missing docstring in magic method /content/pytorch/torch/optim/adam.py:135 in public method `step`: D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs') /content/pytorch/torch/optim/adam.py:281 in public function `adam`: D202: No blank lines allowed after function docstring (found 1) /content/pytorch/torch/optim/adam.py:281 in public function `adam`: D205: 1 blank line required between summary line and description (found 0) 7 ``` ``` After: /content/pytorch/torch/optim/adam.py:1 at module level: D100: Missing docstring in public module /content/pytorch/torch/optim/adam.py:14 in public class `Adam`: D101: Missing docstring in public class /content/pytorch/torch/optim/adam.py:15 in public method `__init__`: D107: Missing docstring in __init__ /content/pytorch/torch/optim/adam.py:65 in public method `__setstate__`: D105: Missing docstring in magic method 4 ``` 11) **File: torch/optim/adamax.py** ``` Before: /content/pytorch/torch/optim/adamax.py:1 at module level: D100: Missing docstring in public module /content/pytorch/torch/optim/adamax.py:12 in public class `Adamax`: D101: Missing docstring in public class /content/pytorch/torch/optim/adamax.py:13 in public method `__init__`: D107: Missing docstring in __init__ /content/pytorch/torch/optim/adamax.py:47 in public method `__setstate__`: D105: Missing docstring in magic method /content/pytorch/torch/optim/adamax.py:91 in public method `step`: D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs') /content/pytorch/torch/optim/adamax.py:203 in public function `adamax`: D202: No blank lines allowed after function docstring (found 1) 6 ``` ``` After: /content/pytorch/torch/optim/adamax.py:1 at module level: D100: Missing docstring in public module /content/pytorch/torch/optim/adamax.py:12 in public class `Adamax`: D101: Missing docstring in public class /content/pytorch/torch/optim/adamax.py:13 in public method `__init__`: D107: Missing docstring in __init__ /content/pytorch/torch/optim/adamax.py:47 in public method `__setstate__`: D105: Missing docstring in magic method 4 ``` 12) **File: torch/optim/adamw.py** ``` Before: /content/pytorch/torch/optim/adamw.py:1 at module level: D100: Missing docstring in public module /content/pytorch/torch/optim/adamw.py:12 in public class `AdamW`: D101: Missing docstring in public class /content/pytorch/torch/optim/adamw.py:13 in public method `__init__`: D107: Missing docstring in __init__ /content/pytorch/torch/optim/adamw.py:73 in public method `__setstate__`: D105: Missing docstring in magic method /content/pytorch/torch/optim/adamw.py:153 in public method `step`: D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs') /content/pytorch/torch/optim/adamw.py:304 in public function `adamw`: D202: No blank lines allowed after function docstring (found 1) 6 ``` ``` After: /content/pytorch/torch/optim/adamw.py:1 at module level: D100: Missing docstring in public module /content/pytorch/torch/optim/adamw.py:12 in public class `AdamW`: D101: Missing docstring in public class /content/pytorch/torch/optim/adamw.py:13 in public method `__init__`: D107: Missing docstring in __init__ /content/pytorch/torch/optim/adamw.py:73 in public method `__setstate__`: D105: Missing docstring in magic method 4 ``` 13) **File: torch/optim/asgd.py** ``` Before: /content/pytorch/torch/optim/asgd.py:1 at module level: D100: Missing docstring in public module /content/pytorch/torch/optim/asgd.py:17 in public class `ASGD`: D101: Missing docstring in public class /content/pytorch/torch/optim/asgd.py:18 in public method `__init__`: D107: Missing docstring in __init__ /content/pytorch/torch/optim/asgd.py:52 in public method `__setstate__`: D105: Missing docstring in magic method /content/pytorch/torch/optim/asgd.py:107 in public method `step`: D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs') /content/pytorch/torch/optim/asgd.py:195 in public function `asgd`: D202: No blank lines allowed after function docstring (found 1) 6 ``` ``` After: /content/pytorch/torch/optim/asgd.py:1 at module level: D100: Missing docstring in public module /content/pytorch/torch/optim/asgd.py:17 in public class `ASGD`: D101: Missing docstring in public class /content/pytorch/torch/optim/asgd.py:18 in public method `__init__`: D107: Missing docstring in __init__ /content/pytorch/torch/optim/asgd.py:52 in public method `__setstate__`: D105: Missing docstring in magic method 4 ``` Resolved docstring errors as listed. I initially changed in the main branch of forked repo which caused changes to appear in my PR to other issue. I have fixed that and hope this PR won't have any conflicts. Kindly review @svekars @jbschlosser. In case of any other issues please let me know. Thanks! Pull Request resolved: https://github.com/pytorch/pytorch/pull/112964 Approved by: https://github.com/kit1980 |
||
|
|
f74d766632 |
feat(optim): use has_complex shortcut flag for all applicable optimizers, use _view_as_real auxiliary function (#110706)
Follow up to: https://github.com/pytorch/pytorch/pull/110607 CC: @lezcano @janeyx99 Pull Request resolved: https://github.com/pytorch/pytorch/pull/110706 Approved by: https://github.com/lezcano |
||
|
|
93a9b1314b |
Make step() faster by passing in a tensor vs scalar 1 (#111084)
This is the culminated result of https://github.com/pytorch/pytorch/pull/110954#issuecomment-1758520411. We are making the code slightly more complicated to gain some perf in minimizing calls to `.copy_()` and `.to()`. ### Code ``` import torch with torch.cuda.device(0): steps = [torch.zeros((), device="cpu", dtype=torch.float32) for i in range(1000)] with torch.profiler.profile( activities=[ torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, ] ) as p: # New code: # step_device = steps[0].device # one = torch.tensor(1.0, device=step_device) if str(step_device) == "cpu" else 1 # torch._foreach_add_(steps, one, 1.0) # Old code: torch._foreach_add_(steps, 1) print(p.key_averages().table(sort_by="cpu_time_total")) ``` ### Profiles **with old code** ``` ------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ------------------------- ------------ ------------ ------------ ------------ ------------ ------------ aten::_foreach_add_ 35.31% 52.089ms 99.99% 147.495ms 147.495ms 1 aten::add_ 25.05% 36.949ms 64.68% 95.406ms 95.406us 1000 aten::to 3.97% 5.852ms 39.63% 58.457ms 58.457us 1000 aten::_to_copy 10.11% 14.917ms 35.66% 52.605ms 52.605us 1000 aten::copy_ 21.65% 31.939ms 21.65% 31.939ms 31.939us 1000 aten::empty_strided 3.90% 5.749ms 3.90% 5.749ms 5.749us 1000 cudaDeviceSynchronize 0.01% 18.000us 0.01% 18.000us 18.000us 1 ------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Self CPU time total: 147.513ms ``` **with new code** ``` ------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ------------------------- ------------ ------------ ------------ ------------ ------------ ------------ aten::_foreach_add_ 55.06% 49.963ms 99.86% 90.625ms 90.625ms 1 aten::add_ 44.81% 40.662ms 44.81% 40.662ms 40.662us 1000 aten::detach_ 0.01% 8.000us 0.05% 45.000us 45.000us 1 detach_ 0.04% 37.000us 0.04% 37.000us 37.000us 1 aten::empty 0.03% 30.000us 0.03% 30.000us 30.000us 1 aten::to 0.03% 23.000us 0.03% 23.000us 23.000us 1 cudaDeviceSynchronize 0.02% 22.000us 0.02% 22.000us 22.000us 1 aten::lift_fresh 0.01% 6.000us 0.01% 6.000us 6.000us 1 ------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Self CPU time total: 90.751ms ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/111084 Approved by: https://github.com/albanD ghstack dependencies: #111079 |
||
|
|
b460c30893 |
[BE] Enable Ruff's Flake8 PYI042 (#111114)
Enable [snake-case-type-alias (PYI042)](https://docs.astral.sh/ruff/rules/snake-case-type-alias/) Link: #110950 Pull Request resolved: https://github.com/pytorch/pytorch/pull/111114 Approved by: https://github.com/albanD |
||
|
|
d279979102 |
perf(inductor): improve Adam compile times by shortcutting for loops (via has_complex) (#110607)
Adam part of: https://github.com/pytorch/pytorch/issues/110506 TODO: - If this approach is validated as a good one, it an also be applied to all other optimizers which convert `complex` via list comprehensions ### Results: `NUM_PARAMS=200, foreach=True` - main: dynamo: 43s, inductor: 31s, total: 74s - this PR: dynamo: 3.5s, inductor: 30s, total: 34s (dynamo speedup: 12.3x, overall speedup: 34s, 2.1x) `NUM_PARAMS=1000, foreach=True, has_complex shortcut`: ``` <class 'torch.optim.adam.Adam'> {'lr': 0.01, 'foreach': True} torch.float32 TorchDynamo compilation metrics: Function Runtimes (s) ------------------------------------ ------------------------------- _compile.<locals>.compile_inner 0.0329, 50.0806, 0.0041 OutputGraph.call_user_compiler 44.9924 ``` `NUM_PARAMS=1000, foreach=True`: ``` <class 'torch.optim.adam.Adam'> {'lr': 0.01, 'foreach': True} torch.float32 TorchDynamo compilation metrics: Function Runtimes (s) ------------------------------------ ------------------------------- _compile.<locals>.compile_inner 0.0389, 58.6069, 0.0043 OutputGraph.call_user_compiler 44.1425 ``` ### Discussion - `has_complex` shortcut provides additional 2x dynamo speedup. It is not necessary to achieve a significant overall speedup. CC: @janeyx99 @mlazos Pull Request resolved: https://github.com/pytorch/pytorch/pull/110607 Approved by: https://github.com/janeyx99, https://github.com/lezcano |
||
|
|
df7d01aed5 |
perf(inductor): use for loop with shortcut in Optimizers to speedup against list comprehensions (e.g. complex conversion) (#110613)
Fully fixes: https://github.com/pytorch/pytorch/issues/110506 Depends: https://github.com/pytorch/pytorch/pull/110607 Potential merge conflicts: - https://github.com/pytorch/pytorch/pull/110339 - https://github.com/pytorch/pytorch/pull/110345 - https://github.com/pytorch/pytorch/pull/110454 Related: - https://github.com/pytorch/pytorch/issues/110606 (we can apply the improvements here orthogonally to the complex support) ### Results Benchmark: 100 params. Breakdowns (float32, dynamo): ``` Adagrad: this PR: 4.4s, main: 8.8s Adam: this PR: 2.1s, main: 9.8s AdamW: this PR: 2.5s, main: 8.2s ASGD: this PR: 3.1s, main: 8.5s RMSProp: this PR: 1.3s, main: 4.2s RProp: this PR: 6.7s, main: 14.9s ``` Notes: 1. Adagrad is still slow due to `_get_value` list comprehension. Can be fixed in https://github.com/pytorch/pytorch/pull/110339/files by utilizing capturable path 2. Adamax is not actually compiled (it is currently disabled). 3. Inductor compile time is quite variable. We calculate dynamo by subtracting `call_user_compiler` from `compile_inner` timing. <details> This PR: ``` Adagrad (torch.float32): 28.47496461868286s Adagrad (torch.complex64): 29.379547357559204s Adam (torch.float32): 17.334211587905884s Adam (torch.complex64): 29.637500524520874s Adamax (torch.float32): 2.4749321937561035s Adamax (torch.complex64): 3.1997995376586914s AdamW (torch.float32): 18.06532859802246s AdamW (torch.complex64): 28.25661015510559s ASGD (torch.float32): 23.70255398750305s ASGD (torch.complex64): 25.33756995201111s RMSprop (torch.float32): 7.964028596878052s RMSprop (torch.complex64): 12.909599781036377s Rprop (torch.float32): 30.512362003326416s Rprop (torch.complex64): 44.74405765533447s ``` Main ``` Adagrad (torch.float32): 26.919506072998047s Adagrad (torch.complex64): 35.190622091293335s Adam (torch.float32): 25.715000867843628s Adam (torch.complex64): 24.17716670036316s Adamax (torch.float32): 2.4404726028442383s Adamax (torch.complex64): 3.3538928031921387s AdamW (torch.float32): 25.2022807598114s AdamW (torch.complex64): 28.915700912475586s ASGD (torch.float32): 24.108731985092163s ASGD (torch.complex64): 26.589075088500977s RMSprop (torch.float32): 10.781344175338745s RMSprop (torch.complex64): 15.136352777481079s Rprop (torch.float32): 42.46482181549072s Rprop (torch.complex64): 48.28277635574341s ``` Seems that it doesn't help the complex case by much (but that's not the majority case). torch.float32 is generally positive, when it does not show drastic improvement / regresses, it is due to inductor variance (by manually inspecting the logs). </details> ### Benchmark Script ```python import torch import time from torch.optim import Adagrad, Adam, Adamax, AdamW, ASGD, RMSprop, Rprop OPTIMS = [Adagrad, Adam, Adamax, AdamW, ASGD, RMSprop, Rprop] DTYPES = [torch.float, torch.cfloat] NUM_PARAMS = 100 kwargs = { "lr": 0.01, "foreach": True } summary = [] for optim_cls in OPTIMS: for dtype in DTYPES: torch._dynamo.reset() # torch._inductor.metrics.reset() input = torch.ones([10, 10], dtype=dtype, device="cuda:0") model = torch.nn.Sequential( *[torch.nn.Linear(10, 10, dtype=dtype, device="cuda:0") for _ in range(NUM_PARAMS)] ) model(input).sum().abs().backward() opt_compiled = optim_cls(model.parameters(), **kwargs) compiled_step = torch.compile(opt_compiled.step) with torch.set_grad_enabled(False): start_time = time.time() compiled_step() summary.append(f"{optim_cls.__name__} ({dtype}): {time.time() - start_time}s") print(optim_cls, kwargs, dtype, torch._dynamo.utils.compile_times()) for s in summary: print(s) ``` CC: @janeyx99 @mlazos Pull Request resolved: https://github.com/pytorch/pytorch/pull/110613 Approved by: https://github.com/janeyx99 |
||
|
|
1641d671e5 |
[optim] FusedAdam/W accepts lr: Tensor without h2ds (#106916)
Starts addressing #106802 This PR also conveniently does some BE: - Fixes a bug in adamw where we use amsgrad instead of per group amsgrad - Brings the impls of adamw and adam closer to correctness and to each other I couldn't fully remove the .pyi's because mypy was going to complain about the entire files which scared me and shouldn't go in this PR anyway. Test plan: - Add tests to ensure that lr could be passed as a Tensor - Did some profiling of the below code (runs 1k iterations of step for Adam) ``` import torch from torch.testing._internal.common_utils import TestCase param = torch.rand(2, 3, dtype=torch.float, device='cuda:0', requires_grad=True) param.grad = torch.rand_like(param) lr = torch.tensor(.001, device='cuda:0') opt = torch.optim.Adam([param], lr=lr, fused=True) with torch.profiler.profile( activities=[ torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, ] ) as p: for _ in range(1000): opt.step() print(p.key_averages().table(sort_by="cpu_time_total")) ``` Before my change: <img width="1381" alt="image" src="https://github.com/pytorch/pytorch/assets/31798555/cfc5175a-0f41-4829-941f-342554f3b152"> After my change (notice there are no d2h syncs and the CPU time is lower!):  Next steps long term: - have all capturable foreach + forloop impls in Adam(W) handle tensor LR - have all capturable impls handle tensor LR - have all impls handle tensor LR Pull Request resolved: https://github.com/pytorch/pytorch/pull/106916 Approved by: https://github.com/albanD |
||
|
|
608afe8083 |
Added xla friendly codepath to single_tensor_adamw (#102858)
There are extra graph compilations on XLA when beta{1,2} ** step get too small. This PR addresses this issue by making the `capturable` interface enabled for XLA, as well as switching to `torch.float_power` which preserves the same behaviour as the non-capturable flow on XLA.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102858
Approved by: https://github.com/janeyx99, https://github.com/albanD
|
||
|
|
21ede4547a |
remove duplicated code in optimizer (#106022)
Fixes #ISSUE_NUMBER as the title, the check code has duplicates Pull Request resolved: https://github.com/pytorch/pytorch/pull/106022 Approved by: https://github.com/janeyx99 |
||
|
|
6d43c89f37 |
[BE]: Update Ruff to 0.0.280 (#105724)
Removes unusued loop values in python dictionary iteration. Automated fix from Ruff master Pull Request resolved: https://github.com/pytorch/pytorch/pull/105724 Approved by: https://github.com/ezyang, https://github.com/janeyx99 |
||
|
|
e1296a7f8d |
[Adam] Fix complex x amsgrad support (#104989)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104989 Approved by: https://github.com/albanD |
||
|
|
25d80c69ce |
[foreach] super minor BE: remove unnecessary cast (#105601)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105601 Approved by: https://github.com/albanD |
||
|
|
3721fa5612 |
[BE] Enable ruff's UP rules and autoformat optim/ (#105426)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105426 Approved by: https://github.com/malfet, https://github.com/albanD, https://github.com/aaronenyeshi, https://github.com/janeyx99 |
||
|
|
ef05c5f202 |
Use plain power operator in Adam/Adamw when capturing (#104254)
The goal is to fix the problem from https://github.com/pytorch/pytorch/pull/102858 The full error this used to raise was : ``` 2023-06-27T15:12:15.0663239Z File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/optim/adamw.py", line 409, in _single_tensor_adamw 2023-06-27T15:12:15.0663699Z bias_correction1 = 1 - beta1 ** step 2023-06-27T15:12:15.0664200Z File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_tensor.py", line 40, in wrapped 2023-06-27T15:12:15.0664547Z return f(*args, **kwargs) 2023-06-27T15:12:15.0665031Z File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_tensor.py", line 882, in __rpow__ 2023-06-27T15:12:15.0665483Z return torch.tensor(other, dtype=dtype, device=self.device) ** self 2023-06-27T15:12:15.0665899Z RuntimeError: CUDA error: operation not permitted when stream is capturing 2023-06-27T15:12:15.0666401Z CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. ``` This pow issue was fixed in https://github.com/pytorch/pytorch/pull/104264 and so this problem should be solvable now. Pull Request resolved: https://github.com/pytorch/pytorch/pull/104254 Approved by: https://github.com/janeyx99, https://github.com/aws-murandoo |
||
|
|
231364fd06 |
[optim] use lerp whenever possible (#104796)
This is a better copy (with fixes) of #104781. Test plan: CI will pass once https://github.com/pytorch/pytorch/pull/104784 is landed Internal CI (and the newly enabled compiled optim tests) will pass after https://github.com/pytorch/pytorch/pull/104866 is landed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/104796 Approved by: https://github.com/albanD |
||
|
|
35f0e35529 |
[foreach][Adam] Minimize use of intermediates to decrease peak memory (#104780)
Starts addressing https://github.com/pytorch/pytorch/issues/97712 by - Minimizing intermediates usage for foreach Adam - Document the extra memory usage - Add comments within the code for clarity now that we reuse intermediates - Add tests - Did some refactoring Next steps involve doing this for all other foreach implementations. Note that even after this change, foreach mem usage will be higher than forloop due to the fact that we have a minimum budget of 1 intermediate (to not muddle the input values) and the intermediate will be larger. For capturable, the memory usage is higher due to moving more tensors to CUDA. Pull Request resolved: https://github.com/pytorch/pytorch/pull/104780 Approved by: https://github.com/albanD |
||
|
|
e7fe2a797c |
Revert "[optim] use lerp whenever possible (#104796)"
This reverts commit
|
||
|
|
fbe2a7e50a |
[optim] use lerp whenever possible (#104796)
This is a better copy (with fixes) of #104781. Test plan: CI will pass once https://github.com/pytorch/pytorch/pull/104784 is landed Pull Request resolved: https://github.com/pytorch/pytorch/pull/104796 Approved by: https://github.com/albanD |
||
|
|
a290cbf32b |
Enable fused foreach Adam compilation (#104121)
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/104121 Approved by: https://github.com/janeyx99 |
||
|
|
6d2887cc06 |
Reland "Move tensor grouping to ATen" (#103912)
This is a reland of https://github.com/pytorch/pytorch/pull/100007 with a build fix for Windows debug builds.
`at::native::ParamsHash` only works on structs with standard layout, but `std::string` isn't one in Visual C++ debug builds, which one can easily verified by running something like:
```cpp
#define _DEBUG
#include <type_traits>
#include <string>
static_assert(std::is_standard_layout_v<std::string>, "Oh noes");
```
If above conditon is not met, instead of printing a static_assert output, VC++ raises a very cryptic compilation errors, see https://github.com/pytorch/pytorch/pull/100007#discussion_r1227116292 for more detail.
Also, using `std::hash` for string should result in a faster hash function.
(cherry picked from commit
|
||
|
|
fa893f3f58 |
Fix optim state_dict casting to allow step to cast to CPU (#102619)
I'm guessing this should fix https://github.com/pytorch/pytorch/pull/88015#issuecomment-1569523106 but am waiting on @ychfan to supply more details so I could write a good test case. Pull Request resolved: https://github.com/pytorch/pytorch/pull/102619 Approved by: https://github.com/albanD |
||
|
|
0cb5bc3b04 |
Revert "Move tensor grouping to ATen (#100007)"
This reverts commit
|