From c8bb0e4720ddddf3cd1b0b48b336978f763c71ca Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Fri, 22 Aug 2025 07:10:56 -0700 Subject: [PATCH] [MPS] Fix `index_copy` for scalars (#161267) By `squeezing the input` when copying into scalar tensor from a 1d one And enable `test_index_copy_scalars_mps` Fixes https://github.com/pytorch/pytorch/issues/160737 Pull Request resolved: https://github.com/pytorch/pytorch/pull/161267 Approved by: https://github.com/manuelcandales, https://github.com/Skylion007, https://github.com/dcci ghstack dependencies: #161206 --- aten/src/ATen/native/mps/operations/Indexing.mm | 7 ++++--- test/test_indexing.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/native/mps/operations/Indexing.mm b/aten/src/ATen/native/mps/operations/Indexing.mm index c48fc5fc2aa..3ab0cd95346 100644 --- a/aten/src/ATen/native/mps/operations/Indexing.mm +++ b/aten/src/ATen/native/mps/operations/Indexing.mm @@ -230,7 +230,7 @@ TORCH_IMPL_FUNC(index_copy_out_mps)(const Tensor& self, index.numel()); int64_t idx = index.item(); TORCH_CHECK(idx == 0, "index_copy_(): the only valid index for a 0-dim tensor is 0, but got ", idx); - result.copy_(source); + result.copy_(source.squeeze()); return; } @@ -254,11 +254,12 @@ TORCH_IMPL_FUNC(index_copy_out_mps)(const Tensor& self, } } - TORCH_CHECK(source.size(dim) == index.numel(), + const auto source_size_dim = source.dim() > 0 ? source.size(dim) : 1; + TORCH_CHECK(index.numel() == source_size_dim, "index_copy_(): Number of indices (", index.numel(), ") should be equal to source.size(dim) (", - source.size(dim), + source_size_dim, ")"); auto stream = getCurrentMPSStream(); diff --git a/test/test_indexing.py b/test/test_indexing.py index 488ecae59c0..00b539d069f 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -1913,8 +1913,8 @@ class TestIndexing(TestCase): # onlyNativeDeviceTypes due to an XLA error: # https://github.com/pytorch/pytorch/issues/53256 @onlyNativeDeviceTypes - @expectedFailureMPS # See https://github.com/pytorch/pytorch/issues/160737 @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) + @dtypesIfMPS(*all_mps_types_and(torch.bool, torch.cfloat)) def test_index_copy_scalars(self, device, dtype): # Create the 8 possible combinations of scalar sizes for target / index / source scalars = (