mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Add broadcastable check to index_put (#94849)
Copy-n-paste it from
989299802c/aten/src/ATen/native/TensorAdvancedIndexing.cpp (L582-L583)
Which is used for both CPU and CUDA checks, unless op is called for GPU with `deterministicAlgorithms()` set to true
Followup: do the same for XLA and fix the case when indices are not null
Fixes https://github.com/pytorch/pytorch/issues/94667
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94849
Approved by: https://github.com/ngimel
This commit is contained in:
parent
e0ede1cc30
commit
d5d55363d9
|
|
@ -248,6 +248,9 @@ static std::vector<int64_t> computeLinearStride(const Tensor & tensor) {
|
|||
// computes the stride as if tensor were contiguous
|
||||
auto sizes = tensor.sizes();
|
||||
std::vector<int64_t> stride(tensor.dim());
|
||||
if (stride.empty()) {
|
||||
return stride;
|
||||
}
|
||||
stride[tensor.dim() - 1] = 1;
|
||||
std::partial_sum(sizes.rbegin(), sizes.rend() - 1, stride.rbegin() + 1, std::multiplies<int64_t>());
|
||||
return stride;
|
||||
|
|
@ -331,6 +334,8 @@ int64_t largestIndex(const Tensor &self) {
|
|||
}
|
||||
|
||||
void index_put_with_sort_kernel(Tensor & self, const c10::List<c10::optional<Tensor>>& indices, const Tensor & value, bool accumulate, bool unsafe) {
|
||||
TORCH_CHECK(!indices.empty() || is_expandable_to(value.sizes(), self.sizes()), "shape mismatch: value tensor of shape ", value.sizes(),
|
||||
" cannot be broadcast to indexing result of shape ", self.sizes());
|
||||
if (indices.size() > (size_t)self.dim()) {
|
||||
TORCH_CHECK_INDEX(false, "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from torch.testing._internal.common_utils import (
|
|||
TestCase, run_tests, TEST_WITH_TORCHDYNAMO)
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests, onlyCUDA, dtypes, dtypesIfCPU, dtypesIfCUDA,
|
||||
onlyNativeDeviceTypes)
|
||||
onlyNativeDeviceTypes, skipXLA)
|
||||
|
||||
|
||||
class TestIndexing(TestCase):
|
||||
|
|
@ -911,6 +911,13 @@ class TestIndexing(TestCase):
|
|||
torch.index_put_(inp_res, (ind_int, ind_int), src, accum)
|
||||
self.assertEqual(inp_ref, inp_res)
|
||||
|
||||
@skipXLA
|
||||
def test_index_put_accumulate_empty(self, device):
|
||||
# Regression test for https://github.com/pytorch/pytorch/issues/94667
|
||||
input = torch.rand([], dtype=torch.float32, device=device)
|
||||
with self.assertRaises(RuntimeError):
|
||||
input.index_put([], torch.tensor([1.0], device=device), True)
|
||||
|
||||
def test_multiple_byte_mask(self, device):
|
||||
v = torch.randn(5, 7, 3, device=device)
|
||||
# note: these broadcast together and are transposed to the first dim
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user