mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
fix set item to scalar tensor missing gradient info
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78246 Approved by: https://github.com/ngimel
This commit is contained in:
parent
e01fb9cd07
commit
cd41c8f032
|
|
@ -353,8 +353,8 @@ static inline void copy_to(const Tensor& dst, const Tensor& src) {
|
|||
// appear. Users can workaround that case by dst[index..] = src.reshape(..)
|
||||
dst.copy_(src);
|
||||
return;
|
||||
} else if (src.sizes().size() == 0 && src.device().type() == at::kCPU) {
|
||||
dst.fill_(src.item());
|
||||
} else if (src.dim() == 0 && src.device().type() == at::kCPU) {
|
||||
dst.fill_(src);
|
||||
return;
|
||||
}
|
||||
auto src_view = src.view(slicePrefix1sSize(src.sizes()));
|
||||
|
|
|
|||
|
|
@ -649,6 +649,16 @@ class TestIndexing(TestCase):
|
|||
self.assertEqual(reference[[0, 123, 44488, 68807, 123343], ],
|
||||
torch.tensor([0, 123, 44488, 68807, 123343], dtype=torch.int))
|
||||
|
||||
def test_set_item_to_scalar_tensor(self, device):
|
||||
m = random.randint(1, 10)
|
||||
n = random.randint(1, 10)
|
||||
z = torch.randn([m, n], device=device)
|
||||
a = 1.0
|
||||
w = torch.tensor(a, requires_grad=True, device=device)
|
||||
z[:, 0] = w
|
||||
z.sum().backward()
|
||||
self.assertEqual(w.grad, m * a)
|
||||
|
||||
def test_single_int(self, device):
|
||||
v = torch.randn(5, 7, 3, device=device)
|
||||
self.assertEqual(v[4].shape, (7, 3))
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user