remove index_fill side-effect for scalar tensors (#52209)

Summary:
`index_fill` silently promotes zero dim Tensors to 1-dim Tensors. This PR fixes that.
Was:
```
In [1]: import torch

In [2]: x = torch.tensor(1)

In [3]: idx = torch.tensor(0).long()

In [4]: x.dim()
Out[4]: 0

In [5]: x.index_fill(0, idx, -1).dim()
Out[5]: 1

```
Now:
```
In [6]: x.index_fill(0, idx, -1).dim()
Out[6]: 0

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/52209

Reviewed By: ejguan

Differential Revision: D26446470

Pulled By: ngimel

fbshipit-source-id: 4737e6941a7216b57f3416b59362817834df3a3a
This commit is contained in:
Nikita Vedeneev 2021-02-25 00:33:07 -08:00 committed by Facebook GitHub Bot
parent 57947c5d85
commit 0048d97eda
2 changed files with 14 additions and 11 deletions

View File

@ -737,17 +737,15 @@ Tensor & index_fill_(Tensor & self, int64_t dim, const Tensor & index, Scalar so
}
// Handle the case when `self` is 0-dim
if (0 == self.dim()) {
self.unsqueeze_(-1);
}
Tensor self_nonzero_dim = (self.dim() == 0) ? self.unsqueeze(-1) : self;
dim = at::maybe_wrap_dim(dim, self);
dim = at::maybe_wrap_dim(dim, self_nonzero_dim);
TORCH_CHECK(index.dim() <= 1, "Index has to be a vector/scalar");
// Prepare `index` for TensorIterator.
// It is restrided to be broadcastable over `self` in TensorIterator.
auto index_sizes = std::vector<int64_t>(self.dim(), 1);
auto index_strides = std::vector<int64_t>(self.dim(), 0);
auto index_sizes = std::vector<int64_t>(self_nonzero_dim.dim(), 1);
auto index_strides = std::vector<int64_t>(self_nonzero_dim.dim(), 0);
index_sizes[dim] = index.numel();
index_strides[dim] = (index.dim() > 0) ? index.stride(0) : 1; // `index` is 1d or scalar
auto index_restrided = index.as_strided(
@ -762,11 +760,11 @@ Tensor & index_fill_(Tensor & self, int64_t dim, const Tensor & index, Scalar so
// match as required by TensorIterator (input shape should
// strictly broadcast over output shape, i.e.
// output.shape[i] >= input.shape[i] for i in range(dims)).
auto self_sizes = self.sizes().vec();
auto self_strides = self.strides().vec();
auto self_sizes = self_nonzero_dim.sizes().vec();
auto self_strides = self_nonzero_dim.strides().vec();
self_sizes[dim] = index.numel();
self_strides[dim] = 0;
auto self_restrided = self.as_strided(self_sizes, self_strides);
auto self_restrided = self_nonzero_dim.as_strided(self_sizes, self_strides);
auto iter = TensorIteratorConfig()
// We do not check for overlap because `self` is restrided
@ -779,8 +777,8 @@ Tensor & index_fill_(Tensor & self, int64_t dim, const Tensor & index, Scalar so
.add_input(index_restrided)
.build();
auto self_dim_size = (self.sizes())[dim];
auto self_dim_stride = (self.strides())[dim];
auto self_dim_size = (self_nonzero_dim.sizes())[dim];
auto self_dim_stride = (self_nonzero_dim.strides())[dim];
index_fill_stub(
iter.device_type(),
iter,

View File

@ -4449,6 +4449,11 @@ class TestTorchDeviceType(TestCase):
if not x.is_complex():
with self.assertRaisesRegex(RuntimeError, r"Scalar"):
x.index_fill_(1, index, 1 + 1j)
# Make sure that the result stays 0-dim while applied to
# a 0-dim input
x = torch.tensor(1, dtype=dtype, device=device)
self.assertEqual(0, x.index_fill(0, index, -1).dim())
self.assertEqual(0, x.index_fill_(0, index, -1).dim())
def test_index_select(self, device):
for dtype in [torch.int, torch.long]: