mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18598 ghimport-source-id: c74597e5e7437e94a43c163cee0639b20d0d0c6a Stack from [ghstack](https://github.com/ezyang/ghstack): * **#18598 Turn on F401: Unused import warning.** This was requested by someone at Facebook; this lint is turned on for Facebook by default. "Sure, why not." I had to noqa a number of imports in __init__. Hypothetically we're supposed to use __all__ in this case, but I was too lazy to fix it. Left for future work. Be careful! flake8-2 and flake8-3 behave differently with respect to import resolution for # type: comments. flake8-3 will report an import unused; flake8-2 will not. For now, I just noqa'd all these sites. All the changes were done by hand. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Differential Revision: D14687478 fbshipit-source-id: 30d532381e914091aadfa0d2a5a89404819663e3
513 lines
24 KiB
Python
513 lines
24 KiB
Python
import copy
|
|
|
|
import torch
|
|
|
|
from torch.cuda.comm import broadcast_coalesced
|
|
import torch.distributed as dist
|
|
|
|
if dist.is_available():
|
|
from torch.distributed.distributed_c10d import _get_default_group
|
|
|
|
from ..modules import Module
|
|
from .replicate import replicate
|
|
from .scatter_gather import scatter_kwargs, gather
|
|
from .parallel_apply import parallel_apply
|
|
from torch.cuda._utils import _get_device_index
|
|
|
|
|
|
class DistributedDataParallel(Module):
|
|
r"""Implements distributed data parallelism that is based on
|
|
``torch.distributed`` package at the module level.
|
|
|
|
This container parallelizes the application of the given module by
|
|
splitting the input across the specified devices by chunking in the batch
|
|
dimension. The module is replicated on each machine and each device, and
|
|
each such replica handles a portion of the input. During the backwards
|
|
pass, gradients from each node are averaged.
|
|
|
|
The batch size should be larger than the number of GPUs used locally.
|
|
|
|
See also: :ref:`distributed-basics` and :ref:`cuda-nn-dataparallel-instead`.
|
|
The same constraints on input as in :class:`torch.nn.DataParallel` apply.
|
|
|
|
Creation of this class requires that ``torch.distributed`` to be already
|
|
initialized, by calling :func:`torch.distributed.init_process_group`.
|
|
|
|
``DistributedDataParallel`` can be used in the following two ways:
|
|
|
|
(1) Single-Process Multi-GPU
|
|
|
|
In this case, a single process will be
|
|
spawned on each host/node and each process will operate on all the GPUs
|
|
of the node where it's running. To use ``DistributedDataParallel`` in
|
|
this way, you can simply construct the model as the following:
|
|
|
|
>>> torch.distributed.init_process_group(backend="nccl")
|
|
>>> model = DistributedDataParallel(model) # device_ids will include all GPU devices by default
|
|
|
|
(2) Multi-Process Single-GPU
|
|
|
|
This is the highly recommended way to use ``DistributedDataParallel``, with
|
|
multiple processes, each of which operates on a single GPU. This is
|
|
currently the fastest approach to do data parallel training using PyTorch
|
|
and applies to both single-node(multi-GPU) and multi-node data
|
|
parallel training. It is proven to be significantly faster than
|
|
:class:`torch.nn.DataParallel` for single-node multi-GPU data
|
|
parallel training.
|
|
|
|
Here is how to use it: on each host with N GPUs, you should spawn up N
|
|
processes, while ensuring that each process individually works on a single GPU
|
|
from 0 to N-1. Therefore, it is your job to ensure that your training script
|
|
operates on a single given GPU by calling:
|
|
|
|
>>> torch.cuda.set_device(i)
|
|
|
|
where i is from 0 to N-1. In each process, you should refer the following
|
|
to construct this module:
|
|
|
|
>>> torch.distributed.init_process_group(backend='nccl', world_size=4, init_method='...')
|
|
>>> model = DistributedDataParallel(model, device_ids=[i], output_device=i)
|
|
|
|
In order to spawn up multiple processes per node, you can use either
|
|
``torch.distributed.launch`` or ``torch.multiprocessing.spawn``
|
|
|
|
.. note:: ``nccl`` backend is currently the fastest and
|
|
highly recommended backend to be used with Multi-Process Single-GPU
|
|
distributed training and this applies to both single-node and multi-node
|
|
distributed training
|
|
|
|
.. note:: This module also supports mixed-precision distributed training.
|
|
This means that your model can have different types of parameters such
|
|
as mixed types of fp16 and fp32, the gradient reduction on these
|
|
mixed types of parameters will just work fine.
|
|
Also note that ``nccl`` backend is currently the fastest and highly
|
|
recommended backend for fp16/fp32 mixed-precision training.
|
|
|
|
.. warning::
|
|
This module works only with the ``gloo`` and ``nccl`` backends.
|
|
|
|
.. warning::
|
|
Constructor, forward method, and differentiation of the output (or a
|
|
function of the output of this module) is a distributed synchronization
|
|
point. Take that into account in case different processes might be
|
|
executing different code.
|
|
|
|
.. warning::
|
|
This module assumes all parameters are registered in the model by the
|
|
time it is created. No parameters should be added nor removed later.
|
|
Same applies to buffers.
|
|
|
|
.. warning::
|
|
This module assumes all parameters are registered in the model of each
|
|
distributed processes are in the same order. The module itself will
|
|
conduct gradient all-reduction following the reverse order of the
|
|
registered parameters of the model. In other words, it is users'
|
|
responsibility to ensure that each distributed process has the exact
|
|
same model and thus the exact same parameter registration order.
|
|
|
|
.. warning::
|
|
This module assumes all buffers and gradients are dense.
|
|
|
|
.. warning::
|
|
This module doesn't work with :func:`torch.autograd.grad` (i.e. it will
|
|
only work if gradients are to be accumulated in ``.grad`` attributes of
|
|
parameters).
|
|
|
|
.. warning::
|
|
|
|
If you plan on using this module with a ``nccl`` backend or a ``gloo``
|
|
backend (that uses Infiniband), together with a DataLoader that uses
|
|
multiple workers, please change the multiprocessing start method to
|
|
``forkserver`` (Python 3 only) or ``spawn``. Unfortunately
|
|
Gloo (that uses Infiniband) and NCCL2 are not fork safe, and you will
|
|
likely experience deadlocks if you don't change this setting.
|
|
|
|
.. warning::
|
|
Forward and backward hooks defined on :attr:`module` and its submodules
|
|
won't be invoked anymore, unless the hooks are initialized in the
|
|
:meth:`forward` method.
|
|
|
|
.. warning::
|
|
You should never try to change your model's parameters after wrapping
|
|
up your model with DistributedDataParallel. In other words, when
|
|
wrapping up your model with DistributedDataParallel, the constructor of
|
|
DistributedDataParallel will register the additional gradient
|
|
reduction functions on all the parameters of the model itself at the
|
|
time of construction. If you change the model's parameters after
|
|
the DistributedDataParallel construction, this is not supported and
|
|
unexpected behaviors can happen, since some parameters' gradient
|
|
reduction functions might not get called.
|
|
|
|
.. note::
|
|
Parameters are never broadcast between processes. The module performs
|
|
an all-reduce step on gradients and assumes that they will be modified
|
|
by the optimizer in all processes in the same way. Buffers
|
|
(e.g. BatchNorm stats) are broadcast from the module in process of rank
|
|
0, to all other replicas in the system in every iteration.
|
|
|
|
Args:
|
|
module (Module): module to be parallelized
|
|
device_ids (list of int or torch.device): CUDA devices (default: all devices)
|
|
output_device (int or torch.device): device location of output (default: device_ids[0])
|
|
broadcast_buffers (bool): flag that enables syncing (broadcasting) buffers of
|
|
the module at beginning of the forward function.
|
|
(default: ``True``)
|
|
process_group: the process group to be used for distributed data
|
|
all-reduction. If ``None``, the default process group, which
|
|
is created by ```torch.distributed.init_process_group```,
|
|
will be used. (default: ``None``)
|
|
bucket_cap_mb: DistributedDataParallel will bucket parameters into
|
|
multiple buckets so that gradient reduction of each
|
|
bucket can potentially overlap with backward computation.
|
|
:attr:`bucket_cap_mb` controls the bucket size in MegaBytes (MB)
|
|
(default: 25)
|
|
check_reduction: when setting to ``True``, it enables DistributedDataParallel
|
|
to automatically check if the previous iteration's
|
|
backward reductions were successfully issued at the
|
|
beginning of every iteration's forward function.
|
|
You normally don't need this option enabled unless you
|
|
are observing weird behaviors such as different ranks
|
|
are getting different gradients, which should not
|
|
happen if DistributedDataParallel is correctly used.
|
|
(default: ``False``)
|
|
|
|
Attributes:
|
|
module (Module): the module to be parallelized
|
|
|
|
Example::
|
|
|
|
>>> torch.distributed.init_process_group(backend='nccl', world_size=4, init_method='...')
|
|
>>> net = torch.nn.DistributedDataParallel(model, pg)
|
|
"""
|
|
def __init__(self, module, device_ids=None,
|
|
output_device=None, dim=0, broadcast_buffers=True,
|
|
process_group=None, bucket_cap_mb=25,
|
|
check_reduction=False):
|
|
|
|
super(DistributedDataParallel, self).__init__()
|
|
|
|
# Use all devices by default
|
|
if device_ids is None:
|
|
device_ids = list(range(torch.cuda.device_count()))
|
|
|
|
if output_device is None:
|
|
output_device = device_ids[0]
|
|
|
|
if process_group is None:
|
|
self.process_group = _get_default_group()
|
|
else:
|
|
self.process_group = process_group
|
|
|
|
self.dim = dim
|
|
self.module = module
|
|
self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
|
|
self.output_device = _get_device_index(output_device, True)
|
|
self.broadcast_buffers = broadcast_buffers
|
|
self.check_reduction = check_reduction
|
|
|
|
MB = 1024 * 1024
|
|
|
|
# used for intra-node param sync and inter-node sync as well
|
|
self.broadcast_bucket_size = 250 * MB
|
|
|
|
# reduction bucket size
|
|
self.bucket_bytes_cap = bucket_cap_mb * MB
|
|
|
|
# Sync params and buffers
|
|
module_states = list(self.module.state_dict().values())
|
|
if len(module_states) > 0:
|
|
self._dist_broadcast_coalesced(module_states,
|
|
self.broadcast_bucket_size)
|
|
|
|
self._ddp_init_helper()
|
|
|
|
def _ddp_init_helper(self):
|
|
"""
|
|
Initialization helper function that does the following:
|
|
|
|
(1) replicating the module from device[0] to the other devices
|
|
(2) bucketing the parameters for reductions
|
|
(3) resetting the bucketing states
|
|
(4) registering the grad hooks
|
|
(5) passing a handle of DDP to SyncBatchNorm Layer
|
|
"""
|
|
if len(self.device_ids) > 1:
|
|
# TODO: we don't need to replicate params in here. they're always going to
|
|
# be broadcasted using larger blocks in broadcast_coalesced, so it might be
|
|
# better to not pollute the caches with these small blocks
|
|
self._module_copies = replicate(self.module, self.device_ids, detach=True)
|
|
self._module_copies[0] = self.module
|
|
|
|
for module_copy in self._module_copies[1:]:
|
|
for param, copy_param in zip(self.module.parameters(), module_copy.parameters()):
|
|
copy_param.requires_grad = param.requires_grad
|
|
|
|
else:
|
|
self._module_copies = [self.module]
|
|
|
|
self.modules_params_data = [[] for _ in range(len(self.device_ids))]
|
|
self.modules_buffers_data = [[] for _ in range(len(self.device_ids))]
|
|
|
|
for dev_idx, module in enumerate(self._module_copies):
|
|
self.modules_params_data[dev_idx] = [p.data for p in module.parameters()]
|
|
self.modules_buffers_data[dev_idx] = [b.data for b in module.buffers()]
|
|
|
|
# This is a triply-nested list where the "dimensions" are: devices, buckets, bucket_elems
|
|
param_buckets = []
|
|
|
|
# Split the parameters into buckets and by types as well
|
|
# We only need to bucket and reduce parameters that require grad and
|
|
# this is also true for backward since only the backward hooks for
|
|
# parameters that require grad will be registered with gradient
|
|
# reduction functions
|
|
params_to_bucket = [[] for _ in self._module_copies]
|
|
for dev_idx, m in enumerate(self._module_copies):
|
|
for p in m.parameters():
|
|
if p.requires_grad:
|
|
params_to_bucket[dev_idx].append(p)
|
|
|
|
param_buckets = [dist._dist_bucket_tensors(dev_params_to_bucket,
|
|
int(self.bucket_bytes_cap),
|
|
fine_grained=False)
|
|
for dev_params_to_bucket in params_to_bucket]
|
|
|
|
self.bucket_sizes = []
|
|
self.bucket_map = {}
|
|
|
|
# We transpose param_buckets, so the loop is over buckets.
|
|
# param_buckets_tuple is a doubly-nested list with "dims": devices, bucket_elems
|
|
for bucket_idx, param_buckets_tuple in enumerate(zip(*param_buckets)):
|
|
self.bucket_sizes.append(0)
|
|
# Now, we transpose again, so we iterate over bucket_elems, but getting tuples
|
|
# of params from each device.
|
|
for param_tuple in zip(*param_buckets_tuple):
|
|
if not param_tuple[0].requires_grad:
|
|
continue
|
|
for p in param_tuple:
|
|
self.bucket_map[p] = (bucket_idx, self.bucket_sizes[bucket_idx])
|
|
self.bucket_sizes[bucket_idx] += 1
|
|
|
|
self.buckets = [[[None for _ in range(self.bucket_sizes[i])]
|
|
for _ in range(len(self.device_ids))] for i in range(len(self.bucket_sizes))]
|
|
# The number of params ready in each bucket
|
|
self.buckets_ready_size = [[0 for _ in range(len(self.device_ids))] for i in range(len(self.bucket_sizes))]
|
|
|
|
# coalesced bucket for only device 0
|
|
self.buckets_coalesced = [[] for _ in range(len(self.bucket_sizes))]
|
|
# We will always reduce the bucket following the reverse order
|
|
# that is, alway reduces following the order of: n - 1, n - 2, ..., 0
|
|
self.next_bucket = len(self.bucket_sizes) - 1
|
|
# When all buckets are reduced, this will be set to True. This flag is
|
|
# useful for sanity checks to ensure that each iteration's backward has
|
|
# always reduced all buckets
|
|
self.all_buckets_reduced = False
|
|
self.check_previous_reduction = False
|
|
self.ready_buckets_not_reduced = set()
|
|
self.reduction_works = [None for _ in range(len(self.bucket_sizes))]
|
|
self.devs_ready = [0 for _ in range(len(self.bucket_sizes))]
|
|
self._register_grad_hooks()
|
|
|
|
# passing a handle to torch.nn.SyncBatchNorm layer
|
|
self._passing_sync_batchnorm_handle(self._module_copies)
|
|
|
|
def __getstate__(self):
|
|
self._check_default_group()
|
|
attrs = copy.copy(self.__dict__)
|
|
del attrs['process_group'], \
|
|
attrs['default_streams'], \
|
|
attrs['_grad_accs']
|
|
return attrs
|
|
|
|
def __setstate__(self, state):
|
|
# If serializable, then the process group should be the default one
|
|
self.process_group = _get_default_group()
|
|
self.check_previous_reduction = False
|
|
super(DistributedDataParallel, self).__setstate__(state)
|
|
self._ddp_init_helper()
|
|
|
|
def _check_default_group(self):
|
|
pickle_not_supported = False
|
|
try:
|
|
if self.process_group != _get_default_group():
|
|
pickle_not_supported = True
|
|
except RuntimeError:
|
|
pickle_not_supported = True
|
|
|
|
if pickle_not_supported:
|
|
raise RuntimeError("DDP Pickling/Unpickling are only supported "
|
|
"when using DDP with the default process "
|
|
"group. That is, when you have called "
|
|
"init_process_group and have not passed "
|
|
"process_group argument to DDP constructor")
|
|
|
|
def _check_previous_reduction(self):
|
|
if not self.training:
|
|
return
|
|
# self.check_previous_reduction will be False in the first iteration
|
|
# and is then toggled to True for all future iterations.
|
|
if self.check_previous_reduction is False:
|
|
self.check_previous_reduction = True
|
|
else:
|
|
if not self.all_buckets_reduced:
|
|
raise RuntimeError("Not all gradients have been reduced from "
|
|
"the backward of the previous iteration. "
|
|
"This is an unexpected and fatal error. "
|
|
"Please check and ensure that the model's "
|
|
"parameters are not changed after you wrap "
|
|
"up the model with DistributedDataParallel.")
|
|
self.all_buckets_reduced = False
|
|
|
|
def forward(self, *inputs, **kwargs):
|
|
if self.check_reduction:
|
|
self._check_previous_reduction()
|
|
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
|
|
self._sync_params()
|
|
if len(self.device_ids) == 1:
|
|
return self.module(*inputs[0], **kwargs[0])
|
|
outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs)
|
|
return self.gather(outputs, self.output_device)
|
|
|
|
def scatter(self, inputs, kwargs, device_ids):
|
|
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
|
|
|
|
def parallel_apply(self, replicas, inputs, kwargs):
|
|
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
|
|
|
|
def gather(self, outputs, output_device):
|
|
return gather(outputs, output_device, dim=self.dim)
|
|
|
|
def train(self, mode=True):
|
|
self.check_previous_reduction = False
|
|
super(DistributedDataParallel, self).train(mode)
|
|
for module in self._module_copies[1:]:
|
|
module.train(mode)
|
|
|
|
def _dist_broadcast_coalesced(self, tensors, buffer_size):
|
|
dist._dist_broadcast_coalesced(self.process_group, tensors, buffer_size, False)
|
|
|
|
def _sync_params(self):
|
|
if len(self.device_ids) > 1:
|
|
# intra-node parameter sync
|
|
result = broadcast_coalesced(self.modules_params_data[0],
|
|
self.device_ids,
|
|
self.broadcast_bucket_size)
|
|
for tensors, module_params_data in zip(result[1:], self.modules_params_data[1:]):
|
|
for tensor, param_data in zip(tensors, module_params_data):
|
|
param_data.set_(tensor)
|
|
|
|
# module buffer sync
|
|
if self.broadcast_buffers:
|
|
if len(self.modules_buffers_data[0]) > 0:
|
|
# cross-node buffer sync
|
|
self._dist_broadcast_coalesced(self.modules_buffers_data[0],
|
|
self.broadcast_bucket_size)
|
|
if len(self.device_ids) > 1:
|
|
# intra-node buffer sync
|
|
result = broadcast_coalesced(self.modules_buffers_data[0],
|
|
self.device_ids,
|
|
self.broadcast_bucket_size)
|
|
for tensors, module_buffers_data in zip(result[1:], self.modules_buffers_data[1:]):
|
|
for tensor, buffer_data in zip(tensors, module_buffers_data):
|
|
buffer_data.set_(tensor)
|
|
|
|
def _passing_sync_batchnorm_handle(self, module_copies):
|
|
for dev_idx, module in enumerate(module_copies):
|
|
for layer in module.modules():
|
|
if isinstance(layer, torch.nn.modules.SyncBatchNorm):
|
|
layer._specify_ddp_gpu_num(len(self.device_ids))
|
|
|
|
def _register_grad_hooks(self):
|
|
self._grad_accs = [] # need to keep them in scope
|
|
|
|
# default stream tracking to launch nccl reduce kernels
|
|
self.default_streams = []
|
|
for dev_id in self.device_ids:
|
|
with torch.cuda.device(dev_id):
|
|
self.default_streams.append(torch.cuda.current_stream())
|
|
|
|
for device_idx, module in enumerate(self._module_copies):
|
|
for p in module.parameters():
|
|
if p.requires_grad:
|
|
p_tmp = p.expand_as(p)
|
|
grad_acc = p_tmp.grad_fn.next_functions[0][0]
|
|
grad_acc.register_hook(self._make_param_hook(p, device_idx))
|
|
self._grad_accs.append(grad_acc)
|
|
|
|
def _make_param_hook(self, param, device_idx):
|
|
bucket_idx, bucket_offset = self.bucket_map[param]
|
|
|
|
def distributed_data_parallel_hook(*unused):
|
|
if param.grad.requires_grad:
|
|
raise RuntimeError("DistributedDataParallel only works "
|
|
"with gradients that don't require grad")
|
|
bucket = self.buckets[bucket_idx][device_idx]
|
|
bucket[bucket_offset] = param.grad.data
|
|
self.buckets_ready_size[bucket_idx][device_idx] += 1
|
|
|
|
# We can flush these and save memory for replicas
|
|
if device_idx > 0:
|
|
param.grad = None
|
|
param.data.set_()
|
|
|
|
# Current device's bucket is full
|
|
if self.buckets_ready_size[bucket_idx][device_idx] == self.bucket_sizes[bucket_idx]:
|
|
self.devs_ready[bucket_idx] += 1
|
|
if self.devs_ready[bucket_idx] < len(self.device_ids):
|
|
return
|
|
|
|
# Now all devices's buckets with index: bucket_idx are ready
|
|
if bucket_idx == self.next_bucket:
|
|
self._queue_reduction(bucket_idx)
|
|
self.next_bucket -= 1
|
|
# Now reduce anything that is ready but not yet reduced
|
|
if len(self.ready_buckets_not_reduced) > 0:
|
|
sorted_todo = sorted(self.ready_buckets_not_reduced, reverse=True)
|
|
for i in sorted_todo:
|
|
# Nothing can be reduced now
|
|
if i < self.next_bucket:
|
|
break
|
|
self._queue_reduction(i)
|
|
self.ready_buckets_not_reduced.remove(i)
|
|
if i == self.next_bucket:
|
|
self.next_bucket -= 1
|
|
else:
|
|
self.ready_buckets_not_reduced.add(bucket_idx)
|
|
|
|
# When all devices' buckets
|
|
if self.next_bucket == -1:
|
|
# A final sync for all the reduction works
|
|
self._sync_reduction_works()
|
|
self.all_buckets_reduced = True
|
|
|
|
return distributed_data_parallel_hook
|
|
|
|
def _queue_reduction(self, bucket_idx):
|
|
# _queue_reduction will use a seperate CUDA stream to coalesce
|
|
# the small tensors to achieve more parallelisms, before passing the
|
|
# coalesced tensor into the c10d CUDA stream for reduction
|
|
result = dist._queue_reduction(self.process_group,
|
|
self.buckets[bucket_idx],
|
|
self.device_ids)
|
|
self.reduction_works[bucket_idx] = result[0]
|
|
self.buckets_coalesced[bucket_idx] = result[1]
|
|
|
|
def _sync_reduction_works(self):
|
|
# Now only work on the first GPU of self.device_ids
|
|
# _sync_reduction will use a seperate CUDA stream to uncoalesce
|
|
# the coalesced tensors to achieve more parallelisms
|
|
for bucket_idx, grads_batch in enumerate(self.buckets):
|
|
dist._sync_reduction(self.reduction_works[bucket_idx],
|
|
grads_batch[0],
|
|
self.buckets_coalesced[bucket_idx])
|
|
|
|
# Reset the module states
|
|
self.next_bucket = len(self.bucket_sizes) - 1
|
|
self.ready_buckets_not_reduced = set()
|
|
self.reduction_works = [None for _ in range(len(self.bucket_sizes))]
|
|
self.devs_ready = [0 for _ in range(len(self.bucket_sizes))]
|
|
|
|
self.buckets = [[[None for _ in range(self.bucket_sizes[i])]
|
|
for _ in range(len(self.device_ids))] for i in range(len(self.bucket_sizes))]
|
|
self.buckets_coalesced = [[] for _ in range(len(self.bucket_sizes))]
|
|
self.buckets_ready_size = [[0 for _ in range(len(self.device_ids))] for i in range(len(self.bucket_sizes))]
|