mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "add unit test for preferred_blas_library settings (#150581)"
This reverts commit 781d28e265.
Reverted https://github.com/pytorch/pytorch/pull/150581 on behalf of https://github.com/clee2000 due to new test broken internally D72395624 ([comment](https://github.com/pytorch/pytorch/pull/150581#issuecomment-2777228731))
This commit is contained in:
parent
1ab6c4ff04
commit
b0e28f60df
|
|
@ -595,64 +595,6 @@ class TestCuda(TestCase):
|
|||
q_copy[1].fill_(10)
|
||||
self.assertEqual(q_copy[3], torch.cuda.IntStorage(10).fill_(10))
|
||||
|
||||
@setBlasBackendsToDefaultFinally
|
||||
def test_preferred_blas_library_settings(self):
|
||||
def _check_default():
|
||||
default = torch.backends.cuda.preferred_blas_library()
|
||||
if torch.version.cuda:
|
||||
# CUDA logic is easy, it's always cublas
|
||||
self.assertTrue(default == torch._C._BlasBackend.Cublas)
|
||||
else:
|
||||
# ROCm logic is less so, it's cublaslt for some Instinct, cublas for all else
|
||||
gcn_arch = str(
|
||||
torch.cuda.get_device_properties(0).gcnArchName.split(":", 1)[0]
|
||||
)
|
||||
if gcn_arch in ["gfx90a", "gfx942", "gfx950"]:
|
||||
self.assertTrue(default == torch._C._BlasBackend.Cublaslt)
|
||||
else:
|
||||
self.assertTrue(default == torch._C._BlasBackend.Cublas)
|
||||
|
||||
_check_default()
|
||||
# "Default" can be set but is immediately reset internally to the actual default value.
|
||||
self.assertTrue(
|
||||
torch.backends.cuda.preferred_blas_library("default")
|
||||
!= torch._C._BlasBackend.Default
|
||||
)
|
||||
_check_default()
|
||||
self.assertTrue(
|
||||
torch.backends.cuda.preferred_blas_library("cublas")
|
||||
== torch._C._BlasBackend.Cublas
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.backends.cuda.preferred_blas_library("hipblas")
|
||||
== torch._C._BlasBackend.Cublas
|
||||
)
|
||||
# check bad strings
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Unknown input value. Choose from: default, cublas, hipblas, cublaslt, hipblaslt, ck.",
|
||||
):
|
||||
torch.backends.cuda.preferred_blas_library("unknown")
|
||||
# check bad input type
|
||||
with self.assertRaisesRegex(RuntimeError, "Unknown input value type."):
|
||||
torch.backends.cuda.preferred_blas_library(1.0)
|
||||
# check env var override
|
||||
custom_envs = [
|
||||
{"TORCH_BLAS_PREFER_CUBLASLT": "1"},
|
||||
{"TORCH_BLAS_PREFER_HIPBLASLT": "1"},
|
||||
]
|
||||
test_script = "import torch;print(torch.backends.cuda.preferred_blas_library())"
|
||||
for env_config in custom_envs:
|
||||
env = os.environ.copy()
|
||||
for key, value in env_config.items():
|
||||
env[key] = value
|
||||
r = (
|
||||
subprocess.check_output([sys.executable, "-c", test_script], env=env)
|
||||
.decode("ascii")
|
||||
.strip()
|
||||
)
|
||||
self.assertEqual("_BlasBackend.Cublaslt", r)
|
||||
|
||||
@unittest.skipIf(TEST_CUDAMALLOCASYNC, "temporarily disabled for async")
|
||||
@setBlasBackendsToDefaultFinally
|
||||
def test_cublas_workspace_explicit_allocation(self):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user