mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[CUDA][cuBLAS][cuBLASLt] avoid polluting prefer cuBLAS/Lt setting across tests (#153655)
Some tests may not set the preferred backend, which leads to unexpected behavior when multiple tests are run vs. standalone Tests that should exercise both backends should explicitly parametrize this setting Pull Request resolved: https://github.com/pytorch/pytorch/pull/153655 Approved by: https://github.com/ngimel
This commit is contained in:
parent
084c4aa614
commit
3bde364996
|
|
@ -62,6 +62,15 @@ if TEST_CUDA:
|
||||||
assert torch.get_default_dtype() is torch.float32
|
assert torch.get_default_dtype() is torch.float32
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def blas_library_context(backend):
|
||||||
|
prev_backend = torch.backends.cuda.preferred_blas_library()
|
||||||
|
torch.backends.cuda.preferred_blas_library(backend)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
torch.backends.cuda.preferred_blas_library(prev_backend)
|
||||||
|
|
||||||
class TestMatmulCuda(TestCase):
|
class TestMatmulCuda(TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
|
|
@ -141,8 +150,10 @@ class TestMatmulCuda(TestCase):
|
||||||
torch.float32: xtol(atol=1e-1, rtol=1e-1)})
|
torch.float32: xtol(atol=1e-1, rtol=1e-1)})
|
||||||
@dtypes(torch.float16, torch.bfloat16, torch.float32)
|
@dtypes(torch.float16, torch.bfloat16, torch.float32)
|
||||||
@parametrize("size", [100, 1000, 10000])
|
@parametrize("size", [100, 1000, 10000])
|
||||||
def test_cublas_addmm(self, size: int, dtype: torch.dtype):
|
@parametrize("backend", ["cublas", "cublaslt"])
|
||||||
self.cublas_addmm(size, dtype, False)
|
def test_cublas_addmm(self, size: int, dtype: torch.dtype, backend):
|
||||||
|
with blas_library_context(backend):
|
||||||
|
self.cublas_addmm(size, dtype, False)
|
||||||
|
|
||||||
@onlyCUDA
|
@onlyCUDA
|
||||||
@skipIfRocmVersionLessThan((5, 2))
|
@skipIfRocmVersionLessThan((5, 2))
|
||||||
|
|
@ -151,8 +162,10 @@ class TestMatmulCuda(TestCase):
|
||||||
torch.bfloat16: xtol(atol=1e1, rtol=2e-1)})
|
torch.bfloat16: xtol(atol=1e1, rtol=2e-1)})
|
||||||
@dtypes(torch.float16, torch.bfloat16)
|
@dtypes(torch.float16, torch.bfloat16)
|
||||||
@parametrize("size", [100, 1000, 10000])
|
@parametrize("size", [100, 1000, 10000])
|
||||||
def test_cublas_addmm_reduced_precision(self, size: int, dtype: torch.dtype):
|
@parametrize("backend", ["cublas", "cublaslt"])
|
||||||
self.cublas_addmm(size, dtype, True)
|
def test_cublas_addmm_reduced_precision(self, size: int, dtype: torch.dtype, backend):
|
||||||
|
with blas_library_context(backend):
|
||||||
|
self.cublas_addmm(size, dtype, True)
|
||||||
|
|
||||||
@onlyCUDA
|
@onlyCUDA
|
||||||
@skipIfRocmVersionLessThan((5, 2))
|
@skipIfRocmVersionLessThan((5, 2))
|
||||||
|
|
@ -161,8 +174,10 @@ class TestMatmulCuda(TestCase):
|
||||||
torch.bfloat16: xtol(atol=1e1, rtol=2e-1)})
|
torch.bfloat16: xtol(atol=1e1, rtol=2e-1)})
|
||||||
@dtypes(torch.float16, torch.bfloat16)
|
@dtypes(torch.float16, torch.bfloat16)
|
||||||
@parametrize("size", [100, 1000, 10000])
|
@parametrize("size", [100, 1000, 10000])
|
||||||
def test_cublas_addmm_reduced_precision_fp16_accumulate(self, size: int, dtype: torch.dtype):
|
@parametrize("backend", ["cublas", "cublaslt"])
|
||||||
self.cublas_addmm(size, dtype, False, True)
|
def test_cublas_addmm_reduced_precision_fp16_accumulate(self, size: int, dtype: torch.dtype, backend):
|
||||||
|
with blas_library_context(backend):
|
||||||
|
self.cublas_addmm(size, dtype, False, True)
|
||||||
|
|
||||||
@onlyCUDA
|
@onlyCUDA
|
||||||
@skipIfRocm
|
@skipIfRocm
|
||||||
|
|
@ -456,49 +471,48 @@ class TestMatmulCuda(TestCase):
|
||||||
def test_mm_bmm_dtype_overload(self, input_dtype, M, N, K, batch_size, backend):
|
def test_mm_bmm_dtype_overload(self, input_dtype, M, N, K, batch_size, backend):
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
dtype = input_dtype
|
dtype = input_dtype
|
||||||
torch.backends.cuda.preferred_blas_library(backend)
|
with blas_library_context(backend):
|
||||||
|
def create_inputs(B=None):
|
||||||
def create_inputs(B=None):
|
if B is None:
|
||||||
if B is None:
|
a = torch.randn(M, K, device=device, dtype=dtype)
|
||||||
a = torch.randn(M, K, device=device, dtype=dtype)
|
b = torch.randn(K, N, device=device, dtype=dtype)
|
||||||
b = torch.randn(K, N, device=device, dtype=dtype)
|
|
||||||
else:
|
|
||||||
a = torch.randn(B, M, K, device=device, dtype=dtype)
|
|
||||||
b = torch.randn(B, K, N, device=device, dtype=dtype)
|
|
||||||
return a, b
|
|
||||||
|
|
||||||
a, b = create_inputs(batch_size)
|
|
||||||
|
|
||||||
a_fp32, b_fp32 = a.to(torch.float32), b.to(torch.float32)
|
|
||||||
|
|
||||||
output_dtypes = [torch.float32]
|
|
||||||
|
|
||||||
if input_dtype != torch.float32:
|
|
||||||
output_dtypes.append(input_dtype)
|
|
||||||
|
|
||||||
for output_dtype in output_dtypes:
|
|
||||||
# Catch edge case of incompat with bfloat16 and major version < 8
|
|
||||||
if input_dtype == torch.bfloat16 and not PLATFORM_SUPPORTS_BF16:
|
|
||||||
if output_dtype == torch.bfloat16:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if batch_size:
|
|
||||||
with self.assertRaises(RuntimeError):
|
|
||||||
torch.bmm(a, b, out_dtype=output_dtype)
|
|
||||||
else:
|
else:
|
||||||
with self.assertRaises(RuntimeError):
|
a = torch.randn(B, M, K, device=device, dtype=dtype)
|
||||||
torch.mm(a, b, out_dtype=output_dtype)
|
b = torch.randn(B, K, N, device=device, dtype=dtype)
|
||||||
else:
|
return a, b
|
||||||
if batch_size:
|
|
||||||
out = torch.bmm(a, b, out_dtype=output_dtype)
|
a, b = create_inputs(batch_size)
|
||||||
baseline = torch.bmm(a_fp32, b_fp32) if output_dtype == torch.float32 else torch.bmm(a, b)
|
|
||||||
|
a_fp32, b_fp32 = a.to(torch.float32), b.to(torch.float32)
|
||||||
|
|
||||||
|
output_dtypes = [torch.float32]
|
||||||
|
|
||||||
|
if input_dtype != torch.float32:
|
||||||
|
output_dtypes.append(input_dtype)
|
||||||
|
|
||||||
|
for output_dtype in output_dtypes:
|
||||||
|
# Catch edge case of incompat with bfloat16 and major version < 8
|
||||||
|
if input_dtype == torch.bfloat16 and not PLATFORM_SUPPORTS_BF16:
|
||||||
|
if output_dtype == torch.bfloat16:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if batch_size:
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
torch.bmm(a, b, out_dtype=output_dtype)
|
||||||
|
else:
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
torch.mm(a, b, out_dtype=output_dtype)
|
||||||
else:
|
else:
|
||||||
out = torch.mm(a, b, out_dtype=output_dtype)
|
if batch_size:
|
||||||
baseline = torch.mm(a_fp32, b_fp32) if output_dtype == torch.float32 else torch.mm(a, b)
|
out = torch.bmm(a, b, out_dtype=output_dtype)
|
||||||
|
baseline = torch.bmm(a_fp32, b_fp32) if output_dtype == torch.float32 else torch.bmm(a, b)
|
||||||
|
else:
|
||||||
|
out = torch.mm(a, b, out_dtype=output_dtype)
|
||||||
|
baseline = torch.mm(a_fp32, b_fp32) if output_dtype == torch.float32 else torch.mm(a, b)
|
||||||
|
|
||||||
self.assertEqual(out.dtype, output_dtype)
|
self.assertEqual(out.dtype, output_dtype)
|
||||||
|
|
||||||
torch.testing.assert_close(out, baseline, atol=1e-3, rtol=1e-3)
|
torch.testing.assert_close(out, baseline, atol=1e-3, rtol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
@onlyCUDA
|
@onlyCUDA
|
||||||
|
|
@ -512,51 +526,56 @@ class TestMatmulCuda(TestCase):
|
||||||
def test_addmm_baddmm_dtype_overload(self, input_dtype, M, N, K, batch_size, backend):
|
def test_addmm_baddmm_dtype_overload(self, input_dtype, M, N, K, batch_size, backend):
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
dtype = input_dtype
|
dtype = input_dtype
|
||||||
torch.backends.cuda.preferred_blas_library(backend)
|
with blas_library_context(backend):
|
||||||
|
def create_inputs(B=None):
|
||||||
def create_inputs(B=None):
|
if B is None:
|
||||||
if B is None:
|
a = torch.randn(M, K, device=device, dtype=dtype)
|
||||||
a = torch.randn(M, K, device=device, dtype=dtype)
|
b = torch.randn(K, N, device=device, dtype=dtype)
|
||||||
b = torch.randn(K, N, device=device, dtype=dtype)
|
c = torch.randn(M, N, device=device, dtype=dtype)
|
||||||
c = torch.randn(M, N, device=device, dtype=dtype)
|
|
||||||
else:
|
|
||||||
a = torch.randn(B, M, K, device=device, dtype=dtype)
|
|
||||||
b = torch.randn(B, K, N, device=device, dtype=dtype)
|
|
||||||
c = torch.randn(B, M, N, device=device, dtype=dtype)
|
|
||||||
|
|
||||||
return a, b, c
|
|
||||||
|
|
||||||
a, b, c = create_inputs(batch_size)
|
|
||||||
|
|
||||||
a_fp32, b_fp32, c_fp32 = a.to(torch.float32), b.to(torch.float32), c.to(torch.float32)
|
|
||||||
|
|
||||||
output_dtypes = [torch.float32]
|
|
||||||
|
|
||||||
if input_dtype != torch.float32:
|
|
||||||
output_dtypes.append(input_dtype)
|
|
||||||
|
|
||||||
for output_dtype in output_dtypes:
|
|
||||||
# Catch edge case of incompat with bfloat16 and major version < 8
|
|
||||||
if input_dtype == torch.bfloat16 and not PLATFORM_SUPPORTS_BF16:
|
|
||||||
if output_dtype == torch.bfloat16:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if batch_size:
|
|
||||||
with self.assertRaises(RuntimeError):
|
|
||||||
torch.baddbmm(c, a, b, out_dtype=output_dtype)
|
|
||||||
else:
|
else:
|
||||||
with self.assertRaises(RuntimeError):
|
a = torch.randn(B, M, K, device=device, dtype=dtype)
|
||||||
torch.addmm(c, a, b, out_dtype=output_dtype)
|
b = torch.randn(B, K, N, device=device, dtype=dtype)
|
||||||
else:
|
c = torch.randn(B, M, N, device=device, dtype=dtype)
|
||||||
if batch_size:
|
|
||||||
out = torch.baddbmm(c, a, b, out_dtype=output_dtype)
|
|
||||||
baseline = torch.baddbmm(c_fp32, a_fp32, b_fp32) if output_dtype == torch.float32 else torch.baddbmm(c, a, b)
|
|
||||||
else:
|
|
||||||
out = torch.addmm(c, a, b, out_dtype=output_dtype)
|
|
||||||
baseline = torch.addmm(c_fp32, a_fp32, b_fp32) if output_dtype == torch.float32 else torch.addmm(c, a, b)
|
|
||||||
|
|
||||||
self.assertEqual(out.dtype, output_dtype)
|
return a, b, c
|
||||||
torch.testing.assert_close(out, baseline, atol=1e-3, rtol=1e-3)
|
|
||||||
|
a, b, c = create_inputs(batch_size)
|
||||||
|
|
||||||
|
a_fp32, b_fp32, c_fp32 = a.to(torch.float32), b.to(torch.float32), c.to(torch.float32)
|
||||||
|
|
||||||
|
output_dtypes = [torch.float32]
|
||||||
|
|
||||||
|
if input_dtype != torch.float32:
|
||||||
|
output_dtypes.append(input_dtype)
|
||||||
|
|
||||||
|
for output_dtype in output_dtypes:
|
||||||
|
# Catch edge case of incompat with bfloat16 and major version < 8
|
||||||
|
if input_dtype == torch.bfloat16 and not PLATFORM_SUPPORTS_BF16:
|
||||||
|
if output_dtype == torch.bfloat16:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if batch_size:
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
torch.baddbmm(c, a, b, out_dtype=output_dtype)
|
||||||
|
else:
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
torch.addmm(c, a, b, out_dtype=output_dtype)
|
||||||
|
else:
|
||||||
|
if batch_size:
|
||||||
|
out = torch.baddbmm(c, a, b, out_dtype=output_dtype)
|
||||||
|
if output_dtype == torch.float32:
|
||||||
|
baseline = torch.baddbmm(c_fp32, a_fp32, b_fp32)
|
||||||
|
else:
|
||||||
|
baseline = torch.baddbmm(c, a, b)
|
||||||
|
else:
|
||||||
|
out = torch.addmm(c, a, b, out_dtype=output_dtype)
|
||||||
|
if output_dtype == torch.float32:
|
||||||
|
baseline = torch.addmm(c_fp32, a_fp32, b_fp32)
|
||||||
|
else:
|
||||||
|
baseline = torch.addmm(c, a, b)
|
||||||
|
|
||||||
|
self.assertEqual(out.dtype, output_dtype)
|
||||||
|
torch.testing.assert_close(out, baseline, atol=1e-3, rtol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
@onlyCUDA
|
@onlyCUDA
|
||||||
|
|
@ -567,35 +586,36 @@ class TestMatmulCuda(TestCase):
|
||||||
M, N, K = 32, 32, 32
|
M, N, K = 32, 32, 32
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
torch.backends.cuda.preferred_blas_library(backend)
|
with blas_library_context(backend):
|
||||||
|
torch.backends.cuda.preferred_blas_library(backend)
|
||||||
|
|
||||||
orig_fp16_accum = torch.backends.cuda.matmul.allow_fp16_accumulation
|
orig_fp16_accum = torch.backends.cuda.matmul.allow_fp16_accumulation
|
||||||
torch.backends.cuda.matmul.allow_fp16_accumulation = True
|
torch.backends.cuda.matmul.allow_fp16_accumulation = True
|
||||||
|
|
||||||
def create_inputs():
|
def create_inputs():
|
||||||
a = torch.randn(M, K, device=device, dtype=dtype)
|
a = torch.randn(M, K, device=device, dtype=dtype)
|
||||||
b = torch.randn(K, N, device=device, dtype=dtype)
|
b = torch.randn(K, N, device=device, dtype=dtype)
|
||||||
c = torch.randn(M, N, device=device, dtype=dtype)
|
c = torch.randn(M, N, device=device, dtype=dtype)
|
||||||
return a, b, c
|
return a, b, c
|
||||||
|
|
||||||
def expand(tensor):
|
def expand(tensor):
|
||||||
return tensor.unsqueeze(0).expand(batch_size, *tensor.shape)
|
return tensor.unsqueeze(0).expand(batch_size, *tensor.shape)
|
||||||
|
|
||||||
a, b, c = create_inputs()
|
a, b, c = create_inputs()
|
||||||
|
|
||||||
with self.assertRaises(Exception):
|
with self.assertRaises(Exception):
|
||||||
torch.baddbmm(expand(c), expand(a), expand(b), out_dtype=torch.float32)
|
torch.baddbmm(expand(c), expand(a), expand(b), out_dtype=torch.float32)
|
||||||
|
|
||||||
with self.assertRaises(Exception):
|
with self.assertRaises(Exception):
|
||||||
torch.addmm(c, a, b, out_dtype=torch.float32)
|
torch.addmm(c, a, b, out_dtype=torch.float32)
|
||||||
|
|
||||||
with self.assertRaises(Exception):
|
with self.assertRaises(Exception):
|
||||||
torch.bmm(expand(a,), expand(b), out_dtype=torch.float32)
|
torch.bmm(expand(a,), expand(b), out_dtype=torch.float32)
|
||||||
|
|
||||||
with self.assertRaises(Exception):
|
with self.assertRaises(Exception):
|
||||||
torch.mm(a, b, out_dtype=torch.float32)
|
torch.mm(a, b, out_dtype=torch.float32)
|
||||||
|
|
||||||
torch.backends.cuda.matmul.allow_fp16_accumulation = orig_fp16_accum
|
torch.backends.cuda.matmul.allow_fp16_accumulation = orig_fp16_accum
|
||||||
|
|
||||||
f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+ devices"
|
f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+ devices"
|
||||||
mx_skip_msg = "MX gemm is only supported on CUDA capability 10.0+"
|
mx_skip_msg = "MX gemm is only supported on CUDA capability 10.0+"
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user