[ROCm] Reduce duplication in bfloat16_support_literal definition (#166147)

This PR refactors the bfloat16_support_literal constant in the PyTorch build logic to eliminate duplicated ROCm-specific code.

Previously, there were two nearly identical branches for ROCM_VERSION < 70000 and ROCM_VERSION >= 70000, differing only by a single typedef. These have been unified into one conditional block with a minimal version guard inside. (https://github.com/ROCm/pytorch/pull/2502)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166147
Approved by: https://github.com/jerrymannil, https://github.com/jeffdaily
This commit is contained in:
rraminen 2025-10-29 16:59:03 +00:00 committed by PyTorch MergeBot
parent d7040e6d75
commit deb776319b

View File

@ -260,82 +260,20 @@ typedef __half half;
)"; )";
#endif #endif
#if defined(USE_ROCM) && ROCM_VERSION < 70000 #if defined(USE_ROCM)
#if ROCM_VERSION >= 70000
#define BF16_UINT32_DEF "typedef unsigned int uint32_t;\n"
#else
#define BF16_UINT32_DEF ""
#endif
constexpr auto bfloat16_support_literal = constexpr auto bfloat16_support_literal =
R"( R"(
#ifndef __align__ #ifndef __align__
#define __align__(x) __attribute__((aligned(x))) #define __align__(x) __attribute__((aligned(x)))
#endif #endif
)" BF16_UINT32_DEF R"(
typedef struct __align__(2) {
unsigned short x;
}
__nv_bfloat16_raw;
#if defined(__cplusplus)
struct __align__(2) __nv_bfloat16 {
__host__ __device__ __nv_bfloat16() {}
__host__ __device__ __nv_bfloat16& operator=(const __nv_bfloat16_raw& hr) {
__x = hr.x;
return *this;
}
unsigned short __x;
};
__device__ unsigned short __internal_float2bfloat16(
const float f,
unsigned int& sign,
unsigned int& remainder) {
unsigned int x;
x = __float_as_uint(f);
if ((x & 0x7fffffffU) > 0x7f800000U) {
sign = 0U;
remainder = 0U;
return static_cast<unsigned short>(0x7fffU);
}
sign = x >> 31;
remainder = x << 16;
return static_cast<unsigned short>(x >> 16);
}
/* Definitions of intrinsics */
__device__ __nv_bfloat16 __float2bfloat16(const float a) {
__nv_bfloat16 val;
__nv_bfloat16_raw r;
unsigned int sign;
unsigned int remainder;
r.x = __internal_float2bfloat16(a, sign, remainder);
if ((remainder > 0x80000000U) ||
((remainder == 0x80000000U) && ((r.x & 0x1U) != 0U))) {
r.x++;
}
val = r;
return val;
}
__device__ float __bfloat162float(const __nv_bfloat16 a) {
union
{
uint32_t int32;
float fp32;
} u = {uint32_t(a.__x) << 16};
return u.fp32;
}
#endif /* defined(__cplusplus) */
)";
#elif defined(USE_ROCM) && ROCM_VERSION >= 70000
constexpr auto bfloat16_support_literal =
R"(
#ifndef __align__
#define __align__(x) __attribute__((aligned(x)))
#endif
typedef unsigned int uint32_t;
typedef struct __align__(2) { typedef struct __align__(2) {
unsigned short x; unsigned short x;
} }