pytorch/test/test_matmul_cuda.py
Jeff Daily 96793e0f10 [ROCm] enable scaled_gemm (#117822)
scaled_gemm for ROCm using hipblaslt.  As of ROCm 6.0, HIPBLASLT_MATMUL_DESC_AMAX_D_POINTER is not supported.  A work-around is provided, performing the absmax operation on the output buffer, but this results in some loss of accuracy for the absmax result.  For this reason the feature should be considered beta/preview.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117822
Approved by: https://github.com/jianyuh, https://github.com/xw285cornell
2024-02-29 10:20:48 +00:00

477 lines
20 KiB
Python

# Owner(s): ["module: linear algebra"]
import unittest
from itertools import product
from functools import partial
from typing import Optional
import torch
from torch.quantization._quantized_conversions import (
pack_int4_to_int8,
quantized_weight_reorder_for_mixed_dtypes_linear_cutlass,
)
from torch.testing import make_tensor
from torch.testing._internal.common_cuda import SM53OrLater, _get_torch_cuda_version
from torch.testing._internal.common_device_type import (
dtypes,
instantiate_device_type_tests,
onlyCUDA,
tol as xtol,
toleranceOverride,
)
from torch.testing._internal.common_utils import (
IS_ARM64,
IS_JETSON,
IS_WINDOWS,
parametrize,
run_tests,
skipIfRocmVersionLessThan,
TEST_WITH_ROCM,
skipIfRocm,
TestCase,
)
_IS_SM8X = False
if torch.cuda.is_available():
_IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8
# Protects against includes accidentally setting the default dtype
assert torch.get_default_dtype() is torch.float32
@unittest.skipIf(IS_ARM64, "Issue with numpy version on arm")
class TestMatmulCuda(TestCase):
def setUp(self):
super(self.__class__, self).setUp()
torch.backends.cuda.matmul.allow_tf32 = False
def tearDown(self):
torch.backends.cuda.matmul.allow_tf32 = True
super(self.__class__, self).tearDown()
def cublas_addmm(self, size: int, dtype: torch.dtype, reduced_precision: bool = False):
#
# Check for catastrophic cuBLAS inaccuracy by measuring the deviation between
# results from the CUDA invocation of torch.addmm and the CPU invocation
# (which does not use CUDA backend).
#
# Get dims
n, m, p = (size + 1, size, size + 2)
# Disable reduced precision reductions in BFloat16 to bypass some kernels
# which fail the threshold check
orig_bf16 = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction
orig_fp16 = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = reduced_precision
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = reduced_precision
# Make random tensors on CPU (seed set on common_utils.py import)
# (Not using numpy because it does not support bfloat16)
make_arg = partial(make_tensor, dtype=dtype, device="cpu")
m_beta = make_arg(1)
m_input = make_arg((n, p))
m_1 = make_arg((n, m))
m_2 = make_arg((m, p))
# *(B)FLOAT16 Special Handling*
# Backend does not tensorize float16 on CPU,
# and bloat16 may present accuracy issues,
# so convert to float32 for these cases
# (but keep same for other types, e.g. float32 and int*)
if dtype == torch.float16 or dtype == torch.bfloat16:
m_beta = m_beta.to(dtype=torch.float32)
m_input = m_input.to(dtype=torch.float32)
m_1 = m_1.to(dtype=torch.float32)
m_2 = m_2.to(dtype=torch.float32)
# Get CPU result
res_cpu = torch.addmm(m_input, m_1, m_2, beta=m_beta.item())
# *(B)FLOAT16 Special Handling*``
# Convert back to (b)float16
if dtype == torch.float16 or dtype == torch.bfloat16:
m_beta = m_beta.to(dtype=dtype)
m_input = m_input.to(dtype=dtype)
m_1 = m_1.to(dtype=dtype)
m_2 = m_2.to(dtype=dtype)
res_cpu = res_cpu.to(dtype=dtype)
# Move arg tensors to CUDA
m_beta = m_beta.to("cuda")
m_input = m_input.to("cuda")
m_1 = m_1.to("cuda")
m_2 = m_2.to("cuda")
# Get CUDA result
res_cuda = torch.addmm(m_input, m_1, m_2, beta=m_beta.item())
# Move to CPU for comparison
res_cuda = res_cuda.to("cpu")
# Compare
self.assertEqual(res_cpu, res_cuda)
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = orig_bf16
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig_fp16
@onlyCUDA
@skipIfRocmVersionLessThan((5, 2))
# imported 'tol' as 'xtol' to avoid aliasing in code above
@toleranceOverride({torch.float16: xtol(atol=1e-1, rtol=1e-1),
torch.bfloat16: xtol(atol=1e-1, rtol=1e-1),
torch.float32: xtol(atol=1e-1, rtol=1e-1)})
@dtypes(torch.float16, torch.bfloat16, torch.float32)
@parametrize("size", [100, 1000, 10000])
def test_cublas_addmm(self, size: int, dtype: torch.dtype):
self.cublas_addmm(size, dtype, False)
@onlyCUDA
@skipIfRocmVersionLessThan((5, 2))
# imported 'tol' as 'xtol' to avoid aliasing in code above
@toleranceOverride({torch.float16: xtol(atol=7e-1, rtol=2e-1),
torch.bfloat16: xtol(atol=1e1, rtol=2e-1)})
@dtypes(torch.float16, torch.bfloat16)
@parametrize("size", [100, 1000, 10000])
def test_cublas_addmm_reduced_precision(self, size: int, dtype: torch.dtype):
self.cublas_addmm(size, dtype, True)
@onlyCUDA
@toleranceOverride({torch.float16: xtol(atol=1e-3, rtol=2e-3)})
@dtypes(torch.float16)
def test_cublas_addmm_alignment(self, dtype):
device = 'cuda'
# perturb X, A, or B alignment
for idx in range(0, 3):
for offset in range(1, 3):
offsets = [0, 0, 0]
offsets[idx] = offset
x_offset, a_offset, b_offset = offsets
A = torch.rand((5120 * 2560 + a_offset), requires_grad=True, dtype=dtype, device=device)
A = A[a_offset:].reshape(5120, 2560)
X = torch.rand((26 * 2560 + x_offset), requires_grad=True, dtype=dtype, device=device)
X = X[x_offset:].reshape(26, 1, 2560)
B = torch.rand((5120 + b_offset), requires_grad=True, dtype=dtype, device=device)
B = B[b_offset:].reshape(5120)
out = torch.nn.functional.linear(X, A, B)
self.assertEqual(out, torch.matmul(X, A.transpose(1, 0)) + B)
@onlyCUDA
@unittest.skipIf(IS_JETSON, "Too large for Jetson")
@toleranceOverride({torch.float32: xtol(atol=1e-5, rtol=1.1e-5)})
@dtypes(*([torch.float32, torch.float16] +
[torch.bfloat16] if TEST_WITH_ROCM or SM53OrLater else []))
@parametrize(
"batch_size, N, M, P",
[(2, 100, 100, 100),
(2, 1000, 1000, 1000),
(1, 10000, 1000, 10000),
(1, 10000, 10000, 10000)],
name_fn=lambda batch_size, N, M, P: f"{batch_size}_{N}_{M}_{P}",
)
@skipIfRocm
def test_cublas_baddbmm_large_input(self, device, batch_size, N, M, P, dtype):
cpu_dtype = dtype
if dtype == torch.float16 or dtype == torch.bfloat16:
cpu_dtype = torch.float32
M1 = torch.rand((N, M), device=device, dtype=dtype)
M2 = torch.rand((M, P), device=device, dtype=dtype)
A = torch.rand((N, P), device=device, dtype=dtype)
def _convert_to_cpu(t):
return t.to(device='cpu', dtype=cpu_dtype)
M1_cpu, M2_cpu, A_cpu = map(_convert_to_cpu, [M1, M2, A])
# linear
out1_cpu = torch.nn.functional.linear(M1_cpu, M2_cpu.t(), A_cpu).to(dtype=dtype)
out1_gpu = torch.nn.functional.linear(M1, M2.t(), A).cpu()
self.assertEqual(out1_cpu, out1_gpu)
# test multiply the identity matrix
if N == M and M == P:
M2_eye = torch.eye(N, device=device, dtype=dtype)
out1_eye_gpu = torch.nn.functional.linear(M1, M2_eye.t(), torch.zeros_like(A))
self.assertEqual(M1_cpu.to(dtype=dtype), out1_eye_gpu.cpu())
# baddbmm
def _expand_to_batch(t: torch.Tensor):
return t.expand((batch_size, ) + t.size())
alpha, beta = 1.0, 1.0
M1, M2, A, M1_cpu, M2_cpu, A_cpu = map(_expand_to_batch, [M1, M2, A, M1_cpu, M2_cpu, A_cpu])
out2_cpu = torch.baddbmm(A_cpu, M1_cpu, M2_cpu, beta=beta, alpha=alpha).to(dtype=dtype)
out2_gpu = torch.baddbmm(A, M1, M2, beta=beta, alpha=alpha).cpu()
self.assertEqual(out2_cpu, out2_gpu)
# test multiply the identity matrix
if N == M and M == P:
M2_eye = torch.eye(N, device=device, dtype=dtype).expand(batch_size, N, N)
out2_eye_gpu = torch.baddbmm(torch.zeros_like(A), M1, M2_eye, beta=beta, alpha=alpha)
self.assertEqual(M1_cpu.to(dtype=dtype), out2_eye_gpu.cpu())
# cross comparison
self.assertEqual(out1_gpu, out2_gpu[0])
f8_msg = "FP8 is only supported on H100+ and sm_89 and MI300+ devices"
if torch.version.hip:
e4m3_type = torch.float8_e4m3fnuz
e5m2_type = torch.float8_e5m2fnuz
else:
e4m3_type = torch.float8_e4m3fn
e5m2_type = torch.float8_e5m2
def scaled_mm_supported_device():
if torch.cuda.is_available():
if torch.version.hip:
return 'gfx94' in torch.cuda.get_device_properties(0).gcnArchName
else:
return torch.cuda.get_device_capability() >= (9, 0) or torch.cuda.get_device_capability() == (8, 9)
return False
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not found")
class TestFP8MatmulCuda(TestCase):
@unittest.skipIf(not scaled_mm_supported_device(), f8_msg)
def _test_tautological_mm(self, device: str = "cuda",
x_dtype: torch.dtype = e4m3_type,
y_dtype: torch.dtype = e4m3_type,
out_dtype: Optional[torch.dtype] = None,
size: int = 16) -> None:
x_fp8 = torch.rand(size, size, device=device).to(x_dtype)
y_fp8 = torch.eye(size, device=device, dtype=y_dtype).t()
out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float))
(out_fp8, amax_fp8) = torch._scaled_mm(x_fp8, y_fp8, out_dtype=out_dtype)
if out_dtype is not None:
self.assertEqual(out_dtype, out_fp8.dtype)
if out_dtype not in [torch.float16, torch.bfloat16, torch.float]:
self.assertEqual(out_fp32.amax(), amax_fp8)
self.assertEqual(out_fp32, out_fp8.to(torch.float))
@unittest.skipIf(not scaled_mm_supported_device(), f8_msg)
def test_float8_basics(self, device) -> None:
self._test_tautological_mm(device, e4m3_type, e4m3_type, size=16)
# hipblaslt does not yet support mixed e4m3_type input
if torch.version.hip is None:
self._test_tautological_mm(device, e4m3_type, e5m2_type, size=32)
self._test_tautological_mm(device, e5m2_type, e4m3_type, size=48)
# According to https://docs.nvidia.com/cuda/cublas/#id99 8F_E5M2 MM is unsupported
with self.assertRaises(RuntimeError):
self._test_tautological_mm(device, e5m2_type, e5m2_type)
self._test_tautological_mm(device, size=64, out_dtype=torch.float16)
self._test_tautological_mm(device, size=96, out_dtype=torch.float32)
# hipblaslt does not yet support bfloat16 output
if torch.version.hip is None:
self._test_tautological_mm(device, size=80, out_dtype=torch.bfloat16)
with self.assertRaises(RuntimeError):
self._test_tautological_mm(device, out_dtype=e5m2_type)
@unittest.skipIf(not scaled_mm_supported_device(), f8_msg)
def test_float8_scale(self, device) -> None:
size = (16, 16)
x = torch.full(size, .5, device=device, dtype=e4m3_type)
# hipblaslt does not yet support mixed e4m3_type input
y_type = e4m3_type if torch.version.hip else e5m2_type
y = torch.full(size, .5, device=device, dtype=y_type).t()
scale_a = torch.tensor(1.5, device=device)
scale_b = torch.tensor(0.66, device=device)
out_fp8, amax_fp8 = torch._scaled_mm(x, y)
self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4., device=device))
out_fp8_s, amax_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b)
self.assertEqual(out_fp8, out_fp8_s)
@unittest.skipIf(not scaled_mm_supported_device(), f8_msg)
def test_float8_bias(self, device) -> None:
(k, l, m) = (16, 48, 32)
x = torch.rand((k, l), device=device).to(e4m3_type)
y = torch.full((m, l), .25, device=device, dtype=e4m3_type).t()
bias = torch.full((m,), 4.0, device=device, dtype=torch.half)
out_fp8, amax_fp8 = torch._scaled_mm(x, y)
outb_fp8, amaxb_fp8 = torch._scaled_mm(x, y, bias=bias)
# this fails on ROCm currently because hipblaslt doesn't have amax op
if torch.version.hip is None:
self.assertEqual((amaxb_fp8 - amax_fp8).item(), 4.0)
@unittest.skipIf(not scaled_mm_supported_device(), f8_msg)
@parametrize("bias", [True, False])
def test_non_divisible_leading_dim(self, device, bias: torch.bool) -> None:
x = torch.rand((17, 16), device=device).to(e4m3_type)
y = torch.rand((16, 16), device=device).to(e4m3_type).t()
input_bias = None
if bias:
input_bias = torch.rand((16,), device=device).to(torch.half)
out_fp8, amax_fp8 = torch._scaled_mm(x, y, bias=input_bias)
@unittest.skipIf(not scaled_mm_supported_device(), f8_msg)
def test_float8_bias_relu_edgecase(self, device) -> None:
(k, l, m) = (16, 48, 32)
x = torch.full((k, l), 0.0, device=device).to(e4m3_type)
y = torch.full((m, l), 1.0, device=device, dtype=e4m3_type).t()
bias = torch.full((m,), -3.0, device=device, dtype=torch.half)
outb_fp8, amaxb_fp8 = torch._scaled_mm(x, y, bias=bias)
self.assertEqual(amaxb_fp8.item(), 3.0)
@unittest.skipIf(not scaled_mm_supported_device(), f8_msg)
def test_float32_output_errors_with_bias(self, device) -> None:
(k, l, m) = (16, 48, 32)
x = torch.rand((k, l), device=device).to(e4m3_type)
y = torch.full((m, l), .25, device=device, dtype=e4m3_type).t()
bias = torch.full((m,), 4.0, device=device, dtype=torch.bfloat16)
self.assertRaisesRegex(
RuntimeError,
"Bias is not supported when out_dtype is set to Float32",
lambda: torch._scaled_mm(x, y, bias=bias, out_dtype=torch.float32),
)
@unittest.skipIf(scaled_mm_supported_device(),
"This test is only for devices with compute capability < 8.9")
def test_error_message_fp8_pre_sm89(self, device) -> None:
(k, l, m) = (16, 48, 32)
x = torch.rand((k, l), device=device).to(e4m3_type)
y = torch.rand((m, l), device=device).to(e4m3_type).t()
self.assertRaisesRegex(
RuntimeError,
r"torch\.\_scaled\_mm is only supported on CUDA devices with compute capability \>\= 9\.0 or 8\.9, or ROCm MI300\+",
lambda: torch._scaled_mm(x, y, out_dtype=torch.float32),
)
@unittest.skipIf(not scaled_mm_supported_device(), f8_msg)
def test_float8_scale_fast_accum(self, device) -> None:
size = (16, 16)
x = torch.full(size, .5, device=device, dtype=e4m3_type)
# hipblaslt does not yet support mixed e4m3_type input
y_type = e4m3_type if torch.version.hip else e5m2_type
y = torch.full(size, .5, device=device, dtype=y_type).t()
scale_a = torch.tensor(1.5, device=device)
scale_b = torch.tensor(0.66, device=device)
out_fp8, amax_fp8 = torch._scaled_mm(x, y, use_fast_accum=True)
self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4., device=device))
out_fp8_s, amax_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, use_fast_accum=True)
self.assertEqual(out_fp8, out_fp8_s)
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions")
@unittest.skipIf(not _IS_SM8X, "mixed dtypes linear only supported on SM 8.x")
class TestMixedDtypesLinearCuda(TestCase):
@dtypes(torch.float16, torch.bfloat16)
def test_mixed_dtypes_linear(self, dtype: torch.dtype, device: str = "cuda"):
version = _get_torch_cuda_version()
if version < (11, 8):
self.skipTest("_mixed_dtypes_linear only compiled for CUDA 11.8+")
def run_test(
batch_shape,
m,
n,
k,
add_bias,
activation,
dtype,
dtypeq,
device,
rtol,
atol,
):
if not add_bias and activation != "none":
return
val_lo, val_hi = -1, 1
valq_lo, valq_hi = -2, 2
input = make_tensor(
*batch_shape, m, k, low=val_lo, high=val_hi, dtype=dtype, device=device
)
weight = make_tensor(
n, k, low=valq_lo, high=valq_hi, dtype=torch.int8, device=device
)
scale = make_tensor(
(n,), low=val_lo, high=val_hi, dtype=input.dtype, device=device
)
bias = (
make_tensor(
(n,), low=val_lo, high=val_hi, dtype=input.dtype, device=device
)
if add_bias
else None
)
input_ref = input.reshape(-1, input.shape[-1])
# First, test plain multiplication.
weight_ref = weight.T.to(input.dtype) * scale.view(1, n)
weightq = (
pack_int4_to_int8(weight.T) if dtypeq == torch.quint4x2 else weight.T
)
output_ref = torch.mm(input_ref, weight_ref).reshape(*input.shape[:-1], n)
output = torch.ops.aten._mixed_dtypes_linear(
input,
quantized_weight_reorder_for_mixed_dtypes_linear_cutlass(
weightq, dtypeq, transpose=False
),
scale,
)
torch.testing.assert_close(output, output_ref, rtol=rtol, atol=atol)
# Second, test the linear operator itself.
weight_ref = weight.to(input.dtype) * scale.view(n, 1)
weightq = pack_int4_to_int8(weight) if dtypeq == torch.quint4x2 else weight
bias_ref = bias.view(1, n) if add_bias else None
output_ref = torch.nn.functional.linear(
input_ref, weight_ref, bias=bias_ref
).reshape(*input.shape[:-1], n)
if activation == "relu":
relu = torch.nn.ReLU()
output_ref = relu(output_ref)
elif activation == "silu":
silu = torch.nn.SiLU()
output_ref = silu(output_ref)
output = torch.ops.aten._mixed_dtypes_linear(
input,
quantized_weight_reorder_for_mixed_dtypes_linear_cutlass(
weightq, dtypeq, transpose=True
),
scale,
bias=bias,
activation=activation,
)
torch.testing.assert_close(output, output_ref, rtol=rtol, atol=atol)
dtypeqs = [torch.int8, torch.quint4x2]
batch_shapes = [[], [2], [2, 1]]
shapes = [
[8, 64, 64],
[8, 64, 128],
[8, 128, 64],
[8, 128, 128],
[8, 128, 192],
[8, 128, 256],
[8, 256, 128],
[8, 256, 384],
[8, 384, 256],
]
activations = [None, "relu", "silu"]
rtol, atol = 1e-3, 1e-3
if dtype == torch.bfloat16:
rtol, atol = 1e-2, 1e-3
for dtypeq, batch_shape, (m, n, k), add_bias, activation in product(
dtypeqs, batch_shapes, shapes, (False, True), activations
):
run_test(
batch_shape,
m,
n,
k,
add_bias,
activation,
dtype,
dtypeq,
device,
rtol,
atol,
)
instantiate_device_type_tests(TestMatmulCuda, globals(), except_for="cpu")
instantiate_device_type_tests(TestFP8MatmulCuda, globals(), except_for="cpu")
instantiate_device_type_tests(TestMixedDtypesLinearCuda, globals(), except_for="cpu")
if __name__ == '__main__':
TestCase._default_dtype_check_enabled = True
run_tests()