mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MPS] sparse add unary funcs + add for sparse tensors (#160839)
Adds several unary functions and add. Enables tests for unary functions in test_sparse but not enabling other tests yet, needs more ops before we fully migrate to testing SparseMPS with `test_sparse.py` Pull Request resolved: https://github.com/pytorch/pytorch/pull/160839 Approved by: https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
This commit is contained in:
parent
ebfee60101
commit
8627a19adf
|
|
@ -417,6 +417,7 @@ TORCH_IMPL_FUNC(sgn_out_mps)(const Tensor& self, const Tensor& output) {
|
||||||
|
|
||||||
Tensor& conj_physical_out_mps(const Tensor& self, Tensor& result) {
|
Tensor& conj_physical_out_mps(const Tensor& self, Tensor& result) {
|
||||||
TORCH_CHECK(self.is_complex());
|
TORCH_CHECK(self.is_complex());
|
||||||
|
TORCH_CHECK(self.dtype() != at::kComplexDouble);
|
||||||
mps::unary_op(self, result, "conj", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
mps::unary_op(self, result, "conj", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||||
return [mpsGraph conjugateWithTensor:inputTensor name:nil];
|
return [mpsGraph conjugateWithTensor:inputTensor name:nil];
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -340,8 +340,8 @@
|
||||||
variants: function, method
|
variants: function, method
|
||||||
dispatch:
|
dispatch:
|
||||||
CompositeExplicitAutograd: abs
|
CompositeExplicitAutograd: abs
|
||||||
SparseCPU, SparseCUDA: abs_sparse
|
SparseCPU, SparseCUDA, SparseMPS: abs_sparse
|
||||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: abs_sparse_csr
|
SparseCsrCPU, SparseCsrCUDA, SparseCsrMPS, SparseCsrMeta: abs_sparse_csr
|
||||||
NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_abs
|
NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_abs
|
||||||
tags: [core, pointwise]
|
tags: [core, pointwise]
|
||||||
|
|
||||||
|
|
@ -350,16 +350,16 @@
|
||||||
variants: function, method
|
variants: function, method
|
||||||
dispatch:
|
dispatch:
|
||||||
CompositeExplicitAutograd: abs_
|
CompositeExplicitAutograd: abs_
|
||||||
SparseCPU, SparseCUDA: abs_sparse_
|
SparseCPU, SparseCUDA, SparseMPS: abs_sparse_
|
||||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: abs_sparse_csr_
|
SparseCsrCPU, SparseCsrCUDA, SparseCsrMPS, SparseCsrMeta: abs_sparse_csr_
|
||||||
NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_abs_
|
NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_abs_
|
||||||
|
|
||||||
- func: abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
- func: abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
||||||
device_check: NoCheck # TensorIterator
|
device_check: NoCheck # TensorIterator
|
||||||
dispatch:
|
dispatch:
|
||||||
CPU, CUDA, MPS, MTIA: abs_out
|
CPU, CUDA, MPS, MTIA: abs_out
|
||||||
SparseCPU, SparseCUDA: abs_sparse_out
|
SparseCPU, SparseCUDA, SparseMPS: abs_sparse_out
|
||||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: abs_sparse_csr_out
|
SparseCsrCPU, SparseCsrCUDA, SparseCsrMPS, SparseCsrMeta: abs_sparse_csr_out
|
||||||
tags: pointwise
|
tags: pointwise
|
||||||
|
|
||||||
# Note [Adding an alias]
|
# Note [Adding an alias]
|
||||||
|
|
@ -476,7 +476,7 @@
|
||||||
variants: function, method
|
variants: function, method
|
||||||
dispatch:
|
dispatch:
|
||||||
CompositeExplicitAutograd: _conj_physical
|
CompositeExplicitAutograd: _conj_physical
|
||||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: conj_physical_sparse_csr
|
SparseCsrCPU, SparseCsrCUDA, SparseCsrMPS, SparseCsrMeta: conj_physical_sparse_csr
|
||||||
autogen: _conj_physical.out
|
autogen: _conj_physical.out
|
||||||
|
|
||||||
- func: conj_physical(Tensor self) -> Tensor
|
- func: conj_physical(Tensor self) -> Tensor
|
||||||
|
|
@ -487,8 +487,8 @@
|
||||||
dispatch:
|
dispatch:
|
||||||
CPU, CUDA: conj_physical_out
|
CPU, CUDA: conj_physical_out
|
||||||
MPS: conj_physical_out_mps
|
MPS: conj_physical_out_mps
|
||||||
SparseCPU, SparseCUDA: conj_physical_out_sparse
|
SparseCPU, SparseCUDA, SparseMPS: conj_physical_out_sparse
|
||||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: conj_physical_sparse_csr_out
|
SparseCsrCPU, SparseCsrCUDA, SparseCsrMPS, SparseCsrMeta: conj_physical_sparse_csr_out
|
||||||
tags: pointwise
|
tags: pointwise
|
||||||
|
|
||||||
- func: conj_physical_(Tensor(a!) self) -> Tensor(a!)
|
- func: conj_physical_(Tensor(a!) self) -> Tensor(a!)
|
||||||
|
|
@ -554,7 +554,7 @@
|
||||||
structured_delegate: add.out
|
structured_delegate: add.out
|
||||||
variants: function, method
|
variants: function, method
|
||||||
dispatch:
|
dispatch:
|
||||||
SparseCPU, SparseCUDA, SparseMeta: add_sparse
|
SparseCPU, SparseCUDA, SparseMPS, SparseMeta: add_sparse
|
||||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: add_sparse_csr
|
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: add_sparse_csr
|
||||||
MkldnnCPU: mkldnn_add
|
MkldnnCPU: mkldnn_add
|
||||||
ZeroTensor: add_zerotensor
|
ZeroTensor: add_zerotensor
|
||||||
|
|
@ -566,7 +566,7 @@
|
||||||
variants: method
|
variants: method
|
||||||
structured_delegate: add.out
|
structured_delegate: add.out
|
||||||
dispatch:
|
dispatch:
|
||||||
SparseCPU, SparseCUDA, SparseMeta: add_sparse_
|
SparseCPU, SparseCUDA, SparseMPS, SparseMeta: add_sparse_
|
||||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: add_sparse_csr_
|
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: add_sparse_csr_
|
||||||
MkldnnCPU: mkldnn_add_
|
MkldnnCPU: mkldnn_add_
|
||||||
NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_add__Tensor
|
NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_add__Tensor
|
||||||
|
|
@ -582,6 +582,7 @@
|
||||||
dispatch:
|
dispatch:
|
||||||
SparseCPU, SparseMeta: add_out_sparse_cpu
|
SparseCPU, SparseMeta: add_out_sparse_cpu
|
||||||
SparseCUDA: add_out_sparse_cuda
|
SparseCUDA: add_out_sparse_cuda
|
||||||
|
SparseMPS: add_out_sparse_mps
|
||||||
SparseCsrCPU, SparseCsrMeta: add_out_sparse_compressed_cpu
|
SparseCsrCPU, SparseCsrMeta: add_out_sparse_compressed_cpu
|
||||||
SparseCsrCUDA: add_out_sparse_compressed_cuda
|
SparseCsrCUDA: add_out_sparse_compressed_cuda
|
||||||
MkldnnCPU: mkldnn_add_out
|
MkldnnCPU: mkldnn_add_out
|
||||||
|
|
@ -2406,7 +2407,7 @@
|
||||||
MPS: empty_mps
|
MPS: empty_mps
|
||||||
Meta: empty_meta_symint
|
Meta: empty_meta_symint
|
||||||
MkldnnCPU: empty_mkldnn
|
MkldnnCPU: empty_mkldnn
|
||||||
SparseCPU, SparseCUDA: empty_sparse
|
SparseCPU, SparseCUDA, SparseMPS: empty_sparse
|
||||||
SparseMeta: empty_sparse_symint
|
SparseMeta: empty_sparse_symint
|
||||||
SparseCsrCPU, SparseCsrCUDA: empty_sparse_compressed
|
SparseCsrCPU, SparseCsrCUDA: empty_sparse_compressed
|
||||||
SparseCsrMeta: empty_sparse_compressed_symint
|
SparseCsrMeta: empty_sparse_compressed_symint
|
||||||
|
|
@ -6386,8 +6387,8 @@
|
||||||
device_check: NoCheck # TensorIterator
|
device_check: NoCheck # TensorIterator
|
||||||
variants: function, method
|
variants: function, method
|
||||||
dispatch:
|
dispatch:
|
||||||
SparseCPU, SparseCUDA: trunc_sparse
|
SparseCPU, SparseCUDA, SparseMPS: trunc_sparse
|
||||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: trunc_sparse_csr
|
SparseCsrCPU, SparseCsrCUDA, SparseCsrMPS, SparseCsrMeta: trunc_sparse_csr
|
||||||
tags: [core, pointwise]
|
tags: [core, pointwise]
|
||||||
|
|
||||||
- func: trunc_(Tensor(a!) self) -> Tensor(a!)
|
- func: trunc_(Tensor(a!) self) -> Tensor(a!)
|
||||||
|
|
@ -6395,8 +6396,8 @@
|
||||||
device_check: NoCheck # TensorIterator
|
device_check: NoCheck # TensorIterator
|
||||||
variants: function, method
|
variants: function, method
|
||||||
dispatch:
|
dispatch:
|
||||||
SparseCPU, SparseCUDA: trunc_sparse_
|
SparseCPU, SparseCUDA, SparseMPS: trunc_sparse_
|
||||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: trunc_sparse_csr_
|
SparseCsrCPU, SparseCsrCUDA, SparseCsrMPS, SparseCsrMeta: trunc_sparse_csr_
|
||||||
tags: pointwise
|
tags: pointwise
|
||||||
|
|
||||||
- func: trunc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
- func: trunc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
||||||
|
|
@ -6405,8 +6406,8 @@
|
||||||
device_check: NoCheck # TensorIterator
|
device_check: NoCheck # TensorIterator
|
||||||
dispatch:
|
dispatch:
|
||||||
CPU, CUDA, MPS: trunc_out
|
CPU, CUDA, MPS: trunc_out
|
||||||
SparseCPU, SparseCUDA: trunc_sparse_out
|
SparseCPU, SparseCUDA, SparseMPS: trunc_sparse_out
|
||||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: trunc_sparse_csr_out
|
SparseCsrCPU, SparseCsrCUDA, SparseCsrMPS, SparseCsrMeta: trunc_sparse_csr_out
|
||||||
tags: pointwise
|
tags: pointwise
|
||||||
# Alias for trunc
|
# Alias for trunc
|
||||||
|
|
||||||
|
|
@ -7368,8 +7369,8 @@
|
||||||
- func: _to_dense(Tensor self, ScalarType? dtype=None, bool? masked_grad=None) -> Tensor
|
- func: _to_dense(Tensor self, ScalarType? dtype=None, bool? masked_grad=None) -> Tensor
|
||||||
variants: method
|
variants: method
|
||||||
dispatch:
|
dispatch:
|
||||||
SparseCPU, SparseCUDA: sparse_to_dense
|
SparseCPU, SparseCUDA, SparseMPS: sparse_to_dense
|
||||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sparse_compressed_to_dense
|
SparseCsrCPU, SparseCsrCUDA, SparseCsrMPS, SparseCsrMeta: sparse_compressed_to_dense
|
||||||
MkldnnCPU: mkldnn_to_dense
|
MkldnnCPU: mkldnn_to_dense
|
||||||
autogen: _to_dense.out
|
autogen: _to_dense.out
|
||||||
|
|
||||||
|
|
@ -7395,8 +7396,8 @@
|
||||||
- func: dense_dim(Tensor self) -> int
|
- func: dense_dim(Tensor self) -> int
|
||||||
variants: method
|
variants: method
|
||||||
dispatch:
|
dispatch:
|
||||||
SparseCPU, SparseCUDA, SparseMeta: dense_dim_sparse
|
SparseCPU, SparseCUDA, SparseMPS, SparseMeta: dense_dim_sparse
|
||||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: dense_dim_sparse_csr
|
SparseCsrCPU, SparseCsrCUDA, SparseCsrMPS, SparseCsrMeta: dense_dim_sparse_csr
|
||||||
CompositeExplicitAutograd: dense_dim_default
|
CompositeExplicitAutograd: dense_dim_default
|
||||||
device_check: NoCheck
|
device_check: NoCheck
|
||||||
device_guard: False
|
device_guard: False
|
||||||
|
|
@ -7529,7 +7530,7 @@
|
||||||
device_check: NoCheck # Allows copy into different device
|
device_check: NoCheck # Allows copy into different device
|
||||||
variants: function
|
variants: function
|
||||||
dispatch:
|
dispatch:
|
||||||
SparseCPU, SparseCUDA, SparseMeta: copy_sparse_
|
SparseCPU, SparseCUDA, SparseMPS, SparseMeta: copy_sparse_
|
||||||
autogen: copy_sparse_to_sparse, copy_sparse_to_sparse.out
|
autogen: copy_sparse_to_sparse, copy_sparse_to_sparse.out
|
||||||
|
|
||||||
# By adding the AutogradNestedTensor this makes this function CompositeImplicit-like for nested tensors
|
# By adding the AutogradNestedTensor this makes this function CompositeImplicit-like for nested tensors
|
||||||
|
|
|
||||||
73
aten/src/ATen/native/sparse/mps/FlattenIndices.mm
Normal file
73
aten/src/ATen/native/sparse/mps/FlattenIndices.mm
Normal file
|
|
@ -0,0 +1,73 @@
|
||||||
|
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||||
|
#include <ATen/native/SparseTensorUtils.h>
|
||||||
|
#include <ATen/native/mps/OperationUtils.h>
|
||||||
|
#include <ATen/native/sparse/SparseStubs.h>
|
||||||
|
#include <ATen/native/sparse/FlattenIndicesCommon.h>
|
||||||
|
#include <ATen/ExpandUtils.h>
|
||||||
|
|
||||||
|
#ifndef AT_PER_OPERATOR_HEADERS
|
||||||
|
#include <ATen/Functions.h>
|
||||||
|
#include <ATen/NativeFunctions.h>
|
||||||
|
#else
|
||||||
|
#include <ATen/ops/_coalesce_native.h>
|
||||||
|
#include <ATen/ops/_sparse_coo_tensor_unsafe_native.h>
|
||||||
|
#include <ATen/ops/empty_native.h>
|
||||||
|
#include <ATen/ops/zeros_native.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace at::native {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using namespace mps;
|
||||||
|
using namespace at::sparse;
|
||||||
|
|
||||||
|
#ifndef PYTORCH_JIT_COMPILE_SHADERS
|
||||||
|
static auto& lib = mps::MetalShaderLibrary::getBundledLibrary();
|
||||||
|
#else
|
||||||
|
#include <ATen/native/mps/FlattenIndices_metallib.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
Tensor flatten_indices_mps(const Tensor& indices, IntArrayRef size) {
|
||||||
|
TORCH_CHECK(indices.dim() == 2, "flatten_indices: indices must be 2D");
|
||||||
|
TORCH_CHECK(static_cast<size_t>(indices.size(0)) == size.size(),
|
||||||
|
"flatten_indices: indices.size(0) must equal size.size()");
|
||||||
|
|
||||||
|
const int64_t sparse_dim = indices.size(0);
|
||||||
|
const int64_t nnz = indices.size(1);
|
||||||
|
|
||||||
|
if (nnz == 0) {
|
||||||
|
return at::empty({0}, indices.options().dtype(kLong));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Row-major multipliers for flattening: mul[d] = prod_{j>d}(size[j])
|
||||||
|
std::vector<int64_t> row_muls(sparse_dim);
|
||||||
|
row_muls[sparse_dim - 1] = 1;
|
||||||
|
for (int64_t i = sparse_dim - 2; i >= 0; --i) {
|
||||||
|
row_muls[i] = row_muls[i + 1] * size[i + 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
auto flat_indices = at::empty({nnz}, indices.options().dtype(kLong));
|
||||||
|
|
||||||
|
auto stream = getCurrentMPSStream();
|
||||||
|
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||||
|
@autoreleasepool {
|
||||||
|
auto pipeline = lib.getPipelineStateForFunc("flatten_indices_kernel");
|
||||||
|
auto encoder = stream->commandEncoder();
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
|
mtl_setArgs(encoder,
|
||||||
|
indices,
|
||||||
|
row_muls,
|
||||||
|
flat_indices,
|
||||||
|
static_cast<uint>(sparse_dim),
|
||||||
|
indices.strides()
|
||||||
|
);
|
||||||
|
|
||||||
|
mtl_dispatch1DJob(encoder, pipeline, nnz);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return flat_indices;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
REGISTER_MPS_DISPATCH(flatten_indices_stub, &flatten_indices_mps)
|
||||||
|
} // namespace at::native
|
||||||
|
|
@ -20,46 +20,9 @@ using namespace at::sparse;
|
||||||
#ifndef PYTORCH_JIT_COMPILE_SHADERS
|
#ifndef PYTORCH_JIT_COMPILE_SHADERS
|
||||||
static auto& lib = mps::MetalShaderLibrary::getBundledLibrary();
|
static auto& lib = mps::MetalShaderLibrary::getBundledLibrary();
|
||||||
#else
|
#else
|
||||||
#include <ATen/native/mps/Sparse_metallib.h>
|
#include <ATen/native/mps/Coalesce_metallib.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
static Tensor flatten_indices(const Tensor& indices, IntArrayRef size) {
|
|
||||||
|
|
||||||
TORCH_CHECK(indices.dim() == 2, "flatten_indices: indices must be 2D");
|
|
||||||
TORCH_CHECK(static_cast<size_t>(indices.size(0)) == size.size(),
|
|
||||||
"flatten_indices: indices.size(0) must equal size.size()");
|
|
||||||
|
|
||||||
int64_t sparse_dim = indices.size(0);
|
|
||||||
int64_t nnz = indices.size(1);
|
|
||||||
|
|
||||||
if (nnz == 0) {
|
|
||||||
return at::empty({0}, indices.options().dtype(kLong));
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<int64_t> strides(sparse_dim);
|
|
||||||
strides[sparse_dim - 1] = 1;
|
|
||||||
for (int64_t i = sparse_dim - 2; i >= 0; i--) {
|
|
||||||
strides[i] = strides[i + 1] * size[i + 1];
|
|
||||||
}
|
|
||||||
|
|
||||||
Tensor flat_indices = at::empty({nnz}, indices.options().dtype(kLong));
|
|
||||||
|
|
||||||
auto stream = getCurrentMPSStream();
|
|
||||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
|
||||||
@autoreleasepool {
|
|
||||||
auto pipeline = lib.getPipelineStateForFunc("flatten_indices_kernel");
|
|
||||||
auto encoder = stream->commandEncoder();
|
|
||||||
[encoder setComputePipelineState:pipeline];
|
|
||||||
|
|
||||||
mtl_setArgs(encoder, indices, strides, flat_indices, sparse_dim, nnz);
|
|
||||||
mtl_dispatch1DJob(encoder, pipeline, nnz);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
return flat_indices;
|
|
||||||
}
|
|
||||||
|
|
||||||
static Tensor compute_output_positions(const Tensor& is_unique) {
|
static Tensor compute_output_positions(const Tensor& is_unique) {
|
||||||
|
|
||||||
int64_t nnz = is_unique.size(0);
|
int64_t nnz = is_unique.size(0);
|
||||||
|
|
|
||||||
169
aten/src/ATen/native/sparse/mps/SparseMPSTensorMath.mm
Normal file
169
aten/src/ATen/native/sparse/mps/SparseMPSTensorMath.mm
Normal file
|
|
@ -0,0 +1,169 @@
|
||||||
|
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||||
|
#include <ATen/native/SparseTensorUtils.h>
|
||||||
|
#include <ATen/native/mps/OperationUtils.h>
|
||||||
|
|
||||||
|
#ifndef AT_PER_OPERATOR_HEADERS
|
||||||
|
#include <ATen/Functions.h>
|
||||||
|
#include <ATen/NativeFunctions.h>
|
||||||
|
#else
|
||||||
|
#include <ATen/ops/_coalesce_native.h>
|
||||||
|
#include <ATen/ops/_sparse_coo_tensor_unsafe_native.h>
|
||||||
|
#include <ATen/ops/cat.h>
|
||||||
|
#include <ATen/ops/add_native.h>
|
||||||
|
#include <ATen/ops/empty_native.h>
|
||||||
|
#include <ATen/ops/zeros_native.h>
|
||||||
|
#include <ATen/ops/result_type.h>
|
||||||
|
#include <ATen/ops/copy_sparse_to_sparse.h>
|
||||||
|
#include <ATen/ops/mul.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace at::native {
|
||||||
|
|
||||||
|
using namespace at::sparse;
|
||||||
|
|
||||||
|
Tensor& add_out_dense_sparse_mps(Tensor& out, const Tensor& dense, const SparseTensor& sparse, const Scalar& alpha);
|
||||||
|
|
||||||
|
Tensor& add_out_dense_sparse_mps(
|
||||||
|
Tensor& out,
|
||||||
|
const Tensor& dense,
|
||||||
|
const SparseTensor& sparse,
|
||||||
|
const Scalar& alpha) {
|
||||||
|
TORCH_CHECK(dense.is_mps(), "add: expected 'self' to be an MPS tensor, got ", dense.device());
|
||||||
|
TORCH_CHECK(sparse.is_mps(), "add: expected 'other' to be an MPS tensor, got ", sparse.device());
|
||||||
|
TORCH_CHECK(out.is_mps(), "add: expected 'out' to be an MPS tensor, got ", out.device());
|
||||||
|
TORCH_CHECK(dense.sizes().equals(sparse.sizes()),
|
||||||
|
"add: expected 'self' and 'other' to have same size, but self has size ",
|
||||||
|
dense.sizes(), " while other has size ", sparse.sizes(),
|
||||||
|
" (FYI: dense-sparse addition does not currently support broadcasting)");
|
||||||
|
|
||||||
|
const int64_t nnz = sparse._nnz();
|
||||||
|
if (nnz == 0) {
|
||||||
|
out.resize_as_(dense);
|
||||||
|
out.copy_(dense);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto commonDtype = at::result_type(dense, sparse);
|
||||||
|
TORCH_CHECK(canCast(commonDtype, out.scalar_type()),
|
||||||
|
"Can't convert result type ", commonDtype, " to output ", out.scalar_type());
|
||||||
|
|
||||||
|
Tensor r;
|
||||||
|
const bool need_separate_buffer = out.is_same(dense) || (out.scalar_type() != commonDtype);
|
||||||
|
if (need_separate_buffer) {
|
||||||
|
r = at::empty(dense.sizes(), out.options().dtype(commonDtype));
|
||||||
|
} else {
|
||||||
|
r = out;
|
||||||
|
r.resize_as_(dense);
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor dense_buffer = dense.to(commonDtype);
|
||||||
|
if (!r.is_same(dense_buffer)) {
|
||||||
|
r.copy_(dense_buffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor indices = sparse._indices();
|
||||||
|
Tensor values = sparse._values().to(commonDtype);
|
||||||
|
if (values.numel() == 0) {
|
||||||
|
if (!out.is_same(r)) {
|
||||||
|
out.resize_as_(dense);
|
||||||
|
out.copy_(r);
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t nDim = r.dim();
|
||||||
|
const int64_t nDimI = sparse.sparse_dim();
|
||||||
|
TORCH_CHECK(nDimI >= 0 && nDimI <= nDim,
|
||||||
|
"Invalid sparse_dim=", nDimI, " for dense tensor of dim ", nDim);
|
||||||
|
|
||||||
|
Tensor indices1D = at::sparse::flatten_indices(indices, sparse.sizes()).contiguous();
|
||||||
|
|
||||||
|
int64_t view_rows = 1;
|
||||||
|
int64_t view_cols = 1;
|
||||||
|
for (int64_t i = 0; i < nDimI; i++) {
|
||||||
|
view_rows *= r.size(i);
|
||||||
|
}
|
||||||
|
for (int64_t i = nDimI; i < nDim; i++) {
|
||||||
|
view_cols *= r.size(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (view_cols == 1) {
|
||||||
|
Tensor r_flat = r.reshape({view_rows});
|
||||||
|
Tensor values_1d = values.reshape({nnz});
|
||||||
|
r_flat.index_add_(0, indices1D, values_1d, alpha);
|
||||||
|
} else {
|
||||||
|
Tensor r_view = r.view({view_rows, view_cols});
|
||||||
|
Tensor values_2d = values.reshape({nnz, view_cols});
|
||||||
|
r_view.index_add_(0, indices1D, values_2d, alpha);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!out.is_same(r)) {
|
||||||
|
out.resize_as_(dense);
|
||||||
|
out.copy_(r);
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
SparseTensor& add_out_sparse_mps(const SparseTensor& self,
|
||||||
|
const SparseTensor& other,
|
||||||
|
const Scalar& alpha,
|
||||||
|
SparseTensor& out) {
|
||||||
|
TORCH_CHECK(other.is_sparse(), "add(sparse, dense) is not supported. Use add(dense, sparse) instead.");
|
||||||
|
TORCH_CHECK(self.is_mps(), "add: expected 'self' to be MPS, but got ", self.device());
|
||||||
|
TORCH_CHECK(other.is_mps(), "add: expected 'other' to be MPS, but got ", other.device());
|
||||||
|
TORCH_CHECK(out.is_mps(), "add: expected 'out' to be MPS, but got ", out.device());
|
||||||
|
if (!self.is_sparse()) {
|
||||||
|
return add_out_dense_sparse_mps(out, self, other, alpha);
|
||||||
|
}
|
||||||
|
auto commonDtype = at::result_type(self, other);
|
||||||
|
TORCH_CHECK(canCast(commonDtype, out.scalar_type()),
|
||||||
|
"Can't convert result type ", commonDtype, " to output ", out.scalar_type());
|
||||||
|
|
||||||
|
TORCH_CHECK(self.sizes().equals(other.sizes()),
|
||||||
|
"add: expected 'self' and 'other' to have same size, but ", self.sizes(), " != ", other.sizes());
|
||||||
|
|
||||||
|
TORCH_CHECK(is_same_density(self, other),
|
||||||
|
"add: expected 'self' and 'other' to have same density, but 'self' has ",
|
||||||
|
self.sparse_dim(), " sparse dimensions while 'other' has ", other.sparse_dim(), " sparse dimensions");
|
||||||
|
|
||||||
|
if (other._nnz() == 0) {
|
||||||
|
out.resize_as_(self);
|
||||||
|
Tensor vals = self._values();
|
||||||
|
if (vals.scalar_type() != out.scalar_type()) {
|
||||||
|
vals = vals.to(out.scalar_type());
|
||||||
|
}
|
||||||
|
alias_into_sparse(out, self._indices(), vals);
|
||||||
|
out._coalesced_(self.is_coalesced());
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor t_indices_ = self._indices();
|
||||||
|
Tensor s_indices_ = other._indices();
|
||||||
|
|
||||||
|
Tensor t_values_ = self._values().to(commonDtype);
|
||||||
|
Tensor s_values_ = other._values().to(commonDtype);
|
||||||
|
if (!alpha.isIntegral(false) || alpha.to<double>() != 1.0) {
|
||||||
|
s_values_ = at::mul(s_values_, alpha);
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor r_indices_ = at::cat({t_indices_, s_indices_}, 1);
|
||||||
|
Tensor r_values_ = at::cat({t_values_, s_values_ }, 0);
|
||||||
|
|
||||||
|
SparseTensor tmp = empty({0}, out.options().dtype(commonDtype));
|
||||||
|
tmp.resize_as_(other);
|
||||||
|
alias_into_sparse(tmp, r_indices_, r_values_);
|
||||||
|
tmp = _coalesce_sparse_mps(tmp);
|
||||||
|
|
||||||
|
out.resize_as_(other);
|
||||||
|
Tensor out_vals = tmp._values();
|
||||||
|
if (out.scalar_type() != commonDtype) {
|
||||||
|
out_vals = out_vals.to(out.scalar_type());
|
||||||
|
}
|
||||||
|
alias_into_sparse(out, tmp._indices(), out_vals);
|
||||||
|
out._coalesced_(tmp.is_coalesced());
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace at::native
|
||||||
|
|
@ -2,19 +2,6 @@
|
||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
kernel void flatten_indices_kernel(
|
|
||||||
device const int64_t* indices [[buffer(0)]],
|
|
||||||
device const int64_t* strides [[buffer(1)]],
|
|
||||||
device int64_t* flat_indices [[buffer(2)]],
|
|
||||||
constant uint& sparse_dim [[buffer(3)]],
|
|
||||||
constant uint& nnz [[buffer(4)]],
|
|
||||||
uint gid [[thread_position_in_grid]]) {
|
|
||||||
int64_t flat_idx = 0;
|
|
||||||
for (uint d = 0; d < sparse_dim; d++) {
|
|
||||||
flat_idx += indices[d * nnz + gid] * strides[d];
|
|
||||||
}
|
|
||||||
flat_indices[gid] = flat_idx;
|
|
||||||
}
|
|
||||||
|
|
||||||
kernel void compute_output_positions_kernel(
|
kernel void compute_output_positions_kernel(
|
||||||
device const bool* is_unique [[buffer(0)]],
|
device const bool* is_unique [[buffer(0)]],
|
||||||
19
aten/src/ATen/native/sparse/mps/kernels/FlattenIndices.metal
Normal file
19
aten/src/ATen/native/sparse/mps/kernels/FlattenIndices.metal
Normal file
|
|
@ -0,0 +1,19 @@
|
||||||
|
#include <metal_stdlib>
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
|
||||||
|
kernel void flatten_indices_kernel(
|
||||||
|
device const long* indices [[ buffer(0) ]],
|
||||||
|
device const long* row_muls [[ buffer(1) ]],
|
||||||
|
device long* flat_indices [[ buffer(2) ]],
|
||||||
|
constant uint& sparse_dim [[ buffer(3) ]],
|
||||||
|
constant long2& idx_strides [[ buffer(4) ]],
|
||||||
|
uint gid [[ thread_position_in_grid ]]) {
|
||||||
|
long flat = 0;
|
||||||
|
for (uint d = 0; d < sparse_dim; ++d) {
|
||||||
|
long off = (long)d * idx_strides.x + (long)gid * idx_strides.y;
|
||||||
|
long v = indices[off];
|
||||||
|
flat += v * row_muls[d];
|
||||||
|
}
|
||||||
|
flat_indices[gid] = flat;
|
||||||
|
}
|
||||||
|
|
@ -12885,6 +12885,100 @@ class TestSparseMPS(TestCaseMPS):
|
||||||
self.assertEqual(coalesced_mps._indices().cpu(), coalesced_cpu._indices())
|
self.assertEqual(coalesced_mps._indices().cpu(), coalesced_cpu._indices())
|
||||||
self.assertEqual(coalesced_mps._values().cpu(), coalesced_cpu._values())
|
self.assertEqual(coalesced_mps._values().cpu(), coalesced_cpu._values())
|
||||||
|
|
||||||
|
def test_sparse_add(self):
|
||||||
|
# Basic dense + sparse add
|
||||||
|
dense_mps = torch.zeros((2, 3), device="mps", dtype=torch.float32)
|
||||||
|
sparse_mps = self._get_basic_sparse_coo(device="mps")
|
||||||
|
|
||||||
|
dense_cpu = dense_mps.cpu()
|
||||||
|
sparse_cpu = torch.sparse_coo_tensor(
|
||||||
|
sparse_mps._indices().cpu(), sparse_mps._values().cpu(), sparse_mps.size(), device="cpu"
|
||||||
|
)
|
||||||
|
|
||||||
|
res_mps = torch.add(dense_mps, sparse_mps)
|
||||||
|
res_cpu = torch.add(dense_cpu, sparse_cpu)
|
||||||
|
self.assertEqual(res_mps.cpu(), res_cpu)
|
||||||
|
|
||||||
|
# alpha scaling (integral alpha)
|
||||||
|
res_mps = torch.add(dense_mps, sparse_mps, alpha=2)
|
||||||
|
res_cpu = torch.add(dense_cpu, sparse_cpu, alpha=2)
|
||||||
|
self.assertEqual(res_mps.cpu(), res_cpu)
|
||||||
|
|
||||||
|
# alpha scaling (float alpha) with random dense
|
||||||
|
dense2_mps = torch.randn((2, 3), device="mps", dtype=torch.float32)
|
||||||
|
dense2_cpu = dense2_mps.cpu()
|
||||||
|
res_mps = torch.add(dense2_mps, sparse_mps, alpha=0.5)
|
||||||
|
res_cpu = torch.add(dense2_cpu, sparse_cpu, alpha=0.5)
|
||||||
|
self.assertEqual(res_mps.cpu(), res_cpu)
|
||||||
|
|
||||||
|
# nnz == 0 fast-path
|
||||||
|
empty_indices_mps = torch.zeros((2, 0), dtype=torch.int64, device="mps")
|
||||||
|
empty_values_mps = torch.tensor([], dtype=torch.float32, device="mps")
|
||||||
|
empty_sparse_mps = torch.sparse_coo_tensor(empty_indices_mps, empty_values_mps, (2, 3), device="mps")
|
||||||
|
|
||||||
|
empty_indices_cpu = empty_indices_mps.cpu()
|
||||||
|
empty_values_cpu = empty_values_mps.cpu()
|
||||||
|
empty_sparse_cpu = torch.sparse_coo_tensor(empty_indices_cpu, empty_values_cpu, (2, 3), device="cpu")
|
||||||
|
|
||||||
|
res_mps = torch.add(dense2_mps, empty_sparse_mps)
|
||||||
|
res_cpu = torch.add(dense2_cpu, empty_sparse_cpu)
|
||||||
|
self.assertEqual(res_mps.cpu(), res_cpu)
|
||||||
|
|
||||||
|
# 3D case to exercise view_cols > 1 path (values are 2D)
|
||||||
|
indices3_mps = torch.tensor([[0, 1], [2, 0]], dtype=torch.int64, device="mps")
|
||||||
|
values3_mps = torch.tensor([[1., 2., 3., 4.], [5., 6., 7., 8.]], dtype=torch.float32, device="mps")
|
||||||
|
size3 = (2, 3, 4)
|
||||||
|
sp3_mps = torch.sparse_coo_tensor(indices3_mps, values3_mps, size3, device="mps")
|
||||||
|
dense3_mps = torch.randn(size3, device="mps", dtype=torch.float32)
|
||||||
|
|
||||||
|
indices3_cpu = indices3_mps.cpu()
|
||||||
|
values3_cpu = values3_mps.cpu()
|
||||||
|
sp3_cpu = torch.sparse_coo_tensor(indices3_cpu, values3_cpu, size3, device="cpu")
|
||||||
|
dense3_cpu = dense3_mps.cpu()
|
||||||
|
|
||||||
|
res_mps = torch.add(dense3_mps, sp3_mps, alpha=1.0)
|
||||||
|
res_cpu = torch.add(dense3_cpu, sp3_cpu, alpha=1.0)
|
||||||
|
self.assertEqual(res_mps.cpu(), res_cpu)
|
||||||
|
|
||||||
|
# dtype promotion: dense float32 + sparse float16
|
||||||
|
sparse_f16_mps = torch.sparse_coo_tensor(
|
||||||
|
sparse_mps._indices(),
|
||||||
|
sparse_mps._values().to(torch.float16),
|
||||||
|
sparse_mps.size(),
|
||||||
|
device="mps",
|
||||||
|
)
|
||||||
|
sparse_f16_cpu = torch.sparse_coo_tensor(
|
||||||
|
sparse_f16_mps._indices().cpu(),
|
||||||
|
sparse_f16_mps._values().cpu(),
|
||||||
|
sparse_f16_mps.size(),
|
||||||
|
device="cpu",
|
||||||
|
)
|
||||||
|
res_mps = torch.add(dense2_mps, sparse_f16_mps, alpha=0.25)
|
||||||
|
res_cpu = torch.add(dense2_cpu, sparse_f16_cpu, alpha=0.25)
|
||||||
|
self.assertEqual(res_mps.cpu(), res_cpu)
|
||||||
|
|
||||||
|
# broadcasting not supported: mismatched size should error
|
||||||
|
bad_sparse_mps = torch.sparse_coo_tensor(
|
||||||
|
sparse_mps._indices(), sparse_mps._values(), (2, 4), device="mps"
|
||||||
|
)
|
||||||
|
with self.assertRaisesRegex(RuntimeError, "same size"):
|
||||||
|
torch.add(dense_mps, bad_sparse_mps)
|
||||||
|
|
||||||
|
# sparse + sparse with overlap (tests concatenation + coalesce + alpha)
|
||||||
|
s1_idx = torch.tensor([[0, 0, 1], [0, 0, 2]], dtype=torch.int64)
|
||||||
|
s1_val = torch.tensor([1., 2., 3.], dtype=torch.float32)
|
||||||
|
s2_idx = torch.tensor([[0, 1, 1], [0, 2, 2]], dtype=torch.int64)
|
||||||
|
s2_val = torch.tensor([4., 5., 6.], dtype=torch.float32)
|
||||||
|
|
||||||
|
s1_mps = torch.sparse_coo_tensor(s1_idx.to("mps"), s1_val.to("mps"), (2, 3), device="mps")
|
||||||
|
s2_mps = torch.sparse_coo_tensor(s2_idx.to("mps"), s2_val.to("mps"), (2, 3), device="mps")
|
||||||
|
s1_cpu = torch.sparse_coo_tensor(s1_idx, s1_val, (2, 3), device="cpu")
|
||||||
|
s2_cpu = torch.sparse_coo_tensor(s2_idx, s2_val, (2, 3), device="cpu")
|
||||||
|
|
||||||
|
sp_res_mps = torch.add(s1_mps, s2_mps, alpha=2.0).coalesce()
|
||||||
|
sp_res_cpu = torch.add(s1_cpu, s2_cpu, alpha=2.0).coalesce()
|
||||||
|
self.assertEqual(sp_res_mps.cpu(), sp_res_cpu)
|
||||||
|
|
||||||
|
|
||||||
# TODO: Actually instantiate that test for the "mps" device to better reflect what it is doing.
|
# TODO: Actually instantiate that test for the "mps" device to better reflect what it is doing.
|
||||||
# This requires mps to be properly registered in the device generic test framework which is not the
|
# This requires mps to be properly registered in the device generic test framework which is not the
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ from torch.testing._internal.common_utils import TestCase, run_tests, do_test_dt
|
||||||
parametrize, subtest, is_coalesced_indices, suppress_warnings, instantiate_parametrized_tests, \
|
parametrize, subtest, is_coalesced_indices, suppress_warnings, instantiate_parametrized_tests, \
|
||||||
skipIfCrossRef
|
skipIfCrossRef
|
||||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||||
|
from torch.testing._internal.common_mps import mps_ops_modifier
|
||||||
from numbers import Number
|
from numbers import Number
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
@ -42,7 +43,6 @@ def _op_supports_any_sparse(op):
|
||||||
or op.supports_sparse_bsc)
|
or op.supports_sparse_bsc)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
reduction_ops_with_sparse_support = [
|
reduction_ops_with_sparse_support = [
|
||||||
op for op in reduction_ops if 'masked.' not in op.name and
|
op for op in reduction_ops if 'masked.' not in op.name and
|
||||||
_op_supports_any_sparse(op) and not isinstance(op, ReductionPythonRefInfo)]
|
_op_supports_any_sparse(op) and not isinstance(op, ReductionPythonRefInfo)]
|
||||||
|
|
@ -4126,7 +4126,7 @@ def _sparse_to_dense(tensor):
|
||||||
return tensor.to(torch.int8).to_dense().to(torch.bool)
|
return tensor.to(torch.int8).to_dense().to(torch.bool)
|
||||||
|
|
||||||
|
|
||||||
_sparse_unary_ops = ops(sparse_unary_ufuncs, dtypes=OpDTypes.supported,
|
_sparse_unary_ops = ops(mps_ops_modifier(sparse_unary_ufuncs, sparse=True), dtypes=OpDTypes.supported,
|
||||||
allowed_dtypes=all_types_and_complex())
|
allowed_dtypes=all_types_and_complex())
|
||||||
class TestSparseUnaryUfuncs(TestCase):
|
class TestSparseUnaryUfuncs(TestCase):
|
||||||
exact_dtype = True
|
exact_dtype = True
|
||||||
|
|
@ -4178,8 +4178,8 @@ class TestSparseUnaryUfuncs(TestCase):
|
||||||
@_sparse_unary_ops
|
@_sparse_unary_ops
|
||||||
def test_sparse_zero_dims(self, device, dtype, op):
|
def test_sparse_zero_dims(self, device, dtype, op):
|
||||||
# test 0x0 sparse_coo_tensor
|
# test 0x0 sparse_coo_tensor
|
||||||
indices = torch.empty(2, 0, dtype=torch.int64)
|
indices = torch.empty(2, 0, dtype=torch.int64, device=device)
|
||||||
values = torch.empty(0, dtype=dtype)
|
values = torch.empty(0, dtype=dtype, device=device)
|
||||||
sparse_0x0 = torch.sparse_coo_tensor(indices, values, (0, 0))
|
sparse_0x0 = torch.sparse_coo_tensor(indices, values, (0, 0))
|
||||||
expected = torch.sparse_coo_tensor(indices, op(values), (0, 0))
|
expected = torch.sparse_coo_tensor(indices, op(values), (0, 0))
|
||||||
actual = op(sparse_0x0)
|
actual = op(sparse_0x0)
|
||||||
|
|
@ -5526,7 +5526,7 @@ class TestSparseAny(TestCase):
|
||||||
|
|
||||||
|
|
||||||
# e.g., TestSparseUnaryUfuncsCPU and TestSparseUnaryUfuncsCUDA
|
# e.g., TestSparseUnaryUfuncsCPU and TestSparseUnaryUfuncsCUDA
|
||||||
instantiate_device_type_tests(TestSparseUnaryUfuncs, globals(), except_for='meta')
|
instantiate_device_type_tests(TestSparseUnaryUfuncs, globals(), allow_mps=True, except_for='meta')
|
||||||
|
|
||||||
instantiate_device_type_tests(TestSparseMaskedReductions, globals(), except_for='meta')
|
instantiate_device_type_tests(TestSparseMaskedReductions, globals(), except_for='meta')
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,8 +12,9 @@ if torch.backends.mps.is_available():
|
||||||
|
|
||||||
def mps_ops_modifier(
|
def mps_ops_modifier(
|
||||||
ops: Sequence[OpInfo],
|
ops: Sequence[OpInfo],
|
||||||
device_type: Optional[str] = None,
|
device_type: str = "mps",
|
||||||
xfail_exclusion: Optional[list[str]] = None,
|
xfail_exclusion: Optional[list[str]] = None,
|
||||||
|
sparse: bool = False,
|
||||||
) -> Sequence[OpInfo]:
|
) -> Sequence[OpInfo]:
|
||||||
if xfail_exclusion is None:
|
if xfail_exclusion is None:
|
||||||
xfail_exclusion = []
|
xfail_exclusion = []
|
||||||
|
|
@ -294,7 +295,7 @@ if torch.backends.mps.is_available():
|
||||||
}
|
}
|
||||||
|
|
||||||
# Those ops are not expected to work
|
# Those ops are not expected to work
|
||||||
UNIMPLEMENTED_XFAILLIST = {
|
UNIMPLEMENTED_XFAILLIST: dict[str, Optional[list]] = {
|
||||||
# Failures due to lack of op implementation on MPS backend
|
# Failures due to lack of op implementation on MPS backend
|
||||||
"logspace": None,
|
"logspace": None,
|
||||||
"logspacetensor_overload": None,
|
"logspacetensor_overload": None,
|
||||||
|
|
@ -440,6 +441,42 @@ if torch.backends.mps.is_available():
|
||||||
torch.int8,
|
torch.int8,
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
UNIMPLEMENTED_XFAILLIST_SPARSE: dict[str, Optional[list]] = {
|
||||||
|
"logspace": None,
|
||||||
|
"logspacetensor_overload": None,
|
||||||
|
"linalg.eig": None,
|
||||||
|
"linalg.eigvals": None,
|
||||||
|
"put": None,
|
||||||
|
"deg2rad": None,
|
||||||
|
"erf": None,
|
||||||
|
"expm1": None,
|
||||||
|
"floor": None,
|
||||||
|
"frac": None,
|
||||||
|
"isneginf": None,
|
||||||
|
"isposinf": None,
|
||||||
|
"log1p": None,
|
||||||
|
"nan_to_num": None,
|
||||||
|
"neg": None,
|
||||||
|
"rad2deg": None,
|
||||||
|
"round": None,
|
||||||
|
"sgn": None,
|
||||||
|
"sign": None,
|
||||||
|
"signbit": None,
|
||||||
|
"sin": None,
|
||||||
|
"sinh": None,
|
||||||
|
"sqrt": None,
|
||||||
|
"tan": None,
|
||||||
|
"tanh": None,
|
||||||
|
"asinh": None,
|
||||||
|
"asin": None,
|
||||||
|
"isnan": None,
|
||||||
|
"isinf": None,
|
||||||
|
"atan": None,
|
||||||
|
"atanh": None,
|
||||||
|
"ceil": None,
|
||||||
|
"relu": None,
|
||||||
|
"nn.functional.relu": None,
|
||||||
|
}
|
||||||
|
|
||||||
if MACOS_VERSION < 15.0:
|
if MACOS_VERSION < 15.0:
|
||||||
UNIMPLEMENTED_XFAILLIST.update(
|
UNIMPLEMENTED_XFAILLIST.update(
|
||||||
|
|
@ -448,8 +485,10 @@ if torch.backends.mps.is_available():
|
||||||
"nanquantile": None,
|
"nanquantile": None,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
if sparse:
|
||||||
|
UNIMPLEMENTED_XFAILLIST.update(UNIMPLEMENTED_XFAILLIST_SPARSE)
|
||||||
|
|
||||||
UNDEFINED_XFAILLIST = {
|
UNDEFINED_XFAILLIST: dict[str, Optional[list]] = {
|
||||||
# Top 60 operators
|
# Top 60 operators
|
||||||
# topk fails with duplicate indices
|
# topk fails with duplicate indices
|
||||||
"topk": [
|
"topk": [
|
||||||
|
|
@ -526,7 +565,7 @@ if torch.backends.mps.is_available():
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
ON_MPS_XFAILLIST = {
|
ON_MPS_XFAILLIST: dict[str, Optional[list]] = {
|
||||||
# Failures due to lack of implementation of downstream functions on MPS backend
|
# Failures due to lack of implementation of downstream functions on MPS backend
|
||||||
# TODO: remove these once downstream function 'aten::_linalg_svd.U' have been implemented
|
# TODO: remove these once downstream function 'aten::_linalg_svd.U' have been implemented
|
||||||
"linalg.matrix_rank": None,
|
"linalg.matrix_rank": None,
|
||||||
|
|
@ -590,6 +629,11 @@ if torch.backends.mps.is_available():
|
||||||
# precision types. So we have to skip these for now.
|
# precision types. So we have to skip these for now.
|
||||||
"grid_sampler_3d": [torch.float16, torch.bfloat16],
|
"grid_sampler_3d": [torch.float16, torch.bfloat16],
|
||||||
}
|
}
|
||||||
|
SKIPLIST_SPARSE = {
|
||||||
|
# Skipped due to test_sparse_zero_dims test in test_sparse.py which allocates empty tensor
|
||||||
|
# and does basically a no-op op(positive), which leads to unexpected success
|
||||||
|
"positive": [torch.complex128],
|
||||||
|
}
|
||||||
|
|
||||||
def addDecorator(op: OpInfo, d: DecorateInfo) -> None:
|
def addDecorator(op: OpInfo, d: DecorateInfo) -> None:
|
||||||
if device_type is not None:
|
if device_type is not None:
|
||||||
|
|
@ -599,6 +643,28 @@ if torch.backends.mps.is_available():
|
||||||
|
|
||||||
for op in ops:
|
for op in ops:
|
||||||
key = op.name + op.variant_test_name
|
key = op.name + op.variant_test_name
|
||||||
|
addDecorator(
|
||||||
|
op,
|
||||||
|
DecorateInfo(
|
||||||
|
unittest.expectedFailure,
|
||||||
|
dtypes=[
|
||||||
|
torch.double,
|
||||||
|
torch.cdouble,
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if sparse and op.name in SKIPLIST_SPARSE:
|
||||||
|
addDecorator(
|
||||||
|
op,
|
||||||
|
DecorateInfo(
|
||||||
|
unittest.skip(
|
||||||
|
"Skipped due to MPS not supporting complex128 tensors"
|
||||||
|
),
|
||||||
|
dtypes=[
|
||||||
|
torch.complex128,
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
if key in EMPTY_OPS_SKIPLIST:
|
if key in EMPTY_OPS_SKIPLIST:
|
||||||
addDecorator(
|
addDecorator(
|
||||||
op,
|
op,
|
||||||
|
|
@ -805,3 +871,12 @@ if torch.backends.mps.is_available():
|
||||||
addDecorator(op, DecorateInfo(unittest.expectedFailure))
|
addDecorator(op, DecorateInfo(unittest.expectedFailure))
|
||||||
|
|
||||||
return ops
|
return ops
|
||||||
|
else:
|
||||||
|
|
||||||
|
def mps_ops_modifier(
|
||||||
|
ops: Sequence[OpInfo],
|
||||||
|
device_type: str = "mps",
|
||||||
|
xfail_exclusion: Optional[list[str]] = None,
|
||||||
|
sparse: bool = False,
|
||||||
|
) -> Sequence[OpInfo]:
|
||||||
|
return ops
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user