mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
123 lines
3.9 KiB
Python
123 lines
3.9 KiB
Python
import torch
|
|
from . import _INITIALIZED_PG, _INITIALIZED_MW
|
|
|
|
|
|
class reduce_op(object):
|
|
SUM = object()
|
|
PRODUCT = object()
|
|
MAX = object()
|
|
MIN = object()
|
|
|
|
|
|
class group(object):
|
|
WORLD = object()
|
|
|
|
|
|
class _DistributedRequest(object):
|
|
def __init__(self, request):
|
|
self.request = request
|
|
|
|
def is_completed(self):
|
|
return torch._C._dist_request_is_completed(self.request)
|
|
|
|
def wait(self):
|
|
torch._C._dist_request_wait(self.request)
|
|
|
|
|
|
def get_rank():
|
|
assert torch.distributed._initialized
|
|
return torch._C._dist_get_rank()
|
|
|
|
|
|
def get_world_size():
|
|
assert torch.distributed._initialized
|
|
return torch._C._dist_get_num_processes()
|
|
|
|
|
|
def isend(tensor, dst):
|
|
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
|
"collective only supported in process-group mode"
|
|
return _DistributedRequest(torch._C._dist_isend(tensor, dst))
|
|
|
|
|
|
def irecv(tensor, src):
|
|
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
|
"collective only supported in process-group mode"
|
|
return _DistributedRequest(torch._C._dist_irecv(tensor, src))
|
|
|
|
|
|
def send(tensor, dst):
|
|
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
|
"collective only supported in process-group mode"
|
|
return torch._C._dist_send(tensor, dst)
|
|
|
|
|
|
def recv(tensor, src=None):
|
|
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
|
"collective only supported in process-group mode"
|
|
if src is None:
|
|
return torch._C._dist_recv_any_source(tensor)
|
|
return torch._C._dist_recv(tensor, src)
|
|
|
|
|
|
def broadcast(tensor, src, group=group.WORLD):
|
|
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
|
"collective only supported in process-group mode"
|
|
return torch._C._dist_broadcast(tensor, src, group)
|
|
|
|
|
|
def all_reduce(tensor, op=reduce_op.SUM, group=group.WORLD):
|
|
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
|
"collective only supported in process-group mode"
|
|
return torch._C._dist_all_reduce(tensor, op, group)
|
|
|
|
|
|
def reduce(tensor, dst, op=reduce_op.SUM, group=group.WORLD):
|
|
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
|
"collective only supported in process-group mode"
|
|
return torch._C._dist_reduce(tensor, dst, op, group)
|
|
|
|
|
|
def all_gather(tensor_list, tensor, group=group.WORLD):
|
|
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
|
"collective only supported in process-group mode"
|
|
return torch._C._dist_all_gather(tensor_list, tensor, group)
|
|
|
|
|
|
def gather_send(tensor, dst, group=group.WORLD):
|
|
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
|
"collective only supported in process-group mode"
|
|
return torch._C._dist_gather_send(tensor, dst, group)
|
|
|
|
|
|
def gather_recv(tensor_list, tensor, group=group.WORLD):
|
|
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
|
"collective only supported in process-group mode"
|
|
return torch._C._dist_gather_recv(tensor_list, tensor, group)
|
|
|
|
|
|
def scatter_send(tensor_list, tensor, group=group.WORLD):
|
|
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
|
"collective only supported in process-group mode"
|
|
return torch._C._dist_scatter_send(tensor_list, tensor, group)
|
|
|
|
|
|
def scatter_recv(tensor, src, group=group.WORLD):
|
|
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
|
"collective only supported in process-group mode"
|
|
return torch._C._dist_scatter_recv(tensor, src, group)
|
|
|
|
|
|
def barrier(group=group.WORLD):
|
|
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
|
"collective only supported in process-group mode"
|
|
return torch._C._dist_barrier(group)
|
|
|
|
|
|
def new_group(ranks=None):
|
|
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
|
"collective only supported in process-group mode"
|
|
if ranks is None:
|
|
ranks = list(range(get_world_size()))
|
|
return torch._C._dist_new_group(ranks)
|