Revert "Fix index_add for int64 input + zerodim index (#161511)"

This reverts commit d51486616c.

Reverted https://github.com/pytorch/pytorch/pull/161511 on behalf of https://github.com/clee2000 due to broke test_indexing.py::TestIndexingCPU::test_index_add_zerodim_index_floating_alpha_cpu [GH job link](https://github.com/pytorch/pytorch/actions/runs/17257089116/job/48971728595) [HUD commit link](d51486616c) on dynamo? ([comment](https://github.com/pytorch/pytorch/pull/161511#issuecomment-3228705842))
This commit is contained in:
PyTorch MergeBot 2025-08-27 15:38:11 +00:00
parent 378edb047f
commit 28af843ee0
2 changed files with 4 additions and 18 deletions

View File

@ -528,12 +528,10 @@ TORCH_IMPL_FUNC(index_add_mps_out)
for (const auto i : c10::irange(dim)) {
indices.emplace_back();
}
const auto&& index_ = (index.dim() == 0) ? index.view(1).to(at::kLong) : index.to(at::kLong);
indices.emplace_back(index_);
const auto&& result_ = (result.dim() == 0) ? result.view(1) : result;
const auto&& source_ = (source.dim() == 0) ? source.view(1) : source;
const auto&& alpha_ = at::scalar_tensor(alpha, source_.options());
result_.index_put_(indices, source_.mul(alpha_), true);
indices.emplace_back(index.to(at::kLong));
const Tensor result_ = (result.dim() == 0) ? result.view(1) : result;
const Tensor source_ = (source.dim() == 0) ? source.view(1) : source;
result_.index_put_(indices, source_.mul(alpha), true);
return;
}

View File

@ -2029,18 +2029,6 @@ class TestIndexing(TestCase):
self.assertEqual(output, input_list)
@onlyNativeDeviceTypes
def test_index_add_zerodim_index_floating_alpha(self, device) -> None:
# Regression test for https://github.com/pytorch/pytorch/issues/161446
x = torch.ones([2, 3], dtype=torch.int64, device=device)
index = torch.tensor(0, dtype=torch.int64, device=device)
src = torch.full([1, 3], 2, dtype=torch.int64, device=device)
alpha = 1.5
x.index_add_(0, index, src, alpha=alpha)
self.assertEqual(
x, torch.tensor([[3, 3, 3], [1, 1, 1]], dtype=torch.int64, device=device)
)
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
@expectedFailureMPS
def test_index_fill(self, device, dtype):