[PyTorch][Distributed] Enable Reduce Scatter and modify all_to_all for sharded linear with more test cases. (#68786)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68786

To enable the auto grad for the sharded linear, we find we need to make some changes to the current nn function api (c10d api with auto grad enabled). So we made the following several changes:

1. Add a new api `reduce_scatter` since we need it in the rowwise sharding.
2. Modify the `all_to_all` api to make sure it consistent with the ones in distributed_c10d.py.
3. Found the cpp input params of `reduce_scatter` is missing input param, added more unit test to cover these cases.
4. Sync the NN test from gloo to nccl.
ghstack-source-id: 144860208

Test Plan: CI + Unit Test

Reviewed By: pritamdamania87

Differential Revision: D32569674

fbshipit-source-id: 9bd613f91bbf7a39eede0af32a5a5db0f2ade43b
This commit is contained in:
Junjie Wang 2021-12-06 13:37:10 -08:00 committed by Facebook GitHub Bot
parent e032dae329
commit 7c2489bdae
6 changed files with 440 additions and 150 deletions

View File

@ -611,6 +611,35 @@ class ProcessGroupNCCLTest(MultiProcessTestCase):
# TODO(#38095): Replace assertEqualIgnoreType. See issue #38095
self.assertEqualIgnoreType(expected, output[i])
# Test the input params overridden scenarios, aka, when the input is
# a list and output is just one tensor.
# Sum
output_tensor = torch.empty_like(input_per_gpu[0][0]).cuda(self.rank)
input_list = [tensor[0].cuda(self.rank) for tensor in input_per_gpu]
pg.reduce_scatter(output_tensor, input_list, c10d.ReduceOp.SUM).wait()
expected = torch.tensor(
float((1 + self.world_size) * self.world_size / 2) + self.world_size * self.rank
)
self.assertEqualIgnoreType(expected, output_tensor)
# Min
pg.reduce_scatter(output_tensor, input_list, c10d.ReduceOp.MIN).wait()
expected = torch.tensor(self.rank + 1)
self.assertEqualIgnoreType(expected, output_tensor)
# Max
pg.reduce_scatter(output_tensor, input_list, c10d.ReduceOp.MAX).wait()
expected = torch.tensor(self.rank + self.world_size)
self.assertEqualIgnoreType(expected, output_tensor)
# Product
pg.reduce_scatter(output_tensor, input_list, c10d.ReduceOp.PRODUCT).wait()
prod_val = self.rank + 1
for k in range(1, self.world_size):
prod_val = prod_val * (self.rank + 1 + k)
expected = torch.tensor(prod_val)
self.assertEqualIgnoreType(expected, output_tensor)
@requires_nccl()
@sandcastle_skip_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
def test_reduce_scatter_base_ops(self):

View File

@ -1,13 +1,16 @@
# Owner(s): ["oncall: distributed"]
import os
import sys
import tempfile
import torch
import torch.distributed as c10d
import torch.multiprocessing as mp
from torch.testing._internal.common_utils import NO_MULTIPROCESSING_SPAWN
from torch.testing._internal.common_utils import load_tests
from torch.testing._internal.common_distributed import \
MultiProcessTestCase
from torch.testing._internal.common_utils import load_tests,\
NO_MULTIPROCESSING_SPAWN
# Torch distributed.nn is not available in windows
# check #42095, it errors on import.
@ -96,3 +99,154 @@ class AbstractProcessGroupShareTensorTest(object):
c2p.put((rank, torch.ones(2, 2) * i, ys[0][i].to("cpu")))
p2c.get()
class TestDistributedNNFunctions(MultiProcessTestCase):
def setUp(self):
super(TestDistributedNNFunctions, self).setUp()
self._spawn_processes()
def tearDown(self):
super(TestDistributedNNFunctions, self).tearDown()
try:
os.remove(self.file_name)
except OSError:
pass
@property
def op_timeout_sec(self):
return 1
@property
def world_size(self):
return 2
def _test_broadcast(self, backend):
store = c10d.FileStore(self.file_name, self.world_size)
# This is required because these functions calls directly to the .dist and needs
# the world to be initialized
c10d.init_process_group(
store=store, rank=self.rank, world_size=self.world_size, backend=backend
)
device = torch.device(f"cuda:{self.rank}")
x = torch.ones(5, 5, device=device) + self.rank
x.requires_grad = True
y = torch.distributed.nn.broadcast(x, 1)
self.assertEqual(y, 1 + torch.ones(5, 5))
z = y.sin().sum()
z.backward()
# We can't check the gradient of communications numerically so we have to do some calculations
if self.rank == 1:
self.assertEqual(x.grad, 2 * torch.cos(x))
elif self.rank == 0:
self.assertEqual(x.grad, torch.zeros(5, 5, device=device))
def _test_reduce(self, backend):
store = c10d.FileStore(self.file_name, self.world_size)
# This is required because these functions calls directly to the .dist and needs
# the world to be initialized
c10d.init_process_group(
store=store, rank=self.rank, world_size=self.world_size, backend=backend
)
device = torch.device(f"cuda:{self.rank}")
x = torch.ones(5, 5, device=device) + self.rank
x.requires_grad = True
y = torch.distributed.nn.reduce(x, 1, op=c10d.ReduceOp.SUM)
if self.rank == 1:
self.assertEqual(y, 3 * torch.ones(5, 5, device=device))
z = y.sin().sum()
z.backward()
# Gradients are broadcasted to both ranks
x_g = (3 * torch.ones(5, 5, device=device)).cos()
self.assertEqual(x.grad, x_g)
def _test_allreduce(self, backend):
store = c10d.FileStore(self.file_name, self.world_size)
# This is required because these functions calls directly to the .dist and needs
# the world to be initialized
c10d.init_process_group(
store=store, rank=self.rank, world_size=self.world_size, backend=backend
)
device = torch.device(f"cuda:{self.rank}")
x = torch.ones(5, 5, device=device) + self.rank
x.requires_grad = True
y = torch.distributed.nn.all_reduce(x, op=c10d.ReduceOp.SUM)
self.assertEqual(y, 3 * torch.ones(5, 5, device=device))
z = y.sin().sum()
z.backward()
x_g = 2 * (3 * torch.ones(5, 5, device=device)).cos()
self.assertEqual(x.grad, x_g)
def _test_all_gather(self, backend):
store = c10d.FileStore(self.file_name, self.world_size)
# This is required because these functions calls directly to the .dist and needs
# the world to be initialized
c10d.init_process_group(
store=store, rank=self.rank, world_size=self.world_size, backend=backend
)
device = torch.device(f"cuda:{self.rank}")
x = torch.ones(5, 5, device=device) + self.rank
x.requires_grad = True
tensors = torch.distributed.nn.all_gather(x)
for i, t in enumerate(tensors):
self.assertEqual(t, torch.ones(5, 5, device=device) + i)
y = torch.sum(torch.stack(tensors), axis=0)
z = y.sin().sum()
z.backward()
x_s = 2 * (3 * torch.ones(5, 5, device=device)).cos()
self.assertEqual(x.grad, x_s)
def _test_all_to_all(self, backend):
store = c10d.FileStore(self.file_name, self.world_size)
# This is required because these functions calls directly to the .dist and needs
# the world to be initialized
c10d.init_process_group(
store=store, rank=self.rank, world_size=self.world_size, backend=backend
)
device = torch.device(f"cuda:{self.rank}")
x0 = torch.ones(5, 5, device=device) + 2 * self.rank
x1 = torch.ones(5, 5, device=device) + 2 * self.rank
x0.requires_grad = True
x1.requires_grad = True
y0 = torch.empty_like(x0)
y1 = torch.empty_like(x1)
tensors = torch.distributed.nn.all_to_all([y0, y1], [x0, x1])
for i, t in enumerate(tensors):
self.assertEqual(t, torch.ones(5, 5, device=device) + 2 * i)
y = torch.sum(torch.stack(tensors), axis=0)
z = y.sin().sum()
z.backward()
x_s = (4 * torch.ones(5, 5, device=device)).cos()
self.assertEqual(x0.grad, x_s)
self.assertEqual(x1.grad, x_s)
def _test_all_to_all_single(self, backend):
store = c10d.FileStore(self.file_name, self.world_size)
# This is required because these functions calls directly to the .dist and needs
# the world to be initialized
c10d.init_process_group(
store=store, rank=self.rank, world_size=self.world_size, backend=backend
)
device = torch.device(f"cuda:{self.rank}")
row = self.world_size * (self.rank + 1) * (self.world_size + 1) / 2
x = torch.ones(int(row), 5, device=device) * (self.rank + 1)
x.requires_grad = True
y = torch.empty_like(x)
split_sizes = [(i + 1) * (self.rank + 1) for i in range(self.world_size)]
y = torch.distributed.nn.all_to_all_single(
y, x, output_split_sizes=split_sizes, input_split_sizes=split_sizes
)
expected = []
for idx, tensor in enumerate(torch.split(x, split_sizes)):
expected.append(torch.full_like(tensor, (idx + 1)))
expected = torch.cat(expected)
self.assertEqual(y, expected)
z = y.sin().sum()
z.backward()
x_s = ((self.rank + 1) * torch.ones(int(row), 5, device=device)).cos()
self.assertEqual(x.grad, x_s)

View File

@ -9,10 +9,10 @@ import test_c10d_spawn
import torch
import torch.distributed as c10d
import torch.nn as nn
from test_c10d_spawn import _torch_dist_nn_available
from test_c10d_spawn import _torch_dist_nn_available, TestDistributedNNFunctions
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU
from torch.testing._internal.common_distributed import requires_gloo, \
create_device, MultiProcessTestCase, skip_if_lt_x_gpu
create_device, skip_if_lt_x_gpu
from torch.testing._internal.common_utils import TestCase, run_tests, sandcastle_skip_if, TEST_WITH_DEV_DBG_ASAN
# Fails on Python-3.9, see https://github.com/pytorch/pytorch/issues/51619
@ -176,47 +176,45 @@ class DistributedDataParallelSingleProcessTest(TestCase):
# Skip dev-asan as torch + multiprocessing spawn have known issues
if not TEST_WITH_DEV_DBG_ASAN:
class TestDistributedNNFunctions(MultiProcessTestCase):
def setUp(self):
super(TestDistributedNNFunctions, self).setUp()
self._spawn_processes()
def tearDown(self):
super(TestDistributedNNFunctions, self).tearDown()
try:
os.remove(self.file_name)
except OSError:
pass
@property
def op_timeout_sec(self):
return 1
@property
def world_size(self):
return 2
class TestDistributedNNFunctionsGloo(TestDistributedNNFunctions):
# Test Common Ops First.
@requires_gloo()
@skip_if_lt_x_gpu(2)
@sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
def test_broadcast(self):
store = c10d.FileStore(self.file_name, self.world_size)
# This is required because these functions calls directly to the .dist and needs
# the world to be initialized
c10d.init_process_group(store=store, rank=self.rank, world_size=self.world_size, backend='gloo')
device = torch.device(f"cuda:{self.rank}")
x = torch.ones(5, 5, device=device) + self.rank
x.requires_grad = True
y = torch.distributed.nn.broadcast(x, 1)
self.assertEqual(y, 1 + torch.ones(5, 5))
z = y.sin().sum()
z.backward()
# We can't check the gradient of communications numerically so we have to do some calculations
if self.rank == 1:
self.assertEqual(x.grad, 2 * torch.cos(x))
elif self.rank == 0:
self.assertEqual(x.grad, torch.zeros(5, 5, device=device))
self._test_broadcast("gloo")
@requires_gloo()
@skip_if_lt_x_gpu(2)
@sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
def test_reduce(self):
self._test_reduce("gloo")
@requires_gloo()
@skip_if_lt_x_gpu(2)
@sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
def test_allreduce(self):
self._test_allreduce("gloo")
@requires_gloo()
@skip_if_lt_x_gpu(2)
@sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
def test_all_gather(self):
self._test_all_gather("gloo")
@requires_gloo()
@skip_if_lt_x_gpu(2)
@sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
def test_all_to_all(self):
self._test_all_to_all("gloo")
@requires_gloo()
@skip_if_lt_x_gpu(2)
@sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
def test_all_to_all_single(self):
self._test_all_to_all_single("gloo")
# Test Ops only supported in GLOO.
@requires_gloo()
@skip_if_lt_x_gpu(2)
@sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
@ -275,92 +273,6 @@ if not TEST_WITH_DEV_DBG_ASAN:
if self.rank == 0:
self.assertEqual(x0.grad, torch.zeros(5, 5, device=device))
@requires_gloo()
@skip_if_lt_x_gpu(2)
@sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
def test_reduce(self):
store = c10d.FileStore(self.file_name, self.world_size)
# This is required because these functions calls directly to the .dist and needs
# the world to be initialized
c10d.init_process_group(store=store, rank=self.rank, world_size=self.world_size, backend='gloo')
device = torch.device(f"cuda:{self.rank}")
x = torch.ones(5, 5, device=device) + self.rank
x.requires_grad = True
y = torch.distributed.nn.reduce(x, 1, op=c10d.ReduceOp.SUM)
if self.rank == 1:
self.assertEqual(y, 3 * torch.ones(5, 5, device=device))
z = y.sin().sum()
z.backward()
# Gradients are broadcasted to both ranks
x_g = (3 * torch.ones(5, 5, device=device)).cos()
self.assertEqual(x.grad, x_g)
@requires_gloo()
@skip_if_lt_x_gpu(2)
@sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
def test_allreduce(self):
store = c10d.FileStore(self.file_name, self.world_size)
# This is required because these functions calls directly to the .dist and needs
# the world to be initialized
c10d.init_process_group(store=store, rank=self.rank, world_size=self.world_size, backend='gloo')
device = torch.device(f"cuda:{self.rank}")
x = torch.ones(5, 5, device=device) + self.rank
x.requires_grad = True
y = torch.distributed.nn.all_reduce(x, op=c10d.ReduceOp.SUM)
self.assertEqual(y, 3 * torch.ones(5, 5, device=device))
z = y.sin().sum()
z.backward()
x_g = 2 * (3 * torch.ones(5, 5, device=device)).cos()
self.assertEqual(x.grad, x_g)
@requires_gloo()
@skip_if_lt_x_gpu(2)
@sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
def test_all_gather(self):
store = c10d.FileStore(self.file_name, self.world_size)
# This is required because these functions calls directly to the .dist and needs
# the world to be initialized
c10d.init_process_group(store=store, rank=self.rank, world_size=self.world_size, backend='gloo')
device = torch.device(f"cuda:{self.rank}")
x = torch.ones(5, 5, device=device) + self.rank
x.requires_grad = True
tensors = torch.distributed.nn.all_gather(x)
for i, t in enumerate(tensors):
self.assertEqual(t, torch.ones(5, 5, device=device) + i)
y = torch.sum(torch.stack(tensors), axis=0)
z = y.sin().sum()
z.backward()
x_s = 2 * (3 * torch.ones(5, 5, device=device)).cos()
self.assertEqual(x.grad, x_s)
@requires_gloo()
@skip_if_lt_x_gpu(2)
@sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
def test_all_to_all(self):
store = c10d.FileStore(self.file_name, self.world_size)
# This is required because these functions calls directly to the .dist and needs
# the world to be initialized
c10d.init_process_group(store=store, rank=self.rank, world_size=self.world_size, backend='gloo')
device = torch.device(f"cuda:{self.rank}")
x0 = torch.ones(5, 5, device=device) + 2 * self.rank
x1 = torch.ones(5, 5, device=device) + 2 * self.rank
x0.requires_grad = True
x1.requires_grad = True
tensors = torch.distributed.nn.all_to_all([x0, x1])
for i, t in enumerate(tensors):
self.assertEqual(t, torch.ones(5, 5, device=device) + 2 * i)
y = torch.sum(torch.stack(tensors), axis=0)
z = y.sin().sum()
z.backward()
x_s = (4 * torch.ones(5, 5, device=device)).cos()
self.assertEqual(x0.grad, x_s)
self.assertEqual(x1.grad, x_s)
if __name__ == '__main__':
run_tests()

View File

@ -4,15 +4,27 @@ import sys
import test_c10d_spawn
import torch
import torch.distributed as c10d
from test_c10d_spawn import _torch_dist_nn_available, TestDistributedNNFunctions
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_utils import TestCase, run_tests, sandcastle_skip_if
from torch.testing._internal.common_distributed import (
requires_nccl,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
sandcastle_skip_if,
TEST_WITH_DEV_DBG_ASAN,
)
NO_NCCL = not hasattr(c10d, "ProcessGroupNCCL")
# Fails on Python-3.9, see https://github.com/pytorch/pytorch/issues/51619
if sys.version_info < (3, 9):
class ProcessGroupShareTensorTest(test_c10d_spawn.AbstractProcessGroupShareTensorTest, TestCase):
class ProcessGroupShareTensorTest(
test_c10d_spawn.AbstractProcessGroupShareTensorTest, TestCase
):
@classmethod
def _init_pg_nccl(cls, rank, filename, world_size):
store = c10d.FileStore(filename, world_size)
@ -25,7 +37,8 @@ if sys.version_info < (3, 9):
ProcessGroupShareTensorTest._test_broadcast_process,
[torch.ones(2, 2).to(i) * i for i in range(self.world_size)],
ProcessGroupShareTensorTest._init_pg_nccl,
1)
1,
)
@sandcastle_skip_if(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed")
@sandcastle_skip_if(NO_NCCL, "NCCL needed")
@ -34,11 +47,13 @@ if sys.version_info < (3, 9):
ProcessGroupShareTensorTest._test_allreduce_process,
[torch.ones(2, 2).to(i) for i in range(self.world_size)],
ProcessGroupShareTensorTest._init_pg_nccl,
1)
1,
)
@classmethod
def _test_reduce_process(
cls, rank, filename, shared_tensors, world_size, init_pg, c2p, p2c):
cls, rank, filename, shared_tensors, world_size, init_pg, c2p, p2c
):
pg = init_pg(rank, filename, world_size)
x = shared_tensors[rank]
pg.reduce(x, root=0, op=c10d.ReduceOp.SUM).wait()
@ -55,7 +70,8 @@ if sys.version_info < (3, 9):
ProcessGroupShareTensorTest._test_reduce_process,
[torch.ones(2, 2).to(i) for i in range(self.world_size)],
ProcessGroupShareTensorTest._init_pg_nccl,
1)
1,
)
@sandcastle_skip_if(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed")
@sandcastle_skip_if(NO_NCCL, "NCCL needed")
@ -64,8 +80,80 @@ if sys.version_info < (3, 9):
ProcessGroupShareTensorTest._test_allgather_process,
[torch.ones(2, 2).to(i) * i for i in range(self.world_size)],
ProcessGroupShareTensorTest._init_pg_nccl,
self.world_size)
self.world_size,
)
if __name__ == '__main__':
# Skip dev-asan as torch + multiprocessing spawn have known issues
if not TEST_WITH_DEV_DBG_ASAN:
class TestDistributedNNFunctionsNccl(TestDistributedNNFunctions):
# Test Common Ops First.
@requires_nccl()
@skip_if_lt_x_gpu(2)
@sandcastle_skip_if(
not _torch_dist_nn_available, "torch.distributed.nn is not available"
)
def test_broadcast(self):
self._test_broadcast("nccl")
@requires_nccl()
@skip_if_lt_x_gpu(2)
@sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
def test_reduce(self):
self._test_reduce("nccl")
@requires_nccl()
@skip_if_lt_x_gpu(2)
@sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
def test_allreduce(self):
self._test_allreduce("nccl")
@requires_nccl()
@skip_if_lt_x_gpu(2)
@sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
def test_all_gather(self):
self._test_all_gather("nccl")
@requires_nccl()
@skip_if_lt_x_gpu(2)
@sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
def test_all_to_all(self):
self._test_all_to_all("nccl")
@requires_nccl()
@skip_if_lt_x_gpu(2)
@sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
def test_all_to_all_single(self):
self._test_all_to_all_single("nccl")
# Test Ops only supported in NCCL.
@requires_nccl()
@skip_if_lt_x_gpu(2)
@sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
def test_reduce_scatter(self):
store = c10d.FileStore(self.file_name, self.world_size)
# This is required because these functions calls directly to the .dist and needs
# the world to be initialized
c10d.init_process_group(store=store, rank=self.rank, world_size=self.world_size, backend='nccl')
device = torch.device(f"cuda:{self.rank}")
x0 = torch.ones(5, 5, device=device) + self.rank
x1 = torch.ones(5, 5, device=device) + self.rank + 1
x0.requires_grad = True
x1.requires_grad = True
y = torch.empty_like(x0)
expected = (1 + self.world_size) * self.world_size / 2 + self.world_size * self.rank
y = torch.distributed.nn.reduce_scatter(y, [x0, x1])
self.assertEqual(y, torch.ones(5, 5, device=device) * expected)
z = y.sin().sum()
z.backward()
expected_0 = (1 + self.world_size) * self.world_size / 2
expected_1 = expected_0 + self.world_size
x_s_0 = (expected_0 * torch.ones(5, 5, device=device)).cos()
x_s_1 = (expected_1 * torch.ones(5, 5, device=device)).cos()
self.assertEqual(x0.grad, x_s_0)
self.assertEqual(x1.grad, x_s_1)
if __name__ == "__main__":
run_tests()

View File

@ -1151,14 +1151,17 @@ Arguments:
"reduce_scatter",
[](::c10d::ProcessGroup& pg,
at::Tensor& output,
std::vector<at::Tensor>& input) {
std::vector<at::Tensor>& input,
::c10d::ReduceOp op) {
std::vector<at::Tensor> outputs = {output};
std::vector<std::vector<at::Tensor>> inputs = {input};
return pg.reduce_scatter(
outputs, inputs, ::c10d::ReduceScatterOptions());
::c10d::ReduceScatterOptions opts;
opts.reduceOp = op;
return pg.reduce_scatter(outputs, inputs, opts);
},
py::arg("output_tensors"),
py::arg("input_tensor"),
py::arg("op") = ::c10d::ReduceOp::SUM,
py::call_guard<py::gil_scoped_release>())
.def(

View File

@ -1,6 +1,6 @@
import torch
from torch.autograd import Function
import torch.distributed as dist
from torch.autograd import Function
def broadcast(tensor, src, group=dist.group.WORLD):
@ -79,6 +79,25 @@ def reduce(tensor, dst, op=dist.ReduceOp.SUM, group=dist.group.WORLD):
return _Reduce.apply(dst, op, group, tensor)
def reduce_scatter(output, input_list, op=dist.ReduceOp.SUM, group=dist.group.WORLD):
"""
Reduces, then scatters a list of tensors to all processes in a group.
Arguments:
output (Tensor): Output tensor.
input_list (list[Tensor]): List of tensors to reduce and scatter.
op (optional): One of the values from
``torch.distributed.ReduceOp``
enum. Specifies an operation used for element-wise reductions.
group (ProcessGroup, optional): The process group to work on.
Returns:
Tensor: Output of the collective.
"""
return _Reduce_Scatter.apply(op, group, output, *input_list)
def all_gather(tensor, group=dist.group.WORLD):
"""
Gathers tensors from the whole group in a list.
@ -88,26 +107,58 @@ def all_gather(tensor, group=dist.group.WORLD):
group (ProcessGroup, optional): The process group to work on.
Returns:
tuple[Tensor]): Output of the collective.
tuple([Tensor]): Output of the collective.
"""
return _AllGather.apply(group, tensor)
def all_to_all(tensors, group=dist.group.WORLD):
def all_to_all(output_tensor_list, input_tensor_list, group=dist.group.WORLD):
"""
Each process scatters list of input tensors to all processes in a group and
return gathered list of tensors in output list.
Arguments:
tensors (list[Tensor]): List of tensors to scatter one per rank.
out_tensor_list (list[Tensor]): list of tensors to gather one per rank.
input_tensor_list (list[Tensor]): List of tensors to scatter one per rank.
group (ProcessGroup, optional): The process group to work on.
Returns:
tuple[Tensor]): Output of the collective.
tuple([Tensor]): Output of the collective.
"""
return _AlltoAll.apply(group, *tensors)
return _AlltoAll.apply(group, output_tensor_list, *input_tensor_list)
def all_to_all_single(
output,
input,
output_split_sizes=None,
input_split_sizes=None,
group=dist.group.WORLD,
):
"""
Each process splits input tensor and then scatters the split list
to all processes in a group. Then concatenate the received tensors from all
the processes in the group and return single output tensor.
Arguments:
output (Tensor): Gathered cancatenated output tensor.
input (Tensor): Input tensor to scatter.
output_split_sizes: (list[Int], optional): Output split sizes for dim 0
if specified None or empty, dim 0 of ``output`` tensor must divide
equally by ``world_size``.
input_split_sizes: (list[Int], optional): Input split sizes for dim 0
if specified None or empty, dim 0 of ``input`` tensor must divide
equally by ``world_size``.
Returns:
Tensor: Output of the collective.
"""
return _AlltoAllSingle.apply(
group, output, output_split_sizes, input_split_sizes, input
)
def all_reduce(tensor, op=dist.ReduceOp.SUM, group=dist.group.WORLD):
@ -207,6 +258,20 @@ class _Reduce(Function):
return (None, None, None) + (_Broadcast.apply(ctx.src, ctx.group, grad_output),)
class _Reduce_Scatter(Function):
@staticmethod
def forward(ctx, op, group, tensor, *input_tensor_list):
ctx.group = group
dist.reduce_scatter(tensor, list(input_tensor_list), op=op, group=group)
return tensor
@staticmethod
def backward(ctx, grad_output):
return (None, None, None) + _AllGather.apply(
ctx.group, grad_output.contiguous()
)
class _AllGather(Function):
@staticmethod
def forward(ctx, group, tensor):
@ -219,19 +284,19 @@ class _AllGather(Function):
@staticmethod
def backward(ctx, *grad_outputs):
gxs = _AlltoAll.apply(ctx.group, *grad_outputs)
tensor_list = [torch.empty_like(tensor) for tensor in grad_outputs]
gxs = _AlltoAll.apply(ctx.group, tensor_list, *grad_outputs)
gx = torch.sum(torch.stack(gxs), dim=0)
return (None, gx)
class _AlltoAll(Function):
@staticmethod
def forward(ctx, group, *tensors):
def forward(ctx, group, out_tensor_list, *tensors):
ctx.group = group
out_tensor_list = [
torch.empty_like(tensors[i]) for i in range(dist.get_world_size(group=group))
ctx.input_tensor_size_list = [
tensors[i].size() for i in range(dist.get_world_size(group=group))
]
reqs = [None] * dist.get_world_size(group=group)
my_rank = dist.get_rank(group=group)
# Implement it on means of scatter/gather, send/recv async operations have issues
if dist.get_backend(group=group) is dist.Backend.GLOO:
@ -241,12 +306,51 @@ class _AlltoAll(Function):
to_send = list(tensors)
dist.scatter(out_tensor_list[i], to_send, i, group=group)
else:
dist.all_to_all(out_tensor_list, list(tensors), group=group)
dist.all_to_all(
out_tensor_list,
list(tensors),
group=group,
)
return tuple(out_tensor_list)
@staticmethod
def backward(ctx, *grad_outputs):
return (None,) + _AlltoAll.apply(ctx.group, *grad_outputs)
tensor_list = [
torch.empty(size, device=grad_outputs[0].device)
for size in ctx.input_tensor_size_list
]
grad_outputs = tuple(tensor.contiguous() for tensor in grad_outputs)
return (None, None) + _AlltoAll.apply(ctx.group, tensor_list, *grad_outputs)
class _AlltoAllSingle(Function):
@staticmethod
def forward(ctx, group, output, output_split_sizes, input_split_sizes, input):
ctx.group = group
ctx.input_size = input.size()
ctx.output_split_sizes = input_split_sizes
ctx.input_split_sizes = output_split_sizes
dist.all_to_all_single(
output,
input,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=group,
)
return output
@staticmethod
def backward(ctx, grad_output):
tensor = torch.empty(ctx.input_size, device=grad_output.device)
return (None, None, None, None) + (
_AlltoAllSingle.apply(
ctx.group,
tensor,
ctx.output_split_sizes,
ctx.input_split_sizes,
grad_output.contiguous(),
),
)
class _AllReduce(Function):