fix broken nn_convolution test (#166666)

Summary: Broken by oss diff during oncall by third party contributor

Test Plan: buck test 'fbcode//mode/dev-nosan' fbcode//caffe2/test:nn_convolution -- --run-disabled

Differential Revision: D85899891

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166666
Approved by: https://github.com/atalman, https://github.com/seemethere, https://github.com/Skylion007
This commit is contained in:
Camyll Harajli 2025-10-31 19:59:50 +00:00 committed by PyTorch MergeBot
parent d2be06f673
commit ef8d97efcf

View File

@ -12,6 +12,16 @@ import torch.backends.cudnn as cudnn
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.testing import make_tensor from torch.testing import make_tensor
def _get_cudnn_version():
"""Safely get cuDNN version, returning None if unavailable."""
try:
return torch.backends.cudnn.version()
except RuntimeError:
return None
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_CUDNN, tf32_on_and_off from torch.testing._internal.common_cuda import TEST_CUDA, TEST_CUDNN, tf32_on_and_off
from torch.testing._internal.common_device_type import ( from torch.testing._internal.common_device_type import (
disablecuDNN, disablecuDNN,
@ -4210,10 +4220,7 @@ class TestConvolutionNNDeviceType(NNTestCase):
@largeTensorTest("20GB") @largeTensorTest("20GB")
@largeTensorTest("64GB", "cpu") @largeTensorTest("64GB", "cpu")
# TODO(eqy): Remove this once it is fixed in cuDNN and we can dispatch to it again # TODO(eqy): Remove this once it is fixed in cuDNN and we can dispatch to it again
@xfailIf( @xfailIf(_get_cudnn_version() is not None and _get_cudnn_version() > 91000)
torch.backends.cudnn.version() is not None
and torch.backends.cudnn.version() > 91000
)
def test_depthwise_conv_64bit_indexing(self, device): def test_depthwise_conv_64bit_indexing(self, device):
x = torch.randn(1, 2, 32800, 32800, dtype=torch.half).to( x = torch.randn(1, 2, 32800, 32800, dtype=torch.half).to(
memory_format=torch.channels_last memory_format=torch.channels_last