mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MPS] enable cat op for sparse (#162007)
Enable cat op for sparse on MPS Pull Request resolved: https://github.com/pytorch/pytorch/pull/162007 Approved by: https://github.com/malfet
This commit is contained in:
parent
f8ffa9194e
commit
2c03f0acc5
|
|
@ -1412,7 +1412,7 @@
|
|||
- func: cat(Tensor[] tensors, int dim=0) -> Tensor
|
||||
structured_delegate: cat.out
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA: cat_sparse
|
||||
SparseCPU, SparseCUDA, SparseMPS: cat_sparse
|
||||
QuantizedCPU: cat_quantized_cpu
|
||||
NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: cat_nested
|
||||
tags: core
|
||||
|
|
|
|||
|
|
@ -1121,9 +1121,9 @@ class TestSparse(TestSparseBase):
|
|||
x.sub_(2 * x)
|
||||
self.assertLessEqual(x._nnz(), 10)
|
||||
|
||||
@expectedFailureMPS
|
||||
@coalescedonoff
|
||||
@dtypes(torch.double, torch.cdouble)
|
||||
@dtypesIfMPS(torch.float32, torch.complex64)
|
||||
def test_cat(self, device, dtype, coalesced):
|
||||
# shapes: list of tuples (sparse_dims, nnz, sizes)
|
||||
def test_shapes(shapes, dim, fail_message=None):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user