mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: alloctor -> allocator excutable -> executable excution -> execution foward -> forward initiaize -> initialize paralell -> parallel preprocesor -> preprocessor tranpose -> transpose Pull Request resolved: https://github.com/pytorch/pytorch/pull/21665 Differential Revision: D15806155 Pulled By: soumith fbshipit-source-id: d92b21ec8650a2b32f05faf9af0b7d2b073e992c
485 lines
22 KiB
Python
485 lines
22 KiB
Python
import sys
|
|
import threading
|
|
import copy
|
|
|
|
import torch
|
|
from torch.autograd import Variable
|
|
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors, \
|
|
_take_tensors
|
|
import torch.utils.hooks
|
|
|
|
from torch.cuda.comm import broadcast_coalesced
|
|
from torch.cuda import nccl
|
|
import torch.distributed.deprecated as dist
|
|
|
|
from ...modules import Module
|
|
from ..replicate import replicate
|
|
from ..scatter_gather import scatter_kwargs, gather
|
|
from ..parallel_apply import parallel_apply
|
|
|
|
if sys.version_info[0] == 3:
|
|
import queue
|
|
else:
|
|
import Queue as queue
|
|
|
|
|
|
class DistributedDataParallel(Module):
|
|
r"""Implements distributed data parallelism at the module level.
|
|
|
|
This container parallelizes the application of the given module by
|
|
splitting the input across the specified devices by chunking in the batch
|
|
dimension. The module is replicated on each machine and each device, and
|
|
each such replica handles a portion of the input. During the backwards
|
|
pass, gradients from each node are averaged.
|
|
|
|
The batch size should be larger than the number of GPUs used locally. It
|
|
should also be an integer multiple of the number of GPUs so that each chunk
|
|
is the same size (so that each GPU processes the same number of samples).
|
|
|
|
See also: :ref:`distributed-basics` and :ref:`cuda-nn-dataparallel-instead`.
|
|
The same constraints on input as in :class:`torch.nn.DataParallel` apply.
|
|
|
|
Creation of this class requires the distributed package to be already
|
|
initialized in the process group mode
|
|
(see :func:`torch.distributed.deprecated.init_process_group`).
|
|
|
|
.. warning::
|
|
This module works only with the ``nccl`` and ``gloo`` backends.
|
|
|
|
.. warning::
|
|
Constructor, forward method, and differentiation of the output (or a
|
|
function of the output of this module) is a distributed synchronization
|
|
point. Take that into account in case different processes might be
|
|
executing different code.
|
|
|
|
.. warning::
|
|
This module assumes all parameters are registered in the model by the
|
|
time it is created. No parameters should be added nor removed later.
|
|
Same applies to buffers.
|
|
|
|
.. warning::
|
|
This module assumes all buffers and gradients are dense.
|
|
|
|
.. warning::
|
|
This module doesn't work with :func:`torch.autograd.grad` (i.e. it will
|
|
only work if gradients are to be accumulated in ``.grad`` attributes of
|
|
parameters).
|
|
|
|
.. warning::
|
|
If you plan on using this module with a ``nccl`` backend or a ``gloo``
|
|
backend (that uses Infiniband), together with a DataLoader that uses
|
|
multiple workers, please change the multiprocessing start method to
|
|
``forkserver`` (Python 3 only) or ``spawn``. Unfortunately
|
|
Gloo (that uses Infiniband) and NCCL2 are not fork safe, and you will
|
|
likely experience deadlocks if you don't change this setting.
|
|
|
|
.. note::
|
|
Parameters are never broadcast between processes. The module performs
|
|
an all-reduce step on gradients and assumes that they will be modified
|
|
by the optimizer in all processes in the same way. Buffers
|
|
(e.g. BatchNorm stats) are broadcast from the module in process of rank
|
|
0, to all other replicas in the system in every iteration.
|
|
|
|
.. warning::
|
|
Forward and backward hooks defined on :attr:`module` and its submodules
|
|
won't be invoked anymore, unless the hooks are initialized in the
|
|
:meth:`forward` method.
|
|
|
|
Args:
|
|
module: module to be parallelized
|
|
device_ids: CUDA devices (default: all devices)
|
|
output_device: device location of output (default: device_ids[0])
|
|
broadcast_buffers: flag that enables syncing (broadcasting) buffers of
|
|
the module at beginning of the forward function.
|
|
(default: True)
|
|
|
|
Attributes:
|
|
module (Module): the module to be parallelized
|
|
|
|
Example::
|
|
|
|
>>> torch.distributed.deprecated.init_process_group(world_size=4, init_method='...')
|
|
>>> net = torch.nn.DistributedDataParallel(model)
|
|
"""
|
|
|
|
def __init__(self, module, device_ids=None, output_device=None, dim=0,
|
|
broadcast_buffers=True):
|
|
super(DistributedDataParallel, self).__init__()
|
|
if dist._backend not in (dist.dist_backend.NCCL, dist.dist_backend.GLOO):
|
|
raise ValueError('Invalid backend, only NCCL and GLOO backends are supported by DistributedDataParallel')
|
|
|
|
if device_ids is None:
|
|
device_ids = list(range(torch.cuda.device_count()))
|
|
if output_device is None:
|
|
output_device = device_ids[0]
|
|
self.dim = dim
|
|
self.module = module
|
|
self.device_ids = device_ids
|
|
self.output_device = output_device
|
|
self.broadcast_buffers = broadcast_buffers
|
|
|
|
# Flag used by the NCCL backend to make sure we only reduce gradients
|
|
# one time in the execution engine
|
|
self.need_reduction = False
|
|
|
|
MB = 1024 * 1024
|
|
# used for intra-node param sync and inter-node sync as well
|
|
self.broadcast_bucket_size = 10 * MB
|
|
self.nccl_reduce_bucket_size = 256 * MB
|
|
|
|
# Sync params and buffers
|
|
module_states = list(self.module.state_dict().values())
|
|
if len(module_states) > 0:
|
|
self._dist_broadcast_coalesced(module_states,
|
|
self.broadcast_bucket_size)
|
|
|
|
if len(device_ids) > 1:
|
|
# TODO: we don't need to replicate params in here. they're always going to
|
|
# be broadcasted using larger blocks in broadcast_coalesced, so it might be
|
|
# better to not pollute the caches with these small blocks
|
|
self._module_copies = replicate(self.module, self.device_ids, detach=True)
|
|
self._module_copies[0] = self.module
|
|
|
|
for module_copy in self._module_copies[1:]:
|
|
for param, copy_param in zip(self.module.parameters(), module_copy.parameters()):
|
|
copy_param.requires_grad = param.requires_grad
|
|
|
|
else:
|
|
self._module_copies = [self.module]
|
|
|
|
# For NCCL backend, since every single NCCL call is asynchoronous, we
|
|
# therefore directly enqueue all the NCCL reduction calls to the
|
|
# default CUDA stream without spawning up other reduction threads.
|
|
# This achieves the best performance.
|
|
if dist._backend == dist.dist_backend.NCCL:
|
|
self._register_nccl_grad_hook()
|
|
return
|
|
|
|
bucket_bytes_cap = 1 * MB
|
|
|
|
# This is a triply-nested list where the "dimensions" are: devices, buckets, bucket_elems
|
|
param_buckets = []
|
|
# Split the parameters into buckets and by types as well
|
|
for dev_idx, module in enumerate(self._module_copies):
|
|
param_buckets.append(list(_take_tensors(module.parameters(), bucket_bytes_cap)))
|
|
|
|
self.bucket_sizes = []
|
|
self.bucket_map = {}
|
|
|
|
# We transpose param_buckets, so the loop is over buckets.
|
|
# param_buckets_tuple is a doubly-nested list with "dims": devices, bucket_elems
|
|
for bucket_idx, param_buckets_tuple in enumerate(zip(*param_buckets)):
|
|
self.bucket_sizes.append(0)
|
|
# Now, we transpose again, so we iterate over bucket_elems, but getting tuples
|
|
# of params from each device.
|
|
for idx, param_tuple in enumerate(zip(*param_buckets_tuple)):
|
|
if idx == 0:
|
|
# Bucket parameter type tracking
|
|
bucket_param_type = param_tuple[0].type()
|
|
# Only gloo and nccl support half-precision
|
|
if bucket_param_type == torch.cuda.HalfTensor and \
|
|
dist._backend != dist.dist_backend.GLOO:
|
|
raise RuntimeError("DistributedDataParallel currently only "
|
|
"supports half precision parameters "
|
|
"with Nccl and Gloo backend")
|
|
if not param_tuple[0].requires_grad:
|
|
continue
|
|
for p in param_tuple:
|
|
self.bucket_map[p] = bucket_idx
|
|
self.bucket_sizes[bucket_idx] += 1
|
|
|
|
self.buckets = [[[] for _ in range(len(self.device_ids))] for _ in range(len(self.bucket_sizes))]
|
|
self.bucket_events = [[None] * len(self.device_ids) for _ in range(len(self.bucket_sizes))]
|
|
self.reduced = [False] * len(self.bucket_sizes)
|
|
|
|
self._register_grad_hooks()
|
|
self.dispatch_lock = threading.Lock()
|
|
self._start_reduction_threads()
|
|
|
|
def __getstate__(self):
|
|
attrs = copy.copy(self.__dict__)
|
|
if dist._backend != dist.dist_backend.NCCL:
|
|
del attrs['_grad_accs'], attrs['_reduction_queues'], \
|
|
attrs['_reduction_streams'], attrs['_reduction_threads'], \
|
|
attrs['_nccl_streams'], attrs['_default_streams'], \
|
|
attrs['dispatch_lock']
|
|
return attrs
|
|
|
|
def __setstate__(self, state):
|
|
super(DistributedDataParallel, self).__setstate__(state)
|
|
if dist._backend == dist.dist_backend.NCCL:
|
|
self._register_nccl_grad_hook()
|
|
else:
|
|
self._register_grad_hooks()
|
|
self.dispatch_lock = threading.Lock()
|
|
self._start_reduction_threads()
|
|
|
|
def forward(self, *inputs, **kwargs):
|
|
self.need_reduction = True
|
|
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
|
|
self._sync_params()
|
|
if len(self.device_ids) == 1:
|
|
return self.module(*inputs[0], **kwargs[0])
|
|
outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs)
|
|
return self.gather(outputs, self.output_device)
|
|
|
|
def scatter(self, inputs, kwargs, device_ids):
|
|
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
|
|
|
|
def parallel_apply(self, replicas, inputs, kwargs):
|
|
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
|
|
|
|
def gather(self, outputs, output_device):
|
|
return gather(outputs, output_device, dim=self.dim)
|
|
|
|
def train(self, mode=True):
|
|
super(DistributedDataParallel, self).train(mode)
|
|
for module in self._module_copies[1:]:
|
|
module.train(mode)
|
|
|
|
def _dist_broadcast_coalesced(self, tensors, buffer_size):
|
|
"""
|
|
Broadcast a sequence of tensors to the default group from rank 0.
|
|
Small tensors are first coalesced into a buffer to reduce the number of
|
|
broadcasts.
|
|
|
|
tensors (sequence): tensors to broadcast. Each tensor needs to be on the
|
|
same GPU.
|
|
buffer_size (int): maximum size of the buffer for coalescing
|
|
"""
|
|
for tensors in _take_tensors(tensors, buffer_size):
|
|
flat_tensors = _flatten_dense_tensors(tensors)
|
|
dist.broadcast(flat_tensors, 0)
|
|
for tensor, synced in zip(tensors,
|
|
_unflatten_dense_tensors(flat_tensors, tensors)):
|
|
tensor.copy_(synced)
|
|
|
|
def _sync_params(self):
|
|
if len(self.device_ids) > 1:
|
|
# intra-node parameter sync
|
|
params = [p.data for p in self.module.parameters()]
|
|
result = broadcast_coalesced(params, self.device_ids, self.broadcast_bucket_size)
|
|
for tensors, module in zip(result[1:], self._module_copies[1:]):
|
|
for tensor, param in zip(tensors, module.parameters()):
|
|
param.data.set_(tensor)
|
|
|
|
# module buffer sync
|
|
if self.broadcast_buffers:
|
|
buffers = [b.data for b in self.module.buffers()]
|
|
if len(buffers) > 0:
|
|
# cross-node buffer sync
|
|
self._dist_broadcast_coalesced(buffers, self.broadcast_bucket_size)
|
|
|
|
if len(self.device_ids) > 1:
|
|
# intra-node buffer sync
|
|
result = broadcast_coalesced(buffers, self.device_ids, self.broadcast_bucket_size)
|
|
for tensors, module in zip(result[1:], self._module_copies[1:]):
|
|
for tensor, buf in zip(tensors, module.buffers()):
|
|
buf.data.set_(tensor)
|
|
|
|
def _register_grad_hooks(self):
|
|
self._grad_accs = [] # need to keep them in scope
|
|
for device_idx, module in enumerate(self._module_copies):
|
|
for p in module.parameters():
|
|
if p.requires_grad:
|
|
p_tmp = p.expand_as(p)
|
|
grad_acc = p_tmp.grad_fn.next_functions[0][0]
|
|
grad_acc.register_hook(self._make_param_hook(p, device_idx))
|
|
self._grad_accs.append(grad_acc)
|
|
|
|
def _register_nccl_grad_hook(self):
|
|
"""
|
|
This function registers the callback all-reduction function for the
|
|
NCCL backend. All gradients will be all reduced in one single step.
|
|
The NCCL reduction will directly be enqueued into the
|
|
default CUDA stream. Therefore, no synchronization is needed.
|
|
"""
|
|
# Creating a new group
|
|
self.nccl_reduction_group_id = dist.new_group()
|
|
|
|
def reduction_fn_nccl():
|
|
# This function only needs to be called once
|
|
if not self.need_reduction:
|
|
return
|
|
|
|
self.need_reduction = False
|
|
all_grads = [[] for _ in range(len(self._module_copies))]
|
|
all_grads_buckets_iters = []
|
|
|
|
# Bucketing all the gradients
|
|
for dev_idx, module in enumerate(self._module_copies):
|
|
for param in module.parameters():
|
|
if not param.requires_grad or param.grad is None:
|
|
continue
|
|
if param.grad.requires_grad:
|
|
raise RuntimeError("DistributedDataParallel only works "
|
|
"with gradients that don't require "
|
|
"grad")
|
|
# Adding the gradients for reduction
|
|
all_grads[dev_idx].append(param.grad.data)
|
|
|
|
# Now bucketing the parameters
|
|
dev_grads_buckets = _take_tensors(all_grads[dev_idx],
|
|
self.nccl_reduce_bucket_size)
|
|
|
|
all_grads_buckets_iters.append(dev_grads_buckets)
|
|
|
|
# Now reduce each bucket one after another
|
|
for grads_batch in zip(*all_grads_buckets_iters):
|
|
grads_batch_coalesced = []
|
|
# Coalesce each bucket
|
|
for dev_idx, dev_grads_batch in enumerate(grads_batch):
|
|
dev_id = self.device_ids[dev_idx]
|
|
with torch.cuda.device(dev_id):
|
|
dev_grads_batch_coalesced = _flatten_dense_tensors(dev_grads_batch)
|
|
grads_batch_coalesced.append(dev_grads_batch_coalesced)
|
|
|
|
# We will only use device 0's results, but this single op should be
|
|
# faster than doing the following two operation sequentially:
|
|
# (1) intra-node reduce to lead GPU, followed by
|
|
# (2) inter-node allreduce for all the first lead GPUs in all nodes
|
|
dist.all_reduce_multigpu(grads_batch_coalesced,
|
|
group=self.nccl_reduction_group_id)
|
|
|
|
# Now only work on the first device of self.device_ids, uncoalesce
|
|
# the gradients for each bucket
|
|
grads_batch_coalesced[0] /= dist.get_world_size()
|
|
grads_batch_reduced = _unflatten_dense_tensors(grads_batch_coalesced[0], grads_batch[0])
|
|
for grad, reduced in zip(grads_batch[0], grads_batch_reduced):
|
|
grad.copy_(reduced)
|
|
|
|
# clear the gradients and save memory for replicas
|
|
for module in self._module_copies[1:]:
|
|
for param in module.parameters():
|
|
if param.requires_grad:
|
|
param.grad = None
|
|
param.data.set_()
|
|
|
|
# Now register the reduction hook on the parameters
|
|
for p in self.module.parameters():
|
|
if not p.requires_grad:
|
|
continue
|
|
|
|
@torch.utils.hooks.unserializable_hook
|
|
def allreduce_hook(*unused):
|
|
Variable._execution_engine.queue_callback(reduction_fn_nccl)
|
|
|
|
p.register_hook(allreduce_hook)
|
|
|
|
def _make_param_hook(self, param, device_idx):
|
|
|
|
bucket_idx = self.bucket_map[param]
|
|
|
|
def distributed_data_parallel_hook(*unused):
|
|
if param.grad.requires_grad:
|
|
raise RuntimeError("DistributedDataParallel only works with "
|
|
"gradients that don't require grad")
|
|
bucket = self.buckets[bucket_idx][device_idx]
|
|
bucket.append(param.grad.data)
|
|
|
|
# We can flush these and save memory for replicas
|
|
if device_idx > 0:
|
|
param.grad = None
|
|
param.data.set_()
|
|
|
|
# Current device's bucket is full
|
|
if len(bucket) == self.bucket_sizes[bucket_idx]:
|
|
with torch.cuda.device(self.device_ids[device_idx]):
|
|
event = torch.cuda.Event()
|
|
event.record()
|
|
with self.dispatch_lock:
|
|
self.bucket_events[bucket_idx][device_idx] = event
|
|
self._queue_reduction(bucket_idx)
|
|
|
|
return distributed_data_parallel_hook
|
|
|
|
def _queue_reduction(self, bucket_idx):
|
|
dev_buckets = self.buckets[bucket_idx]
|
|
dev_events = self.bucket_events[bucket_idx]
|
|
|
|
# Check if it's ready
|
|
if any(evt is None for evt in dev_events):
|
|
return
|
|
|
|
# Queue the reduction and make sure backward waits for it
|
|
event = threading.Event()
|
|
self._reduction_queues[bucket_idx].put((dev_buckets, dev_events, event))
|
|
Variable._execution_engine.queue_callback(lambda: event.wait())
|
|
|
|
# Reset bucket state
|
|
self.buckets[bucket_idx] = [[] for _ in range(len(self.device_ids))]
|
|
self.bucket_events[bucket_idx] = [None] * len(self.device_ids)
|
|
self.reduced[bucket_idx] = True
|
|
if all(self.reduced):
|
|
self.reduced = [False] * len(self.bucket_sizes)
|
|
|
|
def sync_reduction_streams():
|
|
# We only have to sync with the first one, but it's safer to do it this way
|
|
# in case we change the way in which we parallelize work
|
|
r_streams = zip(*self._reduction_streams)
|
|
for dev_id, default_stream, dev_r_streams in zip(self.device_ids, self._default_streams, r_streams):
|
|
with torch.cuda.device(dev_id):
|
|
for reduction_stream in dev_r_streams:
|
|
default_stream.wait_stream(reduction_stream)
|
|
Variable._execution_engine.queue_callback(sync_reduction_streams)
|
|
|
|
def _start_reduction_threads(self):
|
|
num_buckets = len(self.bucket_sizes)
|
|
self._reduction_queues = [queue.Queue() for _ in range(num_buckets)]
|
|
self._reduction_threads = []
|
|
self._reduction_streams = [[] for _ in range(num_buckets)]
|
|
self._nccl_streams = []
|
|
self._default_streams = []
|
|
for dev_id in self.device_ids:
|
|
with torch.cuda.device(dev_id):
|
|
# TODO: don't assume we're on a default stream
|
|
self._default_streams.append(torch.cuda.current_stream())
|
|
self._nccl_streams.append(torch.cuda.Stream())
|
|
for reduction_queue, reduction_streams in zip(self._reduction_queues, self._reduction_streams):
|
|
for dev_id in self.device_ids:
|
|
with torch.cuda.device(dev_id):
|
|
reduction_streams.append(torch.cuda.Stream())
|
|
# We only use the first device for distributed reductions
|
|
dist._register_stream(reduction_streams[0])
|
|
|
|
group_id = dist.new_group()
|
|
|
|
self._reduction_threads.append(threading.Thread(
|
|
target=self._reduction_thread_fn,
|
|
args=(reduction_queue, group_id, self.device_ids, reduction_streams, self._nccl_streams)))
|
|
self._reduction_threads[-1].daemon = True
|
|
self._reduction_threads[-1].start()
|
|
|
|
@staticmethod
|
|
def _reduction_thread_fn(queue, group_id, device_ids, reduction_streams, nccl_streams):
|
|
|
|
def _process_batch():
|
|
dev_grad_batch, dev_events, job_event = queue.get()
|
|
dev_coalesced = []
|
|
# Coalesce the tensors on all devices and start a local reduction
|
|
for dev_id, grad_batch, event, stream in zip(device_ids, dev_grad_batch, dev_events, reduction_streams):
|
|
with torch.cuda.device(dev_id), torch.cuda.stream(stream):
|
|
stream.wait_event(event)
|
|
coalesced = _flatten_dense_tensors(grad_batch)
|
|
dev_coalesced.append(coalesced)
|
|
# Wait for all copies to complete before starting the NCCL kernel
|
|
for stream in reduction_streams:
|
|
stream.synchronize()
|
|
nccl.reduce(dev_coalesced, root=0, streams=nccl_streams)
|
|
|
|
# From now on we're only going to work on the first device (from device_ids)
|
|
grad_batch = dev_grad_batch[0]
|
|
coalesced = dev_coalesced[0]
|
|
reduce_stream = reduction_streams[0]
|
|
with torch.cuda.stream(reduce_stream):
|
|
reduce_stream.wait_stream(nccl_streams[0])
|
|
coalesced /= dist.get_world_size()
|
|
dist.all_reduce(coalesced, group=group_id)
|
|
for grad, reduced in zip(grad_batch, _unflatten_dense_tensors(coalesced, grad_batch)):
|
|
grad.copy_(reduced)
|
|
job_event.set()
|
|
|
|
with torch.cuda.device(device_ids[0]):
|
|
while True:
|
|
_process_batch() # just to have a clear scope
|