additional support for float8_e4m3fnuz and _e5m2fnuz (#115214)

Follow up to #107586.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115214
Approved by: https://github.com/peterbell10, https://github.com/malfet
This commit is contained in:
Jeff Daily 2024-01-22 18:33:41 +00:00 committed by PyTorch MergeBot
parent 56ef5afdee
commit 01abb5af21
43 changed files with 708 additions and 625 deletions

View File

@ -4,7 +4,9 @@
#include <c10/core/ScalarType.h>
#include <c10/util/BFloat16.h>
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e4m3fnuz.h>
#include <c10/util/Float8_e5m2.h>
#include <c10/util/Float8_e5m2fnuz.h>
#include <c10/util/Half.h>
// Defines the accumulation type for a scalar type.
@ -87,6 +89,8 @@ MPS_ACC_TYPE(BFloat16, float);
MPS_ACC_TYPE(Half, float);
MPS_ACC_TYPE(Float8_e5m2, float);
MPS_ACC_TYPE(Float8_e4m3fn, float);
MPS_ACC_TYPE(Float8_e5m2fnuz, float);
MPS_ACC_TYPE(Float8_e4m3fnuz, float);
MPS_ACC_TYPE(float, float);
MPS_ACC_TYPE(double, float);
MPS_ACC_TYPE(int8_t, int64_t);
@ -107,6 +111,8 @@ CUDA_ACC_TYPE(BFloat16, float);
CUDA_ACC_TYPE(Half, float);
CUDA_ACC_TYPE(Float8_e5m2, float);
CUDA_ACC_TYPE(Float8_e4m3fn, float);
CUDA_ACC_TYPE(Float8_e5m2fnuz, float);
CUDA_ACC_TYPE(Float8_e4m3fnuz, float);
CUDA_ACC_TYPE(float, float);
CUDA_ACC_TYPE(double, double);
CUDA_ACC_TYPE(int8_t, int64_t);
@ -123,8 +129,8 @@ CUDA_ACC_TYPE(c10::complex<double>, c10::complex<double>);
CPU_ACC_TYPE(BFloat16, float);
CPU_ACC_TYPE(Half, float);
CPU_ACC_TYPE(Float8_e5m2, float);
CPU_ACC_TYPE(Float8_e5m2fnuz, float);
CPU_ACC_TYPE(Float8_e4m3fn, float);
CPU_ACC_TYPE(Float8_e5m2fnuz, float);
CPU_ACC_TYPE(Float8_e4m3fnuz, float);
CPU_ACC_TYPE(float, double);
CPU_ACC_TYPE(double, double);

View File

@ -7,7 +7,9 @@
#include <c10/macros/Macros.h>
#include <c10/util/BFloat16.h>
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e4m3fnuz.h>
#include <c10/util/Float8_e5m2.h>
#include <c10/util/Float8_e5m2fnuz.h>
#include <c10/util/Half.h>
#include <c10/util/complex.h>
@ -80,6 +82,22 @@ inline C10_HOST_DEVICE bool _isnan(T val) {
return val.isnan();
}
template <
typename T,
typename std::enable_if<std::is_same<T, at::Float8_e5m2fnuz>::value, int>::
type = 0>
inline C10_HOST_DEVICE bool _isnan(T val) {
return val.isnan();
}
template <
typename T,
typename std::enable_if<std::is_same<T, at::Float8_e4m3fnuz>::value, int>::
type = 0>
inline C10_HOST_DEVICE bool _isnan(T val) {
return val.isnan();
}
// std::isinf isn't performant to use on integral types; it will
// (uselessly) convert to floating point and then do the test.
// This function is.
@ -118,6 +136,14 @@ inline C10_HOST_DEVICE bool _isinf(at::Float8_e4m3fn val) {
return false;
}
inline C10_HOST_DEVICE bool _isinf(at::Float8_e5m2fnuz val) {
return false;
}
inline C10_HOST_DEVICE bool _isinf(at::Float8_e4m3fnuz val) {
return false;
}
template <typename T>
C10_HOST_DEVICE inline T exp(T x) {
static_assert(

View File

@ -4,7 +4,9 @@
#include <c10/util/BFloat16.h>
#include <c10/util/Exception.h>
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e4m3fnuz.h>
#include <c10/util/Float8_e5m2.h>
#include <c10/util/Float8_e5m2fnuz.h>
#include <c10/util/Half.h>
namespace at {
@ -31,6 +33,14 @@ struct OpMathType<at::Float8_e4m3fn> {
using type = float;
};
template <>
struct OpMathType<at::Float8_e5m2fnuz> {
using type = float;
};
template <>
struct OpMathType<at::Float8_e4m3fnuz> {
using type = float;
};
template <>
struct OpMathType<c10::complex<Half>> {
using type = c10::complex<float>;
};

View File

@ -110,6 +110,8 @@
#include <c10/util/Flags.h>
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e5m2.h>
#include <c10/util/Float8_e4m3fnuz.h>
#include <c10/util/Float8_e5m2fnuz.h>
#include <c10/util/FunctionRef.h>
#include <c10/util/Half.h>
#include <c10/util/IdWrapper.h>

View File

@ -92,6 +92,13 @@ inline cudaDataType ScalarTypeToCudaDataType(const c10::ScalarType& scalar_type)
case c10::ScalarType::Float8_e5m2:
return CUDA_R_8F_E5M2;
#endif
#else // USE_ROCM
#if ROCM_VERSION >= 60000
case c10::ScalarType::Float8_e4m3fnuz:
return HIP_R_8F_E4M3_FNUZ;
case c10::ScalarType::Float8_e5m2fnuz:
return HIP_R_8F_E5M2_FNUZ;
#endif
#endif
default:
TORCH_INTERNAL_ASSERT(false, "Cannot convert ScalarType ", scalar_type, " to cudaDataType.")

View File

@ -1324,9 +1324,9 @@ Tensor outer(const Tensor& self, const Tensor& vec2) {
#if !defined(C10_MOBILE)
#define _AT_DISPATCH_ADDMM_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
kBFloat16, kHalf, kFloat8_e5m2, kFloat8_e4m3fn, \
#define _AT_DISPATCH_ADDMM_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6( \
kBFloat16, kHalf, kFloat8_e5m2, kFloat8_e4m3fn, kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, \
TYPE, NAME, __VA_ARGS__)
#else
#define _AT_DISPATCH_ADDMM_TYPES(TYPE, NAME, ...) \

View File

@ -90,10 +90,7 @@ void atan2_kernel(TensorIteratorBase& iter) {
kHalf, \
kBool, \
kBFloat16, \
kFloat8_e5m2, \
kFloat8_e5m2fnuz, \
kFloat8_e4m3fn, \
kFloat8_e4m3fnuz, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
#define _AT_DISPATCH_ALL_TYPES_NO_BOOL(TYPE, NAME, ...) \
AT_DISPATCH_V2( \
TYPE, \
@ -102,12 +99,10 @@ void atan2_kernel(TensorIteratorBase& iter) {
kComplexHalf, \
kHalf, \
kBFloat16, \
kFloat8_e5m2, \
kFloat8_e4m3fn, \
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
#define _AT_DISPATCH_MUL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_V2(TYPE, NAME, AT_WRAP(__VA_ARGS__), \
kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
#else
#define _AT_DISPATCH_ALL_TYPES_AND_BOOL(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \

View File

@ -268,9 +268,9 @@ void gemm_core_(
}
#if !defined(C10_MOBILE)
#define _AT_DISPATCH_GEMM_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
#define _AT_DISPATCH_GEMM_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6( \
kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, \
TYPE, NAME, __VA_ARGS__)
#else
#define _AT_DISPATCH_GEMM_TYPES(TYPE, NAME, ...) \

View File

@ -179,9 +179,9 @@ static void logit_kernel(TensorIteratorBase& iter, const Scalar& eps_scalar) {
}
#if !defined(C10_MOBILE)
#define _AT_DISPATCH_ABS_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
#define _AT_DISPATCH_ABS_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6( \
kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, \
TYPE, NAME, __VA_ARGS__)
#else
#define _AT_DISPATCH_ABS_TYPES(TYPE, NAME, ...) \

View File

@ -33,7 +33,7 @@ C10_NOINLINE void compare_eq_ne_kernel(TensorIteratorBase &iter, EqOpType op) {
AT_DISPATCH_V2(iter.common_dtype(), "compare_eq_ne_cuda", AT_WRAP([&]() {
opmath_symmetric_gpu_kernel_with_scalars<scalar_t, bool>(
iter, CompareEqFunctor<scalar_t>(op));
}), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBFloat16, kBool, kFloat8_e4m3fn, kFloat8_e5m2, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
}), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBFloat16, kBool, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
}
void eq_kernel_cuda(TensorIteratorBase& iter) {

View File

@ -35,7 +35,6 @@ void float8_copy_kernel_cuda(TensorIteratorBase &iter) {
ScalarType other_dtype = iter.dtype(1);
if (dtype == kFloat8_e4m3fn) {
switch (other_dtype) {
#if !defined(USE_ROCM)
case kFloat:
gpu_kernel_nocast(iter, [] GPU_LAMBDA(float value) {
return Float8_e4m3fn(value);
@ -51,14 +50,12 @@ void float8_copy_kernel_cuda(TensorIteratorBase &iter) {
return Float8_e4m3fn(value);
});
break;
#endif /* !defined(USE_ROCM) */
default:
gpu_kernel(iter, [] GPU_LAMBDA(Float8_e4m3fn x) { return x; });
break;
}
} else if (dtype == kFloat8_e5m2) {
switch (other_dtype) {
#if !defined(USE_ROCM)
case kFloat:
gpu_kernel_nocast(iter, [] GPU_LAMBDA(float value) {
#ifdef AT_USE_NV_CVT_INTRINSICS
@ -89,11 +86,52 @@ void float8_copy_kernel_cuda(TensorIteratorBase &iter) {
#endif
});
break;
#endif /* !defined(USE_ROCM) */
default:
gpu_kernel(iter, [] GPU_LAMBDA(Float8_e5m2 x) { return x; });
break;
}
} else if (dtype == kFloat8_e4m3fnuz) {
switch (other_dtype) {
case kFloat:
gpu_kernel_nocast(iter, [] GPU_LAMBDA(float value) {
return Float8_e4m3fnuz(value);
});
break;
case kHalf:
gpu_kernel_nocast(iter, [] GPU_LAMBDA(Half value) {
return Float8_e4m3fnuz(value);
});
break;
case kBFloat16:
gpu_kernel_nocast(iter, [] GPU_LAMBDA(BFloat16 value) {
return Float8_e4m3fnuz(value);
});
break;
default:
gpu_kernel(iter, [] GPU_LAMBDA(Float8_e4m3fnuz x) { return x; });
break;
}
} else if (dtype == kFloat8_e5m2fnuz) {
switch (other_dtype) {
case kFloat:
gpu_kernel_nocast(iter, [] GPU_LAMBDA(float value) {
return Float8_e5m2fnuz(value);
});
break;
case kHalf:
gpu_kernel_nocast(iter, [] GPU_LAMBDA(Half value) {
return Float8_e5m2fnuz(value);
});
break;
case kBFloat16:
gpu_kernel_nocast(iter, [] GPU_LAMBDA(BFloat16 value) {
return Float8_e5m2fnuz(value);
});
break;
default:
gpu_kernel(iter, [] GPU_LAMBDA(Float8_e5m2fnuz x) { return x; });
break;
}
} else {
TORCH_CHECK(false, "This supposed ot be called only for Float8 types");
}
@ -107,16 +145,14 @@ void direct_copy_kernel_cuda(TensorIteratorBase &iter) {
AT_DISPATCH_QINT_TYPES(dtype, "copy_", [&] {
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return x; });
});
} else if (dtype == kFloat8_e5m2 || dtype == kFloat8_e4m3fn) {
} else if (dtype == kFloat8_e5m2 || dtype == kFloat8_e4m3fn || dtype == kFloat8_e5m2fnuz || dtype == kFloat8_e4m3fnuz) {
float8_copy_kernel_cuda(iter);
#if !defined(USE_ROCM)
} else if (isBitsType(dtype)) {
TORCH_CHECK(dtype == iter.dtype(1), "copy_() does not support casting "
"bits types to different bits types. Source dtype is ", iter.dtype(1), "target dtype is ", dtype);
AT_DISPATCH_BIT_TYPES(dtype, "copy_", [&] {
gpu_kernel_nocast(iter, [] GPU_LAMBDA(scalar_t x) { return x; });
});
#endif /* !defined(USE_ROCM) */
} else {
AT_DISPATCH_V2(
dtype, "copy_", AT_WRAP([&] {

View File

@ -22,7 +22,7 @@ struct FillFunctor {
void fill_kernel_cuda(TensorIterator& iter, const Scalar& value) {
AT_DISPATCH_V2(iter.dtype(), "fill_cuda", AT_WRAP([&]() {
gpu_kernel(iter, FillFunctor<scalar_t>(value.to<scalar_t>()));
}), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kBool, kHalf, kBFloat16, kFloat8_e4m3fn, kFloat8_e5m2, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
}), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kBool, kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
}
REGISTER_DISPATCH(fill_stub, &fill_kernel_cuda);

View File

@ -298,6 +298,33 @@ static void launch_kernel(int64_t N, const func_t& f, array_t data) {}
} // namespace modern
template <typename func_t>
void gpu_kernel_impl_nocast(TensorIteratorBase& iter, const func_t& f) {
using traits = function_traits<func_t>;
using arg0_t = typename traits::result_type;
constexpr int ntensors = traits::arity + 1;
TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
TORCH_INTERNAL_ASSERT(!needs_dynamic_casting<func_t>::check(iter));
at::detail::Array<char*, ntensors> data;
for (int i = 0; i < ntensors; i++) {
data[i] = (char*)iter.data_ptr(i);
}
int64_t numel = iter.numel();
auto offset_calc = ::make_offset_calculator<traits::arity + 1>(iter);
constexpr int unroll_factor = sizeof(arg0_t) >= 4 ? 2 : 4;
legacy::launch_kernel<128, unroll_factor>(numel, [=] GPU_LAMBDA(int idx) {
auto offsets = offset_calc.get(idx);
arg0_t* out = (arg0_t*)(data[0] + offsets[0]);
*out = legacy::invoke(f, &data.data[1], &offsets.data[1], 1);
});
}
template <typename func_t>
void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
using traits = function_traits<func_t>;

View File

@ -525,8 +525,15 @@ static inline bool isSignedType(ScalarType t) {
case ScalarType::ComplexFloat:
case ScalarType::ComplexDouble:
return true;
AT_FORALL_SCALAR_TYPES_AND5(
Half, Bool, BFloat16, Float8_e5m2, Float8_e4m3fn, CASE_SIGNED)
AT_FORALL_SCALAR_TYPES_AND7(
Half,
Bool,
BFloat16,
Float8_e5m2,
Float8_e4m3fn,
Float8_e5m2fnuz,
Float8_e4m3fnuz,
CASE_SIGNED)
default:
TORCH_CHECK(false, "Unknown ScalarType");
}

View File

@ -96,7 +96,7 @@ inline C10_HOST_DEVICE float fp8e4m3fn_to_fp32_value(uint8_t input) {
* mantissa will shift into exponent, turning the biased exponent into 1, and
* making mantissa normalized (i.e. without leading 1).
*/
#if defined(__CUDA_ARCH__)
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
uint32_t renorm_shift = __clz(nonsign);
#elif defined(__SYCL_DEVICE_ONLY__)
// Note: zero is not a supported input into `__builtin_clz`

View File

@ -1,6 +1,8 @@
#pragma once
#include <c10/macros/Macros.h>
#include <c10/util/Float8_fnuz_cvt.h>
#include <cstring>
#include <limits>
C10_CLANG_DIAGNOSTIC_PUSH()
@ -12,21 +14,208 @@ namespace c10 {
/// Constructors
C10_HOST_DEVICE inline Float8_e4m3fnuz::Float8_e4m3fnuz(float value)
inline C10_HOST_DEVICE Float8_e4m3fnuz::Float8_e4m3fnuz(float value)
: x(detail::fp8e4m3fnuz_from_fp32_value(value)) {}
/// Implicit conversions
C10_HOST_DEVICE inline Float8_e4m3fnuz::operator float() const {
return detail::fp8e4m3fnuz_to_fp32_value(x);
inline C10_HOST_DEVICE Float8_e4m3fnuz::operator float() const {
return detail::fp8_fnuz_to_fp32_value<4, 3>(x);
}
/// Special values helper
C10_HOST_DEVICE inline bool Float8_e4m3fnuz::isnan() const {
inline C10_HOST_DEVICE bool Float8_e4m3fnuz::isnan() const {
return x == 0b10000000;
}
/// Arithmetic
inline C10_HOST_DEVICE Float8_e4m3fnuz
operator+(const Float8_e4m3fnuz& a, const Float8_e4m3fnuz& b) {
return static_cast<float>(a) + static_cast<float>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fnuz
operator-(const Float8_e4m3fnuz& a, const Float8_e4m3fnuz& b) {
return static_cast<float>(a) - static_cast<float>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fnuz
operator*(const Float8_e4m3fnuz& a, const Float8_e4m3fnuz& b) {
return static_cast<float>(a) * static_cast<float>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(
const Float8_e4m3fnuz& a,
const Float8_e4m3fnuz& b) __ubsan_ignore_float_divide_by_zero__ {
return static_cast<float>(a) / static_cast<float>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(const Float8_e4m3fnuz& a) {
return -static_cast<float>(a);
}
inline C10_HOST_DEVICE Float8_e4m3fnuz& operator+=(
Float8_e4m3fnuz& a,
const Float8_e4m3fnuz& b) {
a = a + b;
return a;
}
inline C10_HOST_DEVICE Float8_e4m3fnuz& operator-=(
Float8_e4m3fnuz& a,
const Float8_e4m3fnuz& b) {
a = a - b;
return a;
}
inline C10_HOST_DEVICE Float8_e4m3fnuz& operator*=(
Float8_e4m3fnuz& a,
const Float8_e4m3fnuz& b) {
a = a * b;
return a;
}
inline C10_HOST_DEVICE Float8_e4m3fnuz& operator/=(
Float8_e4m3fnuz& a,
const Float8_e4m3fnuz& b) {
a = a / b;
return a;
}
/// Arithmetic with floats
inline C10_HOST_DEVICE float operator+(Float8_e4m3fnuz a, float b) {
return static_cast<float>(a) + b;
}
inline C10_HOST_DEVICE float operator-(Float8_e4m3fnuz a, float b) {
return static_cast<float>(a) - b;
}
inline C10_HOST_DEVICE float operator*(Float8_e4m3fnuz a, float b) {
return static_cast<float>(a) * b;
}
inline C10_HOST_DEVICE float operator/(Float8_e4m3fnuz a, float b)
__ubsan_ignore_float_divide_by_zero__ {
return static_cast<float>(a) / b;
}
inline C10_HOST_DEVICE float operator+(float a, Float8_e4m3fnuz b) {
return a + static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator-(float a, Float8_e4m3fnuz b) {
return a - static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator*(float a, Float8_e4m3fnuz b) {
return a * static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator/(float a, Float8_e4m3fnuz b)
__ubsan_ignore_float_divide_by_zero__ {
return a / static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e4m3fnuz& b) {
return a += static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e4m3fnuz& b) {
return a -= static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e4m3fnuz& b) {
return a *= static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e4m3fnuz& b) {
return a /= static_cast<float>(b);
}
/// Arithmetic with doubles
inline C10_HOST_DEVICE double operator+(Float8_e4m3fnuz a, double b) {
return static_cast<double>(a) + b;
}
inline C10_HOST_DEVICE double operator-(Float8_e4m3fnuz a, double b) {
return static_cast<double>(a) - b;
}
inline C10_HOST_DEVICE double operator*(Float8_e4m3fnuz a, double b) {
return static_cast<double>(a) * b;
}
inline C10_HOST_DEVICE double operator/(Float8_e4m3fnuz a, double b)
__ubsan_ignore_float_divide_by_zero__ {
return static_cast<double>(a) / b;
}
inline C10_HOST_DEVICE double operator+(double a, Float8_e4m3fnuz b) {
return a + static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator-(double a, Float8_e4m3fnuz b) {
return a - static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator*(double a, Float8_e4m3fnuz b) {
return a * static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator/(double a, Float8_e4m3fnuz b)
__ubsan_ignore_float_divide_by_zero__ {
return a / static_cast<double>(b);
}
/// Arithmetic with ints
inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(Float8_e4m3fnuz a, int b) {
return a + static_cast<Float8_e4m3fnuz>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(Float8_e4m3fnuz a, int b) {
return a - static_cast<Float8_e4m3fnuz>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(Float8_e4m3fnuz a, int b) {
return a * static_cast<Float8_e4m3fnuz>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(Float8_e4m3fnuz a, int b) {
return a / static_cast<Float8_e4m3fnuz>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(int a, Float8_e4m3fnuz b) {
return static_cast<Float8_e4m3fnuz>(a) + b;
}
inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(int a, Float8_e4m3fnuz b) {
return static_cast<Float8_e4m3fnuz>(a) - b;
}
inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(int a, Float8_e4m3fnuz b) {
return static_cast<Float8_e4m3fnuz>(a) * b;
}
inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(int a, Float8_e4m3fnuz b) {
return static_cast<Float8_e4m3fnuz>(a) / b;
}
//// Arithmetic with int64_t
inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(Float8_e4m3fnuz a, int64_t b) {
return a + static_cast<Float8_e4m3fnuz>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(Float8_e4m3fnuz a, int64_t b) {
return a - static_cast<Float8_e4m3fnuz>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(Float8_e4m3fnuz a, int64_t b) {
return a * static_cast<Float8_e4m3fnuz>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(Float8_e4m3fnuz a, int64_t b) {
return a / static_cast<Float8_e4m3fnuz>(b);
}
inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(int64_t a, Float8_e4m3fnuz b) {
return static_cast<Float8_e4m3fnuz>(a) + b;
}
inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(int64_t a, Float8_e4m3fnuz b) {
return static_cast<Float8_e4m3fnuz>(a) - b;
}
inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(int64_t a, Float8_e4m3fnuz b) {
return static_cast<Float8_e4m3fnuz>(a) * b;
}
inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(int64_t a, Float8_e4m3fnuz b) {
return static_cast<Float8_e4m3fnuz>(a) / b;
}
/// NOTE: we do not define comparisons directly and instead rely on the implicit
/// conversion from c10::Float8_e4m3fnuz to float.
} // namespace c10
namespace std {

View File

@ -1,276 +1,8 @@
#include <c10/util/Float8_e4m3fnuz.h>
#include <array>
#include <iostream>
namespace c10 {
namespace detail {
C10_HOST_DEVICE float fp8e4m3fnuz_to_fp32_value(uint8_t input) {
constexpr std::array<float, 256> e4m3fnuz_lut = {
0.0f,
0.0009765625f,
0.001953125f,
0.0029296875f,
0.00390625f,
0.0048828125f,
0.005859375f,
0.0068359375f,
0.0078125f,
0.0087890625f,
0.009765625f,
0.0107421875f,
0.01171875f,
0.0126953125f,
0.013671875f,
0.0146484375f,
0.015625f,
0.017578125f,
0.01953125f,
0.021484375f,
0.0234375f,
0.025390625f,
0.02734375f,
0.029296875f,
0.03125f,
0.03515625f,
0.0390625f,
0.04296875f,
0.046875f,
0.05078125f,
0.0546875f,
0.05859375f,
0.0625f,
0.0703125f,
0.078125f,
0.0859375f,
0.09375f,
0.1015625f,
0.109375f,
0.1171875f,
0.125f,
0.140625f,
0.15625f,
0.171875f,
0.1875f,
0.203125f,
0.21875f,
0.234375f,
0.25f,
0.28125f,
0.3125f,
0.34375f,
0.375f,
0.40625f,
0.4375f,
0.46875f,
0.5f,
0.5625f,
0.625f,
0.6875f,
0.75f,
0.8125f,
0.875f,
0.9375f,
1.0f,
1.125f,
1.25f,
1.375f,
1.5f,
1.625f,
1.75f,
1.875f,
2.0f,
2.25f,
2.5f,
2.75f,
3.0f,
3.25f,
3.5f,
3.75f,
4.0f,
4.5f,
5.0f,
5.5f,
6.0f,
6.5f,
7.0f,
7.5f,
8.0f,
9.0f,
10.0f,
11.0f,
12.0f,
13.0f,
14.0f,
15.0f,
16.0f,
18.0f,
20.0f,
22.0f,
24.0f,
26.0f,
28.0f,
30.0f,
32.0f,
36.0f,
40.0f,
44.0f,
48.0f,
52.0f,
56.0f,
60.0f,
64.0f,
72.0f,
80.0f,
88.0f,
96.0f,
104.0f,
112.0f,
120.0f,
128.0f,
144.0f,
160.0f,
176.0f,
192.0f,
208.0f,
224.0f,
240.0f,
std::numeric_limits<float>::signaling_NaN(),
-0.0009765625f,
-0.001953125f,
-0.0029296875f,
-0.00390625f,
-0.0048828125f,
-0.005859375f,
-0.0068359375f,
-0.0078125f,
-0.0087890625f,
-0.009765625f,
-0.0107421875f,
-0.01171875f,
-0.0126953125f,
-0.013671875f,
-0.0146484375f,
-0.015625f,
-0.017578125f,
-0.01953125f,
-0.021484375f,
-0.0234375f,
-0.025390625f,
-0.02734375f,
-0.029296875f,
-0.03125f,
-0.03515625f,
-0.0390625f,
-0.04296875f,
-0.046875f,
-0.05078125f,
-0.0546875f,
-0.05859375f,
-0.0625f,
-0.0703125f,
-0.078125f,
-0.0859375f,
-0.09375f,
-0.1015625f,
-0.109375f,
-0.1171875f,
-0.125f,
-0.140625f,
-0.15625f,
-0.171875f,
-0.1875f,
-0.203125f,
-0.21875f,
-0.234375f,
-0.25f,
-0.28125f,
-0.3125f,
-0.34375f,
-0.375f,
-0.40625f,
-0.4375f,
-0.46875f,
-0.5f,
-0.5625f,
-0.625f,
-0.6875f,
-0.75f,
-0.8125f,
-0.875f,
-0.9375f,
-1.0f,
-1.125f,
-1.25f,
-1.375f,
-1.5f,
-1.625f,
-1.75f,
-1.875f,
-2.0f,
-2.25f,
-2.5f,
-2.75f,
-3.0f,
-3.25f,
-3.5f,
-3.75f,
-4.0f,
-4.5f,
-5.0f,
-5.5f,
-6.0f,
-6.5f,
-7.0f,
-7.5f,
-8.0f,
-9.0f,
-10.0f,
-11.0f,
-12.0f,
-13.0f,
-14.0f,
-15.0f,
-16.0f,
-18.0f,
-20.0f,
-22.0f,
-24.0f,
-26.0f,
-28.0f,
-30.0f,
-32.0f,
-36.0f,
-40.0f,
-44.0f,
-48.0f,
-52.0f,
-56.0f,
-60.0f,
-64.0f,
-72.0f,
-80.0f,
-88.0f,
-96.0f,
-104.0f,
-112.0f,
-120.0f,
-128.0f,
-144.0f,
-160.0f,
-176.0f,
-192.0f,
-208.0f,
-224.0f,
-240.0f,
};
return e4m3fnuz_lut[input];
}
} // namespace detail
static_assert(
std::is_standard_layout_v<Float8_e4m3fnuz>,
"c10::Float8_e4m3fnuz must be standard layout.");

View File

@ -4,13 +4,11 @@
/// conversions to standard C types and basic arithmetic operations. Note that
/// arithmetic operations are implemented by converting to floating point and
/// performing the operation in float32.
///
/// Binary configuration remains the same as Float8_e4m3fn:
/// s eeee mmm
/// 1 sign bit
/// 4 exponent bits
/// 3 mantissa bits
///
/// The key differences versus Float8_e4m3fn are:
/// bias = 8
/// no infinities or negative zero
@ -23,6 +21,7 @@
#include <c10/util/C++17.h>
#include <c10/util/TypeSafeSignMath.h>
#include <c10/util/floating_point_utils.h>
#include <type_traits>
#if defined(__cplusplus) && (__cplusplus >= 201103L)
#include <cstdint>
@ -38,27 +37,11 @@ namespace c10 {
namespace detail {
/*
* Convert a 8-bit floating-point number in fp8 E4M3FNUZ format, in bit
* representation, to a 32-bit floating-point number in IEEE single-precision
* format, in bit representation.
*
* @note The implementation doesn't use any floating-point operations.
*/
#if defined(__CUDA_ARCH__) || defined(__HIP__)
C10_HOST_DEVICE C10_API inline float fp8e4m3fnuz_to_fp32_value(uint8_t) {
CUDA_KERNEL_ASSERT(false && "e4m3fnuz is not supported by CUDA or HIP");
return -1.0;
}
#else
C10_API float fp8e4m3fnuz_to_fp32_value(uint8_t input);
#endif
/*
* Convert a 32-bit floating-point number in IEEE single-precision format to a
* 8-bit floating-point number in fp8 E4M3FNUZ format, in bit representation.
*/
C10_HOST_DEVICE inline uint8_t fp8e4m3fnuz_from_fp32_value(float f) {
inline C10_HOST_DEVICE uint8_t fp8e4m3fnuz_from_fp32_value(float f) {
/*
* Binary representation of 256.0f, which is the first value not representable
* (i.e. the first value which would overflow in to the sign bit, resulting in
@ -70,7 +53,7 @@ C10_HOST_DEVICE inline uint8_t fp8e4m3fnuz_from_fp32_value(float f) {
/*
* A mask for converting fp32 numbers lower than fp8e4m3fnuz normal range
* into denormalized representation.
* into denorm representation
* magic number: ((127 - 8) + (23 - 3) + 1)
*/
constexpr uint32_t denorm_mask = UINT32_C(0x8C) << 23;
@ -123,7 +106,6 @@ C10_HOST_DEVICE inline uint8_t fp8e4m3fnuz_from_fp32_value(float f) {
}
result |= sign >> 24;
return result;
}
@ -133,7 +115,7 @@ struct alignas(1) Float8_e4m3fnuz {
uint8_t x;
struct from_bits_t {};
static constexpr C10_HOST_DEVICE from_bits_t from_bits() {
C10_HOST_DEVICE static constexpr from_bits_t from_bits() {
return from_bits_t();
}

View File

@ -1,6 +1,8 @@
#pragma once
#include <c10/macros/Macros.h>
#include <c10/util/Float8_fnuz_cvt.h>
#include <cstring>
#include <limits>
C10_CLANG_DIAGNOSTIC_PUSH()
@ -12,21 +14,212 @@ namespace c10 {
/// Constructors
C10_HOST_DEVICE inline Float8_e5m2fnuz::Float8_e5m2fnuz(float value)
inline C10_HOST_DEVICE Float8_e5m2fnuz::Float8_e5m2fnuz(float value)
: x(detail::fp8e5m2fnuz_from_fp32_value(value)) {}
/// Implicit conversions
C10_HOST_DEVICE inline Float8_e5m2fnuz::operator float() const {
return detail::fp8e5m2fnuz_to_fp32_value(x);
inline C10_HOST_DEVICE Float8_e5m2fnuz::operator float() const {
return detail::fp8_fnuz_to_fp32_value<5, 2>(x);
}
/// Special values helpers
C10_HOST_DEVICE inline bool Float8_e5m2fnuz::isnan() const {
inline C10_HOST_DEVICE bool Float8_e5m2fnuz::isnan() const {
return x == 0b10000000;
}
inline C10_HOST_DEVICE bool Float8_e5m2fnuz::isinf() const {
return false;
}
/// Arithmetic
inline C10_HOST_DEVICE Float8_e5m2fnuz
operator+(const Float8_e5m2fnuz& a, const Float8_e5m2fnuz& b) {
return static_cast<float>(a) + static_cast<float>(b);
}
inline C10_HOST_DEVICE Float8_e5m2fnuz
operator-(const Float8_e5m2fnuz& a, const Float8_e5m2fnuz& b) {
return static_cast<float>(a) - static_cast<float>(b);
}
inline C10_HOST_DEVICE Float8_e5m2fnuz
operator*(const Float8_e5m2fnuz& a, const Float8_e5m2fnuz& b) {
return static_cast<float>(a) * static_cast<float>(b);
}
inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(
const Float8_e5m2fnuz& a,
const Float8_e5m2fnuz& b) __ubsan_ignore_float_divide_by_zero__ {
return static_cast<float>(a) / static_cast<float>(b);
}
inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(const Float8_e5m2fnuz& a) {
return -static_cast<float>(a);
}
inline C10_HOST_DEVICE Float8_e5m2fnuz& operator+=(
Float8_e5m2fnuz& a,
const Float8_e5m2fnuz& b) {
a = a + b;
return a;
}
inline C10_HOST_DEVICE Float8_e5m2fnuz& operator-=(
Float8_e5m2fnuz& a,
const Float8_e5m2fnuz& b) {
a = a - b;
return a;
}
inline C10_HOST_DEVICE Float8_e5m2fnuz& operator*=(
Float8_e5m2fnuz& a,
const Float8_e5m2fnuz& b) {
a = a * b;
return a;
}
inline C10_HOST_DEVICE Float8_e5m2fnuz& operator/=(
Float8_e5m2fnuz& a,
const Float8_e5m2fnuz& b) {
a = a / b;
return a;
}
/// Arithmetic with floats
inline C10_HOST_DEVICE float operator+(Float8_e5m2fnuz a, float b) {
return static_cast<float>(a) + b;
}
inline C10_HOST_DEVICE float operator-(Float8_e5m2fnuz a, float b) {
return static_cast<float>(a) - b;
}
inline C10_HOST_DEVICE float operator*(Float8_e5m2fnuz a, float b) {
return static_cast<float>(a) * b;
}
inline C10_HOST_DEVICE float operator/(Float8_e5m2fnuz a, float b)
__ubsan_ignore_float_divide_by_zero__ {
return static_cast<float>(a) / b;
}
inline C10_HOST_DEVICE float operator+(float a, Float8_e5m2fnuz b) {
return a + static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator-(float a, Float8_e5m2fnuz b) {
return a - static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator*(float a, Float8_e5m2fnuz b) {
return a * static_cast<float>(b);
}
inline C10_HOST_DEVICE float operator/(float a, Float8_e5m2fnuz b)
__ubsan_ignore_float_divide_by_zero__ {
return a / static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e5m2fnuz& b) {
return a += static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e5m2fnuz& b) {
return a -= static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e5m2fnuz& b) {
return a *= static_cast<float>(b);
}
inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e5m2fnuz& b) {
return a /= static_cast<float>(b);
}
/// Arithmetic with doubles
inline C10_HOST_DEVICE double operator+(Float8_e5m2fnuz a, double b) {
return static_cast<double>(a) + b;
}
inline C10_HOST_DEVICE double operator-(Float8_e5m2fnuz a, double b) {
return static_cast<double>(a) - b;
}
inline C10_HOST_DEVICE double operator*(Float8_e5m2fnuz a, double b) {
return static_cast<double>(a) * b;
}
inline C10_HOST_DEVICE double operator/(Float8_e5m2fnuz a, double b)
__ubsan_ignore_float_divide_by_zero__ {
return static_cast<double>(a) / b;
}
inline C10_HOST_DEVICE double operator+(double a, Float8_e5m2fnuz b) {
return a + static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator-(double a, Float8_e5m2fnuz b) {
return a - static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator*(double a, Float8_e5m2fnuz b) {
return a * static_cast<double>(b);
}
inline C10_HOST_DEVICE double operator/(double a, Float8_e5m2fnuz b)
__ubsan_ignore_float_divide_by_zero__ {
return a / static_cast<double>(b);
}
/// Arithmetic with ints
inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(Float8_e5m2fnuz a, int b) {
return a + static_cast<Float8_e5m2fnuz>(b);
}
inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(Float8_e5m2fnuz a, int b) {
return a - static_cast<Float8_e5m2fnuz>(b);
}
inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(Float8_e5m2fnuz a, int b) {
return a * static_cast<Float8_e5m2fnuz>(b);
}
inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(Float8_e5m2fnuz a, int b) {
return a / static_cast<Float8_e5m2fnuz>(b);
}
inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(int a, Float8_e5m2fnuz b) {
return static_cast<Float8_e5m2fnuz>(a) + b;
}
inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(int a, Float8_e5m2fnuz b) {
return static_cast<Float8_e5m2fnuz>(a) - b;
}
inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(int a, Float8_e5m2fnuz b) {
return static_cast<Float8_e5m2fnuz>(a) * b;
}
inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(int a, Float8_e5m2fnuz b) {
return static_cast<Float8_e5m2fnuz>(a) / b;
}
//// Arithmetic with int64_t
inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(Float8_e5m2fnuz a, int64_t b) {
return a + static_cast<Float8_e5m2fnuz>(b);
}
inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(Float8_e5m2fnuz a, int64_t b) {
return a - static_cast<Float8_e5m2fnuz>(b);
}
inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(Float8_e5m2fnuz a, int64_t b) {
return a * static_cast<Float8_e5m2fnuz>(b);
}
inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(Float8_e5m2fnuz a, int64_t b) {
return a / static_cast<Float8_e5m2fnuz>(b);
}
inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(int64_t a, Float8_e5m2fnuz b) {
return static_cast<Float8_e5m2fnuz>(a) + b;
}
inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(int64_t a, Float8_e5m2fnuz b) {
return static_cast<Float8_e5m2fnuz>(a) - b;
}
inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(int64_t a, Float8_e5m2fnuz b) {
return static_cast<Float8_e5m2fnuz>(a) * b;
}
inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(int64_t a, Float8_e5m2fnuz b) {
return static_cast<Float8_e5m2fnuz>(a) / b;
}
/// NOTE: we do not define comparisons directly and instead rely on the implicit
/// conversion from c10::Float8_e5m2fnuz to float.
} // namespace c10
namespace std {

View File

@ -1,276 +1,8 @@
#include <c10/util/Float8_e5m2fnuz.h>
#include <array>
#include <iostream>
namespace c10 {
namespace detail {
C10_HOST_DEVICE float fp8e5m2fnuz_to_fp32_value(uint8_t input) {
constexpr std::array<float, 256> e5m2fnuz_lut = {
0.0f,
7.62939453125e-06f,
1.52587890625e-05f,
2.288818359375e-05f,
3.0517578125e-05f,
3.814697265625e-05f,
4.57763671875e-05f,
5.340576171875e-05f,
6.103515625e-05f,
7.62939453125e-05f,
9.1552734375e-05f,
0.0001068115234375f,
0.0001220703125f,
0.000152587890625f,
0.00018310546875f,
0.000213623046875f,
0.000244140625f,
0.00030517578125f,
0.0003662109375f,
0.00042724609375f,
0.00048828125f,
0.0006103515625f,
0.000732421875f,
0.0008544921875f,
0.0009765625f,
0.001220703125f,
0.00146484375f,
0.001708984375f,
0.001953125f,
0.00244140625f,
0.0029296875f,
0.00341796875f,
0.00390625f,
0.0048828125f,
0.005859375f,
0.0068359375f,
0.0078125f,
0.009765625f,
0.01171875f,
0.013671875f,
0.015625f,
0.01953125f,
0.0234375f,
0.02734375f,
0.03125f,
0.0390625f,
0.046875f,
0.0546875f,
0.0625f,
0.078125f,
0.09375f,
0.109375f,
0.125f,
0.15625f,
0.1875f,
0.21875f,
0.25f,
0.3125f,
0.375f,
0.4375f,
0.5f,
0.625f,
0.75f,
0.875f,
1.0f,
1.25f,
1.5f,
1.75f,
2.0f,
2.5f,
3.0f,
3.5f,
4.0f,
5.0f,
6.0f,
7.0f,
8.0f,
10.0f,
12.0f,
14.0f,
16.0f,
20.0f,
24.0f,
28.0f,
32.0f,
40.0f,
48.0f,
56.0f,
64.0f,
80.0f,
96.0f,
112.0f,
128.0f,
160.0f,
192.0f,
224.0f,
256.0f,
320.0f,
384.0f,
448.0f,
512.0f,
640.0f,
768.0f,
896.0f,
1024.0f,
1280.0f,
1536.0f,
1792.0f,
2048.0f,
2560.0f,
3072.0f,
3584.0f,
4096.0f,
5120.0f,
6144.0f,
7168.0f,
8192.0f,
10240.0f,
12288.0f,
14336.0f,
16384.0f,
20480.0f,
24576.0f,
28672.0f,
32768.0f,
40960.0f,
49152.0f,
57344.0f,
std::numeric_limits<float>::signaling_NaN(),
-7.62939453125e-06f,
-1.52587890625e-05f,
-2.288818359375e-05f,
-3.0517578125e-05f,
-3.814697265625e-05f,
-4.57763671875e-05f,
-5.340576171875e-05f,
-6.103515625e-05f,
-7.62939453125e-05f,
-9.1552734375e-05f,
-0.0001068115234375f,
-0.0001220703125f,
-0.000152587890625f,
-0.00018310546875f,
-0.000213623046875f,
-0.000244140625f,
-0.00030517578125f,
-0.0003662109375f,
-0.00042724609375f,
-0.00048828125f,
-0.0006103515625f,
-0.000732421875f,
-0.0008544921875f,
-0.0009765625f,
-0.001220703125f,
-0.00146484375f,
-0.001708984375f,
-0.001953125f,
-0.00244140625f,
-0.0029296875f,
-0.00341796875f,
-0.00390625f,
-0.0048828125f,
-0.005859375f,
-0.0068359375f,
-0.0078125f,
-0.009765625f,
-0.01171875f,
-0.013671875f,
-0.015625f,
-0.01953125f,
-0.0234375f,
-0.02734375f,
-0.03125f,
-0.0390625f,
-0.046875f,
-0.0546875f,
-0.0625f,
-0.078125f,
-0.09375f,
-0.109375f,
-0.125f,
-0.15625f,
-0.1875f,
-0.21875f,
-0.25f,
-0.3125f,
-0.375f,
-0.4375f,
-0.5f,
-0.625f,
-0.75f,
-0.875f,
-1.0f,
-1.25f,
-1.5f,
-1.75f,
-2.0f,
-2.5f,
-3.0f,
-3.5f,
-4.0f,
-5.0f,
-6.0f,
-7.0f,
-8.0f,
-10.0f,
-12.0f,
-14.0f,
-16.0f,
-20.0f,
-24.0f,
-28.0f,
-32.0f,
-40.0f,
-48.0f,
-56.0f,
-64.0f,
-80.0f,
-96.0f,
-112.0f,
-128.0f,
-160.0f,
-192.0f,
-224.0f,
-256.0f,
-320.0f,
-384.0f,
-448.0f,
-512.0f,
-640.0f,
-768.0f,
-896.0f,
-1024.0f,
-1280.0f,
-1536.0f,
-1792.0f,
-2048.0f,
-2560.0f,
-3072.0f,
-3584.0f,
-4096.0f,
-5120.0f,
-6144.0f,
-7168.0f,
-8192.0f,
-10240.0f,
-12288.0f,
-14336.0f,
-16384.0f,
-20480.0f,
-24576.0f,
-28672.0f,
-32768.0f,
-40960.0f,
-49152.0f,
-57344.0f,
};
return e5m2fnuz_lut[input];
}
} // namespace detail
static_assert(
std::is_standard_layout_v<Float8_e5m2fnuz>,
"c10::Float8_e5m2 must be standard layout.");

View File

@ -4,13 +4,11 @@
/// conversions to standard C types and basic arithmetic operations. Note that
/// arithmetic operations are implemented by converting to floating point and
/// performing the operation in float32.
///
/// Binary configuration remains the same as e5m2:
/// s eeeee mm
/// 1 sign bit
/// 5 exponent bits
/// 2 mantissa bits
///
/// The key differences that e5m2fnuz brings are:
/// bias = 16
/// no infinities or negative zero
@ -38,27 +36,11 @@ namespace c10 {
namespace detail {
/*
* Convert a 8-bit floating-point number in fp8 E5M2FNUZ format, in bit
* representation, to a 32-bit floating-point number in IEEE single-precision
* format, in bit representation.
*
* @note The implementation doesn't use any floating-point operations.
*/
#if defined(__CUDA_ARCH__) || defined(__HIP__)
C10_HOST_DEVICE C10_API inline float fp8e5m2fnuz_to_fp32_value(uint8_t) {
CUDA_KERNEL_ASSERT(false && "e5m2fnuz is not supported by CUDA or HIP");
return -1.0;
}
#else
C10_API float fp8e5m2fnuz_to_fp32_value(uint8_t input);
#endif
/*
* Convert a 32-bit floating-point number in IEEE single-precision format to a
* 8-bit floating-point number in fp8 E5M2 format, in bit representation.
*/
C10_HOST_DEVICE inline uint8_t fp8e5m2fnuz_from_fp32_value(float f) {
inline C10_HOST_DEVICE uint8_t fp8e5m2fnuz_from_fp32_value(float f) {
/*
* Binary representation of 65536.0f, which is the first value not
* representable (i.e. the first value which would overflow in to the sign
@ -76,7 +58,6 @@ C10_HOST_DEVICE inline uint8_t fp8e5m2fnuz_from_fp32_value(float f) {
constexpr uint32_t denorm_mask = UINT32_C(0x85) << 23;
uint32_t f_bits = fp32_to_bits(f);
uint32_t result = 0u;
/*
@ -132,7 +113,7 @@ struct alignas(1) Float8_e5m2fnuz {
uint8_t x;
struct from_bits_t {};
static constexpr C10_HOST_DEVICE from_bits_t from_bits() {
C10_HOST_DEVICE static constexpr from_bits_t from_bits() {
return from_bits_t();
}
@ -143,6 +124,7 @@ struct alignas(1) Float8_e5m2fnuz {
inline C10_HOST_DEVICE Float8_e5m2fnuz(float value);
inline C10_HOST_DEVICE operator float() const;
inline C10_HOST_DEVICE bool isnan() const;
inline C10_HOST_DEVICE bool isinf() const;
};
C10_API std::ostream& operator<<(

View File

@ -0,0 +1,58 @@
#pragma once
#include <c10/util/floating_point_utils.h>
#include <cstdint>
namespace c10::detail {
/*
* Convert a 8-bit floating-point number in either f8 E4M3FNUZ or bf8 E5M2FNUZ
* format, in bit representation, to a 32-bit floating-point number.
*/
template <uint32_t we, uint32_t wm>
inline C10_HOST_DEVICE float fp8_fnuz_to_fp32_value(uint8_t x) {
static_assert((we == 4 && wm == 3) || (we == 5 && wm == 2));
constexpr uint32_t weo = 8;
constexpr uint32_t wmo = 23;
if (x == 0) {
return 0;
}
if (x == 0x80) {
constexpr uint32_t ifNaN = 0x7F800001;
return fp32_from_bits(ifNaN);
}
uint32_t mantissa = x & ((1 << wm) - 1);
uint32_t exponent = (x & 0x7F) >> wm;
// subnormal input
if (exponent == 0) {
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
uint32_t renorm_shift = __clz(mantissa);
#elif defined(_MSC_VER)
unsigned long nonsign_bsr;
_BitScanReverse(&nonsign_bsr, (unsigned long)mantissa);
uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31;
#else
uint32_t renorm_shift = __builtin_clz(mantissa);
#endif
uint32_t sh = 1 + renorm_shift - (32 - wm);
mantissa <<= sh;
exponent += 1 - sh;
mantissa &= ((1 << wm) - 1);
}
const uint32_t exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1));
exponent += exp_low_cutoff - 1;
mantissa <<= wmo - wm;
uint32_t sign = x >> 7;
uint32_t retval = (sign << 31) | (exponent << 23) | mantissa;
return fp32_from_bits(retval);
}
} // namespace c10::detail

View File

@ -25,6 +25,8 @@ class TensorProtoDataType(Enum):
BFLOAT16 = ...
FLOAT8E5M2 = ...
FLOAT8E4M3FN = ...
FLOAT8E5M2FNUZ = ...
FLOAT8E4M3FNUZ = ...
class OperatorExportTypes(Enum):
ONNX = ...

View File

@ -85,6 +85,8 @@ DTYPE_TO_ATEN = {
torch.complex64: "at::kComplexFloat",
torch.float8_e4m3fn: "at::kFloat8_e4m3fn",
torch.float8_e5m2: "at::kFloat8_e5m2",
torch.float8_e4m3fnuz: "at::kFloat8_e4m3fnuz",
torch.float8_e5m2fnuz: "at::kFloat8_e5m2fnuz",
}
DEVICE_TO_ATEN = {

View File

@ -387,6 +387,10 @@ def triton_compute_type(dtype):
triton_type_name = "float8e4nv"
elif triton_type_name == "float8_e5m2":
triton_type_name = "float8e5"
elif triton_type_name == "float8_e4m3fnuz":
triton_type_name = "float8e4b8"
elif triton_type_name == "float8_e5m2":
triton_type_name = "float8e5b16"
return f"tl.{triton_type_name}"

View File

@ -18,6 +18,10 @@ def signature_of(arg: Union[TensorArg, SizeArg], *, size_dtype: str) -> str:
tye = "*fp8e4nv"
elif arg.dtype == torch.float8_e5m2:
tye = "*fp8e5"
elif arg.dtype == torch.float8_e4m3fnuz:
tye = "*fp8e4b8"
elif arg.dtype == torch.float8_e5m2fnuz:
tye = "*fp8e5b16"
else:
tye = JITFunction._type_of(arg.dtype)
if V.graph.is_unspec_arg(arg.buffer):

View File

@ -91,6 +91,8 @@ def supported_dtype_of_cpp_wrapper(dtype, cuda):
if cuda:
supported_dtype.add(torch.float8_e4m3fn)
supported_dtype.add(torch.float8_e5m2)
supported_dtype.add(torch.float8_e4m3fnuz)
supported_dtype.add(torch.float8_e5m2fnuz)
return dtype in supported_dtype

View File

@ -5520,7 +5520,12 @@ def meta_scaled_mm(
return stride[0] == 1 and stride[1] == shape[0]
def is_fp8_type(dtype):
return dtype in (torch.float8_e4m3fn, torch.float8_e5m2)
return dtype in (
torch.float8_e4m3fn,
torch.float8_e5m2,
torch.float8_e4m3fnuz,
torch.float8_e5m2fnuz,
)
torch._check(
self.dim() == 2 and mat2.dim() == 2,

View File

@ -397,6 +397,8 @@ class Tensor(torch._C.TensorBase):
v3_dtypes = [
torch.float8_e5m2,
torch.float8_e4m3fn,
torch.float8_e5m2fnuz,
torch.float8_e4m3fnuz,
torch.bits8,
torch.bits16,
torch.bits1x8,

View File

@ -81,6 +81,8 @@ def _get_allowed_globals():
torch.complex128,
torch.float8_e5m2,
torch.float8_e4m3fn,
torch.float8_e5m2fnuz,
torch.float8_e4m3fnuz,
torch.float16,
torch.float32,
torch.float64,

View File

@ -85,6 +85,8 @@ AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_cuda();
AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float8_e5m2();
AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float8_e4m3fn();
AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float8_e5m2fnuz();
AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float8_e4m3fnuz();
AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_bfloat16();
AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float16();
AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float32();

View File

@ -74,6 +74,14 @@ int32_t aoti_torch_dtype_float8_e4m3fn() {
return (int32_t)c10::ScalarType::Float8_e4m3fn;
}
int32_t aoti_torch_dtype_float8_e5m2fnuz() {
return (int32_t)c10::ScalarType::Float8_e5m2fnuz;
}
int32_t aoti_torch_dtype_float8_e4m3fnuz() {
return (int32_t)c10::ScalarType::Float8_e4m3fnuz;
}
int32_t aoti_torch_dtype_bfloat16() {
return (int32_t)c10::ScalarType::BFloat16;
}

View File

@ -91,8 +91,12 @@ c10::optional<at::ScalarType> ONNXTypeToATenType(int32_t onnx_type) {
return at::kBFloat16;
case ::torch::onnx::TensorProto_DataType_FLOAT8E5M2:
return at::kFloat8_e5m2;
case ::torch::onnx::TensorProto_DataType_FLOAT8E5M2FNUZ:
return at::kFloat8_e5m2fnuz;
case ::torch::onnx::TensorProto_DataType_FLOAT8E4M3FN:
return at::kFloat8_e4m3fn;
case ::torch::onnx::TensorProto_DataType_FLOAT8E4M3FNUZ:
return at::kFloat8_e4m3fnuz;
default:
TORCH_CHECK(
false,

View File

@ -31,6 +31,8 @@ static const std::unordered_map<c10::ScalarType, int, ScalarTypeHashFunction>
{c10::kBFloat16, 15},
{c10::kFloat8_e4m3fn, 16},
{c10::kFloat8_e5m2, 17},
{c10::kFloat8_e4m3fnuz, 18},
{c10::kFloat8_e5m2fnuz, 19},
};
static int64_t ScalarTypeToONNXType(const c10::ScalarType& st) {

View File

@ -469,6 +469,10 @@ onnx::TensorProto_DataType ATenTypeToOnnxType(at::ScalarType at_type) {
return onnx_torch::TensorProto_DataType_FLOAT8E4M3FN;
case at::kFloat8_e5m2:
return onnx_torch::TensorProto_DataType_FLOAT8E5M2;
case at::kFloat8_e4m3fnuz:
return onnx_torch::TensorProto_DataType_FLOAT8E4M3FNUZ;
case at::kFloat8_e5m2fnuz:
return onnx_torch::TensorProto_DataType_FLOAT8E5M2FNUZ;
default:
TORCH_CHECK(
false,

View File

@ -74,8 +74,15 @@ int Dtype::byte_size() const {
scalar_size = sizeof(Type); \
break;
AT_FORALL_SCALAR_TYPES_AND5(
Bool, Half, BFloat16, Float8_e5m2, Float8_e4m3fn, TYPE_CASE);
AT_FORALL_SCALAR_TYPES_AND7(
Bool,
Half,
BFloat16,
Float8_e5m2,
Float8_e4m3fn,
Float8_e5m2fnuz,
Float8_e4m3fnuz,
TYPE_CASE);
TYPE_CASE(c10::quint8, QUInt8);
TYPE_CASE(c10::qint8, QInt8);
#undef TYPE_CASE

View File

@ -12,8 +12,14 @@ namespace torch::onnx {
// ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN
constexpr auto TensorProto_DataType_FLOAT8E4M3FN =
static_cast<::ONNX_NAMESPACE::TensorProto_DataType>(17);
// ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ
constexpr auto TensorProto_DataType_FLOAT8E4M3FNUZ =
static_cast<::ONNX_NAMESPACE::TensorProto_DataType>(18);
// ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2
constexpr auto TensorProto_DataType_FLOAT8E5M2 =
static_cast<::ONNX_NAMESPACE::TensorProto_DataType>(19);
// ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ
constexpr auto TensorProto_DataType_FLOAT8E5M2FNUZ =
static_cast<::ONNX_NAMESPACE::TensorProto_DataType>(20);
} // namespace torch::onnx

View File

@ -274,7 +274,11 @@ void initONNXBindings(PyObject* module) {
.value("COMPLEX128", ::ONNX_NAMESPACE::TensorProto_DataType_COMPLEX128)
.value("BFLOAT16", ::ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16)
.value("FLOAT8E4M3FN", ::torch::onnx::TensorProto_DataType_FLOAT8E4M3FN)
.value("FLOAT8E5M2", ::torch::onnx::TensorProto_DataType_FLOAT8E5M2);
.value(
"FLOAT8E4M3FNUZ", ::torch::onnx::TensorProto_DataType_FLOAT8E4M3FNUZ)
.value("FLOAT8E5M2", ::torch::onnx::TensorProto_DataType_FLOAT8E5M2)
.value(
"FLOAT8E5M2FNUZ", ::torch::onnx::TensorProto_DataType_FLOAT8E5M2FNUZ);
py::enum_<OperatorExportTypes>(onnx, "OperatorExportTypes")
.value("ONNX", OperatorExportTypes::ONNX)

View File

@ -152,8 +152,8 @@ inline PyObject* load_scalar(void* data, at::ScalarType scalarType) {
return PyFloat_FromDouble(at::convert<double, at::Float8_e5m2fnuz>(
*(at::Float8_e5m2fnuz*)data));
case at::kFloat8_e4m3fnuz:
return PyFloat_FromDouble(at::convert<double, at::Float8_e5m2fnuz>(
*(at::Float8_e5m2fnuz*)data));
return PyFloat_FromDouble(at::convert<double, at::Float8_e4m3fnuz>(
*(at::Float8_e4m3fnuz*)data));
default:
throw std::runtime_error("invalid type");
}

View File

@ -208,6 +208,8 @@ dtype_abbrs = {
torch.float16: 'f16',
torch.float8_e4m3fn: 'f8e4m3fn',
torch.float8_e5m2: 'f8e5m2',
torch.float8_e4m3fnuz: 'f8e4m3fnuz',
torch.float8_e5m2fnuz: 'f8e5m2fnuz',
torch.complex32: 'c32',
torch.complex64: 'c64',
torch.complex128: 'c128',

View File

@ -33,6 +33,8 @@ ScalarName = Literal[
"BFloat16",
"Float8E5M2",
"Float8E4M3FN",
"Float8E5M2FNUZ",
"Float8E4M3FNUZ",
"Undefined",
]
@ -55,6 +57,8 @@ TorchName = Literal[
"bfloat16",
"float8_e5m2",
"float8_e4m3fn",
"float8_e5m2fnuz",
"float8_e4m3fnuz",
]
@ -96,7 +100,9 @@ class JitScalarType(enum.IntEnum):
BFLOAT16 = enum.auto() # 15
FLOAT8E5M2 = enum.auto() # 16
FLOAT8E4M3FN = enum.auto() # 17
UNDEFINED = enum.auto() # 18
FLOAT8E5M2FNUZ = enum.auto() # 18
FLOAT8E4M3FNUZ = enum.auto() # 19
UNDEFINED = enum.auto() # 20
@classmethod
@_beartype.beartype
@ -286,6 +292,8 @@ _SCALAR_TYPE_TO_NAME: Dict[JitScalarType, ScalarName] = {
JitScalarType.BFLOAT16: "BFloat16",
JitScalarType.FLOAT8E5M2: "Float8E5M2",
JitScalarType.FLOAT8E4M3FN: "Float8E4M3FN",
JitScalarType.FLOAT8E5M2FNUZ: "Float8E5M2FNUZ",
JitScalarType.FLOAT8E4M3FNUZ: "Float8E4M3FNUZ",
JitScalarType.UNDEFINED: "Undefined",
}
@ -312,6 +320,8 @@ _SCALAR_TYPE_TO_TORCH_NAME: Dict[JitScalarType, TorchName] = {
JitScalarType.BFLOAT16: "bfloat16",
JitScalarType.FLOAT8E5M2: "float8_e5m2",
JitScalarType.FLOAT8E4M3FN: "float8_e4m3fn",
JitScalarType.FLOAT8E5M2FNUZ: "float8_e5m2fnuz",
JitScalarType.FLOAT8E4M3FNUZ: "float8_e4m3fnuz",
}
_TORCH_NAME_TO_SCALAR_TYPE: Dict[TorchName, JitScalarType] = {
@ -338,6 +348,8 @@ _SCALAR_TYPE_TO_ONNX = {
JitScalarType.QINT32: _C_onnx.TensorProtoDataType.INT32,
JitScalarType.FLOAT8E5M2: _C_onnx.TensorProtoDataType.FLOAT8E5M2,
JitScalarType.FLOAT8E4M3FN: _C_onnx.TensorProtoDataType.FLOAT8E4M3FN,
JitScalarType.FLOAT8E5M2FNUZ: _C_onnx.TensorProtoDataType.FLOAT8E5M2FNUZ,
JitScalarType.FLOAT8E4M3FNUZ: _C_onnx.TensorProtoDataType.FLOAT8E4M3FNUZ,
}
# source of truth is
@ -361,6 +373,8 @@ _SCALAR_TYPE_TO_DTYPE = {
JitScalarType.BFLOAT16: torch.bfloat16,
JitScalarType.FLOAT8E5M2: torch.float8_e5m2,
JitScalarType.FLOAT8E4M3FN: torch.float8_e4m3fn,
JitScalarType.FLOAT8E5M2FNUZ: torch.float8_e5m2fnuz,
JitScalarType.FLOAT8E4M3FNUZ: torch.float8_e4m3fnuz,
}
_DTYPE_TO_SCALAR_TYPE = {v: k for k, v in _SCALAR_TYPE_TO_DTYPE.items()}

View File

@ -207,6 +207,14 @@ class _StorageBase:
"""Casts this storage to float8_e4m3fn type"""
return self._to(torch.float8_e4m3fn)
def float8_e5m2fnuz(self):
"""Casts this storage to float8_e5m2fnuz type"""
return self._to(torch.float8_e5m2fnuz)
def float8_e4m3fnuz(self):
"""Casts this storage to float8_e4m3fnuz type"""
return self._to(torch.float8_e4m3fnuz)
def is_pinned(self, device: Union[str, torch.device] = 'cuda'):
r"""Determine whether the CPU storage is already pinned on device.
@ -1070,6 +1078,16 @@ class TypedStorage:
_warn_typed_storage_removal()
return self._to(torch.float8_e4m3fn)
def float8_e5m2fnuz(self):
"""Casts this storage to float8_e5m2fnuz type"""
_warn_typed_storage_removal()
return self._to(torch.float8_e5m2fnuz)
def float8_e4m3fnuz(self):
"""Casts this storage to float8_e4m3fnuz type"""
_warn_typed_storage_removal()
return self._to(torch.float8_e4m3fnuz)
@classmethod
def from_file(cls, filename, shared, size):
"""from_file(filename, shared=False, size=0) -> Storage

View File

@ -20,7 +20,12 @@ _INTEGRAL_TYPES = [
torch.uint64,
]
_FLOATING_TYPES = [torch.float16, torch.bfloat16, torch.float32, torch.float64]
_FLOATING_8BIT_TYPES = [torch.float8_e4m3fn, torch.float8_e5m2]
_FLOATING_8BIT_TYPES = [
torch.float8_e4m3fn,
torch.float8_e5m2,
torch.float8_e4m3fnuz,
torch.float8_e5m2fnuz,
]
_COMPLEX_TYPES = [torch.complex32, torch.complex64, torch.complex128]
_BOOLEAN_OR_INTEGRAL_TYPES = [torch.bool, *_INTEGRAL_TYPES]
_FLOATING_OR_COMPLEX_TYPES = [*_FLOATING_TYPES, *_COMPLEX_TYPES]