pytorch/test/distributed/test_data_parallel.py
Michael Carilli 0f0271e255 [RELAND2] Eager autocasting, out-of-place ops only (with MSVC 2017 fix) (#35102)
Summary:
This is the second reland attempt for https://github.com/pytorch/pytorch/pull/32140.

The first reland attempt https://github.com/pytorch/pytorch/pull/35011 failed due a [small incompatible change](https://github.com/pytorch/pytorch/pull/35011#issuecomment-601754216) in recent master (`skipIfRocm` was removed from `test_data_parallel.py`).

The present PR restores skipIfRocm.

Description from first reland attempt https://github.com/pytorch/pytorch/pull/35011:

> https://github.com/pytorch/pytorch/pull/32140 was approved and merged, but [reverted](d0577e19f0) because it broke builds with versions of Visual Studio older than 15.8 that were not represented in public CI.  The build failures were caused by a [known VS bug](https://developercommunity.visualstudio.com/content/problem/27729/allow-function-with-internal-linkage-as-template-n.html), fixed in versions 15.8 and newer.
>
> The present PR reverts the revert (restoring https://github.com/pytorch/pytorch/pull/32140 's diffs) and adds a workaround to enable compilation with VS < 15.8.  The workaround isn't pretty, but it's guarded by macros such that it's only used when compiling with VS < 15.8.  All other builds compile with the same code/control flow as was merged in https://github.com/pytorch/pytorch/pull/32140.
>
> Original description of https://github.com/pytorch/pytorch/pull/32140:
> > Initial integration of eager autocasting, supporting out-of-place ops only for easier review.
> Relevant issue/RFC: https://github.com/pytorch/pytorch/issues/25081
>
> > In-place ops and ops with user-supplied out=... can certainly be supported as well (my initial WIP https://github.com/pytorch/pytorch/issues/29552 handled many) but require substantially more complex special casing in the autocasting backend and tests. Support for these ops (much of which has already been written) will be broken into later PRs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35102

Differential Revision: D20596918

Pulled By: ezyang

fbshipit-source-id: 60caa279bb0ce4a9bb0b28c1d585d42cf1cc7e50
2020-03-24 09:08:04 -07:00

674 lines
26 KiB
Python

import contextlib
import unittest
from copy import deepcopy
from collections import OrderedDict
import torch
from torch import nn
import torch.nn.parallel as dp
from torch.testing._internal.common_cuda import TEST_MULTIGPU, TEST_CUDA
from torch.testing._internal.common_utils import run_tests, TestCase, repeat_test_for_types, ALL_TENSORTYPES, PY3
from torch.testing._internal.common_utils import _assertGradAndGradgradChecks
from torch.testing._internal.common_utils import dtype2prec_DONTUSE
from torch.testing._internal.common_utils import skipIfRocm
import torch.nn.functional as F
torch.set_default_dtype(torch.double)
class TestDataParallel(TestCase):
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_data_parallel_buffers_requiring_grad(self):
class TestModule(nn.Module):
def __init__(self, t):
super(TestModule, self).__init__()
self.register_buffer('t_rg', t)
self.register_buffer('t_not_rg', t.clone().detach())
def forward(self, x):
return x * self.t_rg + self.t_not_rg
m = TestModule(torch.randn(100, device='cuda', requires_grad=True))
self.assertTrue(m.t_rg.requires_grad)
dpm = nn.DataParallel(m, [0, 1])
inp = torch.randn(2, 100, device='cuda')
def fn(t):
return dpm(inp)
torch.autograd.gradcheck(fn, (m.t_rg,))
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_data_parallel_rnn(self):
class TestModule(torch.nn.Module):
def __init__(self):
super(TestModule, self).__init__()
self.rnn = torch.nn.LSTM(300, 1024, 1, batch_first=True, bidirectional=True)
def forward(self, x):
self.rnn.flatten_parameters()
return self.rnn(x)
def step(model):
opt = torch.optim.SGD(model.parameters(), lr=10)
input = torch.ones(4, 4, 300).to(0)
output = model(input)
loss = F.mse_loss(output[0], torch.zeros_like(output[0]))
loss.backward()
opt.step()
with torch.no_grad():
model = TestModule().to(0)
model_dp = torch.nn.DataParallel(deepcopy(model))
# make sure DP does not crash when grad is disabled.
# See #21108
model_dp(torch.rand(2, 4, 300).to(0))
step(model)
step(model_dp)
for p1, p2 in zip(model.parameters(), model_dp.parameters()):
self.assertTrue(p1.allclose(p2))
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_parallel_apply(self):
l1 = nn.Linear(10, 5).to("cuda:0", torch.float)
l2 = nn.Linear(10, 5).to("cuda:1", torch.float)
i1 = torch.randn(2, 10, device="cuda:0", dtype=torch.float)
i2 = torch.randn(2, 10, device="cuda:1", dtype=torch.float)
expected1 = l1(i1)
expected2 = l2(i2)
modules = (l1, l2)
expected_outputs = (expected1, expected2)
# each input can be either a collection of positional arguments
# or an object representing the single argument
for inputs in [((i1,), (i2,)), (i1, i2)]:
outputs = dp.parallel_apply(modules, inputs, None)
for out, expected in zip(outputs, expected_outputs):
self.assertEqual(out, expected)
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_parallel_apply_passes_exception(self):
# we define and instantiate a module that will throw a KeyError
class TestModule(nn.Module):
def forward(self, *args):
return {}['wonderful']
l1 = TestModule().to("cuda", torch.float)
# and check that parallel_apply passes on the exception
# (we can use a single device twice for this test)
with self.assertRaisesRegex(KeyError,
'Caught KeyError in replica \\d '
'on device 0.\nOriginal Traceback'
'[\\s\\S]+wonderful'):
dp.parallel_apply(modules=(l1, l1), inputs=(None, None))
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_data_parallel_multiple_input(self):
class TestModule(nn.Module):
def forward(self, var1, var2, float1, var3=None):
if var3 is None:
return float1 * (var1 * var2)
else:
return float1 * (var1 * var2 + var3)
m = TestModule()
var1 = torch.randn(5, 5, dtype=torch.float, requires_grad=True)
var2 = torch.randn(5, 5, dtype=torch.float, requires_grad=True)
var3 = torch.randn(5, 5, dtype=torch.float, requires_grad=False)
float1 = torch.randn(1).item()
expected = m(var1, var2, float1)
loss = expected.sum()
loss.backward()
gvar1_exp = var1.grad.clone()
gvar2_exp = var2.grad.clone()
def local_test(out):
with torch.no_grad():
var1.grad.fill_(0.0)
var2.grad.fill_(0.0)
loss = out.sum()
loss.backward()
self.assertEqual(out, expected)
self.assertEqual(gvar1_exp, var1.grad)
self.assertEqual(gvar2_exp, var2.grad)
out = dp.data_parallel(m, (var1, var2, float1), (0, 1))
local_test(out)
out = dp.data_parallel(m, (var1, var2, float1), (1, 0))
local_test(out)
out = dp.data_parallel(m, (var1, var2, float1), (0,))
local_test(out)
with torch.no_grad():
var1.grad.fill_(0.0)
var2.grad.fill_(0.0)
expected = m(var1, var2, float1, var3=var3)
loss = expected.sum()
loss.backward()
gvar1_exp = var1.grad.clone()
gvar2_exp = var2.grad.clone()
dpm = nn.DataParallel(TestModule())
out = dpm(var1, var2, float1, var3=var3)
local_test(out)
dpm = nn.DataParallel(TestModule(), device_ids=[0])
out = dpm(var1, var2, float1, var3=var3)
local_test(out)
kwarg_wrap = {'var3': var3}
out = dp.data_parallel(
m, (var1, var2, float1), (0, 1), module_kwargs=kwarg_wrap)
local_test(out)
out = dp.data_parallel(
m, (var1, var2, float1), (0,), module_kwargs=kwarg_wrap)
local_test(out)
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_data_parallel_small_back(self):
l = nn.Linear(10, 5).float().cuda()
i = torch.randn(20, 10, dtype=torch.float, device="cuda")
out = dp.data_parallel(l, i, (0, 1))
self.assertEqual(out, l(i))
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_data_parallel_model_device(self):
r"""Test device[0] check at forward time.
"""
l = nn.Linear(2, 2)
inp = torch.randn(2, 2)
inp_cuda0 = inp.cuda(0)
inp_cuda1 = inp.cuda(1)
error_msg = "module must have its parameters and buffers on device {}"
@contextlib.contextmanager
def dummy_ctx_manager():
yield
def test(inner_m, dp_device, inp, device_ids, should_fail):
if device_ids is None:
device_ids = list(range(torch.cuda.device_count()))
if isinstance(device_ids[0], torch.device):
expect_device = device_ids[0]
else:
expect_device = torch.device("cuda:{}".format(device_ids[0]))
if should_fail:
def assert_correct():
return self.assertRaisesRegex(RuntimeError, error_msg.format(expect_device))
else:
assert_correct = dummy_ctx_manager
# test DataParallel module
dpm = nn.DataParallel(inner_m, device_ids)
if dp_device is not None:
dpm = dpm.to(dp_device)
with assert_correct():
dpm(inp)
# test functional
with assert_correct():
nn.parallel.data_parallel(inner_m.to(dp_device), inp, device_ids)
test(l.to('cpu'), None, inp, None, should_fail=True)
test(l.cuda(1), None, inp_cuda0, None, should_fail=True)
test(l.cuda(), None, inp_cuda0, [1, 0], should_fail=True)
test(l.cuda(), None, inp_cuda0, None, should_fail=False)
test(l.cpu(), 'cuda', inp_cuda0, None, should_fail=False)
test(l.cuda(1), None, inp_cuda1, [1, 0], should_fail=False)
test(l.cpu(), 'cuda:1', inp_cuda1, [1, 0], should_fail=False)
s = nn.Sequential(l.cpu())
test(s, None, inp, None, should_fail=True)
test(s, None, inp, [0, 1], should_fail=True)
test(s, None, inp, [1, 0], should_fail=True)
s = nn.Sequential(deepcopy(l).cpu(), l.cuda())
test(s, None, inp, None, should_fail=True)
test(s, None, inp, [0, 1], should_fail=True)
test(s, None, inp, [1, 0], should_fail=True)
s = nn.Sequential(l.cuda(), deepcopy(l).cuda(1))
test(s, None, inp, None, should_fail=True)
test(s, None, inp, [0, 1], should_fail=True)
test(s, None, inp, [1, 0], should_fail=True)
s = nn.Sequential(l.cuda(), deepcopy(l).cuda())
test(s, None, inp, None, should_fail=False)
test(s, None, inp, [0, 1], should_fail=False)
test(s, None, inp, [1, 0], should_fail=True)
test(s.cpu(), None, inp, [1, 0], should_fail=True)
test(s.cuda(1), None, inp, [1, 0], should_fail=False)
@unittest.skipIf(not TEST_MULTIGPU or not PY3, "multi-GPU not supported")
def test_data_parallel_model_no_refcycles(self):
# Python 2.7 will create reference cycles with the following
# Module on multiple GPUs, but Python 3 shouldn't unless
# there are refcycles on the PyTorch side (or the defined module)
import gc
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear = nn.Linear(1, 1)
def forward(self, x):
return self.linear(x)
gc.collect()
model = nn.DataParallel(Model().cuda())
data = torch.randn(1, device="cuda")
model(data)
refcycles = gc.collect()
self.assertEqual(refcycles, 0)
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_data_parallel_no_grad(self):
test = self
class Layer(nn.Module):
def forward(self, x):
test.assertFalse(torch.is_grad_enabled())
return x
l = Layer()
i = torch.randn(20, 10, dtype=torch.float, device="cuda")
with torch.no_grad():
dp.data_parallel(l, i, (0, 1))
self.assertRaises(AssertionError, lambda: dp.data_parallel(l, i, (0, 1)))
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_data_parallel(self):
l = nn.Linear(10, 5).float().cuda()
i = torch.randn(20, 10, dtype=torch.float, device="cuda:1")
l.cuda(1)
expected_out = l(i)
loss = expected_out.sum()
loss.backward()
expected_grads = []
for param in l.parameters():
expected_grads.append(param.grad.clone())
dev_ids_list = [(0, 1), (1, 0)]
for dev_id in dev_ids_list:
with torch.cuda.device(dev_id[0]):
l.cuda()
l.zero_grad()
out = dp.data_parallel(l, i, dev_id)
loss = out.sum()
loss.backward()
self.assertEqual(out.get_device(), dev_id[0])
self.assertEqual(out, expected_out)
for expected, param in zip(expected_grads, l.parameters()):
self.assertEqual(param.grad, expected)
# Check for None device_ids
l = l.cuda()
out = dp.data_parallel(l, i)
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_data_parallel_sparse(self):
l = nn.Embedding(10, 5, sparse=True).to("cuda:1")
i = torch.randint(10, (20, 5), device="cuda:1", dtype=torch.long)
expected_out = l(i)
loss = expected_out.sum()
loss.backward()
expected_grads = []
for param in l.parameters():
expected_grads.append(param.grad.clone())
dev_ids_list = [(0, 1), (1, 0)]
for dev_id in dev_ids_list:
with torch.cuda.device(dev_id[0]):
l.cuda()
l.zero_grad()
out = dp.data_parallel(l, i, dev_id)
loss = out.sum()
loss.backward()
self.assertEqual(out.get_device(), dev_id[0])
self.assertEqual(out, expected_out)
for expected, param in zip(expected_grads, l.parameters()):
self.assertEqual(param.grad, expected)
# Check for None device_ids
l = l.cuda()
out = dp.data_parallel(l, i)
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_data_parallel_nested_output(self):
def fn(input):
return [
input, (input.sin(), input.cos(), [input.add(1)]), input,
OrderedDict(a=input, b=[input.sin()])
]
class Net(nn.Module):
def forward(self, input):
return fn(input)
i = torch.randn(2, 2).float().cuda(1)
gpus = range(torch.cuda.device_count())
output = dp.data_parallel(Net(), i, gpus)
self.assertEqual(output, fn(i))
self.assertIsInstance(output[0], torch.Tensor)
self.assertIsInstance(output[1], tuple)
self.assertIsInstance(output[1][0], torch.Tensor)
self.assertIsInstance(output[1][1], torch.Tensor)
self.assertIsInstance(output[1][2], list)
self.assertIsInstance(output[1][2][0], torch.Tensor)
self.assertIsInstance(output[2], torch.Tensor)
self.assertIsInstance(output[3], dict)
self.assertEqual(len(output[3]), 2)
self.assertIn('a', output[3])
self.assertIn('b', output[3])
self.assertIsInstance(output[3]['a'], torch.Tensor)
self.assertIsInstance(output[3]['b'], list)
self.assertIsInstance(output[3]['b'][0], torch.Tensor)
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_data_parallel_nested_input(self):
def fn(input):
return input[1][0]
class Net(nn.Module):
def forward(self, *input):
return fn(input)
i = torch.randn(20, 3, dtype=torch.float, device="cuda:1")
input = (i.cos(), (i.sin(), i), i.sin())
gpus = range(torch.cuda.device_count())
output = dp.data_parallel(Net(), input, gpus)
self.assertEqual(output, fn(input))
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
@repeat_test_for_types(ALL_TENSORTYPES)
def test_data_parallel_module(self, dtype=torch.float):
l = nn.Linear(10, 5).to("cuda", dtype)
i = torch.randn(20, 10, device="cuda", dtype=dtype)
expected_out = l(i)
net = nn.DataParallel(l)
out = net(i)
self.assertEqual(out.get_device(), 0)
self.assertEqual(out, expected_out, dtype2prec_DONTUSE[dtype])
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
@repeat_test_for_types(ALL_TENSORTYPES)
def test_data_parallel_module_kwargs_only(self, dtype=torch.float):
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.l = l
def forward(self, input):
return self.l(input)
l = nn.Linear(10, 5).to("cuda", dtype)
i = torch.randn(20, 10, device="cuda", dtype=dtype)
expected_out = l(i)
n = nn.DataParallel(Net())
out = n(input=i)
self.assertEqual(out.get_device(), 0)
self.assertEqual(out, expected_out, dtype2prec_DONTUSE[dtype])
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
@repeat_test_for_types(ALL_TENSORTYPES)
def test_data_parallel_module_kwargs_only_empty_list(self, dtype=torch.float):
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.l = l
def forward(self, input):
return self.l(input['data'])
l = nn.Linear(10, 5).to("cuda", dtype)
i = torch.randn(20, 10, device="cuda", dtype=dtype)
expected_out = l(i)
n = nn.DataParallel(Net())
out = n(input={'data': i, 'unused': []})
self.assertEqual(out.get_device(), 0)
self.assertEqual(out, expected_out, dtype2prec_DONTUSE[dtype])
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
@repeat_test_for_types(ALL_TENSORTYPES)
def test_data_parallel_module_kwargs_only_empty_dict(self, dtype=torch.float):
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.l = l
def forward(self, input):
return self.l(input['data'])
l = nn.Linear(10, 5).to("cuda", dtype)
i = torch.randn(20, 10, device="cuda", dtype=dtype)
expected_out = l(i)
n = nn.DataParallel(Net())
out = n(input={'data': i, 'unused': {}})
self.assertEqual(out.get_device(), 0)
self.assertEqual(out, expected_out, dtype2prec_DONTUSE[dtype])
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
@repeat_test_for_types(ALL_TENSORTYPES)
def test_data_parallel_module_kwargs_only_empty_tuple(self, dtype=torch.float):
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.l = l
def forward(self, input):
return self.l(input['data'])
l = nn.Linear(10, 5).to("cuda", dtype)
i = torch.randn(20, 10, device="cuda", dtype=dtype)
expected_out = l(i)
n = nn.DataParallel(Net())
out = n(input={'data': i, 'unused': ()})
self.assertEqual(out.get_device(), 0)
self.assertEqual(out, expected_out, dtype2prec_DONTUSE[dtype])
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_data_parallel_device_args(self):
cuda0 = torch.device('cuda:0')
cuda1 = torch.device('cuda:1')
# test output_device
l = nn.Linear(10, 5).to(cuda0, torch.float)
i = torch.randn(20, 10, dtype=torch.float, device=cuda0, requires_grad=True)
out = dp.data_parallel(l, i, device_ids=(0, 1), output_device=cuda0)
self.assertEqual(out, l(i))
# test device_ids
l = nn.Linear(10, 5).to(cuda0, torch.float)
i = torch.randn(20, 10, dtype=torch.float, device=cuda0, requires_grad=True)
out = dp.data_parallel(l, i, device_ids=(cuda0, cuda1), output_device=cuda0)
self.assertEqual(out, l(i))
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_data_parallel_function_deletion(self):
# this test case is originated from #16532
def gradient_penalty(net, x):
output = net(x)
loss = torch.autograd.grad(
outputs=output, inputs=x,
grad_outputs=x.new_ones(output.size()),
create_graph=True, retain_graph=True)[0].mean()
return loss
net = nn.Linear(4, 1).cuda()
dpn = nn.DataParallel(net, [0, 1])
x = torch.ones(2, 4, requires_grad=True).cuda()
dpn.zero_grad()
loss = gradient_penalty(dpn, x)
loss.backward()
grads = [p.grad for p in net.parameters()]
self.assertEqual(2, len(grads))
self.assertEqual(
torch.tensor([[0.25, 0.25, 0.25, 0.25]], device='cuda:0'),
grads[0])
self.assertEqual(torch.tensor([0.0], device='cuda:0'), grads[1])
def _test_scatter(self, tensor):
x = tensor.detach().requires_grad_()
result = dp.scatter(x, (0, 1))
self.assertEqual(len(result), 2)
self.assertEqual(result[0], x[:2])
self.assertEqual(result[0].get_device(), 0)
self.assertEqual(result[1], x[2:])
self.assertEqual(result[1].get_device(), 1)
grad = result[0].detach().clone().fill_(2)
result[0].backward(grad)
self.assertEqual(x.grad[:2], grad)
self.assertEqual(x.grad[2:], grad.clone().zero_())
_assertGradAndGradgradChecks(self, lambda y: dp.scatter(y, (0, 1)), (x,))
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_scatter_cpu(self):
self._test_scatter(torch.randn((4, 4)))
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_scatter_gpu(self):
self._test_scatter(torch.randn((4, 4)).cuda())
def _test_gather(self, output_device):
inputs = (
torch.randn(2, 4, device='cuda:0', requires_grad=True),
torch.randn(2, 4, device='cuda:1', requires_grad=True),
)
result = dp.gather(inputs, output_device)
self.assertEqual(result.size(), torch.Size([4, 4]))
self.assertEqual(result[:2], inputs[0])
self.assertEqual(result[2:], inputs[1])
if output_device != -1:
self.assertEqual(result.get_device(), output_device)
else:
self.assertFalse(result.is_cuda)
grad = torch.randn((4, 4))
if output_device != -1:
grad = grad.cuda(output_device)
result.backward(grad)
self.assertEqual(inputs[0].grad, grad[:2])
self.assertEqual(inputs[1].grad, grad[2:])
_assertGradAndGradgradChecks(self, lambda x, y: dp.gather((x, y), output_device), inputs)
# test scalar inputs, should stack into a vector in this case
inputs = (
torch.randn((), device='cuda:0', requires_grad=True),
torch.randn((), device='cuda:1', requires_grad=True),
)
result = dp.gather(inputs, output_device)
self.assertEqual(result.size(), torch.Size([2]))
self.assertEqual(result[0], inputs[0])
self.assertEqual(result[1], inputs[1])
if output_device != -1:
self.assertEqual(result.get_device(), output_device)
else:
self.assertFalse(result.is_cuda)
grad = torch.randn(2)
if output_device != -1:
grad = grad.cuda(output_device)
result.backward(grad)
self.assertEqual(inputs[0].grad, grad[0])
self.assertEqual(inputs[1].grad, grad[1])
_assertGradAndGradgradChecks(self, lambda x, y: dp.gather((x, y), output_device), inputs)
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_gather_cpu(self):
self._test_gather(-1)
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_gather_gpu(self):
self._test_gather(0)
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_gather_different_len_dicts(self):
inputs = (
{'a': torch.randn(1, 2, requires_grad=True, device="cuda:0")},
{
'b': torch.randn(1, 2, requires_grad=True, device="cuda:1"),
'a': torch.randn(1, 2, requires_grad=True, device="cuda:1"),
}
)
with self.assertRaises(ValueError):
_ = dp.gather(inputs, target_device=0)
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_replicate(self):
module = nn.Linear(10, 5).float().cuda()
input = torch.randn(2, 10, dtype=torch.float, device="cuda")
expected_output = module(input)
for devices in [(0, 1), [0, 1]]:
replicas = dp.replicate(module, devices)
for i, replica in enumerate(replicas):
for p in replica.parameters():
self.assertEqual(p.get_device(), i)
replica_input = input.cuda(i)
self.assertEqual(replica(replica_input), expected_output)
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_replicate_buffers(self):
net = nn.Module()
net.bn = nn.BatchNorm2d(10)
net.cuda()
for devices in [(0, 1), [0, 1]]:
replicas = dp.replicate(net, devices)
for i, replica in enumerate(replicas):
self.assertEqual(replica.bn.running_mean.get_device(), i, 'buffer on wrong device')
self.assertEqual(replica.bn.running_var.get_device(), i, 'buffer on wrong device')
self.assertEqual(replica.bn.num_batches_tracked.get_device(), i, 'buffer on wrong device')
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_zero_grad(self):
# zero_grad should warn about using gradients inside forward
class Net(torch.nn.Module):
def __init__(self, testcase):
super(Net, self).__init__()
self._testcase = testcase
def forward(self, x):
self._testcase.assertWarnsRegex(
lambda: self.zero_grad(),
r"Calling \.zero_grad\(\) from a module that was passed to a nn\.DataParallel\(\) has no effect.")
return x
module = Net(self).cuda()
dpm = dp.DataParallel(module)
dpm(torch.rand(4, 3, 6, 5))
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
@skipIfRocm
def test_autocast(self):
class Model(torch.nn.Linear):
def __init__(self):
super(Model, self).__init__(8, 8)
@torch.cuda.amp.autocast()
def forward(self, input):
return super(Model, self).forward(input)
model = dp.DataParallel(Model().cuda().to(dtype=torch.float32))
input = torch.randn((8, 8), dtype=torch.float32, device="cuda")
self.assertTrue(model(input).dtype is torch.float16)
if __name__ == '__main__':
run_tests()