mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Newer matmul kernels, e.g. those targeting Hopper GPUs, sometime use a "persistent" schedule which consists in launching as many CUDA blocks as there are SMs on the GPU, with each such block then working on multiple output tiles in a row. This allows to eliminate the overhead of starting and finishing each tile, effectively doing cross-tile pipelining. In previous generations these latencies could be hidden by having multiple CUDA blocks per SM but, with blocks becoming larger, only one can run at a time per SM and thus this needs to be taken care of in software. Persistent kernels become an issue when other kernels are running concurrently. The classical example is a NCCL communication kernel running in the background. In such cases the matmul expects to be able to use all the SMs but is prevented from doing so because some of the are busy. This can lead to its blocks being scheduled as two separate waves on the available SMs. This "wave quantization" can double the latency of the matmul kernels. While we wait for smarter solutions, such as automatic load balancing among the blocks, an easy way to unblock ourselves is to tell the matmuls to only use a subset of the GPU's SMs. For this, I am introducing a global `sm_carveout` flag which can be used to specify how many SMs should be left available for other kernels. For now I only change the cuBLAS kernels and the scaled-mm CUTLASS kernel. More kernels can be opted-in later. I tested this change manually, by using the Kineto profiler to look up the grid size of a scaled-mm kernel with different values of `sm_carveout`, and making sure it changed. Suggestions are welcome for a more automated test. Pull Request resolved: https://github.com/pytorch/pytorch/pull/144974 Approved by: https://github.com/eqy, https://github.com/albanD
1298 lines
52 KiB
Python
1298 lines
52 KiB
Python
# Owner(s): ["module: linear algebra"]
|
|
|
|
import contextlib
|
|
import json
|
|
import math
|
|
import re
|
|
import tempfile
|
|
import unittest
|
|
from itertools import product
|
|
from functools import partial
|
|
from typing import Optional
|
|
|
|
import torch
|
|
|
|
from torch.quantization._quantized_conversions import (
|
|
pack_int4_to_int8,
|
|
quantized_weight_reorder_for_mixed_dtypes_linear_cutlass,
|
|
)
|
|
|
|
from torch.testing import make_tensor
|
|
from torch.testing._internal.common_cuda import (
|
|
SM53OrLater,
|
|
SM89OrLater,
|
|
SM90OrLater,
|
|
_get_torch_cuda_version,
|
|
PLATFORM_SUPPORTS_FP8,
|
|
PLATFORM_SUPPORTS_MX_GEMM
|
|
)
|
|
from torch.testing._internal.common_device_type import (
|
|
dtypes,
|
|
instantiate_device_type_tests,
|
|
onlyCUDA,
|
|
tol as xtol,
|
|
toleranceOverride,
|
|
)
|
|
|
|
from torch.testing._internal.common_utils import (
|
|
IS_ARM64,
|
|
IS_JETSON,
|
|
IS_WINDOWS,
|
|
parametrize,
|
|
run_tests,
|
|
skipIfRocm,
|
|
skipIfRocmVersionLessThan,
|
|
TEST_CUDA,
|
|
TEST_WITH_ROCM,
|
|
TestCase,
|
|
)
|
|
|
|
_IS_SM8X = False
|
|
if TEST_CUDA:
|
|
_IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8
|
|
|
|
# Protects against includes accidentally setting the default dtype
|
|
assert torch.get_default_dtype() is torch.float32
|
|
|
|
|
|
@unittest.skipIf(IS_ARM64, "Issue with numpy version on arm")
|
|
class TestMatmulCuda(TestCase):
|
|
def setUp(self):
|
|
super(self.__class__, self).setUp()
|
|
torch.backends.cuda.matmul.allow_tf32 = False
|
|
|
|
def tearDown(self):
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
super(self.__class__, self).tearDown()
|
|
|
|
def cublas_addmm(self, size: int, dtype: torch.dtype, reduced_precision: bool = False, fp16_accumulate: bool = False):
|
|
#
|
|
# Check for catastrophic cuBLAS inaccuracy by measuring the deviation between
|
|
# results from the CUDA invocation of torch.addmm and the CPU invocation
|
|
# (which does not use CUDA backend).
|
|
#
|
|
# Get dims
|
|
n, m, p = (size + 1, size, size + 2)
|
|
# Disable reduced precision reductions in BFloat16 to bypass some kernels
|
|
# which fail the threshold check
|
|
orig_bf16 = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction
|
|
orig_fp16 = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction
|
|
orig_fp16_accumulate = torch.backends.cuda.matmul.allow_fp16_accumulation
|
|
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = reduced_precision
|
|
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = reduced_precision
|
|
torch.backends.cuda.matmul.allow_fp16_accumulation = fp16_accumulate
|
|
# Make random tensors on CPU (seed set on common_utils.py import)
|
|
# (Not using numpy because it does not support bfloat16)
|
|
make_arg = partial(make_tensor, dtype=dtype, device="cpu")
|
|
m_beta = make_arg(1)
|
|
m_input = make_arg((n, p))
|
|
m_1 = make_arg((n, m))
|
|
m_2 = make_arg((m, p))
|
|
# scale to abate overflows in fp16 accum
|
|
if fp16_accumulate:
|
|
m_1 = m_1 / 100
|
|
m_2 = m_2 / 100
|
|
# *(B)FLOAT16 Special Handling*
|
|
# Backend does not tensorize float16 on CPU,
|
|
# and bloat16 may present accuracy issues,
|
|
# so convert to float32 for these cases
|
|
# (but keep same for other types, e.g. float32 and int*)
|
|
if dtype == torch.float16 or dtype == torch.bfloat16:
|
|
m_beta = m_beta.to(dtype=torch.float32)
|
|
m_input = m_input.to(dtype=torch.float32)
|
|
m_1 = m_1.to(dtype=torch.float32)
|
|
m_2 = m_2.to(dtype=torch.float32)
|
|
# Get CPU result
|
|
res_cpu = torch.addmm(m_input, m_1, m_2, beta=m_beta.item())
|
|
# *(B)FLOAT16 Special Handling*``
|
|
# Convert back to (b)float16
|
|
if dtype == torch.float16 or dtype == torch.bfloat16:
|
|
m_beta = m_beta.to(dtype=dtype)
|
|
m_input = m_input.to(dtype=dtype)
|
|
m_1 = m_1.to(dtype=dtype)
|
|
m_2 = m_2.to(dtype=dtype)
|
|
res_cpu = res_cpu.to(dtype=dtype)
|
|
# Move arg tensors to CUDA
|
|
m_beta = m_beta.to("cuda")
|
|
m_input = m_input.to("cuda")
|
|
m_1 = m_1.to("cuda")
|
|
m_2 = m_2.to("cuda")
|
|
# Get CUDA result
|
|
res_cuda = torch.addmm(m_input, m_1, m_2, beta=m_beta.item())
|
|
# Move to CPU for comparison
|
|
res_cuda = res_cuda.to("cpu")
|
|
# Compare
|
|
self.assertEqual(res_cpu, res_cuda)
|
|
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = orig_bf16
|
|
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig_fp16
|
|
torch.backends.cuda.matmul.allow_fp16_accumulation = orig_fp16_accumulate
|
|
|
|
@onlyCUDA
|
|
@skipIfRocmVersionLessThan((5, 2))
|
|
# imported 'tol' as 'xtol' to avoid aliasing in code above
|
|
@toleranceOverride({torch.float16: xtol(atol=1e-1, rtol=1e-1),
|
|
torch.bfloat16: xtol(atol=1e-1, rtol=1e-1),
|
|
torch.float32: xtol(atol=1e-1, rtol=1e-1)})
|
|
@dtypes(torch.float16, torch.bfloat16, torch.float32)
|
|
@parametrize("size", [100, 1000, 10000])
|
|
def test_cublas_addmm(self, size: int, dtype: torch.dtype):
|
|
self.cublas_addmm(size, dtype, False)
|
|
|
|
@onlyCUDA
|
|
@skipIfRocmVersionLessThan((5, 2))
|
|
# imported 'tol' as 'xtol' to avoid aliasing in code above
|
|
@toleranceOverride({torch.float16: xtol(atol=7e-1, rtol=2e-1),
|
|
torch.bfloat16: xtol(atol=1e1, rtol=2e-1)})
|
|
@dtypes(torch.float16, torch.bfloat16)
|
|
@parametrize("size", [100, 1000, 10000])
|
|
def test_cublas_addmm_reduced_precision(self, size: int, dtype: torch.dtype):
|
|
self.cublas_addmm(size, dtype, True)
|
|
|
|
@onlyCUDA
|
|
@skipIfRocmVersionLessThan((5, 2))
|
|
# imported 'tol' as 'xtol' to avoid aliasing in code above
|
|
@toleranceOverride({torch.float16: xtol(atol=7e-1, rtol=2e-1),
|
|
torch.bfloat16: xtol(atol=1e1, rtol=2e-1)})
|
|
@dtypes(torch.float16, torch.bfloat16)
|
|
@parametrize("size", [100, 1000, 10000])
|
|
def test_cublas_addmm_reduced_precision_fp16_accumulate(self, size: int, dtype: torch.dtype):
|
|
self.cublas_addmm(size, dtype, False, True)
|
|
|
|
@onlyCUDA
|
|
@skipIfRocm
|
|
def test_cublas_and_lt_reduced_precision_fp16_accumulate(self):
|
|
orig_fp16_accumulate = torch.backends.cuda.matmul.allow_fp16_accumulation
|
|
torch.backends.cuda.matmul.allow_fp16_accumulation = True
|
|
x = torch.rand(32, 512, 512, device='cuda', dtype=torch.half)
|
|
w = torch.rand(512, 512, device='cuda', dtype=torch.half)
|
|
b = torch.rand(512, device='cuda', dtype=torch.half)
|
|
out = torch.nn.functional.linear(x, w, b)
|
|
out_cpu = torch.nn.functional.linear(x.cpu(), w.cpu(), b.cpu())
|
|
self.assertEqual(out, out_cpu, atol=5e-3, rtol=8e-3)
|
|
|
|
a = torch.rand(16, 128, 128, device='cuda', dtype=torch.half)
|
|
b = torch.rand(16, 128, 128, device='cuda', dtype=torch.half)
|
|
c = torch.rand(16, 128, 128, device='cuda', dtype=torch.half)
|
|
out = torch.baddbmm(a, b, c)
|
|
out_cpu = torch.baddbmm(a.cpu(), b.cpu(), c.cpu())
|
|
self.assertEqual(out, out_cpu, atol=1e-3, rtol=5e-3)
|
|
torch.backends.cuda.matmul.allow_fp16_accumulation = orig_fp16_accumulate
|
|
|
|
@onlyCUDA
|
|
@toleranceOverride({torch.float16: xtol(atol=1e-3, rtol=2e-3)})
|
|
@dtypes(torch.float16)
|
|
def test_cublas_addmm_alignment(self, dtype):
|
|
device = 'cuda'
|
|
# perturb X, A, or B alignment
|
|
for idx in range(0, 3):
|
|
for offset in range(1, 3):
|
|
offsets = [0, 0, 0]
|
|
offsets[idx] = offset
|
|
x_offset, a_offset, b_offset = offsets
|
|
A = torch.rand((5120 * 2560 + a_offset), requires_grad=True, dtype=dtype, device=device)
|
|
A = A[a_offset:].reshape(5120, 2560)
|
|
X = torch.rand((26 * 2560 + x_offset), requires_grad=True, dtype=dtype, device=device)
|
|
X = X[x_offset:].reshape(26, 1, 2560)
|
|
B = torch.rand((5120 + b_offset), requires_grad=True, dtype=dtype, device=device)
|
|
B = B[b_offset:].reshape(5120)
|
|
out = torch.nn.functional.linear(X, A, B)
|
|
self.assertEqual(out, torch.matmul(X, A.transpose(1, 0)) + B)
|
|
|
|
@onlyCUDA
|
|
@unittest.skipIf(IS_JETSON, "Too large for Jetson")
|
|
@toleranceOverride({torch.float32: xtol(atol=1e-5, rtol=1.1e-5)})
|
|
@dtypes(*([torch.float32, torch.float16] +
|
|
[torch.bfloat16] if TEST_WITH_ROCM or SM53OrLater else []))
|
|
@parametrize(
|
|
"batch_size, N, M, P",
|
|
[(2, 100, 100, 100),
|
|
(2, 1000, 1000, 1000),
|
|
(1, 10000, 1000, 10000),
|
|
(1, 10000, 10000, 10000)],
|
|
name_fn=lambda batch_size, N, M, P: f"{batch_size}_{N}_{M}_{P}",
|
|
)
|
|
@skipIfRocm
|
|
def test_cublas_baddbmm_large_input(self, device, batch_size, N, M, P, dtype):
|
|
cpu_dtype = dtype
|
|
if dtype == torch.float16 or dtype == torch.bfloat16:
|
|
cpu_dtype = torch.float32
|
|
|
|
M1 = torch.rand((N, M), device=device, dtype=dtype)
|
|
M2 = torch.rand((M, P), device=device, dtype=dtype)
|
|
A = torch.rand((N, P), device=device, dtype=dtype)
|
|
|
|
def _convert_to_cpu(t):
|
|
return t.to(device='cpu', dtype=cpu_dtype)
|
|
M1_cpu, M2_cpu, A_cpu = map(_convert_to_cpu, [M1, M2, A])
|
|
|
|
# linear
|
|
out1_cpu = torch.nn.functional.linear(M1_cpu, M2_cpu.t(), A_cpu).to(dtype=dtype)
|
|
out1_gpu = torch.nn.functional.linear(M1, M2.t(), A).cpu()
|
|
self.assertEqual(out1_cpu, out1_gpu)
|
|
# test multiply the identity matrix
|
|
if N == M and M == P:
|
|
M2_eye = torch.eye(N, device=device, dtype=dtype)
|
|
out1_eye_gpu = torch.nn.functional.linear(M1, M2_eye.t(), torch.zeros_like(A))
|
|
self.assertEqual(M1_cpu.to(dtype=dtype), out1_eye_gpu.cpu())
|
|
|
|
# baddbmm
|
|
def _expand_to_batch(t: torch.Tensor):
|
|
return t.expand((batch_size, ) + t.size())
|
|
alpha, beta = 1.0, 1.0
|
|
M1, M2, A, M1_cpu, M2_cpu, A_cpu = map(_expand_to_batch, [M1, M2, A, M1_cpu, M2_cpu, A_cpu])
|
|
|
|
out2_cpu = torch.baddbmm(A_cpu, M1_cpu, M2_cpu, beta=beta, alpha=alpha).to(dtype=dtype)
|
|
out2_gpu = torch.baddbmm(A, M1, M2, beta=beta, alpha=alpha).cpu()
|
|
self.assertEqual(out2_cpu, out2_gpu)
|
|
# test multiply the identity matrix
|
|
if N == M and M == P:
|
|
M2_eye = torch.eye(N, device=device, dtype=dtype).expand(batch_size, N, N)
|
|
out2_eye_gpu = torch.baddbmm(torch.zeros_like(A), M1, M2_eye, beta=beta, alpha=alpha)
|
|
self.assertEqual(M1_cpu.to(dtype=dtype), out2_eye_gpu.cpu())
|
|
|
|
# cross comparison
|
|
self.assertEqual(out1_gpu, out2_gpu[0])
|
|
|
|
|
|
f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+ devices"
|
|
mx_skip_msg = "MX gemm is only supported on CUDA capability 10.0+"
|
|
|
|
if torch.version.hip and 'gfx94' in torch.cuda.get_device_properties(0).gcnArchName:
|
|
e4m3_type = torch.float8_e4m3fnuz
|
|
e5m2_type = torch.float8_e5m2fnuz
|
|
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max
|
|
E5M2_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max
|
|
else:
|
|
e4m3_type = torch.float8_e4m3fn
|
|
e5m2_type = torch.float8_e5m2
|
|
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max
|
|
E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max
|
|
|
|
# avoid division by zero when calculating scale
|
|
EPS = 1e-12
|
|
|
|
def amax_to_scale(
|
|
amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
|
|
):
|
|
""" Converts the amax value of a tensor to the fp8 scale.
|
|
Args:
|
|
amax: The amax value of the tensor.
|
|
float8_dtype: the float8 dtype.
|
|
orig_dtype: The original dtype of the tensor.
|
|
"""
|
|
scale = torch.empty_like(amax, dtype=torch.float32)
|
|
if float8_dtype == e4m3_type:
|
|
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
|
|
elif float8_dtype == e5m2_type:
|
|
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
|
|
else:
|
|
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")
|
|
|
|
# Ensure the scale is representable in float16,
|
|
# this helps when amax is small. We are assuming that we don't need
|
|
# to care about this for float32/bfloat16
|
|
if orig_dtype is torch.float16:
|
|
res = torch.clamp(res, max=torch.finfo(torch.float16).max)
|
|
|
|
scale.copy_(res)
|
|
return scale
|
|
|
|
def tensor_to_scale(x: torch.Tensor, float8_dtype: torch.dtype, dim=None):
|
|
if dim is None:
|
|
amax = torch.max(torch.abs(x))
|
|
else:
|
|
amax = torch.max(torch.abs(x), dim=dim, keepdim=True).values
|
|
|
|
return amax_to_scale(amax, float8_dtype, x.dtype)
|
|
|
|
def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor:
|
|
# naive implementation: dq -> op -> q
|
|
x_fp32 = x.to(torch.float) / x_scale
|
|
y_fp32 = y.to(torch.float) / y_scale
|
|
out_fp32 = torch.mm(x_fp32, y_fp32)
|
|
|
|
return out_fp32.to(out_dtype)
|
|
|
|
def addmm_float8_unwrapped(
|
|
a_data: torch.Tensor,
|
|
a_scale: torch.Tensor,
|
|
b_data: torch.Tensor,
|
|
b_scale: torch.tensor,
|
|
output_dtype: torch.dtype,
|
|
output_scale: Optional[torch.Tensor],
|
|
bias: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
a_inverse_scale = a_scale.reciprocal()
|
|
b_inverse_scale = b_scale.reciprocal()
|
|
if output_dtype == torch.float32 and bias is not None:
|
|
# Bias is not supported by _scaled_mm when output is fp32
|
|
output = torch._scaled_mm(
|
|
a_data,
|
|
b_data,
|
|
scale_a=a_inverse_scale,
|
|
scale_b=b_inverse_scale,
|
|
scale_result=output_scale,
|
|
out_dtype=output_dtype,
|
|
)
|
|
output += bias
|
|
return output
|
|
output = torch._scaled_mm(
|
|
a_data,
|
|
b_data,
|
|
bias=bias,
|
|
scale_a=a_inverse_scale,
|
|
scale_b=b_inverse_scale,
|
|
scale_result=output_scale,
|
|
out_dtype=output_dtype,
|
|
)
|
|
return output
|
|
|
|
def mm_float8(
|
|
a: torch.Tensor,
|
|
b: torch.Tensor,
|
|
a_scale: torch.Tensor,
|
|
b_scale: torch.Tensor,
|
|
output_dtype: torch.dtype, # output dtype
|
|
output_scale: Optional[torch.Tensor] = None, # output scale, precomputed
|
|
) -> torch.Tensor:
|
|
return addmm_float8_unwrapped(
|
|
a, a_scale, b, b_scale, output_dtype, output_scale
|
|
)
|
|
|
|
def to_fp8_saturated(
|
|
x: torch.Tensor,
|
|
fp8_dtype: torch.dtype
|
|
):
|
|
if fp8_dtype == e4m3_type:
|
|
x = x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS)
|
|
elif fp8_dtype == e5m2_type:
|
|
x = x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS)
|
|
else:
|
|
raise ValueError(f"to_fp8_saturated(): Unsupported fp8_dtype: {fp8_dtype}")
|
|
|
|
return x.to(fp8_dtype)
|
|
|
|
# copied from https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/mx/to_blocked.py
|
|
def ceil_div(a, b):
|
|
return (a + b - 1) // b
|
|
|
|
def to_blocked(input_matrix) -> torch.Tensor:
|
|
"""
|
|
Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern.
|
|
|
|
See:
|
|
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
|
|
|
|
Args:
|
|
input_matrix: Input tensor of shape (H, W)
|
|
|
|
Returns:
|
|
Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4))
|
|
"""
|
|
rows, cols = input_matrix.shape
|
|
n_row_blocks = ceil_div(rows, 128)
|
|
n_col_blocks = ceil_div(cols, 4)
|
|
|
|
# Calculate the padded shape
|
|
padded_rows = n_row_blocks * 128
|
|
padded_cols = n_col_blocks * 4
|
|
|
|
padded = input_matrix
|
|
# Ideally we would use torch.nn.pad but it doesn't support float8_e8m0fnu for now
|
|
if (rows, cols) != (padded_rows, padded_cols):
|
|
padded = torch.zeros((padded_rows, padded_cols), device=input_matrix.device, dtype=input_matrix.dtype)
|
|
padded[:rows, :cols] = input_matrix
|
|
|
|
# Rearrange the blocks
|
|
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
|
|
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
|
|
|
|
return rearranged.flatten()
|
|
|
|
def compute_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
"""Computes the error between two tensors in dB.
|
|
|
|
For more details see:
|
|
https://en.wikipedia.org/wiki/Signal-to-noise_ratio
|
|
|
|
Args:
|
|
x: The original tensor.
|
|
y: The tensor to compare to the original tensor.
|
|
"""
|
|
Ps = torch.norm(x)
|
|
Pn = torch.norm(x - y)
|
|
return 20 * torch.log10(Ps / Pn)
|
|
|
|
# largest power of 2 representable in `torch.float8_e4m3fn`
|
|
F8E4M3_LARGEST_POW2 = 8
|
|
# max value of `torch.float8_e4m3fn` (448)
|
|
F8E4M3_MAX_VAL = torch.finfo(torch.float8_e4m3fn).max
|
|
# exponent bias of `torch.float8_e8m0fnu`
|
|
F8E8M0_EXP_BIAS = 127
|
|
|
|
def data_to_mx_scale(x, block_size):
|
|
# simple implementation of https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
|
|
# section 6.3, not all edge cases (such as NaN) are handled/tested
|
|
orig_shape = x.shape
|
|
x = x.reshape(-1, block_size)
|
|
max_abs = torch.amax(torch.abs(x), 1)
|
|
largest_p2_lt_max_abs = torch.floor(torch.log2(max_abs))
|
|
scale_e8m0_unbiased = largest_p2_lt_max_abs - F8E4M3_LARGEST_POW2
|
|
scale_e8m0_unbiased = torch.clamp(scale_e8m0_unbiased, -1 * F8E8M0_EXP_BIAS, F8E8M0_EXP_BIAS)
|
|
scale_e8m0_biased = scale_e8m0_unbiased + F8E8M0_EXP_BIAS
|
|
scale_e8m0_biased = scale_e8m0_biased.to(torch.uint8)
|
|
scale_e8m0_biased = scale_e8m0_biased.view(torch.float8_e8m0fnu)
|
|
return scale_e8m0_biased.reshape(orig_shape[0], -1)
|
|
|
|
|
|
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not found")
|
|
class TestFP8MatmulCuda(TestCase):
|
|
|
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
def _test_tautological_mm(self, device: str = "cuda",
|
|
x_dtype: torch.dtype = e4m3_type,
|
|
y_dtype: torch.dtype = e4m3_type,
|
|
out_dtype: Optional[torch.dtype] = None,
|
|
size: int = 16) -> None:
|
|
x_fp8 = torch.rand(size, size, device=device).to(x_dtype)
|
|
y_fp8 = torch.eye(size, device=device, dtype=y_dtype).t()
|
|
out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float))
|
|
scale_a = torch.tensor(1.0, device=device)
|
|
scale_b = torch.tensor(1.0, device=device)
|
|
out_fp8 = torch._scaled_mm(x_fp8, y_fp8, scale_a, scale_b, out_dtype=out_dtype)
|
|
if out_dtype is not None:
|
|
self.assertEqual(out_dtype, out_fp8.dtype)
|
|
self.assertEqual(out_fp32, out_fp8.to(torch.float))
|
|
|
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
def test_float8_basics(self, device) -> None:
|
|
self._test_tautological_mm(device, e4m3_type, e4m3_type, size=16)
|
|
# According to https://docs.nvidia.com/cuda/cublas/#id99 8F_E5M2 MM is unsupported
|
|
# supported on ROCm but fails on CUDA
|
|
ctx = self.assertRaises(RuntimeError) if torch.version.hip is None else contextlib.nullcontext()
|
|
with ctx:
|
|
self._test_tautological_mm(device, e5m2_type, e5m2_type)
|
|
|
|
self._test_tautological_mm(device, e4m3_type, e5m2_type, size=32)
|
|
self._test_tautological_mm(device, e5m2_type, e4m3_type, size=48)
|
|
|
|
self._test_tautological_mm(device, size=64, out_dtype=torch.float16)
|
|
self._test_tautological_mm(device, size=96, out_dtype=torch.float32)
|
|
self._test_tautological_mm(device, size=80, out_dtype=torch.bfloat16)
|
|
|
|
with self.assertRaises(AssertionError if torch.version.hip else RuntimeError):
|
|
self._test_tautological_mm(device, out_dtype=e5m2_type)
|
|
|
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
def test_float8_scale(self, device) -> None:
|
|
size = (16, 16)
|
|
x = torch.full(size, .5, device=device, dtype=e4m3_type)
|
|
# hipblaslt does not yet support mixed e4m3_type input
|
|
y_type = e4m3_type if torch.version.hip else e5m2_type
|
|
y = torch.full(size, .5, device=device, dtype=y_type).t()
|
|
scale_one = torch.tensor(1.0, device=device)
|
|
scale_a = torch.tensor(1.5, device=device)
|
|
scale_b = torch.tensor(0.66, device=device)
|
|
out_fp8 = torch._scaled_mm(x, y, scale_a=scale_one, scale_b=scale_one)
|
|
self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4., device=device))
|
|
out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b)
|
|
self.assertEqual(out_fp8, out_fp8_s)
|
|
|
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
@parametrize("base_dtype", [torch.float16, torch.bfloat16, torch.float32])
|
|
def test_scaled_mm_vs_emulated(self, base_dtype):
|
|
torch.manual_seed(42)
|
|
input_dtype = e4m3_type
|
|
output_dtype = base_dtype
|
|
compare_type = torch.float32
|
|
|
|
x = torch.randn(16, 16, device="cuda", dtype=base_dtype)
|
|
y = torch.randn(32, 16, device="cuda", dtype=base_dtype).t()
|
|
|
|
x_scale = tensor_to_scale(x, input_dtype).float()
|
|
y_scale = tensor_to_scale(y, input_dtype).float()
|
|
|
|
x_fp8 = to_fp8_saturated(x * x_scale, input_dtype)
|
|
y_fp8 = to_fp8_saturated(y * y_scale, input_dtype)
|
|
|
|
# Calculate actual F8 mm
|
|
out_scaled_mm = mm_float8(
|
|
x_fp8,
|
|
y_fp8,
|
|
a_scale=x_scale,
|
|
b_scale=y_scale,
|
|
output_dtype=output_dtype
|
|
)
|
|
|
|
# Calculate emulated F8 mm
|
|
out_emulated = mm_float8_emulated(
|
|
x_fp8,
|
|
x_scale,
|
|
y_fp8,
|
|
y_scale,
|
|
output_dtype
|
|
)
|
|
|
|
if output_dtype != base_dtype:
|
|
out_scaled_mm = out_scaled_mm.to(compare_type)
|
|
out_scaled_mm = out_scaled_mm / tensor_to_scale(out_scaled_mm, input_dtype)
|
|
|
|
out_emulated = out_emulated.to(compare_type)
|
|
out_emulated = out_emulated / tensor_to_scale(out_emulated, input_dtype)
|
|
|
|
if base_dtype in {torch.bfloat16, torch.float16}:
|
|
atol, rtol = 7e-2, 7e-2
|
|
else:
|
|
atol, rtol = 3e-3, 3e-3
|
|
|
|
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
|
|
|
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
@parametrize("base_dtype", [torch.float16, torch.bfloat16, torch.float32])
|
|
def test_scaled_mm_change_stride(self, base_dtype):
|
|
torch.manual_seed(42)
|
|
input_dtype = e4m3_type
|
|
output_dtype = base_dtype
|
|
compare_type = torch.float32
|
|
|
|
x = torch.empty_strided((16, 16), (16, 1), device="cuda", dtype=base_dtype)
|
|
y = torch.empty_strided((16, 32), (1, 64), device="cuda", dtype=base_dtype)
|
|
|
|
x.normal_()
|
|
y.normal_()
|
|
|
|
x_scale = tensor_to_scale(x, input_dtype).float()
|
|
y_scale = tensor_to_scale(y, input_dtype).float()
|
|
|
|
x_fp8 = to_fp8_saturated(x * x_scale, input_dtype)
|
|
y_fp8 = to_fp8_saturated(y * y_scale, input_dtype)
|
|
|
|
# Calculate actual F8 mm
|
|
out_scaled_mm = mm_float8(
|
|
x_fp8,
|
|
y_fp8,
|
|
a_scale=x_scale,
|
|
b_scale=y_scale,
|
|
output_dtype=output_dtype
|
|
)
|
|
|
|
# Calculate emulated F8 mm
|
|
out_emulated = mm_float8_emulated(
|
|
x_fp8,
|
|
x_scale,
|
|
y_fp8,
|
|
y_scale,
|
|
output_dtype
|
|
)
|
|
|
|
if output_dtype != base_dtype:
|
|
out_scaled_mm = out_scaled_mm.to(compare_type)
|
|
out_scaled_mm = out_scaled_mm / tensor_to_scale(out_scaled_mm, input_dtype)
|
|
|
|
out_emulated = out_emulated.to(compare_type)
|
|
out_emulated = out_emulated / tensor_to_scale(out_emulated, input_dtype)
|
|
|
|
if base_dtype in {torch.bfloat16, torch.float16}:
|
|
atol, rtol = 7e-2, 7e-2
|
|
else:
|
|
atol, rtol = 3e-3, 3e-3
|
|
|
|
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
|
|
|
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
def test_float8_bias(self, device) -> None:
|
|
(k, l, m) = (16, 48, 32)
|
|
x = torch.ones((k, l), device=device).to(e4m3_type)
|
|
y = torch.full((m, l), .25, device=device, dtype=e4m3_type).t()
|
|
bias = torch.full((m,), 4.0, device=device, dtype=torch.half)
|
|
scale_a = torch.tensor(1.0, device=device)
|
|
scale_b = torch.tensor(1.0, device=device)
|
|
out_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b)
|
|
outb_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, bias=bias)
|
|
# this fails on ROCm currently because hipblaslt doesn't have amax op
|
|
out_fp32 = out_fp8.to(torch.float32)
|
|
outb_fp32 = outb_fp8.to(torch.float32)
|
|
difference = torch.abs(out_fp32 - outb_fp32)
|
|
self.assertEqual(difference, torch.tensor(4.0, device=device).expand_as(out_fp32))
|
|
|
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
@parametrize("bias", [True, False])
|
|
def test_non_divisible_leading_dim(self, device, bias: bool) -> None:
|
|
x = torch.rand((17, 16), device=device).to(e4m3_type)
|
|
y = torch.rand((16, 16), device=device).to(e4m3_type).t()
|
|
scale_a = torch.tensor(1.0, device=device)
|
|
scale_b = torch.tensor(1.0, device=device)
|
|
input_bias = None
|
|
if bias:
|
|
input_bias = torch.rand((16,), device=device).to(torch.half)
|
|
_ = torch._scaled_mm(x, y, scale_a, scale_b, bias=input_bias)
|
|
|
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
def test_float8_bias_relu_edgecase(self, device) -> None:
|
|
(k, l, m) = (16, 48, 32)
|
|
x = torch.full((k, l), 0.0, device=device).to(e4m3_type)
|
|
y = torch.full((m, l), 1.0, device=device, dtype=e4m3_type).t()
|
|
bias = torch.full((m,), -3.0, device=device, dtype=torch.half)
|
|
scale_a = torch.tensor(1.0, device=device)
|
|
scale_b = torch.tensor(1.0, device=device)
|
|
outb_fp8 = torch._scaled_mm(x, y, scale_a, scale_b, bias=bias)
|
|
outb_fp32 = outb_fp8.to(torch.float32)
|
|
self.assertEqual(outb_fp32, torch.tensor(-3.0, device=device).expand_as(outb_fp32))
|
|
|
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
def test_float32_output_errors_with_bias(self, device) -> None:
|
|
(k, l, m) = (16, 48, 32)
|
|
x = torch.rand((k, l), device=device).to(e4m3_type)
|
|
y = torch.full((m, l), .25, device=device, dtype=e4m3_type).t()
|
|
scale_a = torch.tensor(1.0, device=device)
|
|
scale_b = torch.tensor(1.0, device=device)
|
|
bias = torch.full((m,), 4.0, device=device, dtype=torch.bfloat16)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Bias is not supported when out_dtype is set to Float32",
|
|
lambda: torch._scaled_mm(x, y, scale_a, scale_b, bias=bias, out_dtype=torch.float32),
|
|
)
|
|
|
|
@unittest.skipIf(PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
def test_error_message_fp8_pre_sm89(self, device) -> None:
|
|
(k, l, m) = (16, 48, 32)
|
|
x = torch.rand((k, l), device=device).to(e4m3_type)
|
|
y = torch.rand((m, l), device=device).to(e4m3_type).t()
|
|
scale_a = torch.tensor(1.0, device=device)
|
|
scale_b = torch.tensor(1.0, device=device)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"torch\.\_scaled\_mm is only supported on CUDA devices with compute capability \>\= 9\.0 or 8\.9, or ROCm MI300\+",
|
|
lambda: torch._scaled_mm(x, y, scale_a, scale_b, out_dtype=torch.float32),
|
|
)
|
|
|
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
def test_float8_scale_fast_accum(self, device) -> None:
|
|
size = (16, 16)
|
|
x = torch.full(size, .5, device=device, dtype=e4m3_type)
|
|
# hipblaslt does not yet support mixed e4m3_type input
|
|
y_type = e4m3_type if torch.version.hip else e5m2_type
|
|
y = torch.full(size, .5, device=device, dtype=y_type).t()
|
|
scale_a = torch.tensor(1.5, device=device)
|
|
scale_b = torch.tensor(0.66, device=device)
|
|
out_fp8 = torch._scaled_mm(x, y, scale_a, scale_b, use_fast_accum=True)
|
|
self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4., device=device))
|
|
out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, use_fast_accum=True)
|
|
self.assertEqual(out_fp8, out_fp8_s)
|
|
|
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
|
|
@unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89+ specific")
|
|
@parametrize("use_fast_accum", [True, False])
|
|
def test_float8_rowwise_scaling_sanity(self, device, use_fast_accum: bool) -> None:
|
|
M, K, N = (1024, 512, 2048)
|
|
fill_value = 0.5
|
|
x = torch.full((M, K), fill_value, device=device)
|
|
y = torch.full((N, K), fill_value, device=device)
|
|
|
|
x_scales = torch.ones((x.shape[0], 1), device=device, dtype=torch.float32)
|
|
y_scales = torch.ones((1, y.shape[0]), device=device, dtype=torch.float32)
|
|
|
|
x_fp8 = x.to(e4m3_type)
|
|
y_fp8 = y.to(e4m3_type).t()
|
|
|
|
out_fp8 = torch._scaled_mm(
|
|
x_fp8,
|
|
y_fp8,
|
|
scale_a=x_scales,
|
|
scale_b=y_scales,
|
|
out_dtype=torch.bfloat16,
|
|
use_fast_accum=use_fast_accum,
|
|
)
|
|
self.assertEqual(
|
|
out_fp8.to(torch.float32), torch.full((M, N), K * (fill_value**2), device=device)
|
|
)
|
|
|
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
|
|
@skipIfRocm()
|
|
def test_float8_error_messages(self, device) -> None:
|
|
M, K, N = (1024, 512, 2048)
|
|
fill_value = 0.5
|
|
x = torch.full((M, K), fill_value, device=device)
|
|
y = torch.full((N, K), fill_value, device=device)
|
|
|
|
x_fp8 = x.to(e4m3_type)
|
|
y_fp8 = y.to(e4m3_type).t()
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
re.escape(
|
|
"For RowWise scaling, scale_a should be (1024, 1) and scale_b "
|
|
"should be (1, 2048). Got scale_a.size()=(1, 1) and scale_b.size()=(1, 2)"
|
|
),
|
|
):
|
|
torch._scaled_mm(
|
|
x_fp8,
|
|
y_fp8,
|
|
scale_a=torch.ones((1, 1), device="cuda"),
|
|
scale_b=torch.ones((1, 2), device="cuda"),
|
|
out_dtype=torch.bfloat16,
|
|
)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
re.escape(
|
|
" For RowWise scaling, scale_a should be (1024, 1) and scale_b "
|
|
"should be (1, 2048). Got scale_a.size()=(1024, 1) and scale_b.size()=(1, 2049)"
|
|
),
|
|
):
|
|
torch._scaled_mm(
|
|
x_fp8,
|
|
y_fp8,
|
|
scale_a=torch.ones((M, 1), device="cuda"),
|
|
scale_b=torch.ones((1, N + 1), device="cuda"),
|
|
out_dtype=torch.bfloat16,
|
|
)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
re.escape("For non-TensorWise scaling, scale tensors must be 2-dimensional"),
|
|
):
|
|
torch._scaled_mm(
|
|
x_fp8,
|
|
y_fp8,
|
|
scale_a=torch.ones((M), device="cuda"),
|
|
scale_b=torch.ones((N, N), device="cuda"),
|
|
out_dtype=torch.bfloat16,
|
|
)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
re.escape(
|
|
"Both scale_a and scale_b must be contiguous for RowWise scaling."
|
|
),
|
|
):
|
|
torch._scaled_mm(
|
|
x_fp8,
|
|
y_fp8,
|
|
scale_a=torch.ones((M, 1), device="cuda"),
|
|
scale_b=torch.ones((1, N * 2), device="cuda")[:, ::2],
|
|
out_dtype=torch.bfloat16,
|
|
)
|
|
|
|
# Note re.compile is used, not re.escape. This is to accomodate fn vs fnuz type message.
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"Expected b\.dtype\(\) == at::kFloat8_e4m3fnu?z? to be true, but got false\.",
|
|
):
|
|
torch._scaled_mm(
|
|
x_fp8,
|
|
y_fp8.to(e5m2_type),
|
|
scale_a=torch.ones((M, 1), device="cuda"),
|
|
scale_b=torch.ones((1, N), device="cuda"),
|
|
out_dtype=torch.bfloat16,
|
|
)
|
|
|
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
|
|
@unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89+ specific")
|
|
@parametrize("base_dtype", [torch.bfloat16])
|
|
def test_scaled_mm_vs_emulated_row_wise(self, base_dtype):
|
|
torch.manual_seed(42)
|
|
input_dtype = e4m3_type
|
|
output_dtype = base_dtype
|
|
|
|
x = torch.randn(16, 16, device="cuda", dtype=base_dtype)
|
|
y = torch.randn(32, 16, device="cuda", dtype=base_dtype).t()
|
|
|
|
x_scales = tensor_to_scale(x, input_dtype, dim=1).float()
|
|
y_scales = tensor_to_scale(y, input_dtype, dim=0).float()
|
|
|
|
x_fp8 = to_fp8_saturated(x * x_scales, e4m3_type)
|
|
y_fp8 = to_fp8_saturated(y * y_scales, e4m3_type)
|
|
|
|
# Calculate actual F8 mm
|
|
out_scaled_mm = mm_float8(
|
|
x_fp8, y_fp8, a_scale=x_scales, b_scale=y_scales, output_dtype=output_dtype
|
|
)
|
|
|
|
# Calculate emulated F8 mm
|
|
out_emulated = mm_float8_emulated(
|
|
x_fp8, x_scales, y_fp8, y_scales, output_dtype
|
|
)
|
|
|
|
if base_dtype in {torch.bfloat16, torch.float16}:
|
|
atol, rtol = 7e-2, 7e-2
|
|
else:
|
|
atol, rtol = 2e-3, 2e-3
|
|
|
|
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
|
|
|
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
@parametrize("which_dim_zero", [0, 1, 2])
|
|
@parametrize("use_torch_compile", [False, True])
|
|
def test_zero_dim_tensorwise(self, which_dim_zero, use_torch_compile) -> None:
|
|
device = "cuda"
|
|
x_dtype, y_dtype = torch.float8_e4m3fn, torch.float8_e4m3fn
|
|
out_dtype = torch.bfloat16
|
|
M, K, N = 32, 32, 32
|
|
if which_dim_zero == 0:
|
|
M = 0
|
|
elif which_dim_zero == 1:
|
|
K = 0
|
|
elif which_dim_zero == 2:
|
|
N = 0
|
|
|
|
x_fp8 = torch.zeros(M, K, device=device).to(x_dtype)
|
|
y_fp8 = torch.zeros(N, K, device=device, dtype=y_dtype).t()
|
|
out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float))
|
|
scale_a = torch.tensor(float('-inf'), device=device)
|
|
scale_b = torch.tensor(float('-inf'), device=device)
|
|
f = torch._scaled_mm
|
|
if use_torch_compile:
|
|
f = torch.compile(torch._scaled_mm)
|
|
out_fp8 = f(x_fp8, y_fp8, scale_a, scale_b, out_dtype=out_dtype)
|
|
self.assertEqual(out_dtype, out_fp8.dtype)
|
|
self.assertEqual(out_fp32, out_fp8.to(torch.float))
|
|
|
|
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support row-wise scaling")
|
|
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support row-wise scaling")
|
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
@unittest.skipIf(not SM90OrLater, "sm89 kernel isn't opted into carveout yet")
|
|
def test_honor_sm_carveout(self) -> None:
|
|
torch.manual_seed(42)
|
|
|
|
x = torch.randn(8192, 2048, device="cuda", dtype=torch.float32)
|
|
y = torch.randn(8192, 2048, device="cuda", dtype=torch.float32).t()
|
|
x_scales = tensor_to_scale(x, e4m3_type, dim=1).reciprocal()
|
|
y_scales = tensor_to_scale(y, e4m3_type, dim=0).reciprocal()
|
|
x_fp8 = to_fp8_saturated(x / x_scales, e4m3_type)
|
|
y_fp8 = to_fp8_saturated(y / y_scales, e4m3_type)
|
|
|
|
with tempfile.NamedTemporaryFile() as f:
|
|
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
|
|
self.assertIsNone(torch._C._get_sm_carveout_experimental())
|
|
torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
|
|
torch._C._set_sm_carveout_experimental(0)
|
|
self.assertEqual(torch._C._get_sm_carveout_experimental(), 0)
|
|
torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
|
|
torch._C._set_sm_carveout_experimental(66)
|
|
self.assertEqual(torch._C._get_sm_carveout_experimental(), 66)
|
|
torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
|
|
torch._C._set_sm_carveout_experimental(None)
|
|
self.assertIsNone(torch._C._get_sm_carveout_experimental())
|
|
torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
|
|
|
|
prof.export_chrome_trace(f.name)
|
|
no_carveout, carveout_0, carveout_66, no_carveout_again = [
|
|
math.prod(evt.get("args", {}).get("grid", []))
|
|
for evt in json.load(open(f.name))["traceEvents"]
|
|
if evt.get("cat", "") == "kernel"
|
|
]
|
|
|
|
self.assertEqual(no_carveout, no_carveout_again)
|
|
self.assertNotEqual(no_carveout, carveout_66)
|
|
self.assertNotEqual(carveout_66, carveout_0)
|
|
|
|
@unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM, mx_skip_msg)
|
|
@parametrize("test_case_name", [
|
|
"a_eye_b_eye",
|
|
"a_ones_b_ones",
|
|
"a_ones_modified_b_ones",
|
|
"a_ones_b_ones_modified",
|
|
"a_scale_modified_b_ones",
|
|
"a_ones_b_scale_modified",
|
|
"data_random_scales_one",
|
|
"data_random_scales_from_data",
|
|
])
|
|
@parametrize("fast_accum", [False, True])
|
|
@parametrize("mkn", [
|
|
# Nice shapes
|
|
(128, 128, 128),
|
|
(256, 256, 256),
|
|
(128, 256, 512),
|
|
(256, 512, 128),
|
|
(512, 128, 256),
|
|
|
|
# Non block multiples
|
|
(65, 96, 112),
|
|
(197, 224, 272),
|
|
# K not multiple of 32
|
|
(197, 240, 272),
|
|
|
|
# Very unbalanced
|
|
(1023, 64, 48),
|
|
(31, 1024, 64),
|
|
(45, 96, 1024),
|
|
|
|
# Mixed large and small
|
|
(2, 1024, 128),
|
|
(127, 96, 1024),
|
|
(1025, 128, 96)
|
|
], name_fn=lambda mkn: f"{mkn[0]}_{mkn[1]}_{mkn[2]}")
|
|
def test_blockwise_mxfp8_numerics(self, test_case_name, fast_accum, mkn) -> None:
|
|
# inspiration: https://github.com/pytorch/ao/pull/1625
|
|
|
|
device = "cuda"
|
|
M, K, N = mkn
|
|
BLOCK_SIZE = 32
|
|
require_exact_match = True
|
|
|
|
def ceil_div(a, b):
|
|
return (a + b - 1) // b
|
|
|
|
if test_case_name == "a_eye_b_eye":
|
|
if not ((M == K) and (M == N)):
|
|
return unittest.skip("this test is only defined for M == K == N, skipping")
|
|
A_ref = torch.eye(M, device=device, dtype=torch.bfloat16)
|
|
B_ref = torch.eye(M, device=device, dtype=torch.bfloat16)
|
|
|
|
A = A_ref.to(torch.float8_e4m3fn)
|
|
B = B_ref.to(torch.float8_e4m3fn)
|
|
|
|
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
|
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
|
# convert to swizzled format
|
|
A_scale = to_blocked(A_scale)
|
|
B_scale = to_blocked(B_scale)
|
|
|
|
elif test_case_name == "a_ones_b_ones":
|
|
A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16)
|
|
B_ref = torch.ones(N, K, device=device, dtype=torch.bfloat16)
|
|
|
|
A = A_ref.to(torch.float8_e4m3fn)
|
|
B = B_ref.to(torch.float8_e4m3fn)
|
|
|
|
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
|
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
|
# convert to swizzled format
|
|
A_scale = to_blocked(A_scale)
|
|
B_scale = to_blocked(B_scale)
|
|
|
|
elif test_case_name == "a_ones_modified_b_ones":
|
|
A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16)
|
|
B_ref = torch.ones(N, K, device=device, dtype=torch.bfloat16)
|
|
|
|
A = A_ref.to(torch.float8_e4m3fn)
|
|
B = B_ref.to(torch.float8_e4m3fn)
|
|
|
|
A_ref[1][0:BLOCK_SIZE] = 2
|
|
A[1][0:BLOCK_SIZE] = 2
|
|
|
|
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
|
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
|
# convert to swizzled format
|
|
A_scale = to_blocked(A_scale)
|
|
B_scale = to_blocked(B_scale)
|
|
|
|
elif test_case_name == "a_ones_b_ones_modified":
|
|
A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16)
|
|
B_ref = torch.ones(N, K, device=device, dtype=torch.bfloat16)
|
|
|
|
A = A_ref.to(torch.float8_e4m3fn)
|
|
B = B_ref.to(torch.float8_e4m3fn)
|
|
|
|
B_ref[1][0:BLOCK_SIZE] = 2
|
|
B[1][0:BLOCK_SIZE] = 2
|
|
|
|
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
|
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
|
# convert to swizzled format
|
|
A_scale = to_blocked(A_scale)
|
|
B_scale = to_blocked(B_scale)
|
|
|
|
elif test_case_name == "a_scale_modified_b_ones":
|
|
A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16)
|
|
B_ref = torch.ones(N, K, device=device, dtype=torch.bfloat16)
|
|
|
|
A = A_ref.to(torch.float8_e4m3fn)
|
|
B = B_ref.to(torch.float8_e4m3fn)
|
|
|
|
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
|
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
|
|
|
A_ref[1][0:BLOCK_SIZE] = 4
|
|
A[1][0:BLOCK_SIZE] = 2
|
|
A_scale[1][0] = 2
|
|
|
|
# convert to swizzled format
|
|
A_scale = to_blocked(A_scale)
|
|
B_scale = to_blocked(B_scale)
|
|
|
|
elif test_case_name == "a_ones_b_scale_modified":
|
|
A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16)
|
|
B_ref = torch.ones(N, K, device=device, dtype=torch.bfloat16)
|
|
|
|
A = A_ref.to(torch.float8_e4m3fn)
|
|
B = B_ref.to(torch.float8_e4m3fn)
|
|
|
|
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
|
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
|
|
|
B_ref[1][0:BLOCK_SIZE] = 4
|
|
B[1][0:BLOCK_SIZE] = 2
|
|
B_scale[1][0] = 2
|
|
|
|
# convert to swizzled format
|
|
A_scale = to_blocked(A_scale)
|
|
B_scale = to_blocked(B_scale)
|
|
|
|
elif test_case_name == "data_random_scales_one":
|
|
require_exact_match = False
|
|
# scales all-ones, element data random while being exactly representable in float8_e4m3fn
|
|
|
|
# generate integers in [0, 255] and interpret as float8_e4m3fn
|
|
A_ref = torch.randint(0, 255, (M, K), device=device, dtype=torch.uint8).view(torch.float8_e4m3fn).to(torch.bfloat16)
|
|
B_ref = torch.randint(0, 255, (N, K), device=device, dtype=torch.uint8).view(torch.float8_e4m3fn).to(torch.bfloat16)
|
|
# modification: don't allow NaN values
|
|
A_ref[torch.isnan(A_ref)] = 0
|
|
B_ref[torch.isnan(B_ref)] = 0
|
|
|
|
A = A_ref.to(torch.float8_e4m3fn)
|
|
B = B_ref.to(torch.float8_e4m3fn)
|
|
|
|
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
|
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
|
|
|
# convert to swizzled format
|
|
A_scale = to_blocked(A_scale)
|
|
B_scale = to_blocked(B_scale)
|
|
|
|
elif test_case_name == "data_random_scales_from_data":
|
|
if not K % BLOCK_SIZE == 0:
|
|
return unittest.skip(f"this test is only defined for K a multiple of {BLOCK_SIZE}, skipping")
|
|
require_exact_match = False
|
|
# random data, scales from data
|
|
A_ref = torch.randn((M, K), device=device, dtype=torch.bfloat16) * 1000
|
|
B_ref = torch.randn((N, K), device=device, dtype=torch.bfloat16) * 1000
|
|
|
|
# Calculate scales based on the inputs
|
|
A_scale = data_to_mx_scale(A_ref, BLOCK_SIZE)
|
|
B_scale = data_to_mx_scale(B_ref, BLOCK_SIZE)
|
|
|
|
max_val = F8E4M3_MAX_VAL
|
|
min_val = -1 * max_val
|
|
|
|
A = (A_ref.reshape(-1, BLOCK_SIZE) / A_scale.reshape(M * ceil_div(K, BLOCK_SIZE), 1).float()).reshape(M, K)
|
|
A = A.clamp(min=min_val, max=max_val).to(torch.float8_e4m3fn)
|
|
B = (B_ref.reshape(-1, BLOCK_SIZE) / B_scale.reshape(N * ceil_div(K, BLOCK_SIZE), 1).float()).reshape(N, K)
|
|
B = B.clamp(min=min_val, max=max_val).to(torch.float8_e4m3fn)
|
|
|
|
# convert to swizzled format
|
|
A_scale = to_blocked(A_scale)
|
|
B_scale = to_blocked(B_scale)
|
|
|
|
C_ref = A_ref @ B_ref.t()
|
|
|
|
C = torch._scaled_mm(
|
|
A,
|
|
B.t(),
|
|
A_scale,
|
|
B_scale,
|
|
out_dtype=torch.bfloat16,
|
|
use_fast_accum=fast_accum,
|
|
)
|
|
|
|
if require_exact_match:
|
|
torch.testing.assert_close(C, C_ref, atol=0, rtol=0)
|
|
else:
|
|
sqnr = compute_error(C_ref, C)
|
|
assert sqnr.item() > 22.0
|
|
|
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
|
|
@skipIfRocm()
|
|
def test_blockwise_mxfloat8_error_messages(self, device) -> None:
|
|
M, K, N = (1024, 512, 2048)
|
|
BLOCK_SIZE_K = 32
|
|
BLOCK_SIZE_MN = 128
|
|
fill_value = 0.5
|
|
|
|
x = torch.full((M, K), fill_value, device=device)
|
|
y = torch.full((N, K), fill_value, device=device)
|
|
|
|
x_fp8 = x.to(e4m3_type)
|
|
y_fp8 = y.to(e4m3_type).t()
|
|
|
|
def ceil_div(a, b):
|
|
return (a + b - 1) // b
|
|
|
|
num_k_blocks = ceil_div(K, BLOCK_SIZE_K)
|
|
padded_num_k_blocks = ceil_div(num_k_blocks, 4) * 4
|
|
expected_a_size = BLOCK_SIZE_MN * ceil_div(M, BLOCK_SIZE_MN) * padded_num_k_blocks
|
|
expected_b_size = BLOCK_SIZE_MN * ceil_div(N, BLOCK_SIZE_MN) * padded_num_k_blocks
|
|
|
|
|
|
# Test wrong scale tensor size for scale_a with correct dtype
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
re.escape(
|
|
f"For BlockWise scaling: Expected scale_a size to be {expected_a_size} "
|
|
f"but got {expected_a_size - 1}"
|
|
),
|
|
):
|
|
incorrect_size_a = torch.ones(expected_a_size - 1, device=device, dtype=torch.float8_e8m0fnu)
|
|
correct_size_b = torch.ones(expected_b_size, device=device, dtype=torch.float8_e8m0fnu)
|
|
torch._scaled_mm(
|
|
x_fp8,
|
|
y_fp8,
|
|
scale_a=incorrect_size_a,
|
|
scale_b=correct_size_b,
|
|
out_dtype=torch.bfloat16,
|
|
)
|
|
|
|
# Test wrong scale tensor size for scale_b with correct dtype
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
re.escape(
|
|
f"For BlockWise scaling: Expected scale_b size to be {expected_b_size} "
|
|
f"but got {expected_b_size + 1}"
|
|
),
|
|
):
|
|
correct_size_a = torch.ones(expected_a_size, device=device, dtype=torch.float8_e8m0fnu)
|
|
incorrect_size_b = torch.ones(expected_b_size + 1, device=device, dtype=torch.float8_e8m0fnu)
|
|
torch._scaled_mm(
|
|
x_fp8,
|
|
y_fp8,
|
|
scale_a=correct_size_a,
|
|
scale_b=incorrect_size_b,
|
|
out_dtype=torch.bfloat16,
|
|
)
|
|
|
|
# Test non-contiguous scale tensors with correct dtype
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
re.escape(
|
|
"For BlockWise scaling: Both scale_a and scale_b must be contiguous"
|
|
),
|
|
):
|
|
non_contiguous_a = torch.ones(expected_a_size * 2, device=device, dtype=torch.float8_e8m0fnu)[::2]
|
|
contiguous_b = torch.ones(expected_b_size, device=device, dtype=torch.float8_e8m0fnu)
|
|
torch._scaled_mm(
|
|
x_fp8,
|
|
y_fp8,
|
|
scale_a=non_contiguous_a,
|
|
scale_b=contiguous_b,
|
|
out_dtype=torch.bfloat16,
|
|
)
|
|
|
|
|
|
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
|
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions")
|
|
@unittest.skipIf(not _IS_SM8X, "mixed dtypes linear only supported on SM 8.x")
|
|
class TestMixedDtypesLinearCuda(TestCase):
|
|
@dtypes(torch.float16, torch.bfloat16)
|
|
def test_mixed_dtypes_linear(self, dtype: torch.dtype, device: str = "cuda"):
|
|
version = _get_torch_cuda_version()
|
|
if version < (11, 8):
|
|
self.skipTest("_mixed_dtypes_linear only compiled for CUDA 11.8+")
|
|
|
|
def run_test(
|
|
batch_shape,
|
|
m,
|
|
n,
|
|
k,
|
|
add_bias,
|
|
activation,
|
|
dtype,
|
|
dtypeq,
|
|
device,
|
|
rtol,
|
|
atol,
|
|
):
|
|
if not add_bias and activation != "none":
|
|
return
|
|
|
|
val_lo, val_hi = -1, 1
|
|
valq_lo, valq_hi = -2, 2
|
|
input = make_tensor(
|
|
*batch_shape, m, k, low=val_lo, high=val_hi, dtype=dtype, device=device
|
|
)
|
|
weight = make_tensor(
|
|
n, k, low=valq_lo, high=valq_hi, dtype=torch.int8, device=device
|
|
)
|
|
scale = make_tensor(
|
|
(n,), low=val_lo, high=val_hi, dtype=input.dtype, device=device
|
|
)
|
|
bias = (
|
|
make_tensor(
|
|
(n,), low=val_lo, high=val_hi, dtype=input.dtype, device=device
|
|
)
|
|
if add_bias
|
|
else None
|
|
)
|
|
|
|
input_ref = input.reshape(-1, input.shape[-1])
|
|
|
|
# First, test plain multiplication.
|
|
weight_ref = weight.T.to(input.dtype) * scale.view(1, n)
|
|
weightq = (
|
|
pack_int4_to_int8(weight.T) if dtypeq == torch.quint4x2 else weight.T
|
|
)
|
|
output_ref = torch.mm(input_ref, weight_ref).reshape(*input.shape[:-1], n)
|
|
output = torch.ops.aten._mixed_dtypes_linear(
|
|
input,
|
|
quantized_weight_reorder_for_mixed_dtypes_linear_cutlass(
|
|
weightq, dtypeq, transpose=False
|
|
),
|
|
scale,
|
|
)
|
|
torch.testing.assert_close(output, output_ref, rtol=rtol, atol=atol)
|
|
|
|
# Second, test the linear operator itself.
|
|
weight_ref = weight.to(input.dtype) * scale.view(n, 1)
|
|
weightq = pack_int4_to_int8(weight) if dtypeq == torch.quint4x2 else weight
|
|
bias_ref = bias.view(1, n) if add_bias else None
|
|
output_ref = torch.nn.functional.linear(
|
|
input_ref, weight_ref, bias=bias_ref
|
|
).reshape(*input.shape[:-1], n)
|
|
if activation == "relu":
|
|
relu = torch.nn.ReLU()
|
|
output_ref = relu(output_ref)
|
|
elif activation == "silu":
|
|
silu = torch.nn.SiLU()
|
|
output_ref = silu(output_ref)
|
|
output = torch.ops.aten._mixed_dtypes_linear(
|
|
input,
|
|
quantized_weight_reorder_for_mixed_dtypes_linear_cutlass(
|
|
weightq, dtypeq, transpose=True
|
|
),
|
|
scale,
|
|
bias=bias,
|
|
activation=activation,
|
|
)
|
|
torch.testing.assert_close(output, output_ref, rtol=rtol, atol=atol)
|
|
|
|
dtypeqs = [torch.int8, torch.quint4x2]
|
|
batch_shapes = [[], [2], [2, 1]]
|
|
shapes = [
|
|
[8, 64, 64],
|
|
[8, 64, 128],
|
|
[8, 128, 64],
|
|
[8, 128, 128],
|
|
[8, 128, 192],
|
|
[8, 128, 256],
|
|
[8, 256, 128],
|
|
[8, 256, 384],
|
|
[8, 384, 256],
|
|
]
|
|
activations = [None, "relu", "silu"]
|
|
rtol, atol = 1e-3, 1e-3
|
|
if dtype == torch.bfloat16:
|
|
rtol, atol = 1e-2, 1e-3
|
|
for dtypeq, batch_shape, (m, n, k), add_bias, activation in product(
|
|
dtypeqs, batch_shapes, shapes, (False, True), activations
|
|
):
|
|
run_test(
|
|
batch_shape,
|
|
m,
|
|
n,
|
|
k,
|
|
add_bias,
|
|
activation,
|
|
dtype,
|
|
dtypeq,
|
|
device,
|
|
rtol,
|
|
atol,
|
|
)
|
|
|
|
instantiate_device_type_tests(TestMatmulCuda, globals(), except_for="cpu")
|
|
instantiate_device_type_tests(TestFP8MatmulCuda, globals(), except_for="cpu")
|
|
instantiate_device_type_tests(TestMixedDtypesLinearCuda, globals(), except_for="cpu")
|
|
|
|
if __name__ == '__main__':
|
|
TestCase._default_dtype_check_enabled = True
|
|
run_tests()
|