mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Initial integration of eager autocasting, supporting out-of-place ops only for easier review. Relevant issue/RFC: https://github.com/pytorch/pytorch/issues/25081 In-place ops and ops with user-supplied `out=...` can certainly be supported as well (my initial WIP https://github.com/pytorch/pytorch/pull/29552 handled many) but require substantially more complex special casing in the autocasting backend and tests. Support for these ops (much of which has already been written) will be broken into later PRs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/32140 Differential Revision: D20346700 Pulled By: ezyang fbshipit-source-id: 12d77b3917310186fbddf11c59b2794dc859131f
674 lines
26 KiB
Python
674 lines
26 KiB
Python
import contextlib
|
|
import unittest
|
|
from copy import deepcopy
|
|
from collections import OrderedDict
|
|
|
|
import torch
|
|
from torch import nn
|
|
import torch.nn.parallel as dp
|
|
from torch.testing._internal.common_cuda import TEST_MULTIGPU, TEST_CUDA
|
|
from torch.testing._internal.common_utils import run_tests, TestCase, skipIfRocm, repeat_test_for_types, ALL_TENSORTYPES, PY3
|
|
from torch.testing._internal.common_utils import _assertGradAndGradgradChecks
|
|
from torch.testing._internal.common_utils import dtype2prec_DONTUSE
|
|
import torch.nn.functional as F
|
|
|
|
torch.set_default_dtype(torch.double)
|
|
|
|
class TestDataParallel(TestCase):
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
|
def test_data_parallel_buffers_requiring_grad(self):
|
|
class TestModule(nn.Module):
|
|
def __init__(self, t):
|
|
super(TestModule, self).__init__()
|
|
self.register_buffer('t_rg', t)
|
|
self.register_buffer('t_not_rg', t.clone().detach())
|
|
|
|
def forward(self, x):
|
|
return x * self.t_rg + self.t_not_rg
|
|
|
|
m = TestModule(torch.randn(100, device='cuda', requires_grad=True))
|
|
self.assertTrue(m.t_rg.requires_grad)
|
|
|
|
dpm = nn.DataParallel(m, [0, 1])
|
|
inp = torch.randn(2, 100, device='cuda')
|
|
|
|
def fn(t):
|
|
return dpm(inp)
|
|
|
|
torch.autograd.gradcheck(fn, (m.t_rg,))
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
|
def test_data_parallel_rnn(self):
|
|
|
|
class TestModule(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
super(TestModule, self).__init__()
|
|
self.rnn = torch.nn.LSTM(300, 1024, 1, batch_first=True, bidirectional=True)
|
|
|
|
def forward(self, x):
|
|
self.rnn.flatten_parameters()
|
|
return self.rnn(x)
|
|
|
|
def step(model):
|
|
opt = torch.optim.SGD(model.parameters(), lr=10)
|
|
input = torch.ones(4, 4, 300).to(0)
|
|
output = model(input)
|
|
loss = F.mse_loss(output[0], torch.zeros_like(output[0]))
|
|
loss.backward()
|
|
opt.step()
|
|
|
|
with torch.no_grad():
|
|
model = TestModule().to(0)
|
|
model_dp = torch.nn.DataParallel(deepcopy(model))
|
|
|
|
# make sure DP does not crash when grad is disabled.
|
|
# See #21108
|
|
model_dp(torch.rand(2, 4, 300).to(0))
|
|
|
|
step(model)
|
|
step(model_dp)
|
|
|
|
for p1, p2 in zip(model.parameters(), model_dp.parameters()):
|
|
self.assertTrue(p1.allclose(p2))
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
|
def test_parallel_apply(self):
|
|
l1 = nn.Linear(10, 5).to("cuda:0", torch.float)
|
|
l2 = nn.Linear(10, 5).to("cuda:1", torch.float)
|
|
i1 = torch.randn(2, 10, device="cuda:0", dtype=torch.float)
|
|
i2 = torch.randn(2, 10, device="cuda:1", dtype=torch.float)
|
|
expected1 = l1(i1)
|
|
expected2 = l2(i2)
|
|
modules = (l1, l2)
|
|
expected_outputs = (expected1, expected2)
|
|
|
|
# each input can be either a collection of positional arguments
|
|
# or an object representing the single argument
|
|
for inputs in [((i1,), (i2,)), (i1, i2)]:
|
|
outputs = dp.parallel_apply(modules, inputs, None)
|
|
for out, expected in zip(outputs, expected_outputs):
|
|
self.assertEqual(out, expected)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
|
def test_parallel_apply_passes_exception(self):
|
|
# we define and instantiate a module that will throw a KeyError
|
|
class TestModule(nn.Module):
|
|
|
|
def forward(self, *args):
|
|
return {}['wonderful']
|
|
|
|
l1 = TestModule().to("cuda", torch.float)
|
|
# and check that parallel_apply passes on the exception
|
|
# (we can use a single device twice for this test)
|
|
with self.assertRaisesRegex(KeyError,
|
|
'Caught KeyError in replica \\d '
|
|
'on device 0.\nOriginal Traceback'
|
|
'[\\s\\S]+wonderful'):
|
|
dp.parallel_apply(modules=(l1, l1), inputs=(None, None))
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
|
def test_data_parallel_multiple_input(self):
|
|
class TestModule(nn.Module):
|
|
|
|
def forward(self, var1, var2, float1, var3=None):
|
|
if var3 is None:
|
|
return float1 * (var1 * var2)
|
|
else:
|
|
return float1 * (var1 * var2 + var3)
|
|
|
|
m = TestModule()
|
|
var1 = torch.randn(5, 5, dtype=torch.float, requires_grad=True)
|
|
var2 = torch.randn(5, 5, dtype=torch.float, requires_grad=True)
|
|
var3 = torch.randn(5, 5, dtype=torch.float, requires_grad=False)
|
|
|
|
float1 = torch.randn(1).item()
|
|
|
|
expected = m(var1, var2, float1)
|
|
loss = expected.sum()
|
|
loss.backward()
|
|
gvar1_exp = var1.grad.clone()
|
|
gvar2_exp = var2.grad.clone()
|
|
|
|
def local_test(out):
|
|
with torch.no_grad():
|
|
var1.grad.fill_(0.0)
|
|
var2.grad.fill_(0.0)
|
|
loss = out.sum()
|
|
loss.backward()
|
|
self.assertEqual(out, expected)
|
|
self.assertEqual(gvar1_exp, var1.grad)
|
|
self.assertEqual(gvar2_exp, var2.grad)
|
|
|
|
out = dp.data_parallel(m, (var1, var2, float1), (0, 1))
|
|
local_test(out)
|
|
|
|
out = dp.data_parallel(m, (var1, var2, float1), (1, 0))
|
|
local_test(out)
|
|
|
|
out = dp.data_parallel(m, (var1, var2, float1), (0,))
|
|
local_test(out)
|
|
|
|
with torch.no_grad():
|
|
var1.grad.fill_(0.0)
|
|
var2.grad.fill_(0.0)
|
|
expected = m(var1, var2, float1, var3=var3)
|
|
loss = expected.sum()
|
|
loss.backward()
|
|
gvar1_exp = var1.grad.clone()
|
|
gvar2_exp = var2.grad.clone()
|
|
|
|
dpm = nn.DataParallel(TestModule())
|
|
out = dpm(var1, var2, float1, var3=var3)
|
|
local_test(out)
|
|
|
|
dpm = nn.DataParallel(TestModule(), device_ids=[0])
|
|
out = dpm(var1, var2, float1, var3=var3)
|
|
local_test(out)
|
|
|
|
kwarg_wrap = {'var3': var3}
|
|
out = dp.data_parallel(
|
|
m, (var1, var2, float1), (0, 1), module_kwargs=kwarg_wrap)
|
|
local_test(out)
|
|
|
|
out = dp.data_parallel(
|
|
m, (var1, var2, float1), (0,), module_kwargs=kwarg_wrap)
|
|
local_test(out)
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
|
def test_data_parallel_small_back(self):
|
|
l = nn.Linear(10, 5).float().cuda()
|
|
i = torch.randn(20, 10, dtype=torch.float, device="cuda")
|
|
out = dp.data_parallel(l, i, (0, 1))
|
|
self.assertEqual(out, l(i))
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
|
def test_data_parallel_model_device(self):
|
|
r"""Test device[0] check at forward time.
|
|
"""
|
|
l = nn.Linear(2, 2)
|
|
inp = torch.randn(2, 2)
|
|
inp_cuda0 = inp.cuda(0)
|
|
inp_cuda1 = inp.cuda(1)
|
|
|
|
error_msg = "module must have its parameters and buffers on device {}"
|
|
|
|
@contextlib.contextmanager
|
|
def dummy_ctx_manager():
|
|
yield
|
|
|
|
def test(inner_m, dp_device, inp, device_ids, should_fail):
|
|
if device_ids is None:
|
|
device_ids = list(range(torch.cuda.device_count()))
|
|
|
|
if isinstance(device_ids[0], torch.device):
|
|
expect_device = device_ids[0]
|
|
else:
|
|
expect_device = torch.device("cuda:{}".format(device_ids[0]))
|
|
|
|
if should_fail:
|
|
def assert_correct():
|
|
return self.assertRaisesRegex(RuntimeError, error_msg.format(expect_device))
|
|
else:
|
|
assert_correct = dummy_ctx_manager
|
|
|
|
# test DataParallel module
|
|
dpm = nn.DataParallel(inner_m, device_ids)
|
|
if dp_device is not None:
|
|
dpm = dpm.to(dp_device)
|
|
|
|
with assert_correct():
|
|
dpm(inp)
|
|
|
|
# test functional
|
|
with assert_correct():
|
|
nn.parallel.data_parallel(inner_m.to(dp_device), inp, device_ids)
|
|
|
|
test(l.to('cpu'), None, inp, None, should_fail=True)
|
|
test(l.cuda(1), None, inp_cuda0, None, should_fail=True)
|
|
test(l.cuda(), None, inp_cuda0, [1, 0], should_fail=True)
|
|
|
|
test(l.cuda(), None, inp_cuda0, None, should_fail=False)
|
|
test(l.cpu(), 'cuda', inp_cuda0, None, should_fail=False)
|
|
test(l.cuda(1), None, inp_cuda1, [1, 0], should_fail=False)
|
|
test(l.cpu(), 'cuda:1', inp_cuda1, [1, 0], should_fail=False)
|
|
|
|
s = nn.Sequential(l.cpu())
|
|
test(s, None, inp, None, should_fail=True)
|
|
test(s, None, inp, [0, 1], should_fail=True)
|
|
test(s, None, inp, [1, 0], should_fail=True)
|
|
|
|
s = nn.Sequential(deepcopy(l).cpu(), l.cuda())
|
|
test(s, None, inp, None, should_fail=True)
|
|
test(s, None, inp, [0, 1], should_fail=True)
|
|
test(s, None, inp, [1, 0], should_fail=True)
|
|
|
|
s = nn.Sequential(l.cuda(), deepcopy(l).cuda(1))
|
|
test(s, None, inp, None, should_fail=True)
|
|
test(s, None, inp, [0, 1], should_fail=True)
|
|
test(s, None, inp, [1, 0], should_fail=True)
|
|
|
|
s = nn.Sequential(l.cuda(), deepcopy(l).cuda())
|
|
test(s, None, inp, None, should_fail=False)
|
|
test(s, None, inp, [0, 1], should_fail=False)
|
|
test(s, None, inp, [1, 0], should_fail=True)
|
|
test(s.cpu(), None, inp, [1, 0], should_fail=True)
|
|
test(s.cuda(1), None, inp, [1, 0], should_fail=False)
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU or not PY3, "multi-GPU not supported")
|
|
@skipIfRocm
|
|
def test_data_parallel_model_no_refcycles(self):
|
|
# Python 2.7 will create reference cycles with the following
|
|
# Module on multiple GPUs, but Python 3 shouldn't unless
|
|
# there are refcycles on the PyTorch side (or the defined module)
|
|
import gc
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self):
|
|
super(Model, self).__init__()
|
|
self.linear = nn.Linear(1, 1)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
gc.collect()
|
|
model = nn.DataParallel(Model().cuda())
|
|
data = torch.randn(1, device="cuda")
|
|
model(data)
|
|
|
|
refcycles = gc.collect()
|
|
self.assertEqual(refcycles, 0)
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
|
def test_data_parallel_no_grad(self):
|
|
test = self
|
|
|
|
class Layer(nn.Module):
|
|
def forward(self, x):
|
|
test.assertFalse(torch.is_grad_enabled())
|
|
return x
|
|
|
|
l = Layer()
|
|
i = torch.randn(20, 10, dtype=torch.float, device="cuda")
|
|
with torch.no_grad():
|
|
dp.data_parallel(l, i, (0, 1))
|
|
self.assertRaises(AssertionError, lambda: dp.data_parallel(l, i, (0, 1)))
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
|
def test_data_parallel(self):
|
|
l = nn.Linear(10, 5).float().cuda()
|
|
i = torch.randn(20, 10, dtype=torch.float, device="cuda:1")
|
|
l.cuda(1)
|
|
expected_out = l(i)
|
|
loss = expected_out.sum()
|
|
loss.backward()
|
|
expected_grads = []
|
|
for param in l.parameters():
|
|
expected_grads.append(param.grad.clone())
|
|
dev_ids_list = [(0, 1), (1, 0)]
|
|
for dev_id in dev_ids_list:
|
|
with torch.cuda.device(dev_id[0]):
|
|
l.cuda()
|
|
l.zero_grad()
|
|
out = dp.data_parallel(l, i, dev_id)
|
|
loss = out.sum()
|
|
loss.backward()
|
|
self.assertEqual(out.get_device(), dev_id[0])
|
|
self.assertEqual(out, expected_out)
|
|
for expected, param in zip(expected_grads, l.parameters()):
|
|
self.assertEqual(param.grad, expected)
|
|
|
|
# Check for None device_ids
|
|
l = l.cuda()
|
|
out = dp.data_parallel(l, i)
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
|
def test_data_parallel_sparse(self):
|
|
l = nn.Embedding(10, 5, sparse=True).to("cuda:1")
|
|
i = torch.randint(10, (20, 5), device="cuda:1", dtype=torch.long)
|
|
expected_out = l(i)
|
|
loss = expected_out.sum()
|
|
loss.backward()
|
|
expected_grads = []
|
|
for param in l.parameters():
|
|
expected_grads.append(param.grad.clone())
|
|
dev_ids_list = [(0, 1), (1, 0)]
|
|
for dev_id in dev_ids_list:
|
|
with torch.cuda.device(dev_id[0]):
|
|
l.cuda()
|
|
l.zero_grad()
|
|
out = dp.data_parallel(l, i, dev_id)
|
|
loss = out.sum()
|
|
loss.backward()
|
|
self.assertEqual(out.get_device(), dev_id[0])
|
|
self.assertEqual(out, expected_out)
|
|
for expected, param in zip(expected_grads, l.parameters()):
|
|
self.assertEqual(param.grad, expected)
|
|
|
|
# Check for None device_ids
|
|
l = l.cuda()
|
|
out = dp.data_parallel(l, i)
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
|
def test_data_parallel_nested_output(self):
|
|
def fn(input):
|
|
return [
|
|
input, (input.sin(), input.cos(), [input.add(1)]), input,
|
|
OrderedDict(a=input, b=[input.sin()])
|
|
]
|
|
|
|
class Net(nn.Module):
|
|
def forward(self, input):
|
|
return fn(input)
|
|
|
|
i = torch.randn(2, 2).float().cuda(1)
|
|
gpus = range(torch.cuda.device_count())
|
|
output = dp.data_parallel(Net(), i, gpus)
|
|
self.assertEqual(output, fn(i))
|
|
self.assertIsInstance(output[0], torch.Tensor)
|
|
self.assertIsInstance(output[1], tuple)
|
|
self.assertIsInstance(output[1][0], torch.Tensor)
|
|
self.assertIsInstance(output[1][1], torch.Tensor)
|
|
self.assertIsInstance(output[1][2], list)
|
|
self.assertIsInstance(output[1][2][0], torch.Tensor)
|
|
self.assertIsInstance(output[2], torch.Tensor)
|
|
self.assertIsInstance(output[3], dict)
|
|
self.assertEqual(len(output[3]), 2)
|
|
self.assertIn('a', output[3])
|
|
self.assertIn('b', output[3])
|
|
self.assertIsInstance(output[3]['a'], torch.Tensor)
|
|
self.assertIsInstance(output[3]['b'], list)
|
|
self.assertIsInstance(output[3]['b'][0], torch.Tensor)
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
|
def test_data_parallel_nested_input(self):
|
|
def fn(input):
|
|
return input[1][0]
|
|
|
|
class Net(nn.Module):
|
|
def forward(self, *input):
|
|
return fn(input)
|
|
|
|
i = torch.randn(20, 3, dtype=torch.float, device="cuda:1")
|
|
input = (i.cos(), (i.sin(), i), i.sin())
|
|
gpus = range(torch.cuda.device_count())
|
|
output = dp.data_parallel(Net(), input, gpus)
|
|
self.assertEqual(output, fn(input))
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
|
@repeat_test_for_types(ALL_TENSORTYPES)
|
|
def test_data_parallel_module(self, dtype=torch.float):
|
|
l = nn.Linear(10, 5).to("cuda", dtype)
|
|
i = torch.randn(20, 10, device="cuda", dtype=dtype)
|
|
expected_out = l(i)
|
|
net = nn.DataParallel(l)
|
|
out = net(i)
|
|
self.assertEqual(out.get_device(), 0)
|
|
self.assertEqual(out, expected_out, dtype2prec_DONTUSE[dtype])
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
|
@repeat_test_for_types(ALL_TENSORTYPES)
|
|
def test_data_parallel_module_kwargs_only(self, dtype=torch.float):
|
|
class Net(nn.Module):
|
|
def __init__(self):
|
|
super(Net, self).__init__()
|
|
self.l = l
|
|
|
|
def forward(self, input):
|
|
return self.l(input)
|
|
|
|
l = nn.Linear(10, 5).to("cuda", dtype)
|
|
i = torch.randn(20, 10, device="cuda", dtype=dtype)
|
|
expected_out = l(i)
|
|
n = nn.DataParallel(Net())
|
|
out = n(input=i)
|
|
self.assertEqual(out.get_device(), 0)
|
|
self.assertEqual(out, expected_out, dtype2prec_DONTUSE[dtype])
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
|
@repeat_test_for_types(ALL_TENSORTYPES)
|
|
def test_data_parallel_module_kwargs_only_empty_list(self, dtype=torch.float):
|
|
class Net(nn.Module):
|
|
def __init__(self):
|
|
super(Net, self).__init__()
|
|
self.l = l
|
|
|
|
def forward(self, input):
|
|
return self.l(input['data'])
|
|
|
|
l = nn.Linear(10, 5).to("cuda", dtype)
|
|
i = torch.randn(20, 10, device="cuda", dtype=dtype)
|
|
expected_out = l(i)
|
|
n = nn.DataParallel(Net())
|
|
out = n(input={'data': i, 'unused': []})
|
|
self.assertEqual(out.get_device(), 0)
|
|
self.assertEqual(out, expected_out, dtype2prec_DONTUSE[dtype])
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
|
@repeat_test_for_types(ALL_TENSORTYPES)
|
|
def test_data_parallel_module_kwargs_only_empty_dict(self, dtype=torch.float):
|
|
class Net(nn.Module):
|
|
def __init__(self):
|
|
super(Net, self).__init__()
|
|
self.l = l
|
|
|
|
def forward(self, input):
|
|
return self.l(input['data'])
|
|
|
|
l = nn.Linear(10, 5).to("cuda", dtype)
|
|
i = torch.randn(20, 10, device="cuda", dtype=dtype)
|
|
expected_out = l(i)
|
|
n = nn.DataParallel(Net())
|
|
out = n(input={'data': i, 'unused': {}})
|
|
self.assertEqual(out.get_device(), 0)
|
|
self.assertEqual(out, expected_out, dtype2prec_DONTUSE[dtype])
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
|
@repeat_test_for_types(ALL_TENSORTYPES)
|
|
def test_data_parallel_module_kwargs_only_empty_tuple(self, dtype=torch.float):
|
|
class Net(nn.Module):
|
|
def __init__(self):
|
|
super(Net, self).__init__()
|
|
self.l = l
|
|
|
|
def forward(self, input):
|
|
return self.l(input['data'])
|
|
|
|
l = nn.Linear(10, 5).to("cuda", dtype)
|
|
i = torch.randn(20, 10, device="cuda", dtype=dtype)
|
|
expected_out = l(i)
|
|
n = nn.DataParallel(Net())
|
|
out = n(input={'data': i, 'unused': ()})
|
|
self.assertEqual(out.get_device(), 0)
|
|
self.assertEqual(out, expected_out, dtype2prec_DONTUSE[dtype])
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
|
def test_data_parallel_device_args(self):
|
|
cuda0 = torch.device('cuda:0')
|
|
cuda1 = torch.device('cuda:1')
|
|
|
|
# test output_device
|
|
l = nn.Linear(10, 5).to(cuda0, torch.float)
|
|
i = torch.randn(20, 10, dtype=torch.float, device=cuda0, requires_grad=True)
|
|
out = dp.data_parallel(l, i, device_ids=(0, 1), output_device=cuda0)
|
|
self.assertEqual(out, l(i))
|
|
|
|
# test device_ids
|
|
l = nn.Linear(10, 5).to(cuda0, torch.float)
|
|
i = torch.randn(20, 10, dtype=torch.float, device=cuda0, requires_grad=True)
|
|
out = dp.data_parallel(l, i, device_ids=(cuda0, cuda1), output_device=cuda0)
|
|
self.assertEqual(out, l(i))
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
|
def test_data_parallel_function_deletion(self):
|
|
# this test case is originated from #16532
|
|
def gradient_penalty(net, x):
|
|
output = net(x)
|
|
loss = torch.autograd.grad(
|
|
outputs=output, inputs=x,
|
|
grad_outputs=x.new_ones(output.size()),
|
|
create_graph=True, retain_graph=True)[0].mean()
|
|
return loss
|
|
|
|
net = nn.Linear(4, 1).cuda()
|
|
dpn = nn.DataParallel(net, [0, 1])
|
|
x = torch.ones(2, 4, requires_grad=True).cuda()
|
|
|
|
dpn.zero_grad()
|
|
loss = gradient_penalty(dpn, x)
|
|
loss.backward()
|
|
grads = [p.grad for p in net.parameters()]
|
|
self.assertEqual(2, len(grads))
|
|
self.assertEqual(
|
|
torch.tensor([[0.25, 0.25, 0.25, 0.25]], device='cuda:0'),
|
|
grads[0])
|
|
self.assertEqual(torch.tensor([0.0], device='cuda:0'), grads[1])
|
|
|
|
def _test_scatter(self, tensor):
|
|
x = tensor.detach().requires_grad_()
|
|
result = dp.scatter(x, (0, 1))
|
|
self.assertEqual(len(result), 2)
|
|
self.assertEqual(result[0], x[:2])
|
|
self.assertEqual(result[0].get_device(), 0)
|
|
self.assertEqual(result[1], x[2:])
|
|
self.assertEqual(result[1].get_device(), 1)
|
|
grad = result[0].detach().clone().fill_(2)
|
|
result[0].backward(grad)
|
|
self.assertEqual(x.grad[:2], grad)
|
|
self.assertEqual(x.grad[2:], grad.clone().zero_())
|
|
_assertGradAndGradgradChecks(self, lambda y: dp.scatter(y, (0, 1)), (x,))
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
|
def test_scatter_cpu(self):
|
|
self._test_scatter(torch.randn((4, 4)))
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
|
def test_scatter_gpu(self):
|
|
self._test_scatter(torch.randn((4, 4)).cuda())
|
|
|
|
def _test_gather(self, output_device):
|
|
inputs = (
|
|
torch.randn(2, 4, device='cuda:0', requires_grad=True),
|
|
torch.randn(2, 4, device='cuda:1', requires_grad=True),
|
|
)
|
|
result = dp.gather(inputs, output_device)
|
|
self.assertEqual(result.size(), torch.Size([4, 4]))
|
|
self.assertEqual(result[:2], inputs[0])
|
|
self.assertEqual(result[2:], inputs[1])
|
|
if output_device != -1:
|
|
self.assertEqual(result.get_device(), output_device)
|
|
else:
|
|
self.assertFalse(result.is_cuda)
|
|
grad = torch.randn((4, 4))
|
|
if output_device != -1:
|
|
grad = grad.cuda(output_device)
|
|
result.backward(grad)
|
|
self.assertEqual(inputs[0].grad, grad[:2])
|
|
self.assertEqual(inputs[1].grad, grad[2:])
|
|
_assertGradAndGradgradChecks(self, lambda x, y: dp.gather((x, y), output_device), inputs)
|
|
|
|
# test scalar inputs, should stack into a vector in this case
|
|
inputs = (
|
|
torch.randn((), device='cuda:0', requires_grad=True),
|
|
torch.randn((), device='cuda:1', requires_grad=True),
|
|
)
|
|
result = dp.gather(inputs, output_device)
|
|
self.assertEqual(result.size(), torch.Size([2]))
|
|
self.assertEqual(result[0], inputs[0])
|
|
self.assertEqual(result[1], inputs[1])
|
|
if output_device != -1:
|
|
self.assertEqual(result.get_device(), output_device)
|
|
else:
|
|
self.assertFalse(result.is_cuda)
|
|
grad = torch.randn(2)
|
|
if output_device != -1:
|
|
grad = grad.cuda(output_device)
|
|
result.backward(grad)
|
|
self.assertEqual(inputs[0].grad, grad[0])
|
|
self.assertEqual(inputs[1].grad, grad[1])
|
|
_assertGradAndGradgradChecks(self, lambda x, y: dp.gather((x, y), output_device), inputs)
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
|
def test_gather_cpu(self):
|
|
self._test_gather(-1)
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
|
def test_gather_gpu(self):
|
|
self._test_gather(0)
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
|
def test_gather_different_len_dicts(self):
|
|
inputs = (
|
|
{'a': torch.randn(1, 2, requires_grad=True, device="cuda:0")},
|
|
{
|
|
'b': torch.randn(1, 2, requires_grad=True, device="cuda:1"),
|
|
'a': torch.randn(1, 2, requires_grad=True, device="cuda:1"),
|
|
}
|
|
)
|
|
with self.assertRaises(ValueError):
|
|
_ = dp.gather(inputs, target_device=0)
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
|
def test_replicate(self):
|
|
module = nn.Linear(10, 5).float().cuda()
|
|
input = torch.randn(2, 10, dtype=torch.float, device="cuda")
|
|
expected_output = module(input)
|
|
for devices in [(0, 1), [0, 1]]:
|
|
replicas = dp.replicate(module, devices)
|
|
for i, replica in enumerate(replicas):
|
|
for p in replica.parameters():
|
|
self.assertEqual(p.get_device(), i)
|
|
replica_input = input.cuda(i)
|
|
self.assertEqual(replica(replica_input), expected_output)
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
|
def test_replicate_buffers(self):
|
|
net = nn.Module()
|
|
net.bn = nn.BatchNorm2d(10)
|
|
net.cuda()
|
|
for devices in [(0, 1), [0, 1]]:
|
|
replicas = dp.replicate(net, devices)
|
|
for i, replica in enumerate(replicas):
|
|
self.assertEqual(replica.bn.running_mean.get_device(), i, 'buffer on wrong device')
|
|
self.assertEqual(replica.bn.running_var.get_device(), i, 'buffer on wrong device')
|
|
self.assertEqual(replica.bn.num_batches_tracked.get_device(), i, 'buffer on wrong device')
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
|
def test_zero_grad(self):
|
|
# zero_grad should warn about using gradients inside forward
|
|
|
|
class Net(torch.nn.Module):
|
|
def __init__(self, testcase):
|
|
super(Net, self).__init__()
|
|
self._testcase = testcase
|
|
|
|
def forward(self, x):
|
|
self._testcase.assertWarnsRegex(
|
|
lambda: self.zero_grad(),
|
|
r"Calling \.zero_grad\(\) from a module that was passed to a nn\.DataParallel\(\) has no effect.")
|
|
return x
|
|
|
|
module = Net(self).cuda()
|
|
dpm = dp.DataParallel(module)
|
|
dpm(torch.rand(4, 3, 6, 5))
|
|
|
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
|
@skipIfRocm
|
|
def test_autocast(self):
|
|
class Model(torch.nn.Linear):
|
|
def __init__(self):
|
|
super(Model, self).__init__(8, 8)
|
|
|
|
@torch.cuda.amp.autocast()
|
|
def forward(self, input):
|
|
return super(Model, self).forward(input)
|
|
|
|
model = dp.DataParallel(Model().cuda().to(dtype=torch.float32))
|
|
input = torch.randn((8, 8), dtype=torch.float32, device="cuda")
|
|
self.assertTrue(model(input).dtype is torch.float16)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|