mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
By reusing `c10/metal/atomic.h` This also fixes `GPUTests.test_index_put_fallback[12]_mps` that is unrolled by inductor, so no need for dedicated atomic_add support TODOs: - Get rid of indexing kernel and compute it directly when kernel is run - Simulate atomic_add for int64 types as series of int32 atomic-add-and-fetch - Setup tolerances correctly to pass float16/bfloat16 tests (as CPU always takes sequential strategy) Pull Request resolved: https://github.com/pytorch/pytorch/pull/151869 Approved by: https://github.com/Skylion007, https://github.com/dcci
75 lines
1.9 KiB
C++
75 lines
1.9 KiB
C++
#pragma once
|
|
#include <metal_atomic>
|
|
namespace c10 {
|
|
namespace metal {
|
|
|
|
// Atomic operations helper
|
|
template <typename T>
|
|
struct AtomicType {};
|
|
template <typename T>
|
|
using AtomicType_t = typename AtomicType<T>::type;
|
|
|
|
template <>
|
|
struct AtomicType<float> {
|
|
using type = ::metal::atomic<float>;
|
|
static inline void atomic_add(device type* data, long offset, float value) {
|
|
::metal::atomic_fetch_add_explicit(
|
|
data + offset, value, ::metal::memory_order_relaxed);
|
|
}
|
|
};
|
|
|
|
template <>
|
|
struct AtomicType<int> {
|
|
using type = ::metal::atomic<int>;
|
|
static inline void atomic_add(device type* data, long offset, int value) {
|
|
::metal::atomic_fetch_add_explicit(
|
|
data + offset, value, ::metal::memory_order_relaxed);
|
|
}
|
|
};
|
|
|
|
// As of Metal3.2 atomic operations are not supported on half-precision floats,
|
|
// so they must be simulated Using atomic compare and exchange over 32-bit
|
|
// atomic type
|
|
template <typename T>
|
|
static inline void atomic_add_helper(
|
|
device ::metal::atomic<uint>* data,
|
|
long offset,
|
|
T value) {
|
|
auto ptr = data + (offset >> 1);
|
|
auto old = ::metal::atomic_load_explicit(ptr, ::metal::memory_order_relaxed);
|
|
union {
|
|
uint i;
|
|
T t[2];
|
|
} val;
|
|
do {
|
|
val.i = old;
|
|
val.t[offset & 1] += value;
|
|
} while (!::metal::atomic_compare_exchange_weak_explicit(
|
|
ptr,
|
|
&old,
|
|
val.i,
|
|
::metal::memory_order_relaxed,
|
|
::metal::memory_order_relaxed));
|
|
}
|
|
|
|
template <>
|
|
struct AtomicType<half> {
|
|
using type = ::metal::atomic<uint>;
|
|
static inline void atomic_add(device type* data, long offset, half value) {
|
|
atomic_add_helper<half>(data, offset, value);
|
|
}
|
|
};
|
|
|
|
#if __METAL_VERSION__ >= 310
|
|
template <>
|
|
struct AtomicType<bfloat> {
|
|
using type = ::metal::atomic<uint>;
|
|
static inline void atomic_add(device type* data, long offset, bfloat value) {
|
|
atomic_add_helper<bfloat>(data, offset, value);
|
|
}
|
|
};
|
|
#endif
|
|
|
|
} // namespace metal
|
|
} // namespace c10
|