mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
fix broken nn_convolution test (#166666)
Summary: Broken by oss diff during oncall by third party contributor Test Plan: buck test 'fbcode//mode/dev-nosan' fbcode//caffe2/test:nn_convolution -- --run-disabled Differential Revision: D85899891 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166666 Approved by: https://github.com/atalman, https://github.com/seemethere, https://github.com/Skylion007
This commit is contained in:
parent
d2be06f673
commit
ef8d97efcf
|
|
@ -12,6 +12,16 @@ import torch.backends.cudnn as cudnn
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.testing import make_tensor
|
||||
|
||||
|
||||
def _get_cudnn_version():
|
||||
"""Safely get cuDNN version, returning None if unavailable."""
|
||||
try:
|
||||
return torch.backends.cudnn.version()
|
||||
except RuntimeError:
|
||||
return None
|
||||
|
||||
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_CUDNN, tf32_on_and_off
|
||||
from torch.testing._internal.common_device_type import (
|
||||
disablecuDNN,
|
||||
|
|
@ -4210,10 +4220,7 @@ class TestConvolutionNNDeviceType(NNTestCase):
|
|||
@largeTensorTest("20GB")
|
||||
@largeTensorTest("64GB", "cpu")
|
||||
# TODO(eqy): Remove this once it is fixed in cuDNN and we can dispatch to it again
|
||||
@xfailIf(
|
||||
torch.backends.cudnn.version() is not None
|
||||
and torch.backends.cudnn.version() > 91000
|
||||
)
|
||||
@xfailIf(_get_cudnn_version() is not None and _get_cudnn_version() > 91000)
|
||||
def test_depthwise_conv_64bit_indexing(self, device):
|
||||
x = torch.randn(1, 2, 32800, 32800, dtype=torch.half).to(
|
||||
memory_format=torch.channels_last
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user