# Owner(s): ["module: linear algebra"] import contextlib import json import math import re import tempfile import unittest from itertools import product from functools import partial from typing import Optional import torch from torch.quantization._quantized_conversions import ( pack_int4_to_int8, quantized_weight_reorder_for_mixed_dtypes_linear_cutlass, ) from torch.testing import make_tensor from torch.testing._internal.common_cuda import ( SM53OrLater, SM89OrLater, SM90OrLater, _get_torch_cuda_version, PLATFORM_SUPPORTS_FP8, PLATFORM_SUPPORTS_MX_GEMM ) from torch.testing._internal.common_device_type import ( dtypes, instantiate_device_type_tests, onlyCUDA, tol as xtol, toleranceOverride, ) from torch.testing._internal.common_utils import ( IS_ARM64, IS_JETSON, IS_WINDOWS, parametrize, run_tests, skipIfRocm, skipIfRocmVersionLessThan, TEST_CUDA, TEST_WITH_ROCM, TestCase, ) _IS_SM8X = False if TEST_CUDA: _IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8 # Protects against includes accidentally setting the default dtype assert torch.get_default_dtype() is torch.float32 @unittest.skipIf(IS_ARM64, "Issue with numpy version on arm") class TestMatmulCuda(TestCase): def setUp(self): super(self.__class__, self).setUp() torch.backends.cuda.matmul.allow_tf32 = False def tearDown(self): torch.backends.cuda.matmul.allow_tf32 = True super(self.__class__, self).tearDown() def cublas_addmm(self, size: int, dtype: torch.dtype, reduced_precision: bool = False, fp16_accumulate: bool = False): # # Check for catastrophic cuBLAS inaccuracy by measuring the deviation between # results from the CUDA invocation of torch.addmm and the CPU invocation # (which does not use CUDA backend). # # Get dims n, m, p = (size + 1, size, size + 2) # Disable reduced precision reductions in BFloat16 to bypass some kernels # which fail the threshold check orig_bf16 = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction orig_fp16 = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction orig_fp16_accumulate = torch.backends.cuda.matmul.allow_fp16_accumulation torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = reduced_precision torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = reduced_precision torch.backends.cuda.matmul.allow_fp16_accumulation = fp16_accumulate # Make random tensors on CPU (seed set on common_utils.py import) # (Not using numpy because it does not support bfloat16) make_arg = partial(make_tensor, dtype=dtype, device="cpu") m_beta = make_arg(1) m_input = make_arg((n, p)) m_1 = make_arg((n, m)) m_2 = make_arg((m, p)) # scale to abate overflows in fp16 accum if fp16_accumulate: m_1 = m_1 / 100 m_2 = m_2 / 100 # *(B)FLOAT16 Special Handling* # Backend does not tensorize float16 on CPU, # and bloat16 may present accuracy issues, # so convert to float32 for these cases # (but keep same for other types, e.g. float32 and int*) if dtype == torch.float16 or dtype == torch.bfloat16: m_beta = m_beta.to(dtype=torch.float32) m_input = m_input.to(dtype=torch.float32) m_1 = m_1.to(dtype=torch.float32) m_2 = m_2.to(dtype=torch.float32) # Get CPU result res_cpu = torch.addmm(m_input, m_1, m_2, beta=m_beta.item()) # *(B)FLOAT16 Special Handling*`` # Convert back to (b)float16 if dtype == torch.float16 or dtype == torch.bfloat16: m_beta = m_beta.to(dtype=dtype) m_input = m_input.to(dtype=dtype) m_1 = m_1.to(dtype=dtype) m_2 = m_2.to(dtype=dtype) res_cpu = res_cpu.to(dtype=dtype) # Move arg tensors to CUDA m_beta = m_beta.to("cuda") m_input = m_input.to("cuda") m_1 = m_1.to("cuda") m_2 = m_2.to("cuda") # Get CUDA result res_cuda = torch.addmm(m_input, m_1, m_2, beta=m_beta.item()) # Move to CPU for comparison res_cuda = res_cuda.to("cpu") # Compare self.assertEqual(res_cpu, res_cuda) torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = orig_bf16 torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig_fp16 torch.backends.cuda.matmul.allow_fp16_accumulation = orig_fp16_accumulate @onlyCUDA @skipIfRocmVersionLessThan((5, 2)) # imported 'tol' as 'xtol' to avoid aliasing in code above @toleranceOverride({torch.float16: xtol(atol=1e-1, rtol=1e-1), torch.bfloat16: xtol(atol=1e-1, rtol=1e-1), torch.float32: xtol(atol=1e-1, rtol=1e-1)}) @dtypes(torch.float16, torch.bfloat16, torch.float32) @parametrize("size", [100, 1000, 10000]) def test_cublas_addmm(self, size: int, dtype: torch.dtype): self.cublas_addmm(size, dtype, False) @onlyCUDA @skipIfRocmVersionLessThan((5, 2)) # imported 'tol' as 'xtol' to avoid aliasing in code above @toleranceOverride({torch.float16: xtol(atol=7e-1, rtol=2e-1), torch.bfloat16: xtol(atol=1e1, rtol=2e-1)}) @dtypes(torch.float16, torch.bfloat16) @parametrize("size", [100, 1000, 10000]) def test_cublas_addmm_reduced_precision(self, size: int, dtype: torch.dtype): self.cublas_addmm(size, dtype, True) @onlyCUDA @skipIfRocmVersionLessThan((5, 2)) # imported 'tol' as 'xtol' to avoid aliasing in code above @toleranceOverride({torch.float16: xtol(atol=7e-1, rtol=2e-1), torch.bfloat16: xtol(atol=1e1, rtol=2e-1)}) @dtypes(torch.float16, torch.bfloat16) @parametrize("size", [100, 1000, 10000]) def test_cublas_addmm_reduced_precision_fp16_accumulate(self, size: int, dtype: torch.dtype): self.cublas_addmm(size, dtype, False, True) @onlyCUDA @skipIfRocm def test_cublas_and_lt_reduced_precision_fp16_accumulate(self): orig_fp16_accumulate = torch.backends.cuda.matmul.allow_fp16_accumulation torch.backends.cuda.matmul.allow_fp16_accumulation = True x = torch.rand(32, 512, 512, device='cuda', dtype=torch.half) w = torch.rand(512, 512, device='cuda', dtype=torch.half) b = torch.rand(512, device='cuda', dtype=torch.half) out = torch.nn.functional.linear(x, w, b) out_cpu = torch.nn.functional.linear(x.cpu(), w.cpu(), b.cpu()) self.assertEqual(out, out_cpu, atol=5e-3, rtol=8e-3) a = torch.rand(16, 128, 128, device='cuda', dtype=torch.half) b = torch.rand(16, 128, 128, device='cuda', dtype=torch.half) c = torch.rand(16, 128, 128, device='cuda', dtype=torch.half) out = torch.baddbmm(a, b, c) out_cpu = torch.baddbmm(a.cpu(), b.cpu(), c.cpu()) self.assertEqual(out, out_cpu, atol=1e-3, rtol=5e-3) torch.backends.cuda.matmul.allow_fp16_accumulation = orig_fp16_accumulate @onlyCUDA @toleranceOverride({torch.float16: xtol(atol=1e-3, rtol=2e-3)}) @dtypes(torch.float16) def test_cublas_addmm_alignment(self, dtype): device = 'cuda' # perturb X, A, or B alignment for idx in range(0, 3): for offset in range(1, 3): offsets = [0, 0, 0] offsets[idx] = offset x_offset, a_offset, b_offset = offsets A = torch.rand((5120 * 2560 + a_offset), requires_grad=True, dtype=dtype, device=device) A = A[a_offset:].reshape(5120, 2560) X = torch.rand((26 * 2560 + x_offset), requires_grad=True, dtype=dtype, device=device) X = X[x_offset:].reshape(26, 1, 2560) B = torch.rand((5120 + b_offset), requires_grad=True, dtype=dtype, device=device) B = B[b_offset:].reshape(5120) out = torch.nn.functional.linear(X, A, B) self.assertEqual(out, torch.matmul(X, A.transpose(1, 0)) + B) @onlyCUDA @unittest.skipIf(IS_JETSON, "Too large for Jetson") @toleranceOverride({torch.float32: xtol(atol=1e-5, rtol=1.1e-5)}) @dtypes(*([torch.float32, torch.float16] + [torch.bfloat16] if TEST_WITH_ROCM or SM53OrLater else [])) @parametrize( "batch_size, N, M, P", [(2, 100, 100, 100), (2, 1000, 1000, 1000), (1, 10000, 1000, 10000), (1, 10000, 10000, 10000)], name_fn=lambda batch_size, N, M, P: f"{batch_size}_{N}_{M}_{P}", ) @skipIfRocm def test_cublas_baddbmm_large_input(self, device, batch_size, N, M, P, dtype): cpu_dtype = dtype if dtype == torch.float16 or dtype == torch.bfloat16: cpu_dtype = torch.float32 M1 = torch.rand((N, M), device=device, dtype=dtype) M2 = torch.rand((M, P), device=device, dtype=dtype) A = torch.rand((N, P), device=device, dtype=dtype) def _convert_to_cpu(t): return t.to(device='cpu', dtype=cpu_dtype) M1_cpu, M2_cpu, A_cpu = map(_convert_to_cpu, [M1, M2, A]) # linear out1_cpu = torch.nn.functional.linear(M1_cpu, M2_cpu.t(), A_cpu).to(dtype=dtype) out1_gpu = torch.nn.functional.linear(M1, M2.t(), A).cpu() self.assertEqual(out1_cpu, out1_gpu) # test multiply the identity matrix if N == M and M == P: M2_eye = torch.eye(N, device=device, dtype=dtype) out1_eye_gpu = torch.nn.functional.linear(M1, M2_eye.t(), torch.zeros_like(A)) self.assertEqual(M1_cpu.to(dtype=dtype), out1_eye_gpu.cpu()) # baddbmm def _expand_to_batch(t: torch.Tensor): return t.expand((batch_size, ) + t.size()) alpha, beta = 1.0, 1.0 M1, M2, A, M1_cpu, M2_cpu, A_cpu = map(_expand_to_batch, [M1, M2, A, M1_cpu, M2_cpu, A_cpu]) out2_cpu = torch.baddbmm(A_cpu, M1_cpu, M2_cpu, beta=beta, alpha=alpha).to(dtype=dtype) out2_gpu = torch.baddbmm(A, M1, M2, beta=beta, alpha=alpha).cpu() self.assertEqual(out2_cpu, out2_gpu) # test multiply the identity matrix if N == M and M == P: M2_eye = torch.eye(N, device=device, dtype=dtype).expand(batch_size, N, N) out2_eye_gpu = torch.baddbmm(torch.zeros_like(A), M1, M2_eye, beta=beta, alpha=alpha) self.assertEqual(M1_cpu.to(dtype=dtype), out2_eye_gpu.cpu()) # cross comparison self.assertEqual(out1_gpu, out2_gpu[0]) f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+ devices" mx_skip_msg = "MX gemm is only supported on CUDA capability 10.0+" if torch.version.hip and 'gfx94' in torch.cuda.get_device_properties(0).gcnArchName: e4m3_type = torch.float8_e4m3fnuz e5m2_type = torch.float8_e5m2fnuz E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max E5M2_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max else: e4m3_type = torch.float8_e4m3fn e5m2_type = torch.float8_e5m2 E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max # avoid division by zero when calculating scale EPS = 1e-12 def amax_to_scale( amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype ): """ Converts the amax value of a tensor to the fp8 scale. Args: amax: The amax value of the tensor. float8_dtype: the float8 dtype. orig_dtype: The original dtype of the tensor. """ scale = torch.empty_like(amax, dtype=torch.float32) if float8_dtype == e4m3_type: res = E4M3_MAX_POS / torch.clamp(amax, min=EPS) elif float8_dtype == e5m2_type: res = E4M3_MAX_POS / torch.clamp(amax, min=EPS) else: raise ValueError(f"Unsupported float8_dtype: {float8_dtype}") # Ensure the scale is representable in float16, # this helps when amax is small. We are assuming that we don't need # to care about this for float32/bfloat16 if orig_dtype is torch.float16: res = torch.clamp(res, max=torch.finfo(torch.float16).max) scale.copy_(res) return scale def tensor_to_scale(x: torch.Tensor, float8_dtype: torch.dtype, dim=None): if dim is None: amax = torch.max(torch.abs(x)) else: amax = torch.max(torch.abs(x), dim=dim, keepdim=True).values return amax_to_scale(amax, float8_dtype, x.dtype) def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor: # naive implementation: dq -> op -> q x_fp32 = x.to(torch.float) / x_scale y_fp32 = y.to(torch.float) / y_scale out_fp32 = torch.mm(x_fp32, y_fp32) return out_fp32.to(out_dtype) def addmm_float8_unwrapped( a_data: torch.Tensor, a_scale: torch.Tensor, b_data: torch.Tensor, b_scale: torch.tensor, output_dtype: torch.dtype, output_scale: Optional[torch.Tensor], bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: a_inverse_scale = a_scale.reciprocal() b_inverse_scale = b_scale.reciprocal() if output_dtype == torch.float32 and bias is not None: # Bias is not supported by _scaled_mm when output is fp32 output = torch._scaled_mm( a_data, b_data, scale_a=a_inverse_scale, scale_b=b_inverse_scale, scale_result=output_scale, out_dtype=output_dtype, ) output += bias return output output = torch._scaled_mm( a_data, b_data, bias=bias, scale_a=a_inverse_scale, scale_b=b_inverse_scale, scale_result=output_scale, out_dtype=output_dtype, ) return output def mm_float8( a: torch.Tensor, b: torch.Tensor, a_scale: torch.Tensor, b_scale: torch.Tensor, output_dtype: torch.dtype, # output dtype output_scale: Optional[torch.Tensor] = None, # output scale, precomputed ) -> torch.Tensor: return addmm_float8_unwrapped( a, a_scale, b, b_scale, output_dtype, output_scale ) def to_fp8_saturated( x: torch.Tensor, fp8_dtype: torch.dtype ): if fp8_dtype == e4m3_type: x = x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS) elif fp8_dtype == e5m2_type: x = x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS) else: raise ValueError(f"to_fp8_saturated(): Unsupported fp8_dtype: {fp8_dtype}") return x.to(fp8_dtype) # copied from https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/mx/to_blocked.py def ceil_div(a, b): return (a + b - 1) // b def to_blocked(input_matrix) -> torch.Tensor: """ Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern. See: https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout Args: input_matrix: Input tensor of shape (H, W) Returns: Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4)) """ rows, cols = input_matrix.shape n_row_blocks = ceil_div(rows, 128) n_col_blocks = ceil_div(cols, 4) # Calculate the padded shape padded_rows = n_row_blocks * 128 padded_cols = n_col_blocks * 4 padded = input_matrix # Ideally we would use torch.nn.pad but it doesn't support float8_e8m0fnu for now if (rows, cols) != (padded_rows, padded_cols): padded = torch.zeros((padded_rows, padded_cols), device=input_matrix.device, dtype=input_matrix.dtype) padded[:rows, :cols] = input_matrix # Rearrange the blocks blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3) rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) return rearranged.flatten() def compute_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """Computes the error between two tensors in dB. For more details see: https://en.wikipedia.org/wiki/Signal-to-noise_ratio Args: x: The original tensor. y: The tensor to compare to the original tensor. """ Ps = torch.norm(x) Pn = torch.norm(x - y) return 20 * torch.log10(Ps / Pn) # largest power of 2 representable in `torch.float8_e4m3fn` F8E4M3_LARGEST_POW2 = 8 # max value of `torch.float8_e4m3fn` (448) F8E4M3_MAX_VAL = torch.finfo(torch.float8_e4m3fn).max # exponent bias of `torch.float8_e8m0fnu` F8E8M0_EXP_BIAS = 127 def data_to_mx_scale(x, block_size): # simple implementation of https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf # section 6.3, not all edge cases (such as NaN) are handled/tested orig_shape = x.shape x = x.reshape(-1, block_size) max_abs = torch.amax(torch.abs(x), 1) largest_p2_lt_max_abs = torch.floor(torch.log2(max_abs)) scale_e8m0_unbiased = largest_p2_lt_max_abs - F8E4M3_LARGEST_POW2 scale_e8m0_unbiased = torch.clamp(scale_e8m0_unbiased, -1 * F8E8M0_EXP_BIAS, F8E8M0_EXP_BIAS) scale_e8m0_biased = scale_e8m0_unbiased + F8E8M0_EXP_BIAS scale_e8m0_biased = scale_e8m0_biased.to(torch.uint8) scale_e8m0_biased = scale_e8m0_biased.view(torch.float8_e8m0fnu) return scale_e8m0_biased.reshape(orig_shape[0], -1) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not found") class TestFP8MatmulCuda(TestCase): @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) def _test_tautological_mm(self, device: str = "cuda", x_dtype: torch.dtype = e4m3_type, y_dtype: torch.dtype = e4m3_type, out_dtype: Optional[torch.dtype] = None, size: int = 16) -> None: x_fp8 = torch.rand(size, size, device=device).to(x_dtype) y_fp8 = torch.eye(size, device=device, dtype=y_dtype).t() out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float)) scale_a = torch.tensor(1.0, device=device) scale_b = torch.tensor(1.0, device=device) out_fp8 = torch._scaled_mm(x_fp8, y_fp8, scale_a, scale_b, out_dtype=out_dtype) if out_dtype is not None: self.assertEqual(out_dtype, out_fp8.dtype) self.assertEqual(out_fp32, out_fp8.to(torch.float)) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) def test_float8_basics(self, device) -> None: self._test_tautological_mm(device, e4m3_type, e4m3_type, size=16) # According to https://docs.nvidia.com/cuda/cublas/#id99 8F_E5M2 MM is unsupported # supported on ROCm but fails on CUDA ctx = self.assertRaises(RuntimeError) if torch.version.hip is None else contextlib.nullcontext() with ctx: self._test_tautological_mm(device, e5m2_type, e5m2_type) self._test_tautological_mm(device, e4m3_type, e5m2_type, size=32) self._test_tautological_mm(device, e5m2_type, e4m3_type, size=48) self._test_tautological_mm(device, size=64, out_dtype=torch.float16) self._test_tautological_mm(device, size=96, out_dtype=torch.float32) self._test_tautological_mm(device, size=80, out_dtype=torch.bfloat16) with self.assertRaises(AssertionError if torch.version.hip else RuntimeError): self._test_tautological_mm(device, out_dtype=e5m2_type) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) def test_float8_scale(self, device) -> None: size = (16, 16) x = torch.full(size, .5, device=device, dtype=e4m3_type) # hipblaslt does not yet support mixed e4m3_type input y_type = e4m3_type if torch.version.hip else e5m2_type y = torch.full(size, .5, device=device, dtype=y_type).t() scale_one = torch.tensor(1.0, device=device) scale_a = torch.tensor(1.5, device=device) scale_b = torch.tensor(0.66, device=device) out_fp8 = torch._scaled_mm(x, y, scale_a=scale_one, scale_b=scale_one) self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4., device=device)) out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b) self.assertEqual(out_fp8, out_fp8_s) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("base_dtype", [torch.float16, torch.bfloat16, torch.float32]) def test_scaled_mm_vs_emulated(self, base_dtype): torch.manual_seed(42) input_dtype = e4m3_type output_dtype = base_dtype compare_type = torch.float32 x = torch.randn(16, 16, device="cuda", dtype=base_dtype) y = torch.randn(32, 16, device="cuda", dtype=base_dtype).t() x_scale = tensor_to_scale(x, input_dtype).float() y_scale = tensor_to_scale(y, input_dtype).float() x_fp8 = to_fp8_saturated(x * x_scale, input_dtype) y_fp8 = to_fp8_saturated(y * y_scale, input_dtype) # Calculate actual F8 mm out_scaled_mm = mm_float8( x_fp8, y_fp8, a_scale=x_scale, b_scale=y_scale, output_dtype=output_dtype ) # Calculate emulated F8 mm out_emulated = mm_float8_emulated( x_fp8, x_scale, y_fp8, y_scale, output_dtype ) if output_dtype != base_dtype: out_scaled_mm = out_scaled_mm.to(compare_type) out_scaled_mm = out_scaled_mm / tensor_to_scale(out_scaled_mm, input_dtype) out_emulated = out_emulated.to(compare_type) out_emulated = out_emulated / tensor_to_scale(out_emulated, input_dtype) if base_dtype in {torch.bfloat16, torch.float16}: atol, rtol = 7e-2, 7e-2 else: atol, rtol = 3e-3, 3e-3 torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("base_dtype", [torch.float16, torch.bfloat16, torch.float32]) def test_scaled_mm_change_stride(self, base_dtype): torch.manual_seed(42) input_dtype = e4m3_type output_dtype = base_dtype compare_type = torch.float32 x = torch.empty_strided((16, 16), (16, 1), device="cuda", dtype=base_dtype) y = torch.empty_strided((16, 32), (1, 64), device="cuda", dtype=base_dtype) x.normal_() y.normal_() x_scale = tensor_to_scale(x, input_dtype).float() y_scale = tensor_to_scale(y, input_dtype).float() x_fp8 = to_fp8_saturated(x * x_scale, input_dtype) y_fp8 = to_fp8_saturated(y * y_scale, input_dtype) # Calculate actual F8 mm out_scaled_mm = mm_float8( x_fp8, y_fp8, a_scale=x_scale, b_scale=y_scale, output_dtype=output_dtype ) # Calculate emulated F8 mm out_emulated = mm_float8_emulated( x_fp8, x_scale, y_fp8, y_scale, output_dtype ) if output_dtype != base_dtype: out_scaled_mm = out_scaled_mm.to(compare_type) out_scaled_mm = out_scaled_mm / tensor_to_scale(out_scaled_mm, input_dtype) out_emulated = out_emulated.to(compare_type) out_emulated = out_emulated / tensor_to_scale(out_emulated, input_dtype) if base_dtype in {torch.bfloat16, torch.float16}: atol, rtol = 7e-2, 7e-2 else: atol, rtol = 3e-3, 3e-3 torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) def test_float8_bias(self, device) -> None: (k, l, m) = (16, 48, 32) x = torch.ones((k, l), device=device).to(e4m3_type) y = torch.full((m, l), .25, device=device, dtype=e4m3_type).t() bias = torch.full((m,), 4.0, device=device, dtype=torch.half) scale_a = torch.tensor(1.0, device=device) scale_b = torch.tensor(1.0, device=device) out_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b) outb_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, bias=bias) # this fails on ROCm currently because hipblaslt doesn't have amax op out_fp32 = out_fp8.to(torch.float32) outb_fp32 = outb_fp8.to(torch.float32) difference = torch.abs(out_fp32 - outb_fp32) self.assertEqual(difference, torch.tensor(4.0, device=device).expand_as(out_fp32)) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("bias", [True, False]) def test_non_divisible_leading_dim(self, device, bias: bool) -> None: x = torch.rand((17, 16), device=device).to(e4m3_type) y = torch.rand((16, 16), device=device).to(e4m3_type).t() scale_a = torch.tensor(1.0, device=device) scale_b = torch.tensor(1.0, device=device) input_bias = None if bias: input_bias = torch.rand((16,), device=device).to(torch.half) _ = torch._scaled_mm(x, y, scale_a, scale_b, bias=input_bias) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) def test_float8_bias_relu_edgecase(self, device) -> None: (k, l, m) = (16, 48, 32) x = torch.full((k, l), 0.0, device=device).to(e4m3_type) y = torch.full((m, l), 1.0, device=device, dtype=e4m3_type).t() bias = torch.full((m,), -3.0, device=device, dtype=torch.half) scale_a = torch.tensor(1.0, device=device) scale_b = torch.tensor(1.0, device=device) outb_fp8 = torch._scaled_mm(x, y, scale_a, scale_b, bias=bias) outb_fp32 = outb_fp8.to(torch.float32) self.assertEqual(outb_fp32, torch.tensor(-3.0, device=device).expand_as(outb_fp32)) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) def test_float32_output_errors_with_bias(self, device) -> None: (k, l, m) = (16, 48, 32) x = torch.rand((k, l), device=device).to(e4m3_type) y = torch.full((m, l), .25, device=device, dtype=e4m3_type).t() scale_a = torch.tensor(1.0, device=device) scale_b = torch.tensor(1.0, device=device) bias = torch.full((m,), 4.0, device=device, dtype=torch.bfloat16) self.assertRaisesRegex( RuntimeError, "Bias is not supported when out_dtype is set to Float32", lambda: torch._scaled_mm(x, y, scale_a, scale_b, bias=bias, out_dtype=torch.float32), ) @unittest.skipIf(PLATFORM_SUPPORTS_FP8, f8_msg) def test_error_message_fp8_pre_sm89(self, device) -> None: (k, l, m) = (16, 48, 32) x = torch.rand((k, l), device=device).to(e4m3_type) y = torch.rand((m, l), device=device).to(e4m3_type).t() scale_a = torch.tensor(1.0, device=device) scale_b = torch.tensor(1.0, device=device) self.assertRaisesRegex( RuntimeError, r"torch\.\_scaled\_mm is only supported on CUDA devices with compute capability \>\= 9\.0 or 8\.9, or ROCm MI300\+", lambda: torch._scaled_mm(x, y, scale_a, scale_b, out_dtype=torch.float32), ) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) def test_float8_scale_fast_accum(self, device) -> None: size = (16, 16) x = torch.full(size, .5, device=device, dtype=e4m3_type) # hipblaslt does not yet support mixed e4m3_type input y_type = e4m3_type if torch.version.hip else e5m2_type y = torch.full(size, .5, device=device, dtype=y_type).t() scale_a = torch.tensor(1.5, device=device) scale_b = torch.tensor(0.66, device=device) out_fp8 = torch._scaled_mm(x, y, scale_a, scale_b, use_fast_accum=True) self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4., device=device)) out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, use_fast_accum=True) self.assertEqual(out_fp8, out_fp8_s) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) @unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89+ specific") @parametrize("use_fast_accum", [True, False]) def test_float8_rowwise_scaling_sanity(self, device, use_fast_accum: bool) -> None: M, K, N = (1024, 512, 2048) fill_value = 0.5 x = torch.full((M, K), fill_value, device=device) y = torch.full((N, K), fill_value, device=device) x_scales = torch.ones((x.shape[0], 1), device=device, dtype=torch.float32) y_scales = torch.ones((1, y.shape[0]), device=device, dtype=torch.float32) x_fp8 = x.to(e4m3_type) y_fp8 = y.to(e4m3_type).t() out_fp8 = torch._scaled_mm( x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16, use_fast_accum=use_fast_accum, ) self.assertEqual( out_fp8.to(torch.float32), torch.full((M, N), K * (fill_value**2), device=device) ) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) def test_float8_error_messages(self, device) -> None: M, K, N = (1024, 512, 2048) fill_value = 0.5 x = torch.full((M, K), fill_value, device=device) y = torch.full((N, K), fill_value, device=device) x_fp8 = x.to(e4m3_type) y_fp8 = y.to(e4m3_type).t() with self.assertRaisesRegex( RuntimeError, re.escape( "For RowWise scaling, scale_a should be (1024, 1) and scale_b " "should be (1, 2048). Got scale_a.size()=(1, 1) and scale_b.size()=(1, 2)" ), ): torch._scaled_mm( x_fp8, y_fp8, scale_a=torch.ones((1, 1), device="cuda"), scale_b=torch.ones((1, 2), device="cuda"), out_dtype=torch.bfloat16, ) with self.assertRaisesRegex( RuntimeError, re.escape( " For RowWise scaling, scale_a should be (1024, 1) and scale_b " "should be (1, 2048). Got scale_a.size()=(1024, 1) and scale_b.size()=(1, 2049)" ), ): torch._scaled_mm( x_fp8, y_fp8, scale_a=torch.ones((M, 1), device="cuda"), scale_b=torch.ones((1, N + 1), device="cuda"), out_dtype=torch.bfloat16, ) with self.assertRaisesRegex( RuntimeError, re.escape("For non-TensorWise scaling, scale tensors must be 2-dimensional"), ): torch._scaled_mm( x_fp8, y_fp8, scale_a=torch.ones((M), device="cuda"), scale_b=torch.ones((N, N), device="cuda"), out_dtype=torch.bfloat16, ) with self.assertRaisesRegex( RuntimeError, re.escape( "Both scale_a and scale_b must be contiguous for RowWise scaling." ), ): torch._scaled_mm( x_fp8, y_fp8, scale_a=torch.ones((M, 1), device="cuda"), scale_b=torch.ones((1, N * 2), device="cuda")[:, ::2], out_dtype=torch.bfloat16, ) # Note re.compile is used, not re.escape. This is to accomodate fn vs fnuz type message. with self.assertRaisesRegex( RuntimeError, r"Expected b\.dtype\(\) == at::kFloat8_e4m3fnu?z? to be true, but got false\.", ): torch._scaled_mm( x_fp8, y_fp8.to(e5m2_type), scale_a=torch.ones((M, 1), device="cuda"), scale_b=torch.ones((1, N), device="cuda"), out_dtype=torch.bfloat16, ) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) @unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89+ specific") @parametrize("base_dtype", [torch.bfloat16]) def test_scaled_mm_vs_emulated_row_wise(self, base_dtype): torch.manual_seed(42) input_dtype = e4m3_type output_dtype = base_dtype x = torch.randn(16, 16, device="cuda", dtype=base_dtype) y = torch.randn(32, 16, device="cuda", dtype=base_dtype).t() x_scales = tensor_to_scale(x, input_dtype, dim=1).float() y_scales = tensor_to_scale(y, input_dtype, dim=0).float() x_fp8 = to_fp8_saturated(x * x_scales, e4m3_type) y_fp8 = to_fp8_saturated(y * y_scales, e4m3_type) # Calculate actual F8 mm out_scaled_mm = mm_float8( x_fp8, y_fp8, a_scale=x_scales, b_scale=y_scales, output_dtype=output_dtype ) # Calculate emulated F8 mm out_emulated = mm_float8_emulated( x_fp8, x_scales, y_fp8, y_scales, output_dtype ) if base_dtype in {torch.bfloat16, torch.float16}: atol, rtol = 7e-2, 7e-2 else: atol, rtol = 2e-3, 2e-3 torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("which_dim_zero", [0, 1, 2]) @parametrize("use_torch_compile", [False, True]) def test_zero_dim_tensorwise(self, which_dim_zero, use_torch_compile) -> None: device = "cuda" x_dtype, y_dtype = torch.float8_e4m3fn, torch.float8_e4m3fn out_dtype = torch.bfloat16 M, K, N = 32, 32, 32 if which_dim_zero == 0: M = 0 elif which_dim_zero == 1: K = 0 elif which_dim_zero == 2: N = 0 x_fp8 = torch.zeros(M, K, device=device).to(x_dtype) y_fp8 = torch.zeros(N, K, device=device, dtype=y_dtype).t() out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float)) scale_a = torch.tensor(float('-inf'), device=device) scale_b = torch.tensor(float('-inf'), device=device) f = torch._scaled_mm if use_torch_compile: f = torch.compile(torch._scaled_mm) out_fp8 = f(x_fp8, y_fp8, scale_a, scale_b, out_dtype=out_dtype) self.assertEqual(out_dtype, out_fp8.dtype) self.assertEqual(out_fp32, out_fp8.to(torch.float)) @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support sm carveout") @unittest.skipIf(IS_WINDOWS, "Windows doesn't support row-wise scaling") @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @unittest.skipIf(not SM90OrLater, "sm89 kernel isn't opted into carveout yet") def test_honor_sm_carveout(self) -> None: torch.manual_seed(42) x = torch.randn(8192, 2048, device="cuda", dtype=torch.float32) y = torch.randn(8192, 2048, device="cuda", dtype=torch.float32).t() x_scales = tensor_to_scale(x, e4m3_type, dim=1).reciprocal() y_scales = tensor_to_scale(y, e4m3_type, dim=0).reciprocal() x_fp8 = to_fp8_saturated(x / x_scales, e4m3_type) y_fp8 = to_fp8_saturated(y / y_scales, e4m3_type) with tempfile.NamedTemporaryFile() as f: with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: self.assertIsNone(torch._C._get_sm_carveout_experimental()) torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16) torch._C._set_sm_carveout_experimental(0) self.assertEqual(torch._C._get_sm_carveout_experimental(), 0) torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16) torch._C._set_sm_carveout_experimental(66) self.assertEqual(torch._C._get_sm_carveout_experimental(), 66) torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16) torch._C._set_sm_carveout_experimental(None) self.assertIsNone(torch._C._get_sm_carveout_experimental()) torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16) prof.export_chrome_trace(f.name) no_carveout, carveout_0, carveout_66, no_carveout_again = [ math.prod(evt.get("args", {}).get("grid", [])) for evt in json.load(open(f.name))["traceEvents"] if evt.get("cat", "") == "kernel" ] self.assertEqual(no_carveout, no_carveout_again) self.assertNotEqual(no_carveout, carveout_66) self.assertNotEqual(carveout_66, carveout_0) @unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM, mx_skip_msg) @parametrize("test_case_name", [ "a_eye_b_eye", "a_ones_b_ones", "a_ones_modified_b_ones", "a_ones_b_ones_modified", "a_scale_modified_b_ones", "a_ones_b_scale_modified", "data_random_scales_one", "data_random_scales_from_data", ]) @parametrize("fast_accum", [False, True]) @parametrize("mkn", [ # Nice shapes (128, 128, 128), (256, 256, 256), (128, 256, 512), (256, 512, 128), (512, 128, 256), # Non block multiples (65, 96, 112), (197, 224, 272), # K not multiple of 32 (197, 240, 272), # Very unbalanced (1023, 64, 48), (31, 1024, 64), (45, 96, 1024), # Mixed large and small (2, 1024, 128), (127, 96, 1024), (1025, 128, 96) ], name_fn=lambda mkn: f"{mkn[0]}_{mkn[1]}_{mkn[2]}") def test_blockwise_mxfp8_numerics(self, test_case_name, fast_accum, mkn) -> None: # inspiration: https://github.com/pytorch/ao/pull/1625 device = "cuda" M, K, N = mkn BLOCK_SIZE = 32 require_exact_match = True def ceil_div(a, b): return (a + b - 1) // b if test_case_name == "a_eye_b_eye": if not ((M == K) and (M == N)): return unittest.skip("this test is only defined for M == K == N, skipping") A_ref = torch.eye(M, device=device, dtype=torch.bfloat16) B_ref = torch.eye(M, device=device, dtype=torch.bfloat16) A = A_ref.to(torch.float8_e4m3fn) B = B_ref.to(torch.float8_e4m3fn) A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu) B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu) # convert to swizzled format A_scale = to_blocked(A_scale) B_scale = to_blocked(B_scale) elif test_case_name == "a_ones_b_ones": A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16) B_ref = torch.ones(N, K, device=device, dtype=torch.bfloat16) A = A_ref.to(torch.float8_e4m3fn) B = B_ref.to(torch.float8_e4m3fn) A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu) B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu) # convert to swizzled format A_scale = to_blocked(A_scale) B_scale = to_blocked(B_scale) elif test_case_name == "a_ones_modified_b_ones": A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16) B_ref = torch.ones(N, K, device=device, dtype=torch.bfloat16) A = A_ref.to(torch.float8_e4m3fn) B = B_ref.to(torch.float8_e4m3fn) A_ref[1][0:BLOCK_SIZE] = 2 A[1][0:BLOCK_SIZE] = 2 A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu) B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu) # convert to swizzled format A_scale = to_blocked(A_scale) B_scale = to_blocked(B_scale) elif test_case_name == "a_ones_b_ones_modified": A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16) B_ref = torch.ones(N, K, device=device, dtype=torch.bfloat16) A = A_ref.to(torch.float8_e4m3fn) B = B_ref.to(torch.float8_e4m3fn) B_ref[1][0:BLOCK_SIZE] = 2 B[1][0:BLOCK_SIZE] = 2 A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu) B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu) # convert to swizzled format A_scale = to_blocked(A_scale) B_scale = to_blocked(B_scale) elif test_case_name == "a_scale_modified_b_ones": A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16) B_ref = torch.ones(N, K, device=device, dtype=torch.bfloat16) A = A_ref.to(torch.float8_e4m3fn) B = B_ref.to(torch.float8_e4m3fn) A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu) B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu) A_ref[1][0:BLOCK_SIZE] = 4 A[1][0:BLOCK_SIZE] = 2 A_scale[1][0] = 2 # convert to swizzled format A_scale = to_blocked(A_scale) B_scale = to_blocked(B_scale) elif test_case_name == "a_ones_b_scale_modified": A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16) B_ref = torch.ones(N, K, device=device, dtype=torch.bfloat16) A = A_ref.to(torch.float8_e4m3fn) B = B_ref.to(torch.float8_e4m3fn) A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu) B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu) B_ref[1][0:BLOCK_SIZE] = 4 B[1][0:BLOCK_SIZE] = 2 B_scale[1][0] = 2 # convert to swizzled format A_scale = to_blocked(A_scale) B_scale = to_blocked(B_scale) elif test_case_name == "data_random_scales_one": require_exact_match = False # scales all-ones, element data random while being exactly representable in float8_e4m3fn # generate integers in [0, 255] and interpret as float8_e4m3fn A_ref = torch.randint(0, 255, (M, K), device=device, dtype=torch.uint8).view(torch.float8_e4m3fn).to(torch.bfloat16) B_ref = torch.randint(0, 255, (N, K), device=device, dtype=torch.uint8).view(torch.float8_e4m3fn).to(torch.bfloat16) # modification: don't allow NaN values A_ref[torch.isnan(A_ref)] = 0 B_ref[torch.isnan(B_ref)] = 0 A = A_ref.to(torch.float8_e4m3fn) B = B_ref.to(torch.float8_e4m3fn) A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu) B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu) # convert to swizzled format A_scale = to_blocked(A_scale) B_scale = to_blocked(B_scale) elif test_case_name == "data_random_scales_from_data": if not K % BLOCK_SIZE == 0: return unittest.skip(f"this test is only defined for K a multiple of {BLOCK_SIZE}, skipping") require_exact_match = False # random data, scales from data A_ref = torch.randn((M, K), device=device, dtype=torch.bfloat16) * 1000 B_ref = torch.randn((N, K), device=device, dtype=torch.bfloat16) * 1000 # Calculate scales based on the inputs A_scale = data_to_mx_scale(A_ref, BLOCK_SIZE) B_scale = data_to_mx_scale(B_ref, BLOCK_SIZE) max_val = F8E4M3_MAX_VAL min_val = -1 * max_val A = (A_ref.reshape(-1, BLOCK_SIZE) / A_scale.reshape(M * ceil_div(K, BLOCK_SIZE), 1).float()).reshape(M, K) A = A.clamp(min=min_val, max=max_val).to(torch.float8_e4m3fn) B = (B_ref.reshape(-1, BLOCK_SIZE) / B_scale.reshape(N * ceil_div(K, BLOCK_SIZE), 1).float()).reshape(N, K) B = B.clamp(min=min_val, max=max_val).to(torch.float8_e4m3fn) # convert to swizzled format A_scale = to_blocked(A_scale) B_scale = to_blocked(B_scale) C_ref = A_ref @ B_ref.t() C = torch._scaled_mm( A, B.t(), A_scale, B_scale, out_dtype=torch.bfloat16, use_fast_accum=fast_accum, ) if require_exact_match: torch.testing.assert_close(C, C_ref, atol=0, rtol=0) else: sqnr = compute_error(C_ref, C) assert sqnr.item() > 22.0 @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) def test_blockwise_mxfloat8_error_messages(self, device) -> None: M, K, N = (1024, 512, 2048) BLOCK_SIZE_K = 32 BLOCK_SIZE_MN = 128 fill_value = 0.5 x = torch.full((M, K), fill_value, device=device) y = torch.full((N, K), fill_value, device=device) x_fp8 = x.to(e4m3_type) y_fp8 = y.to(e4m3_type).t() def ceil_div(a, b): return (a + b - 1) // b num_k_blocks = ceil_div(K, BLOCK_SIZE_K) padded_num_k_blocks = ceil_div(num_k_blocks, 4) * 4 expected_a_size = BLOCK_SIZE_MN * ceil_div(M, BLOCK_SIZE_MN) * padded_num_k_blocks expected_b_size = BLOCK_SIZE_MN * ceil_div(N, BLOCK_SIZE_MN) * padded_num_k_blocks # Test wrong scale tensor size for scale_a with correct dtype with self.assertRaisesRegex( RuntimeError, re.escape( f"For BlockWise scaling: Expected scale_a size to be {expected_a_size} " f"but got {expected_a_size - 1}" ), ): incorrect_size_a = torch.ones(expected_a_size - 1, device=device, dtype=torch.float8_e8m0fnu) correct_size_b = torch.ones(expected_b_size, device=device, dtype=torch.float8_e8m0fnu) torch._scaled_mm( x_fp8, y_fp8, scale_a=incorrect_size_a, scale_b=correct_size_b, out_dtype=torch.bfloat16, ) # Test wrong scale tensor size for scale_b with correct dtype with self.assertRaisesRegex( RuntimeError, re.escape( f"For BlockWise scaling: Expected scale_b size to be {expected_b_size} " f"but got {expected_b_size + 1}" ), ): correct_size_a = torch.ones(expected_a_size, device=device, dtype=torch.float8_e8m0fnu) incorrect_size_b = torch.ones(expected_b_size + 1, device=device, dtype=torch.float8_e8m0fnu) torch._scaled_mm( x_fp8, y_fp8, scale_a=correct_size_a, scale_b=incorrect_size_b, out_dtype=torch.bfloat16, ) # Test non-contiguous scale tensors with correct dtype with self.assertRaisesRegex( RuntimeError, re.escape( "For BlockWise scaling: Both scale_a and scale_b must be contiguous" ), ): non_contiguous_a = torch.ones(expected_a_size * 2, device=device, dtype=torch.float8_e8m0fnu)[::2] contiguous_b = torch.ones(expected_b_size, device=device, dtype=torch.float8_e8m0fnu) torch._scaled_mm( x_fp8, y_fp8, scale_a=non_contiguous_a, scale_b=contiguous_b, out_dtype=torch.bfloat16, ) @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") @unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions") @unittest.skipIf(not _IS_SM8X, "mixed dtypes linear only supported on SM 8.x") class TestMixedDtypesLinearCuda(TestCase): @dtypes(torch.float16, torch.bfloat16) def test_mixed_dtypes_linear(self, dtype: torch.dtype, device: str = "cuda"): version = _get_torch_cuda_version() if version < (11, 8): self.skipTest("_mixed_dtypes_linear only compiled for CUDA 11.8+") def run_test( batch_shape, m, n, k, add_bias, activation, dtype, dtypeq, device, rtol, atol, ): if not add_bias and activation != "none": return val_lo, val_hi = -1, 1 valq_lo, valq_hi = -2, 2 input = make_tensor( *batch_shape, m, k, low=val_lo, high=val_hi, dtype=dtype, device=device ) weight = make_tensor( n, k, low=valq_lo, high=valq_hi, dtype=torch.int8, device=device ) scale = make_tensor( (n,), low=val_lo, high=val_hi, dtype=input.dtype, device=device ) bias = ( make_tensor( (n,), low=val_lo, high=val_hi, dtype=input.dtype, device=device ) if add_bias else None ) input_ref = input.reshape(-1, input.shape[-1]) # First, test plain multiplication. weight_ref = weight.T.to(input.dtype) * scale.view(1, n) weightq = ( pack_int4_to_int8(weight.T) if dtypeq == torch.quint4x2 else weight.T ) output_ref = torch.mm(input_ref, weight_ref).reshape(*input.shape[:-1], n) output = torch.ops.aten._mixed_dtypes_linear( input, quantized_weight_reorder_for_mixed_dtypes_linear_cutlass( weightq, dtypeq, transpose=False ), scale, ) torch.testing.assert_close(output, output_ref, rtol=rtol, atol=atol) # Second, test the linear operator itself. weight_ref = weight.to(input.dtype) * scale.view(n, 1) weightq = pack_int4_to_int8(weight) if dtypeq == torch.quint4x2 else weight bias_ref = bias.view(1, n) if add_bias else None output_ref = torch.nn.functional.linear( input_ref, weight_ref, bias=bias_ref ).reshape(*input.shape[:-1], n) if activation == "relu": relu = torch.nn.ReLU() output_ref = relu(output_ref) elif activation == "silu": silu = torch.nn.SiLU() output_ref = silu(output_ref) output = torch.ops.aten._mixed_dtypes_linear( input, quantized_weight_reorder_for_mixed_dtypes_linear_cutlass( weightq, dtypeq, transpose=True ), scale, bias=bias, activation=activation, ) torch.testing.assert_close(output, output_ref, rtol=rtol, atol=atol) dtypeqs = [torch.int8, torch.quint4x2] batch_shapes = [[], [2], [2, 1]] shapes = [ [8, 64, 64], [8, 64, 128], [8, 128, 64], [8, 128, 128], [8, 128, 192], [8, 128, 256], [8, 256, 128], [8, 256, 384], [8, 384, 256], ] activations = [None, "relu", "silu"] rtol, atol = 1e-3, 1e-3 if dtype == torch.bfloat16: rtol, atol = 1e-2, 1e-3 for dtypeq, batch_shape, (m, n, k), add_bias, activation in product( dtypeqs, batch_shapes, shapes, (False, True), activations ): run_test( batch_shape, m, n, k, add_bias, activation, dtype, dtypeq, device, rtol, atol, ) instantiate_device_type_tests(TestMatmulCuda, globals(), except_for="cpu") instantiate_device_type_tests(TestFP8MatmulCuda, globals(), except_for="cpu") instantiate_device_type_tests(TestMixedDtypesLinearCuda, globals(), except_for="cpu") if __name__ == '__main__': TestCase._default_dtype_check_enabled = True run_tests()