Revert "Add torch.float8_e5m2 and torch.float8_e4m3 data types (#104242)"

This reverts commit a9804130e5.

Reverted https://github.com/pytorch/pytorch/pull/104242 on behalf of https://github.com/PaliC due to breaks lint (run lintrunner and remerge) ([comment](https://github.com/pytorch/pytorch/pull/104242#issuecomment-1644150284))
This commit is contained in:
PyTorch MergeBot 2023-07-20 15:37:53 +00:00
parent 02cd971e95
commit f2b15772ff
36 changed files with 98 additions and 1626 deletions

View File

@ -2,8 +2,6 @@
#include <ATen/Config.h>
#include <c10/core/ScalarType.h>
#include <c10/util/BFloat16.h>
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e5m2.h>
#include <c10/util/Half.h>
// Defines the accumulation type for a scalar type.
@ -69,14 +67,6 @@ struct AccumulateType<Half, true> {
using type = float;
};
template <>
struct AccumulateType<Float8_e5m2, true> {
using type = float;
};
template <>
struct AccumulateType<Float8_e4m3fn, true> {
using type = float;
};
template <>
struct AccumulateType<float, true> {
using type = float;
};
@ -121,14 +111,6 @@ struct AccumulateType<BFloat16, false> {
using type = float;
};
template <>
struct AccumulateType<Float8_e5m2, false> {
using type = float;
};
template <>
struct AccumulateType<Float8_e4m3fn, false> {
using type = float;
};
template <>
struct AccumulateType<c10::complex<Half>, false> {
using type = c10::complex<float>;
};

View File

@ -53,10 +53,6 @@ DLDataType getDLDataType(const Tensor& t) {
case ScalarType::BFloat16:
dtype.code = DLDataTypeCode::kDLBfloat;
break;
case ScalarType::Float8_e5m2:
case ScalarType::Float8_e4m3fn:
TORCH_CHECK(false, "float8 types are not supported by dlpack");
break;
case ScalarType::QInt8:
case ScalarType::QUInt8:
case ScalarType::QInt32:

View File

@ -291,22 +291,6 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
AT_DISPATCH_CASE_FLOATING_TYPES_AND3( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
#define AT_DISPATCH_CASE_FLOATING_TYPES_AND4( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \
AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)
#define AT_DISPATCH_FLOATING_TYPES_AND4( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, \
NAME, \
AT_DISPATCH_CASE_FLOATING_TYPES_AND4( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))
#define AT_DISPATCH_CASE_COMPLEX_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::ComplexDouble, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::ComplexFloat, __VA_ARGS__)
@ -531,73 +515,6 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))
#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND5( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5, ...) \
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__)
#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND5( \
SCALARTYPE1, \
SCALARTYPE2, \
SCALARTYPE3, \
SCALARTYPE4, \
SCALARTYPE5, \
TYPE, \
NAME, \
...) \
AT_DISPATCH_SWITCH( \
TYPE, \
NAME, \
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND5( \
SCALARTYPE1, \
SCALARTYPE2, \
SCALARTYPE3, \
SCALARTYPE4, \
SCALARTYPE5, \
__VA_ARGS__))
#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND6( \
SCALARTYPE1, \
SCALARTYPE2, \
SCALARTYPE3, \
SCALARTYPE4, \
SCALARTYPE5, \
SCALARTYPE6, \
...) \
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__)
#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6( \
SCALARTYPE1, \
SCALARTYPE2, \
SCALARTYPE3, \
SCALARTYPE4, \
SCALARTYPE5, \
SCALARTYPE6, \
TYPE, \
NAME, \
...) \
AT_DISPATCH_SWITCH( \
TYPE, \
NAME, \
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND6( \
SCALARTYPE1, \
SCALARTYPE2, \
SCALARTYPE3, \
SCALARTYPE4, \
SCALARTYPE5, \
SCALARTYPE6, \
__VA_ARGS__))
#define AT_DISPATCH_INDEX_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, \

View File

@ -6,8 +6,6 @@
#include <c10/macros/Macros.h>
#include <c10/util/BFloat16.h>
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e5m2.h>
#include <c10/util/Half.h>
#include <c10/util/complex.h>
@ -64,22 +62,6 @@ inline C10_HOST_DEVICE bool _isnan(at::BFloat16 val) {
return at::_isnan(static_cast<float>(val));
}
template <
typename T,
typename std::enable_if<std::is_same<T, at::Float8_e5m2>::value, int>::
type = 0>
inline C10_HOST_DEVICE bool _isnan(T val) {
return val.isnan();
}
template <
typename T,
typename std::enable_if<std::is_same<T, at::Float8_e4m3fn>::value, int>::
type = 0>
inline C10_HOST_DEVICE bool _isnan(T val) {
return val.isnan();
}
// std::isinf isn't performant to use on integral types; it will
// (uselessly) convert to floating point and then do the test.
// This function is.
@ -110,14 +92,6 @@ inline C10_HOST_DEVICE bool _isinf(at::BFloat16 val) {
return at::_isinf(static_cast<float>(val));
}
inline C10_HOST_DEVICE bool _isinf(at::Float8_e5m2 val) {
return val.isinf();
}
inline C10_HOST_DEVICE bool _isinf(at::Float8_e4m3fn val) {
return false;
}
template <typename T>
C10_HOST_DEVICE inline T exp(T x) {
static_assert(

View File

@ -3,8 +3,6 @@
#include <c10/core/ScalarType.h>
#include <c10/util/BFloat16.h>
#include <c10/util/Exception.h>
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e5m2.h>
#include <c10/util/Half.h>
namespace at {
@ -23,14 +21,6 @@ struct OpMathType<at::BFloat16> {
using type = float;
};
template <>
struct OpMathType<at::Float8_e5m2> {
using type = float;
};
template <>
struct OpMathType<at::Float8_e4m3fn> {
using type = float;
};
template <>
struct OpMathType<c10::complex<Half>> {
using type = c10::complex<float>;
};

View File

@ -72,9 +72,7 @@ struct is_floating_point:
std::integral_constant<bool,
std::is_floating_point<T>::value ||
std::is_same<T, at::Half>::value ||
std::is_same<T, at::BFloat16>::value ||
std::is_same<T, at::Float8_e5m2>::value ||
std::is_same<T, at::Float8_e4m3fn>::value> {
std::is_same<T, at::BFloat16>::value> {
};
template<typename T>

View File

@ -48,18 +48,6 @@ bool copy_transpose_valid(const Tensor& self, const Tensor& src) {
self.numel() >= MIN_SZ;
}
#if !defined(C10_MOBILE)
#define _AT_DISPATCH_CP_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6( \
kComplexHalf, kHalf, kBool, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
TYPE, NAME, __VA_ARGS__)
#else
#define _AT_DISPATCH_CP_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
kComplexHalf, kHalf, kBool, kBFloat16, \
TYPE, NAME, __VA_ARGS__)
#endif
// special case copy where tensor is contiguous and src is a transposed matrix
// This can be generalized to most copies, but it's trickier
void copy_same_type_transpose_(Tensor& self, const Tensor& src) {
@ -77,7 +65,7 @@ void copy_same_type_transpose_(Tensor& self, const Tensor& src) {
// The code below is implemented with the assumption that sizes are equal
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(self.sizes().equals(src.sizes()));
_AT_DISPATCH_CP_TYPES(self.scalar_type(), "copy_", [&] {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kHalf, kBool, kBFloat16, kComplexHalf, self.scalar_type(), "copy_", [&] {
scalar_t* sp = src.data_ptr<scalar_t>();
scalar_t* rp = self.data_ptr<scalar_t>();
scalar_t* bp = buf.data_ptr<scalar_t>();

View File

@ -1310,18 +1310,6 @@ Tensor outer(const Tensor& self, const Tensor& vec2) {
return self.reshape_symint({self.sym_size(0), 1}) * vec2;
}
#if !defined(C10_MOBILE)
#define _AT_DISPATCH_ADDMM_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( \
kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
TYPE, NAME, __VA_ARGS__)
#else
#define _AT_DISPATCH_ADDMM_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, \
TYPE, NAME, __VA_ARGS__)
#endif
static void addmm_impl_cpu_(
Tensor &result, const Tensor &self, Tensor m1, Tensor m2, const Scalar& beta, const Scalar& alpha) {
TORCH_INTERNAL_ASSERT(self.dim() == 2 && m1.dim() == 2 && m2.dim() == 2);
@ -1450,7 +1438,9 @@ static void addmm_impl_cpu_(
if(!dispatched) {
// Apply BLAS routine
_AT_DISPATCH_ADDMM_TYPES(result.scalar_type(), "addmm_impl_cpu_", [&]{
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16,
result.scalar_type(), "addmm_impl_cpu_",
[&]{
using opmath_t = at::opmath_type<scalar_t>;
at::native::cpublas::gemm(
transpose_a ? a.is_conj() ? TransposeType::ConjTranspose : TransposeType::Transpose : TransposeType::NoTranspose,

View File

@ -28,21 +28,10 @@ Scalar item(const Tensor& self) {
}
}
#if !defined(C10_MOBILE)
#define _AT_DISPATCH_SD_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6( \
kComplexHalf, kHalf, kBool, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
TYPE, NAME, __VA_ARGS__)
#else
#define _AT_DISPATCH_SD_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
kComplexHalf, kHalf, kBool, kBFloat16, \
TYPE, NAME, __VA_ARGS__)
#endif
Scalar _local_scalar_dense_cpu(const Tensor& self) {
Scalar r;
_AT_DISPATCH_SD_TYPES(self.scalar_type(), "_local_scalar_dense_cpu", [&] {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
kComplexHalf, kHalf, kBool, kBFloat16, self.scalar_type(), "_local_scalar_dense_cpu", [&] {
scalar_t value = *self.data_ptr<scalar_t>();
r = Scalar(value);
});

View File

@ -369,18 +369,6 @@ Tensor isreal(const Tensor& self) {
return at::imag(self) == 0;
}
#if !defined(C10_MOBILE)
#define _AT_DISPATCH_INF_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_FLOATING_TYPES_AND3( kHalf, kBFloat16, kFloat8_e5m2, \
TYPE, NAME, __VA_ARGS__)
#else
#define _AT_DISPATCH_INF_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, \
TYPE, NAME, __VA_ARGS__)
#endif
Tensor isinf(const Tensor &self) {
// Note: Integral tensor values are never infinite
if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {
@ -393,7 +381,7 @@ Tensor isinf(const Tensor &self) {
(at::isinf(at::imag(self)));
}
return _AT_DISPATCH_INF_TYPES(self.scalar_type(), "isinf", [&]() {
return AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, self.scalar_type(), "isinf", [&]() {
return self.abs() == std::numeric_limits<scalar_t>::infinity();
});
}
@ -409,7 +397,7 @@ Tensor isfinite(const Tensor& self) {
return at::isfinite(at::real(self)).__iand__(at::isfinite(at::imag(self)));
}
return _AT_DISPATCH_INF_TYPES(self.scalar_type(), "isfinite", [&]() {
return AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, self.scalar_type(), "isfinite", [&]() {
return (self == self) * (self.abs() != std::numeric_limits<scalar_t>::infinity());
});
}

View File

@ -61,34 +61,6 @@ void atan2_kernel(TensorIteratorBase& iter) {
});
}
#if !defined(C10_MOBILE)
#define _AT_DISPATCH_ALL_TYPES_AND_BOOL(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6( \
kComplexHalf, kHalf, kBool, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
TYPE, NAME, __VA_ARGS__)
#define _AT_DISPATCH_ALL_TYPES_NO_BOOL(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND5( \
kComplexHalf, kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
TYPE, NAME, __VA_ARGS__)
#define _AT_DISPATCH_MUL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
TYPE, NAME, __VA_ARGS__)
#else
#define _AT_DISPATCH_ALL_TYPES_AND_BOOL(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
kComplexHalf, kHalf, kBool, kBFloat16, \
TYPE, NAME, __VA_ARGS__)
#define _AT_DISPATCH_ALL_TYPES_NO_BOOL(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( \
kComplexHalf, kHalf, kBFloat16, \
TYPE, NAME, __VA_ARGS__)
#define _AT_DISPATCH_MUL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( \
kHalf, kBFloat16, \
TYPE, NAME, __VA_ARGS__)
#endif
void mul_kernel(TensorIteratorBase& iter) {
auto dtype = iter.common_dtype();
if (dtype == ScalarType::Bool) {
@ -113,7 +85,7 @@ void mul_kernel(TensorIteratorBase& iter) {
});
});
} else {
_AT_DISPATCH_MUL_TYPES(dtype, "mul_cpu", [&]() {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, dtype, "mul_cpu", [&]() {
cpu_kernel_vec(iter,
[=](scalar_t a, scalar_t b) __ubsan_ignore_undefined__ -> scalar_t { return a * b; },
[=](Vectorized<scalar_t> a, Vectorized<scalar_t> b) __ubsan_ignore_undefined__ {
@ -556,14 +528,14 @@ void ge_kernel(TensorIteratorBase& iter) {
void eq_kernel(TensorIteratorBase& iter) {
// See Note [special-case bool outputs]
if (iter.dtype() == ScalarType::Bool) {
_AT_DISPATCH_ALL_TYPES_AND_BOOL(iter.common_dtype(), "eq_cpu", [&]() {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kBool, kBFloat16, kHalf, iter.common_dtype(), "eq_cpu", [&]() {
cpu_kernel(iter,
[](scalar_t a, scalar_t b) -> bool {
return a == b;
});
});
} else {
_AT_DISPATCH_ALL_TYPES_NO_BOOL(iter.common_dtype(), "eq_cpu", [&]() {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kComplexHalf, kBFloat16, kHalf, iter.common_dtype(), "eq_cpu", [&]() {
cpu_kernel_vec(
iter,
[](scalar_t a, scalar_t b) -> scalar_t {
@ -579,14 +551,14 @@ void eq_kernel(TensorIteratorBase& iter) {
void ne_kernel(TensorIteratorBase& iter) {
// See Note [special-case bool outputs]
if (iter.dtype() == ScalarType::Bool) {
_AT_DISPATCH_ALL_TYPES_AND_BOOL(iter.common_dtype(), "ne_cpu", [&]() {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kBool, kBFloat16, kHalf, iter.common_dtype(), "ne_cpu", [&]() {
cpu_kernel(iter,
[](scalar_t a, scalar_t b) -> bool {
return a != b;
});
});
} else {
_AT_DISPATCH_ALL_TYPES_NO_BOOL(iter.common_dtype(), "ne_cpu", [&]() {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kComplexHalf, kBFloat16, kHalf, iter.common_dtype(), "ne_cpu", [&]() {
cpu_kernel_vec(
iter,
[](scalar_t a, scalar_t b) -> scalar_t {

View File

@ -263,17 +263,6 @@ void gemm_core_(
}
}
#if !defined(C10_MOBILE)
#define _AT_DISPATCH_GEMM_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
TYPE, NAME, __VA_ARGS__)
#else
#define _AT_DISPATCH_GEMM_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( \
kHalf, kBFloat16, \
TYPE, NAME, __VA_ARGS__)
#endif
void cpublas_gemm_impl(
at::ScalarType type,
TransposeType transa, TransposeType transb,
@ -283,7 +272,9 @@ void cpublas_gemm_impl(
const void *b, int64_t ldb,
const Scalar& beta,
void *c, int64_t ldc) {
_AT_DISPATCH_GEMM_TYPES(type, "cpublas_gemm_impl", [&]{
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::kHalf, at::kBFloat16,
type, "cpublas_gemm_impl",
[&]{
using opmath_t = at::opmath_type<scalar_t>;
gemm_core_(
transa, transb, m, n, k,

View File

@ -164,27 +164,6 @@ static void float_bfloat16_copy_kernel(TensorIteratorBase &iter, bool requires_n
}
}
#if !defined(C10_MOBILE)
#define _AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6( \
ScalarType::ComplexHalf, ScalarType::Half, ScalarType::Bool, \
ScalarType::BFloat16, ScalarType::Float8_e5m2, ScalarType::Float8_e4m3fn, \
TYPE, NAME, __VA_ARGS__)
#define _AT_DISPATCH_ALL_TYPES_NO_CF(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND5( \
kBool, kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
TYPE, NAME, __VA_ARGS__)
#else
#define _AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
ScalarType::ComplexHalf, ScalarType::Half, ScalarType::Bool,ScalarType::BFloat16, \
TYPE, NAME, __VA_ARGS__)
#define _AT_DISPATCH_ALL_TYPES_NO_CF(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( \
kBool, kHalf, kBFloat16, \
TYPE, NAME, __VA_ARGS__)
#endif
void direct_copy_kernel(TensorIteratorBase &iter) {
// TODO: we don't actually need separate instantiations per dtype;
// we only need a separate instantiation per dtype size. This would
@ -204,7 +183,8 @@ void direct_copy_kernel(TensorIteratorBase &iter) {
} else if (dtype == ScalarType::ComplexHalf) {
cpu_kernel(iter, [=](c10::complex<at::Half> a) -> c10::complex<at::Half> { return a; });
} else {
_AT_DISPATCH_ALL_TYPES_NO_CF(dtype, "copy_kernel", [&] {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
kBool, kHalf, kBFloat16, dtype, "copy_kernel", [&] {
cpu_kernel_vec(
iter,
[=](scalar_t a) -> scalar_t { return a; },
@ -257,9 +237,9 @@ void copy_kernel(TensorIterator& iter, bool /*non_blocking*/) {
sizeof(BFloat16) == strides_out[0] && (sizeof(float) == strides_in[0] || strides_in[0] == 0)))) {
float_bfloat16_copy_kernel(iter, requires_neg);
} else {
_AT_DISPATCH_ALL_TYPES(dtype, "copy_", [&] {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(ScalarType::ComplexHalf, ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16, dtype, "copy_", [&] {
using dest_t = scalar_t;
_AT_DISPATCH_ALL_TYPES(iter.dtype(1), "copy_", [&] {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(ScalarType::ComplexHalf, ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16, iter.dtype(1), "copy_", [&] {
if (iter.has_contiguous_first_dim()) {
TORCH_INTERNAL_ASSERT(iter.ninputs() == 1);
TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);

View File

@ -179,18 +179,6 @@ static void logit_kernel(TensorIteratorBase& iter, const Scalar& eps_scalar) {
});
}
#if !defined(C10_MOBILE)
#define _AT_DISPATCH_ABS_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
TYPE, NAME, __VA_ARGS__)
#else
#define _AT_DISPATCH_ABS_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( \
kHalf, kBFloat16, \
TYPE, NAME, __VA_ARGS__)
#endif
static void abs_kernel(TensorIteratorBase& iter) {
auto dtype = iter.dtype();
if (dtype == kComplexHalf) {
@ -198,7 +186,7 @@ static void abs_kernel(TensorIteratorBase& iter) {
using opmath_t = at::opmath_type<scalar_t>;
cpu_kernel(iter, [=](scalar_t a) -> scalar_t { return abs_impl(opmath_t{a}); });
} else {
_AT_DISPATCH_ABS_TYPES(iter.dtype(), "abs_cpu", [&]() {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.dtype(), "abs_cpu", [&]() {
cpu_kernel_vec(
iter,
[=](scalar_t a) -> scalar_t { return abs_impl(a); },

View File

@ -186,12 +186,6 @@ template <> inline std::string typeName<at::Half>(){
template <> inline std::string typeName<at::BFloat16>(){
return "at::BFloat16";
}
template <> inline std::string typeName<at::Float8_e5m2>(){
return "at::Float8_e5m2";
}
template <> inline std::string typeName<at::Float8_e4m3fn>(){
return "at::Float8_e4m3fn";
}
#define TYPE_NAME_CASE(ctype, scalartype) \
case ScalarType::scalartype: return typeName<ctype>();

View File

@ -48,13 +48,7 @@ class C10_API Scalar {
#define DEFINE_IMPLICIT_CTOR(type, name) \
Scalar(type vv) : Scalar(vv, true) {}
AT_FORALL_SCALAR_TYPES_AND5(
Half,
BFloat16,
Float8_e5m2,
Float8_e4m3fn,
ComplexHalf,
DEFINE_IMPLICIT_CTOR)
AT_FORALL_SCALAR_TYPES_AND3(Half, BFloat16, ComplexHalf, DEFINE_IMPLICIT_CTOR)
AT_FORALL_COMPLEX_TYPES(DEFINE_IMPLICIT_CTOR)
#undef DEFINE_IMPLICIT_CTOR

View File

@ -3,8 +3,6 @@
#include <c10/util/BFloat16.h>
#include <c10/util/Deprecated.h>
#include <c10/util/Exception.h>
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e5m2.h>
#include <c10/util/Half.h>
#include <c10/util/bits.h>
#include <c10/util/complex.h>
@ -52,9 +50,7 @@ namespace c10 {
_(c10::bits2x4, Bits2x4) /* 19 */ \
_(c10::bits4x2, Bits4x2) /* 20 */ \
_(c10::bits8, Bits8) /* 21 */ \
_(c10::bits16, Bits16) /* 22 */ \
_(c10::Float8_e5m2, Float8_e5m2) /* 23 */ \
_(c10::Float8_e4m3fn, Float8_e4m3fn) /* 24 */
_(c10::bits16, Bits16) /* 22 */
// If you want to support ComplexHalf for real, add ComplexHalf
// into this macro (and change the name). But beware: convert()
@ -71,9 +67,7 @@ namespace c10 {
_(c10::complex<float>, ComplexFloat) \
_(c10::complex<double>, ComplexDouble) \
_(bool, Bool) \
_(at::BFloat16, BFloat16) \
_(at::Float8_e5m2, Float8_e5m2) \
_(at::Float8_e4m3fn, Float8_e4m3fn)
_(at::BFloat16, BFloat16)
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(_) \
_(uint8_t, Byte) \
@ -88,9 +82,7 @@ namespace c10 {
_(c10::complex<float>, ComplexFloat) \
_(c10::complex<double>, ComplexDouble) \
_(bool, Bool) \
_(at::BFloat16, BFloat16) \
_(at::Float8_e5m2, Float8_e5m2) \
_(at::Float8_e4m3fn, Float8_e4m3fn)
_(at::BFloat16, BFloat16)
enum class ScalarType : int8_t {
#define DEFINE_ENUM(_1, n) n,
@ -209,53 +201,6 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
::c10::ScalarType::SCALARTYPE3>::t), \
SCALARTYPE3)
#define AT_FORALL_SCALAR_TYPES_AND4( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, _) \
_(uint8_t, Byte) \
_(int8_t, Char) \
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long) \
_(float, Float) \
_(double, Double) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE1>::t), \
SCALARTYPE1) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE2>::t), \
SCALARTYPE2) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE3>::t), \
SCALARTYPE3) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE4>::t), \
SCALARTYPE4)
#define AT_FORALL_SCALAR_TYPES_AND5( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5, _) \
_(uint8_t, Byte) \
_(int8_t, Char) \
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long) \
_(float, Float) \
_(double, Double) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE1>::t), \
SCALARTYPE1) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE2>::t), \
SCALARTYPE2) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE3>::t), \
SCALARTYPE3) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE4>::t), \
SCALARTYPE4) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE5>::t), \
SCALARTYPE5)
#define AT_FORALL_QINT_TYPES(_) \
_(c10::qint8, QInt8) \
_(c10::quint8, QUInt8) \
@ -316,8 +261,7 @@ static inline bool isIntegralType(ScalarType t) {
static inline bool isFloatingType(ScalarType t) {
return (
t == ScalarType::Double || t == ScalarType::Float ||
t == ScalarType::Half || t == ScalarType::BFloat16 ||
t == ScalarType::Float8_e5m2 || t == ScalarType::Float8_e4m3fn);
t == ScalarType::Half || t == ScalarType::BFloat16);
}
static inline bool isReducedFloatingType(ScalarType t) {
@ -390,8 +334,7 @@ static inline bool isSignedType(ScalarType t) {
case ScalarType::ComplexFloat:
case ScalarType::ComplexDouble:
return true;
AT_FORALL_SCALAR_TYPES_AND5(
Half, Bool, BFloat16, Float8_e5m2, Float8_e4m3fn, CASE_SIGNED)
AT_FORALL_SCALAR_TYPES_AND3(Half, Bool, BFloat16, CASE_SIGNED)
default:
TORCH_CHECK(false, "Unknown ScalarType");
}
@ -482,8 +425,6 @@ static inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
constexpr auto c8 = ScalarType::ComplexDouble;
constexpr auto b1 = ScalarType::Bool;
constexpr auto bf = ScalarType::BFloat16;
constexpr auto b8 = ScalarType::Float8_e5m2;
constexpr auto h8 = ScalarType::Float8_e4m3fn;
constexpr auto ud = ScalarType::Undefined;
if (a == ud || b == ud) {
return ScalarType::Undefined;
@ -521,25 +462,23 @@ static inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
// clang-format off
static constexpr ScalarType _promoteTypesLookup[
NUM_PROMOTE_TYPES][NUM_PROMOTE_TYPES] = {
/* u1 i1 i2 i4 i8 f2 f4 f8 c2 c4 c8 b1 q1 q2 q3 bf b8 h8*/
/* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, u1, ud, ud, ud, bf, b8, h8},
/* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, i1, ud, ud, ud, bf, b8, h8},
/* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, i2, ud, ud, ud, bf, b8, h8},
/* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, c2, c4, c8, i4, ud, ud, ud, bf, b8, h8},
/* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, c2, c4, c8, i8, ud, ud, ud, bf, b8, h8},
/* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, c2, c4, c8, f2, ud, ud, ud, f4, f4, f4},
/* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, c4, c4, c8, f4, ud, ud, ud, f4, f4, f4},
/* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, c8, c8, c8, f8, ud, ud, ud, f8, f8, f8},
/* c2 */ {c2, c2, c2, c2, c2, c2, c4, c8, c2, c4, c8, c2, ud, ud, ud, c4, c4, c4},
/* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c4, c8, c4, ud, ud, ud, c4, c4, c4},
/* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, ud, ud, ud, c8, c8, c8},
/* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, b1, ud, ud, ud, bf, b8, h8},
/* q1 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
/* q2 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
/* q3 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
/* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, c4, c4, c8, bf, ud, ud, ud, bf, bf, bf},
/* b8 */ {b8, b8, b8, b8, b8, f4, f4, f8, c4, c4, c8, b8, ud, ud, ud, bf, b8, ud},
/* h8 */ {h8, h8, h8, h8, h8, f4, f4, f8, c4, c4, c8, h8, ud, ud, ud, bf, ud, h8},
/* u1 i1 i2 i4 i8 f2 f4 f8 c2 c4 c8 b1 q1 q2 q3 bf*/
/* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, u1, ud, ud, ud, bf},
/* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, i1, ud, ud, ud, bf},
/* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, i2, ud, ud, ud, bf},
/* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, c2, c4, c8, i4, ud, ud, ud, bf},
/* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, c2, c4, c8, i8, ud, ud, ud, bf},
/* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, c2, c4, c8, f2, ud, ud, ud, f4},
/* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, c4, c4, c8, f4, ud, ud, ud, f4},
/* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, c8, c8, c8, f8, ud, ud, ud, f8},
/* c2 */ {c2, c2, c2, c2, c2, c2, c4, c8, c2, c4, c8, c2, ud, ud, ud, c4},
/* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c4, c8, c4, ud, ud, ud, c4},
/* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, ud, ud, ud, c8},
/* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, b1, ud, ud, ud, bf},
/* q1 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
/* q2 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
/* q3 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
/* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, c4, c4, c8, bf, ud, ud, ud, bf},
};
// clang-format on
return _promoteTypesLookup[static_cast<int>(a)][static_cast<int>(b)];
@ -551,4 +490,8 @@ inline std::ostream& operator<<(
return stream << toString(scalar_type);
}
#define AT_FORAUTOCAST_SCALAR_TYPES(_) \
_(half, Half) /* 0 */ \
_(bfloat16, BFloat16) /* 1 */
} // namespace c10

View File

@ -1,275 +0,0 @@
#pragma once
#include <c10/macros/Macros.h>
#include <cstring>
#include <limits>
C10_CLANG_DIAGNOSTIC_PUSH()
#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
#endif
namespace c10 {
/// Constructors
inline C10_HOST_DEVICE Float8_e4m3fn::Float8_e4m3fn(float value) {
x = detail::fp8e4m3fn_from_fp32_value(value);
}
/// Implicit conversions
inline C10_HOST_DEVICE Float8_e4m3fn::operator float() const {
return detail::fp8e4m3fn_to_fp32_value(x);
}
/// Special values helper
inline C10_HOST_DEVICE bool Float8_e4m3fn::isnan() const {
return (x & 0b01111111) == 0b01111111;
}
/// Arithmetic
inline C10_HOST_DEVICE Float8_e4m3fn
operator+(const Float8_e4m3fn& a, const Float8_e4m3fn& b) {
return static_cast<float>(a) + static_cast<float>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fn
operator-(const Float8_e4m3fn& a, const Float8_e4m3fn& b) {
return static_cast<float>(a) - static_cast<float>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fn
operator*(const Float8_e4m3fn& a, const Float8_e4m3fn& b) {
return static_cast<float>(a) * static_cast<float>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fn operator/(
const Float8_e4m3fn& a,
const Float8_e4m3fn& b) __ubsan_ignore_float_divide_by_zero__ {
return static_cast<float>(a) / static_cast<float>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fn operator-(const Float8_e4m3fn& a) {
return -static_cast<float>(a);
}
inline C10_HOST_DEVICE Float8_e4m3fn& operator+=(
Float8_e4m3fn& a,
const Float8_e4m3fn& b) {
a = a + b;
return a;
}
inline C10_HOST_DEVICE Float8_e4m3fn& operator-=(
Float8_e4m3fn& a,
const Float8_e4m3fn& b) {
a = a - b;
return a;
}
inline C10_HOST_DEVICE Float8_e4m3fn& operator*=(
Float8_e4m3fn& a,
const Float8_e4m3fn& b) {
a = a * b;
return a;
}
inline C10_HOST_DEVICE Float8_e4m3fn& operator/=(
Float8_e4m3fn& a,
const Float8_e4m3fn& b) {
a = a / b;
return a;
}
/// Arithmetic with floats
inline C10_HOST_DEVICE float operator+(Float8_e4m3fn a, float b) {
return static_cast<float>(a) + b;
}
inline C10_HOST_DEVICE float operator-(Float8_e4m3fn a, float b) {
return static_cast<float>(a) - b;
}
inline C10_HOST_DEVICE float operator*(Float8_e4m3fn a, float b) {
return static_cast<float>(a) * b;
}
inline C10_HOST_DEVICE float operator/(Float8_e4m3fn a, float b)
__ubsan_ignore_float_divide_by_zero__ {
return static_cast<float>(a) / b;
}
inline C10_HOST_DEVICE float operator+(float a, Float8_e4m3fn b) {
return a + static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator-(float a, Float8_e4m3fn b) {
return a - static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator*(float a, Float8_e4m3fn b) {
return a * static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator/(float a, Float8_e4m3fn b)
__ubsan_ignore_float_divide_by_zero__ {
return a / static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e4m3fn& b) {
return a += static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e4m3fn& b) {
return a -= static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e4m3fn& b) {
return a *= static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e4m3fn& b) {
return a /= static_cast<float>(b);
}
/// Arithmetic with doubles
inline C10_HOST_DEVICE double operator+(Float8_e4m3fn a, double b) {
return static_cast<double>(a) + b;
}
inline C10_HOST_DEVICE double operator-(Float8_e4m3fn a, double b) {
return static_cast<double>(a) - b;
}
inline C10_HOST_DEVICE double operator*(Float8_e4m3fn a, double b) {
return static_cast<double>(a) * b;
}
inline C10_HOST_DEVICE double operator/(Float8_e4m3fn a, double b)
__ubsan_ignore_float_divide_by_zero__ {
return static_cast<double>(a) / b;
}
inline C10_HOST_DEVICE double operator+(double a, Float8_e4m3fn b) {
return a + static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator-(double a, Float8_e4m3fn b) {
return a - static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator*(double a, Float8_e4m3fn b) {
return a * static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator/(double a, Float8_e4m3fn b)
__ubsan_ignore_float_divide_by_zero__ {
return a / static_cast<double>(b);
}
/// Arithmetic with ints
inline C10_HOST_DEVICE Float8_e4m3fn operator+(Float8_e4m3fn a, int b) {
return a + static_cast<Float8_e4m3fn>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fn operator-(Float8_e4m3fn a, int b) {
return a - static_cast<Float8_e4m3fn>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fn operator*(Float8_e4m3fn a, int b) {
return a * static_cast<Float8_e4m3fn>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fn operator/(Float8_e4m3fn a, int b) {
return a / static_cast<Float8_e4m3fn>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fn operator+(int a, Float8_e4m3fn b) {
return static_cast<Float8_e4m3fn>(a) + b;
}
inline C10_HOST_DEVICE Float8_e4m3fn operator-(int a, Float8_e4m3fn b) {
return static_cast<Float8_e4m3fn>(a) - b;
}
inline C10_HOST_DEVICE Float8_e4m3fn operator*(int a, Float8_e4m3fn b) {
return static_cast<Float8_e4m3fn>(a) * b;
}
inline C10_HOST_DEVICE Float8_e4m3fn operator/(int a, Float8_e4m3fn b) {
return static_cast<Float8_e4m3fn>(a) / b;
}
//// Arithmetic with int64_t
inline C10_HOST_DEVICE Float8_e4m3fn operator+(Float8_e4m3fn a, int64_t b) {
return a + static_cast<Float8_e4m3fn>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fn operator-(Float8_e4m3fn a, int64_t b) {
return a - static_cast<Float8_e4m3fn>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fn operator*(Float8_e4m3fn a, int64_t b) {
return a * static_cast<Float8_e4m3fn>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fn operator/(Float8_e4m3fn a, int64_t b) {
return a / static_cast<Float8_e4m3fn>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fn operator+(int64_t a, Float8_e4m3fn b) {
return static_cast<Float8_e4m3fn>(a) + b;
}
inline C10_HOST_DEVICE Float8_e4m3fn operator-(int64_t a, Float8_e4m3fn b) {
return static_cast<Float8_e4m3fn>(a) - b;
}
inline C10_HOST_DEVICE Float8_e4m3fn operator*(int64_t a, Float8_e4m3fn b) {
return static_cast<Float8_e4m3fn>(a) * b;
}
inline C10_HOST_DEVICE Float8_e4m3fn operator/(int64_t a, Float8_e4m3fn b) {
return static_cast<Float8_e4m3fn>(a) / b;
}
/// NOTE: we do not define comparisons directly and instead rely on the implicit
/// conversion from c10::Float8_e4m3fn to float.
} // namespace c10
namespace std {
template <>
class numeric_limits<c10::Float8_e4m3fn> {
public:
static constexpr bool is_specialized = true;
static constexpr bool is_signed = true;
static constexpr bool is_integer = false;
static constexpr bool is_exact = false;
static constexpr bool has_infinity = false;
static constexpr bool has_quiet_NaN = true;
static constexpr bool has_signaling_NaN = false;
static constexpr auto has_denorm = true;
static constexpr auto has_denorm_loss = true;
static constexpr auto round_style = numeric_limits<float>::round_style;
static constexpr bool is_iec559 = false;
static constexpr bool is_bounded = true;
static constexpr bool is_modulo = false;
static constexpr int digits = 4;
static constexpr int digits10 = 0;
static constexpr int max_digits10 = 3;
static constexpr int radix = 2;
static constexpr int min_exponent = -5;
static constexpr int min_exponent10 = -1;
static constexpr int max_exponent = 8;
static constexpr int max_exponent10 = 2;
static constexpr auto traps = numeric_limits<float>::traps;
static constexpr auto tinyness_before = false;
static constexpr c10::Float8_e4m3fn min() {
return c10::Float8_e4m3fn(0x08, c10::Float8_e4m3fn::from_bits());
}
static constexpr c10::Float8_e4m3fn lowest() {
return c10::Float8_e4m3fn(0xFE, c10::Float8_e4m3fn::from_bits());
}
static constexpr c10::Float8_e4m3fn max() {
return c10::Float8_e4m3fn(0x7E, c10::Float8_e4m3fn::from_bits());
}
static constexpr c10::Float8_e4m3fn epsilon() {
return c10::Float8_e4m3fn(0x20, c10::Float8_e4m3fn::from_bits());
}
static constexpr c10::Float8_e4m3fn round_error() {
return c10::Float8_e4m3fn(0x30, c10::Float8_e4m3fn::from_bits());
}
static constexpr c10::Float8_e4m3fn quiet_NaN() {
return c10::Float8_e4m3fn(0x7F, c10::Float8_e4m3fn::from_bits());
}
static constexpr c10::Float8_e4m3fn denorm_min() {
return c10::Float8_e4m3fn(0x01, c10::Float8_e4m3fn::from_bits());
}
};
} // namespace std
C10_CLANG_DIAGNOSTIC_POP()

View File

@ -1,14 +0,0 @@
#include <c10/util/Float8_e4m3fn.h>
#include <iostream>
namespace c10 {
static_assert(
std::is_standard_layout<Float8_e4m3fn>::value,
"c10::Float8_e4m3fn must be standard layout.");
std::ostream& operator<<(std::ostream& out, const Float8_e4m3fn& value) {
out << (float)value;
return out;
}
} // namespace c10

View File

@ -1,240 +0,0 @@
#pragma once
/// Defines the Float8_e4m3fn type (8-bit floating-point) including conversions
/// to standard C types and basic arithmetic operations. Note that arithmetic
/// operations are implemented by converting to floating point and
/// performing the operation in float32.
/// Binary configuration:
/// s eeee mmm
/// 1 sign bit
/// 4 exponent bits
/// 3 mantissa bits
/// bias = 7
///
/// Implementation based on the paper https://arxiv.org/pdf/2209.05433.pdf
/// and inspired by Half implementation from pytorch/c10/util/Half.h
#include <c10/macros/Macros.h>
#include <c10/util/C++17.h>
#include <c10/util/TypeSafeSignMath.h>
#include <c10/util/floating_point_utils.h>
#include <type_traits>
#if defined(__cplusplus) && (__cplusplus >= 201103L)
#include <cmath>
#include <cstdint>
#elif !defined(__OPENCL_VERSION__)
#include <math.h>
#include <stdint.h>
#endif
#ifdef _MSC_VER
#include <intrin.h>
#endif
#include <cstdint>
#include <cstring>
#include <iosfwd>
#include <limits>
#include <sstream>
#include <stdexcept>
#include <string>
#include <utility>
#include <typeinfo> // operator typeid
namespace c10 {
namespace detail {
/*
* Convert a 8-bit floating-point number in fp8 E4M3FN format, in bit
* representation, to a 32-bit floating-point number in IEEE single-precision
* format, in bit representation.
*
* @note The implementation doesn't use any floating-point operations.
*/
inline C10_HOST_DEVICE float fp8e4m3fn_to_fp32_value(uint8_t input) {
/*
* Extend the fp8 E4M3FN number to 32 bits and shift to the
* upper part of the 32-bit word:
* +---+----+---+-----------------------------+
* | S |EEEE|MMM|0000 0000 0000 0000 0000 0000|
* +---+----+---+-----------------------------+
* Bits 31 27-30 24-26 0-23
*
* S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0
* - zero bits.
*/
const uint32_t w = (uint32_t)input << 24;
/*
* Extract the sign of the input number into the high bit of the 32-bit word:
*
* +---+----------------------------------+
* | S |0000000 00000000 00000000 00000000|
* +---+----------------------------------+
* Bits 31 0-31
*/
const uint32_t sign = w & UINT32_C(0x80000000);
/*
* Extract mantissa and biased exponent of the input number into the bits 0-30
* of the 32-bit word:
*
* +---+----+---+-----------------------------+
* | S |EEEE|MMM|0000 0000 0000 0000 0000 0000|
* +---+----+---+-----------------------------+
* Bits 31 27-30 24-26 0-23
*/
const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF);
/*
* Renorm shift is the number of bits to shift mantissa left to make the
* half-precision number normalized. If the initial number is normalized, some
* of its high 5 bits (sign == 0 and 4-bit exponent) equals one. In this case
* renorm_shift == 0. If the number is denormalize, renorm_shift > 0. Note
* that if we shift denormalized nonsign by renorm_shift, the unit bit of
* mantissa will shift into exponent, turning the biased exponent into 1, and
* making mantissa normalized (i.e. without leading 1).
*/
#if defined(__CUDA_ARCH__)
uint32_t renorm_shift = __clz(nonsign);
#elif defined(_MSC_VER)
unsigned long nonsign_bsr;
_BitScanReverse(&nonsign_bsr, (unsigned long)nonsign);
uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31;
#else
uint32_t renorm_shift = __builtin_clz(nonsign);
#endif
renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0;
/*
* Iff fp8e4m3fn number has all exponent and mantissa bits set to 1,
* the addition overflows it into bit 31, and the subsequent shift turns the
* high 9 bits into 1. Thus inf_nan_mask == 0x7F800000 if the fp8e4m3fn number
* is Nan, 0x00000000 otherwise
*/
const int32_t inf_nan_mask =
((int32_t)(nonsign + 0x01000000) >> 8) & INT32_C(0x7F800000);
/*
* Iff nonsign is 0, it overflows into 0xFFFFFFFF, turning bit 31
* into 1. Otherwise, bit 31 remains 0. The signed shift right by 31
* broadcasts bit 31 into all bits of the zero_mask. Thus zero_mask ==
* 0xFFFFFFFF if the half-precision number was zero (+0.0h or -0.0h)
* 0x00000000 otherwise
*/
const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31;
/*
* 1. Shift nonsign left by renorm_shift to normalize it (if the input
* was denormal)
* 2. Shift nonsign right by 4 so the exponent (4 bits originally)
* becomes an 8-bit field and 3-bit mantissa shifts into the 3 high
* bits of the 23-bit mantissa of IEEE single-precision number.
* 3. Add 0x78 to the exponent (starting at bit 23) to compensate the
* different in exponent bias (0x7F for single-precision number less 0x07
* for fp8e4m3fn number).
* 4. Subtract renorm_shift from the exponent (starting at bit 23) to
* account for renormalization. As renorm_shift is less than 0x78, this
* can be combined with step 3.
* 5. Binary OR with inf_nan_mask to turn the exponent into 0xFF if the
* input was NaN or infinity.
* 6. Binary ANDNOT with zero_mask to turn the mantissa and exponent
* into zero if the input was zero.
* 7. Combine with the sign of the input number.
*/
uint32_t result = sign |
((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) |
inf_nan_mask) &
~zero_mask);
return fp32_from_bits(result);
}
/*
* Convert a 32-bit floating-point number in IEEE single-precision format to a
* 8-bit floating-point number in fp8 E4M3FN format, in bit representation.
*/
inline C10_HOST_DEVICE uint8_t fp8e4m3fn_from_fp32_value(float f) {
/*
* Binary representation of 480.0f, which is the first value
* not representable in fp8e4m3fn range:
* 0 1111 111 - fp8e4m3fn
* 0 10000111 11100000000000000000000 - fp32
*/
constexpr uint32_t fp8_max = UINT32_C(1087) << 20;
/*
* A mask for converting fp32 numbers lower than fp8e4m3fn normal range
* into denorm representation
* magic number: ((127 - 7) + (23 - 3) + 1)
*/
constexpr uint32_t denorm_mask = UINT32_C(141) << 23;
uint32_t f_bits = fp32_to_bits(f);
uint8_t result = 0u;
/*
* Extract the sign of the input number into the high bit of the 32-bit word:
*
* +---+----------------------------------+
* | S |0000000 00000000 00000000 00000000|
* +---+----------------------------------+
* Bits 31 0-31
*/
const uint32_t sign = f_bits & UINT32_C(0x80000000);
/*
* Set sign bit to 0
*/
f_bits ^= sign;
if (f_bits >= fp8_max) {
// NaN - all exponent and mantissa bits set to 1
result = 0x7f;
} else {
if (f_bits < (UINT32_C(121) << 23)) {
// Input number is smaller than 2^(-6), which is the smallest
// fp8e4m3fn normal number
f_bits =
fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
result = static_cast<uint8_t>(f_bits - denorm_mask);
} else {
// resulting mantissa is odd
uint8_t mant_odd = (f_bits >> 20) & 1;
// update exponent, rounding bias part 1
f_bits += ((uint32_t)(7 - 127) << 23) + 0x7FFFF;
// rounding bias part 2
f_bits += mant_odd;
// take the bits!
result = static_cast<uint8_t>(f_bits >> 20);
}
}
result |= static_cast<uint8_t>(sign >> 24);
return result;
}
} // namespace detail
struct alignas(1) Float8_e4m3fn {
uint8_t x;
struct from_bits_t {};
C10_HOST_DEVICE static constexpr from_bits_t from_bits() {
return from_bits_t();
}
Float8_e4m3fn() = default;
constexpr C10_HOST_DEVICE Float8_e4m3fn(uint8_t bits, from_bits_t)
: x(bits){};
inline C10_HOST_DEVICE Float8_e4m3fn(float value);
inline C10_HOST_DEVICE operator float() const;
inline C10_HOST_DEVICE bool isnan() const;
};
C10_API std::ostream& operator<<(std::ostream& out, const Float8_e4m3fn& value);
} // namespace c10
#include <c10/util/Float8_e4m3fn-inl.h> // IWYU pragma: keep

View File

@ -1,284 +0,0 @@
#pragma once
#include <c10/macros/Macros.h>
#include <cstring>
#include <limits>
C10_CLANG_DIAGNOSTIC_PUSH()
#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
#endif
#define EXP_WIDTH_FP8 5
#define MAN_WIDTH_FP8 2
#define EXP_BIAS_FP8 15
namespace c10 {
/// Constructors
inline C10_HOST_DEVICE Float8_e5m2::Float8_e5m2(float value) {
x = detail::fp8e5m2_from_fp32_value(value);
}
/// Implicit conversions
inline C10_HOST_DEVICE Float8_e5m2::operator float() const {
return detail::fp8e5m2_to_fp32_value(x);
}
/// Special values helpers
inline C10_HOST_DEVICE bool Float8_e5m2::isnan() const {
return (x & 0b01111111) > 0b01111100;
}
inline C10_HOST_DEVICE bool Float8_e5m2::isinf() const {
return (x & 0b01111111) == 0b01111100;
}
/// Arithmetic
inline C10_HOST_DEVICE Float8_e5m2
operator+(const Float8_e5m2& a, const Float8_e5m2& b) {
return static_cast<float>(a) + static_cast<float>(b);
}
inline C10_HOST_DEVICE Float8_e5m2
operator-(const Float8_e5m2& a, const Float8_e5m2& b) {
return static_cast<float>(a) - static_cast<float>(b);
}
inline C10_HOST_DEVICE Float8_e5m2
operator*(const Float8_e5m2& a, const Float8_e5m2& b) {
return static_cast<float>(a) * static_cast<float>(b);
}
inline C10_HOST_DEVICE Float8_e5m2 operator/(
const Float8_e5m2& a,
const Float8_e5m2& b) __ubsan_ignore_float_divide_by_zero__ {
return static_cast<float>(a) / static_cast<float>(b);
}
inline C10_HOST_DEVICE Float8_e5m2 operator-(const Float8_e5m2& a) {
return -static_cast<float>(a);
}
inline C10_HOST_DEVICE Float8_e5m2& operator+=(
Float8_e5m2& a,
const Float8_e5m2& b) {
a = a + b;
return a;
}
inline C10_HOST_DEVICE Float8_e5m2& operator-=(
Float8_e5m2& a,
const Float8_e5m2& b) {
a = a - b;
return a;
}
inline C10_HOST_DEVICE Float8_e5m2& operator*=(
Float8_e5m2& a,
const Float8_e5m2& b) {
a = a * b;
return a;
}
inline C10_HOST_DEVICE Float8_e5m2& operator/=(
Float8_e5m2& a,
const Float8_e5m2& b) {
a = a / b;
return a;
}
/// Arithmetic with floats
inline C10_HOST_DEVICE float operator+(Float8_e5m2 a, float b) {
return static_cast<float>(a) + b;
}
inline C10_HOST_DEVICE float operator-(Float8_e5m2 a, float b) {
return static_cast<float>(a) - b;
}
inline C10_HOST_DEVICE float operator*(Float8_e5m2 a, float b) {
return static_cast<float>(a) * b;
}
inline C10_HOST_DEVICE float operator/(Float8_e5m2 a, float b)
__ubsan_ignore_float_divide_by_zero__ {
return static_cast<float>(a) / b;
}
inline C10_HOST_DEVICE float operator+(float a, Float8_e5m2 b) {
return a + static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator-(float a, Float8_e5m2 b) {
return a - static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator*(float a, Float8_e5m2 b) {
return a * static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator/(float a, Float8_e5m2 b)
__ubsan_ignore_float_divide_by_zero__ {
return a / static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e5m2& b) {
return a += static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e5m2& b) {
return a -= static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e5m2& b) {
return a *= static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e5m2& b) {
return a /= static_cast<float>(b);
}
/// Arithmetic with doubles
inline C10_HOST_DEVICE double operator+(Float8_e5m2 a, double b) {
return static_cast<double>(a) + b;
}
inline C10_HOST_DEVICE double operator-(Float8_e5m2 a, double b) {
return static_cast<double>(a) - b;
}
inline C10_HOST_DEVICE double operator*(Float8_e5m2 a, double b) {
return static_cast<double>(a) * b;
}
inline C10_HOST_DEVICE double operator/(Float8_e5m2 a, double b)
__ubsan_ignore_float_divide_by_zero__ {
return static_cast<double>(a) / b;
}
inline C10_HOST_DEVICE double operator+(double a, Float8_e5m2 b) {
return a + static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator-(double a, Float8_e5m2 b) {
return a - static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator*(double a, Float8_e5m2 b) {
return a * static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator/(double a, Float8_e5m2 b)
__ubsan_ignore_float_divide_by_zero__ {
return a / static_cast<double>(b);
}
/// Arithmetic with ints
inline C10_HOST_DEVICE Float8_e5m2 operator+(Float8_e5m2 a, int b) {
return a + static_cast<Float8_e5m2>(b);
}
inline C10_HOST_DEVICE Float8_e5m2 operator-(Float8_e5m2 a, int b) {
return a - static_cast<Float8_e5m2>(b);
}
inline C10_HOST_DEVICE Float8_e5m2 operator*(Float8_e5m2 a, int b) {
return a * static_cast<Float8_e5m2>(b);
}
inline C10_HOST_DEVICE Float8_e5m2 operator/(Float8_e5m2 a, int b) {
return a / static_cast<Float8_e5m2>(b);
}
inline C10_HOST_DEVICE Float8_e5m2 operator+(int a, Float8_e5m2 b) {
return static_cast<Float8_e5m2>(a) + b;
}
inline C10_HOST_DEVICE Float8_e5m2 operator-(int a, Float8_e5m2 b) {
return static_cast<Float8_e5m2>(a) - b;
}
inline C10_HOST_DEVICE Float8_e5m2 operator*(int a, Float8_e5m2 b) {
return static_cast<Float8_e5m2>(a) * b;
}
inline C10_HOST_DEVICE Float8_e5m2 operator/(int a, Float8_e5m2 b) {
return static_cast<Float8_e5m2>(a) / b;
}
//// Arithmetic with int64_t
inline C10_HOST_DEVICE Float8_e5m2 operator+(Float8_e5m2 a, int64_t b) {
return a + static_cast<Float8_e5m2>(b);
}
inline C10_HOST_DEVICE Float8_e5m2 operator-(Float8_e5m2 a, int64_t b) {
return a - static_cast<Float8_e5m2>(b);
}
inline C10_HOST_DEVICE Float8_e5m2 operator*(Float8_e5m2 a, int64_t b) {
return a * static_cast<Float8_e5m2>(b);
}
inline C10_HOST_DEVICE Float8_e5m2 operator/(Float8_e5m2 a, int64_t b) {
return a / static_cast<Float8_e5m2>(b);
}
inline C10_HOST_DEVICE Float8_e5m2 operator+(int64_t a, Float8_e5m2 b) {
return static_cast<Float8_e5m2>(a) + b;
}
inline C10_HOST_DEVICE Float8_e5m2 operator-(int64_t a, Float8_e5m2 b) {
return static_cast<Float8_e5m2>(a) - b;
}
inline C10_HOST_DEVICE Float8_e5m2 operator*(int64_t a, Float8_e5m2 b) {
return static_cast<Float8_e5m2>(a) * b;
}
inline C10_HOST_DEVICE Float8_e5m2 operator/(int64_t a, Float8_e5m2 b) {
return static_cast<Float8_e5m2>(a) / b;
}
/// NOTE: we do not define comparisons directly and instead rely on the implicit
/// conversion from c10::Float8_e5m2 to float.
} // namespace c10
namespace std {
template <>
class numeric_limits<c10::Float8_e5m2> {
public:
static constexpr bool is_signed = true;
static constexpr bool is_integer = false;
static constexpr bool is_specialized = true;
static constexpr bool is_exact = false;
static constexpr bool has_infinity = true;
static constexpr bool has_quiet_NaN = false;
static constexpr bool has_signaling_NaN = false;
static constexpr auto has_denorm = true;
static constexpr auto has_denorm_loss = true;
static constexpr auto round_style = numeric_limits<float>::round_style;
static constexpr bool is_iec559 = false;
static constexpr bool is_bounded = true;
static constexpr bool is_modulo = false;
static constexpr int digits = 3;
static constexpr int digits10 = 0;
static constexpr int max_digits10 = 2;
static constexpr int radix = 2;
static constexpr int min_exponent = -13;
static constexpr int min_exponent10 = -4;
static constexpr int max_exponent = 16;
static constexpr int max_exponent10 = 4;
static constexpr auto traps = numeric_limits<float>::traps;
static constexpr auto tinyness_before =
numeric_limits<float>::tinyness_before;
static constexpr c10::Float8_e5m2 min() {
return c10::Float8_e5m2(0x4, c10::Float8_e5m2::from_bits());
}
static constexpr c10::Float8_e5m2 max() {
return c10::Float8_e5m2(0x7B, c10::Float8_e5m2::from_bits());
}
static constexpr c10::Float8_e5m2 lowest() {
return c10::Float8_e5m2(0xFB, c10::Float8_e5m2::from_bits());
}
static constexpr c10::Float8_e5m2 epsilon() {
return c10::Float8_e5m2(0x34, c10::Float8_e5m2::from_bits());
}
static constexpr c10::Float8_e5m2 round_error() {
return c10::Float8_e5m2(0x38, c10::Float8_e5m2::from_bits());
}
static constexpr c10::Float8_e5m2 infinity() {
return c10::Float8_e5m2(0x7C, c10::Float8_e5m2::from_bits());
}
static constexpr c10::Float8_e5m2 denorm_min() {
return c10::Float8_e5m2(0x01, c10::Float8_e5m2::from_bits());
}
};
} // namespace std
C10_CLANG_DIAGNOSTIC_POP()

View File

@ -1,14 +0,0 @@
#include <c10/util/Float8_e5m2.h>
#include <iostream>
namespace c10 {
static_assert(
std::is_standard_layout<Float8_e5m2>::value,
"c10::Float8_e5m2 must be standard layout.");
std::ostream& operator<<(std::ostream& out, const Float8_e5m2& value) {
out << (float)value;
return out;
}
} // namespace c10

View File

@ -1,143 +0,0 @@
#pragma once
/// Defines the Float8_e5m2 type (8-bit floating-point) including conversions
/// to standard C types and basic arithmetic operations. Note that arithmetic
/// operations are implemented by converting to floating point and
/// performing the operation in float32.
/// Binary configuration:
/// s eeeee mm
/// 1 sign bit
/// 5 exponent bits
/// 2 mantissa bits
/// bias = 15
///
/// Implementation based on the paper https://arxiv.org/pdf/2209.05433.pdf
/// and inspired by Half implementation from pytorch/c10/util/Half.h
#include <c10/util/Half.h>
namespace c10 {
namespace detail {
/*
* Convert a 8-bit floating-point number in fp8 E5M2 format, in bit
* representation, to a 32-bit floating-point number in IEEE single-precision
* format, in bit representation.
*
* @note The implementation doesn't use any floating-point operations.
*/
inline C10_HOST_DEVICE float fp8e5m2_to_fp32_value(uint8_t input) {
/*
* Extend the fp8 E5M2 number to 32 bits and shift to the
* upper part of the 32-bit word:
* +---+----+---+-----------------------------+
* | S |EEEEE|MM|0000 0000 0000 0000 0000 0000|
* +---+----+---+-----------------------------+
* Bits 31 26-30 24-25 0-23
*
* S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0
* - zero bits.
*/
uint16_t half_representation = input;
half_representation <<= 8;
return fp16_ieee_to_fp32_value(half_representation);
}
/*
* Convert a 32-bit floating-point number in IEEE single-precision format to a
* 8-bit floating-point number in fp8 E5M2 format, in bit representation.
*/
inline C10_HOST_DEVICE uint8_t fp8e5m2_from_fp32_value(float f) {
/*
* Binary representation of fp32 infinity
* 0 11111111 00000000000000000000000
*/
constexpr uint32_t fp32_inf = UINT32_C(255) << 23;
/*
* Binary representation of 65536.0f, which is the first value
* not representable in fp8e5m2 range:
* 0 11111 00 - fp8e5m2
* 0 10001111 00000000000000000000000 - fp32
*/
constexpr uint32_t fp8_max = UINT32_C(143) << 23;
/*
* A mask for converting fp32 numbers lower than fp8e5m2 normal range
* into denorm representation
* magic number: ((127 - 15) + (23 - 2) + 1)
*/
constexpr uint32_t denorm_mask = UINT32_C(134) << 23;
uint32_t f_bits = fp32_to_bits(f);
uint8_t result = 0u;
/*
* Extract the sign of the input number into the high bit of the 32-bit word:
*
* +---+----------------------------------+
* | S |0000000 00000000 00000000 00000000|
* +---+----------------------------------+
* Bits 31 0-31
*/
const uint32_t sign = f_bits & UINT32_C(0x80000000);
/*
* Set sign bit to 0
*/
f_bits ^= sign;
if (f_bits >= fp8_max) {
// NaN - all exponent and mantissa bits set to 1
result = f_bits > fp32_inf ? UINT8_C(0x7F) : UINT8_C(0x7C);
} else {
if (f_bits < (UINT32_C(113) << 23)) {
// Input number is smaller than 2^(-14), which is the smallest
// fp8e5m2 normal number
f_bits =
fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
result = static_cast<uint8_t>(f_bits - denorm_mask);
} else {
// resulting mantissa is odd
uint32_t mant_odd = (f_bits >> 21) & 1;
// update exponent, rounding bias part 1
f_bits += ((uint32_t)(15 - 127) << 23) + 0xFFFFF;
// rounding bias part 2
f_bits += mant_odd;
// take the bits!
result = static_cast<uint8_t>(f_bits >> 21);
}
}
result |= static_cast<uint8_t>(sign >> 24);
return result;
}
} // namespace detail
struct alignas(1) Float8_e5m2 {
uint8_t x;
struct from_bits_t {};
C10_HOST_DEVICE static constexpr from_bits_t from_bits() {
return from_bits_t();
}
Float8_e5m2() = default;
constexpr C10_HOST_DEVICE Float8_e5m2(uint8_t bits, from_bits_t) : x(bits){};
inline C10_HOST_DEVICE Float8_e5m2(float value);
inline C10_HOST_DEVICE operator float() const;
inline C10_HOST_DEVICE bool isnan() const;
inline C10_HOST_DEVICE bool isinf() const;
};
C10_API std::ostream& operator<<(std::ostream& out, const Float8_e5m2& value);
} // namespace c10
#include <c10/util/Float8_e5m2-inl.h> // IWYU pragma: keep

View File

@ -13,7 +13,6 @@
#include <c10/util/C++17.h>
#include <c10/util/TypeSafeSignMath.h>
#include <c10/util/complex.h>
#include <c10/util/floating_point_utils.h>
#include <type_traits>
#if defined(__cplusplus) && (__cplusplus >= 201103L)
@ -52,12 +51,51 @@
#include <sycl/sycl.hpp> // for SYCL 2020
#endif
// Standard check for compiling CUDA with clang
#if defined(__clang__) && defined(__CUDA__) && defined(__CUDA_ARCH__)
#define C10_DEVICE_HOST_FUNCTION __device__ __host__
#else
#define C10_DEVICE_HOST_FUNCTION
#endif
#include <typeinfo> // operator typeid
namespace c10 {
namespace detail {
C10_DEVICE_HOST_FUNCTION inline float fp32_from_bits(uint32_t w) {
#if defined(__OPENCL_VERSION__)
return as_float(w);
#elif defined(__CUDA_ARCH__)
return __uint_as_float((unsigned int)w);
#elif defined(__INTEL_COMPILER)
return _castu32_f32(w);
#else
union {
uint32_t as_bits;
float as_value;
} fp32 = {w};
return fp32.as_value;
#endif
}
C10_DEVICE_HOST_FUNCTION inline uint32_t fp32_to_bits(float f) {
#if defined(__OPENCL_VERSION__)
return as_uint(f);
#elif defined(__CUDA_ARCH__)
return (uint32_t)__float_as_uint(f);
#elif defined(__INTEL_COMPILER)
return _castf32_u32(f);
#else
union {
float as_value;
uint32_t as_bits;
} fp32 = {f};
return fp32.as_bits;
#endif
}
/*
* Convert a 16-bit floating-point number in IEEE half-precision format, in bit
* representation, to a 32-bit floating-point number in IEEE single-precision
@ -163,7 +201,7 @@ inline uint32_t fp16_ieee_to_fp32_bits(uint16_t h) {
* mode and no operations on denormals) floating-point operations and bitcasts
* between integer and floating-point variables.
*/
C10_HOST_DEVICE inline float fp16_ieee_to_fp32_value(uint16_t h) {
inline float fp16_ieee_to_fp32_value(uint16_t h) {
/*
* Extend the half-precision floating-point number to 32 bits and shift to the
* upper part of the 32-bit word:

View File

@ -1,8 +1,6 @@
#pragma once
#include <c10/macros/Macros.h>
#include <c10/util/BFloat16.h>
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e5m2.h>
#include <c10/util/Half.h>
#include <type_traits>
@ -79,26 +77,6 @@ struct static_cast_with_inter_type<c10::complex<c10::Half>, c10::BFloat16> {
}
};
template <>
struct static_cast_with_inter_type<c10::complex<c10::Half>, c10::Float8_e5m2> {
C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex<
c10::Half>
apply(c10::Float8_e5m2 src) {
return static_cast<c10::complex<c10::Half>>(c10::complex<float>{src});
}
};
template <>
struct static_cast_with_inter_type<
c10::complex<c10::Half>,
c10::Float8_e4m3fn> {
C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex<
c10::Half>
apply(c10::Float8_e4m3fn src) {
return static_cast<c10::complex<c10::Half>>(c10::complex<float>{src});
}
};
template <>
struct static_cast_with_inter_type<c10::complex<c10::Half>, c10::Half> {
C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex<

View File

@ -1,39 +0,0 @@
#pragma once
#include <cstdint>
namespace c10::detail {
C10_HOST_DEVICE inline float fp32_from_bits(uint32_t w) {
#if defined(__OPENCL_VERSION__)
return as_float(w);
#elif defined(__CUDA_ARCH__)
return __uint_as_float((unsigned int)w);
#elif defined(__INTEL_COMPILER)
return _castu32_f32(w);
#else
union {
uint32_t as_bits;
float as_value;
} fp32 = {w};
return fp32.as_value;
#endif
}
C10_HOST_DEVICE inline uint32_t fp32_to_bits(float f) {
#if defined(__OPENCL_VERSION__)
return as_uint(f);
#elif defined(__CUDA_ARCH__)
return (uint32_t)__float_as_uint(f);
#elif defined(__INTEL_COMPILER)
return _castf32_u32(f);
#else
union {
float as_value;
uint32_t as_bits;
} fp32 = {f};
return fp32.as_bits;
#endif
}
} // namespace c10::detail

View File

@ -1,154 +0,0 @@
# Owner(s): ["oncall: quantization"]
import torch
from torch.testing._internal.common_utils import (
TestCase,
parametrize,
instantiate_parametrized_tests,
run_tests,
)
# Masks for float8 simulation
# 0 11111111 11000000000000000000000b
MASK_152 = torch.tensor(2145386496, dtype=torch.int)
# 0 11111111 11100000000000000000000b
MASK_143 = torch.tensor(2146435072, dtype=torch.int)
MASK = {
torch.float8_e5m2: MASK_152,
torch.float8_e4m3fn: MASK_143,
}
# 0 00000000 00011111111111111111111b
MASK_ROUND_152 = torch.tensor(1048575, dtype=torch.int)
# 0 00000000 00001111111111111111111b
MASK_ROUND_143 = torch.tensor(524287, dtype=torch.int)
MASK_ROUND = {
torch.float8_e5m2: MASK_ROUND_152,
torch.float8_e4m3fn: MASK_ROUND_143,
}
FP8_MAX_152 = torch.tensor(57344, dtype=torch.float)
FP8_MAX_143 = torch.tensor(448, dtype=torch.float)
FP8_MAX = {torch.float8_e5m2: FP8_MAX_152, torch.float8_e4m3fn: FP8_MAX_143}
SPECIAL_NUMBERS = {
torch.float8_e5m2: [
("01111100", float("inf"), "inf"),
("11111100", -1.0 * float("inf"), "neg_inf"),
("01111101", float("nan"), "nan"),
("11111101", float("nan"), "nan"),
("01111110", float("nan"), "nan"),
("11111110", float("nan"), "nan"),
("01111111", float("nan"), "nan"),
("11111111", float("nan"), "nan"),
("00000000", 0.0, "zero"),
("10000000", -0.0, "neg_zero"),
("01111011", 57344.0, "max_normal"),
("11111011", -57344.0, "neg_max_normal"),
("00000100", 2**-14, "min_normal"),
("10000100", -1 * (2**-14), "neg_min_normal"),
("00000011", 0.75 * (2**-14), "max_subnorm"),
("10000011", -0.75 * (2**-14), "neg_max_subnorm"),
("00000001", 2**-16, "min_subnorm"),
("10000001", -1 * (2**-16), "neg_min_subnorm"),
],
torch.float8_e4m3fn: [
("01111111", float("nan"), "nan"),
("11111111", float("nan"), "nan"),
("00000000", 0.0, "zero"),
("10000000", -0.0, "neg_zero"),
("01111110", 448.0, "max_normal"),
("11111110", -448.0, "neg_max_normal"),
("00001000", 2**-6, "min_normal"),
("10001000", -1 * (2**-6), "neg_min_normal"),
("00000111", 0.875 * (2**-6), "max_subnorm"),
("10000111", -0.875 * (2**-6), "neg_max_subnorm"),
("00000001", 2**-9, "min_subnorm"),
("10000001", -1 * (2**-9), "neg_min_subnorm"),
],
}
def simulateFp8Precision(input, variant):
dtype = torch.float
int_type = torch.int
mask = MASK[variant]
mask_round = MASK_ROUND[variant]
excessive_bits = torch.tensor(21, dtype=int_type)
signs = torch.where(input < 0.0, -1.0, 1.0).to(dtype)
asInt = torch.bitwise_and(input.view(int_type), 2147483647)
mant_odd = torch.bitwise_and(
torch.bitwise_right_shift(asInt, excessive_bits),
torch.tensor(1, dtype=int_type),
)
asInt_masked = asInt + mask_round
asInt_odded = asInt_masked + mant_odd
masked = torch.bitwise_and(asInt_odded, mask)
return masked.view(dtype) * signs
class TestFloat8Dtype(TestCase):
"""
Sanity test for zeros comparison
"""
@parametrize("dtype", [torch.float8_e5m2, torch.float8_e4m3fn])
def test_creation_with_zeros(self, dtype):
x = torch.zeros(8, dtype=torch.float)
x8 = torch.zeros(8, dtype=dtype)
self.assertEqual(x, x8.float())
"""
Numerical test of float8 conversion
"""
@parametrize("dtype", [torch.float8_e5m2, torch.float8_e4m3fn])
def test_cast_to_float8(self, dtype):
x = torch.rand((100, 100)) * FP8_MAX[dtype]
x = torch.cat((x, -x))
x8 = x.to(dtype)
x8_simulated = simulateFp8Precision(x, dtype)
self.assertEqual(x8_simulated, x8.float())
"""
Test of mul implementation
"""
@parametrize("dtype", [torch.float8_e5m2, torch.float8_e4m3fn])
def test_mul(self, dtype):
shape = (10, 10)
a = torch.randn(shape)
a8_simulated = simulateFp8Precision(a, dtype)
a8 = a.to(dtype)
b = torch.randn(shape)
b8_simulated = simulateFp8Precision(b, dtype)
b8 = b.to(dtype)
mul8 = a8 * b8
mul8_simulated = (a8_simulated * b8_simulated).to(dtype)
self.assertEqual(mul8, mul8_simulated)
"""
Test special numbers
"""
@parametrize("dtype", [torch.float8_e5m2, torch.float8_e4m3fn])
def test_special_numbers(self, dtype):
def compare_binary_with_decimal(binary, decimal, number_name, dtype):
bits_int = int(binary, 2)
tensor_int = torch.tensor([bits_int], dtype=torch.uint8)
tensor_fp8 = tensor_int.view(dtype)
if number_name == "nan":
assert tensor_fp8.isnan()
else:
tensor_fp32 = tensor_fp8.float()
ref_tensor_fp32 = torch.tensor([decimal], dtype=torch.float)
self.assertEqual(tensor_fp32, ref_tensor_fp32)
for number in SPECIAL_NUMBERS[dtype]:
compare_binary_with_decimal(*number, dtype)
instantiate_parametrized_tests(TestFloat8Dtype)
if __name__ == "__main__":
run_tests()

View File

@ -4,8 +4,6 @@
{ include: [ "<c10/util/BFloat16-inl.h>", private, "<c10/util/BFloat16.h>", public ] },
{ include: [ "<c10/util/Half-inl.h>", private, "<c10/util/Half.h>", public ] },
{ include: [ "<c10/util/Float8_e5m2-inl.h>", private, "<c10/util/Float8_e5m2.h>", public ] },
{ include: [ "<c10/util/Float8_e4m3fn-inl.h>", private, "<c10/util/Float8_e4m3fn.h>", public ] },
{ include: [ "<c10/util/complex_math.h>", private, "<c10/util/complex.h>", public ] },
{ include: [ "<c10/util/complex_utils.h>", private, "<c10/util/complex.h>", public ] },

View File

@ -330,12 +330,7 @@ def _tensor_str(self, indent):
if self.is_neg():
self = self.resolve_neg()
if self.dtype in [
torch.float16,
torch.bfloat16,
torch.float8_e5m2,
torch.float8_e4m3fn,
]:
if self.dtype is torch.float16 or self.dtype is torch.bfloat16:
self = self.float()
if self.dtype is torch.complex32:

View File

@ -213,18 +213,15 @@ static PyObject* THPStorage_fromBuffer(
auto dtype = reinterpret_cast<THPDtype*>(dtype_obj);
scalar_type = dtype->scalar_type;
const bool is_endian_independent = (scalar_type == at::kByte) ||
(scalar_type == at::kChar) || (scalar_type == at::kFloat8_e5m2) ||
(scalar_type == at::kFloat8_e4m3fn);
TORCH_CHECK(
is_endian_independent || (byte_order_str != nullptr),
(scalar_type == at::kByte) || (scalar_type == at::kChar) ||
(byte_order_str != nullptr),
"function missing required argument 'byte_order' (pos 2)");
size_t element_size = c10::elementSize(scalar_type);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
bool do_byte_swap;
if (!is_endian_independent) {
if (scalar_type != at::kByte && scalar_type != at::kChar) {
if (strcmp(byte_order_str, "native") == 0) {
do_byte_swap = false;
} else if (strcmp(byte_order_str, "big") == 0) {
@ -295,7 +292,7 @@ static PyObject* THPStorage_fromBuffer(
c10::GetDefaultCPUAllocator(),
/*resizable=*/true);
if (is_endian_independent) {
if (scalar_type == at::kByte || scalar_type == at::kChar) {
memcpy(storage->mutable_data(), src + offset, count);
} else if (scalar_type == at::kBool) {
// Because of ASAN checks, that are failing whenever

View File

@ -14,13 +14,7 @@ Dtype Dtype::scalar_dtype() const {
// NOLINTNEXTLINE
#define DTYPE_DEFINE(_1, n) TORCH_API Dtype k##n(ScalarType::n, 1);
AT_FORALL_SCALAR_TYPES_AND5(
Bool,
Half,
BFloat16,
Float8_e5m2,
Float8_e4m3fn,
DTYPE_DEFINE)
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, DTYPE_DEFINE)
DTYPE_DEFINE(c10::quint8, QUInt8);
DTYPE_DEFINE(c10::qint8, QInt8);
@ -34,8 +28,7 @@ Dtype ToDtype(ScalarType type) {
#define TYPE_CASE(_1, n) \
case ScalarType::n: \
return k##n;
AT_FORALL_SCALAR_TYPES_AND5(
Bool, Half, BFloat16, Float8_e5m2, Float8_e4m3fn, TYPE_CASE)
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE)
TYPE_CASE(c10::quint8, QUInt8);
TYPE_CASE(c10::qint8, QInt8);
#undef TYPE_CASE
@ -65,8 +58,7 @@ int Dtype::byte_size() const {
scalar_size = sizeof(Type); \
break;
AT_FORALL_SCALAR_TYPES_AND5(
Bool, Half, BFloat16, Float8_e5m2, Float8_e4m3fn, TYPE_CASE);
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
TYPE_CASE(c10::quint8, QUInt8);
TYPE_CASE(c10::qint8, QInt8);
#undef TYPE_CASE
@ -91,10 +83,6 @@ std::string Dtype::ToCppString() const {
return "half";
case ScalarType::BFloat16:
return "bfloat16";
case ScalarType::Float8_e5m2:
return "float8_e5m2";
case ScalarType::Float8_e4m3fn:
return "float8_e4m3fn";
case ScalarType::QInt8:
return "qint8";
case ScalarType::QUInt8:

View File

@ -1,8 +1,6 @@
#pragma once
#include <c10/util/BFloat16.h>
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e5m2.h>
#include <c10/util/Half.h>
#include <torch/csrc/Export.h>
#include <cstddef>
@ -156,14 +154,6 @@ TORCH_API void THP_decodeBFloat16Buffer(
const uint8_t* src,
THPByteOrder order,
size_t len);
TORCH_API void THP_decodeFloat8_e5m2Buffer(
at::Float8_e5m2* dst,
const uint8_t* src,
size_t len);
TORCH_API void THP_decodeFloat8_e4m3fnBuffer(
at::Float8_e4m3fn* dst,
const uint8_t* src,
size_t len);
TORCH_API void THP_decodeComplexFloatBuffer(
c10::complex<float>* dst,
const uint8_t* src,

View File

@ -70,14 +70,6 @@ inline void store_scalar(void* data, at::ScalarType scalarType, PyObject* obj) {
*(at::BFloat16*)data =
at::convert<at::BFloat16, double>(THPUtils_unpackDouble(obj));
break;
case at::kFloat8_e5m2:
*(at::Float8_e5m2*)data =
at::convert<at::Float8_e5m2, double>(THPUtils_unpackDouble(obj));
break;
case at::kFloat8_e4m3fn:
*(at::Float8_e4m3fn*)data =
at::convert<at::Float8_e4m3fn, double>(THPUtils_unpackDouble(obj));
break;
default:
throw std::runtime_error("invalid type");
}
@ -118,12 +110,6 @@ inline PyObject* load_scalar(void* data, at::ScalarType scalarType) {
case at::kBFloat16:
return PyFloat_FromDouble(
at::convert<double, at::BFloat16>(*(at::BFloat16*)data));
case at::kFloat8_e5m2:
return PyFloat_FromDouble(
at::convert<double, at::Float8_e5m2>(*(at::Float8_e5m2*)data));
case at::kFloat8_e4m3fn:
return PyFloat_FromDouble(
at::convert<double, at::Float8_e4m3fn>(*(at::Float8_e4m3fn*)data));
default:
throw std::runtime_error("invalid type");
}

View File

@ -62,10 +62,6 @@ std::pair<std::string, std::string> getDtypeNames(at::ScalarType scalarType) {
return std::make_pair("bits8", "");
case at::ScalarType::Bits16:
return std::make_pair("bits16", "");
case at::ScalarType::Float8_e5m2:
return std::make_pair("float8_e5m2", "");
case at::ScalarType::Float8_e4m3fn:
return std::make_pair("float8_e4m3fn", "");
default:
throw std::runtime_error("Unimplemented scalar type");
}

View File

@ -47,8 +47,6 @@ complexHalfT = BaseCppType(
complexFloatT = BaseCppType("c10", "complex<float>")
complexDoubleT = BaseCppType("c10", "complex<double>")
bfloat16T = BaseCppType("at", "BFloat16")
float8_e5m2T = BaseCppType("at", "Float8_e5m2")
float8_e4m3fnT = BaseCppType("at", "Float8_e4m3fn")
stringT = BaseCppType("c10", "string_view")
generatorT = BaseCppType("at", "Generator")
scalarTypeT = BaseCppType("at", "ScalarType")
@ -95,8 +93,7 @@ ScalarTypeToCppMapping: Dict[ScalarType, BaseCppType] = {
ScalarType.ComplexFloat: complexFloatT,
ScalarType.ComplexDouble: complexDoubleT,
ScalarType.Bool: boolT,
ScalarType.Float8_e5m2: float8_e5m2T,
ScalarType.Float8_e4m3fn: float8_e4m3fnT,
ScalarType.BFloat16: bfloat16T,
}
BaseTypeToCppMapping: Dict[BaseTy, BaseCppType] = {

View File

@ -298,8 +298,6 @@ class ScalarType(Enum):
ComplexDouble = auto()
Bool = auto()
BFloat16 = auto()
Float8_e5m2 = auto()
Float8_e4m3fn = auto()
def __str__(self) -> str:
return self.name