mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
CUDA BFloat16 infrastructure (#44925)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44925 Reviewed By: agolynski Differential Revision: D23783910 Pulled By: ngimel fbshipit-source-id: dacac2ad87d58056bdc68bfe0b7ab1de5c2af0d8
This commit is contained in:
parent
8cb7280242
commit
2fa062002e
|
|
@ -723,6 +723,7 @@ torch_cuda_half_options = [
|
|||
"-DCUDA_HAS_FP16=1",
|
||||
"-D__CUDA_NO_HALF_OPERATORS__",
|
||||
"-D__CUDA_NO_HALF_CONVERSIONS__",
|
||||
"-D__CUDA_NO_BFLOAT16_CONVERSIONS__",
|
||||
"-D__CUDA_NO_HALF2_OPERATORS__",
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -7,15 +7,44 @@ namespace c10 {
|
|||
|
||||
/// Constructors
|
||||
inline C10_HOST_DEVICE BFloat16::BFloat16(float value) {
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
x = __bfloat16_as_ushort(__float2bfloat16(value));
|
||||
#else
|
||||
// RNE by default
|
||||
x = detail::round_to_nearest_even(value);
|
||||
#endif
|
||||
}
|
||||
|
||||
/// Implicit conversions
|
||||
inline C10_HOST_DEVICE BFloat16::operator float() const {
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
|
||||
return __bfloat162float(*reinterpret_cast<const __nv_bfloat16*>(&x));
|
||||
#else
|
||||
return detail::f32_from_bits(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
|
||||
inline C10_HOST_DEVICE BFloat16::BFloat16(const __nv_bfloat16& value) {
|
||||
x = *reinterpret_cast<const unsigned short*>(&value);
|
||||
}
|
||||
inline C10_HOST_DEVICE BFloat16::operator __nv_bfloat16() const {
|
||||
return *reinterpret_cast<const __nv_bfloat16*>(&x);
|
||||
}
|
||||
#endif
|
||||
|
||||
// CUDA intrinsics
|
||||
|
||||
#if defined(__CUDACC__) || defined(__HIPCC__)
|
||||
inline C10_DEVICE BFloat16 __ldg(const BFloat16* ptr) {
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
return __ldg(reinterpret_cast<const __nv_bfloat16 *>(ptr));
|
||||
#else
|
||||
return *ptr;
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
/// Arithmetic
|
||||
|
||||
inline C10_HOST_DEVICE BFloat16 operator+(const BFloat16& a, const BFloat16& b) {
|
||||
|
|
|
|||
|
|
@ -7,6 +7,10 @@
|
|||
#include <cmath>
|
||||
#include <cstring>
|
||||
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
|
||||
#include <cuda_bf16.h>
|
||||
#endif
|
||||
|
||||
namespace c10 {
|
||||
|
||||
namespace detail {
|
||||
|
|
@ -84,6 +88,11 @@ struct alignas(2) BFloat16 {
|
|||
constexpr C10_HOST_DEVICE BFloat16(unsigned short bits, from_bits_t) : x(bits){};
|
||||
inline C10_HOST_DEVICE BFloat16(float value);
|
||||
inline C10_HOST_DEVICE operator float() const;
|
||||
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
|
||||
inline C10_HOST_DEVICE BFloat16(const __nv_bfloat16& value);
|
||||
explicit inline C10_HOST_DEVICE operator __nv_bfloat16() const;
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace c10
|
||||
|
|
|
|||
|
|
@ -1504,7 +1504,8 @@ if(NOT INTERN_BUILD_MOBILE)
|
|||
|
||||
if(CUDA_HAS_FP16 OR NOT ${CUDA_VERSION} LESS 7.5)
|
||||
message(STATUS "Found CUDA with FP16 support, compiling with torch.cuda.HalfTensor")
|
||||
list(APPEND CUDA_NVCC_FLAGS "-DCUDA_HAS_FP16=1" "-D__CUDA_NO_HALF_OPERATORS__" "-D__CUDA_NO_HALF_CONVERSIONS__" "-D__CUDA_NO_HALF2_OPERATORS__")
|
||||
list(APPEND CUDA_NVCC_FLAGS "-DCUDA_HAS_FP16=1" "-D__CUDA_NO_HALF_OPERATORS__" "-D__CUDA_NO_HALF_CONVERSIONS__"
|
||||
"-D__CUDA_NO_BFLOAT16_CONVERSIONS__" "-D__CUDA_NO_HALF2_OPERATORS__")
|
||||
add_compile_options(-DCUDA_HAS_FP16=1)
|
||||
else()
|
||||
message(STATUS "Could not find CUDA with FP16 support, compiling without torch.CudaHalfTensor")
|
||||
|
|
|
|||
|
|
@ -152,6 +152,7 @@ MSVC_IGNORE_CUDAFE_WARNINGS = [
|
|||
COMMON_NVCC_FLAGS = [
|
||||
'-D__CUDA_NO_HALF_OPERATORS__',
|
||||
'-D__CUDA_NO_HALF_CONVERSIONS__',
|
||||
'-D__CUDA_NO_BFLOAT16_CONVERSIONS__',
|
||||
'-D__CUDA_NO_HALF2_OPERATORS__',
|
||||
'--expt-relaxed-constexpr'
|
||||
]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user