Add torch._scaled_mm for CPU (#150410)

This PR is the duplicated one for https://github.com/pytorch/pytorch/pull/139975.

This PR is to add torch._scaled_mm for CPU backend.

_scaled_mm_out_cpu and _scaled_mm_cpu are new added and included in torch._scaled_mm CPU dispatch. We also add _scaled_mm_out_cpu_emulated as a fallback function if the current platform cannot run FP8 matmul using oneDNN. And this PR also updates the various UTs related to FP8 to support CPU tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150410
Approved by: https://github.com/atalman
This commit is contained in:
Jiang, Yanbing 2025-04-11 02:23:00 +00:00 committed by PyTorch MergeBot
parent 24ca7e91e6
commit 1e92579126
11 changed files with 364 additions and 61 deletions

View File

@ -7,6 +7,11 @@
#include <ATen/Config.h> #include <ATen/Config.h>
#include <ATen/native/mkldnn/Matmul.h> #include <ATen/native/mkldnn/Matmul.h>
#include <ATen/native/mkldnn/Linear.h>
#include <ATen/native/Resize.h>
#if !defined(__s390x__) && !defined(__powerpc__)
#include <cpuinfo.h>
#endif
#ifndef AT_PER_OPERATOR_HEADERS #ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/CPUFunctions.h> #include <ATen/CPUFunctions.h>
@ -24,6 +29,9 @@
#include <ATen/ops/mv_native.h> #include <ATen/ops/mv_native.h>
#include <ATen/ops/scalar_tensor_native.h> #include <ATen/ops/scalar_tensor_native.h>
#include <ATen/ops/vdot_native.h> #include <ATen/ops/vdot_native.h>
#include <ATen/ops/_scaled_mm_native.h>
#include <ATen/ops/mul.h>
#include <ATen/ops/matmul.h>
#endif #endif
namespace at::meta { namespace at::meta {
@ -222,4 +230,92 @@ Tensor vdot(const Tensor &self, const Tensor &other){
} }
static Tensor&
_scaled_mm_out_cpu_emulated(const Tensor& mat1, const Tensor& mat2,
const Tensor& scale_a,
const Tensor& scale_b,
const std::optional<at::Tensor>& bias,
const std::optional<at::Tensor>& scale_result,
std::optional<c10::ScalarType> out_dtype,
bool use_fast_accum,
Tensor& out) {
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix");
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
TORCH_CHECK(
mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
TORCH_INTERNAL_ASSERT((scale_a.numel() == 1 && scale_b.numel() == 1), "Now _scaled_mm only supports per-tensor scaling for CPU backend.");
TORCH_CHECK(!bias || bias->numel() == mat2.sizes()[1], "Bias must be size ", mat2.sizes()[1],
" but got ", bias->numel());
// Check types
TORCH_CHECK(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type");
TORCH_CHECK(isFloat8Type(mat1.scalar_type()), "Expected mat1 to be Float8 matrix got ", mat1.scalar_type());
TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat2.scalar_type());
auto mat1_c = mat1.contiguous();
auto mat2_c = mat2.contiguous();
IntArrayRef mat1_sizes = mat1_c.sizes();
IntArrayRef mat2_sizes = mat2_c.sizes();
at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
float input_scale = scale_a.item<float>();
float weight_scale = scale_b.item<float>();
auto fp32_mat1 = at::mul(mat1.to(kFloat), input_scale);
auto fp32_mat2 = at::mul(mat2_c.to(kFloat), weight_scale);
auto out_tmp = at::matmul(fp32_mat1, fp32_mat2);
if (bias) {
out_tmp.add_(bias.value());
}
out_tmp = out_tmp.to(out.scalar_type());
out.copy_(out_tmp);
return out;
}
Tensor&
_scaled_mm_out_cpu(const Tensor& mat1, const Tensor& mat2,
const Tensor& scale_a,
const Tensor& scale_b,
const std::optional<at::Tensor>& bias,
const std::optional<at::Tensor>& scale_result,
std::optional<c10::ScalarType> out_dtype,
bool use_fast_accum,
Tensor& out) {
#if AT_MKLDNN_ENABLED()
if (at::globalContext().userEnabledMkldnn()) {
bool mixed_dtype = mat1.scalar_type() != mat2.scalar_type();
if ((!mixed_dtype && cpuinfo_has_x86_amx_int8()) ||
(mixed_dtype && cpuinfo_has_x86_amx_fp16())) {
return mkldnn_scaled_mm(
mat1,
mat2,
scale_a,
scale_b,
bias,
scale_result,
out_dtype,
use_fast_accum,
out);
}
}
#endif
{
return _scaled_mm_out_cpu_emulated(mat1, mat2, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out);
}
}
Tensor
_scaled_mm_cpu(const Tensor& mat_a, const Tensor& mat_b,
const Tensor& scale_a,
const Tensor& scale_b,
const std::optional<at::Tensor>& bias,
const std::optional<at::Tensor>& scale_result,
std::optional<c10::ScalarType> out_dtype,
bool use_fast_accum) {
const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type());
Tensor out = at::empty({0}, mat_a.options().dtype(out_dtype_));
return _scaled_mm_out_cpu(mat_a, mat_b, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out);
}
} // namespace at::native } // namespace at::native

View File

@ -4,6 +4,7 @@
#include <ATen/core/Tensor.h> #include <ATen/core/Tensor.h>
#include <torch/library.h> #include <torch/library.h>
#include <ATen/native/mkldnn/Linear.h> #include <ATen/native/mkldnn/Linear.h>
#include <ATen/native/Resize.h>
#ifndef AT_PER_OPERATOR_HEADERS #ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h> #include <ATen/Functions.h>
@ -46,8 +47,19 @@ std::tuple<Tensor, Tensor, Tensor> mkldnn_linear_backward(
TORCH_CHECK(false, "mkldnn_linear_backward: ATen not compiled with MKLDNN support"); TORCH_CHECK(false, "mkldnn_linear_backward: ATen not compiled with MKLDNN support");
} }
} // namespace at::native Tensor&
mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2,
const Tensor& scale_a,
const Tensor& scale_b,
const std::optional<at::Tensor>& bias,
const std::optional<at::Tensor>& scale_result,
std::optional<c10::ScalarType> out_dtype,
bool use_fast_accum,
Tensor& out) {
TORCH_INTERNAL_ASSERT(false, "mkldnn_scaled_mm: ATen not compiled with MKLDNN support");
}
} // namespace at::native
#else // AT_MKLDNN_ENABLED #else // AT_MKLDNN_ENABLED
@ -459,6 +471,118 @@ TORCH_LIBRARY_IMPL(mkldnn, MkldnnCPU, m) {
TORCH_FN(mkldnn_linear_pointwise_binary)); TORCH_FN(mkldnn_linear_pointwise_binary));
} }
Tensor&
mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2,
const Tensor& scale_a,
const Tensor& scale_b,
const std::optional<at::Tensor>& bias,
const std::optional<at::Tensor>& scale_result,
std::optional<c10::ScalarType> out_dtype,
bool use_fast_accum,
Tensor& out) {
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix");
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
TORCH_CHECK(
mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
TORCH_INTERNAL_ASSERT((scale_a.numel() == 1 && scale_b.numel() == 1), "Now _scaled_mm only supports per-tensor scaling for CPU backend.");
TORCH_CHECK(!bias || bias->numel() == mat2.sizes()[1], "Bias must be size ", mat2.sizes()[1],
" but got ", bias->numel());
// Check types
TORCH_CHECK(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type");
TORCH_CHECK(isFloat8Type(mat1.scalar_type()), "Expected mat1 to be Float8 matrix got ", mat1.scalar_type());
TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat2.scalar_type());
// Validation checks have passed lets resize the output to actual size
auto mat1_c = mat1.contiguous();
auto mat2_c = mat2.contiguous();
IntArrayRef mat1_sizes = mat1_c.sizes();
IntArrayRef mat2_sizes = mat2_c.sizes();
at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
float input_scale = scale_a.item<float>();
float weight_scale = scale_b.item<float>();
auto src = at::native::itensor_view_from_dense(mat1_c);
auto weight_t = at::native::itensor_view_from_dense(mat2_c);
bool with_bias = bias.has_value();
int64_t K = mat1_sizes[1], M = mat1_sizes[0],
N = mat2_sizes[1];
std::vector<int64_t> src_dims = {M, K};
std::vector<int64_t> weight_dims = {K, N};
std::vector<int64_t> dst_dims = {M, N};
ideep::tensor dst = at::native::itensor_view_from_dense(out);
auto src_desc = ideep::tensor::desc(
src_dims,
get_mkldnn_dtype(mat1.scalar_type()),
ideep::format_tag::any);
auto weights_desc = ideep::tensor::desc(
weight_dims,
get_mkldnn_dtype(mat2.scalar_type()),
ideep::format_tag::any);
auto dst_desc = ideep::tensor::desc(
dst_dims,
get_mkldnn_dtype(out.scalar_type()),
ideep::format_tag::any);
ideep::tensor onednn_bias;
if (with_bias) {
auto bias_value = bias.value();
if (bias_value.dim() == 1) {
auto b_reshape = bias_value.reshape({1, bias_value.size(0)});
onednn_bias = at::native::itensor_view_from_dense(b_reshape);
} else {
onednn_bias = at::native::itensor_view_from_dense(bias_value);
}
}
auto bias_desc = ideep::tensor::desc();
if (with_bias) {
bias_desc = ideep::tensor::desc(onednn_bias.get_dims(),
get_mkldnn_dtype(bias.value().scalar_type()),
ideep::format_tag::any);
}
auto op_attr = ideep::attr_t();
if (input_scale != 1.0f) {
op_attr.set_scales_mask(DNNL_ARG_SRC, 0);
}
if (weight_scale != 1.0f) {
op_attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0);
}
op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
auto engine = ideep::engine::cpu_engine();
dnnl::matmul::primitive_desc primitive_desc = with_bias
? dnnl::matmul::primitive_desc(
engine, src_desc, weights_desc, bias_desc, dst_desc, op_attr)
: dnnl::matmul::primitive_desc(
engine, src_desc, weights_desc, dst_desc, op_attr);
auto expected_weight = weight_t.reorder_if_differ_in(primitive_desc.weights_desc());
auto primitive = dnnl::matmul(primitive_desc);
// Prepare args and execute primitive
ideep::tensor scratchpad(primitive_desc.scratchpad_desc());
ideep::exec_args args;
args.insert({DNNL_ARG_SRC, src});
args.insert({DNNL_ARG_WEIGHTS, expected_weight});
args.insert({DNNL_ARG_DST, dst});
args.insert({DNNL_ARG_SCRATCHPAD, scratchpad});
if (with_bias) {
args.insert({DNNL_ARG_BIAS, onednn_bias});
}
ideep::tensor src_scales_t = ideep::tensor(ideep::scale_t(1, input_scale));
ideep::tensor wei_scales_t = ideep::tensor(ideep::scale_t(1, weight_scale));
if (input_scale != 1.0f) {
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scales_t});
}
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_scales_t});
primitive.execute(ideep::stream::default_stream(), args);
return out;
}
} // namespace at } // namespace at
#endif // AT_MKLDNN_ENABLED #endif // AT_MKLDNN_ENABLED

View File

@ -35,3 +35,15 @@ C10_API Tensor mkl_linear(
} // namespace at } // namespace at
#endif // AT_MKLDNN_ENABLED() #endif // AT_MKLDNN_ENABLED()
namespace at::native {
Tensor&
mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2,
const Tensor& scale_a,
const Tensor& scale_b,
const std::optional<at::Tensor>& bias,
const std::optional<at::Tensor>& scale_result,
std::optional<c10::ScalarType> out_dtype,
bool use_fast_accum,
Tensor& out);
} // namespace at::native

View File

@ -57,6 +57,10 @@ ideep::tensor::data_type get_mkldnn_dtype(ScalarType type) {
return ideep::tensor::data_type::bf16; return ideep::tensor::data_type::bf16;
case ScalarType::Half: case ScalarType::Half:
return ideep::tensor::data_type::f16; return ideep::tensor::data_type::f16;
case ScalarType::Float8_e4m3fn:
return ideep::tensor::data_type::f8_e4m3;
case ScalarType::Float8_e5m2:
return ideep::tensor::data_type::f8_e5m2;
default: default:
TORCH_CHECK(false, "get_mkldnn_dtype: unsupported data type"); TORCH_CHECK(false, "get_mkldnn_dtype: unsupported data type");
} }
@ -161,8 +165,24 @@ ideep::tensor itensor_view_from_dense(const Tensor& tensor, bool from_const_data
const_cast<void*>(tensor.const_data_ptr()) : const_cast<void*>(tensor.const_data_ptr()) :
tensor.data_ptr()}; tensor.data_ptr()};
} }
else if (tensor.scalar_type() == ScalarType::Float8_e4m3fn) {
return {{tensor.sizes().vec(),
ideep::tensor::data_type::f8_e4m3,
tensor.strides().vec()},
from_const_data_ptr ?
const_cast<void*>(tensor.const_data_ptr()) :
tensor.data_ptr()};
}
else if (tensor.scalar_type() == ScalarType::Float8_e5m2) {
return {{tensor.sizes().vec(),
ideep::tensor::data_type::f8_e5m2,
tensor.strides().vec()},
from_const_data_ptr ?
const_cast<void*>(tensor.const_data_ptr()) :
tensor.data_ptr()};
}
else { else {
TORCH_CHECK(false, "itensor_view_from_dense expects float/bfloat16/half/int8 tensor input"); TORCH_CHECK(false, "itensor_view_from_dense expects float/bfloat16/half/int8/fp8 tensor input", tensor.scalar_type());
} }
} }

View File

@ -7067,11 +7067,13 @@
- func: _scaled_mm(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False) -> Tensor - func: _scaled_mm(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False) -> Tensor
variants: function variants: function
dispatch: dispatch:
CPU: _scaled_mm_cpu
CUDA: _scaled_mm_cuda CUDA: _scaled_mm_cuda
- func: _scaled_mm.out(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False, *, Tensor(a!) out) -> Tensor(a!) - func: _scaled_mm.out(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False, *, Tensor(a!) out) -> Tensor(a!)
variants: function variants: function
dispatch: dispatch:
CPU: _scaled_mm_out_cpu
CUDA: _scaled_mm_out_cuda CUDA: _scaled_mm_out_cuda

View File

@ -13,7 +13,7 @@ from torch.testing._internal.common_utils import (
instantiate_parametrized_tests, instantiate_parametrized_tests,
parametrize, parametrize,
) )
from torch.testing._internal.inductor_utils import HAS_CUDA from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
from torch.utils._triton import has_triton_tma_device from torch.utils._triton import has_triton_tma_device
@ -116,9 +116,9 @@ def _fix_fp8_dtype_for_rocm(
@instantiate_parametrized_tests @instantiate_parametrized_tests
class TestFP8Types(TestCase): class TestFP8Types(TestCase):
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2)) @parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2))
def test_xblock_for_small_numel(self, float8_dtype: torch.dtype): @parametrize("device", ("cuda", "cpu"))
def test_xblock_for_small_numel(self, float8_dtype: torch.dtype, device: str):
""" """
TritonOverrides.to_dtype will set min_elem_per_thread to 2 or 4 TritonOverrides.to_dtype will set min_elem_per_thread to 2 or 4
depends on the variant of fp8 type. depends on the variant of fp8 type.
@ -127,30 +127,34 @@ class TestFP8Types(TestCase):
We should not pick a XBLOCK larger than xnumel We should not pick a XBLOCK larger than xnumel
""" """
float8_dtype = _fix_fp8_dtype_for_rocm(float8_dtype, device="cuda") float8_dtype = _fix_fp8_dtype_for_rocm(float8_dtype, device=device)
if device == "cuda" and not PLATFORM_SUPPORTS_FP8:
raise unittest.SkipTest(f8_msg)
def f(x): def f(x):
return x.to(dtype=float8_dtype) return x.to(dtype=float8_dtype)
x = torch.randn(1, device="cuda") x = torch.randn(1, device=device)
expected = f(x) expected = f(x)
actual = torch.compile(f)(x) actual = torch.compile(f)(x)
torch.testing.assert_close(expected.half(), actual.half(), rtol=1e-2, atol=1e-2) torch.testing.assert_close(expected.half(), actual.half(), rtol=1e-2, atol=1e-2)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@parametrize("dtype", (torch.float16, torch.bfloat16)) @parametrize("dtype", (torch.float16, torch.bfloat16))
def test_eager_fallback(self, dtype: torch.dtype): @parametrize("device", ("cuda", "cpu"))
def test_eager_fallback(self, dtype: torch.dtype, device: torch.device):
if device == "cuda" and not PLATFORM_SUPPORTS_FP8:
raise unittest.SkipTest(f8_msg)
weight_shape = (32, 16) weight_shape = (32, 16)
e4m3_type = torch.float8_e4m3fn e4m3_type = torch.float8_e4m3fn
e4m3_type = _fix_fp8_dtype_for_rocm(e4m3_type, device="cuda") e4m3_type = _fix_fp8_dtype_for_rocm(e4m3_type, device=device)
def fp8_matmul_unwrapped(x): def fp8_matmul_unwrapped(x):
a_scale = torch.Tensor([1.0]).to(device="cuda") a_scale = torch.Tensor([1.0]).to(device=device)
b_scale = torch.Tensor([1.0]).to(device="cuda") b_scale = torch.Tensor([1.0]).to(device=device)
output_scale = None output_scale = None
input_bias = torch.rand(32, device="cuda", dtype=dtype) input_bias = torch.rand(32, device=device, dtype=dtype)
weight = torch.rand(*weight_shape, device="cuda", dtype=dtype).T.to( weight = torch.rand(*weight_shape, device=device, dtype=dtype).T.to(
e4m3_type e4m3_type
) )
a_inverse_scale = 1 / a_scale a_inverse_scale = 1 / a_scale
@ -171,19 +175,23 @@ class TestFP8Types(TestCase):
) )
x_shape = (16, 16) x_shape = (16, 16)
x = torch.rand(*x_shape, device="cuda", dtype=dtype).to(e4m3_type) x = torch.rand(*x_shape, device=device, dtype=dtype).to(e4m3_type)
y_fp8 = compiled_fp8_matmul(x) # noqa: F841 y_fp8 = compiled_fp8_matmul(x) # noqa: F841
x_shape = (15, 16) x_shape = (15, 16)
x = torch.rand(*x_shape, device="cuda", dtype=dtype).to(e4m3_type) x = torch.rand(*x_shape, device=device, dtype=dtype).to(e4m3_type)
y_fp8 = compiled_fp8_matmul(x) # noqa: F841 y_fp8 = compiled_fp8_matmul(x) # noqa: F841
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@parametrize("dtype", (torch.float16, torch.bfloat16, torch.float)) @parametrize("dtype", (torch.float16, torch.bfloat16, torch.float))
@parametrize("shape", ("15,3,13", "4,2048,4096")) @parametrize("shape", ("15,3,13", "4,2048,4096"))
@parametrize("dst_types", [(torch.float8_e4m3fn, torch.float8_e5m2)]) @parametrize("dst_types", [(torch.float8_e4m3fn, torch.float8_e5m2)])
def test_valid_cast(self, dtype: torch.dtype, shape: str, dst_types: tuple): @parametrize("device", ("cuda", "cpu"))
dst_types = _fix_fp8_dtype_for_rocm(dst_types, device="cuda") def test_valid_cast(
self, dtype: torch.dtype, shape: str, dst_types: tuple, device: torch.device
):
if device == "cuda" and not PLATFORM_SUPPORTS_FP8:
raise unittest.SkipTest(f8_msg)
dst_types = _fix_fp8_dtype_for_rocm(dst_types, device=device)
e4m3, e5m2 = dst_types e4m3, e5m2 = dst_types
def fp8_cast(x): def fp8_cast(x):
@ -194,7 +202,7 @@ class TestFP8Types(TestCase):
compiled_fp8_cast = torch.compile(fp8_cast, backend="inductor", dynamic=True) compiled_fp8_cast = torch.compile(fp8_cast, backend="inductor", dynamic=True)
shape = [int(dim) for dim in shape.split(",")] shape = [int(dim) for dim in shape.split(",")]
x = torch.rand(*shape, device="cuda", dtype=dtype) x = torch.rand(*shape, device=device, dtype=dtype)
y0_fp8, y1_fp8 = compiled_fp8_cast(x) y0_fp8, y1_fp8 = compiled_fp8_cast(x)
torch.testing.assert_close(y0_fp8, x, rtol=5e-1, atol=5e-1) torch.testing.assert_close(y0_fp8, x, rtol=5e-1, atol=5e-1)
@ -223,14 +231,20 @@ class TestFP8Types(TestCase):
x = torch.rand(*x_shape, device="cuda").to(dtype=torch.float8_e5m2) x = torch.rand(*x_shape, device="cuda").to(dtype=torch.float8_e5m2)
compiled_fp8_cast(x, torch.float8_e4m3fn) compiled_fp8_cast(x, torch.float8_e4m3fn)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@parametrize("src_dtype", (torch.float16, torch.bfloat16, torch.float)) @parametrize("src_dtype", (torch.float16, torch.bfloat16, torch.float))
@parametrize("dst_dtype", (torch.float8_e4m3fn, torch.float8_e5m2)) @parametrize("dst_dtype", (torch.float8_e4m3fn, torch.float8_e5m2))
@parametrize("shape", ("16,16,16", "4,2048,4096")) @parametrize("shape", ("16,16,16", "4,2048,4096"))
@parametrize("device", ("cuda", "cpu"))
def test_to_fp8_saturated( def test_to_fp8_saturated(
self, src_dtype: torch.dtype, dst_dtype: torch.dtype, shape: str self,
src_dtype: torch.dtype,
dst_dtype: torch.dtype,
shape: str,
device: torch.device,
): ):
dst_dtype = _fix_fp8_dtype_for_rocm(dst_dtype, device="cuda") if device == "cuda" and not PLATFORM_SUPPORTS_FP8:
raise unittest.SkipTest(f8_msg)
dst_dtype = _fix_fp8_dtype_for_rocm(dst_dtype, device=device)
def fp8_saturated(x, dtype): def fp8_saturated(x, dtype):
return _to_fp8_saturated(x, dtype) return _to_fp8_saturated(x, dtype)
@ -239,17 +253,23 @@ class TestFP8Types(TestCase):
fp8_saturated, backend="inductor", dynamic=True fp8_saturated, backend="inductor", dynamic=True
) )
shape = [int(dim) for dim in shape.split(",")] shape = [int(dim) for dim in shape.split(",")]
x = torch.rand(*shape, device="cuda", dtype=src_dtype) x = torch.rand(*shape, device=device, dtype=src_dtype)
y_compiled = compiled_fp8_cast(x, dst_dtype) y_compiled = compiled_fp8_cast(x, dst_dtype)
y = fp8_saturated(x, dst_dtype) y = fp8_saturated(x, dst_dtype)
torch.testing.assert_close(y_compiled.half(), y.half(), rtol=5e-1, atol=5e-1) torch.testing.assert_close(y_compiled.half(), y.half(), rtol=5e-1, atol=5e-1)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2)) @parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2))
@parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096")) @parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096"))
def test_amax_fp8_quant(self, float8_dtype: torch.dtype, shape: str): @parametrize("device", ("cuda", "cpu"))
float8_dtype = _fix_fp8_dtype_for_rocm(float8_dtype, device="cuda") def test_amax_fp8_quant(
self, float8_dtype: torch.dtype, shape: str, device: torch.device
):
float8_dtype = _fix_fp8_dtype_for_rocm(float8_dtype, device=device)
if device == "cuda" and not PLATFORM_SUPPORTS_FP8:
raise unittest.SkipTest(
"FP8 is only supported on H100+ and sm_89 and MI300+ devices"
)
shape = [int(dim) for dim in shape.split(",")] shape = [int(dim) for dim in shape.split(",")]
batch_size, sequence_length, hidden_size = shape batch_size, sequence_length, hidden_size = shape
@ -262,19 +282,23 @@ class TestFP8Types(TestCase):
compiled_amax_fp8_quant = torch.compile(amax_fp8, backend="inductor") compiled_amax_fp8_quant = torch.compile(amax_fp8, backend="inductor")
x_shape = (batch_size, sequence_length, hidden_size) x_shape = (batch_size, sequence_length, hidden_size)
x = torch.rand(*x_shape, device="cuda", dtype=torch.half) x = torch.rand(*x_shape, device=device, dtype=torch.half)
scale = torch.tensor(0.2, device="cuda", dtype=torch.float) scale = torch.tensor(0.2, device=device, dtype=torch.float)
y_compiled = compiled_amax_fp8_quant(x, scale) y_compiled = compiled_amax_fp8_quant(x, scale)
y = amax_fp8(x, scale) y = amax_fp8(x, scale)
torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-2, atol=1e-2) torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-2, atol=1e-2)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2)) @parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2))
@parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096")) @parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096"))
def test_amax_along_with_fp8_quant(self, float8_dtype: torch.dtype, shape: str): @parametrize("device", ("cuda", "cpu"))
float8_dtype = _fix_fp8_dtype_for_rocm(float8_dtype, device="cuda") def test_amax_along_with_fp8_quant(
self, float8_dtype: torch.dtype, shape: str, device: torch.device
):
if device == "cuda" and not PLATFORM_SUPPORTS_FP8:
raise unittest.SkipTest(f8_msg)
float8_dtype = _fix_fp8_dtype_for_rocm(float8_dtype, device=device)
shape = [int(dim) for dim in shape.split(",")] shape = [int(dim) for dim in shape.split(",")]
batch_size, sequence_length, hidden_size = shape batch_size, sequence_length, hidden_size = shape
@ -287,12 +311,12 @@ class TestFP8Types(TestCase):
compiled_amax_fp8_quant = torch.compile(amax_fp8, backend="inductor") compiled_amax_fp8_quant = torch.compile(amax_fp8, backend="inductor")
x_shape = (batch_size, sequence_length, hidden_size) x_shape = (batch_size, sequence_length, hidden_size)
x = torch.rand(*x_shape, device="cuda", dtype=torch.half) x = torch.rand(*x_shape, device=device, dtype=torch.half)
scale = torch.tensor(1.0, device="cuda", dtype=torch.float) scale = torch.tensor(1.0, device=device, dtype=torch.float)
amax_buffer_compiled = torch.zeros((1), device="cuda", dtype=torch.half) amax_buffer_compiled = torch.zeros((1), device=device, dtype=torch.half)
y_compiled = compiled_amax_fp8_quant(x, scale, amax_buffer_compiled) y_compiled = compiled_amax_fp8_quant(x, scale, amax_buffer_compiled)
amax_buffer = torch.zeros((1), device="cuda", dtype=torch.half) amax_buffer = torch.zeros((1), device=device, dtype=torch.half)
y = amax_fp8(x, scale, amax_buffer) y = amax_fp8(x, scale, amax_buffer)
torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-1, atol=1e-1) torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-1, atol=1e-1)
@ -300,14 +324,22 @@ class TestFP8Types(TestCase):
amax_buffer_compiled, amax_buffer, rtol=1e-2, atol=1e-2 amax_buffer_compiled, amax_buffer, rtol=1e-2, atol=1e-2
) )
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2)) @parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2))
@parametrize("amax_keep_dim", (True, False)) @parametrize("amax_keep_dim", (True, False))
@parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096")) @parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096"))
@parametrize("device", ("cuda", "cpu"))
def test_layernorm_fp8_quant( def test_layernorm_fp8_quant(
self, float8_dtype: torch.dtype, amax_keep_dim: bool, shape: str self,
float8_dtype: torch.dtype,
amax_keep_dim: bool,
shape: str,
device: torch.device,
): ):
float8_dtype = _fix_fp8_dtype_for_rocm(float8_dtype, device="cuda") if device == "cuda" and not PLATFORM_SUPPORTS_FP8:
raise unittest.SkipTest(
"FP8 is only supported on H100+ and sm_89 and MI300+ devices"
)
float8_dtype = _fix_fp8_dtype_for_rocm(float8_dtype, device=device)
shape = [int(dim) for dim in shape.split(",")] shape = [int(dim) for dim in shape.split(",")]
batch_size, sequence_length, hidden_size = shape batch_size, sequence_length, hidden_size = shape
@ -329,12 +361,12 @@ class TestFP8Types(TestCase):
compiled_ln_fp8_quant = torch.compile(ln_fp8, backend="inductor") compiled_ln_fp8_quant = torch.compile(ln_fp8, backend="inductor")
x_shape = (batch_size, sequence_length, hidden_size) x_shape = (batch_size, sequence_length, hidden_size)
x = torch.rand(*x_shape, device="cuda", dtype=torch.half) x = torch.rand(*x_shape, device=device, dtype=torch.half)
scale = torch.tensor(0.2, device="cuda", dtype=torch.float) scale = torch.tensor(0.2, device=device, dtype=torch.float)
amax_buffer_compiled = torch.zeros((1), device="cuda", dtype=torch.half) amax_buffer_compiled = torch.zeros((1), device=device, dtype=torch.half)
y_compiled = compiled_ln_fp8_quant(x, scale, amax_buffer_compiled) y_compiled = compiled_ln_fp8_quant(x, scale, amax_buffer_compiled)
amax_buffer = torch.zeros((1), device="cuda", dtype=torch.half) amax_buffer = torch.zeros((1), device=device, dtype=torch.half)
y = ln_fp8(x, scale, amax_buffer) y = ln_fp8(x, scale, amax_buffer)
torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-1, atol=1e-1) torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-1, atol=1e-1)
@ -750,5 +782,5 @@ class TestFP8Lowering(TestCase):
if __name__ == "__main__": if __name__ == "__main__":
if HAS_CUDA: if HAS_CUDA or HAS_CPU:
run_tests() run_tests()

View File

@ -653,15 +653,15 @@ def _bfloat16_to_float4_e2m1fn_x2(x):
return x return x
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not found") class TestFP8Matmul(TestCase):
class TestFP8MatmulCuda(TestCase):
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
def _test_tautological_mm(self, device: str = "cuda", def _test_tautological_mm(self, device: str = "cuda",
x_dtype: torch.dtype = e4m3_type, x_dtype: torch.dtype = e4m3_type,
y_dtype: torch.dtype = e4m3_type, y_dtype: torch.dtype = e4m3_type,
out_dtype: Optional[torch.dtype] = None, out_dtype: Optional[torch.dtype] = None,
size: int = 16) -> None: size: int = 16) -> None:
if device != "cpu" and torch.cuda.is_available() and not PLATFORM_SUPPORTS_FP8:
raise unittest.SkipTest(f8_msg)
x_fp8 = torch.rand(size, size, device=device).to(x_dtype) x_fp8 = torch.rand(size, size, device=device).to(x_dtype)
y_fp8 = torch.eye(size, device=device, dtype=y_dtype).t() y_fp8 = torch.eye(size, device=device, dtype=y_dtype).t()
out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float)) out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float))
@ -672,12 +672,13 @@ class TestFP8MatmulCuda(TestCase):
self.assertEqual(out_dtype, out_fp8.dtype) self.assertEqual(out_dtype, out_fp8.dtype)
self.assertEqual(out_fp32, out_fp8.to(torch.float)) self.assertEqual(out_fp32, out_fp8.to(torch.float))
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
def test_float8_basics(self, device) -> None: def test_float8_basics(self, device) -> None:
if device != "cpu" and torch.cuda.is_available() and not PLATFORM_SUPPORTS_FP8:
raise unittest.SkipTest(f8_msg)
self._test_tautological_mm(device, e4m3_type, e4m3_type, size=16) self._test_tautological_mm(device, e4m3_type, e4m3_type, size=16)
# According to https://docs.nvidia.com/cuda/cublas/#id99 8F_E5M2 MM is unsupported # According to https://docs.nvidia.com/cuda/cublas/#id99 8F_E5M2 MM is unsupported
# supported on ROCm but fails on CUDA # supported on ROCm but fails on CUDA
ctx = self.assertRaises(RuntimeError) if torch.version.hip is None else contextlib.nullcontext() ctx = self.assertRaises(RuntimeError) if torch.version.hip is None and device != "cpu" else contextlib.nullcontext()
with ctx: with ctx:
self._test_tautological_mm(device, e5m2_type, e5m2_type) self._test_tautological_mm(device, e5m2_type, e5m2_type)
@ -688,11 +689,12 @@ class TestFP8MatmulCuda(TestCase):
self._test_tautological_mm(device, size=96, out_dtype=torch.float32) self._test_tautological_mm(device, size=96, out_dtype=torch.float32)
self._test_tautological_mm(device, size=80, out_dtype=torch.bfloat16) self._test_tautological_mm(device, size=80, out_dtype=torch.bfloat16)
with self.assertRaises(AssertionError if torch.version.hip else RuntimeError): with self.assertRaises(AssertionError if torch.version.hip or device == "cpu" else RuntimeError):
self._test_tautological_mm(device, out_dtype=e5m2_type) self._test_tautological_mm(device, out_dtype=e5m2_type)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
def test_float8_scale(self, device) -> None: def test_float8_scale(self, device) -> None:
if device != "cpu" and torch.cuda.is_available() and not PLATFORM_SUPPORTS_FP8:
raise unittest.SkipTest(f8_msg)
size = (16, 16) size = (16, 16)
x = torch.full(size, .5, device=device, dtype=e4m3_type) x = torch.full(size, .5, device=device, dtype=e4m3_type)
# hipblaslt does not yet support mixed e4m3_type input # hipblaslt does not yet support mixed e4m3_type input
@ -807,8 +809,9 @@ class TestFP8MatmulCuda(TestCase):
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
def test_float8_bias(self, device) -> None: def test_float8_bias(self, device) -> None:
if device != "cpu" and torch.cuda.is_available() and not PLATFORM_SUPPORTS_FP8:
raise unittest.SkipTest(f8_msg)
(k, l, m) = (16, 48, 32) (k, l, m) = (16, 48, 32)
x = torch.ones((k, l), device=device).to(e4m3_type) x = torch.ones((k, l), device=device).to(e4m3_type)
y = torch.full((m, l), .25, device=device, dtype=e4m3_type).t() y = torch.full((m, l), .25, device=device, dtype=e4m3_type).t()
@ -861,7 +864,7 @@ class TestFP8MatmulCuda(TestCase):
lambda: torch._scaled_mm(x, y, scale_a, scale_b, bias=bias, out_dtype=torch.float32), lambda: torch._scaled_mm(x, y, scale_a, scale_b, bias=bias, out_dtype=torch.float32),
) )
@unittest.skipIf(PLATFORM_SUPPORTS_FP8, f8_msg) @unittest.skipIf(PLATFORM_SUPPORTS_FP8 or not torch.cuda.is_available(), f8_msg)
def test_error_message_fp8_pre_sm89(self, device) -> None: def test_error_message_fp8_pre_sm89(self, device) -> None:
(k, l, m) = (16, 48, 32) (k, l, m) = (16, 48, 32)
x = torch.rand((k, l), device=device).to(e4m3_type) x = torch.rand((k, l), device=device).to(e4m3_type)
@ -1718,8 +1721,8 @@ class TestMixedDtypesLinearCuda(TestCase):
) )
instantiate_device_type_tests(TestMatmulCuda, globals(), except_for="cpu") instantiate_device_type_tests(TestMatmulCuda, globals(), except_for="cpu")
instantiate_device_type_tests(TestFP8MatmulCuda, globals(), except_for="cpu")
instantiate_device_type_tests(TestMixedDtypesLinearCuda, globals(), except_for="cpu") instantiate_device_type_tests(TestMixedDtypesLinearCuda, globals(), except_for="cpu")
instantiate_device_type_tests(TestFP8Matmul, globals())
if __name__ == '__main__': if __name__ == '__main__':
TestCase._default_dtype_check_enabled = True TestCase._default_dtype_check_enabled = True

View File

@ -21,6 +21,8 @@
#include <c10/util/Float8_e4m3fn.h> #include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e5m2.h> #include <c10/util/Float8_e5m2.h>
#include <c10/util/Float8_e4m3fnuz.h>
#include <c10/util/Float8_e5m2fnuz.h>
#include <c10/util/BFloat16.h> #include <c10/util/BFloat16.h>
#include <c10/util/BFloat16-math.h> #include <c10/util/BFloat16-math.h>
#include <c10/util/generic_math.h> #include <c10/util/generic_math.h>

View File

@ -37,6 +37,8 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_flash_attent
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_flash_attention_for_cpu_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle out, AtenTensorHandle logsumexp, double dropout_p, int32_t is_causal, AtenTensorHandle* attn_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_flash_attention_for_cpu_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle out, AtenTensorHandle logsumexp, double dropout_p, int32_t is_causal, AtenTensorHandle* attn_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__segment_reduce_backward(AtenTensorHandle grad, AtenTensorHandle output, AtenTensorHandle data, const char* reduce, AtenTensorHandle* lengths, AtenTensorHandle* offsets, int64_t axis, double* initial, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__segment_reduce_backward(AtenTensorHandle grad, AtenTensorHandle output, AtenTensorHandle data, const char* reduce, AtenTensorHandle* lengths, AtenTensorHandle* offsets, int64_t axis, double* initial, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0);

View File

@ -1003,6 +1003,8 @@ ANY_DTYPE_ORDER = (
torch.int8, torch.int8,
torch.uint8, torch.uint8,
torch.bool, torch.bool,
torch.float8_e4m3fn,
torch.float8_e5m2,
) )

View File

@ -22,7 +22,7 @@ from torch.testing import make_tensor
from torch.testing._internal.common_dtype import ( from torch.testing._internal.common_dtype import (
_dispatch_dtypes, floating_types, floating_types_and, complex_types, floating_and_complex_types, _dispatch_dtypes, floating_types, floating_types_and, complex_types, floating_and_complex_types,
floating_and_complex_types_and, all_types_and_complex_and, all_types_and, all_types_and_complex, integral_types_and, floating_and_complex_types_and, all_types_and_complex_and, all_types_and, all_types_and_complex, integral_types_and,
empty_types, complex_types_and, integral_types, custom_types, all_types_complex_float8_and, empty_types, complex_types_and, integral_types, custom_types, all_types_complex_float8_and, float8_types,
) )
from torch.testing._internal.common_device_type import \ from torch.testing._internal.common_device_type import \
(onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver, (onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver,
@ -16221,7 +16221,7 @@ op_db: list[OpInfo] = [
OpInfo( OpInfo(
'torch._scaled_mm', 'torch._scaled_mm',
sample_inputs_func=sample_inputs_scaled_mm, sample_inputs_func=sample_inputs_scaled_mm,
dtypes=empty_types(), dtypes=float8_types(),
dtypesIfCUDA=empty_types() + (torch.float8_e4m3fn,), dtypesIfCUDA=empty_types() + (torch.float8_e4m3fn,),
supports_out=True, supports_out=True,
supports_forward_ad=False, supports_forward_ad=False,
@ -16229,12 +16229,20 @@ op_db: list[OpInfo] = [
decorators=[skipCUDAIf(not SM89OrLater or TEST_WITH_ROCM, 'Requires CUDA SM >= 8.9')], decorators=[skipCUDAIf(not SM89OrLater or TEST_WITH_ROCM, 'Requires CUDA SM >= 8.9')],
skips=( skips=(
# Sample inputs isn't really parametrized on dtype # Sample inputs isn't really parametrized on dtype
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes'),
device_type='cuda'), # "add_stub" not implemented for 'Float8_e4m3fn'
# "mul_cuda" not implemented for float8_e4m3fn # "ufunc_add_CUDA" not implemented for 'Float8_e4m3fn'
# https://github.com/pytorch/pytorch/issues/107256 # https://github.com/pytorch/pytorch/issues/107256
DecorateInfo(unittest.skip("Skipped!"), 'TestSchemaCheckModeOpInfo', 'test_schema_correctness', DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'),
dtypes=(torch.float8_e4m3fn,)), # "mul_cuda" not implemented for float8_e4m3fn
# "mul_cpu_reduced_float" not implemented for 'Float8_e4m3fn'
# https://github.com/pytorch/pytorch/issues/107256
DecorateInfo(unittest.skip("Skipped!"), 'TestSchemaCheckModeOpInfo', 'test_schema_correctness'),
# aten::_scaled_mm hit the vmap fallback which is currently disabled
DecorateInfo(unittest.skip("Skipped!"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"),
DecorateInfo(unittest.skip("Skipped!"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"),
DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness',
dtypes=(torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz)),
) )
), ),
OpInfo( OpInfo(