Rename arguments to distributed collectives

This commit is contained in:
Adam Paszke 2017-06-10 12:33:26 -04:00 committed by Soumith Chintala
parent 714351ff39
commit d9d50f80c7
5 changed files with 42 additions and 42 deletions

View File

@ -64,7 +64,7 @@ class Barrier(object):
data = f.read() data = f.read()
if int(data) >= cls.barrier_id: if int(data) >= cls.barrier_id:
arrived += 1 arrived += 1
if arrived == dist.get_num_processes(): if arrived == dist.get_world_size():
break break
if time.time() - start_time > timeout: if time.time() - start_time > timeout:
@ -87,7 +87,7 @@ class _DistTestBase(object):
return (group, group_id, rank) return (group, group_id, rank)
def _init_global_test(self): def _init_global_test(self):
group = [i for i in range(0, dist.get_num_processes())] group = [i for i in range(0, dist.get_world_size())]
group_id = dist.group.WORLD group_id = dist.group.WORLD
rank = dist.get_rank() rank = dist.get_rank()
return (group, group_id, rank) return (group, group_id, rank)
@ -96,7 +96,7 @@ class _DistTestBase(object):
def test_get_rank(self): def test_get_rank(self):
test_dir = os.path.join(TEMP_DIR, 'test_dir') test_dir = os.path.join(TEMP_DIR, 'test_dir')
pid = str(os.getpid()) pid = str(os.getpid())
num_processes = dist.get_num_processes() num_processes = dist.get_world_size()
with open(os.path.join(test_dir, pid), 'w') as f: with open(os.path.join(test_dir, pid), 'w') as f:
f.write(str(dist.get_rank())) f.write(str(dist.get_rank()))
@ -121,12 +121,12 @@ class _DistTestBase(object):
def test_send_recv(self): def test_send_recv(self):
rank = dist.get_rank() rank = dist.get_rank()
tensor = _build_tensor(rank + 1) tensor = _build_tensor(rank + 1)
for dest in range(0, dist.get_num_processes()): for dest in range(0, dist.get_world_size()):
if dest == rank: if dest == rank:
continue continue
dist.send(tensor, dest) dist.send(tensor, dest)
for src in range(0, dist.get_num_processes()): for src in range(0, dist.get_world_size()):
if src == rank: if src == rank:
continue continue
tensor = _build_tensor(src + 1, value=-1) tensor = _build_tensor(src + 1, value=-1)
@ -142,27 +142,27 @@ class _DistTestBase(object):
def test_send_recv_any_source(self): def test_send_recv_any_source(self):
rank = dist.get_rank() rank = dist.get_rank()
tensor = _build_tensor(10, rank) tensor = _build_tensor(10, rank)
for dest in range(0, dist.get_num_processes()): for dest in range(0, dist.get_world_size()):
if dest == rank: if dest == rank:
continue continue
dist.send(tensor, dest) dist.send(tensor, dest)
recv_ranks = set() recv_ranks = set()
for src in range(0, dist.get_num_processes()): for src in range(0, dist.get_world_size()):
if src == rank: if src == rank:
continue continue
tensor = _build_tensor(10, value=-1) tensor = _build_tensor(10, value=-1)
dist.recv(tensor) dist.recv(tensor)
recv_ranks.add(tensor.resize_(1)[0]) recv_ranks.add(tensor.resize_(1)[0])
self.assertEqual(len(recv_ranks), dist.get_num_processes() - 1) self.assertEqual(len(recv_ranks), dist.get_world_size() - 1)
self._barrier() self._barrier()
# ISEND # ISEND
@unittest.skipIf(BACKEND == 'gloo', "Gloo does not support isend") @unittest.skipIf(BACKEND == 'gloo', "Gloo does not support isend")
def test_isend(self): def test_isend(self):
rank = dist.get_rank() rank = dist.get_rank()
world_size = dist.get_num_processes() world_size = dist.get_world_size()
if rank == 0: if rank == 0:
requests = [ requests = [
@ -182,7 +182,7 @@ class _DistTestBase(object):
@unittest.skipIf(BACKEND == 'gloo', "Gloo does not support irecv") @unittest.skipIf(BACKEND == 'gloo', "Gloo does not support irecv")
def test_irecv(self): def test_irecv(self):
rank = dist.get_rank() rank = dist.get_rank()
world_size = dist.get_num_processes() world_size = dist.get_world_size()
if rank == 0: if rank == 0:
expected_tensors = [_build_tensor(src, -1) for src in range(1, world_size)] expected_tensors = [_build_tensor(src, -1) for src in range(1, world_size)]

View File

@ -29,41 +29,41 @@ def get_rank():
return torch._C._dist_get_rank() return torch._C._dist_get_rank()
def get_num_processes(): def get_world_size():
assert torch.distributed._initialized assert torch.distributed._initialized
return torch._C._dist_get_num_processes() return torch._C._dist_get_num_processes()
def isend(tensor, dst_rank): def isend(tensor, dst):
assert torch.distributed._initialized == _INITIALIZED_PG, \ assert torch.distributed._initialized == _INITIALIZED_PG, \
"collective only supported in process-group mode" "collective only supported in process-group mode"
return _DistributedRequest(torch._C._dist_isend(tensor, dst_rank)) return _DistributedRequest(torch._C._dist_isend(tensor, dst))
def irecv(tensor, src_rank): def irecv(tensor, src):
assert torch.distributed._initialized == _INITIALIZED_PG, \ assert torch.distributed._initialized == _INITIALIZED_PG, \
"collective only supported in process-group mode" "collective only supported in process-group mode"
return _DistributedRequest(torch._C._dist_irecv(tensor, src_rank)) return _DistributedRequest(torch._C._dist_irecv(tensor, src))
def send(tensor, dst_rank): def send(tensor, dst):
assert torch.distributed._initialized == _INITIALIZED_PG, \ assert torch.distributed._initialized == _INITIALIZED_PG, \
"collective only supported in process-group mode" "collective only supported in process-group mode"
return torch._C._dist_send(tensor, dst_rank) return torch._C._dist_send(tensor, dst)
def recv(tensor, src_rank=None): def recv(tensor, src=None):
assert torch.distributed._initialized == _INITIALIZED_PG, \ assert torch.distributed._initialized == _INITIALIZED_PG, \
"collective only supported in process-group mode" "collective only supported in process-group mode"
if src_rank is None: if src is None:
return torch._C._dist_recv_any_source(tensor) return torch._C._dist_recv_any_source(tensor)
return torch._C._dist_recv(tensor, src_rank) return torch._C._dist_recv(tensor, src)
def broadcast(tensor, src_rank, group=group.WORLD): def broadcast(tensor, src, group=group.WORLD):
assert torch.distributed._initialized == _INITIALIZED_PG, \ assert torch.distributed._initialized == _INITIALIZED_PG, \
"collective only supported in process-group mode" "collective only supported in process-group mode"
return torch._C._dist_broadcast(tensor, src_rank, group) return torch._C._dist_broadcast(tensor, src, group)
def all_reduce(tensor, op=reduce_op.SUM, group=group.WORLD): def all_reduce(tensor, op=reduce_op.SUM, group=group.WORLD):
@ -72,40 +72,40 @@ def all_reduce(tensor, op=reduce_op.SUM, group=group.WORLD):
return torch._C._dist_all_reduce(tensor, op, group) return torch._C._dist_all_reduce(tensor, op, group)
def reduce(tensor, dst_rank, op=reduce_op.SUM, group=group.WORLD): def reduce(tensor, dst, op=reduce_op.SUM, group=group.WORLD):
assert torch.distributed._initialized == _INITIALIZED_PG, \ assert torch.distributed._initialized == _INITIALIZED_PG, \
"collective only supported in process-group mode" "collective only supported in process-group mode"
return torch._C._dist_reduce(tensor, dst_rank, op, group) return torch._C._dist_reduce(tensor, dst, op, group)
def all_gather(tensors, tensor, group=group.WORLD): def all_gather(tensor_list, tensor, group=group.WORLD):
assert torch.distributed._initialized == _INITIALIZED_PG, \ assert torch.distributed._initialized == _INITIALIZED_PG, \
"collective only supported in process-group mode" "collective only supported in process-group mode"
return torch._C._dist_all_gather(tensors, tensor, group) return torch._C._dist_all_gather(tensor_list, tensor, group)
def gather_send(tensor, dst_rank, group=group.WORLD): def gather_send(tensor, dst, group=group.WORLD):
assert torch.distributed._initialized == _INITIALIZED_PG, \ assert torch.distributed._initialized == _INITIALIZED_PG, \
"collective only supported in process-group mode" "collective only supported in process-group mode"
return torch._C._dist_gather_send(tensor, dst_rank, group) return torch._C._dist_gather_send(tensor, dst, group)
def gather_recv(tensors, tensor, group=group.WORLD): def gather_recv(tensor_list, tensor, group=group.WORLD):
assert torch.distributed._initialized == _INITIALIZED_PG, \ assert torch.distributed._initialized == _INITIALIZED_PG, \
"collective only supported in process-group mode" "collective only supported in process-group mode"
return torch._C._dist_gather_recv(tensors, tensor, group) return torch._C._dist_gather_recv(tensor_list, tensor, group)
def scatter_send(tensors, tensor, group=group.WORLD): def scatter_send(tensor_list, tensor, group=group.WORLD):
assert torch.distributed._initialized == _INITIALIZED_PG, \ assert torch.distributed._initialized == _INITIALIZED_PG, \
"collective only supported in process-group mode" "collective only supported in process-group mode"
return torch._C._dist_scatter_send(tensors, tensor, group) return torch._C._dist_scatter_send(tensor_list, tensor, group)
def scatter_recv(tensor, src_rank, group=group.WORLD): def scatter_recv(tensor, src, group=group.WORLD):
assert torch.distributed._initialized == _INITIALIZED_PG, \ assert torch.distributed._initialized == _INITIALIZED_PG, \
"collective only supported in process-group mode" "collective only supported in process-group mode"
return torch._C._dist_scatter_recv(tensor, src_rank, group) return torch._C._dist_scatter_recv(tensor, src, group)
def barrier(group=group.WORLD): def barrier(group=group.WORLD):
@ -118,5 +118,5 @@ def new_group(ranks=None):
assert torch.distributed._initialized == _INITIALIZED_PG, \ assert torch.distributed._initialized == _INITIALIZED_PG, \
"collective only supported in process-group mode" "collective only supported in process-group mode"
if ranks is None: if ranks is None:
ranks = list(range(get_num_processes())) ranks = list(range(get_world_size()))
return torch._C._dist_new_group(ranks) return torch._C._dist_new_group(ranks)

View File

@ -134,7 +134,7 @@ if rank == 0:
print_header("scatter") print_header("scatter")
for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]: for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
tensor = torch.ByteTensor(bytes).fill_(42) tensor = torch.ByteTensor(bytes).fill_(42)
tensors = [tensor for n in range(0, dist.get_num_processes())] tensors = [tensor for n in range(0, dist.get_world_size())]
for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]: for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
start = timer() start = timer()
for i in range(0, num_tensors): for i in range(0, num_tensors):
@ -154,7 +154,7 @@ if rank == 0:
print_header("gather") print_header("gather")
for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]: for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
tensor = torch.ByteTensor(bytes).fill_(42) tensor = torch.ByteTensor(bytes).fill_(42)
tensors = [tensor for n in range(0, dist.get_num_processes())] tensors = [tensor for n in range(0, dist.get_world_size())]
for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]: for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
start = timer() start = timer()
for i in range(0, num_tensors): for i in range(0, num_tensors):
@ -174,7 +174,7 @@ if rank == 0:
print_header("all gather") print_header("all gather")
for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]: for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
tensor = torch.ByteTensor(bytes).fill_(42) tensor = torch.ByteTensor(bytes).fill_(42)
tensors = [tensor for n in range(0, dist.get_num_processes())] tensors = [tensor for n in range(0, dist.get_world_size())]
for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]: for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
start = timer() start = timer()
for i in range(0, num_tensors): for i in range(0, num_tensors):
@ -185,7 +185,7 @@ if rank == 0:
else: else:
for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]: for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
tensor = torch.ByteTensor(bytes).fill_(42) tensor = torch.ByteTensor(bytes).fill_(42)
tensors = [tensor for n in range(0, dist.get_num_processes())] tensors = [tensor for n in range(0, dist.get_world_size())]
for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]: for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
for i in range(0, num_tensors): for i in range(0, num_tensors):
dist.all_gather(tensors, tensor) dist.all_gather(tensors, tensor)

View File

@ -305,7 +305,7 @@ class DistributedDataParallel(Module):
reduce_stream = reduction_streams[0] reduce_stream = reduction_streams[0]
with torch.cuda.stream(reduce_stream): with torch.cuda.stream(reduce_stream):
reduce_stream.wait_stream(nccl_streams[0]) reduce_stream.wait_stream(nccl_streams[0])
coalesced /= dist.get_num_processes() coalesced /= dist.get_world_size()
dist.all_reduce(coalesced, group=group_id) dist.all_reduce(coalesced, group=group_id)
for grad, reduced in zip(grad_batch, _unflatten_tensors(coalesced, grad_batch)): for grad, reduced in zip(grad_batch, _unflatten_tensors(coalesced, grad_batch)):
grad.copy_(reduced) grad.copy_(reduced)

View File

@ -1,7 +1,7 @@
import math import math
import torch import torch
from .sampler import Sampler from .sampler import Sampler
from torch.distributed.collectives import get_num_processes, get_rank from torch.distributed.collectives import get_world_size, get_rank
class DistributedSampler(Sampler): class DistributedSampler(Sampler):
@ -24,7 +24,7 @@ class DistributedSampler(Sampler):
def __init__(self, dataset, num_replicas=None, rank=None): def __init__(self, dataset, num_replicas=None, rank=None):
if num_replicas is None: if num_replicas is None:
num_replicas = get_num_processes() num_replicas = get_world_size()
if rank is None: if rank is None:
rank = get_rank() rank = get_rank()
self.dataset = dataset self.dataset = dataset