mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Decouple DataParallel/DistributedDataParallel from CUDA (#38454)
Summary: Decouple DataParallel/DistributedDataParallel from CUDA to support more device types. - Move torch/cuda/comm.py to torch/nn/parallel/comm.py with minor changes for common devices support. Torch.cuda.comm is kept as is for backward compatibility - Provide common APIs to arbitrary device types without changing existing CUDA APIs in torch.cuda space. - Replace the torch.cuda calls in DataParellel/DistributedDataParallel with the new APIs. Related RFC: [https://github.com/pytorch/pytorch/issues/36160](https://github.com/pytorch/pytorch/issues/36160) Pull Request resolved: https://github.com/pytorch/pytorch/pull/38454 Differential Revision: D22051557 Pulled By: mrshenli fbshipit-source-id: 7842dad0e5d3ca0f6fb760bda49182dcf6653af8
This commit is contained in:
parent
75155df8b4
commit
8d570bc708
3
mypy.ini
3
mypy.ini
|
|
@ -203,6 +203,9 @@ ignore_errors = True
|
|||
[mypy-torch.nn.parallel._functions]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.nn.parallel.comm]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.nn.quantized.functional]
|
||||
ignore_errors = True
|
||||
|
||||
|
|
|
|||
|
|
@ -2113,20 +2113,20 @@ class DistributedDataParallelTest(MultiProcessTestCase):
|
|||
gpus = gpus[:2]
|
||||
model = DoubleGpuNet(gpus)
|
||||
|
||||
with self.assertRaisesRegex(AssertionError, "output_device .* single-device CUDA"):
|
||||
with self.assertRaisesRegex(AssertionError, "output_device .* single-device GPU"):
|
||||
ddp_model = DistributedDataParallel(
|
||||
model, output_device=gpus[1], process_group=process_group)
|
||||
|
||||
with self.assertRaisesRegex(AssertionError, "device_ids .* single-device CUDA"):
|
||||
with self.assertRaisesRegex(AssertionError, "device_ids .* single-device GPU"):
|
||||
ddp_model = DistributedDataParallel(
|
||||
model, device_ids=gpus, process_group=process_group)
|
||||
|
||||
with self.assertRaisesRegex(AssertionError, "only works with CUDA devices"):
|
||||
with self.assertRaisesRegex(AssertionError, "input module must be on the same type of devices"):
|
||||
model.fc1 = model.fc1.cpu()
|
||||
ddp_model = DistributedDataParallel(model, process_group=process_group)
|
||||
|
||||
model = model.cpu()
|
||||
with self.assertRaisesRegex(AssertionError, "device_ids .* single-device CUDA"):
|
||||
with self.assertRaisesRegex(AssertionError, "device_ids .* single-device GPU"):
|
||||
ddp_model = DistributedDataParallel(
|
||||
model, device_ids=gpus, process_group=process_group)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
import torch
|
||||
import torch._six
|
||||
from typing import Optional
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
import sys
|
||||
|
|
@ -416,3 +418,68 @@ class ExceptionWrapper(object):
|
|||
# (https://bugs.python.org/issue2651), so we work around it.
|
||||
msg = KeyErrorMessage(msg)
|
||||
raise self.exc_type(msg)
|
||||
|
||||
|
||||
def _get_available_device_type():
|
||||
if torch.cuda.is_available():
|
||||
return "cuda"
|
||||
# add more available device types here
|
||||
return None
|
||||
|
||||
|
||||
def _get_device_attr(get_member):
|
||||
device_type = _get_available_device_type()
|
||||
if device_type.lower() == "cuda":
|
||||
return get_member(torch.cuda)
|
||||
# add more available device types here
|
||||
return None
|
||||
|
||||
|
||||
def _get_current_device_index():
|
||||
# current device index
|
||||
return _get_device_attr(lambda m: m.current_device())
|
||||
|
||||
|
||||
def _get_all_device_indices():
|
||||
# all device index
|
||||
return _get_device_attr(lambda m: list(range(m.device_count())))
|
||||
|
||||
|
||||
def _get_devices_properties(device_ids):
|
||||
# all device properties
|
||||
return [_get_device_attr(lambda m: m.get_device_properties(i)) for i in device_ids]
|
||||
|
||||
|
||||
def _get_device_index(device, optional=False, allow_cpu=False) -> int:
|
||||
r"""Gets the device index from :attr:`device`, which can be a torch.device
|
||||
object, a Python integer, or ``None``.
|
||||
|
||||
If :attr:`device` is a torch.device object, returns the device index if it
|
||||
has index. Note that for a device without a specified index,
|
||||
i.e., ``torch.device('xxx')``, this will return the current default
|
||||
device of that type if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``,
|
||||
CPU devices will be accepted and ``-1`` will be returned in this case.
|
||||
|
||||
If :attr:`device` is a Python integer, it is returned as is.
|
||||
|
||||
If :attr:`device` is ``None``, this will return the current default
|
||||
device of the supported runtime platform if :attr:`optional` is ``True``.
|
||||
i.e., the current default CUDA device will be returned if CUDA runtime is supported.
|
||||
"""
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
device_idx: Optional[int]
|
||||
device_idx = None
|
||||
if isinstance(device, torch.device):
|
||||
if not allow_cpu and device.type == 'cpu':
|
||||
raise ValueError('Expected a non cpu device, but got: {}'.format(device))
|
||||
device_idx = -1 if device.type == 'cpu' else device.index
|
||||
if isinstance(device, int):
|
||||
device_idx = device
|
||||
if device_idx is None:
|
||||
if optional:
|
||||
device_idx = _get_current_device_index()
|
||||
else:
|
||||
raise ValueError('Expected a torch.device with a specified index '
|
||||
'or an integer, but got:{}'.format(device))
|
||||
return device_idx
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import torch
|
||||
from typing import Optional, Union
|
||||
from typing import Union
|
||||
from torch.types import Device
|
||||
# The _get_device_index has been moved to torch.utils._get_device_index
|
||||
from torch._utils import _get_device_index as _torch_get_device_index
|
||||
|
||||
|
||||
def _get_device_index(device: Union[Device, str, int], optional: bool = False,
|
||||
|
|
@ -21,27 +23,13 @@ def _get_device_index(device: Union[Device, str, int], optional: bool = False,
|
|||
"""
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
device_idx: Optional[int]
|
||||
if isinstance(device, torch.device):
|
||||
dev_type = device.type
|
||||
if allow_cpu:
|
||||
if device.type not in {'cuda', 'cpu'}:
|
||||
raise ValueError('Expected a cuda or cpu device, but got: {}'.format(device))
|
||||
elif device.type != 'cuda':
|
||||
raise ValueError('Expected a cuda device, but got: {}'.format(device))
|
||||
device_idx = -1 if device.type == 'cpu' else device.index
|
||||
else:
|
||||
if device is not None and not isinstance(device, torch._six.int_classes):
|
||||
raise ValueError('Cannot recognize device {}'.format(device))
|
||||
device_idx = device
|
||||
if device_idx is None:
|
||||
if optional:
|
||||
# default cuda device index
|
||||
return torch.cuda.current_device()
|
||||
else:
|
||||
raise ValueError('Expected a cuda device with a specified index '
|
||||
'or an integer, but got: {}'.format(device))
|
||||
return device_idx
|
||||
return _torch_get_device_index(device, optional, allow_cpu)
|
||||
|
||||
|
||||
def _dummy_type(name: str) -> type:
|
||||
|
|
|
|||
|
|
@ -1,233 +1,5 @@
|
|||
import warnings
|
||||
# The functions here have been moved to torch.nn.parallel.comm
|
||||
from torch.nn.parallel.comm import broadcast, broadcast_coalesced, reduce_add, \
|
||||
reduce_add_coalesced, scatter, gather
|
||||
|
||||
import torch
|
||||
|
||||
from . import nccl
|
||||
from torch._utils import _take_tensors, _flatten_dense_tensors, \
|
||||
_unflatten_dense_tensors, _reorder_tensors_as
|
||||
|
||||
|
||||
def broadcast(tensor, devices=None, *, out=None):
|
||||
r"""Broadcasts a tensor to specified CUDA devices.
|
||||
|
||||
Arguments:
|
||||
tensor (Tensor): tensor to broadcast. Can be on CPU or CUDA.
|
||||
devices (Iterable[torch.device, str or int], optional): an iterable of
|
||||
CUDA devices, among which to broadcast.
|
||||
out (Sequence[Tensor], optional, keyword-only): the CUDA tensors to
|
||||
store output results.
|
||||
|
||||
.. note::
|
||||
Exactly one of :attr:`devices` and :attr:`out` must be specified.
|
||||
|
||||
Returns:
|
||||
- If :attr:`devices` is specified,
|
||||
a tuple containing copies of :attr:`tensor`, placed on
|
||||
:attr:`devices`.
|
||||
- If :attr:`out` is specified,
|
||||
a tuple containing :attr:`out` tensors, each containing a copy of
|
||||
:attr:`tensor`.
|
||||
"""
|
||||
if not ((devices is None) ^ (out is None)):
|
||||
raise RuntimeError(
|
||||
"Exactly one of 'devices' and 'out' must be specified, but got "
|
||||
"devices={} and out={}".format(devices, out))
|
||||
if devices is not None:
|
||||
devices = [torch.cuda._utils._get_device_index(d) for d in devices]
|
||||
return torch._C._broadcast(tensor, devices)
|
||||
else:
|
||||
return torch._C._broadcast_out(tensor, out)
|
||||
|
||||
|
||||
def broadcast_coalesced(tensors, devices, buffer_size=10485760):
|
||||
r"""Broadcasts a sequence tensors to the specified CUDA devices.
|
||||
Small tensors are first coalesced into a buffer to reduce the number
|
||||
of synchronizations.
|
||||
|
||||
Arguments:
|
||||
tensors (sequence): tensors to broadcast. Must be on the same device,
|
||||
either CPU or CUDA.
|
||||
devices (Iterable[torch.device, str or int]): an iterable of CUDA
|
||||
devices, among which to broadcast.
|
||||
buffer_size (int): maximum size of the buffer used for coalescing
|
||||
|
||||
Returns:
|
||||
A tuple containing copies of :attr:`tensor`, placed on :attr:`devices`.
|
||||
"""
|
||||
devices = [torch.cuda._utils._get_device_index(d) for d in devices]
|
||||
return torch._C._broadcast_coalesced(tensors, devices, buffer_size)
|
||||
|
||||
|
||||
def reduce_add(inputs, destination=None):
|
||||
"""Sums tensors from multiple GPUs.
|
||||
|
||||
All inputs should have matching shapes, dtype, and layout. The output tensor
|
||||
will be of the same shape, dtype, and layout.
|
||||
|
||||
Arguments:
|
||||
inputs (Iterable[Tensor]): an iterable of tensors to add.
|
||||
destination (int, optional): a device on which the output will be
|
||||
placed (default: current device).
|
||||
|
||||
Returns:
|
||||
A tensor containing an elementwise sum of all inputs, placed on the
|
||||
:attr:`destination` device.
|
||||
"""
|
||||
destination = torch.cuda._utils._get_device_index(destination, optional=True)
|
||||
input_size = inputs[0].size()
|
||||
root_index = None # index of input tensor that already is on the correct device
|
||||
for i, inp in enumerate(inputs):
|
||||
assert inp.is_cuda, "reduce_add expects all inputs to be on GPUs"
|
||||
if inp.get_device() == destination:
|
||||
root_index = i
|
||||
if inp.size() != input_size:
|
||||
got = 'x'.join(str(x) for x in inp.size())
|
||||
expected = 'x'.join(str(x) for x in input_size)
|
||||
raise ValueError("input {} has invalid size: got {}, but expected "
|
||||
"{}".format(i, got, expected))
|
||||
if root_index is None:
|
||||
raise RuntimeError("reduce_add expects destination to be on the same GPU with one of the tensors")
|
||||
|
||||
if len(inputs) == 1:
|
||||
return inputs[0]
|
||||
|
||||
if nccl.is_available(inputs):
|
||||
result = torch.empty_like(inputs[root_index])
|
||||
nccl.reduce(inputs, output=result, root=root_index)
|
||||
else:
|
||||
nonroot = [t for i, t in enumerate(inputs) if i != root_index]
|
||||
result = inputs[root_index] + nonroot[0].cuda(destination, non_blocking=True) # make a new tensor w/o clone
|
||||
for other in nonroot[1:]:
|
||||
result.add_(other.cuda(destination, non_blocking=True))
|
||||
return result
|
||||
|
||||
|
||||
def reduce_add_coalesced(inputs, destination=None, buffer_size=10485760):
|
||||
"""Sums tensors from multiple GPUs.
|
||||
|
||||
Small tensors are first coalesced into a buffer to reduce the number
|
||||
of synchronizations.
|
||||
|
||||
Arguments:
|
||||
inputs (Iterable[Iterable[Tensor]]): iterable of iterables that
|
||||
contain tensors from a single device.
|
||||
destination (int, optional): a device on which the output will be
|
||||
placed (default: current device).
|
||||
buffer_size (int): maximum size of the buffer used for coalescing
|
||||
|
||||
Returns:
|
||||
A tuple of tensors containing an elementwise sum of each group of
|
||||
inputs, placed on the ``destination`` device.
|
||||
"""
|
||||
# TODO: When `len(inputs) == 1` and all inputs are on `destination`, just
|
||||
# return `inputs`.
|
||||
dense_tensors = [[] for _ in inputs] # shape (num_gpus, num_tensors)
|
||||
output = []
|
||||
ref_order = []
|
||||
# process sparse ones first since they may have different sizes on different gpus
|
||||
for tensor_at_gpus in zip(*inputs):
|
||||
if all(t.is_sparse for t in tensor_at_gpus):
|
||||
result = reduce_add(tensor_at_gpus, destination) # this will be sparse too
|
||||
output.append(result)
|
||||
ref_order.append(tensor_at_gpus[0])
|
||||
else:
|
||||
for coll, t in zip(dense_tensors, tensor_at_gpus):
|
||||
coll.append(t.to_dense() if t.is_sparse else t)
|
||||
ref_order.append(dense_tensors[0][-1])
|
||||
itrs = [_take_tensors(tensors, buffer_size) for tensors in dense_tensors]
|
||||
# now the dense ones, which have consistent sizes
|
||||
for chunks in zip(*itrs):
|
||||
flat_tensors = [_flatten_dense_tensors(chunk) for chunk in chunks] # (num_gpus,)
|
||||
flat_result = reduce_add(flat_tensors, destination)
|
||||
for t in _unflatten_dense_tensors(flat_result, chunks[0]):
|
||||
# The unflattened tensors do not share storage, and we don't expose
|
||||
# base flat tensor anyways, so give them different version counters.
|
||||
# See NOTE [ Version Counter in comm.*_coalesced ]
|
||||
output.append(t.data)
|
||||
return tuple(_reorder_tensors_as(output, ref_order))
|
||||
|
||||
|
||||
def scatter(tensor, devices=None, chunk_sizes=None, dim=0, streams=None, *, out=None):
|
||||
r"""Scatters a tensor across multiple CUDA devices.
|
||||
|
||||
Arguments:
|
||||
tensor (Tensor): tensor to scatter. Can be on CPU or CUDA.
|
||||
devices (Iterable[torch.device, str or int], optional): an iterable of
|
||||
CUDA devices, among which to scatter.
|
||||
chunk_sizes (Iterable[int], optional): sizes of chunks to be placed on
|
||||
each device. It should match :attr:`devices` in length and sums to
|
||||
``tensor.size(dim)``. If not specified, :attr:`tensor` will be divided
|
||||
into equal chunks.
|
||||
dim (int, optional): A dimension along which to chunk :attr:`tensor`.
|
||||
Default: ``0``.
|
||||
out (Sequence[Tensor], optional, keyword-only): the CUDA tensors to
|
||||
store output results. Sizes of these tensors must match that of
|
||||
:attr:`tensor`, except for :attr:`dim`, where the total size must
|
||||
sum to ``tensor.size(dim)``.
|
||||
|
||||
.. note::
|
||||
Exactly one of :attr:`devices` and :attr:`out` must be specified. When
|
||||
:attr:`out` is specified, :attr:`chunk_sizes` must not be specified and
|
||||
will be inferred from sizes of :attr:`out`.
|
||||
|
||||
Returns:
|
||||
- If :attr:`devices` is specified,
|
||||
a tuple containing chunks of :attr:`tensor`, placed on
|
||||
:attr:`devices`.
|
||||
- If :attr:`out` is specified,
|
||||
a tuple containing :attr:`out` tensors, each containing a chunk of
|
||||
:attr:`tensor`.
|
||||
"""
|
||||
if out is None:
|
||||
devices = [torch.cuda._utils._get_device_index(d) for d in devices]
|
||||
return tuple(torch._C._scatter(tensor, devices, chunk_sizes, dim, streams))
|
||||
else:
|
||||
if devices is not None:
|
||||
raise RuntimeError(
|
||||
"'devices' must not be specified when 'out' is specified, but "
|
||||
"got devices={}".format(devices))
|
||||
if chunk_sizes is not None:
|
||||
raise RuntimeError(
|
||||
"'chunk_sizes' must not be specified when 'out' is specified, "
|
||||
"but got chunk_sizes={}".format(chunk_sizes))
|
||||
return tuple(torch._C._scatter_out(tensor, out, dim, streams))
|
||||
|
||||
def gather(tensors, dim=0, destination=None, *, out=None):
|
||||
r"""Gathers tensors from multiple CUDA devices.
|
||||
|
||||
Arguments:
|
||||
tensors (Iterable[Tensor]): an iterable of tensors to gather.
|
||||
Tensor sizes in all dimensions other than :attr:`dim` have to match.
|
||||
dim (int, optional): a dimension along which the tensors will be
|
||||
concatenated. Default: ``0``.
|
||||
destination (torch.device, str, or int, optional): the output device.
|
||||
Can be CPU or CUDA. Default: the current CUDA device.
|
||||
out (Tensor, optional, keyword-only): the tensor to store gather result.
|
||||
Its sizes must match those of :attr:`tensors`, except for :attr:`dim`,
|
||||
where the size must equal ``sum(tensor.size(dim) for tensor in tensors)``.
|
||||
Can be on CPU or CUDA.
|
||||
|
||||
.. note::
|
||||
:attr:`destination` must not be specified when :attr:`out` is specified.
|
||||
|
||||
Returns:
|
||||
- If :attr:`destination` is specified,
|
||||
a tensor located on :attr:`destination` device, that is a result of
|
||||
concatenating :attr:`tensors` along :attr:`dim`.
|
||||
- If :attr:`out` is specified,
|
||||
the :attr:`out` tensor, now containing results of concatenating
|
||||
:attr:`tensors` along :attr:`dim`.
|
||||
"""
|
||||
if out is None:
|
||||
if destination == -1:
|
||||
warnings.warn(
|
||||
'Using -1 to represent CPU tensor is deprecated. Please use a '
|
||||
'device object or string instead, e.g., "cpu".')
|
||||
destination = torch.cuda._utils._get_device_index(destination, allow_cpu=True, optional=True)
|
||||
return torch._C._gather(tensors, dim, destination)
|
||||
else:
|
||||
if destination is not None:
|
||||
raise RuntimeError(
|
||||
"'destination' must not be specified when 'out' is specified, but "
|
||||
"got destination={}".format(destination))
|
||||
return torch._C._gather_out(tensors, out, dim)
|
||||
__all__ = [broadcast, broadcast_coalesced, reduce_add, reduce_add_coalesced, scatter, gather]
|
||||
|
|
|
|||
|
|
@ -1,17 +1,18 @@
|
|||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.cuda.comm as comm
|
||||
from . import comm
|
||||
from torch.autograd import Function
|
||||
from torch.cuda._utils import _get_device_index
|
||||
from torch._utils import _get_device_index
|
||||
|
||||
|
||||
class Broadcast(Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, target_gpus, *inputs):
|
||||
if not all(input.is_cuda for input in inputs):
|
||||
raise TypeError('Broadcast function not implemented for CPU tensors')
|
||||
assert all(map(lambda i: i.device.type != 'cpu', inputs)), (
|
||||
'Broadcast function not implemented for CPU tensors'
|
||||
)
|
||||
target_gpus = list(map(lambda x: _get_device_index(x, True), target_gpus))
|
||||
ctx.target_gpus = target_gpus
|
||||
if len(inputs) == 0:
|
||||
|
|
@ -51,7 +52,9 @@ class Gather(Function):
|
|||
|
||||
@staticmethod
|
||||
def forward(ctx, target_device, dim, *inputs):
|
||||
assert all(map(lambda i: i.is_cuda, inputs))
|
||||
assert all(map(lambda i: i.device.type != 'cpu', inputs)), (
|
||||
'Gather function not implemented for CPU tensors'
|
||||
)
|
||||
target_device = _get_device_index(target_device, True)
|
||||
ctx.target_device = target_device
|
||||
ctx.dim = dim
|
||||
|
|
@ -81,9 +84,9 @@ class Scatter(Function):
|
|||
def forward(ctx, target_gpus, chunk_sizes, dim, input):
|
||||
target_gpus = list(map(lambda x: _get_device_index(x, True), target_gpus))
|
||||
ctx.dim = dim
|
||||
ctx.input_device = input.get_device() if input.is_cuda else -1
|
||||
ctx.input_device = input.get_device() if input.device.type != "cpu" else -1
|
||||
streams = None
|
||||
if ctx.input_device == -1:
|
||||
if torch.cuda.is_available() and ctx.input_device == -1:
|
||||
# Perform CPU to GPU copies in a background stream
|
||||
streams = [_get_stream(device) for device in target_gpus]
|
||||
outputs = comm.scatter(input, target_gpus, chunk_sizes, ctx.dim, streams)
|
||||
|
|
|
|||
233
torch/nn/parallel/comm.py
Normal file
233
torch/nn/parallel/comm.py
Normal file
|
|
@ -0,0 +1,233 @@
|
|||
import warnings
|
||||
import torch
|
||||
from torch.cuda import nccl
|
||||
from torch._utils import _take_tensors, _flatten_dense_tensors, \
|
||||
_unflatten_dense_tensors, _reorder_tensors_as, _get_device_index
|
||||
|
||||
|
||||
def broadcast(tensor, devices=None, *, out=None):
|
||||
r"""Broadcasts a tensor to specified GPU devices.
|
||||
|
||||
Arguments:
|
||||
tensor (Tensor): tensor to broadcast. Can be on CPU or GPU.
|
||||
devices (Iterable[torch.device, str or int], optional): an iterable of
|
||||
GPU devices, among which to broadcast.
|
||||
out (Sequence[Tensor], optional, keyword-only): the GPU tensors to
|
||||
store output results.
|
||||
|
||||
.. note::
|
||||
Exactly one of :attr:`devices` and :attr:`out` must be specified.
|
||||
|
||||
Returns:
|
||||
- If :attr:`devices` is specified,
|
||||
a tuple containing copies of :attr:`tensor`, placed on
|
||||
:attr:`devices`.
|
||||
- If :attr:`out` is specified,
|
||||
a tuple containing :attr:`out` tensors, each containing a copy of
|
||||
:attr:`tensor`.
|
||||
"""
|
||||
if not ((devices is None) ^ (out is None)):
|
||||
raise RuntimeError(
|
||||
"Exactly one of 'devices' and 'out' must be specified, but got "
|
||||
"devices={} and out={}".format(devices, out))
|
||||
if devices is not None:
|
||||
devices = [_get_device_index(d) for d in devices]
|
||||
return torch._C._broadcast(tensor, devices)
|
||||
else:
|
||||
return torch._C._broadcast_out(tensor, out)
|
||||
|
||||
|
||||
def broadcast_coalesced(tensors, devices, buffer_size=10485760):
|
||||
"""Broadcasts a sequence tensors to the specified GPUs.
|
||||
Small tensors are first coalesced into a buffer to reduce the number
|
||||
of synchronizations.
|
||||
|
||||
Arguments:
|
||||
tensors (sequence): tensors to broadcast. Must be on the same device,
|
||||
either CPU or GPU.
|
||||
devices (Iterable[torch.device, str or int]): an iterable of GPU
|
||||
devices, among which to broadcast.
|
||||
buffer_size (int): maximum size of the buffer used for coalescing
|
||||
|
||||
Returns:
|
||||
A tuple containing copies of :attr:`tensor`, placed on :attr:`devices`.
|
||||
"""
|
||||
devices = [_get_device_index(d) for d in devices]
|
||||
return torch._C._broadcast_coalesced(tensors, devices, buffer_size)
|
||||
|
||||
|
||||
def reduce_add(inputs, destination=None):
|
||||
"""Sums tensors from multiple GPUs.
|
||||
|
||||
All inputs should have matching shapes, dtype, and layout. The output tensor
|
||||
will be of the same shape, dtype, and layout.
|
||||
|
||||
Arguments:
|
||||
inputs (Iterable[Tensor]): an iterable of tensors to add.
|
||||
destination (int, optional): a device on which the output will be
|
||||
placed (default: current device).
|
||||
|
||||
Returns:
|
||||
A tensor containing an elementwise sum of all inputs, placed on the
|
||||
:attr:`destination` device.
|
||||
"""
|
||||
destination = _get_device_index(destination, optional=True)
|
||||
input_size = inputs[0].size()
|
||||
root_index = None # index of input tensor that already is on the correct device
|
||||
for i, inp in enumerate(inputs):
|
||||
assert inp.device.type != "cpu", "reduce_add expects all inputs to be on GPUs"
|
||||
if inp.get_device() == destination:
|
||||
root_index = i
|
||||
if inp.size() != input_size:
|
||||
got = 'x'.join(str(x) for x in inp.size())
|
||||
expected = 'x'.join(str(x) for x in input_size)
|
||||
raise ValueError("input {} has invalid size: got {}, but expected "
|
||||
"{}".format(i, got, expected))
|
||||
if root_index is None:
|
||||
raise RuntimeError("reduce_add expects destination to be on the same GPU with one of the tensors")
|
||||
|
||||
if len(inputs) == 1:
|
||||
return inputs[0]
|
||||
|
||||
if nccl.is_available(inputs):
|
||||
result = torch.empty_like(inputs[root_index])
|
||||
nccl.reduce(inputs, output=result, root=root_index)
|
||||
else:
|
||||
destination_device = torch.device(inputs[root_index].device.type, destination)
|
||||
nonroot = [t for i, t in enumerate(inputs) if i != root_index]
|
||||
# make a new tensor w/o clone
|
||||
result = inputs[root_index] + nonroot[0].to(device=destination_device, non_blocking=True)
|
||||
for other in nonroot[1:]:
|
||||
result.add_(other.to(device=destination_device, non_blocking=True))
|
||||
return result
|
||||
|
||||
|
||||
def reduce_add_coalesced(inputs, destination=None, buffer_size=10485760):
|
||||
"""Sums tensors from multiple GPUs.
|
||||
|
||||
Small tensors are first coalesced into a buffer to reduce the number
|
||||
of synchronizations.
|
||||
|
||||
Arguments:
|
||||
inputs (Iterable[Iterable[Tensor]]): iterable of iterables that
|
||||
contain tensors from a single device.
|
||||
destination (int, optional): a device on which the output will be
|
||||
placed (default: current device).
|
||||
buffer_size (int): maximum size of the buffer used for coalescing
|
||||
|
||||
Returns:
|
||||
A tuple of tensors containing an elementwise sum of each group of
|
||||
inputs, placed on the ``destination`` device.
|
||||
"""
|
||||
# TODO: When `len(inputs) == 1` and all inputs are on `destination`, just
|
||||
# return `inputs`.
|
||||
dense_tensors = [[] for _ in inputs] # shape (num_gpus, num_tensors)
|
||||
output = []
|
||||
ref_order = []
|
||||
# process sparse ones first since they may have different sizes on different gpus
|
||||
for tensor_at_gpus in zip(*inputs):
|
||||
if all(t.is_sparse for t in tensor_at_gpus):
|
||||
result = reduce_add(tensor_at_gpus, destination) # this will be sparse too
|
||||
output.append(result)
|
||||
ref_order.append(tensor_at_gpus[0])
|
||||
else:
|
||||
for coll, t in zip(dense_tensors, tensor_at_gpus):
|
||||
coll.append(t.to_dense() if t.is_sparse else t)
|
||||
ref_order.append(dense_tensors[0][-1])
|
||||
itrs = [_take_tensors(tensors, buffer_size) for tensors in dense_tensors]
|
||||
# now the dense ones, which have consistent sizes
|
||||
for chunks in zip(*itrs):
|
||||
flat_tensors = [_flatten_dense_tensors(chunk) for chunk in chunks] # (num_gpus,)
|
||||
flat_result = reduce_add(flat_tensors, destination)
|
||||
for t in _unflatten_dense_tensors(flat_result, chunks[0]):
|
||||
# The unflattened tensors do not share storage, and we don't expose
|
||||
# base flat tensor anyways, so give them different version counters.
|
||||
# See NOTE [ Version Counter in comm.*_coalesced ]
|
||||
output.append(t.data)
|
||||
return tuple(_reorder_tensors_as(output, ref_order))
|
||||
|
||||
|
||||
def scatter(tensor, devices=None, chunk_sizes=None, dim=0, streams=None, *, out=None):
|
||||
"""Scatters tensor across multiple GPUs.
|
||||
|
||||
Arguments:
|
||||
tensor (Tensor): tensor to scatter. Can be on CPU or GPU.
|
||||
devices (Iterable[torch.device, str or int], optional): an iterable of
|
||||
GPU devices, among which to scatter.
|
||||
chunk_sizes (Iterable[int], optional): sizes of chunks to be placed on
|
||||
each device. It should match :attr:`devices` in length and sums to
|
||||
``tensor.size(dim)``. If not specified, :attr:`tensor` will be divided
|
||||
into equal chunks.
|
||||
dim (int, optional): A dimension along which to chunk :attr:`tensor`.
|
||||
Default: ``0``.
|
||||
out (Sequence[Tensor], optional, keyword-only): the GPU tensors to
|
||||
store output results. Sizes of these tensors must match that of
|
||||
:attr:`tensor`, except for :attr:`dim`, where the total size must
|
||||
sum to ``tensor.size(dim)``.
|
||||
|
||||
.. note::
|
||||
Exactly one of :attr:`devices` and :attr:`out` must be specified. When
|
||||
:attr:`out` is specified, :attr:`chunk_sizes` must not be specified and
|
||||
will be inferred from sizes of :attr:`out`.
|
||||
|
||||
Returns:
|
||||
- If :attr:`devices` is specified,
|
||||
a tuple containing chunks of :attr:`tensor`, placed on
|
||||
:attr:`devices`.
|
||||
- If :attr:`out` is specified,
|
||||
a tuple containing :attr:`out` tensors, each containing a chunk of
|
||||
:attr:`tensor`.
|
||||
"""
|
||||
if out is None:
|
||||
devices = [_get_device_index(d) for d in devices]
|
||||
return tuple(torch._C._scatter(tensor, devices, chunk_sizes, dim, streams))
|
||||
else:
|
||||
if devices is not None:
|
||||
raise RuntimeError(
|
||||
"'devices' must not be specified when 'out' is specified, but "
|
||||
"got devices={}".format(devices))
|
||||
if chunk_sizes is not None:
|
||||
raise RuntimeError(
|
||||
"'chunk_sizes' must not be specified when 'out' is specified, "
|
||||
"but got chunk_sizes={}".format(chunk_sizes))
|
||||
return tuple(torch._C._scatter_out(tensor, out, dim, streams))
|
||||
|
||||
def gather(tensors, dim=0, destination=None, *, out=None):
|
||||
r"""Gathers tensors from multiple GPU devices.
|
||||
|
||||
Arguments:
|
||||
tensors (Iterable[Tensor]): an iterable of tensors to gather.
|
||||
Tensor sizes in all dimensions other than :attr:`dim` have to match.
|
||||
dim (int, optional): a dimension along which the tensors will be
|
||||
concatenated. Default: ``0``.
|
||||
destination (torch.device, str, or int, optional): the output device.
|
||||
Can be CPU or CUDA. Default: the current CUDA device.
|
||||
out (Tensor, optional, keyword-only): the tensor to store gather result.
|
||||
Its sizes must match those of :attr:`tensors`, except for :attr:`dim`,
|
||||
where the size must equal ``sum(tensor.size(dim) for tensor in tensors)``.
|
||||
Can be on CPU or CUDA.
|
||||
|
||||
.. note::
|
||||
:attr:`destination` must not be specified when :attr:`out` is specified.
|
||||
|
||||
Returns:
|
||||
- If :attr:`destination` is specified,
|
||||
a tensor located on :attr:`destination` device, that is a result of
|
||||
concatenating :attr:`tensors` along :attr:`dim`.
|
||||
- If :attr:`out` is specified,
|
||||
the :attr:`out` tensor, now containing results of concatenating
|
||||
:attr:`tensors` along :attr:`dim`.
|
||||
"""
|
||||
if out is None:
|
||||
if destination == -1:
|
||||
warnings.warn(
|
||||
'Using -1 to represent CPU tensor is deprecated. Please use a '
|
||||
'device object or string instead, e.g., "cpu".')
|
||||
destination = _get_device_index(destination, allow_cpu=True, optional=True)
|
||||
return torch._C._gather(tensors, dim, destination)
|
||||
else:
|
||||
if destination is not None:
|
||||
raise RuntimeError(
|
||||
"'destination' must not be specified when 'out' is specified, but "
|
||||
"got destination={}".format(destination))
|
||||
return torch._C._gather_out(tensors, out, dim)
|
||||
|
|
@ -6,8 +6,12 @@ from ..modules import Module
|
|||
from .scatter_gather import scatter_kwargs, gather
|
||||
from .replicate import replicate
|
||||
from .parallel_apply import parallel_apply
|
||||
from torch.cuda._utils import _get_device_index
|
||||
|
||||
from torch._utils import (
|
||||
_get_all_device_indices,
|
||||
_get_available_device_type,
|
||||
_get_device_index,
|
||||
_get_devices_properties
|
||||
)
|
||||
|
||||
def _check_balance(device_ids):
|
||||
imbalance_warn = """
|
||||
|
|
@ -16,7 +20,7 @@ def _check_balance(device_ids):
|
|||
the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES
|
||||
environment variable."""
|
||||
device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
|
||||
dev_props = [torch.cuda.get_device_properties(i) for i in device_ids]
|
||||
dev_props = _get_devices_properties(device_ids)
|
||||
|
||||
def warn_imbalance(get_prop):
|
||||
values = [get_prop(props) for props in dev_props]
|
||||
|
|
@ -117,13 +121,15 @@ class DataParallel(Module):
|
|||
def __init__(self, module, device_ids=None, output_device=None, dim=0):
|
||||
super(DataParallel, self).__init__()
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
device_type = _get_available_device_type()
|
||||
if device_type is None:
|
||||
self.module = module
|
||||
self.device_ids = []
|
||||
return
|
||||
|
||||
if device_ids is None:
|
||||
device_ids = list(range(torch.cuda.device_count()))
|
||||
device_ids = _get_all_device_indices()
|
||||
|
||||
if output_device is None:
|
||||
output_device = device_ids[0]
|
||||
|
||||
|
|
@ -131,12 +137,12 @@ class DataParallel(Module):
|
|||
self.module = module
|
||||
self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
|
||||
self.output_device = _get_device_index(output_device, True)
|
||||
self.src_device_obj = torch.device("cuda:{}".format(self.device_ids[0]))
|
||||
self.src_device_obj = torch.device(device_type, self.device_ids[0])
|
||||
|
||||
_check_balance(self.device_ids)
|
||||
|
||||
if len(self.device_ids) == 1:
|
||||
self.module.cuda(device_ids[0])
|
||||
self.module.to(self.src_device_obj)
|
||||
|
||||
def forward(self, *inputs, **kwargs):
|
||||
if not self.device_ids:
|
||||
|
|
@ -186,15 +192,17 @@ def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, mo
|
|||
if not isinstance(inputs, tuple):
|
||||
inputs = (inputs,)
|
||||
|
||||
device_type = _get_available_device_type()
|
||||
|
||||
if device_ids is None:
|
||||
device_ids = list(range(torch.cuda.device_count()))
|
||||
device_ids = _get_all_device_indices()
|
||||
|
||||
if output_device is None:
|
||||
output_device = device_ids[0]
|
||||
|
||||
device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
|
||||
output_device = _get_device_index(output_device, True)
|
||||
src_device_obj = torch.device("cuda:{}".format(device_ids[0]))
|
||||
src_device_obj = torch.device(device_type, device_ids[0])
|
||||
|
||||
for t in chain(module.parameters(), module.buffers()):
|
||||
if t.device != src_device_obj:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import itertools
|
|||
|
||||
import torch
|
||||
|
||||
import torch.cuda.comm
|
||||
from . import comm
|
||||
import torch.distributed as dist
|
||||
|
||||
if dist.is_available():
|
||||
|
|
@ -14,7 +14,7 @@ from ..modules import Module
|
|||
from .replicate import replicate
|
||||
from .scatter_gather import scatter_kwargs, gather
|
||||
from .parallel_apply import parallel_apply
|
||||
from torch.cuda._utils import _get_device_index
|
||||
from torch._utils import _get_device_index, _get_all_device_indices
|
||||
|
||||
|
||||
def _find_tensors(obj):
|
||||
|
|
@ -268,21 +268,26 @@ class DistributedDataParallel(Module):
|
|||
)
|
||||
|
||||
self.is_multi_device_module = len({p.device for p in module.parameters()}) > 1
|
||||
self.is_cuda = all([p.device.type == 'cuda' for p in module.parameters()])
|
||||
distinct_device_types = {p.device.type for p in module.parameters()}
|
||||
assert len(distinct_device_types) == 1, (
|
||||
"DistributedDataParallel's input module must be on "
|
||||
"the same type of devices, but input module parameters locate in {}."
|
||||
).format(distinct_device_types)
|
||||
self.device_type = list(distinct_device_types)[0]
|
||||
|
||||
if not self.is_cuda or self.is_multi_device_module:
|
||||
if self.device_type == "cpu" or self.is_multi_device_module:
|
||||
assert not device_ids and not output_device, (
|
||||
"DistributedDataParallel device_ids and output_device arguments "
|
||||
"only work with single-device CUDA modules, but got "
|
||||
"only work with single-device GPU modules, but got "
|
||||
"device_ids {}, output_device {}, and module parameters {}."
|
||||
).format(device_ids, output_device, {p.device for p in module.parameters()})
|
||||
|
||||
self.device_ids = None
|
||||
self.output_device = None
|
||||
else:
|
||||
# Use all devices by default for single-device CUDA modules
|
||||
# Use all devices by default for single-device GPU modules
|
||||
if device_ids is None:
|
||||
device_ids = list(range(torch.cuda.device_count()))
|
||||
device_ids = _get_all_device_indices()
|
||||
|
||||
self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
|
||||
|
||||
|
|
@ -291,12 +296,6 @@ class DistributedDataParallel(Module):
|
|||
|
||||
self.output_device = _get_device_index(output_device, True)
|
||||
|
||||
if self.is_multi_device_module:
|
||||
assert self.is_cuda, (
|
||||
"DistributedDataParallel with multi-device module only works "
|
||||
"with CUDA devices, but module parameters locate in {}."
|
||||
).format({p.device for p in module.parameters()})
|
||||
|
||||
if process_group is None:
|
||||
self.process_group = _get_default_group()
|
||||
else:
|
||||
|
|
@ -551,7 +550,7 @@ class DistributedDataParallel(Module):
|
|||
# CUDA modules
|
||||
if self.device_ids and len(self.device_ids) > 1:
|
||||
# intra-node parameter sync
|
||||
result = torch.cuda.comm.broadcast_coalesced(
|
||||
result = comm.broadcast_coalesced(
|
||||
self.modules_params[0],
|
||||
self.device_ids,
|
||||
self.broadcast_bucket_size)
|
||||
|
|
@ -584,7 +583,7 @@ class DistributedDataParallel(Module):
|
|||
# CUDA modules
|
||||
if self.device_ids and len(self.device_ids) > 1:
|
||||
# intra-node buffer sync
|
||||
result = torch.cuda.comm.broadcast_coalesced(
|
||||
result = comm.broadcast_coalesced(
|
||||
self.modules_buffers[0],
|
||||
self.device_ids,
|
||||
self.broadcast_bucket_size)
|
||||
|
|
@ -597,6 +596,6 @@ class DistributedDataParallel(Module):
|
|||
for dev_idx, module in enumerate(module_copies):
|
||||
for layer in module.modules():
|
||||
if isinstance(layer, torch.nn.modules.SyncBatchNorm):
|
||||
assert self.is_cuda, "SyncBatchNorm layers only work with CUDA modules"
|
||||
assert self.device_type != 'cpu', "SyncBatchNorm layers only work with GPU modules"
|
||||
layer._specify_ddp_gpu_num(
|
||||
len(self.device_ids) if self.device_ids else 1)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import torch.cuda.comm as comm
|
||||
from torch.cuda._utils import _get_device_index
|
||||
from . import comm
|
||||
from torch._utils import _get_device_index
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user