mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Test Plan: revert-hammer Differential Revision: D20346700 Original commit changeset: 12d77b391731 fbshipit-source-id: 108d72bf24232f443c0be293ec932c0c478d6a60
659 lines
25 KiB
Python
659 lines
25 KiB
Python
import contextlib
|
|
import unittest
|
|
from copy import deepcopy
|
|
from collections import OrderedDict
|
|
|
|
import torch
|
|
from torch import nn
|
|
import torch.nn.parallel as dp
|
|
from torch.testing._internal.common_cuda import TEST_MULTIGPU, TEST_CUDA
|
|
from torch.testing._internal.common_utils import run_tests, TestCase, skipIfRocm, repeat_test_for_types, ALL_TENSORTYPES, PY3
|
|
from torch.testing._internal.common_utils import _assertGradAndGradgradChecks
|
|
from torch.testing._internal.common_utils import dtype2prec_DONTUSE
|
|
import torch.nn.functional as F
|
|
|
|
torch.set_default_dtype(torch.double)
|
|
|
|
class TestDataParallel(TestCase):
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
|
def test_data_parallel_buffers_requiring_grad(self):
|
|
class TestModule(nn.Module):
|
|
def __init__(self, t):
|
|
super(TestModule, self).__init__()
|
|
self.register_buffer('t_rg', t)
|
|
self.register_buffer('t_not_rg', t.clone().detach())
|
|
|
|
def forward(self, x):
|
|
return x * self.t_rg + self.t_not_rg
|
|
|
|
m = TestModule(torch.randn(100, device='cuda', requires_grad=True))
|
|
self.assertTrue(m.t_rg.requires_grad)
|
|
|
|
dpm = nn.DataParallel(m, [0, 1])
|
|
inp = torch.randn(2, 100, device='cuda')
|
|
|
|
def fn(t):
|
|
return dpm(inp)
|
|
|
|
torch.autograd.gradcheck(fn, (m.t_rg,))
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
|
def test_data_parallel_rnn(self):
|
|
|
|
class TestModule(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
super(TestModule, self).__init__()
|
|
self.rnn = torch.nn.LSTM(300, 1024, 1, batch_first=True, bidirectional=True)
|
|
|
|
def forward(self, x):
|
|
self.rnn.flatten_parameters()
|
|
return self.rnn(x)
|
|
|
|
def step(model):
|
|
opt = torch.optim.SGD(model.parameters(), lr=10)
|
|
input = torch.ones(4, 4, 300).to(0)
|
|
output = model(input)
|
|
loss = F.mse_loss(output[0], torch.zeros_like(output[0]))
|
|
loss.backward()
|
|
opt.step()
|
|
|
|
with torch.no_grad():
|
|
model = TestModule().to(0)
|
|
model_dp = torch.nn.DataParallel(deepcopy(model))
|
|
|
|
# make sure DP does not crash when grad is disabled.
|
|
# See #21108
|
|
model_dp(torch.rand(2, 4, 300).to(0))
|
|
|
|
step(model)
|
|
step(model_dp)
|
|
|
|
for p1, p2 in zip(model.parameters(), model_dp.parameters()):
|
|
self.assertTrue(p1.allclose(p2))
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
|
def test_parallel_apply(self):
|
|
l1 = nn.Linear(10, 5).to("cuda:0", torch.float)
|
|
l2 = nn.Linear(10, 5).to("cuda:1", torch.float)
|
|
i1 = torch.randn(2, 10, device="cuda:0", dtype=torch.float)
|
|
i2 = torch.randn(2, 10, device="cuda:1", dtype=torch.float)
|
|
expected1 = l1(i1)
|
|
expected2 = l2(i2)
|
|
modules = (l1, l2)
|
|
expected_outputs = (expected1, expected2)
|
|
|
|
# each input can be either a collection of positional arguments
|
|
# or an object representing the single argument
|
|
for inputs in [((i1,), (i2,)), (i1, i2)]:
|
|
outputs = dp.parallel_apply(modules, inputs, None)
|
|
for out, expected in zip(outputs, expected_outputs):
|
|
self.assertEqual(out, expected)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
|
def test_parallel_apply_passes_exception(self):
|
|
# we define and instantiate a module that will throw a KeyError
|
|
class TestModule(nn.Module):
|
|
|
|
def forward(self, *args):
|
|
return {}['wonderful']
|
|
|
|
l1 = TestModule().to("cuda", torch.float)
|
|
# and check that parallel_apply passes on the exception
|
|
# (we can use a single device twice for this test)
|
|
with self.assertRaisesRegex(KeyError,
|
|
'Caught KeyError in replica \\d '
|
|
'on device 0.\nOriginal Traceback'
|
|
'[\\s\\S]+wonderful'):
|
|
dp.parallel_apply(modules=(l1, l1), inputs=(None, None))
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
|
def test_data_parallel_multiple_input(self):
|
|
class TestModule(nn.Module):
|
|
|
|
def forward(self, var1, var2, float1, var3=None):
|
|
if var3 is None:
|
|
return float1 * (var1 * var2)
|
|
else:
|
|
return float1 * (var1 * var2 + var3)
|
|
|
|
m = TestModule()
|
|
var1 = torch.randn(5, 5, dtype=torch.float, requires_grad=True)
|
|
var2 = torch.randn(5, 5, dtype=torch.float, requires_grad=True)
|
|
var3 = torch.randn(5, 5, dtype=torch.float, requires_grad=False)
|
|
|
|
float1 = torch.randn(1).item()
|
|
|
|
expected = m(var1, var2, float1)
|
|
loss = expected.sum()
|
|
loss.backward()
|
|
gvar1_exp = var1.grad.clone()
|
|
gvar2_exp = var2.grad.clone()
|
|
|
|
def local_test(out):
|
|
with torch.no_grad():
|
|
var1.grad.fill_(0.0)
|
|
var2.grad.fill_(0.0)
|
|
loss = out.sum()
|
|
loss.backward()
|
|
self.assertEqual(out, expected)
|
|
self.assertEqual(gvar1_exp, var1.grad)
|
|
self.assertEqual(gvar2_exp, var2.grad)
|
|
|
|
out = dp.data_parallel(m, (var1, var2, float1), (0, 1))
|
|
local_test(out)
|
|
|
|
out = dp.data_parallel(m, (var1, var2, float1), (1, 0))
|
|
local_test(out)
|
|
|
|
out = dp.data_parallel(m, (var1, var2, float1), (0,))
|
|
local_test(out)
|
|
|
|
with torch.no_grad():
|
|
var1.grad.fill_(0.0)
|
|
var2.grad.fill_(0.0)
|
|
expected = m(var1, var2, float1, var3=var3)
|
|
loss = expected.sum()
|
|
loss.backward()
|
|
gvar1_exp = var1.grad.clone()
|
|
gvar2_exp = var2.grad.clone()
|
|
|
|
dpm = nn.DataParallel(TestModule())
|
|
out = dpm(var1, var2, float1, var3=var3)
|
|
local_test(out)
|
|
|
|
dpm = nn.DataParallel(TestModule(), device_ids=[0])
|
|
out = dpm(var1, var2, float1, var3=var3)
|
|
local_test(out)
|
|
|
|
kwarg_wrap = {'var3': var3}
|
|
out = dp.data_parallel(
|
|
m, (var1, var2, float1), (0, 1), module_kwargs=kwarg_wrap)
|
|
local_test(out)
|
|
|
|
out = dp.data_parallel(
|
|
m, (var1, var2, float1), (0,), module_kwargs=kwarg_wrap)
|
|
local_test(out)
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
|
def test_data_parallel_small_back(self):
|
|
l = nn.Linear(10, 5).float().cuda()
|
|
i = torch.randn(20, 10, dtype=torch.float, device="cuda")
|
|
out = dp.data_parallel(l, i, (0, 1))
|
|
self.assertEqual(out, l(i))
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
|
def test_data_parallel_model_device(self):
|
|
r"""Test device[0] check at forward time.
|
|
"""
|
|
l = nn.Linear(2, 2)
|
|
inp = torch.randn(2, 2)
|
|
inp_cuda0 = inp.cuda(0)
|
|
inp_cuda1 = inp.cuda(1)
|
|
|
|
error_msg = "module must have its parameters and buffers on device {}"
|
|
|
|
@contextlib.contextmanager
|
|
def dummy_ctx_manager():
|
|
yield
|
|
|
|
def test(inner_m, dp_device, inp, device_ids, should_fail):
|
|
if device_ids is None:
|
|
device_ids = list(range(torch.cuda.device_count()))
|
|
|
|
if isinstance(device_ids[0], torch.device):
|
|
expect_device = device_ids[0]
|
|
else:
|
|
expect_device = torch.device("cuda:{}".format(device_ids[0]))
|
|
|
|
if should_fail:
|
|
def assert_correct():
|
|
return self.assertRaisesRegex(RuntimeError, error_msg.format(expect_device))
|
|
else:
|
|
assert_correct = dummy_ctx_manager
|
|
|
|
# test DataParallel module
|
|
dpm = nn.DataParallel(inner_m, device_ids)
|
|
if dp_device is not None:
|
|
dpm = dpm.to(dp_device)
|
|
|
|
with assert_correct():
|
|
dpm(inp)
|
|
|
|
# test functional
|
|
with assert_correct():
|
|
nn.parallel.data_parallel(inner_m.to(dp_device), inp, device_ids)
|
|
|
|
test(l.to('cpu'), None, inp, None, should_fail=True)
|
|
test(l.cuda(1), None, inp_cuda0, None, should_fail=True)
|
|
test(l.cuda(), None, inp_cuda0, [1, 0], should_fail=True)
|
|
|
|
test(l.cuda(), None, inp_cuda0, None, should_fail=False)
|
|
test(l.cpu(), 'cuda', inp_cuda0, None, should_fail=False)
|
|
test(l.cuda(1), None, inp_cuda1, [1, 0], should_fail=False)
|
|
test(l.cpu(), 'cuda:1', inp_cuda1, [1, 0], should_fail=False)
|
|
|
|
s = nn.Sequential(l.cpu())
|
|
test(s, None, inp, None, should_fail=True)
|
|
test(s, None, inp, [0, 1], should_fail=True)
|
|
test(s, None, inp, [1, 0], should_fail=True)
|
|
|
|
s = nn.Sequential(deepcopy(l).cpu(), l.cuda())
|
|
test(s, None, inp, None, should_fail=True)
|
|
test(s, None, inp, [0, 1], should_fail=True)
|
|
test(s, None, inp, [1, 0], should_fail=True)
|
|
|
|
s = nn.Sequential(l.cuda(), deepcopy(l).cuda(1))
|
|
test(s, None, inp, None, should_fail=True)
|
|
test(s, None, inp, [0, 1], should_fail=True)
|
|
test(s, None, inp, [1, 0], should_fail=True)
|
|
|
|
s = nn.Sequential(l.cuda(), deepcopy(l).cuda())
|
|
test(s, None, inp, None, should_fail=False)
|
|
test(s, None, inp, [0, 1], should_fail=False)
|
|
test(s, None, inp, [1, 0], should_fail=True)
|
|
test(s.cpu(), None, inp, [1, 0], should_fail=True)
|
|
test(s.cuda(1), None, inp, [1, 0], should_fail=False)
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU or not PY3, "multi-GPU not supported")
|
|
@skipIfRocm
|
|
def test_data_parallel_model_no_refcycles(self):
|
|
# Python 2.7 will create reference cycles with the following
|
|
# Module on multiple GPUs, but Python 3 shouldn't unless
|
|
# there are refcycles on the PyTorch side (or the defined module)
|
|
import gc
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self):
|
|
super(Model, self).__init__()
|
|
self.linear = nn.Linear(1, 1)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
gc.collect()
|
|
model = nn.DataParallel(Model().cuda())
|
|
data = torch.randn(1, device="cuda")
|
|
model(data)
|
|
|
|
refcycles = gc.collect()
|
|
self.assertEqual(refcycles, 0)
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
|
def test_data_parallel_no_grad(self):
|
|
test = self
|
|
|
|
class Layer(nn.Module):
|
|
def forward(self, x):
|
|
test.assertFalse(torch.is_grad_enabled())
|
|
return x
|
|
|
|
l = Layer()
|
|
i = torch.randn(20, 10, dtype=torch.float, device="cuda")
|
|
with torch.no_grad():
|
|
dp.data_parallel(l, i, (0, 1))
|
|
self.assertRaises(AssertionError, lambda: dp.data_parallel(l, i, (0, 1)))
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
|
def test_data_parallel(self):
|
|
l = nn.Linear(10, 5).float().cuda()
|
|
i = torch.randn(20, 10, dtype=torch.float, device="cuda:1")
|
|
l.cuda(1)
|
|
expected_out = l(i)
|
|
loss = expected_out.sum()
|
|
loss.backward()
|
|
expected_grads = []
|
|
for param in l.parameters():
|
|
expected_grads.append(param.grad.clone())
|
|
dev_ids_list = [(0, 1), (1, 0)]
|
|
for dev_id in dev_ids_list:
|
|
with torch.cuda.device(dev_id[0]):
|
|
l.cuda()
|
|
l.zero_grad()
|
|
out = dp.data_parallel(l, i, dev_id)
|
|
loss = out.sum()
|
|
loss.backward()
|
|
self.assertEqual(out.get_device(), dev_id[0])
|
|
self.assertEqual(out, expected_out)
|
|
for expected, param in zip(expected_grads, l.parameters()):
|
|
self.assertEqual(param.grad, expected)
|
|
|
|
# Check for None device_ids
|
|
l = l.cuda()
|
|
out = dp.data_parallel(l, i)
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
|
def test_data_parallel_sparse(self):
|
|
l = nn.Embedding(10, 5, sparse=True).to("cuda:1")
|
|
i = torch.randint(10, (20, 5), device="cuda:1", dtype=torch.long)
|
|
expected_out = l(i)
|
|
loss = expected_out.sum()
|
|
loss.backward()
|
|
expected_grads = []
|
|
for param in l.parameters():
|
|
expected_grads.append(param.grad.clone())
|
|
dev_ids_list = [(0, 1), (1, 0)]
|
|
for dev_id in dev_ids_list:
|
|
with torch.cuda.device(dev_id[0]):
|
|
l.cuda()
|
|
l.zero_grad()
|
|
out = dp.data_parallel(l, i, dev_id)
|
|
loss = out.sum()
|
|
loss.backward()
|
|
self.assertEqual(out.get_device(), dev_id[0])
|
|
self.assertEqual(out, expected_out)
|
|
for expected, param in zip(expected_grads, l.parameters()):
|
|
self.assertEqual(param.grad, expected)
|
|
|
|
# Check for None device_ids
|
|
l = l.cuda()
|
|
out = dp.data_parallel(l, i)
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
|
def test_data_parallel_nested_output(self):
|
|
def fn(input):
|
|
return [
|
|
input, (input.sin(), input.cos(), [input.add(1)]), input,
|
|
OrderedDict(a=input, b=[input.sin()])
|
|
]
|
|
|
|
class Net(nn.Module):
|
|
def forward(self, input):
|
|
return fn(input)
|
|
|
|
i = torch.randn(2, 2).float().cuda(1)
|
|
gpus = range(torch.cuda.device_count())
|
|
output = dp.data_parallel(Net(), i, gpus)
|
|
self.assertEqual(output, fn(i))
|
|
self.assertIsInstance(output[0], torch.Tensor)
|
|
self.assertIsInstance(output[1], tuple)
|
|
self.assertIsInstance(output[1][0], torch.Tensor)
|
|
self.assertIsInstance(output[1][1], torch.Tensor)
|
|
self.assertIsInstance(output[1][2], list)
|
|
self.assertIsInstance(output[1][2][0], torch.Tensor)
|
|
self.assertIsInstance(output[2], torch.Tensor)
|
|
self.assertIsInstance(output[3], dict)
|
|
self.assertEqual(len(output[3]), 2)
|
|
self.assertIn('a', output[3])
|
|
self.assertIn('b', output[3])
|
|
self.assertIsInstance(output[3]['a'], torch.Tensor)
|
|
self.assertIsInstance(output[3]['b'], list)
|
|
self.assertIsInstance(output[3]['b'][0], torch.Tensor)
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
|
def test_data_parallel_nested_input(self):
|
|
def fn(input):
|
|
return input[1][0]
|
|
|
|
class Net(nn.Module):
|
|
def forward(self, *input):
|
|
return fn(input)
|
|
|
|
i = torch.randn(20, 3, dtype=torch.float, device="cuda:1")
|
|
input = (i.cos(), (i.sin(), i), i.sin())
|
|
gpus = range(torch.cuda.device_count())
|
|
output = dp.data_parallel(Net(), input, gpus)
|
|
self.assertEqual(output, fn(input))
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
|
@repeat_test_for_types(ALL_TENSORTYPES)
|
|
def test_data_parallel_module(self, dtype=torch.float):
|
|
l = nn.Linear(10, 5).to("cuda", dtype)
|
|
i = torch.randn(20, 10, device="cuda", dtype=dtype)
|
|
expected_out = l(i)
|
|
net = nn.DataParallel(l)
|
|
out = net(i)
|
|
self.assertEqual(out.get_device(), 0)
|
|
self.assertEqual(out, expected_out, dtype2prec_DONTUSE[dtype])
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
|
@repeat_test_for_types(ALL_TENSORTYPES)
|
|
def test_data_parallel_module_kwargs_only(self, dtype=torch.float):
|
|
class Net(nn.Module):
|
|
def __init__(self):
|
|
super(Net, self).__init__()
|
|
self.l = l
|
|
|
|
def forward(self, input):
|
|
return self.l(input)
|
|
|
|
l = nn.Linear(10, 5).to("cuda", dtype)
|
|
i = torch.randn(20, 10, device="cuda", dtype=dtype)
|
|
expected_out = l(i)
|
|
n = nn.DataParallel(Net())
|
|
out = n(input=i)
|
|
self.assertEqual(out.get_device(), 0)
|
|
self.assertEqual(out, expected_out, dtype2prec_DONTUSE[dtype])
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
|
@repeat_test_for_types(ALL_TENSORTYPES)
|
|
def test_data_parallel_module_kwargs_only_empty_list(self, dtype=torch.float):
|
|
class Net(nn.Module):
|
|
def __init__(self):
|
|
super(Net, self).__init__()
|
|
self.l = l
|
|
|
|
def forward(self, input):
|
|
return self.l(input['data'])
|
|
|
|
l = nn.Linear(10, 5).to("cuda", dtype)
|
|
i = torch.randn(20, 10, device="cuda", dtype=dtype)
|
|
expected_out = l(i)
|
|
n = nn.DataParallel(Net())
|
|
out = n(input={'data': i, 'unused': []})
|
|
self.assertEqual(out.get_device(), 0)
|
|
self.assertEqual(out, expected_out, dtype2prec_DONTUSE[dtype])
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
|
@repeat_test_for_types(ALL_TENSORTYPES)
|
|
def test_data_parallel_module_kwargs_only_empty_dict(self, dtype=torch.float):
|
|
class Net(nn.Module):
|
|
def __init__(self):
|
|
super(Net, self).__init__()
|
|
self.l = l
|
|
|
|
def forward(self, input):
|
|
return self.l(input['data'])
|
|
|
|
l = nn.Linear(10, 5).to("cuda", dtype)
|
|
i = torch.randn(20, 10, device="cuda", dtype=dtype)
|
|
expected_out = l(i)
|
|
n = nn.DataParallel(Net())
|
|
out = n(input={'data': i, 'unused': {}})
|
|
self.assertEqual(out.get_device(), 0)
|
|
self.assertEqual(out, expected_out, dtype2prec_DONTUSE[dtype])
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
|
@repeat_test_for_types(ALL_TENSORTYPES)
|
|
def test_data_parallel_module_kwargs_only_empty_tuple(self, dtype=torch.float):
|
|
class Net(nn.Module):
|
|
def __init__(self):
|
|
super(Net, self).__init__()
|
|
self.l = l
|
|
|
|
def forward(self, input):
|
|
return self.l(input['data'])
|
|
|
|
l = nn.Linear(10, 5).to("cuda", dtype)
|
|
i = torch.randn(20, 10, device="cuda", dtype=dtype)
|
|
expected_out = l(i)
|
|
n = nn.DataParallel(Net())
|
|
out = n(input={'data': i, 'unused': ()})
|
|
self.assertEqual(out.get_device(), 0)
|
|
self.assertEqual(out, expected_out, dtype2prec_DONTUSE[dtype])
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
|
def test_data_parallel_device_args(self):
|
|
cuda0 = torch.device('cuda:0')
|
|
cuda1 = torch.device('cuda:1')
|
|
|
|
# test output_device
|
|
l = nn.Linear(10, 5).to(cuda0, torch.float)
|
|
i = torch.randn(20, 10, dtype=torch.float, device=cuda0, requires_grad=True)
|
|
out = dp.data_parallel(l, i, device_ids=(0, 1), output_device=cuda0)
|
|
self.assertEqual(out, l(i))
|
|
|
|
# test device_ids
|
|
l = nn.Linear(10, 5).to(cuda0, torch.float)
|
|
i = torch.randn(20, 10, dtype=torch.float, device=cuda0, requires_grad=True)
|
|
out = dp.data_parallel(l, i, device_ids=(cuda0, cuda1), output_device=cuda0)
|
|
self.assertEqual(out, l(i))
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
|
def test_data_parallel_function_deletion(self):
|
|
# this test case is originated from #16532
|
|
def gradient_penalty(net, x):
|
|
output = net(x)
|
|
loss = torch.autograd.grad(
|
|
outputs=output, inputs=x,
|
|
grad_outputs=x.new_ones(output.size()),
|
|
create_graph=True, retain_graph=True)[0].mean()
|
|
return loss
|
|
|
|
net = nn.Linear(4, 1).cuda()
|
|
dpn = nn.DataParallel(net, [0, 1])
|
|
x = torch.ones(2, 4, requires_grad=True).cuda()
|
|
|
|
dpn.zero_grad()
|
|
loss = gradient_penalty(dpn, x)
|
|
loss.backward()
|
|
grads = [p.grad for p in net.parameters()]
|
|
self.assertEqual(2, len(grads))
|
|
self.assertEqual(
|
|
torch.tensor([[0.25, 0.25, 0.25, 0.25]], device='cuda:0'),
|
|
grads[0])
|
|
self.assertEqual(torch.tensor([0.0], device='cuda:0'), grads[1])
|
|
|
|
def _test_scatter(self, tensor):
|
|
x = tensor.detach().requires_grad_()
|
|
result = dp.scatter(x, (0, 1))
|
|
self.assertEqual(len(result), 2)
|
|
self.assertEqual(result[0], x[:2])
|
|
self.assertEqual(result[0].get_device(), 0)
|
|
self.assertEqual(result[1], x[2:])
|
|
self.assertEqual(result[1].get_device(), 1)
|
|
grad = result[0].detach().clone().fill_(2)
|
|
result[0].backward(grad)
|
|
self.assertEqual(x.grad[:2], grad)
|
|
self.assertEqual(x.grad[2:], grad.clone().zero_())
|
|
_assertGradAndGradgradChecks(self, lambda y: dp.scatter(y, (0, 1)), (x,))
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
|
def test_scatter_cpu(self):
|
|
self._test_scatter(torch.randn((4, 4)))
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
|
def test_scatter_gpu(self):
|
|
self._test_scatter(torch.randn((4, 4)).cuda())
|
|
|
|
def _test_gather(self, output_device):
|
|
inputs = (
|
|
torch.randn(2, 4, device='cuda:0', requires_grad=True),
|
|
torch.randn(2, 4, device='cuda:1', requires_grad=True),
|
|
)
|
|
result = dp.gather(inputs, output_device)
|
|
self.assertEqual(result.size(), torch.Size([4, 4]))
|
|
self.assertEqual(result[:2], inputs[0])
|
|
self.assertEqual(result[2:], inputs[1])
|
|
if output_device != -1:
|
|
self.assertEqual(result.get_device(), output_device)
|
|
else:
|
|
self.assertFalse(result.is_cuda)
|
|
grad = torch.randn((4, 4))
|
|
if output_device != -1:
|
|
grad = grad.cuda(output_device)
|
|
result.backward(grad)
|
|
self.assertEqual(inputs[0].grad, grad[:2])
|
|
self.assertEqual(inputs[1].grad, grad[2:])
|
|
_assertGradAndGradgradChecks(self, lambda x, y: dp.gather((x, y), output_device), inputs)
|
|
|
|
# test scalar inputs, should stack into a vector in this case
|
|
inputs = (
|
|
torch.randn((), device='cuda:0', requires_grad=True),
|
|
torch.randn((), device='cuda:1', requires_grad=True),
|
|
)
|
|
result = dp.gather(inputs, output_device)
|
|
self.assertEqual(result.size(), torch.Size([2]))
|
|
self.assertEqual(result[0], inputs[0])
|
|
self.assertEqual(result[1], inputs[1])
|
|
if output_device != -1:
|
|
self.assertEqual(result.get_device(), output_device)
|
|
else:
|
|
self.assertFalse(result.is_cuda)
|
|
grad = torch.randn(2)
|
|
if output_device != -1:
|
|
grad = grad.cuda(output_device)
|
|
result.backward(grad)
|
|
self.assertEqual(inputs[0].grad, grad[0])
|
|
self.assertEqual(inputs[1].grad, grad[1])
|
|
_assertGradAndGradgradChecks(self, lambda x, y: dp.gather((x, y), output_device), inputs)
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
|
def test_gather_cpu(self):
|
|
self._test_gather(-1)
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
|
def test_gather_gpu(self):
|
|
self._test_gather(0)
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
|
def test_gather_different_len_dicts(self):
|
|
inputs = (
|
|
{'a': torch.randn(1, 2, requires_grad=True, device="cuda:0")},
|
|
{
|
|
'b': torch.randn(1, 2, requires_grad=True, device="cuda:1"),
|
|
'a': torch.randn(1, 2, requires_grad=True, device="cuda:1"),
|
|
}
|
|
)
|
|
with self.assertRaises(ValueError):
|
|
_ = dp.gather(inputs, target_device=0)
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
|
def test_replicate(self):
|
|
module = nn.Linear(10, 5).float().cuda()
|
|
input = torch.randn(2, 10, dtype=torch.float, device="cuda")
|
|
expected_output = module(input)
|
|
for devices in [(0, 1), [0, 1]]:
|
|
replicas = dp.replicate(module, devices)
|
|
for i, replica in enumerate(replicas):
|
|
for p in replica.parameters():
|
|
self.assertEqual(p.get_device(), i)
|
|
replica_input = input.cuda(i)
|
|
self.assertEqual(replica(replica_input), expected_output)
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
|
def test_replicate_buffers(self):
|
|
net = nn.Module()
|
|
net.bn = nn.BatchNorm2d(10)
|
|
net.cuda()
|
|
for devices in [(0, 1), [0, 1]]:
|
|
replicas = dp.replicate(net, devices)
|
|
for i, replica in enumerate(replicas):
|
|
self.assertEqual(replica.bn.running_mean.get_device(), i, 'buffer on wrong device')
|
|
self.assertEqual(replica.bn.running_var.get_device(), i, 'buffer on wrong device')
|
|
self.assertEqual(replica.bn.num_batches_tracked.get_device(), i, 'buffer on wrong device')
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
|
def test_zero_grad(self):
|
|
# zero_grad should warn about using gradients inside forward
|
|
|
|
class Net(torch.nn.Module):
|
|
def __init__(self, testcase):
|
|
super(Net, self).__init__()
|
|
self._testcase = testcase
|
|
|
|
def forward(self, x):
|
|
self._testcase.assertWarnsRegex(
|
|
lambda: self.zero_grad(),
|
|
r"Calling \.zero_grad\(\) from a module that was passed to a nn\.DataParallel\(\) has no effect.")
|
|
return x
|
|
|
|
module = Net(self).cuda()
|
|
dpm = dp.DataParallel(module)
|
|
dpm(torch.rand(4, 3, 6, 5))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|