# Owner(s): ["module: nn"] import math import unittest import itertools import warnings from itertools import product import torch import torch.autograd.forward_ad as fwAD import torch.backends.cudnn as cudnn import torch.nn as nn import torch.nn.functional as F from torch.testing._internal.common_dtype import floating_types_and, floating_and_complex_types_and from torch.testing._internal.common_utils import run_tests, \ skipIfRocmVersionLessThan, skipIfNotMiopenSuggestNHWC, TEST_SCIPY, TEST_WITH_ROCM, \ download_file, parametrize as parametrize_test, subtest, \ instantiate_parametrized_tests, set_default_dtype from torch.testing._internal.common_cuda import TEST_CUDA, TEST_CUDNN from torch.testing._internal.common_nn import NNTestCase, _test_module_empty_input from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes, \ dtypesIfCUDA, precisionOverride, skipCUDAIfNoCudnn, skipCUDAIfCudnnVersionLessThan, onlyCUDA, onlyCPU, \ skipCUDAIfRocm, skipCUDAIfRocmVersionLessThan, skipCUDAIfNotMiopenSuggestNHWC, \ onlyNativeDeviceTypes, largeTensorTest, skipMeta, \ disableMkldnn, skipCPUIfNoMkldnn, disablecuDNN, skipCUDAIfMiopen, skipCUDAIfNoMiopen from torch.testing import make_tensor from torch.testing._internal.common_utils import gradcheck, gradgradcheck, \ GRADCHECK_NONDET_TOL from torch.testing._internal.common_utils import dtype2prec_DONTUSE from torch.testing._internal.common_cuda import tf32_on_and_off, tf32_is_not_fp32 AMPERE_OR_ROCM = TEST_WITH_ROCM or tf32_is_not_fp32() if TEST_SCIPY: import scipy.signal import scipy.ndimage class TestConvolutionNN(NNTestCase): _do_cuda_memory_leak_check = True _do_cuda_non_default_stream = True def test_conv_backcompat(self): from torch.serialization import SourceChangeWarning # This file was generated by running on PyTorch 1.0.1 on Python 2: # # import torch # from torch import nn # m = nn.Conv2d(1, 1, 1) # torch.save(m, 'legacy_conv2d.pt') # # NB: This Pickle also contains some Unicode data! path = download_file('https://download.pytorch.org/test_data/legacy_conv2d.pt') with warnings.catch_warnings(): warnings.simplefilter('ignore', SourceChangeWarning) m = torch.load(path, encoding='utf-8') input = torch.randn((1, 1, 1, 1), dtype=torch.float) self.assertEqual(m(input).size(), (1, 1, 1, 1)) def test_invalid_conv1d(self): for dtype in [torch.half, torch.bfloat16, torch.float, torch.double, torch.cfloat, torch.cdouble]: module = nn.Conv1d(in_channels=3, out_channels=33, kernel_size=10, stride=1, bias=True).to(dtype) input = torch.randn(1, 3, 4).to(dtype) with self.assertRaisesRegex(RuntimeError, r'Calculated padded input size per channel: \(4\). ' + r'Kernel size: \(10\). Kernel size can\'t be greater than actual input size'): module(input) # Negative stride check module = nn.Conv1d(in_channels=3, out_channels=6, kernel_size=3, stride=-1, bias=True).to(dtype) input = torch.randn(1, 3, 4).to(dtype) with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'): module(input) def test_mismatch_shape_conv2d(self): for dtype in (torch.float, torch.cfloat): x = torch.randn(1, 10, 1, 28, 28, dtype=dtype) w = torch.randn(6, 1, 5, 5, dtype=dtype) with self.assertRaisesRegex(RuntimeError, r'Expected 3D \(unbatched\) or 4D \(batched\) input to conv2d, but got ' + r'input of size: \[1, 10, 1, 28, 28\]'): F.conv2d(x, w) def test_conv2d_discontiguous_weight(self): for dtype in (torch.float, torch.cfloat): # Test for https://github.com/pytorch/pytorch/issues/55781 x = torch.ones(64, 16, 16, 16, dtype=dtype) weight = torch.arange(0, 1.0, 1 / 2.0 ** 10).reshape(32, 16, 1, 2).to(dtype)[:, :, :, ::2] self.assertFalse(weight.is_contiguous()) y = torch.nn.functional.conv2d(x, weight, None) if torch.backends.mkldnn.is_available(): # Disable MKLDNN explicitly, so that either NNPACK or THCNN will be used with torch.backends.mkldnn.flags(enabled=False): y_ = torch.nn.functional.conv2d(x, weight, None) self.assertEqual(y, y_) self.assertEqual(y.sum(), 4186112.) def test_invalid_conv2d(self): for dtype in [torch.half, torch.bfloat16, torch.float, torch.double, torch.cfloat, torch.cdouble]: module = torch.nn.Conv2d(1, 1, kernel_size=3, dilation=2, stride=2).to(dtype) input = torch.empty(1, 1, 4, 4).to(dtype) self.assertRaises(RuntimeError, lambda: module(input)) module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, stride=1, bias=True) input = torch.randn(1, 3, 1, 1) with self.assertRaisesRegex(RuntimeError, r'Calculated padded input size per channel: \(1 x 1\). ' + r'Kernel size: \(10 x 10\). Kernel size can\'t be greater than actual input size'): module(input) # Negative stride check module = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=4, stride=-1, bias=True).to(dtype) input = torch.randn(1, 3, 4, 4).to(dtype) with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'): module(input) # Zero stride check module = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=4, stride=0, bias=True).to(dtype) input = torch.randn(1, 3, 4, 4).to(dtype) with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'): module(input) def test_invalid_conv3d(self): for dtype in [torch.half, torch.bfloat16, torch.float, torch.double, torch.cfloat, torch.cdouble]: module = torch.nn.Conv3d(1, 1, kernel_size=3, dilation=2, stride=2).to(dtype) input = torch.empty(1, 1, 4, 4, 4).to(dtype) self.assertRaises(RuntimeError, lambda: module(input)) # Negative stride check module = torch.nn.Conv3d(1, 1, kernel_size=3, stride=-2) input = torch.empty(1, 1, 4, 4, 4) with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'): module(input) def test_conv_invalid_groups(self): with self.assertRaisesRegex(ValueError, 'groups must be a positive integer'): torch.nn.Conv1d(1, 1, kernel_size=3, dilation=2, stride=2, groups=0) with self.assertRaisesRegex(ValueError, 'groups must be a positive integer'): torch.nn.Conv2d(1, 1, kernel_size=3, dilation=2, stride=2, groups=-1) with self.assertRaisesRegex(ValueError, 'groups must be a positive integer'): torch.nn.Conv3d(1, 1, kernel_size=3, dilation=2, stride=2, groups=-2) def test_Conv1d_module_same_padding(self): # Compare module against functional: without strides/dilation, asymmetric padding x = torch.rand(1, 1, 20) module = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=10, padding='same') expect = F.conv1d(x, module.weight, module.bias, padding='same') self.assertEqual(expect, module(x)) # Test dilation, symmetric padding module = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=10, padding='same', dilation=2) expect = F.conv1d(x, module.weight, module.bias, padding='same', dilation=2) self.assertEqual(expect, module(x)) # Test non-zero padding_mode, requiring explicit padding module = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=10, padding='same', padding_mode='replicate') x_padded = F.pad(x, [4, 5], mode='replicate') expect = F.conv1d(x_padded, module.weight, module.bias, padding='valid') self.assertEqual(expect, module(x)) self.assertEqual(x.size(), expect.size()) # Test connstruction with invalid padding string raises with self.assertRaisesRegex(ValueError, 'Invalid padding string'): module = nn.Conv1d(in_channels=3, out_channels=33, kernel_size=10, padding='foo') # Test connstruction with same padding and strides raises with self.assertRaisesRegex(ValueError, "padding='same'"): module = nn.Conv1d(in_channels=3, out_channels=33, kernel_size=10, padding='same', stride=2) def test_Conv2d_module_same_padding(self): # Compare module against functional: # without strides/dilation, both symmetric and asymmetric padding x = torch.rand(1, 1, 9, 20) module = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(5, 10), padding='same') expect = F.conv2d(x, module.weight, module.bias, padding='same') self.assertEqual(expect, module(x)) # with dilation, symmetric padding module = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(3, 4), padding='same', dilation=(1, 2)) expect = F.conv2d(x, module.weight, module.bias, padding='same', dilation=(1, 2)) self.assertEqual(expect, module(x)) # Test non-zero padding_mode, requiring explicit padding module = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(3, 4), padding='same', padding_mode='reflect') x_padded = F.pad(x, [1, 2, 1, 1], mode='reflect') expect = F.conv2d(x_padded, module.weight, module.bias, padding='valid') self.assertEqual(expect, module(x)) self.assertEqual(x.size(), expect.size()) # Test connstruction with invalid padding string raises with self.assertRaisesRegex(ValueError, 'Invalid padding string'): module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, padding='foo') # Test connstruction with same padding and strides raises with self.assertRaisesRegex(ValueError, "padding='same'"): module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, padding='same', stride=2) with self.assertRaisesRegex(ValueError, "padding='same'"): module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, padding='same', stride=(1, 3)) with self.assertRaisesRegex(ValueError, "padding='same'"): module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, padding='same', stride=(4, 1)) def test_Conv3d_module_same_padding(self): # Compare module against functional: x = torch.rand(1, 1, 4, 4, 4) # without dilation, both symmetric and asymmetric padding module = nn.Conv3d(in_channels=1, out_channels=1, kernel_size=(2, 3, 4), padding='same') expect = F.conv3d(x, module.weight, module.bias, padding='same') self.assertEqual(expect, module(x)) # with dilation, both symmetric and asymmetric padding module = nn.Conv3d(in_channels=1, out_channels=1, kernel_size=(2, 3, 4), padding='same', dilation=(3, 2, 1)) expect = F.conv3d(x, module.weight, module.bias, padding='same', dilation=(3, 2, 1)) self.assertEqual(expect, module(x)) # Test non-zero padding_mode, requiring explicit padding module = nn.Conv3d(in_channels=1, out_channels=1, kernel_size=(2, 3, 4), padding='same', padding_mode='circular') x_padded = F.pad(x, [1, 2, 1, 1, 0, 1], mode='circular') expect = F.conv3d(x_padded, module.weight, module.bias, padding='valid') self.assertEqual(expect, module(x)) self.assertEqual(x.size(), expect.size()) # Test connstruction with invalid padding string raises with self.assertRaisesRegex(ValueError, 'Invalid padding string'): module = nn.Conv3d(in_channels=3, out_channels=33, kernel_size=10, padding='foo') # Test connstruction with same padding and strides raises with self.assertRaisesRegex(ValueError, "padding='same'"): module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, padding='same', stride=2) with self.assertRaisesRegex(ValueError, "padding='same'"): module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, padding='same', stride=(1, 1, 3)) with self.assertRaisesRegex(ValueError, "padding='same'"): module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, padding='same', stride=(1, 4, 1)) with self.assertRaisesRegex(ValueError, "padding='same'"): module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, padding='same', stride=(5, 1, 1)) @unittest.skipIf(not TEST_CUDA, 'CUDA not available') def test_thnn_conv_strided_padded_dilated(self): for convfn, dims, transposed in ( (torch.nn.functional.conv2d, 2, False), (torch.nn.functional.conv_transpose2d, 2, True), (torch.nn.functional.conv3d, 3, False), (torch.nn.functional.conv_transpose3d, 3, True)): for stride, padding, dilation in ( (2, 0, 1), (1, 1, 1), (2, 1, 1), (1, 0, 2)): kwargs = {"stride": stride, "padding": padding, "dilation": dilation} inp_shape = (1, 2) + dims * (4,) weight_shape = (2, 2) + dims * (1,) inputs = torch.randn(inp_shape, dtype=torch.double, device="cuda", requires_grad=True) weight = torch.randn(weight_shape, dtype=torch.double, device="cuda", requires_grad=True) bias = torch.randn(2, dtype=torch.double, device="cuda", requires_grad=True) with torch.backends.cudnn.flags(enabled=False): res = convfn(inputs, weight, bias, **kwargs) res_cpu = convfn(inputs.cpu(), weight.cpu(), bias.cpu(), **kwargs) self.assertEqual(res, res_cpu) with torch.backends.cudnn.flags(enabled=False): torch.autograd.gradcheck( lambda x, w, b: convfn(x, w, b, **kwargs), (inputs, weight, bias) ) torch.autograd.gradcheck( lambda x, w, b: convfn(x, w, b, **kwargs), (inputs.cpu(), weight.cpu(), bias.cpu()) ) def test_Conv2d_inconsistent_types(self): inputs = torch.randn(4, 1, 7, 7, dtype=torch.float) weights = torch.randn(1, 1, 3, 3, dtype=torch.double) # inconsistent types should raise an exception self.assertRaises(RuntimeError, lambda: nn.functional.conv2d(inputs, weights)) # but it should work with the same type nn.functional.conv2d(inputs.float(), weights.float()) @unittest.skipIf(not TEST_CUDA, 'CUDA not available') def test_Conv2d_inconsistent_types_on_GPU_without_cudnn(self): inputs = torch.randn(4, 1, 7, 7, dtype=torch.float, device="cuda") weights = torch.randn(1, 1, 3, 3, dtype=torch.double, device="cuda") bias = torch.randn(1, dtype=torch.double, device="cuda") with torch.backends.cudnn.flags(enabled=False): # inconsistent types should raise an exception self.assertRaises(RuntimeError, lambda: nn.functional.conv2d(inputs, weights)) self.assertRaises(RuntimeError, lambda: nn.functional.conv2d(inputs, weights.float(), bias)) # but it should work with the same type nn.functional.conv2d(inputs.float(), weights.float(), bias.float()) def test_Conv2d_1x1(self): in_channels = 2 out_channels = 2 mod = torch.nn.Conv2d(2, 2, 1, bias=False).to(dtype=torch.double) input = torch.randn(1, in_channels, 5, 5, requires_grad=True, dtype=torch.double) for enabled in (False, True): with torch.backends.mkldnn.flags(enabled=enabled): gradcheck(F.conv2d, (input, mod.weight)) def test_Conv2d_OneDNN(self): def run_once(group_val=24, dilation=1): ifm = torch.ones([1, group_val, 6, 6], dtype=torch.float32) weights = torch.ones([group_val, 1, 3, 3], dtype=torch.float32) op = torch.nn.Conv2d( in_channels=group_val, out_channels=group_val, kernel_size=[3, 3], stride=[2, 2], padding=[1, 1], dilation=[dilation, dilation], groups=group_val, bias=False, padding_mode='zeros' ) op.weight.data = weights res = op(ifm) grad_in = torch.ones(res.shape, dtype=torch.float32) res.backward(grad_in) return op.weight.grad for gorup_val in (24, 48, 23, 25): for dilation in (1, 2): with torch.backends.mkldnn.flags(enabled=False): without_onednn = run_once(gorup_val, dilation) with torch.backends.mkldnn.flags(enabled=True): with_onednn = run_once(gorup_val, dilation) self.assertEqual(without_onednn, with_onednn) @unittest.skipIf(not TEST_CUDA, 'CUDA not available') @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available') def test_cudnn_non_contiguous(self): x = torch.randn(192, 16, 50).cuda() x = x.permute(0, 2, 1).contiguous().permute(0, 2, 1) m = torch.nn.Conv1d( in_channels=16, out_channels=32, kernel_size=2, bias=True).cuda() result = m(x) @unittest.skipIf(not TEST_CUDA, 'CUDA not available') @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available') def test_Conv2d_inconsistent_types_on_GPU_with_cudnn(self): inputs = torch.randn(4, 1, 7, 7, dtype=torch.float, device="cuda") weights = torch.randn(1, 1, 3, 3, dtype=torch.double, device="cuda") bias = torch.randn(1, dtype=torch.double, device="cuda") with torch.backends.cudnn.flags(enabled=True): # inconsistent types should raise an exception self.assertRaises(RuntimeError, lambda: nn.functional.conv2d(inputs, weights)) self.assertRaises(RuntimeError, lambda: nn.functional.conv2d(inputs, weights.float(), bias)) # but it should work with the same type nn.functional.conv2d(inputs.float(), weights.float(), bias.float()) def test_Conv2d_missing_argument(self): c = nn.Conv2d(3, 3, 3) self.assertRaises(TypeError, lambda: c(None)) def test_Conv2d_backward_twice(self): input = torch.randn(2, 3, 5, 5) c = nn.Conv2d(3, 3, 3) o1 = c(input) o1.sum().backward() self.assertRaisesRegex(RuntimeError, 'Specify retain_graph=True', lambda: o1.sum().backward()) def test_conv_modules_raise_error_on_incorrect_input_size(self): for dtype in [torch.half, torch.bfloat16, torch.double, torch.float]: modules = [nn.Conv1d(3, 8, 3).to(dtype), nn.ConvTranspose1d(3, 8, 3).to(dtype), nn.Conv2d(3, 8, 3).to(dtype), nn.ConvTranspose2d(3, 8, 3).to(dtype), nn.Conv3d(3, 8, 3).to(dtype), nn.ConvTranspose3d(3, 8, 3).to(dtype)] invalid_input_dims = [(1, 4), (1, 4), (2, 5), (2, 5), (3, 6), (3, 6)] for invalid_dims, module in zip(invalid_input_dims, modules): for dims in invalid_dims: input = torch.empty(torch.Size((3, ) * dims)) self.assertRaises(RuntimeError, lambda: module(input)) def test_conv_shapecheck(self): def test(should_raise, module, input_size, dtype): input = torch.empty(3, *input_size).to(dtype) if should_raise: self.assertRaises(RuntimeError, lambda: module(input)) else: # just run it to ensure no exception raised. module(input) for dtype in [torch.half, torch.bfloat16, torch.float, torch.double, torch.cfloat, torch.cdouble]: # Conv1d test(True, nn.Conv1d(1, 1, 3).to(dtype), (1, 2), dtype) test(True, nn.Conv1d(1, 1, 3, stride=2).to(dtype), (1, 2), dtype) test(False, nn.Conv1d(1, 1, 2).to(dtype), (1, 2), dtype) test(False, nn.Conv1d(1, 1, 2, stride=2).to(dtype), (1, 2), dtype) test(False, nn.Conv1d(1, 1, 3, stride=2, padding=1).to(dtype), (1, 2), dtype) # Conv2d test(True, nn.Conv2d(1, 1, (3, 3)).to(dtype), (1, 2, 2), dtype) test(False, nn.Conv2d(1, 1, (3, 3)).to(dtype), (1, 3, 3), dtype) test(False, nn.Conv2d(1, 1, (3, 3), padding=1).to(dtype), (1, 2, 2), dtype) # Conv3D test(True, nn.Conv3d(1, 1, (3, 3, 3)).to(dtype), (1, 2, 2, 2), dtype) test(False, nn.Conv3d(1, 1, (3, 3, 3)).to(dtype), (1, 3, 3, 3), dtype) test(False, nn.Conv3d(1, 1, (3, 3, 3), padding=1).to(dtype), (1, 2, 2, 2), dtype) def test_ConvTranspose2d_output_size(self): m = nn.ConvTranspose2d(3, 4, 3, 3, 0, 2) i = torch.randn(2, 3, 6, 6) for h in range(15, 22): for w in range(15, 22): if 18 <= h <= 20 and 18 <= w <= 20: output = m(i, output_size=(h, w)) self.assertEqual(output.size()[2:], (h, w)) else: self.assertRaises(ValueError, lambda: m(i, (h, w))) def test_ConvTranspose2d_output_size_downsample_upsample(self): b, c, hid_c = 2, 3, 2 for h in range(13, 24): for w in range(13, 17): for k in range(2, 5): for d in range(1, 5): for s in range(1, 4): for p in range(3): conv = nn.Conv2d( in_channels=c, out_channels=hid_c, kernel_size=k, stride=s, padding=p, dilation=d, ) t_conv = nn.ConvTranspose2d( in_channels=hid_c, out_channels=c, kernel_size=k, stride=s, padding=p, dilation=d, ) i = torch.randn(b, c, h, w) out = t_conv(conv(i), output_size=i.shape) self.assertEqual(out.size()[2:], i.size()[2:]) def test_ConvTranspose3d_correct_output_size(self): # Check that ConvTranspose3d can take a 5d output_size. m = nn.ConvTranspose3d(2, 2, 2) i = torch.rand(1, 2, 1, 1, 1) out = m(i, output_size=(1, 2, 2, 2, 2)) @unittest.skipIf(not TEST_CUDA, 'CUDA not available') def test_ConvTranspose2d_half_cublas_gemm(self): with torch.backends.cudnn.flags(enabled=False): inputs = torch.randn(1, 1, 16, 16, device='cuda', dtype=torch.half) deconv = nn.ConvTranspose2d( 1, 1, 3, stride=2, padding=1, output_padding=1).cuda().half() output = deconv(inputs) output.mean().backward() # For https://github.com/pytorch/pytorch/pull/1273 # Almost identical to the above `test_Conv2d_naive_groups` @torch.backends.cudnn.flags(enabled=True, benchmark=False) @unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm, since it is failing on ROCm 5.7") def test_Conv2d_groups_nobias(self): dev_dtypes = [("cpu", torch.float)] if TEST_CUDA: dev_dtypes += [("cuda", torch.float), ("cuda", torch.half)] if AMPERE_OR_ROCM: dev_dtypes += [("cuda", torch.bfloat16)] for device, dtype in dev_dtypes: m = nn.Conv2d(4, 4, kernel_size=3, groups=2, bias=False).to(device, dtype) i = torch.randn(2, 4, 6, 6, device=device, dtype=dtype, requires_grad=True) output = m(i) grad_output = torch.randn(2, 4, 4, 4, device=device, dtype=dtype) output.backward(grad_output) m1 = nn.Conv2d(2, 2, kernel_size=3, bias=False).to(device, dtype) m1.weight.data.copy_(m.weight.data[:2]) i1 = i.data[:, :2].contiguous().requires_grad_(True) output1 = m1(i1) output1.backward(grad_output[:, :2].contiguous()) m2 = nn.Conv2d(2, 2, kernel_size=3, bias=False).to(device, dtype) m2.weight.data.copy_(m.weight.data[2:]) i2 = i.data[:, 2:].contiguous().requires_grad_(True) output2 = m2(i2) output2.backward(grad_output[:, 2:].contiguous()) self.assertEqual(output, torch.cat([output1, output2], 1)) self.assertEqual(i.grad.data, torch.cat([i1.grad.data, i2.grad.data], 1), atol=dtype2prec_DONTUSE[dtype], rtol=0) self.assertEqual(m.weight.grad.data, torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0), atol=1e-1 if dtype == torch.half else dtype2prec_DONTUSE[dtype], rtol=0) # Almost identical to the above `test_Conv2d_naive_groups` # Covering special case when group > 1, input-channel / group < 16 and output-channel is multiple of 16 # See also https://github.com/pytorch/pytorch/pull/18463#issuecomment-476563686 # and https://github.com/pytorch/pytorch/pull/18463#issuecomment-477001024 @torch.backends.cudnn.flags(enabled=True, benchmark=False) @unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm, since it is failing on ROCm 5.7") def test_Conv2d_groups_nobias_v2(self): torch.manual_seed(123) dev_dtypes = [("cpu", torch.float)] if TEST_CUDA: dev_dtypes += [("cuda", torch.float), ("cuda", torch.half)] if AMPERE_OR_ROCM: dev_dtypes += [("cuda", torch.bfloat16)] for device, dtype in dev_dtypes: m = nn.Conv2d(4, 16, kernel_size=3, groups=2, bias=False).to(device, dtype) i = torch.randn(2, 4, 6, 6, device=device, dtype=dtype, requires_grad=True) output = m(i) grad_output = torch.randn(2, 16, 4, 4, device=device, dtype=dtype) output.backward(grad_output) m1 = nn.Conv2d(2, 8, kernel_size=3, bias=False).to(device, dtype) m1.weight.data.copy_(m.weight.data[:8]) i1 = i.data[:, :2].contiguous().requires_grad_(True) output1 = m1(i1) output1.backward(grad_output[:, :8].contiguous()) m2 = nn.Conv2d(2, 8, kernel_size=3, bias=False).to(device, dtype) m2.weight.data.copy_(m.weight.data[8:]) i2 = i.data[:, 2:].contiguous().requires_grad_(True) output2 = m2(i2) output2.backward(grad_output[:, 8:].contiguous()) self.assertEqual(output, torch.cat([output1, output2], 1)) self.assertEqual(i.grad.data, torch.cat([i1.grad.data, i2.grad.data], 1), atol=dtype2prec_DONTUSE[dtype], rtol=0) self.assertEqual(m.weight.grad.data, torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0), atol=1e-1 if dtype == torch.half else dtype2prec_DONTUSE[dtype], rtol=0) # CPU-only test for group conv3d fast implementation using bmm # See: https://github.com/pytorch/pytorch/pull/36355 def test_Conv3d_groups_nobias(self): torch.manual_seed(123) m = nn.Conv3d(4, 16, kernel_size=3, groups=2, bias=False).to("cpu", torch.float) i = torch.randn(2, 4, 6, 6, 6, device="cpu", dtype=torch.float, requires_grad=True) output = m(i) grad_output = torch.randn(2, 16, 4, 4, 4, device="cpu", dtype=torch.float) output.backward(grad_output) m1 = nn.Conv3d(2, 8, kernel_size=3, bias=False).to("cpu", torch.float) m1.weight.data.copy_(m.weight.data[:8]) i1 = i.data[:, :2].contiguous().requires_grad_(True) output1 = m1(i1) output1.backward(grad_output[:, :8].contiguous()) m2 = nn.Conv3d(2, 8, kernel_size=3, bias=False).to("cpu", torch.float) m2.weight.data.copy_(m.weight.data[8:]) i2 = i.data[:, 2:].contiguous().requires_grad_(True) output2 = m2(i2) output2.backward(grad_output[:, 8:].contiguous()) self.assertEqual(output, torch.cat([output1, output2], 1)) self.assertEqual(i.grad.data, torch.cat([i1.grad.data, i2.grad.data], 1), atol=dtype2prec_DONTUSE[torch.float], rtol=0) self.assertEqual(m.weight.grad.data, torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0), atol=dtype2prec_DONTUSE[torch.float], rtol=dtype2prec_DONTUSE[torch.float]) def test_Conv3d_groups_wbias(self): torch.manual_seed(123) m = nn.Conv3d(4, 16, kernel_size=3, groups=2, bias=True).to("cpu", torch.float) i = torch.randn(2, 4, 6, 6, 6, device="cpu", dtype=torch.float, requires_grad=True) output = m(i) grad_output = torch.randn(2, 16, 4, 4, 4, device="cpu", dtype=torch.float) output.backward(grad_output) m1 = nn.Conv3d(2, 8, kernel_size=3, bias=True).to("cpu", torch.float) m1.weight.data.copy_(m.weight.data[:8]) m1.bias.data.copy_(m.bias.data[:8]) i1 = i.data[:, :2].contiguous().requires_grad_(True) output1 = m1(i1) output1.backward(grad_output[:, :8].contiguous()) m2 = nn.Conv3d(2, 8, kernel_size=3, bias=True).to("cpu", torch.float) m2.weight.data.copy_(m.weight.data[8:]) m2.bias.data.copy_(m.bias.data[8:]) i2 = i.data[:, 2:].contiguous().requires_grad_(True) output2 = m2(i2) output2.backward(grad_output[:, 8:].contiguous()) self.assertEqual(output, torch.cat([output1, output2], 1)) self.assertEqual(i.grad.data, torch.cat([i1.grad.data, i2.grad.data], 1), atol=dtype2prec_DONTUSE[torch.float], rtol=dtype2prec_DONTUSE[torch.float]) self.assertEqual(m.weight.grad.data, torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0), atol=dtype2prec_DONTUSE[torch.float], rtol=dtype2prec_DONTUSE[torch.float]) self.assertEqual(m.bias.grad.data, torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0), atol=dtype2prec_DONTUSE[torch.float], rtol=dtype2prec_DONTUSE[torch.float]) def test_conv_tbc(self): with set_default_dtype(torch.double): inp = torch.randn(9, 4, 5, requires_grad=True) weight = torch.randn(3, 5, 6, requires_grad=True) bias = torch.randn(6, requires_grad=True) gradcheck(lambda i, w, b, pad: F.conv_tbc(i, w, b, pad), (inp, weight, bias, 3)) @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") @unittest.skipIf(not TEST_CUDNN, "needs cudnn") @skipIfRocmVersionLessThan((4, 3)) @skipIfNotMiopenSuggestNHWC def test_grouped_conv_cudnn_nhwc_support(self): # in order to catch the hols in grouped convolution in nhwc support for earlier cudnn version input = torch.randn((16, 16, 8, 8), dtype=torch.float16, device="cuda").to(memory_format=torch.channels_last) weight = torch.randn((8, 4, 3, 3), dtype=torch.float16, device="cuda").to(memory_format=torch.channels_last) out = torch.convolution(input, weight, None, (1, 1), (1, 1), (1, 1), False, (0, 0), 4) input = torch.randn((16, 8, 8, 8), dtype=torch.float16, device="cuda").to(memory_format=torch.channels_last) out_transpose = torch.convolution(input, weight, None, (1, 1), (1, 1), (1, 1), True, (0, 0), 4) @unittest.expectedFailure @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") @unittest.skipIf(not TEST_CUDNN, "needs cudnn") def test_conv_cudnn_memory_layout_dominance(self): # desired behavior here is to have the memory_layout of conv.weight to # dominante the layout of output. # which is not the same as current behavior, we'll fix this in # following up PRs and remove the `expectedFailure` tag input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float32, device="cuda", requires_grad=True) conv = nn.Conv2d(8, 4, 3).cuda().float() out = conv(input) self.assertTrue(out.is_contiguous()) input = input.contiguous(memory_format=torch.channels_last) out = conv(input) self.assertTrue(out.is_contiguous()) conv.weight.data = conv.weight.contiguous(memory_format=torch.channels_last) out = conv(input) self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) input = input.contiguous() out = conv(input) self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") def test_cudnn_noncontiguous_weight(self): # Noncontiguous weights must be contiguous() before being # passed to cuDNN input = torch.tensor([1, 1, 1], dtype=torch.double, device="cuda").view(1, 1, 3) weights1 = torch.tensor([1], dtype=torch.double, device="cuda").expand(1, 1, 2) weights2 = torch.tensor([1], dtype=torch.double, device="cuda").expand(1, 1, 2).contiguous() self.assertEqual(F.conv1d(input, weights1, bias=None, stride=2, dilation=2), F.conv1d(input, weights2, bias=None, stride=2, dilation=2)) def run_grad_conv_test(self, func_forward, func_backward, dim=1, gradient='input'): for kern, inp_size in [(3, 6), (3, 7), (4, 9)]: for batch, stride, padding, chan_in, chan_out, dilation in \ product([1, 2], [1, 2], [0, 1, 2], [2], [3], [1]): for has_bias in [True, False]: input_shape = [batch, chan_in] weight_shape = [chan_out, chan_in] for _ in range(dim): input_shape.append(inp_size) weight_shape.append(kern) input = torch.randn(input_shape, requires_grad=True) weight = torch.randn(weight_shape, requires_grad=True) if has_bias: bias = torch.randn([chan_out], requires_grad=True) output = func_forward(input, weight, stride=stride, padding=padding, dilation=dilation, bias=bias) gradient_o = torch.randn(output.shape) gradient_w = torch.autograd.grad(output, input if (gradient == 'input') else weight, gradient_o) self.assertEqual(gradient_w[0], func_backward( input_shape if (gradient == 'input') else input, weight_shape if (gradient == 'weight') else weight, gradient_o, stride=stride, padding=padding, dilation=dilation)) def test_grad_conv1d_input(self): self.run_grad_conv_test(F.conv1d, F.grad.conv1d_input, 1, 'input') def test_grad_conv1d_weight(self): self.run_grad_conv_test(F.conv1d, F.grad.conv1d_weight, 1, 'weight') def test_grad_conv2d_input(self): self.run_grad_conv_test(F.conv2d, F.grad.conv2d_input, 2, 'input') def test_grad_conv2d_weight(self): self.run_grad_conv_test(F.conv2d, F.grad.conv2d_weight, 2, 'weight') def test_grad_conv3d_input(self): self.run_grad_conv_test(F.conv3d, F.grad.conv3d_input, 3, 'input') def test_grad_conv3d_weight(self): self.run_grad_conv_test(F.conv3d, F.grad.conv3d_weight, 3, 'weight') @unittest.skipIf(not torch._nnpack_available(), "NNPACK unavailable") def test_nnpack_conv(self): for kern, inp_size in [(3, 6), (3, 7), (4, 9)]: for batch, stride, padding, chan_in, chan_out in \ product([1, 2, 3, 4], [1, 2], [0, 1, 2], [2], [3]): for has_bias in [True, False]: input_shape = [batch, chan_in] weight_shape = [chan_out, chan_in] for _ in range(2): input_shape.append(inp_size) weight_shape.append(kern) input = torch.randn(input_shape, requires_grad=True, dtype=torch.float) weight = torch.randn(weight_shape, requires_grad=True, dtype=torch.float) if has_bias: bias = torch.randn([chan_out], requires_grad=True, dtype=torch.float) output = torch._nnpack_spatial_convolution(input, weight, stride=stride, padding=padding, bias=bias) output_expected = torch.nn.functional.conv2d( input, weight, stride=stride, padding=padding, bias=bias) self.assertEqual(output, output_expected, atol=3e-4, rtol=0) gradient_o = torch.randn(output.shape, dtype=torch.float) grads = torch.autograd.grad(output, [input, weight], gradient_o) grads_expected = torch.autograd.grad(output_expected, [input, weight], gradient_o) for gr, gr_expected in zip(grads, grads_expected): self.assertEqual(gr, gr_expected, atol=3e-4, rtol=0) def test_conv_padding_mode(self): with self.assertRaisesRegex(ValueError, "padding_mode must be one of"): nn.Conv2d(3, 3, 3, padding_mode="xyz") with self.assertRaisesRegex(ValueError, "padding_mode must be one of"): nn.Conv2d(3, 3, 3, padding_mode=3) with self.assertRaisesRegex(ValueError, "Only \"zeros\" "): nn.ConvTranspose2d(3, 3, 3, padding_mode="reflect") def test_functional_grad_conv(self): # Conv 1D input = torch.randn(1, 1, 5, requires_grad=True) weight = torch.randn(1, 1, 3, requires_grad=True) output = F.conv1d(input, weight, dilation=2) grad_output = torch.randn(output.shape) grad_input_autograd, grad_weight_autograd = torch.autograd.grad(output, (input, weight), grad_output) grad_input_functional = torch.nn.grad.conv1d_input(input.shape, weight, grad_output, dilation=2) self.assertEqual(grad_input_functional, grad_input_autograd) grad_weight_functional = torch.nn.grad.conv1d_weight(input, weight.shape, grad_output, dilation=2) self.assertEqual(grad_weight_functional, grad_weight_autograd) # Conv 2D input = torch.randn(1, 1, 5, 5, requires_grad=True) weight = torch.randn(1, 1, 3, 3, requires_grad=True) output = F.conv2d(input, weight, dilation=2) grad_output = torch.randn(output.shape) (grad_input_autograd, grad_weight_autograd) = torch.autograd.grad(output, (input, weight), grad_output) grad_input_functional = torch.nn.grad.conv2d_input(input.shape, weight, grad_output, dilation=2) self.assertEqual(grad_input_functional, grad_input_autograd) grad_weight_functional = torch.nn.grad.conv2d_weight(input, weight.shape, grad_output, dilation=2) self.assertEqual(grad_weight_functional, grad_weight_autograd) # Conv 3D input = torch.randn(1, 1, 5, 5, 5, requires_grad=True) weight = torch.randn(1, 1, 3, 3, 3, requires_grad=True) output = F.conv3d(input, weight, dilation=2) grad_output = torch.randn(output.shape) (grad_input_autograd, grad_weight_autograd) = torch.autograd.grad(output, (input, weight), grad_output) grad_input_functional = torch.nn.grad.conv3d_input(input.shape, weight, grad_output, dilation=2) self.assertEqual(grad_input_functional, grad_input_autograd) grad_weight_functional = torch.nn.grad.conv3d_weight(input, weight.shape, grad_output, dilation=2) self.assertEqual(grad_weight_functional, grad_weight_autograd) def test_functional_grad_conv2d(self): BATCH_SIZE = 4 IN_CH = 8 OUT_CH = 16 SPATIAL = 32 def _test_conv2d(stride, kernel_size, groups, dilation): padding = kernel_size // 2 input = torch.empty(BATCH_SIZE, IN_CH, SPATIAL, SPATIAL).uniform_(-8.0, 8.0).requires_grad_(True) weight = torch.empty(OUT_CH, IN_CH // groups, kernel_size, kernel_size).uniform_(-4.0, 4.0).requires_grad_(True) output = F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) grad_output = torch.randn(output.shape) (grad_input_autograd, grad_weight_autograd) = torch.autograd.grad(output, (input, weight), grad_output) grad_input_functional = torch.nn.grad.conv2d_input(input.shape, weight, grad_output, stride=stride, padding=padding, dilation=dilation, groups=groups) self.assertEqual(grad_input_functional, grad_input_autograd) grad_weight_functional = torch.nn.grad.conv2d_weight(input, weight.shape, grad_output, stride=stride, padding=padding, dilation=dilation, groups=groups) self.assertEqual(grad_weight_functional, grad_weight_autograd) strides = [1, 2] kernel_sizes = [1, 3, 5] groups = [1, 2, 4] dilates = [1, 2] for s, k, g, d in product(strides, kernel_sizes, groups, dilates): _test_conv2d(s, k, g, d) class TestConvolutionNNDeviceType(NNTestCase): def run_conv_double_back_test(self, kern, stride, padding, chan_in, chan_out, batch_size, inp_size, dilation, no_weight, groups=1, use_cuda=False, use_bias=True, dtype=torch.double): if use_cuda: device = torch.device("cuda") else: device = torch.device("cpu") x = torch.randn(batch_size, chan_in, inp_size, inp_size, device=device, dtype=dtype, requires_grad=True) weight = torch.randn(chan_out, chan_in // groups, kern, kern, device=device, dtype=dtype, requires_grad=not no_weight) if use_bias: bias = torch.randn(chan_out, device=device, dtype=dtype, requires_grad=True) else: bias = None def func(*inputs): if use_bias: lx, lweight, lbias = inputs else: lx, lweight = inputs lbias = None # We disable cudnn during forward to avoid finite difference imprecision issues with cudnn.flags(enabled=False): out = F.conv2d(lx, lweight, lbias, stride, padding, dilation, groups) return out if use_bias: inputs = x, weight, bias else: inputs = x, weight dummy_out = func(*inputs) grad_y = torch.randn_like(dummy_out, device=device, dtype=dtype, requires_grad=True) # Issue #15353: test mkldnn double backward, don't run gradgradcheck due # to imprecision issues if dtype == torch.float: g, = torch.autograd.grad(dummy_out.sum(), x, create_graph=True) return g.requires_grad return gradgradcheck(func, inputs, (grad_y,)) @onlyCUDA @skipCUDAIfNoCudnn @dtypes(*floating_and_complex_types_and(torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else [])) def test_Conv2d_deterministic_cudnn(self, device, dtype): inputs = torch.randn(2, 3, 5, 5, device=device, dtype=dtype, requires_grad=True) with cudnn.flags(enabled=True, benchmark=True, deterministic=True): conv1 = torch.nn.Conv2d(3, 3, 3).to(device, dtype) conv2 = torch.nn.Conv2d(3, 3, 3).to(device, dtype) conv2.bias.data.copy_(conv1.bias.data) conv2.weight.data.copy_(conv1.weight.data) out1 = conv1(inputs) out2 = conv2(inputs) self.assertEqual(out1, out2, atol=0.0, rtol=0) y = torch.randn(out1.size(), device=device, dtype=dtype) out1.backward(y) out2.backward(y) self.assertEqual(conv1.bias.grad.data, conv2.bias.grad.data, atol=0.0, rtol=0) self.assertEqual(conv1.weight.grad.data, conv2.weight.grad.data, atol=0.0, rtol=0) @onlyCUDA @dtypes(*floating_types_and(torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else [])) def test_Conv2d_large_workspace(self, device, dtype): # These sizes require huge cuDNN workspaces. Make sure we choose a # reasonable algorithm that does not run out of memory sizes = [ (1, 256, 109, 175), (1, 256, 80, 128), (1, 256, 120, 192), ] def run_test(benchmark): with torch.backends.cudnn.flags(enabled=True, benchmark=benchmark): conv = torch.nn.Conv2d(256, 256, kernel_size=3, padding=1).to(device, dtype) for size in sizes: x = torch.randn(size, device=device, dtype=dtype) out = conv(x.detach().clone().requires_grad_()) out.backward(torch.ones_like(out)) run_test(benchmark=False) run_test(benchmark=True) @onlyCUDA @dtypes(torch.half, torch.float) def test_ConvTranspose2d_large_output_padding(self, device, dtype): net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\ .to(device=device, dtype=dtype) net2 = torch.nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)\ .to(device=device, dtype=dtype) net3 = torch.nn.ConvTranspose2d(32, 3, kernel_size=3, stride=2, padding=1, output_padding=1)\ .to(device=device, dtype=dtype) x = torch.rand(1, 128, 6, 6, device=device, dtype=dtype, requires_grad=True) x = net1(x) x = net2(x) x = net3(x) x.backward(torch.randn_like(x)) torch.cuda.synchronize() @onlyCUDA @dtypes(torch.float, torch.double, torch.half) # Very similar to test_Conv2d_naive_groups but with special care to handle # the number of groups == number of input channels @torch.backends.cudnn.flags(enabled=True, benchmark=False) @tf32_on_and_off(0.01) def test_Conv2d_depthwise_naive_groups(self, device, dtype): for depth_multiplier in [1, 2]: m = nn.Conv2d(2, 2 * depth_multiplier, kernel_size=3, groups=2).to(device, dtype) i = torch.randn(2, 2, 6, 6, device="cuda", dtype=dtype).div_(2).requires_grad_() output = m(i) grad_output = torch.randn(2, 2 * depth_multiplier, 4, 4, device=device, dtype=dtype) / 2 output.backward(grad_output) offset = 1 * depth_multiplier m1 = nn.Conv2d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype) m1.weight.data = m.weight.data[:offset].clone() m1.bias.data = m.bias.data[:offset].clone() i1 = i.detach()[:, :1].clone().requires_grad_() output1 = m1(i1) output1.backward(grad_output[:, :offset].contiguous()) m2 = nn.Conv2d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype) m2.weight.data.copy_(m.weight.data[offset:]) m2.bias.data.copy_(m.bias.data[offset:]) i2 = i.detach()[:, 1:].clone().requires_grad_() output2 = m2(i2) output2.backward(grad_output[:, offset:].contiguous()) self.assertEqual(output, torch.cat([output1, output2], 1), atol=dtype2prec_DONTUSE[dtype], rtol=0) self.assertEqual(i.grad.data, torch.cat([i1.grad.data, i2.grad.data], 1), atol=dtype2prec_DONTUSE[dtype], rtol=0) self.assertEqual(m.bias.grad.data, torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0), atol=dtype2prec_DONTUSE[dtype], rtol=0) self.assertEqual(m.weight.grad.data, torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0), atol=dtype2prec_DONTUSE[dtype], rtol=0) @onlyCUDA @dtypes(torch.float, torch.double, torch.half) @torch.backends.cudnn.flags(enabled=True, benchmark=False) @tf32_on_and_off(0.005) def test_Conv3d_depthwise_naive_groups(self, device, dtype): for depth_multiplier in [1, 2]: m = nn.Conv3d(2, 2 * depth_multiplier, kernel_size=3, groups=2).to(device, dtype) i = torch.randn(2, 2, 6, 6, 6, device="cuda", dtype=dtype).div_(2).requires_grad_() output = m(i) grad_output = torch.randn(2, 2 * depth_multiplier, 4, 4, 4, device=device, dtype=dtype) / 2 output.backward(grad_output) offset = 1 * depth_multiplier m1 = nn.Conv3d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype) m1.weight.data = m.weight.data[:offset].clone() m1.bias.data = m.bias.data[:offset].clone() i1 = i.detach()[:, :1].clone().requires_grad_() output1 = m1(i1) output1.backward(grad_output[:, :offset].contiguous()) m2 = nn.Conv3d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype) m2.weight.data.copy_(m.weight.data[offset:]) m2.bias.data.copy_(m.bias.data[offset:]) i2 = i.detach()[:, 1:].clone().requires_grad_() output2 = m2(i2) output2.backward(grad_output[:, offset:].contiguous()) is_cuda_sm86 = device.startswith("cuda") and torch.cuda.get_device_capability(0) == (8, 6) atol, rtol = (3e-4, 3e-2) if dtype == torch.float32 and is_cuda_sm86 else (dtype2prec_DONTUSE[dtype], 0) self.assertEqual(output, torch.cat([output1, output2], 1), atol=atol, rtol=rtol) self.assertEqual(i.grad.data, torch.cat([i1.grad.data, i2.grad.data], 1), atol=dtype2prec_DONTUSE[dtype], rtol=0) self.assertEqual(m.bias.grad.data, torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0), atol=dtype2prec_DONTUSE[dtype], rtol=0) self.assertEqual(m.weight.grad.data, torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0), atol=atol, rtol=rtol) @onlyCUDA @dtypes(*floating_types_and(torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else [])) def test_noncontig_conv_grad(self, device, dtype): # FIXME: remove after adding non-contiguous grad tests for all modules module = nn.Conv2d(3, 5, kernel_size=3, padding=1).to(device, dtype) input = torch.randn(2, 3, 10, 10, dtype=dtype, device=device, requires_grad=True) output = module(input) grad = torch.randn(2, 2, 5, 10, 10, dtype=dtype, device=device)[:, 1] assert not grad.is_contiguous() output.backward(grad, retain_graph=True) self.assertIsNotNone(input.grad) result = input.grad.data.clone() input.grad.data.zero_() output.backward(grad.contiguous()) self.assertEqual(result, input.grad.data, atol=dtype2prec_DONTUSE[dtype], rtol=0) @onlyCUDA @dtypes(torch.double) def test_conv_double_backward(self, device, dtype): with torch.backends.cudnn.flags(enabled=True, deterministic=True): # Double backward only runs with DoubleTensor due to precision reason batch_size = 1 for kern, inp_size, dilations in [(3, 5, [1, 2]), (4, 9, [1])]: for stride, padding, chan_in, chan_out, dilation in product([1], [2], [2], [3], dilations): no_weight = stride == 2 result = self.run_conv_double_back_test(kern, stride, padding, chan_in, chan_out, batch_size, inp_size, dilation, no_weight, use_cuda=True, dtype=dtype) self.assertTrue(result, "Conv double backward test failed with parameters:" + "\nkern: " + str(kern) + "\nstride: " + str(stride) + "\npadding: " + str(padding) + "\nchan_in: " + str(chan_in) + "\nchan_out: " + str(chan_out) + "\nbatch_size: " + str(batch_size) + "\ninp_size: " + str(inp_size) + "\ndilation: " + str(dilation)) def test_conv_double_backward_no_bias(self): kern = 3 stride = 2 chan_in, chan_out = 2, 4 batch_size = 2 inp_size = 5 padding = 1 dilation = 1 no_weight = False use_bias = True result = self.run_conv_double_back_test(kern, stride, padding, chan_in, chan_out, batch_size, inp_size, dilation, no_weight, use_bias=use_bias) self.assertTrue(result, "Conv double backward test failed with parameters:" + "\nkern: " + str(kern) + "\nstride: " + str(stride) + "\npadding: " + str(padding) + "\nchan_in: " + str(chan_in) + "\nchan_out: " + str(chan_out) + "\nbatch_size: " + str(batch_size) + "\ninp_size: " + str(inp_size) + "\ndilation: " + str(dilation)) def test_conv_double_backward_groups(self): kern = 3 stride = 1 padding = 2 chan_in, chan_out = 2, 4 batch_size = 2 inp_size = 6 dilation = 1 no_weight = False groups = 2 result = self.run_conv_double_back_test(kern, stride, padding, chan_in * groups, chan_out * groups, batch_size, inp_size, dilation, no_weight, groups=groups) self.assertTrue(result, "Conv double backward test failed with parameters:" + "\nkern: " + str(kern) + "\nstride: " + str(stride) + "\npadding: " + str(padding) + "\nchan_in: " + str(chan_in) + "\nchan_out: " + str(chan_out) + "\nbatch_size: " + str(batch_size) + "\ninp_size: " + str(inp_size) + "\ndilation: " + str(dilation) + "\ngroups: " + str(groups)) def test_conv_double_backward_stride(self): batch_size = 2 # Cannot provide ggW when stride is > 1 for kern, inp_size, dilations in [(3, 5, [1, 2]), (3, 7, [1])]: for stride, padding, chan_in, chan_out, dilation in product([2], [0, 1], [1], [2], dilations): no_weight = False self.run_conv_double_back_test(kern, stride, padding, chan_in, chan_out, batch_size, inp_size, dilation, no_weight) @dtypes(torch.float, torch.cfloat) @torch.backends.cudnn.flags(enabled=True, benchmark=False) def test_conv1d_same_padding(self, device, dtype): # Test padding='same' outputs the correct shape test_args = [ # in_size range(50, 55), # kernel_size [1, 2, 3, 8], # dilation range(1, 4), # stride [1], ] for in_size, k_size, dilation, stride in itertools.product(*test_args): x = torch.rand(1, 1, in_size, device=device, dtype=dtype) y = torch.rand(1, 1, k_size, device=device, dtype=dtype) z = F.conv1d(x, y, padding='same', dilation=dilation, stride=stride) self.assertEqual(z.size(2), int(math.ceil(in_size / stride))) # Compare F.conv1d padding='same' output against manual padding # Without strides/dilation x = torch.rand(1, 1, 12, device=device, dtype=dtype) y = torch.rand(1, 1, 3, device=device, dtype=dtype) expect = F.conv1d(x, y, padding=1) actual = F.conv1d(x, y, padding='same') self.assertEqual(expect, actual) # With dilation x = torch.rand(1, 1, 12, device=device, dtype=dtype) y = torch.rand(1, 1, 4, device=device, dtype=dtype) expect = F.conv1d(x, y, padding=3, dilation=2) actual = F.conv1d(x, y, padding='same', dilation=2) self.assertEqual(expect, actual) # Dilation with asymmetric padding expect = F.conv1d(x, y, padding=5, dilation=3)[..., 1:] actual = F.conv1d(x, y, padding='same', dilation=3) self.assertEqual(expect, actual) @dtypes(torch.float, torch.cfloat) def test_conv2d_same_padding(self, device, dtype): if dtype is torch.cfloat: rtol, atol = 2e-6, 2e-6 else: rtol, atol = None, None # Compare F.conv2d padding='same' output against manual padding # Without strides/dilation x = torch.rand(1, 1, 10, 11, device=device, dtype=dtype) y = torch.rand(1, 1, 4, 5, device=device, dtype=dtype) expect = F.conv2d(x, y, padding=(2, 2))[..., 1:, :] actual = F.conv2d(x, y, padding='same') self.assertEqual(expect, actual, rtol=rtol, atol=atol) # With dilation y = torch.rand(1, 1, 3, 4, device=device, dtype=dtype) expect = F.conv2d(x, y, padding=(2, 3), dilation=2) actual = F.conv2d(x, y, padding='same', dilation=2) self.assertEqual(expect, actual, rtol=rtol, atol=atol) # Dilation with asymmetric padding y = torch.rand(1, 1, 4, 4, device=device, dtype=dtype) expect = F.conv2d(x, y, padding=5, dilation=3)[..., 1:, 1:] actual = F.conv2d(x, y, padding='same', dilation=3) self.assertEqual(expect, actual, rtol=rtol, atol=atol) @dtypes(torch.float, torch.cfloat) def test_conv3d_same_padding(self, device, dtype): if dtype is torch.cfloat: rtol, atol = 2e-6, 2e-6 else: rtol, atol = None, None # Compare F.conv3d padding='same' output against manual padding # Without strides/dilation x = torch.rand(1, 1, 10, 11, 12, device=device, dtype=dtype) y = torch.rand(1, 1, 1, 2, 5, device=device, dtype=dtype) expect = F.conv3d(x, y, padding=(0, 1, 2))[..., :, 1:, :] actual = F.conv3d(x, y, padding='same') self.assertEqual(expect, actual, rtol=rtol, atol=atol) # With dilation expect = F.conv3d(x, y, padding=(0, 1, 4), dilation=2) actual = F.conv3d(x, y, padding='same', dilation=2) self.assertEqual(expect, actual, rtol=rtol, atol=atol) # Dilation with asymmetric padding y = torch.rand(1, 1, 4, 4, 4, device=device, dtype=dtype) expect = F.conv3d(x, y, padding=5, dilation=3)[..., 1:, 1:, 1:] actual = F.conv3d(x, y, padding='same', dilation=3) self.assertEqual(expect, actual, rtol=rtol, atol=atol) @dtypes(torch.float, torch.cfloat) def test_conv1d_valid_padding(self, device, dtype): # Test F.conv1d padding='valid' is the same as no padding x = torch.rand(1, 1, 10, device=device, dtype=dtype) y = torch.rand(1, 1, 4, device=device, dtype=dtype) expect = F.conv1d(x, y) actual = F.conv1d(x, y, padding='valid') self.assertEqual(expect, actual) @dtypes(torch.float, torch.cfloat) def test_conv2d_valid_padding(self, device, dtype): # Test F.conv2d padding='valid' is the same as no padding x = torch.rand(1, 1, 1, 10, device=device, dtype=dtype) y = torch.rand(1, 1, 1, 4, device=device, dtype=dtype) expect = F.conv2d(x, y) actual = F.conv2d(x, y, padding='valid') self.assertEqual(expect, actual) @dtypes(torch.float, torch.cfloat) def test_conv3d_valid_padding(self, device, dtype): # Test F.conv3d padding='valid' is the same as no padding x = torch.rand(1, 1, 1, 1, 10, dtype=dtype, device=device) y = torch.rand(1, 1, 1, 1, 4, dtype=dtype, device=device) expect = F.conv3d(x, y) actual = F.conv3d(x, y, padding='valid') self.assertEqual(expect, actual) @dtypes(torch.float, torch.cfloat) def test_conv1d_same_padding_backward(self, device, dtype): # Test F.conv1d gradients work with padding='same' x = torch.rand(1, 1, 12, dtype=dtype, device=device, requires_grad=True) y = torch.rand(1, 1, 4, dtype=dtype, device=device, requires_grad=True) # Symmetric padding z = F.conv1d(x, y, padding=3, dilation=2) z.sum().abs().backward() gx_expect, gy_expect = x.grad, y.grad x.grad, y.grad = None, None z = F.conv1d(x, y, padding='same', dilation=2) z.sum().abs().backward() self.assertEqual(gx_expect, x.grad) self.assertEqual(gy_expect, y.grad) x.grad, y.grad = None, None # Asymmetric padding z = F.conv1d(x, y, padding=2)[..., 1:] z.sum().abs().backward() gx_expect, gy_expect = x.grad, y.grad x.grad, y.grad = None, None z = F.conv1d(x, y, padding='same') z.sum().abs().backward() self.assertEqual(gx_expect, x.grad) self.assertEqual(gy_expect, y.grad) @dtypes(torch.float, torch.cfloat) @tf32_on_and_off(0.001) def test_conv2d_same_padding_backward(self, device, dtype): # Test F.conv2d gradients work with padding='same' x = torch.rand(1, 1, 10, 11, device=device, dtype=dtype, requires_grad=True) y = torch.rand(1, 1, 4, 5, device=device, dtype=dtype, requires_grad=True) # Symmetric padding z = F.conv2d(x, y, padding=(3, 4), dilation=2) z.sum().abs().backward() gx_expect, gy_expect = x.grad, y.grad x.grad, y.grad = None, None z = F.conv2d(x, y, padding='same', dilation=2) z.sum().abs().backward() self.assertEqual(gx_expect, x.grad) self.assertEqual(gy_expect, y.grad) x.grad, y.grad = None, None # Asymmetric padding y = torch.rand(1, 1, 4, 4, device=device, dtype=dtype, requires_grad=True) z = F.conv2d(x, y, padding=2)[..., 1:, 1:] z.sum().abs().backward() gx_expect, gy_expect = x.grad, y.grad x.grad, y.grad = None, None z = F.conv2d(x, y, padding='same') z.sum().abs().backward() self.assertEqual(gx_expect, x.grad) self.assertEqual(gy_expect, y.grad) @dtypes(torch.double, torch.cdouble) def test_conv3d_same_padding_backward(self, device, dtype): check_forward_ad = torch.device(device).type != 'xla' # Test F.conv3d gradients work with padding='same' x = torch.rand(1, 1, 1, 11, 12, dtype=dtype, device=device, requires_grad=True) y = torch.rand(1, 1, 1, 2, 5, dtype=dtype, device=device, requires_grad=True) # Symmetric padding z = F.conv3d(x, y, padding=(0, 1, 4), dilation=2) z.sum().abs().backward() gx_expect, gy_expect = x.grad, y.grad x.grad, y.grad = None, None z = F.conv3d(x, y, padding='same', dilation=2) z.sum().abs().backward() self.assertEqual(gx_expect, x.grad) self.assertEqual(gy_expect, y.grad) x.grad, y.grad = None, None gradcheck(lambda x, y: F.conv3d(x, y, padding='same', dilation=2), (x, y), check_forward_ad=check_forward_ad, nondet_tol=1e-5) if torch.device(device).type != 'cuda': # https://github.com/pytorch/pytorch/issues/70702 gradgradcheck(lambda x, y: F.conv3d(x, y, padding='same', dilation=2), (x, y), check_fwd_over_rev=True) # Asymmetric padding y = torch.rand(1, 1, 1, 4, 4, dtype=dtype, device=device, requires_grad=True) z = F.conv3d(x, y, padding=2)[..., 1:, 1:] z.sum().abs().backward() gx_expect, gy_expect = x.grad, y.grad x.grad, y.grad = None, None z = F.conv3d(x, y, padding='same') z.sum().abs().backward() self.assertEqual(gx_expect, x.grad) self.assertEqual(gy_expect, y.grad) gradcheck(lambda x, y: F.conv3d(x, y, padding='same'), (x, y), check_forward_ad=check_forward_ad, nondet_tol=1e-5) if torch.device(device).type != 'cuda': # https://github.com/pytorch/pytorch/issues/70702 gradgradcheck(lambda x, y: F.conv3d(x, y, padding='same'), (x, y), check_fwd_over_rev=True) @dtypes(torch.float, torch.cfloat) def test_conv1d_valid_padding_backward(self, device, dtype): # Test F.conv1d gradients work with padding='valid' x = torch.rand(1, 1, 10, dtype=dtype, device=device, requires_grad=True) y = torch.rand(1, 1, 4, dtype=dtype, device=device, requires_grad=True) F.conv1d(x, y, padding=0).sum().abs().backward() gx_expect, gy_expect = x.grad, y.grad x.grad, y.grad = None, None F.conv1d(x, y, padding='valid').sum().abs().backward() gx_actual, gy_actual = x.grad, y.grad self.assertEqual(gx_expect, gx_actual) self.assertEqual(gy_expect, gy_actual) @unittest.skipIf(not TEST_SCIPY, "Scipy required for the test.") @dtypes(torch.float, torch.cfloat) @parametrize_test("mode", ('valid', 'same')) def test_conv1d_vs_scipy(self, device, dtype, mode): t = make_tensor((1, 10), device=device, dtype=dtype) feat_dim = t.shape[1] weight_even = make_tensor((1, 1, 4), device=device, dtype=dtype) weight_odd = make_tensor((1, 1, 5), device=device, dtype=dtype) def _test(t, weight, mode): # SciPy expects two 1-D inputs. t_a = t.view(-1).cpu().numpy() w_a = weight.view(-1).cpu().numpy() expected = scipy.signal.convolve(t_a, w_a, mode=mode) kwargs = {'padding': mode} if mode == 'same': # `same` padding in PyTorch conv1d is different # from SciPy p = weight.shape[2] // 2 t = torch.nn.functional.pad(t, (p, p)) # We have already taken care of padding kwargs.pop("padding") # second input is flipped in SciPy's convolve weight_flipped = torch.flip(weight, (2,)) actual = torch.nn.functional.conv1d(t, weight_flipped, **kwargs).squeeze(0) if mode == 'same': actual = actual[:feat_dim] self.assertEqual(actual, expected, atol=2e-5, rtol=2e-5) # Global dtype for this test suite is torch.double # This leads to change in type-promotion # and conv1d outputs `complex128` for `complex64` input. with set_default_dtype(torch.float): _test(t, weight_even, mode) _test(t, weight_odd, mode) @unittest.skipIf(not TEST_SCIPY, "Scipy required for the test.") @dtypes(torch.float, torch.cfloat) @parametrize_test("mode", ('valid', 'same')) def test_conv2d_vs_scipy(self, device, dtype, mode): t = make_tensor((1, 5, 10), device=device, dtype=dtype) weight_even = make_tensor((1, 1, 2, 4), device=device, dtype=dtype) weight_odd = make_tensor((1, 1, 3, 5), device=device, dtype=dtype) def _test(t, weight, mode): # SciPy expects two 2-D inputs. t_a = t.squeeze(0).cpu().numpy() w_a = weight.squeeze(0).squeeze(0).cpu().numpy() expected = scipy.signal.convolve2d(t_a, w_a, mode=mode) kwargs = {'padding': mode} if mode == 'same': # `same` padding in PyTorch conv2d is different # from SciPy left_right_pad = weight.shape[3] // 2 top_bottom_pad = weight.shape[2] // 2 p = (left_right_pad, left_right_pad, top_bottom_pad, top_bottom_pad) t = torch.nn.functional.pad(t, p) # We have already taken care of padding kwargs.pop("padding") # second input is flipped in SciPy's convolve2d weight_flipped = torch.flip(weight, (2, 3)) actual = torch.nn.functional.conv2d(t, weight_flipped, **kwargs).squeeze(0) if mode == 'same': actual = actual[:5, :10] self.assertEqual(actual, expected, rtol=2e-5, atol=5e-6) # Global dtype for this test suite is torch.double # This leads to change in type-promotion # and conv1d outputs `complex128` for `complex64` input. with set_default_dtype(torch.float): _test(t, weight_even, mode) _test(t, weight_odd, mode) @unittest.skipIf(not TEST_SCIPY, "Scipy required for the test.") @dtypes(torch.float, torch.cfloat) @parametrize_test("mode", ('valid', 'same')) def test_conv3d_vs_scipy(self, device, dtype, mode): t = make_tensor((1, 5, 5, 10), device=device, dtype=dtype) weight_even = make_tensor((1, 1, 2, 2, 4), device=device, dtype=dtype) weight_odd = make_tensor((1, 1, 2, 3, 5), device=device, dtype=dtype) def _test(t, weight, mode): # SciPy expects two 3-D inputs. t_a = t.squeeze(0).cpu().numpy() w_a = weight.squeeze(0).squeeze(0).cpu().numpy() expected = scipy.signal.convolve(t_a, w_a, mode=mode) kwargs = {'padding': mode} if mode == 'same': # `same` padding in PyTorch conv3d is different # from SciPy left_right_pad = weight.shape[4] // 2 top_bottom_pad = weight.shape[3] // 2 front_back_pad = weight.shape[2] // 2 p = (left_right_pad, left_right_pad, top_bottom_pad, top_bottom_pad, front_back_pad, front_back_pad) t = torch.nn.functional.pad(t, p) # We have already taken care of padding kwargs.pop("padding") # second input is flipped in SciPy's convolve weight_flipped = torch.flip(weight, (2, 3, 4)) actual = torch.nn.functional.conv3d(t, weight_flipped, **kwargs).squeeze(0) if mode == 'same': actual = actual[:5, :5, :10] if tf32_is_not_fp32() and (dtype == torch.float or dtype == torch.complex64): self.assertEqual(actual, expected, atol=0.05, rtol=0.05) else: self.assertEqual(actual, expected, rtol=2e-5, atol=5e-6) # Global dtype for this test suite is torch.double # This leads to change in type-promotion # and conv1d outputs `complex128` for `complex64` input. with set_default_dtype(torch.float): _test(t, weight_even, mode) _test(t, weight_odd, mode) @dtypes(torch.float, torch.complex64) def test_conv2d_valid_padding_backward(self, device, dtype): # Test F.conv2d gradients work with padding='valid' x = torch.rand(1, 1, 1, 10, device=device, dtype=dtype, requires_grad=True) y = torch.rand(1, 1, 1, 4, device=device, dtype=dtype, requires_grad=True) F.conv2d(x, y, padding=0).sum().abs().backward() gx_expect, gy_expect = x.grad, y.grad x.grad, y.grad = None, None F.conv2d(x, y, padding='valid').sum().abs().backward() gx_actual, gy_actual = x.grad, y.grad self.assertEqual(gx_expect, gx_actual) self.assertEqual(gy_expect, gy_actual) @dtypes(torch.double, torch.cdouble) def test_conv3d_valid_padding_backward(self, device, dtype): check_forward_ad = torch.device(device).type != 'xla' # Test F.conv3d gradients work with padding='valid' x = torch.rand(1, 1, 1, 1, 10, dtype=dtype, device=device, requires_grad=True) y = torch.rand(1, 1, 1, 1, 4, dtype=dtype, device=device, requires_grad=True) F.conv3d(x, y, padding=0).sum().abs().backward() gx_expect, gy_expect = x.grad, y.grad x.grad, y.grad = None, None F.conv3d(x, y, padding='valid').sum().abs().backward() gx_actual, gy_actual = x.grad, y.grad self.assertEqual(gx_expect, gx_actual) self.assertEqual(gy_expect, gy_actual) gradcheck(lambda x, y: F.conv3d(x, y, padding='valid'), (x, y), check_forward_ad=check_forward_ad) gradgradcheck(lambda x, y: F.conv3d(x, y, padding='valid'), (x, y), check_fwd_over_rev=check_forward_ad) @parametrize_test("N", range(2, 4), name_fn=lambda N: f'ConvTranspose{N}d') def test_conv_transpose_with_output_size_and_no_batch_dim(self, device, N): # For inputs with no batch dim, verify output is the correct shape when output_size is set. # See https://github.com/pytorch/pytorch/issues/75889 inp = torch.randn((1, 15, 13) if N == 2 else (1, 15, 13, 13), device=device) output_size = (1, 240, 200) if N == 2 else (1, 240, 200, 200) ConvTransposeNd = getattr(nn, f'ConvTranspose{N}d') m = ConvTransposeNd(1, 1, kernel_size=16, stride=16, padding=7, bias=False, device=device) output = m(inp, output_size=output_size) self.assertEqual(output.shape, output_size) @skipMeta @parametrize_test("input_shape,transposed,dilated,groups,layout,backend_expected", [ # === slow === subtest(((2, 6, 7), False, False, 3, torch.strided, torch._C._ConvBackend.Slow2d), decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow1d'), subtest(((2, 6, 7), True, False, 3, torch.strided, torch._C._ConvBackend.SlowTranspose2d), decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow1d_transposed'), subtest(((2, 6, 7), False, True, 3, torch.strided, torch._C._ConvBackend.SlowDilated2d), decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow1d_dilated'), subtest(((2, 6, 7), True, True, 3, torch.strided, torch._C._ConvBackend.SlowTranspose2d), decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow1d_dilated_transposed'), subtest(((2, 6, 7, 8), False, False, 3, torch.strided, torch._C._ConvBackend.Slow2d), decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow2d'), subtest(((2, 6, 7, 8), True, False, 3, torch.strided, torch._C._ConvBackend.SlowTranspose2d), decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow2d_transposed'), subtest(((2, 6, 7, 8), False, True, 3, torch.strided, torch._C._ConvBackend.SlowDilated2d), decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow2d_dilated'), subtest(((2, 6, 7, 8), True, True, 3, torch.strided, torch._C._ConvBackend.SlowTranspose2d), decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow2d_dilated_transposed'), subtest(((2, 6, 7, 8, 9), False, False, 3, torch.strided, torch._C._ConvBackend.Slow3d), decorators=[onlyCPU, disableMkldnn], name='slow3d_cpu'), # CUDA doesn't have a slow 3D implementation, so it goes to the dilated 3D implementation instead subtest(((2, 6, 7, 8, 9), False, False, 3, torch.strided, torch._C._ConvBackend.SlowDilated3d), decorators=[onlyCUDA, disablecuDNN], name='slow3d_cuda'), # FIXME: RuntimeError: CUDA out of memory. # subtest(((2, 6, 7, 8, 9), True, False, 3, torch.strided, torch._C._ConvBackend.SlowTranspose3d), # decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow3d_transposed'), subtest(((2, 6, 7, 8, 9), False, True, 3, torch.strided, torch._C._ConvBackend.SlowDilated3d), decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow3d_dilated'), # FIXME: RuntimeError: CUDA out of memory. # subtest(((2, 6, 7, 8, 9), True, True, 3, torch.strided, torch._C._ConvBackend.SlowTranspose3d), # decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow3d_dilated_transposed'), subtest(((0, 6, 7), False, False, 3, torch.strided, torch._C._ConvBackend.Empty), decorators=[onlyNativeDeviceTypes, disableMkldnn], name='empty_batch1d'), subtest(((2, 0, 7), False, False, 3, torch.strided, torch._C._ConvBackend.Empty), decorators=[onlyNativeDeviceTypes, disableMkldnn], name='empty_channel1d'), subtest(((0, 0, 7), False, False, 3, torch.strided, torch._C._ConvBackend.Empty), decorators=[onlyNativeDeviceTypes, disableMkldnn], name='empty_batch_channel1d'), subtest(((0, 6, 7, 8), False, False, 3, torch.strided, torch._C._ConvBackend.Empty), decorators=[onlyNativeDeviceTypes, disableMkldnn], name='empty_batch2d'), subtest(((2, 0, 7, 8), False, False, 3, torch.strided, torch._C._ConvBackend.Empty), decorators=[onlyNativeDeviceTypes, disableMkldnn], name='empty_channel2d'), subtest(((0, 0, 7, 8), False, False, 3, torch.strided, torch._C._ConvBackend.Empty), decorators=[onlyNativeDeviceTypes, disableMkldnn], name='empty_batch_channel2d'), subtest(((0, 6, 7, 8, 9), False, False, 3, torch.strided, torch._C._ConvBackend.Empty), decorators=[onlyNativeDeviceTypes, disableMkldnn], name='empty_batch3d'), subtest(((2, 0, 7, 8, 9), False, False, 3, torch.strided, torch._C._ConvBackend.Empty), decorators=[onlyNativeDeviceTypes, disableMkldnn], name='empty_channel3d'), subtest(((0, 0, 7, 8, 9), False, False, 3, torch.strided, torch._C._ConvBackend.Empty), decorators=[onlyNativeDeviceTypes, disableMkldnn], name='empty_batch_channel3d'), # === cuda === # Note that disablecuDNN disables miopen as well. subtest(((2, 6, 7), False, False, 6, torch.strided, torch._C._ConvBackend.CudaDepthwise2d), decorators=[onlyCUDA, disablecuDNN], name='cuda_depthwise1d'), subtest(((2, 6, 7, 8), False, False, 6, torch.strided, torch._C._ConvBackend.CudaDepthwise2d), decorators=[onlyCUDA, disablecuDNN], name='cuda_depthwise2d'), subtest(((2, 6, 7, 8, 9), False, False, 6, torch.strided, torch._C._ConvBackend.CudaDepthwise3d), decorators=[onlyCUDA, disablecuDNN], name='cuda_depthwise3d'), # === cudnn === subtest(((2, 6, 7), False, False, 3, torch.strided, torch._C._ConvBackend.Cudnn), decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen], name='cudnn1d'), subtest(((2, 6, 7, 8), False, False, 3, torch.strided, torch._C._ConvBackend.Cudnn), decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen], name='cudnn2d'), subtest(((2, 6, 7, 8, 9), False, False, 3, torch.strided, torch._C._ConvBackend.Cudnn), decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen], name='cudnn3d'), subtest(((2, 6, 7), True, False, 3, torch.strided, torch._C._ConvBackend.CudnnTranspose), decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen], name='cudnn1d_transposed'), subtest(((2, 6, 7, 8), True, False, 3, torch.strided, torch._C._ConvBackend.CudnnTranspose), decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen], name='cudnn2d_transposed'), # FIXME: RuntimeError: CUDA out of memory. # subtest(((2, 6, 7, 8, 9), True, False, 3, torch.strided, torch._C._ConvBackend.CudnnTranspose), # decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen], name='cudnn3d_transposed'), # === miopen === subtest(((2, 6, 7), False, False, 3, torch.strided, torch._C._ConvBackend.Miopen), decorators=[onlyCUDA, skipCUDAIfNoMiopen], name='miopen1d'), subtest(((2, 6, 7, 8), False, False, 3, torch.strided, torch._C._ConvBackend.Miopen), decorators=[onlyCUDA, skipCUDAIfNoMiopen], name='miopen2d'), subtest(((2, 6, 7, 8, 9), False, False, 3, torch.strided, torch._C._ConvBackend.Miopen), decorators=[onlyCUDA, skipCUDAIfNoMiopen], name='miopen3d'), subtest(((2, 6, 7), True, False, 3, torch.strided, torch._C._ConvBackend.MiopenTranspose), decorators=[onlyCUDA, skipCUDAIfNoMiopen], name='miopen1d_transposed'), subtest(((2, 6, 7, 8), True, False, 3, torch.strided, torch._C._ConvBackend.MiopenTranspose), decorators=[onlyCUDA, skipCUDAIfNoMiopen], name='miopen2d_transposed'), subtest(((2, 6, 7, 8, 9), True, False, 3, torch.strided, torch._C._ConvBackend.MiopenTranspose), decorators=[onlyCUDA, skipCUDAIfNoMiopen], name='miopen3d_transposed'), subtest(((2, 6, 7), False, False, 6, torch.strided, torch._C._ConvBackend.MiopenDepthwise), decorators=[onlyCUDA, skipCUDAIfNoMiopen], name='miopen_depthwise1d'), subtest(((2, 6, 7, 8), False, False, 6, torch.strided, torch._C._ConvBackend.MiopenDepthwise), decorators=[onlyCUDA, skipCUDAIfNoMiopen], name='miopen_depthwise2d'), subtest(((2, 6, 7, 8, 9), False, False, 6, torch.strided, torch._C._ConvBackend.MiopenDepthwise), decorators=[onlyCUDA, skipCUDAIfNoMiopen], name='miopen_depthwise3d'), # === mkldnn === subtest(((2, 6, 7), False, False, 3, torch._mkldnn, torch._C._ConvBackend.Mkldnn), decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn1d'), subtest(((2, 6, 7, 8), False, False, 3, torch._mkldnn, torch._C._ConvBackend.Mkldnn), decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn2d'), subtest(((2, 6, 7, 8, 9), False, False, 3, torch._mkldnn, torch._C._ConvBackend.Mkldnn), decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn3d'), # Transposed convolution is broken for mkldnn. See https://github.com/pytorch/pytorch/issues/68775. subtest(((2, 6, 7), True, False, 3, torch._mkldnn, torch._C._ConvBackend.Mkldnn), decorators=[onlyCPU, skipCPUIfNoMkldnn, unittest.expectedFailure], name='mkldnn1d_transposed'), subtest(((2, 6, 7, 8), True, False, 3, torch._mkldnn, torch._C._ConvBackend.Mkldnn), decorators=[onlyCPU, skipCPUIfNoMkldnn, unittest.expectedFailure], name='mkldnn2d_transposed'), subtest(((2, 6, 7, 8, 9), True, False, 3, torch._mkldnn, torch._C._ConvBackend.Mkldnn), decorators=[onlyCPU, skipCPUIfNoMkldnn, unittest.expectedFailure], name='mkldnn3d_transposed'), subtest(((2, 6, 7), False, True, 3, torch.strided, torch._C._ConvBackend.Mkldnn), decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn1d_cpu_input'), subtest(((2, 6, 7, 8), False, True, 3, torch.strided, torch._C._ConvBackend.Mkldnn), decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn2d_cpu_input'), subtest(((2, 6, 7, 8, 9), False, True, 3, torch.strided, torch._C._ConvBackend.Mkldnn), decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn3d_cpu_input'), subtest(((0, 6, 7), False, False, 3, torch._mkldnn, torch._C._ConvBackend.MkldnnEmpty), decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn_empty_batch1d'), subtest(((2, 0, 7), False, False, 3, torch._mkldnn, torch._C._ConvBackend.MkldnnEmpty), decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn_empty_channel1d'), subtest(((0, 0, 7), False, False, 3, torch._mkldnn, torch._C._ConvBackend.MkldnnEmpty), decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn_empty_batch_channel1d'), subtest(((0, 6, 7, 8), False, False, 3, torch._mkldnn, torch._C._ConvBackend.MkldnnEmpty), decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn_empty_batch2d'), subtest(((2, 0, 7, 8), False, False, 3, torch._mkldnn, torch._C._ConvBackend.MkldnnEmpty), decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn_empty_channel2d'), subtest(((0, 0, 7, 8), False, False, 3, torch._mkldnn, torch._C._ConvBackend.MkldnnEmpty), decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn_empty_batch_channel2d'), subtest(((0, 6, 7, 8, 9), False, False, 3, torch._mkldnn, torch._C._ConvBackend.MkldnnEmpty), decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn_empty_batch3d'), subtest(((2, 0, 7, 8, 9), False, False, 3, torch._mkldnn, torch._C._ConvBackend.MkldnnEmpty), decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn_empty_channel3d'), subtest(((0, 0, 7, 8, 9), False, False, 3, torch._mkldnn, torch._C._ConvBackend.MkldnnEmpty), decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn_empty_batch_channel3d'), # Note: Tests for mobile backends are not currently supported. This comprises # NnpackSpatial, Winograd3x3Depthwise, and Xnnpack2d backends. Testing these # requires the ability to gate tests by whether PyTorch is built with USE_MOBILE=1. ]) # Test with both bias and no bias. @parametrize_test("has_bias", [False, True]) # Test with both stride=1 and stride>1 cases. @parametrize_test("strided", [False, True]) # Test with both contiguous and non-contiguous inputs. @parametrize_test("contiguous", [False, True]) def test_conv_backend( self, device, input_shape, has_bias, strided, contiguous, transposed, dilated, groups, layout, backend_expected): # Build up inputs. dtype = torch.float32 C_in, C_out, dim, kernel_size = input_shape[1], 12, len(input_shape) - 2, 3 x = torch.randn(*input_shape, device=device, dtype=dtype, requires_grad=True) weight = torch.randn(C_in if transposed else C_out, C_out // groups if transposed else C_in // groups, *[kernel_size for _ in range(dim)], device=device, dtype=dtype, requires_grad=True) bias = torch.randn(C_out, device=device, dtype=dtype, requires_grad=True) if has_bias else None def _make_noncontiguous(inp): if inp is None: return None old_requires_grad = inp.requires_grad inp = torch.repeat_interleave(inp, 2, dim=-1) inp = inp[..., ::2].detach().requires_grad_(old_requires_grad) return inp if not contiguous: x = _make_noncontiguous(x) weight = _make_noncontiguous(weight) bias = _make_noncontiguous(bias) if layout is torch._mkldnn: x = x.to_mkldnn() # Note that weight and bias are not supported as mkldnn tensors during training. stride = (2,) * dim if strided else (1,) * dim padding = (0,) * dim dilation = (2,) * dim if dilated else (1,) * dim output_padding = (0,) * dim inputs = [x, weight, bias, stride, padding, dilation, transposed, output_padding, groups] # Ensure correct backend is selected. backend_actual = torch._C._select_conv_backend(*inputs) self.assertEqual(backend_actual, backend_expected) # Ensure backward call succeeds. convolution = torch.ops.aten.convolution output = convolution(*inputs) grad_output = torch.randn(output.shape, device=device, dtype=dtype) if not contiguous: grad_output = _make_noncontiguous(grad_output) if layout is torch._mkldnn: grad_output = grad_output.to_mkldnn() output.backward(grad_output) # mkldnn doesn't support gradcheck :( if layout is torch._mkldnn: return if backend_actual != torch._C._ConvBackend.Empty: # FIXME: forward AD fails # Forward AD and forward-over-reverse AD smoke test in float32 # TODO: remove this if we introduce per-op gradient tests for float32 with fwAD.dual_level(): dual_inputs = [(fwAD.make_dual(i, torch.rand_like(i)) if isinstance(i, torch.Tensor) else i) for i in inputs] # Forward AD output = convolution(*dual_inputs) # Forward over reverse AD grad_output_d = fwAD.make_dual(torch.rand_like(output), torch.rand_like(output)) if has_bias: torch.autograd.grad(output, [x, weight, bias], grad_output_d) else: torch.autograd.grad(output, [x, weight], grad_output_d) # Convert to float64 for gradcheck. x = x.to(torch.float64).detach().requires_grad_(True) weight = weight.to(torch.float64).detach().requires_grad_(True) if bias is not None: bias = bias.to(torch.float64).detach().requires_grad_(True) inputs = [x, weight, bias, stride, padding, dilation, transposed, output_padding, groups] # Set some backend-specific validation settings. gradcheck_nondet_tol = 0.0 if torch.backends.cudnn.is_available(): # cuDNN introduces non-determinism gradcheck_nondet_tol = GRADCHECK_NONDET_TOL self.assertTrue(gradcheck(convolution, inputs, nondet_tol=gradcheck_nondet_tol)) # double backward doesn't support bias gradients if bias is not None: bias.requires_grad_(False) self.assertTrue(gradgradcheck(convolution, inputs, nondet_tol=gradcheck_nondet_tol)) @onlyCPU def test_conv_contiguous_for_oneDNN(self): # See https://github.com/pytorch/pytorch/issues/80837. for dtype in [torch.float, torch.bfloat16, torch.half]: conv = nn.Conv2d( 1, 128, kernel_size=(5, 2), stride=(2, 1), padding=(0, 1), dilation=(1, 1), groups=1, bias=True, padding_mode='zeros').to(dtype=dtype) x = torch.rand([1, 2, 321, 201, 1]).to(dtype=dtype) x = torch.transpose(x, 1, 4) x2 = x[..., 0] inputs = [x2, conv.weight, conv.bias, (2, 1), (0, 1), (1, 1), False, (0, 1), 1] if torch.backends.mkldnn.is_available(): y = conv(x2) # Disable MKLDNN explicitly with torch.backends.mkldnn.flags(enabled=False): y_ = conv(x2) self.assertEqual(y, y_) @onlyCPU def test_conv_ic1_channels_last_for_oneDNN(self): # See https://github.com/pytorch/pytorch/issues/82060, N > 1 will call in OneDNN path. for dtype in [torch.float, torch.bfloat16, torch.half]: conv = torch.nn.Conv2d(1, 64, kernel_size=(3, 3), padding=(1, 1), bias=False) conv = conv.to(memory_format=torch.channels_last).to(dtype=dtype) x = torch.rand(2, 1, 100, 100).to(dtype=dtype) if torch.backends.mkldnn.is_available(): y = conv(x) # Disable MKLDNN explicitly with torch.backends.mkldnn.flags(enabled=False): y_ = conv(x) self.assertEqual(y, y_) @dtypes(torch.float, torch.cfloat) def test_conv_empty_channel(self, device, dtype): in_channels = 0 mod = torch.nn.Conv1d(in_channels, 8, 2, stride=2, dtype=dtype).to(device) inp = torch.randn(2, 0, 15, device=device, dtype=dtype) _test_module_empty_input(self, mod, inp, check_size=False) with self.assertRaisesRegex(RuntimeError, "Given groups=1, weight"): inp = torch.randn(2, 1, 0, device=device, dtype=dtype) mod(inp) mod = torch.nn.Conv2d(in_channels, 33, 3, stride=2, dtype=dtype).to(device) inp = torch.randn(2, 0, 50, 100, device=device, dtype=dtype) _test_module_empty_input(self, mod, inp, check_size=False) with self.assertRaisesRegex(RuntimeError, "Given groups=1, weight"): inp = torch.randn(2, 1, 40, 0, device=device, dtype=dtype) mod(inp) mod = torch.nn.Conv3d(in_channels, 33, 3, stride=2, dtype=dtype).to(device) inp = torch.randn(2, 0, 50, 20, 40, device=device, dtype=dtype) _test_module_empty_input(self, mod, inp, check_size=False) with self.assertRaisesRegex(RuntimeError, "Given groups=1, weight"): inp = torch.randn(2, 1, 50, 0, 40, device=device, dtype=dtype) mod(inp) def test_group_conv_empty(self, device): mod = torch.nn.Conv2d(4, 4, stride=2, kernel_size=3, padding=1, groups=4).to(device) inp = torch.randn(0, 4, 4, 4, device=device) _test_module_empty_input(self, mod, inp, check_size=False) if self.device_type == 'cuda' and self.has_cudnn(): with torch.backends.cudnn.flags(enabled=False): _test_module_empty_input(self, mod, inp, check_size=False) def test_group_convTranspose_empty(self, device): mod = torch.nn.ConvTranspose2d(4, 4, stride=2, kernel_size=3, padding=1, groups=4).to(device) inp = torch.randn(0, 4, 4, 4, device=device) _test_module_empty_input(self, mod, inp, check_size=False) if self.device_type == 'cuda' and self.has_cudnn(): with torch.backends.cudnn.flags(enabled=False): _test_module_empty_input(self, mod, inp, check_size=False) def test_convTranspose_empty(self, device): mod = torch.nn.ConvTranspose2d(4, 4, stride=2, kernel_size=3, padding=1).to(device) inp = torch.randn(0, 4, 4, 4, device=device) _test_module_empty_input(self, mod, inp, check_size=False) if self.device_type == 'cuda' and self.has_cudnn(): with torch.backends.cudnn.flags(enabled=False): _test_module_empty_input(self, mod, inp, check_size=False) @onlyCUDA @largeTensorTest('12GB') def test_conv_large_nosplit(self, device): # Here we just test the convolution correctly route to the fallback implementation # that is, it does not crash. The correctness of fallback implementation should be # covered in other tests dtype = torch.half if self.device_type == 'cuda' else torch.float conv1 = nn.Conv2d(2, 2, 8, 8).to(device).to(dtype) input_large = torch.randn(1, 2, 1024, 1024 * 1024, dtype=dtype, device=device) conv1(input_large) conv2 = torch.nn.Conv2d(1, 1024, 1, 1).to(device).to(dtype) input_large = torch.randn(1, 1, 2048, 1024, dtype=dtype, device=device) conv2(input_large) def test_conv_noncontig_weights(self, device): for dim in (1, 2, 3): for grouped in (False, True): nc = 3 groups = 3 if grouped else 1 w = torch.randn([3] * dim, device=device) w = w.expand([nc, int(nc / groups)] + list(w.shape)) w = w.detach().requires_grad_() x = torch.randn([1, nc] + ([5] * dim), device=device, requires_grad=True) y = getattr(F, f'conv{dim}d')(x, w, groups=groups) y.sum().backward() y = getattr(F, f'conv_transpose{dim}d')(x, w, groups=groups) y.sum().backward() def test_conv_noncontig_weights_and_bias(self, device): # need floats to exercise https://github.com/pytorch/pytorch/issues/16018 for bias in [True, False]: conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=bias).to(device, torch.float) input_nc = torch.randn((1, 3, 224, 224, 2), device=device, dtype=torch.float)[:, :, :, :, 1] input_c = input_nc.contiguous() weight_nc = torch.randn((64, 3, 7, 7, 2), device=device, dtype=torch.float)[:, :, :, :, 1] conv1.weight = nn.Parameter(weight_nc) weight_c = conv1.weight.contiguous() if bias: bias_nc = torch.randn((64, 2), device=device, dtype=torch.float)[:, 1] conv1.bias = nn.Parameter(bias_nc) bias_c = conv1.bias.contiguous() out1 = conv1(input_nc) conv1.weight = nn.Parameter(weight_c) if bias: conv1.bias = nn.Parameter(bias_c) out2 = conv1(input_c) self.assertEqual(out1, out2) @onlyCUDA @largeTensorTest('12GB') @skipIfRocmVersionLessThan((6, 0)) def test_conv_transposed_large(self, device): dtype = torch.half if self.device_type == 'cuda' else torch.float conv = nn.ConvTranspose2d(1, 1, 1, 1, bias=False).to(device).to(dtype) input_large = torch.randn(4096, 1, 512, 1024, dtype=dtype, device=device) # forward ret = conv(input_large) maxdiff0 = (ret.narrow(0, 0, 1024) - conv(input_large.narrow(0, 0, 1024))).abs_().max().item() maxdiff1 = (ret.narrow(0, 1024, 1024) - conv(input_large.narrow(0, 1024, 1024))).abs_().max().item() maxdiff2 = (ret.narrow(0, 2048, 1024) - conv(input_large.narrow(0, 2048, 1024))).abs_().max().item() maxdiff3 = (ret.narrow(0, 3072, 1024) - conv(input_large.narrow(0, 3072, 1024))).abs_().max().item() if self.device_type == 'cuda': # cuDNN may use algorithms such as FFT that don't guarantee a diff of 0 self.assertEqual(maxdiff0, 0, atol=2e-3, rtol=1e-5) self.assertEqual(maxdiff1, 0, atol=2e-3, rtol=1e-5) self.assertEqual(maxdiff2, 0, atol=2e-3, rtol=1e-5) self.assertEqual(maxdiff3, 0, atol=2e-3, rtol=1e-5) else: self.assertEqual(maxdiff0, 0) self.assertEqual(maxdiff1, 0) self.assertEqual(maxdiff2, 0) self.assertEqual(maxdiff3, 0) @onlyCUDA @skipCUDAIfRocm @largeTensorTest('12GB') def test_conv_large(self, device): dtype = torch.half if self.device_type == 'cuda' else torch.float conv = nn.Conv2d(2, 2, 8, 8, bias=False).to(device).to(dtype) input_large = torch.randn(4097, 2, 512, 512, dtype=dtype, device=device) # forward ret = conv(input_large) self.assertEqual(ret[:2048], conv(input_large[:2048])) self.assertEqual(ret[2048:4096], conv(input_large[2048:4096])) self.assertEqual(ret[4096:], conv(input_large[4096:])) # backward conv.zero_grad() # When computing the backward, we are using the `max(dim=1)`` to create # some sparsity. Without this sparsity, the rounding error would be # too large (as large as 1e-5) to satisfy the creterion (1e-6) of `assertEqual` ret.view(4097, -1).max(dim=1).values.sum().backward() del ret grad1 = conv.weight.grad.detach().clone() conv.zero_grad() conv(input_large[:2048]).view(2048, -1).max(dim=1).values.sum().backward() conv(input_large[2048:4096]).view(2048, -1).max(dim=1).values.sum().backward() conv(input_large[4096:]).view(1, -1).max(dim=1).values.sum().backward() grad2 = conv.weight.grad.detach().clone() # gradients are at the order of hundreds, we need to scale it to # the order of one so that we can compare scale = 1 / grad2.abs().mean() grad1 = grad1 * scale grad2 = grad2 * scale self.assertEqual(grad1, grad2, atol=5e-2, rtol=5e-3) @onlyCUDA @skipCUDAIfNoCudnn def test_contig_wrong_stride_cudnn(self, device): # x has to have batch_size 1 to test contiguous checks x = torch.randn(1, 16, 5, 5, device=device) stride = list(x.stride()) stride[0] = 20 # change the stride in dimension 0. the tensor is still contiguous because size[0] is 1 x.set_(x.storage(), 0, x.size(), stride) self.assertTrue(x.is_contiguous()) F.conv_transpose2d(x, torch.randn(16, 1, 1, 1, device=device)) F.conv2d(x, torch.randn(1, 16, 1, 1, device=device)) @onlyCUDA def test_Conv2d_size_1_kernel(self, device): x_cpu = torch.randn(2, 3, 5, 5) conv_cpu = torch.nn.Conv2d(3, 3, kernel_size=1) y_cpu = conv_cpu(x_cpu) y = torch.rand_like(y_cpu) y_cpu.backward(y) with cudnn.flags(enabled=False): conv_cuda = torch.nn.Conv2d(3, 3, kernel_size=1).to(device) conv_cuda.bias.data.copy_(conv_cpu.bias.data) conv_cuda.weight.data.copy_(conv_cpu.weight.data) y_cuda = conv_cuda(x_cpu.to(device)) y_cuda.backward(y.to(device)) self.assertEqual(y_cpu, y_cuda, atol=1e-5, rtol=0, exact_device=False) self.assertEqual(conv_cpu.bias.grad.data, conv_cuda.bias.grad.data, atol=1e-5, rtol=0, exact_device=False) self.assertEqual(conv_cpu.weight.grad.data, conv_cuda.weight.grad.data, atol=1e-5, rtol=0, exact_device=False) @onlyCUDA def test_ConvTranspose2d_size_1_kernel(self, device): x_cpu = torch.randn(2, 3, 5, 5) conv_cpu = torch.nn.ConvTranspose2d(3, 3, kernel_size=1) y_cpu = conv_cpu(x_cpu) y = torch.rand_like(y_cpu) y_cpu.backward(y) with cudnn.flags(enabled=False): conv_cuda = torch.nn.ConvTranspose2d(3, 3, kernel_size=1).to(device) conv_cuda.bias.data.copy_(conv_cpu.bias.data) conv_cuda.weight.data.copy_(conv_cpu.weight.data) y_cuda = conv_cuda(x_cpu.to(device)) y_cuda.backward(y.to(device)) self.assertEqual(y_cpu, y_cuda, atol=1e-5, rtol=0, exact_device=False) self.assertEqual(conv_cpu.bias.grad.data, conv_cuda.bias.grad.data, atol=1e-5, rtol=0, exact_device=False) self.assertEqual(conv_cpu.weight.grad.data, conv_cuda.weight.grad.data, atol=1e-5, rtol=0, exact_device=False) @onlyCUDA def test_ConvTranspose3d_size_1_kernel(self, device): with set_default_dtype(torch.double): x_cpu = torch.randn(2, 3, 3, 5, 5) conv_cpu = torch.nn.ConvTranspose3d(3, 3, kernel_size=1) y_cpu = conv_cpu(x_cpu) y = torch.rand_like(y_cpu) y_cpu.backward(y) with cudnn.flags(enabled=False): conv_cuda = torch.nn.ConvTranspose3d(3, 3, kernel_size=1).to(device) conv_cuda.bias.data.copy_(conv_cpu.bias.data) conv_cuda.weight.data.copy_(conv_cpu.weight.data) y_cuda = conv_cuda(x_cpu.to(device)) y_cuda.backward(y.to(device)) self.assertEqual(y_cpu, y_cuda, atol=1e-5, rtol=0, exact_device=False) self.assertEqual(conv_cpu.bias.grad.data, conv_cuda.bias.grad.data, atol=1e-5, rtol=0, exact_device=False) self.assertEqual(conv_cpu.weight.grad.data, conv_cuda.weight.grad.data, atol=1e-5, rtol=0, exact_device=False) @dtypesIfCUDA(*floating_types_and(torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else [])) @dtypes(torch.float) @torch.backends.cudnn.flags(enabled=True, benchmark=False) @unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm, since it is failing on ROCm 5.7") def test_Conv2d_naive_groups(self, device, dtype): # Check that grouped convolutions matches two half convolutions m = nn.Conv2d(4, 4, kernel_size=3, groups=2).to(device, dtype) i = torch.randn(2, 4, 6, 6, device=device, dtype=dtype, requires_grad=True) output = m(i) grad_output = torch.randn(2, 4, 4, 4, device=device, dtype=dtype) output.backward(grad_output) m1 = nn.Conv2d(2, 2, kernel_size=3).to(device, dtype) m1.weight.data.copy_(m.weight.data[:2]) m1.bias.data.copy_(m.bias.data[:2]) i1 = i.data[:, :2].contiguous().requires_grad_(True) output1 = m1(i1) output1.backward(grad_output[:, :2].contiguous()) m2 = nn.Conv2d(2, 2, kernel_size=3).to(device, dtype) m2.weight.data.copy_(m.weight.data[2:]) m2.bias.data.copy_(m.bias.data[2:]) i2 = i.data[:, 2:].contiguous().requires_grad_(True) output2 = m2(i2) output2.backward(grad_output[:, 2:].contiguous()) self.assertEqual(output, torch.cat([output1, output2], 1)) self.assertEqual(i.grad.data, torch.cat([i1.grad.data, i2.grad.data], 1), atol=dtype2prec_DONTUSE[dtype], rtol=0) self.assertEqual(m.bias.grad.data, torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0), atol=dtype2prec_DONTUSE[dtype], rtol=0) self.assertEqual(m.weight.grad.data, torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0), atol=dtype2prec_DONTUSE[dtype], rtol=0) @dtypes(torch.double, torch.cdouble) def test_Conv2d_backward_depthwise(self, device, dtype): x = torch.randn(2, 2, 4, 20, device=device, dtype=dtype, requires_grad=True) weight = torch.randn(2, 1, 3, 5, device=device, dtype=dtype, requires_grad=True) def conv2d_depthwise(x, weight): return torch.nn.functional.conv2d( x, weight, bias=None, stride=(1, 10), groups=2) for cudnn_enabled in [False, True]: with torch.backends.cudnn.flags(enabled=cudnn_enabled): torch.autograd.gradcheck(conv2d_depthwise, (x, weight)) @onlyCPU @dtypes(torch.float, torch.double) def test_conv_thnn_nhwc(self, device, dtype): def helper(mod, n, c, h, w, out_channels, kernel_size, dilation, groups, input_format, weight_format): input = torch.randint(-3, 3, (n, c, h, w), dtype=dtype, device=device)\ .to(memory_format=input_format) input.requires_grad_() conv = mod(c, out_channels, kernel_size, dilation=dilation, groups=groups)\ .to(device='cpu', dtype=dtype, memory_format=weight_format) for p in conv.parameters(): p.data = torch.randint_like(p, -3, 3) ref_input = input.detach().clone().contiguous().requires_grad_() ref_conv = mod(c, out_channels, kernel_size, dilation=dilation, groups=groups) # load_state_dict will restore the stride & memory_layout on ref_conv.weight. ref_conv.load_state_dict(conv.state_dict()) ref_conv = ref_conv.to(device='cpu', dtype=dtype, memory_format=torch.contiguous_format) out = conv(input) ref_out = ref_conv(ref_input) grad = torch.randint_like(out, -3, 3) ref_grad = grad.detach().clone().contiguous() out.backward(grad) ref_out.backward(ref_grad) self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) self.assertTrue(ref_out.is_contiguous()) self.assertEqual(out, ref_out, exact_dtype=False) self.assertEqual(conv.weight.grad, ref_conv.weight.grad, exact_dtype=False) self.assertEqual(conv.bias.grad, ref_conv.bias.grad, exact_dtype=False) self.assertEqual(input.grad, ref_input.grad, exact_dtype=False) with torch.backends.mkldnn.flags(enabled=False): formats = [[torch.channels_last, torch.channels_last], [torch.channels_last, torch.contiguous_format], [torch.contiguous_format, torch.channels_last]] for input_format, weight_format in formats: # non-dilated conv: thnn_conv2d normal path (with im2col) helper(nn.Conv2d, 2, 8, 4, 4, out_channels=4, kernel_size=3, dilation=1, groups=1, input_format=input_format, weight_format=weight_format) helper(nn.Conv2d, 2, 8, 4, 4, out_channels=8, kernel_size=3, dilation=1, groups=8, input_format=input_format, weight_format=weight_format) # test when input chanels is 1 and not converted to channels last helper(nn.Conv2d, 2, 1, 10, 10, out_channels=8, kernel_size=3, dilation=1, groups=1, input_format=torch.contiguous_format, weight_format=torch.channels_last) # non-dilated conv: thnn_conv2d fast path (skip im2col) helper(nn.Conv2d, 1, 16, 56, 56, out_channels=16, kernel_size=1, dilation=1, groups=1, input_format=input_format, weight_format=weight_format) # ic == oc == 1 here, so need to stick input to CL to activate channels last helper(nn.Conv2d, 1, 16, 56, 56, out_channels=16, kernel_size=1, dilation=1, groups=16, input_format=torch.channels_last, weight_format=weight_format) # dilated conv: slow_conv_dilated2d helper(nn.Conv2d, 2, 8, 11, 13, out_channels=16, kernel_size=3, dilation=2, groups=1, input_format=input_format, weight_format=weight_format) helper(nn.Conv2d, 2, 16, 11, 13, out_channels=32, kernel_size=3, dilation=2, groups=16, input_format=input_format, weight_format=weight_format) # transposed-conv: slow_conv_transpose2d helper(nn.ConvTranspose2d, 2, 8, 4, 4, out_channels=4, kernel_size=3, dilation=1, groups=1, input_format=input_format, weight_format=weight_format) helper(nn.ConvTranspose2d, 2, 8, 4, 4, out_channels=8, kernel_size=3, dilation=1, groups=8, input_format=input_format, weight_format=weight_format) helper(nn.ConvTranspose2d, 1, 16, 56, 56, out_channels=16, kernel_size=1, dilation=1, groups=1, input_format=input_format, weight_format=weight_format) helper(nn.ConvTranspose2d, 1, 16, 56, 56, out_channels=32, kernel_size=1, dilation=1, groups=16, input_format=input_format, weight_format=weight_format) @onlyCUDA @skipCUDAIfRocmVersionLessThan((4, 3)) @skipCUDAIfNotMiopenSuggestNHWC @skipCUDAIfCudnnVersionLessThan(7603) @dtypes(torch.half, torch.float, torch.cfloat) def test_conv_cudnn_nhwc(self, device, dtype): def helper(n, c, h, w, out_channels, kernel_size, groups): input = torch.randint(-3, 3, (n, c, h, w), dtype=dtype, device=device)\ .to(memory_format=torch.channels_last) input.requires_grad_() conv = nn.Conv2d(c, out_channels, kernel_size, groups=groups)\ .to(device='cuda', dtype=dtype, memory_format=torch.channels_last) for p in conv.parameters(): p.data = torch.randint_like(p, -3, 3) # use FP64 channels-first conv as reference ref_input = input.detach().clone().contiguous().double().requires_grad_() ref_conv = nn.Conv2d(c, out_channels, kernel_size, groups=groups) # load_state_dict will restore the stride & memory_layout on ref_conv.weight. ref_conv.load_state_dict(conv.state_dict()) ref_conv = ref_conv.to(device='cuda', dtype=torch.double, memory_format=torch.contiguous_format) out = conv(input) ref_out = ref_conv(ref_input) grad = torch.randint_like(out, -3, 3) ref_grad = grad.detach().clone().double().contiguous() out.backward(grad) ref_out.backward(ref_grad) self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) self.assertTrue(input.grad.is_contiguous(memory_format=torch.channels_last)) self.assertTrue(conv.weight.grad.is_contiguous(memory_format=torch.channels_last)) self.assertTrue(ref_out.is_contiguous()) self.assertTrue(ref_input.grad.is_contiguous()) self.assertTrue(ref_conv.weight.grad.is_contiguous()) self.assertEqual(out, ref_out, exact_dtype=False) self.assertEqual(conv.weight.grad, ref_conv.weight.grad, exact_dtype=False) self.assertEqual(conv.bias.grad, ref_conv.bias.grad, exact_dtype=False) self.assertEqual(input.grad, ref_input.grad, exact_dtype=False) helper(2, 8, 4, 4, out_channels=4, kernel_size=3, groups=1) helper(2, 8, 4, 4, out_channels=8, kernel_size=3, groups=8) helper(1, 16, 56, 56, out_channels=16, kernel_size=3, groups=1) helper(1, 16, 56, 56, out_channels=16, kernel_size=3, groups=16) @onlyCUDA @skipCUDAIfRocm @skipCUDAIfCudnnVersionLessThan(8005) @dtypes(torch.half, torch.float) def test_conv_cudnn_ndhwc(self, device, dtype): def helper(n, c, d, h, w, out_channels, kernel_size, groups): input = torch.randint(-2, 2, (n, c, d, h, w), dtype=dtype, device=device)\ .to(memory_format=torch.channels_last_3d) input.requires_grad_() conv = nn.Conv3d(c, out_channels, kernel_size, groups=groups)\ .to(device='cuda', dtype=dtype, memory_format=torch.channels_last_3d) for p in conv.parameters(): p.data = torch.randint_like(p, -2, 2) # use FP64 channels-first conv as reference ref_input = input.detach().clone().contiguous().double().requires_grad_() ref_conv = nn.Conv3d(c, out_channels, kernel_size, groups=groups) # load_state_dict will restore the stride & memory_layout on ref_conv.weight. ref_conv.load_state_dict(conv.state_dict()) ref_conv = ref_conv.to(device='cuda', dtype=torch.double, memory_format=torch.contiguous_format) out = conv(input) ref_out = ref_conv(ref_input) grad = torch.randint_like(out, -2, 2) ref_grad = grad.detach().clone().double().contiguous() out.backward(grad) ref_out.backward(ref_grad) self.assertTrue(out.is_contiguous(memory_format=torch.channels_last_3d)) self.assertTrue(input.grad.is_contiguous(memory_format=torch.channels_last_3d)) self.assertTrue(conv.weight.grad.is_contiguous(memory_format=torch.channels_last_3d)) self.assertTrue(ref_out.is_contiguous()) self.assertTrue(ref_input.grad.is_contiguous()) self.assertTrue(ref_conv.weight.grad.is_contiguous()) self.assertEqual(out, ref_out, exact_dtype=False) self.assertEqual(conv.weight.grad, ref_conv.weight.grad, exact_dtype=False) self.assertEqual(conv.bias.grad, ref_conv.bias.grad, exact_dtype=False) self.assertEqual(input.grad, ref_input.grad, exact_dtype=False) helper(2, 8, 4, 4, 4, out_channels=4, kernel_size=3, groups=1) helper(2, 8, 4, 4, 4, out_channels=8, kernel_size=3, groups=8) helper(1, 16, 18, 18, 18, out_channels=16, kernel_size=3, groups=1) helper(1, 16, 18, 18, 18, out_channels=16, kernel_size=3, groups=16) def _run_conv(self, layer, device, inp, grad, ref_conv, ref_input, ref_out, input_format, weight_format, grad_format, output_format): conv = layer(inp.size(1), grad.size(1), ref_conv.weight.size(2)).float().to(device) # load_state_dict will restore the stride & memory_layout on ref_conv.weight. conv.load_state_dict(ref_conv.state_dict()) weight_data = conv.weight.detach().clone().contiguous(memory_format=weight_format) conv.weight.data = weight_data.resize_(weight_data.size(), memory_format=weight_format) input = inp.clone().contiguous(memory_format=input_format) input.resize_(input.size(), memory_format=input_format) input = input.requires_grad_() grad = grad.contiguous(memory_format=grad_format) grad.resize_(grad.size(), memory_format=grad_format) out = conv(input) out.backward(grad) self.assertTrue(out.is_contiguous(memory_format=output_format)) self.assertEqual(out, ref_out) self.assertEqual(conv.weight.grad, ref_conv.weight.grad) self.assertEqual(conv.bias.grad, ref_conv.bias.grad) self.assertEqual(input.grad, ref_input.grad) def _test_conv_cudnn_nhwc_nchw(self, layer, n, c, h, w, k, filter_size, device): data = torch.randint(1, 10, (n, c, h, w), dtype=torch.float32, device=device) ref_input = data.clone().contiguous().requires_grad_(True) ref_conv = layer(c, k, filter_size).float().to(device) ref_out = ref_conv(ref_input) grad = torch.randint(1, 10, ref_out.size(), dtype=torch.float32, device="cuda") ref_out.backward(grad) for w_f in [torch.contiguous_format, torch.channels_last]: for g_f in [torch.contiguous_format, torch.channels_last]: for input_format in [torch.contiguous_format, torch.channels_last]: output_format = torch.contiguous_format # Older versions of CudNN have Channels Last support disabled if torch.backends.cudnn.version() >= 7603: if input_format == torch.channels_last: output_format = torch.channels_last # This is because we have N111 weight that cannot handle # the ambiguous memory_format if w_f == torch.channels_last: if layer == nn.Conv2d and filter_size * c != 1: output_format = torch.channels_last if layer == nn.ConvTranspose2d and filter_size * k != 1: output_format = torch.channels_last self._run_conv(layer, device, data, grad, ref_conv, ref_input, ref_out, input_format, w_f, g_f, output_format) @onlyCUDA @skipCUDAIfRocmVersionLessThan((4, 3)) @skipCUDAIfNotMiopenSuggestNHWC @skipCUDAIfCudnnVersionLessThan(7603) @tf32_on_and_off(0.05) def test_conv_cudnn_mismatch_memory_format(self, device): configs = [ [4, 2, 8, 8, 4, 2], [4, 1, 8, 8, 4, 2], [1, 1, 8, 8, 4, 2], [4, 2, 2, 8, 4, 1], [4, 2, 1, 8, 4, 1], [4, 2, 8, 8, 4, 1], [4, 1, 8, 8, 4, 1], ] for n, c, h, w, k, filter_size in configs: self._test_conv_cudnn_nhwc_nchw(nn.Conv2d, n, c, h, w, k, filter_size, device) self._test_conv_cudnn_nhwc_nchw(nn.ConvTranspose2d, n, c, h, w, k, filter_size, device) # torch.half is erroring out on Windows with CUDA 10.1 + cuDNN 7.6.4 # returning CUDNN_STATUS_BAD_PARAM # Disabling that specific test for now [see issue # 33918] @onlyCUDA @skipCUDAIfNoCudnn @dtypes(torch.float, torch.double) def test_conv_cudnn_nhwc_support(self, device, dtype): input = torch.randn((1, 16, 1, 1), dtype=dtype, device="cuda", requires_grad=True) weight = torch.randn((8, 16, 3, 3), dtype=dtype, device="cuda", requires_grad=True) weight = weight.to(memory_format=torch.channels_last) o = torch.conv2d(input, weight, None, (2, 1), (1, 1), (1, 1), 1) self.assertTrue(o.is_contiguous(memory_format=torch.channels_last)) o.sum().backward() # Test that faster algorithms used for inference produce the same results # Validates depthwise3x3 bug reported in https://github.com/pytorch/pytorch/issues/60176 @onlyCPU @dtypes(torch.float) def test_conv2d_no_grad(self, device, dtype): for batch in [1, 2, 3]: for groups in [1, 2, 4]: input = torch.rand(batch, groups, 8, 8, dtype=dtype, device=device) m = nn.Conv2d(groups, 8, kernel_size=(3, 3), groups=groups, dtype=dtype, device=device) with torch.no_grad(): output_ng = m(input) output = m(input) self.assertEqual(output, output_ng, rtol=1e-2, atol=1e-5) @onlyCUDA @skipCUDAIfNoCudnn @dtypes(torch.float, torch.float16) @precisionOverride({torch.half: 0.002, torch.float: 1e-4}) def test_cudnn_convolution_relu(self, device, dtype): for batch, groups, image_size, kernel_size, memory_format in \ product((1, 2, 3), (1, 2, 4), ((1, 1), (8, 8)), ((1, 1), (3, 3)), (torch.channels_last, torch.contiguous_format)): if image_size[0] < kernel_size[0]: continue inp = torch.rand(batch, groups, *image_size, dtype=dtype, device=device) w = torch.randn(8, groups, *kernel_size, dtype=dtype, device=device) conv2d_out = torch.conv2d(inp, w, None, (1, 1), (0, 0), (1, 1), 1) inp = inp.to(memory_format=memory_format) w = w.to(memory_format=memory_format) if torch.version.hip: cudnn_out = torch.miopen_convolution_relu(inp, w, None, (1, 1), (0, 0), (1, 1), 1) else: cudnn_out = torch.cudnn_convolution_relu(inp, w, None, (1, 1), (0, 0), (1, 1), 1) self.assertTrue(cudnn_out.is_contiguous(memory_format=memory_format)) if tf32_is_not_fp32() and dtype == torch.float: self.assertEqual(conv2d_out.relu(), cudnn_out, atol=4e-3, rtol=0.006) else: self.assertEqual(conv2d_out.relu(), cudnn_out) @onlyCUDA @skipCUDAIfNoCudnn @dtypes(torch.float, torch.float16) @precisionOverride({torch.half: 0.002, torch.float: 1e-4}) def test_cudnn_convolution_add_relu(self, device, dtype): for batch, groups, image_size, kernel_size, memory_format in \ product((1, 2, 3), (1, 2, 4), ((1, 1), (8, 8)), ((1, 1), (3, 3)), (torch.channels_last, torch.contiguous_format)): if image_size[0] < kernel_size[0]: continue inp = torch.rand(batch, groups, *image_size, dtype=dtype, device=device) w = torch.randn(8, groups, *kernel_size, dtype=dtype, device=device) conv2d_out = torch.conv2d(inp, w, None, (1, 1), (0, 0), (1, 1), 1) alpha = 2.0 z = torch.randn_like(conv2d_out) inp = inp.to(memory_format=memory_format) w = w.to(memory_format=memory_format) z = z.to(memory_format=memory_format) if torch.version.hip: cudnn_out = torch.miopen_convolution_add_relu(inp, w, z, alpha, None, (1, 1), (0, 0), (1, 1), 1) else: cudnn_out = torch.cudnn_convolution_add_relu(inp, w, z, alpha, None, (1, 1), (0, 0), (1, 1), 1) self.assertTrue(cudnn_out.is_contiguous(memory_format=memory_format)) if tf32_is_not_fp32() and dtype == torch.float: self.assertEqual(F.relu(conv2d_out + alpha * z), cudnn_out, atol=2e-3, rtol=0.006) else: self.assertEqual(F.relu(conv2d_out + alpha * z), cudnn_out) @onlyCUDA @skipCUDAIfRocm @skipCUDAIfCudnnVersionLessThan(7603) def test_convert_conv2d_weight_memory_format(self, device): input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float32, device=device) model = nn.Sequential( nn.Conv2d(8, 4, 3), nn.BatchNorm2d(4)).to(device).float() for memory_format in [torch.channels_last, torch.contiguous_format]: model = nn.utils.convert_conv2d_weight_memory_format(model, memory_format) out = model(input) self.assertTrue(out.is_contiguous(memory_format=memory_format)) model = nn.Sequential( nn.ConvTranspose2d(8, 4, 3), nn.BatchNorm2d(4)).to(device).float() for memory_format in [torch.channels_last, torch.contiguous_format]: model = nn.utils.convert_conv2d_weight_memory_format(model, memory_format) out = model(input) self.assertTrue(out.is_contiguous(memory_format=memory_format)) @onlyCUDA @skipCUDAIfRocm @skipCUDAIfCudnnVersionLessThan(7603) def test_convert_conv3d_weight_memory_format(self, device): input = torch.randint(1, 10, (2, 8, 4, 4, 4), dtype=torch.float32, device=device) model = nn.Sequential( nn.ConvTranspose3d(8, 4, 3), nn.BatchNorm3d(4)).to(device).float() for memory_format in [torch.channels_last_3d, torch.contiguous_format]: model = nn.utils.convert_conv3d_weight_memory_format(model, memory_format) out = model(input) self.assertTrue(out.is_contiguous(memory_format=memory_format)) def test_conv_double_backward_strided_with_3D_input_and_weight(self, device): # Test that _convolution_double_backward() outputs the correct grad shapes # for 3D input / weight when stride > 1. This is an ad-hoc regression test for a # specific case that was uncovered during the convolution consolidation effort. # The test can be safely deleted if _convolution_double_backward() is removed. input = torch.randn(2, 3, 6, device=device) weight = torch.randn(3, 3, 3, device=device) bias = torch.randn(3, device=device) stride = (2,) padding = (1,) dilation = (1,) transposed = False output_padding = (0,) groups = 1 output = torch.ops.aten.convolution(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups) ggI = torch.randn(input.shape, device=device) ggW = torch.randn(weight.shape, device=device) ggB = torch.randn(bias.shape, device=device) gO = torch.randn(output.shape, device=device) output_mask = [True, True, True] grad_grad_output, grad_input, grad_weight = torch.ops.aten._convolution_double_backward( ggI, ggW, ggB, gO, weight, input, stride, padding, dilation, transposed, output_padding, groups, output_mask) # Make sure the correct shapes are computed. self.assertEqual(grad_grad_output.shape, gO.shape) self.assertEqual(grad_input.shape, input.shape) self.assertEqual(grad_weight.shape, weight.shape) @onlyCUDA @largeTensorTest('40GB') @largeTensorTest('24GB', 'cpu') def test_conv3d_64bit_indexing(self, device): x = torch.rand(1, 32, 512, 512, 256) m = torch.nn.Conv3d(32, 1, kernel_size=1, padding=0, stride=1, bias=False) yref = m(x) y = m.to(device=device)(x.to(device=device)) self.assertEqual(yref, y) instantiate_device_type_tests(TestConvolutionNNDeviceType, globals()) instantiate_parametrized_tests(TestConvolutionNN) if __name__ == '__main__': run_tests()