Faster index_select for COO tensors

This commit is contained in:
Nikita Vedeneev 2022-02-11 05:21:03 +00:00
parent 987f146185
commit 90fceb015d
2 changed files with 242 additions and 2 deletions

View File

@ -1390,6 +1390,247 @@ Tensor index_select_sparse(const Tensor& self, int64_t dim, const Tensor& index)
}
}
Tensor index_select_sparse_cpu(const Tensor& self, int64_t dim, const Tensor& index) {
/*
Algorithm:
index - a 1-D tensor of indicies with shape (n,)
self - sparse tensor, its shape is sizes = sparse_shape + dense_shape
indices - 2-D tensor of indices, shape is (sparse_dims, nnz)
values - (1+len(dense_shape))-D tensor of values, shape is (nnz,) + dense_shape
index_select(dim, index) returns a sparse tensor with the following data
new_sizes = sizes[:dim] + (n,) + sizes[dim+1:]
new_indices - shape is (sparse_dims, new_nnz)
new_values - shape is (new_nnz,) + dense_shape
if dim < len(sparse_shape):
for i, idx in enumerate(index):
for j, jdx in enumerate(indices[dim]):
if idx == jdx:
icol = indices[:dim][j] + (i,) + indices[dim+1:][j]
new_indices.add_column(icol)
new_values.add_row(values[j])
else:
new_indices = indices
new_values = values.index_select(dim - sparse_dim + 1, index);
*/
const auto ndim = self.dim();
if (ndim == 0) {
TORCH_CHECK_INDEX(false, "index_select() cannot be applied to a 0-dim tensor.");
}
if (!(index.dim() == 1 && index.dtype() == at::kLong)) {
TORCH_CHECK_INDEX(false, "index_select() argument index must be 1-D long-tensor.");
}
dim = maybe_wrap_dim(dim, ndim);
const auto size = self.size(dim);
const auto sparse_dim = self.sparse_dim();
const auto dense_dim = self.dense_dim();
const auto indices = self._indices();
const auto values = self._values();
const auto nnz = values.size(0);
const auto index_len = index.size(0);
auto res_sizes = self.sizes().vec();
res_sizes[dim] = index_len;
// TODO: decide on optimal grain size for sparse tensors.
const auto grain_size = at::internal::GRAIN_SIZE;
// 1 <= n_threads_nnz <= min(ceil(nnz / grain_size), get_num_threads())
const auto n_threads_nnz = std::max<int64_t>(
1,
std::min<int64_t>((nnz + grain_size - 1) / grain_size, at::get_num_threads())
);
// If indexing into sparse dimensions
if (dim < sparse_dim) {
auto nneg_index = at::empty_like(index);
{
auto* ptr_index = index.data_ptr<int64_t>();
auto* ptr_index_end = ptr_index + index_len;
auto* ptr_nneg_index = nneg_index.data_ptr<int64_t>();
while (ptr_index != ptr_index_end) {
auto idx = *ptr_index++;
if (idx < -size || idx >= size) {
TORCH_CHECK_INDEX(false,
"index_select(): index contains ", idx, " that is out of range for tensor of size ",
self.sizes(), " at dimension ", dim
);
}
if (idx < 0) {
idx += size;
}
*ptr_nneg_index++ = idx;
}
}
// Much faster than at::unique_dim
const auto unique_with_counts = [](
const Tensor& t, int64_t len
) -> std::tuple<Tensor, Tensor, Tensor, int64_t> {
Tensor t_unique, t_idx;
std::tie(t_unique, t_idx) = at::sort(t);
auto t_counts = at::ones_like(t_unique);
int64_t len_unique;
{
auto* ptr_counts = t_counts.data_ptr<int64_t>();
auto* first = t_unique.data_ptr<int64_t>();
auto* last = first + len;
auto* curr = first;
auto* counts = ptr_counts;
while (++first != last) {
if (*curr != *first) {
++counts;
if (++curr != first) {
// std::swap(*curr, *first);
*curr = *first;
}
}
else {
++(*counts);
}
}
len_unique = counts - ptr_counts + 1;
}
return std::make_tuple(t_idx, t_unique, t_counts, len_unique);
};
Tensor dim_sort_indices, dim_indices_unique, dim_indices_counts;
int64_t n_unique_dim_indices;
std::tie(dim_sort_indices, dim_indices_unique, dim_indices_counts, n_unique_dim_indices)
= unique_with_counts(indices[dim].contiguous(), nnz);
Tensor sel_sort_indices, sel_indices_unique, sel_indices_counts;
int64_t n_unique_sel_indices;
std::tie(sel_sort_indices, sel_indices_unique, sel_indices_counts, n_unique_sel_indices)
= unique_with_counts(nneg_index, index_len);
const auto compute_index_intersections_and_nnz = [](
const Tensor& t1, const Tensor& c1, int64_t l1,
const Tensor& t2, const Tensor& c2, int64_t l2
) -> std::tuple<Tensor, Tensor, int64_t, int64_t> {
const auto lmin = std::min(l1, l2);
auto t1_idx = at::empty({lmin}, t1.options());
auto t2_idx = at::empty({lmin}, t2.options());
int64_t nnz = 0;
int64_t n_intersect = 0;
auto* ptr_t1_idx = t1_idx.data_ptr<int64_t>();
auto* ptr_t2_idx = t2_idx.data_ptr<int64_t>();
auto* ptr_t1 = t1.data_ptr<int64_t>();
auto* ptr_c1 = c1.data_ptr<int64_t>();
auto* ptr_t2 = t2.data_ptr<int64_t>();
auto* ptr_c2 = c2.data_ptr<int64_t>();
// we assume search in t2
auto* first = ptr_t2;
auto* last = ptr_t2 + l2;
for (const auto i : c10::irange(l1)) {
const auto idx = ptr_t1[i];
const auto idx_pos = std::lower_bound(first, last, idx);
if (idx_pos != last && *idx_pos == idx) {
const auto j = idx_pos - first;
const auto count1 = ptr_c1[i];
const auto count2 = ptr_c2[j];
*ptr_t1_idx++ = i;
*ptr_t2_idx++ = j;
++n_intersect;
nnz += count1 * count2;
}
}
return std::make_tuple(t1_idx, t2_idx, n_intersect, nnz);
};
const auto search_in_index = (
n_unique_dim_indices * std::log2(n_unique_sel_indices) < std::log2(n_unique_dim_indices) * n_unique_sel_indices
);
Tensor dim_indices_intersect, index_intersect;
int64_t n_intersect, res_nnz;
if (search_in_index) {
std::tie(dim_indices_intersect, index_intersect, n_intersect, res_nnz)
= compute_index_intersections_and_nnz(
dim_indices_unique, dim_indices_counts, n_unique_dim_indices,
sel_indices_unique, sel_indices_counts, n_unique_sel_indices
);
}
else {
std::tie(index_intersect, dim_indices_intersect, n_intersect, res_nnz)
= compute_index_intersections_and_nnz(
sel_indices_unique, sel_indices_counts, n_unique_sel_indices,
dim_indices_unique, dim_indices_counts, n_unique_dim_indices
);
}
const auto compute_offsets = [](const Tensor& counts, int64_t len) {
const auto narrowed_counts = counts.narrow(-1, 0, len);
return narrowed_counts.cumsum(/*dim=*/0).sub_(narrowed_counts);
};
const auto dim_sort_indices_offsets = compute_offsets(dim_indices_counts, n_unique_dim_indices);
const auto sel_sort_indices_offsets = compute_offsets(sel_indices_counts, n_unique_sel_indices);
auto selected_dim_indices = at::empty({res_nnz}, dim_indices_unique.options());
auto res_dim_indices = at::empty({res_nnz}, sel_indices_unique.options());
{
auto* ptr_selected_dim_indices = selected_dim_indices.data_ptr<int64_t>();
auto* ptr_res_dim_indices = res_dim_indices.data_ptr<int64_t>();
const auto* ptr_dim_indices_intersect = dim_indices_intersect.data_ptr<int64_t>();
const auto* ptr_sel_indices_intersect = index_intersect.data_ptr<int64_t>();
const auto* ptr_dim_indices_counts = dim_indices_counts.data_ptr<int64_t>();
const auto* ptr_sel_indices_counts = sel_indices_counts.data_ptr<int64_t>();
const auto* ptr_dim_sort_indices = dim_sort_indices.data_ptr<int64_t>();
const auto* ptr_sel_sort_indices = sel_sort_indices.data_ptr<int64_t>();
const auto* ptr_dim_indices_offsets = dim_sort_indices_offsets.data_ptr<int64_t>();
const auto* ptr_sel_indices_offsets = sel_sort_indices_offsets.data_ptr<int64_t>();
for (const auto ii : c10::irange(n_intersect)) {
const auto i_idx = *ptr_dim_indices_intersect++;
const auto j_idx = *ptr_sel_indices_intersect++;
const auto i_idx_count = ptr_dim_indices_counts[i_idx];
const auto j_idx_count = ptr_sel_indices_counts[j_idx];
const auto i_idx_offset = ptr_dim_indices_offsets[i_idx];
const auto j_idx_offset = ptr_sel_indices_offsets[j_idx];
const auto* src_dim_sort_indices = ptr_dim_sort_indices + i_idx_offset;
const auto* src_sel_sort_indices = ptr_sel_sort_indices + j_idx_offset;
const auto copy_chunk_len = i_idx_count * j_idx_count;
for (const auto chunk_elem_idx : c10::irange(copy_chunk_len)) {
*ptr_selected_dim_indices++ = *(src_dim_sort_indices + (chunk_elem_idx % i_idx_count));
*ptr_res_dim_indices++ = *(src_sel_sort_indices + (chunk_elem_idx % j_idx_count));
}
}
}
auto res_indices = indices.index_select(1, selected_dim_indices);
res_indices[dim] = res_dim_indices;
const auto res_values = values.index_select(0, selected_dim_indices);
return _sparse_coo_tensor_with_dims_and_tensors(
sparse_dim, dense_dim, res_sizes, res_indices, res_values, self.options());
}
// If indexing into dense dimensions
else {
// It is sufficient to just perform `index_select` on values
// if `dim` refers to dense dimensions.
const auto res_values = values.index_select(dim - sparse_dim + 1, index);
return _sparse_coo_tensor_with_dims_and_tensors(
sparse_dim, dense_dim, res_sizes, indices, res_values, self.options());
}
}
Tensor slice(
const Tensor& self,
int64_t dim,

View File

@ -6864,8 +6864,7 @@
QuantizedCPU: index_select_quantized_cpu_
CUDA: index_select_cuda
QuantizedCUDA: index_select_quantized_cuda
SparseCPU: index_select_sparse
SparseCUDA: index_select_sparse
SparseCPU: index_select_sparse_cpu
- func: index_select.dimname_out(Tensor self, Dimname dim, Tensor index, *, Tensor(a!) out) -> Tensor(a!)