mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
remove index_fill side-effect for scalar tensors (#52209)
Summary: `index_fill` silently promotes zero dim Tensors to 1-dim Tensors. This PR fixes that. Was: ``` In [1]: import torch In [2]: x = torch.tensor(1) In [3]: idx = torch.tensor(0).long() In [4]: x.dim() Out[4]: 0 In [5]: x.index_fill(0, idx, -1).dim() Out[5]: 1 ``` Now: ``` In [6]: x.index_fill(0, idx, -1).dim() Out[6]: 0 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/52209 Reviewed By: ejguan Differential Revision: D26446470 Pulled By: ngimel fbshipit-source-id: 4737e6941a7216b57f3416b59362817834df3a3a
This commit is contained in:
parent
57947c5d85
commit
0048d97eda
|
|
@ -737,17 +737,15 @@ Tensor & index_fill_(Tensor & self, int64_t dim, const Tensor & index, Scalar so
|
|||
}
|
||||
|
||||
// Handle the case when `self` is 0-dim
|
||||
if (0 == self.dim()) {
|
||||
self.unsqueeze_(-1);
|
||||
}
|
||||
Tensor self_nonzero_dim = (self.dim() == 0) ? self.unsqueeze(-1) : self;
|
||||
|
||||
dim = at::maybe_wrap_dim(dim, self);
|
||||
dim = at::maybe_wrap_dim(dim, self_nonzero_dim);
|
||||
TORCH_CHECK(index.dim() <= 1, "Index has to be a vector/scalar");
|
||||
|
||||
// Prepare `index` for TensorIterator.
|
||||
// It is restrided to be broadcastable over `self` in TensorIterator.
|
||||
auto index_sizes = std::vector<int64_t>(self.dim(), 1);
|
||||
auto index_strides = std::vector<int64_t>(self.dim(), 0);
|
||||
auto index_sizes = std::vector<int64_t>(self_nonzero_dim.dim(), 1);
|
||||
auto index_strides = std::vector<int64_t>(self_nonzero_dim.dim(), 0);
|
||||
index_sizes[dim] = index.numel();
|
||||
index_strides[dim] = (index.dim() > 0) ? index.stride(0) : 1; // `index` is 1d or scalar
|
||||
auto index_restrided = index.as_strided(
|
||||
|
|
@ -762,11 +760,11 @@ Tensor & index_fill_(Tensor & self, int64_t dim, const Tensor & index, Scalar so
|
|||
// match as required by TensorIterator (input shape should
|
||||
// strictly broadcast over output shape, i.e.
|
||||
// output.shape[i] >= input.shape[i] for i in range(dims)).
|
||||
auto self_sizes = self.sizes().vec();
|
||||
auto self_strides = self.strides().vec();
|
||||
auto self_sizes = self_nonzero_dim.sizes().vec();
|
||||
auto self_strides = self_nonzero_dim.strides().vec();
|
||||
self_sizes[dim] = index.numel();
|
||||
self_strides[dim] = 0;
|
||||
auto self_restrided = self.as_strided(self_sizes, self_strides);
|
||||
auto self_restrided = self_nonzero_dim.as_strided(self_sizes, self_strides);
|
||||
|
||||
auto iter = TensorIteratorConfig()
|
||||
// We do not check for overlap because `self` is restrided
|
||||
|
|
@ -779,8 +777,8 @@ Tensor & index_fill_(Tensor & self, int64_t dim, const Tensor & index, Scalar so
|
|||
.add_input(index_restrided)
|
||||
.build();
|
||||
|
||||
auto self_dim_size = (self.sizes())[dim];
|
||||
auto self_dim_stride = (self.strides())[dim];
|
||||
auto self_dim_size = (self_nonzero_dim.sizes())[dim];
|
||||
auto self_dim_stride = (self_nonzero_dim.strides())[dim];
|
||||
index_fill_stub(
|
||||
iter.device_type(),
|
||||
iter,
|
||||
|
|
|
|||
|
|
@ -4449,6 +4449,11 @@ class TestTorchDeviceType(TestCase):
|
|||
if not x.is_complex():
|
||||
with self.assertRaisesRegex(RuntimeError, r"Scalar"):
|
||||
x.index_fill_(1, index, 1 + 1j)
|
||||
# Make sure that the result stays 0-dim while applied to
|
||||
# a 0-dim input
|
||||
x = torch.tensor(1, dtype=dtype, device=device)
|
||||
self.assertEqual(0, x.index_fill(0, index, -1).dim())
|
||||
self.assertEqual(0, x.index_fill_(0, index, -1).dim())
|
||||
|
||||
def test_index_select(self, device):
|
||||
for dtype in [torch.int, torch.long]:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user