mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add torch._scaled_mm for CPU (#150410)
This PR is the duplicated one for https://github.com/pytorch/pytorch/pull/139975. This PR is to add torch._scaled_mm for CPU backend. _scaled_mm_out_cpu and _scaled_mm_cpu are new added and included in torch._scaled_mm CPU dispatch. We also add _scaled_mm_out_cpu_emulated as a fallback function if the current platform cannot run FP8 matmul using oneDNN. And this PR also updates the various UTs related to FP8 to support CPU tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150410 Approved by: https://github.com/atalman
This commit is contained in:
parent
24ca7e91e6
commit
1e92579126
|
|
@ -7,6 +7,11 @@
|
||||||
#include <ATen/Config.h>
|
#include <ATen/Config.h>
|
||||||
|
|
||||||
#include <ATen/native/mkldnn/Matmul.h>
|
#include <ATen/native/mkldnn/Matmul.h>
|
||||||
|
#include <ATen/native/mkldnn/Linear.h>
|
||||||
|
#include <ATen/native/Resize.h>
|
||||||
|
#if !defined(__s390x__) && !defined(__powerpc__)
|
||||||
|
#include <cpuinfo.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
#ifndef AT_PER_OPERATOR_HEADERS
|
#ifndef AT_PER_OPERATOR_HEADERS
|
||||||
#include <ATen/CPUFunctions.h>
|
#include <ATen/CPUFunctions.h>
|
||||||
|
|
@ -24,6 +29,9 @@
|
||||||
#include <ATen/ops/mv_native.h>
|
#include <ATen/ops/mv_native.h>
|
||||||
#include <ATen/ops/scalar_tensor_native.h>
|
#include <ATen/ops/scalar_tensor_native.h>
|
||||||
#include <ATen/ops/vdot_native.h>
|
#include <ATen/ops/vdot_native.h>
|
||||||
|
#include <ATen/ops/_scaled_mm_native.h>
|
||||||
|
#include <ATen/ops/mul.h>
|
||||||
|
#include <ATen/ops/matmul.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace at::meta {
|
namespace at::meta {
|
||||||
|
|
@ -222,4 +230,92 @@ Tensor vdot(const Tensor &self, const Tensor &other){
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Tensor&
|
||||||
|
_scaled_mm_out_cpu_emulated(const Tensor& mat1, const Tensor& mat2,
|
||||||
|
const Tensor& scale_a,
|
||||||
|
const Tensor& scale_b,
|
||||||
|
const std::optional<at::Tensor>& bias,
|
||||||
|
const std::optional<at::Tensor>& scale_result,
|
||||||
|
std::optional<c10::ScalarType> out_dtype,
|
||||||
|
bool use_fast_accum,
|
||||||
|
Tensor& out) {
|
||||||
|
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix");
|
||||||
|
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
|
||||||
|
TORCH_CHECK(
|
||||||
|
mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
|
||||||
|
mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
|
||||||
|
|
||||||
|
TORCH_INTERNAL_ASSERT((scale_a.numel() == 1 && scale_b.numel() == 1), "Now _scaled_mm only supports per-tensor scaling for CPU backend.");
|
||||||
|
TORCH_CHECK(!bias || bias->numel() == mat2.sizes()[1], "Bias must be size ", mat2.sizes()[1],
|
||||||
|
" but got ", bias->numel());
|
||||||
|
|
||||||
|
// Check types
|
||||||
|
TORCH_CHECK(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type");
|
||||||
|
TORCH_CHECK(isFloat8Type(mat1.scalar_type()), "Expected mat1 to be Float8 matrix got ", mat1.scalar_type());
|
||||||
|
TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat2.scalar_type());
|
||||||
|
|
||||||
|
auto mat1_c = mat1.contiguous();
|
||||||
|
auto mat2_c = mat2.contiguous();
|
||||||
|
IntArrayRef mat1_sizes = mat1_c.sizes();
|
||||||
|
IntArrayRef mat2_sizes = mat2_c.sizes();
|
||||||
|
at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
|
||||||
|
|
||||||
|
float input_scale = scale_a.item<float>();
|
||||||
|
float weight_scale = scale_b.item<float>();
|
||||||
|
auto fp32_mat1 = at::mul(mat1.to(kFloat), input_scale);
|
||||||
|
auto fp32_mat2 = at::mul(mat2_c.to(kFloat), weight_scale);
|
||||||
|
auto out_tmp = at::matmul(fp32_mat1, fp32_mat2);
|
||||||
|
if (bias) {
|
||||||
|
out_tmp.add_(bias.value());
|
||||||
|
}
|
||||||
|
out_tmp = out_tmp.to(out.scalar_type());
|
||||||
|
out.copy_(out_tmp);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor&
|
||||||
|
_scaled_mm_out_cpu(const Tensor& mat1, const Tensor& mat2,
|
||||||
|
const Tensor& scale_a,
|
||||||
|
const Tensor& scale_b,
|
||||||
|
const std::optional<at::Tensor>& bias,
|
||||||
|
const std::optional<at::Tensor>& scale_result,
|
||||||
|
std::optional<c10::ScalarType> out_dtype,
|
||||||
|
bool use_fast_accum,
|
||||||
|
Tensor& out) {
|
||||||
|
#if AT_MKLDNN_ENABLED()
|
||||||
|
if (at::globalContext().userEnabledMkldnn()) {
|
||||||
|
bool mixed_dtype = mat1.scalar_type() != mat2.scalar_type();
|
||||||
|
if ((!mixed_dtype && cpuinfo_has_x86_amx_int8()) ||
|
||||||
|
(mixed_dtype && cpuinfo_has_x86_amx_fp16())) {
|
||||||
|
return mkldnn_scaled_mm(
|
||||||
|
mat1,
|
||||||
|
mat2,
|
||||||
|
scale_a,
|
||||||
|
scale_b,
|
||||||
|
bias,
|
||||||
|
scale_result,
|
||||||
|
out_dtype,
|
||||||
|
use_fast_accum,
|
||||||
|
out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
{
|
||||||
|
return _scaled_mm_out_cpu_emulated(mat1, mat2, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor
|
||||||
|
_scaled_mm_cpu(const Tensor& mat_a, const Tensor& mat_b,
|
||||||
|
const Tensor& scale_a,
|
||||||
|
const Tensor& scale_b,
|
||||||
|
const std::optional<at::Tensor>& bias,
|
||||||
|
const std::optional<at::Tensor>& scale_result,
|
||||||
|
std::optional<c10::ScalarType> out_dtype,
|
||||||
|
bool use_fast_accum) {
|
||||||
|
const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type());
|
||||||
|
Tensor out = at::empty({0}, mat_a.options().dtype(out_dtype_));
|
||||||
|
return _scaled_mm_out_cpu(mat_a, mat_b, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace at::native
|
} // namespace at::native
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@
|
||||||
#include <ATen/core/Tensor.h>
|
#include <ATen/core/Tensor.h>
|
||||||
#include <torch/library.h>
|
#include <torch/library.h>
|
||||||
#include <ATen/native/mkldnn/Linear.h>
|
#include <ATen/native/mkldnn/Linear.h>
|
||||||
|
#include <ATen/native/Resize.h>
|
||||||
|
|
||||||
#ifndef AT_PER_OPERATOR_HEADERS
|
#ifndef AT_PER_OPERATOR_HEADERS
|
||||||
#include <ATen/Functions.h>
|
#include <ATen/Functions.h>
|
||||||
|
|
@ -46,8 +47,19 @@ std::tuple<Tensor, Tensor, Tensor> mkldnn_linear_backward(
|
||||||
TORCH_CHECK(false, "mkldnn_linear_backward: ATen not compiled with MKLDNN support");
|
TORCH_CHECK(false, "mkldnn_linear_backward: ATen not compiled with MKLDNN support");
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace at::native
|
Tensor&
|
||||||
|
mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2,
|
||||||
|
const Tensor& scale_a,
|
||||||
|
const Tensor& scale_b,
|
||||||
|
const std::optional<at::Tensor>& bias,
|
||||||
|
const std::optional<at::Tensor>& scale_result,
|
||||||
|
std::optional<c10::ScalarType> out_dtype,
|
||||||
|
bool use_fast_accum,
|
||||||
|
Tensor& out) {
|
||||||
|
TORCH_INTERNAL_ASSERT(false, "mkldnn_scaled_mm: ATen not compiled with MKLDNN support");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace at::native
|
||||||
|
|
||||||
#else // AT_MKLDNN_ENABLED
|
#else // AT_MKLDNN_ENABLED
|
||||||
|
|
||||||
|
|
@ -459,6 +471,118 @@ TORCH_LIBRARY_IMPL(mkldnn, MkldnnCPU, m) {
|
||||||
TORCH_FN(mkldnn_linear_pointwise_binary));
|
TORCH_FN(mkldnn_linear_pointwise_binary));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Tensor&
|
||||||
|
mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2,
|
||||||
|
const Tensor& scale_a,
|
||||||
|
const Tensor& scale_b,
|
||||||
|
const std::optional<at::Tensor>& bias,
|
||||||
|
const std::optional<at::Tensor>& scale_result,
|
||||||
|
std::optional<c10::ScalarType> out_dtype,
|
||||||
|
bool use_fast_accum,
|
||||||
|
Tensor& out) {
|
||||||
|
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix");
|
||||||
|
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
|
||||||
|
TORCH_CHECK(
|
||||||
|
mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
|
||||||
|
mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
|
||||||
|
|
||||||
|
TORCH_INTERNAL_ASSERT((scale_a.numel() == 1 && scale_b.numel() == 1), "Now _scaled_mm only supports per-tensor scaling for CPU backend.");
|
||||||
|
TORCH_CHECK(!bias || bias->numel() == mat2.sizes()[1], "Bias must be size ", mat2.sizes()[1],
|
||||||
|
" but got ", bias->numel());
|
||||||
|
|
||||||
|
// Check types
|
||||||
|
TORCH_CHECK(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type");
|
||||||
|
TORCH_CHECK(isFloat8Type(mat1.scalar_type()), "Expected mat1 to be Float8 matrix got ", mat1.scalar_type());
|
||||||
|
TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat2.scalar_type());
|
||||||
|
|
||||||
|
// Validation checks have passed lets resize the output to actual size
|
||||||
|
auto mat1_c = mat1.contiguous();
|
||||||
|
auto mat2_c = mat2.contiguous();
|
||||||
|
IntArrayRef mat1_sizes = mat1_c.sizes();
|
||||||
|
IntArrayRef mat2_sizes = mat2_c.sizes();
|
||||||
|
at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
|
||||||
|
|
||||||
|
float input_scale = scale_a.item<float>();
|
||||||
|
float weight_scale = scale_b.item<float>();
|
||||||
|
auto src = at::native::itensor_view_from_dense(mat1_c);
|
||||||
|
auto weight_t = at::native::itensor_view_from_dense(mat2_c);
|
||||||
|
bool with_bias = bias.has_value();
|
||||||
|
int64_t K = mat1_sizes[1], M = mat1_sizes[0],
|
||||||
|
N = mat2_sizes[1];
|
||||||
|
|
||||||
|
std::vector<int64_t> src_dims = {M, K};
|
||||||
|
std::vector<int64_t> weight_dims = {K, N};
|
||||||
|
std::vector<int64_t> dst_dims = {M, N};
|
||||||
|
|
||||||
|
ideep::tensor dst = at::native::itensor_view_from_dense(out);
|
||||||
|
auto src_desc = ideep::tensor::desc(
|
||||||
|
src_dims,
|
||||||
|
get_mkldnn_dtype(mat1.scalar_type()),
|
||||||
|
ideep::format_tag::any);
|
||||||
|
auto weights_desc = ideep::tensor::desc(
|
||||||
|
weight_dims,
|
||||||
|
get_mkldnn_dtype(mat2.scalar_type()),
|
||||||
|
ideep::format_tag::any);
|
||||||
|
auto dst_desc = ideep::tensor::desc(
|
||||||
|
dst_dims,
|
||||||
|
get_mkldnn_dtype(out.scalar_type()),
|
||||||
|
ideep::format_tag::any);
|
||||||
|
ideep::tensor onednn_bias;
|
||||||
|
if (with_bias) {
|
||||||
|
auto bias_value = bias.value();
|
||||||
|
if (bias_value.dim() == 1) {
|
||||||
|
auto b_reshape = bias_value.reshape({1, bias_value.size(0)});
|
||||||
|
onednn_bias = at::native::itensor_view_from_dense(b_reshape);
|
||||||
|
} else {
|
||||||
|
onednn_bias = at::native::itensor_view_from_dense(bias_value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto bias_desc = ideep::tensor::desc();
|
||||||
|
if (with_bias) {
|
||||||
|
bias_desc = ideep::tensor::desc(onednn_bias.get_dims(),
|
||||||
|
get_mkldnn_dtype(bias.value().scalar_type()),
|
||||||
|
ideep::format_tag::any);
|
||||||
|
}
|
||||||
|
auto op_attr = ideep::attr_t();
|
||||||
|
if (input_scale != 1.0f) {
|
||||||
|
op_attr.set_scales_mask(DNNL_ARG_SRC, 0);
|
||||||
|
}
|
||||||
|
if (weight_scale != 1.0f) {
|
||||||
|
op_attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
|
||||||
|
auto engine = ideep::engine::cpu_engine();
|
||||||
|
dnnl::matmul::primitive_desc primitive_desc = with_bias
|
||||||
|
? dnnl::matmul::primitive_desc(
|
||||||
|
engine, src_desc, weights_desc, bias_desc, dst_desc, op_attr)
|
||||||
|
: dnnl::matmul::primitive_desc(
|
||||||
|
engine, src_desc, weights_desc, dst_desc, op_attr);
|
||||||
|
auto expected_weight = weight_t.reorder_if_differ_in(primitive_desc.weights_desc());
|
||||||
|
auto primitive = dnnl::matmul(primitive_desc);
|
||||||
|
|
||||||
|
// Prepare args and execute primitive
|
||||||
|
ideep::tensor scratchpad(primitive_desc.scratchpad_desc());
|
||||||
|
ideep::exec_args args;
|
||||||
|
args.insert({DNNL_ARG_SRC, src});
|
||||||
|
args.insert({DNNL_ARG_WEIGHTS, expected_weight});
|
||||||
|
args.insert({DNNL_ARG_DST, dst});
|
||||||
|
args.insert({DNNL_ARG_SCRATCHPAD, scratchpad});
|
||||||
|
if (with_bias) {
|
||||||
|
args.insert({DNNL_ARG_BIAS, onednn_bias});
|
||||||
|
}
|
||||||
|
ideep::tensor src_scales_t = ideep::tensor(ideep::scale_t(1, input_scale));
|
||||||
|
ideep::tensor wei_scales_t = ideep::tensor(ideep::scale_t(1, weight_scale));
|
||||||
|
|
||||||
|
if (input_scale != 1.0f) {
|
||||||
|
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scales_t});
|
||||||
|
}
|
||||||
|
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_scales_t});
|
||||||
|
|
||||||
|
primitive.execute(ideep::stream::default_stream(), args);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace at
|
} // namespace at
|
||||||
|
|
||||||
#endif // AT_MKLDNN_ENABLED
|
#endif // AT_MKLDNN_ENABLED
|
||||||
|
|
|
||||||
|
|
@ -35,3 +35,15 @@ C10_API Tensor mkl_linear(
|
||||||
} // namespace at
|
} // namespace at
|
||||||
|
|
||||||
#endif // AT_MKLDNN_ENABLED()
|
#endif // AT_MKLDNN_ENABLED()
|
||||||
|
|
||||||
|
namespace at::native {
|
||||||
|
Tensor&
|
||||||
|
mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2,
|
||||||
|
const Tensor& scale_a,
|
||||||
|
const Tensor& scale_b,
|
||||||
|
const std::optional<at::Tensor>& bias,
|
||||||
|
const std::optional<at::Tensor>& scale_result,
|
||||||
|
std::optional<c10::ScalarType> out_dtype,
|
||||||
|
bool use_fast_accum,
|
||||||
|
Tensor& out);
|
||||||
|
} // namespace at::native
|
||||||
|
|
|
||||||
|
|
@ -57,6 +57,10 @@ ideep::tensor::data_type get_mkldnn_dtype(ScalarType type) {
|
||||||
return ideep::tensor::data_type::bf16;
|
return ideep::tensor::data_type::bf16;
|
||||||
case ScalarType::Half:
|
case ScalarType::Half:
|
||||||
return ideep::tensor::data_type::f16;
|
return ideep::tensor::data_type::f16;
|
||||||
|
case ScalarType::Float8_e4m3fn:
|
||||||
|
return ideep::tensor::data_type::f8_e4m3;
|
||||||
|
case ScalarType::Float8_e5m2:
|
||||||
|
return ideep::tensor::data_type::f8_e5m2;
|
||||||
default:
|
default:
|
||||||
TORCH_CHECK(false, "get_mkldnn_dtype: unsupported data type");
|
TORCH_CHECK(false, "get_mkldnn_dtype: unsupported data type");
|
||||||
}
|
}
|
||||||
|
|
@ -161,8 +165,24 @@ ideep::tensor itensor_view_from_dense(const Tensor& tensor, bool from_const_data
|
||||||
const_cast<void*>(tensor.const_data_ptr()) :
|
const_cast<void*>(tensor.const_data_ptr()) :
|
||||||
tensor.data_ptr()};
|
tensor.data_ptr()};
|
||||||
}
|
}
|
||||||
|
else if (tensor.scalar_type() == ScalarType::Float8_e4m3fn) {
|
||||||
|
return {{tensor.sizes().vec(),
|
||||||
|
ideep::tensor::data_type::f8_e4m3,
|
||||||
|
tensor.strides().vec()},
|
||||||
|
from_const_data_ptr ?
|
||||||
|
const_cast<void*>(tensor.const_data_ptr()) :
|
||||||
|
tensor.data_ptr()};
|
||||||
|
}
|
||||||
|
else if (tensor.scalar_type() == ScalarType::Float8_e5m2) {
|
||||||
|
return {{tensor.sizes().vec(),
|
||||||
|
ideep::tensor::data_type::f8_e5m2,
|
||||||
|
tensor.strides().vec()},
|
||||||
|
from_const_data_ptr ?
|
||||||
|
const_cast<void*>(tensor.const_data_ptr()) :
|
||||||
|
tensor.data_ptr()};
|
||||||
|
}
|
||||||
else {
|
else {
|
||||||
TORCH_CHECK(false, "itensor_view_from_dense expects float/bfloat16/half/int8 tensor input");
|
TORCH_CHECK(false, "itensor_view_from_dense expects float/bfloat16/half/int8/fp8 tensor input", tensor.scalar_type());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7067,11 +7067,13 @@
|
||||||
- func: _scaled_mm(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False) -> Tensor
|
- func: _scaled_mm(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False) -> Tensor
|
||||||
variants: function
|
variants: function
|
||||||
dispatch:
|
dispatch:
|
||||||
|
CPU: _scaled_mm_cpu
|
||||||
CUDA: _scaled_mm_cuda
|
CUDA: _scaled_mm_cuda
|
||||||
|
|
||||||
- func: _scaled_mm.out(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False, *, Tensor(a!) out) -> Tensor(a!)
|
- func: _scaled_mm.out(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False, *, Tensor(a!) out) -> Tensor(a!)
|
||||||
variants: function
|
variants: function
|
||||||
dispatch:
|
dispatch:
|
||||||
|
CPU: _scaled_mm_out_cpu
|
||||||
CUDA: _scaled_mm_out_cuda
|
CUDA: _scaled_mm_out_cuda
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ from torch.testing._internal.common_utils import (
|
||||||
instantiate_parametrized_tests,
|
instantiate_parametrized_tests,
|
||||||
parametrize,
|
parametrize,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.inductor_utils import HAS_CUDA
|
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
|
||||||
from torch.utils._triton import has_triton_tma_device
|
from torch.utils._triton import has_triton_tma_device
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -116,9 +116,9 @@ def _fix_fp8_dtype_for_rocm(
|
||||||
|
|
||||||
@instantiate_parametrized_tests
|
@instantiate_parametrized_tests
|
||||||
class TestFP8Types(TestCase):
|
class TestFP8Types(TestCase):
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
||||||
@parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2))
|
@parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2))
|
||||||
def test_xblock_for_small_numel(self, float8_dtype: torch.dtype):
|
@parametrize("device", ("cuda", "cpu"))
|
||||||
|
def test_xblock_for_small_numel(self, float8_dtype: torch.dtype, device: str):
|
||||||
"""
|
"""
|
||||||
TritonOverrides.to_dtype will set min_elem_per_thread to 2 or 4
|
TritonOverrides.to_dtype will set min_elem_per_thread to 2 or 4
|
||||||
depends on the variant of fp8 type.
|
depends on the variant of fp8 type.
|
||||||
|
|
@ -127,30 +127,34 @@ class TestFP8Types(TestCase):
|
||||||
|
|
||||||
We should not pick a XBLOCK larger than xnumel
|
We should not pick a XBLOCK larger than xnumel
|
||||||
"""
|
"""
|
||||||
float8_dtype = _fix_fp8_dtype_for_rocm(float8_dtype, device="cuda")
|
float8_dtype = _fix_fp8_dtype_for_rocm(float8_dtype, device=device)
|
||||||
|
if device == "cuda" and not PLATFORM_SUPPORTS_FP8:
|
||||||
|
raise unittest.SkipTest(f8_msg)
|
||||||
|
|
||||||
def f(x):
|
def f(x):
|
||||||
return x.to(dtype=float8_dtype)
|
return x.to(dtype=float8_dtype)
|
||||||
|
|
||||||
x = torch.randn(1, device="cuda")
|
x = torch.randn(1, device=device)
|
||||||
expected = f(x)
|
expected = f(x)
|
||||||
actual = torch.compile(f)(x)
|
actual = torch.compile(f)(x)
|
||||||
torch.testing.assert_close(expected.half(), actual.half(), rtol=1e-2, atol=1e-2)
|
torch.testing.assert_close(expected.half(), actual.half(), rtol=1e-2, atol=1e-2)
|
||||||
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
||||||
@parametrize("dtype", (torch.float16, torch.bfloat16))
|
@parametrize("dtype", (torch.float16, torch.bfloat16))
|
||||||
def test_eager_fallback(self, dtype: torch.dtype):
|
@parametrize("device", ("cuda", "cpu"))
|
||||||
|
def test_eager_fallback(self, dtype: torch.dtype, device: torch.device):
|
||||||
|
if device == "cuda" and not PLATFORM_SUPPORTS_FP8:
|
||||||
|
raise unittest.SkipTest(f8_msg)
|
||||||
weight_shape = (32, 16)
|
weight_shape = (32, 16)
|
||||||
|
|
||||||
e4m3_type = torch.float8_e4m3fn
|
e4m3_type = torch.float8_e4m3fn
|
||||||
e4m3_type = _fix_fp8_dtype_for_rocm(e4m3_type, device="cuda")
|
e4m3_type = _fix_fp8_dtype_for_rocm(e4m3_type, device=device)
|
||||||
|
|
||||||
def fp8_matmul_unwrapped(x):
|
def fp8_matmul_unwrapped(x):
|
||||||
a_scale = torch.Tensor([1.0]).to(device="cuda")
|
a_scale = torch.Tensor([1.0]).to(device=device)
|
||||||
b_scale = torch.Tensor([1.0]).to(device="cuda")
|
b_scale = torch.Tensor([1.0]).to(device=device)
|
||||||
output_scale = None
|
output_scale = None
|
||||||
input_bias = torch.rand(32, device="cuda", dtype=dtype)
|
input_bias = torch.rand(32, device=device, dtype=dtype)
|
||||||
weight = torch.rand(*weight_shape, device="cuda", dtype=dtype).T.to(
|
weight = torch.rand(*weight_shape, device=device, dtype=dtype).T.to(
|
||||||
e4m3_type
|
e4m3_type
|
||||||
)
|
)
|
||||||
a_inverse_scale = 1 / a_scale
|
a_inverse_scale = 1 / a_scale
|
||||||
|
|
@ -171,19 +175,23 @@ class TestFP8Types(TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
x_shape = (16, 16)
|
x_shape = (16, 16)
|
||||||
x = torch.rand(*x_shape, device="cuda", dtype=dtype).to(e4m3_type)
|
x = torch.rand(*x_shape, device=device, dtype=dtype).to(e4m3_type)
|
||||||
y_fp8 = compiled_fp8_matmul(x) # noqa: F841
|
y_fp8 = compiled_fp8_matmul(x) # noqa: F841
|
||||||
|
|
||||||
x_shape = (15, 16)
|
x_shape = (15, 16)
|
||||||
x = torch.rand(*x_shape, device="cuda", dtype=dtype).to(e4m3_type)
|
x = torch.rand(*x_shape, device=device, dtype=dtype).to(e4m3_type)
|
||||||
y_fp8 = compiled_fp8_matmul(x) # noqa: F841
|
y_fp8 = compiled_fp8_matmul(x) # noqa: F841
|
||||||
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
||||||
@parametrize("dtype", (torch.float16, torch.bfloat16, torch.float))
|
@parametrize("dtype", (torch.float16, torch.bfloat16, torch.float))
|
||||||
@parametrize("shape", ("15,3,13", "4,2048,4096"))
|
@parametrize("shape", ("15,3,13", "4,2048,4096"))
|
||||||
@parametrize("dst_types", [(torch.float8_e4m3fn, torch.float8_e5m2)])
|
@parametrize("dst_types", [(torch.float8_e4m3fn, torch.float8_e5m2)])
|
||||||
def test_valid_cast(self, dtype: torch.dtype, shape: str, dst_types: tuple):
|
@parametrize("device", ("cuda", "cpu"))
|
||||||
dst_types = _fix_fp8_dtype_for_rocm(dst_types, device="cuda")
|
def test_valid_cast(
|
||||||
|
self, dtype: torch.dtype, shape: str, dst_types: tuple, device: torch.device
|
||||||
|
):
|
||||||
|
if device == "cuda" and not PLATFORM_SUPPORTS_FP8:
|
||||||
|
raise unittest.SkipTest(f8_msg)
|
||||||
|
dst_types = _fix_fp8_dtype_for_rocm(dst_types, device=device)
|
||||||
e4m3, e5m2 = dst_types
|
e4m3, e5m2 = dst_types
|
||||||
|
|
||||||
def fp8_cast(x):
|
def fp8_cast(x):
|
||||||
|
|
@ -194,7 +202,7 @@ class TestFP8Types(TestCase):
|
||||||
compiled_fp8_cast = torch.compile(fp8_cast, backend="inductor", dynamic=True)
|
compiled_fp8_cast = torch.compile(fp8_cast, backend="inductor", dynamic=True)
|
||||||
|
|
||||||
shape = [int(dim) for dim in shape.split(",")]
|
shape = [int(dim) for dim in shape.split(",")]
|
||||||
x = torch.rand(*shape, device="cuda", dtype=dtype)
|
x = torch.rand(*shape, device=device, dtype=dtype)
|
||||||
y0_fp8, y1_fp8 = compiled_fp8_cast(x)
|
y0_fp8, y1_fp8 = compiled_fp8_cast(x)
|
||||||
|
|
||||||
torch.testing.assert_close(y0_fp8, x, rtol=5e-1, atol=5e-1)
|
torch.testing.assert_close(y0_fp8, x, rtol=5e-1, atol=5e-1)
|
||||||
|
|
@ -223,14 +231,20 @@ class TestFP8Types(TestCase):
|
||||||
x = torch.rand(*x_shape, device="cuda").to(dtype=torch.float8_e5m2)
|
x = torch.rand(*x_shape, device="cuda").to(dtype=torch.float8_e5m2)
|
||||||
compiled_fp8_cast(x, torch.float8_e4m3fn)
|
compiled_fp8_cast(x, torch.float8_e4m3fn)
|
||||||
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
||||||
@parametrize("src_dtype", (torch.float16, torch.bfloat16, torch.float))
|
@parametrize("src_dtype", (torch.float16, torch.bfloat16, torch.float))
|
||||||
@parametrize("dst_dtype", (torch.float8_e4m3fn, torch.float8_e5m2))
|
@parametrize("dst_dtype", (torch.float8_e4m3fn, torch.float8_e5m2))
|
||||||
@parametrize("shape", ("16,16,16", "4,2048,4096"))
|
@parametrize("shape", ("16,16,16", "4,2048,4096"))
|
||||||
|
@parametrize("device", ("cuda", "cpu"))
|
||||||
def test_to_fp8_saturated(
|
def test_to_fp8_saturated(
|
||||||
self, src_dtype: torch.dtype, dst_dtype: torch.dtype, shape: str
|
self,
|
||||||
|
src_dtype: torch.dtype,
|
||||||
|
dst_dtype: torch.dtype,
|
||||||
|
shape: str,
|
||||||
|
device: torch.device,
|
||||||
):
|
):
|
||||||
dst_dtype = _fix_fp8_dtype_for_rocm(dst_dtype, device="cuda")
|
if device == "cuda" and not PLATFORM_SUPPORTS_FP8:
|
||||||
|
raise unittest.SkipTest(f8_msg)
|
||||||
|
dst_dtype = _fix_fp8_dtype_for_rocm(dst_dtype, device=device)
|
||||||
|
|
||||||
def fp8_saturated(x, dtype):
|
def fp8_saturated(x, dtype):
|
||||||
return _to_fp8_saturated(x, dtype)
|
return _to_fp8_saturated(x, dtype)
|
||||||
|
|
@ -239,17 +253,23 @@ class TestFP8Types(TestCase):
|
||||||
fp8_saturated, backend="inductor", dynamic=True
|
fp8_saturated, backend="inductor", dynamic=True
|
||||||
)
|
)
|
||||||
shape = [int(dim) for dim in shape.split(",")]
|
shape = [int(dim) for dim in shape.split(",")]
|
||||||
x = torch.rand(*shape, device="cuda", dtype=src_dtype)
|
x = torch.rand(*shape, device=device, dtype=src_dtype)
|
||||||
y_compiled = compiled_fp8_cast(x, dst_dtype)
|
y_compiled = compiled_fp8_cast(x, dst_dtype)
|
||||||
y = fp8_saturated(x, dst_dtype)
|
y = fp8_saturated(x, dst_dtype)
|
||||||
|
|
||||||
torch.testing.assert_close(y_compiled.half(), y.half(), rtol=5e-1, atol=5e-1)
|
torch.testing.assert_close(y_compiled.half(), y.half(), rtol=5e-1, atol=5e-1)
|
||||||
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
||||||
@parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2))
|
@parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2))
|
||||||
@parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096"))
|
@parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096"))
|
||||||
def test_amax_fp8_quant(self, float8_dtype: torch.dtype, shape: str):
|
@parametrize("device", ("cuda", "cpu"))
|
||||||
float8_dtype = _fix_fp8_dtype_for_rocm(float8_dtype, device="cuda")
|
def test_amax_fp8_quant(
|
||||||
|
self, float8_dtype: torch.dtype, shape: str, device: torch.device
|
||||||
|
):
|
||||||
|
float8_dtype = _fix_fp8_dtype_for_rocm(float8_dtype, device=device)
|
||||||
|
if device == "cuda" and not PLATFORM_SUPPORTS_FP8:
|
||||||
|
raise unittest.SkipTest(
|
||||||
|
"FP8 is only supported on H100+ and sm_89 and MI300+ devices"
|
||||||
|
)
|
||||||
shape = [int(dim) for dim in shape.split(",")]
|
shape = [int(dim) for dim in shape.split(",")]
|
||||||
batch_size, sequence_length, hidden_size = shape
|
batch_size, sequence_length, hidden_size = shape
|
||||||
|
|
||||||
|
|
@ -262,19 +282,23 @@ class TestFP8Types(TestCase):
|
||||||
compiled_amax_fp8_quant = torch.compile(amax_fp8, backend="inductor")
|
compiled_amax_fp8_quant = torch.compile(amax_fp8, backend="inductor")
|
||||||
|
|
||||||
x_shape = (batch_size, sequence_length, hidden_size)
|
x_shape = (batch_size, sequence_length, hidden_size)
|
||||||
x = torch.rand(*x_shape, device="cuda", dtype=torch.half)
|
x = torch.rand(*x_shape, device=device, dtype=torch.half)
|
||||||
scale = torch.tensor(0.2, device="cuda", dtype=torch.float)
|
scale = torch.tensor(0.2, device=device, dtype=torch.float)
|
||||||
|
|
||||||
y_compiled = compiled_amax_fp8_quant(x, scale)
|
y_compiled = compiled_amax_fp8_quant(x, scale)
|
||||||
y = amax_fp8(x, scale)
|
y = amax_fp8(x, scale)
|
||||||
|
|
||||||
torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-2, atol=1e-2)
|
torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-2, atol=1e-2)
|
||||||
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
||||||
@parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2))
|
@parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2))
|
||||||
@parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096"))
|
@parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096"))
|
||||||
def test_amax_along_with_fp8_quant(self, float8_dtype: torch.dtype, shape: str):
|
@parametrize("device", ("cuda", "cpu"))
|
||||||
float8_dtype = _fix_fp8_dtype_for_rocm(float8_dtype, device="cuda")
|
def test_amax_along_with_fp8_quant(
|
||||||
|
self, float8_dtype: torch.dtype, shape: str, device: torch.device
|
||||||
|
):
|
||||||
|
if device == "cuda" and not PLATFORM_SUPPORTS_FP8:
|
||||||
|
raise unittest.SkipTest(f8_msg)
|
||||||
|
float8_dtype = _fix_fp8_dtype_for_rocm(float8_dtype, device=device)
|
||||||
shape = [int(dim) for dim in shape.split(",")]
|
shape = [int(dim) for dim in shape.split(",")]
|
||||||
batch_size, sequence_length, hidden_size = shape
|
batch_size, sequence_length, hidden_size = shape
|
||||||
|
|
||||||
|
|
@ -287,12 +311,12 @@ class TestFP8Types(TestCase):
|
||||||
compiled_amax_fp8_quant = torch.compile(amax_fp8, backend="inductor")
|
compiled_amax_fp8_quant = torch.compile(amax_fp8, backend="inductor")
|
||||||
|
|
||||||
x_shape = (batch_size, sequence_length, hidden_size)
|
x_shape = (batch_size, sequence_length, hidden_size)
|
||||||
x = torch.rand(*x_shape, device="cuda", dtype=torch.half)
|
x = torch.rand(*x_shape, device=device, dtype=torch.half)
|
||||||
scale = torch.tensor(1.0, device="cuda", dtype=torch.float)
|
scale = torch.tensor(1.0, device=device, dtype=torch.float)
|
||||||
|
|
||||||
amax_buffer_compiled = torch.zeros((1), device="cuda", dtype=torch.half)
|
amax_buffer_compiled = torch.zeros((1), device=device, dtype=torch.half)
|
||||||
y_compiled = compiled_amax_fp8_quant(x, scale, amax_buffer_compiled)
|
y_compiled = compiled_amax_fp8_quant(x, scale, amax_buffer_compiled)
|
||||||
amax_buffer = torch.zeros((1), device="cuda", dtype=torch.half)
|
amax_buffer = torch.zeros((1), device=device, dtype=torch.half)
|
||||||
y = amax_fp8(x, scale, amax_buffer)
|
y = amax_fp8(x, scale, amax_buffer)
|
||||||
|
|
||||||
torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-1, atol=1e-1)
|
torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-1, atol=1e-1)
|
||||||
|
|
@ -300,14 +324,22 @@ class TestFP8Types(TestCase):
|
||||||
amax_buffer_compiled, amax_buffer, rtol=1e-2, atol=1e-2
|
amax_buffer_compiled, amax_buffer, rtol=1e-2, atol=1e-2
|
||||||
)
|
)
|
||||||
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
||||||
@parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2))
|
@parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2))
|
||||||
@parametrize("amax_keep_dim", (True, False))
|
@parametrize("amax_keep_dim", (True, False))
|
||||||
@parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096"))
|
@parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096"))
|
||||||
|
@parametrize("device", ("cuda", "cpu"))
|
||||||
def test_layernorm_fp8_quant(
|
def test_layernorm_fp8_quant(
|
||||||
self, float8_dtype: torch.dtype, amax_keep_dim: bool, shape: str
|
self,
|
||||||
|
float8_dtype: torch.dtype,
|
||||||
|
amax_keep_dim: bool,
|
||||||
|
shape: str,
|
||||||
|
device: torch.device,
|
||||||
):
|
):
|
||||||
float8_dtype = _fix_fp8_dtype_for_rocm(float8_dtype, device="cuda")
|
if device == "cuda" and not PLATFORM_SUPPORTS_FP8:
|
||||||
|
raise unittest.SkipTest(
|
||||||
|
"FP8 is only supported on H100+ and sm_89 and MI300+ devices"
|
||||||
|
)
|
||||||
|
float8_dtype = _fix_fp8_dtype_for_rocm(float8_dtype, device=device)
|
||||||
shape = [int(dim) for dim in shape.split(",")]
|
shape = [int(dim) for dim in shape.split(",")]
|
||||||
batch_size, sequence_length, hidden_size = shape
|
batch_size, sequence_length, hidden_size = shape
|
||||||
|
|
||||||
|
|
@ -329,12 +361,12 @@ class TestFP8Types(TestCase):
|
||||||
compiled_ln_fp8_quant = torch.compile(ln_fp8, backend="inductor")
|
compiled_ln_fp8_quant = torch.compile(ln_fp8, backend="inductor")
|
||||||
|
|
||||||
x_shape = (batch_size, sequence_length, hidden_size)
|
x_shape = (batch_size, sequence_length, hidden_size)
|
||||||
x = torch.rand(*x_shape, device="cuda", dtype=torch.half)
|
x = torch.rand(*x_shape, device=device, dtype=torch.half)
|
||||||
scale = torch.tensor(0.2, device="cuda", dtype=torch.float)
|
scale = torch.tensor(0.2, device=device, dtype=torch.float)
|
||||||
|
|
||||||
amax_buffer_compiled = torch.zeros((1), device="cuda", dtype=torch.half)
|
amax_buffer_compiled = torch.zeros((1), device=device, dtype=torch.half)
|
||||||
y_compiled = compiled_ln_fp8_quant(x, scale, amax_buffer_compiled)
|
y_compiled = compiled_ln_fp8_quant(x, scale, amax_buffer_compiled)
|
||||||
amax_buffer = torch.zeros((1), device="cuda", dtype=torch.half)
|
amax_buffer = torch.zeros((1), device=device, dtype=torch.half)
|
||||||
y = ln_fp8(x, scale, amax_buffer)
|
y = ln_fp8(x, scale, amax_buffer)
|
||||||
|
|
||||||
torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-1, atol=1e-1)
|
torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-1, atol=1e-1)
|
||||||
|
|
@ -750,5 +782,5 @@ class TestFP8Lowering(TestCase):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
if HAS_CUDA:
|
if HAS_CUDA or HAS_CPU:
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|
|
||||||
|
|
@ -653,15 +653,15 @@ def _bfloat16_to_float4_e2m1fn_x2(x):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not found")
|
class TestFP8Matmul(TestCase):
|
||||||
class TestFP8MatmulCuda(TestCase):
|
|
||||||
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
||||||
def _test_tautological_mm(self, device: str = "cuda",
|
def _test_tautological_mm(self, device: str = "cuda",
|
||||||
x_dtype: torch.dtype = e4m3_type,
|
x_dtype: torch.dtype = e4m3_type,
|
||||||
y_dtype: torch.dtype = e4m3_type,
|
y_dtype: torch.dtype = e4m3_type,
|
||||||
out_dtype: Optional[torch.dtype] = None,
|
out_dtype: Optional[torch.dtype] = None,
|
||||||
size: int = 16) -> None:
|
size: int = 16) -> None:
|
||||||
|
if device != "cpu" and torch.cuda.is_available() and not PLATFORM_SUPPORTS_FP8:
|
||||||
|
raise unittest.SkipTest(f8_msg)
|
||||||
x_fp8 = torch.rand(size, size, device=device).to(x_dtype)
|
x_fp8 = torch.rand(size, size, device=device).to(x_dtype)
|
||||||
y_fp8 = torch.eye(size, device=device, dtype=y_dtype).t()
|
y_fp8 = torch.eye(size, device=device, dtype=y_dtype).t()
|
||||||
out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float))
|
out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float))
|
||||||
|
|
@ -672,12 +672,13 @@ class TestFP8MatmulCuda(TestCase):
|
||||||
self.assertEqual(out_dtype, out_fp8.dtype)
|
self.assertEqual(out_dtype, out_fp8.dtype)
|
||||||
self.assertEqual(out_fp32, out_fp8.to(torch.float))
|
self.assertEqual(out_fp32, out_fp8.to(torch.float))
|
||||||
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
||||||
def test_float8_basics(self, device) -> None:
|
def test_float8_basics(self, device) -> None:
|
||||||
|
if device != "cpu" and torch.cuda.is_available() and not PLATFORM_SUPPORTS_FP8:
|
||||||
|
raise unittest.SkipTest(f8_msg)
|
||||||
self._test_tautological_mm(device, e4m3_type, e4m3_type, size=16)
|
self._test_tautological_mm(device, e4m3_type, e4m3_type, size=16)
|
||||||
# According to https://docs.nvidia.com/cuda/cublas/#id99 8F_E5M2 MM is unsupported
|
# According to https://docs.nvidia.com/cuda/cublas/#id99 8F_E5M2 MM is unsupported
|
||||||
# supported on ROCm but fails on CUDA
|
# supported on ROCm but fails on CUDA
|
||||||
ctx = self.assertRaises(RuntimeError) if torch.version.hip is None else contextlib.nullcontext()
|
ctx = self.assertRaises(RuntimeError) if torch.version.hip is None and device != "cpu" else contextlib.nullcontext()
|
||||||
with ctx:
|
with ctx:
|
||||||
self._test_tautological_mm(device, e5m2_type, e5m2_type)
|
self._test_tautological_mm(device, e5m2_type, e5m2_type)
|
||||||
|
|
||||||
|
|
@ -688,11 +689,12 @@ class TestFP8MatmulCuda(TestCase):
|
||||||
self._test_tautological_mm(device, size=96, out_dtype=torch.float32)
|
self._test_tautological_mm(device, size=96, out_dtype=torch.float32)
|
||||||
self._test_tautological_mm(device, size=80, out_dtype=torch.bfloat16)
|
self._test_tautological_mm(device, size=80, out_dtype=torch.bfloat16)
|
||||||
|
|
||||||
with self.assertRaises(AssertionError if torch.version.hip else RuntimeError):
|
with self.assertRaises(AssertionError if torch.version.hip or device == "cpu" else RuntimeError):
|
||||||
self._test_tautological_mm(device, out_dtype=e5m2_type)
|
self._test_tautological_mm(device, out_dtype=e5m2_type)
|
||||||
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
||||||
def test_float8_scale(self, device) -> None:
|
def test_float8_scale(self, device) -> None:
|
||||||
|
if device != "cpu" and torch.cuda.is_available() and not PLATFORM_SUPPORTS_FP8:
|
||||||
|
raise unittest.SkipTest(f8_msg)
|
||||||
size = (16, 16)
|
size = (16, 16)
|
||||||
x = torch.full(size, .5, device=device, dtype=e4m3_type)
|
x = torch.full(size, .5, device=device, dtype=e4m3_type)
|
||||||
# hipblaslt does not yet support mixed e4m3_type input
|
# hipblaslt does not yet support mixed e4m3_type input
|
||||||
|
|
@ -807,8 +809,9 @@ class TestFP8MatmulCuda(TestCase):
|
||||||
|
|
||||||
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
|
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
||||||
def test_float8_bias(self, device) -> None:
|
def test_float8_bias(self, device) -> None:
|
||||||
|
if device != "cpu" and torch.cuda.is_available() and not PLATFORM_SUPPORTS_FP8:
|
||||||
|
raise unittest.SkipTest(f8_msg)
|
||||||
(k, l, m) = (16, 48, 32)
|
(k, l, m) = (16, 48, 32)
|
||||||
x = torch.ones((k, l), device=device).to(e4m3_type)
|
x = torch.ones((k, l), device=device).to(e4m3_type)
|
||||||
y = torch.full((m, l), .25, device=device, dtype=e4m3_type).t()
|
y = torch.full((m, l), .25, device=device, dtype=e4m3_type).t()
|
||||||
|
|
@ -861,7 +864,7 @@ class TestFP8MatmulCuda(TestCase):
|
||||||
lambda: torch._scaled_mm(x, y, scale_a, scale_b, bias=bias, out_dtype=torch.float32),
|
lambda: torch._scaled_mm(x, y, scale_a, scale_b, bias=bias, out_dtype=torch.float32),
|
||||||
)
|
)
|
||||||
|
|
||||||
@unittest.skipIf(PLATFORM_SUPPORTS_FP8, f8_msg)
|
@unittest.skipIf(PLATFORM_SUPPORTS_FP8 or not torch.cuda.is_available(), f8_msg)
|
||||||
def test_error_message_fp8_pre_sm89(self, device) -> None:
|
def test_error_message_fp8_pre_sm89(self, device) -> None:
|
||||||
(k, l, m) = (16, 48, 32)
|
(k, l, m) = (16, 48, 32)
|
||||||
x = torch.rand((k, l), device=device).to(e4m3_type)
|
x = torch.rand((k, l), device=device).to(e4m3_type)
|
||||||
|
|
@ -1718,8 +1721,8 @@ class TestMixedDtypesLinearCuda(TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
instantiate_device_type_tests(TestMatmulCuda, globals(), except_for="cpu")
|
instantiate_device_type_tests(TestMatmulCuda, globals(), except_for="cpu")
|
||||||
instantiate_device_type_tests(TestFP8MatmulCuda, globals(), except_for="cpu")
|
|
||||||
instantiate_device_type_tests(TestMixedDtypesLinearCuda, globals(), except_for="cpu")
|
instantiate_device_type_tests(TestMixedDtypesLinearCuda, globals(), except_for="cpu")
|
||||||
|
instantiate_device_type_tests(TestFP8Matmul, globals())
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
TestCase._default_dtype_check_enabled = True
|
TestCase._default_dtype_check_enabled = True
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,8 @@
|
||||||
|
|
||||||
#include <c10/util/Float8_e4m3fn.h>
|
#include <c10/util/Float8_e4m3fn.h>
|
||||||
#include <c10/util/Float8_e5m2.h>
|
#include <c10/util/Float8_e5m2.h>
|
||||||
|
#include <c10/util/Float8_e4m3fnuz.h>
|
||||||
|
#include <c10/util/Float8_e5m2fnuz.h>
|
||||||
#include <c10/util/BFloat16.h>
|
#include <c10/util/BFloat16.h>
|
||||||
#include <c10/util/BFloat16-math.h>
|
#include <c10/util/BFloat16-math.h>
|
||||||
#include <c10/util/generic_math.h>
|
#include <c10/util/generic_math.h>
|
||||||
|
|
|
||||||
|
|
@ -37,6 +37,8 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_flash_attent
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_flash_attention_for_cpu_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle out, AtenTensorHandle logsumexp, double dropout_p, int32_t is_causal, AtenTensorHandle* attn_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_flash_attention_for_cpu_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle out, AtenTensorHandle logsumexp, double dropout_p, int32_t is_causal, AtenTensorHandle* attn_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
|
||||||
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum, AtenTensorHandle* ret0);
|
||||||
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__segment_reduce_backward(AtenTensorHandle grad, AtenTensorHandle output, AtenTensorHandle data, const char* reduce, AtenTensorHandle* lengths, AtenTensorHandle* offsets, int64_t axis, double* initial, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__segment_reduce_backward(AtenTensorHandle grad, AtenTensorHandle output, AtenTensorHandle data, const char* reduce, AtenTensorHandle* lengths, AtenTensorHandle* offsets, int64_t axis, double* initial, AtenTensorHandle* ret0);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0);
|
||||||
|
|
|
||||||
|
|
@ -1003,6 +1003,8 @@ ANY_DTYPE_ORDER = (
|
||||||
torch.int8,
|
torch.int8,
|
||||||
torch.uint8,
|
torch.uint8,
|
||||||
torch.bool,
|
torch.bool,
|
||||||
|
torch.float8_e4m3fn,
|
||||||
|
torch.float8_e5m2,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,7 @@ from torch.testing import make_tensor
|
||||||
from torch.testing._internal.common_dtype import (
|
from torch.testing._internal.common_dtype import (
|
||||||
_dispatch_dtypes, floating_types, floating_types_and, complex_types, floating_and_complex_types,
|
_dispatch_dtypes, floating_types, floating_types_and, complex_types, floating_and_complex_types,
|
||||||
floating_and_complex_types_and, all_types_and_complex_and, all_types_and, all_types_and_complex, integral_types_and,
|
floating_and_complex_types_and, all_types_and_complex_and, all_types_and, all_types_and_complex, integral_types_and,
|
||||||
empty_types, complex_types_and, integral_types, custom_types, all_types_complex_float8_and,
|
empty_types, complex_types_and, integral_types, custom_types, all_types_complex_float8_and, float8_types,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_device_type import \
|
from torch.testing._internal.common_device_type import \
|
||||||
(onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver,
|
(onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver,
|
||||||
|
|
@ -16221,7 +16221,7 @@ op_db: list[OpInfo] = [
|
||||||
OpInfo(
|
OpInfo(
|
||||||
'torch._scaled_mm',
|
'torch._scaled_mm',
|
||||||
sample_inputs_func=sample_inputs_scaled_mm,
|
sample_inputs_func=sample_inputs_scaled_mm,
|
||||||
dtypes=empty_types(),
|
dtypes=float8_types(),
|
||||||
dtypesIfCUDA=empty_types() + (torch.float8_e4m3fn,),
|
dtypesIfCUDA=empty_types() + (torch.float8_e4m3fn,),
|
||||||
supports_out=True,
|
supports_out=True,
|
||||||
supports_forward_ad=False,
|
supports_forward_ad=False,
|
||||||
|
|
@ -16229,12 +16229,20 @@ op_db: list[OpInfo] = [
|
||||||
decorators=[skipCUDAIf(not SM89OrLater or TEST_WITH_ROCM, 'Requires CUDA SM >= 8.9')],
|
decorators=[skipCUDAIf(not SM89OrLater or TEST_WITH_ROCM, 'Requires CUDA SM >= 8.9')],
|
||||||
skips=(
|
skips=(
|
||||||
# Sample inputs isn't really parametrized on dtype
|
# Sample inputs isn't really parametrized on dtype
|
||||||
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes',
|
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes'),
|
||||||
device_type='cuda'),
|
# "add_stub" not implemented for 'Float8_e4m3fn'
|
||||||
# "mul_cuda" not implemented for float8_e4m3fn
|
# "ufunc_add_CUDA" not implemented for 'Float8_e4m3fn'
|
||||||
# https://github.com/pytorch/pytorch/issues/107256
|
# https://github.com/pytorch/pytorch/issues/107256
|
||||||
DecorateInfo(unittest.skip("Skipped!"), 'TestSchemaCheckModeOpInfo', 'test_schema_correctness',
|
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'),
|
||||||
dtypes=(torch.float8_e4m3fn,)),
|
# "mul_cuda" not implemented for float8_e4m3fn
|
||||||
|
# "mul_cpu_reduced_float" not implemented for 'Float8_e4m3fn'
|
||||||
|
# https://github.com/pytorch/pytorch/issues/107256
|
||||||
|
DecorateInfo(unittest.skip("Skipped!"), 'TestSchemaCheckModeOpInfo', 'test_schema_correctness'),
|
||||||
|
# aten::_scaled_mm hit the vmap fallback which is currently disabled
|
||||||
|
DecorateInfo(unittest.skip("Skipped!"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"),
|
||||||
|
DecorateInfo(unittest.skip("Skipped!"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"),
|
||||||
|
DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness',
|
||||||
|
dtypes=(torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz)),
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
OpInfo(
|
OpInfo(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user