1. Prevents unintended aliasing of `self._last_lr`/`get_last_lr(...)` with `group["lr"]` when `group["lr"]` is a tensor.
2. Prevents unintended aliasing of `LRScheduler.base_lrs` with the `group["initial_lr"]`s.
3. Updates `test/optim/test_lrscheduler.py` to test tensor LRs.
4. Changes type annotations for `_last_lr`, `get_last_lr()`, `base_lrs`, `get_lr()`, and `_get_closed_form_lr()` from `list[float]` to `list[float | Tensor]`; adds documentation.
Fixes#163103
LR schedulers can behave in unexpected ways when using a tensor LR due to patterns like this:
```python
self._last_lr: list[float] = [group["lr"] for group in self.optimizer.param_groups]
```
This PR adds a helper to address this:
```python
def _param_groups_val_list(optimizer: Optimizer, key: str) -> list[Any]:
"""Create a list containing group[key] for each optimizer param_group.
Prevents aliasing when group[key] could be a Tensor.
Raises a KeyError when group[key] does not exist.
"""
return [
group[key].clone() if isinstance(group[key], Tensor) else group[key]
for group in optimizer.param_groups
]
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163120
Approved by: https://github.com/janeyx99
Prevents edge cases in SequentialLR and ReduceLROnPlateau which could corrupt learning rates or trigger recompilation.
Supersedes #162360Fixes#162359Fixes#163093
While putting #162360 together, I noticed the class of issue I was fixing (i.e. unintended aliasing in lr_schedulers when using Tensor lrs) appeared in several other places. @janeyx99 suggested I put together a follow-up PR.
There are several bugs resembling the one fixed in #162360. I added a helper to fix these:
```python
def _update_param_group_val(param_group: dict[str, Any], key: str, val: float | Tensor):
"""Set param_group[key] to val without aliasing or assignment when they're both tensors.
Raises a KeyError if param_group[key] does not exist.
"""
if isinstance(param_group[key], Tensor):
param_group[key].fill_(_to_scalar(val))
else:
param_group[key] = val
```
And applied it to fix bugs in `SequentialLR.__init__` and `LRScheduler._update_lr`. I also added it to `CyclicLR.__init__` which was using an equivalent pattern, and `CosineAnnealingWarmRestarts.step` which *should* have had a similar issue:
```python
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group["lr"] = lr
```
But did not, because `get_lr()` actually returns tensors when using a tensor lr (despite its `list[float]` return type annotation). Relying on this propagation seems fragile, so I conservatively added the method here as well. I'll be fixing the type annotations and several related issues in followup PRs built off of this one.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163098
Approved by: https://github.com/janeyx99
## Summary
This PR updates the docstring for `CosineAnnealingLR` to accurately reflect its recursive learning rate schedule. The previous docstring displayed only the SGDR closed-form expression, which doesn't match the actual recursive implementation in code.
Changes:
- Added the recursive update formula used in `get_lr()`
- Retained the original closed-form SGDR expression for reference
- Clarified that warm restarts are not implemented in this scheduler
This addresses confusion raised in issue #152081.
## Related issue
[#152081](https://github.com/pytorch/pytorch/issues/152081)
## Testing
Doc-only change. Ran pre-commit to verify formatting.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152936
Approved by: https://github.com/janeyx99
When stub files (`*.pyi`) were removed from `optim` (#125556, #125452), some types that existed are no longer available. This pull request adds them back.
Just for reference, these types are used in `pytorch-lightning`'s `LightningCLI`. Command line interfaces are created automatically, and having type hints make them nicer.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136185
Approved by: https://github.com/janeyx99
Fix docstrings in Learning Rate Scheduler.
The fix can be verified by running pydocstyle path-to-file --count
Related #112593
**BEFORE the PR:**
pydocstyle torch/optim/lr_scheduler.py --count
92
**AFTER the PR:**
pydocstyle torch/optim/lr_scheduler.py --count
0
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128679
Approved by: https://github.com/janeyx99
I'm currently locked into jsonargparse version 4.19.0, and it complains when used in combination with LightningCLI (v2.0.8). This is because it cares about the types declared in google style docstrings. This causes a problem when it tries to parse how it should cast arguments to construct an instance of an LRScheduler class because the docstrings declare the "verbose" parameter as a bool, but the defaults recently changed to a string "deprecated". This means the type should really be `bool | str`.
This PR adds a `| str` to the docstring type in each learning rate scheduler class. This will prevent jsonargparse from complaining.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127943
Approved by: https://github.com/janeyx99
Enables LRScheduler to handle tensor LRs.
Note on test changes:
For the test modifications I just removed itertools.product and created two loops. This allows us to create a new set of optim_inputs on each iteration to prevent mutations on the tensor LR carrying over across iterations. Nothing else in those tests was modified.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123753
Approved by: https://github.com/janeyx99
ghstack dependencies: #123751, #123752
Fixes https://github.com/pytorch/pytorch/issues/98921
There were two issues detected:
- `MultiStepLR`: issue is described in https://github.com/pytorch/pytorch/issues/98921, this is resolved by allowlisting `collections.Counter`
- `OneCycleLR`: `state_dict['anneal_func']` is either `<function OneCycleLR._annealing_cos at 0x7f364186f5b0>` or
`<function OneCycleLR._annealing_linear at 0x7f39aa483640>` depending on the `anneal_func` kwarg.
This leads to `WeightsUnpickler error: Unsupported class __builtin__.getattr` from the `weights_only` Unpickler.
Fixed the above in a BC-compatible manner by adding `OneCyclicLR._anneal_func_type` as a string attribute and removing `OneCyclicLR.anneal_func`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123775
Approved by: https://github.com/albanD, https://github.com/malfet