mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary:
This PR adds fused Adam and AdamW implementations.
Benchmark on Macbook Pro with M1 Max chip and 64GB unified memory:
**Fast math enabled:**
```
[---------------------------------------------- Fused Adam ----------------------------------------------]
| Fused: True | Fused: False
1 threads: -----------------------------------------------------------------------------------------------
amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 100 | 10 | 100
amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 100 | 9 | 89
amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 100 | 9 | 90
amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 100 | 9 | 83
amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 100 | 12 | 94
amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 100 | 11 | 88
amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 100 | 12 | 90
amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 100 | 11 | 100
amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 100 | 27 | 100
amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 100 | 23 | 100
amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 100 | 27 | 100
amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 100 | 23 | 98
amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 500 | 82 | 480
amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 500 | 72 | 450
amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 500 | 82 | 450
amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 500 | 73 | 420
amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 500 | 91 | 500
amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 500 | 83 | 400
amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 500 | 94 | 500
amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 500 | 78 | 400
amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 500 | 170 | 500
amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 500 | 140 | 600
amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 500 | 170 | 600
amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 500 | 140 | 500
amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 1000 | 250 | 890
amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 1000 | 220 | 850
amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 1000 | 250 | 830
amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 1000 | 220 | 770
amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 1000 | 270 | 870
amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 1000 | 230 | 840
amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 1000 | 270 | 810
amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 1000 | 240 | 800
amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 1000 | 400 | 1000
amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 1000 | 360 | 2000
amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 1000 | 430 | 2000
amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 1000 | 360 | 1300
Times are in milliseconds (ms).
```
**Fast math disabled:**
```
[---------------------------------------------- Fused Adam ----------------------------------------------]
| Fused: True | Fused: False
1 threads: -----------------------------------------------------------------------------------------------
amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 100 | 10 | 100
amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 100 | 9 | 84
amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 100 | 9 | 84
amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 100 | 9 | 79
amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 100 | 11 | 93
amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 100 | 10 | 90
amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 100 | 11 | 91
amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 100 | 11 | 81
amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 100 | 34 | 100
amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 100 | 31 | 100
amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 100 | 34 | 95
amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 100 | 31 | 100
amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 500 | 94 | 500
amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 500 | 82 | 430
amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 500 | 92 | 430
amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 500 | 81 | 390
amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 500 | 98 | 500
amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 500 | 88 | 430
amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 500 | 100 | 500
amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 500 | 88 | 400
amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 500 | 210 | 500
amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 500 | 190 | 610
amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 500 | 210 | 510
amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 500 | 190 | 500
amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 1000 | 300 | 900
amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 1000 | 260 | 850
amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 1000 | 295 | 900
amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 1000 | 260 | 800
amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 1000 | 320 | 910
amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 1000 | 280 | 900
amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 1000 | 320 | 900
amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 1000 | 300 | 900
amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 1000 | 500 | 2000
amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 1000 | 480 | 2000
amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 1000 | 540 | 1500
amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 1000 | 480 | 1200
Times are in milliseconds (ms).
```
```python
def profile_fused_adam():
from torch.optim import adam, adamw
import torch.utils.benchmark as benchmark
import itertools
def profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused):
fn(
params,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
foreach=False,
capturable=False,
fused=fused,
amsgrad=amsgrad,
beta1=0.9,
beta2=0.99,
lr=1e-3,
weight_decay=.0,
eps=1e-5,
maximize=False,
grad_scale=None,
found_inf=None,
)
torch.mps.synchronize()
device = "mps"
results = []
for num_tensors, numel, adamWflag, amsgrad in itertools.product([100, 500, 1000], [1024, 65536, 1048576], [True, False], [True, False]):
print(f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}")
params, grads, exp_avgs, exp_avg_sqs = [[torch.arange(numel, dtype=torch.float32, device=device) + (numel * i) for i in range(num_tensors)] for _ in range(4)]
max_exp_avg_sqs = [torch.arange(numel, dtype=torch.float32, device=device) for _ in range(num_tensors)] if amsgrad else []
state_steps = [torch.tensor([5], dtype=torch.float32, device=device) for _ in range(num_tensors)]
if adamWflag:
fn = adamw.adamw
else:
fn = adam.adam
for fused in [True, False]:
t = benchmark.Timer(
stmt='profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused)',
label='Fused Adam',
sub_label=f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}",
globals=locals(),
description= f"Fused: {fused}",
).blocked_autorange(min_run_time=5)
results.append(t)
compare = benchmark.Compare(results)
compare.trim_significant_figures()
compare.colorize(rowwise=True)
compare.print()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127242
Approved by: https://github.com/kulinseth, https://github.com/janeyx99
45 lines
2.3 KiB
Python
45 lines
2.3 KiB
Python
from typing import List, Dict, Tuple, Optional
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
from torch.autograd.grad_mode import no_grad
|
|
from typing_extensions import TypeAlias
|
|
|
|
def _get_foreach_kernels_supported_devices() -> List[str]:
|
|
r"""Return the device type list that supports foreach kernels."""
|
|
return ["cuda", "xpu", torch._C._get_privateuse1_backend_name()]
|
|
|
|
def _get_fused_kernels_supported_devices() -> List[str]:
|
|
r"""Return the device type list that supports fused kernels in optimizer."""
|
|
return ["mps", "cuda", "xpu", "cpu", torch._C._get_privateuse1_backend_name()]
|
|
|
|
TensorListList: TypeAlias = List[List[Optional[Tensor]]]
|
|
Indices: TypeAlias = List[int]
|
|
_foreach_supported_types = [torch.Tensor]
|
|
|
|
|
|
# This util function splits tensors into groups by device and dtype, which is useful before sending
|
|
# tensors off to a foreach implementation, which requires tensors to be on one device and dtype.
|
|
# If tensorlistlist contains more than one tensorlist, the following assumptions are made BUT NOT verified:
|
|
# - tensorlists CAN be None
|
|
# - all tensors in the first specified list cannot be None
|
|
# - given an index i, all specified tensorlist[i]s match in dtype and device
|
|
# with_indices (bool, optional): whether to track previous indices as the last list per dictionary entry.
|
|
# It comes in handy if there are Nones or literals in the tensorlists that are getting scattered out.
|
|
# Whereas mutating a tensor in the resulting split-up tensorlists WILL propagate changes back to the
|
|
# original input tensorlists, changing up Nones/literals WILL NOT propagate, and manual propagation
|
|
# may be necessary. Check out torch/optim/sgd.py for an example.
|
|
@no_grad()
|
|
def _group_tensors_by_device_and_dtype(
|
|
tensorlistlist: TensorListList,
|
|
with_indices: bool = False,
|
|
) -> Dict[Tuple[torch.device, torch.dtype], Tuple[TensorListList, Indices]]:
|
|
return torch._C._group_tensors_by_device_and_dtype(tensorlistlist, with_indices)
|
|
|
|
def _device_has_foreach_support(device: torch.device) -> bool:
|
|
return device.type in (_get_foreach_kernels_supported_devices() + ["cpu"]) and not torch.jit.is_scripting()
|
|
|
|
|
|
def _has_foreach_support(tensors: List[Tensor], device: torch.device) -> bool:
|
|
return _device_has_foreach_support(device) and all(t is None or type(t) in _foreach_supported_types for t in tensors)
|