From deb776319b12cad7fffb5cf9d8851a50e2b0e9ea Mon Sep 17 00:00:00 2001 From: rraminen Date: Wed, 29 Oct 2025 16:59:03 +0000 Subject: [PATCH] [ROCm] Reduce duplication in bfloat16_support_literal definition (#166147) This PR refactors the bfloat16_support_literal constant in the PyTorch build logic to eliminate duplicated ROCm-specific code. Previously, there were two nearly identical branches for ROCM_VERSION < 70000 and ROCM_VERSION >= 70000, differing only by a single typedef. These have been unified into one conditional block with a minimal version guard inside. (https://github.com/ROCm/pytorch/pull/2502) Pull Request resolved: https://github.com/pytorch/pytorch/pull/166147 Approved by: https://github.com/jerrymannil, https://github.com/jeffdaily --- .../jit/codegen/fuser/cuda/resource_strings.h | 80 +++---------------- 1 file changed, 9 insertions(+), 71 deletions(-) diff --git a/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h b/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h index 0ac2c79d1e9..16ccc5002f9 100644 --- a/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h +++ b/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h @@ -260,82 +260,20 @@ typedef __half half; )"; #endif -#if defined(USE_ROCM) && ROCM_VERSION < 70000 +#if defined(USE_ROCM) + +#if ROCM_VERSION >= 70000 +#define BF16_UINT32_DEF "typedef unsigned int uint32_t;\n" +#else +#define BF16_UINT32_DEF "" +#endif + constexpr auto bfloat16_support_literal = R"( #ifndef __align__ #define __align__(x) __attribute__((aligned(x))) #endif - -typedef struct __align__(2) { - unsigned short x; -} -__nv_bfloat16_raw; - -#if defined(__cplusplus) -struct __align__(2) __nv_bfloat16 { - __host__ __device__ __nv_bfloat16() {} - - __host__ __device__ __nv_bfloat16& operator=(const __nv_bfloat16_raw& hr) { - __x = hr.x; - return *this; - } - - unsigned short __x; -}; - -__device__ unsigned short __internal_float2bfloat16( - const float f, - unsigned int& sign, - unsigned int& remainder) { - unsigned int x; - - x = __float_as_uint(f); - - if ((x & 0x7fffffffU) > 0x7f800000U) { - sign = 0U; - remainder = 0U; - return static_cast(0x7fffU); - } - sign = x >> 31; - remainder = x << 16; - return static_cast(x >> 16); -} - -/* Definitions of intrinsics */ -__device__ __nv_bfloat16 __float2bfloat16(const float a) { - __nv_bfloat16 val; - __nv_bfloat16_raw r; - unsigned int sign; - unsigned int remainder; - r.x = __internal_float2bfloat16(a, sign, remainder); - if ((remainder > 0x80000000U) || - ((remainder == 0x80000000U) && ((r.x & 0x1U) != 0U))) { - r.x++; - } - val = r; - return val; -} - -__device__ float __bfloat162float(const __nv_bfloat16 a) { - union - { - uint32_t int32; - float fp32; - } u = {uint32_t(a.__x) << 16}; - return u.fp32; -} -#endif /* defined(__cplusplus) */ -)"; -#elif defined(USE_ROCM) && ROCM_VERSION >= 70000 -constexpr auto bfloat16_support_literal = - R"( -#ifndef __align__ -#define __align__(x) __attribute__((aligned(x))) -#endif - -typedef unsigned int uint32_t; - +)" BF16_UINT32_DEF R"( typedef struct __align__(2) { unsigned short x; }