mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Faster index_select for COO tensors
This commit is contained in:
parent
987f146185
commit
90fceb015d
|
|
@ -1390,6 +1390,247 @@ Tensor index_select_sparse(const Tensor& self, int64_t dim, const Tensor& index)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Tensor index_select_sparse_cpu(const Tensor& self, int64_t dim, const Tensor& index) {
|
||||||
|
/*
|
||||||
|
Algorithm:
|
||||||
|
index - a 1-D tensor of indicies with shape (n,)
|
||||||
|
self - sparse tensor, its shape is sizes = sparse_shape + dense_shape
|
||||||
|
indices - 2-D tensor of indices, shape is (sparse_dims, nnz)
|
||||||
|
values - (1+len(dense_shape))-D tensor of values, shape is (nnz,) + dense_shape
|
||||||
|
index_select(dim, index) returns a sparse tensor with the following data
|
||||||
|
new_sizes = sizes[:dim] + (n,) + sizes[dim+1:]
|
||||||
|
new_indices - shape is (sparse_dims, new_nnz)
|
||||||
|
new_values - shape is (new_nnz,) + dense_shape
|
||||||
|
|
||||||
|
if dim < len(sparse_shape):
|
||||||
|
for i, idx in enumerate(index):
|
||||||
|
for j, jdx in enumerate(indices[dim]):
|
||||||
|
if idx == jdx:
|
||||||
|
icol = indices[:dim][j] + (i,) + indices[dim+1:][j]
|
||||||
|
new_indices.add_column(icol)
|
||||||
|
new_values.add_row(values[j])
|
||||||
|
else:
|
||||||
|
new_indices = indices
|
||||||
|
new_values = values.index_select(dim - sparse_dim + 1, index);
|
||||||
|
*/
|
||||||
|
const auto ndim = self.dim();
|
||||||
|
if (ndim == 0) {
|
||||||
|
TORCH_CHECK_INDEX(false, "index_select() cannot be applied to a 0-dim tensor.");
|
||||||
|
}
|
||||||
|
if (!(index.dim() == 1 && index.dtype() == at::kLong)) {
|
||||||
|
TORCH_CHECK_INDEX(false, "index_select() argument index must be 1-D long-tensor.");
|
||||||
|
}
|
||||||
|
dim = maybe_wrap_dim(dim, ndim);
|
||||||
|
const auto size = self.size(dim);
|
||||||
|
const auto sparse_dim = self.sparse_dim();
|
||||||
|
const auto dense_dim = self.dense_dim();
|
||||||
|
const auto indices = self._indices();
|
||||||
|
const auto values = self._values();
|
||||||
|
const auto nnz = values.size(0);
|
||||||
|
const auto index_len = index.size(0);
|
||||||
|
auto res_sizes = self.sizes().vec();
|
||||||
|
res_sizes[dim] = index_len;
|
||||||
|
|
||||||
|
// TODO: decide on optimal grain size for sparse tensors.
|
||||||
|
const auto grain_size = at::internal::GRAIN_SIZE;
|
||||||
|
|
||||||
|
// 1 <= n_threads_nnz <= min(ceil(nnz / grain_size), get_num_threads())
|
||||||
|
const auto n_threads_nnz = std::max<int64_t>(
|
||||||
|
1,
|
||||||
|
std::min<int64_t>((nnz + grain_size - 1) / grain_size, at::get_num_threads())
|
||||||
|
);
|
||||||
|
|
||||||
|
// If indexing into sparse dimensions
|
||||||
|
if (dim < sparse_dim) {
|
||||||
|
auto nneg_index = at::empty_like(index);
|
||||||
|
{
|
||||||
|
auto* ptr_index = index.data_ptr<int64_t>();
|
||||||
|
auto* ptr_index_end = ptr_index + index_len;
|
||||||
|
auto* ptr_nneg_index = nneg_index.data_ptr<int64_t>();
|
||||||
|
while (ptr_index != ptr_index_end) {
|
||||||
|
auto idx = *ptr_index++;
|
||||||
|
if (idx < -size || idx >= size) {
|
||||||
|
TORCH_CHECK_INDEX(false,
|
||||||
|
"index_select(): index contains ", idx, " that is out of range for tensor of size ",
|
||||||
|
self.sizes(), " at dimension ", dim
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (idx < 0) {
|
||||||
|
idx += size;
|
||||||
|
}
|
||||||
|
*ptr_nneg_index++ = idx;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Much faster than at::unique_dim
|
||||||
|
const auto unique_with_counts = [](
|
||||||
|
const Tensor& t, int64_t len
|
||||||
|
) -> std::tuple<Tensor, Tensor, Tensor, int64_t> {
|
||||||
|
Tensor t_unique, t_idx;
|
||||||
|
std::tie(t_unique, t_idx) = at::sort(t);
|
||||||
|
|
||||||
|
auto t_counts = at::ones_like(t_unique);
|
||||||
|
int64_t len_unique;
|
||||||
|
{
|
||||||
|
auto* ptr_counts = t_counts.data_ptr<int64_t>();
|
||||||
|
auto* first = t_unique.data_ptr<int64_t>();
|
||||||
|
auto* last = first + len;
|
||||||
|
auto* curr = first;
|
||||||
|
auto* counts = ptr_counts;
|
||||||
|
while (++first != last) {
|
||||||
|
if (*curr != *first) {
|
||||||
|
++counts;
|
||||||
|
if (++curr != first) {
|
||||||
|
// std::swap(*curr, *first);
|
||||||
|
*curr = *first;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
++(*counts);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
len_unique = counts - ptr_counts + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_tuple(t_idx, t_unique, t_counts, len_unique);
|
||||||
|
};
|
||||||
|
|
||||||
|
Tensor dim_sort_indices, dim_indices_unique, dim_indices_counts;
|
||||||
|
int64_t n_unique_dim_indices;
|
||||||
|
std::tie(dim_sort_indices, dim_indices_unique, dim_indices_counts, n_unique_dim_indices)
|
||||||
|
= unique_with_counts(indices[dim].contiguous(), nnz);
|
||||||
|
|
||||||
|
Tensor sel_sort_indices, sel_indices_unique, sel_indices_counts;
|
||||||
|
int64_t n_unique_sel_indices;
|
||||||
|
std::tie(sel_sort_indices, sel_indices_unique, sel_indices_counts, n_unique_sel_indices)
|
||||||
|
= unique_with_counts(nneg_index, index_len);
|
||||||
|
|
||||||
|
const auto compute_index_intersections_and_nnz = [](
|
||||||
|
const Tensor& t1, const Tensor& c1, int64_t l1,
|
||||||
|
const Tensor& t2, const Tensor& c2, int64_t l2
|
||||||
|
) -> std::tuple<Tensor, Tensor, int64_t, int64_t> {
|
||||||
|
const auto lmin = std::min(l1, l2);
|
||||||
|
auto t1_idx = at::empty({lmin}, t1.options());
|
||||||
|
auto t2_idx = at::empty({lmin}, t2.options());
|
||||||
|
int64_t nnz = 0;
|
||||||
|
int64_t n_intersect = 0;
|
||||||
|
|
||||||
|
auto* ptr_t1_idx = t1_idx.data_ptr<int64_t>();
|
||||||
|
auto* ptr_t2_idx = t2_idx.data_ptr<int64_t>();
|
||||||
|
|
||||||
|
auto* ptr_t1 = t1.data_ptr<int64_t>();
|
||||||
|
auto* ptr_c1 = c1.data_ptr<int64_t>();
|
||||||
|
auto* ptr_t2 = t2.data_ptr<int64_t>();
|
||||||
|
auto* ptr_c2 = c2.data_ptr<int64_t>();
|
||||||
|
|
||||||
|
// we assume search in t2
|
||||||
|
auto* first = ptr_t2;
|
||||||
|
auto* last = ptr_t2 + l2;
|
||||||
|
|
||||||
|
for (const auto i : c10::irange(l1)) {
|
||||||
|
const auto idx = ptr_t1[i];
|
||||||
|
const auto idx_pos = std::lower_bound(first, last, idx);
|
||||||
|
if (idx_pos != last && *idx_pos == idx) {
|
||||||
|
const auto j = idx_pos - first;
|
||||||
|
const auto count1 = ptr_c1[i];
|
||||||
|
const auto count2 = ptr_c2[j];
|
||||||
|
*ptr_t1_idx++ = i;
|
||||||
|
*ptr_t2_idx++ = j;
|
||||||
|
++n_intersect;
|
||||||
|
nnz += count1 * count2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_tuple(t1_idx, t2_idx, n_intersect, nnz);
|
||||||
|
};
|
||||||
|
|
||||||
|
const auto search_in_index = (
|
||||||
|
n_unique_dim_indices * std::log2(n_unique_sel_indices) < std::log2(n_unique_dim_indices) * n_unique_sel_indices
|
||||||
|
);
|
||||||
|
|
||||||
|
Tensor dim_indices_intersect, index_intersect;
|
||||||
|
int64_t n_intersect, res_nnz;
|
||||||
|
|
||||||
|
if (search_in_index) {
|
||||||
|
std::tie(dim_indices_intersect, index_intersect, n_intersect, res_nnz)
|
||||||
|
= compute_index_intersections_and_nnz(
|
||||||
|
dim_indices_unique, dim_indices_counts, n_unique_dim_indices,
|
||||||
|
sel_indices_unique, sel_indices_counts, n_unique_sel_indices
|
||||||
|
);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
std::tie(index_intersect, dim_indices_intersect, n_intersect, res_nnz)
|
||||||
|
= compute_index_intersections_and_nnz(
|
||||||
|
sel_indices_unique, sel_indices_counts, n_unique_sel_indices,
|
||||||
|
dim_indices_unique, dim_indices_counts, n_unique_dim_indices
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto compute_offsets = [](const Tensor& counts, int64_t len) {
|
||||||
|
const auto narrowed_counts = counts.narrow(-1, 0, len);
|
||||||
|
return narrowed_counts.cumsum(/*dim=*/0).sub_(narrowed_counts);
|
||||||
|
};
|
||||||
|
|
||||||
|
const auto dim_sort_indices_offsets = compute_offsets(dim_indices_counts, n_unique_dim_indices);
|
||||||
|
const auto sel_sort_indices_offsets = compute_offsets(sel_indices_counts, n_unique_sel_indices);
|
||||||
|
|
||||||
|
auto selected_dim_indices = at::empty({res_nnz}, dim_indices_unique.options());
|
||||||
|
auto res_dim_indices = at::empty({res_nnz}, sel_indices_unique.options());
|
||||||
|
{
|
||||||
|
auto* ptr_selected_dim_indices = selected_dim_indices.data_ptr<int64_t>();
|
||||||
|
auto* ptr_res_dim_indices = res_dim_indices.data_ptr<int64_t>();
|
||||||
|
|
||||||
|
const auto* ptr_dim_indices_intersect = dim_indices_intersect.data_ptr<int64_t>();
|
||||||
|
const auto* ptr_sel_indices_intersect = index_intersect.data_ptr<int64_t>();
|
||||||
|
|
||||||
|
const auto* ptr_dim_indices_counts = dim_indices_counts.data_ptr<int64_t>();
|
||||||
|
const auto* ptr_sel_indices_counts = sel_indices_counts.data_ptr<int64_t>();
|
||||||
|
|
||||||
|
const auto* ptr_dim_sort_indices = dim_sort_indices.data_ptr<int64_t>();
|
||||||
|
const auto* ptr_sel_sort_indices = sel_sort_indices.data_ptr<int64_t>();
|
||||||
|
|
||||||
|
const auto* ptr_dim_indices_offsets = dim_sort_indices_offsets.data_ptr<int64_t>();
|
||||||
|
const auto* ptr_sel_indices_offsets = sel_sort_indices_offsets.data_ptr<int64_t>();
|
||||||
|
|
||||||
|
for (const auto ii : c10::irange(n_intersect)) {
|
||||||
|
const auto i_idx = *ptr_dim_indices_intersect++;
|
||||||
|
const auto j_idx = *ptr_sel_indices_intersect++;
|
||||||
|
|
||||||
|
const auto i_idx_count = ptr_dim_indices_counts[i_idx];
|
||||||
|
const auto j_idx_count = ptr_sel_indices_counts[j_idx];
|
||||||
|
|
||||||
|
const auto i_idx_offset = ptr_dim_indices_offsets[i_idx];
|
||||||
|
const auto j_idx_offset = ptr_sel_indices_offsets[j_idx];
|
||||||
|
|
||||||
|
const auto* src_dim_sort_indices = ptr_dim_sort_indices + i_idx_offset;
|
||||||
|
const auto* src_sel_sort_indices = ptr_sel_sort_indices + j_idx_offset;
|
||||||
|
|
||||||
|
const auto copy_chunk_len = i_idx_count * j_idx_count;
|
||||||
|
for (const auto chunk_elem_idx : c10::irange(copy_chunk_len)) {
|
||||||
|
*ptr_selected_dim_indices++ = *(src_dim_sort_indices + (chunk_elem_idx % i_idx_count));
|
||||||
|
*ptr_res_dim_indices++ = *(src_sel_sort_indices + (chunk_elem_idx % j_idx_count));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto res_indices = indices.index_select(1, selected_dim_indices);
|
||||||
|
res_indices[dim] = res_dim_indices;
|
||||||
|
const auto res_values = values.index_select(0, selected_dim_indices);
|
||||||
|
|
||||||
|
return _sparse_coo_tensor_with_dims_and_tensors(
|
||||||
|
sparse_dim, dense_dim, res_sizes, res_indices, res_values, self.options());
|
||||||
|
}
|
||||||
|
// If indexing into dense dimensions
|
||||||
|
else {
|
||||||
|
// It is sufficient to just perform `index_select` on values
|
||||||
|
// if `dim` refers to dense dimensions.
|
||||||
|
const auto res_values = values.index_select(dim - sparse_dim + 1, index);
|
||||||
|
|
||||||
|
return _sparse_coo_tensor_with_dims_and_tensors(
|
||||||
|
sparse_dim, dense_dim, res_sizes, indices, res_values, self.options());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Tensor slice(
|
Tensor slice(
|
||||||
const Tensor& self,
|
const Tensor& self,
|
||||||
int64_t dim,
|
int64_t dim,
|
||||||
|
|
|
||||||
|
|
@ -6864,8 +6864,7 @@
|
||||||
QuantizedCPU: index_select_quantized_cpu_
|
QuantizedCPU: index_select_quantized_cpu_
|
||||||
CUDA: index_select_cuda
|
CUDA: index_select_cuda
|
||||||
QuantizedCUDA: index_select_quantized_cuda
|
QuantizedCUDA: index_select_quantized_cuda
|
||||||
SparseCPU: index_select_sparse
|
SparseCPU: index_select_sparse_cpu
|
||||||
SparseCUDA: index_select_sparse
|
|
||||||
|
|
||||||
- func: index_select.dimname_out(Tensor self, Dimname dim, Tensor index, *, Tensor(a!) out) -> Tensor(a!)
|
- func: index_select.dimname_out(Tensor self, Dimname dim, Tensor index, *, Tensor(a!) out) -> Tensor(a!)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user