mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Metal][BE] Move atomic ops to c10/metal/atomic.h (#151868)
To be reused from indexing and MPSInductor implementaiton of atomic_add stores Added wrapper for `metal::atomic<int>`(to be used by followup PR) Pull Request resolved: https://github.com/pytorch/pytorch/pull/151868 Approved by: https://github.com/Skylion007
This commit is contained in:
parent
159e2f96e3
commit
d778c92e16
|
|
@ -1,58 +1,8 @@
|
|||
#include <c10/metal/atomic.h>
|
||||
#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
// Atomic operations helper
|
||||
template <typename T>
|
||||
struct AtomicType {};
|
||||
template <typename T>
|
||||
using AtomicType_t = typename AtomicType<T>::type;
|
||||
|
||||
template <>
|
||||
struct AtomicType<float> {
|
||||
using type = atomic<float>;
|
||||
static inline void atomic_add(device type* data, long offset, float value) {
|
||||
atomic_fetch_add_explicit(data + offset, value, memory_order_relaxed);
|
||||
}
|
||||
};
|
||||
|
||||
// As of Metal3.2 atomic operations are not supported on half-precision floats,
|
||||
// so they must be simulated Using atomic compare and exchange over 32-bit
|
||||
// atomic type
|
||||
template <typename T>
|
||||
static inline void atomic_add_helper(
|
||||
device atomic<int>* data,
|
||||
long offset,
|
||||
float value) {
|
||||
auto ptr = data + (offset >> 1);
|
||||
auto old = atomic_load_explicit(ptr, memory_order_relaxed);
|
||||
union {
|
||||
int i;
|
||||
T t[2];
|
||||
} val;
|
||||
do {
|
||||
val.i = old;
|
||||
val.t[offset & 1] += static_cast<T>(value);
|
||||
} while (!atomic_compare_exchange_weak_explicit(
|
||||
ptr, &old, val.i, memory_order_relaxed, memory_order_relaxed));
|
||||
}
|
||||
|
||||
template <>
|
||||
struct AtomicType<half> {
|
||||
using type = atomic<int>;
|
||||
static inline void atomic_add(device type* data, long offset, float value) {
|
||||
atomic_add_helper<half>(data, offset, value);
|
||||
}
|
||||
};
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
template <>
|
||||
struct AtomicType<bfloat> {
|
||||
using type = atomic<int>;
|
||||
static inline void atomic_add(device type* data, long offset, float value) {
|
||||
atomic_add_helper<bfloat>(data, offset, value);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
using namespace c10::metal;
|
||||
|
||||
// Based on
|
||||
// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
|
||||
|
|
|
|||
74
c10/metal/atomic.h
Normal file
74
c10/metal/atomic.h
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
#pragma once
|
||||
#include <metal_atomic>
|
||||
namespace c10 {
|
||||
namespace metal {
|
||||
|
||||
// Atomic operations helper
|
||||
template <typename T>
|
||||
struct AtomicType {};
|
||||
template <typename T>
|
||||
using AtomicType_t = typename AtomicType<T>::type;
|
||||
|
||||
template <>
|
||||
struct AtomicType<float> {
|
||||
using type = ::metal::atomic<float>;
|
||||
static inline void atomic_add(device type* data, long offset, float value) {
|
||||
::metal::atomic_fetch_add_explicit(
|
||||
data + offset, value, ::metal::memory_order_relaxed);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct AtomicType<int> {
|
||||
using type = ::metal::atomic<int>;
|
||||
static inline void atomic_add(device type* data, long offset, int value) {
|
||||
::metal::atomic_fetch_add_explicit(
|
||||
data + offset, value, ::metal::memory_order_relaxed);
|
||||
}
|
||||
};
|
||||
|
||||
// As of Metal3.2 atomic operations are not supported on half-precision floats,
|
||||
// so they must be simulated Using atomic compare and exchange over 32-bit
|
||||
// atomic type
|
||||
template <typename T>
|
||||
static inline void atomic_add_helper(
|
||||
device ::metal::atomic<int>* data,
|
||||
long offset,
|
||||
float value) {
|
||||
auto ptr = data + (offset >> 1);
|
||||
auto old = ::metal::atomic_load_explicit(ptr, ::metal::memory_order_relaxed);
|
||||
union {
|
||||
int i;
|
||||
T t[2];
|
||||
} val;
|
||||
do {
|
||||
val.i = old;
|
||||
val.t[offset & 1] += static_cast<T>(value);
|
||||
} while (!::metal::atomic_compare_exchange_weak_explicit(
|
||||
ptr,
|
||||
&old,
|
||||
val.i,
|
||||
::metal::memory_order_relaxed,
|
||||
::metal::memory_order_relaxed));
|
||||
}
|
||||
|
||||
template <>
|
||||
struct AtomicType<half> {
|
||||
using type = ::metal::atomic<int>;
|
||||
static inline void atomic_add(device type* data, long offset, float value) {
|
||||
atomic_add_helper<half>(data, offset, value);
|
||||
}
|
||||
};
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
template <>
|
||||
struct AtomicType<bfloat> {
|
||||
using type = ::metal::atomic<int>;
|
||||
static inline void atomic_add(device type* data, long offset, float value) {
|
||||
atomic_add_helper<bfloat>(data, offset, value);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
} // namespace metal
|
||||
} // namespace c10
|
||||
Loading…
Reference in New Issue
Block a user