mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
51 lines
1.7 KiB
Python
51 lines
1.7 KiB
Python
import torch
|
|
from torch.autograd import Variable
|
|
from ._functions import Scatter, Gather
|
|
|
|
|
|
def scatter(inputs, target_gpus, dim=0):
|
|
"""
|
|
Slices variables into approximately equal chunks and
|
|
distributes them accross given GPUs. Duplicates
|
|
references to objects that are not variables. Does not
|
|
support Tensors.
|
|
"""
|
|
def scatter_map(obj):
|
|
if isinstance(obj, Variable):
|
|
return Scatter.apply(target_gpus, None, dim, obj)
|
|
assert not torch.is_tensor(obj), "Tensors not supported in scatter."
|
|
if isinstance(obj, tuple):
|
|
return tuple(zip(*map(scatter_map, obj)))
|
|
if isinstance(obj, list):
|
|
return tuple(map(list, zip(*map(scatter_map, obj))))
|
|
if isinstance(obj, dict):
|
|
return tuple(map(type(obj), zip(*map(scatter_map, obj.items()))))
|
|
return tuple(obj for targets in target_gpus)
|
|
|
|
return scatter_map(inputs)
|
|
|
|
|
|
def scatter_kwargs(inputs, kwargs, target_gpus, dim=0):
|
|
"""Scatter with support for kwargs dictionary"""
|
|
inputs = scatter(inputs, target_gpus, dim)
|
|
if kwargs is None or len(kwargs) == 0:
|
|
kwargs = tuple({} for _ in inputs)
|
|
else:
|
|
kwargs = scatter(kwargs, target_gpus, dim)[:len(inputs)]
|
|
return inputs, kwargs
|
|
|
|
|
|
def gather(outputs, target_device, dim=0):
|
|
"""
|
|
Gathers variables from different GPUs on a specified device
|
|
(-1 means the CPU).
|
|
"""
|
|
def gather_map(outputs):
|
|
out = outputs[0]
|
|
if isinstance(out, Variable):
|
|
return Gather.apply(target_device, dim, *outputs)
|
|
if out is None:
|
|
return None
|
|
return type(out)(map(gather_map, zip(*outputs)))
|
|
return gather_map(outputs)
|