mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
additional support for float8_e4m3fnuz and _e5m2fnuz (#115214)
Follow up to #107586. Pull Request resolved: https://github.com/pytorch/pytorch/pull/115214 Approved by: https://github.com/peterbell10, https://github.com/malfet
This commit is contained in:
parent
56ef5afdee
commit
01abb5af21
|
|
@ -4,7 +4,9 @@
|
|||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Float8_e4m3fn.h>
|
||||
#include <c10/util/Float8_e4m3fnuz.h>
|
||||
#include <c10/util/Float8_e5m2.h>
|
||||
#include <c10/util/Float8_e5m2fnuz.h>
|
||||
#include <c10/util/Half.h>
|
||||
|
||||
// Defines the accumulation type for a scalar type.
|
||||
|
|
@ -87,6 +89,8 @@ MPS_ACC_TYPE(BFloat16, float);
|
|||
MPS_ACC_TYPE(Half, float);
|
||||
MPS_ACC_TYPE(Float8_e5m2, float);
|
||||
MPS_ACC_TYPE(Float8_e4m3fn, float);
|
||||
MPS_ACC_TYPE(Float8_e5m2fnuz, float);
|
||||
MPS_ACC_TYPE(Float8_e4m3fnuz, float);
|
||||
MPS_ACC_TYPE(float, float);
|
||||
MPS_ACC_TYPE(double, float);
|
||||
MPS_ACC_TYPE(int8_t, int64_t);
|
||||
|
|
@ -107,6 +111,8 @@ CUDA_ACC_TYPE(BFloat16, float);
|
|||
CUDA_ACC_TYPE(Half, float);
|
||||
CUDA_ACC_TYPE(Float8_e5m2, float);
|
||||
CUDA_ACC_TYPE(Float8_e4m3fn, float);
|
||||
CUDA_ACC_TYPE(Float8_e5m2fnuz, float);
|
||||
CUDA_ACC_TYPE(Float8_e4m3fnuz, float);
|
||||
CUDA_ACC_TYPE(float, float);
|
||||
CUDA_ACC_TYPE(double, double);
|
||||
CUDA_ACC_TYPE(int8_t, int64_t);
|
||||
|
|
@ -123,8 +129,8 @@ CUDA_ACC_TYPE(c10::complex<double>, c10::complex<double>);
|
|||
CPU_ACC_TYPE(BFloat16, float);
|
||||
CPU_ACC_TYPE(Half, float);
|
||||
CPU_ACC_TYPE(Float8_e5m2, float);
|
||||
CPU_ACC_TYPE(Float8_e5m2fnuz, float);
|
||||
CPU_ACC_TYPE(Float8_e4m3fn, float);
|
||||
CPU_ACC_TYPE(Float8_e5m2fnuz, float);
|
||||
CPU_ACC_TYPE(Float8_e4m3fnuz, float);
|
||||
CPU_ACC_TYPE(float, double);
|
||||
CPU_ACC_TYPE(double, double);
|
||||
|
|
|
|||
|
|
@ -7,7 +7,9 @@
|
|||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Float8_e4m3fn.h>
|
||||
#include <c10/util/Float8_e4m3fnuz.h>
|
||||
#include <c10/util/Float8_e5m2.h>
|
||||
#include <c10/util/Float8_e5m2fnuz.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <c10/util/complex.h>
|
||||
|
||||
|
|
@ -80,6 +82,22 @@ inline C10_HOST_DEVICE bool _isnan(T val) {
|
|||
return val.isnan();
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename std::enable_if<std::is_same<T, at::Float8_e5m2fnuz>::value, int>::
|
||||
type = 0>
|
||||
inline C10_HOST_DEVICE bool _isnan(T val) {
|
||||
return val.isnan();
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename std::enable_if<std::is_same<T, at::Float8_e4m3fnuz>::value, int>::
|
||||
type = 0>
|
||||
inline C10_HOST_DEVICE bool _isnan(T val) {
|
||||
return val.isnan();
|
||||
}
|
||||
|
||||
// std::isinf isn't performant to use on integral types; it will
|
||||
// (uselessly) convert to floating point and then do the test.
|
||||
// This function is.
|
||||
|
|
@ -118,6 +136,14 @@ inline C10_HOST_DEVICE bool _isinf(at::Float8_e4m3fn val) {
|
|||
return false;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE bool _isinf(at::Float8_e5m2fnuz val) {
|
||||
return false;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE bool _isinf(at::Float8_e4m3fnuz val) {
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
C10_HOST_DEVICE inline T exp(T x) {
|
||||
static_assert(
|
||||
|
|
|
|||
|
|
@ -4,7 +4,9 @@
|
|||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Float8_e4m3fn.h>
|
||||
#include <c10/util/Float8_e4m3fnuz.h>
|
||||
#include <c10/util/Float8_e5m2.h>
|
||||
#include <c10/util/Float8_e5m2fnuz.h>
|
||||
#include <c10/util/Half.h>
|
||||
|
||||
namespace at {
|
||||
|
|
@ -31,6 +33,14 @@ struct OpMathType<at::Float8_e4m3fn> {
|
|||
using type = float;
|
||||
};
|
||||
template <>
|
||||
struct OpMathType<at::Float8_e5m2fnuz> {
|
||||
using type = float;
|
||||
};
|
||||
template <>
|
||||
struct OpMathType<at::Float8_e4m3fnuz> {
|
||||
using type = float;
|
||||
};
|
||||
template <>
|
||||
struct OpMathType<c10::complex<Half>> {
|
||||
using type = c10::complex<float>;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -110,6 +110,8 @@
|
|||
#include <c10/util/Flags.h>
|
||||
#include <c10/util/Float8_e4m3fn.h>
|
||||
#include <c10/util/Float8_e5m2.h>
|
||||
#include <c10/util/Float8_e4m3fnuz.h>
|
||||
#include <c10/util/Float8_e5m2fnuz.h>
|
||||
#include <c10/util/FunctionRef.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <c10/util/IdWrapper.h>
|
||||
|
|
|
|||
|
|
@ -92,6 +92,13 @@ inline cudaDataType ScalarTypeToCudaDataType(const c10::ScalarType& scalar_type)
|
|||
case c10::ScalarType::Float8_e5m2:
|
||||
return CUDA_R_8F_E5M2;
|
||||
#endif
|
||||
#else // USE_ROCM
|
||||
#if ROCM_VERSION >= 60000
|
||||
case c10::ScalarType::Float8_e4m3fnuz:
|
||||
return HIP_R_8F_E4M3_FNUZ;
|
||||
case c10::ScalarType::Float8_e5m2fnuz:
|
||||
return HIP_R_8F_E5M2_FNUZ;
|
||||
#endif
|
||||
#endif
|
||||
default:
|
||||
TORCH_INTERNAL_ASSERT(false, "Cannot convert ScalarType ", scalar_type, " to cudaDataType.")
|
||||
|
|
|
|||
|
|
@ -1325,8 +1325,8 @@ Tensor outer(const Tensor& self, const Tensor& vec2) {
|
|||
|
||||
#if !defined(C10_MOBILE)
|
||||
#define _AT_DISPATCH_ADDMM_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
|
||||
kBFloat16, kHalf, kFloat8_e5m2, kFloat8_e4m3fn, \
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6( \
|
||||
kBFloat16, kHalf, kFloat8_e5m2, kFloat8_e4m3fn, kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, \
|
||||
TYPE, NAME, __VA_ARGS__)
|
||||
#else
|
||||
#define _AT_DISPATCH_ADDMM_TYPES(TYPE, NAME, ...) \
|
||||
|
|
|
|||
|
|
@ -90,10 +90,7 @@ void atan2_kernel(TensorIteratorBase& iter) {
|
|||
kHalf, \
|
||||
kBool, \
|
||||
kBFloat16, \
|
||||
kFloat8_e5m2, \
|
||||
kFloat8_e5m2fnuz, \
|
||||
kFloat8_e4m3fn, \
|
||||
kFloat8_e4m3fnuz, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
|
||||
AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
|
||||
#define _AT_DISPATCH_ALL_TYPES_NO_BOOL(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_V2( \
|
||||
TYPE, \
|
||||
|
|
@ -102,12 +99,10 @@ void atan2_kernel(TensorIteratorBase& iter) {
|
|||
kComplexHalf, \
|
||||
kHalf, \
|
||||
kBFloat16, \
|
||||
kFloat8_e5m2, \
|
||||
kFloat8_e4m3fn, \
|
||||
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
|
||||
AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
|
||||
#define _AT_DISPATCH_MUL_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_V2(TYPE, NAME, AT_WRAP(__VA_ARGS__), \
|
||||
kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
|
||||
kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
|
||||
#else
|
||||
#define _AT_DISPATCH_ALL_TYPES_AND_BOOL(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
|
||||
|
|
|
|||
|
|
@ -269,8 +269,8 @@ void gemm_core_(
|
|||
|
||||
#if !defined(C10_MOBILE)
|
||||
#define _AT_DISPATCH_GEMM_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
|
||||
kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6( \
|
||||
kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, \
|
||||
TYPE, NAME, __VA_ARGS__)
|
||||
#else
|
||||
#define _AT_DISPATCH_GEMM_TYPES(TYPE, NAME, ...) \
|
||||
|
|
|
|||
|
|
@ -180,8 +180,8 @@ static void logit_kernel(TensorIteratorBase& iter, const Scalar& eps_scalar) {
|
|||
|
||||
#if !defined(C10_MOBILE)
|
||||
#define _AT_DISPATCH_ABS_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
|
||||
kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6( \
|
||||
kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, \
|
||||
TYPE, NAME, __VA_ARGS__)
|
||||
#else
|
||||
#define _AT_DISPATCH_ABS_TYPES(TYPE, NAME, ...) \
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ C10_NOINLINE void compare_eq_ne_kernel(TensorIteratorBase &iter, EqOpType op) {
|
|||
AT_DISPATCH_V2(iter.common_dtype(), "compare_eq_ne_cuda", AT_WRAP([&]() {
|
||||
opmath_symmetric_gpu_kernel_with_scalars<scalar_t, bool>(
|
||||
iter, CompareEqFunctor<scalar_t>(op));
|
||||
}), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBFloat16, kBool, kFloat8_e4m3fn, kFloat8_e5m2, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
|
||||
}), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBFloat16, kBool, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
|
||||
}
|
||||
|
||||
void eq_kernel_cuda(TensorIteratorBase& iter) {
|
||||
|
|
|
|||
|
|
@ -35,7 +35,6 @@ void float8_copy_kernel_cuda(TensorIteratorBase &iter) {
|
|||
ScalarType other_dtype = iter.dtype(1);
|
||||
if (dtype == kFloat8_e4m3fn) {
|
||||
switch (other_dtype) {
|
||||
#if !defined(USE_ROCM)
|
||||
case kFloat:
|
||||
gpu_kernel_nocast(iter, [] GPU_LAMBDA(float value) {
|
||||
return Float8_e4m3fn(value);
|
||||
|
|
@ -51,14 +50,12 @@ void float8_copy_kernel_cuda(TensorIteratorBase &iter) {
|
|||
return Float8_e4m3fn(value);
|
||||
});
|
||||
break;
|
||||
#endif /* !defined(USE_ROCM) */
|
||||
default:
|
||||
gpu_kernel(iter, [] GPU_LAMBDA(Float8_e4m3fn x) { return x; });
|
||||
break;
|
||||
}
|
||||
} else if (dtype == kFloat8_e5m2) {
|
||||
switch (other_dtype) {
|
||||
#if !defined(USE_ROCM)
|
||||
case kFloat:
|
||||
gpu_kernel_nocast(iter, [] GPU_LAMBDA(float value) {
|
||||
#ifdef AT_USE_NV_CVT_INTRINSICS
|
||||
|
|
@ -89,11 +86,52 @@ void float8_copy_kernel_cuda(TensorIteratorBase &iter) {
|
|||
#endif
|
||||
});
|
||||
break;
|
||||
#endif /* !defined(USE_ROCM) */
|
||||
default:
|
||||
gpu_kernel(iter, [] GPU_LAMBDA(Float8_e5m2 x) { return x; });
|
||||
break;
|
||||
}
|
||||
} else if (dtype == kFloat8_e4m3fnuz) {
|
||||
switch (other_dtype) {
|
||||
case kFloat:
|
||||
gpu_kernel_nocast(iter, [] GPU_LAMBDA(float value) {
|
||||
return Float8_e4m3fnuz(value);
|
||||
});
|
||||
break;
|
||||
case kHalf:
|
||||
gpu_kernel_nocast(iter, [] GPU_LAMBDA(Half value) {
|
||||
return Float8_e4m3fnuz(value);
|
||||
});
|
||||
break;
|
||||
case kBFloat16:
|
||||
gpu_kernel_nocast(iter, [] GPU_LAMBDA(BFloat16 value) {
|
||||
return Float8_e4m3fnuz(value);
|
||||
});
|
||||
break;
|
||||
default:
|
||||
gpu_kernel(iter, [] GPU_LAMBDA(Float8_e4m3fnuz x) { return x; });
|
||||
break;
|
||||
}
|
||||
} else if (dtype == kFloat8_e5m2fnuz) {
|
||||
switch (other_dtype) {
|
||||
case kFloat:
|
||||
gpu_kernel_nocast(iter, [] GPU_LAMBDA(float value) {
|
||||
return Float8_e5m2fnuz(value);
|
||||
});
|
||||
break;
|
||||
case kHalf:
|
||||
gpu_kernel_nocast(iter, [] GPU_LAMBDA(Half value) {
|
||||
return Float8_e5m2fnuz(value);
|
||||
});
|
||||
break;
|
||||
case kBFloat16:
|
||||
gpu_kernel_nocast(iter, [] GPU_LAMBDA(BFloat16 value) {
|
||||
return Float8_e5m2fnuz(value);
|
||||
});
|
||||
break;
|
||||
default:
|
||||
gpu_kernel(iter, [] GPU_LAMBDA(Float8_e5m2fnuz x) { return x; });
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "This supposed ot be called only for Float8 types");
|
||||
}
|
||||
|
|
@ -107,16 +145,14 @@ void direct_copy_kernel_cuda(TensorIteratorBase &iter) {
|
|||
AT_DISPATCH_QINT_TYPES(dtype, "copy_", [&] {
|
||||
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return x; });
|
||||
});
|
||||
} else if (dtype == kFloat8_e5m2 || dtype == kFloat8_e4m3fn) {
|
||||
} else if (dtype == kFloat8_e5m2 || dtype == kFloat8_e4m3fn || dtype == kFloat8_e5m2fnuz || dtype == kFloat8_e4m3fnuz) {
|
||||
float8_copy_kernel_cuda(iter);
|
||||
#if !defined(USE_ROCM)
|
||||
} else if (isBitsType(dtype)) {
|
||||
TORCH_CHECK(dtype == iter.dtype(1), "copy_() does not support casting "
|
||||
"bits types to different bits types. Source dtype is ", iter.dtype(1), "target dtype is ", dtype);
|
||||
AT_DISPATCH_BIT_TYPES(dtype, "copy_", [&] {
|
||||
gpu_kernel_nocast(iter, [] GPU_LAMBDA(scalar_t x) { return x; });
|
||||
});
|
||||
#endif /* !defined(USE_ROCM) */
|
||||
} else {
|
||||
AT_DISPATCH_V2(
|
||||
dtype, "copy_", AT_WRAP([&] {
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ struct FillFunctor {
|
|||
void fill_kernel_cuda(TensorIterator& iter, const Scalar& value) {
|
||||
AT_DISPATCH_V2(iter.dtype(), "fill_cuda", AT_WRAP([&]() {
|
||||
gpu_kernel(iter, FillFunctor<scalar_t>(value.to<scalar_t>()));
|
||||
}), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kBool, kHalf, kBFloat16, kFloat8_e4m3fn, kFloat8_e5m2, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
|
||||
}), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kBool, kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
|
||||
}
|
||||
|
||||
REGISTER_DISPATCH(fill_stub, &fill_kernel_cuda);
|
||||
|
|
|
|||
|
|
@ -298,6 +298,33 @@ static void launch_kernel(int64_t N, const func_t& f, array_t data) {}
|
|||
} // namespace modern
|
||||
|
||||
|
||||
template <typename func_t>
|
||||
void gpu_kernel_impl_nocast(TensorIteratorBase& iter, const func_t& f) {
|
||||
using traits = function_traits<func_t>;
|
||||
using arg0_t = typename traits::result_type;
|
||||
constexpr int ntensors = traits::arity + 1;
|
||||
|
||||
TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
|
||||
TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
|
||||
TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
|
||||
TORCH_INTERNAL_ASSERT(!needs_dynamic_casting<func_t>::check(iter));
|
||||
|
||||
at::detail::Array<char*, ntensors> data;
|
||||
for (int i = 0; i < ntensors; i++) {
|
||||
data[i] = (char*)iter.data_ptr(i);
|
||||
}
|
||||
|
||||
int64_t numel = iter.numel();
|
||||
|
||||
auto offset_calc = ::make_offset_calculator<traits::arity + 1>(iter);
|
||||
constexpr int unroll_factor = sizeof(arg0_t) >= 4 ? 2 : 4;
|
||||
legacy::launch_kernel<128, unroll_factor>(numel, [=] GPU_LAMBDA(int idx) {
|
||||
auto offsets = offset_calc.get(idx);
|
||||
arg0_t* out = (arg0_t*)(data[0] + offsets[0]);
|
||||
*out = legacy::invoke(f, &data.data[1], &offsets.data[1], 1);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename func_t>
|
||||
void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
|
||||
using traits = function_traits<func_t>;
|
||||
|
|
|
|||
|
|
@ -525,8 +525,15 @@ static inline bool isSignedType(ScalarType t) {
|
|||
case ScalarType::ComplexFloat:
|
||||
case ScalarType::ComplexDouble:
|
||||
return true;
|
||||
AT_FORALL_SCALAR_TYPES_AND5(
|
||||
Half, Bool, BFloat16, Float8_e5m2, Float8_e4m3fn, CASE_SIGNED)
|
||||
AT_FORALL_SCALAR_TYPES_AND7(
|
||||
Half,
|
||||
Bool,
|
||||
BFloat16,
|
||||
Float8_e5m2,
|
||||
Float8_e4m3fn,
|
||||
Float8_e5m2fnuz,
|
||||
Float8_e4m3fnuz,
|
||||
CASE_SIGNED)
|
||||
default:
|
||||
TORCH_CHECK(false, "Unknown ScalarType");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -96,7 +96,7 @@ inline C10_HOST_DEVICE float fp8e4m3fn_to_fp32_value(uint8_t input) {
|
|||
* mantissa will shift into exponent, turning the biased exponent into 1, and
|
||||
* making mantissa normalized (i.e. without leading 1).
|
||||
*/
|
||||
#if defined(__CUDA_ARCH__)
|
||||
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
|
||||
uint32_t renorm_shift = __clz(nonsign);
|
||||
#elif defined(__SYCL_DEVICE_ONLY__)
|
||||
// Note: zero is not a supported input into `__builtin_clz`
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
#pragma once
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/Float8_fnuz_cvt.h>
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
|
||||
C10_CLANG_DIAGNOSTIC_PUSH()
|
||||
|
|
@ -12,21 +14,208 @@ namespace c10 {
|
|||
|
||||
/// Constructors
|
||||
|
||||
C10_HOST_DEVICE inline Float8_e4m3fnuz::Float8_e4m3fnuz(float value)
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz::Float8_e4m3fnuz(float value)
|
||||
: x(detail::fp8e4m3fnuz_from_fp32_value(value)) {}
|
||||
|
||||
/// Implicit conversions
|
||||
|
||||
C10_HOST_DEVICE inline Float8_e4m3fnuz::operator float() const {
|
||||
return detail::fp8e4m3fnuz_to_fp32_value(x);
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz::operator float() const {
|
||||
return detail::fp8_fnuz_to_fp32_value<4, 3>(x);
|
||||
}
|
||||
|
||||
/// Special values helper
|
||||
|
||||
C10_HOST_DEVICE inline bool Float8_e4m3fnuz::isnan() const {
|
||||
inline C10_HOST_DEVICE bool Float8_e4m3fnuz::isnan() const {
|
||||
return x == 0b10000000;
|
||||
}
|
||||
|
||||
/// Arithmetic
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz
|
||||
operator+(const Float8_e4m3fnuz& a, const Float8_e4m3fnuz& b) {
|
||||
return static_cast<float>(a) + static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz
|
||||
operator-(const Float8_e4m3fnuz& a, const Float8_e4m3fnuz& b) {
|
||||
return static_cast<float>(a) - static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz
|
||||
operator*(const Float8_e4m3fnuz& a, const Float8_e4m3fnuz& b) {
|
||||
return static_cast<float>(a) * static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(
|
||||
const Float8_e4m3fnuz& a,
|
||||
const Float8_e4m3fnuz& b) __ubsan_ignore_float_divide_by_zero__ {
|
||||
return static_cast<float>(a) / static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(const Float8_e4m3fnuz& a) {
|
||||
return -static_cast<float>(a);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz& operator+=(
|
||||
Float8_e4m3fnuz& a,
|
||||
const Float8_e4m3fnuz& b) {
|
||||
a = a + b;
|
||||
return a;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz& operator-=(
|
||||
Float8_e4m3fnuz& a,
|
||||
const Float8_e4m3fnuz& b) {
|
||||
a = a - b;
|
||||
return a;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz& operator*=(
|
||||
Float8_e4m3fnuz& a,
|
||||
const Float8_e4m3fnuz& b) {
|
||||
a = a * b;
|
||||
return a;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz& operator/=(
|
||||
Float8_e4m3fnuz& a,
|
||||
const Float8_e4m3fnuz& b) {
|
||||
a = a / b;
|
||||
return a;
|
||||
}
|
||||
|
||||
/// Arithmetic with floats
|
||||
|
||||
inline C10_HOST_DEVICE float operator+(Float8_e4m3fnuz a, float b) {
|
||||
return static_cast<float>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator-(Float8_e4m3fnuz a, float b) {
|
||||
return static_cast<float>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator*(Float8_e4m3fnuz a, float b) {
|
||||
return static_cast<float>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator/(Float8_e4m3fnuz a, float b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return static_cast<float>(a) / b;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE float operator+(float a, Float8_e4m3fnuz b) {
|
||||
return a + static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator-(float a, Float8_e4m3fnuz b) {
|
||||
return a - static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator*(float a, Float8_e4m3fnuz b) {
|
||||
return a * static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator/(float a, Float8_e4m3fnuz b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return a / static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e4m3fnuz& b) {
|
||||
return a += static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e4m3fnuz& b) {
|
||||
return a -= static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e4m3fnuz& b) {
|
||||
return a *= static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e4m3fnuz& b) {
|
||||
return a /= static_cast<float>(b);
|
||||
}
|
||||
|
||||
/// Arithmetic with doubles
|
||||
|
||||
inline C10_HOST_DEVICE double operator+(Float8_e4m3fnuz a, double b) {
|
||||
return static_cast<double>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator-(Float8_e4m3fnuz a, double b) {
|
||||
return static_cast<double>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator*(Float8_e4m3fnuz a, double b) {
|
||||
return static_cast<double>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator/(Float8_e4m3fnuz a, double b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return static_cast<double>(a) / b;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE double operator+(double a, Float8_e4m3fnuz b) {
|
||||
return a + static_cast<double>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator-(double a, Float8_e4m3fnuz b) {
|
||||
return a - static_cast<double>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator*(double a, Float8_e4m3fnuz b) {
|
||||
return a * static_cast<double>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator/(double a, Float8_e4m3fnuz b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return a / static_cast<double>(b);
|
||||
}
|
||||
|
||||
/// Arithmetic with ints
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(Float8_e4m3fnuz a, int b) {
|
||||
return a + static_cast<Float8_e4m3fnuz>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(Float8_e4m3fnuz a, int b) {
|
||||
return a - static_cast<Float8_e4m3fnuz>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(Float8_e4m3fnuz a, int b) {
|
||||
return a * static_cast<Float8_e4m3fnuz>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(Float8_e4m3fnuz a, int b) {
|
||||
return a / static_cast<Float8_e4m3fnuz>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(int a, Float8_e4m3fnuz b) {
|
||||
return static_cast<Float8_e4m3fnuz>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(int a, Float8_e4m3fnuz b) {
|
||||
return static_cast<Float8_e4m3fnuz>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(int a, Float8_e4m3fnuz b) {
|
||||
return static_cast<Float8_e4m3fnuz>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(int a, Float8_e4m3fnuz b) {
|
||||
return static_cast<Float8_e4m3fnuz>(a) / b;
|
||||
}
|
||||
|
||||
//// Arithmetic with int64_t
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(Float8_e4m3fnuz a, int64_t b) {
|
||||
return a + static_cast<Float8_e4m3fnuz>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(Float8_e4m3fnuz a, int64_t b) {
|
||||
return a - static_cast<Float8_e4m3fnuz>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(Float8_e4m3fnuz a, int64_t b) {
|
||||
return a * static_cast<Float8_e4m3fnuz>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(Float8_e4m3fnuz a, int64_t b) {
|
||||
return a / static_cast<Float8_e4m3fnuz>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(int64_t a, Float8_e4m3fnuz b) {
|
||||
return static_cast<Float8_e4m3fnuz>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(int64_t a, Float8_e4m3fnuz b) {
|
||||
return static_cast<Float8_e4m3fnuz>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(int64_t a, Float8_e4m3fnuz b) {
|
||||
return static_cast<Float8_e4m3fnuz>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(int64_t a, Float8_e4m3fnuz b) {
|
||||
return static_cast<Float8_e4m3fnuz>(a) / b;
|
||||
}
|
||||
|
||||
/// NOTE: we do not define comparisons directly and instead rely on the implicit
|
||||
/// conversion from c10::Float8_e4m3fnuz to float.
|
||||
|
||||
} // namespace c10
|
||||
|
||||
namespace std {
|
||||
|
|
|
|||
|
|
@ -1,276 +1,8 @@
|
|||
#include <c10/util/Float8_e4m3fnuz.h>
|
||||
#include <array>
|
||||
#include <iostream>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
namespace detail {
|
||||
|
||||
C10_HOST_DEVICE float fp8e4m3fnuz_to_fp32_value(uint8_t input) {
|
||||
constexpr std::array<float, 256> e4m3fnuz_lut = {
|
||||
0.0f,
|
||||
0.0009765625f,
|
||||
0.001953125f,
|
||||
0.0029296875f,
|
||||
0.00390625f,
|
||||
0.0048828125f,
|
||||
0.005859375f,
|
||||
0.0068359375f,
|
||||
0.0078125f,
|
||||
0.0087890625f,
|
||||
0.009765625f,
|
||||
0.0107421875f,
|
||||
0.01171875f,
|
||||
0.0126953125f,
|
||||
0.013671875f,
|
||||
0.0146484375f,
|
||||
0.015625f,
|
||||
0.017578125f,
|
||||
0.01953125f,
|
||||
0.021484375f,
|
||||
0.0234375f,
|
||||
0.025390625f,
|
||||
0.02734375f,
|
||||
0.029296875f,
|
||||
0.03125f,
|
||||
0.03515625f,
|
||||
0.0390625f,
|
||||
0.04296875f,
|
||||
0.046875f,
|
||||
0.05078125f,
|
||||
0.0546875f,
|
||||
0.05859375f,
|
||||
0.0625f,
|
||||
0.0703125f,
|
||||
0.078125f,
|
||||
0.0859375f,
|
||||
0.09375f,
|
||||
0.1015625f,
|
||||
0.109375f,
|
||||
0.1171875f,
|
||||
0.125f,
|
||||
0.140625f,
|
||||
0.15625f,
|
||||
0.171875f,
|
||||
0.1875f,
|
||||
0.203125f,
|
||||
0.21875f,
|
||||
0.234375f,
|
||||
0.25f,
|
||||
0.28125f,
|
||||
0.3125f,
|
||||
0.34375f,
|
||||
0.375f,
|
||||
0.40625f,
|
||||
0.4375f,
|
||||
0.46875f,
|
||||
0.5f,
|
||||
0.5625f,
|
||||
0.625f,
|
||||
0.6875f,
|
||||
0.75f,
|
||||
0.8125f,
|
||||
0.875f,
|
||||
0.9375f,
|
||||
1.0f,
|
||||
1.125f,
|
||||
1.25f,
|
||||
1.375f,
|
||||
1.5f,
|
||||
1.625f,
|
||||
1.75f,
|
||||
1.875f,
|
||||
2.0f,
|
||||
2.25f,
|
||||
2.5f,
|
||||
2.75f,
|
||||
3.0f,
|
||||
3.25f,
|
||||
3.5f,
|
||||
3.75f,
|
||||
4.0f,
|
||||
4.5f,
|
||||
5.0f,
|
||||
5.5f,
|
||||
6.0f,
|
||||
6.5f,
|
||||
7.0f,
|
||||
7.5f,
|
||||
8.0f,
|
||||
9.0f,
|
||||
10.0f,
|
||||
11.0f,
|
||||
12.0f,
|
||||
13.0f,
|
||||
14.0f,
|
||||
15.0f,
|
||||
16.0f,
|
||||
18.0f,
|
||||
20.0f,
|
||||
22.0f,
|
||||
24.0f,
|
||||
26.0f,
|
||||
28.0f,
|
||||
30.0f,
|
||||
32.0f,
|
||||
36.0f,
|
||||
40.0f,
|
||||
44.0f,
|
||||
48.0f,
|
||||
52.0f,
|
||||
56.0f,
|
||||
60.0f,
|
||||
64.0f,
|
||||
72.0f,
|
||||
80.0f,
|
||||
88.0f,
|
||||
96.0f,
|
||||
104.0f,
|
||||
112.0f,
|
||||
120.0f,
|
||||
128.0f,
|
||||
144.0f,
|
||||
160.0f,
|
||||
176.0f,
|
||||
192.0f,
|
||||
208.0f,
|
||||
224.0f,
|
||||
240.0f,
|
||||
std::numeric_limits<float>::signaling_NaN(),
|
||||
-0.0009765625f,
|
||||
-0.001953125f,
|
||||
-0.0029296875f,
|
||||
-0.00390625f,
|
||||
-0.0048828125f,
|
||||
-0.005859375f,
|
||||
-0.0068359375f,
|
||||
-0.0078125f,
|
||||
-0.0087890625f,
|
||||
-0.009765625f,
|
||||
-0.0107421875f,
|
||||
-0.01171875f,
|
||||
-0.0126953125f,
|
||||
-0.013671875f,
|
||||
-0.0146484375f,
|
||||
-0.015625f,
|
||||
-0.017578125f,
|
||||
-0.01953125f,
|
||||
-0.021484375f,
|
||||
-0.0234375f,
|
||||
-0.025390625f,
|
||||
-0.02734375f,
|
||||
-0.029296875f,
|
||||
-0.03125f,
|
||||
-0.03515625f,
|
||||
-0.0390625f,
|
||||
-0.04296875f,
|
||||
-0.046875f,
|
||||
-0.05078125f,
|
||||
-0.0546875f,
|
||||
-0.05859375f,
|
||||
-0.0625f,
|
||||
-0.0703125f,
|
||||
-0.078125f,
|
||||
-0.0859375f,
|
||||
-0.09375f,
|
||||
-0.1015625f,
|
||||
-0.109375f,
|
||||
-0.1171875f,
|
||||
-0.125f,
|
||||
-0.140625f,
|
||||
-0.15625f,
|
||||
-0.171875f,
|
||||
-0.1875f,
|
||||
-0.203125f,
|
||||
-0.21875f,
|
||||
-0.234375f,
|
||||
-0.25f,
|
||||
-0.28125f,
|
||||
-0.3125f,
|
||||
-0.34375f,
|
||||
-0.375f,
|
||||
-0.40625f,
|
||||
-0.4375f,
|
||||
-0.46875f,
|
||||
-0.5f,
|
||||
-0.5625f,
|
||||
-0.625f,
|
||||
-0.6875f,
|
||||
-0.75f,
|
||||
-0.8125f,
|
||||
-0.875f,
|
||||
-0.9375f,
|
||||
-1.0f,
|
||||
-1.125f,
|
||||
-1.25f,
|
||||
-1.375f,
|
||||
-1.5f,
|
||||
-1.625f,
|
||||
-1.75f,
|
||||
-1.875f,
|
||||
-2.0f,
|
||||
-2.25f,
|
||||
-2.5f,
|
||||
-2.75f,
|
||||
-3.0f,
|
||||
-3.25f,
|
||||
-3.5f,
|
||||
-3.75f,
|
||||
-4.0f,
|
||||
-4.5f,
|
||||
-5.0f,
|
||||
-5.5f,
|
||||
-6.0f,
|
||||
-6.5f,
|
||||
-7.0f,
|
||||
-7.5f,
|
||||
-8.0f,
|
||||
-9.0f,
|
||||
-10.0f,
|
||||
-11.0f,
|
||||
-12.0f,
|
||||
-13.0f,
|
||||
-14.0f,
|
||||
-15.0f,
|
||||
-16.0f,
|
||||
-18.0f,
|
||||
-20.0f,
|
||||
-22.0f,
|
||||
-24.0f,
|
||||
-26.0f,
|
||||
-28.0f,
|
||||
-30.0f,
|
||||
-32.0f,
|
||||
-36.0f,
|
||||
-40.0f,
|
||||
-44.0f,
|
||||
-48.0f,
|
||||
-52.0f,
|
||||
-56.0f,
|
||||
-60.0f,
|
||||
-64.0f,
|
||||
-72.0f,
|
||||
-80.0f,
|
||||
-88.0f,
|
||||
-96.0f,
|
||||
-104.0f,
|
||||
-112.0f,
|
||||
-120.0f,
|
||||
-128.0f,
|
||||
-144.0f,
|
||||
-160.0f,
|
||||
-176.0f,
|
||||
-192.0f,
|
||||
-208.0f,
|
||||
-224.0f,
|
||||
-240.0f,
|
||||
};
|
||||
|
||||
return e4m3fnuz_lut[input];
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
static_assert(
|
||||
std::is_standard_layout_v<Float8_e4m3fnuz>,
|
||||
"c10::Float8_e4m3fnuz must be standard layout.");
|
||||
|
|
|
|||
|
|
@ -4,13 +4,11 @@
|
|||
/// conversions to standard C types and basic arithmetic operations. Note that
|
||||
/// arithmetic operations are implemented by converting to floating point and
|
||||
/// performing the operation in float32.
|
||||
///
|
||||
/// Binary configuration remains the same as Float8_e4m3fn:
|
||||
/// s eeee mmm
|
||||
/// 1 sign bit
|
||||
/// 4 exponent bits
|
||||
/// 3 mantissa bits
|
||||
///
|
||||
/// The key differences versus Float8_e4m3fn are:
|
||||
/// bias = 8
|
||||
/// no infinities or negative zero
|
||||
|
|
@ -23,6 +21,7 @@
|
|||
#include <c10/util/C++17.h>
|
||||
#include <c10/util/TypeSafeSignMath.h>
|
||||
#include <c10/util/floating_point_utils.h>
|
||||
#include <type_traits>
|
||||
|
||||
#if defined(__cplusplus) && (__cplusplus >= 201103L)
|
||||
#include <cstdint>
|
||||
|
|
@ -38,27 +37,11 @@ namespace c10 {
|
|||
|
||||
namespace detail {
|
||||
|
||||
/*
|
||||
* Convert a 8-bit floating-point number in fp8 E4M3FNUZ format, in bit
|
||||
* representation, to a 32-bit floating-point number in IEEE single-precision
|
||||
* format, in bit representation.
|
||||
*
|
||||
* @note The implementation doesn't use any floating-point operations.
|
||||
*/
|
||||
#if defined(__CUDA_ARCH__) || defined(__HIP__)
|
||||
C10_HOST_DEVICE C10_API inline float fp8e4m3fnuz_to_fp32_value(uint8_t) {
|
||||
CUDA_KERNEL_ASSERT(false && "e4m3fnuz is not supported by CUDA or HIP");
|
||||
return -1.0;
|
||||
}
|
||||
#else
|
||||
C10_API float fp8e4m3fnuz_to_fp32_value(uint8_t input);
|
||||
#endif
|
||||
|
||||
/*
|
||||
* Convert a 32-bit floating-point number in IEEE single-precision format to a
|
||||
* 8-bit floating-point number in fp8 E4M3FNUZ format, in bit representation.
|
||||
*/
|
||||
C10_HOST_DEVICE inline uint8_t fp8e4m3fnuz_from_fp32_value(float f) {
|
||||
inline C10_HOST_DEVICE uint8_t fp8e4m3fnuz_from_fp32_value(float f) {
|
||||
/*
|
||||
* Binary representation of 256.0f, which is the first value not representable
|
||||
* (i.e. the first value which would overflow in to the sign bit, resulting in
|
||||
|
|
@ -70,7 +53,7 @@ C10_HOST_DEVICE inline uint8_t fp8e4m3fnuz_from_fp32_value(float f) {
|
|||
|
||||
/*
|
||||
* A mask for converting fp32 numbers lower than fp8e4m3fnuz normal range
|
||||
* into denormalized representation.
|
||||
* into denorm representation
|
||||
* magic number: ((127 - 8) + (23 - 3) + 1)
|
||||
*/
|
||||
constexpr uint32_t denorm_mask = UINT32_C(0x8C) << 23;
|
||||
|
|
@ -123,7 +106,6 @@ C10_HOST_DEVICE inline uint8_t fp8e4m3fnuz_from_fp32_value(float f) {
|
|||
}
|
||||
|
||||
result |= sign >> 24;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
|
@ -133,7 +115,7 @@ struct alignas(1) Float8_e4m3fnuz {
|
|||
uint8_t x;
|
||||
|
||||
struct from_bits_t {};
|
||||
static constexpr C10_HOST_DEVICE from_bits_t from_bits() {
|
||||
C10_HOST_DEVICE static constexpr from_bits_t from_bits() {
|
||||
return from_bits_t();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
#pragma once
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/Float8_fnuz_cvt.h>
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
|
||||
C10_CLANG_DIAGNOSTIC_PUSH()
|
||||
|
|
@ -12,21 +14,212 @@ namespace c10 {
|
|||
|
||||
/// Constructors
|
||||
|
||||
C10_HOST_DEVICE inline Float8_e5m2fnuz::Float8_e5m2fnuz(float value)
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz::Float8_e5m2fnuz(float value)
|
||||
: x(detail::fp8e5m2fnuz_from_fp32_value(value)) {}
|
||||
|
||||
/// Implicit conversions
|
||||
|
||||
C10_HOST_DEVICE inline Float8_e5m2fnuz::operator float() const {
|
||||
return detail::fp8e5m2fnuz_to_fp32_value(x);
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz::operator float() const {
|
||||
return detail::fp8_fnuz_to_fp32_value<5, 2>(x);
|
||||
}
|
||||
|
||||
/// Special values helpers
|
||||
|
||||
C10_HOST_DEVICE inline bool Float8_e5m2fnuz::isnan() const {
|
||||
inline C10_HOST_DEVICE bool Float8_e5m2fnuz::isnan() const {
|
||||
return x == 0b10000000;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE bool Float8_e5m2fnuz::isinf() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Arithmetic
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz
|
||||
operator+(const Float8_e5m2fnuz& a, const Float8_e5m2fnuz& b) {
|
||||
return static_cast<float>(a) + static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz
|
||||
operator-(const Float8_e5m2fnuz& a, const Float8_e5m2fnuz& b) {
|
||||
return static_cast<float>(a) - static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz
|
||||
operator*(const Float8_e5m2fnuz& a, const Float8_e5m2fnuz& b) {
|
||||
return static_cast<float>(a) * static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(
|
||||
const Float8_e5m2fnuz& a,
|
||||
const Float8_e5m2fnuz& b) __ubsan_ignore_float_divide_by_zero__ {
|
||||
return static_cast<float>(a) / static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(const Float8_e5m2fnuz& a) {
|
||||
return -static_cast<float>(a);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz& operator+=(
|
||||
Float8_e5m2fnuz& a,
|
||||
const Float8_e5m2fnuz& b) {
|
||||
a = a + b;
|
||||
return a;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz& operator-=(
|
||||
Float8_e5m2fnuz& a,
|
||||
const Float8_e5m2fnuz& b) {
|
||||
a = a - b;
|
||||
return a;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz& operator*=(
|
||||
Float8_e5m2fnuz& a,
|
||||
const Float8_e5m2fnuz& b) {
|
||||
a = a * b;
|
||||
return a;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz& operator/=(
|
||||
Float8_e5m2fnuz& a,
|
||||
const Float8_e5m2fnuz& b) {
|
||||
a = a / b;
|
||||
return a;
|
||||
}
|
||||
|
||||
/// Arithmetic with floats
|
||||
|
||||
inline C10_HOST_DEVICE float operator+(Float8_e5m2fnuz a, float b) {
|
||||
return static_cast<float>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator-(Float8_e5m2fnuz a, float b) {
|
||||
return static_cast<float>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator*(Float8_e5m2fnuz a, float b) {
|
||||
return static_cast<float>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator/(Float8_e5m2fnuz a, float b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return static_cast<float>(a) / b;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE float operator+(float a, Float8_e5m2fnuz b) {
|
||||
return a + static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator-(float a, Float8_e5m2fnuz b) {
|
||||
return a - static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator*(float a, Float8_e5m2fnuz b) {
|
||||
return a * static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator/(float a, Float8_e5m2fnuz b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return a / static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e5m2fnuz& b) {
|
||||
return a += static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e5m2fnuz& b) {
|
||||
return a -= static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e5m2fnuz& b) {
|
||||
return a *= static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e5m2fnuz& b) {
|
||||
return a /= static_cast<float>(b);
|
||||
}
|
||||
|
||||
/// Arithmetic with doubles
|
||||
|
||||
inline C10_HOST_DEVICE double operator+(Float8_e5m2fnuz a, double b) {
|
||||
return static_cast<double>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator-(Float8_e5m2fnuz a, double b) {
|
||||
return static_cast<double>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator*(Float8_e5m2fnuz a, double b) {
|
||||
return static_cast<double>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator/(Float8_e5m2fnuz a, double b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return static_cast<double>(a) / b;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE double operator+(double a, Float8_e5m2fnuz b) {
|
||||
return a + static_cast<double>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator-(double a, Float8_e5m2fnuz b) {
|
||||
return a - static_cast<double>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator*(double a, Float8_e5m2fnuz b) {
|
||||
return a * static_cast<double>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator/(double a, Float8_e5m2fnuz b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return a / static_cast<double>(b);
|
||||
}
|
||||
|
||||
/// Arithmetic with ints
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(Float8_e5m2fnuz a, int b) {
|
||||
return a + static_cast<Float8_e5m2fnuz>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(Float8_e5m2fnuz a, int b) {
|
||||
return a - static_cast<Float8_e5m2fnuz>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(Float8_e5m2fnuz a, int b) {
|
||||
return a * static_cast<Float8_e5m2fnuz>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(Float8_e5m2fnuz a, int b) {
|
||||
return a / static_cast<Float8_e5m2fnuz>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(int a, Float8_e5m2fnuz b) {
|
||||
return static_cast<Float8_e5m2fnuz>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(int a, Float8_e5m2fnuz b) {
|
||||
return static_cast<Float8_e5m2fnuz>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(int a, Float8_e5m2fnuz b) {
|
||||
return static_cast<Float8_e5m2fnuz>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(int a, Float8_e5m2fnuz b) {
|
||||
return static_cast<Float8_e5m2fnuz>(a) / b;
|
||||
}
|
||||
|
||||
//// Arithmetic with int64_t
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(Float8_e5m2fnuz a, int64_t b) {
|
||||
return a + static_cast<Float8_e5m2fnuz>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(Float8_e5m2fnuz a, int64_t b) {
|
||||
return a - static_cast<Float8_e5m2fnuz>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(Float8_e5m2fnuz a, int64_t b) {
|
||||
return a * static_cast<Float8_e5m2fnuz>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(Float8_e5m2fnuz a, int64_t b) {
|
||||
return a / static_cast<Float8_e5m2fnuz>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(int64_t a, Float8_e5m2fnuz b) {
|
||||
return static_cast<Float8_e5m2fnuz>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(int64_t a, Float8_e5m2fnuz b) {
|
||||
return static_cast<Float8_e5m2fnuz>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(int64_t a, Float8_e5m2fnuz b) {
|
||||
return static_cast<Float8_e5m2fnuz>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(int64_t a, Float8_e5m2fnuz b) {
|
||||
return static_cast<Float8_e5m2fnuz>(a) / b;
|
||||
}
|
||||
|
||||
/// NOTE: we do not define comparisons directly and instead rely on the implicit
|
||||
/// conversion from c10::Float8_e5m2fnuz to float.
|
||||
|
||||
} // namespace c10
|
||||
|
||||
namespace std {
|
||||
|
|
|
|||
|
|
@ -1,276 +1,8 @@
|
|||
#include <c10/util/Float8_e5m2fnuz.h>
|
||||
#include <array>
|
||||
#include <iostream>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
namespace detail {
|
||||
|
||||
C10_HOST_DEVICE float fp8e5m2fnuz_to_fp32_value(uint8_t input) {
|
||||
constexpr std::array<float, 256> e5m2fnuz_lut = {
|
||||
0.0f,
|
||||
7.62939453125e-06f,
|
||||
1.52587890625e-05f,
|
||||
2.288818359375e-05f,
|
||||
3.0517578125e-05f,
|
||||
3.814697265625e-05f,
|
||||
4.57763671875e-05f,
|
||||
5.340576171875e-05f,
|
||||
6.103515625e-05f,
|
||||
7.62939453125e-05f,
|
||||
9.1552734375e-05f,
|
||||
0.0001068115234375f,
|
||||
0.0001220703125f,
|
||||
0.000152587890625f,
|
||||
0.00018310546875f,
|
||||
0.000213623046875f,
|
||||
0.000244140625f,
|
||||
0.00030517578125f,
|
||||
0.0003662109375f,
|
||||
0.00042724609375f,
|
||||
0.00048828125f,
|
||||
0.0006103515625f,
|
||||
0.000732421875f,
|
||||
0.0008544921875f,
|
||||
0.0009765625f,
|
||||
0.001220703125f,
|
||||
0.00146484375f,
|
||||
0.001708984375f,
|
||||
0.001953125f,
|
||||
0.00244140625f,
|
||||
0.0029296875f,
|
||||
0.00341796875f,
|
||||
0.00390625f,
|
||||
0.0048828125f,
|
||||
0.005859375f,
|
||||
0.0068359375f,
|
||||
0.0078125f,
|
||||
0.009765625f,
|
||||
0.01171875f,
|
||||
0.013671875f,
|
||||
0.015625f,
|
||||
0.01953125f,
|
||||
0.0234375f,
|
||||
0.02734375f,
|
||||
0.03125f,
|
||||
0.0390625f,
|
||||
0.046875f,
|
||||
0.0546875f,
|
||||
0.0625f,
|
||||
0.078125f,
|
||||
0.09375f,
|
||||
0.109375f,
|
||||
0.125f,
|
||||
0.15625f,
|
||||
0.1875f,
|
||||
0.21875f,
|
||||
0.25f,
|
||||
0.3125f,
|
||||
0.375f,
|
||||
0.4375f,
|
||||
0.5f,
|
||||
0.625f,
|
||||
0.75f,
|
||||
0.875f,
|
||||
1.0f,
|
||||
1.25f,
|
||||
1.5f,
|
||||
1.75f,
|
||||
2.0f,
|
||||
2.5f,
|
||||
3.0f,
|
||||
3.5f,
|
||||
4.0f,
|
||||
5.0f,
|
||||
6.0f,
|
||||
7.0f,
|
||||
8.0f,
|
||||
10.0f,
|
||||
12.0f,
|
||||
14.0f,
|
||||
16.0f,
|
||||
20.0f,
|
||||
24.0f,
|
||||
28.0f,
|
||||
32.0f,
|
||||
40.0f,
|
||||
48.0f,
|
||||
56.0f,
|
||||
64.0f,
|
||||
80.0f,
|
||||
96.0f,
|
||||
112.0f,
|
||||
128.0f,
|
||||
160.0f,
|
||||
192.0f,
|
||||
224.0f,
|
||||
256.0f,
|
||||
320.0f,
|
||||
384.0f,
|
||||
448.0f,
|
||||
512.0f,
|
||||
640.0f,
|
||||
768.0f,
|
||||
896.0f,
|
||||
1024.0f,
|
||||
1280.0f,
|
||||
1536.0f,
|
||||
1792.0f,
|
||||
2048.0f,
|
||||
2560.0f,
|
||||
3072.0f,
|
||||
3584.0f,
|
||||
4096.0f,
|
||||
5120.0f,
|
||||
6144.0f,
|
||||
7168.0f,
|
||||
8192.0f,
|
||||
10240.0f,
|
||||
12288.0f,
|
||||
14336.0f,
|
||||
16384.0f,
|
||||
20480.0f,
|
||||
24576.0f,
|
||||
28672.0f,
|
||||
32768.0f,
|
||||
40960.0f,
|
||||
49152.0f,
|
||||
57344.0f,
|
||||
std::numeric_limits<float>::signaling_NaN(),
|
||||
-7.62939453125e-06f,
|
||||
-1.52587890625e-05f,
|
||||
-2.288818359375e-05f,
|
||||
-3.0517578125e-05f,
|
||||
-3.814697265625e-05f,
|
||||
-4.57763671875e-05f,
|
||||
-5.340576171875e-05f,
|
||||
-6.103515625e-05f,
|
||||
-7.62939453125e-05f,
|
||||
-9.1552734375e-05f,
|
||||
-0.0001068115234375f,
|
||||
-0.0001220703125f,
|
||||
-0.000152587890625f,
|
||||
-0.00018310546875f,
|
||||
-0.000213623046875f,
|
||||
-0.000244140625f,
|
||||
-0.00030517578125f,
|
||||
-0.0003662109375f,
|
||||
-0.00042724609375f,
|
||||
-0.00048828125f,
|
||||
-0.0006103515625f,
|
||||
-0.000732421875f,
|
||||
-0.0008544921875f,
|
||||
-0.0009765625f,
|
||||
-0.001220703125f,
|
||||
-0.00146484375f,
|
||||
-0.001708984375f,
|
||||
-0.001953125f,
|
||||
-0.00244140625f,
|
||||
-0.0029296875f,
|
||||
-0.00341796875f,
|
||||
-0.00390625f,
|
||||
-0.0048828125f,
|
||||
-0.005859375f,
|
||||
-0.0068359375f,
|
||||
-0.0078125f,
|
||||
-0.009765625f,
|
||||
-0.01171875f,
|
||||
-0.013671875f,
|
||||
-0.015625f,
|
||||
-0.01953125f,
|
||||
-0.0234375f,
|
||||
-0.02734375f,
|
||||
-0.03125f,
|
||||
-0.0390625f,
|
||||
-0.046875f,
|
||||
-0.0546875f,
|
||||
-0.0625f,
|
||||
-0.078125f,
|
||||
-0.09375f,
|
||||
-0.109375f,
|
||||
-0.125f,
|
||||
-0.15625f,
|
||||
-0.1875f,
|
||||
-0.21875f,
|
||||
-0.25f,
|
||||
-0.3125f,
|
||||
-0.375f,
|
||||
-0.4375f,
|
||||
-0.5f,
|
||||
-0.625f,
|
||||
-0.75f,
|
||||
-0.875f,
|
||||
-1.0f,
|
||||
-1.25f,
|
||||
-1.5f,
|
||||
-1.75f,
|
||||
-2.0f,
|
||||
-2.5f,
|
||||
-3.0f,
|
||||
-3.5f,
|
||||
-4.0f,
|
||||
-5.0f,
|
||||
-6.0f,
|
||||
-7.0f,
|
||||
-8.0f,
|
||||
-10.0f,
|
||||
-12.0f,
|
||||
-14.0f,
|
||||
-16.0f,
|
||||
-20.0f,
|
||||
-24.0f,
|
||||
-28.0f,
|
||||
-32.0f,
|
||||
-40.0f,
|
||||
-48.0f,
|
||||
-56.0f,
|
||||
-64.0f,
|
||||
-80.0f,
|
||||
-96.0f,
|
||||
-112.0f,
|
||||
-128.0f,
|
||||
-160.0f,
|
||||
-192.0f,
|
||||
-224.0f,
|
||||
-256.0f,
|
||||
-320.0f,
|
||||
-384.0f,
|
||||
-448.0f,
|
||||
-512.0f,
|
||||
-640.0f,
|
||||
-768.0f,
|
||||
-896.0f,
|
||||
-1024.0f,
|
||||
-1280.0f,
|
||||
-1536.0f,
|
||||
-1792.0f,
|
||||
-2048.0f,
|
||||
-2560.0f,
|
||||
-3072.0f,
|
||||
-3584.0f,
|
||||
-4096.0f,
|
||||
-5120.0f,
|
||||
-6144.0f,
|
||||
-7168.0f,
|
||||
-8192.0f,
|
||||
-10240.0f,
|
||||
-12288.0f,
|
||||
-14336.0f,
|
||||
-16384.0f,
|
||||
-20480.0f,
|
||||
-24576.0f,
|
||||
-28672.0f,
|
||||
-32768.0f,
|
||||
-40960.0f,
|
||||
-49152.0f,
|
||||
-57344.0f,
|
||||
};
|
||||
|
||||
return e5m2fnuz_lut[input];
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
static_assert(
|
||||
std::is_standard_layout_v<Float8_e5m2fnuz>,
|
||||
"c10::Float8_e5m2 must be standard layout.");
|
||||
|
|
|
|||
|
|
@ -4,13 +4,11 @@
|
|||
/// conversions to standard C types and basic arithmetic operations. Note that
|
||||
/// arithmetic operations are implemented by converting to floating point and
|
||||
/// performing the operation in float32.
|
||||
///
|
||||
/// Binary configuration remains the same as e5m2:
|
||||
/// s eeeee mm
|
||||
/// 1 sign bit
|
||||
/// 5 exponent bits
|
||||
/// 2 mantissa bits
|
||||
///
|
||||
/// The key differences that e5m2fnuz brings are:
|
||||
/// bias = 16
|
||||
/// no infinities or negative zero
|
||||
|
|
@ -38,27 +36,11 @@ namespace c10 {
|
|||
|
||||
namespace detail {
|
||||
|
||||
/*
|
||||
* Convert a 8-bit floating-point number in fp8 E5M2FNUZ format, in bit
|
||||
* representation, to a 32-bit floating-point number in IEEE single-precision
|
||||
* format, in bit representation.
|
||||
*
|
||||
* @note The implementation doesn't use any floating-point operations.
|
||||
*/
|
||||
#if defined(__CUDA_ARCH__) || defined(__HIP__)
|
||||
C10_HOST_DEVICE C10_API inline float fp8e5m2fnuz_to_fp32_value(uint8_t) {
|
||||
CUDA_KERNEL_ASSERT(false && "e5m2fnuz is not supported by CUDA or HIP");
|
||||
return -1.0;
|
||||
}
|
||||
#else
|
||||
C10_API float fp8e5m2fnuz_to_fp32_value(uint8_t input);
|
||||
#endif
|
||||
|
||||
/*
|
||||
* Convert a 32-bit floating-point number in IEEE single-precision format to a
|
||||
* 8-bit floating-point number in fp8 E5M2 format, in bit representation.
|
||||
*/
|
||||
C10_HOST_DEVICE inline uint8_t fp8e5m2fnuz_from_fp32_value(float f) {
|
||||
inline C10_HOST_DEVICE uint8_t fp8e5m2fnuz_from_fp32_value(float f) {
|
||||
/*
|
||||
* Binary representation of 65536.0f, which is the first value not
|
||||
* representable (i.e. the first value which would overflow in to the sign
|
||||
|
|
@ -76,7 +58,6 @@ C10_HOST_DEVICE inline uint8_t fp8e5m2fnuz_from_fp32_value(float f) {
|
|||
constexpr uint32_t denorm_mask = UINT32_C(0x85) << 23;
|
||||
|
||||
uint32_t f_bits = fp32_to_bits(f);
|
||||
|
||||
uint32_t result = 0u;
|
||||
|
||||
/*
|
||||
|
|
@ -132,7 +113,7 @@ struct alignas(1) Float8_e5m2fnuz {
|
|||
uint8_t x;
|
||||
|
||||
struct from_bits_t {};
|
||||
static constexpr C10_HOST_DEVICE from_bits_t from_bits() {
|
||||
C10_HOST_DEVICE static constexpr from_bits_t from_bits() {
|
||||
return from_bits_t();
|
||||
}
|
||||
|
||||
|
|
@ -143,6 +124,7 @@ struct alignas(1) Float8_e5m2fnuz {
|
|||
inline C10_HOST_DEVICE Float8_e5m2fnuz(float value);
|
||||
inline C10_HOST_DEVICE operator float() const;
|
||||
inline C10_HOST_DEVICE bool isnan() const;
|
||||
inline C10_HOST_DEVICE bool isinf() const;
|
||||
};
|
||||
|
||||
C10_API std::ostream& operator<<(
|
||||
|
|
|
|||
58
c10/util/Float8_fnuz_cvt.h
Normal file
58
c10/util/Float8_fnuz_cvt.h
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
#pragma once
|
||||
|
||||
#include <c10/util/floating_point_utils.h>
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
namespace c10::detail {
|
||||
|
||||
/*
|
||||
* Convert a 8-bit floating-point number in either f8 E4M3FNUZ or bf8 E5M2FNUZ
|
||||
* format, in bit representation, to a 32-bit floating-point number.
|
||||
*/
|
||||
template <uint32_t we, uint32_t wm>
|
||||
inline C10_HOST_DEVICE float fp8_fnuz_to_fp32_value(uint8_t x) {
|
||||
static_assert((we == 4 && wm == 3) || (we == 5 && wm == 2));
|
||||
constexpr uint32_t weo = 8;
|
||||
constexpr uint32_t wmo = 23;
|
||||
|
||||
if (x == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (x == 0x80) {
|
||||
constexpr uint32_t ifNaN = 0x7F800001;
|
||||
return fp32_from_bits(ifNaN);
|
||||
}
|
||||
|
||||
uint32_t mantissa = x & ((1 << wm) - 1);
|
||||
uint32_t exponent = (x & 0x7F) >> wm;
|
||||
|
||||
// subnormal input
|
||||
if (exponent == 0) {
|
||||
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
|
||||
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
|
||||
uint32_t renorm_shift = __clz(mantissa);
|
||||
#elif defined(_MSC_VER)
|
||||
unsigned long nonsign_bsr;
|
||||
_BitScanReverse(&nonsign_bsr, (unsigned long)mantissa);
|
||||
uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31;
|
||||
#else
|
||||
uint32_t renorm_shift = __builtin_clz(mantissa);
|
||||
#endif
|
||||
uint32_t sh = 1 + renorm_shift - (32 - wm);
|
||||
mantissa <<= sh;
|
||||
exponent += 1 - sh;
|
||||
mantissa &= ((1 << wm) - 1);
|
||||
}
|
||||
|
||||
const uint32_t exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1));
|
||||
exponent += exp_low_cutoff - 1;
|
||||
mantissa <<= wmo - wm;
|
||||
|
||||
uint32_t sign = x >> 7;
|
||||
uint32_t retval = (sign << 31) | (exponent << 23) | mantissa;
|
||||
return fp32_from_bits(retval);
|
||||
}
|
||||
|
||||
} // namespace c10::detail
|
||||
|
|
@ -25,6 +25,8 @@ class TensorProtoDataType(Enum):
|
|||
BFLOAT16 = ...
|
||||
FLOAT8E5M2 = ...
|
||||
FLOAT8E4M3FN = ...
|
||||
FLOAT8E5M2FNUZ = ...
|
||||
FLOAT8E4M3FNUZ = ...
|
||||
|
||||
class OperatorExportTypes(Enum):
|
||||
ONNX = ...
|
||||
|
|
|
|||
|
|
@ -85,6 +85,8 @@ DTYPE_TO_ATEN = {
|
|||
torch.complex64: "at::kComplexFloat",
|
||||
torch.float8_e4m3fn: "at::kFloat8_e4m3fn",
|
||||
torch.float8_e5m2: "at::kFloat8_e5m2",
|
||||
torch.float8_e4m3fnuz: "at::kFloat8_e4m3fnuz",
|
||||
torch.float8_e5m2fnuz: "at::kFloat8_e5m2fnuz",
|
||||
}
|
||||
|
||||
DEVICE_TO_ATEN = {
|
||||
|
|
|
|||
|
|
@ -387,6 +387,10 @@ def triton_compute_type(dtype):
|
|||
triton_type_name = "float8e4nv"
|
||||
elif triton_type_name == "float8_e5m2":
|
||||
triton_type_name = "float8e5"
|
||||
elif triton_type_name == "float8_e4m3fnuz":
|
||||
triton_type_name = "float8e4b8"
|
||||
elif triton_type_name == "float8_e5m2":
|
||||
triton_type_name = "float8e5b16"
|
||||
return f"tl.{triton_type_name}"
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -18,6 +18,10 @@ def signature_of(arg: Union[TensorArg, SizeArg], *, size_dtype: str) -> str:
|
|||
tye = "*fp8e4nv"
|
||||
elif arg.dtype == torch.float8_e5m2:
|
||||
tye = "*fp8e5"
|
||||
elif arg.dtype == torch.float8_e4m3fnuz:
|
||||
tye = "*fp8e4b8"
|
||||
elif arg.dtype == torch.float8_e5m2fnuz:
|
||||
tye = "*fp8e5b16"
|
||||
else:
|
||||
tye = JITFunction._type_of(arg.dtype)
|
||||
if V.graph.is_unspec_arg(arg.buffer):
|
||||
|
|
|
|||
|
|
@ -91,6 +91,8 @@ def supported_dtype_of_cpp_wrapper(dtype, cuda):
|
|||
if cuda:
|
||||
supported_dtype.add(torch.float8_e4m3fn)
|
||||
supported_dtype.add(torch.float8_e5m2)
|
||||
supported_dtype.add(torch.float8_e4m3fnuz)
|
||||
supported_dtype.add(torch.float8_e5m2fnuz)
|
||||
|
||||
return dtype in supported_dtype
|
||||
|
||||
|
|
|
|||
|
|
@ -5520,7 +5520,12 @@ def meta_scaled_mm(
|
|||
return stride[0] == 1 and stride[1] == shape[0]
|
||||
|
||||
def is_fp8_type(dtype):
|
||||
return dtype in (torch.float8_e4m3fn, torch.float8_e5m2)
|
||||
return dtype in (
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e5m2,
|
||||
torch.float8_e4m3fnuz,
|
||||
torch.float8_e5m2fnuz,
|
||||
)
|
||||
|
||||
torch._check(
|
||||
self.dim() == 2 and mat2.dim() == 2,
|
||||
|
|
|
|||
|
|
@ -397,6 +397,8 @@ class Tensor(torch._C.TensorBase):
|
|||
v3_dtypes = [
|
||||
torch.float8_e5m2,
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e5m2fnuz,
|
||||
torch.float8_e4m3fnuz,
|
||||
torch.bits8,
|
||||
torch.bits16,
|
||||
torch.bits1x8,
|
||||
|
|
|
|||
|
|
@ -81,6 +81,8 @@ def _get_allowed_globals():
|
|||
torch.complex128,
|
||||
torch.float8_e5m2,
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e5m2fnuz,
|
||||
torch.float8_e4m3fnuz,
|
||||
torch.float16,
|
||||
torch.float32,
|
||||
torch.float64,
|
||||
|
|
|
|||
|
|
@ -85,6 +85,8 @@ AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_cuda();
|
|||
|
||||
AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float8_e5m2();
|
||||
AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float8_e4m3fn();
|
||||
AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float8_e5m2fnuz();
|
||||
AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float8_e4m3fnuz();
|
||||
AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_bfloat16();
|
||||
AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float16();
|
||||
AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float32();
|
||||
|
|
|
|||
|
|
@ -74,6 +74,14 @@ int32_t aoti_torch_dtype_float8_e4m3fn() {
|
|||
return (int32_t)c10::ScalarType::Float8_e4m3fn;
|
||||
}
|
||||
|
||||
int32_t aoti_torch_dtype_float8_e5m2fnuz() {
|
||||
return (int32_t)c10::ScalarType::Float8_e5m2fnuz;
|
||||
}
|
||||
|
||||
int32_t aoti_torch_dtype_float8_e4m3fnuz() {
|
||||
return (int32_t)c10::ScalarType::Float8_e4m3fnuz;
|
||||
}
|
||||
|
||||
int32_t aoti_torch_dtype_bfloat16() {
|
||||
return (int32_t)c10::ScalarType::BFloat16;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -91,8 +91,12 @@ c10::optional<at::ScalarType> ONNXTypeToATenType(int32_t onnx_type) {
|
|||
return at::kBFloat16;
|
||||
case ::torch::onnx::TensorProto_DataType_FLOAT8E5M2:
|
||||
return at::kFloat8_e5m2;
|
||||
case ::torch::onnx::TensorProto_DataType_FLOAT8E5M2FNUZ:
|
||||
return at::kFloat8_e5m2fnuz;
|
||||
case ::torch::onnx::TensorProto_DataType_FLOAT8E4M3FN:
|
||||
return at::kFloat8_e4m3fn;
|
||||
case ::torch::onnx::TensorProto_DataType_FLOAT8E4M3FNUZ:
|
||||
return at::kFloat8_e4m3fnuz;
|
||||
default:
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
|
|
|
|||
|
|
@ -31,6 +31,8 @@ static const std::unordered_map<c10::ScalarType, int, ScalarTypeHashFunction>
|
|||
{c10::kBFloat16, 15},
|
||||
{c10::kFloat8_e4m3fn, 16},
|
||||
{c10::kFloat8_e5m2, 17},
|
||||
{c10::kFloat8_e4m3fnuz, 18},
|
||||
{c10::kFloat8_e5m2fnuz, 19},
|
||||
};
|
||||
|
||||
static int64_t ScalarTypeToONNXType(const c10::ScalarType& st) {
|
||||
|
|
|
|||
|
|
@ -469,6 +469,10 @@ onnx::TensorProto_DataType ATenTypeToOnnxType(at::ScalarType at_type) {
|
|||
return onnx_torch::TensorProto_DataType_FLOAT8E4M3FN;
|
||||
case at::kFloat8_e5m2:
|
||||
return onnx_torch::TensorProto_DataType_FLOAT8E5M2;
|
||||
case at::kFloat8_e4m3fnuz:
|
||||
return onnx_torch::TensorProto_DataType_FLOAT8E4M3FNUZ;
|
||||
case at::kFloat8_e5m2fnuz:
|
||||
return onnx_torch::TensorProto_DataType_FLOAT8E5M2FNUZ;
|
||||
default:
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
|
|
|
|||
|
|
@ -74,8 +74,15 @@ int Dtype::byte_size() const {
|
|||
scalar_size = sizeof(Type); \
|
||||
break;
|
||||
|
||||
AT_FORALL_SCALAR_TYPES_AND5(
|
||||
Bool, Half, BFloat16, Float8_e5m2, Float8_e4m3fn, TYPE_CASE);
|
||||
AT_FORALL_SCALAR_TYPES_AND7(
|
||||
Bool,
|
||||
Half,
|
||||
BFloat16,
|
||||
Float8_e5m2,
|
||||
Float8_e4m3fn,
|
||||
Float8_e5m2fnuz,
|
||||
Float8_e4m3fnuz,
|
||||
TYPE_CASE);
|
||||
TYPE_CASE(c10::quint8, QUInt8);
|
||||
TYPE_CASE(c10::qint8, QInt8);
|
||||
#undef TYPE_CASE
|
||||
|
|
|
|||
|
|
@ -12,8 +12,14 @@ namespace torch::onnx {
|
|||
// ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN
|
||||
constexpr auto TensorProto_DataType_FLOAT8E4M3FN =
|
||||
static_cast<::ONNX_NAMESPACE::TensorProto_DataType>(17);
|
||||
// ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ
|
||||
constexpr auto TensorProto_DataType_FLOAT8E4M3FNUZ =
|
||||
static_cast<::ONNX_NAMESPACE::TensorProto_DataType>(18);
|
||||
// ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2
|
||||
constexpr auto TensorProto_DataType_FLOAT8E5M2 =
|
||||
static_cast<::ONNX_NAMESPACE::TensorProto_DataType>(19);
|
||||
// ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ
|
||||
constexpr auto TensorProto_DataType_FLOAT8E5M2FNUZ =
|
||||
static_cast<::ONNX_NAMESPACE::TensorProto_DataType>(20);
|
||||
|
||||
} // namespace torch::onnx
|
||||
|
|
|
|||
|
|
@ -274,7 +274,11 @@ void initONNXBindings(PyObject* module) {
|
|||
.value("COMPLEX128", ::ONNX_NAMESPACE::TensorProto_DataType_COMPLEX128)
|
||||
.value("BFLOAT16", ::ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16)
|
||||
.value("FLOAT8E4M3FN", ::torch::onnx::TensorProto_DataType_FLOAT8E4M3FN)
|
||||
.value("FLOAT8E5M2", ::torch::onnx::TensorProto_DataType_FLOAT8E5M2);
|
||||
.value(
|
||||
"FLOAT8E4M3FNUZ", ::torch::onnx::TensorProto_DataType_FLOAT8E4M3FNUZ)
|
||||
.value("FLOAT8E5M2", ::torch::onnx::TensorProto_DataType_FLOAT8E5M2)
|
||||
.value(
|
||||
"FLOAT8E5M2FNUZ", ::torch::onnx::TensorProto_DataType_FLOAT8E5M2FNUZ);
|
||||
|
||||
py::enum_<OperatorExportTypes>(onnx, "OperatorExportTypes")
|
||||
.value("ONNX", OperatorExportTypes::ONNX)
|
||||
|
|
|
|||
|
|
@ -152,8 +152,8 @@ inline PyObject* load_scalar(void* data, at::ScalarType scalarType) {
|
|||
return PyFloat_FromDouble(at::convert<double, at::Float8_e5m2fnuz>(
|
||||
*(at::Float8_e5m2fnuz*)data));
|
||||
case at::kFloat8_e4m3fnuz:
|
||||
return PyFloat_FromDouble(at::convert<double, at::Float8_e5m2fnuz>(
|
||||
*(at::Float8_e5m2fnuz*)data));
|
||||
return PyFloat_FromDouble(at::convert<double, at::Float8_e4m3fnuz>(
|
||||
*(at::Float8_e4m3fnuz*)data));
|
||||
default:
|
||||
throw std::runtime_error("invalid type");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -208,6 +208,8 @@ dtype_abbrs = {
|
|||
torch.float16: 'f16',
|
||||
torch.float8_e4m3fn: 'f8e4m3fn',
|
||||
torch.float8_e5m2: 'f8e5m2',
|
||||
torch.float8_e4m3fnuz: 'f8e4m3fnuz',
|
||||
torch.float8_e5m2fnuz: 'f8e5m2fnuz',
|
||||
torch.complex32: 'c32',
|
||||
torch.complex64: 'c64',
|
||||
torch.complex128: 'c128',
|
||||
|
|
|
|||
|
|
@ -33,6 +33,8 @@ ScalarName = Literal[
|
|||
"BFloat16",
|
||||
"Float8E5M2",
|
||||
"Float8E4M3FN",
|
||||
"Float8E5M2FNUZ",
|
||||
"Float8E4M3FNUZ",
|
||||
"Undefined",
|
||||
]
|
||||
|
||||
|
|
@ -55,6 +57,8 @@ TorchName = Literal[
|
|||
"bfloat16",
|
||||
"float8_e5m2",
|
||||
"float8_e4m3fn",
|
||||
"float8_e5m2fnuz",
|
||||
"float8_e4m3fnuz",
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -96,7 +100,9 @@ class JitScalarType(enum.IntEnum):
|
|||
BFLOAT16 = enum.auto() # 15
|
||||
FLOAT8E5M2 = enum.auto() # 16
|
||||
FLOAT8E4M3FN = enum.auto() # 17
|
||||
UNDEFINED = enum.auto() # 18
|
||||
FLOAT8E5M2FNUZ = enum.auto() # 18
|
||||
FLOAT8E4M3FNUZ = enum.auto() # 19
|
||||
UNDEFINED = enum.auto() # 20
|
||||
|
||||
@classmethod
|
||||
@_beartype.beartype
|
||||
|
|
@ -286,6 +292,8 @@ _SCALAR_TYPE_TO_NAME: Dict[JitScalarType, ScalarName] = {
|
|||
JitScalarType.BFLOAT16: "BFloat16",
|
||||
JitScalarType.FLOAT8E5M2: "Float8E5M2",
|
||||
JitScalarType.FLOAT8E4M3FN: "Float8E4M3FN",
|
||||
JitScalarType.FLOAT8E5M2FNUZ: "Float8E5M2FNUZ",
|
||||
JitScalarType.FLOAT8E4M3FNUZ: "Float8E4M3FNUZ",
|
||||
JitScalarType.UNDEFINED: "Undefined",
|
||||
}
|
||||
|
||||
|
|
@ -312,6 +320,8 @@ _SCALAR_TYPE_TO_TORCH_NAME: Dict[JitScalarType, TorchName] = {
|
|||
JitScalarType.BFLOAT16: "bfloat16",
|
||||
JitScalarType.FLOAT8E5M2: "float8_e5m2",
|
||||
JitScalarType.FLOAT8E4M3FN: "float8_e4m3fn",
|
||||
JitScalarType.FLOAT8E5M2FNUZ: "float8_e5m2fnuz",
|
||||
JitScalarType.FLOAT8E4M3FNUZ: "float8_e4m3fnuz",
|
||||
}
|
||||
|
||||
_TORCH_NAME_TO_SCALAR_TYPE: Dict[TorchName, JitScalarType] = {
|
||||
|
|
@ -338,6 +348,8 @@ _SCALAR_TYPE_TO_ONNX = {
|
|||
JitScalarType.QINT32: _C_onnx.TensorProtoDataType.INT32,
|
||||
JitScalarType.FLOAT8E5M2: _C_onnx.TensorProtoDataType.FLOAT8E5M2,
|
||||
JitScalarType.FLOAT8E4M3FN: _C_onnx.TensorProtoDataType.FLOAT8E4M3FN,
|
||||
JitScalarType.FLOAT8E5M2FNUZ: _C_onnx.TensorProtoDataType.FLOAT8E5M2FNUZ,
|
||||
JitScalarType.FLOAT8E4M3FNUZ: _C_onnx.TensorProtoDataType.FLOAT8E4M3FNUZ,
|
||||
}
|
||||
|
||||
# source of truth is
|
||||
|
|
@ -361,6 +373,8 @@ _SCALAR_TYPE_TO_DTYPE = {
|
|||
JitScalarType.BFLOAT16: torch.bfloat16,
|
||||
JitScalarType.FLOAT8E5M2: torch.float8_e5m2,
|
||||
JitScalarType.FLOAT8E4M3FN: torch.float8_e4m3fn,
|
||||
JitScalarType.FLOAT8E5M2FNUZ: torch.float8_e5m2fnuz,
|
||||
JitScalarType.FLOAT8E4M3FNUZ: torch.float8_e4m3fnuz,
|
||||
}
|
||||
|
||||
_DTYPE_TO_SCALAR_TYPE = {v: k for k, v in _SCALAR_TYPE_TO_DTYPE.items()}
|
||||
|
|
|
|||
|
|
@ -207,6 +207,14 @@ class _StorageBase:
|
|||
"""Casts this storage to float8_e4m3fn type"""
|
||||
return self._to(torch.float8_e4m3fn)
|
||||
|
||||
def float8_e5m2fnuz(self):
|
||||
"""Casts this storage to float8_e5m2fnuz type"""
|
||||
return self._to(torch.float8_e5m2fnuz)
|
||||
|
||||
def float8_e4m3fnuz(self):
|
||||
"""Casts this storage to float8_e4m3fnuz type"""
|
||||
return self._to(torch.float8_e4m3fnuz)
|
||||
|
||||
def is_pinned(self, device: Union[str, torch.device] = 'cuda'):
|
||||
r"""Determine whether the CPU storage is already pinned on device.
|
||||
|
||||
|
|
@ -1070,6 +1078,16 @@ class TypedStorage:
|
|||
_warn_typed_storage_removal()
|
||||
return self._to(torch.float8_e4m3fn)
|
||||
|
||||
def float8_e5m2fnuz(self):
|
||||
"""Casts this storage to float8_e5m2fnuz type"""
|
||||
_warn_typed_storage_removal()
|
||||
return self._to(torch.float8_e5m2fnuz)
|
||||
|
||||
def float8_e4m3fnuz(self):
|
||||
"""Casts this storage to float8_e4m3fnuz type"""
|
||||
_warn_typed_storage_removal()
|
||||
return self._to(torch.float8_e4m3fnuz)
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, filename, shared, size):
|
||||
"""from_file(filename, shared=False, size=0) -> Storage
|
||||
|
|
|
|||
|
|
@ -20,7 +20,12 @@ _INTEGRAL_TYPES = [
|
|||
torch.uint64,
|
||||
]
|
||||
_FLOATING_TYPES = [torch.float16, torch.bfloat16, torch.float32, torch.float64]
|
||||
_FLOATING_8BIT_TYPES = [torch.float8_e4m3fn, torch.float8_e5m2]
|
||||
_FLOATING_8BIT_TYPES = [
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e5m2,
|
||||
torch.float8_e4m3fnuz,
|
||||
torch.float8_e5m2fnuz,
|
||||
]
|
||||
_COMPLEX_TYPES = [torch.complex32, torch.complex64, torch.complex128]
|
||||
_BOOLEAN_OR_INTEGRAL_TYPES = [torch.bool, *_INTEGRAL_TYPES]
|
||||
_FLOATING_OR_COMPLEX_TYPES = [*_FLOATING_TYPES, *_COMPLEX_TYPES]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user