mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
* Codemod to update our codebase to 0.4 standard * Update some of the test scri[ts * remove Variable in test_clip_grad_value * fix _symbolic_override_wrapper_maker
71 lines
2.6 KiB
Python
71 lines
2.6 KiB
Python
import torch
|
|
from ._functions import Scatter, Gather
|
|
|
|
|
|
def scatter(inputs, target_gpus, dim=0):
|
|
r"""
|
|
Slices tensors into approximately equal chunks and
|
|
distributes them across given GPUs. Duplicates
|
|
references to objects that are not tensors. Does not
|
|
support Tensors.
|
|
"""
|
|
def scatter_map(obj):
|
|
if isinstance(obj, torch.Tensor):
|
|
return Scatter.apply(target_gpus, None, dim, obj)
|
|
if isinstance(obj, tuple) and len(obj) > 0:
|
|
return list(zip(*map(scatter_map, obj)))
|
|
if isinstance(obj, list) and len(obj) > 0:
|
|
return list(map(list, zip(*map(scatter_map, obj))))
|
|
if isinstance(obj, dict) and len(obj) > 0:
|
|
return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
|
|
return [obj for targets in target_gpus]
|
|
|
|
# After scatter_map is called, a scatter_map cell will exist. This cell
|
|
# has a reference to the actual function scatter_map, which has references
|
|
# to a closure that has a reference to the scatter_map cell (because the
|
|
# fn is recursive). To avoid this reference cycle, we set the function to
|
|
# None, clearing the cell
|
|
try:
|
|
return scatter_map(inputs)
|
|
finally:
|
|
scatter_map = None
|
|
|
|
|
|
def scatter_kwargs(inputs, kwargs, target_gpus, dim=0):
|
|
r"""Scatter with support for kwargs dictionary"""
|
|
inputs = scatter(inputs, target_gpus, dim) if inputs else []
|
|
kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
|
|
if len(inputs) < len(kwargs):
|
|
inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
|
|
elif len(kwargs) < len(inputs):
|
|
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
|
|
inputs = tuple(inputs)
|
|
kwargs = tuple(kwargs)
|
|
return inputs, kwargs
|
|
|
|
|
|
def gather(outputs, target_device, dim=0):
|
|
r"""
|
|
Gathers tensors from different GPUs on a specified device
|
|
(-1 means the CPU).
|
|
"""
|
|
def gather_map(outputs):
|
|
out = outputs[0]
|
|
if isinstance(out, torch.Tensor):
|
|
return Gather.apply(target_device, dim, *outputs)
|
|
if out is None:
|
|
return None
|
|
if isinstance(out, dict):
|
|
if not all((len(out) == len(d) for d in outputs)):
|
|
raise ValueError('All dicts must have the same number of keys')
|
|
return type(out)(((k, gather_map([d[k] for d in outputs]))
|
|
for k in out))
|
|
return type(out)(map(gather_map, zip(*outputs)))
|
|
|
|
# Recursive function calls like this create reference cycles.
|
|
# Setting the function to None clears the refcycle.
|
|
try:
|
|
return gather_map(outputs)
|
|
finally:
|
|
gather_map = None
|