[CUDA][cuBLAS][cuBLASLt] avoid polluting prefer cuBLAS/Lt setting across tests (#153655)

Some tests may not set the preferred backend, which leads to unexpected behavior when multiple tests are run vs. standalone

Tests that should exercise both backends should explicitly parametrize this setting

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153655
Approved by: https://github.com/ngimel
This commit is contained in:
Eddie Yan 2025-05-20 16:18:35 +00:00 committed by PyTorch MergeBot
parent a7c01d7f13
commit 5163bf0069

View File

@ -62,6 +62,15 @@ if TEST_CUDA:
assert torch.get_default_dtype() is torch.float32
@contextlib.contextmanager
def blas_library_context(backend):
prev_backend = torch.backends.cuda.preferred_blas_library()
torch.backends.cuda.preferred_blas_library(backend)
try:
yield
finally:
torch.backends.cuda.preferred_blas_library(prev_backend)
class TestMatmulCuda(TestCase):
def setUp(self):
super().setUp()
@ -141,8 +150,10 @@ class TestMatmulCuda(TestCase):
torch.float32: xtol(atol=1e-1, rtol=1e-1)})
@dtypes(torch.float16, torch.bfloat16, torch.float32)
@parametrize("size", [100, 1000, 10000])
def test_cublas_addmm(self, size: int, dtype: torch.dtype):
self.cublas_addmm(size, dtype, False)
@parametrize("backend", ["cublas", "cublaslt"])
def test_cublas_addmm(self, size: int, dtype: torch.dtype, backend):
with blas_library_context(backend):
self.cublas_addmm(size, dtype, False)
@onlyCUDA
@skipIfRocmVersionLessThan((5, 2))
@ -151,31 +162,31 @@ class TestMatmulCuda(TestCase):
torch.bfloat16: xtol(atol=1e1, rtol=2e-1)})
@dtypes(torch.float16, torch.bfloat16)
@parametrize("size", [100, 1000, 10000])
def test_cublas_addmm_reduced_precision(self, size: int, dtype: torch.dtype):
self.cublas_addmm(size, dtype, True)
@parametrize("backend", ["cublas", "cublaslt"])
def test_cublas_addmm_reduced_precision(self, size: int, dtype: torch.dtype, backend):
with blas_library_context(backend):
self.cublas_addmm(size, dtype, True)
@onlyCUDA
@skipIfRocmVersionLessThan((5, 2))
@dtypes(torch.float16)
# m == 4 chooses OUTPUT_TYPE reduction on H200
# m == 8 chooses OUTOUT_TYPE reduction on A100
# m == 8 chooses OUTPUT_TYPE reduction on A100
@parametrize("small_size", [4, 8])
@parametrize("size", [32768])
@parametrize("backend", ["cublaslt", "cublas"])
def test_cublas_addmm_no_reduced_precision(self, small_size: int, size: int, dtype: torch.dtype, backend):
# TODO(eqy): replace with contextlib once that is merged
orig = torch.backends.cuda.preferred_blas_library()
torch.backends.cuda.preferred_blas_library(backend)
orig_precision = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
m1 = torch.full((small_size, size), 65504.0, dtype=dtype, device='cuda')
m2 = torch.ones((size, small_size), dtype=dtype, device='cuda')
m2[size // 2:, :] = -1.0
b = torch.zeros((small_size,), dtype=dtype, device='cuda')
out = torch.addmm(b, m1, m2, beta=1.0)
self.assertEqual(out.sum().item(), 0.0)
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig_precision
torch.backends.cuda.preferred_blas_library(orig)
with blas_library_context(backend):
torch.backends.cuda.preferred_blas_library(backend)
orig_precision = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
m1 = torch.full((small_size, size), 65504.0, dtype=dtype, device='cuda')
m2 = torch.ones((size, small_size), dtype=dtype, device='cuda')
m2[size // 2:, :] = -1.0
b = torch.zeros((small_size,), dtype=dtype, device='cuda')
out = torch.addmm(b, m1, m2, beta=1.0)
self.assertEqual(out.sum().item(), 0.0)
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig_precision
@onlyCUDA
@skipIfRocmVersionLessThan((5, 2))
@ -184,8 +195,10 @@ class TestMatmulCuda(TestCase):
torch.bfloat16: xtol(atol=1e1, rtol=2e-1)})
@dtypes(torch.float16, torch.bfloat16)
@parametrize("size", [100, 1000, 10000])
def test_cublas_addmm_reduced_precision_fp16_accumulate(self, size: int, dtype: torch.dtype):
self.cublas_addmm(size, dtype, False, True)
@parametrize("backend", ["cublas", "cublaslt"])
def test_cublas_addmm_reduced_precision_fp16_accumulate(self, size: int, dtype: torch.dtype, backend):
with blas_library_context(backend):
self.cublas_addmm(size, dtype, False, True)
@onlyCUDA
@skipIfRocm
@ -479,49 +492,48 @@ class TestMatmulCuda(TestCase):
def test_mm_bmm_dtype_overload(self, input_dtype, M, N, K, batch_size, backend):
device = "cuda"
dtype = input_dtype
torch.backends.cuda.preferred_blas_library(backend)
def create_inputs(B=None):
if B is None:
a = torch.randn(M, K, device=device, dtype=dtype)
b = torch.randn(K, N, device=device, dtype=dtype)
else:
a = torch.randn(B, M, K, device=device, dtype=dtype)
b = torch.randn(B, K, N, device=device, dtype=dtype)
return a, b
a, b = create_inputs(batch_size)
a_fp32, b_fp32 = a.to(torch.float32), b.to(torch.float32)
output_dtypes = [torch.float32]
if input_dtype != torch.float32:
output_dtypes.append(input_dtype)
for output_dtype in output_dtypes:
# Catch edge case of incompat with bfloat16 and major version < 8
if input_dtype == torch.bfloat16 and not PLATFORM_SUPPORTS_BF16:
if output_dtype == torch.bfloat16:
continue
if batch_size:
with self.assertRaises(RuntimeError):
torch.bmm(a, b, out_dtype=output_dtype)
with blas_library_context(backend):
def create_inputs(B=None):
if B is None:
a = torch.randn(M, K, device=device, dtype=dtype)
b = torch.randn(K, N, device=device, dtype=dtype)
else:
with self.assertRaises(RuntimeError):
torch.mm(a, b, out_dtype=output_dtype)
else:
if batch_size:
out = torch.bmm(a, b, out_dtype=output_dtype)
baseline = torch.bmm(a_fp32, b_fp32) if output_dtype == torch.float32 else torch.bmm(a, b)
a = torch.randn(B, M, K, device=device, dtype=dtype)
b = torch.randn(B, K, N, device=device, dtype=dtype)
return a, b
a, b = create_inputs(batch_size)
a_fp32, b_fp32 = a.to(torch.float32), b.to(torch.float32)
output_dtypes = [torch.float32]
if input_dtype != torch.float32:
output_dtypes.append(input_dtype)
for output_dtype in output_dtypes:
# Catch edge case of incompat with bfloat16 and major version < 8
if input_dtype == torch.bfloat16 and not PLATFORM_SUPPORTS_BF16:
if output_dtype == torch.bfloat16:
continue
if batch_size:
with self.assertRaises(RuntimeError):
torch.bmm(a, b, out_dtype=output_dtype)
else:
with self.assertRaises(RuntimeError):
torch.mm(a, b, out_dtype=output_dtype)
else:
out = torch.mm(a, b, out_dtype=output_dtype)
baseline = torch.mm(a_fp32, b_fp32) if output_dtype == torch.float32 else torch.mm(a, b)
if batch_size:
out = torch.bmm(a, b, out_dtype=output_dtype)
baseline = torch.bmm(a_fp32, b_fp32) if output_dtype == torch.float32 else torch.bmm(a, b)
else:
out = torch.mm(a, b, out_dtype=output_dtype)
baseline = torch.mm(a_fp32, b_fp32) if output_dtype == torch.float32 else torch.mm(a, b)
self.assertEqual(out.dtype, output_dtype)
self.assertEqual(out.dtype, output_dtype)
torch.testing.assert_close(out, baseline, atol=1e-3, rtol=1e-3)
torch.testing.assert_close(out, baseline, atol=1e-3, rtol=1e-3)
@onlyCUDA
@ -535,51 +547,56 @@ class TestMatmulCuda(TestCase):
def test_addmm_baddmm_dtype_overload(self, input_dtype, M, N, K, batch_size, backend):
device = "cuda"
dtype = input_dtype
torch.backends.cuda.preferred_blas_library(backend)
def create_inputs(B=None):
if B is None:
a = torch.randn(M, K, device=device, dtype=dtype)
b = torch.randn(K, N, device=device, dtype=dtype)
c = torch.randn(M, N, device=device, dtype=dtype)
else:
a = torch.randn(B, M, K, device=device, dtype=dtype)
b = torch.randn(B, K, N, device=device, dtype=dtype)
c = torch.randn(B, M, N, device=device, dtype=dtype)
return a, b, c
a, b, c = create_inputs(batch_size)
a_fp32, b_fp32, c_fp32 = a.to(torch.float32), b.to(torch.float32), c.to(torch.float32)
output_dtypes = [torch.float32]
if input_dtype != torch.float32:
output_dtypes.append(input_dtype)
for output_dtype in output_dtypes:
# Catch edge case of incompat with bfloat16 and major version < 8
if input_dtype == torch.bfloat16 and not PLATFORM_SUPPORTS_BF16:
if output_dtype == torch.bfloat16:
continue
if batch_size:
with self.assertRaises(RuntimeError):
torch.baddbmm(c, a, b, out_dtype=output_dtype)
with blas_library_context(backend):
def create_inputs(B=None):
if B is None:
a = torch.randn(M, K, device=device, dtype=dtype)
b = torch.randn(K, N, device=device, dtype=dtype)
c = torch.randn(M, N, device=device, dtype=dtype)
else:
with self.assertRaises(RuntimeError):
torch.addmm(c, a, b, out_dtype=output_dtype)
else:
if batch_size:
out = torch.baddbmm(c, a, b, out_dtype=output_dtype)
baseline = torch.baddbmm(c_fp32, a_fp32, b_fp32) if output_dtype == torch.float32 else torch.baddbmm(c, a, b)
else:
out = torch.addmm(c, a, b, out_dtype=output_dtype)
baseline = torch.addmm(c_fp32, a_fp32, b_fp32) if output_dtype == torch.float32 else torch.addmm(c, a, b)
a = torch.randn(B, M, K, device=device, dtype=dtype)
b = torch.randn(B, K, N, device=device, dtype=dtype)
c = torch.randn(B, M, N, device=device, dtype=dtype)
self.assertEqual(out.dtype, output_dtype)
torch.testing.assert_close(out, baseline, atol=1e-3, rtol=1e-3)
return a, b, c
a, b, c = create_inputs(batch_size)
a_fp32, b_fp32, c_fp32 = a.to(torch.float32), b.to(torch.float32), c.to(torch.float32)
output_dtypes = [torch.float32]
if input_dtype != torch.float32:
output_dtypes.append(input_dtype)
for output_dtype in output_dtypes:
# Catch edge case of incompat with bfloat16 and major version < 8
if input_dtype == torch.bfloat16 and not PLATFORM_SUPPORTS_BF16:
if output_dtype == torch.bfloat16:
continue
if batch_size:
with self.assertRaises(RuntimeError):
torch.baddbmm(c, a, b, out_dtype=output_dtype)
else:
with self.assertRaises(RuntimeError):
torch.addmm(c, a, b, out_dtype=output_dtype)
else:
if batch_size:
out = torch.baddbmm(c, a, b, out_dtype=output_dtype)
if output_dtype == torch.float32:
baseline = torch.baddbmm(c_fp32, a_fp32, b_fp32)
else:
baseline = torch.baddbmm(c, a, b)
else:
out = torch.addmm(c, a, b, out_dtype=output_dtype)
if output_dtype == torch.float32:
baseline = torch.addmm(c_fp32, a_fp32, b_fp32)
else:
baseline = torch.addmm(c, a, b)
self.assertEqual(out.dtype, output_dtype)
torch.testing.assert_close(out, baseline, atol=1e-3, rtol=1e-3)
@onlyCUDA
@ -590,35 +607,36 @@ class TestMatmulCuda(TestCase):
M, N, K = 32, 32, 32
device = "cuda"
dtype = torch.float16
torch.backends.cuda.preferred_blas_library(backend)
with blas_library_context(backend):
torch.backends.cuda.preferred_blas_library(backend)
orig_fp16_accum = torch.backends.cuda.matmul.allow_fp16_accumulation
torch.backends.cuda.matmul.allow_fp16_accumulation = True
orig_fp16_accum = torch.backends.cuda.matmul.allow_fp16_accumulation
torch.backends.cuda.matmul.allow_fp16_accumulation = True
def create_inputs():
a = torch.randn(M, K, device=device, dtype=dtype)
b = torch.randn(K, N, device=device, dtype=dtype)
c = torch.randn(M, N, device=device, dtype=dtype)
return a, b, c
def create_inputs():
a = torch.randn(M, K, device=device, dtype=dtype)
b = torch.randn(K, N, device=device, dtype=dtype)
c = torch.randn(M, N, device=device, dtype=dtype)
return a, b, c
def expand(tensor):
return tensor.unsqueeze(0).expand(batch_size, *tensor.shape)
def expand(tensor):
return tensor.unsqueeze(0).expand(batch_size, *tensor.shape)
a, b, c = create_inputs()
a, b, c = create_inputs()
with self.assertRaises(Exception):
torch.baddbmm(expand(c), expand(a), expand(b), out_dtype=torch.float32)
with self.assertRaises(Exception):
torch.baddbmm(expand(c), expand(a), expand(b), out_dtype=torch.float32)
with self.assertRaises(Exception):
torch.addmm(c, a, b, out_dtype=torch.float32)
with self.assertRaises(Exception):
torch.addmm(c, a, b, out_dtype=torch.float32)
with self.assertRaises(Exception):
torch.bmm(expand(a,), expand(b), out_dtype=torch.float32)
with self.assertRaises(Exception):
torch.bmm(expand(a,), expand(b), out_dtype=torch.float32)
with self.assertRaises(Exception):
torch.mm(a, b, out_dtype=torch.float32)
with self.assertRaises(Exception):
torch.mm(a, b, out_dtype=torch.float32)
torch.backends.cuda.matmul.allow_fp16_accumulation = orig_fp16_accum
torch.backends.cuda.matmul.allow_fp16_accumulation = orig_fp16_accum
f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+ devices"
mx_skip_msg = "MX gemm is only supported on CUDA capability 10.0+"