Remove repeat test for types in test nn (#70872)

Summary:
Helps fix a part of https://github.com/pytorch/pytorch/issues/69865

The first commit just migrates everything as is.

The second commit uses the "device" variable instead of passing "cuda" everywhere

Pull Request resolved: https://github.com/pytorch/pytorch/pull/70872

Reviewed By: jbschlosser

Differential Revision: D33455941

Pulled By: janeyx99

fbshipit-source-id: 9d9ec8c95f1714c40d55800e652ccd69b0c314dc
This commit is contained in:
Jane Xu 2022-01-06 09:17:45 -08:00 committed by Facebook GitHub Bot
parent bc514cb425
commit c00d33033c
2 changed files with 310 additions and 324 deletions

View File

@ -37,7 +37,7 @@ from torch.nn.parallel._functions import Broadcast
from torch.testing._internal.common_dtype import integral_types, get_all_fp_dtypes, get_all_math_dtypes from torch.testing._internal.common_dtype import integral_types, get_all_fp_dtypes, get_all_math_dtypes
from torch.testing._internal.common_utils import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, skipIfRocm, \ from torch.testing._internal.common_utils import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, skipIfRocm, \
skipIfRocmVersionLessThan, skipIfNotMiopenSuggestNHWC, TEST_NUMPY, TEST_SCIPY, TEST_WITH_ROCM, download_file, \ skipIfRocmVersionLessThan, skipIfNotMiopenSuggestNHWC, TEST_NUMPY, TEST_SCIPY, TEST_WITH_ROCM, download_file, \
get_function_arglist, load_tests, repeat_test_for_types, ALL_TENSORTYPES, \ get_function_arglist, load_tests, ALL_TENSORTYPES, \
ALL_TENSORTYPES2, suppress_warnings, TemporaryFileName, TEST_WITH_UBSAN, IS_PPC, \ ALL_TENSORTYPES2, suppress_warnings, TemporaryFileName, TEST_WITH_UBSAN, IS_PPC, \
parametrize as parametrize_test, subtest parametrize as parametrize_test, subtest
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, TEST_CUDNN_VERSION from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, TEST_CUDNN_VERSION
@ -73,8 +73,6 @@ if TEST_SCIPY:
if TEST_NUMPY: if TEST_NUMPY:
import numpy as np import numpy as np
DOUBLE_TENSORTYPES = [torch.double]
# WARNING: If you add a new top-level test case to this file, you MUST # WARNING: If you add a new top-level test case to this file, you MUST
# update test/run_test.py to list it, otherwise it will NOT be run in # update test/run_test.py to list it, otherwise it will NOT be run in
@ -6155,25 +6153,6 @@ class TestNN(NNTestCase):
# but it should work with the same type # but it should work with the same type
nn.functional.conv2d(inputs.float(), weights.float(), bias.float()) nn.functional.conv2d(inputs.float(), weights.float(), bias.float())
@unittest.skipIf(not TEST_CUDA, 'CUDA not available')
@unittest.skipIf(not TEST_CUDNN, 'CUDNN not available')
@repeat_test_for_types(get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM))
def test_Conv2d_deterministic_cudnn(self, dtype=torch.float):
inputs = torch.randn(2, 3, 5, 5, device="cuda", dtype=dtype, requires_grad=True)
with cudnn.flags(enabled=True, benchmark=True, deterministic=True):
conv1 = torch.nn.Conv2d(3, 3, 3).to("cuda", dtype)
conv2 = torch.nn.Conv2d(3, 3, 3).to("cuda", dtype)
conv2.bias.data.copy_(conv1.bias.data)
conv2.weight.data.copy_(conv1.weight.data)
out1 = conv1(inputs)
out2 = conv2(inputs)
self.assertEqual(out1, out2, atol=0.0, rtol=0)
y = torch.randn(out1.size(), device="cuda", dtype=dtype)
out1.backward(y)
out2.backward(y)
self.assertEqual(conv1.bias.grad.data, conv2.bias.grad.data, atol=0.0, rtol=0)
self.assertEqual(conv1.weight.grad.data, conv2.weight.grad.data, atol=0.0, rtol=0)
def test_Conv2d_missing_argument(self): def test_Conv2d_missing_argument(self):
c = nn.Conv2d(3, 3, 3) c = nn.Conv2d(3, 3, 3)
self.assertRaises(TypeError, lambda: c(None)) self.assertRaises(TypeError, lambda: c(None))
@ -6186,27 +6165,6 @@ class TestNN(NNTestCase):
self.assertRaisesRegex(RuntimeError, 'Specify retain_graph=True', self.assertRaisesRegex(RuntimeError, 'Specify retain_graph=True',
lambda: o1.sum().backward()) lambda: o1.sum().backward())
@unittest.skipIf(not TEST_CUDA, 'CUDA not available')
@repeat_test_for_types(get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM))
def test_Conv2d_large_workspace(self, dtype=torch.float):
# These sizes require huge cuDNN workspaces. Make sure we choose a
# reasonable algorithm that does not run out of memory
sizes = [
(1, 256, 109, 175),
(1, 256, 80, 128),
(1, 256, 120, 192),
]
def run_test(benchmark):
with torch.backends.cudnn.flags(benchmark=benchmark):
conv = torch.nn.Conv2d(256, 256, kernel_size=3, padding=1).to("cuda", dtype)
for size in sizes:
x = torch.randn(size, device="cuda", dtype=dtype)
out = conv(x.detach().clone().requires_grad_())
out.backward(torch.ones_like(out))
run_test(benchmark=False)
run_test(benchmark=True)
def test_conv_modules_raise_error_on_incorrect_input_size(self): def test_conv_modules_raise_error_on_incorrect_input_size(self):
for dtype in [torch.bfloat16, torch.double, torch.float]: for dtype in [torch.bfloat16, torch.double, torch.float]:
@ -6308,25 +6266,10 @@ class TestNN(NNTestCase):
output = deconv(inputs) output = deconv(inputs)
output.mean().backward() output.mean().backward()
@unittest.skipIf(not TEST_CUDA, 'CUDA not available')
@repeat_test_for_types([torch.half, torch.float])
def test_ConvTranspose2d_large_output_padding(self, dtype=torch.half):
net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\
.to(device='cuda', dtype=dtype)
net2 = torch.nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)\
.to(device='cuda', dtype=dtype)
net3 = torch.nn.ConvTranspose2d(32, 3, kernel_size=3, stride=2, padding=1, output_padding=1)\
.to(device='cuda', dtype=dtype)
x = torch.rand(1, 128, 6, 6, device='cuda', dtype=dtype, requires_grad=True)
x = net1(x)
x = net2(x)
x = net3(x)
x.backward(torch.randn_like(x))
torch.cuda.synchronize()
@skipIfRocm
# For https://github.com/pytorch/pytorch/pull/1273 # For https://github.com/pytorch/pytorch/pull/1273
# Almost identical to the above `test_Conv2d_naive_groups` # Almost identical to the above `test_Conv2d_naive_groups`
@skipIfRocm
def test_Conv2d_groups_nobias(self): def test_Conv2d_groups_nobias(self):
dev_dtypes = [("cpu", torch.float)] dev_dtypes = [("cpu", torch.float)]
if TEST_CUDA: if TEST_CUDA:
@ -6464,89 +6407,7 @@ class TestNN(NNTestCase):
torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0), torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0),
atol=dtype2prec_DONTUSE[torch.float], rtol=dtype2prec_DONTUSE[torch.float]) atol=dtype2prec_DONTUSE[torch.float], rtol=dtype2prec_DONTUSE[torch.float])
# Very similar to test_Conv2d_naive_groups but with special care to handle
# the number of groups == number of input channels
@unittest.skipIf(not TEST_CUDA, 'CUDA not available')
@repeat_test_for_types(ALL_TENSORTYPES)
@tf32_on_and_off(0.01)
def test_Conv2d_depthwise_naive_groups_cuda(self, dtype=torch.float):
for depth_multiplier in [1, 2]:
m = nn.Conv2d(2, 2 * depth_multiplier, kernel_size=3, groups=2).to("cuda", dtype)
i = torch.randn(2, 2, 6, 6, device="cuda", dtype=dtype).div_(2).requires_grad_()
output = m(i)
grad_output = torch.randn(2, 2 * depth_multiplier, 4, 4, device="cuda", dtype=dtype) / 2
output.backward(grad_output)
offset = 1 * depth_multiplier
m1 = nn.Conv2d(1, 1 * depth_multiplier, kernel_size=3).to("cuda", dtype)
m1.weight.data = m.weight.data[:offset].clone()
m1.bias.data = m.bias.data[:offset].clone()
i1 = i.detach()[:, :1].clone().requires_grad_()
output1 = m1(i1)
output1.backward(grad_output[:, :offset].contiguous())
m2 = nn.Conv2d(1, 1 * depth_multiplier, kernel_size=3).to("cuda", dtype)
m2.weight.data.copy_(m.weight.data[offset:])
m2.bias.data.copy_(m.bias.data[offset:])
i2 = i.detach()[:, 1:].clone().requires_grad_()
output2 = m2(i2)
output2.backward(grad_output[:, offset:].contiguous())
self.assertEqual(output, torch.cat([output1, output2], 1),
atol=dtype2prec_DONTUSE[dtype], rtol=0)
self.assertEqual(i.grad.data,
torch.cat([i1.grad.data, i2.grad.data], 1),
atol=dtype2prec_DONTUSE[dtype], rtol=0)
self.assertEqual(m.bias.grad.data,
torch.cat([m1.bias.grad.data,
m2.bias.grad.data], 0),
atol=dtype2prec_DONTUSE[dtype], rtol=0)
self.assertEqual(m.weight.grad.data,
torch.cat([m1.weight.grad.data,
m2.weight.grad.data], 0),
atol=dtype2prec_DONTUSE[dtype], rtol=0)
@unittest.skipIf(not TEST_CUDA, 'CUDA not available')
@repeat_test_for_types(ALL_TENSORTYPES)
@tf32_on_and_off(0.005)
def test_Conv3d_depthwise_naive_groups_cuda(self, dtype=torch.float):
for depth_multiplier in [1, 2]:
m = nn.Conv3d(2, 2 * depth_multiplier, kernel_size=3, groups=2).to("cuda", dtype)
i = torch.randn(2, 2, 6, 6, 6, device="cuda", dtype=dtype).div_(2).requires_grad_()
output = m(i)
grad_output = torch.randn(2, 2 * depth_multiplier, 4, 4, 4, device="cuda", dtype=dtype) / 2
output.backward(grad_output)
offset = 1 * depth_multiplier
m1 = nn.Conv3d(1, 1 * depth_multiplier, kernel_size=3).to("cuda", dtype)
m1.weight.data = m.weight.data[:offset].clone()
m1.bias.data = m.bias.data[:offset].clone()
i1 = i.detach()[:, :1].clone().requires_grad_()
output1 = m1(i1)
output1.backward(grad_output[:, :offset].contiguous())
m2 = nn.Conv3d(1, 1 * depth_multiplier, kernel_size=3).to("cuda", dtype)
m2.weight.data.copy_(m.weight.data[offset:])
m2.bias.data.copy_(m.bias.data[offset:])
i2 = i.detach()[:, 1:].clone().requires_grad_()
output2 = m2(i2)
output2.backward(grad_output[:, offset:].contiguous())
self.assertEqual(output, torch.cat([output1, output2], 1),
atol=dtype2prec_DONTUSE[dtype], rtol=0)
self.assertEqual(i.grad.data,
torch.cat([i1.grad.data, i2.grad.data], 1),
atol=dtype2prec_DONTUSE[dtype], rtol=0)
self.assertEqual(m.bias.grad.data,
torch.cat([m1.bias.grad.data,
m2.bias.grad.data], 0),
atol=dtype2prec_DONTUSE[dtype], rtol=0)
self.assertEqual(m.weight.grad.data,
torch.cat([m1.weight.grad.data,
m2.weight.grad.data], 0),
atol=dtype2prec_DONTUSE[dtype], rtol=0)
def test_MaxUnpool2d_output_size(self): def test_MaxUnpool2d_output_size(self):
m = nn.MaxPool2d(3, stride=2, return_indices=True) m = nn.MaxPool2d(3, stride=2, return_indices=True)
@ -9059,23 +8920,6 @@ class TestNN(NNTestCase):
output.backward(grad_output) output.backward(grad_output)
self.assertEqual(grad_output, grad_output_clone) self.assertEqual(grad_output, grad_output_clone)
@unittest.skipIf(not TEST_CUDA, 'CUDA not available')
@repeat_test_for_types(get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM))
def test_noncontig_conv_grad_cuda(self, dtype=torch.float):
# FIXME: remove after adding non-contiguous grad tests for all modules
module = nn.Conv2d(3, 5, kernel_size=3, padding=1).to("cuda", dtype)
input = torch.randn(2, 3, 10, 10, dtype=dtype, device="cuda", requires_grad=True)
output = module(input)
grad = torch.randn(2, 2, 5, 10, 10, dtype=dtype, device="cuda")[:, 1]
assert not grad.is_contiguous()
output.backward(grad, retain_graph=True)
self.assertIsNotNone(input.grad)
result = input.grad.data.clone()
input.grad.data.zero_()
output.backward(grad.contiguous())
self.assertEqual(result, input.grad.data, atol=dtype2prec_DONTUSE[dtype], rtol=0)
def test_pixel_shuffle_unshuffle(self): def test_pixel_shuffle_unshuffle(self):
def _test_pixel_shuffle_unshuffle_helper(num_input_dims, valid_channels_dim=True, def _test_pixel_shuffle_unshuffle_helper(num_input_dims, valid_channels_dim=True,
@ -9599,13 +9443,6 @@ class TestNN(NNTestCase):
output = m(input) output = m(input)
self.assertEqualTypeString(output, input) self.assertEqualTypeString(output, input)
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
@repeat_test_for_types([torch.float, torch.half])
def test_batchnorm_large_batch(self, dtype=torch.float):
bn = nn.BatchNorm2d(1).to('cuda', dtype)
data = torch.rand(880801, 1, 1, 1, device="cuda", dtype=dtype)
out = bn(data).sum().backward()
def test_batchnorm_raises_error_if_less_than_one_value_per_channel(self): def test_batchnorm_raises_error_if_less_than_one_value_per_channel(self):
x = torch.rand(10)[None, :, None] x = torch.rand(10)[None, :, None]
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
@ -11083,49 +10920,6 @@ class TestNN(NNTestCase):
gradcheck(lambda i, w, b, pad: F.conv_tbc(i, w, b, pad), (inp, weight, bias, 3)) gradcheck(lambda i, w, b, pad: F.conv_tbc(i, w, b, pad), (inp, weight, bias, 3))
def run_conv_double_back_test(self, kern, stride, padding, chan_in, chan_out, batch_size,
inp_size, dilation, no_weight, groups=1, use_cuda=False,
use_bias=True, dtype=torch.double):
if use_cuda:
device = torch.device("cuda")
else:
device = torch.device("cpu")
x = torch.randn(batch_size, chan_in, inp_size, inp_size, device=device,
dtype=dtype, requires_grad=True)
weight = torch.randn(chan_out, chan_in // groups, kern, kern, device=device,
dtype=dtype, requires_grad=not no_weight)
if use_bias:
bias = torch.randn(chan_out, device=device, dtype=dtype, requires_grad=True)
else:
bias = None
def func(*inputs):
if use_bias:
lx, lweight, lbias = inputs
else:
lx, lweight = inputs
lbias = None
# We disable cudnn during forward to avoid finite difference imprecision issues
with cudnn.flags(enabled=False):
out = F.conv2d(lx, lweight, lbias, stride, padding, dilation, groups)
return out
if use_bias:
inputs = x, weight, bias
else:
inputs = x, weight
dummy_out = func(*inputs)
grad_y = torch.randn_like(dummy_out, device=device, dtype=dtype, requires_grad=True)
# Issue #15353: test mkldnn double backward, don't run gradgradcheck due
# to imprecision issues
if dtype == torch.float:
g, = torch.autograd.grad(dummy_out.sum(), x, create_graph=True)
return g.requires_grad
return gradgradcheck(func, inputs, (grad_y,))
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable") @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
@unittest.skipIf(not TEST_CUDNN, "needs cudnn") @unittest.skipIf(not TEST_CUDNN, "needs cudnn")
@ -11165,91 +10959,6 @@ class TestNN(NNTestCase):
out = conv(input) out = conv(input)
self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
def test_conv_double_backward(self):
batch_size = 2
for kern, inp_size, dilations in [(3, 6, [1, 2]), (3, 7, [1]), (4, 9, [1])]:
for stride, padding, chan_in, chan_out, dilation in \
product([1, 2], [0, 1, 2], [2], [3], dilations):
for no_weight in (True, False):
for dtype in (torch.float, torch.double):
result = self.run_conv_double_back_test(kern, stride,
padding, chan_in, chan_out,
batch_size, inp_size, dilation,
no_weight, dtype=dtype)
self.assertTrue(result,
"Conv double backward test failed with parameters:" +
"\nkern: " + str(kern) +
"\nstride: " + str(stride) +
"\npadding: " + str(padding) +
"\nchan_in: " + str(chan_in) +
"\nchan_out: " + str(chan_out) +
"\nbatch_size: " + str(batch_size) +
"\ninp_size: " + str(inp_size) +
"\ndilation: " + str(dilation) +
"\ndtype: " + str(dtype))
def test_conv_double_backward_no_bias(self):
kern = 3
stride = 2
chan_in, chan_out = 2, 4
batch_size = 2
inp_size = 5
padding = 1
dilation = 1
no_weight = False
use_bias = True
result = self.run_conv_double_back_test(kern, stride,
padding, chan_in, chan_out,
batch_size, inp_size, dilation,
no_weight, use_bias=use_bias)
self.assertTrue(result,
"Conv double backward test failed with parameters:" +
"\nkern: " + str(kern) +
"\nstride: " + str(stride) +
"\npadding: " + str(padding) +
"\nchan_in: " + str(chan_in) +
"\nchan_out: " + str(chan_out) +
"\nbatch_size: " + str(batch_size) +
"\ninp_size: " + str(inp_size) +
"\ndilation: " + str(dilation))
def test_conv_double_backward_groups(self):
kern = 3
stride = 1
padding = 2
chan_in, chan_out = 2, 4
batch_size = 2
inp_size = 6
dilation = 1
no_weight = False
groups = 2
result = self.run_conv_double_back_test(kern, stride,
padding, chan_in * groups, chan_out * groups,
batch_size, inp_size, dilation,
no_weight, groups=groups)
self.assertTrue(result,
"Conv double backward test failed with parameters:" +
"\nkern: " + str(kern) +
"\nstride: " + str(stride) +
"\npadding: " + str(padding) +
"\nchan_in: " + str(chan_in) +
"\nchan_out: " + str(chan_out) +
"\nbatch_size: " + str(batch_size) +
"\ninp_size: " + str(inp_size) +
"\ndilation: " + str(dilation) +
"\ngroups: " + str(groups))
def test_conv_double_backward_stride(self):
batch_size = 2
# Cannot provide ggW when stride is > 1
for kern, inp_size, dilations in [(3, 5, [1, 2]), (3, 7, [1])]:
for stride, padding, chan_in, chan_out, dilation in product([2], [0, 1], [1], [2], dilations):
no_weight = False
self.run_conv_double_back_test(kern, stride,
padding, chan_in, chan_out,
batch_size, inp_size, dilation,
no_weight)
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable") @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_cudnn_noncontiguous_weight(self): def test_cudnn_noncontiguous_weight(self):
@ -11261,29 +10970,6 @@ class TestNN(NNTestCase):
self.assertEqual(F.conv1d(input, weights1, bias=None, stride=2, dilation=2), self.assertEqual(F.conv1d(input, weights1, bias=None, stride=2, dilation=2),
F.conv1d(input, weights2, bias=None, stride=2, dilation=2)) F.conv1d(input, weights2, bias=None, stride=2, dilation=2))
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
@repeat_test_for_types(DOUBLE_TENSORTYPES)
def test_conv_double_backward_cuda(self, dtype=torch.double):
with torch.backends.cudnn.flags(deterministic=True):
# Double backward only runs with DoubleTensor due to precison reason
batch_size = 1
for kern, inp_size, dilations in [(3, 5, [1, 2]), (4, 9, [1])]:
for stride, padding, chan_in, chan_out, dilation in product([1], [2], [2], [3], dilations):
no_weight = stride == 2
result = self.run_conv_double_back_test(kern, stride,
padding, chan_in, chan_out,
batch_size, inp_size, dilation,
no_weight, use_cuda=True, dtype=dtype)
self.assertTrue(result,
"Conv double backward test failed with parameters:" +
"\nkern: " + str(kern) +
"\nstride: " + str(stride) +
"\npadding: " + str(padding) +
"\nchan_in: " + str(chan_in) +
"\nchan_out: " + str(chan_out) +
"\nbatch_size: " + str(batch_size) +
"\ninp_size: " + str(inp_size) +
"\ndilation: " + str(dilation))
def run_grad_conv_test(self, func_forward, func_backward, dim=1, gradient='input'): def run_grad_conv_test(self, func_forward, func_backward, dim=1, gradient='input'):
for kern, inp_size in [(3, 6), (3, 7), (4, 9)]: for kern, inp_size in [(3, 6), (3, 7), (4, 9)]:
@ -12725,6 +12411,50 @@ def _buildEquivalentAffineTransforms3d(device, input_size, output_size, angle_ra
class TestNNDeviceType(NNTestCase): class TestNNDeviceType(NNTestCase):
def run_conv_double_back_test(self, kern, stride, padding, chan_in, chan_out, batch_size,
inp_size, dilation, no_weight, groups=1, use_cuda=False,
use_bias=True, dtype=torch.double):
if use_cuda:
device = torch.device("cuda")
else:
device = torch.device("cpu")
x = torch.randn(batch_size, chan_in, inp_size, inp_size, device=device,
dtype=dtype, requires_grad=True)
weight = torch.randn(chan_out, chan_in // groups, kern, kern, device=device,
dtype=dtype, requires_grad=not no_weight)
if use_bias:
bias = torch.randn(chan_out, device=device, dtype=dtype, requires_grad=True)
else:
bias = None
def func(*inputs):
if use_bias:
lx, lweight, lbias = inputs
else:
lx, lweight = inputs
lbias = None
# We disable cudnn during forward to avoid finite difference imprecision issues
with cudnn.flags(enabled=False):
out = F.conv2d(lx, lweight, lbias, stride, padding, dilation, groups)
return out
if use_bias:
inputs = x, weight, bias
else:
inputs = x, weight
dummy_out = func(*inputs)
grad_y = torch.randn_like(dummy_out, device=device, dtype=dtype, requires_grad=True)
# Issue #15353: test mkldnn double backward, don't run gradgradcheck due
# to imprecision issues
if dtype == torch.float:
g, = torch.autograd.grad(dummy_out.sum(), x, create_graph=True)
return g.requires_grad
return gradgradcheck(func, inputs, (grad_y,))
def _test_dropout(self, cls, device, input, memory_format=torch.contiguous_format): def _test_dropout(self, cls, device, input, memory_format=torch.contiguous_format):
p = 0.2 p = 0.2
input = input.to(device).fill_(1 - p) input = input.to(device).fill_(1 - p)
@ -13243,6 +12973,270 @@ class TestNNDeviceType(NNTestCase):
self.assertEqual(scipy_ary, gridsample_ary.reshape_as(scipy_ary)) self.assertEqual(scipy_ary, gridsample_ary.reshape_as(scipy_ary))
@onlyCUDA
@skipCUDAIfNoCudnn
@dtypes(*get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM))
def test_Conv2d_deterministic_cudnn(self, device, dtype):
inputs = torch.randn(2, 3, 5, 5, device=device, dtype=dtype, requires_grad=True)
with cudnn.flags(enabled=True, benchmark=True, deterministic=True):
conv1 = torch.nn.Conv2d(3, 3, 3).to(device, dtype)
conv2 = torch.nn.Conv2d(3, 3, 3).to(device, dtype)
conv2.bias.data.copy_(conv1.bias.data)
conv2.weight.data.copy_(conv1.weight.data)
out1 = conv1(inputs)
out2 = conv2(inputs)
self.assertEqual(out1, out2, atol=0.0, rtol=0)
y = torch.randn(out1.size(), device=device, dtype=dtype)
out1.backward(y)
out2.backward(y)
self.assertEqual(conv1.bias.grad.data, conv2.bias.grad.data, atol=0.0, rtol=0)
self.assertEqual(conv1.weight.grad.data, conv2.weight.grad.data, atol=0.0, rtol=0)
@onlyCUDA
@dtypes(*get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM))
def test_Conv2d_large_workspace(self, device, dtype):
# These sizes require huge cuDNN workspaces. Make sure we choose a
# reasonable algorithm that does not run out of memory
sizes = [
(1, 256, 109, 175),
(1, 256, 80, 128),
(1, 256, 120, 192),
]
def run_test(benchmark):
with torch.backends.cudnn.flags(benchmark=benchmark):
conv = torch.nn.Conv2d(256, 256, kernel_size=3, padding=1).to(device, dtype)
for size in sizes:
x = torch.randn(size, device=device, dtype=dtype)
out = conv(x.detach().clone().requires_grad_())
out.backward(torch.ones_like(out))
run_test(benchmark=False)
run_test(benchmark=True)
@onlyCUDA
@dtypes(torch.half, torch.float)
def test_ConvTranspose2d_large_output_padding(self, device, dtype):
net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\
.to(device=device, dtype=dtype)
net2 = torch.nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)\
.to(device=device, dtype=dtype)
net3 = torch.nn.ConvTranspose2d(32, 3, kernel_size=3, stride=2, padding=1, output_padding=1)\
.to(device=device, dtype=dtype)
x = torch.rand(1, 128, 6, 6, device=device, dtype=dtype, requires_grad=True)
x = net1(x)
x = net2(x)
x = net3(x)
x.backward(torch.randn_like(x))
torch.cuda.synchronize()
@onlyCUDA
@tf32_on_and_off(0.01)
@dtypes(*ALL_TENSORTYPES)
# Very similar to test_Conv2d_naive_groups but with special care to handle
# the number of groups == number of input channels
def test_Conv2d_depthwise_naive_groups(self, device, dtype):
for depth_multiplier in [1, 2]:
m = nn.Conv2d(2, 2 * depth_multiplier, kernel_size=3, groups=2).to(device, dtype)
i = torch.randn(2, 2, 6, 6, device="cuda", dtype=dtype).div_(2).requires_grad_()
output = m(i)
grad_output = torch.randn(2, 2 * depth_multiplier, 4, 4, device=device, dtype=dtype) / 2
output.backward(grad_output)
offset = 1 * depth_multiplier
m1 = nn.Conv2d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype)
m1.weight.data = m.weight.data[:offset].clone()
m1.bias.data = m.bias.data[:offset].clone()
i1 = i.detach()[:, :1].clone().requires_grad_()
output1 = m1(i1)
output1.backward(grad_output[:, :offset].contiguous())
m2 = nn.Conv2d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype)
m2.weight.data.copy_(m.weight.data[offset:])
m2.bias.data.copy_(m.bias.data[offset:])
i2 = i.detach()[:, 1:].clone().requires_grad_()
output2 = m2(i2)
output2.backward(grad_output[:, offset:].contiguous())
self.assertEqual(output, torch.cat([output1, output2], 1),
atol=dtype2prec_DONTUSE[dtype], rtol=0)
self.assertEqual(i.grad.data,
torch.cat([i1.grad.data, i2.grad.data], 1),
atol=dtype2prec_DONTUSE[dtype], rtol=0)
self.assertEqual(m.bias.grad.data,
torch.cat([m1.bias.grad.data,
m2.bias.grad.data], 0),
atol=dtype2prec_DONTUSE[dtype], rtol=0)
self.assertEqual(m.weight.grad.data,
torch.cat([m1.weight.grad.data,
m2.weight.grad.data], 0),
atol=dtype2prec_DONTUSE[dtype], rtol=0)
@onlyCUDA
@dtypes(*ALL_TENSORTYPES)
@tf32_on_and_off(0.005)
def test_Conv3d_depthwise_naive_groups(self, device, dtype):
for depth_multiplier in [1, 2]:
m = nn.Conv3d(2, 2 * depth_multiplier, kernel_size=3, groups=2).to(device, dtype)
i = torch.randn(2, 2, 6, 6, 6, device="cuda", dtype=dtype).div_(2).requires_grad_()
output = m(i)
grad_output = torch.randn(2, 2 * depth_multiplier, 4, 4, 4, device=device, dtype=dtype) / 2
output.backward(grad_output)
offset = 1 * depth_multiplier
m1 = nn.Conv3d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype)
m1.weight.data = m.weight.data[:offset].clone()
m1.bias.data = m.bias.data[:offset].clone()
i1 = i.detach()[:, :1].clone().requires_grad_()
output1 = m1(i1)
output1.backward(grad_output[:, :offset].contiguous())
m2 = nn.Conv3d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype)
m2.weight.data.copy_(m.weight.data[offset:])
m2.bias.data.copy_(m.bias.data[offset:])
i2 = i.detach()[:, 1:].clone().requires_grad_()
output2 = m2(i2)
output2.backward(grad_output[:, offset:].contiguous())
self.assertEqual(output, torch.cat([output1, output2], 1),
atol=dtype2prec_DONTUSE[dtype], rtol=0)
self.assertEqual(i.grad.data,
torch.cat([i1.grad.data, i2.grad.data], 1),
atol=dtype2prec_DONTUSE[dtype], rtol=0)
self.assertEqual(m.bias.grad.data,
torch.cat([m1.bias.grad.data,
m2.bias.grad.data], 0),
atol=dtype2prec_DONTUSE[dtype], rtol=0)
self.assertEqual(m.weight.grad.data,
torch.cat([m1.weight.grad.data,
m2.weight.grad.data], 0),
atol=dtype2prec_DONTUSE[dtype], rtol=0)
@onlyCUDA
@dtypes(*get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM))
def test_noncontig_conv_grad(self, device, dtype):
# FIXME: remove after adding non-contiguous grad tests for all modules
module = nn.Conv2d(3, 5, kernel_size=3, padding=1).to(device, dtype)
input = torch.randn(2, 3, 10, 10, dtype=dtype, device=device, requires_grad=True)
output = module(input)
grad = torch.randn(2, 2, 5, 10, 10, dtype=dtype, device=device)[:, 1]
assert not grad.is_contiguous()
output.backward(grad, retain_graph=True)
self.assertIsNotNone(input.grad)
result = input.grad.data.clone()
input.grad.data.zero_()
output.backward(grad.contiguous())
self.assertEqual(result, input.grad.data, atol=dtype2prec_DONTUSE[dtype], rtol=0)
@onlyCUDA
@dtypes(torch.float, torch.half)
def test_batchnorm_large_batch(self, device, dtype):
bn = nn.BatchNorm2d(1).to(device, dtype)
data = torch.rand(880801, 1, 1, 1, device=device, dtype=dtype)
out = bn(data).sum().backward()
@onlyCUDA
@dtypes(torch.double)
def test_conv_double_backward(self, device, dtype):
with torch.backends.cudnn.flags(deterministic=True):
# Double backward only runs with DoubleTensor due to precision reason
batch_size = 1
for kern, inp_size, dilations in [(3, 5, [1, 2]), (4, 9, [1])]:
for stride, padding, chan_in, chan_out, dilation in product([1], [2], [2], [3], dilations):
no_weight = stride == 2
result = self.run_conv_double_back_test(kern, stride,
padding, chan_in, chan_out,
batch_size, inp_size, dilation,
no_weight, use_cuda=True, dtype=dtype)
self.assertTrue(result,
"Conv double backward test failed with parameters:" +
"\nkern: " + str(kern) +
"\nstride: " + str(stride) +
"\npadding: " + str(padding) +
"\nchan_in: " + str(chan_in) +
"\nchan_out: " + str(chan_out) +
"\nbatch_size: " + str(batch_size) +
"\ninp_size: " + str(inp_size) +
"\ndilation: " + str(dilation))
def test_conv_double_backward_no_bias(self):
kern = 3
stride = 2
chan_in, chan_out = 2, 4
batch_size = 2
inp_size = 5
padding = 1
dilation = 1
no_weight = False
use_bias = True
result = self.run_conv_double_back_test(kern, stride,
padding, chan_in, chan_out,
batch_size, inp_size, dilation,
no_weight, use_bias=use_bias)
self.assertTrue(result,
"Conv double backward test failed with parameters:" +
"\nkern: " + str(kern) +
"\nstride: " + str(stride) +
"\npadding: " + str(padding) +
"\nchan_in: " + str(chan_in) +
"\nchan_out: " + str(chan_out) +
"\nbatch_size: " + str(batch_size) +
"\ninp_size: " + str(inp_size) +
"\ndilation: " + str(dilation))
def test_conv_double_backward_groups(self):
kern = 3
stride = 1
padding = 2
chan_in, chan_out = 2, 4
batch_size = 2
inp_size = 6
dilation = 1
no_weight = False
groups = 2
result = self.run_conv_double_back_test(kern, stride,
padding, chan_in * groups, chan_out * groups,
batch_size, inp_size, dilation,
no_weight, groups=groups)
self.assertTrue(result,
"Conv double backward test failed with parameters:" +
"\nkern: " + str(kern) +
"\nstride: " + str(stride) +
"\npadding: " + str(padding) +
"\nchan_in: " + str(chan_in) +
"\nchan_out: " + str(chan_out) +
"\nbatch_size: " + str(batch_size) +
"\ninp_size: " + str(inp_size) +
"\ndilation: " + str(dilation) +
"\ngroups: " + str(groups))
def test_conv_double_backward_stride(self):
batch_size = 2
# Cannot provide ggW when stride is > 1
for kern, inp_size, dilations in [(3, 5, [1, 2]), (3, 7, [1])]:
for stride, padding, chan_in, chan_out, dilation in product([2], [0, 1], [1], [2], dilations):
no_weight = False
self.run_conv_double_back_test(kern, stride,
padding, chan_in, chan_out,
batch_size, inp_size, dilation,
no_weight)
def test_conv1d_same_padding(self, device): def test_conv1d_same_padding(self, device):
# Test padding='same' outputs the correct shape # Test padding='same' outputs the correct shape
test_args = [ test_args = [

View File

@ -755,14 +755,6 @@ def process_intentional_test_runs(runs: List[TestCase]) -> Tuple[int, int]:
num_pass += 1 num_pass += 1
REPEAT_TEST_FOR_TYPES_TESTS = [ REPEAT_TEST_FOR_TYPES_TESTS = [
"test_Conv2d_deterministic_cudnn ",
"test_Conv2d_large_workspace",
"test_ConvTranspose2d_large_output_padding",
"test_Conv2d_depthwise_naive_groups_cuda",
"test_Conv3d_depthwise_naive_groups_cuda",
"test_noncontig_conv_grad_cuda",
"test_batchnorm_large_batch",
"test_conv_double_backward_cuda",
"test_data_parallel_module", "test_data_parallel_module",
"test_data_parallel_module_kwargs_only", "test_data_parallel_module_kwargs_only",
"test_data_parallel_module_kwargs_only_empty_list", "test_data_parallel_module_kwargs_only_empty_list",