mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Fully fixes: https://github.com/pytorch/pytorch/issues/110506 Depends: https://github.com/pytorch/pytorch/pull/110607 Potential merge conflicts: - https://github.com/pytorch/pytorch/pull/110339 - https://github.com/pytorch/pytorch/pull/110345 - https://github.com/pytorch/pytorch/pull/110454 Related: - https://github.com/pytorch/pytorch/issues/110606 (we can apply the improvements here orthogonally to the complex support) ### Results Benchmark: 100 params. Breakdowns (float32, dynamo): ``` Adagrad: this PR: 4.4s, main: 8.8s Adam: this PR: 2.1s, main: 9.8s AdamW: this PR: 2.5s, main: 8.2s ASGD: this PR: 3.1s, main: 8.5s RMSProp: this PR: 1.3s, main: 4.2s RProp: this PR: 6.7s, main: 14.9s ``` Notes: 1. Adagrad is still slow due to `_get_value` list comprehension. Can be fixed in https://github.com/pytorch/pytorch/pull/110339/files by utilizing capturable path 2. Adamax is not actually compiled (it is currently disabled). 3. Inductor compile time is quite variable. We calculate dynamo by subtracting `call_user_compiler` from `compile_inner` timing. <details> This PR: ``` Adagrad (torch.float32): 28.47496461868286s Adagrad (torch.complex64): 29.379547357559204s Adam (torch.float32): 17.334211587905884s Adam (torch.complex64): 29.637500524520874s Adamax (torch.float32): 2.4749321937561035s Adamax (torch.complex64): 3.1997995376586914s AdamW (torch.float32): 18.06532859802246s AdamW (torch.complex64): 28.25661015510559s ASGD (torch.float32): 23.70255398750305s ASGD (torch.complex64): 25.33756995201111s RMSprop (torch.float32): 7.964028596878052s RMSprop (torch.complex64): 12.909599781036377s Rprop (torch.float32): 30.512362003326416s Rprop (torch.complex64): 44.74405765533447s ``` Main ``` Adagrad (torch.float32): 26.919506072998047s Adagrad (torch.complex64): 35.190622091293335s Adam (torch.float32): 25.715000867843628s Adam (torch.complex64): 24.17716670036316s Adamax (torch.float32): 2.4404726028442383s Adamax (torch.complex64): 3.3538928031921387s AdamW (torch.float32): 25.2022807598114s AdamW (torch.complex64): 28.915700912475586s ASGD (torch.float32): 24.108731985092163s ASGD (torch.complex64): 26.589075088500977s RMSprop (torch.float32): 10.781344175338745s RMSprop (torch.complex64): 15.136352777481079s Rprop (torch.float32): 42.46482181549072s Rprop (torch.complex64): 48.28277635574341s ``` Seems that it doesn't help the complex case by much (but that's not the majority case). torch.float32 is generally positive, when it does not show drastic improvement / regresses, it is due to inductor variance (by manually inspecting the logs). </details> ### Benchmark Script ```python import torch import time from torch.optim import Adagrad, Adam, Adamax, AdamW, ASGD, RMSprop, Rprop OPTIMS = [Adagrad, Adam, Adamax, AdamW, ASGD, RMSprop, Rprop] DTYPES = [torch.float, torch.cfloat] NUM_PARAMS = 100 kwargs = { "lr": 0.01, "foreach": True } summary = [] for optim_cls in OPTIMS: for dtype in DTYPES: torch._dynamo.reset() # torch._inductor.metrics.reset() input = torch.ones([10, 10], dtype=dtype, device="cuda:0") model = torch.nn.Sequential( *[torch.nn.Linear(10, 10, dtype=dtype, device="cuda:0") for _ in range(NUM_PARAMS)] ) model(input).sum().abs().backward() opt_compiled = optim_cls(model.parameters(), **kwargs) compiled_step = torch.compile(opt_compiled.step) with torch.set_grad_enabled(False): start_time = time.time() compiled_step() summary.append(f"{optim_cls.__name__} ({dtype}): {time.time() - start_time}s") print(optim_cls, kwargs, dtype, torch._dynamo.utils.compile_times()) for s in summary: print(s) ``` CC: @janeyx99 @mlazos Pull Request resolved: https://github.com/pytorch/pytorch/pull/110613 Approved by: https://github.com/janeyx99
376 lines
13 KiB
Python
376 lines
13 KiB
Python
import torch
|
|
from torch import Tensor
|
|
|
|
from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value,
|
|
_default_to_fused_or_foreach, _differentiable_doc, _foreach_doc, _maximize_doc)
|
|
from typing import List, Optional
|
|
|
|
__all__ = ["Adagrad", "adagrad"]
|
|
|
|
|
|
class Adagrad(Optimizer):
|
|
def __init__(
|
|
self,
|
|
params,
|
|
lr=1e-2,
|
|
lr_decay=0,
|
|
weight_decay=0,
|
|
initial_accumulator_value=0,
|
|
eps=1e-10,
|
|
foreach: Optional[bool] = None,
|
|
*,
|
|
maximize: bool = False,
|
|
differentiable: bool = False,
|
|
):
|
|
if not 0.0 <= lr:
|
|
raise ValueError(f"Invalid learning rate: {lr}")
|
|
if not 0.0 <= lr_decay:
|
|
raise ValueError(f"Invalid lr_decay value: {lr_decay}")
|
|
if not 0.0 <= weight_decay:
|
|
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
|
if not 0.0 <= initial_accumulator_value:
|
|
raise ValueError(
|
|
f"Invalid initial_accumulator_value value: {initial_accumulator_value}"
|
|
)
|
|
if not 0.0 <= eps:
|
|
raise ValueError(f"Invalid epsilon value: {eps}")
|
|
|
|
defaults = dict(
|
|
lr=lr,
|
|
lr_decay=lr_decay,
|
|
eps=eps,
|
|
weight_decay=weight_decay,
|
|
initial_accumulator_value=initial_accumulator_value,
|
|
foreach=foreach,
|
|
maximize=maximize,
|
|
differentiable=differentiable,
|
|
)
|
|
super().__init__(params, defaults)
|
|
|
|
for group in self.param_groups:
|
|
for p in group["params"]:
|
|
state = self.state[p]
|
|
state["step"] = torch.tensor(0.0)
|
|
init_value = (
|
|
complex(initial_accumulator_value, initial_accumulator_value)
|
|
if torch.is_complex(p)
|
|
else initial_accumulator_value
|
|
)
|
|
state["sum"] = torch.full_like(
|
|
p, init_value, memory_format=torch.preserve_format
|
|
)
|
|
|
|
def __setstate__(self, state):
|
|
super().__setstate__(state)
|
|
for group in self.param_groups:
|
|
group.setdefault("foreach", None)
|
|
group.setdefault("maximize", False)
|
|
group.setdefault("differentiable", False)
|
|
|
|
state_values = list(self.state.values())
|
|
step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
|
|
state_values[0]["step"]
|
|
)
|
|
if not step_is_tensor:
|
|
for s in state_values:
|
|
s["step"] = torch.tensor(float(s["step"]))
|
|
|
|
def share_memory(self):
|
|
for group in self.param_groups:
|
|
for p in group["params"]:
|
|
state = self.state[p]
|
|
state["sum"].share_memory_()
|
|
|
|
def _init_group(self, group, params_with_grad, grads, state_sums, state_steps):
|
|
has_sparse_grad = False
|
|
for p in group["params"]:
|
|
if p.grad is not None:
|
|
if p.grad.is_sparse:
|
|
has_sparse_grad = True
|
|
params_with_grad.append(p)
|
|
grads.append(p.grad)
|
|
state = self.state[p]
|
|
state_sums.append(state["sum"])
|
|
state_steps.append(state["step"])
|
|
|
|
return has_sparse_grad
|
|
|
|
@_use_grad_for_differentiable
|
|
def step(self, closure=None):
|
|
"""Performs a single optimization step.
|
|
|
|
Args:
|
|
closure (Callable, optional): A closure that reevaluates the model
|
|
and returns the loss.
|
|
"""
|
|
loss = None
|
|
|
|
if closure is not None:
|
|
with torch.enable_grad():
|
|
loss = closure()
|
|
|
|
for group in self.param_groups:
|
|
params_with_grad = []
|
|
grads = []
|
|
state_sums = []
|
|
state_steps = []
|
|
|
|
has_sparse_grad = self._init_group(group, params_with_grad, grads, state_sums, state_steps)
|
|
|
|
adagrad(
|
|
params_with_grad,
|
|
grads,
|
|
state_sums,
|
|
state_steps,
|
|
lr=group["lr"],
|
|
weight_decay=group["weight_decay"],
|
|
lr_decay=group["lr_decay"],
|
|
eps=group["eps"],
|
|
has_sparse_grad=has_sparse_grad,
|
|
foreach=group["foreach"],
|
|
maximize=group["maximize"],
|
|
differentiable=group["differentiable"],
|
|
)
|
|
|
|
return loss
|
|
|
|
|
|
Adagrad.__doc__ = r"""Implements Adagrad algorithm.
|
|
|
|
.. math::
|
|
\begin{aligned}
|
|
&\rule{110mm}{0.4pt} \\
|
|
&\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, \: f(\theta)
|
|
\text{ (objective)}, \: \lambda \text{ (weight decay)}, \\
|
|
&\hspace{12mm} \tau \text{ (initial accumulator value)}, \: \eta\text{ (lr decay)}\\
|
|
&\textbf{initialize} : state\_sum_0 \leftarrow 0 \\[-1.ex]
|
|
&\rule{110mm}{0.4pt} \\
|
|
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
|
|
&\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
|
|
&\hspace{5mm} \tilde{\gamma} \leftarrow \gamma / (1 +(t-1) \eta) \\
|
|
&\hspace{5mm} \textbf{if} \: \lambda \neq 0 \\
|
|
&\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
|
|
&\hspace{5mm}state\_sum_t \leftarrow state\_sum_{t-1} + g^2_t \\
|
|
&\hspace{5mm}\theta_t \leftarrow
|
|
\theta_{t-1}- \tilde{\gamma} \frac{g_t}{\sqrt{state\_sum_t}+\epsilon} \\
|
|
&\rule{110mm}{0.4pt} \\[-1.ex]
|
|
&\bf{return} \: \theta_t \\[-1.ex]
|
|
&\rule{110mm}{0.4pt} \\[-1.ex]
|
|
\end{aligned}
|
|
|
|
For further details regarding the algorithm we refer to `Adaptive Subgradient Methods for Online Learning
|
|
and Stochastic Optimization`_.
|
|
""" + fr"""
|
|
Args:
|
|
params (iterable): iterable of parameters to optimize or dicts defining
|
|
parameter groups
|
|
lr (float, optional): learning rate (default: 1e-2)
|
|
lr_decay (float, optional): learning rate decay (default: 0)
|
|
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
|
eps (float, optional): term added to the denominator to improve
|
|
numerical stability (default: 1e-10)
|
|
{_foreach_doc}
|
|
{_maximize_doc}
|
|
{_differentiable_doc}
|
|
|
|
.. _Adaptive Subgradient Methods for Online Learning and Stochastic
|
|
Optimization: http://jmlr.org/papers/v12/duchi11a.html
|
|
|
|
"""
|
|
|
|
|
|
def adagrad(
|
|
params: List[Tensor],
|
|
grads: List[Tensor],
|
|
state_sums: List[Tensor],
|
|
state_steps: List[Tensor],
|
|
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
|
|
# setting these as kwargs for now as functional API is compiled by torch/distributed/optim
|
|
has_sparse_grad: bool = None,
|
|
foreach: Optional[bool] = None,
|
|
differentiable: bool = False,
|
|
*,
|
|
lr: float,
|
|
weight_decay: float,
|
|
lr_decay: float,
|
|
eps: float,
|
|
maximize: bool,
|
|
):
|
|
r"""Functional API that performs Adagrad algorithm computation.
|
|
|
|
See :class:`~torch.optim.Adagrad` for details.
|
|
"""
|
|
|
|
if not all(isinstance(t, torch.Tensor) for t in state_steps):
|
|
raise RuntimeError(
|
|
"API has changed, `state_steps` argument must contain a list of singleton tensors"
|
|
)
|
|
|
|
if foreach is None:
|
|
_, foreach = _default_to_fused_or_foreach(params, differentiable, use_fused=False)
|
|
|
|
if foreach and torch.jit.is_scripting():
|
|
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
|
|
|
|
if foreach and not torch.jit.is_scripting():
|
|
func = _multi_tensor_adagrad
|
|
else:
|
|
func = _single_tensor_adagrad
|
|
|
|
func(
|
|
params,
|
|
grads,
|
|
state_sums,
|
|
state_steps,
|
|
lr=lr,
|
|
weight_decay=weight_decay,
|
|
lr_decay=lr_decay,
|
|
eps=eps,
|
|
has_sparse_grad=has_sparse_grad,
|
|
maximize=maximize,
|
|
differentiable=differentiable,
|
|
)
|
|
|
|
|
|
def _make_sparse(grad, grad_indices, values):
|
|
size = grad.size()
|
|
if grad_indices.numel() == 0 or values.numel() == 0:
|
|
return torch.empty_like(grad)
|
|
return torch.sparse_coo_tensor(grad_indices, values, size)
|
|
|
|
|
|
def _single_tensor_adagrad(
|
|
params: List[Tensor],
|
|
grads: List[Tensor],
|
|
state_sums: List[Tensor],
|
|
state_steps: List[Tensor],
|
|
*,
|
|
lr: float,
|
|
weight_decay: float,
|
|
lr_decay: float,
|
|
eps: float,
|
|
has_sparse_grad: bool,
|
|
maximize: bool,
|
|
differentiable: bool,
|
|
):
|
|
|
|
for (param, grad, state_sum, step_t) in zip(params, grads, state_sums, state_steps):
|
|
# update step
|
|
step_t += 1
|
|
step = _get_value(step_t)
|
|
grad = grad if not maximize else -grad
|
|
|
|
if weight_decay != 0:
|
|
if grad.is_sparse:
|
|
raise RuntimeError(
|
|
"weight_decay option is not compatible with sparse gradients"
|
|
)
|
|
grad = grad.add(param, alpha=weight_decay)
|
|
|
|
clr = lr / (1 + (step - 1) * lr_decay)
|
|
|
|
if grad.is_sparse:
|
|
grad = grad.coalesce() # the update is non-linear so indices must be unique
|
|
grad_indices = grad._indices()
|
|
grad_values = grad._values()
|
|
|
|
state_sum.add_(_make_sparse(grad, grad_indices, grad_values.pow(2)))
|
|
std = state_sum.sparse_mask(grad)
|
|
std_values = std._values().sqrt_().add_(eps)
|
|
param.add_(
|
|
_make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr
|
|
)
|
|
else:
|
|
is_complex = torch.is_complex(param)
|
|
if is_complex:
|
|
grad = torch.view_as_real(grad)
|
|
state_sum = torch.view_as_real(state_sum)
|
|
param = torch.view_as_real(param)
|
|
state_sum.addcmul_(grad, grad, value=1)
|
|
if differentiable:
|
|
std = state_sum.sqrt() + eps
|
|
else:
|
|
std = state_sum.sqrt().add_(eps)
|
|
param.addcdiv_(grad, std, value=-clr)
|
|
if is_complex:
|
|
param = torch.view_as_complex(param)
|
|
state_sum = torch.view_as_complex(state_sum)
|
|
|
|
|
|
def _multi_tensor_adagrad(
|
|
params: List[Tensor],
|
|
grads: List[Tensor],
|
|
state_sums: List[Tensor],
|
|
state_steps: List[Tensor],
|
|
*,
|
|
lr: float,
|
|
weight_decay: float,
|
|
lr_decay: float,
|
|
eps: float,
|
|
has_sparse_grad: bool,
|
|
maximize: bool,
|
|
differentiable: bool,
|
|
):
|
|
|
|
assert not differentiable, "_foreach ops don't support autograd"
|
|
|
|
# Foreach functions will throw errors if given empty lists
|
|
if len(params) == 0:
|
|
return
|
|
|
|
grouped_tensorlists = Optimizer._group_tensors_by_device_and_dtype([params, grads, state_sums, state_steps])
|
|
for ((device_params, device_grads, device_state_sums, device_state_steps), _) in grouped_tensorlists.values():
|
|
device_has_sparse_grad = has_sparse_grad and any(grad.is_sparse for grad in device_grads)
|
|
|
|
if device_has_sparse_grad:
|
|
_single_tensor_adagrad(
|
|
device_params,
|
|
device_grads,
|
|
device_state_sums,
|
|
device_state_steps,
|
|
lr=lr,
|
|
weight_decay=weight_decay,
|
|
lr_decay=lr_decay,
|
|
eps=eps,
|
|
has_sparse_grad=True,
|
|
maximize=False,
|
|
differentiable=differentiable,
|
|
)
|
|
continue
|
|
|
|
if maximize:
|
|
device_grads = torch._foreach_neg(device_grads)
|
|
|
|
# Handle complex parameters
|
|
for i in range(len(device_params)):
|
|
if torch.is_complex(device_params[i]):
|
|
device_params[i] = torch.view_as_real(device_params[i])
|
|
device_grads[i] = torch.view_as_real(device_grads[i])
|
|
device_state_sums[i] = torch.view_as_real(device_state_sums[i])
|
|
|
|
# Update steps
|
|
torch._foreach_add_(device_state_steps, 1)
|
|
|
|
if weight_decay != 0:
|
|
# Re-use the intermediate memory (device_grads) already allocated for maximize
|
|
if maximize:
|
|
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
|
|
else:
|
|
device_grads = torch._foreach_add(device_grads, device_params, alpha=weight_decay)
|
|
|
|
minus_clr = [-lr / (1 + (_get_value(step) - 1) * lr_decay) for step in device_state_steps]
|
|
|
|
torch._foreach_addcmul_(device_state_sums, device_grads, device_grads, value=1)
|
|
|
|
std = torch._foreach_sqrt(device_state_sums)
|
|
torch._foreach_add_(std, eps)
|
|
|
|
if weight_decay != 0 or maximize:
|
|
# Again, re-use the intermediate memory (device_grads) already allocated
|
|
torch._foreach_mul_(device_grads, minus_clr)
|
|
numerator = device_grads
|
|
else:
|
|
numerator = torch._foreach_mul(device_grads, minus_clr)
|
|
|
|
torch._foreach_addcdiv_(device_params, numerator, std)
|