diff --git a/test/nn/test_convolution.py b/test/nn/test_convolution.py index b5613c5dcac..f6d03554615 100644 --- a/test/nn/test_convolution.py +++ b/test/nn/test_convolution.py @@ -12,6 +12,16 @@ import torch.backends.cudnn as cudnn import torch.nn as nn import torch.nn.functional as F from torch.testing import make_tensor + + +def _get_cudnn_version(): + """Safely get cuDNN version, returning None if unavailable.""" + try: + return torch.backends.cudnn.version() + except RuntimeError: + return None + + from torch.testing._internal.common_cuda import TEST_CUDA, TEST_CUDNN, tf32_on_and_off from torch.testing._internal.common_device_type import ( disablecuDNN, @@ -4210,10 +4220,7 @@ class TestConvolutionNNDeviceType(NNTestCase): @largeTensorTest("20GB") @largeTensorTest("64GB", "cpu") # TODO(eqy): Remove this once it is fixed in cuDNN and we can dispatch to it again - @xfailIf( - torch.backends.cudnn.version() is not None - and torch.backends.cudnn.version() > 91000 - ) + @xfailIf(_get_cudnn_version() is not None and _get_cudnn_version() > 91000) def test_depthwise_conv_64bit_indexing(self, device): x = torch.randn(1, 2, 32800, 32800, dtype=torch.half).to( memory_format=torch.channels_last