add nominal support for int32 indices in index/index_put ops (#86309)

Currently index_select/index_add decompositions decompose to `index` or `index_put` ops. The problem with this is that `index_select` and `index_add` accept int32 indices while `index` doesn't. That leads to error in meta func for those decompositions. This PR adds non-performant support for int32 indices to `index` operations, to allow decompositions go through.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86309
Approved by: https://github.com/lezcano
This commit is contained in:
Natalia Gimelshein 2022-10-05 23:59:16 +00:00 committed by PyTorch MergeBot
parent e8b0bea677
commit dc9c507d24
4 changed files with 47 additions and 5 deletions

View File

@ -48,12 +48,18 @@ static C10_UNUSED std::vector<Tensor> expandTensors(const Tensor & self, IOptTen
return result;
}
static C10_UNUSED void checkIndexTensorTypes(IOptTensorListRef indices) {
static C10_UNUSED void checkIndexTensorTypes(IOptTensorListRef indices, bool allow_int=false) {
for (const auto& tensor : indices) {
if (tensor.has_value() && tensor->defined()) {
auto scalarType = tensor->scalar_type();
if (scalarType != kLong && scalarType != kByte && scalarType != kBool) {
TORCH_CHECK_INDEX(false, "tensors used as indices must be long, byte or bool tensors");
if (allow_int) {
if (scalarType != kLong && scalarType != kByte && scalarType != kBool && scalarType != kInt) {
TORCH_CHECK_INDEX(false, "tensors used as indices must be long, byte or bool tensors");
}
} else {
if (scalarType != kLong && scalarType != kByte && scalarType != kBool) {
TORCH_CHECK_INDEX(false, "tensors used as indices must be long, byte or bool tensors");
}
}
}
}

View File

@ -57,7 +57,7 @@ const Tensor& value){
}
static AdvancedIndex make_info(Tensor self, IOptTensorListRef orig) {
checkIndexTensorTypes(orig);
checkIndexTensorTypes(orig, /*allow_int*/ true);
// first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors
auto indices = expandTensors(self, orig);
// next broadcast all index tensors together
@ -82,6 +82,12 @@ static AdvancedIndex make_info(Tensor self, IOptTensorListRef orig) {
indice = indice.to(self.device());
}
}
for (auto & indice : indices) {
if (indice.defined() && indice.dtype() == at::kInt) {
indice = indice.to(at::kLong);
}
}
return AdvancedIndex(self, indices);
}

View File

@ -291,9 +291,14 @@ computeLinearIndex(const Tensor & src, TensorList indices, bool check_range) {
static std::tuple<Tensor, Tensor, int64_t, int64_t, int64_t, std::vector<int64_t>> makeLinearIndex(Tensor self, IOptTensorListRef orig, bool check_range) {
checkIndexTensorTypes(orig);
checkIndexTensorTypes(orig, /*allow_int*/true);
// first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors
auto indices = expandTensors(self, orig);
for (auto & i : indices) {
if (i.defined() && i.dtype() == at::kInt) {
i = i.to(at::kLong);
}
}
// next broadcast all index tensors together
indices = expand_outplace(indices);
// add missing null Tensors so that it matches self.dim()

View File

@ -881,6 +881,31 @@ class TestIndexing(TestCase):
self.assertEqual(output, input_list)
@onlyNativeDeviceTypes
def test_index_ind_dtype(self, device):
x = torch.randn(4, 4, device=device)
ind_long = torch.randint(4, (4,), dtype=torch.long, device=device)
ind_int = ind_long.int()
src = torch.randn(4, device=device)
ref = x[ind_long, ind_long]
res = x[ind_int, ind_int]
self.assertEqual(ref, res)
ref = x[ind_long, :]
res = x[ind_int, :]
self.assertEqual(ref, res)
ref = x[:, ind_long]
res = x[:, ind_int]
self.assertEqual(ref, res)
# no repeating indices for index_put
ind_long = torch.arange(4, dtype=torch.long, device=device)
ind_int = ind_long.int()
for accum in (True, False):
inp_ref = x.clone()
inp_res = x.clone()
torch.index_put_(inp_ref, (ind_long, ind_long), src, accum)
torch.index_put_(inp_res, (ind_int, ind_int), src, accum)
self.assertEqual(inp_ref, inp_res)
def test_multiple_byte_mask(self, device):
v = torch.randn(5, 7, 3, device=device)
# note: these broadcast together and are transposed to the first dim