mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add torch.float8_e5m2 and torch.float8_e4m3 data types (#104242)
Proposal of two float8 variants - e5m2 and e4m3 - based on https://arxiv.org/pdf/2209.05433.pdf Hide all Float8 operator implementations behind `#if !defined(C10_MOBILE)` guard to keep Android build size almost unchanged TODO: - Refactor duplicated code - Cleanup unbalanced pragma pop in dtype utils - Add native implementation on the CUDA size Co-authored-by: Nikita Shulga <nshulga@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/104242 Approved by: https://github.com/albanD
This commit is contained in:
parent
803d58a408
commit
b64bd4a5dd
|
|
@ -2,6 +2,8 @@
|
|||
#include <ATen/Config.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Float8_e4m3fn.h>
|
||||
#include <c10/util/Float8_e5m2.h>
|
||||
#include <c10/util/Half.h>
|
||||
|
||||
// Defines the accumulation type for a scalar type.
|
||||
|
|
@ -67,6 +69,14 @@ struct AccumulateType<Half, true> {
|
|||
using type = float;
|
||||
};
|
||||
template <>
|
||||
struct AccumulateType<Float8_e5m2, true> {
|
||||
using type = float;
|
||||
};
|
||||
template <>
|
||||
struct AccumulateType<Float8_e4m3fn, true> {
|
||||
using type = float;
|
||||
};
|
||||
template <>
|
||||
struct AccumulateType<float, true> {
|
||||
using type = float;
|
||||
};
|
||||
|
|
@ -111,6 +121,14 @@ struct AccumulateType<BFloat16, false> {
|
|||
using type = float;
|
||||
};
|
||||
template <>
|
||||
struct AccumulateType<Float8_e5m2, false> {
|
||||
using type = float;
|
||||
};
|
||||
template <>
|
||||
struct AccumulateType<Float8_e4m3fn, false> {
|
||||
using type = float;
|
||||
};
|
||||
template <>
|
||||
struct AccumulateType<c10::complex<Half>, false> {
|
||||
using type = c10::complex<float>;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -53,6 +53,10 @@ DLDataType getDLDataType(const Tensor& t) {
|
|||
case ScalarType::BFloat16:
|
||||
dtype.code = DLDataTypeCode::kDLBfloat;
|
||||
break;
|
||||
case ScalarType::Float8_e5m2:
|
||||
case ScalarType::Float8_e4m3fn:
|
||||
TORCH_CHECK(false, "float8 types are not supported by dlpack");
|
||||
break;
|
||||
case ScalarType::QInt8:
|
||||
case ScalarType::QUInt8:
|
||||
case ScalarType::QInt32:
|
||||
|
|
|
|||
|
|
@ -291,6 +291,22 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
|
|||
AT_DISPATCH_CASE_FLOATING_TYPES_AND3( \
|
||||
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
|
||||
|
||||
#define AT_DISPATCH_CASE_FLOATING_TYPES_AND4( \
|
||||
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \
|
||||
AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)
|
||||
|
||||
#define AT_DISPATCH_FLOATING_TYPES_AND4( \
|
||||
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH( \
|
||||
TYPE, \
|
||||
NAME, \
|
||||
AT_DISPATCH_CASE_FLOATING_TYPES_AND4( \
|
||||
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))
|
||||
|
||||
#define AT_DISPATCH_CASE_COMPLEX_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::ComplexDouble, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::ComplexFloat, __VA_ARGS__)
|
||||
|
|
@ -515,6 +531,73 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
|
|||
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \
|
||||
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))
|
||||
|
||||
#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND5( \
|
||||
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5, ...) \
|
||||
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__)
|
||||
|
||||
#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND5( \
|
||||
SCALARTYPE1, \
|
||||
SCALARTYPE2, \
|
||||
SCALARTYPE3, \
|
||||
SCALARTYPE4, \
|
||||
SCALARTYPE5, \
|
||||
TYPE, \
|
||||
NAME, \
|
||||
...) \
|
||||
AT_DISPATCH_SWITCH( \
|
||||
TYPE, \
|
||||
NAME, \
|
||||
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND5( \
|
||||
SCALARTYPE1, \
|
||||
SCALARTYPE2, \
|
||||
SCALARTYPE3, \
|
||||
SCALARTYPE4, \
|
||||
SCALARTYPE5, \
|
||||
__VA_ARGS__))
|
||||
|
||||
#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND6( \
|
||||
SCALARTYPE1, \
|
||||
SCALARTYPE2, \
|
||||
SCALARTYPE3, \
|
||||
SCALARTYPE4, \
|
||||
SCALARTYPE5, \
|
||||
SCALARTYPE6, \
|
||||
...) \
|
||||
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__)
|
||||
|
||||
#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6( \
|
||||
SCALARTYPE1, \
|
||||
SCALARTYPE2, \
|
||||
SCALARTYPE3, \
|
||||
SCALARTYPE4, \
|
||||
SCALARTYPE5, \
|
||||
SCALARTYPE6, \
|
||||
TYPE, \
|
||||
NAME, \
|
||||
...) \
|
||||
AT_DISPATCH_SWITCH( \
|
||||
TYPE, \
|
||||
NAME, \
|
||||
AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND6( \
|
||||
SCALARTYPE1, \
|
||||
SCALARTYPE2, \
|
||||
SCALARTYPE3, \
|
||||
SCALARTYPE4, \
|
||||
SCALARTYPE5, \
|
||||
SCALARTYPE6, \
|
||||
__VA_ARGS__))
|
||||
|
||||
#define AT_DISPATCH_INDEX_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH( \
|
||||
TYPE, \
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@
|
|||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Float8_e4m3fn.h>
|
||||
#include <c10/util/Float8_e5m2.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <c10/util/complex.h>
|
||||
|
||||
|
|
@ -62,6 +64,22 @@ inline C10_HOST_DEVICE bool _isnan(at::BFloat16 val) {
|
|||
return at::_isnan(static_cast<float>(val));
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename std::enable_if<std::is_same<T, at::Float8_e5m2>::value, int>::
|
||||
type = 0>
|
||||
inline C10_HOST_DEVICE bool _isnan(T val) {
|
||||
return val.isnan();
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename std::enable_if<std::is_same<T, at::Float8_e4m3fn>::value, int>::
|
||||
type = 0>
|
||||
inline C10_HOST_DEVICE bool _isnan(T val) {
|
||||
return val.isnan();
|
||||
}
|
||||
|
||||
// std::isinf isn't performant to use on integral types; it will
|
||||
// (uselessly) convert to floating point and then do the test.
|
||||
// This function is.
|
||||
|
|
@ -92,6 +110,14 @@ inline C10_HOST_DEVICE bool _isinf(at::BFloat16 val) {
|
|||
return at::_isinf(static_cast<float>(val));
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE bool _isinf(at::Float8_e5m2 val) {
|
||||
return val.isinf();
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE bool _isinf(at::Float8_e4m3fn val) {
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
C10_HOST_DEVICE inline T exp(T x) {
|
||||
static_assert(
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@
|
|||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Float8_e4m3fn.h>
|
||||
#include <c10/util/Float8_e5m2.h>
|
||||
#include <c10/util/Half.h>
|
||||
|
||||
namespace at {
|
||||
|
|
@ -21,6 +23,14 @@ struct OpMathType<at::BFloat16> {
|
|||
using type = float;
|
||||
};
|
||||
template <>
|
||||
struct OpMathType<at::Float8_e5m2> {
|
||||
using type = float;
|
||||
};
|
||||
template <>
|
||||
struct OpMathType<at::Float8_e4m3fn> {
|
||||
using type = float;
|
||||
};
|
||||
template <>
|
||||
struct OpMathType<c10::complex<Half>> {
|
||||
using type = c10::complex<float>;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -72,7 +72,9 @@ struct is_floating_point:
|
|||
std::integral_constant<bool,
|
||||
std::is_floating_point<T>::value ||
|
||||
std::is_same<T, at::Half>::value ||
|
||||
std::is_same<T, at::BFloat16>::value> {
|
||||
std::is_same<T, at::BFloat16>::value ||
|
||||
std::is_same<T, at::Float8_e5m2>::value ||
|
||||
std::is_same<T, at::Float8_e4m3fn>::value> {
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
|
|
|
|||
|
|
@ -48,6 +48,18 @@ bool copy_transpose_valid(const Tensor& self, const Tensor& src) {
|
|||
self.numel() >= MIN_SZ;
|
||||
}
|
||||
|
||||
#if !defined(C10_MOBILE)
|
||||
#define _AT_DISPATCH_CP_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6( \
|
||||
kComplexHalf, kHalf, kBool, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
|
||||
TYPE, NAME, __VA_ARGS__)
|
||||
#else
|
||||
#define _AT_DISPATCH_CP_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
|
||||
kComplexHalf, kHalf, kBool, kBFloat16, \
|
||||
TYPE, NAME, __VA_ARGS__)
|
||||
#endif
|
||||
|
||||
// special case copy where tensor is contiguous and src is a transposed matrix
|
||||
// This can be generalized to most copies, but it's trickier
|
||||
void copy_same_type_transpose_(Tensor& self, const Tensor& src) {
|
||||
|
|
@ -65,7 +77,7 @@ void copy_same_type_transpose_(Tensor& self, const Tensor& src) {
|
|||
// The code below is implemented with the assumption that sizes are equal
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(self.sizes().equals(src.sizes()));
|
||||
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kHalf, kBool, kBFloat16, kComplexHalf, self.scalar_type(), "copy_", [&] {
|
||||
_AT_DISPATCH_CP_TYPES(self.scalar_type(), "copy_", [&] {
|
||||
scalar_t* sp = src.data_ptr<scalar_t>();
|
||||
scalar_t* rp = self.data_ptr<scalar_t>();
|
||||
scalar_t* bp = buf.data_ptr<scalar_t>();
|
||||
|
|
|
|||
|
|
@ -1310,6 +1310,18 @@ Tensor outer(const Tensor& self, const Tensor& vec2) {
|
|||
return self.reshape_symint({self.sym_size(0), 1}) * vec2;
|
||||
}
|
||||
|
||||
|
||||
#if !defined(C10_MOBILE)
|
||||
#define _AT_DISPATCH_ADDMM_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( \
|
||||
kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
|
||||
TYPE, NAME, __VA_ARGS__)
|
||||
#else
|
||||
#define _AT_DISPATCH_ADDMM_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, \
|
||||
TYPE, NAME, __VA_ARGS__)
|
||||
#endif
|
||||
|
||||
static void addmm_impl_cpu_(
|
||||
Tensor &result, const Tensor &self, Tensor m1, Tensor m2, const Scalar& beta, const Scalar& alpha) {
|
||||
TORCH_INTERNAL_ASSERT(self.dim() == 2 && m1.dim() == 2 && m2.dim() == 2);
|
||||
|
|
@ -1438,9 +1450,7 @@ static void addmm_impl_cpu_(
|
|||
|
||||
if(!dispatched) {
|
||||
// Apply BLAS routine
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16,
|
||||
result.scalar_type(), "addmm_impl_cpu_",
|
||||
[&]{
|
||||
_AT_DISPATCH_ADDMM_TYPES(result.scalar_type(), "addmm_impl_cpu_", [&]{
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
at::native::cpublas::gemm(
|
||||
transpose_a ? a.is_conj() ? TransposeType::ConjTranspose : TransposeType::Transpose : TransposeType::NoTranspose,
|
||||
|
|
|
|||
|
|
@ -28,10 +28,21 @@ Scalar item(const Tensor& self) {
|
|||
}
|
||||
}
|
||||
|
||||
#if !defined(C10_MOBILE)
|
||||
#define _AT_DISPATCH_SD_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6( \
|
||||
kComplexHalf, kHalf, kBool, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
|
||||
TYPE, NAME, __VA_ARGS__)
|
||||
#else
|
||||
#define _AT_DISPATCH_SD_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
|
||||
kComplexHalf, kHalf, kBool, kBFloat16, \
|
||||
TYPE, NAME, __VA_ARGS__)
|
||||
#endif
|
||||
|
||||
Scalar _local_scalar_dense_cpu(const Tensor& self) {
|
||||
Scalar r;
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
|
||||
kComplexHalf, kHalf, kBool, kBFloat16, self.scalar_type(), "_local_scalar_dense_cpu", [&] {
|
||||
_AT_DISPATCH_SD_TYPES(self.scalar_type(), "_local_scalar_dense_cpu", [&] {
|
||||
scalar_t value = *self.data_ptr<scalar_t>();
|
||||
r = Scalar(value);
|
||||
});
|
||||
|
|
|
|||
|
|
@ -369,6 +369,18 @@ Tensor isreal(const Tensor& self) {
|
|||
return at::imag(self) == 0;
|
||||
}
|
||||
|
||||
|
||||
#if !defined(C10_MOBILE)
|
||||
#define _AT_DISPATCH_INF_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_FLOATING_TYPES_AND3( kHalf, kBFloat16, kFloat8_e5m2, \
|
||||
TYPE, NAME, __VA_ARGS__)
|
||||
#else
|
||||
#define _AT_DISPATCH_INF_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, \
|
||||
TYPE, NAME, __VA_ARGS__)
|
||||
#endif
|
||||
|
||||
|
||||
Tensor isinf(const Tensor &self) {
|
||||
// Note: Integral tensor values are never infinite
|
||||
if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {
|
||||
|
|
@ -381,7 +393,7 @@ Tensor isinf(const Tensor &self) {
|
|||
(at::isinf(at::imag(self)));
|
||||
}
|
||||
|
||||
return AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, self.scalar_type(), "isinf", [&]() {
|
||||
return _AT_DISPATCH_INF_TYPES(self.scalar_type(), "isinf", [&]() {
|
||||
return self.abs() == std::numeric_limits<scalar_t>::infinity();
|
||||
});
|
||||
}
|
||||
|
|
@ -397,7 +409,7 @@ Tensor isfinite(const Tensor& self) {
|
|||
return at::isfinite(at::real(self)).__iand__(at::isfinite(at::imag(self)));
|
||||
}
|
||||
|
||||
return AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, self.scalar_type(), "isfinite", [&]() {
|
||||
return _AT_DISPATCH_INF_TYPES(self.scalar_type(), "isfinite", [&]() {
|
||||
return (self == self) * (self.abs() != std::numeric_limits<scalar_t>::infinity());
|
||||
});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -61,6 +61,34 @@ void atan2_kernel(TensorIteratorBase& iter) {
|
|||
});
|
||||
}
|
||||
|
||||
#if !defined(C10_MOBILE)
|
||||
#define _AT_DISPATCH_ALL_TYPES_AND_BOOL(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6( \
|
||||
kComplexHalf, kHalf, kBool, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
|
||||
TYPE, NAME, __VA_ARGS__)
|
||||
#define _AT_DISPATCH_ALL_TYPES_NO_BOOL(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND5( \
|
||||
kComplexHalf, kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
|
||||
TYPE, NAME, __VA_ARGS__)
|
||||
#define _AT_DISPATCH_MUL_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
|
||||
kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
|
||||
TYPE, NAME, __VA_ARGS__)
|
||||
#else
|
||||
#define _AT_DISPATCH_ALL_TYPES_AND_BOOL(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
|
||||
kComplexHalf, kHalf, kBool, kBFloat16, \
|
||||
TYPE, NAME, __VA_ARGS__)
|
||||
#define _AT_DISPATCH_ALL_TYPES_NO_BOOL(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( \
|
||||
kComplexHalf, kHalf, kBFloat16, \
|
||||
TYPE, NAME, __VA_ARGS__)
|
||||
#define _AT_DISPATCH_MUL_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( \
|
||||
kHalf, kBFloat16, \
|
||||
TYPE, NAME, __VA_ARGS__)
|
||||
#endif
|
||||
|
||||
void mul_kernel(TensorIteratorBase& iter) {
|
||||
auto dtype = iter.common_dtype();
|
||||
if (dtype == ScalarType::Bool) {
|
||||
|
|
@ -85,7 +113,7 @@ void mul_kernel(TensorIteratorBase& iter) {
|
|||
});
|
||||
});
|
||||
} else {
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, dtype, "mul_cpu", [&]() {
|
||||
_AT_DISPATCH_MUL_TYPES(dtype, "mul_cpu", [&]() {
|
||||
cpu_kernel_vec(iter,
|
||||
[=](scalar_t a, scalar_t b) __ubsan_ignore_undefined__ -> scalar_t { return a * b; },
|
||||
[=](Vectorized<scalar_t> a, Vectorized<scalar_t> b) __ubsan_ignore_undefined__ {
|
||||
|
|
@ -528,14 +556,14 @@ void ge_kernel(TensorIteratorBase& iter) {
|
|||
void eq_kernel(TensorIteratorBase& iter) {
|
||||
// See Note [special-case bool outputs]
|
||||
if (iter.dtype() == ScalarType::Bool) {
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kBool, kBFloat16, kHalf, iter.common_dtype(), "eq_cpu", [&]() {
|
||||
_AT_DISPATCH_ALL_TYPES_AND_BOOL(iter.common_dtype(), "eq_cpu", [&]() {
|
||||
cpu_kernel(iter,
|
||||
[](scalar_t a, scalar_t b) -> bool {
|
||||
return a == b;
|
||||
});
|
||||
});
|
||||
} else {
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kComplexHalf, kBFloat16, kHalf, iter.common_dtype(), "eq_cpu", [&]() {
|
||||
_AT_DISPATCH_ALL_TYPES_NO_BOOL(iter.common_dtype(), "eq_cpu", [&]() {
|
||||
cpu_kernel_vec(
|
||||
iter,
|
||||
[](scalar_t a, scalar_t b) -> scalar_t {
|
||||
|
|
@ -551,14 +579,14 @@ void eq_kernel(TensorIteratorBase& iter) {
|
|||
void ne_kernel(TensorIteratorBase& iter) {
|
||||
// See Note [special-case bool outputs]
|
||||
if (iter.dtype() == ScalarType::Bool) {
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kBool, kBFloat16, kHalf, iter.common_dtype(), "ne_cpu", [&]() {
|
||||
_AT_DISPATCH_ALL_TYPES_AND_BOOL(iter.common_dtype(), "ne_cpu", [&]() {
|
||||
cpu_kernel(iter,
|
||||
[](scalar_t a, scalar_t b) -> bool {
|
||||
return a != b;
|
||||
});
|
||||
});
|
||||
} else {
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kComplexHalf, kBFloat16, kHalf, iter.common_dtype(), "ne_cpu", [&]() {
|
||||
_AT_DISPATCH_ALL_TYPES_NO_BOOL(iter.common_dtype(), "ne_cpu", [&]() {
|
||||
cpu_kernel_vec(
|
||||
iter,
|
||||
[](scalar_t a, scalar_t b) -> scalar_t {
|
||||
|
|
|
|||
|
|
@ -263,6 +263,17 @@ void gemm_core_(
|
|||
}
|
||||
}
|
||||
|
||||
#if !defined(C10_MOBILE)
|
||||
#define _AT_DISPATCH_GEMM_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
|
||||
kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
|
||||
TYPE, NAME, __VA_ARGS__)
|
||||
#else
|
||||
#define _AT_DISPATCH_GEMM_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( \
|
||||
kHalf, kBFloat16, \
|
||||
TYPE, NAME, __VA_ARGS__)
|
||||
#endif
|
||||
void cpublas_gemm_impl(
|
||||
at::ScalarType type,
|
||||
TransposeType transa, TransposeType transb,
|
||||
|
|
@ -272,9 +283,7 @@ void cpublas_gemm_impl(
|
|||
const void *b, int64_t ldb,
|
||||
const Scalar& beta,
|
||||
void *c, int64_t ldc) {
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::kHalf, at::kBFloat16,
|
||||
type, "cpublas_gemm_impl",
|
||||
[&]{
|
||||
_AT_DISPATCH_GEMM_TYPES(type, "cpublas_gemm_impl", [&]{
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
gemm_core_(
|
||||
transa, transb, m, n, k,
|
||||
|
|
|
|||
|
|
@ -164,6 +164,27 @@ static void float_bfloat16_copy_kernel(TensorIteratorBase &iter, bool requires_n
|
|||
}
|
||||
}
|
||||
|
||||
#if !defined(C10_MOBILE)
|
||||
#define _AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6( \
|
||||
ScalarType::ComplexHalf, ScalarType::Half, ScalarType::Bool, \
|
||||
ScalarType::BFloat16, ScalarType::Float8_e5m2, ScalarType::Float8_e4m3fn, \
|
||||
TYPE, NAME, __VA_ARGS__)
|
||||
#define _AT_DISPATCH_ALL_TYPES_NO_CF(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND5( \
|
||||
kBool, kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
|
||||
TYPE, NAME, __VA_ARGS__)
|
||||
#else
|
||||
#define _AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
|
||||
ScalarType::ComplexHalf, ScalarType::Half, ScalarType::Bool,ScalarType::BFloat16, \
|
||||
TYPE, NAME, __VA_ARGS__)
|
||||
#define _AT_DISPATCH_ALL_TYPES_NO_CF(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( \
|
||||
kBool, kHalf, kBFloat16, \
|
||||
TYPE, NAME, __VA_ARGS__)
|
||||
#endif
|
||||
|
||||
void direct_copy_kernel(TensorIteratorBase &iter) {
|
||||
// TODO: we don't actually need separate instantiations per dtype;
|
||||
// we only need a separate instantiation per dtype size. This would
|
||||
|
|
@ -183,8 +204,7 @@ void direct_copy_kernel(TensorIteratorBase &iter) {
|
|||
} else if (dtype == ScalarType::ComplexHalf) {
|
||||
cpu_kernel(iter, [=](c10::complex<at::Half> a) -> c10::complex<at::Half> { return a; });
|
||||
} else {
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
|
||||
kBool, kHalf, kBFloat16, dtype, "copy_kernel", [&] {
|
||||
_AT_DISPATCH_ALL_TYPES_NO_CF(dtype, "copy_kernel", [&] {
|
||||
cpu_kernel_vec(
|
||||
iter,
|
||||
[=](scalar_t a) -> scalar_t { return a; },
|
||||
|
|
@ -237,9 +257,9 @@ void copy_kernel(TensorIterator& iter, bool /*non_blocking*/) {
|
|||
sizeof(BFloat16) == strides_out[0] && (sizeof(float) == strides_in[0] || strides_in[0] == 0)))) {
|
||||
float_bfloat16_copy_kernel(iter, requires_neg);
|
||||
} else {
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(ScalarType::ComplexHalf, ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16, dtype, "copy_", [&] {
|
||||
_AT_DISPATCH_ALL_TYPES(dtype, "copy_", [&] {
|
||||
using dest_t = scalar_t;
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(ScalarType::ComplexHalf, ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16, iter.dtype(1), "copy_", [&] {
|
||||
_AT_DISPATCH_ALL_TYPES(iter.dtype(1), "copy_", [&] {
|
||||
if (iter.has_contiguous_first_dim()) {
|
||||
TORCH_INTERNAL_ASSERT(iter.ninputs() == 1);
|
||||
TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
|
||||
|
|
|
|||
|
|
@ -179,6 +179,18 @@ static void logit_kernel(TensorIteratorBase& iter, const Scalar& eps_scalar) {
|
|||
});
|
||||
}
|
||||
|
||||
#if !defined(C10_MOBILE)
|
||||
#define _AT_DISPATCH_ABS_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
|
||||
kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
|
||||
TYPE, NAME, __VA_ARGS__)
|
||||
#else
|
||||
#define _AT_DISPATCH_ABS_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( \
|
||||
kHalf, kBFloat16, \
|
||||
TYPE, NAME, __VA_ARGS__)
|
||||
#endif
|
||||
|
||||
static void abs_kernel(TensorIteratorBase& iter) {
|
||||
auto dtype = iter.dtype();
|
||||
if (dtype == kComplexHalf) {
|
||||
|
|
@ -186,7 +198,7 @@ static void abs_kernel(TensorIteratorBase& iter) {
|
|||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
cpu_kernel(iter, [=](scalar_t a) -> scalar_t { return abs_impl(opmath_t{a}); });
|
||||
} else {
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.dtype(), "abs_cpu", [&]() {
|
||||
_AT_DISPATCH_ABS_TYPES(iter.dtype(), "abs_cpu", [&]() {
|
||||
cpu_kernel_vec(
|
||||
iter,
|
||||
[=](scalar_t a) -> scalar_t { return abs_impl(a); },
|
||||
|
|
|
|||
|
|
@ -186,6 +186,12 @@ template <> inline std::string typeName<at::Half>(){
|
|||
template <> inline std::string typeName<at::BFloat16>(){
|
||||
return "at::BFloat16";
|
||||
}
|
||||
template <> inline std::string typeName<at::Float8_e5m2>(){
|
||||
return "at::Float8_e5m2";
|
||||
}
|
||||
template <> inline std::string typeName<at::Float8_e4m3fn>(){
|
||||
return "at::Float8_e4m3fn";
|
||||
}
|
||||
|
||||
#define TYPE_NAME_CASE(ctype, scalartype) \
|
||||
case ScalarType::scalartype: return typeName<ctype>();
|
||||
|
|
|
|||
|
|
@ -48,7 +48,13 @@ class C10_API Scalar {
|
|||
#define DEFINE_IMPLICIT_CTOR(type, name) \
|
||||
Scalar(type vv) : Scalar(vv, true) {}
|
||||
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Half, BFloat16, ComplexHalf, DEFINE_IMPLICIT_CTOR)
|
||||
AT_FORALL_SCALAR_TYPES_AND5(
|
||||
Half,
|
||||
BFloat16,
|
||||
Float8_e5m2,
|
||||
Float8_e4m3fn,
|
||||
ComplexHalf,
|
||||
DEFINE_IMPLICIT_CTOR)
|
||||
AT_FORALL_COMPLEX_TYPES(DEFINE_IMPLICIT_CTOR)
|
||||
|
||||
#undef DEFINE_IMPLICIT_CTOR
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@
|
|||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Deprecated.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Float8_e4m3fn.h>
|
||||
#include <c10/util/Float8_e5m2.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <c10/util/bits.h>
|
||||
#include <c10/util/complex.h>
|
||||
|
|
@ -50,7 +52,9 @@ namespace c10 {
|
|||
_(c10::bits2x4, Bits2x4) /* 19 */ \
|
||||
_(c10::bits4x2, Bits4x2) /* 20 */ \
|
||||
_(c10::bits8, Bits8) /* 21 */ \
|
||||
_(c10::bits16, Bits16) /* 22 */
|
||||
_(c10::bits16, Bits16) /* 22 */ \
|
||||
_(c10::Float8_e5m2, Float8_e5m2) /* 23 */ \
|
||||
_(c10::Float8_e4m3fn, Float8_e4m3fn) /* 24 */
|
||||
|
||||
// If you want to support ComplexHalf for real, add ComplexHalf
|
||||
// into this macro (and change the name). But beware: convert()
|
||||
|
|
@ -67,7 +71,9 @@ namespace c10 {
|
|||
_(c10::complex<float>, ComplexFloat) \
|
||||
_(c10::complex<double>, ComplexDouble) \
|
||||
_(bool, Bool) \
|
||||
_(at::BFloat16, BFloat16)
|
||||
_(at::BFloat16, BFloat16) \
|
||||
_(at::Float8_e5m2, Float8_e5m2) \
|
||||
_(at::Float8_e4m3fn, Float8_e4m3fn)
|
||||
|
||||
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(_) \
|
||||
_(uint8_t, Byte) \
|
||||
|
|
@ -82,7 +88,9 @@ namespace c10 {
|
|||
_(c10::complex<float>, ComplexFloat) \
|
||||
_(c10::complex<double>, ComplexDouble) \
|
||||
_(bool, Bool) \
|
||||
_(at::BFloat16, BFloat16)
|
||||
_(at::BFloat16, BFloat16) \
|
||||
_(at::Float8_e5m2, Float8_e5m2) \
|
||||
_(at::Float8_e4m3fn, Float8_e4m3fn)
|
||||
|
||||
enum class ScalarType : int8_t {
|
||||
#define DEFINE_ENUM(_1, n) n,
|
||||
|
|
@ -201,6 +209,53 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
|
|||
::c10::ScalarType::SCALARTYPE3>::t), \
|
||||
SCALARTYPE3)
|
||||
|
||||
#define AT_FORALL_SCALAR_TYPES_AND4( \
|
||||
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, _) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long) \
|
||||
_(float, Float) \
|
||||
_(double, Double) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE1>::t), \
|
||||
SCALARTYPE1) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE2>::t), \
|
||||
SCALARTYPE2) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE3>::t), \
|
||||
SCALARTYPE3) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE4>::t), \
|
||||
SCALARTYPE4)
|
||||
|
||||
#define AT_FORALL_SCALAR_TYPES_AND5( \
|
||||
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5, _) \
|
||||
_(uint8_t, Byte) \
|
||||
_(int8_t, Char) \
|
||||
_(int16_t, Short) \
|
||||
_(int, Int) \
|
||||
_(int64_t, Long) \
|
||||
_(float, Float) \
|
||||
_(double, Double) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE1>::t), \
|
||||
SCALARTYPE1) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE2>::t), \
|
||||
SCALARTYPE2) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE3>::t), \
|
||||
SCALARTYPE3) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE4>::t), \
|
||||
SCALARTYPE4) \
|
||||
_(decltype(::c10::impl::ScalarTypeToCPPType< \
|
||||
::c10::ScalarType::SCALARTYPE5>::t), \
|
||||
SCALARTYPE5)
|
||||
|
||||
#define AT_FORALL_QINT_TYPES(_) \
|
||||
_(c10::qint8, QInt8) \
|
||||
_(c10::quint8, QUInt8) \
|
||||
|
|
@ -261,7 +316,8 @@ static inline bool isIntegralType(ScalarType t) {
|
|||
static inline bool isFloatingType(ScalarType t) {
|
||||
return (
|
||||
t == ScalarType::Double || t == ScalarType::Float ||
|
||||
t == ScalarType::Half || t == ScalarType::BFloat16);
|
||||
t == ScalarType::Half || t == ScalarType::BFloat16 ||
|
||||
t == ScalarType::Float8_e5m2 || t == ScalarType::Float8_e4m3fn);
|
||||
}
|
||||
|
||||
static inline bool isReducedFloatingType(ScalarType t) {
|
||||
|
|
@ -334,7 +390,8 @@ static inline bool isSignedType(ScalarType t) {
|
|||
case ScalarType::ComplexFloat:
|
||||
case ScalarType::ComplexDouble:
|
||||
return true;
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Half, Bool, BFloat16, CASE_SIGNED)
|
||||
AT_FORALL_SCALAR_TYPES_AND5(
|
||||
Half, Bool, BFloat16, Float8_e5m2, Float8_e4m3fn, CASE_SIGNED)
|
||||
default:
|
||||
TORCH_CHECK(false, "Unknown ScalarType");
|
||||
}
|
||||
|
|
@ -425,6 +482,8 @@ static inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
|
|||
constexpr auto c8 = ScalarType::ComplexDouble;
|
||||
constexpr auto b1 = ScalarType::Bool;
|
||||
constexpr auto bf = ScalarType::BFloat16;
|
||||
constexpr auto b8 = ScalarType::Float8_e5m2;
|
||||
constexpr auto h8 = ScalarType::Float8_e4m3fn;
|
||||
constexpr auto ud = ScalarType::Undefined;
|
||||
if (a == ud || b == ud) {
|
||||
return ScalarType::Undefined;
|
||||
|
|
@ -462,23 +521,25 @@ static inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
|
|||
// clang-format off
|
||||
static constexpr ScalarType _promoteTypesLookup[
|
||||
NUM_PROMOTE_TYPES][NUM_PROMOTE_TYPES] = {
|
||||
/* u1 i1 i2 i4 i8 f2 f4 f8 c2 c4 c8 b1 q1 q2 q3 bf*/
|
||||
/* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, u1, ud, ud, ud, bf},
|
||||
/* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, i1, ud, ud, ud, bf},
|
||||
/* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, i2, ud, ud, ud, bf},
|
||||
/* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, c2, c4, c8, i4, ud, ud, ud, bf},
|
||||
/* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, c2, c4, c8, i8, ud, ud, ud, bf},
|
||||
/* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, c2, c4, c8, f2, ud, ud, ud, f4},
|
||||
/* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, c4, c4, c8, f4, ud, ud, ud, f4},
|
||||
/* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, c8, c8, c8, f8, ud, ud, ud, f8},
|
||||
/* c2 */ {c2, c2, c2, c2, c2, c2, c4, c8, c2, c4, c8, c2, ud, ud, ud, c4},
|
||||
/* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c4, c8, c4, ud, ud, ud, c4},
|
||||
/* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, ud, ud, ud, c8},
|
||||
/* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, b1, ud, ud, ud, bf},
|
||||
/* q1 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
|
||||
/* q2 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
|
||||
/* q3 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
|
||||
/* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, c4, c4, c8, bf, ud, ud, ud, bf},
|
||||
/* u1 i1 i2 i4 i8 f2 f4 f8 c2 c4 c8 b1 q1 q2 q3 bf b8 h8*/
|
||||
/* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, u1, ud, ud, ud, bf, b8, h8},
|
||||
/* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, i1, ud, ud, ud, bf, b8, h8},
|
||||
/* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, i2, ud, ud, ud, bf, b8, h8},
|
||||
/* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, c2, c4, c8, i4, ud, ud, ud, bf, b8, h8},
|
||||
/* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, c2, c4, c8, i8, ud, ud, ud, bf, b8, h8},
|
||||
/* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, c2, c4, c8, f2, ud, ud, ud, f4, f4, f4},
|
||||
/* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, c4, c4, c8, f4, ud, ud, ud, f4, f4, f4},
|
||||
/* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, c8, c8, c8, f8, ud, ud, ud, f8, f8, f8},
|
||||
/* c2 */ {c2, c2, c2, c2, c2, c2, c4, c8, c2, c4, c8, c2, ud, ud, ud, c4, c4, c4},
|
||||
/* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c4, c8, c4, ud, ud, ud, c4, c4, c4},
|
||||
/* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, ud, ud, ud, c8, c8, c8},
|
||||
/* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, b1, ud, ud, ud, bf, b8, h8},
|
||||
/* q1 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
|
||||
/* q2 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
|
||||
/* q3 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
|
||||
/* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, c4, c4, c8, bf, ud, ud, ud, bf, bf, bf},
|
||||
/* b8 */ {b8, b8, b8, b8, b8, f4, f4, f8, c4, c4, c8, b8, ud, ud, ud, bf, b8, ud},
|
||||
/* h8 */ {h8, h8, h8, h8, h8, f4, f4, f8, c4, c4, c8, h8, ud, ud, ud, bf, ud, h8},
|
||||
};
|
||||
// clang-format on
|
||||
return _promoteTypesLookup[static_cast<int>(a)][static_cast<int>(b)];
|
||||
|
|
@ -490,8 +551,4 @@ inline std::ostream& operator<<(
|
|||
return stream << toString(scalar_type);
|
||||
}
|
||||
|
||||
#define AT_FORAUTOCAST_SCALAR_TYPES(_) \
|
||||
_(half, Half) /* 0 */ \
|
||||
_(bfloat16, BFloat16) /* 1 */
|
||||
|
||||
} // namespace c10
|
||||
|
|
|
|||
275
c10/util/Float8_e4m3fn-inl.h
Normal file
275
c10/util/Float8_e4m3fn-inl.h
Normal file
|
|
@ -0,0 +1,275 @@
|
|||
#pragma once
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
|
||||
C10_CLANG_DIAGNOSTIC_PUSH()
|
||||
#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
|
||||
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
|
||||
#endif
|
||||
|
||||
namespace c10 {
|
||||
|
||||
/// Constructors
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn::Float8_e4m3fn(float value) {
|
||||
x = detail::fp8e4m3fn_from_fp32_value(value);
|
||||
}
|
||||
|
||||
/// Implicit conversions
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn::operator float() const {
|
||||
return detail::fp8e4m3fn_to_fp32_value(x);
|
||||
}
|
||||
|
||||
/// Special values helper
|
||||
|
||||
inline C10_HOST_DEVICE bool Float8_e4m3fn::isnan() const {
|
||||
return (x & 0b01111111) == 0b01111111;
|
||||
}
|
||||
|
||||
/// Arithmetic
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn
|
||||
operator+(const Float8_e4m3fn& a, const Float8_e4m3fn& b) {
|
||||
return static_cast<float>(a) + static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn
|
||||
operator-(const Float8_e4m3fn& a, const Float8_e4m3fn& b) {
|
||||
return static_cast<float>(a) - static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn
|
||||
operator*(const Float8_e4m3fn& a, const Float8_e4m3fn& b) {
|
||||
return static_cast<float>(a) * static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator/(
|
||||
const Float8_e4m3fn& a,
|
||||
const Float8_e4m3fn& b) __ubsan_ignore_float_divide_by_zero__ {
|
||||
return static_cast<float>(a) / static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator-(const Float8_e4m3fn& a) {
|
||||
return -static_cast<float>(a);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn& operator+=(
|
||||
Float8_e4m3fn& a,
|
||||
const Float8_e4m3fn& b) {
|
||||
a = a + b;
|
||||
return a;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn& operator-=(
|
||||
Float8_e4m3fn& a,
|
||||
const Float8_e4m3fn& b) {
|
||||
a = a - b;
|
||||
return a;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn& operator*=(
|
||||
Float8_e4m3fn& a,
|
||||
const Float8_e4m3fn& b) {
|
||||
a = a * b;
|
||||
return a;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn& operator/=(
|
||||
Float8_e4m3fn& a,
|
||||
const Float8_e4m3fn& b) {
|
||||
a = a / b;
|
||||
return a;
|
||||
}
|
||||
|
||||
/// Arithmetic with floats
|
||||
|
||||
inline C10_HOST_DEVICE float operator+(Float8_e4m3fn a, float b) {
|
||||
return static_cast<float>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator-(Float8_e4m3fn a, float b) {
|
||||
return static_cast<float>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator*(Float8_e4m3fn a, float b) {
|
||||
return static_cast<float>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator/(Float8_e4m3fn a, float b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return static_cast<float>(a) / b;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE float operator+(float a, Float8_e4m3fn b) {
|
||||
return a + static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator-(float a, Float8_e4m3fn b) {
|
||||
return a - static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator*(float a, Float8_e4m3fn b) {
|
||||
return a * static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator/(float a, Float8_e4m3fn b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return a / static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e4m3fn& b) {
|
||||
return a += static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e4m3fn& b) {
|
||||
return a -= static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e4m3fn& b) {
|
||||
return a *= static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e4m3fn& b) {
|
||||
return a /= static_cast<float>(b);
|
||||
}
|
||||
|
||||
/// Arithmetic with doubles
|
||||
|
||||
inline C10_HOST_DEVICE double operator+(Float8_e4m3fn a, double b) {
|
||||
return static_cast<double>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator-(Float8_e4m3fn a, double b) {
|
||||
return static_cast<double>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator*(Float8_e4m3fn a, double b) {
|
||||
return static_cast<double>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator/(Float8_e4m3fn a, double b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return static_cast<double>(a) / b;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE double operator+(double a, Float8_e4m3fn b) {
|
||||
return a + static_cast<double>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator-(double a, Float8_e4m3fn b) {
|
||||
return a - static_cast<double>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator*(double a, Float8_e4m3fn b) {
|
||||
return a * static_cast<double>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator/(double a, Float8_e4m3fn b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return a / static_cast<double>(b);
|
||||
}
|
||||
|
||||
/// Arithmetic with ints
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator+(Float8_e4m3fn a, int b) {
|
||||
return a + static_cast<Float8_e4m3fn>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator-(Float8_e4m3fn a, int b) {
|
||||
return a - static_cast<Float8_e4m3fn>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator*(Float8_e4m3fn a, int b) {
|
||||
return a * static_cast<Float8_e4m3fn>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator/(Float8_e4m3fn a, int b) {
|
||||
return a / static_cast<Float8_e4m3fn>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator+(int a, Float8_e4m3fn b) {
|
||||
return static_cast<Float8_e4m3fn>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator-(int a, Float8_e4m3fn b) {
|
||||
return static_cast<Float8_e4m3fn>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator*(int a, Float8_e4m3fn b) {
|
||||
return static_cast<Float8_e4m3fn>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator/(int a, Float8_e4m3fn b) {
|
||||
return static_cast<Float8_e4m3fn>(a) / b;
|
||||
}
|
||||
|
||||
//// Arithmetic with int64_t
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator+(Float8_e4m3fn a, int64_t b) {
|
||||
return a + static_cast<Float8_e4m3fn>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator-(Float8_e4m3fn a, int64_t b) {
|
||||
return a - static_cast<Float8_e4m3fn>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator*(Float8_e4m3fn a, int64_t b) {
|
||||
return a * static_cast<Float8_e4m3fn>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator/(Float8_e4m3fn a, int64_t b) {
|
||||
return a / static_cast<Float8_e4m3fn>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator+(int64_t a, Float8_e4m3fn b) {
|
||||
return static_cast<Float8_e4m3fn>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator-(int64_t a, Float8_e4m3fn b) {
|
||||
return static_cast<Float8_e4m3fn>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator*(int64_t a, Float8_e4m3fn b) {
|
||||
return static_cast<Float8_e4m3fn>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator/(int64_t a, Float8_e4m3fn b) {
|
||||
return static_cast<Float8_e4m3fn>(a) / b;
|
||||
}
|
||||
|
||||
/// NOTE: we do not define comparisons directly and instead rely on the implicit
|
||||
/// conversion from c10::Float8_e4m3fn to float.
|
||||
|
||||
} // namespace c10
|
||||
|
||||
namespace std {
|
||||
|
||||
template <>
|
||||
class numeric_limits<c10::Float8_e4m3fn> {
|
||||
public:
|
||||
static constexpr bool is_specialized = true;
|
||||
static constexpr bool is_signed = true;
|
||||
static constexpr bool is_integer = false;
|
||||
static constexpr bool is_exact = false;
|
||||
static constexpr bool has_infinity = false;
|
||||
static constexpr bool has_quiet_NaN = true;
|
||||
static constexpr bool has_signaling_NaN = false;
|
||||
static constexpr auto has_denorm = true;
|
||||
static constexpr auto has_denorm_loss = true;
|
||||
static constexpr auto round_style = numeric_limits<float>::round_style;
|
||||
static constexpr bool is_iec559 = false;
|
||||
static constexpr bool is_bounded = true;
|
||||
static constexpr bool is_modulo = false;
|
||||
static constexpr int digits = 4;
|
||||
static constexpr int digits10 = 0;
|
||||
static constexpr int max_digits10 = 3;
|
||||
static constexpr int radix = 2;
|
||||
static constexpr int min_exponent = -5;
|
||||
static constexpr int min_exponent10 = -1;
|
||||
static constexpr int max_exponent = 8;
|
||||
static constexpr int max_exponent10 = 2;
|
||||
static constexpr auto traps = numeric_limits<float>::traps;
|
||||
static constexpr auto tinyness_before = false;
|
||||
|
||||
static constexpr c10::Float8_e4m3fn min() {
|
||||
return c10::Float8_e4m3fn(0x08, c10::Float8_e4m3fn::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e4m3fn lowest() {
|
||||
return c10::Float8_e4m3fn(0xFE, c10::Float8_e4m3fn::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e4m3fn max() {
|
||||
return c10::Float8_e4m3fn(0x7E, c10::Float8_e4m3fn::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e4m3fn epsilon() {
|
||||
return c10::Float8_e4m3fn(0x20, c10::Float8_e4m3fn::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e4m3fn round_error() {
|
||||
return c10::Float8_e4m3fn(0x30, c10::Float8_e4m3fn::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e4m3fn quiet_NaN() {
|
||||
return c10::Float8_e4m3fn(0x7F, c10::Float8_e4m3fn::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e4m3fn denorm_min() {
|
||||
return c10::Float8_e4m3fn(0x01, c10::Float8_e4m3fn::from_bits());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace std
|
||||
|
||||
C10_CLANG_DIAGNOSTIC_POP()
|
||||
14
c10/util/Float8_e4m3fn.cpp
Normal file
14
c10/util/Float8_e4m3fn.cpp
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
#include <c10/util/Float8_e4m3fn.h>
|
||||
#include <iostream>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
static_assert(
|
||||
std::is_standard_layout<Float8_e4m3fn>::value,
|
||||
"c10::Float8_e4m3fn must be standard layout.");
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const Float8_e4m3fn& value) {
|
||||
out << (float)value;
|
||||
return out;
|
||||
}
|
||||
} // namespace c10
|
||||
240
c10/util/Float8_e4m3fn.h
Normal file
240
c10/util/Float8_e4m3fn.h
Normal file
|
|
@ -0,0 +1,240 @@
|
|||
#pragma once
|
||||
|
||||
/// Defines the Float8_e4m3fn type (8-bit floating-point) including conversions
|
||||
/// to standard C types and basic arithmetic operations. Note that arithmetic
|
||||
/// operations are implemented by converting to floating point and
|
||||
/// performing the operation in float32.
|
||||
/// Binary configuration:
|
||||
/// s eeee mmm
|
||||
/// 1 sign bit
|
||||
/// 4 exponent bits
|
||||
/// 3 mantissa bits
|
||||
/// bias = 7
|
||||
///
|
||||
/// Implementation based on the paper https://arxiv.org/pdf/2209.05433.pdf
|
||||
/// and inspired by Half implementation from pytorch/c10/util/Half.h
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/C++17.h>
|
||||
#include <c10/util/TypeSafeSignMath.h>
|
||||
#include <c10/util/floating_point_utils.h>
|
||||
#include <type_traits>
|
||||
|
||||
#if defined(__cplusplus) && (__cplusplus >= 201103L)
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#elif !defined(__OPENCL_VERSION__)
|
||||
#include <math.h>
|
||||
#include <stdint.h>
|
||||
#endif
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#include <intrin.h>
|
||||
#endif
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <iosfwd>
|
||||
#include <limits>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include <typeinfo> // operator typeid
|
||||
|
||||
namespace c10 {
|
||||
|
||||
namespace detail {
|
||||
|
||||
/*
|
||||
* Convert a 8-bit floating-point number in fp8 E4M3FN format, in bit
|
||||
* representation, to a 32-bit floating-point number in IEEE single-precision
|
||||
* format, in bit representation.
|
||||
*
|
||||
* @note The implementation doesn't use any floating-point operations.
|
||||
*/
|
||||
inline C10_HOST_DEVICE float fp8e4m3fn_to_fp32_value(uint8_t input) {
|
||||
/*
|
||||
* Extend the fp8 E4M3FN number to 32 bits and shift to the
|
||||
* upper part of the 32-bit word:
|
||||
* +---+----+---+-----------------------------+
|
||||
* | S |EEEE|MMM|0000 0000 0000 0000 0000 0000|
|
||||
* +---+----+---+-----------------------------+
|
||||
* Bits 31 27-30 24-26 0-23
|
||||
*
|
||||
* S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0
|
||||
* - zero bits.
|
||||
*/
|
||||
const uint32_t w = (uint32_t)input << 24;
|
||||
/*
|
||||
* Extract the sign of the input number into the high bit of the 32-bit word:
|
||||
*
|
||||
* +---+----------------------------------+
|
||||
* | S |0000000 00000000 00000000 00000000|
|
||||
* +---+----------------------------------+
|
||||
* Bits 31 0-31
|
||||
*/
|
||||
const uint32_t sign = w & UINT32_C(0x80000000);
|
||||
/*
|
||||
* Extract mantissa and biased exponent of the input number into the bits 0-30
|
||||
* of the 32-bit word:
|
||||
*
|
||||
* +---+----+---+-----------------------------+
|
||||
* | S |EEEE|MMM|0000 0000 0000 0000 0000 0000|
|
||||
* +---+----+---+-----------------------------+
|
||||
* Bits 31 27-30 24-26 0-23
|
||||
*/
|
||||
const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF);
|
||||
/*
|
||||
* Renorm shift is the number of bits to shift mantissa left to make the
|
||||
* half-precision number normalized. If the initial number is normalized, some
|
||||
* of its high 5 bits (sign == 0 and 4-bit exponent) equals one. In this case
|
||||
* renorm_shift == 0. If the number is denormalize, renorm_shift > 0. Note
|
||||
* that if we shift denormalized nonsign by renorm_shift, the unit bit of
|
||||
* mantissa will shift into exponent, turning the biased exponent into 1, and
|
||||
* making mantissa normalized (i.e. without leading 1).
|
||||
*/
|
||||
#if defined(__CUDA_ARCH__)
|
||||
uint32_t renorm_shift = __clz(nonsign);
|
||||
#elif defined(_MSC_VER)
|
||||
unsigned long nonsign_bsr;
|
||||
_BitScanReverse(&nonsign_bsr, (unsigned long)nonsign);
|
||||
uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31;
|
||||
#else
|
||||
uint32_t renorm_shift = __builtin_clz(nonsign);
|
||||
#endif
|
||||
renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0;
|
||||
/*
|
||||
* Iff fp8e4m3fn number has all exponent and mantissa bits set to 1,
|
||||
* the addition overflows it into bit 31, and the subsequent shift turns the
|
||||
* high 9 bits into 1. Thus inf_nan_mask == 0x7F800000 if the fp8e4m3fn number
|
||||
* is Nan, 0x00000000 otherwise
|
||||
*/
|
||||
const int32_t inf_nan_mask =
|
||||
((int32_t)(nonsign + 0x01000000) >> 8) & INT32_C(0x7F800000);
|
||||
/*
|
||||
* Iff nonsign is 0, it overflows into 0xFFFFFFFF, turning bit 31
|
||||
* into 1. Otherwise, bit 31 remains 0. The signed shift right by 31
|
||||
* broadcasts bit 31 into all bits of the zero_mask. Thus zero_mask ==
|
||||
* 0xFFFFFFFF if the half-precision number was zero (+0.0h or -0.0h)
|
||||
* 0x00000000 otherwise
|
||||
*/
|
||||
const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31;
|
||||
/*
|
||||
* 1. Shift nonsign left by renorm_shift to normalize it (if the input
|
||||
* was denormal)
|
||||
* 2. Shift nonsign right by 4 so the exponent (4 bits originally)
|
||||
* becomes an 8-bit field and 3-bit mantissa shifts into the 3 high
|
||||
* bits of the 23-bit mantissa of IEEE single-precision number.
|
||||
* 3. Add 0x78 to the exponent (starting at bit 23) to compensate the
|
||||
* different in exponent bias (0x7F for single-precision number less 0x07
|
||||
* for fp8e4m3fn number).
|
||||
* 4. Subtract renorm_shift from the exponent (starting at bit 23) to
|
||||
* account for renormalization. As renorm_shift is less than 0x78, this
|
||||
* can be combined with step 3.
|
||||
* 5. Binary OR with inf_nan_mask to turn the exponent into 0xFF if the
|
||||
* input was NaN or infinity.
|
||||
* 6. Binary ANDNOT with zero_mask to turn the mantissa and exponent
|
||||
* into zero if the input was zero.
|
||||
* 7. Combine with the sign of the input number.
|
||||
*/
|
||||
uint32_t result = sign |
|
||||
((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) |
|
||||
inf_nan_mask) &
|
||||
~zero_mask);
|
||||
return fp32_from_bits(result);
|
||||
}
|
||||
|
||||
/*
|
||||
* Convert a 32-bit floating-point number in IEEE single-precision format to a
|
||||
* 8-bit floating-point number in fp8 E4M3FN format, in bit representation.
|
||||
*/
|
||||
inline C10_HOST_DEVICE uint8_t fp8e4m3fn_from_fp32_value(float f) {
|
||||
/*
|
||||
* Binary representation of 480.0f, which is the first value
|
||||
* not representable in fp8e4m3fn range:
|
||||
* 0 1111 111 - fp8e4m3fn
|
||||
* 0 10000111 11100000000000000000000 - fp32
|
||||
*/
|
||||
constexpr uint32_t fp8_max = UINT32_C(1087) << 20;
|
||||
|
||||
/*
|
||||
* A mask for converting fp32 numbers lower than fp8e4m3fn normal range
|
||||
* into denorm representation
|
||||
* magic number: ((127 - 7) + (23 - 3) + 1)
|
||||
*/
|
||||
constexpr uint32_t denorm_mask = UINT32_C(141) << 23;
|
||||
|
||||
uint32_t f_bits = fp32_to_bits(f);
|
||||
|
||||
uint8_t result = 0u;
|
||||
|
||||
/*
|
||||
* Extract the sign of the input number into the high bit of the 32-bit word:
|
||||
*
|
||||
* +---+----------------------------------+
|
||||
* | S |0000000 00000000 00000000 00000000|
|
||||
* +---+----------------------------------+
|
||||
* Bits 31 0-31
|
||||
*/
|
||||
const uint32_t sign = f_bits & UINT32_C(0x80000000);
|
||||
|
||||
/*
|
||||
* Set sign bit to 0
|
||||
*/
|
||||
f_bits ^= sign;
|
||||
|
||||
if (f_bits >= fp8_max) {
|
||||
// NaN - all exponent and mantissa bits set to 1
|
||||
result = 0x7f;
|
||||
} else {
|
||||
if (f_bits < (UINT32_C(121) << 23)) {
|
||||
// Input number is smaller than 2^(-6), which is the smallest
|
||||
// fp8e4m3fn normal number
|
||||
f_bits =
|
||||
fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
|
||||
result = static_cast<uint8_t>(f_bits - denorm_mask);
|
||||
} else {
|
||||
// resulting mantissa is odd
|
||||
uint8_t mant_odd = (f_bits >> 20) & 1;
|
||||
|
||||
// update exponent, rounding bias part 1
|
||||
f_bits += ((uint32_t)(7 - 127) << 23) + 0x7FFFF;
|
||||
|
||||
// rounding bias part 2
|
||||
f_bits += mant_odd;
|
||||
|
||||
// take the bits!
|
||||
result = static_cast<uint8_t>(f_bits >> 20);
|
||||
}
|
||||
}
|
||||
|
||||
result |= static_cast<uint8_t>(sign >> 24);
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
struct alignas(1) Float8_e4m3fn {
|
||||
uint8_t x;
|
||||
|
||||
struct from_bits_t {};
|
||||
C10_HOST_DEVICE static constexpr from_bits_t from_bits() {
|
||||
return from_bits_t();
|
||||
}
|
||||
|
||||
Float8_e4m3fn() = default;
|
||||
|
||||
constexpr C10_HOST_DEVICE Float8_e4m3fn(uint8_t bits, from_bits_t)
|
||||
: x(bits){};
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn(float value);
|
||||
inline C10_HOST_DEVICE operator float() const;
|
||||
inline C10_HOST_DEVICE bool isnan() const;
|
||||
};
|
||||
|
||||
C10_API std::ostream& operator<<(std::ostream& out, const Float8_e4m3fn& value);
|
||||
|
||||
} // namespace c10
|
||||
|
||||
#include <c10/util/Float8_e4m3fn-inl.h> // IWYU pragma: keep
|
||||
284
c10/util/Float8_e5m2-inl.h
Normal file
284
c10/util/Float8_e5m2-inl.h
Normal file
|
|
@ -0,0 +1,284 @@
|
|||
#pragma once
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
|
||||
C10_CLANG_DIAGNOSTIC_PUSH()
|
||||
#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
|
||||
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
|
||||
#endif
|
||||
|
||||
#define EXP_WIDTH_FP8 5
|
||||
#define MAN_WIDTH_FP8 2
|
||||
#define EXP_BIAS_FP8 15
|
||||
|
||||
namespace c10 {
|
||||
|
||||
/// Constructors
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2::Float8_e5m2(float value) {
|
||||
x = detail::fp8e5m2_from_fp32_value(value);
|
||||
}
|
||||
|
||||
/// Implicit conversions
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2::operator float() const {
|
||||
return detail::fp8e5m2_to_fp32_value(x);
|
||||
}
|
||||
|
||||
/// Special values helpers
|
||||
|
||||
inline C10_HOST_DEVICE bool Float8_e5m2::isnan() const {
|
||||
return (x & 0b01111111) > 0b01111100;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE bool Float8_e5m2::isinf() const {
|
||||
return (x & 0b01111111) == 0b01111100;
|
||||
}
|
||||
|
||||
/// Arithmetic
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2
|
||||
operator+(const Float8_e5m2& a, const Float8_e5m2& b) {
|
||||
return static_cast<float>(a) + static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2
|
||||
operator-(const Float8_e5m2& a, const Float8_e5m2& b) {
|
||||
return static_cast<float>(a) - static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2
|
||||
operator*(const Float8_e5m2& a, const Float8_e5m2& b) {
|
||||
return static_cast<float>(a) * static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator/(
|
||||
const Float8_e5m2& a,
|
||||
const Float8_e5m2& b) __ubsan_ignore_float_divide_by_zero__ {
|
||||
return static_cast<float>(a) / static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator-(const Float8_e5m2& a) {
|
||||
return -static_cast<float>(a);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2& operator+=(
|
||||
Float8_e5m2& a,
|
||||
const Float8_e5m2& b) {
|
||||
a = a + b;
|
||||
return a;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2& operator-=(
|
||||
Float8_e5m2& a,
|
||||
const Float8_e5m2& b) {
|
||||
a = a - b;
|
||||
return a;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2& operator*=(
|
||||
Float8_e5m2& a,
|
||||
const Float8_e5m2& b) {
|
||||
a = a * b;
|
||||
return a;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2& operator/=(
|
||||
Float8_e5m2& a,
|
||||
const Float8_e5m2& b) {
|
||||
a = a / b;
|
||||
return a;
|
||||
}
|
||||
|
||||
/// Arithmetic with floats
|
||||
|
||||
inline C10_HOST_DEVICE float operator+(Float8_e5m2 a, float b) {
|
||||
return static_cast<float>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator-(Float8_e5m2 a, float b) {
|
||||
return static_cast<float>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator*(Float8_e5m2 a, float b) {
|
||||
return static_cast<float>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator/(Float8_e5m2 a, float b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return static_cast<float>(a) / b;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE float operator+(float a, Float8_e5m2 b) {
|
||||
return a + static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator-(float a, Float8_e5m2 b) {
|
||||
return a - static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator*(float a, Float8_e5m2 b) {
|
||||
return a * static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator/(float a, Float8_e5m2 b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return a / static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e5m2& b) {
|
||||
return a += static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e5m2& b) {
|
||||
return a -= static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e5m2& b) {
|
||||
return a *= static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e5m2& b) {
|
||||
return a /= static_cast<float>(b);
|
||||
}
|
||||
|
||||
/// Arithmetic with doubles
|
||||
|
||||
inline C10_HOST_DEVICE double operator+(Float8_e5m2 a, double b) {
|
||||
return static_cast<double>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator-(Float8_e5m2 a, double b) {
|
||||
return static_cast<double>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator*(Float8_e5m2 a, double b) {
|
||||
return static_cast<double>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator/(Float8_e5m2 a, double b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return static_cast<double>(a) / b;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE double operator+(double a, Float8_e5m2 b) {
|
||||
return a + static_cast<double>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator-(double a, Float8_e5m2 b) {
|
||||
return a - static_cast<double>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator*(double a, Float8_e5m2 b) {
|
||||
return a * static_cast<double>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator/(double a, Float8_e5m2 b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return a / static_cast<double>(b);
|
||||
}
|
||||
|
||||
/// Arithmetic with ints
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator+(Float8_e5m2 a, int b) {
|
||||
return a + static_cast<Float8_e5m2>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator-(Float8_e5m2 a, int b) {
|
||||
return a - static_cast<Float8_e5m2>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator*(Float8_e5m2 a, int b) {
|
||||
return a * static_cast<Float8_e5m2>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator/(Float8_e5m2 a, int b) {
|
||||
return a / static_cast<Float8_e5m2>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator+(int a, Float8_e5m2 b) {
|
||||
return static_cast<Float8_e5m2>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator-(int a, Float8_e5m2 b) {
|
||||
return static_cast<Float8_e5m2>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator*(int a, Float8_e5m2 b) {
|
||||
return static_cast<Float8_e5m2>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator/(int a, Float8_e5m2 b) {
|
||||
return static_cast<Float8_e5m2>(a) / b;
|
||||
}
|
||||
|
||||
//// Arithmetic with int64_t
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator+(Float8_e5m2 a, int64_t b) {
|
||||
return a + static_cast<Float8_e5m2>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator-(Float8_e5m2 a, int64_t b) {
|
||||
return a - static_cast<Float8_e5m2>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator*(Float8_e5m2 a, int64_t b) {
|
||||
return a * static_cast<Float8_e5m2>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator/(Float8_e5m2 a, int64_t b) {
|
||||
return a / static_cast<Float8_e5m2>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator+(int64_t a, Float8_e5m2 b) {
|
||||
return static_cast<Float8_e5m2>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator-(int64_t a, Float8_e5m2 b) {
|
||||
return static_cast<Float8_e5m2>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator*(int64_t a, Float8_e5m2 b) {
|
||||
return static_cast<Float8_e5m2>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator/(int64_t a, Float8_e5m2 b) {
|
||||
return static_cast<Float8_e5m2>(a) / b;
|
||||
}
|
||||
|
||||
/// NOTE: we do not define comparisons directly and instead rely on the implicit
|
||||
/// conversion from c10::Float8_e5m2 to float.
|
||||
|
||||
} // namespace c10
|
||||
|
||||
namespace std {
|
||||
|
||||
template <>
|
||||
class numeric_limits<c10::Float8_e5m2> {
|
||||
public:
|
||||
static constexpr bool is_signed = true;
|
||||
static constexpr bool is_integer = false;
|
||||
static constexpr bool is_specialized = true;
|
||||
static constexpr bool is_exact = false;
|
||||
static constexpr bool has_infinity = true;
|
||||
static constexpr bool has_quiet_NaN = false;
|
||||
static constexpr bool has_signaling_NaN = false;
|
||||
static constexpr auto has_denorm = true;
|
||||
static constexpr auto has_denorm_loss = true;
|
||||
static constexpr auto round_style = numeric_limits<float>::round_style;
|
||||
static constexpr bool is_iec559 = false;
|
||||
static constexpr bool is_bounded = true;
|
||||
static constexpr bool is_modulo = false;
|
||||
static constexpr int digits = 3;
|
||||
static constexpr int digits10 = 0;
|
||||
static constexpr int max_digits10 = 2;
|
||||
static constexpr int radix = 2;
|
||||
static constexpr int min_exponent = -13;
|
||||
static constexpr int min_exponent10 = -4;
|
||||
static constexpr int max_exponent = 16;
|
||||
static constexpr int max_exponent10 = 4;
|
||||
static constexpr auto traps = numeric_limits<float>::traps;
|
||||
static constexpr auto tinyness_before =
|
||||
numeric_limits<float>::tinyness_before;
|
||||
|
||||
static constexpr c10::Float8_e5m2 min() {
|
||||
return c10::Float8_e5m2(0x4, c10::Float8_e5m2::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e5m2 max() {
|
||||
return c10::Float8_e5m2(0x7B, c10::Float8_e5m2::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e5m2 lowest() {
|
||||
return c10::Float8_e5m2(0xFB, c10::Float8_e5m2::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e5m2 epsilon() {
|
||||
return c10::Float8_e5m2(0x34, c10::Float8_e5m2::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e5m2 round_error() {
|
||||
return c10::Float8_e5m2(0x38, c10::Float8_e5m2::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e5m2 infinity() {
|
||||
return c10::Float8_e5m2(0x7C, c10::Float8_e5m2::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e5m2 denorm_min() {
|
||||
return c10::Float8_e5m2(0x01, c10::Float8_e5m2::from_bits());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace std
|
||||
|
||||
C10_CLANG_DIAGNOSTIC_POP()
|
||||
14
c10/util/Float8_e5m2.cpp
Normal file
14
c10/util/Float8_e5m2.cpp
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
#include <c10/util/Float8_e5m2.h>
|
||||
#include <iostream>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
static_assert(
|
||||
std::is_standard_layout<Float8_e5m2>::value,
|
||||
"c10::Float8_e5m2 must be standard layout.");
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const Float8_e5m2& value) {
|
||||
out << (float)value;
|
||||
return out;
|
||||
}
|
||||
} // namespace c10
|
||||
143
c10/util/Float8_e5m2.h
Normal file
143
c10/util/Float8_e5m2.h
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
#pragma once
|
||||
|
||||
/// Defines the Float8_e5m2 type (8-bit floating-point) including conversions
|
||||
/// to standard C types and basic arithmetic operations. Note that arithmetic
|
||||
/// operations are implemented by converting to floating point and
|
||||
/// performing the operation in float32.
|
||||
/// Binary configuration:
|
||||
/// s eeeee mm
|
||||
/// 1 sign bit
|
||||
/// 5 exponent bits
|
||||
/// 2 mantissa bits
|
||||
/// bias = 15
|
||||
///
|
||||
/// Implementation based on the paper https://arxiv.org/pdf/2209.05433.pdf
|
||||
/// and inspired by Half implementation from pytorch/c10/util/Half.h
|
||||
|
||||
#include <c10/util/Half.h>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
namespace detail {
|
||||
|
||||
/*
|
||||
* Convert a 8-bit floating-point number in fp8 E5M2 format, in bit
|
||||
* representation, to a 32-bit floating-point number in IEEE single-precision
|
||||
* format, in bit representation.
|
||||
*
|
||||
* @note The implementation doesn't use any floating-point operations.
|
||||
*/
|
||||
inline C10_HOST_DEVICE float fp8e5m2_to_fp32_value(uint8_t input) {
|
||||
/*
|
||||
* Extend the fp8 E5M2 number to 32 bits and shift to the
|
||||
* upper part of the 32-bit word:
|
||||
* +---+----+---+-----------------------------+
|
||||
* | S |EEEEE|MM|0000 0000 0000 0000 0000 0000|
|
||||
* +---+----+---+-----------------------------+
|
||||
* Bits 31 26-30 24-25 0-23
|
||||
*
|
||||
* S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0
|
||||
* - zero bits.
|
||||
*/
|
||||
uint16_t half_representation = input;
|
||||
half_representation <<= 8;
|
||||
return fp16_ieee_to_fp32_value(half_representation);
|
||||
}
|
||||
|
||||
/*
|
||||
* Convert a 32-bit floating-point number in IEEE single-precision format to a
|
||||
* 8-bit floating-point number in fp8 E5M2 format, in bit representation.
|
||||
*/
|
||||
inline C10_HOST_DEVICE uint8_t fp8e5m2_from_fp32_value(float f) {
|
||||
/*
|
||||
* Binary representation of fp32 infinity
|
||||
* 0 11111111 00000000000000000000000
|
||||
*/
|
||||
constexpr uint32_t fp32_inf = UINT32_C(255) << 23;
|
||||
|
||||
/*
|
||||
* Binary representation of 65536.0f, which is the first value
|
||||
* not representable in fp8e5m2 range:
|
||||
* 0 11111 00 - fp8e5m2
|
||||
* 0 10001111 00000000000000000000000 - fp32
|
||||
*/
|
||||
constexpr uint32_t fp8_max = UINT32_C(143) << 23;
|
||||
|
||||
/*
|
||||
* A mask for converting fp32 numbers lower than fp8e5m2 normal range
|
||||
* into denorm representation
|
||||
* magic number: ((127 - 15) + (23 - 2) + 1)
|
||||
*/
|
||||
constexpr uint32_t denorm_mask = UINT32_C(134) << 23;
|
||||
|
||||
uint32_t f_bits = fp32_to_bits(f);
|
||||
uint8_t result = 0u;
|
||||
|
||||
/*
|
||||
* Extract the sign of the input number into the high bit of the 32-bit word:
|
||||
*
|
||||
* +---+----------------------------------+
|
||||
* | S |0000000 00000000 00000000 00000000|
|
||||
* +---+----------------------------------+
|
||||
* Bits 31 0-31
|
||||
*/
|
||||
const uint32_t sign = f_bits & UINT32_C(0x80000000);
|
||||
|
||||
/*
|
||||
* Set sign bit to 0
|
||||
*/
|
||||
f_bits ^= sign;
|
||||
|
||||
if (f_bits >= fp8_max) {
|
||||
// NaN - all exponent and mantissa bits set to 1
|
||||
result = f_bits > fp32_inf ? UINT8_C(0x7F) : UINT8_C(0x7C);
|
||||
} else {
|
||||
if (f_bits < (UINT32_C(113) << 23)) {
|
||||
// Input number is smaller than 2^(-14), which is the smallest
|
||||
// fp8e5m2 normal number
|
||||
f_bits =
|
||||
fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
|
||||
result = static_cast<uint8_t>(f_bits - denorm_mask);
|
||||
} else {
|
||||
// resulting mantissa is odd
|
||||
uint32_t mant_odd = (f_bits >> 21) & 1;
|
||||
|
||||
// update exponent, rounding bias part 1
|
||||
f_bits += ((uint32_t)(15 - 127) << 23) + 0xFFFFF;
|
||||
|
||||
// rounding bias part 2
|
||||
f_bits += mant_odd;
|
||||
|
||||
// take the bits!
|
||||
result = static_cast<uint8_t>(f_bits >> 21);
|
||||
}
|
||||
}
|
||||
|
||||
result |= static_cast<uint8_t>(sign >> 24);
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
struct alignas(1) Float8_e5m2 {
|
||||
uint8_t x;
|
||||
|
||||
struct from_bits_t {};
|
||||
C10_HOST_DEVICE static constexpr from_bits_t from_bits() {
|
||||
return from_bits_t();
|
||||
}
|
||||
|
||||
Float8_e5m2() = default;
|
||||
|
||||
constexpr C10_HOST_DEVICE Float8_e5m2(uint8_t bits, from_bits_t) : x(bits){};
|
||||
inline C10_HOST_DEVICE Float8_e5m2(float value);
|
||||
inline C10_HOST_DEVICE operator float() const;
|
||||
inline C10_HOST_DEVICE bool isnan() const;
|
||||
inline C10_HOST_DEVICE bool isinf() const;
|
||||
};
|
||||
|
||||
C10_API std::ostream& operator<<(std::ostream& out, const Float8_e5m2& value);
|
||||
|
||||
} // namespace c10
|
||||
|
||||
#include <c10/util/Float8_e5m2-inl.h> // IWYU pragma: keep
|
||||
|
|
@ -13,6 +13,7 @@
|
|||
#include <c10/util/C++17.h>
|
||||
#include <c10/util/TypeSafeSignMath.h>
|
||||
#include <c10/util/complex.h>
|
||||
#include <c10/util/floating_point_utils.h>
|
||||
#include <type_traits>
|
||||
|
||||
#if defined(__cplusplus) && (__cplusplus >= 201103L)
|
||||
|
|
@ -51,51 +52,12 @@
|
|||
#include <sycl/sycl.hpp> // for SYCL 2020
|
||||
#endif
|
||||
|
||||
// Standard check for compiling CUDA with clang
|
||||
#if defined(__clang__) && defined(__CUDA__) && defined(__CUDA_ARCH__)
|
||||
#define C10_DEVICE_HOST_FUNCTION __device__ __host__
|
||||
#else
|
||||
#define C10_DEVICE_HOST_FUNCTION
|
||||
#endif
|
||||
|
||||
#include <typeinfo> // operator typeid
|
||||
|
||||
namespace c10 {
|
||||
|
||||
namespace detail {
|
||||
|
||||
C10_DEVICE_HOST_FUNCTION inline float fp32_from_bits(uint32_t w) {
|
||||
#if defined(__OPENCL_VERSION__)
|
||||
return as_float(w);
|
||||
#elif defined(__CUDA_ARCH__)
|
||||
return __uint_as_float((unsigned int)w);
|
||||
#elif defined(__INTEL_COMPILER)
|
||||
return _castu32_f32(w);
|
||||
#else
|
||||
union {
|
||||
uint32_t as_bits;
|
||||
float as_value;
|
||||
} fp32 = {w};
|
||||
return fp32.as_value;
|
||||
#endif
|
||||
}
|
||||
|
||||
C10_DEVICE_HOST_FUNCTION inline uint32_t fp32_to_bits(float f) {
|
||||
#if defined(__OPENCL_VERSION__)
|
||||
return as_uint(f);
|
||||
#elif defined(__CUDA_ARCH__)
|
||||
return (uint32_t)__float_as_uint(f);
|
||||
#elif defined(__INTEL_COMPILER)
|
||||
return _castf32_u32(f);
|
||||
#else
|
||||
union {
|
||||
float as_value;
|
||||
uint32_t as_bits;
|
||||
} fp32 = {f};
|
||||
return fp32.as_bits;
|
||||
#endif
|
||||
}
|
||||
|
||||
/*
|
||||
* Convert a 16-bit floating-point number in IEEE half-precision format, in bit
|
||||
* representation, to a 32-bit floating-point number in IEEE single-precision
|
||||
|
|
@ -201,7 +163,7 @@ inline uint32_t fp16_ieee_to_fp32_bits(uint16_t h) {
|
|||
* mode and no operations on denormals) floating-point operations and bitcasts
|
||||
* between integer and floating-point variables.
|
||||
*/
|
||||
inline float fp16_ieee_to_fp32_value(uint16_t h) {
|
||||
C10_HOST_DEVICE inline float fp16_ieee_to_fp32_value(uint16_t h) {
|
||||
/*
|
||||
* Extend the half-precision floating-point number to 32 bits and shift to the
|
||||
* upper part of the 32-bit word:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
#pragma once
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Float8_e4m3fn.h>
|
||||
#include <c10/util/Float8_e5m2.h>
|
||||
#include <c10/util/Half.h>
|
||||
|
||||
#include <type_traits>
|
||||
|
|
@ -77,6 +79,26 @@ struct static_cast_with_inter_type<c10::complex<c10::Half>, c10::BFloat16> {
|
|||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct static_cast_with_inter_type<c10::complex<c10::Half>, c10::Float8_e5m2> {
|
||||
C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex<
|
||||
c10::Half>
|
||||
apply(c10::Float8_e5m2 src) {
|
||||
return static_cast<c10::complex<c10::Half>>(c10::complex<float>{src});
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct static_cast_with_inter_type<
|
||||
c10::complex<c10::Half>,
|
||||
c10::Float8_e4m3fn> {
|
||||
C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex<
|
||||
c10::Half>
|
||||
apply(c10::Float8_e4m3fn src) {
|
||||
return static_cast<c10::complex<c10::Half>>(c10::complex<float>{src});
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct static_cast_with_inter_type<c10::complex<c10::Half>, c10::Half> {
|
||||
C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex<
|
||||
|
|
|
|||
39
c10/util/floating_point_utils.h
Normal file
39
c10/util/floating_point_utils.h
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
namespace c10::detail {
|
||||
|
||||
C10_HOST_DEVICE inline float fp32_from_bits(uint32_t w) {
|
||||
#if defined(__OPENCL_VERSION__)
|
||||
return as_float(w);
|
||||
#elif defined(__CUDA_ARCH__)
|
||||
return __uint_as_float((unsigned int)w);
|
||||
#elif defined(__INTEL_COMPILER)
|
||||
return _castu32_f32(w);
|
||||
#else
|
||||
union {
|
||||
uint32_t as_bits;
|
||||
float as_value;
|
||||
} fp32 = {w};
|
||||
return fp32.as_value;
|
||||
#endif
|
||||
}
|
||||
|
||||
C10_HOST_DEVICE inline uint32_t fp32_to_bits(float f) {
|
||||
#if defined(__OPENCL_VERSION__)
|
||||
return as_uint(f);
|
||||
#elif defined(__CUDA_ARCH__)
|
||||
return (uint32_t)__float_as_uint(f);
|
||||
#elif defined(__INTEL_COMPILER)
|
||||
return _castf32_u32(f);
|
||||
#else
|
||||
union {
|
||||
float as_value;
|
||||
uint32_t as_bits;
|
||||
} fp32 = {f};
|
||||
return fp32.as_bits;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace c10::detail
|
||||
154
test/quantization/core/experimental/test_float8.py
Normal file
154
test/quantization/core/experimental/test_float8.py
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
# Owner(s): ["oncall: quantization"]
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase,
|
||||
parametrize,
|
||||
instantiate_parametrized_tests,
|
||||
run_tests,
|
||||
)
|
||||
|
||||
# Masks for float8 simulation
|
||||
|
||||
# 0 11111111 11000000000000000000000b
|
||||
MASK_152 = torch.tensor(2145386496, dtype=torch.int)
|
||||
# 0 11111111 11100000000000000000000b
|
||||
MASK_143 = torch.tensor(2146435072, dtype=torch.int)
|
||||
MASK = {
|
||||
torch.float8_e5m2: MASK_152,
|
||||
torch.float8_e4m3fn: MASK_143,
|
||||
}
|
||||
|
||||
# 0 00000000 00011111111111111111111b
|
||||
MASK_ROUND_152 = torch.tensor(1048575, dtype=torch.int)
|
||||
# 0 00000000 00001111111111111111111b
|
||||
MASK_ROUND_143 = torch.tensor(524287, dtype=torch.int)
|
||||
MASK_ROUND = {
|
||||
torch.float8_e5m2: MASK_ROUND_152,
|
||||
torch.float8_e4m3fn: MASK_ROUND_143,
|
||||
}
|
||||
|
||||
FP8_MAX_152 = torch.tensor(57344, dtype=torch.float)
|
||||
FP8_MAX_143 = torch.tensor(448, dtype=torch.float)
|
||||
FP8_MAX = {torch.float8_e5m2: FP8_MAX_152, torch.float8_e4m3fn: FP8_MAX_143}
|
||||
|
||||
SPECIAL_NUMBERS = {
|
||||
torch.float8_e5m2: [
|
||||
("01111100", float("inf"), "inf"),
|
||||
("11111100", -1.0 * float("inf"), "neg_inf"),
|
||||
("01111101", float("nan"), "nan"),
|
||||
("11111101", float("nan"), "nan"),
|
||||
("01111110", float("nan"), "nan"),
|
||||
("11111110", float("nan"), "nan"),
|
||||
("01111111", float("nan"), "nan"),
|
||||
("11111111", float("nan"), "nan"),
|
||||
("00000000", 0.0, "zero"),
|
||||
("10000000", -0.0, "neg_zero"),
|
||||
("01111011", 57344.0, "max_normal"),
|
||||
("11111011", -57344.0, "neg_max_normal"),
|
||||
("00000100", 2**-14, "min_normal"),
|
||||
("10000100", -1 * (2**-14), "neg_min_normal"),
|
||||
("00000011", 0.75 * (2**-14), "max_subnorm"),
|
||||
("10000011", -0.75 * (2**-14), "neg_max_subnorm"),
|
||||
("00000001", 2**-16, "min_subnorm"),
|
||||
("10000001", -1 * (2**-16), "neg_min_subnorm"),
|
||||
],
|
||||
torch.float8_e4m3fn: [
|
||||
("01111111", float("nan"), "nan"),
|
||||
("11111111", float("nan"), "nan"),
|
||||
("00000000", 0.0, "zero"),
|
||||
("10000000", -0.0, "neg_zero"),
|
||||
("01111110", 448.0, "max_normal"),
|
||||
("11111110", -448.0, "neg_max_normal"),
|
||||
("00001000", 2**-6, "min_normal"),
|
||||
("10001000", -1 * (2**-6), "neg_min_normal"),
|
||||
("00000111", 0.875 * (2**-6), "max_subnorm"),
|
||||
("10000111", -0.875 * (2**-6), "neg_max_subnorm"),
|
||||
("00000001", 2**-9, "min_subnorm"),
|
||||
("10000001", -1 * (2**-9), "neg_min_subnorm"),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def simulateFp8Precision(input, variant):
|
||||
dtype = torch.float
|
||||
int_type = torch.int
|
||||
mask = MASK[variant]
|
||||
mask_round = MASK_ROUND[variant]
|
||||
excessive_bits = torch.tensor(21, dtype=int_type)
|
||||
|
||||
signs = torch.where(input < 0.0, -1.0, 1.0).to(dtype)
|
||||
asInt = torch.bitwise_and(input.view(int_type), 2147483647)
|
||||
|
||||
mant_odd = torch.bitwise_and(
|
||||
torch.bitwise_right_shift(asInt, excessive_bits),
|
||||
torch.tensor(1, dtype=int_type),
|
||||
)
|
||||
asInt_masked = asInt + mask_round
|
||||
asInt_odded = asInt_masked + mant_odd
|
||||
masked = torch.bitwise_and(asInt_odded, mask)
|
||||
return masked.view(dtype) * signs
|
||||
|
||||
|
||||
class TestFloat8Dtype(TestCase):
|
||||
"""
|
||||
Sanity test for zeros comparison
|
||||
"""
|
||||
@parametrize("dtype", [torch.float8_e5m2, torch.float8_e4m3fn])
|
||||
def test_creation_with_zeros(self, dtype):
|
||||
x = torch.zeros(8, dtype=torch.float)
|
||||
x8 = torch.zeros(8, dtype=dtype)
|
||||
self.assertEqual(x, x8.float())
|
||||
|
||||
"""
|
||||
Numerical test of float8 conversion
|
||||
"""
|
||||
@parametrize("dtype", [torch.float8_e5m2, torch.float8_e4m3fn])
|
||||
def test_cast_to_float8(self, dtype):
|
||||
x = torch.rand((100, 100)) * FP8_MAX[dtype]
|
||||
x = torch.cat((x, -x))
|
||||
x8 = x.to(dtype)
|
||||
x8_simulated = simulateFp8Precision(x, dtype)
|
||||
self.assertEqual(x8_simulated, x8.float())
|
||||
|
||||
"""
|
||||
Test of mul implementation
|
||||
"""
|
||||
@parametrize("dtype", [torch.float8_e5m2, torch.float8_e4m3fn])
|
||||
def test_mul(self, dtype):
|
||||
shape = (10, 10)
|
||||
a = torch.randn(shape)
|
||||
a8_simulated = simulateFp8Precision(a, dtype)
|
||||
a8 = a.to(dtype)
|
||||
b = torch.randn(shape)
|
||||
b8_simulated = simulateFp8Precision(b, dtype)
|
||||
b8 = b.to(dtype)
|
||||
mul8 = a8 * b8
|
||||
mul8_simulated = (a8_simulated * b8_simulated).to(dtype)
|
||||
self.assertEqual(mul8, mul8_simulated)
|
||||
|
||||
"""
|
||||
Test special numbers
|
||||
"""
|
||||
@parametrize("dtype", [torch.float8_e5m2, torch.float8_e4m3fn])
|
||||
def test_special_numbers(self, dtype):
|
||||
def compare_binary_with_decimal(binary, decimal, number_name, dtype):
|
||||
bits_int = int(binary, 2)
|
||||
tensor_int = torch.tensor([bits_int], dtype=torch.uint8)
|
||||
tensor_fp8 = tensor_int.view(dtype)
|
||||
if number_name == "nan":
|
||||
assert tensor_fp8.isnan()
|
||||
else:
|
||||
tensor_fp32 = tensor_fp8.float()
|
||||
ref_tensor_fp32 = torch.tensor([decimal], dtype=torch.float)
|
||||
self.assertEqual(tensor_fp32, ref_tensor_fp32)
|
||||
|
||||
for number in SPECIAL_NUMBERS[dtype]:
|
||||
compare_binary_with_decimal(*number, dtype)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestFloat8Dtype)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
@ -4,6 +4,8 @@
|
|||
|
||||
{ include: [ "<c10/util/BFloat16-inl.h>", private, "<c10/util/BFloat16.h>", public ] },
|
||||
{ include: [ "<c10/util/Half-inl.h>", private, "<c10/util/Half.h>", public ] },
|
||||
{ include: [ "<c10/util/Float8_e5m2-inl.h>", private, "<c10/util/Float8_e5m2.h>", public ] },
|
||||
{ include: [ "<c10/util/Float8_e4m3fn-inl.h>", private, "<c10/util/Float8_e4m3fn.h>", public ] },
|
||||
|
||||
{ include: [ "<c10/util/complex_math.h>", private, "<c10/util/complex.h>", public ] },
|
||||
{ include: [ "<c10/util/complex_utils.h>", private, "<c10/util/complex.h>", public ] },
|
||||
|
|
|
|||
|
|
@ -330,7 +330,12 @@ def _tensor_str(self, indent):
|
|||
if self.is_neg():
|
||||
self = self.resolve_neg()
|
||||
|
||||
if self.dtype is torch.float16 or self.dtype is torch.bfloat16:
|
||||
if self.dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
torch.float8_e5m2,
|
||||
torch.float8_e4m3fn,
|
||||
]:
|
||||
self = self.float()
|
||||
|
||||
if self.dtype is torch.complex32:
|
||||
|
|
|
|||
|
|
@ -213,15 +213,18 @@ static PyObject* THPStorage_fromBuffer(
|
|||
auto dtype = reinterpret_cast<THPDtype*>(dtype_obj);
|
||||
scalar_type = dtype->scalar_type;
|
||||
|
||||
const bool is_endian_independent = (scalar_type == at::kByte) ||
|
||||
(scalar_type == at::kChar) || (scalar_type == at::kFloat8_e5m2) ||
|
||||
(scalar_type == at::kFloat8_e4m3fn);
|
||||
|
||||
TORCH_CHECK(
|
||||
(scalar_type == at::kByte) || (scalar_type == at::kChar) ||
|
||||
(byte_order_str != nullptr),
|
||||
is_endian_independent || (byte_order_str != nullptr),
|
||||
"function missing required argument 'byte_order' (pos 2)");
|
||||
size_t element_size = c10::elementSize(scalar_type);
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
bool do_byte_swap;
|
||||
if (scalar_type != at::kByte && scalar_type != at::kChar) {
|
||||
if (!is_endian_independent) {
|
||||
if (strcmp(byte_order_str, "native") == 0) {
|
||||
do_byte_swap = false;
|
||||
} else if (strcmp(byte_order_str, "big") == 0) {
|
||||
|
|
@ -292,7 +295,7 @@ static PyObject* THPStorage_fromBuffer(
|
|||
c10::GetDefaultCPUAllocator(),
|
||||
/*resizable=*/true);
|
||||
|
||||
if (scalar_type == at::kByte || scalar_type == at::kChar) {
|
||||
if (is_endian_independent) {
|
||||
memcpy(storage->mutable_data(), src + offset, count);
|
||||
} else if (scalar_type == at::kBool) {
|
||||
// Because of ASAN checks, that are failing whenever
|
||||
|
|
|
|||
|
|
@ -14,7 +14,13 @@ Dtype Dtype::scalar_dtype() const {
|
|||
// NOLINTNEXTLINE
|
||||
#define DTYPE_DEFINE(_1, n) TORCH_API Dtype k##n(ScalarType::n, 1);
|
||||
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, DTYPE_DEFINE)
|
||||
AT_FORALL_SCALAR_TYPES_AND5(
|
||||
Bool,
|
||||
Half,
|
||||
BFloat16,
|
||||
Float8_e5m2,
|
||||
Float8_e4m3fn,
|
||||
DTYPE_DEFINE)
|
||||
DTYPE_DEFINE(c10::quint8, QUInt8);
|
||||
DTYPE_DEFINE(c10::qint8, QInt8);
|
||||
|
||||
|
|
@ -28,7 +34,8 @@ Dtype ToDtype(ScalarType type) {
|
|||
#define TYPE_CASE(_1, n) \
|
||||
case ScalarType::n: \
|
||||
return k##n;
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE)
|
||||
AT_FORALL_SCALAR_TYPES_AND5(
|
||||
Bool, Half, BFloat16, Float8_e5m2, Float8_e4m3fn, TYPE_CASE)
|
||||
TYPE_CASE(c10::quint8, QUInt8);
|
||||
TYPE_CASE(c10::qint8, QInt8);
|
||||
#undef TYPE_CASE
|
||||
|
|
@ -58,7 +65,8 @@ int Dtype::byte_size() const {
|
|||
scalar_size = sizeof(Type); \
|
||||
break;
|
||||
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
|
||||
AT_FORALL_SCALAR_TYPES_AND5(
|
||||
Bool, Half, BFloat16, Float8_e5m2, Float8_e4m3fn, TYPE_CASE);
|
||||
TYPE_CASE(c10::quint8, QUInt8);
|
||||
TYPE_CASE(c10::qint8, QInt8);
|
||||
#undef TYPE_CASE
|
||||
|
|
@ -83,6 +91,10 @@ std::string Dtype::ToCppString() const {
|
|||
return "half";
|
||||
case ScalarType::BFloat16:
|
||||
return "bfloat16";
|
||||
case ScalarType::Float8_e5m2:
|
||||
return "float8_e5m2";
|
||||
case ScalarType::Float8_e4m3fn:
|
||||
return "float8_e4m3fn";
|
||||
case ScalarType::QInt8:
|
||||
return "qint8";
|
||||
case ScalarType::QUInt8:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
#pragma once
|
||||
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Float8_e4m3fn.h>
|
||||
#include <c10/util/Float8_e5m2.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <torch/csrc/Export.h>
|
||||
#include <cstddef>
|
||||
|
|
@ -154,6 +156,14 @@ TORCH_API void THP_decodeBFloat16Buffer(
|
|||
const uint8_t* src,
|
||||
THPByteOrder order,
|
||||
size_t len);
|
||||
TORCH_API void THP_decodeFloat8_e5m2Buffer(
|
||||
at::Float8_e5m2* dst,
|
||||
const uint8_t* src,
|
||||
size_t len);
|
||||
TORCH_API void THP_decodeFloat8_e4m3fnBuffer(
|
||||
at::Float8_e4m3fn* dst,
|
||||
const uint8_t* src,
|
||||
size_t len);
|
||||
TORCH_API void THP_decodeComplexFloatBuffer(
|
||||
c10::complex<float>* dst,
|
||||
const uint8_t* src,
|
||||
|
|
|
|||
|
|
@ -70,6 +70,14 @@ inline void store_scalar(void* data, at::ScalarType scalarType, PyObject* obj) {
|
|||
*(at::BFloat16*)data =
|
||||
at::convert<at::BFloat16, double>(THPUtils_unpackDouble(obj));
|
||||
break;
|
||||
case at::kFloat8_e5m2:
|
||||
*(at::Float8_e5m2*)data =
|
||||
at::convert<at::Float8_e5m2, double>(THPUtils_unpackDouble(obj));
|
||||
break;
|
||||
case at::kFloat8_e4m3fn:
|
||||
*(at::Float8_e4m3fn*)data =
|
||||
at::convert<at::Float8_e4m3fn, double>(THPUtils_unpackDouble(obj));
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("invalid type");
|
||||
}
|
||||
|
|
@ -110,6 +118,12 @@ inline PyObject* load_scalar(void* data, at::ScalarType scalarType) {
|
|||
case at::kBFloat16:
|
||||
return PyFloat_FromDouble(
|
||||
at::convert<double, at::BFloat16>(*(at::BFloat16*)data));
|
||||
case at::kFloat8_e5m2:
|
||||
return PyFloat_FromDouble(
|
||||
at::convert<double, at::Float8_e5m2>(*(at::Float8_e5m2*)data));
|
||||
case at::kFloat8_e4m3fn:
|
||||
return PyFloat_FromDouble(
|
||||
at::convert<double, at::Float8_e4m3fn>(*(at::Float8_e4m3fn*)data));
|
||||
default:
|
||||
throw std::runtime_error("invalid type");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -62,6 +62,10 @@ std::pair<std::string, std::string> getDtypeNames(at::ScalarType scalarType) {
|
|||
return std::make_pair("bits8", "");
|
||||
case at::ScalarType::Bits16:
|
||||
return std::make_pair("bits16", "");
|
||||
case at::ScalarType::Float8_e5m2:
|
||||
return std::make_pair("float8_e5m2", "");
|
||||
case at::ScalarType::Float8_e4m3fn:
|
||||
return std::make_pair("float8_e4m3fn", "");
|
||||
default:
|
||||
throw std::runtime_error("Unimplemented scalar type");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -47,6 +47,8 @@ complexHalfT = BaseCppType(
|
|||
complexFloatT = BaseCppType("c10", "complex<float>")
|
||||
complexDoubleT = BaseCppType("c10", "complex<double>")
|
||||
bfloat16T = BaseCppType("at", "BFloat16")
|
||||
float8_e5m2T = BaseCppType("at", "Float8_e5m2")
|
||||
float8_e4m3fnT = BaseCppType("at", "Float8_e4m3fn")
|
||||
stringT = BaseCppType("c10", "string_view")
|
||||
generatorT = BaseCppType("at", "Generator")
|
||||
scalarTypeT = BaseCppType("at", "ScalarType")
|
||||
|
|
@ -93,7 +95,8 @@ ScalarTypeToCppMapping: Dict[ScalarType, BaseCppType] = {
|
|||
ScalarType.ComplexFloat: complexFloatT,
|
||||
ScalarType.ComplexDouble: complexDoubleT,
|
||||
ScalarType.Bool: boolT,
|
||||
ScalarType.BFloat16: bfloat16T,
|
||||
ScalarType.Float8_e5m2: float8_e5m2T,
|
||||
ScalarType.Float8_e4m3fn: float8_e4m3fnT,
|
||||
}
|
||||
|
||||
BaseTypeToCppMapping: Dict[BaseTy, BaseCppType] = {
|
||||
|
|
|
|||
|
|
@ -298,6 +298,8 @@ class ScalarType(Enum):
|
|||
ComplexDouble = auto()
|
||||
Bool = auto()
|
||||
BFloat16 = auto()
|
||||
Float8_e5m2 = auto()
|
||||
Float8_e4m3fn = auto()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user