fix set item to scalar tensor missing gradient info

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78246

Approved by: https://github.com/ngimel
This commit is contained in:
yuguo68 2022-05-25 09:20:34 -07:00 committed by PyTorch MergeBot
parent e01fb9cd07
commit cd41c8f032
2 changed files with 12 additions and 2 deletions

View File

@ -353,8 +353,8 @@ static inline void copy_to(const Tensor& dst, const Tensor& src) {
// appear. Users can workaround that case by dst[index..] = src.reshape(..)
dst.copy_(src);
return;
} else if (src.sizes().size() == 0 && src.device().type() == at::kCPU) {
dst.fill_(src.item());
} else if (src.dim() == 0 && src.device().type() == at::kCPU) {
dst.fill_(src);
return;
}
auto src_view = src.view(slicePrefix1sSize(src.sizes()));

View File

@ -649,6 +649,16 @@ class TestIndexing(TestCase):
self.assertEqual(reference[[0, 123, 44488, 68807, 123343], ],
torch.tensor([0, 123, 44488, 68807, 123343], dtype=torch.int))
def test_set_item_to_scalar_tensor(self, device):
m = random.randint(1, 10)
n = random.randint(1, 10)
z = torch.randn([m, n], device=device)
a = 1.0
w = torch.tensor(a, requires_grad=True, device=device)
z[:, 0] = w
z.sum().backward()
self.assertEqual(w.grad, m * a)
def test_single_int(self, device):
v = torch.randn(5, 7, 3, device=device)
self.assertEqual(v[4].shape, (7, 3))