mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Faster gradient clipping using the foreach functions
```
[------------------------ (tensors, scalar) -------------------------]
| without foreach | with foreach | apex
1 threads: ----------------------------------------------------------------------
10 tensors of size 4 | 120.5 | 61.1 | 50.3
100 tensors of size 4 | 946.2 | 239.5 | 136.3
1000 tensors of size 4 | 9808.5 | 2151.1 | 1006.9
10000 tensors of size 4 | 96871.2 | 22637.4 | 10119.1
10 tensors of size 16 | 121.0 | 64.1 | 52.5
100 tensors of size 16 | 993.4 | 252.6 | 136.7
1000 tensors of size 16 | 9427.7 | 2151.2 | 1049.5
10000 tensors of size 16 | 97437.1 | 22203.1 | 10340.0
10 tensors of size 256 | 118.9 | 62.3 | 51.5
100 tensors of size 256 | 955.2 | 243.1 | 134.2
1000 tensors of size 256 | 9374.9 | 2140.7 | 1009.6
10000 tensors of size 256 | 95302.5 | 21849.4 | 10215.5
10 tensors of size 65536 | 118.5 | 62.4 | 51.1
100 tensors of size 65536 | 1740.7 | 243.3 | 225.3
1000 tensors of size 65536 | 17364.1 | 2228.7 | 2004.5
10000 tensors of size 65536 | 177510.1 | 25410.4 | 20678.2
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91846
Approved by: https://github.com/janeyx99
38 lines
2.1 KiB
Python
38 lines
2.1 KiB
Python
from collections import defaultdict
|
|
from typing import List, Dict, Tuple, Optional, Union
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
from torch.autograd.grad_mode import no_grad
|
|
|
|
|
|
# This util function splits tensors into groups by device and dtype, which is useful before sending
|
|
# tensors off to a foreach implementation, which requires tensors to be on one device and dtype.
|
|
# If tensorlistlist contains more than one tensorlist, the following assumptions are made BUT NOT verified:
|
|
# - tensorlists CAN be None
|
|
# - all tensors in the first specified list cannot be None
|
|
# - given an index i, all specified tensorlist[i]s match in dtype and device
|
|
# with_indices (bool, optional): whether to track previous indices as the last list per dictionary entry.
|
|
# It comes in handy if there are Nones or literals in the tensorlists that are getting scattered out.
|
|
# Whereas mutating a tensor in the resulting split-up tensorlists WILL propagate changes back to the
|
|
# original input tensorlists, changing up Nones/literals WILL NOT propagate, and manual propagation
|
|
# may be necessary. Check out torch/optim/sgd.py for an example.
|
|
@no_grad()
|
|
def _group_tensors_by_device_and_dtype(tensorlistlist: List[List[Tensor]],
|
|
with_indices: Optional[bool] = False) -> \
|
|
Dict[Tuple[torch.device, torch.dtype], List[List[Union[Tensor, int]]]]:
|
|
assert all([not x or len(x) == len(tensorlistlist[0]) for x in tensorlistlist]), (
|
|
"all specified tensorlists must match in length")
|
|
per_device_and_dtype_tensors: Dict[Tuple[torch.device, torch.dtype], List[List[Union[Tensor, int]]]] = defaultdict(
|
|
lambda: [[] for _ in range(len(tensorlistlist) + (1 if with_indices else 0))])
|
|
for i, t in enumerate(tensorlistlist[0]):
|
|
key = (t.device, t.dtype)
|
|
for j in range(len(tensorlistlist)):
|
|
# a tensorlist may be empty/None
|
|
if tensorlistlist[j]:
|
|
per_device_and_dtype_tensors[key][j].append(tensorlistlist[j][i])
|
|
if with_indices:
|
|
# tack on previous index
|
|
per_device_and_dtype_tensors[key][j + 1].append(i)
|
|
return per_device_and_dtype_tensors
|