pytorch/torch/nn/parallel/data_parallel.py
Sam Gross ea728e7c5e Add DataParallel container (#268)
Adds a container version of the `data_parallel` function. This is a
drop-in replacement for the DataParallel class in the ImageNet example.
2016-11-29 16:36:01 -05:00

86 lines
3.0 KiB
Python

import torch
from torch.nn.modules import Container
from .scatter_gather import scatter, gather
from .replicate import replicate
from .parallel_apply import parallel_apply
class DataParallel(Container):
"""Implements data parallelism at the module level.
This container parallelizes the application of the given module by
splitting the input across the specified devices. In the forward pass, the
module is replicated on each device, and each replica handles a portion of
the input. During the backwards pass, gradients from each replica are
summed into the original module.
Args:
module: module to be parallelized
device_ids: CUDA devices (default: all devices)
output_device: device location of output (default: device_ids[0])
Example:
>>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
>>> output = net(input)
"""
def __init__(self, module, device_ids=None, output_device=None):
super(DataParallel, self).__init__()
if device_ids is None:
device_ids = list(range(torch.cuda.device_count()))
if output_device is None:
output_device = device_ids[0]
self.module = module
self.device_ids = device_ids
self.output_device = output_device
if len(self.device_ids) == 1:
self.module.cuda(device_ids[0])
def forward(self, input):
if len(self.device_ids) == 1:
return self.module(input.cuda(self.device_ids[0]))
replicas = self.replicate(self.module, self.device_ids)
inputs = self.scatter(input, self.device_ids)
replicas = replicas[:len(inputs)]
outputs = self.parallel_apply(replicas, inputs)
return self.gather(outputs, self.output_device)
def replicate(self, module, device_ids):
return replicate(module, device_ids)
def scatter(self, input, device_ids):
return scatter(input, device_ids)
def parallel_apply(self, replicas, inputs):
return parallel_apply(replicas, inputs)
def gather(self, outputs, output_device):
return gather(outputs, output_device)
def data_parallel(module, input, device_ids, output_device=None):
"""Evaluates module(input) in parallel across the GPUs given in device_ids.
This is the functional version of the DataParallel module.
Args:
module: the module to evaluate in parallel
input: input to the module
device_ids: GPU ids on which to replicate module
output_device: GPU location of the output Use -1 to indicate the CPU.
(default: device_ids[0])
Returns:
a Variable containing the result of module(input) located on
output_device
"""
if not device_ids:
return module(input)
if output_device is None:
output_device = device_ids[0]
replicas = replicate(module, device_ids)
inputs = scatter(input, device_ids)
replicas = replicas[:len(inputs)]
outputs = parallel_apply(replicas, inputs)
return gather(outputs, output_device)