mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Add torch._scaled_mm for CPU (#139975)
This PR is to add `torch._scaled_mm` for CPU backend. `_scaled_mm_out_cpu` and `_scaled_mm_cpu` are new added and included in `torch._scaled_mm` CPU dispatch. We also add `_scaled_mm_out_cpu_emulated` as a fallback function if the current platform cannot run FP8 matmul using oneDNN. And this PR also updates the various UTs related to FP8 to support CPU tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/139975 Approved by: https://github.com/mingfeima, https://github.com/jgong5, https://github.com/malfet
This commit is contained in:
parent
d3e9133ab2
commit
cbc4cf3043
|
|
@ -7,6 +7,11 @@
|
||||||
#include <ATen/Config.h>
|
#include <ATen/Config.h>
|
||||||
|
|
||||||
#include <ATen/native/mkldnn/Matmul.h>
|
#include <ATen/native/mkldnn/Matmul.h>
|
||||||
|
#include <ATen/native/mkldnn/Linear.h>
|
||||||
|
#include <ATen/native/Resize.h>
|
||||||
|
#if !defined(__s390x__) && !defined(__powerpc__)
|
||||||
|
#include <cpuinfo.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
#ifndef AT_PER_OPERATOR_HEADERS
|
#ifndef AT_PER_OPERATOR_HEADERS
|
||||||
#include <ATen/CPUFunctions.h>
|
#include <ATen/CPUFunctions.h>
|
||||||
|
|
@ -24,6 +29,9 @@
|
||||||
#include <ATen/ops/mv_native.h>
|
#include <ATen/ops/mv_native.h>
|
||||||
#include <ATen/ops/scalar_tensor_native.h>
|
#include <ATen/ops/scalar_tensor_native.h>
|
||||||
#include <ATen/ops/vdot_native.h>
|
#include <ATen/ops/vdot_native.h>
|
||||||
|
#include <ATen/ops/_scaled_mm_native.h>
|
||||||
|
#include <ATen/ops/mul.h>
|
||||||
|
#include <ATen/ops/matmul.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace at::meta {
|
namespace at::meta {
|
||||||
|
|
@ -222,4 +230,79 @@ Tensor vdot(const Tensor &self, const Tensor &other){
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Tensor&
|
||||||
|
_scaled_mm_out_cpu_emulated(const Tensor& mat1, const Tensor& mat2,
|
||||||
|
const Tensor& scale_a,
|
||||||
|
const Tensor& scale_b,
|
||||||
|
const std::optional<at::Tensor>& bias,
|
||||||
|
const std::optional<at::Tensor>& scale_result,
|
||||||
|
std::optional<c10::ScalarType> out_dtype,
|
||||||
|
bool use_fast_accum,
|
||||||
|
Tensor& out) {
|
||||||
|
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix");
|
||||||
|
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
|
||||||
|
TORCH_CHECK(
|
||||||
|
mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
|
||||||
|
mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
|
||||||
|
|
||||||
|
TORCH_INTERNAL_ASSERT((scale_a.numel() == 1 && scale_b.numel() == 1), "Now _scaled_mm only supports per-tensor scaling for CPU backend.");
|
||||||
|
TORCH_CHECK(!bias || bias->numel() == mat2.sizes()[1], "Bias must be size ", mat2.sizes()[1],
|
||||||
|
" but got ", bias->numel());
|
||||||
|
|
||||||
|
// Check types
|
||||||
|
TORCH_CHECK(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type");
|
||||||
|
TORCH_CHECK(isFloat8Type(mat1.scalar_type()), "Expected mat1 to be Float8 matrix got ", mat1.scalar_type());
|
||||||
|
TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat2.scalar_type());
|
||||||
|
|
||||||
|
auto mat1_c = mat1.contiguous();
|
||||||
|
auto mat2_c = mat2.contiguous();
|
||||||
|
IntArrayRef mat1_sizes = mat1_c.sizes();
|
||||||
|
IntArrayRef mat2_sizes = mat2_c.sizes();
|
||||||
|
at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
|
||||||
|
|
||||||
|
float input_scale = scale_a.item<float>();
|
||||||
|
float weight_scale = scale_b.item<float>();
|
||||||
|
auto fp32_mat1 = at::mul(mat1.to(kFloat), input_scale);
|
||||||
|
auto fp32_mat2 = at::mul(mat2_c.to(kFloat), weight_scale);
|
||||||
|
auto out_tmp = at::matmul(fp32_mat1, fp32_mat2);
|
||||||
|
if (bias) {
|
||||||
|
out_tmp.add_(bias.value());
|
||||||
|
}
|
||||||
|
out_tmp = out_tmp.to(out.scalar_type());
|
||||||
|
out.copy_(out_tmp);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor&
|
||||||
|
_scaled_mm_out_cpu(const Tensor& mat1, const Tensor& mat2,
|
||||||
|
const Tensor& scale_a,
|
||||||
|
const Tensor& scale_b,
|
||||||
|
const std::optional<at::Tensor>& bias,
|
||||||
|
const std::optional<at::Tensor>& scale_result,
|
||||||
|
std::optional<c10::ScalarType> out_dtype,
|
||||||
|
bool use_fast_accum,
|
||||||
|
Tensor& out) {
|
||||||
|
#if AT_MKLDNN_ENABLED()
|
||||||
|
if (at::globalContext().userEnabledMkldnn() && cpuinfo_has_x86_amx_int8()) {
|
||||||
|
return mkldnn_scaled_mm(mat1, mat2, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out);
|
||||||
|
} else
|
||||||
|
#endif
|
||||||
|
{
|
||||||
|
return _scaled_mm_out_cpu_emulated(mat1, mat2, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor
|
||||||
|
_scaled_mm_cpu(const Tensor& mat_a, const Tensor& mat_b,
|
||||||
|
const Tensor& scale_a,
|
||||||
|
const Tensor& scale_b,
|
||||||
|
const std::optional<at::Tensor>& bias,
|
||||||
|
const std::optional<at::Tensor>& scale_result,
|
||||||
|
std::optional<c10::ScalarType> out_dtype,
|
||||||
|
bool use_fast_accum) {
|
||||||
|
const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type());
|
||||||
|
Tensor out = at::empty({0}, mat_a.options().dtype(out_dtype_));
|
||||||
|
return _scaled_mm_out_cpu(mat_a, mat_b, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace at::native
|
} // namespace at::native
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@
|
||||||
#include <ATen/core/Tensor.h>
|
#include <ATen/core/Tensor.h>
|
||||||
#include <torch/library.h>
|
#include <torch/library.h>
|
||||||
#include <ATen/native/mkldnn/Linear.h>
|
#include <ATen/native/mkldnn/Linear.h>
|
||||||
|
#include <ATen/native/Resize.h>
|
||||||
|
|
||||||
#ifndef AT_PER_OPERATOR_HEADERS
|
#ifndef AT_PER_OPERATOR_HEADERS
|
||||||
#include <ATen/Functions.h>
|
#include <ATen/Functions.h>
|
||||||
|
|
@ -46,6 +47,18 @@ std::tuple<Tensor, Tensor, Tensor> mkldnn_linear_backward(
|
||||||
TORCH_CHECK(false, "mkldnn_linear_backward: ATen not compiled with MKLDNN support");
|
TORCH_CHECK(false, "mkldnn_linear_backward: ATen not compiled with MKLDNN support");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Tensor&
|
||||||
|
mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2,
|
||||||
|
const Tensor& scale_a,
|
||||||
|
const Tensor& scale_b,
|
||||||
|
const std::optional<at::Tensor>& bias,
|
||||||
|
const std::optional<at::Tensor>& scale_result,
|
||||||
|
std::optional<c10::ScalarType> out_dtype,
|
||||||
|
bool use_fast_accum,
|
||||||
|
Tensor& out) {
|
||||||
|
TORCH_INTERNAL_ASSERT(false, "mkldnn_scaled_mm: ATen not compiled with MKLDNN support");
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace native
|
} // namespace native
|
||||||
} // namespace at
|
} // namespace at
|
||||||
|
|
||||||
|
|
@ -447,6 +460,119 @@ TORCH_LIBRARY_IMPL(mkldnn, MkldnnCPU, m) {
|
||||||
TORCH_FN(mkldnn_linear_pointwise_binary));
|
TORCH_FN(mkldnn_linear_pointwise_binary));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Tensor&
|
||||||
|
mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2,
|
||||||
|
const Tensor& scale_a,
|
||||||
|
const Tensor& scale_b,
|
||||||
|
const std::optional<at::Tensor>& bias,
|
||||||
|
const std::optional<at::Tensor>& scale_result,
|
||||||
|
std::optional<c10::ScalarType> out_dtype,
|
||||||
|
bool use_fast_accum,
|
||||||
|
Tensor& out) {
|
||||||
|
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix");
|
||||||
|
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
|
||||||
|
TORCH_CHECK(
|
||||||
|
mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
|
||||||
|
mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
|
||||||
|
|
||||||
|
TORCH_INTERNAL_ASSERT((scale_a.numel() == 1 && scale_b.numel() == 1), "Now _scaled_mm only supports per-tensor scaling for CPU backend.");
|
||||||
|
TORCH_CHECK(!bias || bias->numel() == mat2.sizes()[1], "Bias must be size ", mat2.sizes()[1],
|
||||||
|
" but got ", bias->numel());
|
||||||
|
|
||||||
|
// Check types
|
||||||
|
TORCH_CHECK(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type");
|
||||||
|
TORCH_CHECK(isFloat8Type(mat1.scalar_type()), "Expected mat1 to be Float8 matrix got ", mat1.scalar_type());
|
||||||
|
TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat2.scalar_type());
|
||||||
|
// TODO: This check of mat1 and mat2 must have the same data type will be removed after oneDNN v3.6.
|
||||||
|
TORCH_CHECK(mat1.scalar_type() == mat2.scalar_type(), "Expected mat1 and mat2 must have the same data type");
|
||||||
|
|
||||||
|
// Validation checks have passed lets resize the output to actual size
|
||||||
|
auto mat1_c = mat1.contiguous();
|
||||||
|
auto mat2_c = mat2.contiguous();
|
||||||
|
IntArrayRef mat1_sizes = mat1_c.sizes();
|
||||||
|
IntArrayRef mat2_sizes = mat2_c.sizes();
|
||||||
|
at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
|
||||||
|
|
||||||
|
float input_scale = scale_a.item<float>();
|
||||||
|
float weight_scale = scale_b.item<float>();
|
||||||
|
auto src = at::native::itensor_view_from_dense(mat1_c);
|
||||||
|
auto weight_t = at::native::itensor_view_from_dense(mat2_c);
|
||||||
|
bool with_bias = bias.has_value();
|
||||||
|
int64_t K = mat1_sizes[1], M = mat1_sizes[0],
|
||||||
|
N = mat2_sizes[1];
|
||||||
|
|
||||||
|
std::vector<int64_t> src_dims = {M, K};
|
||||||
|
std::vector<int64_t> weight_dims = {K, N};
|
||||||
|
std::vector<int64_t> dst_dims = {M, N};
|
||||||
|
|
||||||
|
ideep::tensor dst = at::native::itensor_view_from_dense(out);
|
||||||
|
auto src_desc = ideep::tensor::desc(
|
||||||
|
src_dims,
|
||||||
|
get_mkldnn_dtype(mat1.scalar_type()),
|
||||||
|
ideep::format_tag::any);
|
||||||
|
auto weights_desc = ideep::tensor::desc(
|
||||||
|
weight_dims,
|
||||||
|
get_mkldnn_dtype(mat2.scalar_type()),
|
||||||
|
ideep::format_tag::any);
|
||||||
|
auto dst_desc = ideep::tensor::desc(
|
||||||
|
dst_dims,
|
||||||
|
get_mkldnn_dtype(out.scalar_type()),
|
||||||
|
ideep::format_tag::any);
|
||||||
|
ideep::tensor onednn_bias;
|
||||||
|
if (with_bias) {
|
||||||
|
auto bias_value = bias.value();
|
||||||
|
if (bias_value.dim() == 1) {
|
||||||
|
auto b_reshape = bias_value.reshape({1, bias_value.size(0)});
|
||||||
|
onednn_bias = at::native::itensor_view_from_dense(b_reshape);
|
||||||
|
} else {
|
||||||
|
onednn_bias = at::native::itensor_view_from_dense(bias_value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto bias_desc = ideep::tensor::desc();
|
||||||
|
if (with_bias) {
|
||||||
|
bias_desc = ideep::tensor::desc(onednn_bias.get_dims(),
|
||||||
|
get_mkldnn_dtype(bias.value().scalar_type()),
|
||||||
|
ideep::format_tag::any);
|
||||||
|
}
|
||||||
|
auto op_attr = ideep::attr_t();
|
||||||
|
if (input_scale != 1.0f) {
|
||||||
|
op_attr.set_scales_mask(DNNL_ARG_SRC, 0);
|
||||||
|
}
|
||||||
|
if (weight_scale != 1.0f) {
|
||||||
|
op_attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
|
||||||
|
auto engine = ideep::engine::cpu_engine();
|
||||||
|
dnnl::matmul::primitive_desc primitive_desc = with_bias
|
||||||
|
? dnnl::matmul::primitive_desc(
|
||||||
|
engine, src_desc, weights_desc, bias_desc, dst_desc, op_attr)
|
||||||
|
: dnnl::matmul::primitive_desc(
|
||||||
|
engine, src_desc, weights_desc, dst_desc, op_attr);
|
||||||
|
auto primitive = dnnl::matmul(primitive_desc);
|
||||||
|
|
||||||
|
// Prepare args and execute primitive
|
||||||
|
ideep::tensor scratchpad(primitive_desc.scratchpad_desc());
|
||||||
|
ideep::exec_args args;
|
||||||
|
args.insert({DNNL_ARG_SRC, src});
|
||||||
|
args.insert({DNNL_ARG_WEIGHTS, weight_t});
|
||||||
|
args.insert({DNNL_ARG_DST, dst});
|
||||||
|
args.insert({DNNL_ARG_SCRATCHPAD, scratchpad});
|
||||||
|
if (with_bias) {
|
||||||
|
args.insert({DNNL_ARG_BIAS, onednn_bias});
|
||||||
|
}
|
||||||
|
ideep::tensor src_scales_t = ideep::tensor(ideep::scale_t(1, input_scale));
|
||||||
|
ideep::tensor wei_scales_t = ideep::tensor(ideep::scale_t(1, weight_scale));
|
||||||
|
|
||||||
|
if (input_scale != 1.0f) {
|
||||||
|
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scales_t});
|
||||||
|
}
|
||||||
|
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_scales_t});
|
||||||
|
|
||||||
|
primitive.execute(ideep::stream::default_stream(), args);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace at
|
} // namespace at
|
||||||
|
|
||||||
#endif // AT_MKLDNN_ENABLED
|
#endif // AT_MKLDNN_ENABLED
|
||||||
|
|
|
||||||
|
|
@ -35,3 +35,15 @@ C10_API Tensor mkl_linear(
|
||||||
} // namespace at
|
} // namespace at
|
||||||
|
|
||||||
#endif // AT_MKLDNN_ENABLED()
|
#endif // AT_MKLDNN_ENABLED()
|
||||||
|
|
||||||
|
namespace at::native {
|
||||||
|
Tensor&
|
||||||
|
mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2,
|
||||||
|
const Tensor& scale_a,
|
||||||
|
const Tensor& scale_b,
|
||||||
|
const std::optional<at::Tensor>& bias,
|
||||||
|
const std::optional<at::Tensor>& scale_result,
|
||||||
|
std::optional<c10::ScalarType> out_dtype,
|
||||||
|
bool use_fast_accum,
|
||||||
|
Tensor& out);
|
||||||
|
} // namespace at::native
|
||||||
|
|
|
||||||
|
|
@ -57,6 +57,10 @@ ideep::tensor::data_type get_mkldnn_dtype(ScalarType type) {
|
||||||
return ideep::tensor::data_type::bf16;
|
return ideep::tensor::data_type::bf16;
|
||||||
case ScalarType::Half:
|
case ScalarType::Half:
|
||||||
return ideep::tensor::data_type::f16;
|
return ideep::tensor::data_type::f16;
|
||||||
|
case ScalarType::Float8_e4m3fn:
|
||||||
|
return ideep::tensor::data_type::f8_e4m3;
|
||||||
|
case ScalarType::Float8_e5m2:
|
||||||
|
return ideep::tensor::data_type::f8_e5m2;
|
||||||
default:
|
default:
|
||||||
TORCH_CHECK(false, "get_mkldnn_dtype: unsupported data type");
|
TORCH_CHECK(false, "get_mkldnn_dtype: unsupported data type");
|
||||||
}
|
}
|
||||||
|
|
@ -161,8 +165,24 @@ ideep::tensor itensor_view_from_dense(const Tensor& tensor, bool from_const_data
|
||||||
const_cast<void*>(tensor.const_data_ptr()) :
|
const_cast<void*>(tensor.const_data_ptr()) :
|
||||||
tensor.data_ptr()};
|
tensor.data_ptr()};
|
||||||
}
|
}
|
||||||
|
else if (tensor.scalar_type() == ScalarType::Float8_e4m3fn) {
|
||||||
|
return {{tensor.sizes().vec(),
|
||||||
|
ideep::tensor::data_type::f8_e4m3,
|
||||||
|
tensor.strides().vec()},
|
||||||
|
from_const_data_ptr ?
|
||||||
|
const_cast<void*>(tensor.const_data_ptr()) :
|
||||||
|
tensor.data_ptr()};
|
||||||
|
}
|
||||||
|
else if (tensor.scalar_type() == ScalarType::Float8_e5m2) {
|
||||||
|
return {{tensor.sizes().vec(),
|
||||||
|
ideep::tensor::data_type::f8_e5m2,
|
||||||
|
tensor.strides().vec()},
|
||||||
|
from_const_data_ptr ?
|
||||||
|
const_cast<void*>(tensor.const_data_ptr()) :
|
||||||
|
tensor.data_ptr()};
|
||||||
|
}
|
||||||
else {
|
else {
|
||||||
TORCH_CHECK(false, "itensor_view_from_dense expects float/bfloat16/half/int8 tensor input");
|
TORCH_CHECK(false, "itensor_view_from_dense expects float/bfloat16/half/int8/fp8 tensor input");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7071,11 +7071,13 @@
|
||||||
- func: _scaled_mm(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False) -> Tensor
|
- func: _scaled_mm(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False) -> Tensor
|
||||||
variants: function
|
variants: function
|
||||||
dispatch:
|
dispatch:
|
||||||
|
CPU: _scaled_mm_cpu
|
||||||
CUDA: _scaled_mm_cuda
|
CUDA: _scaled_mm_cuda
|
||||||
|
|
||||||
- func: _scaled_mm.out(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False, *, Tensor(a!) out) -> Tensor(a!)
|
- func: _scaled_mm.out(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False, *, Tensor(a!) out) -> Tensor(a!)
|
||||||
variants: function
|
variants: function
|
||||||
dispatch:
|
dispatch:
|
||||||
|
CPU: _scaled_mm_out_cpu
|
||||||
CUDA: _scaled_mm_out_cuda
|
CUDA: _scaled_mm_out_cuda
|
||||||
|
|
||||||
# NOTE [ Sparse: autograd and API ]
|
# NOTE [ Sparse: autograd and API ]
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ from torch.testing._internal.common_utils import (
|
||||||
parametrize,
|
parametrize,
|
||||||
TEST_WITH_ROCM,
|
TEST_WITH_ROCM,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.inductor_utils import HAS_CUDA
|
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
|
||||||
from torch.utils._triton import has_triton_tma_device
|
from torch.utils._triton import has_triton_tma_device
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -89,10 +89,10 @@ def _quantize_rowwise(x: Tensor, float8_dtype: torch.dtype):
|
||||||
|
|
||||||
@instantiate_parametrized_tests
|
@instantiate_parametrized_tests
|
||||||
class TestFP8Types(TestCase):
|
class TestFP8Types(TestCase):
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
||||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported yet")
|
@unittest.skipIf(TEST_WITH_ROCM, "Not supported yet")
|
||||||
@parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2))
|
@parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2))
|
||||||
def test_xblock_for_small_numel(self, float8_dtype: torch.dtype):
|
@parametrize("device", ("cuda", "cpu"))
|
||||||
|
def test_xblock_for_small_numel(self, float8_dtype: torch.dtype, device: str):
|
||||||
"""
|
"""
|
||||||
TritonOverrides.to_dtype will set min_elem_per_thread to 2 or 4
|
TritonOverrides.to_dtype will set min_elem_per_thread to 2 or 4
|
||||||
depends on the variant of fp8 type.
|
depends on the variant of fp8 type.
|
||||||
|
|
@ -101,19 +101,23 @@ class TestFP8Types(TestCase):
|
||||||
|
|
||||||
We should not pick a XBLOCK larger than xnumel
|
We should not pick a XBLOCK larger than xnumel
|
||||||
"""
|
"""
|
||||||
|
if device == "cuda" and not PLATFORM_SUPPORTS_FP8:
|
||||||
|
raise unittest.SkipTest(f8_msg)
|
||||||
|
|
||||||
def f(x):
|
def f(x):
|
||||||
return x.to(dtype=float8_dtype)
|
return x.to(dtype=float8_dtype)
|
||||||
|
|
||||||
x = torch.randn(1, device="cuda")
|
x = torch.randn(1, device=device)
|
||||||
expected = f(x)
|
expected = f(x)
|
||||||
actual = torch.compile(f)(x)
|
actual = torch.compile(f)(x)
|
||||||
torch.testing.assert_close(expected.half(), actual.half(), rtol=1e-2, atol=1e-2)
|
torch.testing.assert_close(expected.half(), actual.half(), rtol=1e-2, atol=1e-2)
|
||||||
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
||||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported yet")
|
@unittest.skipIf(TEST_WITH_ROCM, "Not supported yet")
|
||||||
@parametrize("dtype", (torch.float16, torch.bfloat16))
|
@parametrize("dtype", (torch.float16, torch.bfloat16))
|
||||||
def test_eager_fallback(self, dtype: torch.dtype):
|
@parametrize("device", ("cuda", "cpu"))
|
||||||
|
def test_eager_fallback(self, dtype: torch.dtype, device: torch.device):
|
||||||
|
if device == "cuda" and not PLATFORM_SUPPORTS_FP8:
|
||||||
|
raise unittest.SkipTest(f8_msg)
|
||||||
weight_shape = (32, 16)
|
weight_shape = (32, 16)
|
||||||
|
|
||||||
e4m3_type = (
|
e4m3_type = (
|
||||||
|
|
@ -121,11 +125,11 @@ class TestFP8Types(TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
def fp8_matmul_unwrapped(x):
|
def fp8_matmul_unwrapped(x):
|
||||||
a_scale = torch.Tensor([1.0]).to(device="cuda")
|
a_scale = torch.Tensor([1.0]).to(device=device)
|
||||||
b_scale = torch.Tensor([1.0]).to(device="cuda")
|
b_scale = torch.Tensor([1.0]).to(device=device)
|
||||||
output_scale = None
|
output_scale = None
|
||||||
input_bias = torch.rand(32, device="cuda", dtype=dtype)
|
input_bias = torch.rand(32, device=device, dtype=dtype)
|
||||||
weight = torch.rand(*weight_shape, device="cuda", dtype=dtype).T.to(
|
weight = torch.rand(*weight_shape, device=device, dtype=dtype).T.to(
|
||||||
e4m3_type
|
e4m3_type
|
||||||
)
|
)
|
||||||
a_inverse_scale = 1 / a_scale
|
a_inverse_scale = 1 / a_scale
|
||||||
|
|
@ -146,14 +150,13 @@ class TestFP8Types(TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
x_shape = (16, 16)
|
x_shape = (16, 16)
|
||||||
x = torch.rand(*x_shape, device="cuda", dtype=dtype).to(e4m3_type)
|
x = torch.rand(*x_shape, device=device, dtype=dtype).to(e4m3_type)
|
||||||
y_fp8 = compiled_fp8_matmul(x) # noqa: F841
|
y_fp8 = compiled_fp8_matmul(x) # noqa: F841
|
||||||
|
|
||||||
x_shape = (15, 16)
|
x_shape = (15, 16)
|
||||||
x = torch.rand(*x_shape, device="cuda", dtype=dtype).to(e4m3_type)
|
x = torch.rand(*x_shape, device=device, dtype=dtype).to(e4m3_type)
|
||||||
y_fp8 = compiled_fp8_matmul(x) # noqa: F841
|
y_fp8 = compiled_fp8_matmul(x) # noqa: F841
|
||||||
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
||||||
@parametrize("dtype", (torch.float16, torch.bfloat16, torch.float))
|
@parametrize("dtype", (torch.float16, torch.bfloat16, torch.float))
|
||||||
@parametrize("shape", ("15,3,13", "4,2048,4096"))
|
@parametrize("shape", ("15,3,13", "4,2048,4096"))
|
||||||
@parametrize(
|
@parametrize(
|
||||||
|
|
@ -162,7 +165,12 @@ class TestFP8Types(TestCase):
|
||||||
if torch.version.hip is None
|
if torch.version.hip is None
|
||||||
else [(torch.float8_e4m3fnuz, torch.float8_e5m2fnuz)],
|
else [(torch.float8_e4m3fnuz, torch.float8_e5m2fnuz)],
|
||||||
)
|
)
|
||||||
def test_valid_cast(self, dtype: torch.dtype, shape: str, dst_types: tuple):
|
@parametrize("device", ("cuda", "cpu"))
|
||||||
|
def test_valid_cast(
|
||||||
|
self, dtype: torch.dtype, shape: str, dst_types: tuple, device: torch.device
|
||||||
|
):
|
||||||
|
if device == "cuda" and not PLATFORM_SUPPORTS_FP8:
|
||||||
|
raise unittest.SkipTest(f8_msg)
|
||||||
e4m3, e5m2 = dst_types
|
e4m3, e5m2 = dst_types
|
||||||
|
|
||||||
def fp8_cast(x):
|
def fp8_cast(x):
|
||||||
|
|
@ -173,7 +181,7 @@ class TestFP8Types(TestCase):
|
||||||
compiled_fp8_cast = torch.compile(fp8_cast, backend="inductor", dynamic=True)
|
compiled_fp8_cast = torch.compile(fp8_cast, backend="inductor", dynamic=True)
|
||||||
|
|
||||||
shape = [int(dim) for dim in shape.split(",")]
|
shape = [int(dim) for dim in shape.split(",")]
|
||||||
x = torch.rand(*shape, device="cuda", dtype=dtype)
|
x = torch.rand(*shape, device=device, dtype=dtype)
|
||||||
y0_fp8, y1_fp8 = compiled_fp8_cast(x)
|
y0_fp8, y1_fp8 = compiled_fp8_cast(x)
|
||||||
|
|
||||||
torch.testing.assert_close(y0_fp8, x, rtol=5e-1, atol=5e-1)
|
torch.testing.assert_close(y0_fp8, x, rtol=5e-1, atol=5e-1)
|
||||||
|
|
@ -202,7 +210,6 @@ class TestFP8Types(TestCase):
|
||||||
x = torch.rand(*x_shape, device="cuda").to(dtype=torch.float8_e5m2)
|
x = torch.rand(*x_shape, device="cuda").to(dtype=torch.float8_e5m2)
|
||||||
compiled_fp8_cast(x, torch.float8_e4m3fn)
|
compiled_fp8_cast(x, torch.float8_e4m3fn)
|
||||||
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
||||||
@parametrize("src_dtype", (torch.float16, torch.bfloat16, torch.float))
|
@parametrize("src_dtype", (torch.float16, torch.bfloat16, torch.float))
|
||||||
@parametrize(
|
@parametrize(
|
||||||
"dst_dtype",
|
"dst_dtype",
|
||||||
|
|
@ -211,9 +218,17 @@ class TestFP8Types(TestCase):
|
||||||
else (torch.float8_e4m3fnuz, torch.float8_e5m2fnuz),
|
else (torch.float8_e4m3fnuz, torch.float8_e5m2fnuz),
|
||||||
)
|
)
|
||||||
@parametrize("shape", ("16,16,16", "4,2048,4096"))
|
@parametrize("shape", ("16,16,16", "4,2048,4096"))
|
||||||
|
@parametrize("device", ("cuda", "cpu"))
|
||||||
def test_to_fp8_saturated(
|
def test_to_fp8_saturated(
|
||||||
self, src_dtype: torch.dtype, dst_dtype: torch.dtype, shape: str
|
self,
|
||||||
|
src_dtype: torch.dtype,
|
||||||
|
dst_dtype: torch.dtype,
|
||||||
|
shape: str,
|
||||||
|
device: torch.device,
|
||||||
):
|
):
|
||||||
|
if device == "cuda" and not PLATFORM_SUPPORTS_FP8:
|
||||||
|
raise unittest.SkipTest(f8_msg)
|
||||||
|
|
||||||
def fp8_saturated(x, dtype):
|
def fp8_saturated(x, dtype):
|
||||||
return _to_fp8_saturated(x, dtype)
|
return _to_fp8_saturated(x, dtype)
|
||||||
|
|
||||||
|
|
@ -221,14 +236,13 @@ class TestFP8Types(TestCase):
|
||||||
fp8_saturated, backend="inductor", dynamic=True
|
fp8_saturated, backend="inductor", dynamic=True
|
||||||
)
|
)
|
||||||
shape = [int(dim) for dim in shape.split(",")]
|
shape = [int(dim) for dim in shape.split(",")]
|
||||||
x = torch.rand(*shape, device="cuda", dtype=src_dtype)
|
x = torch.rand(*shape, device=device, dtype=src_dtype)
|
||||||
y_compiled = compiled_fp8_cast(x, dst_dtype)
|
y_compiled = compiled_fp8_cast(x, dst_dtype)
|
||||||
y = fp8_saturated(x, dst_dtype)
|
y = fp8_saturated(x, dst_dtype)
|
||||||
|
|
||||||
torch.testing.assert_close(y_compiled.half(), y.half(), rtol=5e-1, atol=5e-1)
|
torch.testing.assert_close(y_compiled.half(), y.half(), rtol=5e-1, atol=5e-1)
|
||||||
|
|
||||||
@unittest.skipIf(TEST_WITH_ROCM, "ROCm fails with accuracy issue")
|
@unittest.skipIf(TEST_WITH_ROCM, "ROCm fails with accuracy issue")
|
||||||
@unittest.skipIf(not SM90OrLater, "FP8 is only supported on H100+")
|
|
||||||
@parametrize(
|
@parametrize(
|
||||||
"float8_dtype",
|
"float8_dtype",
|
||||||
(torch.float8_e4m3fn, torch.float8_e5m2)
|
(torch.float8_e4m3fn, torch.float8_e5m2)
|
||||||
|
|
@ -236,7 +250,12 @@ class TestFP8Types(TestCase):
|
||||||
else (torch.float8_e4m3fnuz, torch.float8_e5m2fnuz),
|
else (torch.float8_e4m3fnuz, torch.float8_e5m2fnuz),
|
||||||
)
|
)
|
||||||
@parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096"))
|
@parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096"))
|
||||||
def test_amax_fp8_quant(self, float8_dtype: torch.dtype, shape: str):
|
@parametrize("device", ("cuda", "cpu"))
|
||||||
|
def test_amax_fp8_quant(
|
||||||
|
self, float8_dtype: torch.dtype, shape: str, device: torch.device
|
||||||
|
):
|
||||||
|
if device == "cuda" and not SM90OrLater:
|
||||||
|
raise unittest.SkipTest("FP8 is only supported on H100+")
|
||||||
shape = [int(dim) for dim in shape.split(",")]
|
shape = [int(dim) for dim in shape.split(",")]
|
||||||
batch_size, sequence_length, hidden_size = shape
|
batch_size, sequence_length, hidden_size = shape
|
||||||
|
|
||||||
|
|
@ -249,15 +268,14 @@ class TestFP8Types(TestCase):
|
||||||
compiled_amax_fp8_quant = torch.compile(amax_fp8, backend="inductor")
|
compiled_amax_fp8_quant = torch.compile(amax_fp8, backend="inductor")
|
||||||
|
|
||||||
x_shape = (batch_size, sequence_length, hidden_size)
|
x_shape = (batch_size, sequence_length, hidden_size)
|
||||||
x = torch.rand(*x_shape, device="cuda", dtype=torch.half)
|
x = torch.rand(*x_shape, device=device, dtype=torch.half)
|
||||||
scale = torch.tensor(0.2, device="cuda", dtype=torch.float)
|
scale = torch.tensor(0.2, device=device, dtype=torch.float)
|
||||||
|
|
||||||
y_compiled = compiled_amax_fp8_quant(x, scale)
|
y_compiled = compiled_amax_fp8_quant(x, scale)
|
||||||
y = amax_fp8(x, scale)
|
y = amax_fp8(x, scale)
|
||||||
|
|
||||||
torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-2, atol=1e-2)
|
torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-2, atol=1e-2)
|
||||||
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
||||||
@parametrize(
|
@parametrize(
|
||||||
"float8_dtype",
|
"float8_dtype",
|
||||||
(torch.float8_e4m3fn, torch.float8_e5m2)
|
(torch.float8_e4m3fn, torch.float8_e5m2)
|
||||||
|
|
@ -265,7 +283,12 @@ class TestFP8Types(TestCase):
|
||||||
else (torch.float8_e4m3fnuz, torch.float8_e5m2fnuz),
|
else (torch.float8_e4m3fnuz, torch.float8_e5m2fnuz),
|
||||||
)
|
)
|
||||||
@parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096"))
|
@parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096"))
|
||||||
def test_amax_along_with_fp8_quant(self, float8_dtype: torch.dtype, shape: str):
|
@parametrize("device", ("cuda", "cpu"))
|
||||||
|
def test_amax_along_with_fp8_quant(
|
||||||
|
self, float8_dtype: torch.dtype, shape: str, device: torch.device
|
||||||
|
):
|
||||||
|
if device == "cuda" and not PLATFORM_SUPPORTS_FP8:
|
||||||
|
raise unittest.SkipTest(f8_msg)
|
||||||
shape = [int(dim) for dim in shape.split(",")]
|
shape = [int(dim) for dim in shape.split(",")]
|
||||||
batch_size, sequence_length, hidden_size = shape
|
batch_size, sequence_length, hidden_size = shape
|
||||||
|
|
||||||
|
|
@ -278,12 +301,12 @@ class TestFP8Types(TestCase):
|
||||||
compiled_amax_fp8_quant = torch.compile(amax_fp8, backend="inductor")
|
compiled_amax_fp8_quant = torch.compile(amax_fp8, backend="inductor")
|
||||||
|
|
||||||
x_shape = (batch_size, sequence_length, hidden_size)
|
x_shape = (batch_size, sequence_length, hidden_size)
|
||||||
x = torch.rand(*x_shape, device="cuda", dtype=torch.half)
|
x = torch.rand(*x_shape, device=device, dtype=torch.half)
|
||||||
scale = torch.tensor(1.0, device="cuda", dtype=torch.float)
|
scale = torch.tensor(1.0, device=device, dtype=torch.float)
|
||||||
|
|
||||||
amax_buffer_compiled = torch.zeros((1), device="cuda", dtype=torch.half)
|
amax_buffer_compiled = torch.zeros((1), device=device, dtype=torch.half)
|
||||||
y_compiled = compiled_amax_fp8_quant(x, scale, amax_buffer_compiled)
|
y_compiled = compiled_amax_fp8_quant(x, scale, amax_buffer_compiled)
|
||||||
amax_buffer = torch.zeros((1), device="cuda", dtype=torch.half)
|
amax_buffer = torch.zeros((1), device=device, dtype=torch.half)
|
||||||
y = amax_fp8(x, scale, amax_buffer)
|
y = amax_fp8(x, scale, amax_buffer)
|
||||||
|
|
||||||
torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-1, atol=1e-1)
|
torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-1, atol=1e-1)
|
||||||
|
|
@ -292,7 +315,6 @@ class TestFP8Types(TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
@unittest.skipIf(TEST_WITH_ROCM, "ROCm fails with accuracy issue")
|
@unittest.skipIf(TEST_WITH_ROCM, "ROCm fails with accuracy issue")
|
||||||
@unittest.skipIf(not SM90OrLater, "FP8 is only supported on H100+")
|
|
||||||
@parametrize(
|
@parametrize(
|
||||||
"float8_dtype",
|
"float8_dtype",
|
||||||
(torch.float8_e4m3fn, torch.float8_e5m2)
|
(torch.float8_e4m3fn, torch.float8_e5m2)
|
||||||
|
|
@ -301,9 +323,16 @@ class TestFP8Types(TestCase):
|
||||||
)
|
)
|
||||||
@parametrize("amax_keep_dim", (True, False))
|
@parametrize("amax_keep_dim", (True, False))
|
||||||
@parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096"))
|
@parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096"))
|
||||||
|
@parametrize("device", ("cuda", "cpu"))
|
||||||
def test_layernorm_fp8_quant(
|
def test_layernorm_fp8_quant(
|
||||||
self, float8_dtype: torch.dtype, amax_keep_dim: bool, shape: str
|
self,
|
||||||
|
float8_dtype: torch.dtype,
|
||||||
|
amax_keep_dim: bool,
|
||||||
|
shape: str,
|
||||||
|
device: torch.device,
|
||||||
):
|
):
|
||||||
|
if device == "cuda" and not SM90OrLater:
|
||||||
|
raise unittest.SkipTest("FP8 is only supported on H100+")
|
||||||
shape = [int(dim) for dim in shape.split(",")]
|
shape = [int(dim) for dim in shape.split(",")]
|
||||||
batch_size, sequence_length, hidden_size = shape
|
batch_size, sequence_length, hidden_size = shape
|
||||||
|
|
||||||
|
|
@ -325,12 +354,12 @@ class TestFP8Types(TestCase):
|
||||||
compiled_ln_fp8_quant = torch.compile(ln_fp8, backend="inductor")
|
compiled_ln_fp8_quant = torch.compile(ln_fp8, backend="inductor")
|
||||||
|
|
||||||
x_shape = (batch_size, sequence_length, hidden_size)
|
x_shape = (batch_size, sequence_length, hidden_size)
|
||||||
x = torch.rand(*x_shape, device="cuda", dtype=torch.half)
|
x = torch.rand(*x_shape, device=device, dtype=torch.half)
|
||||||
scale = torch.tensor(0.2, device="cuda", dtype=torch.float)
|
scale = torch.tensor(0.2, device=device, dtype=torch.float)
|
||||||
|
|
||||||
amax_buffer_compiled = torch.zeros((1), device="cuda", dtype=torch.half)
|
amax_buffer_compiled = torch.zeros((1), device=device, dtype=torch.half)
|
||||||
y_compiled = compiled_ln_fp8_quant(x, scale, amax_buffer_compiled)
|
y_compiled = compiled_ln_fp8_quant(x, scale, amax_buffer_compiled)
|
||||||
amax_buffer = torch.zeros((1), device="cuda", dtype=torch.half)
|
amax_buffer = torch.zeros((1), device=device, dtype=torch.half)
|
||||||
y = ln_fp8(x, scale, amax_buffer)
|
y = ln_fp8(x, scale, amax_buffer)
|
||||||
|
|
||||||
torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-1, atol=1e-1)
|
torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-1, atol=1e-1)
|
||||||
|
|
@ -748,5 +777,5 @@ class TestFP8Lowering(TestCase):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
if HAS_CUDA:
|
if HAS_CUDA or HAS_CPU:
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|
|
||||||
|
|
@ -3,8 +3,6 @@
|
||||||
import unittest
|
import unittest
|
||||||
from itertools import product
|
from itertools import product
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Optional
|
|
||||||
import re
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
@ -17,7 +15,6 @@ from torch.testing import make_tensor
|
||||||
from torch.testing._internal.common_cuda import (
|
from torch.testing._internal.common_cuda import (
|
||||||
SM53OrLater,
|
SM53OrLater,
|
||||||
_get_torch_cuda_version,
|
_get_torch_cuda_version,
|
||||||
PLATFORM_SUPPORTS_FP8
|
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_device_type import (
|
from torch.testing._internal.common_device_type import (
|
||||||
dtypes,
|
dtypes,
|
||||||
|
|
@ -212,524 +209,6 @@ class TestMatmulCuda(TestCase):
|
||||||
self.assertEqual(out1_gpu, out2_gpu[0])
|
self.assertEqual(out1_gpu, out2_gpu[0])
|
||||||
|
|
||||||
|
|
||||||
f8_msg = "FP8 is only supported on H100+ and sm_89 and MI300+ devices"
|
|
||||||
|
|
||||||
if torch.version.hip:
|
|
||||||
e4m3_type = torch.float8_e4m3fnuz
|
|
||||||
e5m2_type = torch.float8_e5m2fnuz
|
|
||||||
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max
|
|
||||||
E5M2_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max
|
|
||||||
else:
|
|
||||||
e4m3_type = torch.float8_e4m3fn
|
|
||||||
e5m2_type = torch.float8_e5m2
|
|
||||||
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max
|
|
||||||
E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max
|
|
||||||
|
|
||||||
# avoid division by zero when calculating scale
|
|
||||||
EPS = 1e-12
|
|
||||||
|
|
||||||
def amax_to_scale(
|
|
||||||
amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
|
|
||||||
):
|
|
||||||
""" Converts the amax value of a tensor to the fp8 scale.
|
|
||||||
Args:
|
|
||||||
amax: The amax value of the tensor.
|
|
||||||
float8_dtype: the float8 dtype.
|
|
||||||
orig_dtype: The original dtype of the tensor.
|
|
||||||
"""
|
|
||||||
scale = torch.empty_like(amax, dtype=torch.float32)
|
|
||||||
if float8_dtype == e4m3_type:
|
|
||||||
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
|
|
||||||
elif float8_dtype == e5m2_type:
|
|
||||||
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")
|
|
||||||
|
|
||||||
# Ensure the scale is representable in float16,
|
|
||||||
# this helps when amax is small. We are assuming that we don't need
|
|
||||||
# to care about this for float32/bfloat16
|
|
||||||
if orig_dtype is torch.float16:
|
|
||||||
res = torch.clamp(res, max=torch.finfo(torch.float16).max)
|
|
||||||
|
|
||||||
scale.copy_(res)
|
|
||||||
return scale
|
|
||||||
|
|
||||||
def tensor_to_scale(x: torch.Tensor, float8_dtype: torch.dtype, dim=None):
|
|
||||||
if dim is None:
|
|
||||||
amax = torch.max(torch.abs(x))
|
|
||||||
else:
|
|
||||||
amax = torch.max(torch.abs(x), dim=dim, keepdim=True).values
|
|
||||||
|
|
||||||
return amax_to_scale(amax, float8_dtype, x.dtype)
|
|
||||||
|
|
||||||
def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor:
|
|
||||||
# naive implementation: dq -> op -> q
|
|
||||||
x_fp32 = x.to(torch.float) / x_scale
|
|
||||||
y_fp32 = y.to(torch.float) / y_scale
|
|
||||||
out_fp32 = torch.mm(x_fp32, y_fp32)
|
|
||||||
|
|
||||||
return out_fp32.to(out_dtype)
|
|
||||||
|
|
||||||
def addmm_float8_unwrapped(
|
|
||||||
a_data: torch.Tensor,
|
|
||||||
a_scale: torch.Tensor,
|
|
||||||
b_data: torch.Tensor,
|
|
||||||
b_scale: torch.tensor,
|
|
||||||
output_dtype: torch.dtype,
|
|
||||||
output_scale: Optional[torch.Tensor],
|
|
||||||
bias: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
a_inverse_scale = a_scale.reciprocal()
|
|
||||||
b_inverse_scale = b_scale.reciprocal()
|
|
||||||
if output_dtype == torch.float32 and bias is not None:
|
|
||||||
# Bias is not supported by _scaled_mm when output is fp32
|
|
||||||
output = torch._scaled_mm(
|
|
||||||
a_data,
|
|
||||||
b_data,
|
|
||||||
scale_a=a_inverse_scale,
|
|
||||||
scale_b=b_inverse_scale,
|
|
||||||
scale_result=output_scale,
|
|
||||||
out_dtype=output_dtype,
|
|
||||||
)
|
|
||||||
output += bias
|
|
||||||
return output
|
|
||||||
output = torch._scaled_mm(
|
|
||||||
a_data,
|
|
||||||
b_data,
|
|
||||||
bias=bias,
|
|
||||||
scale_a=a_inverse_scale,
|
|
||||||
scale_b=b_inverse_scale,
|
|
||||||
scale_result=output_scale,
|
|
||||||
out_dtype=output_dtype,
|
|
||||||
)
|
|
||||||
return output
|
|
||||||
|
|
||||||
def mm_float8(
|
|
||||||
a: torch.Tensor,
|
|
||||||
b: torch.Tensor,
|
|
||||||
a_scale: torch.Tensor,
|
|
||||||
b_scale: torch.Tensor,
|
|
||||||
output_dtype: torch.dtype, # output dtype
|
|
||||||
output_scale: Optional[torch.Tensor] = None, # output scale, precomputed
|
|
||||||
) -> torch.Tensor:
|
|
||||||
return addmm_float8_unwrapped(
|
|
||||||
a, a_scale, b, b_scale, output_dtype, output_scale
|
|
||||||
)
|
|
||||||
|
|
||||||
def to_fp8_saturated(
|
|
||||||
x: torch.Tensor,
|
|
||||||
fp8_dtype: torch.dtype
|
|
||||||
):
|
|
||||||
if fp8_dtype == e4m3_type:
|
|
||||||
x = x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS)
|
|
||||||
elif fp8_dtype == e5m2_type:
|
|
||||||
x = x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"to_fp8_saturated(): Unsupported fp8_dtype: {fp8_dtype}")
|
|
||||||
|
|
||||||
return x.to(fp8_dtype)
|
|
||||||
|
|
||||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not found")
|
|
||||||
class TestFP8MatmulCuda(TestCase):
|
|
||||||
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
||||||
def _test_tautological_mm(self, device: str = "cuda",
|
|
||||||
x_dtype: torch.dtype = e4m3_type,
|
|
||||||
y_dtype: torch.dtype = e4m3_type,
|
|
||||||
out_dtype: Optional[torch.dtype] = None,
|
|
||||||
size: int = 16) -> None:
|
|
||||||
x_fp8 = torch.rand(size, size, device=device).to(x_dtype)
|
|
||||||
y_fp8 = torch.eye(size, device=device, dtype=y_dtype).t()
|
|
||||||
out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float))
|
|
||||||
scale_a = torch.tensor(1.0, device=device)
|
|
||||||
scale_b = torch.tensor(1.0, device=device)
|
|
||||||
out_fp8 = torch._scaled_mm(x_fp8, y_fp8, scale_a, scale_b, out_dtype=out_dtype)
|
|
||||||
if out_dtype is not None:
|
|
||||||
self.assertEqual(out_dtype, out_fp8.dtype)
|
|
||||||
self.assertEqual(out_fp32, out_fp8.to(torch.float))
|
|
||||||
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
||||||
def test_float8_basics(self, device) -> None:
|
|
||||||
self._test_tautological_mm(device, e4m3_type, e4m3_type, size=16)
|
|
||||||
# hipblaslt does not yet support mixed e4m3_type input
|
|
||||||
if torch.version.hip is None:
|
|
||||||
self._test_tautological_mm(device, e4m3_type, e5m2_type, size=32)
|
|
||||||
self._test_tautological_mm(device, e5m2_type, e4m3_type, size=48)
|
|
||||||
# According to https://docs.nvidia.com/cuda/cublas/#id99 8F_E5M2 MM is unsupported
|
|
||||||
with self.assertRaises(RuntimeError):
|
|
||||||
self._test_tautological_mm(device, e5m2_type, e5m2_type)
|
|
||||||
|
|
||||||
self._test_tautological_mm(device, size=64, out_dtype=torch.float16)
|
|
||||||
self._test_tautological_mm(device, size=96, out_dtype=torch.float32)
|
|
||||||
# hipblaslt does not yet support bfloat16 output
|
|
||||||
if torch.version.hip is None:
|
|
||||||
self._test_tautological_mm(device, size=80, out_dtype=torch.bfloat16)
|
|
||||||
with self.assertRaises(RuntimeError):
|
|
||||||
self._test_tautological_mm(device, out_dtype=e5m2_type)
|
|
||||||
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
||||||
def test_float8_scale(self, device) -> None:
|
|
||||||
size = (16, 16)
|
|
||||||
x = torch.full(size, .5, device=device, dtype=e4m3_type)
|
|
||||||
# hipblaslt does not yet support mixed e4m3_type input
|
|
||||||
y_type = e4m3_type if torch.version.hip else e5m2_type
|
|
||||||
y = torch.full(size, .5, device=device, dtype=y_type).t()
|
|
||||||
scale_a = torch.tensor(1.5, device=device)
|
|
||||||
scale_b = torch.tensor(0.66, device=device)
|
|
||||||
out_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b)
|
|
||||||
self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4., device=device))
|
|
||||||
out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b)
|
|
||||||
self.assertEqual(out_fp8, out_fp8_s)
|
|
||||||
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
||||||
@parametrize("base_dtype", [torch.float16, torch.bfloat16, torch.float32])
|
|
||||||
def test_scaled_mm_vs_emulated(self, base_dtype):
|
|
||||||
torch.manual_seed(42)
|
|
||||||
input_dtype = e4m3_type
|
|
||||||
output_dtype = base_dtype
|
|
||||||
compare_type = torch.float32
|
|
||||||
|
|
||||||
x = torch.randn(16, 16, device="cuda", dtype=base_dtype)
|
|
||||||
y = torch.randn(32, 16, device="cuda", dtype=base_dtype).t()
|
|
||||||
|
|
||||||
x_scale = tensor_to_scale(x, input_dtype).float()
|
|
||||||
y_scale = tensor_to_scale(y, input_dtype).float()
|
|
||||||
|
|
||||||
x_fp8 = to_fp8_saturated(x * x_scale, input_dtype)
|
|
||||||
y_fp8 = to_fp8_saturated(y * y_scale, input_dtype)
|
|
||||||
|
|
||||||
# Calculate actual F8 mm
|
|
||||||
out_scaled_mm = mm_float8(
|
|
||||||
x_fp8,
|
|
||||||
y_fp8,
|
|
||||||
a_scale=x_scale,
|
|
||||||
b_scale=y_scale,
|
|
||||||
output_dtype=output_dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate emulated F8 mm
|
|
||||||
out_emulated = mm_float8_emulated(
|
|
||||||
x_fp8,
|
|
||||||
x_scale,
|
|
||||||
y_fp8,
|
|
||||||
y_scale,
|
|
||||||
output_dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
if output_dtype != base_dtype:
|
|
||||||
out_scaled_mm = out_scaled_mm.to(compare_type)
|
|
||||||
out_scaled_mm = out_scaled_mm / tensor_to_scale(out_scaled_mm, input_dtype)
|
|
||||||
|
|
||||||
out_emulated = out_emulated.to(compare_type)
|
|
||||||
out_emulated = out_emulated / tensor_to_scale(out_emulated, input_dtype)
|
|
||||||
|
|
||||||
if base_dtype in {torch.bfloat16, torch.float16}:
|
|
||||||
atol, rtol = 7e-2, 7e-2
|
|
||||||
else:
|
|
||||||
atol, rtol = 3e-3, 3e-3
|
|
||||||
|
|
||||||
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
|
|
||||||
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
||||||
@parametrize("base_dtype", [torch.float16, torch.bfloat16, torch.float32])
|
|
||||||
def test_scaled_mm_change_stride(self, base_dtype):
|
|
||||||
torch.manual_seed(42)
|
|
||||||
input_dtype = e4m3_type
|
|
||||||
output_dtype = base_dtype
|
|
||||||
compare_type = torch.float32
|
|
||||||
|
|
||||||
x = torch.empty_strided((16, 16), (16, 1), device="cuda", dtype=base_dtype)
|
|
||||||
y = torch.empty_strided((16, 32), (1, 64), device="cuda", dtype=base_dtype)
|
|
||||||
|
|
||||||
x_scale = tensor_to_scale(x, input_dtype).float()
|
|
||||||
y_scale = tensor_to_scale(y, input_dtype).float()
|
|
||||||
|
|
||||||
x_fp8 = to_fp8_saturated(x * x_scale, input_dtype)
|
|
||||||
y_fp8 = to_fp8_saturated(y * y_scale, input_dtype)
|
|
||||||
|
|
||||||
# Calculate actual F8 mm
|
|
||||||
out_scaled_mm = mm_float8(
|
|
||||||
x_fp8,
|
|
||||||
y_fp8,
|
|
||||||
a_scale=x_scale,
|
|
||||||
b_scale=y_scale,
|
|
||||||
output_dtype=output_dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate emulated F8 mm
|
|
||||||
out_emulated = mm_float8_emulated(
|
|
||||||
x_fp8,
|
|
||||||
x_scale,
|
|
||||||
y_fp8,
|
|
||||||
y_scale,
|
|
||||||
output_dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
if output_dtype != base_dtype:
|
|
||||||
out_scaled_mm = out_scaled_mm.to(compare_type)
|
|
||||||
out_scaled_mm = out_scaled_mm / tensor_to_scale(out_scaled_mm, input_dtype)
|
|
||||||
|
|
||||||
out_emulated = out_emulated.to(compare_type)
|
|
||||||
out_emulated = out_emulated / tensor_to_scale(out_emulated, input_dtype)
|
|
||||||
|
|
||||||
if base_dtype in {torch.bfloat16, torch.float16}:
|
|
||||||
atol, rtol = 7e-2, 7e-2
|
|
||||||
else:
|
|
||||||
atol, rtol = 3e-3, 3e-3
|
|
||||||
|
|
||||||
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
|
|
||||||
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
||||||
def test_float8_bias(self, device) -> None:
|
|
||||||
(k, l, m) = (16, 48, 32)
|
|
||||||
x = torch.ones((k, l), device=device).to(e4m3_type)
|
|
||||||
y = torch.full((m, l), .25, device=device, dtype=e4m3_type).t()
|
|
||||||
bias = torch.full((m,), 4.0, device=device, dtype=torch.half)
|
|
||||||
scale_a = torch.tensor(1.0, device=device)
|
|
||||||
scale_b = torch.tensor(1.0, device=device)
|
|
||||||
out_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b)
|
|
||||||
outb_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, bias=bias)
|
|
||||||
# this fails on ROCm currently because hipblaslt doesn't have amax op
|
|
||||||
out_fp32 = out_fp8.to(torch.float32)
|
|
||||||
outb_fp32 = outb_fp8.to(torch.float32)
|
|
||||||
difference = torch.abs(out_fp32 - outb_fp32)
|
|
||||||
self.assertEqual(difference, torch.tensor(4.0, device=device).expand_as(out_fp32))
|
|
||||||
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
||||||
@parametrize("bias", [True, False])
|
|
||||||
def test_non_divisible_leading_dim(self, device, bias: bool) -> None:
|
|
||||||
x = torch.rand((17, 16), device=device).to(e4m3_type)
|
|
||||||
y = torch.rand((16, 16), device=device).to(e4m3_type).t()
|
|
||||||
scale_a = torch.tensor(1.0, device=device)
|
|
||||||
scale_b = torch.tensor(1.0, device=device)
|
|
||||||
input_bias = None
|
|
||||||
if bias:
|
|
||||||
input_bias = torch.rand((16,), device=device).to(torch.half)
|
|
||||||
_ = torch._scaled_mm(x, y, scale_a, scale_b, bias=input_bias)
|
|
||||||
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
||||||
def test_float8_bias_relu_edgecase(self, device) -> None:
|
|
||||||
(k, l, m) = (16, 48, 32)
|
|
||||||
x = torch.full((k, l), 0.0, device=device).to(e4m3_type)
|
|
||||||
y = torch.full((m, l), 1.0, device=device, dtype=e4m3_type).t()
|
|
||||||
bias = torch.full((m,), -3.0, device=device, dtype=torch.half)
|
|
||||||
scale_a = torch.tensor(1.0, device=device)
|
|
||||||
scale_b = torch.tensor(1.0, device=device)
|
|
||||||
outb_fp8 = torch._scaled_mm(x, y, scale_a, scale_b, bias=bias)
|
|
||||||
outb_fp32 = outb_fp8.to(torch.float32)
|
|
||||||
self.assertEqual(outb_fp32, torch.tensor(-3.0, device=device).expand_as(outb_fp32))
|
|
||||||
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
||||||
def test_float32_output_errors_with_bias(self, device) -> None:
|
|
||||||
(k, l, m) = (16, 48, 32)
|
|
||||||
x = torch.rand((k, l), device=device).to(e4m3_type)
|
|
||||||
y = torch.full((m, l), .25, device=device, dtype=e4m3_type).t()
|
|
||||||
scale_a = torch.tensor(1.0, device=device)
|
|
||||||
scale_b = torch.tensor(1.0, device=device)
|
|
||||||
bias = torch.full((m,), 4.0, device=device, dtype=torch.bfloat16)
|
|
||||||
self.assertRaisesRegex(
|
|
||||||
RuntimeError,
|
|
||||||
"Bias is not supported when out_dtype is set to Float32",
|
|
||||||
lambda: torch._scaled_mm(x, y, scale_a, scale_b, bias=bias, out_dtype=torch.float32),
|
|
||||||
)
|
|
||||||
|
|
||||||
@unittest.skipIf(PLATFORM_SUPPORTS_FP8,
|
|
||||||
"This test is only for devices with compute capability < 8.9")
|
|
||||||
def test_error_message_fp8_pre_sm89(self, device) -> None:
|
|
||||||
(k, l, m) = (16, 48, 32)
|
|
||||||
x = torch.rand((k, l), device=device).to(e4m3_type)
|
|
||||||
y = torch.rand((m, l), device=device).to(e4m3_type).t()
|
|
||||||
scale_a = torch.tensor(1.0, device=device)
|
|
||||||
scale_b = torch.tensor(1.0, device=device)
|
|
||||||
self.assertRaisesRegex(
|
|
||||||
RuntimeError,
|
|
||||||
r"torch\.\_scaled\_mm is only supported on CUDA devices with compute capability \>\= 9\.0 or 8\.9, or ROCm MI300\+",
|
|
||||||
lambda: torch._scaled_mm(x, y, scale_a, scale_b, out_dtype=torch.float32),
|
|
||||||
)
|
|
||||||
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
||||||
def test_float8_scale_fast_accum(self, device) -> None:
|
|
||||||
size = (16, 16)
|
|
||||||
x = torch.full(size, .5, device=device, dtype=e4m3_type)
|
|
||||||
# hipblaslt does not yet support mixed e4m3_type input
|
|
||||||
y_type = e4m3_type if torch.version.hip else e5m2_type
|
|
||||||
y = torch.full(size, .5, device=device, dtype=y_type).t()
|
|
||||||
scale_a = torch.tensor(1.5, device=device)
|
|
||||||
scale_b = torch.tensor(0.66, device=device)
|
|
||||||
out_fp8 = torch._scaled_mm(x, y, scale_a, scale_b, use_fast_accum=True)
|
|
||||||
self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4., device=device))
|
|
||||||
out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, use_fast_accum=True)
|
|
||||||
self.assertEqual(out_fp8, out_fp8_s)
|
|
||||||
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
|
|
||||||
@unittest.skipIf(not _IS_SM9X, "rowwise implementation is currently sm90 specific")
|
|
||||||
@skipIfRocm()
|
|
||||||
@parametrize("use_fast_accum", [True, False])
|
|
||||||
def test_float8_rowwise_scaling_sanity(self, device, use_fast_accum: bool) -> None:
|
|
||||||
M, K, N = (1024, 512, 2048)
|
|
||||||
fill_value = 0.5
|
|
||||||
x = torch.full((M, K), fill_value, device=device)
|
|
||||||
y = torch.full((N, K), fill_value, device=device)
|
|
||||||
|
|
||||||
x_scales = torch.ones((x.shape[0], 1), device=device, dtype=torch.float32)
|
|
||||||
y_scales = torch.ones((1, y.shape[0]), device=device, dtype=torch.float32)
|
|
||||||
|
|
||||||
x_fp8 = x.to(torch.float8_e4m3fn)
|
|
||||||
y_fp8 = y.to(torch.float8_e4m3fn).t()
|
|
||||||
|
|
||||||
out_fp8 = torch._scaled_mm(
|
|
||||||
x_fp8,
|
|
||||||
y_fp8,
|
|
||||||
scale_a=x_scales,
|
|
||||||
scale_b=y_scales,
|
|
||||||
out_dtype=torch.bfloat16,
|
|
||||||
use_fast_accum=use_fast_accum,
|
|
||||||
)
|
|
||||||
self.assertEqual(
|
|
||||||
out_fp8.to(torch.float32), torch.full((M, N), K * (fill_value**2), device=device)
|
|
||||||
)
|
|
||||||
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
|
|
||||||
@skipIfRocm()
|
|
||||||
def test_float8_error_messages(self, device) -> None:
|
|
||||||
M, K, N = (1024, 512, 2048)
|
|
||||||
fill_value = 0.5
|
|
||||||
x = torch.full((M, K), fill_value, device=device)
|
|
||||||
y = torch.full((N, K), fill_value, device=device)
|
|
||||||
|
|
||||||
x_fp8 = x.to(torch.float8_e4m3fn)
|
|
||||||
y_fp8 = y.to(torch.float8_e4m3fn).t()
|
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
|
||||||
RuntimeError,
|
|
||||||
re.escape(
|
|
||||||
"For RowWise scaling, scale_a should be (1024, 1) and scale_b "
|
|
||||||
"should be (1, 2048). Got scale_a.size()=(1, 1) and scale_b.size()=(1, 2)"
|
|
||||||
),
|
|
||||||
):
|
|
||||||
torch._scaled_mm(
|
|
||||||
x_fp8,
|
|
||||||
y_fp8,
|
|
||||||
scale_a=torch.ones((1, 1), device="cuda"),
|
|
||||||
scale_b=torch.ones((1, 2), device="cuda"),
|
|
||||||
out_dtype=torch.bfloat16,
|
|
||||||
)
|
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
|
||||||
RuntimeError,
|
|
||||||
re.escape(
|
|
||||||
" For RowWise scaling, scale_a should be (1024, 1) and scale_b "
|
|
||||||
"should be (1, 2048). Got scale_a.size()=(1024, 1) and scale_b.size()=(1, 2049)"
|
|
||||||
),
|
|
||||||
):
|
|
||||||
torch._scaled_mm(
|
|
||||||
x_fp8,
|
|
||||||
y_fp8,
|
|
||||||
scale_a=torch.ones((M, 1), device="cuda"),
|
|
||||||
scale_b=torch.ones((1, N + 1), device="cuda"),
|
|
||||||
out_dtype=torch.bfloat16,
|
|
||||||
)
|
|
||||||
with self.assertRaisesRegex(
|
|
||||||
RuntimeError,
|
|
||||||
re.escape("For non-TensorWise scaling, scale tensors must be 2-dimensional"),
|
|
||||||
):
|
|
||||||
torch._scaled_mm(
|
|
||||||
x_fp8,
|
|
||||||
y_fp8,
|
|
||||||
scale_a=torch.ones((M), device="cuda"),
|
|
||||||
scale_b=torch.ones((N, N), device="cuda"),
|
|
||||||
out_dtype=torch.bfloat16,
|
|
||||||
)
|
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
|
||||||
RuntimeError,
|
|
||||||
re.escape(
|
|
||||||
"Both scale_a and scale_b must be contiguous for RowWise scaling."
|
|
||||||
),
|
|
||||||
):
|
|
||||||
torch._scaled_mm(
|
|
||||||
x_fp8,
|
|
||||||
y_fp8,
|
|
||||||
scale_a=torch.ones((M, 1), device="cuda"),
|
|
||||||
scale_b=torch.ones((1, N * 2), device="cuda")[:, ::2],
|
|
||||||
out_dtype=torch.bfloat16,
|
|
||||||
)
|
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
|
||||||
RuntimeError,
|
|
||||||
re.escape("Expected b.dtype() == at::kFloat8_e4m3fn to be true, but got false."),
|
|
||||||
):
|
|
||||||
torch._scaled_mm(
|
|
||||||
x_fp8,
|
|
||||||
y_fp8.to(torch.float8_e5m2),
|
|
||||||
scale_a=torch.ones((M, 1), device="cuda"),
|
|
||||||
scale_b=torch.ones((1, N), device="cuda"),
|
|
||||||
out_dtype=torch.bfloat16,
|
|
||||||
)
|
|
||||||
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
|
|
||||||
@unittest.skipIf(not _IS_SM9X, "rowwise implementation is currently sm90 specific")
|
|
||||||
@skipIfRocm()
|
|
||||||
@parametrize("base_dtype", [torch.bfloat16])
|
|
||||||
def test_scaled_mm_vs_emulated_row_wise(self, base_dtype):
|
|
||||||
torch.manual_seed(42)
|
|
||||||
input_dtype = e4m3_type
|
|
||||||
output_dtype = base_dtype
|
|
||||||
|
|
||||||
x = torch.randn(16, 16, device="cuda", dtype=base_dtype)
|
|
||||||
y = torch.randn(32, 16, device="cuda", dtype=base_dtype).t()
|
|
||||||
|
|
||||||
x_scales = tensor_to_scale(x, input_dtype, dim=1).float()
|
|
||||||
y_scales = tensor_to_scale(y, input_dtype, dim=0).float()
|
|
||||||
|
|
||||||
x_fp8 = to_fp8_saturated(x * x_scales, e4m3_type)
|
|
||||||
y_fp8 = to_fp8_saturated(y * y_scales, e4m3_type)
|
|
||||||
|
|
||||||
# Calculate actual F8 mm
|
|
||||||
out_scaled_mm = mm_float8(
|
|
||||||
x_fp8, y_fp8, a_scale=x_scales, b_scale=y_scales, output_dtype=output_dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate emulated F8 mm
|
|
||||||
out_emulated = mm_float8_emulated(
|
|
||||||
x_fp8, x_scales, y_fp8, y_scales, output_dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
if base_dtype in {torch.bfloat16, torch.float16}:
|
|
||||||
atol, rtol = 7e-2, 7e-2
|
|
||||||
else:
|
|
||||||
atol, rtol = 2e-3, 2e-3
|
|
||||||
|
|
||||||
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
|
|
||||||
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
||||||
@parametrize("which_dim_zero", [0, 1, 2])
|
|
||||||
@parametrize("use_torch_compile", [False, True])
|
|
||||||
def test_zero_dim_tensorwise(self, which_dim_zero, use_torch_compile) -> None:
|
|
||||||
device = "cuda"
|
|
||||||
x_dtype, y_dtype = torch.float8_e4m3fn, torch.float8_e4m3fn
|
|
||||||
out_dtype = torch.bfloat16
|
|
||||||
M, K, N = 32, 32, 32
|
|
||||||
if which_dim_zero == 0:
|
|
||||||
M = 0
|
|
||||||
elif which_dim_zero == 1:
|
|
||||||
K = 0
|
|
||||||
elif which_dim_zero == 2:
|
|
||||||
N = 0
|
|
||||||
|
|
||||||
x_fp8 = torch.zeros(M, K, device=device).to(x_dtype)
|
|
||||||
y_fp8 = torch.zeros(N, K, device=device, dtype=y_dtype).t()
|
|
||||||
out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float))
|
|
||||||
scale_a = torch.tensor(float('-inf'), device=device)
|
|
||||||
scale_b = torch.tensor(float('-inf'), device=device)
|
|
||||||
f = torch._scaled_mm
|
|
||||||
if use_torch_compile:
|
|
||||||
f = torch.compile(torch._scaled_mm)
|
|
||||||
out_fp8 = f(x_fp8, y_fp8, scale_a, scale_b, out_dtype=out_dtype)
|
|
||||||
self.assertEqual(out_dtype, out_fp8.dtype)
|
|
||||||
self.assertEqual(out_fp32, out_fp8.to(torch.float))
|
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
||||||
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions")
|
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions")
|
||||||
@unittest.skipIf(not _IS_SM8X, "mixed dtypes linear only supported on SM 8.x")
|
@unittest.skipIf(not _IS_SM8X, "mixed dtypes linear only supported on SM 8.x")
|
||||||
|
|
@ -851,7 +330,6 @@ class TestMixedDtypesLinearCuda(TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
instantiate_device_type_tests(TestMatmulCuda, globals(), except_for="cpu")
|
instantiate_device_type_tests(TestMatmulCuda, globals(), except_for="cpu")
|
||||||
instantiate_device_type_tests(TestFP8MatmulCuda, globals(), except_for="cpu")
|
|
||||||
instantiate_device_type_tests(TestMixedDtypesLinearCuda, globals(), except_for="cpu")
|
instantiate_device_type_tests(TestMixedDtypesLinearCuda, globals(), except_for="cpu")
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
||||||
563
test/test_matmul_fp8.py
Normal file
563
test/test_matmul_fp8.py
Normal file
|
|
@ -0,0 +1,563 @@
|
||||||
|
# Owner(s): ["module: linear algebra"]
|
||||||
|
|
||||||
|
import re
|
||||||
|
import unittest
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8
|
||||||
|
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||||
|
from torch.testing._internal.common_utils import (
|
||||||
|
IS_WINDOWS,
|
||||||
|
parametrize,
|
||||||
|
run_tests,
|
||||||
|
skipIfRocm,
|
||||||
|
TEST_CUDA,
|
||||||
|
TestCase,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_IS_SM9X = False
|
||||||
|
if TEST_CUDA:
|
||||||
|
_IS_SM9X = torch.cuda.get_device_capability(0)[0] == 9
|
||||||
|
|
||||||
|
# Protects against includes accidentally setting the default dtype
|
||||||
|
assert torch.get_default_dtype() is torch.float32
|
||||||
|
|
||||||
|
f8_msg = "FP8 is only supported on H100+ and sm_89 and MI300+ devices"
|
||||||
|
|
||||||
|
if torch.version.hip:
|
||||||
|
e4m3_type = torch.float8_e4m3fnuz
|
||||||
|
e5m2_type = torch.float8_e5m2fnuz
|
||||||
|
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max
|
||||||
|
E5M2_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max
|
||||||
|
else:
|
||||||
|
e4m3_type = torch.float8_e4m3fn
|
||||||
|
e5m2_type = torch.float8_e5m2
|
||||||
|
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max
|
||||||
|
E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max
|
||||||
|
|
||||||
|
# avoid division by zero when calculating scale
|
||||||
|
EPS = 1e-12
|
||||||
|
|
||||||
|
|
||||||
|
def amax_to_scale(
|
||||||
|
amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
|
||||||
|
):
|
||||||
|
"""Converts the amax value of a tensor to the fp8 scale.
|
||||||
|
Args:
|
||||||
|
amax: The amax value of the tensor.
|
||||||
|
float8_dtype: the float8 dtype.
|
||||||
|
orig_dtype: The original dtype of the tensor.
|
||||||
|
"""
|
||||||
|
scale = torch.empty_like(amax, dtype=torch.float32)
|
||||||
|
if float8_dtype == e4m3_type:
|
||||||
|
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
|
||||||
|
elif float8_dtype == e5m2_type:
|
||||||
|
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")
|
||||||
|
|
||||||
|
# Ensure the scale is representable in float16,
|
||||||
|
# this helps when amax is small. We are assuming that we don't need
|
||||||
|
# to care about this for float32/bfloat16
|
||||||
|
if orig_dtype is torch.float16:
|
||||||
|
res = torch.clamp(res, max=torch.finfo(torch.float16).max)
|
||||||
|
|
||||||
|
scale.copy_(res)
|
||||||
|
return scale
|
||||||
|
|
||||||
|
|
||||||
|
def tensor_to_scale(x: torch.Tensor, float8_dtype: torch.dtype, dim=None):
|
||||||
|
if dim is None:
|
||||||
|
amax = torch.max(torch.abs(x))
|
||||||
|
else:
|
||||||
|
amax = torch.max(torch.abs(x), dim=dim, keepdim=True).values
|
||||||
|
|
||||||
|
return amax_to_scale(amax, float8_dtype, x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor:
|
||||||
|
# naive implementation: dq -> op -> q
|
||||||
|
x_fp32 = x.to(torch.float) / x_scale
|
||||||
|
y_fp32 = y.to(torch.float) / y_scale
|
||||||
|
out_fp32 = torch.mm(x_fp32, y_fp32)
|
||||||
|
|
||||||
|
return out_fp32.to(out_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def addmm_float8_unwrapped(
|
||||||
|
a_data: torch.Tensor,
|
||||||
|
a_scale: torch.Tensor,
|
||||||
|
b_data: torch.Tensor,
|
||||||
|
b_scale: torch.tensor,
|
||||||
|
output_dtype: torch.dtype,
|
||||||
|
output_scale: Optional[torch.Tensor],
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
a_inverse_scale = a_scale.reciprocal()
|
||||||
|
b_inverse_scale = b_scale.reciprocal()
|
||||||
|
if output_dtype == torch.float32 and bias is not None:
|
||||||
|
# Bias is not supported by _scaled_mm when output is fp32
|
||||||
|
output = torch._scaled_mm(
|
||||||
|
a_data,
|
||||||
|
b_data,
|
||||||
|
scale_a=a_inverse_scale,
|
||||||
|
scale_b=b_inverse_scale,
|
||||||
|
scale_result=output_scale,
|
||||||
|
out_dtype=output_dtype,
|
||||||
|
)
|
||||||
|
output += bias
|
||||||
|
return output
|
||||||
|
output = torch._scaled_mm(
|
||||||
|
a_data,
|
||||||
|
b_data,
|
||||||
|
bias=bias,
|
||||||
|
scale_a=a_inverse_scale,
|
||||||
|
scale_b=b_inverse_scale,
|
||||||
|
scale_result=output_scale,
|
||||||
|
out_dtype=output_dtype,
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def mm_float8(
|
||||||
|
a: torch.Tensor,
|
||||||
|
b: torch.Tensor,
|
||||||
|
a_scale: torch.Tensor,
|
||||||
|
b_scale: torch.Tensor,
|
||||||
|
output_dtype: torch.dtype, # output dtype
|
||||||
|
output_scale: Optional[torch.Tensor] = None, # output scale, precomputed
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return addmm_float8_unwrapped(a, a_scale, b, b_scale, output_dtype, output_scale)
|
||||||
|
|
||||||
|
|
||||||
|
def to_fp8_saturated(x: torch.Tensor, fp8_dtype: torch.dtype):
|
||||||
|
if fp8_dtype == e4m3_type:
|
||||||
|
x = x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS)
|
||||||
|
elif fp8_dtype == e5m2_type:
|
||||||
|
x = x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"to_fp8_saturated(): Unsupported fp8_dtype: {fp8_dtype}")
|
||||||
|
|
||||||
|
return x.to(fp8_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class TestFP8Matmul(TestCase):
|
||||||
|
def _test_tautological_mm(
|
||||||
|
self,
|
||||||
|
device: str = "cuda",
|
||||||
|
x_dtype: torch.dtype = e4m3_type,
|
||||||
|
y_dtype: torch.dtype = e4m3_type,
|
||||||
|
out_dtype: Optional[torch.dtype] = None,
|
||||||
|
size: int = 16,
|
||||||
|
) -> None:
|
||||||
|
if device != "cpu" and torch.cuda.is_available() and not PLATFORM_SUPPORTS_FP8:
|
||||||
|
raise unittest.SkipTest(f8_msg)
|
||||||
|
x_fp8 = torch.rand(size, size, device=device).to(x_dtype)
|
||||||
|
y_fp8 = torch.eye(size, device=device, dtype=y_dtype).t()
|
||||||
|
out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float))
|
||||||
|
scale_a = torch.tensor(1.0, device=device)
|
||||||
|
scale_b = torch.tensor(1.0, device=device)
|
||||||
|
out_fp8 = torch._scaled_mm(x_fp8, y_fp8, scale_a, scale_b, out_dtype=out_dtype)
|
||||||
|
if out_dtype is not None:
|
||||||
|
self.assertEqual(out_dtype, out_fp8.dtype)
|
||||||
|
self.assertEqual(out_fp32, out_fp8.to(torch.float))
|
||||||
|
|
||||||
|
def test_float8_basics(self, device) -> None:
|
||||||
|
if device != "cpu" and torch.cuda.is_available() and not PLATFORM_SUPPORTS_FP8:
|
||||||
|
raise unittest.SkipTest(f8_msg)
|
||||||
|
self._test_tautological_mm(device, e4m3_type, e4m3_type, size=16)
|
||||||
|
# hipblaslt does not yet support mixed e4m3_type input
|
||||||
|
if torch.version.hip is None:
|
||||||
|
if device != "cpu":
|
||||||
|
# TODO: The following 2 tests are mixed dtypes between src and weight,
|
||||||
|
# which will be enabled in oneDNN v3.6 in CPU.
|
||||||
|
self._test_tautological_mm(device, e4m3_type, e5m2_type, size=32)
|
||||||
|
self._test_tautological_mm(device, e5m2_type, e4m3_type, size=48)
|
||||||
|
if device == "cpu":
|
||||||
|
self._test_tautological_mm(device, e5m2_type, e5m2_type)
|
||||||
|
else:
|
||||||
|
# According to https://docs.nvidia.com/cuda/cublas/#id99 8F_E5M2 MM is unsupported
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
self._test_tautological_mm(device, e5m2_type, e5m2_type)
|
||||||
|
|
||||||
|
self._test_tautological_mm(device, size=64, out_dtype=torch.float16)
|
||||||
|
self._test_tautological_mm(device, size=96, out_dtype=torch.float32)
|
||||||
|
# hipblaslt does not yet support bfloat16 output
|
||||||
|
if torch.version.hip is None:
|
||||||
|
self._test_tautological_mm(device, size=80, out_dtype=torch.bfloat16)
|
||||||
|
if device != "cpu":
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
self._test_tautological_mm(device, out_dtype=e5m2_type)
|
||||||
|
else:
|
||||||
|
# TODO: e4m3 and e5m2 naturally has numerical gap, maybe relax the tolerance later.
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
self._test_tautological_mm(device, out_dtype=e5m2_type)
|
||||||
|
|
||||||
|
def test_float8_scale(self, device) -> None:
|
||||||
|
if device != "cpu" and torch.cuda.is_available() and not PLATFORM_SUPPORTS_FP8:
|
||||||
|
raise unittest.SkipTest(f8_msg)
|
||||||
|
size = (16, 16)
|
||||||
|
x = torch.full(size, 0.5, device=device, dtype=e4m3_type)
|
||||||
|
# hipblaslt does not yet support mixed e4m3_type input
|
||||||
|
# TODO: will use e5m2_type after upgrading oneDNN to v3.6.
|
||||||
|
y_type = e4m3_type if torch.version.hip or device == "cpu" else e5m2_type
|
||||||
|
y = torch.full(size, 0.5, device=device, dtype=y_type).t()
|
||||||
|
scale_a = torch.tensor(1.5, device=device)
|
||||||
|
scale_b = torch.tensor(0.66, device=device)
|
||||||
|
out_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b)
|
||||||
|
self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4.0, device=device))
|
||||||
|
out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b)
|
||||||
|
self.assertEqual(out_fp8, out_fp8_s)
|
||||||
|
|
||||||
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||||
|
@parametrize("base_dtype", [torch.float16, torch.bfloat16, torch.float32])
|
||||||
|
def test_scaled_mm_vs_emulated(self, base_dtype):
|
||||||
|
torch.manual_seed(42)
|
||||||
|
input_dtype = e4m3_type
|
||||||
|
output_dtype = base_dtype
|
||||||
|
compare_type = torch.float32
|
||||||
|
|
||||||
|
x = torch.randn(16, 16, device="cuda", dtype=base_dtype)
|
||||||
|
y = torch.randn(32, 16, device="cuda", dtype=base_dtype).t()
|
||||||
|
|
||||||
|
x_scale = tensor_to_scale(x, input_dtype).float()
|
||||||
|
y_scale = tensor_to_scale(y, input_dtype).float()
|
||||||
|
|
||||||
|
x_fp8 = to_fp8_saturated(x * x_scale, input_dtype)
|
||||||
|
y_fp8 = to_fp8_saturated(y * y_scale, input_dtype)
|
||||||
|
|
||||||
|
# Calculate actual F8 mm
|
||||||
|
out_scaled_mm = mm_float8(
|
||||||
|
x_fp8, y_fp8, a_scale=x_scale, b_scale=y_scale, output_dtype=output_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate emulated F8 mm
|
||||||
|
out_emulated = mm_float8_emulated(x_fp8, x_scale, y_fp8, y_scale, output_dtype)
|
||||||
|
|
||||||
|
if output_dtype != base_dtype:
|
||||||
|
out_scaled_mm = out_scaled_mm.to(compare_type)
|
||||||
|
out_scaled_mm = out_scaled_mm / tensor_to_scale(out_scaled_mm, input_dtype)
|
||||||
|
|
||||||
|
out_emulated = out_emulated.to(compare_type)
|
||||||
|
out_emulated = out_emulated / tensor_to_scale(out_emulated, input_dtype)
|
||||||
|
|
||||||
|
if base_dtype in {torch.bfloat16, torch.float16}:
|
||||||
|
atol, rtol = 7e-2, 7e-2
|
||||||
|
else:
|
||||||
|
atol, rtol = 3e-3, 3e-3
|
||||||
|
|
||||||
|
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||||
|
@parametrize("base_dtype", [torch.float16, torch.bfloat16, torch.float32])
|
||||||
|
def test_scaled_mm_change_stride(self, base_dtype):
|
||||||
|
torch.manual_seed(42)
|
||||||
|
input_dtype = e4m3_type
|
||||||
|
output_dtype = base_dtype
|
||||||
|
compare_type = torch.float32
|
||||||
|
|
||||||
|
x = torch.empty_strided((16, 16), (16, 1), device="cuda", dtype=base_dtype)
|
||||||
|
y = torch.empty_strided((16, 32), (1, 64), device="cuda", dtype=base_dtype)
|
||||||
|
|
||||||
|
x_scale = tensor_to_scale(x, input_dtype).float()
|
||||||
|
y_scale = tensor_to_scale(y, input_dtype).float()
|
||||||
|
|
||||||
|
x_fp8 = to_fp8_saturated(x * x_scale, input_dtype)
|
||||||
|
y_fp8 = to_fp8_saturated(y * y_scale, input_dtype)
|
||||||
|
|
||||||
|
# Calculate actual F8 mm
|
||||||
|
out_scaled_mm = mm_float8(
|
||||||
|
x_fp8, y_fp8, a_scale=x_scale, b_scale=y_scale, output_dtype=output_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate emulated F8 mm
|
||||||
|
out_emulated = mm_float8_emulated(x_fp8, x_scale, y_fp8, y_scale, output_dtype)
|
||||||
|
|
||||||
|
if output_dtype != base_dtype:
|
||||||
|
out_scaled_mm = out_scaled_mm.to(compare_type)
|
||||||
|
out_scaled_mm = out_scaled_mm / tensor_to_scale(out_scaled_mm, input_dtype)
|
||||||
|
|
||||||
|
out_emulated = out_emulated.to(compare_type)
|
||||||
|
out_emulated = out_emulated / tensor_to_scale(out_emulated, input_dtype)
|
||||||
|
|
||||||
|
if base_dtype in {torch.bfloat16, torch.float16}:
|
||||||
|
atol, rtol = 7e-2, 7e-2
|
||||||
|
else:
|
||||||
|
atol, rtol = 3e-3, 3e-3
|
||||||
|
|
||||||
|
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
|
def test_float8_bias(self, device) -> None:
|
||||||
|
if device != "cpu" and torch.cuda.is_available() and not PLATFORM_SUPPORTS_FP8:
|
||||||
|
raise unittest.SkipTest(f8_msg)
|
||||||
|
(k, l, m) = (16, 48, 32)
|
||||||
|
x = torch.ones((k, l), device=device).to(e4m3_type)
|
||||||
|
y = torch.full((m, l), 0.25, device=device, dtype=e4m3_type).t()
|
||||||
|
bias = torch.full((m,), 4.0, device=device, dtype=torch.half)
|
||||||
|
scale_a = torch.tensor(1.0, device=device)
|
||||||
|
scale_b = torch.tensor(1.0, device=device)
|
||||||
|
out_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b)
|
||||||
|
outb_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, bias=bias)
|
||||||
|
# this fails on ROCm currently because hipblaslt doesn't have amax op
|
||||||
|
out_fp32 = out_fp8.to(torch.float32)
|
||||||
|
outb_fp32 = outb_fp8.to(torch.float32)
|
||||||
|
difference = torch.abs(out_fp32 - outb_fp32)
|
||||||
|
self.assertEqual(
|
||||||
|
difference, torch.tensor(4.0, device=device).expand_as(out_fp32)
|
||||||
|
)
|
||||||
|
|
||||||
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||||
|
@parametrize("bias", [True, False])
|
||||||
|
def test_non_divisible_leading_dim(self, device, bias: bool) -> None:
|
||||||
|
x = torch.rand((17, 16), device=device).to(e4m3_type)
|
||||||
|
y = torch.rand((16, 16), device=device).to(e4m3_type).t()
|
||||||
|
scale_a = torch.tensor(1.0, device=device)
|
||||||
|
scale_b = torch.tensor(1.0, device=device)
|
||||||
|
input_bias = None
|
||||||
|
if bias:
|
||||||
|
input_bias = torch.rand((16,), device=device).to(torch.half)
|
||||||
|
_ = torch._scaled_mm(x, y, scale_a, scale_b, bias=input_bias)
|
||||||
|
|
||||||
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||||
|
def test_float8_bias_relu_edgecase(self, device) -> None:
|
||||||
|
(k, l, m) = (16, 48, 32)
|
||||||
|
x = torch.full((k, l), 0.0, device=device).to(e4m3_type)
|
||||||
|
y = torch.full((m, l), 1.0, device=device, dtype=e4m3_type).t()
|
||||||
|
bias = torch.full((m,), -3.0, device=device, dtype=torch.half)
|
||||||
|
scale_a = torch.tensor(1.0, device=device)
|
||||||
|
scale_b = torch.tensor(1.0, device=device)
|
||||||
|
outb_fp8 = torch._scaled_mm(x, y, scale_a, scale_b, bias=bias)
|
||||||
|
outb_fp32 = outb_fp8.to(torch.float32)
|
||||||
|
self.assertEqual(
|
||||||
|
outb_fp32, torch.tensor(-3.0, device=device).expand_as(outb_fp32)
|
||||||
|
)
|
||||||
|
|
||||||
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||||
|
def test_float32_output_errors_with_bias(self, device) -> None:
|
||||||
|
(k, l, m) = (16, 48, 32)
|
||||||
|
x = torch.rand((k, l), device=device).to(e4m3_type)
|
||||||
|
y = torch.full((m, l), 0.25, device=device, dtype=e4m3_type).t()
|
||||||
|
scale_a = torch.tensor(1.0, device=device)
|
||||||
|
scale_b = torch.tensor(1.0, device=device)
|
||||||
|
bias = torch.full((m,), 4.0, device=device, dtype=torch.bfloat16)
|
||||||
|
self.assertRaisesRegex(
|
||||||
|
RuntimeError,
|
||||||
|
"Bias is not supported when out_dtype is set to Float32",
|
||||||
|
lambda: torch._scaled_mm(
|
||||||
|
x, y, scale_a, scale_b, bias=bias, out_dtype=torch.float32
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@unittest.skipIf(
|
||||||
|
PLATFORM_SUPPORTS_FP8 or not torch.cuda.is_available(),
|
||||||
|
"This test is only for devices with compute capability < 8.9",
|
||||||
|
)
|
||||||
|
def test_error_message_fp8_pre_sm89(self, device) -> None:
|
||||||
|
(k, l, m) = (16, 48, 32)
|
||||||
|
x = torch.rand((k, l), device=device).to(e4m3_type)
|
||||||
|
y = torch.rand((m, l), device=device).to(e4m3_type).t()
|
||||||
|
scale_a = torch.tensor(1.0, device=device)
|
||||||
|
scale_b = torch.tensor(1.0, device=device)
|
||||||
|
self.assertRaisesRegex(
|
||||||
|
RuntimeError,
|
||||||
|
r"torch\.\_scaled\_mm is only supported on CUDA devices with compute capability \>\= 9\.0 or 8\.9, or ROCm MI300\+",
|
||||||
|
lambda: torch._scaled_mm(x, y, scale_a, scale_b, out_dtype=torch.float32),
|
||||||
|
)
|
||||||
|
|
||||||
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||||
|
def test_float8_scale_fast_accum(self, device) -> None:
|
||||||
|
size = (16, 16)
|
||||||
|
x = torch.full(size, 0.5, device=device, dtype=e4m3_type)
|
||||||
|
# hipblaslt does not yet support mixed e4m3_type input
|
||||||
|
y_type = e4m3_type if torch.version.hip else e5m2_type
|
||||||
|
y = torch.full(size, 0.5, device=device, dtype=y_type).t()
|
||||||
|
scale_a = torch.tensor(1.5, device=device)
|
||||||
|
scale_b = torch.tensor(0.66, device=device)
|
||||||
|
out_fp8 = torch._scaled_mm(x, y, scale_a, scale_b, use_fast_accum=True)
|
||||||
|
self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4.0, device=device))
|
||||||
|
out_fp8_s = torch._scaled_mm(
|
||||||
|
x, y, scale_a=scale_a, scale_b=scale_b, use_fast_accum=True
|
||||||
|
)
|
||||||
|
self.assertEqual(out_fp8, out_fp8_s)
|
||||||
|
|
||||||
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
|
||||||
|
@unittest.skipIf(not _IS_SM9X, "rowwise implementation is currently sm90 specific")
|
||||||
|
@skipIfRocm()
|
||||||
|
@parametrize("use_fast_accum", [True, False])
|
||||||
|
def test_float8_rowwise_scaling_sanity(self, device, use_fast_accum: bool) -> None:
|
||||||
|
M, K, N = (1024, 512, 2048)
|
||||||
|
fill_value = 0.5
|
||||||
|
x = torch.full((M, K), fill_value, device=device)
|
||||||
|
y = torch.full((N, K), fill_value, device=device)
|
||||||
|
|
||||||
|
x_scales = torch.ones((x.shape[0], 1), device=device, dtype=torch.float32)
|
||||||
|
y_scales = torch.ones((1, y.shape[0]), device=device, dtype=torch.float32)
|
||||||
|
|
||||||
|
x_fp8 = x.to(torch.float8_e4m3fn)
|
||||||
|
y_fp8 = y.to(torch.float8_e4m3fn).t()
|
||||||
|
|
||||||
|
out_fp8 = torch._scaled_mm(
|
||||||
|
x_fp8,
|
||||||
|
y_fp8,
|
||||||
|
scale_a=x_scales,
|
||||||
|
scale_b=y_scales,
|
||||||
|
out_dtype=torch.bfloat16,
|
||||||
|
use_fast_accum=use_fast_accum,
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
out_fp8.to(torch.float32),
|
||||||
|
torch.full((M, N), K * (fill_value**2), device=device),
|
||||||
|
)
|
||||||
|
|
||||||
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
|
||||||
|
@skipIfRocm()
|
||||||
|
def test_float8_error_messages(self, device) -> None:
|
||||||
|
M, K, N = (1024, 512, 2048)
|
||||||
|
fill_value = 0.5
|
||||||
|
x = torch.full((M, K), fill_value, device=device)
|
||||||
|
y = torch.full((N, K), fill_value, device=device)
|
||||||
|
|
||||||
|
x_fp8 = x.to(torch.float8_e4m3fn)
|
||||||
|
y_fp8 = y.to(torch.float8_e4m3fn).t()
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError,
|
||||||
|
re.escape(
|
||||||
|
"For RowWise scaling, scale_a should be (1024, 1) and scale_b "
|
||||||
|
"should be (1, 2048). Got scale_a.size()=(1, 1) and scale_b.size()=(1, 2)"
|
||||||
|
),
|
||||||
|
):
|
||||||
|
torch._scaled_mm(
|
||||||
|
x_fp8,
|
||||||
|
y_fp8,
|
||||||
|
scale_a=torch.ones((1, 1), device="cuda"),
|
||||||
|
scale_b=torch.ones((1, 2), device="cuda"),
|
||||||
|
out_dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError,
|
||||||
|
re.escape(
|
||||||
|
" For RowWise scaling, scale_a should be (1024, 1) and scale_b "
|
||||||
|
"should be (1, 2048). Got scale_a.size()=(1024, 1) and scale_b.size()=(1, 2049)"
|
||||||
|
),
|
||||||
|
):
|
||||||
|
torch._scaled_mm(
|
||||||
|
x_fp8,
|
||||||
|
y_fp8,
|
||||||
|
scale_a=torch.ones((M, 1), device="cuda"),
|
||||||
|
scale_b=torch.ones((1, N + 1), device="cuda"),
|
||||||
|
out_dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError,
|
||||||
|
re.escape(
|
||||||
|
"For non-TensorWise scaling, scale tensors must be 2-dimensional"
|
||||||
|
),
|
||||||
|
):
|
||||||
|
torch._scaled_mm(
|
||||||
|
x_fp8,
|
||||||
|
y_fp8,
|
||||||
|
scale_a=torch.ones((M), device="cuda"),
|
||||||
|
scale_b=torch.ones((N, N), device="cuda"),
|
||||||
|
out_dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError,
|
||||||
|
re.escape(
|
||||||
|
"Both scale_a and scale_b must be contiguous for RowWise scaling."
|
||||||
|
),
|
||||||
|
):
|
||||||
|
torch._scaled_mm(
|
||||||
|
x_fp8,
|
||||||
|
y_fp8,
|
||||||
|
scale_a=torch.ones((M, 1), device="cuda"),
|
||||||
|
scale_b=torch.ones((1, N * 2), device="cuda")[:, ::2],
|
||||||
|
out_dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError,
|
||||||
|
re.escape(
|
||||||
|
"Expected b.dtype() == at::kFloat8_e4m3fn to be true, but got false."
|
||||||
|
),
|
||||||
|
):
|
||||||
|
torch._scaled_mm(
|
||||||
|
x_fp8,
|
||||||
|
y_fp8.to(torch.float8_e5m2),
|
||||||
|
scale_a=torch.ones((M, 1), device="cuda"),
|
||||||
|
scale_b=torch.ones((1, N), device="cuda"),
|
||||||
|
out_dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
|
||||||
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
|
||||||
|
@unittest.skipIf(not _IS_SM9X, "rowwise implementation is currently sm90 specific")
|
||||||
|
@skipIfRocm()
|
||||||
|
@parametrize("base_dtype", [torch.bfloat16])
|
||||||
|
def test_scaled_mm_vs_emulated_row_wise(self, base_dtype):
|
||||||
|
torch.manual_seed(42)
|
||||||
|
input_dtype = e4m3_type
|
||||||
|
output_dtype = base_dtype
|
||||||
|
|
||||||
|
x = torch.randn(16, 16, device="cuda", dtype=base_dtype)
|
||||||
|
y = torch.randn(32, 16, device="cuda", dtype=base_dtype).t()
|
||||||
|
|
||||||
|
x_scales = tensor_to_scale(x, input_dtype, dim=1).float()
|
||||||
|
y_scales = tensor_to_scale(y, input_dtype, dim=0).float()
|
||||||
|
|
||||||
|
x_fp8 = to_fp8_saturated(x * x_scales, e4m3_type)
|
||||||
|
y_fp8 = to_fp8_saturated(y * y_scales, e4m3_type)
|
||||||
|
|
||||||
|
# Calculate actual F8 mm
|
||||||
|
out_scaled_mm = mm_float8(
|
||||||
|
x_fp8, y_fp8, a_scale=x_scales, b_scale=y_scales, output_dtype=output_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate emulated F8 mm
|
||||||
|
out_emulated = mm_float8_emulated(
|
||||||
|
x_fp8, x_scales, y_fp8, y_scales, output_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
if base_dtype in {torch.bfloat16, torch.float16}:
|
||||||
|
atol, rtol = 7e-2, 7e-2
|
||||||
|
else:
|
||||||
|
atol, rtol = 2e-3, 2e-3
|
||||||
|
|
||||||
|
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||||
|
@parametrize("which_dim_zero", [0, 1, 2])
|
||||||
|
@parametrize("use_torch_compile", [False, True])
|
||||||
|
def test_zero_dim_tensorwise(self, which_dim_zero, use_torch_compile) -> None:
|
||||||
|
device = "cuda"
|
||||||
|
x_dtype, y_dtype = torch.float8_e4m3fn, torch.float8_e4m3fn
|
||||||
|
out_dtype = torch.bfloat16
|
||||||
|
M, K, N = 32, 32, 32
|
||||||
|
if which_dim_zero == 0:
|
||||||
|
M = 0
|
||||||
|
elif which_dim_zero == 1:
|
||||||
|
K = 0
|
||||||
|
elif which_dim_zero == 2:
|
||||||
|
N = 0
|
||||||
|
|
||||||
|
x_fp8 = torch.zeros(M, K, device=device).to(x_dtype)
|
||||||
|
y_fp8 = torch.zeros(N, K, device=device, dtype=y_dtype).t()
|
||||||
|
out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float))
|
||||||
|
scale_a = torch.tensor(float("-inf"), device=device)
|
||||||
|
scale_b = torch.tensor(float("-inf"), device=device)
|
||||||
|
f = torch._scaled_mm
|
||||||
|
if use_torch_compile:
|
||||||
|
f = torch.compile(torch._scaled_mm)
|
||||||
|
out_fp8 = f(x_fp8, y_fp8, scale_a, scale_b, out_dtype=out_dtype)
|
||||||
|
self.assertEqual(out_dtype, out_fp8.dtype)
|
||||||
|
self.assertEqual(out_fp32, out_fp8.to(torch.float))
|
||||||
|
|
||||||
|
|
||||||
|
instantiate_device_type_tests(TestFP8Matmul, globals())
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
TestCase._default_dtype_check_enabled = True
|
||||||
|
run_tests()
|
||||||
|
|
@ -34,6 +34,8 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__pdist_backward(AtenTensorHandle
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__pdist_forward(AtenTensorHandle self, double p, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__pdist_forward(AtenTensorHandle self, double p, AtenTensorHandle* ret0);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_flash_attention_for_cpu(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, double dropout_p, int32_t is_causal, AtenTensorHandle* attn_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_flash_attention_for_cpu(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, double dropout_p, int32_t is_causal, AtenTensorHandle* attn_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_flash_attention_for_cpu_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle out, AtenTensorHandle logsumexp, double dropout_p, int32_t is_causal, AtenTensorHandle* attn_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_flash_attention_for_cpu_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle out, AtenTensorHandle logsumexp, double dropout_p, int32_t is_causal, AtenTensorHandle* attn_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2);
|
||||||
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum, AtenTensorHandle* ret0);
|
||||||
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__segment_reduce_backward(AtenTensorHandle grad, AtenTensorHandle output, AtenTensorHandle data, const char* reduce, AtenTensorHandle* lengths, AtenTensorHandle* offsets, int64_t axis, double* initial, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__segment_reduce_backward(AtenTensorHandle grad, AtenTensorHandle output, AtenTensorHandle data, const char* reduce, AtenTensorHandle* lengths, AtenTensorHandle* offsets, int64_t axis, double* initial, AtenTensorHandle* ret0);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0);
|
||||||
|
|
|
||||||
|
|
@ -1008,6 +1008,8 @@ ANY_DTYPE_ORDER = (
|
||||||
torch.int8,
|
torch.int8,
|
||||||
torch.uint8,
|
torch.uint8,
|
||||||
torch.bool,
|
torch.bool,
|
||||||
|
torch.float8_e4m3fn,
|
||||||
|
torch.float8_e5m2,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ from torch.testing import make_tensor
|
||||||
from torch.testing._internal.common_dtype import (
|
from torch.testing._internal.common_dtype import (
|
||||||
_dispatch_dtypes, floating_types, floating_types_and, complex_types, floating_and_complex_types,
|
_dispatch_dtypes, floating_types, floating_types_and, complex_types, floating_and_complex_types,
|
||||||
floating_and_complex_types_and, all_types_and_complex_and, all_types_and, all_types_and_complex, integral_types_and,
|
floating_and_complex_types_and, all_types_and_complex_and, all_types_and, all_types_and_complex, integral_types_and,
|
||||||
empty_types, complex_types_and, integral_types, custom_types, all_types_complex_float8_and,
|
empty_types, complex_types_and, integral_types, custom_types, all_types_complex_float8_and, float8_types,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_device_type import \
|
from torch.testing._internal.common_device_type import \
|
||||||
(onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver,
|
(onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver,
|
||||||
|
|
@ -8734,18 +8734,27 @@ def sample_inputs_scaled_mm(op_info, device, dtype, requires_grad, **kwargs):
|
||||||
scale1 = make_scale((1,))
|
scale1 = make_scale((1,))
|
||||||
scale2 = make_scale((1,))
|
scale2 = make_scale((1,))
|
||||||
samples.append(SampleInput(mat1, mat2, scale1, scale2))
|
samples.append(SampleInput(mat1, mat2, scale1, scale2))
|
||||||
# mat1 e4m3 mat2 e5m2
|
# two e5m2
|
||||||
mat1 = make_mat_e4m3((M, K))
|
mat1 = make_mat_e5m2((M, K))
|
||||||
mat2 = make_mat_e5m2((K, N)).t().contiguous().t()
|
mat2 = make_mat_e5m2((K, N)).t().contiguous().t()
|
||||||
scale1 = make_scale((1,))
|
scale1 = make_scale((1,))
|
||||||
scale2 = make_scale((1,))
|
scale2 = make_scale((1,))
|
||||||
samples.append(SampleInput(mat1, mat2, scale1, scale2))
|
samples.append(SampleInput(mat1, mat2, scale1, scale2))
|
||||||
# mat1 e5m2 mat2 e4m3
|
# TODO: Will remove this after oneDNN v3.6
|
||||||
mat1 = make_mat_e5m2((M, K))
|
# now oneDNN v3.5.3 only supports mat1 * mat2 with the same data types.
|
||||||
mat2 = make_mat_e4m3((K, N)).t().contiguous().t()
|
if device != 'cpu':
|
||||||
scale1 = make_scale((1,))
|
# mat1 e4m3 mat2 e5m2
|
||||||
scale2 = make_scale((1,))
|
mat1 = make_mat_e4m3((M, K))
|
||||||
samples.append(SampleInput(mat1, mat2, scale1, scale2))
|
mat2 = make_mat_e5m2((K, N)).t().contiguous().t()
|
||||||
|
scale1 = make_scale((1,))
|
||||||
|
scale2 = make_scale((1,))
|
||||||
|
samples.append(SampleInput(mat1, mat2, scale1, scale2))
|
||||||
|
# mat1 e5m2 mat2 e4m3
|
||||||
|
mat1 = make_mat_e5m2((M, K))
|
||||||
|
mat2 = make_mat_e4m3((K, N)).t().contiguous().t()
|
||||||
|
scale1 = make_scale((1,))
|
||||||
|
scale2 = make_scale((1,))
|
||||||
|
samples.append(SampleInput(mat1, mat2, scale1, scale2))
|
||||||
|
|
||||||
yield from samples
|
yield from samples
|
||||||
|
|
||||||
|
|
@ -16210,20 +16219,28 @@ op_db: List[OpInfo] = [
|
||||||
OpInfo(
|
OpInfo(
|
||||||
'torch._scaled_mm',
|
'torch._scaled_mm',
|
||||||
sample_inputs_func=sample_inputs_scaled_mm,
|
sample_inputs_func=sample_inputs_scaled_mm,
|
||||||
dtypes=empty_types(),
|
dtypes=float8_types(),
|
||||||
dtypesIfCUDA=empty_types() + (torch.float8_e4m3fn,),
|
dtypesIfCUDA=empty_types() + (torch.float8_e4m3fn,),
|
||||||
supports_out=True,
|
supports_out=True,
|
||||||
supports_forward_ad=False,
|
supports_forward_ad=False,
|
||||||
supports_autograd=False,
|
supports_autograd=False,
|
||||||
decorators=[skipCUDAIf(not SM90OrLater or TEST_WITH_ROCM, 'Requires CUDA SM >= 9.0')],
|
decorators=[skipCUDAIf(not SM90OrLater or TEST_WITH_ROCM, 'Requires CUDA SM >= 9.0'), ],
|
||||||
skips=(
|
skips=(
|
||||||
# Sample inputs isn't really parametrized on dtype
|
# Sample inputs isn't really parametrized on dtype
|
||||||
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes',
|
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes'),
|
||||||
device_type='cuda'),
|
# "add_stub" not implemented for 'Float8_e4m3fn'
|
||||||
# "mul_cuda" not implemented for float8_e4m3fn
|
|
||||||
# https://github.com/pytorch/pytorch/issues/107256
|
# https://github.com/pytorch/pytorch/issues/107256
|
||||||
DecorateInfo(unittest.skip("Skipped!"), 'TestSchemaCheckModeOpInfo', 'test_schema_correctness',
|
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out',
|
||||||
dtypes=(torch.float8_e4m3fn,)),
|
device_type='cpu'),
|
||||||
|
# "mul_cuda" not implemented for float8_e4m3fn
|
||||||
|
# "mul_cpu_reduced_float" not implemented for 'Float8_e4m3fn'
|
||||||
|
# https://github.com/pytorch/pytorch/issues/107256
|
||||||
|
DecorateInfo(unittest.skip("Skipped!"), 'TestSchemaCheckModeOpInfo', 'test_schema_correctness'),
|
||||||
|
# aten::_scaled_mm hit the vmap fallback which is currently disabled
|
||||||
|
DecorateInfo(unittest.skip("Skipped!"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule",
|
||||||
|
device_type='cpu'),
|
||||||
|
DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness',
|
||||||
|
dtypes=(torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz)),
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
OpInfo(
|
OpInfo(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user