mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MPS] Fix index_kernel for large tensors (#158064)
Move `MetalShaderLibrary::bind_tensors` private method to OperatorUtils.h and extract `iter_tensor_offset` method, that returns an offset from the start of the storage associated with given tensor inside the iterator
Migrated `index`, `index_put[_accumulate][_serial]` to the new paradigm that does not require additional tensor for indices nor special handling for 32 vs 64-bit offset, which resulted in almost 2x perf gain for 2000x2000 tensor, see results below before
```
[------------------------------------------------------------ -----------------------------------------------------------]
| 11x50x50 | 11x100x100 | 11x500x500 | 11x1000x1000 | 11x2000x2000
1 threads: ----------------------------------------------------------------------------------------------------------------
__getitem__ (torch.int8, torch.int64) | 383.5 | 379.8 | 470.9 | 1232.9 | 4410.3
__getitem__ (torch.float16, torch.int64) | 379.6 | 354.5 | 533.2 | 1290.3 | 4442.2
__getitem__ (torch.float32, torch.int64) | 360.8 | 338.6 | 478.6 | 1348.9 | 4870.4
Times are in microseconds (us).
```
and after
```
[------------------------------------------------------------ -----------------------------------------------------------]
| 11x50x50 | 11x100x100 | 11x500x500 | 11x1000x1000 | 11x2000x2000
1 threads: ----------------------------------------------------------------------------------------------------------------
__getitem__ (torch.int8, torch.int64) | 349.8 | 330.5 | 432.6 | 764.5 | 1961.2
__getitem__ (torch.float16, torch.int64) | 342.5 | 330.7 | 434.7 | 741.0 | 1969.4
__getitem__ (torch.float32, torch.int64) | 332.2 | 326.1 | 445.4 | 751.3 | 1972.6
Times are in microseconds (us).
```
While migrating also fixed index_put_accumulate for boolean types, by using compare_and_exchange trick over uint
Fixes https://github.com/pytorch/pytorch/issues/153560
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158064
Approved by: https://github.com/dcci
This commit is contained in:
parent
93854e83b7
commit
beed033b6e
|
|
@ -156,7 +156,6 @@ class MetalShaderLibrary {
|
||||||
MTLLibrary_t lib,
|
MTLLibrary_t lib,
|
||||||
const std::string& fname);
|
const std::string& fname);
|
||||||
MTLLibrary_t compileLibrary(const std::string& src);
|
MTLLibrary_t compileLibrary(const std::string& src);
|
||||||
void bind_tensors(MTLComputeCommandEncoder_t, TensorIteratorBase&);
|
|
||||||
std::string shaderSource;
|
std::string shaderSource;
|
||||||
unsigned nparams;
|
unsigned nparams;
|
||||||
MTLCompileOptions* compile_options;
|
MTLCompileOptions* compile_options;
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@
|
||||||
#include <initializer_list>
|
#include <initializer_list>
|
||||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||||
#include <ATen/Tensor.h>
|
#include <ATen/Tensor.h>
|
||||||
|
#include <ATen/TensorIterator.h>
|
||||||
#include <ATen/Utils.h>
|
#include <ATen/Utils.h>
|
||||||
#include <ATen/mps/MPSStream.h>
|
#include <ATen/mps/MPSStream.h>
|
||||||
#include <ATen/native/mps/MetalShaderLibrary.h>
|
#include <ATen/native/mps/MetalShaderLibrary.h>
|
||||||
|
|
@ -35,10 +36,6 @@
|
||||||
name:(NSString*)name;
|
name:(NSString*)name;
|
||||||
@end
|
@end
|
||||||
|
|
||||||
// Fwd declarations
|
|
||||||
namespace at {
|
|
||||||
struct TensorIteratorBase;
|
|
||||||
}
|
|
||||||
using namespace at::mps;
|
using namespace at::mps;
|
||||||
|
|
||||||
namespace at::native::mps {
|
namespace at::native::mps {
|
||||||
|
|
@ -508,6 +505,30 @@ static inline void mtl_setBytes(id<MTLComputeCommandEncoder> encoder, const MPSS
|
||||||
[encoder setBytes:&s.value length:s.size atIndex:idx];
|
[encoder setBytes:&s.value length:s.size atIndex:idx];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static size_t iter_tensor_offset(TensorIteratorBase& iter, unsigned idx) {
|
||||||
|
// At the moment, MPS storage data is not the real GPU pointer, but rather a pointer to id<MTLBuffer> object
|
||||||
|
// But TensorIterator constructs data_ptr as if base was just a raw pointer
|
||||||
|
// Workaround this problem by computing an offset from the start of the tensor, which works for both
|
||||||
|
// tensor views and sliced 64-bit iterators
|
||||||
|
return reinterpret_cast<size_t>(iter.data_ptr(idx)) -
|
||||||
|
reinterpret_cast<size_t>(iter.tensor_base(idx).storage().data());
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void bind_iter_tensors(id<MTLComputeCommandEncoder> encoder,
|
||||||
|
TensorIteratorBase& iter,
|
||||||
|
std::optional<size_t> ntensors = std::nullopt) {
|
||||||
|
for (auto idx : c10::irange(ntensors.value_or(iter.ntensors()))) {
|
||||||
|
auto& t = iter.tensor_base(idx);
|
||||||
|
// Handle CPU scalars
|
||||||
|
if (C10_UNLIKELY(t.device().type() == kCPU)) {
|
||||||
|
mtl_setBuffer(encoder, t, idx);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto offs = iter_tensor_offset(iter, idx);
|
||||||
|
[encoder setBuffer:getMTLBufferStorage(t) offset:offs atIndex:idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
namespace detail {
|
namespace detail {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline void mtl_setArg(id<MTLComputeCommandEncoder> encoder, const T& val, unsigned idx) {
|
inline void mtl_setArg(id<MTLComputeCommandEncoder> encoder, const T& val, unsigned idx) {
|
||||||
|
|
|
||||||
|
|
@ -971,23 +971,6 @@ class BundledShaderLibary : public MetalShaderLibrary {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
void MetalShaderLibrary::bind_tensors(id<MTLComputeCommandEncoder> encoder, TensorIteratorBase& iter) {
|
|
||||||
for (auto idx : c10::irange(iter.ntensors())) {
|
|
||||||
auto& t = iter.tensor_base(idx);
|
|
||||||
// Handle CPU scalars
|
|
||||||
if (C10_UNLIKELY(t.device().type() == kCPU)) {
|
|
||||||
mtl_setBuffer(encoder, t, idx);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
// At the moment, MPS storage data is not the real GPU pointer, but rather a pointer to id<MTLBuffer> object
|
|
||||||
// But TensorIterator constructs data_ptr as if base was just a raw pointer
|
|
||||||
// Workaround this problem by computing an offset from the start of the tensor, which works for both
|
|
||||||
// tensor vies and sliced 64-bit iterators
|
|
||||||
auto offs = reinterpret_cast<size_t>(iter.data_ptr(idx)) - reinterpret_cast<size_t>(t.storage().data());
|
|
||||||
[encoder setBuffer:getMTLBufferStorage(t) offset:offs atIndex:idx];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void MetalShaderLibrary::exec_unary_kernel(TensorIteratorBase& iter,
|
void MetalShaderLibrary::exec_unary_kernel(TensorIteratorBase& iter,
|
||||||
const std::string& name,
|
const std::string& name,
|
||||||
std::optional<c10::Scalar> alpha,
|
std::optional<c10::Scalar> alpha,
|
||||||
|
|
@ -1024,7 +1007,7 @@ void MetalShaderLibrary::exec_unary_kernel(TensorIteratorBase& iter,
|
||||||
getMPSProfiler().beginProfileKernel(cplState, name, {inputTensor});
|
getMPSProfiler().beginProfileKernel(cplState, name, {inputTensor});
|
||||||
|
|
||||||
[computeEncoder setComputePipelineState:cplState];
|
[computeEncoder setComputePipelineState:cplState];
|
||||||
bind_tensors(computeEncoder, iter);
|
bind_iter_tensors(computeEncoder, iter);
|
||||||
if (!iter.is_contiguous()) {
|
if (!iter.is_contiguous()) {
|
||||||
mtl_setArgs<2>(computeEncoder,
|
mtl_setArgs<2>(computeEncoder,
|
||||||
outputTensor.sizes(),
|
outputTensor.sizes(),
|
||||||
|
|
@ -1100,7 +1083,7 @@ void MetalShaderLibrary::exec_binary_kernel(TensorIteratorBase& iter,
|
||||||
getMPSProfiler().beginProfileKernel(binaryPSO, kernel_name, {input, other});
|
getMPSProfiler().beginProfileKernel(binaryPSO, kernel_name, {input, other});
|
||||||
[computeEncoder setComputePipelineState:binaryPSO];
|
[computeEncoder setComputePipelineState:binaryPSO];
|
||||||
// Set input and output tensors
|
// Set input and output tensors
|
||||||
bind_tensors(computeEncoder, iter);
|
bind_iter_tensors(computeEncoder, iter);
|
||||||
// Iterator is contiguous if all of its elements are dense in storage,
|
// Iterator is contiguous if all of its elements are dense in storage,
|
||||||
// i.e. it's true for both row-first and column-first tensors
|
// i.e. it's true for both row-first and column-first tensors
|
||||||
if (iter.is_contiguous()) {
|
if (iter.is_contiguous()) {
|
||||||
|
|
|
||||||
|
|
@ -9,164 +9,191 @@ struct IndexAB {
|
||||||
constant int64_t* indexArray;
|
constant int64_t* indexArray;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, typename OffsetsT>
|
template <typename T, typename OffsetT = ulong>
|
||||||
kernel void index_select(
|
kernel void index_select(
|
||||||
constant IndexAB* indexAB [[buffer(0)]],
|
device T* output,
|
||||||
constant void* indexSizes [[buffer(1)]],
|
constant T* input,
|
||||||
constant void* indexStrides [[buffer(2)]],
|
constant IndexAB* indices,
|
||||||
constant OffsetsT* offsets [[buffer(3)]],
|
constant int64_t* sizes,
|
||||||
constant void* inputData [[buffer(4)]],
|
constant int64_t* output_strides,
|
||||||
device void* outputData [[buffer(5)]],
|
constant int64_t* input_strides,
|
||||||
constant uint32_t& num_indices [[buffer(6)]],
|
constant int64_t* indices_strides,
|
||||||
uint thread_index [[thread_position_in_grid]]) {
|
|
||||||
constant int64_t* index_sizes = (constant int64_t*)indexSizes;
|
|
||||||
constant int64_t* index_strides = (constant int64_t*)indexStrides;
|
|
||||||
int64_t offset = 0;
|
|
||||||
for (uint32_t i = 0; i < num_indices; i++) {
|
|
||||||
constant int64_t* indexArray = indexAB[i].indexArray;
|
|
||||||
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
|
|
||||||
if (index < 0) {
|
|
||||||
index += index_sizes[i];
|
|
||||||
}
|
|
||||||
offset += index * index_strides[i];
|
|
||||||
}
|
|
||||||
device T* out =
|
|
||||||
(device T*)((device char*)outputData + offsets[thread_index].x);
|
|
||||||
constant T* in = (constant T*)((constant char*)inputData +
|
|
||||||
offsets[thread_index].y + offset);
|
|
||||||
*out = *in;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename OffsetsT>
|
|
||||||
void index_put_impl(
|
|
||||||
constant IndexAB* indexAB,
|
|
||||||
constant int64_t* index_sizes,
|
constant int64_t* index_sizes,
|
||||||
constant int64_t* index_strides,
|
constant int64_t* index_strides,
|
||||||
constant OffsetsT* offsets,
|
constant uint4& ndim_nindices_numel,
|
||||||
constant void* inputData,
|
|
||||||
device void* outputData,
|
|
||||||
constant uint32_t& num_indices,
|
|
||||||
uint thread_index) {
|
|
||||||
int64_t offset = 0;
|
|
||||||
for (uint32_t i = 0; i < num_indices; i++) {
|
|
||||||
constant int64_t* indexArray = indexAB[i].indexArray;
|
|
||||||
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
|
|
||||||
|
|
||||||
if (index < 0) {
|
|
||||||
index += index_sizes[i];
|
|
||||||
}
|
|
||||||
offset += index * index_strides[i];
|
|
||||||
}
|
|
||||||
device T* out =
|
|
||||||
(device T*)((device char*)outputData + offsets[thread_index].x + offset);
|
|
||||||
constant T* in =
|
|
||||||
(constant T*)((constant char*)inputData + offsets[thread_index].y);
|
|
||||||
*out = *in;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename OffsetsT>
|
|
||||||
kernel void index_put_serial(
|
|
||||||
constant IndexAB* indexAB [[buffer(0)]],
|
|
||||||
constant void* indexSizes [[buffer(1)]],
|
|
||||||
constant void* indexStrides [[buffer(2)]],
|
|
||||||
constant OffsetsT* offsets [[buffer(3)]],
|
|
||||||
constant void* inputData [[buffer(4)]],
|
|
||||||
device void* outputData [[buffer(5)]],
|
|
||||||
constant uint32_t& num_indices [[buffer(6)]],
|
|
||||||
constant uint* numIters [[buffer(7)]]) {
|
|
||||||
constant int64_t* index_sizes = (constant int64_t*)indexSizes;
|
|
||||||
constant int64_t* index_strides = (constant int64_t*)indexStrides;
|
|
||||||
|
|
||||||
for (uint iter_i = 0; iter_i < *numIters; iter_i++) {
|
|
||||||
index_put_impl<T>(
|
|
||||||
indexAB,
|
|
||||||
index_sizes,
|
|
||||||
index_strides,
|
|
||||||
offsets,
|
|
||||||
inputData,
|
|
||||||
outputData,
|
|
||||||
num_indices,
|
|
||||||
iter_i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename OffsetsT>
|
|
||||||
kernel void index_put(
|
|
||||||
constant IndexAB* indexAB [[buffer(0)]],
|
|
||||||
constant void* indexSizes [[buffer(1)]],
|
|
||||||
constant void* indexStrides [[buffer(2)]],
|
|
||||||
constant OffsetsT* offsets [[buffer(3)]],
|
|
||||||
constant void* inputData [[buffer(4)]],
|
|
||||||
device void* outputData [[buffer(5)]],
|
|
||||||
constant uint32_t& num_indices [[buffer(6)]],
|
|
||||||
uint thread_index [[thread_position_in_grid]]) {
|
uint thread_index [[thread_position_in_grid]]) {
|
||||||
constant int64_t* index_sizes = (constant int64_t*)indexSizes;
|
const auto ndim = ndim_nindices_numel.x;
|
||||||
constant int64_t* index_strides = (constant int64_t*)indexStrides;
|
const auto num_indices = ndim_nindices_numel.y;
|
||||||
index_put_impl<T>(
|
uint pos[max_ndim];
|
||||||
indexAB,
|
pos_from_thread_index(thread_index, pos, sizes, ndim);
|
||||||
|
const auto output_offs = offset_from_coord(pos, output_strides, ndim);
|
||||||
|
OffsetT input_offs = offset_from_coord(pos, input_strides, ndim);
|
||||||
|
const auto indices_offs =
|
||||||
|
offset_from_coord(pos, indices_strides, ndim) / sizeof(int64_t);
|
||||||
|
for (uint i = 0; i < num_indices; i++) {
|
||||||
|
auto idx = indices[i].indexArray[indices_offs];
|
||||||
|
if (idx < 0) {
|
||||||
|
idx += index_sizes[i];
|
||||||
|
}
|
||||||
|
input_offs += idx * index_strides[i];
|
||||||
|
}
|
||||||
|
output[output_offs / sizeof(T)] = input[input_offs / sizeof(T)];
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename OffsetT = ulong>
|
||||||
|
inline void index_put_impl(
|
||||||
|
device T* output,
|
||||||
|
constant T* input,
|
||||||
|
constant IndexAB* indices,
|
||||||
|
constant int64_t* sizes,
|
||||||
|
constant int64_t* output_strides,
|
||||||
|
constant int64_t* input_strides,
|
||||||
|
constant int64_t* indices_strides,
|
||||||
|
constant int64_t* index_sizes,
|
||||||
|
constant int64_t* index_strides,
|
||||||
|
constant uint4& ndim_nindices_numel,
|
||||||
|
uint thread_index) {
|
||||||
|
const auto ndim = ndim_nindices_numel.x;
|
||||||
|
const auto num_indices = ndim_nindices_numel.y;
|
||||||
|
uint pos[max_ndim];
|
||||||
|
pos_from_thread_index(thread_index, pos, sizes, ndim);
|
||||||
|
OffsetT output_offs = offset_from_coord(pos, output_strides, ndim);
|
||||||
|
const auto input_offs = offset_from_coord(pos, input_strides, ndim);
|
||||||
|
const auto indices_offs =
|
||||||
|
offset_from_coord(pos, indices_strides, ndim) / sizeof(int64_t);
|
||||||
|
for (uint i = 0; i < num_indices; i++) {
|
||||||
|
auto idx = indices[i].indexArray[indices_offs];
|
||||||
|
if (idx < 0) {
|
||||||
|
idx += index_sizes[i];
|
||||||
|
}
|
||||||
|
output_offs += idx * index_strides[i];
|
||||||
|
}
|
||||||
|
output[output_offs / sizeof(T)] = input[input_offs / sizeof(T)];
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename OffsetT = ulong>
|
||||||
|
kernel void index_put(
|
||||||
|
device T* output,
|
||||||
|
constant T* input,
|
||||||
|
constant IndexAB* indices,
|
||||||
|
constant int64_t* sizes,
|
||||||
|
constant int64_t* output_strides,
|
||||||
|
constant int64_t* input_strides,
|
||||||
|
constant int64_t* indices_strides,
|
||||||
|
constant int64_t* index_sizes,
|
||||||
|
constant int64_t* index_strides,
|
||||||
|
constant uint4& ndim_nindices_numel,
|
||||||
|
uint thread_index [[thread_position_in_grid]]) {
|
||||||
|
index_put_impl(
|
||||||
|
output,
|
||||||
|
input,
|
||||||
|
indices,
|
||||||
|
sizes,
|
||||||
|
output_strides,
|
||||||
|
input_strides,
|
||||||
|
indices_strides,
|
||||||
index_sizes,
|
index_sizes,
|
||||||
index_strides,
|
index_strides,
|
||||||
offsets,
|
ndim_nindices_numel,
|
||||||
inputData,
|
|
||||||
outputData,
|
|
||||||
num_indices,
|
|
||||||
thread_index);
|
thread_index);
|
||||||
}
|
}
|
||||||
|
|
||||||
#define REGISTER_INDEX_OP( \
|
template <typename T, typename OffsetT = ulong>
|
||||||
DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE) \
|
kernel void index_put_serial(
|
||||||
template [[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE \
|
device T* output,
|
||||||
"_" #IDX_SIZE)]] kernel void \
|
constant T* input,
|
||||||
index_##INDEX_OP_TYPE<DTYPE, IDX_DTYPE>( \
|
constant IndexAB* indices,
|
||||||
constant IndexAB * indexAB [[buffer(0)]], \
|
constant int64_t* sizes,
|
||||||
constant void* indexSizes [[buffer(1)]], \
|
constant int64_t* output_strides,
|
||||||
constant void* indexStrides [[buffer(2)]], \
|
constant int64_t* input_strides,
|
||||||
constant IDX_DTYPE* offsets [[buffer(3)]], \
|
constant int64_t* indices_strides,
|
||||||
constant void* inputData [[buffer(4)]], \
|
constant int64_t* index_sizes,
|
||||||
device void* outputData [[buffer(5)]], \
|
constant int64_t* index_strides,
|
||||||
constant uint32_t& num_indices [[buffer(6)]], \
|
constant uint4& ndim_nindices_numel,
|
||||||
|
uint thread_index [[thread_position_in_grid]]) {
|
||||||
|
(void)thread_index; // Suppress unused vairable varning
|
||||||
|
for (uint idx = 0; idx < ndim_nindices_numel.z; ++idx) {
|
||||||
|
index_put_impl(
|
||||||
|
output,
|
||||||
|
input,
|
||||||
|
indices,
|
||||||
|
sizes,
|
||||||
|
output_strides,
|
||||||
|
input_strides,
|
||||||
|
indices_strides,
|
||||||
|
index_sizes,
|
||||||
|
index_strides,
|
||||||
|
ndim_nindices_numel,
|
||||||
|
idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename OffsetT = ulong>
|
||||||
|
kernel void index_put_accumulate(
|
||||||
|
device T* output,
|
||||||
|
constant T* input,
|
||||||
|
constant IndexAB* indices,
|
||||||
|
constant int64_t* sizes,
|
||||||
|
constant int64_t* output_strides,
|
||||||
|
constant int64_t* input_strides,
|
||||||
|
constant int64_t* indices_strides,
|
||||||
|
constant int64_t* index_sizes,
|
||||||
|
constant int64_t* index_strides,
|
||||||
|
constant uint4& ndim_nindices_numel,
|
||||||
|
uint thread_index [[thread_position_in_grid]]) {
|
||||||
|
const auto ndim = ndim_nindices_numel.x;
|
||||||
|
const auto num_indices = ndim_nindices_numel.y;
|
||||||
|
uint pos[max_ndim];
|
||||||
|
pos_from_thread_index(thread_index, pos, sizes, ndim);
|
||||||
|
OffsetT output_offs = offset_from_coord(pos, output_strides, ndim);
|
||||||
|
const auto input_offs = offset_from_coord(pos, input_strides, ndim);
|
||||||
|
const auto indices_offs =
|
||||||
|
offset_from_coord(pos, indices_strides, ndim) / sizeof(int64_t);
|
||||||
|
for (uint i = 0; i < num_indices; i++) {
|
||||||
|
auto idx = indices[i].indexArray[indices_offs];
|
||||||
|
if (idx < 0) {
|
||||||
|
idx += index_sizes[i];
|
||||||
|
}
|
||||||
|
output_offs += idx * index_strides[i];
|
||||||
|
}
|
||||||
|
AtomicType<T>::atomic_add(
|
||||||
|
reinterpret_cast<device AtomicType_t<T>*>(output),
|
||||||
|
output_offs / sizeof(T),
|
||||||
|
input[input_offs / sizeof(T)]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#define REGISTER_INDEX_OP(OP_NAME, SUFFIX, DTYPE) \
|
||||||
|
template [[host_name("index_" #OP_NAME "_" #SUFFIX)]] kernel void \
|
||||||
|
index_##OP_NAME<DTYPE>( \
|
||||||
|
device DTYPE * output, \
|
||||||
|
constant DTYPE * input, \
|
||||||
|
constant IndexAB * indices, \
|
||||||
|
constant int64_t* sizes, \
|
||||||
|
constant int64_t* output_strides, \
|
||||||
|
constant int64_t* input_strides, \
|
||||||
|
constant int64_t* indices_strides, \
|
||||||
|
constant int64_t* index_sizes, \
|
||||||
|
constant int64_t* index_strides, \
|
||||||
|
constant uint4& ndim_nindices_numel, \
|
||||||
uint thread_index [[thread_position_in_grid]])
|
uint thread_index [[thread_position_in_grid]])
|
||||||
|
|
||||||
#define REGISTER_INDEX_OP_ALL_DTYPES(INDEX_OP_TYPE) \
|
#define REGISTER_INDEX_OP_ALL_DTYPES(OP_NAME) \
|
||||||
REGISTER_INDEX_OP(8bit, idx32, char, INDEX_OP_TYPE, uint3); \
|
REGISTER_INDEX_OP(OP_NAME, 8bit, char); \
|
||||||
REGISTER_INDEX_OP(8bit, idx64, char, INDEX_OP_TYPE, ulong3); \
|
REGISTER_INDEX_OP(OP_NAME, 16bit, short); \
|
||||||
REGISTER_INDEX_OP(16bit, idx32, short, INDEX_OP_TYPE, uint3); \
|
REGISTER_INDEX_OP(OP_NAME, 32bit, int); \
|
||||||
REGISTER_INDEX_OP(16bit, idx64, short, INDEX_OP_TYPE, ulong3); \
|
REGISTER_INDEX_OP(OP_NAME, 64bit, long)
|
||||||
REGISTER_INDEX_OP(32bit, idx32, int, INDEX_OP_TYPE, uint3); \
|
|
||||||
REGISTER_INDEX_OP(32bit, idx64, int, INDEX_OP_TYPE, ulong3); \
|
|
||||||
REGISTER_INDEX_OP(64bit, idx32, long, INDEX_OP_TYPE, uint3); \
|
|
||||||
REGISTER_INDEX_OP(64bit, idx64, long, INDEX_OP_TYPE, ulong3);
|
|
||||||
|
|
||||||
REGISTER_INDEX_OP_ALL_DTYPES(select);
|
REGISTER_INDEX_OP_ALL_DTYPES(select);
|
||||||
REGISTER_INDEX_OP_ALL_DTYPES(put);
|
REGISTER_INDEX_OP_ALL_DTYPES(put);
|
||||||
|
REGISTER_INDEX_OP_ALL_DTYPES(put_serial);
|
||||||
|
|
||||||
#define REGISTER_SINGLE_THREADED_INDEX_OP( \
|
REGISTER_INDEX_OP(put_accumulate, float, float);
|
||||||
DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE) \
|
REGISTER_INDEX_OP(put_accumulate, half, half);
|
||||||
template [[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE \
|
REGISTER_INDEX_OP(put_accumulate, int, int);
|
||||||
"_" #IDX_SIZE)]] kernel void \
|
REGISTER_INDEX_OP(put_accumulate, bool, bool);
|
||||||
index_##INDEX_OP_TYPE<DTYPE, IDX_DTYPE>( \
|
#if __METAL_VERSION__ >= 310
|
||||||
constant IndexAB * indexAB [[buffer(0)]], \
|
REGISTER_INDEX_OP(put_accumulate, bfloat, bfloat);
|
||||||
constant void* indexSizes [[buffer(1)]], \
|
#endif
|
||||||
constant void* indexStrides [[buffer(2)]], \
|
|
||||||
constant IDX_DTYPE* offsets [[buffer(3)]], \
|
|
||||||
constant void* inputData [[buffer(4)]], \
|
|
||||||
device void* outputData [[buffer(5)]], \
|
|
||||||
constant uint32_t& num_indices [[buffer(6)]], \
|
|
||||||
constant uint* numIters [[buffer(7)]])
|
|
||||||
|
|
||||||
#define REGISTER_SINGLE_THREADED_INDEX_OP_ALL_DTYPES(INDEX_OP_TYPE) \
|
|
||||||
REGISTER_SINGLE_THREADED_INDEX_OP(8bit, idx32, char, INDEX_OP_TYPE, uint3); \
|
|
||||||
REGISTER_SINGLE_THREADED_INDEX_OP(8bit, idx64, char, INDEX_OP_TYPE, ulong3); \
|
|
||||||
REGISTER_SINGLE_THREADED_INDEX_OP( \
|
|
||||||
16bit, idx32, short, INDEX_OP_TYPE, uint3); \
|
|
||||||
REGISTER_SINGLE_THREADED_INDEX_OP( \
|
|
||||||
16bit, idx64, short, INDEX_OP_TYPE, ulong3); \
|
|
||||||
REGISTER_SINGLE_THREADED_INDEX_OP(32bit, idx32, int, INDEX_OP_TYPE, uint3); \
|
|
||||||
REGISTER_SINGLE_THREADED_INDEX_OP(32bit, idx64, int, INDEX_OP_TYPE, ulong3); \
|
|
||||||
REGISTER_SINGLE_THREADED_INDEX_OP(64bit, idx32, long, INDEX_OP_TYPE, uint3); \
|
|
||||||
REGISTER_SINGLE_THREADED_INDEX_OP(64bit, idx64, long, INDEX_OP_TYPE, ulong3);
|
|
||||||
|
|
||||||
REGISTER_SINGLE_THREADED_INDEX_OP_ALL_DTYPES(put_serial);
|
|
||||||
|
|
||||||
template <typename StridesT, typename DataT>
|
template <typename StridesT, typename DataT>
|
||||||
kernel void kernel_index_offsets(
|
kernel void kernel_index_offsets(
|
||||||
|
|
@ -201,60 +228,6 @@ kernel_index_offsets<packed_uint3, ulong3>(
|
||||||
constant uint& num_dimensions [[buffer(3)]],
|
constant uint& num_dimensions [[buffer(3)]],
|
||||||
uint thread_index [[thread_position_in_grid]]);
|
uint thread_index [[thread_position_in_grid]]);
|
||||||
|
|
||||||
template <typename T, typename OffsetsT>
|
|
||||||
kernel void index_put_accumulate(
|
|
||||||
constant IndexAB* indexAB [[buffer(0)]],
|
|
||||||
constant void* indexSizes [[buffer(1)]],
|
|
||||||
constant void* indexStrides [[buffer(2)]],
|
|
||||||
constant OffsetsT* offsets [[buffer(3)]],
|
|
||||||
constant void* inputData [[buffer(4)]],
|
|
||||||
device void* outputData [[buffer(5)]],
|
|
||||||
constant uint32_t& num_indices [[buffer(6)]],
|
|
||||||
uint thread_index [[thread_position_in_grid]]) {
|
|
||||||
constant int64_t* index_sizes = (constant int64_t*)indexSizes;
|
|
||||||
constant int64_t* index_strides = (constant int64_t*)indexStrides;
|
|
||||||
int64_t offset = offsets[thread_index].x;
|
|
||||||
for (uint32_t i = 0; i < num_indices; i++) {
|
|
||||||
constant int64_t* indexArray = indexAB[i].indexArray;
|
|
||||||
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
|
|
||||||
if (index < 0) {
|
|
||||||
index += index_sizes[i];
|
|
||||||
}
|
|
||||||
offset += index * index_strides[i];
|
|
||||||
}
|
|
||||||
const auto in =
|
|
||||||
*(constant T*)((constant char*)inputData + offsets[thread_index].y);
|
|
||||||
AtomicType<T>::atomic_add(
|
|
||||||
reinterpret_cast<device AtomicType_t<T>*>(outputData),
|
|
||||||
offset / sizeof(T),
|
|
||||||
in);
|
|
||||||
}
|
|
||||||
|
|
||||||
#define REGISTER_INDEX_PUT_ACCUMULATE(DTS, DTYPE, IDXS, IDX_DTYPE) \
|
|
||||||
template [[host_name("index_put_accumulate_" #DTS "_" #DTYPE \
|
|
||||||
"_" #IDXS)]] kernel void \
|
|
||||||
index_put_accumulate<DTYPE, IDX_DTYPE>( \
|
|
||||||
constant IndexAB * indexAB [[buffer(0)]], \
|
|
||||||
constant void* indexSizes [[buffer(1)]], \
|
|
||||||
constant void* indexStrides [[buffer(2)]], \
|
|
||||||
constant IDX_DTYPE* offsets [[buffer(3)]], \
|
|
||||||
constant void* inputData [[buffer(4)]], \
|
|
||||||
device void* outputData [[buffer(5)]], \
|
|
||||||
constant uint32_t& num_indices [[buffer(6)]], \
|
|
||||||
uint thread_index [[thread_position_in_grid]])
|
|
||||||
|
|
||||||
REGISTER_INDEX_PUT_ACCUMULATE(32bit, float, idx32, uint3);
|
|
||||||
REGISTER_INDEX_PUT_ACCUMULATE(32bit, float, idx64, ulong3);
|
|
||||||
REGISTER_INDEX_PUT_ACCUMULATE(32bit, int, idx32, uint3);
|
|
||||||
REGISTER_INDEX_PUT_ACCUMULATE(32bit, int, idx64, ulong3);
|
|
||||||
REGISTER_INDEX_PUT_ACCUMULATE(16bit, half, idx32, uint3);
|
|
||||||
REGISTER_INDEX_PUT_ACCUMULATE(16bit, half, idx64, ulong3);
|
|
||||||
|
|
||||||
#if __METAL_VERSION__ >= 310
|
|
||||||
REGISTER_INDEX_PUT_ACCUMULATE(16bit, bfloat, idx32, uint3);
|
|
||||||
REGISTER_INDEX_PUT_ACCUMULATE(16bit, bfloat, idx64, ulong3);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
kernel void masked_fill_scalar_dense(
|
kernel void masked_fill_scalar_dense(
|
||||||
device T* input,
|
device T* input,
|
||||||
|
|
|
||||||
|
|
@ -100,91 +100,9 @@ static std::string getBitSizeString(ScalarType scalar_type) {
|
||||||
TORCH_CHECK(scalarBitSize <= 64, "Unsupported data type: ", getMPSTypeString(scalar_type));
|
TORCH_CHECK(scalarBitSize <= 64, "Unsupported data type: ", getMPSTypeString(scalar_type));
|
||||||
return std::to_string(scalarBitSize) + "bit";
|
return std::to_string(scalarBitSize) + "bit";
|
||||||
}
|
}
|
||||||
static std::string getIndexFunctionName(ScalarType scalar_type,
|
|
||||||
bool index_select,
|
|
||||||
bool accumulate,
|
|
||||||
bool serial,
|
|
||||||
bool use_64bit_indexing) {
|
|
||||||
std::string indexFunction = index_select ? "index_select_"
|
|
||||||
: (accumulate && (scalar_type != kBool)) ? "index_put_accumulate_"
|
|
||||||
: (serial ? "index_put_serial_" : "index_put_");
|
|
||||||
|
|
||||||
indexFunction.append(getBitSizeString(scalar_type));
|
static std::string getBitSizeString(const TensorBase& t) {
|
||||||
if (accumulate) {
|
return getBitSizeString(t.scalar_type());
|
||||||
indexFunction.append(1, '_');
|
|
||||||
indexFunction.append(scalarToMetalTypeString(scalar_type));
|
|
||||||
}
|
|
||||||
indexFunction.append(use_64bit_indexing ? "_idx64" : "_idx32");
|
|
||||||
return indexFunction;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool dispatchIndexKernel(TensorIteratorBase& iter,
|
|
||||||
IntArrayRef index_size,
|
|
||||||
IntArrayRef index_stride,
|
|
||||||
bool index_select,
|
|
||||||
bool accumulate) {
|
|
||||||
using namespace mps;
|
|
||||||
|
|
||||||
if (iter.numel() == 0) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
const bool serial_index_put = at::globalContext().deterministicAlgorithms() && !accumulate && !index_select;
|
|
||||||
|
|
||||||
const Tensor& inputTensor = iter.tensor(1);
|
|
||||||
Tensor outputTensor = iter.tensor(0);
|
|
||||||
MPSStream* mpsStream = getCurrentMPSStream();
|
|
||||||
id<MTLDevice> device = MPSDevice::getInstance()->device();
|
|
||||||
|
|
||||||
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
|
|
||||||
@autoreleasepool {
|
|
||||||
NSError* error = nil;
|
|
||||||
const int64_t num_indices = index_size.size();
|
|
||||||
const uint32_t numIters = serial_index_put ? iter.numel() : 1;
|
|
||||||
uint32_t numThreads = iter.numel();
|
|
||||||
|
|
||||||
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
|
|
||||||
const bool use_64bit_indexing = !iter.can_use_32bit_indexing();
|
|
||||||
auto kernelDataOffsets = generateKernelDataOffsets(computeEncoder, iter, use_64bit_indexing);
|
|
||||||
|
|
||||||
auto indexFunction = getIndexFunctionName(
|
|
||||||
inputTensor.scalar_type(), index_select, accumulate, serial_index_put, use_64bit_indexing);
|
|
||||||
auto indexSelectPSO = lib.getPipelineStateForFunc(indexFunction);
|
|
||||||
size_t argumentBufferLength = sizeof(uint64_t) * num_indices;
|
|
||||||
auto indexAB = [[device newBufferWithLength:argumentBufferLength options:0] autorelease];
|
|
||||||
uint64_t* indexABContents = (uint64_t*)(indexAB.contents);
|
|
||||||
for (uint32_t idx = 0; idx < num_indices; idx++) {
|
|
||||||
const Tensor& indexTensor = iter.tensor(idx + 2);
|
|
||||||
indexABContents[idx] =
|
|
||||||
getMTLBufferStorage(indexTensor).gpuAddress + (indexTensor.storage_offset() * indexTensor.element_size());
|
|
||||||
TORCH_CHECK(indexTensor.scalar_type() == ScalarType::Long, "index(): Expected dtype int64 for Index");
|
|
||||||
[computeEncoder useResource:getMTLBufferStorage(indexTensor) usage:MTLResourceUsageRead];
|
|
||||||
}
|
|
||||||
// this function call is a no-op if MPS Profiler is not enabled
|
|
||||||
getMPSProfiler().beginProfileKernel(indexSelectPSO, indexFunction, {inputTensor});
|
|
||||||
|
|
||||||
[computeEncoder setComputePipelineState:indexSelectPSO];
|
|
||||||
mtl_setArgs(
|
|
||||||
computeEncoder, indexAB, index_size, index_stride, kernelDataOffsets, inputTensor, outputTensor, num_indices);
|
|
||||||
MTLSize gridSize = MTLSizeMake(numThreads, 1, 1);
|
|
||||||
if (serial_index_put) {
|
|
||||||
mtl_setBytes(computeEncoder, numIters, 7);
|
|
||||||
gridSize = MTLSizeMake(1, 1, 1);
|
|
||||||
numThreads = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
NSUInteger tgSize = indexSelectPSO.maxTotalThreadsPerThreadgroup;
|
|
||||||
if (tgSize > numThreads) {
|
|
||||||
tgSize = numThreads;
|
|
||||||
}
|
|
||||||
|
|
||||||
MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
|
|
||||||
[computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:threadGroupSize];
|
|
||||||
|
|
||||||
getMPSProfiler().endProfileKernel(indexSelectPSO);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void validateInputData(const TensorIteratorBase& iter,
|
static void validateInputData(const TensorIteratorBase& iter,
|
||||||
|
|
@ -235,11 +153,56 @@ static Tensor& masked_select_out_mps_impl(Tensor& result, const Tensor& self, co
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void index_kernel_mps(TensorIteratorBase& iter, IntArrayRef index_size, IntArrayRef index_stride) {
|
static void dispatch_index_kernel(TensorIteratorBase& iter,
|
||||||
@autoreleasepool {
|
IntArrayRef index_size,
|
||||||
validateInputData(iter, index_size, index_stride, "index.Tensor_out", /*accumulate=*/false);
|
IntArrayRef index_stride,
|
||||||
dispatchIndexKernel(iter, index_size, index_stride, /*index_select=*/true, /*accumulate=*/false);
|
const std::string& kernel_name,
|
||||||
|
const bool serial = false) {
|
||||||
|
validateInputData(iter, index_size, index_stride, "index.Tensor_out", /*accumulate=*/false);
|
||||||
|
if (iter.numel() == 0)
|
||||||
|
return;
|
||||||
|
if (!iter.can_use_32bit_indexing()) {
|
||||||
|
for (auto& sub_iter : iter.with_32bit_indexing()) {
|
||||||
|
dispatch_index_kernel(sub_iter, index_size, index_stride, kernel_name);
|
||||||
|
}
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
const auto mpsStream = getCurrentMPSStream();
|
||||||
|
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
|
||||||
|
const int64_t num_indices = index_size.size();
|
||||||
|
auto indexSelectPSO = lib.getPipelineStateForFunc(kernel_name);
|
||||||
|
auto computeEncoder = mpsStream->commandEncoder();
|
||||||
|
size_t argumentBufferLength = sizeof(uint64_t) * num_indices;
|
||||||
|
std::vector<uint64_t> indexAB;
|
||||||
|
std::array<uint32_t, 4> ndim_nindiees = {static_cast<uint32_t>(iter.ndim()),
|
||||||
|
static_cast<uint32_t>(index_size.size()),
|
||||||
|
static_cast<uint32_t>(iter.numel()),
|
||||||
|
0};
|
||||||
|
for (uint32_t idx = 0; idx < num_indices; idx++) {
|
||||||
|
const auto& indexTensor = iter.tensor_base(idx + 2);
|
||||||
|
indexAB.push_back(getMTLBufferStorage(indexTensor).gpuAddress + iter_tensor_offset(iter, idx + 2));
|
||||||
|
TORCH_CHECK(indexTensor.scalar_type() == ScalarType::Long, "index(): Expected dtype int64 for Index");
|
||||||
|
[computeEncoder useResource:getMTLBufferStorage(indexTensor) usage:MTLResourceUsageRead];
|
||||||
|
}
|
||||||
|
[computeEncoder setComputePipelineState:indexSelectPSO];
|
||||||
|
bind_iter_tensors(computeEncoder, iter, 2);
|
||||||
|
mtl_setArgs<2>(computeEncoder,
|
||||||
|
indexAB,
|
||||||
|
iter.shape(),
|
||||||
|
iter.strides(0),
|
||||||
|
iter.strides(1),
|
||||||
|
iter.strides(2),
|
||||||
|
index_size,
|
||||||
|
index_stride,
|
||||||
|
ndim_nindiees);
|
||||||
|
mtl_dispatch1DJob(computeEncoder, indexSelectPSO, serial ? 1 : iter.numel());
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
static void index_kernel_mps(TensorIteratorBase& iter, IntArrayRef index_size, IntArrayRef index_stride) {
|
||||||
|
validateInputData(iter, index_size, index_stride, "index.Tensor_out", /*accumulate=*/false);
|
||||||
|
dispatch_index_kernel(
|
||||||
|
iter, index_size, index_stride, fmt::format("index_select_{}", getBitSizeString(iter.tensor_base(0))));
|
||||||
}
|
}
|
||||||
|
|
||||||
static void index_put_kernel_mps(TensorIterator& iter,
|
static void index_put_kernel_mps(TensorIterator& iter,
|
||||||
|
|
@ -248,7 +211,21 @@ static void index_put_kernel_mps(TensorIterator& iter,
|
||||||
bool accumulate) {
|
bool accumulate) {
|
||||||
@autoreleasepool {
|
@autoreleasepool {
|
||||||
validateInputData(iter, index_size, index_stride, "index_put_impl", accumulate);
|
validateInputData(iter, index_size, index_stride, "index_put_impl", accumulate);
|
||||||
dispatchIndexKernel(iter, index_size, index_stride, /*index_select=*/false, accumulate);
|
if (accumulate) {
|
||||||
|
dispatch_index_kernel(iter,
|
||||||
|
index_size,
|
||||||
|
index_stride,
|
||||||
|
fmt::format("index_put_accumulate_{}", scalarToMetalTypeString(iter.tensor_base(0))));
|
||||||
|
} else if (at::globalContext().deterministicAlgorithms()) {
|
||||||
|
dispatch_index_kernel(iter,
|
||||||
|
index_size,
|
||||||
|
index_stride,
|
||||||
|
fmt::format("index_put_serial_{}", getBitSizeString(iter.tensor_base(0))),
|
||||||
|
true);
|
||||||
|
} else {
|
||||||
|
dispatch_index_kernel(
|
||||||
|
iter, index_size, index_stride, fmt::format("index_put_{}", getBitSizeString(iter.tensor_base(0))));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // namespace mps
|
} // namespace mps
|
||||||
|
|
|
||||||
|
|
@ -70,5 +70,36 @@ struct AtomicType<bfloat> {
|
||||||
};
|
};
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// Metal supports atomic_store_explicit for bools, but
|
||||||
|
// sizeof(::metal::atomic_bool) is 4 Therefore it could not be used to
|
||||||
|
// atomically modify unaligned memory, so fall back to compare and exchange
|
||||||
|
// trick As accumulation over booleans are just or operation, do nothing if
|
||||||
|
// value is false
|
||||||
|
template <>
|
||||||
|
struct AtomicType<bool> {
|
||||||
|
using type = ::metal::atomic<uint>;
|
||||||
|
static inline void atomic_add(device type* data, long offset, bool value) {
|
||||||
|
if (!value) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto ptr = data + (offset >> 2);
|
||||||
|
auto old =
|
||||||
|
::metal::atomic_load_explicit(ptr, ::metal::memory_order_relaxed);
|
||||||
|
union {
|
||||||
|
uint i;
|
||||||
|
bool t[4];
|
||||||
|
} val;
|
||||||
|
do {
|
||||||
|
val.i = old;
|
||||||
|
val.t[offset & 3] = true;
|
||||||
|
} while (!::metal::atomic_compare_exchange_weak_explicit(
|
||||||
|
ptr,
|
||||||
|
&old,
|
||||||
|
val.i,
|
||||||
|
::metal::memory_order_relaxed,
|
||||||
|
::metal::memory_order_relaxed));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace metal
|
} // namespace metal
|
||||||
} // namespace c10
|
} // namespace c10
|
||||||
|
|
|
||||||
|
|
@ -158,6 +158,17 @@ def main() -> None:
|
||||||
if torch.backends.mps.is_macos_or_newer(14, 0):
|
if torch.backends.mps.is_macos_or_newer(14, 0):
|
||||||
dtypes.append(torch.bfloat16)
|
dtypes.append(torch.bfloat16)
|
||||||
|
|
||||||
|
# Profile index ops
|
||||||
|
B = 11
|
||||||
|
rc = []
|
||||||
|
for dtype, N in itertools.product(
|
||||||
|
[torch.int8, torch.float16, torch.float32], [50, 100, 500, 1000, 2000]
|
||||||
|
):
|
||||||
|
x = torch.testing.make_tensor((B, N, N), device="mps", dtype=dtype)
|
||||||
|
y = torch.randint(0, B, (3,))
|
||||||
|
rc.append(bench_binary_op(torch.Tensor.__getitem__, x, y, f"{B}x{N}x{N}"))
|
||||||
|
Compare(rc).print()
|
||||||
|
|
||||||
# Profile unary ops
|
# Profile unary ops
|
||||||
rc = []
|
rc = []
|
||||||
for op, dtype in itertools.product([torch.sqrt, torch.sin], dtypes):
|
for op, dtype in itertools.product([torch.sqrt, torch.sin], dtypes):
|
||||||
|
|
|
||||||
|
|
@ -7996,6 +7996,23 @@ class TestLargeTensors(TestCaseMPS):
|
||||||
rc_slice_cpu = (a.cpu() + b.cpu()[slice_idx:]).sin()
|
rc_slice_cpu = (a.cpu() + b.cpu()[slice_idx:]).sin()
|
||||||
self.assertEqual(rc_slice, rc_slice_cpu)
|
self.assertEqual(rc_slice, rc_slice_cpu)
|
||||||
|
|
||||||
|
@serialTest()
|
||||||
|
def test_64bit_index_select(self):
|
||||||
|
if torch.mps.recommended_max_memory() < 16_000_000_000:
|
||||||
|
raise unittest.SkipTest("Needs at least 16Gb of RAM")
|
||||||
|
B, N = 11, 20000
|
||||||
|
x = torch.empty(B, N, N, dtype=torch.float16, device='mps')
|
||||||
|
for i in range(B):
|
||||||
|
x[i] = 1.0 * i
|
||||||
|
batch_idx = torch.tensor([9], device='mps')
|
||||||
|
y = x[batch_idx]
|
||||||
|
self.assertEqual(y[0, 1, 2].item(), 9.0)
|
||||||
|
# Reclaim memory after running the tests
|
||||||
|
del y
|
||||||
|
del x
|
||||||
|
gc.collect()
|
||||||
|
torch.mps.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
class TestLogical(TestCaseMPS):
|
class TestLogical(TestCaseMPS):
|
||||||
def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False):
|
def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False):
|
||||||
|
|
|
||||||
|
|
@ -543,7 +543,6 @@ if torch.backends.mps.is_available():
|
||||||
"rounddecimals_0": [torch.bfloat16],
|
"rounddecimals_0": [torch.bfloat16],
|
||||||
# atomic operations not supported
|
# atomic operations not supported
|
||||||
"_unsafe_masked_index_put_accumulate": [
|
"_unsafe_masked_index_put_accumulate": [
|
||||||
torch.bool,
|
|
||||||
torch.int8,
|
torch.int8,
|
||||||
torch.uint8,
|
torch.uint8,
|
||||||
torch.int16,
|
torch.int16,
|
||||||
|
|
@ -644,7 +643,6 @@ if torch.backends.mps.is_available():
|
||||||
torch.bfloat16,
|
torch.bfloat16,
|
||||||
],
|
],
|
||||||
"index_put": [
|
"index_put": [
|
||||||
torch.bool,
|
|
||||||
torch.uint8,
|
torch.uint8,
|
||||||
torch.int8,
|
torch.int8,
|
||||||
torch.int16,
|
torch.int16,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user