[CUDA][cuBLAS][cuBLASLt] avoid polluting prefer cuBLAS/Lt setting across tests (#153655)

Some tests may not set the preferred backend, which leads to unexpected behavior when multiple tests are run vs. standalone

Tests that should exercise both backends should explicitly parametrize this setting

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153655
Approved by: https://github.com/ngimel
This commit is contained in:
Eddie Yan 2025-05-16 21:31:09 +00:00 committed by PyTorch MergeBot
parent 084c4aa614
commit 3bde364996

View File

@ -62,6 +62,15 @@ if TEST_CUDA:
assert torch.get_default_dtype() is torch.float32
@contextlib.contextmanager
def blas_library_context(backend):
prev_backend = torch.backends.cuda.preferred_blas_library()
torch.backends.cuda.preferred_blas_library(backend)
try:
yield
finally:
torch.backends.cuda.preferred_blas_library(prev_backend)
class TestMatmulCuda(TestCase):
def setUp(self):
super().setUp()
@ -141,7 +150,9 @@ class TestMatmulCuda(TestCase):
torch.float32: xtol(atol=1e-1, rtol=1e-1)})
@dtypes(torch.float16, torch.bfloat16, torch.float32)
@parametrize("size", [100, 1000, 10000])
def test_cublas_addmm(self, size: int, dtype: torch.dtype):
@parametrize("backend", ["cublas", "cublaslt"])
def test_cublas_addmm(self, size: int, dtype: torch.dtype, backend):
with blas_library_context(backend):
self.cublas_addmm(size, dtype, False)
@onlyCUDA
@ -151,7 +162,9 @@ class TestMatmulCuda(TestCase):
torch.bfloat16: xtol(atol=1e1, rtol=2e-1)})
@dtypes(torch.float16, torch.bfloat16)
@parametrize("size", [100, 1000, 10000])
def test_cublas_addmm_reduced_precision(self, size: int, dtype: torch.dtype):
@parametrize("backend", ["cublas", "cublaslt"])
def test_cublas_addmm_reduced_precision(self, size: int, dtype: torch.dtype, backend):
with blas_library_context(backend):
self.cublas_addmm(size, dtype, True)
@onlyCUDA
@ -161,7 +174,9 @@ class TestMatmulCuda(TestCase):
torch.bfloat16: xtol(atol=1e1, rtol=2e-1)})
@dtypes(torch.float16, torch.bfloat16)
@parametrize("size", [100, 1000, 10000])
def test_cublas_addmm_reduced_precision_fp16_accumulate(self, size: int, dtype: torch.dtype):
@parametrize("backend", ["cublas", "cublaslt"])
def test_cublas_addmm_reduced_precision_fp16_accumulate(self, size: int, dtype: torch.dtype, backend):
with blas_library_context(backend):
self.cublas_addmm(size, dtype, False, True)
@onlyCUDA
@ -456,8 +471,7 @@ class TestMatmulCuda(TestCase):
def test_mm_bmm_dtype_overload(self, input_dtype, M, N, K, batch_size, backend):
device = "cuda"
dtype = input_dtype
torch.backends.cuda.preferred_blas_library(backend)
with blas_library_context(backend):
def create_inputs(B=None):
if B is None:
a = torch.randn(M, K, device=device, dtype=dtype)
@ -512,8 +526,7 @@ class TestMatmulCuda(TestCase):
def test_addmm_baddmm_dtype_overload(self, input_dtype, M, N, K, batch_size, backend):
device = "cuda"
dtype = input_dtype
torch.backends.cuda.preferred_blas_library(backend)
with blas_library_context(backend):
def create_inputs(B=None):
if B is None:
a = torch.randn(M, K, device=device, dtype=dtype)
@ -550,10 +563,16 @@ class TestMatmulCuda(TestCase):
else:
if batch_size:
out = torch.baddbmm(c, a, b, out_dtype=output_dtype)
baseline = torch.baddbmm(c_fp32, a_fp32, b_fp32) if output_dtype == torch.float32 else torch.baddbmm(c, a, b)
if output_dtype == torch.float32:
baseline = torch.baddbmm(c_fp32, a_fp32, b_fp32)
else:
baseline = torch.baddbmm(c, a, b)
else:
out = torch.addmm(c, a, b, out_dtype=output_dtype)
baseline = torch.addmm(c_fp32, a_fp32, b_fp32) if output_dtype == torch.float32 else torch.addmm(c, a, b)
if output_dtype == torch.float32:
baseline = torch.addmm(c_fp32, a_fp32, b_fp32)
else:
baseline = torch.addmm(c, a, b)
self.assertEqual(out.dtype, output_dtype)
torch.testing.assert_close(out, baseline, atol=1e-3, rtol=1e-3)
@ -567,6 +586,7 @@ class TestMatmulCuda(TestCase):
M, N, K = 32, 32, 32
device = "cuda"
dtype = torch.float16
with blas_library_context(backend):
torch.backends.cuda.preferred_blas_library(backend)
orig_fp16_accum = torch.backends.cuda.matmul.allow_fp16_accumulation