From d9d50f80c7f162c9c6c6d95c0b51f9e6fd8fccee Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Sat, 10 Jun 2017 12:33:26 -0400 Subject: [PATCH] Rename arguments to distributed collectives --- test/test_distributed.py | 20 +++++------ torch/distributed/collectives.py | 50 ++++++++++++++-------------- torch/lib/THD/benchmark/benchmark.py | 8 ++--- torch/nn/parallel/distributed.py | 2 +- torch/utils/data/distributed.py | 4 +-- 5 files changed, 42 insertions(+), 42 deletions(-) diff --git a/test/test_distributed.py b/test/test_distributed.py index 9106888582d..5cf3de6b4b1 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -64,7 +64,7 @@ class Barrier(object): data = f.read() if int(data) >= cls.barrier_id: arrived += 1 - if arrived == dist.get_num_processes(): + if arrived == dist.get_world_size(): break if time.time() - start_time > timeout: @@ -87,7 +87,7 @@ class _DistTestBase(object): return (group, group_id, rank) def _init_global_test(self): - group = [i for i in range(0, dist.get_num_processes())] + group = [i for i in range(0, dist.get_world_size())] group_id = dist.group.WORLD rank = dist.get_rank() return (group, group_id, rank) @@ -96,7 +96,7 @@ class _DistTestBase(object): def test_get_rank(self): test_dir = os.path.join(TEMP_DIR, 'test_dir') pid = str(os.getpid()) - num_processes = dist.get_num_processes() + num_processes = dist.get_world_size() with open(os.path.join(test_dir, pid), 'w') as f: f.write(str(dist.get_rank())) @@ -121,12 +121,12 @@ class _DistTestBase(object): def test_send_recv(self): rank = dist.get_rank() tensor = _build_tensor(rank + 1) - for dest in range(0, dist.get_num_processes()): + for dest in range(0, dist.get_world_size()): if dest == rank: continue dist.send(tensor, dest) - for src in range(0, dist.get_num_processes()): + for src in range(0, dist.get_world_size()): if src == rank: continue tensor = _build_tensor(src + 1, value=-1) @@ -142,27 +142,27 @@ class _DistTestBase(object): def test_send_recv_any_source(self): rank = dist.get_rank() tensor = _build_tensor(10, rank) - for dest in range(0, dist.get_num_processes()): + for dest in range(0, dist.get_world_size()): if dest == rank: continue dist.send(tensor, dest) recv_ranks = set() - for src in range(0, dist.get_num_processes()): + for src in range(0, dist.get_world_size()): if src == rank: continue tensor = _build_tensor(10, value=-1) dist.recv(tensor) recv_ranks.add(tensor.resize_(1)[0]) - self.assertEqual(len(recv_ranks), dist.get_num_processes() - 1) + self.assertEqual(len(recv_ranks), dist.get_world_size() - 1) self._barrier() # ISEND @unittest.skipIf(BACKEND == 'gloo', "Gloo does not support isend") def test_isend(self): rank = dist.get_rank() - world_size = dist.get_num_processes() + world_size = dist.get_world_size() if rank == 0: requests = [ @@ -182,7 +182,7 @@ class _DistTestBase(object): @unittest.skipIf(BACKEND == 'gloo', "Gloo does not support irecv") def test_irecv(self): rank = dist.get_rank() - world_size = dist.get_num_processes() + world_size = dist.get_world_size() if rank == 0: expected_tensors = [_build_tensor(src, -1) for src in range(1, world_size)] diff --git a/torch/distributed/collectives.py b/torch/distributed/collectives.py index 1ea6ccf6169..49c3714241a 100644 --- a/torch/distributed/collectives.py +++ b/torch/distributed/collectives.py @@ -29,41 +29,41 @@ def get_rank(): return torch._C._dist_get_rank() -def get_num_processes(): +def get_world_size(): assert torch.distributed._initialized return torch._C._dist_get_num_processes() -def isend(tensor, dst_rank): +def isend(tensor, dst): assert torch.distributed._initialized == _INITIALIZED_PG, \ "collective only supported in process-group mode" - return _DistributedRequest(torch._C._dist_isend(tensor, dst_rank)) + return _DistributedRequest(torch._C._dist_isend(tensor, dst)) -def irecv(tensor, src_rank): +def irecv(tensor, src): assert torch.distributed._initialized == _INITIALIZED_PG, \ "collective only supported in process-group mode" - return _DistributedRequest(torch._C._dist_irecv(tensor, src_rank)) + return _DistributedRequest(torch._C._dist_irecv(tensor, src)) -def send(tensor, dst_rank): +def send(tensor, dst): assert torch.distributed._initialized == _INITIALIZED_PG, \ "collective only supported in process-group mode" - return torch._C._dist_send(tensor, dst_rank) + return torch._C._dist_send(tensor, dst) -def recv(tensor, src_rank=None): +def recv(tensor, src=None): assert torch.distributed._initialized == _INITIALIZED_PG, \ "collective only supported in process-group mode" - if src_rank is None: + if src is None: return torch._C._dist_recv_any_source(tensor) - return torch._C._dist_recv(tensor, src_rank) + return torch._C._dist_recv(tensor, src) -def broadcast(tensor, src_rank, group=group.WORLD): +def broadcast(tensor, src, group=group.WORLD): assert torch.distributed._initialized == _INITIALIZED_PG, \ "collective only supported in process-group mode" - return torch._C._dist_broadcast(tensor, src_rank, group) + return torch._C._dist_broadcast(tensor, src, group) def all_reduce(tensor, op=reduce_op.SUM, group=group.WORLD): @@ -72,40 +72,40 @@ def all_reduce(tensor, op=reduce_op.SUM, group=group.WORLD): return torch._C._dist_all_reduce(tensor, op, group) -def reduce(tensor, dst_rank, op=reduce_op.SUM, group=group.WORLD): +def reduce(tensor, dst, op=reduce_op.SUM, group=group.WORLD): assert torch.distributed._initialized == _INITIALIZED_PG, \ "collective only supported in process-group mode" - return torch._C._dist_reduce(tensor, dst_rank, op, group) + return torch._C._dist_reduce(tensor, dst, op, group) -def all_gather(tensors, tensor, group=group.WORLD): +def all_gather(tensor_list, tensor, group=group.WORLD): assert torch.distributed._initialized == _INITIALIZED_PG, \ "collective only supported in process-group mode" - return torch._C._dist_all_gather(tensors, tensor, group) + return torch._C._dist_all_gather(tensor_list, tensor, group) -def gather_send(tensor, dst_rank, group=group.WORLD): +def gather_send(tensor, dst, group=group.WORLD): assert torch.distributed._initialized == _INITIALIZED_PG, \ "collective only supported in process-group mode" - return torch._C._dist_gather_send(tensor, dst_rank, group) + return torch._C._dist_gather_send(tensor, dst, group) -def gather_recv(tensors, tensor, group=group.WORLD): +def gather_recv(tensor_list, tensor, group=group.WORLD): assert torch.distributed._initialized == _INITIALIZED_PG, \ "collective only supported in process-group mode" - return torch._C._dist_gather_recv(tensors, tensor, group) + return torch._C._dist_gather_recv(tensor_list, tensor, group) -def scatter_send(tensors, tensor, group=group.WORLD): +def scatter_send(tensor_list, tensor, group=group.WORLD): assert torch.distributed._initialized == _INITIALIZED_PG, \ "collective only supported in process-group mode" - return torch._C._dist_scatter_send(tensors, tensor, group) + return torch._C._dist_scatter_send(tensor_list, tensor, group) -def scatter_recv(tensor, src_rank, group=group.WORLD): +def scatter_recv(tensor, src, group=group.WORLD): assert torch.distributed._initialized == _INITIALIZED_PG, \ "collective only supported in process-group mode" - return torch._C._dist_scatter_recv(tensor, src_rank, group) + return torch._C._dist_scatter_recv(tensor, src, group) def barrier(group=group.WORLD): @@ -118,5 +118,5 @@ def new_group(ranks=None): assert torch.distributed._initialized == _INITIALIZED_PG, \ "collective only supported in process-group mode" if ranks is None: - ranks = list(range(get_num_processes())) + ranks = list(range(get_world_size())) return torch._C._dist_new_group(ranks) diff --git a/torch/lib/THD/benchmark/benchmark.py b/torch/lib/THD/benchmark/benchmark.py index 2c3ed1e53c5..844827a3aba 100644 --- a/torch/lib/THD/benchmark/benchmark.py +++ b/torch/lib/THD/benchmark/benchmark.py @@ -134,7 +134,7 @@ if rank == 0: print_header("scatter") for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]: tensor = torch.ByteTensor(bytes).fill_(42) - tensors = [tensor for n in range(0, dist.get_num_processes())] + tensors = [tensor for n in range(0, dist.get_world_size())] for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]: start = timer() for i in range(0, num_tensors): @@ -154,7 +154,7 @@ if rank == 0: print_header("gather") for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]: tensor = torch.ByteTensor(bytes).fill_(42) - tensors = [tensor for n in range(0, dist.get_num_processes())] + tensors = [tensor for n in range(0, dist.get_world_size())] for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]: start = timer() for i in range(0, num_tensors): @@ -174,7 +174,7 @@ if rank == 0: print_header("all gather") for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]: tensor = torch.ByteTensor(bytes).fill_(42) - tensors = [tensor for n in range(0, dist.get_num_processes())] + tensors = [tensor for n in range(0, dist.get_world_size())] for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]: start = timer() for i in range(0, num_tensors): @@ -185,7 +185,7 @@ if rank == 0: else: for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]: tensor = torch.ByteTensor(bytes).fill_(42) - tensors = [tensor for n in range(0, dist.get_num_processes())] + tensors = [tensor for n in range(0, dist.get_world_size())] for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]: for i in range(0, num_tensors): dist.all_gather(tensors, tensor) diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index 19d75745055..45173664844 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -305,7 +305,7 @@ class DistributedDataParallel(Module): reduce_stream = reduction_streams[0] with torch.cuda.stream(reduce_stream): reduce_stream.wait_stream(nccl_streams[0]) - coalesced /= dist.get_num_processes() + coalesced /= dist.get_world_size() dist.all_reduce(coalesced, group=group_id) for grad, reduced in zip(grad_batch, _unflatten_tensors(coalesced, grad_batch)): grad.copy_(reduced) diff --git a/torch/utils/data/distributed.py b/torch/utils/data/distributed.py index 06a7a4ce5a8..c1ec77cae42 100644 --- a/torch/utils/data/distributed.py +++ b/torch/utils/data/distributed.py @@ -1,7 +1,7 @@ import math import torch from .sampler import Sampler -from torch.distributed.collectives import get_num_processes, get_rank +from torch.distributed.collectives import get_world_size, get_rank class DistributedSampler(Sampler): @@ -24,7 +24,7 @@ class DistributedSampler(Sampler): def __init__(self, dataset, num_replicas=None, rank=None): if num_replicas is None: - num_replicas = get_num_processes() + num_replicas = get_world_size() if rank is None: rank = get_rank() self.dataset = dataset