mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Fix https://github.com/pytorch/pytorch/issues/20421 `ProcessGroupGloo` only requires input/output tensors to be contiguous. Contiguous tensors might not start from the beginning of the underlying storage, e.g., `chunk(..., dim=0)[1]`. The current implementation passes `tensor.storage().data()` ptr to gloo buffer. This leads to wrong results if the tensor has a non-zero storage offset. The proposed solution is to use `tensor.data_ptr()` instead. Let's see if this breaks any tests. cc qijianan777 Pull Request resolved: https://github.com/pytorch/pytorch/pull/21490 Differential Revision: D15768907 Pulled By: mrshenli fbshipit-source-id: 9d7d1e9baf0461b31187c7d21a4a53b1fbb07397
213 lines
7.5 KiB
Python
213 lines
7.5 KiB
Python
import sys
|
|
import tempfile
|
|
import unittest
|
|
|
|
import torch
|
|
import torch.distributed as c10d
|
|
import torch.multiprocessing as mp
|
|
|
|
from common_cuda import TEST_MULTIGPU
|
|
from common_utils import TestCase, load_tests, run_tests
|
|
from common_utils import NO_MULTIPROCESSING_SPAWN
|
|
|
|
# load_tests from common_utils is used to automatically filter tests for
|
|
# sharding on sandcastle. This line silences flake warnings
|
|
load_tests = load_tests
|
|
|
|
if not c10d.is_available():
|
|
print('c10d not available, skipping tests')
|
|
sys.exit(0)
|
|
|
|
|
|
if NO_MULTIPROCESSING_SPAWN:
|
|
print('spawn not available, skipping tests')
|
|
sys.exit(0)
|
|
|
|
|
|
NO_NCCL = not hasattr(c10d, "ProcessGroupNCCL")
|
|
|
|
|
|
class ProcessGroupShareTensorTest(TestCase):
|
|
|
|
world_size = 2
|
|
|
|
@classmethod
|
|
def opts(cls, threads=2):
|
|
opts = c10d.ProcessGroupGloo.Options()
|
|
opts.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")]
|
|
opts.timeout = 5.0
|
|
opts.threads = threads
|
|
return opts
|
|
|
|
@classmethod
|
|
def _init_pg_gloo(cls, rank, filename, world_size):
|
|
store = c10d.FileStore(filename, world_size)
|
|
return c10d.ProcessGroupGloo(
|
|
store, rank, world_size, ProcessGroupShareTensorTest.opts())
|
|
|
|
@classmethod
|
|
def _init_pg_nccl(cls, rank, filename, world_size):
|
|
store = c10d.FileStore(filename, world_size)
|
|
return c10d.ProcessGroupNCCL(store, rank, world_size)
|
|
|
|
def _test_multiprocess(self, f, shared_tensors, init_pg, n_output):
|
|
ws = self.world_size
|
|
# file store will delete the test file on destruction
|
|
file = tempfile.NamedTemporaryFile(delete=False)
|
|
ctx = mp.get_context('spawn')
|
|
c2p = ctx.Queue(2)
|
|
p2c = ctx.Queue(2)
|
|
ps = []
|
|
for i in range(ws):
|
|
p = ctx.Process(
|
|
target=f,
|
|
args=(i, file.name, shared_tensors, ws, init_pg, c2p, p2c))
|
|
|
|
p.start()
|
|
ps.append(p)
|
|
|
|
for _ in range(ws * n_output):
|
|
pid, expected, result = c2p.get()
|
|
self.assertEqual(
|
|
expected,
|
|
result,
|
|
(
|
|
"Expect rank {} to receive tensor {} but got {}."
|
|
).format(pid, expected, result)
|
|
)
|
|
|
|
for _ in range(ws):
|
|
p2c.put(0)
|
|
|
|
for p in ps:
|
|
p.join(2)
|
|
|
|
# Why classmethod? multiprocessing cannot pickle TestCase subclass when in
|
|
# spawn mode. See https://bugs.python.org/issue33884.
|
|
@classmethod
|
|
def _test_broadcast_process(
|
|
cls, rank, filename, shared_tensors, world_size, init_pg, c2p, p2c):
|
|
pg = init_pg(rank, filename, world_size)
|
|
xs = [shared_tensors[rank]]
|
|
pg.broadcast(xs).wait()
|
|
c2p.put((rank, torch.zeros(2, 2), xs[0].to("cpu")))
|
|
p2c.get()
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed")
|
|
def test_shared_broadcast_gloo(self):
|
|
self._test_multiprocess(
|
|
ProcessGroupShareTensorTest._test_broadcast_process,
|
|
[torch.ones(2, 2).to(i) * i for i in range(self.world_size)],
|
|
ProcessGroupShareTensorTest._init_pg_gloo,
|
|
1)
|
|
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed")
|
|
@unittest.skipIf(NO_NCCL, "NCCL needed")
|
|
def test_shared_broadcast_nccl(self):
|
|
self._test_multiprocess(
|
|
ProcessGroupShareTensorTest._test_broadcast_process,
|
|
[torch.ones(2, 2).to(i) * i for i in range(self.world_size)],
|
|
ProcessGroupShareTensorTest._init_pg_nccl,
|
|
1)
|
|
|
|
@classmethod
|
|
def _test_allreduce_process(
|
|
cls, rank, filename, shared_tensors, world_size, init_pg, c2p, p2c):
|
|
pg = init_pg(rank, filename, world_size)
|
|
xs = [shared_tensors[rank]]
|
|
pg.allreduce(xs, op=c10d.ReduceOp.SUM).wait()
|
|
c2p.put((rank, torch.ones(2, 2) * 2, xs[0].to("cpu")))
|
|
p2c.get()
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed")
|
|
def test_shared_allreduce_gloo(self):
|
|
self._test_multiprocess(
|
|
ProcessGroupShareTensorTest._test_allreduce_process,
|
|
[torch.ones(2, 2).to(i) for i in range(self.world_size)],
|
|
ProcessGroupShareTensorTest._init_pg_gloo,
|
|
1)
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed")
|
|
@unittest.skipIf(NO_NCCL, "NCCL needed")
|
|
def test_shared_allreduce_nccl(self):
|
|
self._test_multiprocess(
|
|
ProcessGroupShareTensorTest._test_allreduce_process,
|
|
[torch.ones(2, 2).to(i) for i in range(self.world_size)],
|
|
ProcessGroupShareTensorTest._init_pg_nccl,
|
|
1)
|
|
|
|
@classmethod
|
|
def _test_reduce_process(
|
|
cls, rank, filename, shared_tensors, world_size, init_pg, c2p, p2c):
|
|
pg = init_pg(rank, filename, world_size)
|
|
x = shared_tensors[rank]
|
|
pg.reduce(x, root=0, op=c10d.ReduceOp.SUM).wait()
|
|
if rank == 0:
|
|
c2p.put((rank, torch.ones(2, 2) * 2, x.to("cpu")))
|
|
else:
|
|
c2p.put((rank, torch.ones(2, 2), x.to("cpu")))
|
|
p2c.get()
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed")
|
|
@unittest.skipIf(NO_NCCL, "NCCL needed")
|
|
def test_shared_reduce_nccl(self):
|
|
self._test_multiprocess(
|
|
ProcessGroupShareTensorTest._test_reduce_process,
|
|
[torch.ones(2, 2).to(i) for i in range(self.world_size)],
|
|
ProcessGroupShareTensorTest._init_pg_nccl,
|
|
1)
|
|
|
|
@classmethod
|
|
def _test_allgather_process(
|
|
cls, rank, filename, shared_tensors, world_size, init_pg, c2p, p2c):
|
|
pg = init_pg(rank, filename, world_size)
|
|
xs = [shared_tensors[rank]]
|
|
ys = [[torch.zeros_like(xs[0]) for i in range(world_size)]]
|
|
pg.allgather(ys, xs).wait()
|
|
for i in range(world_size):
|
|
c2p.put((rank, torch.ones(2, 2) * i, ys[0][i].to("cpu")))
|
|
|
|
p2c.get()
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed")
|
|
def test_shared_allgather_gloo(self):
|
|
self._test_multiprocess(
|
|
ProcessGroupShareTensorTest._test_allgather_process,
|
|
[torch.ones(2, 2).to(i) * i for i in range(self.world_size)],
|
|
ProcessGroupShareTensorTest._init_pg_gloo,
|
|
self.world_size)
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed")
|
|
@unittest.skipIf(NO_NCCL, "NCCL needed")
|
|
def test_shared_allgather_nccl(self):
|
|
self._test_multiprocess(
|
|
ProcessGroupShareTensorTest._test_allgather_process,
|
|
[torch.ones(2, 2).to(i) * i for i in range(self.world_size)],
|
|
ProcessGroupShareTensorTest._init_pg_nccl,
|
|
self.world_size)
|
|
|
|
@classmethod
|
|
def _test_allgather_chunk_process(
|
|
cls, rank, filename, shared_tensor, world_size, init_pg, c2p, p2c):
|
|
pg = init_pg(rank, filename, world_size)
|
|
chunks = torch.chunk(shared_tensor, world_size, dim=0)
|
|
x = chunks[rank]
|
|
ys = [torch.zeros_like(x) for _ in range(world_size)]
|
|
pg.allgather(ys, x).wait()
|
|
c2p.put((rank, chunks[0].to("cpu"), ys[0].to("cpu")))
|
|
c2p.put((rank, chunks[1].to("cpu"), ys[1].to("cpu")))
|
|
p2c.get()
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed")
|
|
def test_shared_allgather_chunk_gloo(self):
|
|
self._test_multiprocess(
|
|
ProcessGroupShareTensorTest._test_allgather_chunk_process,
|
|
torch.tensor(range(4)).reshape(2, 2),
|
|
ProcessGroupShareTensorTest._init_pg_gloo,
|
|
self.world_size)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|