Commit Graph

89 Commits

Author SHA1 Message Date
Yuanyuan Chen
f91899ca6c [2/N] Add strict parameter to Python zip calls (#166257)
This PR adds `strict=True/False` to zip calls in test utils. strict=True is passed when possible.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166257
Approved by: https://github.com/janeyx99
2025-11-01 00:35:41 +00:00
PyTorch MergeBot
f60751024e Revert "[2/N] Add strict parameter to Python zip calls (#166257)"
This reverts commit 39e5cdddf7.

Reverted https://github.com/pytorch/pytorch/pull/166257 on behalf of https://github.com/atalman due to Failing: test/distributed/fsdp/test_fsdp_mixed_precision.py::TestFSDPTrainEval::test_train_ema_eval_flow [GH job link](https://github.com/pytorch/pytorch/actions/runs/18934047991/job/54057218160) [HUD commit link](39e5cdddf7) ([comment](https://github.com/pytorch/pytorch/pull/166257#issuecomment-3467955332))
2025-10-30 13:20:00 +00:00
Yuanyuan Chen
39e5cdddf7 [2/N] Add strict parameter to Python zip calls (#166257)
This PR adds `strict=True/False` to zip calls in test utils. strict=True is passed when possible.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166257
Approved by: https://github.com/janeyx99
2025-10-30 08:10:10 +00:00
Maggie Moss
d1a6e006e0 Fix syntax for pyrefly errors (#166496)
Last one! This ensures all existing suppressions match the syntax expected and will silence only one error code

pyrefly check
lintrunner

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166496
Approved by: https://github.com/Skylion007, https://github.com/mlazos
2025-10-29 20:00:25 +00:00
mansiag05
f8fccb1e48 [Code Clean] Clean asserts in torch/optim. (#165629)
Replaces 50 assert statements across 15 files in torch.optim with explicit  if-checks raising AssertionError to prevent assertions from being disabled with Python -O flag.

fix partially #164878

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165629
Approved by: https://github.com/albanD
2025-10-23 15:56:29 +00:00
Maggie Moss
4ab847bbc7 Pyrefly suppressions 4/n (#164615)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: uncomment lines in the pyrefly.toml file
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/356645cf8cfe33123d9a27f23b30f7b1

after:

0 errors (2,753 ignored)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164615
Approved by: https://github.com/oulgen
2025-10-06 16:14:36 +00:00
Zeina Migeed
4f5be56612 [Pyrefly][Refactor] Replace dict() calls with literal dict syntax for improved readability (#157735)
There are 31 places that I spotted which construct literal dictionaries.

This PR refactors dictionary construction by replacing` dict(...) `calls with `literal {...}` syntax where applicable.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157735
Approved by: https://github.com/ezyang, https://github.com/Skylion007
2025-07-08 18:10:33 +00:00
Xuehai Pan
db259bd6b8 [BE][12/16] fix typos in torch/ (#156602)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156602
Approved by: https://github.com/justinchuby, https://github.com/albanD
ghstack dependencies: #156318, #156320
2025-07-02 22:55:29 +00:00
Xuehai Pan
596b418391 [BE][PYFMT] migrate PYFMT for {torch,test}/{nn,optim}/** to ruff format (#144548)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144548
Approved by: https://github.com/ezyang
2025-06-14 11:27:04 +00:00
Tony-Y
78715a181f Convert Tensor lr to 0-dim as needed for the optimizer to normally work (#145674)
Fixes #145461

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145674
Approved by: https://github.com/janeyx99

Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
2025-03-17 23:07:05 +00:00
Aaron Orenstein
7178b827d7 PEP585: Missed conversions (#145342)
Differential Revision: [D68785969](https://our.internmc.facebook.com/intern/diff/D68785969)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145342
Approved by: https://github.com/bobrenjc93
2025-01-29 05:24:36 +00:00
Aaron Orenstein
45ef3309e3 [BE] typing for decorators (#144161)
Summary:
Untyped decorators strip annotations from the decorated items.

- _compile
- _inductor/fx_passes/post_grad
- _inductor/lowering
- _library/custom_ops
- _meta_registrations
- _ops
- _refs/nn/functional
- ao/quantization/quantizer/xnnpack_quantizer_utils
- distributed/_composable/contract
- fx/experimental/graph_gradual_typechecker
- fx/experimental/migrate_gradual_types/constraint_generator
- optim/optimizer
- signal/windows/windows
- testing/_internal/common_device_type
- torch/_inductor/decomposition
- utils/flop_counter

Test Plan: unit tests

Differential Revision: D62302684

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144161
Approved by: https://github.com/Skylion007, https://github.com/albanD
2025-01-04 16:40:09 +00:00
Jane Xu
7b69f7b449 Clarify what we mean by decoupled weight decay in the *AdamWs (#144101)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144101
Approved by: https://github.com/albanD
2025-01-03 19:06:00 +00:00
Xuehai Pan
e1196dfe51 Deprecate torch._utils.is_compiling() (#127690)
This PR is split from PR #126898.

- #126898

------

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127690
Approved by: https://github.com/Skylion007, https://github.com/malfet
2024-12-08 22:55:36 +00:00
Aaron Gokaslan
08db735629 [BE]: Update mypy to 1.13.0 (#140808)
Update mypy to 1.13.0 . Should hopefully reduce linting time. Has support for orjson cache serialization which should improve mypy cache perf if orjson is installed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140808
Approved by: https://github.com/ezyang, https://github.com/malfet
2024-12-03 02:50:10 +00:00
PyTorch MergeBot
daa77f3d9f Revert "[BE]: Update mypy to 1.13.0 (#140808)"
This reverts commit 00134d68af.

Reverted https://github.com/pytorch/pytorch/pull/140808 on behalf of https://github.com/huydhn due to This is failing a distributed test in trunk, target determination missed this test and did not run it on PR ([comment](https://github.com/pytorch/pytorch/pull/140808#issuecomment-2512788426))
2024-12-02 20:47:43 +00:00
Aaron Gokaslan
00134d68af [BE]: Update mypy to 1.13.0 (#140808)
Update mypy to 1.13.0 . Should hopefully reduce linting time. Has support for orjson cache serialization which should improve mypy cache perf if orjson is installed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140808
Approved by: https://github.com/ezyang, https://github.com/malfet
2024-12-02 18:47:54 +00:00
PyTorch MergeBot
1d28b8b6d5 Revert "Deprecate torch._utils.is_compiling() and torch._dynamo.external_utils.is_compiling() (#127690)"
This reverts commit e84d1121ad.

Reverted https://github.com/pytorch/pytorch/pull/127690 on behalf of https://github.com/ZainRizvi due to Sorry but this is breaking internally. More details in D65483292 ([comment](https://github.com/pytorch/pytorch/pull/127690#issuecomment-2458381056))
2024-11-05 23:10:38 +00:00
Xuehai Pan
e84d1121ad Deprecate torch._utils.is_compiling() and torch._dynamo.external_utils.is_compiling() (#127690)
This PR is split from PR #126898.

- #126898

------

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127690
Approved by: https://github.com/Skylion007, https://github.com/malfet
2024-11-05 10:44:56 +00:00
ErezYosef
197601eeea Add Support for Tracking Parameter Names (named_parameters) in Optimizer State Dict (#134107)
A proposal addressing Issue #1489: **Optimizer should track parameter names and not id.**

(also mentioned in here: [[RFC] Introducing FQNs/clarity eyeglasses to optim state_dict](https://dev-discuss.pytorch.org/t/rfc-introducing-fqns-clarity-to-optim-state-dict/1552)

## Summary
This PR introduces a backward-compatible enhancement where optimizers track parameter names instead of just their id.
Optimizers can be initialized with `named_parameters()` as:
```python
optimizer = optim.SGD(model.named_parameters(), lr=0.01, momentum=0.9)
```
This allows for greater clarity and ease when handling optimizers, as the parameters' names are preserved within the optimizer’s `state_dict` as:
```
state_dict =
{
    'state': {
    0: {'momentum_buffer': tensor(...), ...},
    1: {'momentum_buffer': tensor(...), ...},
    },
    'param_groups': [
        {
        'lr': 0.01,
        'weight_decay': 0,
        ...
        'params': [0,1]
        'param_names' ['layer.weight', 'layer.bias']  (optional)
        }
    ]
}
```
Loading `state_dict` is not changed (backward-compatible) and the `param_names` key will be ignored.

## Key Features
#### Named Parameters in Optimizer Initialization:
Optimizers can accept the output of `model.named_parameters()` during initialization, allowing them to store parameter names directly.
#### Parameter Names in `state_dict`:
The parameter names are saved as a list in the optimizer’s `state_dict` with key `param_names`, alongside the `params` indices, ensuring seamless tracking of both names and parameters.

## Backward Compatibility
#### No Breaking Changes:
This change is fully backward-compatible. The added `param_names` key in the optimizer's `state_dict` is ignored when loading a state to the optimizer.

#### Customization with Hooks:
For more control, the loaded state_dict can be modified using a custom `register_load_state_dict_pre_hook`, providing flexibility for different design needs.

## Documentation Updates
Please refer to the documentation changes for more details on how this feature is implemented and how it can be used effectively.

## Solution Example:

A suggested solution to the problem mentioned in #1489, for the same parameters but in a different order.
The following `register_load_state_dict_pre_hook` should be added to the optimizer before loading to enable loading the state dict :
```python
def adapt_state_dict_ids(optimizer, state_dict):
    # assuming a single param group.
    current_state_group = optimizer.state_dict()['param_groups'][0]
    loaded_state_group = state_dict['param_groups'][0]

    # same number of params, same names, only different ordering
    current_state_name_to_id_mapping = {}  # mapping --  param_name: id
    for i, name in enumerate(current_state_group['param_names']):
        current_state_name_to_id_mapping[name] = current_state_group['params'][i]

    # changing the ids of the loaded state dict to match the order of the given state dict.
    for i, name in enumerate(current_state_group['param_names']):
        loaded_state_group['params'][i] = current_state_name_to_id_mapping[name]

    return state_dict
```
In this code, the loaded `state_dict` ids are adapted to match the order of the current optimizer `state_dict`.
Both the previous and the current optimizers are required to be initiated with `named_parameters()` to have the 'param_names' key in the dict.

### Note
This is my first contribution to PyTorch, and I wish to receive feedback or suggestions for improvement.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134107
Approved by: https://github.com/janeyx99

Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
2024-10-14 19:24:44 +00:00
Jane Xu
972822dea1 Minorly reorder optim kwargs in docs, fixes #137391 (#137531)
Closes #137391

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137531
Approved by: https://github.com/albanD
2024-10-09 04:14:45 +00:00
Jane Xu
14750dd737 Correct return type of grouping helper function in Optimizer (#133360)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133360
Approved by: https://github.com/albanD
2024-08-14 01:56:02 +00:00
PyTorch MergeBot
cbee9c1fd2 Revert "Deprecate torch._utils.is_compiling() and torch._dynamo.external_utils.is_compiling() (#127690)"
This reverts commit 0e7e61f7ce.

Reverted https://github.com/pytorch/pytorch/pull/127690 on behalf of https://github.com/kit1980 due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/127690#issuecomment-2272370386))
2024-08-07 00:05:20 +00:00
Xuehai Pan
0e7e61f7ce Deprecate torch._utils.is_compiling() and torch._dynamo.external_utils.is_compiling() (#127690)
This PR is split from PR #126898.

- #126898

------

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127690
Approved by: https://github.com/Skylion007, https://github.com/malfet
2024-08-03 09:43:38 +00:00
Xuehai Pan
30293319a8 [BE][Easy][19/19] enforce style for empty lines in import segments in torch/[o-z]*/ (#129771)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129771
Approved by: https://github.com/justinchuby, https://github.com/janeyx99
2024-08-01 17:07:14 +00:00
Jane Xu
3816f6420a [BE] remove unnecessary _dispatch_sqrt by using ** 0.5 (#131358)
Based on the discussion here where ** 0.5 is not slower than math.sqrt. https://github.com/pytorch/pytorch/pull/129905#discussion_r1675605075

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131358
Approved by: https://github.com/albanD
2024-07-30 18:08:17 +00:00
PyTorch MergeBot
e4db5dc1c4 Revert "[BE] remove unnecessary _dispatch_sqrt by using ** 0.5 (#131358)"
This reverts commit 4c7f22dee2.

Reverted https://github.com/pytorch/pytorch/pull/131358 on behalf of https://github.com/janeyx99 due to Internal uses this private API and landing that has been a pain so we're reverting this first ([comment](https://github.com/pytorch/pytorch/pull/131358#issuecomment-2253190654))
2024-07-26 17:35:27 +00:00
PyTorch MergeBot
c9888c2739 Revert "[BE] typing for decorators - optim/optimizer (#131583)"
This reverts commit a1dad77dfa.

Reverted https://github.com/pytorch/pytorch/pull/131583 on behalf of https://github.com/atalman due to Breaks CI: [GH job link](https://github.com/pytorch/pytorch/actions/runs/10105959146/job/27947741162) [HUD commit link](a1dad77dfa) ([comment](https://github.com/pytorch/pytorch/pull/131583#issuecomment-2252784280))
2024-07-26 13:41:22 +00:00
Aaron Orenstein
a1dad77dfa [BE] typing for decorators - optim/optimizer (#131583)
See #131429
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131583
Approved by: https://github.com/janeyx99
ghstack dependencies: #131568, #131569, #131570, #131571, #131572, #131573, #131574, #131575, #131576, #131577, #131578, #131579, #131580, #131581, #131582
2024-07-26 05:00:07 +00:00
Jane Xu
4c7f22dee2 [BE] remove unnecessary _dispatch_sqrt by using ** 0.5 (#131358)
Based on the discussion here where ** 0.5 is not slower than math.sqrt. https://github.com/pytorch/pytorch/pull/129905#discussion_r1675605075

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131358
Approved by: https://github.com/albanD
2024-07-24 14:58:57 +00:00
hxwang
276b5238ef [bug] Add is_compiling check for optimizers to avoid untracked tensor during graph tracing (#130909)
Hey folks, I was using the `stateless_func` [here](7c45476d38/torch/distributed/_spmd/api.py (L435)), which worked well before [this commit](https://github.com/pytorch/pytorch/pull/111084) but then introduced a `_tensor_constant0` and made this func non-stateless. Since there is no way to retrieve this constant tensor before compilation and performance is not an issue when tracing a graph, I think it might be good to fall back to the other branch.
![image](https://github.com/user-attachments/assets/6ee4487d-456b-47e0-8c1d-66cb5a641d47)

![image](https://github.com/user-attachments/assets/1ed46502-e50e-45c4-9751-49aa5a4590ae)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130909
Approved by: https://github.com/mlazos
2024-07-24 08:29:27 +00:00
Aaron Orenstein
5a0068cc69 [BE] mypy: disallow untyped decorators (#131428)
Untyped decorators strip the types from their decorated function so even if the underlying function is fully typed then callers to it don't get any benefit from type annotations.

Step 1 - Enable the error and override in all the offending files.

#131429

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131428
Approved by: https://github.com/justinchuby, https://github.com/oulgen
2024-07-23 21:50:55 +00:00
Li-Huai (Allan) Lin
99d9b369f4 [Optim] Support tensor lr for all optimizers and check it is 1-element (#131065)
Fixes: #130980
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131065
Approved by: https://github.com/janeyx99
2024-07-23 04:27:05 +00:00
Sahdev Zala
9795dba1e0 Optim package docstring fix (#129086)
Fix docstrings in various files in optim package. This is a last remaining fix for the issue #112593

The fix can be verified by running pydocstyle path-to-file --count

Fixes #112593

Related #128248

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129086
Approved by: https://github.com/janeyx99
2024-06-21 14:30:53 +00:00
PyTorch MergeBot
90bb510ece Revert "Deprecate torch._utils.is_compiling() and torch._dynamo.external_utils.is_compiling() (#127690)"
This reverts commit 348b181a97.

Reverted https://github.com/pytorch/pytorch/pull/127690 on behalf of https://github.com/clee2000 due to sorry I think https://github.com/pytorch/pytorch/pull/126898#issuecomment-2142884456 is still relevant, I will reach out to them to see what needs to be done in internal to get this remerged ([comment](https://github.com/pytorch/pytorch/pull/127690#issuecomment-2159248859))
2024-06-10 20:44:42 +00:00
Aaron Orenstein
27f9d3b0a1 Flip default value for mypy disallow_untyped_defs [8/11] (#127845)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127845
Approved by: https://github.com/oulgen
ghstack dependencies: #127842, #127843, #127844
2024-06-08 18:49:56 +00:00
Xuehai Pan
348b181a97 Deprecate torch._utils.is_compiling() and torch._dynamo.external_utils.is_compiling() (#127690)
This PR is split from PR #126898.

- #126898

------

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127690
Approved by: https://github.com/Skylion007
2024-06-08 15:25:03 +00:00
PyTorch MergeBot
033e733021 Revert "[BE] wrap deprecated function/class with typing_extensions.deprecated (#126898)"
This reverts commit 749a132fb0.

Reverted https://github.com/pytorch/pytorch/pull/126898 on behalf of https://github.com/fbgheith due to switching typing-extensions=4.3.0 to 4.9.0 causes internal failure ([comment](https://github.com/pytorch/pytorch/pull/126898#issuecomment-2142884456))
2024-05-31 19:47:24 +00:00
Xuehai Pan
749a132fb0 [BE] wrap deprecated function/class with typing_extensions.deprecated (#126898)
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.

Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.

UPDATE: Use `FutureWarning` instead of `DeprecationWarning`.

Resolves #126888

- #126888

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126898
Approved by: https://github.com/albanD
2024-05-29 12:09:27 +00:00
feifan
22712ba5c5 Radam support the flag for "maximize" (#126765)
Fixes #[126642](https://github.com/pytorch/pytorch/issues/126642)

I reference the maximize in `Adam` and add `Radam's` maximize flag. If this pr is OK, I will add another pr for `Nadam`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126765
Approved by: https://github.com/janeyx99
2024-05-27 06:34:50 +00:00
David Chiu
1a28f731dc [optim] Merge the pyi files into py files of optimizer (#125452)
Continue the work of pytorch/pytorch#125153
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125452
Approved by: https://github.com/janeyx99
2024-05-14 18:24:50 +00:00
daitian1995
b805d3cbcb Modify device check in capturable optimizer to support more devices (#124919)
Fixes #124830

Modify device check in capturable optimizer to support more device

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124919
Approved by: https://github.com/janeyx99
2024-05-14 05:56:00 +00:00
Michael Lazos
0f02e0aa39 Disable dynamo on functional optims if capturable=False (#123619)
This resolves a bug in eager where if an old state dict is loaded (without the capturable flag) but the original dict had the capturable flag, then state_steps would be on cuda but we would take the non-capturable path. We now fallback to eager if capturable=False.

Current design doc and discussion: https://docs.google.com/document/d/1DmmbiaSp16CDZtGw1qzXKHFTY_0gqc0xpnBdviXq0vk/edit#heading=h.871u7bvwz7ze

Note on the actual fallback logic - there was an issue with torchscript originally not handling *args, **kwargs properly, after rectifying that by using `functools.wraps`, there was an additional bug with scoping which required the single tensor implementation to be in the global scope at the time of the fallback closure being created. I pass in the single tensor function to the `_disable_dynamo_if_unsupported` decorator to workaround this bug.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123619
Approved by: https://github.com/janeyx99
2024-05-07 22:17:01 +00:00
FFFrog
791e5db705 Part 3: UFMT fix the rest files in torch/optim due to the pr-sanity-checks (#124055)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124055
Approved by: https://github.com/ezyang
ghstack dependencies: #124048, #124053, #124054
2024-04-16 03:22:39 +00:00
Jane Xu
24821fec26 Add RAdam capturable API for forloop (#121260)
Implementation thanks to @MarouaneMaatouk in https://github.com/pytorch/pytorch/pull/118697, though I've since cleaned it up a lot to save perf on the rect < 5 eager case. It also just looks better now :) Added tests and the cudagraph health check.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121260
Approved by: https://github.com/mlazos
2024-03-08 00:00:30 +00:00
Jane Xu
b5ba80828f [optim] Rectify capturable testing and fix bugs! (#118326)
This PR fixes several bugs, listed in priority:
1. `load_state_dict` with a nontensor step was incorrect for capturable and fused implementations since we don't create the tensors on the right device in `__setstate__`. This has been fixed.
2. The most recently added capturable implementations forgot the check that all tensors should be on CUDA for eager. We've now added those checks
3. The most recent change in Adamax only adds capturable for foreach but will silently be incorrect for forloop/single-tensor. I've added erroring and modified testing with many many many skips for that. Honestly my preference after this PR has only been further cemented  that we should just do the single tensor and multi tensor capturable implementations together in the future. @mlazos
4. The conditional for adding cuda-supported configs for the optimizer infos was incorrect! So we hadn't been testing capturable! This also stands rectified and was the trigger for this PR in the first place.
5. In a similar way, the conditional for `_get_optim_inputs_including_global_cliquey_kwargs` was incorrect sometimes as well. This has also been corrected.

The following is not a bug, but is just something to make life simpler by not needing to handle Nones: `optim_input_funcs` must now mandatorily take in a `device`, which could be a string or a torch.device.

Details for posterity:
4. Running the test_foreach_matches_forloop test and printing the configs that get printed yields capturable getting included, which is correct.
```
(pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (5d50138f)]$ python test/test_optim.py -k test_foreach_matches_forloop_AdamW_cuda
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0
  warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
params=None, kwargs={}, desc=default
params=None, kwargs={'lr': 0.01}, desc=non-default lr
params=None, kwargs={'weight_decay': 0.1}, desc=nonzero weight_decay
params=None, kwargs={'weight_decay': 0.1, 'maximize': True}, desc=maximize
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True}, desc=amsgrad
params=None, kwargs={'capturable': True}, desc=capturable
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True}, desc=capturable, amsgrad
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True}, desc=Tensor lr with capturable and amsgrad
.
----------------------------------------------------------------------
Ran 1 test in 19.229s

OK
```
5. Running the test_optimizer_can_be_printed test (which calls `_get_optim_inputs_including_global_cliquey_kwargs`) and printing what gets run is also now correct.
```
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0
  warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
params=None, kwargs={'differentiable': False}, desc=default
params=None, kwargs={'differentiable': True}, desc=default & differentiable
params=None, kwargs={'lr': 0.01, 'differentiable': False}, desc=non-default lr
params=None, kwargs={'lr': 0.01, 'differentiable': True}, desc=non-default lr & differentiable
params=None, kwargs={'weight_decay': 0.1, 'differentiable': False}, desc=nonzero weight_decay
params=None, kwargs={'weight_decay': 0.1, 'differentiable': True}, desc=nonzero weight_decay & differentiable
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'differentiable': False}, desc=maximize
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'differentiable': True}, desc=maximize & differentiable
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'differentiable': False}, desc=amsgrad
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'differentiable': True}, desc=amsgrad & differentiable
.params=None, kwargs={'foreach': False, 'differentiable': False, 'fused': False}, desc=default
params=None, kwargs={'foreach': True, 'differentiable': False, 'fused': False}, desc=default & foreach
params=None, kwargs={'foreach': False, 'differentiable': True, 'fused': False}, desc=default & differentiable
params=None, kwargs={'foreach': False, 'differentiable': False, 'fused': True}, desc=default & fused
params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': False, 'fused': False}, desc=non-default lr
params=None, kwargs={'lr': 0.01, 'foreach': True, 'differentiable': False, 'fused': False}, desc=non-default lr & foreach
params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': True, 'fused': False}, desc=non-default lr & differentiable
params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': False, 'fused': True}, desc=non-default lr & fused
params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': False, 'fused': False}, desc=nonzero weight_decay
params=None, kwargs={'weight_decay': 0.1, 'foreach': True, 'differentiable': False, 'fused': False}, desc=nonzero weight_decay & foreach
params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': True, 'fused': False}, desc=nonzero weight_decay & differentiable
params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': False, 'fused': True}, desc=nonzero weight_decay & fused
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=maximize
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=maximize & foreach
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=maximize & differentiable
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=maximize & fused
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=amsgrad
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=amsgrad & foreach
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=amsgrad & differentiable
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=amsgrad & fused
params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=capturable
params=None, kwargs={'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=capturable & foreach
params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=capturable & differentiable
params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=capturable & fused
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=capturable, amsgrad
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=capturable, amsgrad & foreach
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=capturable, amsgrad & differentiable
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=capturable, amsgrad & fused
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=Tensor lr with capturable and amsgrad
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=Tensor lr with capturable and amsgrad & foreach
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=Tensor lr with capturable and amsgrad & differentiable
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=Tensor lr with capturable and amsgrad & fused
.
----------------------------------------------------------------------
Ran 2 tests in 11.112s

OK
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118326
Approved by: https://github.com/mlazos
2024-02-02 19:13:00 +00:00
PyTorch MergeBot
2964170f3a Revert "[optim] Rectify capturable testing and fix bugs! (#118326)"
This reverts commit d947b9d500.

Reverted https://github.com/pytorch/pytorch/pull/118326 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it looks like there are some relevant failures in trunk d947b9d500, may be a land race ([comment](https://github.com/pytorch/pytorch/pull/118326#issuecomment-1923125676))
2024-02-02 07:08:14 +00:00
Jane Xu
d947b9d500 [optim] Rectify capturable testing and fix bugs! (#118326)
This PR fixes several bugs, listed in priority:
1. `load_state_dict` with a nontensor step was incorrect for capturable and fused implementations since we don't create the tensors on the right device in `__setstate__`. This has been fixed.
2. The most recently added capturable implementations forgot the check that all tensors should be on CUDA for eager. We've now added those checks
3. The most recent change in Adamax only adds capturable for foreach but will silently be incorrect for forloop/single-tensor. I've added erroring and modified testing with many many many skips for that. Honestly my preference after this PR has only been further cemented  that we should just do the single tensor and multi tensor capturable implementations together in the future. @mlazos
4. The conditional for adding cuda-supported configs for the optimizer infos was incorrect! So we hadn't been testing capturable! This also stands rectified and was the trigger for this PR in the first place.
5. In a similar way, the conditional for `_get_optim_inputs_including_global_cliquey_kwargs` was incorrect sometimes as well. This has also been corrected.

The following is not a bug, but is just something to make life simpler by not needing to handle Nones: `optim_input_funcs` must now mandatorily take in a `device`, which could be a string or a torch.device.

Details for posterity:
4. Running the test_foreach_matches_forloop test and printing the configs that get printed yields capturable getting included, which is correct.
```
(pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (5d50138f)]$ python test/test_optim.py -k test_foreach_matches_forloop_AdamW_cuda
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0
  warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
params=None, kwargs={}, desc=default
params=None, kwargs={'lr': 0.01}, desc=non-default lr
params=None, kwargs={'weight_decay': 0.1}, desc=nonzero weight_decay
params=None, kwargs={'weight_decay': 0.1, 'maximize': True}, desc=maximize
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True}, desc=amsgrad
params=None, kwargs={'capturable': True}, desc=capturable
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True}, desc=capturable, amsgrad
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True}, desc=Tensor lr with capturable and amsgrad
.
----------------------------------------------------------------------
Ran 1 test in 19.229s

OK
```
5. Running the test_optimizer_can_be_printed test (which calls `_get_optim_inputs_including_global_cliquey_kwargs`) and printing what gets run is also now correct.
```
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0
  warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
params=None, kwargs={'differentiable': False}, desc=default
params=None, kwargs={'differentiable': True}, desc=default & differentiable
params=None, kwargs={'lr': 0.01, 'differentiable': False}, desc=non-default lr
params=None, kwargs={'lr': 0.01, 'differentiable': True}, desc=non-default lr & differentiable
params=None, kwargs={'weight_decay': 0.1, 'differentiable': False}, desc=nonzero weight_decay
params=None, kwargs={'weight_decay': 0.1, 'differentiable': True}, desc=nonzero weight_decay & differentiable
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'differentiable': False}, desc=maximize
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'differentiable': True}, desc=maximize & differentiable
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'differentiable': False}, desc=amsgrad
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'differentiable': True}, desc=amsgrad & differentiable
.params=None, kwargs={'foreach': False, 'differentiable': False, 'fused': False}, desc=default
params=None, kwargs={'foreach': True, 'differentiable': False, 'fused': False}, desc=default & foreach
params=None, kwargs={'foreach': False, 'differentiable': True, 'fused': False}, desc=default & differentiable
params=None, kwargs={'foreach': False, 'differentiable': False, 'fused': True}, desc=default & fused
params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': False, 'fused': False}, desc=non-default lr
params=None, kwargs={'lr': 0.01, 'foreach': True, 'differentiable': False, 'fused': False}, desc=non-default lr & foreach
params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': True, 'fused': False}, desc=non-default lr & differentiable
params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': False, 'fused': True}, desc=non-default lr & fused
params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': False, 'fused': False}, desc=nonzero weight_decay
params=None, kwargs={'weight_decay': 0.1, 'foreach': True, 'differentiable': False, 'fused': False}, desc=nonzero weight_decay & foreach
params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': True, 'fused': False}, desc=nonzero weight_decay & differentiable
params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': False, 'fused': True}, desc=nonzero weight_decay & fused
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=maximize
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=maximize & foreach
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=maximize & differentiable
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=maximize & fused
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=amsgrad
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=amsgrad & foreach
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=amsgrad & differentiable
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=amsgrad & fused
params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=capturable
params=None, kwargs={'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=capturable & foreach
params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=capturable & differentiable
params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=capturable & fused
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=capturable, amsgrad
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=capturable, amsgrad & foreach
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=capturable, amsgrad & differentiable
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=capturable, amsgrad & fused
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=Tensor lr with capturable and amsgrad
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=Tensor lr with capturable and amsgrad & foreach
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=Tensor lr with capturable and amsgrad & differentiable
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=Tensor lr with capturable and amsgrad & fused
.
----------------------------------------------------------------------
Ran 2 tests in 11.112s

OK
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118326
Approved by: https://github.com/mlazos
2024-02-02 02:02:58 +00:00
Michael Lazos
800e2e823f Add compilable foreach RAdam support (#117912)
Fixes https://github.com/pytorch/pytorch/issues/117807

This brings the number of supported optimizers with `torch.compile` to 11/13 (!)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117912
Approved by: https://github.com/janeyx99
2024-01-27 04:32:27 +00:00
Jane Xu
924f1b841a [optim] Allow torch.float64 scalars for forloop + foreach implementations (#115841)
Should allow for uses cases mentioned in #110940

This would allow scalars to also be float64s in the foreach implementation. The fused implementation would still create a float32 step on Adam and AdamW. This PR also does NOT worry about performance and is mainly for enablement.

Next steps:
- Relax the constraint on fused adam(w) and allow torch.float64 scalars there
- Allow _performant_ mixed dtypes in foreach (a bigger project in itself).

This PR will conflict with my other PRs, I will figure out a landing order

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115841
Approved by: https://github.com/albanD
2023-12-27 09:13:49 +00:00