mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Rename arguments to distributed collectives
This commit is contained in:
parent
714351ff39
commit
d9d50f80c7
|
|
@ -64,7 +64,7 @@ class Barrier(object):
|
|||
data = f.read()
|
||||
if int(data) >= cls.barrier_id:
|
||||
arrived += 1
|
||||
if arrived == dist.get_num_processes():
|
||||
if arrived == dist.get_world_size():
|
||||
break
|
||||
|
||||
if time.time() - start_time > timeout:
|
||||
|
|
@ -87,7 +87,7 @@ class _DistTestBase(object):
|
|||
return (group, group_id, rank)
|
||||
|
||||
def _init_global_test(self):
|
||||
group = [i for i in range(0, dist.get_num_processes())]
|
||||
group = [i for i in range(0, dist.get_world_size())]
|
||||
group_id = dist.group.WORLD
|
||||
rank = dist.get_rank()
|
||||
return (group, group_id, rank)
|
||||
|
|
@ -96,7 +96,7 @@ class _DistTestBase(object):
|
|||
def test_get_rank(self):
|
||||
test_dir = os.path.join(TEMP_DIR, 'test_dir')
|
||||
pid = str(os.getpid())
|
||||
num_processes = dist.get_num_processes()
|
||||
num_processes = dist.get_world_size()
|
||||
with open(os.path.join(test_dir, pid), 'w') as f:
|
||||
f.write(str(dist.get_rank()))
|
||||
|
||||
|
|
@ -121,12 +121,12 @@ class _DistTestBase(object):
|
|||
def test_send_recv(self):
|
||||
rank = dist.get_rank()
|
||||
tensor = _build_tensor(rank + 1)
|
||||
for dest in range(0, dist.get_num_processes()):
|
||||
for dest in range(0, dist.get_world_size()):
|
||||
if dest == rank:
|
||||
continue
|
||||
dist.send(tensor, dest)
|
||||
|
||||
for src in range(0, dist.get_num_processes()):
|
||||
for src in range(0, dist.get_world_size()):
|
||||
if src == rank:
|
||||
continue
|
||||
tensor = _build_tensor(src + 1, value=-1)
|
||||
|
|
@ -142,27 +142,27 @@ class _DistTestBase(object):
|
|||
def test_send_recv_any_source(self):
|
||||
rank = dist.get_rank()
|
||||
tensor = _build_tensor(10, rank)
|
||||
for dest in range(0, dist.get_num_processes()):
|
||||
for dest in range(0, dist.get_world_size()):
|
||||
if dest == rank:
|
||||
continue
|
||||
dist.send(tensor, dest)
|
||||
|
||||
recv_ranks = set()
|
||||
for src in range(0, dist.get_num_processes()):
|
||||
for src in range(0, dist.get_world_size()):
|
||||
if src == rank:
|
||||
continue
|
||||
tensor = _build_tensor(10, value=-1)
|
||||
dist.recv(tensor)
|
||||
recv_ranks.add(tensor.resize_(1)[0])
|
||||
|
||||
self.assertEqual(len(recv_ranks), dist.get_num_processes() - 1)
|
||||
self.assertEqual(len(recv_ranks), dist.get_world_size() - 1)
|
||||
self._barrier()
|
||||
|
||||
# ISEND
|
||||
@unittest.skipIf(BACKEND == 'gloo', "Gloo does not support isend")
|
||||
def test_isend(self):
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_num_processes()
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
if rank == 0:
|
||||
requests = [
|
||||
|
|
@ -182,7 +182,7 @@ class _DistTestBase(object):
|
|||
@unittest.skipIf(BACKEND == 'gloo', "Gloo does not support irecv")
|
||||
def test_irecv(self):
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_num_processes()
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
if rank == 0:
|
||||
expected_tensors = [_build_tensor(src, -1) for src in range(1, world_size)]
|
||||
|
|
|
|||
|
|
@ -29,41 +29,41 @@ def get_rank():
|
|||
return torch._C._dist_get_rank()
|
||||
|
||||
|
||||
def get_num_processes():
|
||||
def get_world_size():
|
||||
assert torch.distributed._initialized
|
||||
return torch._C._dist_get_num_processes()
|
||||
|
||||
|
||||
def isend(tensor, dst_rank):
|
||||
def isend(tensor, dst):
|
||||
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
||||
"collective only supported in process-group mode"
|
||||
return _DistributedRequest(torch._C._dist_isend(tensor, dst_rank))
|
||||
return _DistributedRequest(torch._C._dist_isend(tensor, dst))
|
||||
|
||||
|
||||
def irecv(tensor, src_rank):
|
||||
def irecv(tensor, src):
|
||||
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
||||
"collective only supported in process-group mode"
|
||||
return _DistributedRequest(torch._C._dist_irecv(tensor, src_rank))
|
||||
return _DistributedRequest(torch._C._dist_irecv(tensor, src))
|
||||
|
||||
|
||||
def send(tensor, dst_rank):
|
||||
def send(tensor, dst):
|
||||
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
||||
"collective only supported in process-group mode"
|
||||
return torch._C._dist_send(tensor, dst_rank)
|
||||
return torch._C._dist_send(tensor, dst)
|
||||
|
||||
|
||||
def recv(tensor, src_rank=None):
|
||||
def recv(tensor, src=None):
|
||||
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
||||
"collective only supported in process-group mode"
|
||||
if src_rank is None:
|
||||
if src is None:
|
||||
return torch._C._dist_recv_any_source(tensor)
|
||||
return torch._C._dist_recv(tensor, src_rank)
|
||||
return torch._C._dist_recv(tensor, src)
|
||||
|
||||
|
||||
def broadcast(tensor, src_rank, group=group.WORLD):
|
||||
def broadcast(tensor, src, group=group.WORLD):
|
||||
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
||||
"collective only supported in process-group mode"
|
||||
return torch._C._dist_broadcast(tensor, src_rank, group)
|
||||
return torch._C._dist_broadcast(tensor, src, group)
|
||||
|
||||
|
||||
def all_reduce(tensor, op=reduce_op.SUM, group=group.WORLD):
|
||||
|
|
@ -72,40 +72,40 @@ def all_reduce(tensor, op=reduce_op.SUM, group=group.WORLD):
|
|||
return torch._C._dist_all_reduce(tensor, op, group)
|
||||
|
||||
|
||||
def reduce(tensor, dst_rank, op=reduce_op.SUM, group=group.WORLD):
|
||||
def reduce(tensor, dst, op=reduce_op.SUM, group=group.WORLD):
|
||||
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
||||
"collective only supported in process-group mode"
|
||||
return torch._C._dist_reduce(tensor, dst_rank, op, group)
|
||||
return torch._C._dist_reduce(tensor, dst, op, group)
|
||||
|
||||
|
||||
def all_gather(tensors, tensor, group=group.WORLD):
|
||||
def all_gather(tensor_list, tensor, group=group.WORLD):
|
||||
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
||||
"collective only supported in process-group mode"
|
||||
return torch._C._dist_all_gather(tensors, tensor, group)
|
||||
return torch._C._dist_all_gather(tensor_list, tensor, group)
|
||||
|
||||
|
||||
def gather_send(tensor, dst_rank, group=group.WORLD):
|
||||
def gather_send(tensor, dst, group=group.WORLD):
|
||||
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
||||
"collective only supported in process-group mode"
|
||||
return torch._C._dist_gather_send(tensor, dst_rank, group)
|
||||
return torch._C._dist_gather_send(tensor, dst, group)
|
||||
|
||||
|
||||
def gather_recv(tensors, tensor, group=group.WORLD):
|
||||
def gather_recv(tensor_list, tensor, group=group.WORLD):
|
||||
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
||||
"collective only supported in process-group mode"
|
||||
return torch._C._dist_gather_recv(tensors, tensor, group)
|
||||
return torch._C._dist_gather_recv(tensor_list, tensor, group)
|
||||
|
||||
|
||||
def scatter_send(tensors, tensor, group=group.WORLD):
|
||||
def scatter_send(tensor_list, tensor, group=group.WORLD):
|
||||
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
||||
"collective only supported in process-group mode"
|
||||
return torch._C._dist_scatter_send(tensors, tensor, group)
|
||||
return torch._C._dist_scatter_send(tensor_list, tensor, group)
|
||||
|
||||
|
||||
def scatter_recv(tensor, src_rank, group=group.WORLD):
|
||||
def scatter_recv(tensor, src, group=group.WORLD):
|
||||
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
||||
"collective only supported in process-group mode"
|
||||
return torch._C._dist_scatter_recv(tensor, src_rank, group)
|
||||
return torch._C._dist_scatter_recv(tensor, src, group)
|
||||
|
||||
|
||||
def barrier(group=group.WORLD):
|
||||
|
|
@ -118,5 +118,5 @@ def new_group(ranks=None):
|
|||
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
||||
"collective only supported in process-group mode"
|
||||
if ranks is None:
|
||||
ranks = list(range(get_num_processes()))
|
||||
ranks = list(range(get_world_size()))
|
||||
return torch._C._dist_new_group(ranks)
|
||||
|
|
|
|||
|
|
@ -134,7 +134,7 @@ if rank == 0:
|
|||
print_header("scatter")
|
||||
for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
|
||||
tensor = torch.ByteTensor(bytes).fill_(42)
|
||||
tensors = [tensor for n in range(0, dist.get_num_processes())]
|
||||
tensors = [tensor for n in range(0, dist.get_world_size())]
|
||||
for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
|
||||
start = timer()
|
||||
for i in range(0, num_tensors):
|
||||
|
|
@ -154,7 +154,7 @@ if rank == 0:
|
|||
print_header("gather")
|
||||
for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
|
||||
tensor = torch.ByteTensor(bytes).fill_(42)
|
||||
tensors = [tensor for n in range(0, dist.get_num_processes())]
|
||||
tensors = [tensor for n in range(0, dist.get_world_size())]
|
||||
for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
|
||||
start = timer()
|
||||
for i in range(0, num_tensors):
|
||||
|
|
@ -174,7 +174,7 @@ if rank == 0:
|
|||
print_header("all gather")
|
||||
for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
|
||||
tensor = torch.ByteTensor(bytes).fill_(42)
|
||||
tensors = [tensor for n in range(0, dist.get_num_processes())]
|
||||
tensors = [tensor for n in range(0, dist.get_world_size())]
|
||||
for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
|
||||
start = timer()
|
||||
for i in range(0, num_tensors):
|
||||
|
|
@ -185,7 +185,7 @@ if rank == 0:
|
|||
else:
|
||||
for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
|
||||
tensor = torch.ByteTensor(bytes).fill_(42)
|
||||
tensors = [tensor for n in range(0, dist.get_num_processes())]
|
||||
tensors = [tensor for n in range(0, dist.get_world_size())]
|
||||
for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
|
||||
for i in range(0, num_tensors):
|
||||
dist.all_gather(tensors, tensor)
|
||||
|
|
|
|||
|
|
@ -305,7 +305,7 @@ class DistributedDataParallel(Module):
|
|||
reduce_stream = reduction_streams[0]
|
||||
with torch.cuda.stream(reduce_stream):
|
||||
reduce_stream.wait_stream(nccl_streams[0])
|
||||
coalesced /= dist.get_num_processes()
|
||||
coalesced /= dist.get_world_size()
|
||||
dist.all_reduce(coalesced, group=group_id)
|
||||
for grad, reduced in zip(grad_batch, _unflatten_tensors(coalesced, grad_batch)):
|
||||
grad.copy_(reduced)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import math
|
||||
import torch
|
||||
from .sampler import Sampler
|
||||
from torch.distributed.collectives import get_num_processes, get_rank
|
||||
from torch.distributed.collectives import get_world_size, get_rank
|
||||
|
||||
|
||||
class DistributedSampler(Sampler):
|
||||
|
|
@ -24,7 +24,7 @@ class DistributedSampler(Sampler):
|
|||
|
||||
def __init__(self, dataset, num_replicas=None, rank=None):
|
||||
if num_replicas is None:
|
||||
num_replicas = get_num_processes()
|
||||
num_replicas = get_world_size()
|
||||
if rank is None:
|
||||
rank = get_rank()
|
||||
self.dataset = dataset
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user