mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[PyTorch][Distributed] Enable Reduce Scatter and modify all_to_all for sharded linear with more test cases. (#68786)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/68786 To enable the auto grad for the sharded linear, we find we need to make some changes to the current nn function api (c10d api with auto grad enabled). So we made the following several changes: 1. Add a new api `reduce_scatter` since we need it in the rowwise sharding. 2. Modify the `all_to_all` api to make sure it consistent with the ones in distributed_c10d.py. 3. Found the cpp input params of `reduce_scatter` is missing input param, added more unit test to cover these cases. 4. Sync the NN test from gloo to nccl. ghstack-source-id: 144860208 Test Plan: CI + Unit Test Reviewed By: pritamdamania87 Differential Revision: D32569674 fbshipit-source-id: 9bd613f91bbf7a39eede0af32a5a5db0f2ade43b
This commit is contained in:
parent
e032dae329
commit
7c2489bdae
|
|
@ -611,6 +611,35 @@ class ProcessGroupNCCLTest(MultiProcessTestCase):
|
|||
# TODO(#38095): Replace assertEqualIgnoreType. See issue #38095
|
||||
self.assertEqualIgnoreType(expected, output[i])
|
||||
|
||||
# Test the input params overridden scenarios, aka, when the input is
|
||||
# a list and output is just one tensor.
|
||||
# Sum
|
||||
output_tensor = torch.empty_like(input_per_gpu[0][0]).cuda(self.rank)
|
||||
input_list = [tensor[0].cuda(self.rank) for tensor in input_per_gpu]
|
||||
pg.reduce_scatter(output_tensor, input_list, c10d.ReduceOp.SUM).wait()
|
||||
expected = torch.tensor(
|
||||
float((1 + self.world_size) * self.world_size / 2) + self.world_size * self.rank
|
||||
)
|
||||
self.assertEqualIgnoreType(expected, output_tensor)
|
||||
|
||||
# Min
|
||||
pg.reduce_scatter(output_tensor, input_list, c10d.ReduceOp.MIN).wait()
|
||||
expected = torch.tensor(self.rank + 1)
|
||||
self.assertEqualIgnoreType(expected, output_tensor)
|
||||
|
||||
# Max
|
||||
pg.reduce_scatter(output_tensor, input_list, c10d.ReduceOp.MAX).wait()
|
||||
expected = torch.tensor(self.rank + self.world_size)
|
||||
self.assertEqualIgnoreType(expected, output_tensor)
|
||||
|
||||
# Product
|
||||
pg.reduce_scatter(output_tensor, input_list, c10d.ReduceOp.PRODUCT).wait()
|
||||
prod_val = self.rank + 1
|
||||
for k in range(1, self.world_size):
|
||||
prod_val = prod_val * (self.rank + 1 + k)
|
||||
expected = torch.tensor(prod_val)
|
||||
self.assertEqualIgnoreType(expected, output_tensor)
|
||||
|
||||
@requires_nccl()
|
||||
@sandcastle_skip_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
|
||||
def test_reduce_scatter_base_ops(self):
|
||||
|
|
|
|||
|
|
@ -1,13 +1,16 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
import torch.distributed as c10d
|
||||
import torch.multiprocessing as mp
|
||||
from torch.testing._internal.common_utils import NO_MULTIPROCESSING_SPAWN
|
||||
from torch.testing._internal.common_utils import load_tests
|
||||
from torch.testing._internal.common_distributed import \
|
||||
MultiProcessTestCase
|
||||
from torch.testing._internal.common_utils import load_tests,\
|
||||
NO_MULTIPROCESSING_SPAWN
|
||||
|
||||
# Torch distributed.nn is not available in windows
|
||||
# check #42095, it errors on import.
|
||||
|
|
@ -96,3 +99,154 @@ class AbstractProcessGroupShareTensorTest(object):
|
|||
c2p.put((rank, torch.ones(2, 2) * i, ys[0][i].to("cpu")))
|
||||
|
||||
p2c.get()
|
||||
|
||||
|
||||
class TestDistributedNNFunctions(MultiProcessTestCase):
|
||||
def setUp(self):
|
||||
super(TestDistributedNNFunctions, self).setUp()
|
||||
self._spawn_processes()
|
||||
|
||||
def tearDown(self):
|
||||
super(TestDistributedNNFunctions, self).tearDown()
|
||||
try:
|
||||
os.remove(self.file_name)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
@property
|
||||
def op_timeout_sec(self):
|
||||
return 1
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
return 2
|
||||
|
||||
def _test_broadcast(self, backend):
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
# This is required because these functions calls directly to the .dist and needs
|
||||
# the world to be initialized
|
||||
c10d.init_process_group(
|
||||
store=store, rank=self.rank, world_size=self.world_size, backend=backend
|
||||
)
|
||||
device = torch.device(f"cuda:{self.rank}")
|
||||
x = torch.ones(5, 5, device=device) + self.rank
|
||||
x.requires_grad = True
|
||||
y = torch.distributed.nn.broadcast(x, 1)
|
||||
self.assertEqual(y, 1 + torch.ones(5, 5))
|
||||
z = y.sin().sum()
|
||||
z.backward()
|
||||
# We can't check the gradient of communications numerically so we have to do some calculations
|
||||
if self.rank == 1:
|
||||
self.assertEqual(x.grad, 2 * torch.cos(x))
|
||||
elif self.rank == 0:
|
||||
self.assertEqual(x.grad, torch.zeros(5, 5, device=device))
|
||||
|
||||
def _test_reduce(self, backend):
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
# This is required because these functions calls directly to the .dist and needs
|
||||
# the world to be initialized
|
||||
c10d.init_process_group(
|
||||
store=store, rank=self.rank, world_size=self.world_size, backend=backend
|
||||
)
|
||||
device = torch.device(f"cuda:{self.rank}")
|
||||
x = torch.ones(5, 5, device=device) + self.rank
|
||||
x.requires_grad = True
|
||||
y = torch.distributed.nn.reduce(x, 1, op=c10d.ReduceOp.SUM)
|
||||
|
||||
if self.rank == 1:
|
||||
self.assertEqual(y, 3 * torch.ones(5, 5, device=device))
|
||||
|
||||
z = y.sin().sum()
|
||||
z.backward()
|
||||
# Gradients are broadcasted to both ranks
|
||||
x_g = (3 * torch.ones(5, 5, device=device)).cos()
|
||||
self.assertEqual(x.grad, x_g)
|
||||
|
||||
def _test_allreduce(self, backend):
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
# This is required because these functions calls directly to the .dist and needs
|
||||
# the world to be initialized
|
||||
c10d.init_process_group(
|
||||
store=store, rank=self.rank, world_size=self.world_size, backend=backend
|
||||
)
|
||||
device = torch.device(f"cuda:{self.rank}")
|
||||
x = torch.ones(5, 5, device=device) + self.rank
|
||||
x.requires_grad = True
|
||||
y = torch.distributed.nn.all_reduce(x, op=c10d.ReduceOp.SUM)
|
||||
|
||||
self.assertEqual(y, 3 * torch.ones(5, 5, device=device))
|
||||
|
||||
z = y.sin().sum()
|
||||
z.backward()
|
||||
x_g = 2 * (3 * torch.ones(5, 5, device=device)).cos()
|
||||
self.assertEqual(x.grad, x_g)
|
||||
|
||||
def _test_all_gather(self, backend):
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
# This is required because these functions calls directly to the .dist and needs
|
||||
# the world to be initialized
|
||||
c10d.init_process_group(
|
||||
store=store, rank=self.rank, world_size=self.world_size, backend=backend
|
||||
)
|
||||
device = torch.device(f"cuda:{self.rank}")
|
||||
x = torch.ones(5, 5, device=device) + self.rank
|
||||
x.requires_grad = True
|
||||
tensors = torch.distributed.nn.all_gather(x)
|
||||
for i, t in enumerate(tensors):
|
||||
self.assertEqual(t, torch.ones(5, 5, device=device) + i)
|
||||
y = torch.sum(torch.stack(tensors), axis=0)
|
||||
z = y.sin().sum()
|
||||
z.backward()
|
||||
|
||||
x_s = 2 * (3 * torch.ones(5, 5, device=device)).cos()
|
||||
self.assertEqual(x.grad, x_s)
|
||||
|
||||
def _test_all_to_all(self, backend):
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
# This is required because these functions calls directly to the .dist and needs
|
||||
# the world to be initialized
|
||||
c10d.init_process_group(
|
||||
store=store, rank=self.rank, world_size=self.world_size, backend=backend
|
||||
)
|
||||
device = torch.device(f"cuda:{self.rank}")
|
||||
x0 = torch.ones(5, 5, device=device) + 2 * self.rank
|
||||
x1 = torch.ones(5, 5, device=device) + 2 * self.rank
|
||||
x0.requires_grad = True
|
||||
x1.requires_grad = True
|
||||
y0 = torch.empty_like(x0)
|
||||
y1 = torch.empty_like(x1)
|
||||
tensors = torch.distributed.nn.all_to_all([y0, y1], [x0, x1])
|
||||
for i, t in enumerate(tensors):
|
||||
self.assertEqual(t, torch.ones(5, 5, device=device) + 2 * i)
|
||||
y = torch.sum(torch.stack(tensors), axis=0)
|
||||
z = y.sin().sum()
|
||||
z.backward()
|
||||
x_s = (4 * torch.ones(5, 5, device=device)).cos()
|
||||
self.assertEqual(x0.grad, x_s)
|
||||
self.assertEqual(x1.grad, x_s)
|
||||
|
||||
def _test_all_to_all_single(self, backend):
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
# This is required because these functions calls directly to the .dist and needs
|
||||
# the world to be initialized
|
||||
c10d.init_process_group(
|
||||
store=store, rank=self.rank, world_size=self.world_size, backend=backend
|
||||
)
|
||||
device = torch.device(f"cuda:{self.rank}")
|
||||
row = self.world_size * (self.rank + 1) * (self.world_size + 1) / 2
|
||||
x = torch.ones(int(row), 5, device=device) * (self.rank + 1)
|
||||
x.requires_grad = True
|
||||
y = torch.empty_like(x)
|
||||
split_sizes = [(i + 1) * (self.rank + 1) for i in range(self.world_size)]
|
||||
y = torch.distributed.nn.all_to_all_single(
|
||||
y, x, output_split_sizes=split_sizes, input_split_sizes=split_sizes
|
||||
)
|
||||
expected = []
|
||||
for idx, tensor in enumerate(torch.split(x, split_sizes)):
|
||||
expected.append(torch.full_like(tensor, (idx + 1)))
|
||||
expected = torch.cat(expected)
|
||||
self.assertEqual(y, expected)
|
||||
z = y.sin().sum()
|
||||
z.backward()
|
||||
x_s = ((self.rank + 1) * torch.ones(int(row), 5, device=device)).cos()
|
||||
self.assertEqual(x.grad, x_s)
|
||||
|
|
|
|||
|
|
@ -9,10 +9,10 @@ import test_c10d_spawn
|
|||
import torch
|
||||
import torch.distributed as c10d
|
||||
import torch.nn as nn
|
||||
from test_c10d_spawn import _torch_dist_nn_available
|
||||
from test_c10d_spawn import _torch_dist_nn_available, TestDistributedNNFunctions
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU
|
||||
from torch.testing._internal.common_distributed import requires_gloo, \
|
||||
create_device, MultiProcessTestCase, skip_if_lt_x_gpu
|
||||
create_device, skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, sandcastle_skip_if, TEST_WITH_DEV_DBG_ASAN
|
||||
|
||||
# Fails on Python-3.9, see https://github.com/pytorch/pytorch/issues/51619
|
||||
|
|
@ -176,47 +176,45 @@ class DistributedDataParallelSingleProcessTest(TestCase):
|
|||
|
||||
# Skip dev-asan as torch + multiprocessing spawn have known issues
|
||||
if not TEST_WITH_DEV_DBG_ASAN:
|
||||
class TestDistributedNNFunctions(MultiProcessTestCase):
|
||||
def setUp(self):
|
||||
super(TestDistributedNNFunctions, self).setUp()
|
||||
self._spawn_processes()
|
||||
|
||||
def tearDown(self):
|
||||
super(TestDistributedNNFunctions, self).tearDown()
|
||||
try:
|
||||
os.remove(self.file_name)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
@property
|
||||
def op_timeout_sec(self):
|
||||
return 1
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
return 2
|
||||
|
||||
class TestDistributedNNFunctionsGloo(TestDistributedNNFunctions):
|
||||
# Test Common Ops First.
|
||||
@requires_gloo()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
|
||||
def test_broadcast(self):
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
# This is required because these functions calls directly to the .dist and needs
|
||||
# the world to be initialized
|
||||
c10d.init_process_group(store=store, rank=self.rank, world_size=self.world_size, backend='gloo')
|
||||
device = torch.device(f"cuda:{self.rank}")
|
||||
x = torch.ones(5, 5, device=device) + self.rank
|
||||
x.requires_grad = True
|
||||
y = torch.distributed.nn.broadcast(x, 1)
|
||||
self.assertEqual(y, 1 + torch.ones(5, 5))
|
||||
z = y.sin().sum()
|
||||
z.backward()
|
||||
# We can't check the gradient of communications numerically so we have to do some calculations
|
||||
if self.rank == 1:
|
||||
self.assertEqual(x.grad, 2 * torch.cos(x))
|
||||
elif self.rank == 0:
|
||||
self.assertEqual(x.grad, torch.zeros(5, 5, device=device))
|
||||
self._test_broadcast("gloo")
|
||||
|
||||
@requires_gloo()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
|
||||
def test_reduce(self):
|
||||
self._test_reduce("gloo")
|
||||
|
||||
@requires_gloo()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
|
||||
def test_allreduce(self):
|
||||
self._test_allreduce("gloo")
|
||||
|
||||
@requires_gloo()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
|
||||
def test_all_gather(self):
|
||||
self._test_all_gather("gloo")
|
||||
|
||||
@requires_gloo()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
|
||||
def test_all_to_all(self):
|
||||
self._test_all_to_all("gloo")
|
||||
|
||||
@requires_gloo()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
|
||||
def test_all_to_all_single(self):
|
||||
self._test_all_to_all_single("gloo")
|
||||
|
||||
# Test Ops only supported in GLOO.
|
||||
@requires_gloo()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
|
||||
|
|
@ -275,92 +273,6 @@ if not TEST_WITH_DEV_DBG_ASAN:
|
|||
if self.rank == 0:
|
||||
self.assertEqual(x0.grad, torch.zeros(5, 5, device=device))
|
||||
|
||||
@requires_gloo()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
|
||||
def test_reduce(self):
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
# This is required because these functions calls directly to the .dist and needs
|
||||
# the world to be initialized
|
||||
c10d.init_process_group(store=store, rank=self.rank, world_size=self.world_size, backend='gloo')
|
||||
device = torch.device(f"cuda:{self.rank}")
|
||||
x = torch.ones(5, 5, device=device) + self.rank
|
||||
x.requires_grad = True
|
||||
y = torch.distributed.nn.reduce(x, 1, op=c10d.ReduceOp.SUM)
|
||||
|
||||
if self.rank == 1:
|
||||
self.assertEqual(y, 3 * torch.ones(5, 5, device=device))
|
||||
|
||||
z = y.sin().sum()
|
||||
z.backward()
|
||||
# Gradients are broadcasted to both ranks
|
||||
x_g = (3 * torch.ones(5, 5, device=device)).cos()
|
||||
self.assertEqual(x.grad, x_g)
|
||||
|
||||
@requires_gloo()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
|
||||
def test_allreduce(self):
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
# This is required because these functions calls directly to the .dist and needs
|
||||
# the world to be initialized
|
||||
c10d.init_process_group(store=store, rank=self.rank, world_size=self.world_size, backend='gloo')
|
||||
device = torch.device(f"cuda:{self.rank}")
|
||||
x = torch.ones(5, 5, device=device) + self.rank
|
||||
x.requires_grad = True
|
||||
y = torch.distributed.nn.all_reduce(x, op=c10d.ReduceOp.SUM)
|
||||
|
||||
self.assertEqual(y, 3 * torch.ones(5, 5, device=device))
|
||||
|
||||
z = y.sin().sum()
|
||||
z.backward()
|
||||
x_g = 2 * (3 * torch.ones(5, 5, device=device)).cos()
|
||||
self.assertEqual(x.grad, x_g)
|
||||
|
||||
@requires_gloo()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
|
||||
def test_all_gather(self):
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
# This is required because these functions calls directly to the .dist and needs
|
||||
# the world to be initialized
|
||||
c10d.init_process_group(store=store, rank=self.rank, world_size=self.world_size, backend='gloo')
|
||||
device = torch.device(f"cuda:{self.rank}")
|
||||
x = torch.ones(5, 5, device=device) + self.rank
|
||||
x.requires_grad = True
|
||||
tensors = torch.distributed.nn.all_gather(x)
|
||||
for i, t in enumerate(tensors):
|
||||
self.assertEqual(t, torch.ones(5, 5, device=device) + i)
|
||||
y = torch.sum(torch.stack(tensors), axis=0)
|
||||
z = y.sin().sum()
|
||||
z.backward()
|
||||
|
||||
x_s = 2 * (3 * torch.ones(5, 5, device=device)).cos()
|
||||
self.assertEqual(x.grad, x_s)
|
||||
|
||||
@requires_gloo()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
|
||||
def test_all_to_all(self):
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
# This is required because these functions calls directly to the .dist and needs
|
||||
# the world to be initialized
|
||||
c10d.init_process_group(store=store, rank=self.rank, world_size=self.world_size, backend='gloo')
|
||||
device = torch.device(f"cuda:{self.rank}")
|
||||
x0 = torch.ones(5, 5, device=device) + 2 * self.rank
|
||||
x1 = torch.ones(5, 5, device=device) + 2 * self.rank
|
||||
x0.requires_grad = True
|
||||
x1.requires_grad = True
|
||||
tensors = torch.distributed.nn.all_to_all([x0, x1])
|
||||
for i, t in enumerate(tensors):
|
||||
self.assertEqual(t, torch.ones(5, 5, device=device) + 2 * i)
|
||||
y = torch.sum(torch.stack(tensors), axis=0)
|
||||
z = y.sin().sum()
|
||||
z.backward()
|
||||
x_s = (4 * torch.ones(5, 5, device=device)).cos()
|
||||
self.assertEqual(x0.grad, x_s)
|
||||
self.assertEqual(x1.grad, x_s)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -4,15 +4,27 @@ import sys
|
|||
import test_c10d_spawn
|
||||
import torch
|
||||
import torch.distributed as c10d
|
||||
from test_c10d_spawn import _torch_dist_nn_available, TestDistributedNNFunctions
|
||||
from torch.testing._internal.common_cuda import TEST_MULTIGPU
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, sandcastle_skip_if
|
||||
from torch.testing._internal.common_distributed import (
|
||||
requires_nccl,
|
||||
skip_if_lt_x_gpu,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase,
|
||||
run_tests,
|
||||
sandcastle_skip_if,
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
)
|
||||
|
||||
NO_NCCL = not hasattr(c10d, "ProcessGroupNCCL")
|
||||
|
||||
# Fails on Python-3.9, see https://github.com/pytorch/pytorch/issues/51619
|
||||
if sys.version_info < (3, 9):
|
||||
class ProcessGroupShareTensorTest(test_c10d_spawn.AbstractProcessGroupShareTensorTest, TestCase):
|
||||
|
||||
class ProcessGroupShareTensorTest(
|
||||
test_c10d_spawn.AbstractProcessGroupShareTensorTest, TestCase
|
||||
):
|
||||
@classmethod
|
||||
def _init_pg_nccl(cls, rank, filename, world_size):
|
||||
store = c10d.FileStore(filename, world_size)
|
||||
|
|
@ -25,7 +37,8 @@ if sys.version_info < (3, 9):
|
|||
ProcessGroupShareTensorTest._test_broadcast_process,
|
||||
[torch.ones(2, 2).to(i) * i for i in range(self.world_size)],
|
||||
ProcessGroupShareTensorTest._init_pg_nccl,
|
||||
1)
|
||||
1,
|
||||
)
|
||||
|
||||
@sandcastle_skip_if(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed")
|
||||
@sandcastle_skip_if(NO_NCCL, "NCCL needed")
|
||||
|
|
@ -34,11 +47,13 @@ if sys.version_info < (3, 9):
|
|||
ProcessGroupShareTensorTest._test_allreduce_process,
|
||||
[torch.ones(2, 2).to(i) for i in range(self.world_size)],
|
||||
ProcessGroupShareTensorTest._init_pg_nccl,
|
||||
1)
|
||||
1,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _test_reduce_process(
|
||||
cls, rank, filename, shared_tensors, world_size, init_pg, c2p, p2c):
|
||||
cls, rank, filename, shared_tensors, world_size, init_pg, c2p, p2c
|
||||
):
|
||||
pg = init_pg(rank, filename, world_size)
|
||||
x = shared_tensors[rank]
|
||||
pg.reduce(x, root=0, op=c10d.ReduceOp.SUM).wait()
|
||||
|
|
@ -55,7 +70,8 @@ if sys.version_info < (3, 9):
|
|||
ProcessGroupShareTensorTest._test_reduce_process,
|
||||
[torch.ones(2, 2).to(i) for i in range(self.world_size)],
|
||||
ProcessGroupShareTensorTest._init_pg_nccl,
|
||||
1)
|
||||
1,
|
||||
)
|
||||
|
||||
@sandcastle_skip_if(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed")
|
||||
@sandcastle_skip_if(NO_NCCL, "NCCL needed")
|
||||
|
|
@ -64,8 +80,80 @@ if sys.version_info < (3, 9):
|
|||
ProcessGroupShareTensorTest._test_allgather_process,
|
||||
[torch.ones(2, 2).to(i) * i for i in range(self.world_size)],
|
||||
ProcessGroupShareTensorTest._init_pg_nccl,
|
||||
self.world_size)
|
||||
self.world_size,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Skip dev-asan as torch + multiprocessing spawn have known issues
|
||||
if not TEST_WITH_DEV_DBG_ASAN:
|
||||
|
||||
class TestDistributedNNFunctionsNccl(TestDistributedNNFunctions):
|
||||
# Test Common Ops First.
|
||||
@requires_nccl()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@sandcastle_skip_if(
|
||||
not _torch_dist_nn_available, "torch.distributed.nn is not available"
|
||||
)
|
||||
def test_broadcast(self):
|
||||
self._test_broadcast("nccl")
|
||||
|
||||
@requires_nccl()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
|
||||
def test_reduce(self):
|
||||
self._test_reduce("nccl")
|
||||
|
||||
@requires_nccl()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
|
||||
def test_allreduce(self):
|
||||
self._test_allreduce("nccl")
|
||||
|
||||
@requires_nccl()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
|
||||
def test_all_gather(self):
|
||||
self._test_all_gather("nccl")
|
||||
|
||||
@requires_nccl()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
|
||||
def test_all_to_all(self):
|
||||
self._test_all_to_all("nccl")
|
||||
|
||||
@requires_nccl()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
|
||||
def test_all_to_all_single(self):
|
||||
self._test_all_to_all_single("nccl")
|
||||
|
||||
# Test Ops only supported in NCCL.
|
||||
@requires_nccl()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
|
||||
def test_reduce_scatter(self):
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
# This is required because these functions calls directly to the .dist and needs
|
||||
# the world to be initialized
|
||||
c10d.init_process_group(store=store, rank=self.rank, world_size=self.world_size, backend='nccl')
|
||||
device = torch.device(f"cuda:{self.rank}")
|
||||
x0 = torch.ones(5, 5, device=device) + self.rank
|
||||
x1 = torch.ones(5, 5, device=device) + self.rank + 1
|
||||
x0.requires_grad = True
|
||||
x1.requires_grad = True
|
||||
y = torch.empty_like(x0)
|
||||
expected = (1 + self.world_size) * self.world_size / 2 + self.world_size * self.rank
|
||||
y = torch.distributed.nn.reduce_scatter(y, [x0, x1])
|
||||
self.assertEqual(y, torch.ones(5, 5, device=device) * expected)
|
||||
z = y.sin().sum()
|
||||
z.backward()
|
||||
expected_0 = (1 + self.world_size) * self.world_size / 2
|
||||
expected_1 = expected_0 + self.world_size
|
||||
x_s_0 = (expected_0 * torch.ones(5, 5, device=device)).cos()
|
||||
x_s_1 = (expected_1 * torch.ones(5, 5, device=device)).cos()
|
||||
self.assertEqual(x0.grad, x_s_0)
|
||||
self.assertEqual(x1.grad, x_s_1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -1151,14 +1151,17 @@ Arguments:
|
|||
"reduce_scatter",
|
||||
[](::c10d::ProcessGroup& pg,
|
||||
at::Tensor& output,
|
||||
std::vector<at::Tensor>& input) {
|
||||
std::vector<at::Tensor>& input,
|
||||
::c10d::ReduceOp op) {
|
||||
std::vector<at::Tensor> outputs = {output};
|
||||
std::vector<std::vector<at::Tensor>> inputs = {input};
|
||||
return pg.reduce_scatter(
|
||||
outputs, inputs, ::c10d::ReduceScatterOptions());
|
||||
::c10d::ReduceScatterOptions opts;
|
||||
opts.reduceOp = op;
|
||||
return pg.reduce_scatter(outputs, inputs, opts);
|
||||
},
|
||||
py::arg("output_tensors"),
|
||||
py::arg("input_tensor"),
|
||||
py::arg("op") = ::c10d::ReduceOp::SUM,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
|
||||
.def(
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import torch
|
||||
from torch.autograd import Function
|
||||
import torch.distributed as dist
|
||||
from torch.autograd import Function
|
||||
|
||||
|
||||
def broadcast(tensor, src, group=dist.group.WORLD):
|
||||
|
|
@ -79,6 +79,25 @@ def reduce(tensor, dst, op=dist.ReduceOp.SUM, group=dist.group.WORLD):
|
|||
return _Reduce.apply(dst, op, group, tensor)
|
||||
|
||||
|
||||
def reduce_scatter(output, input_list, op=dist.ReduceOp.SUM, group=dist.group.WORLD):
|
||||
"""
|
||||
Reduces, then scatters a list of tensors to all processes in a group.
|
||||
|
||||
Arguments:
|
||||
output (Tensor): Output tensor.
|
||||
input_list (list[Tensor]): List of tensors to reduce and scatter.
|
||||
op (optional): One of the values from
|
||||
``torch.distributed.ReduceOp``
|
||||
enum. Specifies an operation used for element-wise reductions.
|
||||
group (ProcessGroup, optional): The process group to work on.
|
||||
|
||||
Returns:
|
||||
Tensor: Output of the collective.
|
||||
|
||||
"""
|
||||
return _Reduce_Scatter.apply(op, group, output, *input_list)
|
||||
|
||||
|
||||
def all_gather(tensor, group=dist.group.WORLD):
|
||||
"""
|
||||
Gathers tensors from the whole group in a list.
|
||||
|
|
@ -88,26 +107,58 @@ def all_gather(tensor, group=dist.group.WORLD):
|
|||
group (ProcessGroup, optional): The process group to work on.
|
||||
|
||||
Returns:
|
||||
tuple[Tensor]): Output of the collective.
|
||||
tuple([Tensor]): Output of the collective.
|
||||
|
||||
"""
|
||||
return _AllGather.apply(group, tensor)
|
||||
|
||||
|
||||
def all_to_all(tensors, group=dist.group.WORLD):
|
||||
def all_to_all(output_tensor_list, input_tensor_list, group=dist.group.WORLD):
|
||||
"""
|
||||
Each process scatters list of input tensors to all processes in a group and
|
||||
return gathered list of tensors in output list.
|
||||
|
||||
Arguments:
|
||||
tensors (list[Tensor]): List of tensors to scatter one per rank.
|
||||
out_tensor_list (list[Tensor]): list of tensors to gather one per rank.
|
||||
input_tensor_list (list[Tensor]): List of tensors to scatter one per rank.
|
||||
group (ProcessGroup, optional): The process group to work on.
|
||||
|
||||
Returns:
|
||||
tuple[Tensor]): Output of the collective.
|
||||
tuple([Tensor]): Output of the collective.
|
||||
|
||||
"""
|
||||
return _AlltoAll.apply(group, *tensors)
|
||||
return _AlltoAll.apply(group, output_tensor_list, *input_tensor_list)
|
||||
|
||||
|
||||
def all_to_all_single(
|
||||
output,
|
||||
input,
|
||||
output_split_sizes=None,
|
||||
input_split_sizes=None,
|
||||
group=dist.group.WORLD,
|
||||
):
|
||||
"""
|
||||
Each process splits input tensor and then scatters the split list
|
||||
to all processes in a group. Then concatenate the received tensors from all
|
||||
the processes in the group and return single output tensor.
|
||||
|
||||
Arguments:
|
||||
output (Tensor): Gathered cancatenated output tensor.
|
||||
input (Tensor): Input tensor to scatter.
|
||||
output_split_sizes: (list[Int], optional): Output split sizes for dim 0
|
||||
if specified None or empty, dim 0 of ``output`` tensor must divide
|
||||
equally by ``world_size``.
|
||||
input_split_sizes: (list[Int], optional): Input split sizes for dim 0
|
||||
if specified None or empty, dim 0 of ``input`` tensor must divide
|
||||
equally by ``world_size``.
|
||||
|
||||
Returns:
|
||||
Tensor: Output of the collective.
|
||||
|
||||
"""
|
||||
return _AlltoAllSingle.apply(
|
||||
group, output, output_split_sizes, input_split_sizes, input
|
||||
)
|
||||
|
||||
|
||||
def all_reduce(tensor, op=dist.ReduceOp.SUM, group=dist.group.WORLD):
|
||||
|
|
@ -207,6 +258,20 @@ class _Reduce(Function):
|
|||
return (None, None, None) + (_Broadcast.apply(ctx.src, ctx.group, grad_output),)
|
||||
|
||||
|
||||
class _Reduce_Scatter(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, op, group, tensor, *input_tensor_list):
|
||||
ctx.group = group
|
||||
dist.reduce_scatter(tensor, list(input_tensor_list), op=op, group=group)
|
||||
return tensor
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return (None, None, None) + _AllGather.apply(
|
||||
ctx.group, grad_output.contiguous()
|
||||
)
|
||||
|
||||
|
||||
class _AllGather(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, group, tensor):
|
||||
|
|
@ -219,19 +284,19 @@ class _AllGather(Function):
|
|||
|
||||
@staticmethod
|
||||
def backward(ctx, *grad_outputs):
|
||||
gxs = _AlltoAll.apply(ctx.group, *grad_outputs)
|
||||
tensor_list = [torch.empty_like(tensor) for tensor in grad_outputs]
|
||||
gxs = _AlltoAll.apply(ctx.group, tensor_list, *grad_outputs)
|
||||
gx = torch.sum(torch.stack(gxs), dim=0)
|
||||
return (None, gx)
|
||||
|
||||
|
||||
class _AlltoAll(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, group, *tensors):
|
||||
def forward(ctx, group, out_tensor_list, *tensors):
|
||||
ctx.group = group
|
||||
out_tensor_list = [
|
||||
torch.empty_like(tensors[i]) for i in range(dist.get_world_size(group=group))
|
||||
ctx.input_tensor_size_list = [
|
||||
tensors[i].size() for i in range(dist.get_world_size(group=group))
|
||||
]
|
||||
reqs = [None] * dist.get_world_size(group=group)
|
||||
my_rank = dist.get_rank(group=group)
|
||||
# Implement it on means of scatter/gather, send/recv async operations have issues
|
||||
if dist.get_backend(group=group) is dist.Backend.GLOO:
|
||||
|
|
@ -241,12 +306,51 @@ class _AlltoAll(Function):
|
|||
to_send = list(tensors)
|
||||
dist.scatter(out_tensor_list[i], to_send, i, group=group)
|
||||
else:
|
||||
dist.all_to_all(out_tensor_list, list(tensors), group=group)
|
||||
dist.all_to_all(
|
||||
out_tensor_list,
|
||||
list(tensors),
|
||||
group=group,
|
||||
)
|
||||
return tuple(out_tensor_list)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grad_outputs):
|
||||
return (None,) + _AlltoAll.apply(ctx.group, *grad_outputs)
|
||||
tensor_list = [
|
||||
torch.empty(size, device=grad_outputs[0].device)
|
||||
for size in ctx.input_tensor_size_list
|
||||
]
|
||||
grad_outputs = tuple(tensor.contiguous() for tensor in grad_outputs)
|
||||
return (None, None) + _AlltoAll.apply(ctx.group, tensor_list, *grad_outputs)
|
||||
|
||||
|
||||
class _AlltoAllSingle(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, group, output, output_split_sizes, input_split_sizes, input):
|
||||
ctx.group = group
|
||||
ctx.input_size = input.size()
|
||||
ctx.output_split_sizes = input_split_sizes
|
||||
ctx.input_split_sizes = output_split_sizes
|
||||
dist.all_to_all_single(
|
||||
output,
|
||||
input,
|
||||
output_split_sizes=output_split_sizes,
|
||||
input_split_sizes=input_split_sizes,
|
||||
group=group,
|
||||
)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
tensor = torch.empty(ctx.input_size, device=grad_output.device)
|
||||
return (None, None, None, None) + (
|
||||
_AlltoAllSingle.apply(
|
||||
ctx.group,
|
||||
tensor,
|
||||
ctx.output_split_sizes,
|
||||
ctx.input_split_sizes,
|
||||
grad_output.contiguous(),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class _AllReduce(Function):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user