pytorch/torch/optim/adamax.py
Jon Chuang df7d01aed5 perf(inductor): use for loop with shortcut in Optimizers to speedup against list comprehensions (e.g. complex conversion) (#110613)
Fully fixes: https://github.com/pytorch/pytorch/issues/110506

Depends: https://github.com/pytorch/pytorch/pull/110607
Potential merge conflicts:
- https://github.com/pytorch/pytorch/pull/110339
- https://github.com/pytorch/pytorch/pull/110345
- https://github.com/pytorch/pytorch/pull/110454

Related:
- https://github.com/pytorch/pytorch/issues/110606 (we can apply the improvements here orthogonally to the complex support)

### Results

Benchmark: 100 params.

Breakdowns (float32, dynamo):
```
Adagrad: this PR: 4.4s, main: 8.8s
Adam: this PR: 2.1s, main: 9.8s
AdamW: this PR: 2.5s, main: 8.2s
ASGD: this PR: 3.1s, main: 8.5s
RMSProp: this PR: 1.3s, main: 4.2s
RProp: this PR: 6.7s, main: 14.9s
```

Notes:
1. Adagrad is still slow due to `_get_value` list comprehension. Can be fixed in https://github.com/pytorch/pytorch/pull/110339/files by utilizing capturable path
2. Adamax is not actually compiled (it is currently disabled).
3. Inductor compile time is quite variable. We calculate dynamo by subtracting `call_user_compiler` from `compile_inner` timing.

<details>

This PR:
```
Adagrad (torch.float32): 28.47496461868286s
Adagrad (torch.complex64): 29.379547357559204s
Adam (torch.float32): 17.334211587905884s
Adam (torch.complex64): 29.637500524520874s
Adamax (torch.float32): 2.4749321937561035s
Adamax (torch.complex64): 3.1997995376586914s
AdamW (torch.float32): 18.06532859802246s
AdamW (torch.complex64): 28.25661015510559s
ASGD (torch.float32): 23.70255398750305s
ASGD (torch.complex64): 25.33756995201111s
RMSprop (torch.float32): 7.964028596878052s
RMSprop (torch.complex64): 12.909599781036377s
Rprop (torch.float32): 30.512362003326416s
Rprop (torch.complex64): 44.74405765533447s
```

Main
```
Adagrad (torch.float32): 26.919506072998047s
Adagrad (torch.complex64): 35.190622091293335s
Adam (torch.float32): 25.715000867843628s
Adam (torch.complex64): 24.17716670036316s
Adamax (torch.float32): 2.4404726028442383s
Adamax (torch.complex64): 3.3538928031921387s
AdamW (torch.float32): 25.2022807598114s
AdamW (torch.complex64): 28.915700912475586s
ASGD (torch.float32): 24.108731985092163s
ASGD (torch.complex64): 26.589075088500977s
RMSprop (torch.float32): 10.781344175338745s
RMSprop (torch.complex64): 15.136352777481079s
Rprop (torch.float32): 42.46482181549072s
Rprop (torch.complex64): 48.28277635574341s
```

Seems that it doesn't help the complex case by much (but that's not the majority case). torch.float32 is generally positive, when it does not show drastic improvement / regresses, it is due to inductor variance (by manually inspecting the logs).

</details>

### Benchmark Script
```python
import torch
import time
from torch.optim import Adagrad, Adam, Adamax, AdamW, ASGD, RMSprop, Rprop

OPTIMS = [Adagrad, Adam, Adamax, AdamW, ASGD, RMSprop, Rprop]
DTYPES = [torch.float, torch.cfloat]

NUM_PARAMS = 100
kwargs = { "lr": 0.01, "foreach": True }
summary = []

for optim_cls in OPTIMS:
    for dtype in DTYPES:
        torch._dynamo.reset()
        # torch._inductor.metrics.reset()
        input = torch.ones([10, 10], dtype=dtype, device="cuda:0")
        model = torch.nn.Sequential(
            *[torch.nn.Linear(10, 10, dtype=dtype, device="cuda:0") for _ in range(NUM_PARAMS)]
        )

        model(input).sum().abs().backward()
        opt_compiled = optim_cls(model.parameters(), **kwargs)
        compiled_step = torch.compile(opt_compiled.step)

        with torch.set_grad_enabled(False):
            start_time = time.time()
            compiled_step()
            summary.append(f"{optim_cls.__name__} ({dtype}): {time.time() - start_time}s")

        print(optim_cls, kwargs, dtype, torch._dynamo.utils.compile_times())

for s in summary:
    print(s)
```

CC: @janeyx99 @mlazos
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110613
Approved by: https://github.com/janeyx99
2023-10-05 23:10:52 +00:00

345 lines
12 KiB
Python

import torch
from torch import Tensor
from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _stack_if_compiling,
_default_to_fused_or_foreach, _differentiable_doc, _maximize_doc, _foreach_doc)
from typing import List, Optional
__all__ = ["Adamax", "adamax"]
class Adamax(Optimizer):
def __init__(
self,
params,
lr=2e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
foreach: Optional[bool] = None,
*,
maximize: bool = False,
differentiable: bool = False,
):
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
foreach=foreach,
maximize=maximize,
differentiable=differentiable,
)
super().__init__(params, defaults)
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault("foreach", None)
group.setdefault("maximize", False)
group.setdefault("differentiable", False)
state_values = list(self.state.values())
step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
state_values[0]["step"]
)
if not step_is_tensor:
for s in state_values:
s["step"] = torch.tensor(float(s["step"]))
def _init_group(self, group, params_with_grad, grads, exp_avgs, exp_infs, state_steps):
for p in group["params"]:
if p.grad is None:
continue
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError("Adamax does not support sparse gradients")
grads.append(p.grad)
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = torch.tensor(0.0)
state["exp_avg"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
state["exp_inf"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
exp_avgs.append(state["exp_avg"])
exp_infs.append(state["exp_inf"])
state_steps.append(state["step"])
@_use_grad_for_differentiable
def step(self, closure=None):
"""Performs a single optimization step.
Args:
closure (Callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad = []
grads = []
exp_avgs = []
exp_infs = []
state_steps = []
beta1, beta2 = group["betas"]
eps = group["eps"]
lr = group["lr"]
weight_decay = group["weight_decay"]
foreach = group["foreach"]
maximize = group["maximize"]
differentiable = group["differentiable"]
self._init_group(group, params_with_grad, grads, exp_avgs, exp_infs, state_steps)
adamax(
params_with_grad,
grads,
exp_avgs,
exp_infs,
state_steps,
eps=eps,
beta1=beta1,
beta2=beta2,
lr=lr,
weight_decay=weight_decay,
foreach=foreach,
maximize=maximize,
differentiable=differentiable,
)
return loss
Adamax.__doc__ = r"""Implements Adamax algorithm (a variant of Adam based on infinity norm).
.. math::
\begin{aligned}
&\rule{110mm}{0.4pt} \\
&\textbf{input} : \gamma \text{ (lr)}, \beta_1, \beta_2
\text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)},
\: \lambda \text{ (weight decay)}, \\
&\hspace{13mm} \epsilon \text{ (epsilon)} \\
&\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)},
u_0 \leftarrow 0 \text{ ( infinity norm)} \\[-1.ex]
&\rule{110mm}{0.4pt} \\
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
&\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
&\hspace{5mm}if \: \lambda \neq 0 \\
&\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
&\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
&\hspace{5mm}u_t \leftarrow \mathrm{max}(\beta_2 u_{t-1}, |g_{t}|+\epsilon) \\
&\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \frac{\gamma m_t}{(1-\beta^t_1) u_t} \\
&\rule{110mm}{0.4pt} \\[-1.ex]
&\bf{return} \: \theta_t \\[-1.ex]
&\rule{110mm}{0.4pt} \\[-1.ex]
\end{aligned}
For further details regarding the algorithm we refer to `Adam: A Method for Stochastic Optimization`_.
""" + fr"""
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 2e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
{_foreach_doc}
{_maximize_doc}
{_differentiable_doc}
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
"""
def adamax(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_infs: List[Tensor],
state_steps: List[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
foreach: Optional[bool] = None,
maximize: bool = False,
differentiable: bool = False,
*,
eps: float,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
):
r"""Functional API that performs adamax algorithm computation.
See :class:`~torch.optim.Adamax` for details.
"""
if not all(isinstance(t, torch.Tensor) for t in state_steps):
raise RuntimeError(
"API has changed, `state_steps` argument must contain a list of singleton tensors"
)
if foreach is None:
_, foreach = _default_to_fused_or_foreach(params, differentiable, use_fused=False)
if foreach and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
if foreach and not torch.jit.is_scripting():
func = _multi_tensor_adamax
else:
func = _single_tensor_adamax
func(
params,
grads,
exp_avgs,
exp_infs,
state_steps,
eps=eps,
beta1=beta1,
beta2=beta2,
lr=lr,
weight_decay=weight_decay,
maximize=maximize,
differentiable=differentiable,
)
def _single_tensor_adamax(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_infs: List[Tensor],
state_steps: List[Tensor],
*,
eps: float,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
maximize: bool,
differentiable: bool,
):
for i, param in enumerate(params):
grad = grads[i]
grad = grad if not maximize else -grad
exp_avg = exp_avgs[i]
exp_inf = exp_infs[i]
step_t = state_steps[i]
# update step
step_t += 1
if weight_decay != 0:
grad = grad.add(param, alpha=weight_decay)
if torch.is_complex(param):
param = torch.view_as_real(param)
grad = torch.view_as_real(grad)
exp_avg = torch.view_as_real(exp_avg)
exp_inf = torch.view_as_real(exp_inf)
# Update biased first moment estimate.
exp_avg.lerp_(grad, 1 - beta1)
# Update the exponentially weighted infinity norm.
norm_buf = torch.cat(
[exp_inf.mul_(beta2).unsqueeze(0), grad.abs().add_(eps).unsqueeze_(0)], 0
)
if not differentiable:
torch.amax(norm_buf, 0, keepdim=False, out=exp_inf)
else:
exp_inf.copy_(torch.amax(norm_buf, 0, keepdim=False))
bias_correction = 1 - beta1 ** _get_value(step_t)
clr = lr / bias_correction
param.addcdiv_(exp_avg, exp_inf, value=-clr)
def _multi_tensor_adamax(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_infs: List[Tensor],
state_steps: List[Tensor],
*,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
eps: float,
maximize: bool,
differentiable: bool,
):
assert not differentiable, "_foreach ops don't support autograd"
if len(params) == 0:
return
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_infs, state_steps])
for ((grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_infs, grouped_state_steps), _) in grouped_tensors.values():
if maximize:
grouped_grads = torch._foreach_neg(grouped_grads)
for i in range(len(grouped_params)):
if torch.is_complex(grouped_params[i]):
grouped_params[i] = torch.view_as_real(grouped_params[i])
grouped_grads[i] = torch.view_as_real(grouped_grads[i])
grouped_exp_avgs[i] = torch.view_as_real(grouped_exp_avgs[i])
grouped_exp_infs[i] = torch.view_as_real(grouped_exp_infs[i])
# Update steps
torch._foreach_add_(grouped_state_steps, 1)
if weight_decay != 0:
if maximize:
# Re-use the intermediate memory (grouped_grads) already allocated for maximize
torch._foreach_add_(grouped_grads, grouped_params, alpha=weight_decay)
else:
grouped_grads = torch._foreach_add(grouped_grads, grouped_params, alpha=weight_decay)
# Update biased first moment estimate.
torch._foreach_lerp_(grouped_exp_avgs, grouped_grads, 1 - beta1)
# Update the exponentially weighted infinity norm.
torch._foreach_mul_(grouped_exp_infs, beta2)
for exp_inf, grad in zip(grouped_exp_infs, grouped_grads):
norm_buf = torch.cat(
[exp_inf.unsqueeze(0), grad.abs().add_(eps).unsqueeze_(0)], 0
)
torch.max(norm_buf, 0, keepdim=False, out=(exp_inf, exp_inf.new().long()))
bias_corrections = [1 - beta1 ** _get_value(step) for step in grouped_state_steps]
clr = _stack_if_compiling([-1 * (lr / bias_correction) for bias_correction in bias_corrections])
torch._foreach_addcdiv_(grouped_params, grouped_exp_avgs, grouped_exp_infs, clr)