pytorch/c10/metal/atomic.h
Nikita Shulga 3aecf2dc52 [MPS] Extend index_put to half precision floats (#151869)
By reusing `c10/metal/atomic.h`
This also fixes `GPUTests.test_index_put_fallback[12]_mps` that is unrolled by inductor, so no need for dedicated atomic_add support

TODOs:
 - Get rid of indexing kernel and compute it directly when kernel is run
 - Simulate atomic_add for int64 types as series of int32 atomic-add-and-fetch
 - Setup tolerances correctly to pass float16/bfloat16 tests (as CPU always takes sequential strategy)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151869
Approved by: https://github.com/Skylion007, https://github.com/dcci
2025-04-22 22:00:08 +00:00

75 lines
1.9 KiB
C++

#pragma once
#include <metal_atomic>
namespace c10 {
namespace metal {
// Atomic operations helper
template <typename T>
struct AtomicType {};
template <typename T>
using AtomicType_t = typename AtomicType<T>::type;
template <>
struct AtomicType<float> {
using type = ::metal::atomic<float>;
static inline void atomic_add(device type* data, long offset, float value) {
::metal::atomic_fetch_add_explicit(
data + offset, value, ::metal::memory_order_relaxed);
}
};
template <>
struct AtomicType<int> {
using type = ::metal::atomic<int>;
static inline void atomic_add(device type* data, long offset, int value) {
::metal::atomic_fetch_add_explicit(
data + offset, value, ::metal::memory_order_relaxed);
}
};
// As of Metal3.2 atomic operations are not supported on half-precision floats,
// so they must be simulated Using atomic compare and exchange over 32-bit
// atomic type
template <typename T>
static inline void atomic_add_helper(
device ::metal::atomic<uint>* data,
long offset,
T value) {
auto ptr = data + (offset >> 1);
auto old = ::metal::atomic_load_explicit(ptr, ::metal::memory_order_relaxed);
union {
uint i;
T t[2];
} val;
do {
val.i = old;
val.t[offset & 1] += value;
} while (!::metal::atomic_compare_exchange_weak_explicit(
ptr,
&old,
val.i,
::metal::memory_order_relaxed,
::metal::memory_order_relaxed));
}
template <>
struct AtomicType<half> {
using type = ::metal::atomic<uint>;
static inline void atomic_add(device type* data, long offset, half value) {
atomic_add_helper<half>(data, offset, value);
}
};
#if __METAL_VERSION__ >= 310
template <>
struct AtomicType<bfloat> {
using type = ::metal::atomic<uint>;
static inline void atomic_add(device type* data, long offset, bfloat value) {
atomic_add_helper<bfloat>(data, offset, value);
}
};
#endif
} // namespace metal
} // namespace c10