Add torch._scaled_mm for CPU (#139975)

This PR is to add `torch._scaled_mm` for CPU backend.

`_scaled_mm_out_cpu` and `_scaled_mm_cpu` are new added and included in `torch._scaled_mm` CPU dispatch. We also add `_scaled_mm_out_cpu_emulated` as a fallback function if the current platform cannot run FP8 matmul using oneDNN. And this PR also updates the various UTs related to FP8 to support CPU tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139975
Approved by: https://github.com/mingfeima, https://github.com/jgong5, https://github.com/malfet
This commit is contained in:
Jiang, Yanbing 2024-12-27 06:48:25 +00:00 committed by PyTorch MergeBot
parent d3e9133ab2
commit cbc4cf3043
11 changed files with 908 additions and 574 deletions

View File

@ -7,6 +7,11 @@
#include <ATen/Config.h>
#include <ATen/native/mkldnn/Matmul.h>
#include <ATen/native/mkldnn/Linear.h>
#include <ATen/native/Resize.h>
#if !defined(__s390x__) && !defined(__powerpc__)
#include <cpuinfo.h>
#endif
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/CPUFunctions.h>
@ -24,6 +29,9 @@
#include <ATen/ops/mv_native.h>
#include <ATen/ops/scalar_tensor_native.h>
#include <ATen/ops/vdot_native.h>
#include <ATen/ops/_scaled_mm_native.h>
#include <ATen/ops/mul.h>
#include <ATen/ops/matmul.h>
#endif
namespace at::meta {
@ -222,4 +230,79 @@ Tensor vdot(const Tensor &self, const Tensor &other){
}
static Tensor&
_scaled_mm_out_cpu_emulated(const Tensor& mat1, const Tensor& mat2,
const Tensor& scale_a,
const Tensor& scale_b,
const std::optional<at::Tensor>& bias,
const std::optional<at::Tensor>& scale_result,
std::optional<c10::ScalarType> out_dtype,
bool use_fast_accum,
Tensor& out) {
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix");
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
TORCH_CHECK(
mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
TORCH_INTERNAL_ASSERT((scale_a.numel() == 1 && scale_b.numel() == 1), "Now _scaled_mm only supports per-tensor scaling for CPU backend.");
TORCH_CHECK(!bias || bias->numel() == mat2.sizes()[1], "Bias must be size ", mat2.sizes()[1],
" but got ", bias->numel());
// Check types
TORCH_CHECK(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type");
TORCH_CHECK(isFloat8Type(mat1.scalar_type()), "Expected mat1 to be Float8 matrix got ", mat1.scalar_type());
TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat2.scalar_type());
auto mat1_c = mat1.contiguous();
auto mat2_c = mat2.contiguous();
IntArrayRef mat1_sizes = mat1_c.sizes();
IntArrayRef mat2_sizes = mat2_c.sizes();
at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
float input_scale = scale_a.item<float>();
float weight_scale = scale_b.item<float>();
auto fp32_mat1 = at::mul(mat1.to(kFloat), input_scale);
auto fp32_mat2 = at::mul(mat2_c.to(kFloat), weight_scale);
auto out_tmp = at::matmul(fp32_mat1, fp32_mat2);
if (bias) {
out_tmp.add_(bias.value());
}
out_tmp = out_tmp.to(out.scalar_type());
out.copy_(out_tmp);
return out;
}
Tensor&
_scaled_mm_out_cpu(const Tensor& mat1, const Tensor& mat2,
const Tensor& scale_a,
const Tensor& scale_b,
const std::optional<at::Tensor>& bias,
const std::optional<at::Tensor>& scale_result,
std::optional<c10::ScalarType> out_dtype,
bool use_fast_accum,
Tensor& out) {
#if AT_MKLDNN_ENABLED()
if (at::globalContext().userEnabledMkldnn() && cpuinfo_has_x86_amx_int8()) {
return mkldnn_scaled_mm(mat1, mat2, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out);
} else
#endif
{
return _scaled_mm_out_cpu_emulated(mat1, mat2, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out);
}
}
Tensor
_scaled_mm_cpu(const Tensor& mat_a, const Tensor& mat_b,
const Tensor& scale_a,
const Tensor& scale_b,
const std::optional<at::Tensor>& bias,
const std::optional<at::Tensor>& scale_result,
std::optional<c10::ScalarType> out_dtype,
bool use_fast_accum) {
const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type());
Tensor out = at::empty({0}, mat_a.options().dtype(out_dtype_));
return _scaled_mm_out_cpu(mat_a, mat_b, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out);
}
} // namespace at::native

View File

@ -4,6 +4,7 @@
#include <ATen/core/Tensor.h>
#include <torch/library.h>
#include <ATen/native/mkldnn/Linear.h>
#include <ATen/native/Resize.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
@ -46,6 +47,18 @@ std::tuple<Tensor, Tensor, Tensor> mkldnn_linear_backward(
TORCH_CHECK(false, "mkldnn_linear_backward: ATen not compiled with MKLDNN support");
}
Tensor&
mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2,
const Tensor& scale_a,
const Tensor& scale_b,
const std::optional<at::Tensor>& bias,
const std::optional<at::Tensor>& scale_result,
std::optional<c10::ScalarType> out_dtype,
bool use_fast_accum,
Tensor& out) {
TORCH_INTERNAL_ASSERT(false, "mkldnn_scaled_mm: ATen not compiled with MKLDNN support");
}
} // namespace native
} // namespace at
@ -447,6 +460,119 @@ TORCH_LIBRARY_IMPL(mkldnn, MkldnnCPU, m) {
TORCH_FN(mkldnn_linear_pointwise_binary));
}
Tensor&
mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2,
const Tensor& scale_a,
const Tensor& scale_b,
const std::optional<at::Tensor>& bias,
const std::optional<at::Tensor>& scale_result,
std::optional<c10::ScalarType> out_dtype,
bool use_fast_accum,
Tensor& out) {
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix");
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
TORCH_CHECK(
mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
TORCH_INTERNAL_ASSERT((scale_a.numel() == 1 && scale_b.numel() == 1), "Now _scaled_mm only supports per-tensor scaling for CPU backend.");
TORCH_CHECK(!bias || bias->numel() == mat2.sizes()[1], "Bias must be size ", mat2.sizes()[1],
" but got ", bias->numel());
// Check types
TORCH_CHECK(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type");
TORCH_CHECK(isFloat8Type(mat1.scalar_type()), "Expected mat1 to be Float8 matrix got ", mat1.scalar_type());
TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat2.scalar_type());
// TODO: This check of mat1 and mat2 must have the same data type will be removed after oneDNN v3.6.
TORCH_CHECK(mat1.scalar_type() == mat2.scalar_type(), "Expected mat1 and mat2 must have the same data type");
// Validation checks have passed lets resize the output to actual size
auto mat1_c = mat1.contiguous();
auto mat2_c = mat2.contiguous();
IntArrayRef mat1_sizes = mat1_c.sizes();
IntArrayRef mat2_sizes = mat2_c.sizes();
at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
float input_scale = scale_a.item<float>();
float weight_scale = scale_b.item<float>();
auto src = at::native::itensor_view_from_dense(mat1_c);
auto weight_t = at::native::itensor_view_from_dense(mat2_c);
bool with_bias = bias.has_value();
int64_t K = mat1_sizes[1], M = mat1_sizes[0],
N = mat2_sizes[1];
std::vector<int64_t> src_dims = {M, K};
std::vector<int64_t> weight_dims = {K, N};
std::vector<int64_t> dst_dims = {M, N};
ideep::tensor dst = at::native::itensor_view_from_dense(out);
auto src_desc = ideep::tensor::desc(
src_dims,
get_mkldnn_dtype(mat1.scalar_type()),
ideep::format_tag::any);
auto weights_desc = ideep::tensor::desc(
weight_dims,
get_mkldnn_dtype(mat2.scalar_type()),
ideep::format_tag::any);
auto dst_desc = ideep::tensor::desc(
dst_dims,
get_mkldnn_dtype(out.scalar_type()),
ideep::format_tag::any);
ideep::tensor onednn_bias;
if (with_bias) {
auto bias_value = bias.value();
if (bias_value.dim() == 1) {
auto b_reshape = bias_value.reshape({1, bias_value.size(0)});
onednn_bias = at::native::itensor_view_from_dense(b_reshape);
} else {
onednn_bias = at::native::itensor_view_from_dense(bias_value);
}
}
auto bias_desc = ideep::tensor::desc();
if (with_bias) {
bias_desc = ideep::tensor::desc(onednn_bias.get_dims(),
get_mkldnn_dtype(bias.value().scalar_type()),
ideep::format_tag::any);
}
auto op_attr = ideep::attr_t();
if (input_scale != 1.0f) {
op_attr.set_scales_mask(DNNL_ARG_SRC, 0);
}
if (weight_scale != 1.0f) {
op_attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0);
}
op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
auto engine = ideep::engine::cpu_engine();
dnnl::matmul::primitive_desc primitive_desc = with_bias
? dnnl::matmul::primitive_desc(
engine, src_desc, weights_desc, bias_desc, dst_desc, op_attr)
: dnnl::matmul::primitive_desc(
engine, src_desc, weights_desc, dst_desc, op_attr);
auto primitive = dnnl::matmul(primitive_desc);
// Prepare args and execute primitive
ideep::tensor scratchpad(primitive_desc.scratchpad_desc());
ideep::exec_args args;
args.insert({DNNL_ARG_SRC, src});
args.insert({DNNL_ARG_WEIGHTS, weight_t});
args.insert({DNNL_ARG_DST, dst});
args.insert({DNNL_ARG_SCRATCHPAD, scratchpad});
if (with_bias) {
args.insert({DNNL_ARG_BIAS, onednn_bias});
}
ideep::tensor src_scales_t = ideep::tensor(ideep::scale_t(1, input_scale));
ideep::tensor wei_scales_t = ideep::tensor(ideep::scale_t(1, weight_scale));
if (input_scale != 1.0f) {
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scales_t});
}
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_scales_t});
primitive.execute(ideep::stream::default_stream(), args);
return out;
}
} // namespace at
#endif // AT_MKLDNN_ENABLED

View File

@ -35,3 +35,15 @@ C10_API Tensor mkl_linear(
} // namespace at
#endif // AT_MKLDNN_ENABLED()
namespace at::native {
Tensor&
mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2,
const Tensor& scale_a,
const Tensor& scale_b,
const std::optional<at::Tensor>& bias,
const std::optional<at::Tensor>& scale_result,
std::optional<c10::ScalarType> out_dtype,
bool use_fast_accum,
Tensor& out);
} // namespace at::native

View File

@ -57,6 +57,10 @@ ideep::tensor::data_type get_mkldnn_dtype(ScalarType type) {
return ideep::tensor::data_type::bf16;
case ScalarType::Half:
return ideep::tensor::data_type::f16;
case ScalarType::Float8_e4m3fn:
return ideep::tensor::data_type::f8_e4m3;
case ScalarType::Float8_e5m2:
return ideep::tensor::data_type::f8_e5m2;
default:
TORCH_CHECK(false, "get_mkldnn_dtype: unsupported data type");
}
@ -161,8 +165,24 @@ ideep::tensor itensor_view_from_dense(const Tensor& tensor, bool from_const_data
const_cast<void*>(tensor.const_data_ptr()) :
tensor.data_ptr()};
}
else if (tensor.scalar_type() == ScalarType::Float8_e4m3fn) {
return {{tensor.sizes().vec(),
ideep::tensor::data_type::f8_e4m3,
tensor.strides().vec()},
from_const_data_ptr ?
const_cast<void*>(tensor.const_data_ptr()) :
tensor.data_ptr()};
}
else if (tensor.scalar_type() == ScalarType::Float8_e5m2) {
return {{tensor.sizes().vec(),
ideep::tensor::data_type::f8_e5m2,
tensor.strides().vec()},
from_const_data_ptr ?
const_cast<void*>(tensor.const_data_ptr()) :
tensor.data_ptr()};
}
else {
TORCH_CHECK(false, "itensor_view_from_dense expects float/bfloat16/half/int8 tensor input");
TORCH_CHECK(false, "itensor_view_from_dense expects float/bfloat16/half/int8/fp8 tensor input");
}
}

View File

@ -7071,11 +7071,13 @@
- func: _scaled_mm(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False) -> Tensor
variants: function
dispatch:
CPU: _scaled_mm_cpu
CUDA: _scaled_mm_cuda
- func: _scaled_mm.out(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False, *, Tensor(a!) out) -> Tensor(a!)
variants: function
dispatch:
CPU: _scaled_mm_out_cpu
CUDA: _scaled_mm_out_cuda
# NOTE [ Sparse: autograd and API ]

View File

@ -13,7 +13,7 @@ from torch.testing._internal.common_utils import (
parametrize,
TEST_WITH_ROCM,
)
from torch.testing._internal.inductor_utils import HAS_CUDA
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
from torch.utils._triton import has_triton_tma_device
@ -89,10 +89,10 @@ def _quantize_rowwise(x: Tensor, float8_dtype: torch.dtype):
@instantiate_parametrized_tests
class TestFP8Types(TestCase):
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@unittest.skipIf(TEST_WITH_ROCM, "Not supported yet")
@parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2))
def test_xblock_for_small_numel(self, float8_dtype: torch.dtype):
@parametrize("device", ("cuda", "cpu"))
def test_xblock_for_small_numel(self, float8_dtype: torch.dtype, device: str):
"""
TritonOverrides.to_dtype will set min_elem_per_thread to 2 or 4
depends on the variant of fp8 type.
@ -101,19 +101,23 @@ class TestFP8Types(TestCase):
We should not pick a XBLOCK larger than xnumel
"""
if device == "cuda" and not PLATFORM_SUPPORTS_FP8:
raise unittest.SkipTest(f8_msg)
def f(x):
return x.to(dtype=float8_dtype)
x = torch.randn(1, device="cuda")
x = torch.randn(1, device=device)
expected = f(x)
actual = torch.compile(f)(x)
torch.testing.assert_close(expected.half(), actual.half(), rtol=1e-2, atol=1e-2)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@unittest.skipIf(TEST_WITH_ROCM, "Not supported yet")
@parametrize("dtype", (torch.float16, torch.bfloat16))
def test_eager_fallback(self, dtype: torch.dtype):
@parametrize("device", ("cuda", "cpu"))
def test_eager_fallback(self, dtype: torch.dtype, device: torch.device):
if device == "cuda" and not PLATFORM_SUPPORTS_FP8:
raise unittest.SkipTest(f8_msg)
weight_shape = (32, 16)
e4m3_type = (
@ -121,11 +125,11 @@ class TestFP8Types(TestCase):
)
def fp8_matmul_unwrapped(x):
a_scale = torch.Tensor([1.0]).to(device="cuda")
b_scale = torch.Tensor([1.0]).to(device="cuda")
a_scale = torch.Tensor([1.0]).to(device=device)
b_scale = torch.Tensor([1.0]).to(device=device)
output_scale = None
input_bias = torch.rand(32, device="cuda", dtype=dtype)
weight = torch.rand(*weight_shape, device="cuda", dtype=dtype).T.to(
input_bias = torch.rand(32, device=device, dtype=dtype)
weight = torch.rand(*weight_shape, device=device, dtype=dtype).T.to(
e4m3_type
)
a_inverse_scale = 1 / a_scale
@ -146,14 +150,13 @@ class TestFP8Types(TestCase):
)
x_shape = (16, 16)
x = torch.rand(*x_shape, device="cuda", dtype=dtype).to(e4m3_type)
x = torch.rand(*x_shape, device=device, dtype=dtype).to(e4m3_type)
y_fp8 = compiled_fp8_matmul(x) # noqa: F841
x_shape = (15, 16)
x = torch.rand(*x_shape, device="cuda", dtype=dtype).to(e4m3_type)
x = torch.rand(*x_shape, device=device, dtype=dtype).to(e4m3_type)
y_fp8 = compiled_fp8_matmul(x) # noqa: F841
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@parametrize("dtype", (torch.float16, torch.bfloat16, torch.float))
@parametrize("shape", ("15,3,13", "4,2048,4096"))
@parametrize(
@ -162,7 +165,12 @@ class TestFP8Types(TestCase):
if torch.version.hip is None
else [(torch.float8_e4m3fnuz, torch.float8_e5m2fnuz)],
)
def test_valid_cast(self, dtype: torch.dtype, shape: str, dst_types: tuple):
@parametrize("device", ("cuda", "cpu"))
def test_valid_cast(
self, dtype: torch.dtype, shape: str, dst_types: tuple, device: torch.device
):
if device == "cuda" and not PLATFORM_SUPPORTS_FP8:
raise unittest.SkipTest(f8_msg)
e4m3, e5m2 = dst_types
def fp8_cast(x):
@ -173,7 +181,7 @@ class TestFP8Types(TestCase):
compiled_fp8_cast = torch.compile(fp8_cast, backend="inductor", dynamic=True)
shape = [int(dim) for dim in shape.split(",")]
x = torch.rand(*shape, device="cuda", dtype=dtype)
x = torch.rand(*shape, device=device, dtype=dtype)
y0_fp8, y1_fp8 = compiled_fp8_cast(x)
torch.testing.assert_close(y0_fp8, x, rtol=5e-1, atol=5e-1)
@ -202,7 +210,6 @@ class TestFP8Types(TestCase):
x = torch.rand(*x_shape, device="cuda").to(dtype=torch.float8_e5m2)
compiled_fp8_cast(x, torch.float8_e4m3fn)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@parametrize("src_dtype", (torch.float16, torch.bfloat16, torch.float))
@parametrize(
"dst_dtype",
@ -211,9 +218,17 @@ class TestFP8Types(TestCase):
else (torch.float8_e4m3fnuz, torch.float8_e5m2fnuz),
)
@parametrize("shape", ("16,16,16", "4,2048,4096"))
@parametrize("device", ("cuda", "cpu"))
def test_to_fp8_saturated(
self, src_dtype: torch.dtype, dst_dtype: torch.dtype, shape: str
self,
src_dtype: torch.dtype,
dst_dtype: torch.dtype,
shape: str,
device: torch.device,
):
if device == "cuda" and not PLATFORM_SUPPORTS_FP8:
raise unittest.SkipTest(f8_msg)
def fp8_saturated(x, dtype):
return _to_fp8_saturated(x, dtype)
@ -221,14 +236,13 @@ class TestFP8Types(TestCase):
fp8_saturated, backend="inductor", dynamic=True
)
shape = [int(dim) for dim in shape.split(",")]
x = torch.rand(*shape, device="cuda", dtype=src_dtype)
x = torch.rand(*shape, device=device, dtype=src_dtype)
y_compiled = compiled_fp8_cast(x, dst_dtype)
y = fp8_saturated(x, dst_dtype)
torch.testing.assert_close(y_compiled.half(), y.half(), rtol=5e-1, atol=5e-1)
@unittest.skipIf(TEST_WITH_ROCM, "ROCm fails with accuracy issue")
@unittest.skipIf(not SM90OrLater, "FP8 is only supported on H100+")
@parametrize(
"float8_dtype",
(torch.float8_e4m3fn, torch.float8_e5m2)
@ -236,7 +250,12 @@ class TestFP8Types(TestCase):
else (torch.float8_e4m3fnuz, torch.float8_e5m2fnuz),
)
@parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096"))
def test_amax_fp8_quant(self, float8_dtype: torch.dtype, shape: str):
@parametrize("device", ("cuda", "cpu"))
def test_amax_fp8_quant(
self, float8_dtype: torch.dtype, shape: str, device: torch.device
):
if device == "cuda" and not SM90OrLater:
raise unittest.SkipTest("FP8 is only supported on H100+")
shape = [int(dim) for dim in shape.split(",")]
batch_size, sequence_length, hidden_size = shape
@ -249,15 +268,14 @@ class TestFP8Types(TestCase):
compiled_amax_fp8_quant = torch.compile(amax_fp8, backend="inductor")
x_shape = (batch_size, sequence_length, hidden_size)
x = torch.rand(*x_shape, device="cuda", dtype=torch.half)
scale = torch.tensor(0.2, device="cuda", dtype=torch.float)
x = torch.rand(*x_shape, device=device, dtype=torch.half)
scale = torch.tensor(0.2, device=device, dtype=torch.float)
y_compiled = compiled_amax_fp8_quant(x, scale)
y = amax_fp8(x, scale)
torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-2, atol=1e-2)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@parametrize(
"float8_dtype",
(torch.float8_e4m3fn, torch.float8_e5m2)
@ -265,7 +283,12 @@ class TestFP8Types(TestCase):
else (torch.float8_e4m3fnuz, torch.float8_e5m2fnuz),
)
@parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096"))
def test_amax_along_with_fp8_quant(self, float8_dtype: torch.dtype, shape: str):
@parametrize("device", ("cuda", "cpu"))
def test_amax_along_with_fp8_quant(
self, float8_dtype: torch.dtype, shape: str, device: torch.device
):
if device == "cuda" and not PLATFORM_SUPPORTS_FP8:
raise unittest.SkipTest(f8_msg)
shape = [int(dim) for dim in shape.split(",")]
batch_size, sequence_length, hidden_size = shape
@ -278,12 +301,12 @@ class TestFP8Types(TestCase):
compiled_amax_fp8_quant = torch.compile(amax_fp8, backend="inductor")
x_shape = (batch_size, sequence_length, hidden_size)
x = torch.rand(*x_shape, device="cuda", dtype=torch.half)
scale = torch.tensor(1.0, device="cuda", dtype=torch.float)
x = torch.rand(*x_shape, device=device, dtype=torch.half)
scale = torch.tensor(1.0, device=device, dtype=torch.float)
amax_buffer_compiled = torch.zeros((1), device="cuda", dtype=torch.half)
amax_buffer_compiled = torch.zeros((1), device=device, dtype=torch.half)
y_compiled = compiled_amax_fp8_quant(x, scale, amax_buffer_compiled)
amax_buffer = torch.zeros((1), device="cuda", dtype=torch.half)
amax_buffer = torch.zeros((1), device=device, dtype=torch.half)
y = amax_fp8(x, scale, amax_buffer)
torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-1, atol=1e-1)
@ -292,7 +315,6 @@ class TestFP8Types(TestCase):
)
@unittest.skipIf(TEST_WITH_ROCM, "ROCm fails with accuracy issue")
@unittest.skipIf(not SM90OrLater, "FP8 is only supported on H100+")
@parametrize(
"float8_dtype",
(torch.float8_e4m3fn, torch.float8_e5m2)
@ -301,9 +323,16 @@ class TestFP8Types(TestCase):
)
@parametrize("amax_keep_dim", (True, False))
@parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096"))
@parametrize("device", ("cuda", "cpu"))
def test_layernorm_fp8_quant(
self, float8_dtype: torch.dtype, amax_keep_dim: bool, shape: str
self,
float8_dtype: torch.dtype,
amax_keep_dim: bool,
shape: str,
device: torch.device,
):
if device == "cuda" and not SM90OrLater:
raise unittest.SkipTest("FP8 is only supported on H100+")
shape = [int(dim) for dim in shape.split(",")]
batch_size, sequence_length, hidden_size = shape
@ -325,12 +354,12 @@ class TestFP8Types(TestCase):
compiled_ln_fp8_quant = torch.compile(ln_fp8, backend="inductor")
x_shape = (batch_size, sequence_length, hidden_size)
x = torch.rand(*x_shape, device="cuda", dtype=torch.half)
scale = torch.tensor(0.2, device="cuda", dtype=torch.float)
x = torch.rand(*x_shape, device=device, dtype=torch.half)
scale = torch.tensor(0.2, device=device, dtype=torch.float)
amax_buffer_compiled = torch.zeros((1), device="cuda", dtype=torch.half)
amax_buffer_compiled = torch.zeros((1), device=device, dtype=torch.half)
y_compiled = compiled_ln_fp8_quant(x, scale, amax_buffer_compiled)
amax_buffer = torch.zeros((1), device="cuda", dtype=torch.half)
amax_buffer = torch.zeros((1), device=device, dtype=torch.half)
y = ln_fp8(x, scale, amax_buffer)
torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-1, atol=1e-1)
@ -748,5 +777,5 @@ class TestFP8Lowering(TestCase):
if __name__ == "__main__":
if HAS_CUDA:
if HAS_CUDA or HAS_CPU:
run_tests()

View File

@ -3,8 +3,6 @@
import unittest
from itertools import product
from functools import partial
from typing import Optional
import re
import torch
@ -17,7 +15,6 @@ from torch.testing import make_tensor
from torch.testing._internal.common_cuda import (
SM53OrLater,
_get_torch_cuda_version,
PLATFORM_SUPPORTS_FP8
)
from torch.testing._internal.common_device_type import (
dtypes,
@ -212,524 +209,6 @@ class TestMatmulCuda(TestCase):
self.assertEqual(out1_gpu, out2_gpu[0])
f8_msg = "FP8 is only supported on H100+ and sm_89 and MI300+ devices"
if torch.version.hip:
e4m3_type = torch.float8_e4m3fnuz
e5m2_type = torch.float8_e5m2fnuz
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max
E5M2_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max
else:
e4m3_type = torch.float8_e4m3fn
e5m2_type = torch.float8_e5m2
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max
E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max
# avoid division by zero when calculating scale
EPS = 1e-12
def amax_to_scale(
amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
):
""" Converts the amax value of a tensor to the fp8 scale.
Args:
amax: The amax value of the tensor.
float8_dtype: the float8 dtype.
orig_dtype: The original dtype of the tensor.
"""
scale = torch.empty_like(amax, dtype=torch.float32)
if float8_dtype == e4m3_type:
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
elif float8_dtype == e5m2_type:
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
else:
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")
# Ensure the scale is representable in float16,
# this helps when amax is small. We are assuming that we don't need
# to care about this for float32/bfloat16
if orig_dtype is torch.float16:
res = torch.clamp(res, max=torch.finfo(torch.float16).max)
scale.copy_(res)
return scale
def tensor_to_scale(x: torch.Tensor, float8_dtype: torch.dtype, dim=None):
if dim is None:
amax = torch.max(torch.abs(x))
else:
amax = torch.max(torch.abs(x), dim=dim, keepdim=True).values
return amax_to_scale(amax, float8_dtype, x.dtype)
def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor:
# naive implementation: dq -> op -> q
x_fp32 = x.to(torch.float) / x_scale
y_fp32 = y.to(torch.float) / y_scale
out_fp32 = torch.mm(x_fp32, y_fp32)
return out_fp32.to(out_dtype)
def addmm_float8_unwrapped(
a_data: torch.Tensor,
a_scale: torch.Tensor,
b_data: torch.Tensor,
b_scale: torch.tensor,
output_dtype: torch.dtype,
output_scale: Optional[torch.Tensor],
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
a_inverse_scale = a_scale.reciprocal()
b_inverse_scale = b_scale.reciprocal()
if output_dtype == torch.float32 and bias is not None:
# Bias is not supported by _scaled_mm when output is fp32
output = torch._scaled_mm(
a_data,
b_data,
scale_a=a_inverse_scale,
scale_b=b_inverse_scale,
scale_result=output_scale,
out_dtype=output_dtype,
)
output += bias
return output
output = torch._scaled_mm(
a_data,
b_data,
bias=bias,
scale_a=a_inverse_scale,
scale_b=b_inverse_scale,
scale_result=output_scale,
out_dtype=output_dtype,
)
return output
def mm_float8(
a: torch.Tensor,
b: torch.Tensor,
a_scale: torch.Tensor,
b_scale: torch.Tensor,
output_dtype: torch.dtype, # output dtype
output_scale: Optional[torch.Tensor] = None, # output scale, precomputed
) -> torch.Tensor:
return addmm_float8_unwrapped(
a, a_scale, b, b_scale, output_dtype, output_scale
)
def to_fp8_saturated(
x: torch.Tensor,
fp8_dtype: torch.dtype
):
if fp8_dtype == e4m3_type:
x = x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS)
elif fp8_dtype == e5m2_type:
x = x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS)
else:
raise ValueError(f"to_fp8_saturated(): Unsupported fp8_dtype: {fp8_dtype}")
return x.to(fp8_dtype)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not found")
class TestFP8MatmulCuda(TestCase):
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
def _test_tautological_mm(self, device: str = "cuda",
x_dtype: torch.dtype = e4m3_type,
y_dtype: torch.dtype = e4m3_type,
out_dtype: Optional[torch.dtype] = None,
size: int = 16) -> None:
x_fp8 = torch.rand(size, size, device=device).to(x_dtype)
y_fp8 = torch.eye(size, device=device, dtype=y_dtype).t()
out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float))
scale_a = torch.tensor(1.0, device=device)
scale_b = torch.tensor(1.0, device=device)
out_fp8 = torch._scaled_mm(x_fp8, y_fp8, scale_a, scale_b, out_dtype=out_dtype)
if out_dtype is not None:
self.assertEqual(out_dtype, out_fp8.dtype)
self.assertEqual(out_fp32, out_fp8.to(torch.float))
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
def test_float8_basics(self, device) -> None:
self._test_tautological_mm(device, e4m3_type, e4m3_type, size=16)
# hipblaslt does not yet support mixed e4m3_type input
if torch.version.hip is None:
self._test_tautological_mm(device, e4m3_type, e5m2_type, size=32)
self._test_tautological_mm(device, e5m2_type, e4m3_type, size=48)
# According to https://docs.nvidia.com/cuda/cublas/#id99 8F_E5M2 MM is unsupported
with self.assertRaises(RuntimeError):
self._test_tautological_mm(device, e5m2_type, e5m2_type)
self._test_tautological_mm(device, size=64, out_dtype=torch.float16)
self._test_tautological_mm(device, size=96, out_dtype=torch.float32)
# hipblaslt does not yet support bfloat16 output
if torch.version.hip is None:
self._test_tautological_mm(device, size=80, out_dtype=torch.bfloat16)
with self.assertRaises(RuntimeError):
self._test_tautological_mm(device, out_dtype=e5m2_type)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
def test_float8_scale(self, device) -> None:
size = (16, 16)
x = torch.full(size, .5, device=device, dtype=e4m3_type)
# hipblaslt does not yet support mixed e4m3_type input
y_type = e4m3_type if torch.version.hip else e5m2_type
y = torch.full(size, .5, device=device, dtype=y_type).t()
scale_a = torch.tensor(1.5, device=device)
scale_b = torch.tensor(0.66, device=device)
out_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b)
self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4., device=device))
out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b)
self.assertEqual(out_fp8, out_fp8_s)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@parametrize("base_dtype", [torch.float16, torch.bfloat16, torch.float32])
def test_scaled_mm_vs_emulated(self, base_dtype):
torch.manual_seed(42)
input_dtype = e4m3_type
output_dtype = base_dtype
compare_type = torch.float32
x = torch.randn(16, 16, device="cuda", dtype=base_dtype)
y = torch.randn(32, 16, device="cuda", dtype=base_dtype).t()
x_scale = tensor_to_scale(x, input_dtype).float()
y_scale = tensor_to_scale(y, input_dtype).float()
x_fp8 = to_fp8_saturated(x * x_scale, input_dtype)
y_fp8 = to_fp8_saturated(y * y_scale, input_dtype)
# Calculate actual F8 mm
out_scaled_mm = mm_float8(
x_fp8,
y_fp8,
a_scale=x_scale,
b_scale=y_scale,
output_dtype=output_dtype
)
# Calculate emulated F8 mm
out_emulated = mm_float8_emulated(
x_fp8,
x_scale,
y_fp8,
y_scale,
output_dtype
)
if output_dtype != base_dtype:
out_scaled_mm = out_scaled_mm.to(compare_type)
out_scaled_mm = out_scaled_mm / tensor_to_scale(out_scaled_mm, input_dtype)
out_emulated = out_emulated.to(compare_type)
out_emulated = out_emulated / tensor_to_scale(out_emulated, input_dtype)
if base_dtype in {torch.bfloat16, torch.float16}:
atol, rtol = 7e-2, 7e-2
else:
atol, rtol = 3e-3, 3e-3
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@parametrize("base_dtype", [torch.float16, torch.bfloat16, torch.float32])
def test_scaled_mm_change_stride(self, base_dtype):
torch.manual_seed(42)
input_dtype = e4m3_type
output_dtype = base_dtype
compare_type = torch.float32
x = torch.empty_strided((16, 16), (16, 1), device="cuda", dtype=base_dtype)
y = torch.empty_strided((16, 32), (1, 64), device="cuda", dtype=base_dtype)
x_scale = tensor_to_scale(x, input_dtype).float()
y_scale = tensor_to_scale(y, input_dtype).float()
x_fp8 = to_fp8_saturated(x * x_scale, input_dtype)
y_fp8 = to_fp8_saturated(y * y_scale, input_dtype)
# Calculate actual F8 mm
out_scaled_mm = mm_float8(
x_fp8,
y_fp8,
a_scale=x_scale,
b_scale=y_scale,
output_dtype=output_dtype
)
# Calculate emulated F8 mm
out_emulated = mm_float8_emulated(
x_fp8,
x_scale,
y_fp8,
y_scale,
output_dtype
)
if output_dtype != base_dtype:
out_scaled_mm = out_scaled_mm.to(compare_type)
out_scaled_mm = out_scaled_mm / tensor_to_scale(out_scaled_mm, input_dtype)
out_emulated = out_emulated.to(compare_type)
out_emulated = out_emulated / tensor_to_scale(out_emulated, input_dtype)
if base_dtype in {torch.bfloat16, torch.float16}:
atol, rtol = 7e-2, 7e-2
else:
atol, rtol = 3e-3, 3e-3
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
def test_float8_bias(self, device) -> None:
(k, l, m) = (16, 48, 32)
x = torch.ones((k, l), device=device).to(e4m3_type)
y = torch.full((m, l), .25, device=device, dtype=e4m3_type).t()
bias = torch.full((m,), 4.0, device=device, dtype=torch.half)
scale_a = torch.tensor(1.0, device=device)
scale_b = torch.tensor(1.0, device=device)
out_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b)
outb_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, bias=bias)
# this fails on ROCm currently because hipblaslt doesn't have amax op
out_fp32 = out_fp8.to(torch.float32)
outb_fp32 = outb_fp8.to(torch.float32)
difference = torch.abs(out_fp32 - outb_fp32)
self.assertEqual(difference, torch.tensor(4.0, device=device).expand_as(out_fp32))
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@parametrize("bias", [True, False])
def test_non_divisible_leading_dim(self, device, bias: bool) -> None:
x = torch.rand((17, 16), device=device).to(e4m3_type)
y = torch.rand((16, 16), device=device).to(e4m3_type).t()
scale_a = torch.tensor(1.0, device=device)
scale_b = torch.tensor(1.0, device=device)
input_bias = None
if bias:
input_bias = torch.rand((16,), device=device).to(torch.half)
_ = torch._scaled_mm(x, y, scale_a, scale_b, bias=input_bias)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
def test_float8_bias_relu_edgecase(self, device) -> None:
(k, l, m) = (16, 48, 32)
x = torch.full((k, l), 0.0, device=device).to(e4m3_type)
y = torch.full((m, l), 1.0, device=device, dtype=e4m3_type).t()
bias = torch.full((m,), -3.0, device=device, dtype=torch.half)
scale_a = torch.tensor(1.0, device=device)
scale_b = torch.tensor(1.0, device=device)
outb_fp8 = torch._scaled_mm(x, y, scale_a, scale_b, bias=bias)
outb_fp32 = outb_fp8.to(torch.float32)
self.assertEqual(outb_fp32, torch.tensor(-3.0, device=device).expand_as(outb_fp32))
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
def test_float32_output_errors_with_bias(self, device) -> None:
(k, l, m) = (16, 48, 32)
x = torch.rand((k, l), device=device).to(e4m3_type)
y = torch.full((m, l), .25, device=device, dtype=e4m3_type).t()
scale_a = torch.tensor(1.0, device=device)
scale_b = torch.tensor(1.0, device=device)
bias = torch.full((m,), 4.0, device=device, dtype=torch.bfloat16)
self.assertRaisesRegex(
RuntimeError,
"Bias is not supported when out_dtype is set to Float32",
lambda: torch._scaled_mm(x, y, scale_a, scale_b, bias=bias, out_dtype=torch.float32),
)
@unittest.skipIf(PLATFORM_SUPPORTS_FP8,
"This test is only for devices with compute capability < 8.9")
def test_error_message_fp8_pre_sm89(self, device) -> None:
(k, l, m) = (16, 48, 32)
x = torch.rand((k, l), device=device).to(e4m3_type)
y = torch.rand((m, l), device=device).to(e4m3_type).t()
scale_a = torch.tensor(1.0, device=device)
scale_b = torch.tensor(1.0, device=device)
self.assertRaisesRegex(
RuntimeError,
r"torch\.\_scaled\_mm is only supported on CUDA devices with compute capability \>\= 9\.0 or 8\.9, or ROCm MI300\+",
lambda: torch._scaled_mm(x, y, scale_a, scale_b, out_dtype=torch.float32),
)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
def test_float8_scale_fast_accum(self, device) -> None:
size = (16, 16)
x = torch.full(size, .5, device=device, dtype=e4m3_type)
# hipblaslt does not yet support mixed e4m3_type input
y_type = e4m3_type if torch.version.hip else e5m2_type
y = torch.full(size, .5, device=device, dtype=y_type).t()
scale_a = torch.tensor(1.5, device=device)
scale_b = torch.tensor(0.66, device=device)
out_fp8 = torch._scaled_mm(x, y, scale_a, scale_b, use_fast_accum=True)
self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4., device=device))
out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, use_fast_accum=True)
self.assertEqual(out_fp8, out_fp8_s)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
@unittest.skipIf(not _IS_SM9X, "rowwise implementation is currently sm90 specific")
@skipIfRocm()
@parametrize("use_fast_accum", [True, False])
def test_float8_rowwise_scaling_sanity(self, device, use_fast_accum: bool) -> None:
M, K, N = (1024, 512, 2048)
fill_value = 0.5
x = torch.full((M, K), fill_value, device=device)
y = torch.full((N, K), fill_value, device=device)
x_scales = torch.ones((x.shape[0], 1), device=device, dtype=torch.float32)
y_scales = torch.ones((1, y.shape[0]), device=device, dtype=torch.float32)
x_fp8 = x.to(torch.float8_e4m3fn)
y_fp8 = y.to(torch.float8_e4m3fn).t()
out_fp8 = torch._scaled_mm(
x_fp8,
y_fp8,
scale_a=x_scales,
scale_b=y_scales,
out_dtype=torch.bfloat16,
use_fast_accum=use_fast_accum,
)
self.assertEqual(
out_fp8.to(torch.float32), torch.full((M, N), K * (fill_value**2), device=device)
)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
@skipIfRocm()
def test_float8_error_messages(self, device) -> None:
M, K, N = (1024, 512, 2048)
fill_value = 0.5
x = torch.full((M, K), fill_value, device=device)
y = torch.full((N, K), fill_value, device=device)
x_fp8 = x.to(torch.float8_e4m3fn)
y_fp8 = y.to(torch.float8_e4m3fn).t()
with self.assertRaisesRegex(
RuntimeError,
re.escape(
"For RowWise scaling, scale_a should be (1024, 1) and scale_b "
"should be (1, 2048). Got scale_a.size()=(1, 1) and scale_b.size()=(1, 2)"
),
):
torch._scaled_mm(
x_fp8,
y_fp8,
scale_a=torch.ones((1, 1), device="cuda"),
scale_b=torch.ones((1, 2), device="cuda"),
out_dtype=torch.bfloat16,
)
with self.assertRaisesRegex(
RuntimeError,
re.escape(
" For RowWise scaling, scale_a should be (1024, 1) and scale_b "
"should be (1, 2048). Got scale_a.size()=(1024, 1) and scale_b.size()=(1, 2049)"
),
):
torch._scaled_mm(
x_fp8,
y_fp8,
scale_a=torch.ones((M, 1), device="cuda"),
scale_b=torch.ones((1, N + 1), device="cuda"),
out_dtype=torch.bfloat16,
)
with self.assertRaisesRegex(
RuntimeError,
re.escape("For non-TensorWise scaling, scale tensors must be 2-dimensional"),
):
torch._scaled_mm(
x_fp8,
y_fp8,
scale_a=torch.ones((M), device="cuda"),
scale_b=torch.ones((N, N), device="cuda"),
out_dtype=torch.bfloat16,
)
with self.assertRaisesRegex(
RuntimeError,
re.escape(
"Both scale_a and scale_b must be contiguous for RowWise scaling."
),
):
torch._scaled_mm(
x_fp8,
y_fp8,
scale_a=torch.ones((M, 1), device="cuda"),
scale_b=torch.ones((1, N * 2), device="cuda")[:, ::2],
out_dtype=torch.bfloat16,
)
with self.assertRaisesRegex(
RuntimeError,
re.escape("Expected b.dtype() == at::kFloat8_e4m3fn to be true, but got false."),
):
torch._scaled_mm(
x_fp8,
y_fp8.to(torch.float8_e5m2),
scale_a=torch.ones((M, 1), device="cuda"),
scale_b=torch.ones((1, N), device="cuda"),
out_dtype=torch.bfloat16,
)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
@unittest.skipIf(not _IS_SM9X, "rowwise implementation is currently sm90 specific")
@skipIfRocm()
@parametrize("base_dtype", [torch.bfloat16])
def test_scaled_mm_vs_emulated_row_wise(self, base_dtype):
torch.manual_seed(42)
input_dtype = e4m3_type
output_dtype = base_dtype
x = torch.randn(16, 16, device="cuda", dtype=base_dtype)
y = torch.randn(32, 16, device="cuda", dtype=base_dtype).t()
x_scales = tensor_to_scale(x, input_dtype, dim=1).float()
y_scales = tensor_to_scale(y, input_dtype, dim=0).float()
x_fp8 = to_fp8_saturated(x * x_scales, e4m3_type)
y_fp8 = to_fp8_saturated(y * y_scales, e4m3_type)
# Calculate actual F8 mm
out_scaled_mm = mm_float8(
x_fp8, y_fp8, a_scale=x_scales, b_scale=y_scales, output_dtype=output_dtype
)
# Calculate emulated F8 mm
out_emulated = mm_float8_emulated(
x_fp8, x_scales, y_fp8, y_scales, output_dtype
)
if base_dtype in {torch.bfloat16, torch.float16}:
atol, rtol = 7e-2, 7e-2
else:
atol, rtol = 2e-3, 2e-3
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@parametrize("which_dim_zero", [0, 1, 2])
@parametrize("use_torch_compile", [False, True])
def test_zero_dim_tensorwise(self, which_dim_zero, use_torch_compile) -> None:
device = "cuda"
x_dtype, y_dtype = torch.float8_e4m3fn, torch.float8_e4m3fn
out_dtype = torch.bfloat16
M, K, N = 32, 32, 32
if which_dim_zero == 0:
M = 0
elif which_dim_zero == 1:
K = 0
elif which_dim_zero == 2:
N = 0
x_fp8 = torch.zeros(M, K, device=device).to(x_dtype)
y_fp8 = torch.zeros(N, K, device=device, dtype=y_dtype).t()
out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float))
scale_a = torch.tensor(float('-inf'), device=device)
scale_b = torch.tensor(float('-inf'), device=device)
f = torch._scaled_mm
if use_torch_compile:
f = torch.compile(torch._scaled_mm)
out_fp8 = f(x_fp8, y_fp8, scale_a, scale_b, out_dtype=out_dtype)
self.assertEqual(out_dtype, out_fp8.dtype)
self.assertEqual(out_fp32, out_fp8.to(torch.float))
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions")
@unittest.skipIf(not _IS_SM8X, "mixed dtypes linear only supported on SM 8.x")
@ -851,7 +330,6 @@ class TestMixedDtypesLinearCuda(TestCase):
)
instantiate_device_type_tests(TestMatmulCuda, globals(), except_for="cpu")
instantiate_device_type_tests(TestFP8MatmulCuda, globals(), except_for="cpu")
instantiate_device_type_tests(TestMixedDtypesLinearCuda, globals(), except_for="cpu")
if __name__ == '__main__':

563
test/test_matmul_fp8.py Normal file
View File

@ -0,0 +1,563 @@
# Owner(s): ["module: linear algebra"]
import re
import unittest
from typing import Optional
import torch
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_utils import (
IS_WINDOWS,
parametrize,
run_tests,
skipIfRocm,
TEST_CUDA,
TestCase,
)
_IS_SM9X = False
if TEST_CUDA:
_IS_SM9X = torch.cuda.get_device_capability(0)[0] == 9
# Protects against includes accidentally setting the default dtype
assert torch.get_default_dtype() is torch.float32
f8_msg = "FP8 is only supported on H100+ and sm_89 and MI300+ devices"
if torch.version.hip:
e4m3_type = torch.float8_e4m3fnuz
e5m2_type = torch.float8_e5m2fnuz
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max
E5M2_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max
else:
e4m3_type = torch.float8_e4m3fn
e5m2_type = torch.float8_e5m2
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max
E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max
# avoid division by zero when calculating scale
EPS = 1e-12
def amax_to_scale(
amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
):
"""Converts the amax value of a tensor to the fp8 scale.
Args:
amax: The amax value of the tensor.
float8_dtype: the float8 dtype.
orig_dtype: The original dtype of the tensor.
"""
scale = torch.empty_like(amax, dtype=torch.float32)
if float8_dtype == e4m3_type:
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
elif float8_dtype == e5m2_type:
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
else:
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")
# Ensure the scale is representable in float16,
# this helps when amax is small. We are assuming that we don't need
# to care about this for float32/bfloat16
if orig_dtype is torch.float16:
res = torch.clamp(res, max=torch.finfo(torch.float16).max)
scale.copy_(res)
return scale
def tensor_to_scale(x: torch.Tensor, float8_dtype: torch.dtype, dim=None):
if dim is None:
amax = torch.max(torch.abs(x))
else:
amax = torch.max(torch.abs(x), dim=dim, keepdim=True).values
return amax_to_scale(amax, float8_dtype, x.dtype)
def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor:
# naive implementation: dq -> op -> q
x_fp32 = x.to(torch.float) / x_scale
y_fp32 = y.to(torch.float) / y_scale
out_fp32 = torch.mm(x_fp32, y_fp32)
return out_fp32.to(out_dtype)
def addmm_float8_unwrapped(
a_data: torch.Tensor,
a_scale: torch.Tensor,
b_data: torch.Tensor,
b_scale: torch.tensor,
output_dtype: torch.dtype,
output_scale: Optional[torch.Tensor],
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
a_inverse_scale = a_scale.reciprocal()
b_inverse_scale = b_scale.reciprocal()
if output_dtype == torch.float32 and bias is not None:
# Bias is not supported by _scaled_mm when output is fp32
output = torch._scaled_mm(
a_data,
b_data,
scale_a=a_inverse_scale,
scale_b=b_inverse_scale,
scale_result=output_scale,
out_dtype=output_dtype,
)
output += bias
return output
output = torch._scaled_mm(
a_data,
b_data,
bias=bias,
scale_a=a_inverse_scale,
scale_b=b_inverse_scale,
scale_result=output_scale,
out_dtype=output_dtype,
)
return output
def mm_float8(
a: torch.Tensor,
b: torch.Tensor,
a_scale: torch.Tensor,
b_scale: torch.Tensor,
output_dtype: torch.dtype, # output dtype
output_scale: Optional[torch.Tensor] = None, # output scale, precomputed
) -> torch.Tensor:
return addmm_float8_unwrapped(a, a_scale, b, b_scale, output_dtype, output_scale)
def to_fp8_saturated(x: torch.Tensor, fp8_dtype: torch.dtype):
if fp8_dtype == e4m3_type:
x = x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS)
elif fp8_dtype == e5m2_type:
x = x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS)
else:
raise ValueError(f"to_fp8_saturated(): Unsupported fp8_dtype: {fp8_dtype}")
return x.to(fp8_dtype)
class TestFP8Matmul(TestCase):
def _test_tautological_mm(
self,
device: str = "cuda",
x_dtype: torch.dtype = e4m3_type,
y_dtype: torch.dtype = e4m3_type,
out_dtype: Optional[torch.dtype] = None,
size: int = 16,
) -> None:
if device != "cpu" and torch.cuda.is_available() and not PLATFORM_SUPPORTS_FP8:
raise unittest.SkipTest(f8_msg)
x_fp8 = torch.rand(size, size, device=device).to(x_dtype)
y_fp8 = torch.eye(size, device=device, dtype=y_dtype).t()
out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float))
scale_a = torch.tensor(1.0, device=device)
scale_b = torch.tensor(1.0, device=device)
out_fp8 = torch._scaled_mm(x_fp8, y_fp8, scale_a, scale_b, out_dtype=out_dtype)
if out_dtype is not None:
self.assertEqual(out_dtype, out_fp8.dtype)
self.assertEqual(out_fp32, out_fp8.to(torch.float))
def test_float8_basics(self, device) -> None:
if device != "cpu" and torch.cuda.is_available() and not PLATFORM_SUPPORTS_FP8:
raise unittest.SkipTest(f8_msg)
self._test_tautological_mm(device, e4m3_type, e4m3_type, size=16)
# hipblaslt does not yet support mixed e4m3_type input
if torch.version.hip is None:
if device != "cpu":
# TODO: The following 2 tests are mixed dtypes between src and weight,
# which will be enabled in oneDNN v3.6 in CPU.
self._test_tautological_mm(device, e4m3_type, e5m2_type, size=32)
self._test_tautological_mm(device, e5m2_type, e4m3_type, size=48)
if device == "cpu":
self._test_tautological_mm(device, e5m2_type, e5m2_type)
else:
# According to https://docs.nvidia.com/cuda/cublas/#id99 8F_E5M2 MM is unsupported
with self.assertRaises(RuntimeError):
self._test_tautological_mm(device, e5m2_type, e5m2_type)
self._test_tautological_mm(device, size=64, out_dtype=torch.float16)
self._test_tautological_mm(device, size=96, out_dtype=torch.float32)
# hipblaslt does not yet support bfloat16 output
if torch.version.hip is None:
self._test_tautological_mm(device, size=80, out_dtype=torch.bfloat16)
if device != "cpu":
with self.assertRaises(RuntimeError):
self._test_tautological_mm(device, out_dtype=e5m2_type)
else:
# TODO: e4m3 and e5m2 naturally has numerical gap, maybe relax the tolerance later.
with self.assertRaises(AssertionError):
self._test_tautological_mm(device, out_dtype=e5m2_type)
def test_float8_scale(self, device) -> None:
if device != "cpu" and torch.cuda.is_available() and not PLATFORM_SUPPORTS_FP8:
raise unittest.SkipTest(f8_msg)
size = (16, 16)
x = torch.full(size, 0.5, device=device, dtype=e4m3_type)
# hipblaslt does not yet support mixed e4m3_type input
# TODO: will use e5m2_type after upgrading oneDNN to v3.6.
y_type = e4m3_type if torch.version.hip or device == "cpu" else e5m2_type
y = torch.full(size, 0.5, device=device, dtype=y_type).t()
scale_a = torch.tensor(1.5, device=device)
scale_b = torch.tensor(0.66, device=device)
out_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b)
self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4.0, device=device))
out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b)
self.assertEqual(out_fp8, out_fp8_s)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@parametrize("base_dtype", [torch.float16, torch.bfloat16, torch.float32])
def test_scaled_mm_vs_emulated(self, base_dtype):
torch.manual_seed(42)
input_dtype = e4m3_type
output_dtype = base_dtype
compare_type = torch.float32
x = torch.randn(16, 16, device="cuda", dtype=base_dtype)
y = torch.randn(32, 16, device="cuda", dtype=base_dtype).t()
x_scale = tensor_to_scale(x, input_dtype).float()
y_scale = tensor_to_scale(y, input_dtype).float()
x_fp8 = to_fp8_saturated(x * x_scale, input_dtype)
y_fp8 = to_fp8_saturated(y * y_scale, input_dtype)
# Calculate actual F8 mm
out_scaled_mm = mm_float8(
x_fp8, y_fp8, a_scale=x_scale, b_scale=y_scale, output_dtype=output_dtype
)
# Calculate emulated F8 mm
out_emulated = mm_float8_emulated(x_fp8, x_scale, y_fp8, y_scale, output_dtype)
if output_dtype != base_dtype:
out_scaled_mm = out_scaled_mm.to(compare_type)
out_scaled_mm = out_scaled_mm / tensor_to_scale(out_scaled_mm, input_dtype)
out_emulated = out_emulated.to(compare_type)
out_emulated = out_emulated / tensor_to_scale(out_emulated, input_dtype)
if base_dtype in {torch.bfloat16, torch.float16}:
atol, rtol = 7e-2, 7e-2
else:
atol, rtol = 3e-3, 3e-3
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@parametrize("base_dtype", [torch.float16, torch.bfloat16, torch.float32])
def test_scaled_mm_change_stride(self, base_dtype):
torch.manual_seed(42)
input_dtype = e4m3_type
output_dtype = base_dtype
compare_type = torch.float32
x = torch.empty_strided((16, 16), (16, 1), device="cuda", dtype=base_dtype)
y = torch.empty_strided((16, 32), (1, 64), device="cuda", dtype=base_dtype)
x_scale = tensor_to_scale(x, input_dtype).float()
y_scale = tensor_to_scale(y, input_dtype).float()
x_fp8 = to_fp8_saturated(x * x_scale, input_dtype)
y_fp8 = to_fp8_saturated(y * y_scale, input_dtype)
# Calculate actual F8 mm
out_scaled_mm = mm_float8(
x_fp8, y_fp8, a_scale=x_scale, b_scale=y_scale, output_dtype=output_dtype
)
# Calculate emulated F8 mm
out_emulated = mm_float8_emulated(x_fp8, x_scale, y_fp8, y_scale, output_dtype)
if output_dtype != base_dtype:
out_scaled_mm = out_scaled_mm.to(compare_type)
out_scaled_mm = out_scaled_mm / tensor_to_scale(out_scaled_mm, input_dtype)
out_emulated = out_emulated.to(compare_type)
out_emulated = out_emulated / tensor_to_scale(out_emulated, input_dtype)
if base_dtype in {torch.bfloat16, torch.float16}:
atol, rtol = 7e-2, 7e-2
else:
atol, rtol = 3e-3, 3e-3
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
def test_float8_bias(self, device) -> None:
if device != "cpu" and torch.cuda.is_available() and not PLATFORM_SUPPORTS_FP8:
raise unittest.SkipTest(f8_msg)
(k, l, m) = (16, 48, 32)
x = torch.ones((k, l), device=device).to(e4m3_type)
y = torch.full((m, l), 0.25, device=device, dtype=e4m3_type).t()
bias = torch.full((m,), 4.0, device=device, dtype=torch.half)
scale_a = torch.tensor(1.0, device=device)
scale_b = torch.tensor(1.0, device=device)
out_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b)
outb_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, bias=bias)
# this fails on ROCm currently because hipblaslt doesn't have amax op
out_fp32 = out_fp8.to(torch.float32)
outb_fp32 = outb_fp8.to(torch.float32)
difference = torch.abs(out_fp32 - outb_fp32)
self.assertEqual(
difference, torch.tensor(4.0, device=device).expand_as(out_fp32)
)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@parametrize("bias", [True, False])
def test_non_divisible_leading_dim(self, device, bias: bool) -> None:
x = torch.rand((17, 16), device=device).to(e4m3_type)
y = torch.rand((16, 16), device=device).to(e4m3_type).t()
scale_a = torch.tensor(1.0, device=device)
scale_b = torch.tensor(1.0, device=device)
input_bias = None
if bias:
input_bias = torch.rand((16,), device=device).to(torch.half)
_ = torch._scaled_mm(x, y, scale_a, scale_b, bias=input_bias)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
def test_float8_bias_relu_edgecase(self, device) -> None:
(k, l, m) = (16, 48, 32)
x = torch.full((k, l), 0.0, device=device).to(e4m3_type)
y = torch.full((m, l), 1.0, device=device, dtype=e4m3_type).t()
bias = torch.full((m,), -3.0, device=device, dtype=torch.half)
scale_a = torch.tensor(1.0, device=device)
scale_b = torch.tensor(1.0, device=device)
outb_fp8 = torch._scaled_mm(x, y, scale_a, scale_b, bias=bias)
outb_fp32 = outb_fp8.to(torch.float32)
self.assertEqual(
outb_fp32, torch.tensor(-3.0, device=device).expand_as(outb_fp32)
)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
def test_float32_output_errors_with_bias(self, device) -> None:
(k, l, m) = (16, 48, 32)
x = torch.rand((k, l), device=device).to(e4m3_type)
y = torch.full((m, l), 0.25, device=device, dtype=e4m3_type).t()
scale_a = torch.tensor(1.0, device=device)
scale_b = torch.tensor(1.0, device=device)
bias = torch.full((m,), 4.0, device=device, dtype=torch.bfloat16)
self.assertRaisesRegex(
RuntimeError,
"Bias is not supported when out_dtype is set to Float32",
lambda: torch._scaled_mm(
x, y, scale_a, scale_b, bias=bias, out_dtype=torch.float32
),
)
@unittest.skipIf(
PLATFORM_SUPPORTS_FP8 or not torch.cuda.is_available(),
"This test is only for devices with compute capability < 8.9",
)
def test_error_message_fp8_pre_sm89(self, device) -> None:
(k, l, m) = (16, 48, 32)
x = torch.rand((k, l), device=device).to(e4m3_type)
y = torch.rand((m, l), device=device).to(e4m3_type).t()
scale_a = torch.tensor(1.0, device=device)
scale_b = torch.tensor(1.0, device=device)
self.assertRaisesRegex(
RuntimeError,
r"torch\.\_scaled\_mm is only supported on CUDA devices with compute capability \>\= 9\.0 or 8\.9, or ROCm MI300\+",
lambda: torch._scaled_mm(x, y, scale_a, scale_b, out_dtype=torch.float32),
)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
def test_float8_scale_fast_accum(self, device) -> None:
size = (16, 16)
x = torch.full(size, 0.5, device=device, dtype=e4m3_type)
# hipblaslt does not yet support mixed e4m3_type input
y_type = e4m3_type if torch.version.hip else e5m2_type
y = torch.full(size, 0.5, device=device, dtype=y_type).t()
scale_a = torch.tensor(1.5, device=device)
scale_b = torch.tensor(0.66, device=device)
out_fp8 = torch._scaled_mm(x, y, scale_a, scale_b, use_fast_accum=True)
self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4.0, device=device))
out_fp8_s = torch._scaled_mm(
x, y, scale_a=scale_a, scale_b=scale_b, use_fast_accum=True
)
self.assertEqual(out_fp8, out_fp8_s)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
@unittest.skipIf(not _IS_SM9X, "rowwise implementation is currently sm90 specific")
@skipIfRocm()
@parametrize("use_fast_accum", [True, False])
def test_float8_rowwise_scaling_sanity(self, device, use_fast_accum: bool) -> None:
M, K, N = (1024, 512, 2048)
fill_value = 0.5
x = torch.full((M, K), fill_value, device=device)
y = torch.full((N, K), fill_value, device=device)
x_scales = torch.ones((x.shape[0], 1), device=device, dtype=torch.float32)
y_scales = torch.ones((1, y.shape[0]), device=device, dtype=torch.float32)
x_fp8 = x.to(torch.float8_e4m3fn)
y_fp8 = y.to(torch.float8_e4m3fn).t()
out_fp8 = torch._scaled_mm(
x_fp8,
y_fp8,
scale_a=x_scales,
scale_b=y_scales,
out_dtype=torch.bfloat16,
use_fast_accum=use_fast_accum,
)
self.assertEqual(
out_fp8.to(torch.float32),
torch.full((M, N), K * (fill_value**2), device=device),
)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
@skipIfRocm()
def test_float8_error_messages(self, device) -> None:
M, K, N = (1024, 512, 2048)
fill_value = 0.5
x = torch.full((M, K), fill_value, device=device)
y = torch.full((N, K), fill_value, device=device)
x_fp8 = x.to(torch.float8_e4m3fn)
y_fp8 = y.to(torch.float8_e4m3fn).t()
with self.assertRaisesRegex(
RuntimeError,
re.escape(
"For RowWise scaling, scale_a should be (1024, 1) and scale_b "
"should be (1, 2048). Got scale_a.size()=(1, 1) and scale_b.size()=(1, 2)"
),
):
torch._scaled_mm(
x_fp8,
y_fp8,
scale_a=torch.ones((1, 1), device="cuda"),
scale_b=torch.ones((1, 2), device="cuda"),
out_dtype=torch.bfloat16,
)
with self.assertRaisesRegex(
RuntimeError,
re.escape(
" For RowWise scaling, scale_a should be (1024, 1) and scale_b "
"should be (1, 2048). Got scale_a.size()=(1024, 1) and scale_b.size()=(1, 2049)"
),
):
torch._scaled_mm(
x_fp8,
y_fp8,
scale_a=torch.ones((M, 1), device="cuda"),
scale_b=torch.ones((1, N + 1), device="cuda"),
out_dtype=torch.bfloat16,
)
with self.assertRaisesRegex(
RuntimeError,
re.escape(
"For non-TensorWise scaling, scale tensors must be 2-dimensional"
),
):
torch._scaled_mm(
x_fp8,
y_fp8,
scale_a=torch.ones((M), device="cuda"),
scale_b=torch.ones((N, N), device="cuda"),
out_dtype=torch.bfloat16,
)
with self.assertRaisesRegex(
RuntimeError,
re.escape(
"Both scale_a and scale_b must be contiguous for RowWise scaling."
),
):
torch._scaled_mm(
x_fp8,
y_fp8,
scale_a=torch.ones((M, 1), device="cuda"),
scale_b=torch.ones((1, N * 2), device="cuda")[:, ::2],
out_dtype=torch.bfloat16,
)
with self.assertRaisesRegex(
RuntimeError,
re.escape(
"Expected b.dtype() == at::kFloat8_e4m3fn to be true, but got false."
),
):
torch._scaled_mm(
x_fp8,
y_fp8.to(torch.float8_e5m2),
scale_a=torch.ones((M, 1), device="cuda"),
scale_b=torch.ones((1, N), device="cuda"),
out_dtype=torch.bfloat16,
)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
@unittest.skipIf(not _IS_SM9X, "rowwise implementation is currently sm90 specific")
@skipIfRocm()
@parametrize("base_dtype", [torch.bfloat16])
def test_scaled_mm_vs_emulated_row_wise(self, base_dtype):
torch.manual_seed(42)
input_dtype = e4m3_type
output_dtype = base_dtype
x = torch.randn(16, 16, device="cuda", dtype=base_dtype)
y = torch.randn(32, 16, device="cuda", dtype=base_dtype).t()
x_scales = tensor_to_scale(x, input_dtype, dim=1).float()
y_scales = tensor_to_scale(y, input_dtype, dim=0).float()
x_fp8 = to_fp8_saturated(x * x_scales, e4m3_type)
y_fp8 = to_fp8_saturated(y * y_scales, e4m3_type)
# Calculate actual F8 mm
out_scaled_mm = mm_float8(
x_fp8, y_fp8, a_scale=x_scales, b_scale=y_scales, output_dtype=output_dtype
)
# Calculate emulated F8 mm
out_emulated = mm_float8_emulated(
x_fp8, x_scales, y_fp8, y_scales, output_dtype
)
if base_dtype in {torch.bfloat16, torch.float16}:
atol, rtol = 7e-2, 7e-2
else:
atol, rtol = 2e-3, 2e-3
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@parametrize("which_dim_zero", [0, 1, 2])
@parametrize("use_torch_compile", [False, True])
def test_zero_dim_tensorwise(self, which_dim_zero, use_torch_compile) -> None:
device = "cuda"
x_dtype, y_dtype = torch.float8_e4m3fn, torch.float8_e4m3fn
out_dtype = torch.bfloat16
M, K, N = 32, 32, 32
if which_dim_zero == 0:
M = 0
elif which_dim_zero == 1:
K = 0
elif which_dim_zero == 2:
N = 0
x_fp8 = torch.zeros(M, K, device=device).to(x_dtype)
y_fp8 = torch.zeros(N, K, device=device, dtype=y_dtype).t()
out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float))
scale_a = torch.tensor(float("-inf"), device=device)
scale_b = torch.tensor(float("-inf"), device=device)
f = torch._scaled_mm
if use_torch_compile:
f = torch.compile(torch._scaled_mm)
out_fp8 = f(x_fp8, y_fp8, scale_a, scale_b, out_dtype=out_dtype)
self.assertEqual(out_dtype, out_fp8.dtype)
self.assertEqual(out_fp32, out_fp8.to(torch.float))
instantiate_device_type_tests(TestFP8Matmul, globals())
if __name__ == "__main__":
TestCase._default_dtype_check_enabled = True
run_tests()

View File

@ -34,6 +34,8 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__pdist_backward(AtenTensorHandle
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__pdist_forward(AtenTensorHandle self, double p, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_flash_attention_for_cpu(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, double dropout_p, int32_t is_causal, AtenTensorHandle* attn_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_flash_attention_for_cpu_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle out, AtenTensorHandle logsumexp, double dropout_p, int32_t is_causal, AtenTensorHandle* attn_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__segment_reduce_backward(AtenTensorHandle grad, AtenTensorHandle output, AtenTensorHandle data, const char* reduce, AtenTensorHandle* lengths, AtenTensorHandle* offsets, int64_t axis, double* initial, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0);

View File

@ -1008,6 +1008,8 @@ ANY_DTYPE_ORDER = (
torch.int8,
torch.uint8,
torch.bool,
torch.float8_e4m3fn,
torch.float8_e5m2,
)

View File

@ -21,7 +21,7 @@ from torch.testing import make_tensor
from torch.testing._internal.common_dtype import (
_dispatch_dtypes, floating_types, floating_types_and, complex_types, floating_and_complex_types,
floating_and_complex_types_and, all_types_and_complex_and, all_types_and, all_types_and_complex, integral_types_and,
empty_types, complex_types_and, integral_types, custom_types, all_types_complex_float8_and,
empty_types, complex_types_and, integral_types, custom_types, all_types_complex_float8_and, float8_types,
)
from torch.testing._internal.common_device_type import \
(onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver,
@ -8734,18 +8734,27 @@ def sample_inputs_scaled_mm(op_info, device, dtype, requires_grad, **kwargs):
scale1 = make_scale((1,))
scale2 = make_scale((1,))
samples.append(SampleInput(mat1, mat2, scale1, scale2))
# mat1 e4m3 mat2 e5m2
mat1 = make_mat_e4m3((M, K))
# two e5m2
mat1 = make_mat_e5m2((M, K))
mat2 = make_mat_e5m2((K, N)).t().contiguous().t()
scale1 = make_scale((1,))
scale2 = make_scale((1,))
samples.append(SampleInput(mat1, mat2, scale1, scale2))
# mat1 e5m2 mat2 e4m3
mat1 = make_mat_e5m2((M, K))
mat2 = make_mat_e4m3((K, N)).t().contiguous().t()
scale1 = make_scale((1,))
scale2 = make_scale((1,))
samples.append(SampleInput(mat1, mat2, scale1, scale2))
# TODO: Will remove this after oneDNN v3.6
# now oneDNN v3.5.3 only supports mat1 * mat2 with the same data types.
if device != 'cpu':
# mat1 e4m3 mat2 e5m2
mat1 = make_mat_e4m3((M, K))
mat2 = make_mat_e5m2((K, N)).t().contiguous().t()
scale1 = make_scale((1,))
scale2 = make_scale((1,))
samples.append(SampleInput(mat1, mat2, scale1, scale2))
# mat1 e5m2 mat2 e4m3
mat1 = make_mat_e5m2((M, K))
mat2 = make_mat_e4m3((K, N)).t().contiguous().t()
scale1 = make_scale((1,))
scale2 = make_scale((1,))
samples.append(SampleInput(mat1, mat2, scale1, scale2))
yield from samples
@ -16210,20 +16219,28 @@ op_db: List[OpInfo] = [
OpInfo(
'torch._scaled_mm',
sample_inputs_func=sample_inputs_scaled_mm,
dtypes=empty_types(),
dtypes=float8_types(),
dtypesIfCUDA=empty_types() + (torch.float8_e4m3fn,),
supports_out=True,
supports_forward_ad=False,
supports_autograd=False,
decorators=[skipCUDAIf(not SM90OrLater or TEST_WITH_ROCM, 'Requires CUDA SM >= 9.0')],
decorators=[skipCUDAIf(not SM90OrLater or TEST_WITH_ROCM, 'Requires CUDA SM >= 9.0'), ],
skips=(
# Sample inputs isn't really parametrized on dtype
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes',
device_type='cuda'),
# "mul_cuda" not implemented for float8_e4m3fn
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes'),
# "add_stub" not implemented for 'Float8_e4m3fn'
# https://github.com/pytorch/pytorch/issues/107256
DecorateInfo(unittest.skip("Skipped!"), 'TestSchemaCheckModeOpInfo', 'test_schema_correctness',
dtypes=(torch.float8_e4m3fn,)),
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out',
device_type='cpu'),
# "mul_cuda" not implemented for float8_e4m3fn
# "mul_cpu_reduced_float" not implemented for 'Float8_e4m3fn'
# https://github.com/pytorch/pytorch/issues/107256
DecorateInfo(unittest.skip("Skipped!"), 'TestSchemaCheckModeOpInfo', 'test_schema_correctness'),
# aten::_scaled_mm hit the vmap fallback which is currently disabled
DecorateInfo(unittest.skip("Skipped!"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule",
device_type='cpu'),
DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness',
dtypes=(torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz)),
)
),
OpInfo(