mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ROCm] Reduce duplication in bfloat16_support_literal definition (#166147)
This PR refactors the bfloat16_support_literal constant in the PyTorch build logic to eliminate duplicated ROCm-specific code. Previously, there were two nearly identical branches for ROCM_VERSION < 70000 and ROCM_VERSION >= 70000, differing only by a single typedef. These have been unified into one conditional block with a minimal version guard inside. (https://github.com/ROCm/pytorch/pull/2502) Pull Request resolved: https://github.com/pytorch/pytorch/pull/166147 Approved by: https://github.com/jerrymannil, https://github.com/jeffdaily
This commit is contained in:
parent
d7040e6d75
commit
deb776319b
|
|
@ -260,82 +260,20 @@ typedef __half half;
|
||||||
)";
|
)";
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(USE_ROCM) && ROCM_VERSION < 70000
|
#if defined(USE_ROCM)
|
||||||
|
|
||||||
|
#if ROCM_VERSION >= 70000
|
||||||
|
#define BF16_UINT32_DEF "typedef unsigned int uint32_t;\n"
|
||||||
|
#else
|
||||||
|
#define BF16_UINT32_DEF ""
|
||||||
|
#endif
|
||||||
|
|
||||||
constexpr auto bfloat16_support_literal =
|
constexpr auto bfloat16_support_literal =
|
||||||
R"(
|
R"(
|
||||||
#ifndef __align__
|
#ifndef __align__
|
||||||
#define __align__(x) __attribute__((aligned(x)))
|
#define __align__(x) __attribute__((aligned(x)))
|
||||||
#endif
|
#endif
|
||||||
|
)" BF16_UINT32_DEF R"(
|
||||||
typedef struct __align__(2) {
|
|
||||||
unsigned short x;
|
|
||||||
}
|
|
||||||
__nv_bfloat16_raw;
|
|
||||||
|
|
||||||
#if defined(__cplusplus)
|
|
||||||
struct __align__(2) __nv_bfloat16 {
|
|
||||||
__host__ __device__ __nv_bfloat16() {}
|
|
||||||
|
|
||||||
__host__ __device__ __nv_bfloat16& operator=(const __nv_bfloat16_raw& hr) {
|
|
||||||
__x = hr.x;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
unsigned short __x;
|
|
||||||
};
|
|
||||||
|
|
||||||
__device__ unsigned short __internal_float2bfloat16(
|
|
||||||
const float f,
|
|
||||||
unsigned int& sign,
|
|
||||||
unsigned int& remainder) {
|
|
||||||
unsigned int x;
|
|
||||||
|
|
||||||
x = __float_as_uint(f);
|
|
||||||
|
|
||||||
if ((x & 0x7fffffffU) > 0x7f800000U) {
|
|
||||||
sign = 0U;
|
|
||||||
remainder = 0U;
|
|
||||||
return static_cast<unsigned short>(0x7fffU);
|
|
||||||
}
|
|
||||||
sign = x >> 31;
|
|
||||||
remainder = x << 16;
|
|
||||||
return static_cast<unsigned short>(x >> 16);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Definitions of intrinsics */
|
|
||||||
__device__ __nv_bfloat16 __float2bfloat16(const float a) {
|
|
||||||
__nv_bfloat16 val;
|
|
||||||
__nv_bfloat16_raw r;
|
|
||||||
unsigned int sign;
|
|
||||||
unsigned int remainder;
|
|
||||||
r.x = __internal_float2bfloat16(a, sign, remainder);
|
|
||||||
if ((remainder > 0x80000000U) ||
|
|
||||||
((remainder == 0x80000000U) && ((r.x & 0x1U) != 0U))) {
|
|
||||||
r.x++;
|
|
||||||
}
|
|
||||||
val = r;
|
|
||||||
return val;
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ float __bfloat162float(const __nv_bfloat16 a) {
|
|
||||||
union
|
|
||||||
{
|
|
||||||
uint32_t int32;
|
|
||||||
float fp32;
|
|
||||||
} u = {uint32_t(a.__x) << 16};
|
|
||||||
return u.fp32;
|
|
||||||
}
|
|
||||||
#endif /* defined(__cplusplus) */
|
|
||||||
)";
|
|
||||||
#elif defined(USE_ROCM) && ROCM_VERSION >= 70000
|
|
||||||
constexpr auto bfloat16_support_literal =
|
|
||||||
R"(
|
|
||||||
#ifndef __align__
|
|
||||||
#define __align__(x) __attribute__((aligned(x)))
|
|
||||||
#endif
|
|
||||||
|
|
||||||
typedef unsigned int uint32_t;
|
|
||||||
|
|
||||||
typedef struct __align__(2) {
|
typedef struct __align__(2) {
|
||||||
unsigned short x;
|
unsigned short x;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user