mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ROCm] Define uint32 t when ROCM_VERSION >= 70000 (#160587)
This PR fixes the errors like below: ``` [rank3]: RuntimeError: The following operation failed in the TorchScript interpreter. [rank3]: Traceback of TorchScript (most recent call last): [rank3]: RuntimeError: /tmp/comgr-28f951/input/CompileSourceACC062:67:7: error: unknown type name 'uint32_t'; did you mean '__hip_internal::uint32_t'? [rank3]: 67 | uint32_t int32; [rank3]: | ^~~~~~~~ [rank3]: | __hip_internal::uint32_t ``` Earlier uint32_t was defined in HIP headers in std namespace. Now it is moved to __hip_internal namespace in hip headers. This change is made in ROCm 7.0. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160587 Approved by: https://github.com/jeffdaily
This commit is contained in:
parent
ff6870d134
commit
a956066b4e
|
|
@ -260,7 +260,7 @@ typedef __half half;
|
|||
)";
|
||||
#endif
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#if defined(USE_ROCM) && ROCM_VERSION < 70000
|
||||
constexpr auto bfloat16_support_literal =
|
||||
R"(
|
||||
#ifndef __align__
|
||||
|
|
@ -317,6 +317,75 @@ __device__ __nv_bfloat16 __float2bfloat16(const float a) {
|
|||
return val;
|
||||
}
|
||||
|
||||
__device__ float __bfloat162float(const __nv_bfloat16 a) {
|
||||
union
|
||||
{
|
||||
uint32_t int32;
|
||||
float fp32;
|
||||
} u = {uint32_t(a.__x) << 16};
|
||||
return u.fp32;
|
||||
}
|
||||
#endif /* defined(__cplusplus) */
|
||||
)";
|
||||
#elif defined(USE_ROCM) && ROCM_VERSION >= 70000
|
||||
constexpr auto bfloat16_support_literal =
|
||||
R"(
|
||||
#ifndef __align__
|
||||
#define __align__(x) __attribute__((aligned(x)))
|
||||
#endif
|
||||
|
||||
typedef unsigned int uint32_t;
|
||||
|
||||
typedef struct __align__(2) {
|
||||
unsigned short x;
|
||||
}
|
||||
__nv_bfloat16_raw;
|
||||
|
||||
#if defined(__cplusplus)
|
||||
struct __align__(2) __nv_bfloat16 {
|
||||
__host__ __device__ __nv_bfloat16() {}
|
||||
|
||||
__host__ __device__ __nv_bfloat16& operator=(const __nv_bfloat16_raw& hr) {
|
||||
__x = hr.x;
|
||||
return *this;
|
||||
}
|
||||
|
||||
unsigned short __x;
|
||||
};
|
||||
|
||||
__device__ unsigned short __internal_float2bfloat16(
|
||||
const float f,
|
||||
unsigned int& sign,
|
||||
unsigned int& remainder) {
|
||||
unsigned int x;
|
||||
|
||||
x = __float_as_uint(f);
|
||||
|
||||
if ((x & 0x7fffffffU) > 0x7f800000U) {
|
||||
sign = 0U;
|
||||
remainder = 0U;
|
||||
return static_cast<unsigned short>(0x7fffU);
|
||||
}
|
||||
sign = x >> 31;
|
||||
remainder = x << 16;
|
||||
return static_cast<unsigned short>(x >> 16);
|
||||
}
|
||||
|
||||
/* Definitions of intrinsics */
|
||||
__device__ __nv_bfloat16 __float2bfloat16(const float a) {
|
||||
__nv_bfloat16 val;
|
||||
__nv_bfloat16_raw r;
|
||||
unsigned int sign;
|
||||
unsigned int remainder;
|
||||
r.x = __internal_float2bfloat16(a, sign, remainder);
|
||||
if ((remainder > 0x80000000U) ||
|
||||
((remainder == 0x80000000U) && ((r.x & 0x1U) != 0U))) {
|
||||
r.x++;
|
||||
}
|
||||
val = r;
|
||||
return val;
|
||||
}
|
||||
|
||||
__device__ float __bfloat162float(const __nv_bfloat16 a) {
|
||||
union
|
||||
{
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user