mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17653 Differential Revision: D14302003 Pulled By: ezyang fbshipit-source-id: 8ad90985a392b07127c7e315d4e74ce77962b573
108 lines
4.2 KiB
Python
108 lines
4.2 KiB
Python
import torch
|
|
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
|
import torch.distributed.deprecated as dist
|
|
from torch.nn.modules import Module
|
|
from collections import defaultdict
|
|
from torch.autograd import Variable
|
|
import torch.utils.hooks
|
|
|
|
|
|
class DistributedDataParallelCPU(Module):
|
|
r"""Implements distributed data parallelism for CPU at the module level.
|
|
|
|
This module support the ``mpi``, ``gloo``, ``tcp`` backends.
|
|
|
|
This container parallelizes the application of the given module by
|
|
splitting the input across the specified devices by chunking in the batch
|
|
dimension. The module is replicated on each machine, and each such replica
|
|
handles a portion of the input. During the backwards pass, gradients from
|
|
each node are averaged.
|
|
|
|
This module could be used in conjunction with the DistributedSampler,
|
|
(see :class `torch.utils.data.distributed.DistributedSampler`)
|
|
which will load a subset of the original dataset for each node with the same
|
|
batch size. So strong scaling should be configured like this:
|
|
n = 1, batch size = 128
|
|
n = 2, batch size = 64
|
|
n = 4, batch size = 32
|
|
n = 8, batch size = 16
|
|
|
|
Creation of this class requires the distributed package to be already
|
|
initialized in the process group mode
|
|
(see :func:`torch.distributed.deprecated.init_process_group`).
|
|
|
|
.. warning::
|
|
Constructor, forward method, and differentiation of the output (or a
|
|
function of the output of this module) is a distributed synchronization
|
|
point. Take that into account in case different node might be
|
|
executing different code.
|
|
|
|
.. warning::
|
|
This module assumes all parameters are registered in the model by the
|
|
time it is created. No parameters should be added nor removed later.
|
|
|
|
.. warning::
|
|
This module assumes all gradients are dense.
|
|
|
|
.. warning::
|
|
This module doesn't work with :func:`torch.autograd.grad` (i.e. it will
|
|
only work if gradients are to be accumulated in ``.grad`` attributes of
|
|
parameters).
|
|
|
|
.. note::
|
|
Parameters are broadcast between nodes in the __init__() function. The
|
|
module performs an all-reduce step on gradients and assumes that they
|
|
will be modified by the optimizer in all nodes in the same way.
|
|
|
|
.. warning::
|
|
Forward and backward hooks defined on :attr:`module` and its submodules
|
|
won't be invoked anymore, unless the hooks are initialized in the
|
|
:meth:`forward` method.
|
|
|
|
Args:
|
|
module: module to be parallelized
|
|
|
|
Example::
|
|
|
|
>>> torch.distributed.deprecated.init_process_group(world_size=4, init_method='...')
|
|
>>> net = torch.nn.DistributedDataParallelCPU(model)
|
|
"""
|
|
|
|
def __init__(self, module):
|
|
super(DistributedDataParallelCPU, self).__init__()
|
|
self.module = module
|
|
self.sync_parameters()
|
|
|
|
def allreduce_params():
|
|
if self.needs_reduction:
|
|
self.needs_reduction = False
|
|
buckets = defaultdict(list)
|
|
for param in self.module.parameters():
|
|
if param.requires_grad and param.grad is not None:
|
|
tp = type(param.data)
|
|
buckets[tp].append(param)
|
|
|
|
for bucket in buckets.values():
|
|
grads = [param.grad.data for param in bucket]
|
|
coalesced = _flatten_dense_tensors(grads)
|
|
dist.all_reduce(coalesced)
|
|
coalesced /= dist.get_world_size()
|
|
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
|
|
buf.copy_(synced)
|
|
|
|
for param in list(self.module.parameters()):
|
|
@torch.utils.hooks.unserializable_hook
|
|
def allreduce_hook(*unused):
|
|
Variable._execution_engine.queue_callback(allreduce_params)
|
|
|
|
if param.requires_grad:
|
|
param.register_hook(allreduce_hook)
|
|
|
|
def sync_parameters(self):
|
|
for param in self.module.parameters():
|
|
dist.broadcast(param.data, 0)
|
|
|
|
def forward(self, *inputs, **kwargs):
|
|
self.needs_reduction = True
|
|
return self.module(*inputs, **kwargs)
|