mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/68786 To enable the auto grad for the sharded linear, we find we need to make some changes to the current nn function api (c10d api with auto grad enabled). So we made the following several changes: 1. Add a new api `reduce_scatter` since we need it in the rowwise sharding. 2. Modify the `all_to_all` api to make sure it consistent with the ones in distributed_c10d.py. 3. Found the cpp input params of `reduce_scatter` is missing input param, added more unit test to cover these cases. 4. Sync the NN test from gloo to nccl. ghstack-source-id: 144860208 Test Plan: CI + Unit Test Reviewed By: pritamdamania87 Differential Revision: D32569674 fbshipit-source-id: 9bd613f91bbf7a39eede0af32a5a5db0f2ade43b
253 lines
9.2 KiB
Python
253 lines
9.2 KiB
Python
# Owner(s): ["oncall: distributed"]
|
|
|
|
import os
|
|
import sys
|
|
import tempfile
|
|
|
|
import torch
|
|
import torch.distributed as c10d
|
|
import torch.multiprocessing as mp
|
|
from torch.testing._internal.common_distributed import \
|
|
MultiProcessTestCase
|
|
from torch.testing._internal.common_utils import load_tests,\
|
|
NO_MULTIPROCESSING_SPAWN
|
|
|
|
# Torch distributed.nn is not available in windows
|
|
# check #42095, it errors on import.
|
|
_torch_dist_nn_available = True
|
|
try:
|
|
import torch.distributed.nn
|
|
except ImportError:
|
|
_torch_dist_nn_available = False
|
|
|
|
# load_tests from common_utils is used to automatically filter tests for
|
|
# sharding on sandcastle. This line silences flake warnings
|
|
load_tests = load_tests
|
|
|
|
if not c10d.is_available():
|
|
print('c10d not available, skipping tests', file=sys.stderr)
|
|
sys.exit(0)
|
|
|
|
if NO_MULTIPROCESSING_SPAWN:
|
|
print('spawn not available, skipping tests', file=sys.stderr)
|
|
sys.exit(0)
|
|
|
|
|
|
class AbstractProcessGroupShareTensorTest(object):
|
|
world_size = 2
|
|
|
|
def _test_multiprocess(self, f, shared_tensors, init_pg, n_output):
|
|
ws = self.world_size
|
|
# file store will delete the test file on destruction
|
|
file = tempfile.NamedTemporaryFile(delete=False)
|
|
ctx = mp.get_context('spawn')
|
|
c2p = ctx.Queue(2)
|
|
p2c = ctx.Queue(2)
|
|
ps = []
|
|
for i in range(ws):
|
|
p = ctx.Process(
|
|
target=f,
|
|
args=(i, file.name, shared_tensors, ws, init_pg, c2p, p2c))
|
|
|
|
p.start()
|
|
ps.append(p)
|
|
|
|
for _ in range(ws * n_output):
|
|
pid, expected, result = c2p.get()
|
|
self.assertEqual(
|
|
expected,
|
|
result,
|
|
msg=(
|
|
"Expect rank {} to receive tensor {} but got {}."
|
|
).format(pid, expected, result)
|
|
)
|
|
|
|
for _ in range(ws):
|
|
p2c.put(0)
|
|
|
|
for p in ps:
|
|
p.join(2)
|
|
|
|
# Why classmethod? multiprocessing cannot pickle TestCase subclass when in
|
|
# spawn mode. See https://bugs.python.org/issue33884.
|
|
@classmethod
|
|
def _test_broadcast_process(
|
|
cls, rank, filename, shared_tensors, world_size, init_pg, c2p, p2c):
|
|
pg = init_pg(rank, filename, world_size)
|
|
xs = [shared_tensors[rank]]
|
|
pg.broadcast(xs).wait()
|
|
c2p.put((rank, torch.zeros(2, 2), xs[0].to("cpu")))
|
|
p2c.get()
|
|
|
|
@classmethod
|
|
def _test_allreduce_process(
|
|
cls, rank, filename, shared_tensors, world_size, init_pg, c2p, p2c):
|
|
pg = init_pg(rank, filename, world_size)
|
|
xs = [shared_tensors[rank]]
|
|
pg.allreduce(xs, op=c10d.ReduceOp.SUM).wait()
|
|
c2p.put((rank, torch.ones(2, 2) * 2, xs[0].to("cpu")))
|
|
p2c.get()
|
|
|
|
@classmethod
|
|
def _test_allgather_process(
|
|
cls, rank, filename, shared_tensors, world_size, init_pg, c2p, p2c):
|
|
pg = init_pg(rank, filename, world_size)
|
|
xs = [shared_tensors[rank]]
|
|
ys = [[torch.zeros_like(xs[0]) for i in range(world_size)]]
|
|
pg.allgather(ys, xs).wait()
|
|
for i in range(world_size):
|
|
c2p.put((rank, torch.ones(2, 2) * i, ys[0][i].to("cpu")))
|
|
|
|
p2c.get()
|
|
|
|
|
|
class TestDistributedNNFunctions(MultiProcessTestCase):
|
|
def setUp(self):
|
|
super(TestDistributedNNFunctions, self).setUp()
|
|
self._spawn_processes()
|
|
|
|
def tearDown(self):
|
|
super(TestDistributedNNFunctions, self).tearDown()
|
|
try:
|
|
os.remove(self.file_name)
|
|
except OSError:
|
|
pass
|
|
|
|
@property
|
|
def op_timeout_sec(self):
|
|
return 1
|
|
|
|
@property
|
|
def world_size(self):
|
|
return 2
|
|
|
|
def _test_broadcast(self, backend):
|
|
store = c10d.FileStore(self.file_name, self.world_size)
|
|
# This is required because these functions calls directly to the .dist and needs
|
|
# the world to be initialized
|
|
c10d.init_process_group(
|
|
store=store, rank=self.rank, world_size=self.world_size, backend=backend
|
|
)
|
|
device = torch.device(f"cuda:{self.rank}")
|
|
x = torch.ones(5, 5, device=device) + self.rank
|
|
x.requires_grad = True
|
|
y = torch.distributed.nn.broadcast(x, 1)
|
|
self.assertEqual(y, 1 + torch.ones(5, 5))
|
|
z = y.sin().sum()
|
|
z.backward()
|
|
# We can't check the gradient of communications numerically so we have to do some calculations
|
|
if self.rank == 1:
|
|
self.assertEqual(x.grad, 2 * torch.cos(x))
|
|
elif self.rank == 0:
|
|
self.assertEqual(x.grad, torch.zeros(5, 5, device=device))
|
|
|
|
def _test_reduce(self, backend):
|
|
store = c10d.FileStore(self.file_name, self.world_size)
|
|
# This is required because these functions calls directly to the .dist and needs
|
|
# the world to be initialized
|
|
c10d.init_process_group(
|
|
store=store, rank=self.rank, world_size=self.world_size, backend=backend
|
|
)
|
|
device = torch.device(f"cuda:{self.rank}")
|
|
x = torch.ones(5, 5, device=device) + self.rank
|
|
x.requires_grad = True
|
|
y = torch.distributed.nn.reduce(x, 1, op=c10d.ReduceOp.SUM)
|
|
|
|
if self.rank == 1:
|
|
self.assertEqual(y, 3 * torch.ones(5, 5, device=device))
|
|
|
|
z = y.sin().sum()
|
|
z.backward()
|
|
# Gradients are broadcasted to both ranks
|
|
x_g = (3 * torch.ones(5, 5, device=device)).cos()
|
|
self.assertEqual(x.grad, x_g)
|
|
|
|
def _test_allreduce(self, backend):
|
|
store = c10d.FileStore(self.file_name, self.world_size)
|
|
# This is required because these functions calls directly to the .dist and needs
|
|
# the world to be initialized
|
|
c10d.init_process_group(
|
|
store=store, rank=self.rank, world_size=self.world_size, backend=backend
|
|
)
|
|
device = torch.device(f"cuda:{self.rank}")
|
|
x = torch.ones(5, 5, device=device) + self.rank
|
|
x.requires_grad = True
|
|
y = torch.distributed.nn.all_reduce(x, op=c10d.ReduceOp.SUM)
|
|
|
|
self.assertEqual(y, 3 * torch.ones(5, 5, device=device))
|
|
|
|
z = y.sin().sum()
|
|
z.backward()
|
|
x_g = 2 * (3 * torch.ones(5, 5, device=device)).cos()
|
|
self.assertEqual(x.grad, x_g)
|
|
|
|
def _test_all_gather(self, backend):
|
|
store = c10d.FileStore(self.file_name, self.world_size)
|
|
# This is required because these functions calls directly to the .dist and needs
|
|
# the world to be initialized
|
|
c10d.init_process_group(
|
|
store=store, rank=self.rank, world_size=self.world_size, backend=backend
|
|
)
|
|
device = torch.device(f"cuda:{self.rank}")
|
|
x = torch.ones(5, 5, device=device) + self.rank
|
|
x.requires_grad = True
|
|
tensors = torch.distributed.nn.all_gather(x)
|
|
for i, t in enumerate(tensors):
|
|
self.assertEqual(t, torch.ones(5, 5, device=device) + i)
|
|
y = torch.sum(torch.stack(tensors), axis=0)
|
|
z = y.sin().sum()
|
|
z.backward()
|
|
|
|
x_s = 2 * (3 * torch.ones(5, 5, device=device)).cos()
|
|
self.assertEqual(x.grad, x_s)
|
|
|
|
def _test_all_to_all(self, backend):
|
|
store = c10d.FileStore(self.file_name, self.world_size)
|
|
# This is required because these functions calls directly to the .dist and needs
|
|
# the world to be initialized
|
|
c10d.init_process_group(
|
|
store=store, rank=self.rank, world_size=self.world_size, backend=backend
|
|
)
|
|
device = torch.device(f"cuda:{self.rank}")
|
|
x0 = torch.ones(5, 5, device=device) + 2 * self.rank
|
|
x1 = torch.ones(5, 5, device=device) + 2 * self.rank
|
|
x0.requires_grad = True
|
|
x1.requires_grad = True
|
|
y0 = torch.empty_like(x0)
|
|
y1 = torch.empty_like(x1)
|
|
tensors = torch.distributed.nn.all_to_all([y0, y1], [x0, x1])
|
|
for i, t in enumerate(tensors):
|
|
self.assertEqual(t, torch.ones(5, 5, device=device) + 2 * i)
|
|
y = torch.sum(torch.stack(tensors), axis=0)
|
|
z = y.sin().sum()
|
|
z.backward()
|
|
x_s = (4 * torch.ones(5, 5, device=device)).cos()
|
|
self.assertEqual(x0.grad, x_s)
|
|
self.assertEqual(x1.grad, x_s)
|
|
|
|
def _test_all_to_all_single(self, backend):
|
|
store = c10d.FileStore(self.file_name, self.world_size)
|
|
# This is required because these functions calls directly to the .dist and needs
|
|
# the world to be initialized
|
|
c10d.init_process_group(
|
|
store=store, rank=self.rank, world_size=self.world_size, backend=backend
|
|
)
|
|
device = torch.device(f"cuda:{self.rank}")
|
|
row = self.world_size * (self.rank + 1) * (self.world_size + 1) / 2
|
|
x = torch.ones(int(row), 5, device=device) * (self.rank + 1)
|
|
x.requires_grad = True
|
|
y = torch.empty_like(x)
|
|
split_sizes = [(i + 1) * (self.rank + 1) for i in range(self.world_size)]
|
|
y = torch.distributed.nn.all_to_all_single(
|
|
y, x, output_split_sizes=split_sizes, input_split_sizes=split_sizes
|
|
)
|
|
expected = []
|
|
for idx, tensor in enumerate(torch.split(x, split_sizes)):
|
|
expected.append(torch.full_like(tensor, (idx + 1)))
|
|
expected = torch.cat(expected)
|
|
self.assertEqual(y, expected)
|
|
z = y.sin().sum()
|
|
z.backward()
|
|
x_s = ((self.rank + 1) * torch.ones(int(row), 5, device=device)).cos()
|
|
self.assertEqual(x.grad, x_s)
|