Fix 0-dim Index in Index Copy decomp (#117065)

Fix for https://github.com/pytorch/pytorch/issues/115931

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117065
Approved by: https://github.com/jansel, https://github.com/shunting314
This commit is contained in:
Elias Ellison 2024-01-10 00:21:15 +00:00 committed by PyTorch MergeBot
parent b9293e74a2
commit d6540038c0
2 changed files with 15 additions and 0 deletions

View File

@ -535,6 +535,20 @@ class TestDecomp(TestCase):
res = torch._decomp.decompositions.uniform(x, low=low, high=high)
self.assertEqual(ref, res)
def test_broadcasting_index_copy(self, device):
x = torch.zeros([1, 10], device=device)
xs = torch.ones([2, 10], device=device)
def index_copy(xs, x):
torch._decomp.decompositions.index_copy_(xs, 0, torch.tensor(0).to(device), x)
index_copy(xs, x)
xs_two = torch.ones([2, 10], device=device)
xs_two[0] = x
self.assertEqual(xs, xs_two)
def test_rrelu_with_noise(self, device):
# rrelu_with_noise behavior depends on a) whether elements in the input
# are <= 0, and b) whether we're in training mode. Cover all cases:

View File

@ -2333,6 +2333,7 @@ def _index_copy(
# Treat scalars as elements of \R^1
zero_dim = x.ndim == 0
x1 = x.unsqueeze(0) if zero_dim else x
index = index.unsqueeze(0) if index.ndim == 0 else index
idx = (None,) * dim + (index,)
index_put = aten.index_put_ if inplace else aten.index_put
out = index_put(x1, idx, tensor)