From 28af843ee0ea79867b7fd4ddc5bd0072d6518f3a Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 27 Aug 2025 15:38:11 +0000 Subject: [PATCH] Revert "Fix index_add for int64 input + zerodim index (#161511)" This reverts commit d51486616cb3fe54bc298669a88059be56c1fb22. Reverted https://github.com/pytorch/pytorch/pull/161511 on behalf of https://github.com/clee2000 due to broke test_indexing.py::TestIndexingCPU::test_index_add_zerodim_index_floating_alpha_cpu [GH job link](https://github.com/pytorch/pytorch/actions/runs/17257089116/job/48971728595) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/d51486616cb3fe54bc298669a88059be56c1fb22) on dynamo? ([comment](https://github.com/pytorch/pytorch/pull/161511#issuecomment-3228705842)) --- aten/src/ATen/native/mps/operations/Indexing.mm | 10 ++++------ test/test_indexing.py | 12 ------------ 2 files changed, 4 insertions(+), 18 deletions(-) diff --git a/aten/src/ATen/native/mps/operations/Indexing.mm b/aten/src/ATen/native/mps/operations/Indexing.mm index 70afff34693..fa19d2f4d12 100644 --- a/aten/src/ATen/native/mps/operations/Indexing.mm +++ b/aten/src/ATen/native/mps/operations/Indexing.mm @@ -528,12 +528,10 @@ TORCH_IMPL_FUNC(index_add_mps_out) for (const auto i : c10::irange(dim)) { indices.emplace_back(); } - const auto&& index_ = (index.dim() == 0) ? index.view(1).to(at::kLong) : index.to(at::kLong); - indices.emplace_back(index_); - const auto&& result_ = (result.dim() == 0) ? result.view(1) : result; - const auto&& source_ = (source.dim() == 0) ? source.view(1) : source; - const auto&& alpha_ = at::scalar_tensor(alpha, source_.options()); - result_.index_put_(indices, source_.mul(alpha_), true); + indices.emplace_back(index.to(at::kLong)); + const Tensor result_ = (result.dim() == 0) ? result.view(1) : result; + const Tensor source_ = (source.dim() == 0) ? source.view(1) : source; + result_.index_put_(indices, source_.mul(alpha), true); return; } diff --git a/test/test_indexing.py b/test/test_indexing.py index 8b3915685de..7a202efbe08 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -2029,18 +2029,6 @@ class TestIndexing(TestCase): self.assertEqual(output, input_list) - @onlyNativeDeviceTypes - def test_index_add_zerodim_index_floating_alpha(self, device) -> None: - # Regression test for https://github.com/pytorch/pytorch/issues/161446 - x = torch.ones([2, 3], dtype=torch.int64, device=device) - index = torch.tensor(0, dtype=torch.int64, device=device) - src = torch.full([1, 3], 2, dtype=torch.int64, device=device) - alpha = 1.5 - x.index_add_(0, index, src, alpha=alpha) - self.assertEqual( - x, torch.tensor([[3, 3, 3], [1, 1, 1]], dtype=torch.int64, device=device) - ) - @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) @expectedFailureMPS def test_index_fill(self, device, dtype):