[ROCm] Define uint32 t when ROCM_VERSION >= 70000 (#160587)

This PR fixes the errors like below:
```
[rank3]: RuntimeError: The following operation failed in the TorchScript interpreter.
[rank3]: Traceback of TorchScript (most recent call last):
[rank3]: RuntimeError: /tmp/comgr-28f951/input/CompileSourceACC062:67:7: error: unknown type name 'uint32_t'; did you mean '__hip_internal::uint32_t'?
[rank3]:    67 |       uint32_t int32;
[rank3]:       |       ^~~~~~~~
[rank3]:       |       __hip_internal::uint32_t
```
Earlier uint32_t was defined in HIP headers in std namespace. Now it is moved to __hip_internal namespace in hip headers. This change is made in ROCm 7.0.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160587
Approved by: https://github.com/jeffdaily
This commit is contained in:
Ramya Ramineni 2025-09-12 00:13:26 +00:00 committed by PyTorch MergeBot
parent ff6870d134
commit a956066b4e

View File

@ -260,7 +260,7 @@ typedef __half half;
)";
#endif
#if defined(USE_ROCM)
#if defined(USE_ROCM) && ROCM_VERSION < 70000
constexpr auto bfloat16_support_literal =
R"(
#ifndef __align__
@ -317,6 +317,75 @@ __device__ __nv_bfloat16 __float2bfloat16(const float a) {
return val;
}
__device__ float __bfloat162float(const __nv_bfloat16 a) {
union
{
uint32_t int32;
float fp32;
} u = {uint32_t(a.__x) << 16};
return u.fp32;
}
#endif /* defined(__cplusplus) */
)";
#elif defined(USE_ROCM) && ROCM_VERSION >= 70000
constexpr auto bfloat16_support_literal =
R"(
#ifndef __align__
#define __align__(x) __attribute__((aligned(x)))
#endif
typedef unsigned int uint32_t;
typedef struct __align__(2) {
unsigned short x;
}
__nv_bfloat16_raw;
#if defined(__cplusplus)
struct __align__(2) __nv_bfloat16 {
__host__ __device__ __nv_bfloat16() {}
__host__ __device__ __nv_bfloat16& operator=(const __nv_bfloat16_raw& hr) {
__x = hr.x;
return *this;
}
unsigned short __x;
};
__device__ unsigned short __internal_float2bfloat16(
const float f,
unsigned int& sign,
unsigned int& remainder) {
unsigned int x;
x = __float_as_uint(f);
if ((x & 0x7fffffffU) > 0x7f800000U) {
sign = 0U;
remainder = 0U;
return static_cast<unsigned short>(0x7fffU);
}
sign = x >> 31;
remainder = x << 16;
return static_cast<unsigned short>(x >> 16);
}
/* Definitions of intrinsics */
__device__ __nv_bfloat16 __float2bfloat16(const float a) {
__nv_bfloat16 val;
__nv_bfloat16_raw r;
unsigned int sign;
unsigned int remainder;
r.x = __internal_float2bfloat16(a, sign, remainder);
if ((remainder > 0x80000000U) ||
((remainder == 0x80000000U) && ((r.x & 0x1U) != 0U))) {
r.x++;
}
val = r;
return val;
}
__device__ float __bfloat162float(const __nv_bfloat16 a) {
union
{