diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp index 4588a9ccb72..8744b4a8788 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp +++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp @@ -737,17 +737,15 @@ Tensor & index_fill_(Tensor & self, int64_t dim, const Tensor & index, Scalar so } // Handle the case when `self` is 0-dim - if (0 == self.dim()) { - self.unsqueeze_(-1); - } + Tensor self_nonzero_dim = (self.dim() == 0) ? self.unsqueeze(-1) : self; - dim = at::maybe_wrap_dim(dim, self); + dim = at::maybe_wrap_dim(dim, self_nonzero_dim); TORCH_CHECK(index.dim() <= 1, "Index has to be a vector/scalar"); // Prepare `index` for TensorIterator. // It is restrided to be broadcastable over `self` in TensorIterator. - auto index_sizes = std::vector(self.dim(), 1); - auto index_strides = std::vector(self.dim(), 0); + auto index_sizes = std::vector(self_nonzero_dim.dim(), 1); + auto index_strides = std::vector(self_nonzero_dim.dim(), 0); index_sizes[dim] = index.numel(); index_strides[dim] = (index.dim() > 0) ? index.stride(0) : 1; // `index` is 1d or scalar auto index_restrided = index.as_strided( @@ -762,11 +760,11 @@ Tensor & index_fill_(Tensor & self, int64_t dim, const Tensor & index, Scalar so // match as required by TensorIterator (input shape should // strictly broadcast over output shape, i.e. // output.shape[i] >= input.shape[i] for i in range(dims)). - auto self_sizes = self.sizes().vec(); - auto self_strides = self.strides().vec(); + auto self_sizes = self_nonzero_dim.sizes().vec(); + auto self_strides = self_nonzero_dim.strides().vec(); self_sizes[dim] = index.numel(); self_strides[dim] = 0; - auto self_restrided = self.as_strided(self_sizes, self_strides); + auto self_restrided = self_nonzero_dim.as_strided(self_sizes, self_strides); auto iter = TensorIteratorConfig() // We do not check for overlap because `self` is restrided @@ -779,8 +777,8 @@ Tensor & index_fill_(Tensor & self, int64_t dim, const Tensor & index, Scalar so .add_input(index_restrided) .build(); - auto self_dim_size = (self.sizes())[dim]; - auto self_dim_stride = (self.strides())[dim]; + auto self_dim_size = (self_nonzero_dim.sizes())[dim]; + auto self_dim_stride = (self_nonzero_dim.strides())[dim]; index_fill_stub( iter.device_type(), iter, diff --git a/test/test_torch.py b/test/test_torch.py index 3e03f82c862..eaed87efb3c 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -4449,6 +4449,11 @@ class TestTorchDeviceType(TestCase): if not x.is_complex(): with self.assertRaisesRegex(RuntimeError, r"Scalar"): x.index_fill_(1, index, 1 + 1j) + # Make sure that the result stays 0-dim while applied to + # a 0-dim input + x = torch.tensor(1, dtype=dtype, device=device) + self.assertEqual(0, x.index_fill(0, index, -1).dim()) + self.assertEqual(0, x.index_fill_(0, index, -1).dim()) def test_index_select(self, device): for dtype in [torch.int, torch.long]: