diff --git a/test/test_cuda.py b/test/test_cuda.py index 4f4fb5148a7..a3cc62c5e1d 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -595,64 +595,6 @@ class TestCuda(TestCase): q_copy[1].fill_(10) self.assertEqual(q_copy[3], torch.cuda.IntStorage(10).fill_(10)) - @setBlasBackendsToDefaultFinally - def test_preferred_blas_library_settings(self): - def _check_default(): - default = torch.backends.cuda.preferred_blas_library() - if torch.version.cuda: - # CUDA logic is easy, it's always cublas - self.assertTrue(default == torch._C._BlasBackend.Cublas) - else: - # ROCm logic is less so, it's cublaslt for some Instinct, cublas for all else - gcn_arch = str( - torch.cuda.get_device_properties(0).gcnArchName.split(":", 1)[0] - ) - if gcn_arch in ["gfx90a", "gfx942", "gfx950"]: - self.assertTrue(default == torch._C._BlasBackend.Cublaslt) - else: - self.assertTrue(default == torch._C._BlasBackend.Cublas) - - _check_default() - # "Default" can be set but is immediately reset internally to the actual default value. - self.assertTrue( - torch.backends.cuda.preferred_blas_library("default") - != torch._C._BlasBackend.Default - ) - _check_default() - self.assertTrue( - torch.backends.cuda.preferred_blas_library("cublas") - == torch._C._BlasBackend.Cublas - ) - self.assertTrue( - torch.backends.cuda.preferred_blas_library("hipblas") - == torch._C._BlasBackend.Cublas - ) - # check bad strings - with self.assertRaisesRegex( - RuntimeError, - "Unknown input value. Choose from: default, cublas, hipblas, cublaslt, hipblaslt, ck.", - ): - torch.backends.cuda.preferred_blas_library("unknown") - # check bad input type - with self.assertRaisesRegex(RuntimeError, "Unknown input value type."): - torch.backends.cuda.preferred_blas_library(1.0) - # check env var override - custom_envs = [ - {"TORCH_BLAS_PREFER_CUBLASLT": "1"}, - {"TORCH_BLAS_PREFER_HIPBLASLT": "1"}, - ] - test_script = "import torch;print(torch.backends.cuda.preferred_blas_library())" - for env_config in custom_envs: - env = os.environ.copy() - for key, value in env_config.items(): - env[key] = value - r = ( - subprocess.check_output([sys.executable, "-c", test_script], env=env) - .decode("ascii") - .strip() - ) - self.assertEqual("_BlasBackend.Cublaslt", r) - @unittest.skipIf(TEST_CUDAMALLOCASYNC, "temporarily disabled for async") @setBlasBackendsToDefaultFinally def test_cublas_workspace_explicit_allocation(self):