CUDA BFloat16 infrastructure (#44925)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44925

Reviewed By: agolynski

Differential Revision: D23783910

Pulled By: ngimel

fbshipit-source-id: dacac2ad87d58056bdc68bfe0b7ab1de5c2af0d8
This commit is contained in:
Xiang Gao 2020-10-02 16:19:14 -07:00 committed by Facebook GitHub Bot
parent 8cb7280242
commit 2fa062002e
5 changed files with 42 additions and 1 deletions

View File

@ -723,6 +723,7 @@ torch_cuda_half_options = [
"-DCUDA_HAS_FP16=1",
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
]

View File

@ -7,15 +7,44 @@ namespace c10 {
/// Constructors
inline C10_HOST_DEVICE BFloat16::BFloat16(float value) {
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
x = __bfloat16_as_ushort(__float2bfloat16(value));
#else
// RNE by default
x = detail::round_to_nearest_even(value);
#endif
}
/// Implicit conversions
inline C10_HOST_DEVICE BFloat16::operator float() const {
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
return __bfloat162float(*reinterpret_cast<const __nv_bfloat16*>(&x));
#else
return detail::f32_from_bits(x);
#endif
}
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
inline C10_HOST_DEVICE BFloat16::BFloat16(const __nv_bfloat16& value) {
x = *reinterpret_cast<const unsigned short*>(&value);
}
inline C10_HOST_DEVICE BFloat16::operator __nv_bfloat16() const {
return *reinterpret_cast<const __nv_bfloat16*>(&x);
}
#endif
// CUDA intrinsics
#if defined(__CUDACC__) || defined(__HIPCC__)
inline C10_DEVICE BFloat16 __ldg(const BFloat16* ptr) {
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return __ldg(reinterpret_cast<const __nv_bfloat16 *>(ptr));
#else
return *ptr;
#endif
}
#endif
/// Arithmetic
inline C10_HOST_DEVICE BFloat16 operator+(const BFloat16& a, const BFloat16& b) {

View File

@ -7,6 +7,10 @@
#include <cmath>
#include <cstring>
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
#include <cuda_bf16.h>
#endif
namespace c10 {
namespace detail {
@ -84,6 +88,11 @@ struct alignas(2) BFloat16 {
constexpr C10_HOST_DEVICE BFloat16(unsigned short bits, from_bits_t) : x(bits){};
inline C10_HOST_DEVICE BFloat16(float value);
inline C10_HOST_DEVICE operator float() const;
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
inline C10_HOST_DEVICE BFloat16(const __nv_bfloat16& value);
explicit inline C10_HOST_DEVICE operator __nv_bfloat16() const;
#endif
};
} // namespace c10

View File

@ -1504,7 +1504,8 @@ if(NOT INTERN_BUILD_MOBILE)
if(CUDA_HAS_FP16 OR NOT ${CUDA_VERSION} LESS 7.5)
message(STATUS "Found CUDA with FP16 support, compiling with torch.cuda.HalfTensor")
list(APPEND CUDA_NVCC_FLAGS "-DCUDA_HAS_FP16=1" "-D__CUDA_NO_HALF_OPERATORS__" "-D__CUDA_NO_HALF_CONVERSIONS__" "-D__CUDA_NO_HALF2_OPERATORS__")
list(APPEND CUDA_NVCC_FLAGS "-DCUDA_HAS_FP16=1" "-D__CUDA_NO_HALF_OPERATORS__" "-D__CUDA_NO_HALF_CONVERSIONS__"
"-D__CUDA_NO_BFLOAT16_CONVERSIONS__" "-D__CUDA_NO_HALF2_OPERATORS__")
add_compile_options(-DCUDA_HAS_FP16=1)
else()
message(STATUS "Could not find CUDA with FP16 support, compiling without torch.CudaHalfTensor")

View File

@ -152,6 +152,7 @@ MSVC_IGNORE_CUDAFE_WARNINGS = [
COMMON_NVCC_FLAGS = [
'-D__CUDA_NO_HALF_OPERATORS__',
'-D__CUDA_NO_HALF_CONVERSIONS__',
'-D__CUDA_NO_BFLOAT16_CONVERSIONS__',
'-D__CUDA_NO_HALF2_OPERATORS__',
'--expt-relaxed-constexpr'
]