mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Revert "Fix index_add for int64 input + zerodim index (#161511)"
This reverts commitd51486616c. Reverted https://github.com/pytorch/pytorch/pull/161511 on behalf of https://github.com/clee2000 due to broke test_indexing.py::TestIndexingCPU::test_index_add_zerodim_index_floating_alpha_cpu [GH job link](https://github.com/pytorch/pytorch/actions/runs/17257089116/job/48971728595) [HUD commit link](d51486616c) on dynamo? ([comment](https://github.com/pytorch/pytorch/pull/161511#issuecomment-3228705842))
This commit is contained in:
parent
378edb047f
commit
28af843ee0
|
|
@ -528,12 +528,10 @@ TORCH_IMPL_FUNC(index_add_mps_out)
|
|||
for (const auto i : c10::irange(dim)) {
|
||||
indices.emplace_back();
|
||||
}
|
||||
const auto&& index_ = (index.dim() == 0) ? index.view(1).to(at::kLong) : index.to(at::kLong);
|
||||
indices.emplace_back(index_);
|
||||
const auto&& result_ = (result.dim() == 0) ? result.view(1) : result;
|
||||
const auto&& source_ = (source.dim() == 0) ? source.view(1) : source;
|
||||
const auto&& alpha_ = at::scalar_tensor(alpha, source_.options());
|
||||
result_.index_put_(indices, source_.mul(alpha_), true);
|
||||
indices.emplace_back(index.to(at::kLong));
|
||||
const Tensor result_ = (result.dim() == 0) ? result.view(1) : result;
|
||||
const Tensor source_ = (source.dim() == 0) ? source.view(1) : source;
|
||||
result_.index_put_(indices, source_.mul(alpha), true);
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -2029,18 +2029,6 @@ class TestIndexing(TestCase):
|
|||
|
||||
self.assertEqual(output, input_list)
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
def test_index_add_zerodim_index_floating_alpha(self, device) -> None:
|
||||
# Regression test for https://github.com/pytorch/pytorch/issues/161446
|
||||
x = torch.ones([2, 3], dtype=torch.int64, device=device)
|
||||
index = torch.tensor(0, dtype=torch.int64, device=device)
|
||||
src = torch.full([1, 3], 2, dtype=torch.int64, device=device)
|
||||
alpha = 1.5
|
||||
x.index_add_(0, index, src, alpha=alpha)
|
||||
self.assertEqual(
|
||||
x, torch.tensor([[3, 3, 3], [1, 1, 1]], dtype=torch.int64, device=device)
|
||||
)
|
||||
|
||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
|
||||
@expectedFailureMPS
|
||||
def test_index_fill(self, device, dtype):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user