From 90fceb015d5324662ebc24fcf9cbfc4ff9dd127f Mon Sep 17 00:00:00 2001 From: Nikita Vedeneev Date: Fri, 11 Feb 2022 05:21:03 +0000 Subject: [PATCH] Faster index_select for COO tensors --- aten/src/ATen/native/TensorShape.cpp | 241 +++++++++++++++++++++ aten/src/ATen/native/native_functions.yaml | 3 +- 2 files changed, 242 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 3999805fee1..c7f0084921c 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -1390,6 +1390,247 @@ Tensor index_select_sparse(const Tensor& self, int64_t dim, const Tensor& index) } } +Tensor index_select_sparse_cpu(const Tensor& self, int64_t dim, const Tensor& index) { + /* + Algorithm: + index - a 1-D tensor of indicies with shape (n,) + self - sparse tensor, its shape is sizes = sparse_shape + dense_shape + indices - 2-D tensor of indices, shape is (sparse_dims, nnz) + values - (1+len(dense_shape))-D tensor of values, shape is (nnz,) + dense_shape + index_select(dim, index) returns a sparse tensor with the following data + new_sizes = sizes[:dim] + (n,) + sizes[dim+1:] + new_indices - shape is (sparse_dims, new_nnz) + new_values - shape is (new_nnz,) + dense_shape + + if dim < len(sparse_shape): + for i, idx in enumerate(index): + for j, jdx in enumerate(indices[dim]): + if idx == jdx: + icol = indices[:dim][j] + (i,) + indices[dim+1:][j] + new_indices.add_column(icol) + new_values.add_row(values[j]) + else: + new_indices = indices + new_values = values.index_select(dim - sparse_dim + 1, index); + */ + const auto ndim = self.dim(); + if (ndim == 0) { + TORCH_CHECK_INDEX(false, "index_select() cannot be applied to a 0-dim tensor."); + } + if (!(index.dim() == 1 && index.dtype() == at::kLong)) { + TORCH_CHECK_INDEX(false, "index_select() argument index must be 1-D long-tensor."); + } + dim = maybe_wrap_dim(dim, ndim); + const auto size = self.size(dim); + const auto sparse_dim = self.sparse_dim(); + const auto dense_dim = self.dense_dim(); + const auto indices = self._indices(); + const auto values = self._values(); + const auto nnz = values.size(0); + const auto index_len = index.size(0); + auto res_sizes = self.sizes().vec(); + res_sizes[dim] = index_len; + + // TODO: decide on optimal grain size for sparse tensors. + const auto grain_size = at::internal::GRAIN_SIZE; + + // 1 <= n_threads_nnz <= min(ceil(nnz / grain_size), get_num_threads()) + const auto n_threads_nnz = std::max( + 1, + std::min((nnz + grain_size - 1) / grain_size, at::get_num_threads()) + ); + + // If indexing into sparse dimensions + if (dim < sparse_dim) { + auto nneg_index = at::empty_like(index); + { + auto* ptr_index = index.data_ptr(); + auto* ptr_index_end = ptr_index + index_len; + auto* ptr_nneg_index = nneg_index.data_ptr(); + while (ptr_index != ptr_index_end) { + auto idx = *ptr_index++; + if (idx < -size || idx >= size) { + TORCH_CHECK_INDEX(false, + "index_select(): index contains ", idx, " that is out of range for tensor of size ", + self.sizes(), " at dimension ", dim + ); + } + if (idx < 0) { + idx += size; + } + *ptr_nneg_index++ = idx; + } + } + + // Much faster than at::unique_dim + const auto unique_with_counts = []( + const Tensor& t, int64_t len + ) -> std::tuple { + Tensor t_unique, t_idx; + std::tie(t_unique, t_idx) = at::sort(t); + + auto t_counts = at::ones_like(t_unique); + int64_t len_unique; + { + auto* ptr_counts = t_counts.data_ptr(); + auto* first = t_unique.data_ptr(); + auto* last = first + len; + auto* curr = first; + auto* counts = ptr_counts; + while (++first != last) { + if (*curr != *first) { + ++counts; + if (++curr != first) { + // std::swap(*curr, *first); + *curr = *first; + } + } + else { + ++(*counts); + } + } + len_unique = counts - ptr_counts + 1; + } + + return std::make_tuple(t_idx, t_unique, t_counts, len_unique); + }; + + Tensor dim_sort_indices, dim_indices_unique, dim_indices_counts; + int64_t n_unique_dim_indices; + std::tie(dim_sort_indices, dim_indices_unique, dim_indices_counts, n_unique_dim_indices) + = unique_with_counts(indices[dim].contiguous(), nnz); + + Tensor sel_sort_indices, sel_indices_unique, sel_indices_counts; + int64_t n_unique_sel_indices; + std::tie(sel_sort_indices, sel_indices_unique, sel_indices_counts, n_unique_sel_indices) + = unique_with_counts(nneg_index, index_len); + + const auto compute_index_intersections_and_nnz = []( + const Tensor& t1, const Tensor& c1, int64_t l1, + const Tensor& t2, const Tensor& c2, int64_t l2 + ) -> std::tuple { + const auto lmin = std::min(l1, l2); + auto t1_idx = at::empty({lmin}, t1.options()); + auto t2_idx = at::empty({lmin}, t2.options()); + int64_t nnz = 0; + int64_t n_intersect = 0; + + auto* ptr_t1_idx = t1_idx.data_ptr(); + auto* ptr_t2_idx = t2_idx.data_ptr(); + + auto* ptr_t1 = t1.data_ptr(); + auto* ptr_c1 = c1.data_ptr(); + auto* ptr_t2 = t2.data_ptr(); + auto* ptr_c2 = c2.data_ptr(); + + // we assume search in t2 + auto* first = ptr_t2; + auto* last = ptr_t2 + l2; + + for (const auto i : c10::irange(l1)) { + const auto idx = ptr_t1[i]; + const auto idx_pos = std::lower_bound(first, last, idx); + if (idx_pos != last && *idx_pos == idx) { + const auto j = idx_pos - first; + const auto count1 = ptr_c1[i]; + const auto count2 = ptr_c2[j]; + *ptr_t1_idx++ = i; + *ptr_t2_idx++ = j; + ++n_intersect; + nnz += count1 * count2; + } + } + + return std::make_tuple(t1_idx, t2_idx, n_intersect, nnz); + }; + + const auto search_in_index = ( + n_unique_dim_indices * std::log2(n_unique_sel_indices) < std::log2(n_unique_dim_indices) * n_unique_sel_indices + ); + + Tensor dim_indices_intersect, index_intersect; + int64_t n_intersect, res_nnz; + + if (search_in_index) { + std::tie(dim_indices_intersect, index_intersect, n_intersect, res_nnz) + = compute_index_intersections_and_nnz( + dim_indices_unique, dim_indices_counts, n_unique_dim_indices, + sel_indices_unique, sel_indices_counts, n_unique_sel_indices + ); + } + else { + std::tie(index_intersect, dim_indices_intersect, n_intersect, res_nnz) + = compute_index_intersections_and_nnz( + sel_indices_unique, sel_indices_counts, n_unique_sel_indices, + dim_indices_unique, dim_indices_counts, n_unique_dim_indices + ); + } + + const auto compute_offsets = [](const Tensor& counts, int64_t len) { + const auto narrowed_counts = counts.narrow(-1, 0, len); + return narrowed_counts.cumsum(/*dim=*/0).sub_(narrowed_counts); + }; + + const auto dim_sort_indices_offsets = compute_offsets(dim_indices_counts, n_unique_dim_indices); + const auto sel_sort_indices_offsets = compute_offsets(sel_indices_counts, n_unique_sel_indices); + + auto selected_dim_indices = at::empty({res_nnz}, dim_indices_unique.options()); + auto res_dim_indices = at::empty({res_nnz}, sel_indices_unique.options()); + { + auto* ptr_selected_dim_indices = selected_dim_indices.data_ptr(); + auto* ptr_res_dim_indices = res_dim_indices.data_ptr(); + + const auto* ptr_dim_indices_intersect = dim_indices_intersect.data_ptr(); + const auto* ptr_sel_indices_intersect = index_intersect.data_ptr(); + + const auto* ptr_dim_indices_counts = dim_indices_counts.data_ptr(); + const auto* ptr_sel_indices_counts = sel_indices_counts.data_ptr(); + + const auto* ptr_dim_sort_indices = dim_sort_indices.data_ptr(); + const auto* ptr_sel_sort_indices = sel_sort_indices.data_ptr(); + + const auto* ptr_dim_indices_offsets = dim_sort_indices_offsets.data_ptr(); + const auto* ptr_sel_indices_offsets = sel_sort_indices_offsets.data_ptr(); + + for (const auto ii : c10::irange(n_intersect)) { + const auto i_idx = *ptr_dim_indices_intersect++; + const auto j_idx = *ptr_sel_indices_intersect++; + + const auto i_idx_count = ptr_dim_indices_counts[i_idx]; + const auto j_idx_count = ptr_sel_indices_counts[j_idx]; + + const auto i_idx_offset = ptr_dim_indices_offsets[i_idx]; + const auto j_idx_offset = ptr_sel_indices_offsets[j_idx]; + + const auto* src_dim_sort_indices = ptr_dim_sort_indices + i_idx_offset; + const auto* src_sel_sort_indices = ptr_sel_sort_indices + j_idx_offset; + + const auto copy_chunk_len = i_idx_count * j_idx_count; + for (const auto chunk_elem_idx : c10::irange(copy_chunk_len)) { + *ptr_selected_dim_indices++ = *(src_dim_sort_indices + (chunk_elem_idx % i_idx_count)); + *ptr_res_dim_indices++ = *(src_sel_sort_indices + (chunk_elem_idx % j_idx_count)); + } + } + } + + auto res_indices = indices.index_select(1, selected_dim_indices); + res_indices[dim] = res_dim_indices; + const auto res_values = values.index_select(0, selected_dim_indices); + + return _sparse_coo_tensor_with_dims_and_tensors( + sparse_dim, dense_dim, res_sizes, res_indices, res_values, self.options()); + } + // If indexing into dense dimensions + else { + // It is sufficient to just perform `index_select` on values + // if `dim` refers to dense dimensions. + const auto res_values = values.index_select(dim - sparse_dim + 1, index); + + return _sparse_coo_tensor_with_dims_and_tensors( + sparse_dim, dense_dim, res_sizes, indices, res_values, self.options()); + } +} + Tensor slice( const Tensor& self, int64_t dim, diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 5eb907537f3..91a169fa1c6 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -6864,8 +6864,7 @@ QuantizedCPU: index_select_quantized_cpu_ CUDA: index_select_cuda QuantizedCUDA: index_select_quantized_cuda - SparseCPU: index_select_sparse - SparseCUDA: index_select_sparse + SparseCPU: index_select_sparse_cpu - func: index_select.dimname_out(Tensor self, Dimname dim, Tensor index, *, Tensor(a!) out) -> Tensor(a!)