mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
add nominal support for int32 indices in index/index_put ops (#86309)
Currently index_select/index_add decompositions decompose to `index` or `index_put` ops. The problem with this is that `index_select` and `index_add` accept int32 indices while `index` doesn't. That leads to error in meta func for those decompositions. This PR adds non-performant support for int32 indices to `index` operations, to allow decompositions go through. Pull Request resolved: https://github.com/pytorch/pytorch/pull/86309 Approved by: https://github.com/lezcano
This commit is contained in:
parent
e8b0bea677
commit
dc9c507d24
|
|
@ -48,12 +48,18 @@ static C10_UNUSED std::vector<Tensor> expandTensors(const Tensor & self, IOptTen
|
|||
return result;
|
||||
}
|
||||
|
||||
static C10_UNUSED void checkIndexTensorTypes(IOptTensorListRef indices) {
|
||||
static C10_UNUSED void checkIndexTensorTypes(IOptTensorListRef indices, bool allow_int=false) {
|
||||
for (const auto& tensor : indices) {
|
||||
if (tensor.has_value() && tensor->defined()) {
|
||||
auto scalarType = tensor->scalar_type();
|
||||
if (scalarType != kLong && scalarType != kByte && scalarType != kBool) {
|
||||
TORCH_CHECK_INDEX(false, "tensors used as indices must be long, byte or bool tensors");
|
||||
if (allow_int) {
|
||||
if (scalarType != kLong && scalarType != kByte && scalarType != kBool && scalarType != kInt) {
|
||||
TORCH_CHECK_INDEX(false, "tensors used as indices must be long, byte or bool tensors");
|
||||
}
|
||||
} else {
|
||||
if (scalarType != kLong && scalarType != kByte && scalarType != kBool) {
|
||||
TORCH_CHECK_INDEX(false, "tensors used as indices must be long, byte or bool tensors");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ const Tensor& value){
|
|||
}
|
||||
|
||||
static AdvancedIndex make_info(Tensor self, IOptTensorListRef orig) {
|
||||
checkIndexTensorTypes(orig);
|
||||
checkIndexTensorTypes(orig, /*allow_int*/ true);
|
||||
// first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors
|
||||
auto indices = expandTensors(self, orig);
|
||||
// next broadcast all index tensors together
|
||||
|
|
@ -82,6 +82,12 @@ static AdvancedIndex make_info(Tensor self, IOptTensorListRef orig) {
|
|||
indice = indice.to(self.device());
|
||||
}
|
||||
}
|
||||
for (auto & indice : indices) {
|
||||
if (indice.defined() && indice.dtype() == at::kInt) {
|
||||
indice = indice.to(at::kLong);
|
||||
}
|
||||
}
|
||||
|
||||
return AdvancedIndex(self, indices);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -291,9 +291,14 @@ computeLinearIndex(const Tensor & src, TensorList indices, bool check_range) {
|
|||
|
||||
|
||||
static std::tuple<Tensor, Tensor, int64_t, int64_t, int64_t, std::vector<int64_t>> makeLinearIndex(Tensor self, IOptTensorListRef orig, bool check_range) {
|
||||
checkIndexTensorTypes(orig);
|
||||
checkIndexTensorTypes(orig, /*allow_int*/true);
|
||||
// first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors
|
||||
auto indices = expandTensors(self, orig);
|
||||
for (auto & i : indices) {
|
||||
if (i.defined() && i.dtype() == at::kInt) {
|
||||
i = i.to(at::kLong);
|
||||
}
|
||||
}
|
||||
// next broadcast all index tensors together
|
||||
indices = expand_outplace(indices);
|
||||
// add missing null Tensors so that it matches self.dim()
|
||||
|
|
|
|||
|
|
@ -881,6 +881,31 @@ class TestIndexing(TestCase):
|
|||
|
||||
self.assertEqual(output, input_list)
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
def test_index_ind_dtype(self, device):
|
||||
x = torch.randn(4, 4, device=device)
|
||||
ind_long = torch.randint(4, (4,), dtype=torch.long, device=device)
|
||||
ind_int = ind_long.int()
|
||||
src = torch.randn(4, device=device)
|
||||
ref = x[ind_long, ind_long]
|
||||
res = x[ind_int, ind_int]
|
||||
self.assertEqual(ref, res)
|
||||
ref = x[ind_long, :]
|
||||
res = x[ind_int, :]
|
||||
self.assertEqual(ref, res)
|
||||
ref = x[:, ind_long]
|
||||
res = x[:, ind_int]
|
||||
self.assertEqual(ref, res)
|
||||
# no repeating indices for index_put
|
||||
ind_long = torch.arange(4, dtype=torch.long, device=device)
|
||||
ind_int = ind_long.int()
|
||||
for accum in (True, False):
|
||||
inp_ref = x.clone()
|
||||
inp_res = x.clone()
|
||||
torch.index_put_(inp_ref, (ind_long, ind_long), src, accum)
|
||||
torch.index_put_(inp_res, (ind_int, ind_int), src, accum)
|
||||
self.assertEqual(inp_ref, inp_res)
|
||||
|
||||
def test_multiple_byte_mask(self, device):
|
||||
v = torch.randn(5, 7, 3, device=device)
|
||||
# note: these broadcast together and are transposed to the first dim
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user