Add broadcastable check to index_put (#94849)

Copy-n-paste it from
989299802c/aten/src/ATen/native/TensorAdvancedIndexing.cpp (L582-L583)

Which is used for both CPU and CUDA checks, unless op is called for GPU with `deterministicAlgorithms()` set to true

Followup: do the same for XLA and fix the case when indices are not null

Fixes https://github.com/pytorch/pytorch/issues/94667

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94849
Approved by: https://github.com/ngimel
This commit is contained in:
Nikita Shulga 2023-02-17 20:37:20 +00:00 committed by PyTorch MergeBot
parent e0ede1cc30
commit d5d55363d9
2 changed files with 13 additions and 1 deletions

View File

@ -248,6 +248,9 @@ static std::vector<int64_t> computeLinearStride(const Tensor & tensor) {
// computes the stride as if tensor were contiguous
auto sizes = tensor.sizes();
std::vector<int64_t> stride(tensor.dim());
if (stride.empty()) {
return stride;
}
stride[tensor.dim() - 1] = 1;
std::partial_sum(sizes.rbegin(), sizes.rend() - 1, stride.rbegin() + 1, std::multiplies<int64_t>());
return stride;
@ -331,6 +334,8 @@ int64_t largestIndex(const Tensor &self) {
}
void index_put_with_sort_kernel(Tensor & self, const c10::List<c10::optional<Tensor>>& indices, const Tensor & value, bool accumulate, bool unsafe) {
TORCH_CHECK(!indices.empty() || is_expandable_to(value.sizes(), self.sizes()), "shape mismatch: value tensor of shape ", value.sizes(),
" cannot be broadcast to indexing result of shape ", self.sizes());
if (indices.size() > (size_t)self.dim()) {
TORCH_CHECK_INDEX(false, "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");
}

View File

@ -15,7 +15,7 @@ from torch.testing._internal.common_utils import (
TestCase, run_tests, TEST_WITH_TORCHDYNAMO)
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests, onlyCUDA, dtypes, dtypesIfCPU, dtypesIfCUDA,
onlyNativeDeviceTypes)
onlyNativeDeviceTypes, skipXLA)
class TestIndexing(TestCase):
@ -911,6 +911,13 @@ class TestIndexing(TestCase):
torch.index_put_(inp_res, (ind_int, ind_int), src, accum)
self.assertEqual(inp_ref, inp_res)
@skipXLA
def test_index_put_accumulate_empty(self, device):
# Regression test for https://github.com/pytorch/pytorch/issues/94667
input = torch.rand([], dtype=torch.float32, device=device)
with self.assertRaises(RuntimeError):
input.index_put([], torch.tensor([1.0], device=device), True)
def test_multiple_byte_mask(self, device):
v = torch.randn(5, 7, 3, device=device)
# note: these broadcast together and are transposed to the first dim