Fix TestSparse.test_bmm_windows_error when CUDA is not available (#42626)

Summary:
Refactor comnon pattern of (torch.cuda.version and [int(x) for x in torch.cuda.version.split(".")] >= [a, b]) into `_get_torch_cuda_version()` function

Pull Request resolved: https://github.com/pytorch/pytorch/pull/42626

Reviewed By: seemethere

Differential Revision: D22956149

Pulled By: malfet

fbshipit-source-id: 897c55965e53b477cd20f69e8da15d90489035de
This commit is contained in:
Nikita Shulga 2020-08-05 16:05:51 -07:00 committed by Facebook GitHub Bot
parent 5023995292
commit aa4e91a6dc

View File

@ -36,6 +36,9 @@ def cuda_only(inner):
inner(self, *args, **kwargs)
return outer
def _get_torch_cuda_version():
return [int(x) for x in torch.version.cuda.split(".")] if torch.version.cuda else [0, 0]
class TestSparse(TestCase):
@ -928,9 +931,7 @@ class TestSparse(TestCase):
"bmm sparse-dense CUDA is not yet supported in Windows, at least up to CUDA 10.1"
)
@unittest.skipIf(
TEST_CUDA and (
not torch.version.cuda
or [int(x) for x in torch.version.cuda.split(".")] < [10, 1]),
TEST_CUDA and _get_torch_cuda_version() < [10, 1],
"bmm sparse-dense requires CUDA 10.1 or greater"
)
def test_bmm(self):
@ -994,8 +995,7 @@ class TestSparse(TestCase):
"bmm sparse-dense CUDA is not yet supported in Windows, at least up to CUDA 10.1"
)
@unittest.skipIf(
(not torch.version.cuda
or [int(x) for x in torch.version.cuda.split(".")] < [10, 1]),
_get_torch_cuda_version() < [10, 1],
"bmm sparse-dense requires CUDA 10.1 or greater"
)
def test_bmm_deterministic(self):
@ -1030,7 +1030,7 @@ class TestSparse(TestCase):
@cuda_only
@unittest.skipIf(
not IS_WINDOWS or [int(x) for x in torch.version.cuda.split(".")] >= [11, 0],
not IS_WINDOWS or _get_torch_cuda_version() >= [11, 0],
"this test ensures bmm sparse-dense CUDA gives an error when run on Windows with CUDA < 11.0"
)
def test_bmm_windows_error(self):
@ -1044,8 +1044,7 @@ class TestSparse(TestCase):
@cuda_only
@skipIfRocm
@unittest.skipIf(
(torch.version.cuda
and [int(x) for x in torch.version.cuda.split(".")] >= [10, 1]),
_get_torch_cuda_version() >= [10, 1],
"this test ensures bmm gives error if CUDA version is less than 10.1"
)
def test_bmm_cuda_version_error(self):