mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ROCm] opportunistic fastatomics - fix build error with newer compilers (#152841)
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/152841 Approved by: https://github.com/jeffdaily
This commit is contained in:
parent
1f4f4a61c2
commit
8faa0b18c3
|
|
@ -243,12 +243,10 @@ __device__ __forceinline__ void opportunistic_fastAtomicAdd(
|
|||
{
|
||||
typedef unsigned short __attribute__((ext_vector_type(2))) vec_short2;
|
||||
union ill { unsigned int i[2]; int64_t il; };
|
||||
ill iil_, ill_oneUpDst, ill_oneDnDst = {};
|
||||
ill iil_, ill_oneUpDst = {};
|
||||
iil_.il = (int64_t)dst;
|
||||
ill_oneUpDst.i[0] = __builtin_amdgcn_mov_dpp(iil_.i[0], 0x130, 0xf, 0xf, 0);
|
||||
ill_oneUpDst.i[1] = __builtin_amdgcn_mov_dpp(iil_.i[1], 0x130, 0xf, 0xf, 0);
|
||||
ill_oneDnDst.i[0] = __builtin_amdgcn_mov_dpp(iil_.i[0], 0x138, 0xf, 0xf, 0);
|
||||
ill_oneDnDst.i[1] = __builtin_amdgcn_mov_dpp(iil_.i[1], 0x138, 0xf, 0xf, 0);
|
||||
union bfi {scalar_t bf; short s; } bfi_ = { .bf = value }; bfi bfi_oneUpVal;
|
||||
|
||||
bfi_oneUpVal.s = __builtin_amdgcn_mov_dpp(bfi_.s, 0x130, 0xf, 0xf, 0);
|
||||
|
|
@ -264,7 +262,8 @@ __device__ __forceinline__ void opportunistic_fastAtomicAdd(
|
|||
if (__lane_id()%2==0)
|
||||
{
|
||||
if (canCombnUp) {
|
||||
union bfvs { scalar_t bf[2]; vec_short2 vs2; __half2 df16; };
|
||||
typedef _Float16 __attribute__((ext_vector_type(2))) vec_fp162;
|
||||
union bfvs { scalar_t bf[2]; vec_short2 vs2; vec_fp162 df16; };
|
||||
bfvs bfvs_ = {};
|
||||
bfvs_.bf[0] = value;
|
||||
bfvs_.bf[1] = oneUpVal;
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user