Enable half precision types on test_conv_cudnn_nhwc_support (#163444)

This PR adds flaot16 and bfloat16 cases to `test_conv_cudnn_nhwc_support` and removes outdated comments.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163444
Approved by: https://github.com/Skylion007
This commit is contained in:
Yuanyuan Chen 2025-09-22 04:11:18 +00:00 committed by PyTorch MergeBot
parent 01f927eb40
commit 281bb56cc5

View File

@ -3823,12 +3823,9 @@ class TestConvolutionNNDeviceType(NNTestCase):
nn.ConvTranspose2d, n, c, h, w, k, filter_size, device
)
# torch.half is erroring out on Windows with CUDA 10.1 + cuDNN 7.6.4
# returning CUDNN_STATUS_BAD_PARAM
# Disabling that specific test for now [see issue # 33918]
@onlyCUDA
@skipCUDAIfNoCudnn
@dtypes(torch.float, torch.double)
@dtypes(torch.float, torch.double, torch.float16, torch.bfloat16)
def test_conv_cudnn_nhwc_support(self, device, dtype):
input = torch.randn(
(1, 16, 1, 1), dtype=dtype, device="cuda", requires_grad=True