[MPS] Fix index_kernel for large tensors (#158064)

Move `MetalShaderLibrary::bind_tensors` private method to OperatorUtils.h and extract `iter_tensor_offset` method, that returns an offset from the start of the storage associated with given tensor inside the iterator

Migrated `index`, `index_put[_accumulate][_serial]` to the new paradigm that does not require additional tensor for indices nor special handling for 32 vs 64-bit offset, which resulted in almost 2x perf gain for 2000x2000 tensor, see results below before
```
[------------------------------------------------------------  -----------------------------------------------------------]
                                                |  11x50x50  |  11x100x100  |  11x500x500  |  11x1000x1000  |  11x2000x2000
1 threads: ----------------------------------------------------------------------------------------------------------------
      __getitem__ (torch.int8, torch.int64)     |   383.5    |    379.8     |    470.9     |     1232.9     |     4410.3
      __getitem__ (torch.float16, torch.int64)  |   379.6    |    354.5     |    533.2     |     1290.3     |     4442.2
      __getitem__ (torch.float32, torch.int64)  |   360.8    |    338.6     |    478.6     |     1348.9     |     4870.4

Times are in microseconds (us).
```
and after
```
[------------------------------------------------------------  -----------------------------------------------------------]
                                                |  11x50x50  |  11x100x100  |  11x500x500  |  11x1000x1000  |  11x2000x2000
1 threads: ----------------------------------------------------------------------------------------------------------------
      __getitem__ (torch.int8, torch.int64)     |   349.8    |    330.5     |    432.6     |     764.5      |     1961.2
      __getitem__ (torch.float16, torch.int64)  |   342.5    |    330.7     |    434.7     |     741.0      |     1969.4
      __getitem__ (torch.float32, torch.int64)  |   332.2    |    326.1     |    445.4     |     751.3      |     1972.6

Times are in microseconds (us).
```

While migrating also fixed index_put_accumulate for boolean types, by using compare_and_exchange trick over uint

Fixes https://github.com/pytorch/pytorch/issues/153560
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158064
Approved by: https://github.com/dcci
This commit is contained in:
Nikita Shulga 2025-07-11 11:40:18 -07:00 committed by PyTorch MergeBot
parent 93854e83b7
commit beed033b6e
9 changed files with 322 additions and 312 deletions

View File

@ -156,7 +156,6 @@ class MetalShaderLibrary {
MTLLibrary_t lib, MTLLibrary_t lib,
const std::string& fname); const std::string& fname);
MTLLibrary_t compileLibrary(const std::string& src); MTLLibrary_t compileLibrary(const std::string& src);
void bind_tensors(MTLComputeCommandEncoder_t, TensorIteratorBase&);
std::string shaderSource; std::string shaderSource;
unsigned nparams; unsigned nparams;
MTLCompileOptions* compile_options; MTLCompileOptions* compile_options;

View File

@ -5,6 +5,7 @@
#include <initializer_list> #include <initializer_list>
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/Tensor.h> #include <ATen/Tensor.h>
#include <ATen/TensorIterator.h>
#include <ATen/Utils.h> #include <ATen/Utils.h>
#include <ATen/mps/MPSStream.h> #include <ATen/mps/MPSStream.h>
#include <ATen/native/mps/MetalShaderLibrary.h> #include <ATen/native/mps/MetalShaderLibrary.h>
@ -35,10 +36,6 @@
name:(NSString*)name; name:(NSString*)name;
@end @end
// Fwd declarations
namespace at {
struct TensorIteratorBase;
}
using namespace at::mps; using namespace at::mps;
namespace at::native::mps { namespace at::native::mps {
@ -508,6 +505,30 @@ static inline void mtl_setBytes(id<MTLComputeCommandEncoder> encoder, const MPSS
[encoder setBytes:&s.value length:s.size atIndex:idx]; [encoder setBytes:&s.value length:s.size atIndex:idx];
} }
static size_t iter_tensor_offset(TensorIteratorBase& iter, unsigned idx) {
// At the moment, MPS storage data is not the real GPU pointer, but rather a pointer to id<MTLBuffer> object
// But TensorIterator constructs data_ptr as if base was just a raw pointer
// Workaround this problem by computing an offset from the start of the tensor, which works for both
// tensor views and sliced 64-bit iterators
return reinterpret_cast<size_t>(iter.data_ptr(idx)) -
reinterpret_cast<size_t>(iter.tensor_base(idx).storage().data());
}
static inline void bind_iter_tensors(id<MTLComputeCommandEncoder> encoder,
TensorIteratorBase& iter,
std::optional<size_t> ntensors = std::nullopt) {
for (auto idx : c10::irange(ntensors.value_or(iter.ntensors()))) {
auto& t = iter.tensor_base(idx);
// Handle CPU scalars
if (C10_UNLIKELY(t.device().type() == kCPU)) {
mtl_setBuffer(encoder, t, idx);
continue;
}
auto offs = iter_tensor_offset(iter, idx);
[encoder setBuffer:getMTLBufferStorage(t) offset:offs atIndex:idx];
}
}
namespace detail { namespace detail {
template <typename T> template <typename T>
inline void mtl_setArg(id<MTLComputeCommandEncoder> encoder, const T& val, unsigned idx) { inline void mtl_setArg(id<MTLComputeCommandEncoder> encoder, const T& val, unsigned idx) {

View File

@ -971,23 +971,6 @@ class BundledShaderLibary : public MetalShaderLibrary {
} }
}; };
void MetalShaderLibrary::bind_tensors(id<MTLComputeCommandEncoder> encoder, TensorIteratorBase& iter) {
for (auto idx : c10::irange(iter.ntensors())) {
auto& t = iter.tensor_base(idx);
// Handle CPU scalars
if (C10_UNLIKELY(t.device().type() == kCPU)) {
mtl_setBuffer(encoder, t, idx);
continue;
}
// At the moment, MPS storage data is not the real GPU pointer, but rather a pointer to id<MTLBuffer> object
// But TensorIterator constructs data_ptr as if base was just a raw pointer
// Workaround this problem by computing an offset from the start of the tensor, which works for both
// tensor vies and sliced 64-bit iterators
auto offs = reinterpret_cast<size_t>(iter.data_ptr(idx)) - reinterpret_cast<size_t>(t.storage().data());
[encoder setBuffer:getMTLBufferStorage(t) offset:offs atIndex:idx];
}
}
void MetalShaderLibrary::exec_unary_kernel(TensorIteratorBase& iter, void MetalShaderLibrary::exec_unary_kernel(TensorIteratorBase& iter,
const std::string& name, const std::string& name,
std::optional<c10::Scalar> alpha, std::optional<c10::Scalar> alpha,
@ -1024,7 +1007,7 @@ void MetalShaderLibrary::exec_unary_kernel(TensorIteratorBase& iter,
getMPSProfiler().beginProfileKernel(cplState, name, {inputTensor}); getMPSProfiler().beginProfileKernel(cplState, name, {inputTensor});
[computeEncoder setComputePipelineState:cplState]; [computeEncoder setComputePipelineState:cplState];
bind_tensors(computeEncoder, iter); bind_iter_tensors(computeEncoder, iter);
if (!iter.is_contiguous()) { if (!iter.is_contiguous()) {
mtl_setArgs<2>(computeEncoder, mtl_setArgs<2>(computeEncoder,
outputTensor.sizes(), outputTensor.sizes(),
@ -1100,7 +1083,7 @@ void MetalShaderLibrary::exec_binary_kernel(TensorIteratorBase& iter,
getMPSProfiler().beginProfileKernel(binaryPSO, kernel_name, {input, other}); getMPSProfiler().beginProfileKernel(binaryPSO, kernel_name, {input, other});
[computeEncoder setComputePipelineState:binaryPSO]; [computeEncoder setComputePipelineState:binaryPSO];
// Set input and output tensors // Set input and output tensors
bind_tensors(computeEncoder, iter); bind_iter_tensors(computeEncoder, iter);
// Iterator is contiguous if all of its elements are dense in storage, // Iterator is contiguous if all of its elements are dense in storage,
// i.e. it's true for both row-first and column-first tensors // i.e. it's true for both row-first and column-first tensors
if (iter.is_contiguous()) { if (iter.is_contiguous()) {

View File

@ -9,164 +9,191 @@ struct IndexAB {
constant int64_t* indexArray; constant int64_t* indexArray;
}; };
template <typename T, typename OffsetsT> template <typename T, typename OffsetT = ulong>
kernel void index_select( kernel void index_select(
constant IndexAB* indexAB [[buffer(0)]], device T* output,
constant void* indexSizes [[buffer(1)]], constant T* input,
constant void* indexStrides [[buffer(2)]], constant IndexAB* indices,
constant OffsetsT* offsets [[buffer(3)]], constant int64_t* sizes,
constant void* inputData [[buffer(4)]], constant int64_t* output_strides,
device void* outputData [[buffer(5)]], constant int64_t* input_strides,
constant uint32_t& num_indices [[buffer(6)]], constant int64_t* indices_strides,
uint thread_index [[thread_position_in_grid]]) {
constant int64_t* index_sizes = (constant int64_t*)indexSizes;
constant int64_t* index_strides = (constant int64_t*)indexStrides;
int64_t offset = 0;
for (uint32_t i = 0; i < num_indices; i++) {
constant int64_t* indexArray = indexAB[i].indexArray;
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
if (index < 0) {
index += index_sizes[i];
}
offset += index * index_strides[i];
}
device T* out =
(device T*)((device char*)outputData + offsets[thread_index].x);
constant T* in = (constant T*)((constant char*)inputData +
offsets[thread_index].y + offset);
*out = *in;
}
template <typename T, typename OffsetsT>
void index_put_impl(
constant IndexAB* indexAB,
constant int64_t* index_sizes, constant int64_t* index_sizes,
constant int64_t* index_strides, constant int64_t* index_strides,
constant OffsetsT* offsets, constant uint4& ndim_nindices_numel,
constant void* inputData,
device void* outputData,
constant uint32_t& num_indices,
uint thread_index) {
int64_t offset = 0;
for (uint32_t i = 0; i < num_indices; i++) {
constant int64_t* indexArray = indexAB[i].indexArray;
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
if (index < 0) {
index += index_sizes[i];
}
offset += index * index_strides[i];
}
device T* out =
(device T*)((device char*)outputData + offsets[thread_index].x + offset);
constant T* in =
(constant T*)((constant char*)inputData + offsets[thread_index].y);
*out = *in;
}
template <typename T, typename OffsetsT>
kernel void index_put_serial(
constant IndexAB* indexAB [[buffer(0)]],
constant void* indexSizes [[buffer(1)]],
constant void* indexStrides [[buffer(2)]],
constant OffsetsT* offsets [[buffer(3)]],
constant void* inputData [[buffer(4)]],
device void* outputData [[buffer(5)]],
constant uint32_t& num_indices [[buffer(6)]],
constant uint* numIters [[buffer(7)]]) {
constant int64_t* index_sizes = (constant int64_t*)indexSizes;
constant int64_t* index_strides = (constant int64_t*)indexStrides;
for (uint iter_i = 0; iter_i < *numIters; iter_i++) {
index_put_impl<T>(
indexAB,
index_sizes,
index_strides,
offsets,
inputData,
outputData,
num_indices,
iter_i);
}
}
template <typename T, typename OffsetsT>
kernel void index_put(
constant IndexAB* indexAB [[buffer(0)]],
constant void* indexSizes [[buffer(1)]],
constant void* indexStrides [[buffer(2)]],
constant OffsetsT* offsets [[buffer(3)]],
constant void* inputData [[buffer(4)]],
device void* outputData [[buffer(5)]],
constant uint32_t& num_indices [[buffer(6)]],
uint thread_index [[thread_position_in_grid]]) { uint thread_index [[thread_position_in_grid]]) {
constant int64_t* index_sizes = (constant int64_t*)indexSizes; const auto ndim = ndim_nindices_numel.x;
constant int64_t* index_strides = (constant int64_t*)indexStrides; const auto num_indices = ndim_nindices_numel.y;
index_put_impl<T>( uint pos[max_ndim];
indexAB, pos_from_thread_index(thread_index, pos, sizes, ndim);
const auto output_offs = offset_from_coord(pos, output_strides, ndim);
OffsetT input_offs = offset_from_coord(pos, input_strides, ndim);
const auto indices_offs =
offset_from_coord(pos, indices_strides, ndim) / sizeof(int64_t);
for (uint i = 0; i < num_indices; i++) {
auto idx = indices[i].indexArray[indices_offs];
if (idx < 0) {
idx += index_sizes[i];
}
input_offs += idx * index_strides[i];
}
output[output_offs / sizeof(T)] = input[input_offs / sizeof(T)];
}
template <typename T, typename OffsetT = ulong>
inline void index_put_impl(
device T* output,
constant T* input,
constant IndexAB* indices,
constant int64_t* sizes,
constant int64_t* output_strides,
constant int64_t* input_strides,
constant int64_t* indices_strides,
constant int64_t* index_sizes,
constant int64_t* index_strides,
constant uint4& ndim_nindices_numel,
uint thread_index) {
const auto ndim = ndim_nindices_numel.x;
const auto num_indices = ndim_nindices_numel.y;
uint pos[max_ndim];
pos_from_thread_index(thread_index, pos, sizes, ndim);
OffsetT output_offs = offset_from_coord(pos, output_strides, ndim);
const auto input_offs = offset_from_coord(pos, input_strides, ndim);
const auto indices_offs =
offset_from_coord(pos, indices_strides, ndim) / sizeof(int64_t);
for (uint i = 0; i < num_indices; i++) {
auto idx = indices[i].indexArray[indices_offs];
if (idx < 0) {
idx += index_sizes[i];
}
output_offs += idx * index_strides[i];
}
output[output_offs / sizeof(T)] = input[input_offs / sizeof(T)];
}
template <typename T, typename OffsetT = ulong>
kernel void index_put(
device T* output,
constant T* input,
constant IndexAB* indices,
constant int64_t* sizes,
constant int64_t* output_strides,
constant int64_t* input_strides,
constant int64_t* indices_strides,
constant int64_t* index_sizes,
constant int64_t* index_strides,
constant uint4& ndim_nindices_numel,
uint thread_index [[thread_position_in_grid]]) {
index_put_impl(
output,
input,
indices,
sizes,
output_strides,
input_strides,
indices_strides,
index_sizes, index_sizes,
index_strides, index_strides,
offsets, ndim_nindices_numel,
inputData,
outputData,
num_indices,
thread_index); thread_index);
} }
#define REGISTER_INDEX_OP( \ template <typename T, typename OffsetT = ulong>
DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE) \ kernel void index_put_serial(
template [[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE \ device T* output,
"_" #IDX_SIZE)]] kernel void \ constant T* input,
index_##INDEX_OP_TYPE<DTYPE, IDX_DTYPE>( \ constant IndexAB* indices,
constant IndexAB * indexAB [[buffer(0)]], \ constant int64_t* sizes,
constant void* indexSizes [[buffer(1)]], \ constant int64_t* output_strides,
constant void* indexStrides [[buffer(2)]], \ constant int64_t* input_strides,
constant IDX_DTYPE* offsets [[buffer(3)]], \ constant int64_t* indices_strides,
constant void* inputData [[buffer(4)]], \ constant int64_t* index_sizes,
device void* outputData [[buffer(5)]], \ constant int64_t* index_strides,
constant uint32_t& num_indices [[buffer(6)]], \ constant uint4& ndim_nindices_numel,
uint thread_index [[thread_position_in_grid]]) {
(void)thread_index; // Suppress unused vairable varning
for (uint idx = 0; idx < ndim_nindices_numel.z; ++idx) {
index_put_impl(
output,
input,
indices,
sizes,
output_strides,
input_strides,
indices_strides,
index_sizes,
index_strides,
ndim_nindices_numel,
idx);
}
}
template <typename T, typename OffsetT = ulong>
kernel void index_put_accumulate(
device T* output,
constant T* input,
constant IndexAB* indices,
constant int64_t* sizes,
constant int64_t* output_strides,
constant int64_t* input_strides,
constant int64_t* indices_strides,
constant int64_t* index_sizes,
constant int64_t* index_strides,
constant uint4& ndim_nindices_numel,
uint thread_index [[thread_position_in_grid]]) {
const auto ndim = ndim_nindices_numel.x;
const auto num_indices = ndim_nindices_numel.y;
uint pos[max_ndim];
pos_from_thread_index(thread_index, pos, sizes, ndim);
OffsetT output_offs = offset_from_coord(pos, output_strides, ndim);
const auto input_offs = offset_from_coord(pos, input_strides, ndim);
const auto indices_offs =
offset_from_coord(pos, indices_strides, ndim) / sizeof(int64_t);
for (uint i = 0; i < num_indices; i++) {
auto idx = indices[i].indexArray[indices_offs];
if (idx < 0) {
idx += index_sizes[i];
}
output_offs += idx * index_strides[i];
}
AtomicType<T>::atomic_add(
reinterpret_cast<device AtomicType_t<T>*>(output),
output_offs / sizeof(T),
input[input_offs / sizeof(T)]);
}
#define REGISTER_INDEX_OP(OP_NAME, SUFFIX, DTYPE) \
template [[host_name("index_" #OP_NAME "_" #SUFFIX)]] kernel void \
index_##OP_NAME<DTYPE>( \
device DTYPE * output, \
constant DTYPE * input, \
constant IndexAB * indices, \
constant int64_t* sizes, \
constant int64_t* output_strides, \
constant int64_t* input_strides, \
constant int64_t* indices_strides, \
constant int64_t* index_sizes, \
constant int64_t* index_strides, \
constant uint4& ndim_nindices_numel, \
uint thread_index [[thread_position_in_grid]]) uint thread_index [[thread_position_in_grid]])
#define REGISTER_INDEX_OP_ALL_DTYPES(INDEX_OP_TYPE) \ #define REGISTER_INDEX_OP_ALL_DTYPES(OP_NAME) \
REGISTER_INDEX_OP(8bit, idx32, char, INDEX_OP_TYPE, uint3); \ REGISTER_INDEX_OP(OP_NAME, 8bit, char); \
REGISTER_INDEX_OP(8bit, idx64, char, INDEX_OP_TYPE, ulong3); \ REGISTER_INDEX_OP(OP_NAME, 16bit, short); \
REGISTER_INDEX_OP(16bit, idx32, short, INDEX_OP_TYPE, uint3); \ REGISTER_INDEX_OP(OP_NAME, 32bit, int); \
REGISTER_INDEX_OP(16bit, idx64, short, INDEX_OP_TYPE, ulong3); \ REGISTER_INDEX_OP(OP_NAME, 64bit, long)
REGISTER_INDEX_OP(32bit, idx32, int, INDEX_OP_TYPE, uint3); \
REGISTER_INDEX_OP(32bit, idx64, int, INDEX_OP_TYPE, ulong3); \
REGISTER_INDEX_OP(64bit, idx32, long, INDEX_OP_TYPE, uint3); \
REGISTER_INDEX_OP(64bit, idx64, long, INDEX_OP_TYPE, ulong3);
REGISTER_INDEX_OP_ALL_DTYPES(select); REGISTER_INDEX_OP_ALL_DTYPES(select);
REGISTER_INDEX_OP_ALL_DTYPES(put); REGISTER_INDEX_OP_ALL_DTYPES(put);
REGISTER_INDEX_OP_ALL_DTYPES(put_serial);
#define REGISTER_SINGLE_THREADED_INDEX_OP( \ REGISTER_INDEX_OP(put_accumulate, float, float);
DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE) \ REGISTER_INDEX_OP(put_accumulate, half, half);
template [[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE \ REGISTER_INDEX_OP(put_accumulate, int, int);
"_" #IDX_SIZE)]] kernel void \ REGISTER_INDEX_OP(put_accumulate, bool, bool);
index_##INDEX_OP_TYPE<DTYPE, IDX_DTYPE>( \ #if __METAL_VERSION__ >= 310
constant IndexAB * indexAB [[buffer(0)]], \ REGISTER_INDEX_OP(put_accumulate, bfloat, bfloat);
constant void* indexSizes [[buffer(1)]], \ #endif
constant void* indexStrides [[buffer(2)]], \
constant IDX_DTYPE* offsets [[buffer(3)]], \
constant void* inputData [[buffer(4)]], \
device void* outputData [[buffer(5)]], \
constant uint32_t& num_indices [[buffer(6)]], \
constant uint* numIters [[buffer(7)]])
#define REGISTER_SINGLE_THREADED_INDEX_OP_ALL_DTYPES(INDEX_OP_TYPE) \
REGISTER_SINGLE_THREADED_INDEX_OP(8bit, idx32, char, INDEX_OP_TYPE, uint3); \
REGISTER_SINGLE_THREADED_INDEX_OP(8bit, idx64, char, INDEX_OP_TYPE, ulong3); \
REGISTER_SINGLE_THREADED_INDEX_OP( \
16bit, idx32, short, INDEX_OP_TYPE, uint3); \
REGISTER_SINGLE_THREADED_INDEX_OP( \
16bit, idx64, short, INDEX_OP_TYPE, ulong3); \
REGISTER_SINGLE_THREADED_INDEX_OP(32bit, idx32, int, INDEX_OP_TYPE, uint3); \
REGISTER_SINGLE_THREADED_INDEX_OP(32bit, idx64, int, INDEX_OP_TYPE, ulong3); \
REGISTER_SINGLE_THREADED_INDEX_OP(64bit, idx32, long, INDEX_OP_TYPE, uint3); \
REGISTER_SINGLE_THREADED_INDEX_OP(64bit, idx64, long, INDEX_OP_TYPE, ulong3);
REGISTER_SINGLE_THREADED_INDEX_OP_ALL_DTYPES(put_serial);
template <typename StridesT, typename DataT> template <typename StridesT, typename DataT>
kernel void kernel_index_offsets( kernel void kernel_index_offsets(
@ -201,60 +228,6 @@ kernel_index_offsets<packed_uint3, ulong3>(
constant uint& num_dimensions [[buffer(3)]], constant uint& num_dimensions [[buffer(3)]],
uint thread_index [[thread_position_in_grid]]); uint thread_index [[thread_position_in_grid]]);
template <typename T, typename OffsetsT>
kernel void index_put_accumulate(
constant IndexAB* indexAB [[buffer(0)]],
constant void* indexSizes [[buffer(1)]],
constant void* indexStrides [[buffer(2)]],
constant OffsetsT* offsets [[buffer(3)]],
constant void* inputData [[buffer(4)]],
device void* outputData [[buffer(5)]],
constant uint32_t& num_indices [[buffer(6)]],
uint thread_index [[thread_position_in_grid]]) {
constant int64_t* index_sizes = (constant int64_t*)indexSizes;
constant int64_t* index_strides = (constant int64_t*)indexStrides;
int64_t offset = offsets[thread_index].x;
for (uint32_t i = 0; i < num_indices; i++) {
constant int64_t* indexArray = indexAB[i].indexArray;
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
if (index < 0) {
index += index_sizes[i];
}
offset += index * index_strides[i];
}
const auto in =
*(constant T*)((constant char*)inputData + offsets[thread_index].y);
AtomicType<T>::atomic_add(
reinterpret_cast<device AtomicType_t<T>*>(outputData),
offset / sizeof(T),
in);
}
#define REGISTER_INDEX_PUT_ACCUMULATE(DTS, DTYPE, IDXS, IDX_DTYPE) \
template [[host_name("index_put_accumulate_" #DTS "_" #DTYPE \
"_" #IDXS)]] kernel void \
index_put_accumulate<DTYPE, IDX_DTYPE>( \
constant IndexAB * indexAB [[buffer(0)]], \
constant void* indexSizes [[buffer(1)]], \
constant void* indexStrides [[buffer(2)]], \
constant IDX_DTYPE* offsets [[buffer(3)]], \
constant void* inputData [[buffer(4)]], \
device void* outputData [[buffer(5)]], \
constant uint32_t& num_indices [[buffer(6)]], \
uint thread_index [[thread_position_in_grid]])
REGISTER_INDEX_PUT_ACCUMULATE(32bit, float, idx32, uint3);
REGISTER_INDEX_PUT_ACCUMULATE(32bit, float, idx64, ulong3);
REGISTER_INDEX_PUT_ACCUMULATE(32bit, int, idx32, uint3);
REGISTER_INDEX_PUT_ACCUMULATE(32bit, int, idx64, ulong3);
REGISTER_INDEX_PUT_ACCUMULATE(16bit, half, idx32, uint3);
REGISTER_INDEX_PUT_ACCUMULATE(16bit, half, idx64, ulong3);
#if __METAL_VERSION__ >= 310
REGISTER_INDEX_PUT_ACCUMULATE(16bit, bfloat, idx32, uint3);
REGISTER_INDEX_PUT_ACCUMULATE(16bit, bfloat, idx64, ulong3);
#endif
template <typename T> template <typename T>
kernel void masked_fill_scalar_dense( kernel void masked_fill_scalar_dense(
device T* input, device T* input,

View File

@ -100,91 +100,9 @@ static std::string getBitSizeString(ScalarType scalar_type) {
TORCH_CHECK(scalarBitSize <= 64, "Unsupported data type: ", getMPSTypeString(scalar_type)); TORCH_CHECK(scalarBitSize <= 64, "Unsupported data type: ", getMPSTypeString(scalar_type));
return std::to_string(scalarBitSize) + "bit"; return std::to_string(scalarBitSize) + "bit";
} }
static std::string getIndexFunctionName(ScalarType scalar_type,
bool index_select,
bool accumulate,
bool serial,
bool use_64bit_indexing) {
std::string indexFunction = index_select ? "index_select_"
: (accumulate && (scalar_type != kBool)) ? "index_put_accumulate_"
: (serial ? "index_put_serial_" : "index_put_");
indexFunction.append(getBitSizeString(scalar_type)); static std::string getBitSizeString(const TensorBase& t) {
if (accumulate) { return getBitSizeString(t.scalar_type());
indexFunction.append(1, '_');
indexFunction.append(scalarToMetalTypeString(scalar_type));
}
indexFunction.append(use_64bit_indexing ? "_idx64" : "_idx32");
return indexFunction;
}
static bool dispatchIndexKernel(TensorIteratorBase& iter,
IntArrayRef index_size,
IntArrayRef index_stride,
bool index_select,
bool accumulate) {
using namespace mps;
if (iter.numel() == 0) {
return true;
}
const bool serial_index_put = at::globalContext().deterministicAlgorithms() && !accumulate && !index_select;
const Tensor& inputTensor = iter.tensor(1);
Tensor outputTensor = iter.tensor(0);
MPSStream* mpsStream = getCurrentMPSStream();
id<MTLDevice> device = MPSDevice::getInstance()->device();
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
@autoreleasepool {
NSError* error = nil;
const int64_t num_indices = index_size.size();
const uint32_t numIters = serial_index_put ? iter.numel() : 1;
uint32_t numThreads = iter.numel();
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
const bool use_64bit_indexing = !iter.can_use_32bit_indexing();
auto kernelDataOffsets = generateKernelDataOffsets(computeEncoder, iter, use_64bit_indexing);
auto indexFunction = getIndexFunctionName(
inputTensor.scalar_type(), index_select, accumulate, serial_index_put, use_64bit_indexing);
auto indexSelectPSO = lib.getPipelineStateForFunc(indexFunction);
size_t argumentBufferLength = sizeof(uint64_t) * num_indices;
auto indexAB = [[device newBufferWithLength:argumentBufferLength options:0] autorelease];
uint64_t* indexABContents = (uint64_t*)(indexAB.contents);
for (uint32_t idx = 0; idx < num_indices; idx++) {
const Tensor& indexTensor = iter.tensor(idx + 2);
indexABContents[idx] =
getMTLBufferStorage(indexTensor).gpuAddress + (indexTensor.storage_offset() * indexTensor.element_size());
TORCH_CHECK(indexTensor.scalar_type() == ScalarType::Long, "index(): Expected dtype int64 for Index");
[computeEncoder useResource:getMTLBufferStorage(indexTensor) usage:MTLResourceUsageRead];
}
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler().beginProfileKernel(indexSelectPSO, indexFunction, {inputTensor});
[computeEncoder setComputePipelineState:indexSelectPSO];
mtl_setArgs(
computeEncoder, indexAB, index_size, index_stride, kernelDataOffsets, inputTensor, outputTensor, num_indices);
MTLSize gridSize = MTLSizeMake(numThreads, 1, 1);
if (serial_index_put) {
mtl_setBytes(computeEncoder, numIters, 7);
gridSize = MTLSizeMake(1, 1, 1);
numThreads = 1;
}
NSUInteger tgSize = indexSelectPSO.maxTotalThreadsPerThreadgroup;
if (tgSize > numThreads) {
tgSize = numThreads;
}
MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
[computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:threadGroupSize];
getMPSProfiler().endProfileKernel(indexSelectPSO);
}
});
return true;
} }
static void validateInputData(const TensorIteratorBase& iter, static void validateInputData(const TensorIteratorBase& iter,
@ -235,11 +153,56 @@ static Tensor& masked_select_out_mps_impl(Tensor& result, const Tensor& self, co
return result; return result;
} }
static void index_kernel_mps(TensorIteratorBase& iter, IntArrayRef index_size, IntArrayRef index_stride) { static void dispatch_index_kernel(TensorIteratorBase& iter,
@autoreleasepool { IntArrayRef index_size,
validateInputData(iter, index_size, index_stride, "index.Tensor_out", /*accumulate=*/false); IntArrayRef index_stride,
dispatchIndexKernel(iter, index_size, index_stride, /*index_select=*/true, /*accumulate=*/false); const std::string& kernel_name,
const bool serial = false) {
validateInputData(iter, index_size, index_stride, "index.Tensor_out", /*accumulate=*/false);
if (iter.numel() == 0)
return;
if (!iter.can_use_32bit_indexing()) {
for (auto& sub_iter : iter.with_32bit_indexing()) {
dispatch_index_kernel(sub_iter, index_size, index_stride, kernel_name);
}
return;
} }
const auto mpsStream = getCurrentMPSStream();
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
const int64_t num_indices = index_size.size();
auto indexSelectPSO = lib.getPipelineStateForFunc(kernel_name);
auto computeEncoder = mpsStream->commandEncoder();
size_t argumentBufferLength = sizeof(uint64_t) * num_indices;
std::vector<uint64_t> indexAB;
std::array<uint32_t, 4> ndim_nindiees = {static_cast<uint32_t>(iter.ndim()),
static_cast<uint32_t>(index_size.size()),
static_cast<uint32_t>(iter.numel()),
0};
for (uint32_t idx = 0; idx < num_indices; idx++) {
const auto& indexTensor = iter.tensor_base(idx + 2);
indexAB.push_back(getMTLBufferStorage(indexTensor).gpuAddress + iter_tensor_offset(iter, idx + 2));
TORCH_CHECK(indexTensor.scalar_type() == ScalarType::Long, "index(): Expected dtype int64 for Index");
[computeEncoder useResource:getMTLBufferStorage(indexTensor) usage:MTLResourceUsageRead];
}
[computeEncoder setComputePipelineState:indexSelectPSO];
bind_iter_tensors(computeEncoder, iter, 2);
mtl_setArgs<2>(computeEncoder,
indexAB,
iter.shape(),
iter.strides(0),
iter.strides(1),
iter.strides(2),
index_size,
index_stride,
ndim_nindiees);
mtl_dispatch1DJob(computeEncoder, indexSelectPSO, serial ? 1 : iter.numel());
});
}
static void index_kernel_mps(TensorIteratorBase& iter, IntArrayRef index_size, IntArrayRef index_stride) {
validateInputData(iter, index_size, index_stride, "index.Tensor_out", /*accumulate=*/false);
dispatch_index_kernel(
iter, index_size, index_stride, fmt::format("index_select_{}", getBitSizeString(iter.tensor_base(0))));
} }
static void index_put_kernel_mps(TensorIterator& iter, static void index_put_kernel_mps(TensorIterator& iter,
@ -248,7 +211,21 @@ static void index_put_kernel_mps(TensorIterator& iter,
bool accumulate) { bool accumulate) {
@autoreleasepool { @autoreleasepool {
validateInputData(iter, index_size, index_stride, "index_put_impl", accumulate); validateInputData(iter, index_size, index_stride, "index_put_impl", accumulate);
dispatchIndexKernel(iter, index_size, index_stride, /*index_select=*/false, accumulate); if (accumulate) {
dispatch_index_kernel(iter,
index_size,
index_stride,
fmt::format("index_put_accumulate_{}", scalarToMetalTypeString(iter.tensor_base(0))));
} else if (at::globalContext().deterministicAlgorithms()) {
dispatch_index_kernel(iter,
index_size,
index_stride,
fmt::format("index_put_serial_{}", getBitSizeString(iter.tensor_base(0))),
true);
} else {
dispatch_index_kernel(
iter, index_size, index_stride, fmt::format("index_put_{}", getBitSizeString(iter.tensor_base(0))));
}
} }
} }
} // namespace mps } // namespace mps

View File

@ -70,5 +70,36 @@ struct AtomicType<bfloat> {
}; };
#endif #endif
// Metal supports atomic_store_explicit for bools, but
// sizeof(::metal::atomic_bool) is 4 Therefore it could not be used to
// atomically modify unaligned memory, so fall back to compare and exchange
// trick As accumulation over booleans are just or operation, do nothing if
// value is false
template <>
struct AtomicType<bool> {
using type = ::metal::atomic<uint>;
static inline void atomic_add(device type* data, long offset, bool value) {
if (!value) {
return;
}
auto ptr = data + (offset >> 2);
auto old =
::metal::atomic_load_explicit(ptr, ::metal::memory_order_relaxed);
union {
uint i;
bool t[4];
} val;
do {
val.i = old;
val.t[offset & 3] = true;
} while (!::metal::atomic_compare_exchange_weak_explicit(
ptr,
&old,
val.i,
::metal::memory_order_relaxed,
::metal::memory_order_relaxed));
}
};
} // namespace metal } // namespace metal
} // namespace c10 } // namespace c10

View File

@ -158,6 +158,17 @@ def main() -> None:
if torch.backends.mps.is_macos_or_newer(14, 0): if torch.backends.mps.is_macos_or_newer(14, 0):
dtypes.append(torch.bfloat16) dtypes.append(torch.bfloat16)
# Profile index ops
B = 11
rc = []
for dtype, N in itertools.product(
[torch.int8, torch.float16, torch.float32], [50, 100, 500, 1000, 2000]
):
x = torch.testing.make_tensor((B, N, N), device="mps", dtype=dtype)
y = torch.randint(0, B, (3,))
rc.append(bench_binary_op(torch.Tensor.__getitem__, x, y, f"{B}x{N}x{N}"))
Compare(rc).print()
# Profile unary ops # Profile unary ops
rc = [] rc = []
for op, dtype in itertools.product([torch.sqrt, torch.sin], dtypes): for op, dtype in itertools.product([torch.sqrt, torch.sin], dtypes):

View File

@ -7996,6 +7996,23 @@ class TestLargeTensors(TestCaseMPS):
rc_slice_cpu = (a.cpu() + b.cpu()[slice_idx:]).sin() rc_slice_cpu = (a.cpu() + b.cpu()[slice_idx:]).sin()
self.assertEqual(rc_slice, rc_slice_cpu) self.assertEqual(rc_slice, rc_slice_cpu)
@serialTest()
def test_64bit_index_select(self):
if torch.mps.recommended_max_memory() < 16_000_000_000:
raise unittest.SkipTest("Needs at least 16Gb of RAM")
B, N = 11, 20000
x = torch.empty(B, N, N, dtype=torch.float16, device='mps')
for i in range(B):
x[i] = 1.0 * i
batch_idx = torch.tensor([9], device='mps')
y = x[batch_idx]
self.assertEqual(y[0, 1, 2].item(), 9.0)
# Reclaim memory after running the tests
del y
del x
gc.collect()
torch.mps.empty_cache()
class TestLogical(TestCaseMPS): class TestLogical(TestCaseMPS):
def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False): def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False):

View File

@ -543,7 +543,6 @@ if torch.backends.mps.is_available():
"rounddecimals_0": [torch.bfloat16], "rounddecimals_0": [torch.bfloat16],
# atomic operations not supported # atomic operations not supported
"_unsafe_masked_index_put_accumulate": [ "_unsafe_masked_index_put_accumulate": [
torch.bool,
torch.int8, torch.int8,
torch.uint8, torch.uint8,
torch.int16, torch.int16,
@ -644,7 +643,6 @@ if torch.backends.mps.is_available():
torch.bfloat16, torch.bfloat16,
], ],
"index_put": [ "index_put": [
torch.bool,
torch.uint8, torch.uint8,
torch.int8, torch.int8,
torch.int16, torch.int16,