[Metal][BE] Move atomic ops to c10/metal/atomic.h (#151868)

To be reused from indexing and MPSInductor implementaiton of atomic_add stores
Added wrapper for `metal::atomic<int>`(to be used by followup PR)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151868
Approved by: https://github.com/Skylion007
This commit is contained in:
Nikita Shulga 2025-04-21 19:38:25 -07:00 committed by PyTorch MergeBot
parent 159e2f96e3
commit d778c92e16
2 changed files with 77 additions and 53 deletions

View File

@ -1,58 +1,8 @@
#include <c10/metal/atomic.h>
#include <metal_stdlib>
using namespace metal;
// Atomic operations helper
template <typename T>
struct AtomicType {};
template <typename T>
using AtomicType_t = typename AtomicType<T>::type;
template <>
struct AtomicType<float> {
using type = atomic<float>;
static inline void atomic_add(device type* data, long offset, float value) {
atomic_fetch_add_explicit(data + offset, value, memory_order_relaxed);
}
};
// As of Metal3.2 atomic operations are not supported on half-precision floats,
// so they must be simulated Using atomic compare and exchange over 32-bit
// atomic type
template <typename T>
static inline void atomic_add_helper(
device atomic<int>* data,
long offset,
float value) {
auto ptr = data + (offset >> 1);
auto old = atomic_load_explicit(ptr, memory_order_relaxed);
union {
int i;
T t[2];
} val;
do {
val.i = old;
val.t[offset & 1] += static_cast<T>(value);
} while (!atomic_compare_exchange_weak_explicit(
ptr, &old, val.i, memory_order_relaxed, memory_order_relaxed));
}
template <>
struct AtomicType<half> {
using type = atomic<int>;
static inline void atomic_add(device type* data, long offset, float value) {
atomic_add_helper<half>(data, offset, value);
}
};
#if __METAL_VERSION__ >= 310
template <>
struct AtomicType<bfloat> {
using type = atomic<int>;
static inline void atomic_add(device type* data, long offset, float value) {
atomic_add_helper<bfloat>(data, offset, value);
}
};
#endif
using namespace c10::metal;
// Based on
// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm

74
c10/metal/atomic.h Normal file
View File

@ -0,0 +1,74 @@
#pragma once
#include <metal_atomic>
namespace c10 {
namespace metal {
// Atomic operations helper
template <typename T>
struct AtomicType {};
template <typename T>
using AtomicType_t = typename AtomicType<T>::type;
template <>
struct AtomicType<float> {
using type = ::metal::atomic<float>;
static inline void atomic_add(device type* data, long offset, float value) {
::metal::atomic_fetch_add_explicit(
data + offset, value, ::metal::memory_order_relaxed);
}
};
template <>
struct AtomicType<int> {
using type = ::metal::atomic<int>;
static inline void atomic_add(device type* data, long offset, int value) {
::metal::atomic_fetch_add_explicit(
data + offset, value, ::metal::memory_order_relaxed);
}
};
// As of Metal3.2 atomic operations are not supported on half-precision floats,
// so they must be simulated Using atomic compare and exchange over 32-bit
// atomic type
template <typename T>
static inline void atomic_add_helper(
device ::metal::atomic<int>* data,
long offset,
float value) {
auto ptr = data + (offset >> 1);
auto old = ::metal::atomic_load_explicit(ptr, ::metal::memory_order_relaxed);
union {
int i;
T t[2];
} val;
do {
val.i = old;
val.t[offset & 1] += static_cast<T>(value);
} while (!::metal::atomic_compare_exchange_weak_explicit(
ptr,
&old,
val.i,
::metal::memory_order_relaxed,
::metal::memory_order_relaxed));
}
template <>
struct AtomicType<half> {
using type = ::metal::atomic<int>;
static inline void atomic_add(device type* data, long offset, float value) {
atomic_add_helper<half>(data, offset, value);
}
};
#if __METAL_VERSION__ >= 310
template <>
struct AtomicType<bfloat> {
using type = ::metal::atomic<int>;
static inline void atomic_add(device type* data, long offset, float value) {
atomic_add_helper<bfloat>(data, offset, value);
}
};
#endif
} // namespace metal
} // namespace c10