diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 9cd9af90c97..593c78f74d4 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -62,6 +62,15 @@ if TEST_CUDA: assert torch.get_default_dtype() is torch.float32 +@contextlib.contextmanager +def blas_library_context(backend): + prev_backend = torch.backends.cuda.preferred_blas_library() + torch.backends.cuda.preferred_blas_library(backend) + try: + yield + finally: + torch.backends.cuda.preferred_blas_library(prev_backend) + class TestMatmulCuda(TestCase): def setUp(self): super().setUp() @@ -141,8 +150,10 @@ class TestMatmulCuda(TestCase): torch.float32: xtol(atol=1e-1, rtol=1e-1)}) @dtypes(torch.float16, torch.bfloat16, torch.float32) @parametrize("size", [100, 1000, 10000]) - def test_cublas_addmm(self, size: int, dtype: torch.dtype): - self.cublas_addmm(size, dtype, False) + @parametrize("backend", ["cublas", "cublaslt"]) + def test_cublas_addmm(self, size: int, dtype: torch.dtype, backend): + with blas_library_context(backend): + self.cublas_addmm(size, dtype, False) @onlyCUDA @skipIfRocmVersionLessThan((5, 2)) @@ -151,31 +162,31 @@ class TestMatmulCuda(TestCase): torch.bfloat16: xtol(atol=1e1, rtol=2e-1)}) @dtypes(torch.float16, torch.bfloat16) @parametrize("size", [100, 1000, 10000]) - def test_cublas_addmm_reduced_precision(self, size: int, dtype: torch.dtype): - self.cublas_addmm(size, dtype, True) + @parametrize("backend", ["cublas", "cublaslt"]) + def test_cublas_addmm_reduced_precision(self, size: int, dtype: torch.dtype, backend): + with blas_library_context(backend): + self.cublas_addmm(size, dtype, True) @onlyCUDA @skipIfRocmVersionLessThan((5, 2)) @dtypes(torch.float16) # m == 4 chooses OUTPUT_TYPE reduction on H200 - # m == 8 chooses OUTOUT_TYPE reduction on A100 + # m == 8 chooses OUTPUT_TYPE reduction on A100 @parametrize("small_size", [4, 8]) @parametrize("size", [32768]) @parametrize("backend", ["cublaslt", "cublas"]) def test_cublas_addmm_no_reduced_precision(self, small_size: int, size: int, dtype: torch.dtype, backend): - # TODO(eqy): replace with contextlib once that is merged - orig = torch.backends.cuda.preferred_blas_library() - torch.backends.cuda.preferred_blas_library(backend) - orig_precision = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction - torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False - m1 = torch.full((small_size, size), 65504.0, dtype=dtype, device='cuda') - m2 = torch.ones((size, small_size), dtype=dtype, device='cuda') - m2[size // 2:, :] = -1.0 - b = torch.zeros((small_size,), dtype=dtype, device='cuda') - out = torch.addmm(b, m1, m2, beta=1.0) - self.assertEqual(out.sum().item(), 0.0) - torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig_precision - torch.backends.cuda.preferred_blas_library(orig) + with blas_library_context(backend): + torch.backends.cuda.preferred_blas_library(backend) + orig_precision = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False + m1 = torch.full((small_size, size), 65504.0, dtype=dtype, device='cuda') + m2 = torch.ones((size, small_size), dtype=dtype, device='cuda') + m2[size // 2:, :] = -1.0 + b = torch.zeros((small_size,), dtype=dtype, device='cuda') + out = torch.addmm(b, m1, m2, beta=1.0) + self.assertEqual(out.sum().item(), 0.0) + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig_precision @onlyCUDA @skipIfRocmVersionLessThan((5, 2)) @@ -184,8 +195,10 @@ class TestMatmulCuda(TestCase): torch.bfloat16: xtol(atol=1e1, rtol=2e-1)}) @dtypes(torch.float16, torch.bfloat16) @parametrize("size", [100, 1000, 10000]) - def test_cublas_addmm_reduced_precision_fp16_accumulate(self, size: int, dtype: torch.dtype): - self.cublas_addmm(size, dtype, False, True) + @parametrize("backend", ["cublas", "cublaslt"]) + def test_cublas_addmm_reduced_precision_fp16_accumulate(self, size: int, dtype: torch.dtype, backend): + with blas_library_context(backend): + self.cublas_addmm(size, dtype, False, True) @onlyCUDA @skipIfRocm @@ -479,49 +492,48 @@ class TestMatmulCuda(TestCase): def test_mm_bmm_dtype_overload(self, input_dtype, M, N, K, batch_size, backend): device = "cuda" dtype = input_dtype - torch.backends.cuda.preferred_blas_library(backend) - - def create_inputs(B=None): - if B is None: - a = torch.randn(M, K, device=device, dtype=dtype) - b = torch.randn(K, N, device=device, dtype=dtype) - else: - a = torch.randn(B, M, K, device=device, dtype=dtype) - b = torch.randn(B, K, N, device=device, dtype=dtype) - return a, b - - a, b = create_inputs(batch_size) - - a_fp32, b_fp32 = a.to(torch.float32), b.to(torch.float32) - - output_dtypes = [torch.float32] - - if input_dtype != torch.float32: - output_dtypes.append(input_dtype) - - for output_dtype in output_dtypes: - # Catch edge case of incompat with bfloat16 and major version < 8 - if input_dtype == torch.bfloat16 and not PLATFORM_SUPPORTS_BF16: - if output_dtype == torch.bfloat16: - continue - - if batch_size: - with self.assertRaises(RuntimeError): - torch.bmm(a, b, out_dtype=output_dtype) + with blas_library_context(backend): + def create_inputs(B=None): + if B is None: + a = torch.randn(M, K, device=device, dtype=dtype) + b = torch.randn(K, N, device=device, dtype=dtype) else: - with self.assertRaises(RuntimeError): - torch.mm(a, b, out_dtype=output_dtype) - else: - if batch_size: - out = torch.bmm(a, b, out_dtype=output_dtype) - baseline = torch.bmm(a_fp32, b_fp32) if output_dtype == torch.float32 else torch.bmm(a, b) + a = torch.randn(B, M, K, device=device, dtype=dtype) + b = torch.randn(B, K, N, device=device, dtype=dtype) + return a, b + + a, b = create_inputs(batch_size) + + a_fp32, b_fp32 = a.to(torch.float32), b.to(torch.float32) + + output_dtypes = [torch.float32] + + if input_dtype != torch.float32: + output_dtypes.append(input_dtype) + + for output_dtype in output_dtypes: + # Catch edge case of incompat with bfloat16 and major version < 8 + if input_dtype == torch.bfloat16 and not PLATFORM_SUPPORTS_BF16: + if output_dtype == torch.bfloat16: + continue + + if batch_size: + with self.assertRaises(RuntimeError): + torch.bmm(a, b, out_dtype=output_dtype) + else: + with self.assertRaises(RuntimeError): + torch.mm(a, b, out_dtype=output_dtype) else: - out = torch.mm(a, b, out_dtype=output_dtype) - baseline = torch.mm(a_fp32, b_fp32) if output_dtype == torch.float32 else torch.mm(a, b) + if batch_size: + out = torch.bmm(a, b, out_dtype=output_dtype) + baseline = torch.bmm(a_fp32, b_fp32) if output_dtype == torch.float32 else torch.bmm(a, b) + else: + out = torch.mm(a, b, out_dtype=output_dtype) + baseline = torch.mm(a_fp32, b_fp32) if output_dtype == torch.float32 else torch.mm(a, b) - self.assertEqual(out.dtype, output_dtype) + self.assertEqual(out.dtype, output_dtype) - torch.testing.assert_close(out, baseline, atol=1e-3, rtol=1e-3) + torch.testing.assert_close(out, baseline, atol=1e-3, rtol=1e-3) @onlyCUDA @@ -535,51 +547,56 @@ class TestMatmulCuda(TestCase): def test_addmm_baddmm_dtype_overload(self, input_dtype, M, N, K, batch_size, backend): device = "cuda" dtype = input_dtype - torch.backends.cuda.preferred_blas_library(backend) - - def create_inputs(B=None): - if B is None: - a = torch.randn(M, K, device=device, dtype=dtype) - b = torch.randn(K, N, device=device, dtype=dtype) - c = torch.randn(M, N, device=device, dtype=dtype) - else: - a = torch.randn(B, M, K, device=device, dtype=dtype) - b = torch.randn(B, K, N, device=device, dtype=dtype) - c = torch.randn(B, M, N, device=device, dtype=dtype) - - return a, b, c - - a, b, c = create_inputs(batch_size) - - a_fp32, b_fp32, c_fp32 = a.to(torch.float32), b.to(torch.float32), c.to(torch.float32) - - output_dtypes = [torch.float32] - - if input_dtype != torch.float32: - output_dtypes.append(input_dtype) - - for output_dtype in output_dtypes: - # Catch edge case of incompat with bfloat16 and major version < 8 - if input_dtype == torch.bfloat16 and not PLATFORM_SUPPORTS_BF16: - if output_dtype == torch.bfloat16: - continue - - if batch_size: - with self.assertRaises(RuntimeError): - torch.baddbmm(c, a, b, out_dtype=output_dtype) + with blas_library_context(backend): + def create_inputs(B=None): + if B is None: + a = torch.randn(M, K, device=device, dtype=dtype) + b = torch.randn(K, N, device=device, dtype=dtype) + c = torch.randn(M, N, device=device, dtype=dtype) else: - with self.assertRaises(RuntimeError): - torch.addmm(c, a, b, out_dtype=output_dtype) - else: - if batch_size: - out = torch.baddbmm(c, a, b, out_dtype=output_dtype) - baseline = torch.baddbmm(c_fp32, a_fp32, b_fp32) if output_dtype == torch.float32 else torch.baddbmm(c, a, b) - else: - out = torch.addmm(c, a, b, out_dtype=output_dtype) - baseline = torch.addmm(c_fp32, a_fp32, b_fp32) if output_dtype == torch.float32 else torch.addmm(c, a, b) + a = torch.randn(B, M, K, device=device, dtype=dtype) + b = torch.randn(B, K, N, device=device, dtype=dtype) + c = torch.randn(B, M, N, device=device, dtype=dtype) - self.assertEqual(out.dtype, output_dtype) - torch.testing.assert_close(out, baseline, atol=1e-3, rtol=1e-3) + return a, b, c + + a, b, c = create_inputs(batch_size) + + a_fp32, b_fp32, c_fp32 = a.to(torch.float32), b.to(torch.float32), c.to(torch.float32) + + output_dtypes = [torch.float32] + + if input_dtype != torch.float32: + output_dtypes.append(input_dtype) + + for output_dtype in output_dtypes: + # Catch edge case of incompat with bfloat16 and major version < 8 + if input_dtype == torch.bfloat16 and not PLATFORM_SUPPORTS_BF16: + if output_dtype == torch.bfloat16: + continue + + if batch_size: + with self.assertRaises(RuntimeError): + torch.baddbmm(c, a, b, out_dtype=output_dtype) + else: + with self.assertRaises(RuntimeError): + torch.addmm(c, a, b, out_dtype=output_dtype) + else: + if batch_size: + out = torch.baddbmm(c, a, b, out_dtype=output_dtype) + if output_dtype == torch.float32: + baseline = torch.baddbmm(c_fp32, a_fp32, b_fp32) + else: + baseline = torch.baddbmm(c, a, b) + else: + out = torch.addmm(c, a, b, out_dtype=output_dtype) + if output_dtype == torch.float32: + baseline = torch.addmm(c_fp32, a_fp32, b_fp32) + else: + baseline = torch.addmm(c, a, b) + + self.assertEqual(out.dtype, output_dtype) + torch.testing.assert_close(out, baseline, atol=1e-3, rtol=1e-3) @onlyCUDA @@ -590,35 +607,36 @@ class TestMatmulCuda(TestCase): M, N, K = 32, 32, 32 device = "cuda" dtype = torch.float16 - torch.backends.cuda.preferred_blas_library(backend) + with blas_library_context(backend): + torch.backends.cuda.preferred_blas_library(backend) - orig_fp16_accum = torch.backends.cuda.matmul.allow_fp16_accumulation - torch.backends.cuda.matmul.allow_fp16_accumulation = True + orig_fp16_accum = torch.backends.cuda.matmul.allow_fp16_accumulation + torch.backends.cuda.matmul.allow_fp16_accumulation = True - def create_inputs(): - a = torch.randn(M, K, device=device, dtype=dtype) - b = torch.randn(K, N, device=device, dtype=dtype) - c = torch.randn(M, N, device=device, dtype=dtype) - return a, b, c + def create_inputs(): + a = torch.randn(M, K, device=device, dtype=dtype) + b = torch.randn(K, N, device=device, dtype=dtype) + c = torch.randn(M, N, device=device, dtype=dtype) + return a, b, c - def expand(tensor): - return tensor.unsqueeze(0).expand(batch_size, *tensor.shape) + def expand(tensor): + return tensor.unsqueeze(0).expand(batch_size, *tensor.shape) - a, b, c = create_inputs() + a, b, c = create_inputs() - with self.assertRaises(Exception): - torch.baddbmm(expand(c), expand(a), expand(b), out_dtype=torch.float32) + with self.assertRaises(Exception): + torch.baddbmm(expand(c), expand(a), expand(b), out_dtype=torch.float32) - with self.assertRaises(Exception): - torch.addmm(c, a, b, out_dtype=torch.float32) + with self.assertRaises(Exception): + torch.addmm(c, a, b, out_dtype=torch.float32) - with self.assertRaises(Exception): - torch.bmm(expand(a,), expand(b), out_dtype=torch.float32) + with self.assertRaises(Exception): + torch.bmm(expand(a,), expand(b), out_dtype=torch.float32) - with self.assertRaises(Exception): - torch.mm(a, b, out_dtype=torch.float32) + with self.assertRaises(Exception): + torch.mm(a, b, out_dtype=torch.float32) - torch.backends.cuda.matmul.allow_fp16_accumulation = orig_fp16_accum + torch.backends.cuda.matmul.allow_fp16_accumulation = orig_fp16_accum f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+ devices" mx_skip_msg = "MX gemm is only supported on CUDA capability 10.0+"