mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Fix TestSparse.test_bmm_windows_error when CUDA is not available (#42626)
Summary:
Refactor comnon pattern of (torch.cuda.version and [int(x) for x in torch.cuda.version.split(".")] >= [a, b]) into `_get_torch_cuda_version()` function
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42626
Reviewed By: seemethere
Differential Revision: D22956149
Pulled By: malfet
fbshipit-source-id: 897c55965e53b477cd20f69e8da15d90489035de
This commit is contained in:
parent
5023995292
commit
aa4e91a6dc
|
|
@ -36,6 +36,9 @@ def cuda_only(inner):
|
|||
inner(self, *args, **kwargs)
|
||||
return outer
|
||||
|
||||
def _get_torch_cuda_version():
|
||||
return [int(x) for x in torch.version.cuda.split(".")] if torch.version.cuda else [0, 0]
|
||||
|
||||
|
||||
class TestSparse(TestCase):
|
||||
|
||||
|
|
@ -928,9 +931,7 @@ class TestSparse(TestCase):
|
|||
"bmm sparse-dense CUDA is not yet supported in Windows, at least up to CUDA 10.1"
|
||||
)
|
||||
@unittest.skipIf(
|
||||
TEST_CUDA and (
|
||||
not torch.version.cuda
|
||||
or [int(x) for x in torch.version.cuda.split(".")] < [10, 1]),
|
||||
TEST_CUDA and _get_torch_cuda_version() < [10, 1],
|
||||
"bmm sparse-dense requires CUDA 10.1 or greater"
|
||||
)
|
||||
def test_bmm(self):
|
||||
|
|
@ -994,8 +995,7 @@ class TestSparse(TestCase):
|
|||
"bmm sparse-dense CUDA is not yet supported in Windows, at least up to CUDA 10.1"
|
||||
)
|
||||
@unittest.skipIf(
|
||||
(not torch.version.cuda
|
||||
or [int(x) for x in torch.version.cuda.split(".")] < [10, 1]),
|
||||
_get_torch_cuda_version() < [10, 1],
|
||||
"bmm sparse-dense requires CUDA 10.1 or greater"
|
||||
)
|
||||
def test_bmm_deterministic(self):
|
||||
|
|
@ -1030,7 +1030,7 @@ class TestSparse(TestCase):
|
|||
|
||||
@cuda_only
|
||||
@unittest.skipIf(
|
||||
not IS_WINDOWS or [int(x) for x in torch.version.cuda.split(".")] >= [11, 0],
|
||||
not IS_WINDOWS or _get_torch_cuda_version() >= [11, 0],
|
||||
"this test ensures bmm sparse-dense CUDA gives an error when run on Windows with CUDA < 11.0"
|
||||
)
|
||||
def test_bmm_windows_error(self):
|
||||
|
|
@ -1044,8 +1044,7 @@ class TestSparse(TestCase):
|
|||
@cuda_only
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(
|
||||
(torch.version.cuda
|
||||
and [int(x) for x in torch.version.cuda.split(".")] >= [10, 1]),
|
||||
_get_torch_cuda_version() >= [10, 1],
|
||||
"this test ensures bmm gives error if CUDA version is less than 10.1"
|
||||
)
|
||||
def test_bmm_cuda_version_error(self):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user