From cd41c8f032dd06c445bf97fc76fb82008b19afcb Mon Sep 17 00:00:00 2001 From: yuguo68 Date: Wed, 25 May 2022 09:20:34 -0700 Subject: [PATCH] fix set item to scalar tensor missing gradient info Pull Request resolved: https://github.com/pytorch/pytorch/pull/78246 Approved by: https://github.com/ngimel --- aten/src/ATen/TensorIndexing.h | 4 ++-- test/test_indexing.py | 10 ++++++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/TensorIndexing.h b/aten/src/ATen/TensorIndexing.h index 8352b510f60..43589580c53 100644 --- a/aten/src/ATen/TensorIndexing.h +++ b/aten/src/ATen/TensorIndexing.h @@ -353,8 +353,8 @@ static inline void copy_to(const Tensor& dst, const Tensor& src) { // appear. Users can workaround that case by dst[index..] = src.reshape(..) dst.copy_(src); return; - } else if (src.sizes().size() == 0 && src.device().type() == at::kCPU) { - dst.fill_(src.item()); + } else if (src.dim() == 0 && src.device().type() == at::kCPU) { + dst.fill_(src); return; } auto src_view = src.view(slicePrefix1sSize(src.sizes())); diff --git a/test/test_indexing.py b/test/test_indexing.py index 9d2d82e9f12..0d1022bc24e 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -649,6 +649,16 @@ class TestIndexing(TestCase): self.assertEqual(reference[[0, 123, 44488, 68807, 123343], ], torch.tensor([0, 123, 44488, 68807, 123343], dtype=torch.int)) + def test_set_item_to_scalar_tensor(self, device): + m = random.randint(1, 10) + n = random.randint(1, 10) + z = torch.randn([m, n], device=device) + a = 1.0 + w = torch.tensor(a, requires_grad=True, device=device) + z[:, 0] = w + z.sum().backward() + self.assertEqual(w.grad, m * a) + def test_single_int(self, device): v = torch.randn(5, 7, 3, device=device) self.assertEqual(v[4].shape, (7, 3))