[MPS] enable cat op for sparse (#162007)

Enable cat op for sparse on MPS

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162007
Approved by: https://github.com/malfet
This commit is contained in:
Isalia20 2025-09-03 06:31:35 +00:00 committed by PyTorch MergeBot
parent f8ffa9194e
commit 2c03f0acc5
2 changed files with 2 additions and 2 deletions

View File

@ -1412,7 +1412,7 @@
- func: cat(Tensor[] tensors, int dim=0) -> Tensor
structured_delegate: cat.out
dispatch:
SparseCPU, SparseCUDA: cat_sparse
SparseCPU, SparseCUDA, SparseMPS: cat_sparse
QuantizedCPU: cat_quantized_cpu
NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: cat_nested
tags: core

View File

@ -1121,9 +1121,9 @@ class TestSparse(TestSparseBase):
x.sub_(2 * x)
self.assertLessEqual(x._nnz(), 10)
@expectedFailureMPS
@coalescedonoff
@dtypes(torch.double, torch.cdouble)
@dtypesIfMPS(torch.float32, torch.complex64)
def test_cat(self, device, dtype, coalesced):
# shapes: list of tuples (sparse_dims, nnz, sizes)
def test_shapes(shapes, dim, fail_message=None):