Rename arguments to distributed collectives

This commit is contained in:
Adam Paszke 2017-06-10 12:33:26 -04:00 committed by Soumith Chintala
parent 714351ff39
commit d9d50f80c7
5 changed files with 42 additions and 42 deletions

View File

@ -64,7 +64,7 @@ class Barrier(object):
data = f.read()
if int(data) >= cls.barrier_id:
arrived += 1
if arrived == dist.get_num_processes():
if arrived == dist.get_world_size():
break
if time.time() - start_time > timeout:
@ -87,7 +87,7 @@ class _DistTestBase(object):
return (group, group_id, rank)
def _init_global_test(self):
group = [i for i in range(0, dist.get_num_processes())]
group = [i for i in range(0, dist.get_world_size())]
group_id = dist.group.WORLD
rank = dist.get_rank()
return (group, group_id, rank)
@ -96,7 +96,7 @@ class _DistTestBase(object):
def test_get_rank(self):
test_dir = os.path.join(TEMP_DIR, 'test_dir')
pid = str(os.getpid())
num_processes = dist.get_num_processes()
num_processes = dist.get_world_size()
with open(os.path.join(test_dir, pid), 'w') as f:
f.write(str(dist.get_rank()))
@ -121,12 +121,12 @@ class _DistTestBase(object):
def test_send_recv(self):
rank = dist.get_rank()
tensor = _build_tensor(rank + 1)
for dest in range(0, dist.get_num_processes()):
for dest in range(0, dist.get_world_size()):
if dest == rank:
continue
dist.send(tensor, dest)
for src in range(0, dist.get_num_processes()):
for src in range(0, dist.get_world_size()):
if src == rank:
continue
tensor = _build_tensor(src + 1, value=-1)
@ -142,27 +142,27 @@ class _DistTestBase(object):
def test_send_recv_any_source(self):
rank = dist.get_rank()
tensor = _build_tensor(10, rank)
for dest in range(0, dist.get_num_processes()):
for dest in range(0, dist.get_world_size()):
if dest == rank:
continue
dist.send(tensor, dest)
recv_ranks = set()
for src in range(0, dist.get_num_processes()):
for src in range(0, dist.get_world_size()):
if src == rank:
continue
tensor = _build_tensor(10, value=-1)
dist.recv(tensor)
recv_ranks.add(tensor.resize_(1)[0])
self.assertEqual(len(recv_ranks), dist.get_num_processes() - 1)
self.assertEqual(len(recv_ranks), dist.get_world_size() - 1)
self._barrier()
# ISEND
@unittest.skipIf(BACKEND == 'gloo', "Gloo does not support isend")
def test_isend(self):
rank = dist.get_rank()
world_size = dist.get_num_processes()
world_size = dist.get_world_size()
if rank == 0:
requests = [
@ -182,7 +182,7 @@ class _DistTestBase(object):
@unittest.skipIf(BACKEND == 'gloo', "Gloo does not support irecv")
def test_irecv(self):
rank = dist.get_rank()
world_size = dist.get_num_processes()
world_size = dist.get_world_size()
if rank == 0:
expected_tensors = [_build_tensor(src, -1) for src in range(1, world_size)]

View File

@ -29,41 +29,41 @@ def get_rank():
return torch._C._dist_get_rank()
def get_num_processes():
def get_world_size():
assert torch.distributed._initialized
return torch._C._dist_get_num_processes()
def isend(tensor, dst_rank):
def isend(tensor, dst):
assert torch.distributed._initialized == _INITIALIZED_PG, \
"collective only supported in process-group mode"
return _DistributedRequest(torch._C._dist_isend(tensor, dst_rank))
return _DistributedRequest(torch._C._dist_isend(tensor, dst))
def irecv(tensor, src_rank):
def irecv(tensor, src):
assert torch.distributed._initialized == _INITIALIZED_PG, \
"collective only supported in process-group mode"
return _DistributedRequest(torch._C._dist_irecv(tensor, src_rank))
return _DistributedRequest(torch._C._dist_irecv(tensor, src))
def send(tensor, dst_rank):
def send(tensor, dst):
assert torch.distributed._initialized == _INITIALIZED_PG, \
"collective only supported in process-group mode"
return torch._C._dist_send(tensor, dst_rank)
return torch._C._dist_send(tensor, dst)
def recv(tensor, src_rank=None):
def recv(tensor, src=None):
assert torch.distributed._initialized == _INITIALIZED_PG, \
"collective only supported in process-group mode"
if src_rank is None:
if src is None:
return torch._C._dist_recv_any_source(tensor)
return torch._C._dist_recv(tensor, src_rank)
return torch._C._dist_recv(tensor, src)
def broadcast(tensor, src_rank, group=group.WORLD):
def broadcast(tensor, src, group=group.WORLD):
assert torch.distributed._initialized == _INITIALIZED_PG, \
"collective only supported in process-group mode"
return torch._C._dist_broadcast(tensor, src_rank, group)
return torch._C._dist_broadcast(tensor, src, group)
def all_reduce(tensor, op=reduce_op.SUM, group=group.WORLD):
@ -72,40 +72,40 @@ def all_reduce(tensor, op=reduce_op.SUM, group=group.WORLD):
return torch._C._dist_all_reduce(tensor, op, group)
def reduce(tensor, dst_rank, op=reduce_op.SUM, group=group.WORLD):
def reduce(tensor, dst, op=reduce_op.SUM, group=group.WORLD):
assert torch.distributed._initialized == _INITIALIZED_PG, \
"collective only supported in process-group mode"
return torch._C._dist_reduce(tensor, dst_rank, op, group)
return torch._C._dist_reduce(tensor, dst, op, group)
def all_gather(tensors, tensor, group=group.WORLD):
def all_gather(tensor_list, tensor, group=group.WORLD):
assert torch.distributed._initialized == _INITIALIZED_PG, \
"collective only supported in process-group mode"
return torch._C._dist_all_gather(tensors, tensor, group)
return torch._C._dist_all_gather(tensor_list, tensor, group)
def gather_send(tensor, dst_rank, group=group.WORLD):
def gather_send(tensor, dst, group=group.WORLD):
assert torch.distributed._initialized == _INITIALIZED_PG, \
"collective only supported in process-group mode"
return torch._C._dist_gather_send(tensor, dst_rank, group)
return torch._C._dist_gather_send(tensor, dst, group)
def gather_recv(tensors, tensor, group=group.WORLD):
def gather_recv(tensor_list, tensor, group=group.WORLD):
assert torch.distributed._initialized == _INITIALIZED_PG, \
"collective only supported in process-group mode"
return torch._C._dist_gather_recv(tensors, tensor, group)
return torch._C._dist_gather_recv(tensor_list, tensor, group)
def scatter_send(tensors, tensor, group=group.WORLD):
def scatter_send(tensor_list, tensor, group=group.WORLD):
assert torch.distributed._initialized == _INITIALIZED_PG, \
"collective only supported in process-group mode"
return torch._C._dist_scatter_send(tensors, tensor, group)
return torch._C._dist_scatter_send(tensor_list, tensor, group)
def scatter_recv(tensor, src_rank, group=group.WORLD):
def scatter_recv(tensor, src, group=group.WORLD):
assert torch.distributed._initialized == _INITIALIZED_PG, \
"collective only supported in process-group mode"
return torch._C._dist_scatter_recv(tensor, src_rank, group)
return torch._C._dist_scatter_recv(tensor, src, group)
def barrier(group=group.WORLD):
@ -118,5 +118,5 @@ def new_group(ranks=None):
assert torch.distributed._initialized == _INITIALIZED_PG, \
"collective only supported in process-group mode"
if ranks is None:
ranks = list(range(get_num_processes()))
ranks = list(range(get_world_size()))
return torch._C._dist_new_group(ranks)

View File

@ -134,7 +134,7 @@ if rank == 0:
print_header("scatter")
for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
tensor = torch.ByteTensor(bytes).fill_(42)
tensors = [tensor for n in range(0, dist.get_num_processes())]
tensors = [tensor for n in range(0, dist.get_world_size())]
for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
start = timer()
for i in range(0, num_tensors):
@ -154,7 +154,7 @@ if rank == 0:
print_header("gather")
for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
tensor = torch.ByteTensor(bytes).fill_(42)
tensors = [tensor for n in range(0, dist.get_num_processes())]
tensors = [tensor for n in range(0, dist.get_world_size())]
for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
start = timer()
for i in range(0, num_tensors):
@ -174,7 +174,7 @@ if rank == 0:
print_header("all gather")
for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
tensor = torch.ByteTensor(bytes).fill_(42)
tensors = [tensor for n in range(0, dist.get_num_processes())]
tensors = [tensor for n in range(0, dist.get_world_size())]
for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
start = timer()
for i in range(0, num_tensors):
@ -185,7 +185,7 @@ if rank == 0:
else:
for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
tensor = torch.ByteTensor(bytes).fill_(42)
tensors = [tensor for n in range(0, dist.get_num_processes())]
tensors = [tensor for n in range(0, dist.get_world_size())]
for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
for i in range(0, num_tensors):
dist.all_gather(tensors, tensor)

View File

@ -305,7 +305,7 @@ class DistributedDataParallel(Module):
reduce_stream = reduction_streams[0]
with torch.cuda.stream(reduce_stream):
reduce_stream.wait_stream(nccl_streams[0])
coalesced /= dist.get_num_processes()
coalesced /= dist.get_world_size()
dist.all_reduce(coalesced, group=group_id)
for grad, reduced in zip(grad_batch, _unflatten_tensors(coalesced, grad_batch)):
grad.copy_(reduced)

View File

@ -1,7 +1,7 @@
import math
import torch
from .sampler import Sampler
from torch.distributed.collectives import get_num_processes, get_rank
from torch.distributed.collectives import get_world_size, get_rank
class DistributedSampler(Sampler):
@ -24,7 +24,7 @@ class DistributedSampler(Sampler):
def __init__(self, dataset, num_replicas=None, rank=None):
if num_replicas is None:
num_replicas = get_num_processes()
num_replicas = get_world_size()
if rank is None:
rank = get_rank()
self.dataset = dataset