mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Enable half precision types on test_conv_cudnn_nhwc_support (#163444)
This PR adds flaot16 and bfloat16 cases to `test_conv_cudnn_nhwc_support` and removes outdated comments. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163444 Approved by: https://github.com/Skylion007
This commit is contained in:
parent
01f927eb40
commit
281bb56cc5
|
|
@ -3823,12 +3823,9 @@ class TestConvolutionNNDeviceType(NNTestCase):
|
|||
nn.ConvTranspose2d, n, c, h, w, k, filter_size, device
|
||||
)
|
||||
|
||||
# torch.half is erroring out on Windows with CUDA 10.1 + cuDNN 7.6.4
|
||||
# returning CUDNN_STATUS_BAD_PARAM
|
||||
# Disabling that specific test for now [see issue # 33918]
|
||||
@onlyCUDA
|
||||
@skipCUDAIfNoCudnn
|
||||
@dtypes(torch.float, torch.double)
|
||||
@dtypes(torch.float, torch.double, torch.float16, torch.bfloat16)
|
||||
def test_conv_cudnn_nhwc_support(self, device, dtype):
|
||||
input = torch.randn(
|
||||
(1, 16, 1, 1), dtype=dtype, device="cuda", requires_grad=True
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user