mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix 0-dim Index in Index Copy decomp (#117065)
Fix for https://github.com/pytorch/pytorch/issues/115931 Pull Request resolved: https://github.com/pytorch/pytorch/pull/117065 Approved by: https://github.com/jansel, https://github.com/shunting314
This commit is contained in:
parent
b9293e74a2
commit
d6540038c0
|
|
@ -535,6 +535,20 @@ class TestDecomp(TestCase):
|
|||
res = torch._decomp.decompositions.uniform(x, low=low, high=high)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
def test_broadcasting_index_copy(self, device):
|
||||
x = torch.zeros([1, 10], device=device)
|
||||
xs = torch.ones([2, 10], device=device)
|
||||
|
||||
def index_copy(xs, x):
|
||||
torch._decomp.decompositions.index_copy_(xs, 0, torch.tensor(0).to(device), x)
|
||||
|
||||
index_copy(xs, x)
|
||||
|
||||
xs_two = torch.ones([2, 10], device=device)
|
||||
xs_two[0] = x
|
||||
|
||||
self.assertEqual(xs, xs_two)
|
||||
|
||||
def test_rrelu_with_noise(self, device):
|
||||
# rrelu_with_noise behavior depends on a) whether elements in the input
|
||||
# are <= 0, and b) whether we're in training mode. Cover all cases:
|
||||
|
|
|
|||
|
|
@ -2333,6 +2333,7 @@ def _index_copy(
|
|||
# Treat scalars as elements of \R^1
|
||||
zero_dim = x.ndim == 0
|
||||
x1 = x.unsqueeze(0) if zero_dim else x
|
||||
index = index.unsqueeze(0) if index.ndim == 0 else index
|
||||
idx = (None,) * dim + (index,)
|
||||
index_put = aten.index_put_ if inplace else aten.index_put
|
||||
out = index_put(x1, idx, tensor)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user