mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add torch._scaled_mm for CPU (#139975)
This PR is to add `torch._scaled_mm` for CPU backend. `_scaled_mm_out_cpu` and `_scaled_mm_cpu` are new added and included in `torch._scaled_mm` CPU dispatch. We also add `_scaled_mm_out_cpu_emulated` as a fallback function if the current platform cannot run FP8 matmul using oneDNN. And this PR also updates the various UTs related to FP8 to support CPU tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/139975 Approved by: https://github.com/mingfeima, https://github.com/jgong5, https://github.com/malfet
This commit is contained in:
parent
d3e9133ab2
commit
cbc4cf3043
|
|
@ -7,6 +7,11 @@
|
|||
#include <ATen/Config.h>
|
||||
|
||||
#include <ATen/native/mkldnn/Matmul.h>
|
||||
#include <ATen/native/mkldnn/Linear.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#if !defined(__s390x__) && !defined(__powerpc__)
|
||||
#include <cpuinfo.h>
|
||||
#endif
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/CPUFunctions.h>
|
||||
|
|
@ -24,6 +29,9 @@
|
|||
#include <ATen/ops/mv_native.h>
|
||||
#include <ATen/ops/scalar_tensor_native.h>
|
||||
#include <ATen/ops/vdot_native.h>
|
||||
#include <ATen/ops/_scaled_mm_native.h>
|
||||
#include <ATen/ops/mul.h>
|
||||
#include <ATen/ops/matmul.h>
|
||||
#endif
|
||||
|
||||
namespace at::meta {
|
||||
|
|
@ -222,4 +230,79 @@ Tensor vdot(const Tensor &self, const Tensor &other){
|
|||
|
||||
}
|
||||
|
||||
static Tensor&
|
||||
_scaled_mm_out_cpu_emulated(const Tensor& mat1, const Tensor& mat2,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
const std::optional<at::Tensor>& scale_result,
|
||||
std::optional<c10::ScalarType> out_dtype,
|
||||
bool use_fast_accum,
|
||||
Tensor& out) {
|
||||
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix");
|
||||
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
|
||||
TORCH_CHECK(
|
||||
mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
|
||||
mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
|
||||
|
||||
TORCH_INTERNAL_ASSERT((scale_a.numel() == 1 && scale_b.numel() == 1), "Now _scaled_mm only supports per-tensor scaling for CPU backend.");
|
||||
TORCH_CHECK(!bias || bias->numel() == mat2.sizes()[1], "Bias must be size ", mat2.sizes()[1],
|
||||
" but got ", bias->numel());
|
||||
|
||||
// Check types
|
||||
TORCH_CHECK(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type");
|
||||
TORCH_CHECK(isFloat8Type(mat1.scalar_type()), "Expected mat1 to be Float8 matrix got ", mat1.scalar_type());
|
||||
TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat2.scalar_type());
|
||||
|
||||
auto mat1_c = mat1.contiguous();
|
||||
auto mat2_c = mat2.contiguous();
|
||||
IntArrayRef mat1_sizes = mat1_c.sizes();
|
||||
IntArrayRef mat2_sizes = mat2_c.sizes();
|
||||
at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
|
||||
|
||||
float input_scale = scale_a.item<float>();
|
||||
float weight_scale = scale_b.item<float>();
|
||||
auto fp32_mat1 = at::mul(mat1.to(kFloat), input_scale);
|
||||
auto fp32_mat2 = at::mul(mat2_c.to(kFloat), weight_scale);
|
||||
auto out_tmp = at::matmul(fp32_mat1, fp32_mat2);
|
||||
if (bias) {
|
||||
out_tmp.add_(bias.value());
|
||||
}
|
||||
out_tmp = out_tmp.to(out.scalar_type());
|
||||
out.copy_(out_tmp);
|
||||
return out;
|
||||
}
|
||||
|
||||
Tensor&
|
||||
_scaled_mm_out_cpu(const Tensor& mat1, const Tensor& mat2,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
const std::optional<at::Tensor>& scale_result,
|
||||
std::optional<c10::ScalarType> out_dtype,
|
||||
bool use_fast_accum,
|
||||
Tensor& out) {
|
||||
#if AT_MKLDNN_ENABLED()
|
||||
if (at::globalContext().userEnabledMkldnn() && cpuinfo_has_x86_amx_int8()) {
|
||||
return mkldnn_scaled_mm(mat1, mat2, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out);
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
return _scaled_mm_out_cpu_emulated(mat1, mat2, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out);
|
||||
}
|
||||
}
|
||||
|
||||
Tensor
|
||||
_scaled_mm_cpu(const Tensor& mat_a, const Tensor& mat_b,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
const std::optional<at::Tensor>& scale_result,
|
||||
std::optional<c10::ScalarType> out_dtype,
|
||||
bool use_fast_accum) {
|
||||
const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type());
|
||||
Tensor out = at::empty({0}, mat_a.options().dtype(out_dtype_));
|
||||
return _scaled_mm_out_cpu(mat_a, mat_b, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out);
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
#include <ATen/core/Tensor.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/native/mkldnn/Linear.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
|
|
@ -46,6 +47,18 @@ std::tuple<Tensor, Tensor, Tensor> mkldnn_linear_backward(
|
|||
TORCH_CHECK(false, "mkldnn_linear_backward: ATen not compiled with MKLDNN support");
|
||||
}
|
||||
|
||||
Tensor&
|
||||
mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
const std::optional<at::Tensor>& scale_result,
|
||||
std::optional<c10::ScalarType> out_dtype,
|
||||
bool use_fast_accum,
|
||||
Tensor& out) {
|
||||
TORCH_INTERNAL_ASSERT(false, "mkldnn_scaled_mm: ATen not compiled with MKLDNN support");
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
||||
|
|
@ -447,6 +460,119 @@ TORCH_LIBRARY_IMPL(mkldnn, MkldnnCPU, m) {
|
|||
TORCH_FN(mkldnn_linear_pointwise_binary));
|
||||
}
|
||||
|
||||
Tensor&
|
||||
mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
const std::optional<at::Tensor>& scale_result,
|
||||
std::optional<c10::ScalarType> out_dtype,
|
||||
bool use_fast_accum,
|
||||
Tensor& out) {
|
||||
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix");
|
||||
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
|
||||
TORCH_CHECK(
|
||||
mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
|
||||
mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
|
||||
|
||||
TORCH_INTERNAL_ASSERT((scale_a.numel() == 1 && scale_b.numel() == 1), "Now _scaled_mm only supports per-tensor scaling for CPU backend.");
|
||||
TORCH_CHECK(!bias || bias->numel() == mat2.sizes()[1], "Bias must be size ", mat2.sizes()[1],
|
||||
" but got ", bias->numel());
|
||||
|
||||
// Check types
|
||||
TORCH_CHECK(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type");
|
||||
TORCH_CHECK(isFloat8Type(mat1.scalar_type()), "Expected mat1 to be Float8 matrix got ", mat1.scalar_type());
|
||||
TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat2.scalar_type());
|
||||
// TODO: This check of mat1 and mat2 must have the same data type will be removed after oneDNN v3.6.
|
||||
TORCH_CHECK(mat1.scalar_type() == mat2.scalar_type(), "Expected mat1 and mat2 must have the same data type");
|
||||
|
||||
// Validation checks have passed lets resize the output to actual size
|
||||
auto mat1_c = mat1.contiguous();
|
||||
auto mat2_c = mat2.contiguous();
|
||||
IntArrayRef mat1_sizes = mat1_c.sizes();
|
||||
IntArrayRef mat2_sizes = mat2_c.sizes();
|
||||
at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
|
||||
|
||||
float input_scale = scale_a.item<float>();
|
||||
float weight_scale = scale_b.item<float>();
|
||||
auto src = at::native::itensor_view_from_dense(mat1_c);
|
||||
auto weight_t = at::native::itensor_view_from_dense(mat2_c);
|
||||
bool with_bias = bias.has_value();
|
||||
int64_t K = mat1_sizes[1], M = mat1_sizes[0],
|
||||
N = mat2_sizes[1];
|
||||
|
||||
std::vector<int64_t> src_dims = {M, K};
|
||||
std::vector<int64_t> weight_dims = {K, N};
|
||||
std::vector<int64_t> dst_dims = {M, N};
|
||||
|
||||
ideep::tensor dst = at::native::itensor_view_from_dense(out);
|
||||
auto src_desc = ideep::tensor::desc(
|
||||
src_dims,
|
||||
get_mkldnn_dtype(mat1.scalar_type()),
|
||||
ideep::format_tag::any);
|
||||
auto weights_desc = ideep::tensor::desc(
|
||||
weight_dims,
|
||||
get_mkldnn_dtype(mat2.scalar_type()),
|
||||
ideep::format_tag::any);
|
||||
auto dst_desc = ideep::tensor::desc(
|
||||
dst_dims,
|
||||
get_mkldnn_dtype(out.scalar_type()),
|
||||
ideep::format_tag::any);
|
||||
ideep::tensor onednn_bias;
|
||||
if (with_bias) {
|
||||
auto bias_value = bias.value();
|
||||
if (bias_value.dim() == 1) {
|
||||
auto b_reshape = bias_value.reshape({1, bias_value.size(0)});
|
||||
onednn_bias = at::native::itensor_view_from_dense(b_reshape);
|
||||
} else {
|
||||
onednn_bias = at::native::itensor_view_from_dense(bias_value);
|
||||
}
|
||||
}
|
||||
auto bias_desc = ideep::tensor::desc();
|
||||
if (with_bias) {
|
||||
bias_desc = ideep::tensor::desc(onednn_bias.get_dims(),
|
||||
get_mkldnn_dtype(bias.value().scalar_type()),
|
||||
ideep::format_tag::any);
|
||||
}
|
||||
auto op_attr = ideep::attr_t();
|
||||
if (input_scale != 1.0f) {
|
||||
op_attr.set_scales_mask(DNNL_ARG_SRC, 0);
|
||||
}
|
||||
if (weight_scale != 1.0f) {
|
||||
op_attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0);
|
||||
}
|
||||
|
||||
op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
|
||||
auto engine = ideep::engine::cpu_engine();
|
||||
dnnl::matmul::primitive_desc primitive_desc = with_bias
|
||||
? dnnl::matmul::primitive_desc(
|
||||
engine, src_desc, weights_desc, bias_desc, dst_desc, op_attr)
|
||||
: dnnl::matmul::primitive_desc(
|
||||
engine, src_desc, weights_desc, dst_desc, op_attr);
|
||||
auto primitive = dnnl::matmul(primitive_desc);
|
||||
|
||||
// Prepare args and execute primitive
|
||||
ideep::tensor scratchpad(primitive_desc.scratchpad_desc());
|
||||
ideep::exec_args args;
|
||||
args.insert({DNNL_ARG_SRC, src});
|
||||
args.insert({DNNL_ARG_WEIGHTS, weight_t});
|
||||
args.insert({DNNL_ARG_DST, dst});
|
||||
args.insert({DNNL_ARG_SCRATCHPAD, scratchpad});
|
||||
if (with_bias) {
|
||||
args.insert({DNNL_ARG_BIAS, onednn_bias});
|
||||
}
|
||||
ideep::tensor src_scales_t = ideep::tensor(ideep::scale_t(1, input_scale));
|
||||
ideep::tensor wei_scales_t = ideep::tensor(ideep::scale_t(1, weight_scale));
|
||||
|
||||
if (input_scale != 1.0f) {
|
||||
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scales_t});
|
||||
}
|
||||
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_scales_t});
|
||||
|
||||
primitive.execute(ideep::stream::default_stream(), args);
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace at
|
||||
|
||||
#endif // AT_MKLDNN_ENABLED
|
||||
|
|
|
|||
|
|
@ -35,3 +35,15 @@ C10_API Tensor mkl_linear(
|
|||
} // namespace at
|
||||
|
||||
#endif // AT_MKLDNN_ENABLED()
|
||||
|
||||
namespace at::native {
|
||||
Tensor&
|
||||
mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& scale_b,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
const std::optional<at::Tensor>& scale_result,
|
||||
std::optional<c10::ScalarType> out_dtype,
|
||||
bool use_fast_accum,
|
||||
Tensor& out);
|
||||
} // namespace at::native
|
||||
|
|
|
|||
|
|
@ -57,6 +57,10 @@ ideep::tensor::data_type get_mkldnn_dtype(ScalarType type) {
|
|||
return ideep::tensor::data_type::bf16;
|
||||
case ScalarType::Half:
|
||||
return ideep::tensor::data_type::f16;
|
||||
case ScalarType::Float8_e4m3fn:
|
||||
return ideep::tensor::data_type::f8_e4m3;
|
||||
case ScalarType::Float8_e5m2:
|
||||
return ideep::tensor::data_type::f8_e5m2;
|
||||
default:
|
||||
TORCH_CHECK(false, "get_mkldnn_dtype: unsupported data type");
|
||||
}
|
||||
|
|
@ -161,8 +165,24 @@ ideep::tensor itensor_view_from_dense(const Tensor& tensor, bool from_const_data
|
|||
const_cast<void*>(tensor.const_data_ptr()) :
|
||||
tensor.data_ptr()};
|
||||
}
|
||||
else if (tensor.scalar_type() == ScalarType::Float8_e4m3fn) {
|
||||
return {{tensor.sizes().vec(),
|
||||
ideep::tensor::data_type::f8_e4m3,
|
||||
tensor.strides().vec()},
|
||||
from_const_data_ptr ?
|
||||
const_cast<void*>(tensor.const_data_ptr()) :
|
||||
tensor.data_ptr()};
|
||||
}
|
||||
else if (tensor.scalar_type() == ScalarType::Float8_e5m2) {
|
||||
return {{tensor.sizes().vec(),
|
||||
ideep::tensor::data_type::f8_e5m2,
|
||||
tensor.strides().vec()},
|
||||
from_const_data_ptr ?
|
||||
const_cast<void*>(tensor.const_data_ptr()) :
|
||||
tensor.data_ptr()};
|
||||
}
|
||||
else {
|
||||
TORCH_CHECK(false, "itensor_view_from_dense expects float/bfloat16/half/int8 tensor input");
|
||||
TORCH_CHECK(false, "itensor_view_from_dense expects float/bfloat16/half/int8/fp8 tensor input");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -7071,11 +7071,13 @@
|
|||
- func: _scaled_mm(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: _scaled_mm_cpu
|
||||
CUDA: _scaled_mm_cuda
|
||||
|
||||
- func: _scaled_mm.out(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False, *, Tensor(a!) out) -> Tensor(a!)
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: _scaled_mm_out_cpu
|
||||
CUDA: _scaled_mm_out_cuda
|
||||
|
||||
# NOTE [ Sparse: autograd and API ]
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from torch.testing._internal.common_utils import (
|
|||
parametrize,
|
||||
TEST_WITH_ROCM,
|
||||
)
|
||||
from torch.testing._internal.inductor_utils import HAS_CUDA
|
||||
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
|
||||
from torch.utils._triton import has_triton_tma_device
|
||||
|
||||
|
||||
|
|
@ -89,10 +89,10 @@ def _quantize_rowwise(x: Tensor, float8_dtype: torch.dtype):
|
|||
|
||||
@instantiate_parametrized_tests
|
||||
class TestFP8Types(TestCase):
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported yet")
|
||||
@parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2))
|
||||
def test_xblock_for_small_numel(self, float8_dtype: torch.dtype):
|
||||
@parametrize("device", ("cuda", "cpu"))
|
||||
def test_xblock_for_small_numel(self, float8_dtype: torch.dtype, device: str):
|
||||
"""
|
||||
TritonOverrides.to_dtype will set min_elem_per_thread to 2 or 4
|
||||
depends on the variant of fp8 type.
|
||||
|
|
@ -101,19 +101,23 @@ class TestFP8Types(TestCase):
|
|||
|
||||
We should not pick a XBLOCK larger than xnumel
|
||||
"""
|
||||
if device == "cuda" and not PLATFORM_SUPPORTS_FP8:
|
||||
raise unittest.SkipTest(f8_msg)
|
||||
|
||||
def f(x):
|
||||
return x.to(dtype=float8_dtype)
|
||||
|
||||
x = torch.randn(1, device="cuda")
|
||||
x = torch.randn(1, device=device)
|
||||
expected = f(x)
|
||||
actual = torch.compile(f)(x)
|
||||
torch.testing.assert_close(expected.half(), actual.half(), rtol=1e-2, atol=1e-2)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported yet")
|
||||
@parametrize("dtype", (torch.float16, torch.bfloat16))
|
||||
def test_eager_fallback(self, dtype: torch.dtype):
|
||||
@parametrize("device", ("cuda", "cpu"))
|
||||
def test_eager_fallback(self, dtype: torch.dtype, device: torch.device):
|
||||
if device == "cuda" and not PLATFORM_SUPPORTS_FP8:
|
||||
raise unittest.SkipTest(f8_msg)
|
||||
weight_shape = (32, 16)
|
||||
|
||||
e4m3_type = (
|
||||
|
|
@ -121,11 +125,11 @@ class TestFP8Types(TestCase):
|
|||
)
|
||||
|
||||
def fp8_matmul_unwrapped(x):
|
||||
a_scale = torch.Tensor([1.0]).to(device="cuda")
|
||||
b_scale = torch.Tensor([1.0]).to(device="cuda")
|
||||
a_scale = torch.Tensor([1.0]).to(device=device)
|
||||
b_scale = torch.Tensor([1.0]).to(device=device)
|
||||
output_scale = None
|
||||
input_bias = torch.rand(32, device="cuda", dtype=dtype)
|
||||
weight = torch.rand(*weight_shape, device="cuda", dtype=dtype).T.to(
|
||||
input_bias = torch.rand(32, device=device, dtype=dtype)
|
||||
weight = torch.rand(*weight_shape, device=device, dtype=dtype).T.to(
|
||||
e4m3_type
|
||||
)
|
||||
a_inverse_scale = 1 / a_scale
|
||||
|
|
@ -146,14 +150,13 @@ class TestFP8Types(TestCase):
|
|||
)
|
||||
|
||||
x_shape = (16, 16)
|
||||
x = torch.rand(*x_shape, device="cuda", dtype=dtype).to(e4m3_type)
|
||||
x = torch.rand(*x_shape, device=device, dtype=dtype).to(e4m3_type)
|
||||
y_fp8 = compiled_fp8_matmul(x) # noqa: F841
|
||||
|
||||
x_shape = (15, 16)
|
||||
x = torch.rand(*x_shape, device="cuda", dtype=dtype).to(e4m3_type)
|
||||
x = torch.rand(*x_shape, device=device, dtype=dtype).to(e4m3_type)
|
||||
y_fp8 = compiled_fp8_matmul(x) # noqa: F841
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
@parametrize("dtype", (torch.float16, torch.bfloat16, torch.float))
|
||||
@parametrize("shape", ("15,3,13", "4,2048,4096"))
|
||||
@parametrize(
|
||||
|
|
@ -162,7 +165,12 @@ class TestFP8Types(TestCase):
|
|||
if torch.version.hip is None
|
||||
else [(torch.float8_e4m3fnuz, torch.float8_e5m2fnuz)],
|
||||
)
|
||||
def test_valid_cast(self, dtype: torch.dtype, shape: str, dst_types: tuple):
|
||||
@parametrize("device", ("cuda", "cpu"))
|
||||
def test_valid_cast(
|
||||
self, dtype: torch.dtype, shape: str, dst_types: tuple, device: torch.device
|
||||
):
|
||||
if device == "cuda" and not PLATFORM_SUPPORTS_FP8:
|
||||
raise unittest.SkipTest(f8_msg)
|
||||
e4m3, e5m2 = dst_types
|
||||
|
||||
def fp8_cast(x):
|
||||
|
|
@ -173,7 +181,7 @@ class TestFP8Types(TestCase):
|
|||
compiled_fp8_cast = torch.compile(fp8_cast, backend="inductor", dynamic=True)
|
||||
|
||||
shape = [int(dim) for dim in shape.split(",")]
|
||||
x = torch.rand(*shape, device="cuda", dtype=dtype)
|
||||
x = torch.rand(*shape, device=device, dtype=dtype)
|
||||
y0_fp8, y1_fp8 = compiled_fp8_cast(x)
|
||||
|
||||
torch.testing.assert_close(y0_fp8, x, rtol=5e-1, atol=5e-1)
|
||||
|
|
@ -202,7 +210,6 @@ class TestFP8Types(TestCase):
|
|||
x = torch.rand(*x_shape, device="cuda").to(dtype=torch.float8_e5m2)
|
||||
compiled_fp8_cast(x, torch.float8_e4m3fn)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
@parametrize("src_dtype", (torch.float16, torch.bfloat16, torch.float))
|
||||
@parametrize(
|
||||
"dst_dtype",
|
||||
|
|
@ -211,9 +218,17 @@ class TestFP8Types(TestCase):
|
|||
else (torch.float8_e4m3fnuz, torch.float8_e5m2fnuz),
|
||||
)
|
||||
@parametrize("shape", ("16,16,16", "4,2048,4096"))
|
||||
@parametrize("device", ("cuda", "cpu"))
|
||||
def test_to_fp8_saturated(
|
||||
self, src_dtype: torch.dtype, dst_dtype: torch.dtype, shape: str
|
||||
self,
|
||||
src_dtype: torch.dtype,
|
||||
dst_dtype: torch.dtype,
|
||||
shape: str,
|
||||
device: torch.device,
|
||||
):
|
||||
if device == "cuda" and not PLATFORM_SUPPORTS_FP8:
|
||||
raise unittest.SkipTest(f8_msg)
|
||||
|
||||
def fp8_saturated(x, dtype):
|
||||
return _to_fp8_saturated(x, dtype)
|
||||
|
||||
|
|
@ -221,14 +236,13 @@ class TestFP8Types(TestCase):
|
|||
fp8_saturated, backend="inductor", dynamic=True
|
||||
)
|
||||
shape = [int(dim) for dim in shape.split(",")]
|
||||
x = torch.rand(*shape, device="cuda", dtype=src_dtype)
|
||||
x = torch.rand(*shape, device=device, dtype=src_dtype)
|
||||
y_compiled = compiled_fp8_cast(x, dst_dtype)
|
||||
y = fp8_saturated(x, dst_dtype)
|
||||
|
||||
torch.testing.assert_close(y_compiled.half(), y.half(), rtol=5e-1, atol=5e-1)
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "ROCm fails with accuracy issue")
|
||||
@unittest.skipIf(not SM90OrLater, "FP8 is only supported on H100+")
|
||||
@parametrize(
|
||||
"float8_dtype",
|
||||
(torch.float8_e4m3fn, torch.float8_e5m2)
|
||||
|
|
@ -236,7 +250,12 @@ class TestFP8Types(TestCase):
|
|||
else (torch.float8_e4m3fnuz, torch.float8_e5m2fnuz),
|
||||
)
|
||||
@parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096"))
|
||||
def test_amax_fp8_quant(self, float8_dtype: torch.dtype, shape: str):
|
||||
@parametrize("device", ("cuda", "cpu"))
|
||||
def test_amax_fp8_quant(
|
||||
self, float8_dtype: torch.dtype, shape: str, device: torch.device
|
||||
):
|
||||
if device == "cuda" and not SM90OrLater:
|
||||
raise unittest.SkipTest("FP8 is only supported on H100+")
|
||||
shape = [int(dim) for dim in shape.split(",")]
|
||||
batch_size, sequence_length, hidden_size = shape
|
||||
|
||||
|
|
@ -249,15 +268,14 @@ class TestFP8Types(TestCase):
|
|||
compiled_amax_fp8_quant = torch.compile(amax_fp8, backend="inductor")
|
||||
|
||||
x_shape = (batch_size, sequence_length, hidden_size)
|
||||
x = torch.rand(*x_shape, device="cuda", dtype=torch.half)
|
||||
scale = torch.tensor(0.2, device="cuda", dtype=torch.float)
|
||||
x = torch.rand(*x_shape, device=device, dtype=torch.half)
|
||||
scale = torch.tensor(0.2, device=device, dtype=torch.float)
|
||||
|
||||
y_compiled = compiled_amax_fp8_quant(x, scale)
|
||||
y = amax_fp8(x, scale)
|
||||
|
||||
torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-2, atol=1e-2)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
@parametrize(
|
||||
"float8_dtype",
|
||||
(torch.float8_e4m3fn, torch.float8_e5m2)
|
||||
|
|
@ -265,7 +283,12 @@ class TestFP8Types(TestCase):
|
|||
else (torch.float8_e4m3fnuz, torch.float8_e5m2fnuz),
|
||||
)
|
||||
@parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096"))
|
||||
def test_amax_along_with_fp8_quant(self, float8_dtype: torch.dtype, shape: str):
|
||||
@parametrize("device", ("cuda", "cpu"))
|
||||
def test_amax_along_with_fp8_quant(
|
||||
self, float8_dtype: torch.dtype, shape: str, device: torch.device
|
||||
):
|
||||
if device == "cuda" and not PLATFORM_SUPPORTS_FP8:
|
||||
raise unittest.SkipTest(f8_msg)
|
||||
shape = [int(dim) for dim in shape.split(",")]
|
||||
batch_size, sequence_length, hidden_size = shape
|
||||
|
||||
|
|
@ -278,12 +301,12 @@ class TestFP8Types(TestCase):
|
|||
compiled_amax_fp8_quant = torch.compile(amax_fp8, backend="inductor")
|
||||
|
||||
x_shape = (batch_size, sequence_length, hidden_size)
|
||||
x = torch.rand(*x_shape, device="cuda", dtype=torch.half)
|
||||
scale = torch.tensor(1.0, device="cuda", dtype=torch.float)
|
||||
x = torch.rand(*x_shape, device=device, dtype=torch.half)
|
||||
scale = torch.tensor(1.0, device=device, dtype=torch.float)
|
||||
|
||||
amax_buffer_compiled = torch.zeros((1), device="cuda", dtype=torch.half)
|
||||
amax_buffer_compiled = torch.zeros((1), device=device, dtype=torch.half)
|
||||
y_compiled = compiled_amax_fp8_quant(x, scale, amax_buffer_compiled)
|
||||
amax_buffer = torch.zeros((1), device="cuda", dtype=torch.half)
|
||||
amax_buffer = torch.zeros((1), device=device, dtype=torch.half)
|
||||
y = amax_fp8(x, scale, amax_buffer)
|
||||
|
||||
torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-1, atol=1e-1)
|
||||
|
|
@ -292,7 +315,6 @@ class TestFP8Types(TestCase):
|
|||
)
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "ROCm fails with accuracy issue")
|
||||
@unittest.skipIf(not SM90OrLater, "FP8 is only supported on H100+")
|
||||
@parametrize(
|
||||
"float8_dtype",
|
||||
(torch.float8_e4m3fn, torch.float8_e5m2)
|
||||
|
|
@ -301,9 +323,16 @@ class TestFP8Types(TestCase):
|
|||
)
|
||||
@parametrize("amax_keep_dim", (True, False))
|
||||
@parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096"))
|
||||
@parametrize("device", ("cuda", "cpu"))
|
||||
def test_layernorm_fp8_quant(
|
||||
self, float8_dtype: torch.dtype, amax_keep_dim: bool, shape: str
|
||||
self,
|
||||
float8_dtype: torch.dtype,
|
||||
amax_keep_dim: bool,
|
||||
shape: str,
|
||||
device: torch.device,
|
||||
):
|
||||
if device == "cuda" and not SM90OrLater:
|
||||
raise unittest.SkipTest("FP8 is only supported on H100+")
|
||||
shape = [int(dim) for dim in shape.split(",")]
|
||||
batch_size, sequence_length, hidden_size = shape
|
||||
|
||||
|
|
@ -325,12 +354,12 @@ class TestFP8Types(TestCase):
|
|||
compiled_ln_fp8_quant = torch.compile(ln_fp8, backend="inductor")
|
||||
|
||||
x_shape = (batch_size, sequence_length, hidden_size)
|
||||
x = torch.rand(*x_shape, device="cuda", dtype=torch.half)
|
||||
scale = torch.tensor(0.2, device="cuda", dtype=torch.float)
|
||||
x = torch.rand(*x_shape, device=device, dtype=torch.half)
|
||||
scale = torch.tensor(0.2, device=device, dtype=torch.float)
|
||||
|
||||
amax_buffer_compiled = torch.zeros((1), device="cuda", dtype=torch.half)
|
||||
amax_buffer_compiled = torch.zeros((1), device=device, dtype=torch.half)
|
||||
y_compiled = compiled_ln_fp8_quant(x, scale, amax_buffer_compiled)
|
||||
amax_buffer = torch.zeros((1), device="cuda", dtype=torch.half)
|
||||
amax_buffer = torch.zeros((1), device=device, dtype=torch.half)
|
||||
y = ln_fp8(x, scale, amax_buffer)
|
||||
|
||||
torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-1, atol=1e-1)
|
||||
|
|
@ -748,5 +777,5 @@ class TestFP8Lowering(TestCase):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if HAS_CUDA:
|
||||
if HAS_CUDA or HAS_CPU:
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -3,8 +3,6 @@
|
|||
import unittest
|
||||
from itertools import product
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
import re
|
||||
|
||||
import torch
|
||||
|
||||
|
|
@ -17,7 +15,6 @@ from torch.testing import make_tensor
|
|||
from torch.testing._internal.common_cuda import (
|
||||
SM53OrLater,
|
||||
_get_torch_cuda_version,
|
||||
PLATFORM_SUPPORTS_FP8
|
||||
)
|
||||
from torch.testing._internal.common_device_type import (
|
||||
dtypes,
|
||||
|
|
@ -212,524 +209,6 @@ class TestMatmulCuda(TestCase):
|
|||
self.assertEqual(out1_gpu, out2_gpu[0])
|
||||
|
||||
|
||||
f8_msg = "FP8 is only supported on H100+ and sm_89 and MI300+ devices"
|
||||
|
||||
if torch.version.hip:
|
||||
e4m3_type = torch.float8_e4m3fnuz
|
||||
e5m2_type = torch.float8_e5m2fnuz
|
||||
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max
|
||||
E5M2_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max
|
||||
else:
|
||||
e4m3_type = torch.float8_e4m3fn
|
||||
e5m2_type = torch.float8_e5m2
|
||||
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max
|
||||
E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max
|
||||
|
||||
# avoid division by zero when calculating scale
|
||||
EPS = 1e-12
|
||||
|
||||
def amax_to_scale(
|
||||
amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
|
||||
):
|
||||
""" Converts the amax value of a tensor to the fp8 scale.
|
||||
Args:
|
||||
amax: The amax value of the tensor.
|
||||
float8_dtype: the float8 dtype.
|
||||
orig_dtype: The original dtype of the tensor.
|
||||
"""
|
||||
scale = torch.empty_like(amax, dtype=torch.float32)
|
||||
if float8_dtype == e4m3_type:
|
||||
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
|
||||
elif float8_dtype == e5m2_type:
|
||||
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
|
||||
else:
|
||||
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")
|
||||
|
||||
# Ensure the scale is representable in float16,
|
||||
# this helps when amax is small. We are assuming that we don't need
|
||||
# to care about this for float32/bfloat16
|
||||
if orig_dtype is torch.float16:
|
||||
res = torch.clamp(res, max=torch.finfo(torch.float16).max)
|
||||
|
||||
scale.copy_(res)
|
||||
return scale
|
||||
|
||||
def tensor_to_scale(x: torch.Tensor, float8_dtype: torch.dtype, dim=None):
|
||||
if dim is None:
|
||||
amax = torch.max(torch.abs(x))
|
||||
else:
|
||||
amax = torch.max(torch.abs(x), dim=dim, keepdim=True).values
|
||||
|
||||
return amax_to_scale(amax, float8_dtype, x.dtype)
|
||||
|
||||
def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor:
|
||||
# naive implementation: dq -> op -> q
|
||||
x_fp32 = x.to(torch.float) / x_scale
|
||||
y_fp32 = y.to(torch.float) / y_scale
|
||||
out_fp32 = torch.mm(x_fp32, y_fp32)
|
||||
|
||||
return out_fp32.to(out_dtype)
|
||||
|
||||
def addmm_float8_unwrapped(
|
||||
a_data: torch.Tensor,
|
||||
a_scale: torch.Tensor,
|
||||
b_data: torch.Tensor,
|
||||
b_scale: torch.tensor,
|
||||
output_dtype: torch.dtype,
|
||||
output_scale: Optional[torch.Tensor],
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
a_inverse_scale = a_scale.reciprocal()
|
||||
b_inverse_scale = b_scale.reciprocal()
|
||||
if output_dtype == torch.float32 and bias is not None:
|
||||
# Bias is not supported by _scaled_mm when output is fp32
|
||||
output = torch._scaled_mm(
|
||||
a_data,
|
||||
b_data,
|
||||
scale_a=a_inverse_scale,
|
||||
scale_b=b_inverse_scale,
|
||||
scale_result=output_scale,
|
||||
out_dtype=output_dtype,
|
||||
)
|
||||
output += bias
|
||||
return output
|
||||
output = torch._scaled_mm(
|
||||
a_data,
|
||||
b_data,
|
||||
bias=bias,
|
||||
scale_a=a_inverse_scale,
|
||||
scale_b=b_inverse_scale,
|
||||
scale_result=output_scale,
|
||||
out_dtype=output_dtype,
|
||||
)
|
||||
return output
|
||||
|
||||
def mm_float8(
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
a_scale: torch.Tensor,
|
||||
b_scale: torch.Tensor,
|
||||
output_dtype: torch.dtype, # output dtype
|
||||
output_scale: Optional[torch.Tensor] = None, # output scale, precomputed
|
||||
) -> torch.Tensor:
|
||||
return addmm_float8_unwrapped(
|
||||
a, a_scale, b, b_scale, output_dtype, output_scale
|
||||
)
|
||||
|
||||
def to_fp8_saturated(
|
||||
x: torch.Tensor,
|
||||
fp8_dtype: torch.dtype
|
||||
):
|
||||
if fp8_dtype == e4m3_type:
|
||||
x = x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS)
|
||||
elif fp8_dtype == e5m2_type:
|
||||
x = x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS)
|
||||
else:
|
||||
raise ValueError(f"to_fp8_saturated(): Unsupported fp8_dtype: {fp8_dtype}")
|
||||
|
||||
return x.to(fp8_dtype)
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not found")
|
||||
class TestFP8MatmulCuda(TestCase):
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
def _test_tautological_mm(self, device: str = "cuda",
|
||||
x_dtype: torch.dtype = e4m3_type,
|
||||
y_dtype: torch.dtype = e4m3_type,
|
||||
out_dtype: Optional[torch.dtype] = None,
|
||||
size: int = 16) -> None:
|
||||
x_fp8 = torch.rand(size, size, device=device).to(x_dtype)
|
||||
y_fp8 = torch.eye(size, device=device, dtype=y_dtype).t()
|
||||
out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float))
|
||||
scale_a = torch.tensor(1.0, device=device)
|
||||
scale_b = torch.tensor(1.0, device=device)
|
||||
out_fp8 = torch._scaled_mm(x_fp8, y_fp8, scale_a, scale_b, out_dtype=out_dtype)
|
||||
if out_dtype is not None:
|
||||
self.assertEqual(out_dtype, out_fp8.dtype)
|
||||
self.assertEqual(out_fp32, out_fp8.to(torch.float))
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
def test_float8_basics(self, device) -> None:
|
||||
self._test_tautological_mm(device, e4m3_type, e4m3_type, size=16)
|
||||
# hipblaslt does not yet support mixed e4m3_type input
|
||||
if torch.version.hip is None:
|
||||
self._test_tautological_mm(device, e4m3_type, e5m2_type, size=32)
|
||||
self._test_tautological_mm(device, e5m2_type, e4m3_type, size=48)
|
||||
# According to https://docs.nvidia.com/cuda/cublas/#id99 8F_E5M2 MM is unsupported
|
||||
with self.assertRaises(RuntimeError):
|
||||
self._test_tautological_mm(device, e5m2_type, e5m2_type)
|
||||
|
||||
self._test_tautological_mm(device, size=64, out_dtype=torch.float16)
|
||||
self._test_tautological_mm(device, size=96, out_dtype=torch.float32)
|
||||
# hipblaslt does not yet support bfloat16 output
|
||||
if torch.version.hip is None:
|
||||
self._test_tautological_mm(device, size=80, out_dtype=torch.bfloat16)
|
||||
with self.assertRaises(RuntimeError):
|
||||
self._test_tautological_mm(device, out_dtype=e5m2_type)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
def test_float8_scale(self, device) -> None:
|
||||
size = (16, 16)
|
||||
x = torch.full(size, .5, device=device, dtype=e4m3_type)
|
||||
# hipblaslt does not yet support mixed e4m3_type input
|
||||
y_type = e4m3_type if torch.version.hip else e5m2_type
|
||||
y = torch.full(size, .5, device=device, dtype=y_type).t()
|
||||
scale_a = torch.tensor(1.5, device=device)
|
||||
scale_b = torch.tensor(0.66, device=device)
|
||||
out_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b)
|
||||
self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4., device=device))
|
||||
out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b)
|
||||
self.assertEqual(out_fp8, out_fp8_s)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
@parametrize("base_dtype", [torch.float16, torch.bfloat16, torch.float32])
|
||||
def test_scaled_mm_vs_emulated(self, base_dtype):
|
||||
torch.manual_seed(42)
|
||||
input_dtype = e4m3_type
|
||||
output_dtype = base_dtype
|
||||
compare_type = torch.float32
|
||||
|
||||
x = torch.randn(16, 16, device="cuda", dtype=base_dtype)
|
||||
y = torch.randn(32, 16, device="cuda", dtype=base_dtype).t()
|
||||
|
||||
x_scale = tensor_to_scale(x, input_dtype).float()
|
||||
y_scale = tensor_to_scale(y, input_dtype).float()
|
||||
|
||||
x_fp8 = to_fp8_saturated(x * x_scale, input_dtype)
|
||||
y_fp8 = to_fp8_saturated(y * y_scale, input_dtype)
|
||||
|
||||
# Calculate actual F8 mm
|
||||
out_scaled_mm = mm_float8(
|
||||
x_fp8,
|
||||
y_fp8,
|
||||
a_scale=x_scale,
|
||||
b_scale=y_scale,
|
||||
output_dtype=output_dtype
|
||||
)
|
||||
|
||||
# Calculate emulated F8 mm
|
||||
out_emulated = mm_float8_emulated(
|
||||
x_fp8,
|
||||
x_scale,
|
||||
y_fp8,
|
||||
y_scale,
|
||||
output_dtype
|
||||
)
|
||||
|
||||
if output_dtype != base_dtype:
|
||||
out_scaled_mm = out_scaled_mm.to(compare_type)
|
||||
out_scaled_mm = out_scaled_mm / tensor_to_scale(out_scaled_mm, input_dtype)
|
||||
|
||||
out_emulated = out_emulated.to(compare_type)
|
||||
out_emulated = out_emulated / tensor_to_scale(out_emulated, input_dtype)
|
||||
|
||||
if base_dtype in {torch.bfloat16, torch.float16}:
|
||||
atol, rtol = 7e-2, 7e-2
|
||||
else:
|
||||
atol, rtol = 3e-3, 3e-3
|
||||
|
||||
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
@parametrize("base_dtype", [torch.float16, torch.bfloat16, torch.float32])
|
||||
def test_scaled_mm_change_stride(self, base_dtype):
|
||||
torch.manual_seed(42)
|
||||
input_dtype = e4m3_type
|
||||
output_dtype = base_dtype
|
||||
compare_type = torch.float32
|
||||
|
||||
x = torch.empty_strided((16, 16), (16, 1), device="cuda", dtype=base_dtype)
|
||||
y = torch.empty_strided((16, 32), (1, 64), device="cuda", dtype=base_dtype)
|
||||
|
||||
x_scale = tensor_to_scale(x, input_dtype).float()
|
||||
y_scale = tensor_to_scale(y, input_dtype).float()
|
||||
|
||||
x_fp8 = to_fp8_saturated(x * x_scale, input_dtype)
|
||||
y_fp8 = to_fp8_saturated(y * y_scale, input_dtype)
|
||||
|
||||
# Calculate actual F8 mm
|
||||
out_scaled_mm = mm_float8(
|
||||
x_fp8,
|
||||
y_fp8,
|
||||
a_scale=x_scale,
|
||||
b_scale=y_scale,
|
||||
output_dtype=output_dtype
|
||||
)
|
||||
|
||||
# Calculate emulated F8 mm
|
||||
out_emulated = mm_float8_emulated(
|
||||
x_fp8,
|
||||
x_scale,
|
||||
y_fp8,
|
||||
y_scale,
|
||||
output_dtype
|
||||
)
|
||||
|
||||
if output_dtype != base_dtype:
|
||||
out_scaled_mm = out_scaled_mm.to(compare_type)
|
||||
out_scaled_mm = out_scaled_mm / tensor_to_scale(out_scaled_mm, input_dtype)
|
||||
|
||||
out_emulated = out_emulated.to(compare_type)
|
||||
out_emulated = out_emulated / tensor_to_scale(out_emulated, input_dtype)
|
||||
|
||||
if base_dtype in {torch.bfloat16, torch.float16}:
|
||||
atol, rtol = 7e-2, 7e-2
|
||||
else:
|
||||
atol, rtol = 3e-3, 3e-3
|
||||
|
||||
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
def test_float8_bias(self, device) -> None:
|
||||
(k, l, m) = (16, 48, 32)
|
||||
x = torch.ones((k, l), device=device).to(e4m3_type)
|
||||
y = torch.full((m, l), .25, device=device, dtype=e4m3_type).t()
|
||||
bias = torch.full((m,), 4.0, device=device, dtype=torch.half)
|
||||
scale_a = torch.tensor(1.0, device=device)
|
||||
scale_b = torch.tensor(1.0, device=device)
|
||||
out_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b)
|
||||
outb_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, bias=bias)
|
||||
# this fails on ROCm currently because hipblaslt doesn't have amax op
|
||||
out_fp32 = out_fp8.to(torch.float32)
|
||||
outb_fp32 = outb_fp8.to(torch.float32)
|
||||
difference = torch.abs(out_fp32 - outb_fp32)
|
||||
self.assertEqual(difference, torch.tensor(4.0, device=device).expand_as(out_fp32))
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
@parametrize("bias", [True, False])
|
||||
def test_non_divisible_leading_dim(self, device, bias: bool) -> None:
|
||||
x = torch.rand((17, 16), device=device).to(e4m3_type)
|
||||
y = torch.rand((16, 16), device=device).to(e4m3_type).t()
|
||||
scale_a = torch.tensor(1.0, device=device)
|
||||
scale_b = torch.tensor(1.0, device=device)
|
||||
input_bias = None
|
||||
if bias:
|
||||
input_bias = torch.rand((16,), device=device).to(torch.half)
|
||||
_ = torch._scaled_mm(x, y, scale_a, scale_b, bias=input_bias)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
def test_float8_bias_relu_edgecase(self, device) -> None:
|
||||
(k, l, m) = (16, 48, 32)
|
||||
x = torch.full((k, l), 0.0, device=device).to(e4m3_type)
|
||||
y = torch.full((m, l), 1.0, device=device, dtype=e4m3_type).t()
|
||||
bias = torch.full((m,), -3.0, device=device, dtype=torch.half)
|
||||
scale_a = torch.tensor(1.0, device=device)
|
||||
scale_b = torch.tensor(1.0, device=device)
|
||||
outb_fp8 = torch._scaled_mm(x, y, scale_a, scale_b, bias=bias)
|
||||
outb_fp32 = outb_fp8.to(torch.float32)
|
||||
self.assertEqual(outb_fp32, torch.tensor(-3.0, device=device).expand_as(outb_fp32))
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
def test_float32_output_errors_with_bias(self, device) -> None:
|
||||
(k, l, m) = (16, 48, 32)
|
||||
x = torch.rand((k, l), device=device).to(e4m3_type)
|
||||
y = torch.full((m, l), .25, device=device, dtype=e4m3_type).t()
|
||||
scale_a = torch.tensor(1.0, device=device)
|
||||
scale_b = torch.tensor(1.0, device=device)
|
||||
bias = torch.full((m,), 4.0, device=device, dtype=torch.bfloat16)
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Bias is not supported when out_dtype is set to Float32",
|
||||
lambda: torch._scaled_mm(x, y, scale_a, scale_b, bias=bias, out_dtype=torch.float32),
|
||||
)
|
||||
|
||||
@unittest.skipIf(PLATFORM_SUPPORTS_FP8,
|
||||
"This test is only for devices with compute capability < 8.9")
|
||||
def test_error_message_fp8_pre_sm89(self, device) -> None:
|
||||
(k, l, m) = (16, 48, 32)
|
||||
x = torch.rand((k, l), device=device).to(e4m3_type)
|
||||
y = torch.rand((m, l), device=device).to(e4m3_type).t()
|
||||
scale_a = torch.tensor(1.0, device=device)
|
||||
scale_b = torch.tensor(1.0, device=device)
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"torch\.\_scaled\_mm is only supported on CUDA devices with compute capability \>\= 9\.0 or 8\.9, or ROCm MI300\+",
|
||||
lambda: torch._scaled_mm(x, y, scale_a, scale_b, out_dtype=torch.float32),
|
||||
)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
def test_float8_scale_fast_accum(self, device) -> None:
|
||||
size = (16, 16)
|
||||
x = torch.full(size, .5, device=device, dtype=e4m3_type)
|
||||
# hipblaslt does not yet support mixed e4m3_type input
|
||||
y_type = e4m3_type if torch.version.hip else e5m2_type
|
||||
y = torch.full(size, .5, device=device, dtype=y_type).t()
|
||||
scale_a = torch.tensor(1.5, device=device)
|
||||
scale_b = torch.tensor(0.66, device=device)
|
||||
out_fp8 = torch._scaled_mm(x, y, scale_a, scale_b, use_fast_accum=True)
|
||||
self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4., device=device))
|
||||
out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, use_fast_accum=True)
|
||||
self.assertEqual(out_fp8, out_fp8_s)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
|
||||
@unittest.skipIf(not _IS_SM9X, "rowwise implementation is currently sm90 specific")
|
||||
@skipIfRocm()
|
||||
@parametrize("use_fast_accum", [True, False])
|
||||
def test_float8_rowwise_scaling_sanity(self, device, use_fast_accum: bool) -> None:
|
||||
M, K, N = (1024, 512, 2048)
|
||||
fill_value = 0.5
|
||||
x = torch.full((M, K), fill_value, device=device)
|
||||
y = torch.full((N, K), fill_value, device=device)
|
||||
|
||||
x_scales = torch.ones((x.shape[0], 1), device=device, dtype=torch.float32)
|
||||
y_scales = torch.ones((1, y.shape[0]), device=device, dtype=torch.float32)
|
||||
|
||||
x_fp8 = x.to(torch.float8_e4m3fn)
|
||||
y_fp8 = y.to(torch.float8_e4m3fn).t()
|
||||
|
||||
out_fp8 = torch._scaled_mm(
|
||||
x_fp8,
|
||||
y_fp8,
|
||||
scale_a=x_scales,
|
||||
scale_b=y_scales,
|
||||
out_dtype=torch.bfloat16,
|
||||
use_fast_accum=use_fast_accum,
|
||||
)
|
||||
self.assertEqual(
|
||||
out_fp8.to(torch.float32), torch.full((M, N), K * (fill_value**2), device=device)
|
||||
)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
|
||||
@skipIfRocm()
|
||||
def test_float8_error_messages(self, device) -> None:
|
||||
M, K, N = (1024, 512, 2048)
|
||||
fill_value = 0.5
|
||||
x = torch.full((M, K), fill_value, device=device)
|
||||
y = torch.full((N, K), fill_value, device=device)
|
||||
|
||||
x_fp8 = x.to(torch.float8_e4m3fn)
|
||||
y_fp8 = y.to(torch.float8_e4m3fn).t()
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
re.escape(
|
||||
"For RowWise scaling, scale_a should be (1024, 1) and scale_b "
|
||||
"should be (1, 2048). Got scale_a.size()=(1, 1) and scale_b.size()=(1, 2)"
|
||||
),
|
||||
):
|
||||
torch._scaled_mm(
|
||||
x_fp8,
|
||||
y_fp8,
|
||||
scale_a=torch.ones((1, 1), device="cuda"),
|
||||
scale_b=torch.ones((1, 2), device="cuda"),
|
||||
out_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
re.escape(
|
||||
" For RowWise scaling, scale_a should be (1024, 1) and scale_b "
|
||||
"should be (1, 2048). Got scale_a.size()=(1024, 1) and scale_b.size()=(1, 2049)"
|
||||
),
|
||||
):
|
||||
torch._scaled_mm(
|
||||
x_fp8,
|
||||
y_fp8,
|
||||
scale_a=torch.ones((M, 1), device="cuda"),
|
||||
scale_b=torch.ones((1, N + 1), device="cuda"),
|
||||
out_dtype=torch.bfloat16,
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
re.escape("For non-TensorWise scaling, scale tensors must be 2-dimensional"),
|
||||
):
|
||||
torch._scaled_mm(
|
||||
x_fp8,
|
||||
y_fp8,
|
||||
scale_a=torch.ones((M), device="cuda"),
|
||||
scale_b=torch.ones((N, N), device="cuda"),
|
||||
out_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
re.escape(
|
||||
"Both scale_a and scale_b must be contiguous for RowWise scaling."
|
||||
),
|
||||
):
|
||||
torch._scaled_mm(
|
||||
x_fp8,
|
||||
y_fp8,
|
||||
scale_a=torch.ones((M, 1), device="cuda"),
|
||||
scale_b=torch.ones((1, N * 2), device="cuda")[:, ::2],
|
||||
out_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
re.escape("Expected b.dtype() == at::kFloat8_e4m3fn to be true, but got false."),
|
||||
):
|
||||
torch._scaled_mm(
|
||||
x_fp8,
|
||||
y_fp8.to(torch.float8_e5m2),
|
||||
scale_a=torch.ones((M, 1), device="cuda"),
|
||||
scale_b=torch.ones((1, N), device="cuda"),
|
||||
out_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
|
||||
@unittest.skipIf(not _IS_SM9X, "rowwise implementation is currently sm90 specific")
|
||||
@skipIfRocm()
|
||||
@parametrize("base_dtype", [torch.bfloat16])
|
||||
def test_scaled_mm_vs_emulated_row_wise(self, base_dtype):
|
||||
torch.manual_seed(42)
|
||||
input_dtype = e4m3_type
|
||||
output_dtype = base_dtype
|
||||
|
||||
x = torch.randn(16, 16, device="cuda", dtype=base_dtype)
|
||||
y = torch.randn(32, 16, device="cuda", dtype=base_dtype).t()
|
||||
|
||||
x_scales = tensor_to_scale(x, input_dtype, dim=1).float()
|
||||
y_scales = tensor_to_scale(y, input_dtype, dim=0).float()
|
||||
|
||||
x_fp8 = to_fp8_saturated(x * x_scales, e4m3_type)
|
||||
y_fp8 = to_fp8_saturated(y * y_scales, e4m3_type)
|
||||
|
||||
# Calculate actual F8 mm
|
||||
out_scaled_mm = mm_float8(
|
||||
x_fp8, y_fp8, a_scale=x_scales, b_scale=y_scales, output_dtype=output_dtype
|
||||
)
|
||||
|
||||
# Calculate emulated F8 mm
|
||||
out_emulated = mm_float8_emulated(
|
||||
x_fp8, x_scales, y_fp8, y_scales, output_dtype
|
||||
)
|
||||
|
||||
if base_dtype in {torch.bfloat16, torch.float16}:
|
||||
atol, rtol = 7e-2, 7e-2
|
||||
else:
|
||||
atol, rtol = 2e-3, 2e-3
|
||||
|
||||
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
@parametrize("which_dim_zero", [0, 1, 2])
|
||||
@parametrize("use_torch_compile", [False, True])
|
||||
def test_zero_dim_tensorwise(self, which_dim_zero, use_torch_compile) -> None:
|
||||
device = "cuda"
|
||||
x_dtype, y_dtype = torch.float8_e4m3fn, torch.float8_e4m3fn
|
||||
out_dtype = torch.bfloat16
|
||||
M, K, N = 32, 32, 32
|
||||
if which_dim_zero == 0:
|
||||
M = 0
|
||||
elif which_dim_zero == 1:
|
||||
K = 0
|
||||
elif which_dim_zero == 2:
|
||||
N = 0
|
||||
|
||||
x_fp8 = torch.zeros(M, K, device=device).to(x_dtype)
|
||||
y_fp8 = torch.zeros(N, K, device=device, dtype=y_dtype).t()
|
||||
out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float))
|
||||
scale_a = torch.tensor(float('-inf'), device=device)
|
||||
scale_b = torch.tensor(float('-inf'), device=device)
|
||||
f = torch._scaled_mm
|
||||
if use_torch_compile:
|
||||
f = torch.compile(torch._scaled_mm)
|
||||
out_fp8 = f(x_fp8, y_fp8, scale_a, scale_b, out_dtype=out_dtype)
|
||||
self.assertEqual(out_dtype, out_fp8.dtype)
|
||||
self.assertEqual(out_fp32, out_fp8.to(torch.float))
|
||||
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
||||
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions")
|
||||
@unittest.skipIf(not _IS_SM8X, "mixed dtypes linear only supported on SM 8.x")
|
||||
|
|
@ -851,7 +330,6 @@ class TestMixedDtypesLinearCuda(TestCase):
|
|||
)
|
||||
|
||||
instantiate_device_type_tests(TestMatmulCuda, globals(), except_for="cpu")
|
||||
instantiate_device_type_tests(TestFP8MatmulCuda, globals(), except_for="cpu")
|
||||
instantiate_device_type_tests(TestMixedDtypesLinearCuda, globals(), except_for="cpu")
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
563
test/test_matmul_fp8.py
Normal file
563
test/test_matmul_fp8.py
Normal file
|
|
@ -0,0 +1,563 @@
|
|||
# Owner(s): ["module: linear algebra"]
|
||||
|
||||
import re
|
||||
import unittest
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_WINDOWS,
|
||||
parametrize,
|
||||
run_tests,
|
||||
skipIfRocm,
|
||||
TEST_CUDA,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
|
||||
_IS_SM9X = False
|
||||
if TEST_CUDA:
|
||||
_IS_SM9X = torch.cuda.get_device_capability(0)[0] == 9
|
||||
|
||||
# Protects against includes accidentally setting the default dtype
|
||||
assert torch.get_default_dtype() is torch.float32
|
||||
|
||||
f8_msg = "FP8 is only supported on H100+ and sm_89 and MI300+ devices"
|
||||
|
||||
if torch.version.hip:
|
||||
e4m3_type = torch.float8_e4m3fnuz
|
||||
e5m2_type = torch.float8_e5m2fnuz
|
||||
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max
|
||||
E5M2_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max
|
||||
else:
|
||||
e4m3_type = torch.float8_e4m3fn
|
||||
e5m2_type = torch.float8_e5m2
|
||||
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max
|
||||
E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max
|
||||
|
||||
# avoid division by zero when calculating scale
|
||||
EPS = 1e-12
|
||||
|
||||
|
||||
def amax_to_scale(
|
||||
amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
|
||||
):
|
||||
"""Converts the amax value of a tensor to the fp8 scale.
|
||||
Args:
|
||||
amax: The amax value of the tensor.
|
||||
float8_dtype: the float8 dtype.
|
||||
orig_dtype: The original dtype of the tensor.
|
||||
"""
|
||||
scale = torch.empty_like(amax, dtype=torch.float32)
|
||||
if float8_dtype == e4m3_type:
|
||||
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
|
||||
elif float8_dtype == e5m2_type:
|
||||
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
|
||||
else:
|
||||
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")
|
||||
|
||||
# Ensure the scale is representable in float16,
|
||||
# this helps when amax is small. We are assuming that we don't need
|
||||
# to care about this for float32/bfloat16
|
||||
if orig_dtype is torch.float16:
|
||||
res = torch.clamp(res, max=torch.finfo(torch.float16).max)
|
||||
|
||||
scale.copy_(res)
|
||||
return scale
|
||||
|
||||
|
||||
def tensor_to_scale(x: torch.Tensor, float8_dtype: torch.dtype, dim=None):
|
||||
if dim is None:
|
||||
amax = torch.max(torch.abs(x))
|
||||
else:
|
||||
amax = torch.max(torch.abs(x), dim=dim, keepdim=True).values
|
||||
|
||||
return amax_to_scale(amax, float8_dtype, x.dtype)
|
||||
|
||||
|
||||
def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor:
|
||||
# naive implementation: dq -> op -> q
|
||||
x_fp32 = x.to(torch.float) / x_scale
|
||||
y_fp32 = y.to(torch.float) / y_scale
|
||||
out_fp32 = torch.mm(x_fp32, y_fp32)
|
||||
|
||||
return out_fp32.to(out_dtype)
|
||||
|
||||
|
||||
def addmm_float8_unwrapped(
|
||||
a_data: torch.Tensor,
|
||||
a_scale: torch.Tensor,
|
||||
b_data: torch.Tensor,
|
||||
b_scale: torch.tensor,
|
||||
output_dtype: torch.dtype,
|
||||
output_scale: Optional[torch.Tensor],
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
a_inverse_scale = a_scale.reciprocal()
|
||||
b_inverse_scale = b_scale.reciprocal()
|
||||
if output_dtype == torch.float32 and bias is not None:
|
||||
# Bias is not supported by _scaled_mm when output is fp32
|
||||
output = torch._scaled_mm(
|
||||
a_data,
|
||||
b_data,
|
||||
scale_a=a_inverse_scale,
|
||||
scale_b=b_inverse_scale,
|
||||
scale_result=output_scale,
|
||||
out_dtype=output_dtype,
|
||||
)
|
||||
output += bias
|
||||
return output
|
||||
output = torch._scaled_mm(
|
||||
a_data,
|
||||
b_data,
|
||||
bias=bias,
|
||||
scale_a=a_inverse_scale,
|
||||
scale_b=b_inverse_scale,
|
||||
scale_result=output_scale,
|
||||
out_dtype=output_dtype,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def mm_float8(
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
a_scale: torch.Tensor,
|
||||
b_scale: torch.Tensor,
|
||||
output_dtype: torch.dtype, # output dtype
|
||||
output_scale: Optional[torch.Tensor] = None, # output scale, precomputed
|
||||
) -> torch.Tensor:
|
||||
return addmm_float8_unwrapped(a, a_scale, b, b_scale, output_dtype, output_scale)
|
||||
|
||||
|
||||
def to_fp8_saturated(x: torch.Tensor, fp8_dtype: torch.dtype):
|
||||
if fp8_dtype == e4m3_type:
|
||||
x = x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS)
|
||||
elif fp8_dtype == e5m2_type:
|
||||
x = x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS)
|
||||
else:
|
||||
raise ValueError(f"to_fp8_saturated(): Unsupported fp8_dtype: {fp8_dtype}")
|
||||
|
||||
return x.to(fp8_dtype)
|
||||
|
||||
|
||||
class TestFP8Matmul(TestCase):
|
||||
def _test_tautological_mm(
|
||||
self,
|
||||
device: str = "cuda",
|
||||
x_dtype: torch.dtype = e4m3_type,
|
||||
y_dtype: torch.dtype = e4m3_type,
|
||||
out_dtype: Optional[torch.dtype] = None,
|
||||
size: int = 16,
|
||||
) -> None:
|
||||
if device != "cpu" and torch.cuda.is_available() and not PLATFORM_SUPPORTS_FP8:
|
||||
raise unittest.SkipTest(f8_msg)
|
||||
x_fp8 = torch.rand(size, size, device=device).to(x_dtype)
|
||||
y_fp8 = torch.eye(size, device=device, dtype=y_dtype).t()
|
||||
out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float))
|
||||
scale_a = torch.tensor(1.0, device=device)
|
||||
scale_b = torch.tensor(1.0, device=device)
|
||||
out_fp8 = torch._scaled_mm(x_fp8, y_fp8, scale_a, scale_b, out_dtype=out_dtype)
|
||||
if out_dtype is not None:
|
||||
self.assertEqual(out_dtype, out_fp8.dtype)
|
||||
self.assertEqual(out_fp32, out_fp8.to(torch.float))
|
||||
|
||||
def test_float8_basics(self, device) -> None:
|
||||
if device != "cpu" and torch.cuda.is_available() and not PLATFORM_SUPPORTS_FP8:
|
||||
raise unittest.SkipTest(f8_msg)
|
||||
self._test_tautological_mm(device, e4m3_type, e4m3_type, size=16)
|
||||
# hipblaslt does not yet support mixed e4m3_type input
|
||||
if torch.version.hip is None:
|
||||
if device != "cpu":
|
||||
# TODO: The following 2 tests are mixed dtypes between src and weight,
|
||||
# which will be enabled in oneDNN v3.6 in CPU.
|
||||
self._test_tautological_mm(device, e4m3_type, e5m2_type, size=32)
|
||||
self._test_tautological_mm(device, e5m2_type, e4m3_type, size=48)
|
||||
if device == "cpu":
|
||||
self._test_tautological_mm(device, e5m2_type, e5m2_type)
|
||||
else:
|
||||
# According to https://docs.nvidia.com/cuda/cublas/#id99 8F_E5M2 MM is unsupported
|
||||
with self.assertRaises(RuntimeError):
|
||||
self._test_tautological_mm(device, e5m2_type, e5m2_type)
|
||||
|
||||
self._test_tautological_mm(device, size=64, out_dtype=torch.float16)
|
||||
self._test_tautological_mm(device, size=96, out_dtype=torch.float32)
|
||||
# hipblaslt does not yet support bfloat16 output
|
||||
if torch.version.hip is None:
|
||||
self._test_tautological_mm(device, size=80, out_dtype=torch.bfloat16)
|
||||
if device != "cpu":
|
||||
with self.assertRaises(RuntimeError):
|
||||
self._test_tautological_mm(device, out_dtype=e5m2_type)
|
||||
else:
|
||||
# TODO: e4m3 and e5m2 naturally has numerical gap, maybe relax the tolerance later.
|
||||
with self.assertRaises(AssertionError):
|
||||
self._test_tautological_mm(device, out_dtype=e5m2_type)
|
||||
|
||||
def test_float8_scale(self, device) -> None:
|
||||
if device != "cpu" and torch.cuda.is_available() and not PLATFORM_SUPPORTS_FP8:
|
||||
raise unittest.SkipTest(f8_msg)
|
||||
size = (16, 16)
|
||||
x = torch.full(size, 0.5, device=device, dtype=e4m3_type)
|
||||
# hipblaslt does not yet support mixed e4m3_type input
|
||||
# TODO: will use e5m2_type after upgrading oneDNN to v3.6.
|
||||
y_type = e4m3_type if torch.version.hip or device == "cpu" else e5m2_type
|
||||
y = torch.full(size, 0.5, device=device, dtype=y_type).t()
|
||||
scale_a = torch.tensor(1.5, device=device)
|
||||
scale_b = torch.tensor(0.66, device=device)
|
||||
out_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b)
|
||||
self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4.0, device=device))
|
||||
out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b)
|
||||
self.assertEqual(out_fp8, out_fp8_s)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
@parametrize("base_dtype", [torch.float16, torch.bfloat16, torch.float32])
|
||||
def test_scaled_mm_vs_emulated(self, base_dtype):
|
||||
torch.manual_seed(42)
|
||||
input_dtype = e4m3_type
|
||||
output_dtype = base_dtype
|
||||
compare_type = torch.float32
|
||||
|
||||
x = torch.randn(16, 16, device="cuda", dtype=base_dtype)
|
||||
y = torch.randn(32, 16, device="cuda", dtype=base_dtype).t()
|
||||
|
||||
x_scale = tensor_to_scale(x, input_dtype).float()
|
||||
y_scale = tensor_to_scale(y, input_dtype).float()
|
||||
|
||||
x_fp8 = to_fp8_saturated(x * x_scale, input_dtype)
|
||||
y_fp8 = to_fp8_saturated(y * y_scale, input_dtype)
|
||||
|
||||
# Calculate actual F8 mm
|
||||
out_scaled_mm = mm_float8(
|
||||
x_fp8, y_fp8, a_scale=x_scale, b_scale=y_scale, output_dtype=output_dtype
|
||||
)
|
||||
|
||||
# Calculate emulated F8 mm
|
||||
out_emulated = mm_float8_emulated(x_fp8, x_scale, y_fp8, y_scale, output_dtype)
|
||||
|
||||
if output_dtype != base_dtype:
|
||||
out_scaled_mm = out_scaled_mm.to(compare_type)
|
||||
out_scaled_mm = out_scaled_mm / tensor_to_scale(out_scaled_mm, input_dtype)
|
||||
|
||||
out_emulated = out_emulated.to(compare_type)
|
||||
out_emulated = out_emulated / tensor_to_scale(out_emulated, input_dtype)
|
||||
|
||||
if base_dtype in {torch.bfloat16, torch.float16}:
|
||||
atol, rtol = 7e-2, 7e-2
|
||||
else:
|
||||
atol, rtol = 3e-3, 3e-3
|
||||
|
||||
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
@parametrize("base_dtype", [torch.float16, torch.bfloat16, torch.float32])
|
||||
def test_scaled_mm_change_stride(self, base_dtype):
|
||||
torch.manual_seed(42)
|
||||
input_dtype = e4m3_type
|
||||
output_dtype = base_dtype
|
||||
compare_type = torch.float32
|
||||
|
||||
x = torch.empty_strided((16, 16), (16, 1), device="cuda", dtype=base_dtype)
|
||||
y = torch.empty_strided((16, 32), (1, 64), device="cuda", dtype=base_dtype)
|
||||
|
||||
x_scale = tensor_to_scale(x, input_dtype).float()
|
||||
y_scale = tensor_to_scale(y, input_dtype).float()
|
||||
|
||||
x_fp8 = to_fp8_saturated(x * x_scale, input_dtype)
|
||||
y_fp8 = to_fp8_saturated(y * y_scale, input_dtype)
|
||||
|
||||
# Calculate actual F8 mm
|
||||
out_scaled_mm = mm_float8(
|
||||
x_fp8, y_fp8, a_scale=x_scale, b_scale=y_scale, output_dtype=output_dtype
|
||||
)
|
||||
|
||||
# Calculate emulated F8 mm
|
||||
out_emulated = mm_float8_emulated(x_fp8, x_scale, y_fp8, y_scale, output_dtype)
|
||||
|
||||
if output_dtype != base_dtype:
|
||||
out_scaled_mm = out_scaled_mm.to(compare_type)
|
||||
out_scaled_mm = out_scaled_mm / tensor_to_scale(out_scaled_mm, input_dtype)
|
||||
|
||||
out_emulated = out_emulated.to(compare_type)
|
||||
out_emulated = out_emulated / tensor_to_scale(out_emulated, input_dtype)
|
||||
|
||||
if base_dtype in {torch.bfloat16, torch.float16}:
|
||||
atol, rtol = 7e-2, 7e-2
|
||||
else:
|
||||
atol, rtol = 3e-3, 3e-3
|
||||
|
||||
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
|
||||
|
||||
def test_float8_bias(self, device) -> None:
|
||||
if device != "cpu" and torch.cuda.is_available() and not PLATFORM_SUPPORTS_FP8:
|
||||
raise unittest.SkipTest(f8_msg)
|
||||
(k, l, m) = (16, 48, 32)
|
||||
x = torch.ones((k, l), device=device).to(e4m3_type)
|
||||
y = torch.full((m, l), 0.25, device=device, dtype=e4m3_type).t()
|
||||
bias = torch.full((m,), 4.0, device=device, dtype=torch.half)
|
||||
scale_a = torch.tensor(1.0, device=device)
|
||||
scale_b = torch.tensor(1.0, device=device)
|
||||
out_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b)
|
||||
outb_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, bias=bias)
|
||||
# this fails on ROCm currently because hipblaslt doesn't have amax op
|
||||
out_fp32 = out_fp8.to(torch.float32)
|
||||
outb_fp32 = outb_fp8.to(torch.float32)
|
||||
difference = torch.abs(out_fp32 - outb_fp32)
|
||||
self.assertEqual(
|
||||
difference, torch.tensor(4.0, device=device).expand_as(out_fp32)
|
||||
)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
@parametrize("bias", [True, False])
|
||||
def test_non_divisible_leading_dim(self, device, bias: bool) -> None:
|
||||
x = torch.rand((17, 16), device=device).to(e4m3_type)
|
||||
y = torch.rand((16, 16), device=device).to(e4m3_type).t()
|
||||
scale_a = torch.tensor(1.0, device=device)
|
||||
scale_b = torch.tensor(1.0, device=device)
|
||||
input_bias = None
|
||||
if bias:
|
||||
input_bias = torch.rand((16,), device=device).to(torch.half)
|
||||
_ = torch._scaled_mm(x, y, scale_a, scale_b, bias=input_bias)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
def test_float8_bias_relu_edgecase(self, device) -> None:
|
||||
(k, l, m) = (16, 48, 32)
|
||||
x = torch.full((k, l), 0.0, device=device).to(e4m3_type)
|
||||
y = torch.full((m, l), 1.0, device=device, dtype=e4m3_type).t()
|
||||
bias = torch.full((m,), -3.0, device=device, dtype=torch.half)
|
||||
scale_a = torch.tensor(1.0, device=device)
|
||||
scale_b = torch.tensor(1.0, device=device)
|
||||
outb_fp8 = torch._scaled_mm(x, y, scale_a, scale_b, bias=bias)
|
||||
outb_fp32 = outb_fp8.to(torch.float32)
|
||||
self.assertEqual(
|
||||
outb_fp32, torch.tensor(-3.0, device=device).expand_as(outb_fp32)
|
||||
)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
def test_float32_output_errors_with_bias(self, device) -> None:
|
||||
(k, l, m) = (16, 48, 32)
|
||||
x = torch.rand((k, l), device=device).to(e4m3_type)
|
||||
y = torch.full((m, l), 0.25, device=device, dtype=e4m3_type).t()
|
||||
scale_a = torch.tensor(1.0, device=device)
|
||||
scale_b = torch.tensor(1.0, device=device)
|
||||
bias = torch.full((m,), 4.0, device=device, dtype=torch.bfloat16)
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Bias is not supported when out_dtype is set to Float32",
|
||||
lambda: torch._scaled_mm(
|
||||
x, y, scale_a, scale_b, bias=bias, out_dtype=torch.float32
|
||||
),
|
||||
)
|
||||
|
||||
@unittest.skipIf(
|
||||
PLATFORM_SUPPORTS_FP8 or not torch.cuda.is_available(),
|
||||
"This test is only for devices with compute capability < 8.9",
|
||||
)
|
||||
def test_error_message_fp8_pre_sm89(self, device) -> None:
|
||||
(k, l, m) = (16, 48, 32)
|
||||
x = torch.rand((k, l), device=device).to(e4m3_type)
|
||||
y = torch.rand((m, l), device=device).to(e4m3_type).t()
|
||||
scale_a = torch.tensor(1.0, device=device)
|
||||
scale_b = torch.tensor(1.0, device=device)
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"torch\.\_scaled\_mm is only supported on CUDA devices with compute capability \>\= 9\.0 or 8\.9, or ROCm MI300\+",
|
||||
lambda: torch._scaled_mm(x, y, scale_a, scale_b, out_dtype=torch.float32),
|
||||
)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
def test_float8_scale_fast_accum(self, device) -> None:
|
||||
size = (16, 16)
|
||||
x = torch.full(size, 0.5, device=device, dtype=e4m3_type)
|
||||
# hipblaslt does not yet support mixed e4m3_type input
|
||||
y_type = e4m3_type if torch.version.hip else e5m2_type
|
||||
y = torch.full(size, 0.5, device=device, dtype=y_type).t()
|
||||
scale_a = torch.tensor(1.5, device=device)
|
||||
scale_b = torch.tensor(0.66, device=device)
|
||||
out_fp8 = torch._scaled_mm(x, y, scale_a, scale_b, use_fast_accum=True)
|
||||
self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4.0, device=device))
|
||||
out_fp8_s = torch._scaled_mm(
|
||||
x, y, scale_a=scale_a, scale_b=scale_b, use_fast_accum=True
|
||||
)
|
||||
self.assertEqual(out_fp8, out_fp8_s)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
|
||||
@unittest.skipIf(not _IS_SM9X, "rowwise implementation is currently sm90 specific")
|
||||
@skipIfRocm()
|
||||
@parametrize("use_fast_accum", [True, False])
|
||||
def test_float8_rowwise_scaling_sanity(self, device, use_fast_accum: bool) -> None:
|
||||
M, K, N = (1024, 512, 2048)
|
||||
fill_value = 0.5
|
||||
x = torch.full((M, K), fill_value, device=device)
|
||||
y = torch.full((N, K), fill_value, device=device)
|
||||
|
||||
x_scales = torch.ones((x.shape[0], 1), device=device, dtype=torch.float32)
|
||||
y_scales = torch.ones((1, y.shape[0]), device=device, dtype=torch.float32)
|
||||
|
||||
x_fp8 = x.to(torch.float8_e4m3fn)
|
||||
y_fp8 = y.to(torch.float8_e4m3fn).t()
|
||||
|
||||
out_fp8 = torch._scaled_mm(
|
||||
x_fp8,
|
||||
y_fp8,
|
||||
scale_a=x_scales,
|
||||
scale_b=y_scales,
|
||||
out_dtype=torch.bfloat16,
|
||||
use_fast_accum=use_fast_accum,
|
||||
)
|
||||
self.assertEqual(
|
||||
out_fp8.to(torch.float32),
|
||||
torch.full((M, N), K * (fill_value**2), device=device),
|
||||
)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
|
||||
@skipIfRocm()
|
||||
def test_float8_error_messages(self, device) -> None:
|
||||
M, K, N = (1024, 512, 2048)
|
||||
fill_value = 0.5
|
||||
x = torch.full((M, K), fill_value, device=device)
|
||||
y = torch.full((N, K), fill_value, device=device)
|
||||
|
||||
x_fp8 = x.to(torch.float8_e4m3fn)
|
||||
y_fp8 = y.to(torch.float8_e4m3fn).t()
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
re.escape(
|
||||
"For RowWise scaling, scale_a should be (1024, 1) and scale_b "
|
||||
"should be (1, 2048). Got scale_a.size()=(1, 1) and scale_b.size()=(1, 2)"
|
||||
),
|
||||
):
|
||||
torch._scaled_mm(
|
||||
x_fp8,
|
||||
y_fp8,
|
||||
scale_a=torch.ones((1, 1), device="cuda"),
|
||||
scale_b=torch.ones((1, 2), device="cuda"),
|
||||
out_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
re.escape(
|
||||
" For RowWise scaling, scale_a should be (1024, 1) and scale_b "
|
||||
"should be (1, 2048). Got scale_a.size()=(1024, 1) and scale_b.size()=(1, 2049)"
|
||||
),
|
||||
):
|
||||
torch._scaled_mm(
|
||||
x_fp8,
|
||||
y_fp8,
|
||||
scale_a=torch.ones((M, 1), device="cuda"),
|
||||
scale_b=torch.ones((1, N + 1), device="cuda"),
|
||||
out_dtype=torch.bfloat16,
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
re.escape(
|
||||
"For non-TensorWise scaling, scale tensors must be 2-dimensional"
|
||||
),
|
||||
):
|
||||
torch._scaled_mm(
|
||||
x_fp8,
|
||||
y_fp8,
|
||||
scale_a=torch.ones((M), device="cuda"),
|
||||
scale_b=torch.ones((N, N), device="cuda"),
|
||||
out_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
re.escape(
|
||||
"Both scale_a and scale_b must be contiguous for RowWise scaling."
|
||||
),
|
||||
):
|
||||
torch._scaled_mm(
|
||||
x_fp8,
|
||||
y_fp8,
|
||||
scale_a=torch.ones((M, 1), device="cuda"),
|
||||
scale_b=torch.ones((1, N * 2), device="cuda")[:, ::2],
|
||||
out_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
re.escape(
|
||||
"Expected b.dtype() == at::kFloat8_e4m3fn to be true, but got false."
|
||||
),
|
||||
):
|
||||
torch._scaled_mm(
|
||||
x_fp8,
|
||||
y_fp8.to(torch.float8_e5m2),
|
||||
scale_a=torch.ones((M, 1), device="cuda"),
|
||||
scale_b=torch.ones((1, N), device="cuda"),
|
||||
out_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
|
||||
@unittest.skipIf(not _IS_SM9X, "rowwise implementation is currently sm90 specific")
|
||||
@skipIfRocm()
|
||||
@parametrize("base_dtype", [torch.bfloat16])
|
||||
def test_scaled_mm_vs_emulated_row_wise(self, base_dtype):
|
||||
torch.manual_seed(42)
|
||||
input_dtype = e4m3_type
|
||||
output_dtype = base_dtype
|
||||
|
||||
x = torch.randn(16, 16, device="cuda", dtype=base_dtype)
|
||||
y = torch.randn(32, 16, device="cuda", dtype=base_dtype).t()
|
||||
|
||||
x_scales = tensor_to_scale(x, input_dtype, dim=1).float()
|
||||
y_scales = tensor_to_scale(y, input_dtype, dim=0).float()
|
||||
|
||||
x_fp8 = to_fp8_saturated(x * x_scales, e4m3_type)
|
||||
y_fp8 = to_fp8_saturated(y * y_scales, e4m3_type)
|
||||
|
||||
# Calculate actual F8 mm
|
||||
out_scaled_mm = mm_float8(
|
||||
x_fp8, y_fp8, a_scale=x_scales, b_scale=y_scales, output_dtype=output_dtype
|
||||
)
|
||||
|
||||
# Calculate emulated F8 mm
|
||||
out_emulated = mm_float8_emulated(
|
||||
x_fp8, x_scales, y_fp8, y_scales, output_dtype
|
||||
)
|
||||
|
||||
if base_dtype in {torch.bfloat16, torch.float16}:
|
||||
atol, rtol = 7e-2, 7e-2
|
||||
else:
|
||||
atol, rtol = 2e-3, 2e-3
|
||||
|
||||
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
@parametrize("which_dim_zero", [0, 1, 2])
|
||||
@parametrize("use_torch_compile", [False, True])
|
||||
def test_zero_dim_tensorwise(self, which_dim_zero, use_torch_compile) -> None:
|
||||
device = "cuda"
|
||||
x_dtype, y_dtype = torch.float8_e4m3fn, torch.float8_e4m3fn
|
||||
out_dtype = torch.bfloat16
|
||||
M, K, N = 32, 32, 32
|
||||
if which_dim_zero == 0:
|
||||
M = 0
|
||||
elif which_dim_zero == 1:
|
||||
K = 0
|
||||
elif which_dim_zero == 2:
|
||||
N = 0
|
||||
|
||||
x_fp8 = torch.zeros(M, K, device=device).to(x_dtype)
|
||||
y_fp8 = torch.zeros(N, K, device=device, dtype=y_dtype).t()
|
||||
out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float))
|
||||
scale_a = torch.tensor(float("-inf"), device=device)
|
||||
scale_b = torch.tensor(float("-inf"), device=device)
|
||||
f = torch._scaled_mm
|
||||
if use_torch_compile:
|
||||
f = torch.compile(torch._scaled_mm)
|
||||
out_fp8 = f(x_fp8, y_fp8, scale_a, scale_b, out_dtype=out_dtype)
|
||||
self.assertEqual(out_dtype, out_fp8.dtype)
|
||||
self.assertEqual(out_fp32, out_fp8.to(torch.float))
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestFP8Matmul, globals())
|
||||
|
||||
if __name__ == "__main__":
|
||||
TestCase._default_dtype_check_enabled = True
|
||||
run_tests()
|
||||
|
|
@ -34,6 +34,8 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__pdist_backward(AtenTensorHandle
|
|||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__pdist_forward(AtenTensorHandle self, double p, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_flash_attention_for_cpu(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, double dropout_p, int32_t is_causal, AtenTensorHandle* attn_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_flash_attention_for_cpu_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle out, AtenTensorHandle logsumexp, double dropout_p, int32_t is_causal, AtenTensorHandle* attn_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__segment_reduce_backward(AtenTensorHandle grad, AtenTensorHandle output, AtenTensorHandle data, const char* reduce, AtenTensorHandle* lengths, AtenTensorHandle* offsets, int64_t axis, double* initial, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0);
|
||||
|
|
|
|||
|
|
@ -1008,6 +1008,8 @@ ANY_DTYPE_ORDER = (
|
|||
torch.int8,
|
||||
torch.uint8,
|
||||
torch.bool,
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e5m2,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from torch.testing import make_tensor
|
|||
from torch.testing._internal.common_dtype import (
|
||||
_dispatch_dtypes, floating_types, floating_types_and, complex_types, floating_and_complex_types,
|
||||
floating_and_complex_types_and, all_types_and_complex_and, all_types_and, all_types_and_complex, integral_types_and,
|
||||
empty_types, complex_types_and, integral_types, custom_types, all_types_complex_float8_and,
|
||||
empty_types, complex_types_and, integral_types, custom_types, all_types_complex_float8_and, float8_types,
|
||||
)
|
||||
from torch.testing._internal.common_device_type import \
|
||||
(onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver,
|
||||
|
|
@ -8734,18 +8734,27 @@ def sample_inputs_scaled_mm(op_info, device, dtype, requires_grad, **kwargs):
|
|||
scale1 = make_scale((1,))
|
||||
scale2 = make_scale((1,))
|
||||
samples.append(SampleInput(mat1, mat2, scale1, scale2))
|
||||
# mat1 e4m3 mat2 e5m2
|
||||
mat1 = make_mat_e4m3((M, K))
|
||||
# two e5m2
|
||||
mat1 = make_mat_e5m2((M, K))
|
||||
mat2 = make_mat_e5m2((K, N)).t().contiguous().t()
|
||||
scale1 = make_scale((1,))
|
||||
scale2 = make_scale((1,))
|
||||
samples.append(SampleInput(mat1, mat2, scale1, scale2))
|
||||
# mat1 e5m2 mat2 e4m3
|
||||
mat1 = make_mat_e5m2((M, K))
|
||||
mat2 = make_mat_e4m3((K, N)).t().contiguous().t()
|
||||
scale1 = make_scale((1,))
|
||||
scale2 = make_scale((1,))
|
||||
samples.append(SampleInput(mat1, mat2, scale1, scale2))
|
||||
# TODO: Will remove this after oneDNN v3.6
|
||||
# now oneDNN v3.5.3 only supports mat1 * mat2 with the same data types.
|
||||
if device != 'cpu':
|
||||
# mat1 e4m3 mat2 e5m2
|
||||
mat1 = make_mat_e4m3((M, K))
|
||||
mat2 = make_mat_e5m2((K, N)).t().contiguous().t()
|
||||
scale1 = make_scale((1,))
|
||||
scale2 = make_scale((1,))
|
||||
samples.append(SampleInput(mat1, mat2, scale1, scale2))
|
||||
# mat1 e5m2 mat2 e4m3
|
||||
mat1 = make_mat_e5m2((M, K))
|
||||
mat2 = make_mat_e4m3((K, N)).t().contiguous().t()
|
||||
scale1 = make_scale((1,))
|
||||
scale2 = make_scale((1,))
|
||||
samples.append(SampleInput(mat1, mat2, scale1, scale2))
|
||||
|
||||
yield from samples
|
||||
|
||||
|
|
@ -16210,20 +16219,28 @@ op_db: List[OpInfo] = [
|
|||
OpInfo(
|
||||
'torch._scaled_mm',
|
||||
sample_inputs_func=sample_inputs_scaled_mm,
|
||||
dtypes=empty_types(),
|
||||
dtypes=float8_types(),
|
||||
dtypesIfCUDA=empty_types() + (torch.float8_e4m3fn,),
|
||||
supports_out=True,
|
||||
supports_forward_ad=False,
|
||||
supports_autograd=False,
|
||||
decorators=[skipCUDAIf(not SM90OrLater or TEST_WITH_ROCM, 'Requires CUDA SM >= 9.0')],
|
||||
decorators=[skipCUDAIf(not SM90OrLater or TEST_WITH_ROCM, 'Requires CUDA SM >= 9.0'), ],
|
||||
skips=(
|
||||
# Sample inputs isn't really parametrized on dtype
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes',
|
||||
device_type='cuda'),
|
||||
# "mul_cuda" not implemented for float8_e4m3fn
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes'),
|
||||
# "add_stub" not implemented for 'Float8_e4m3fn'
|
||||
# https://github.com/pytorch/pytorch/issues/107256
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestSchemaCheckModeOpInfo', 'test_schema_correctness',
|
||||
dtypes=(torch.float8_e4m3fn,)),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out',
|
||||
device_type='cpu'),
|
||||
# "mul_cuda" not implemented for float8_e4m3fn
|
||||
# "mul_cpu_reduced_float" not implemented for 'Float8_e4m3fn'
|
||||
# https://github.com/pytorch/pytorch/issues/107256
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestSchemaCheckModeOpInfo', 'test_schema_correctness'),
|
||||
# aten::_scaled_mm hit the vmap fallback which is currently disabled
|
||||
DecorateInfo(unittest.skip("Skipped!"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule",
|
||||
device_type='cpu'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness',
|
||||
dtypes=(torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz)),
|
||||
)
|
||||
),
|
||||
OpInfo(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user