pytorch/torch/optim/sgd.py
ErezYosef 197601eeea Add Support for Tracking Parameter Names (named_parameters) in Optimizer State Dict (#134107)
A proposal addressing Issue #1489: **Optimizer should track parameter names and not id.**

(also mentioned in here: [[RFC] Introducing FQNs/clarity eyeglasses to optim state_dict](https://dev-discuss.pytorch.org/t/rfc-introducing-fqns-clarity-to-optim-state-dict/1552)

## Summary
This PR introduces a backward-compatible enhancement where optimizers track parameter names instead of just their id.
Optimizers can be initialized with `named_parameters()` as:
```python
optimizer = optim.SGD(model.named_parameters(), lr=0.01, momentum=0.9)
```
This allows for greater clarity and ease when handling optimizers, as the parameters' names are preserved within the optimizer’s `state_dict` as:
```
state_dict =
{
    'state': {
    0: {'momentum_buffer': tensor(...), ...},
    1: {'momentum_buffer': tensor(...), ...},
    },
    'param_groups': [
        {
        'lr': 0.01,
        'weight_decay': 0,
        ...
        'params': [0,1]
        'param_names' ['layer.weight', 'layer.bias']  (optional)
        }
    ]
}
```
Loading `state_dict` is not changed (backward-compatible) and the `param_names` key will be ignored.

## Key Features
#### Named Parameters in Optimizer Initialization:
Optimizers can accept the output of `model.named_parameters()` during initialization, allowing them to store parameter names directly.
#### Parameter Names in `state_dict`:
The parameter names are saved as a list in the optimizer’s `state_dict` with key `param_names`, alongside the `params` indices, ensuring seamless tracking of both names and parameters.

## Backward Compatibility
#### No Breaking Changes:
This change is fully backward-compatible. The added `param_names` key in the optimizer's `state_dict` is ignored when loading a state to the optimizer.

#### Customization with Hooks:
For more control, the loaded state_dict can be modified using a custom `register_load_state_dict_pre_hook`, providing flexibility for different design needs.

## Documentation Updates
Please refer to the documentation changes for more details on how this feature is implemented and how it can be used effectively.

## Solution Example:

A suggested solution to the problem mentioned in #1489, for the same parameters but in a different order.
The following `register_load_state_dict_pre_hook` should be added to the optimizer before loading to enable loading the state dict :
```python
def adapt_state_dict_ids(optimizer, state_dict):
    # assuming a single param group.
    current_state_group = optimizer.state_dict()['param_groups'][0]
    loaded_state_group = state_dict['param_groups'][0]

    # same number of params, same names, only different ordering
    current_state_name_to_id_mapping = {}  # mapping --  param_name: id
    for i, name in enumerate(current_state_group['param_names']):
        current_state_name_to_id_mapping[name] = current_state_group['params'][i]

    # changing the ids of the loaded state dict to match the order of the given state dict.
    for i, name in enumerate(current_state_group['param_names']):
        loaded_state_group['params'][i] = current_state_name_to_id_mapping[name]

    return state_dict
```
In this code, the loaded `state_dict` ids are adapted to match the order of the current optimizer `state_dict`.
Both the previous and the current optimizers are required to be initiated with `named_parameters()` to have the 'param_names' key in the dict.

### Note
This is my first contribution to PyTorch, and I wish to receive feedback or suggestions for improvement.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134107
Approved by: https://github.com/janeyx99

Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
2024-10-14 19:24:44 +00:00

514 lines
19 KiB
Python

# mypy: allow-untyped-defs
r"""Implementation for Stochastic Gradient Descent optimizer."""
from typing import cast, List, Optional, Union
import torch
from torch import Tensor
from .optimizer import (
_default_to_fused_or_foreach,
_device_dtype_check_for_fused,
_differentiable_doc,
_foreach_doc,
_fused_doc,
_maximize_doc,
_params_doc,
_use_grad_for_differentiable,
DeviceDict,
Optimizer,
ParamsT,
)
__all__ = ["SGD", "sgd"]
class SGD(Optimizer): # noqa: D101
def __init__(
self,
params: ParamsT,
lr: Union[float, Tensor] = 1e-3,
momentum: float = 0,
dampening: float = 0,
weight_decay: float = 0,
nesterov: bool = False,
*,
maximize: bool = False,
foreach: Optional[bool] = None,
differentiable: bool = False,
fused: Optional[bool] = None,
): # noqa: D107
if isinstance(lr, Tensor) and lr.numel() != 1:
raise ValueError("Tensor lr must be 1-element")
if lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr}")
if momentum < 0.0:
raise ValueError(f"Invalid momentum value: {momentum}")
if weight_decay < 0.0:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict(
lr=lr,
momentum=momentum,
dampening=dampening,
weight_decay=weight_decay,
nesterov=nesterov,
maximize=maximize,
foreach=foreach,
differentiable=differentiable,
fused=fused,
)
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super().__init__(params, defaults)
if fused:
self._step_supports_amp_scaling = True
self._need_device_dtype_check_for_fused = True
if differentiable:
raise RuntimeError("`fused` does not support `differentiable`")
if foreach:
raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
def __setstate__(self, state): # noqa: D105
super().__setstate__(state)
for group in self.param_groups:
group.setdefault("nesterov", False)
group.setdefault("maximize", False)
group.setdefault("foreach", None)
group.setdefault("differentiable", False)
group.setdefault("fused", False)
def _init_group(self, group, params, grads, momentum_buffer_list):
has_sparse_grad = False
for p in group["params"]:
if p.grad is not None:
if group["fused"] and getattr(
self, "_need_device_dtype_check_for_fused", True
):
_device_dtype_check_for_fused(p)
self._need_device_dtype_check_for_fused = False
params.append(p)
grads.append(p.grad)
if p.grad.is_sparse:
has_sparse_grad = True
if group["momentum"] != 0:
state = self.state[p]
momentum_buffer_list.append(state.get("momentum_buffer"))
return has_sparse_grad
@_use_grad_for_differentiable
def step(self, closure=None):
"""Perform a single optimization step.
Args:
closure (Callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params: List[Tensor] = []
grads: List[Tensor] = []
momentum_buffer_list: List[Optional[Tensor]] = []
has_sparse_grad = self._init_group(
group, params, grads, momentum_buffer_list
)
sgd(
params,
grads,
momentum_buffer_list,
weight_decay=group["weight_decay"],
momentum=group["momentum"],
lr=group["lr"],
dampening=group["dampening"],
nesterov=group["nesterov"],
maximize=group["maximize"],
has_sparse_grad=has_sparse_grad,
foreach=group["foreach"],
fused=group["fused"],
grad_scale=getattr(self, "grad_scale", None),
found_inf=getattr(self, "found_inf", None),
)
if group["momentum"] != 0:
# update momentum_buffers in state
for p, momentum_buffer in zip(params, momentum_buffer_list):
state = self.state[p]
state["momentum_buffer"] = momentum_buffer
return loss
SGD.__doc__ = (
r"""Implements stochastic gradient descent (optionally with momentum).
.. math::
\begin{aligned}
&\rule{110mm}{0.4pt} \\
&\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, \: f(\theta)
\text{ (objective)}, \: \lambda \text{ (weight decay)}, \\
&\hspace{13mm} \:\mu \text{ (momentum)}, \:\tau \text{ (dampening)},
\:\textit{ nesterov,}\:\textit{ maximize} \\[-1.ex]
&\rule{110mm}{0.4pt} \\
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
&\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
&\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\
&\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
&\hspace{5mm}\textbf{if} \: \mu \neq 0 \\
&\hspace{10mm}\textbf{if} \: t > 1 \\
&\hspace{15mm} \textbf{b}_t \leftarrow \mu \textbf{b}_{t-1} + (1-\tau) g_t \\
&\hspace{10mm}\textbf{else} \\
&\hspace{15mm} \textbf{b}_t \leftarrow g_t \\
&\hspace{10mm}\textbf{if} \: \textit{nesterov} \\
&\hspace{15mm} g_t \leftarrow g_{t} + \mu \textbf{b}_t \\
&\hspace{10mm}\textbf{else} \\[-1.ex]
&\hspace{15mm} g_t \leftarrow \textbf{b}_t \\
&\hspace{5mm}\textbf{if} \: \textit{maximize} \\
&\hspace{10mm}\theta_t \leftarrow \theta_{t-1} + \gamma g_t \\[-1.ex]
&\hspace{5mm}\textbf{else} \\[-1.ex]
&\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma g_t \\[-1.ex]
&\rule{110mm}{0.4pt} \\[-1.ex]
&\bf{return} \: \theta_t \\[-1.ex]
&\rule{110mm}{0.4pt} \\[-1.ex]
\end{aligned}
Nesterov momentum is based on the formula from
`On the importance of initialization and momentum in deep learning`__.
"""
+ rf"""
Args:
{_params_doc}
lr (float, Tensor, optional): learning rate (default: 1e-3)
momentum (float, optional): momentum factor (default: 0)
dampening (float, optional): dampening for momentum (default: 0)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
nesterov (bool, optional): enables Nesterov momentum. Only applicable
when momentum is non-zero. (default: False)
{_maximize_doc}
{_foreach_doc}
{_differentiable_doc}
{_fused_doc}
"""
+ r"""
Example:
>>> # xdoctest: +SKIP
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
__ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
.. note::
The implementation of SGD with Momentum/Nesterov subtly differs from
Sutskever et al. and implementations in some other frameworks.
Considering the specific case of Momentum, the update can be written as
.. math::
\begin{aligned}
v_{t+1} & = \mu * v_{t} + g_{t+1}, \\
p_{t+1} & = p_{t} - \text{lr} * v_{t+1},
\end{aligned}
where :math:`p`, :math:`g`, :math:`v` and :math:`\mu` denote the
parameters, gradient, velocity, and momentum respectively.
This is in contrast to Sutskever et al. and
other frameworks which employ an update of the form
.. math::
\begin{aligned}
v_{t+1} & = \mu * v_{t} + \text{lr} * g_{t+1}, \\
p_{t+1} & = p_{t} - v_{t+1}.
\end{aligned}
The Nesterov version is analogously modified.
Moreover, the initial value of the momentum buffer is set to the
gradient value at the first step. This is in contrast to some other
frameworks that initialize it to all zeros.
"""
)
def sgd(
params: List[Tensor],
d_p_list: List[Tensor],
momentum_buffer_list: List[Optional[Tensor]],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
has_sparse_grad: bool = False,
foreach: Optional[bool] = None,
fused: Optional[bool] = None,
grad_scale: Optional[Tensor] = None,
found_inf: Optional[Tensor] = None,
*,
weight_decay: float,
momentum: float,
lr: float,
dampening: float,
nesterov: bool,
maximize: bool,
):
r"""Functional API that performs SGD algorithm computation.
See :class:`~torch.optim.SGD` for details.
"""
# Respect when the user inputs False/True for foreach or fused. We only want to change
# the default when neither have been user-specified. Note that we default to foreach
# and pass False to use_fused. This is not a mistake--we want to give the fused impl
# bake-in time before making it the default, even if it is typically faster.
if foreach is None and fused is None:
# why must we be explicit about an if statement for torch.jit.is_scripting here?
# because JIT can't handle Optionals nor fancy conditionals when scripting
if not torch.jit.is_scripting():
fused, foreach = _default_to_fused_or_foreach(
params, differentiable=False, use_fused=False
)
else:
foreach = False
fused = False
if foreach is None:
foreach = False
if fused is None:
fused = False
if foreach and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
if fused and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with fused optimizers")
if foreach and not torch.jit.is_scripting():
func = _multi_tensor_sgd
elif fused and not torch.jit.is_scripting():
func = _fused_sgd
else:
func = _single_tensor_sgd
func(
params,
d_p_list,
momentum_buffer_list,
weight_decay=weight_decay,
momentum=momentum,
lr=lr,
dampening=dampening,
nesterov=nesterov,
has_sparse_grad=has_sparse_grad,
maximize=maximize,
grad_scale=grad_scale,
found_inf=found_inf,
)
def _single_tensor_sgd(
params: List[Tensor],
grads: List[Tensor],
momentum_buffer_list: List[Optional[Tensor]],
grad_scale: Optional[Tensor],
found_inf: Optional[Tensor],
*,
weight_decay: float,
momentum: float,
lr: float,
dampening: float,
nesterov: bool,
maximize: bool,
has_sparse_grad: bool,
):
assert grad_scale is None and found_inf is None
for i, param in enumerate(params):
grad = grads[i] if not maximize else -grads[i]
if weight_decay != 0:
grad = grad.add(param, alpha=weight_decay)
if momentum != 0:
buf = momentum_buffer_list[i]
if buf is None:
buf = torch.clone(grad).detach()
momentum_buffer_list[i] = buf
else:
buf.mul_(momentum).add_(grad, alpha=1 - dampening)
if nesterov:
grad = grad.add(buf, alpha=momentum)
else:
grad = buf
param.add_(grad, alpha=-lr)
def _multi_tensor_sgd(
params: List[Tensor],
grads: List[Tensor],
momentum_buffer_list: List[Optional[Tensor]],
grad_scale: Optional[Tensor],
found_inf: Optional[Tensor],
*,
weight_decay: float,
momentum: float,
lr: float,
dampening: float,
nesterov: bool,
maximize: bool,
has_sparse_grad: bool,
):
assert grad_scale is None and found_inf is None
if len(params) == 0:
return
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
[params, grads, momentum_buffer_list], with_indices=True # type: ignore[list-item]
)
for (
device_params_,
device_grads_,
device_momentum_buffer_list,
), indices in grouped_tensors.values():
device_params: List[Tensor] = cast(List[Tensor], device_params_)
device_grads: List[Tensor] = cast(List[Tensor], device_grads_)
device_has_sparse_grad = has_sparse_grad and any(
grad.is_sparse for grad in device_grads
)
if maximize:
device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]
if weight_decay != 0:
# Re-use the intermediate memory (device_grads) already allocated for maximize
if maximize:
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
else:
device_grads = torch._foreach_add( # type: ignore[assignment]
device_grads, device_params, alpha=weight_decay
)
if momentum != 0:
bufs: List[Tensor] = []
all_states_with_momentum_buffer = True
for i in range(len(device_momentum_buffer_list)):
if device_momentum_buffer_list[i] is None:
all_states_with_momentum_buffer = False
break
else:
bufs.append(cast(Tensor, device_momentum_buffer_list[i]))
if all_states_with_momentum_buffer:
torch._foreach_mul_(bufs, momentum)
torch._foreach_add_(bufs, device_grads, alpha=1 - dampening)
else:
bufs = []
for i in range(len(device_momentum_buffer_list)):
if device_momentum_buffer_list[i] is None:
buf = device_momentum_buffer_list[i] = momentum_buffer_list[
indices[i]
] = torch.clone(device_grads[i]).detach()
else:
buf = cast(Tensor, device_momentum_buffer_list[i])
buf.mul_(momentum).add_(device_grads[i], alpha=1 - dampening)
bufs.append(buf)
if nesterov:
torch._foreach_add_(device_grads, bufs, alpha=momentum)
else:
device_grads = bufs
if not device_has_sparse_grad:
# handle internal item() call if lr is a tensor
if isinstance(lr, torch.Tensor) and torch._utils.is_compiling():
grads_x_lr = torch._foreach_mul(device_grads, -lr)
torch._foreach_add_(device_params, grads_x_lr)
else:
torch._foreach_add_(device_params, device_grads, alpha=-lr)
else:
# foreach APIs don't support sparse
for i in range(len(device_params)):
device_params[i].add_(device_grads[i], alpha=-lr)
def _fused_sgd(
params: List[Tensor],
grads: List[Tensor],
momentum_buffer_list: List[Optional[Tensor]],
grad_scale: Optional[Tensor],
found_inf: Optional[Tensor],
*,
weight_decay: float,
momentum: float,
lr: float,
dampening: float,
nesterov: bool,
maximize: bool,
has_sparse_grad: bool,
) -> None:
if not params:
return
if has_sparse_grad:
raise RuntimeError("`_fused_sgd` does not support sparse gradients")
grad_scale_dict: DeviceDict = (
{grad_scale.device: grad_scale} if grad_scale is not None else {}
)
found_inf_dict: DeviceDict = (
{found_inf.device: found_inf} if found_inf is not None else {}
)
no_momentum_buffer = momentum == 0
is_first_step = (
all(t is None for t in momentum_buffer_list) and not no_momentum_buffer
)
if is_first_step:
for i, g in enumerate(grads):
momentum_buffer_list[i] = torch.empty_like(g)
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
[params, grads, momentum_buffer_list], with_indices=False # type: ignore[list-item]
)
for (device, _), (
(device_params_, device_grads_, device_momentum_buffer_list),
_,
) in grouped_tensors.items():
device_params: List[Tensor] = cast(List[Tensor], device_params_)
device_grads: List[Tensor] = cast(List[Tensor], device_grads_)
device_grad_scale, device_found_inf = None, None
if grad_scale is not None:
device_grad_scale = grad_scale_dict.setdefault(
device, grad_scale.to(device)
)
if found_inf_dict is not None and found_inf is not None:
device_found_inf = found_inf_dict.setdefault(device, found_inf.to(device))
torch._fused_sgd_(
device_params,
device_grads,
[]
if no_momentum_buffer
else cast(List[Tensor], device_momentum_buffer_list),
weight_decay=weight_decay,
momentum=momentum,
lr=lr,
dampening=dampening,
nesterov=nesterov,
maximize=maximize,
is_first_step=is_first_step,
grad_scale=device_grad_scale,
found_inf=device_found_inf,
)