pytorch/torch/distributed/collectives.py
2017-06-12 22:02:11 -04:00

123 lines
3.9 KiB
Python

import torch
from . import _INITIALIZED_PG, _INITIALIZED_MW
class reduce_op(object):
SUM = object()
PRODUCT = object()
MAX = object()
MIN = object()
class group(object):
WORLD = object()
class _DistributedRequest(object):
def __init__(self, request):
self.request = request
def is_completed(self):
return torch._C._dist_request_is_completed(self.request)
def wait(self):
torch._C._dist_request_wait(self.request)
def get_rank():
assert torch.distributed._initialized
return torch._C._dist_get_rank()
def get_world_size():
assert torch.distributed._initialized
return torch._C._dist_get_num_processes()
def isend(tensor, dst):
assert torch.distributed._initialized == _INITIALIZED_PG, \
"collective only supported in process-group mode"
return _DistributedRequest(torch._C._dist_isend(tensor, dst))
def irecv(tensor, src):
assert torch.distributed._initialized == _INITIALIZED_PG, \
"collective only supported in process-group mode"
return _DistributedRequest(torch._C._dist_irecv(tensor, src))
def send(tensor, dst):
assert torch.distributed._initialized == _INITIALIZED_PG, \
"collective only supported in process-group mode"
return torch._C._dist_send(tensor, dst)
def recv(tensor, src=None):
assert torch.distributed._initialized == _INITIALIZED_PG, \
"collective only supported in process-group mode"
if src is None:
return torch._C._dist_recv_any_source(tensor)
return torch._C._dist_recv(tensor, src)
def broadcast(tensor, src, group=group.WORLD):
assert torch.distributed._initialized == _INITIALIZED_PG, \
"collective only supported in process-group mode"
return torch._C._dist_broadcast(tensor, src, group)
def all_reduce(tensor, op=reduce_op.SUM, group=group.WORLD):
assert torch.distributed._initialized == _INITIALIZED_PG, \
"collective only supported in process-group mode"
return torch._C._dist_all_reduce(tensor, op, group)
def reduce(tensor, dst, op=reduce_op.SUM, group=group.WORLD):
assert torch.distributed._initialized == _INITIALIZED_PG, \
"collective only supported in process-group mode"
return torch._C._dist_reduce(tensor, dst, op, group)
def all_gather(tensor_list, tensor, group=group.WORLD):
assert torch.distributed._initialized == _INITIALIZED_PG, \
"collective only supported in process-group mode"
return torch._C._dist_all_gather(tensor_list, tensor, group)
def gather_send(tensor, dst, group=group.WORLD):
assert torch.distributed._initialized == _INITIALIZED_PG, \
"collective only supported in process-group mode"
return torch._C._dist_gather_send(tensor, dst, group)
def gather_recv(tensor_list, tensor, group=group.WORLD):
assert torch.distributed._initialized == _INITIALIZED_PG, \
"collective only supported in process-group mode"
return torch._C._dist_gather_recv(tensor_list, tensor, group)
def scatter_send(tensor_list, tensor, group=group.WORLD):
assert torch.distributed._initialized == _INITIALIZED_PG, \
"collective only supported in process-group mode"
return torch._C._dist_scatter_send(tensor_list, tensor, group)
def scatter_recv(tensor, src, group=group.WORLD):
assert torch.distributed._initialized == _INITIALIZED_PG, \
"collective only supported in process-group mode"
return torch._C._dist_scatter_recv(tensor, src, group)
def barrier(group=group.WORLD):
assert torch.distributed._initialized == _INITIALIZED_PG, \
"collective only supported in process-group mode"
return torch._C._dist_barrier(group)
def new_group(ranks=None):
assert torch.distributed._initialized == _INITIALIZED_PG, \
"collective only supported in process-group mode"
if ranks is None:
ranks = list(range(get_world_size()))
return torch._C._dist_new_group(ranks)