mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Rename arguments to distributed collectives
This commit is contained in:
parent
714351ff39
commit
d9d50f80c7
|
|
@ -64,7 +64,7 @@ class Barrier(object):
|
||||||
data = f.read()
|
data = f.read()
|
||||||
if int(data) >= cls.barrier_id:
|
if int(data) >= cls.barrier_id:
|
||||||
arrived += 1
|
arrived += 1
|
||||||
if arrived == dist.get_num_processes():
|
if arrived == dist.get_world_size():
|
||||||
break
|
break
|
||||||
|
|
||||||
if time.time() - start_time > timeout:
|
if time.time() - start_time > timeout:
|
||||||
|
|
@ -87,7 +87,7 @@ class _DistTestBase(object):
|
||||||
return (group, group_id, rank)
|
return (group, group_id, rank)
|
||||||
|
|
||||||
def _init_global_test(self):
|
def _init_global_test(self):
|
||||||
group = [i for i in range(0, dist.get_num_processes())]
|
group = [i for i in range(0, dist.get_world_size())]
|
||||||
group_id = dist.group.WORLD
|
group_id = dist.group.WORLD
|
||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
return (group, group_id, rank)
|
return (group, group_id, rank)
|
||||||
|
|
@ -96,7 +96,7 @@ class _DistTestBase(object):
|
||||||
def test_get_rank(self):
|
def test_get_rank(self):
|
||||||
test_dir = os.path.join(TEMP_DIR, 'test_dir')
|
test_dir = os.path.join(TEMP_DIR, 'test_dir')
|
||||||
pid = str(os.getpid())
|
pid = str(os.getpid())
|
||||||
num_processes = dist.get_num_processes()
|
num_processes = dist.get_world_size()
|
||||||
with open(os.path.join(test_dir, pid), 'w') as f:
|
with open(os.path.join(test_dir, pid), 'w') as f:
|
||||||
f.write(str(dist.get_rank()))
|
f.write(str(dist.get_rank()))
|
||||||
|
|
||||||
|
|
@ -121,12 +121,12 @@ class _DistTestBase(object):
|
||||||
def test_send_recv(self):
|
def test_send_recv(self):
|
||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
tensor = _build_tensor(rank + 1)
|
tensor = _build_tensor(rank + 1)
|
||||||
for dest in range(0, dist.get_num_processes()):
|
for dest in range(0, dist.get_world_size()):
|
||||||
if dest == rank:
|
if dest == rank:
|
||||||
continue
|
continue
|
||||||
dist.send(tensor, dest)
|
dist.send(tensor, dest)
|
||||||
|
|
||||||
for src in range(0, dist.get_num_processes()):
|
for src in range(0, dist.get_world_size()):
|
||||||
if src == rank:
|
if src == rank:
|
||||||
continue
|
continue
|
||||||
tensor = _build_tensor(src + 1, value=-1)
|
tensor = _build_tensor(src + 1, value=-1)
|
||||||
|
|
@ -142,27 +142,27 @@ class _DistTestBase(object):
|
||||||
def test_send_recv_any_source(self):
|
def test_send_recv_any_source(self):
|
||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
tensor = _build_tensor(10, rank)
|
tensor = _build_tensor(10, rank)
|
||||||
for dest in range(0, dist.get_num_processes()):
|
for dest in range(0, dist.get_world_size()):
|
||||||
if dest == rank:
|
if dest == rank:
|
||||||
continue
|
continue
|
||||||
dist.send(tensor, dest)
|
dist.send(tensor, dest)
|
||||||
|
|
||||||
recv_ranks = set()
|
recv_ranks = set()
|
||||||
for src in range(0, dist.get_num_processes()):
|
for src in range(0, dist.get_world_size()):
|
||||||
if src == rank:
|
if src == rank:
|
||||||
continue
|
continue
|
||||||
tensor = _build_tensor(10, value=-1)
|
tensor = _build_tensor(10, value=-1)
|
||||||
dist.recv(tensor)
|
dist.recv(tensor)
|
||||||
recv_ranks.add(tensor.resize_(1)[0])
|
recv_ranks.add(tensor.resize_(1)[0])
|
||||||
|
|
||||||
self.assertEqual(len(recv_ranks), dist.get_num_processes() - 1)
|
self.assertEqual(len(recv_ranks), dist.get_world_size() - 1)
|
||||||
self._barrier()
|
self._barrier()
|
||||||
|
|
||||||
# ISEND
|
# ISEND
|
||||||
@unittest.skipIf(BACKEND == 'gloo', "Gloo does not support isend")
|
@unittest.skipIf(BACKEND == 'gloo', "Gloo does not support isend")
|
||||||
def test_isend(self):
|
def test_isend(self):
|
||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
world_size = dist.get_num_processes()
|
world_size = dist.get_world_size()
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
requests = [
|
requests = [
|
||||||
|
|
@ -182,7 +182,7 @@ class _DistTestBase(object):
|
||||||
@unittest.skipIf(BACKEND == 'gloo', "Gloo does not support irecv")
|
@unittest.skipIf(BACKEND == 'gloo', "Gloo does not support irecv")
|
||||||
def test_irecv(self):
|
def test_irecv(self):
|
||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
world_size = dist.get_num_processes()
|
world_size = dist.get_world_size()
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
expected_tensors = [_build_tensor(src, -1) for src in range(1, world_size)]
|
expected_tensors = [_build_tensor(src, -1) for src in range(1, world_size)]
|
||||||
|
|
|
||||||
|
|
@ -29,41 +29,41 @@ def get_rank():
|
||||||
return torch._C._dist_get_rank()
|
return torch._C._dist_get_rank()
|
||||||
|
|
||||||
|
|
||||||
def get_num_processes():
|
def get_world_size():
|
||||||
assert torch.distributed._initialized
|
assert torch.distributed._initialized
|
||||||
return torch._C._dist_get_num_processes()
|
return torch._C._dist_get_num_processes()
|
||||||
|
|
||||||
|
|
||||||
def isend(tensor, dst_rank):
|
def isend(tensor, dst):
|
||||||
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
||||||
"collective only supported in process-group mode"
|
"collective only supported in process-group mode"
|
||||||
return _DistributedRequest(torch._C._dist_isend(tensor, dst_rank))
|
return _DistributedRequest(torch._C._dist_isend(tensor, dst))
|
||||||
|
|
||||||
|
|
||||||
def irecv(tensor, src_rank):
|
def irecv(tensor, src):
|
||||||
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
||||||
"collective only supported in process-group mode"
|
"collective only supported in process-group mode"
|
||||||
return _DistributedRequest(torch._C._dist_irecv(tensor, src_rank))
|
return _DistributedRequest(torch._C._dist_irecv(tensor, src))
|
||||||
|
|
||||||
|
|
||||||
def send(tensor, dst_rank):
|
def send(tensor, dst):
|
||||||
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
||||||
"collective only supported in process-group mode"
|
"collective only supported in process-group mode"
|
||||||
return torch._C._dist_send(tensor, dst_rank)
|
return torch._C._dist_send(tensor, dst)
|
||||||
|
|
||||||
|
|
||||||
def recv(tensor, src_rank=None):
|
def recv(tensor, src=None):
|
||||||
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
||||||
"collective only supported in process-group mode"
|
"collective only supported in process-group mode"
|
||||||
if src_rank is None:
|
if src is None:
|
||||||
return torch._C._dist_recv_any_source(tensor)
|
return torch._C._dist_recv_any_source(tensor)
|
||||||
return torch._C._dist_recv(tensor, src_rank)
|
return torch._C._dist_recv(tensor, src)
|
||||||
|
|
||||||
|
|
||||||
def broadcast(tensor, src_rank, group=group.WORLD):
|
def broadcast(tensor, src, group=group.WORLD):
|
||||||
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
||||||
"collective only supported in process-group mode"
|
"collective only supported in process-group mode"
|
||||||
return torch._C._dist_broadcast(tensor, src_rank, group)
|
return torch._C._dist_broadcast(tensor, src, group)
|
||||||
|
|
||||||
|
|
||||||
def all_reduce(tensor, op=reduce_op.SUM, group=group.WORLD):
|
def all_reduce(tensor, op=reduce_op.SUM, group=group.WORLD):
|
||||||
|
|
@ -72,40 +72,40 @@ def all_reduce(tensor, op=reduce_op.SUM, group=group.WORLD):
|
||||||
return torch._C._dist_all_reduce(tensor, op, group)
|
return torch._C._dist_all_reduce(tensor, op, group)
|
||||||
|
|
||||||
|
|
||||||
def reduce(tensor, dst_rank, op=reduce_op.SUM, group=group.WORLD):
|
def reduce(tensor, dst, op=reduce_op.SUM, group=group.WORLD):
|
||||||
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
||||||
"collective only supported in process-group mode"
|
"collective only supported in process-group mode"
|
||||||
return torch._C._dist_reduce(tensor, dst_rank, op, group)
|
return torch._C._dist_reduce(tensor, dst, op, group)
|
||||||
|
|
||||||
|
|
||||||
def all_gather(tensors, tensor, group=group.WORLD):
|
def all_gather(tensor_list, tensor, group=group.WORLD):
|
||||||
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
||||||
"collective only supported in process-group mode"
|
"collective only supported in process-group mode"
|
||||||
return torch._C._dist_all_gather(tensors, tensor, group)
|
return torch._C._dist_all_gather(tensor_list, tensor, group)
|
||||||
|
|
||||||
|
|
||||||
def gather_send(tensor, dst_rank, group=group.WORLD):
|
def gather_send(tensor, dst, group=group.WORLD):
|
||||||
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
||||||
"collective only supported in process-group mode"
|
"collective only supported in process-group mode"
|
||||||
return torch._C._dist_gather_send(tensor, dst_rank, group)
|
return torch._C._dist_gather_send(tensor, dst, group)
|
||||||
|
|
||||||
|
|
||||||
def gather_recv(tensors, tensor, group=group.WORLD):
|
def gather_recv(tensor_list, tensor, group=group.WORLD):
|
||||||
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
||||||
"collective only supported in process-group mode"
|
"collective only supported in process-group mode"
|
||||||
return torch._C._dist_gather_recv(tensors, tensor, group)
|
return torch._C._dist_gather_recv(tensor_list, tensor, group)
|
||||||
|
|
||||||
|
|
||||||
def scatter_send(tensors, tensor, group=group.WORLD):
|
def scatter_send(tensor_list, tensor, group=group.WORLD):
|
||||||
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
||||||
"collective only supported in process-group mode"
|
"collective only supported in process-group mode"
|
||||||
return torch._C._dist_scatter_send(tensors, tensor, group)
|
return torch._C._dist_scatter_send(tensor_list, tensor, group)
|
||||||
|
|
||||||
|
|
||||||
def scatter_recv(tensor, src_rank, group=group.WORLD):
|
def scatter_recv(tensor, src, group=group.WORLD):
|
||||||
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
||||||
"collective only supported in process-group mode"
|
"collective only supported in process-group mode"
|
||||||
return torch._C._dist_scatter_recv(tensor, src_rank, group)
|
return torch._C._dist_scatter_recv(tensor, src, group)
|
||||||
|
|
||||||
|
|
||||||
def barrier(group=group.WORLD):
|
def barrier(group=group.WORLD):
|
||||||
|
|
@ -118,5 +118,5 @@ def new_group(ranks=None):
|
||||||
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
assert torch.distributed._initialized == _INITIALIZED_PG, \
|
||||||
"collective only supported in process-group mode"
|
"collective only supported in process-group mode"
|
||||||
if ranks is None:
|
if ranks is None:
|
||||||
ranks = list(range(get_num_processes()))
|
ranks = list(range(get_world_size()))
|
||||||
return torch._C._dist_new_group(ranks)
|
return torch._C._dist_new_group(ranks)
|
||||||
|
|
|
||||||
|
|
@ -134,7 +134,7 @@ if rank == 0:
|
||||||
print_header("scatter")
|
print_header("scatter")
|
||||||
for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
|
for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
|
||||||
tensor = torch.ByteTensor(bytes).fill_(42)
|
tensor = torch.ByteTensor(bytes).fill_(42)
|
||||||
tensors = [tensor for n in range(0, dist.get_num_processes())]
|
tensors = [tensor for n in range(0, dist.get_world_size())]
|
||||||
for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
|
for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
|
||||||
start = timer()
|
start = timer()
|
||||||
for i in range(0, num_tensors):
|
for i in range(0, num_tensors):
|
||||||
|
|
@ -154,7 +154,7 @@ if rank == 0:
|
||||||
print_header("gather")
|
print_header("gather")
|
||||||
for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
|
for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
|
||||||
tensor = torch.ByteTensor(bytes).fill_(42)
|
tensor = torch.ByteTensor(bytes).fill_(42)
|
||||||
tensors = [tensor for n in range(0, dist.get_num_processes())]
|
tensors = [tensor for n in range(0, dist.get_world_size())]
|
||||||
for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
|
for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
|
||||||
start = timer()
|
start = timer()
|
||||||
for i in range(0, num_tensors):
|
for i in range(0, num_tensors):
|
||||||
|
|
@ -174,7 +174,7 @@ if rank == 0:
|
||||||
print_header("all gather")
|
print_header("all gather")
|
||||||
for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
|
for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
|
||||||
tensor = torch.ByteTensor(bytes).fill_(42)
|
tensor = torch.ByteTensor(bytes).fill_(42)
|
||||||
tensors = [tensor for n in range(0, dist.get_num_processes())]
|
tensors = [tensor for n in range(0, dist.get_world_size())]
|
||||||
for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
|
for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
|
||||||
start = timer()
|
start = timer()
|
||||||
for i in range(0, num_tensors):
|
for i in range(0, num_tensors):
|
||||||
|
|
@ -185,7 +185,7 @@ if rank == 0:
|
||||||
else:
|
else:
|
||||||
for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
|
for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
|
||||||
tensor = torch.ByteTensor(bytes).fill_(42)
|
tensor = torch.ByteTensor(bytes).fill_(42)
|
||||||
tensors = [tensor for n in range(0, dist.get_num_processes())]
|
tensors = [tensor for n in range(0, dist.get_world_size())]
|
||||||
for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
|
for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
|
||||||
for i in range(0, num_tensors):
|
for i in range(0, num_tensors):
|
||||||
dist.all_gather(tensors, tensor)
|
dist.all_gather(tensors, tensor)
|
||||||
|
|
|
||||||
|
|
@ -305,7 +305,7 @@ class DistributedDataParallel(Module):
|
||||||
reduce_stream = reduction_streams[0]
|
reduce_stream = reduction_streams[0]
|
||||||
with torch.cuda.stream(reduce_stream):
|
with torch.cuda.stream(reduce_stream):
|
||||||
reduce_stream.wait_stream(nccl_streams[0])
|
reduce_stream.wait_stream(nccl_streams[0])
|
||||||
coalesced /= dist.get_num_processes()
|
coalesced /= dist.get_world_size()
|
||||||
dist.all_reduce(coalesced, group=group_id)
|
dist.all_reduce(coalesced, group=group_id)
|
||||||
for grad, reduced in zip(grad_batch, _unflatten_tensors(coalesced, grad_batch)):
|
for grad, reduced in zip(grad_batch, _unflatten_tensors(coalesced, grad_batch)):
|
||||||
grad.copy_(reduced)
|
grad.copy_(reduced)
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
from .sampler import Sampler
|
from .sampler import Sampler
|
||||||
from torch.distributed.collectives import get_num_processes, get_rank
|
from torch.distributed.collectives import get_world_size, get_rank
|
||||||
|
|
||||||
|
|
||||||
class DistributedSampler(Sampler):
|
class DistributedSampler(Sampler):
|
||||||
|
|
@ -24,7 +24,7 @@ class DistributedSampler(Sampler):
|
||||||
|
|
||||||
def __init__(self, dataset, num_replicas=None, rank=None):
|
def __init__(self, dataset, num_replicas=None, rank=None):
|
||||||
if num_replicas is None:
|
if num_replicas is None:
|
||||||
num_replicas = get_num_processes()
|
num_replicas = get_world_size()
|
||||||
if rank is None:
|
if rank is None:
|
||||||
rank = get_rank()
|
rank = get_rank()
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user