[Reland] Embedding thrust->cub migration (#63806)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/63427

Pull Request resolved: https://github.com/pytorch/pytorch/pull/63806

Reviewed By: bdhirsh

Differential Revision: D30498255

Pulled By: ngimel

fbshipit-source-id: 78b7085a92a168cf0163f53dcb712bac922f5235
This commit is contained in:
Xiang Gao 2021-08-24 09:24:50 -07:00 committed by Facebook GitHub Bot
parent 94d621584a
commit 227cb268bc
8 changed files with 95 additions and 88 deletions

View File

@ -3,6 +3,7 @@
#include <cstddef>
#include <type_traits>
#include <iterator>
#include <limits>
// include cub in a safe manner, see:
// https://github.com/pytorch/pytorch/pull/55292
@ -102,6 +103,8 @@ static inline void sort_keys(
const key_t *keys_in, key_t *keys_out,
int64_t n, bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8
) {
TORCH_CHECK(n <= std::numeric_limits<int>::max(),
"cub sort does not support sorting more than INT_MAX elements");
using key_t_ = typename detail::cuda_type<key_t>::type;
const key_t_ *keys_in_ = reinterpret_cast<const key_t_*>(keys_in);
@ -124,6 +127,8 @@ static inline void sort_pairs(
const value_t *values_in, value_t *values_out,
int64_t n, bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8
) {
TORCH_CHECK(n <= std::numeric_limits<int>::max(),
"cub sort does not support sorting more than INT_MAX elements");
using key_t_ = typename detail::cuda_type<key_t>::type;
auto allocator = c10::cuda::CUDACachingAllocator::get();
@ -156,6 +161,10 @@ static inline void segmented_sort_pairs(
OffsetIteratorT begin_offsets, OffsetIteratorT end_offsets,
bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8
) {
TORCH_CHECK(num_elements <= std::numeric_limits<int>::max(),
"cub sort does not support sorting more than INT_MAX elements");
TORCH_CHECK(num_segments <= std::numeric_limits<int>::max(),
"cub sort does not support sorting more than INT_MAX elements");
using key_t_ = typename detail::cuda_type<key_t>::type;
auto allocator = c10::cuda::CUDACachingAllocator::get();
@ -305,4 +314,12 @@ inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
}
}
}}}
template<typename InputIteratorT , typename OutputIteratorT , typename NumSelectedIteratorT >
inline void unique(InputIteratorT input, OutputIteratorT output, NumSelectedIteratorT num_selected_out, int64_t num_items) {
TORCH_CHECK(num_items <= std::numeric_limits<int>::max(),
"cub unique does not support more than INT_MAX elements");
CUB_WRAPPER(NO_ROCM(detail)::cub::DeviceSelect::Unique,
input, output, num_selected_out, num_items, at::cuda::getCurrentCUDAStream());
}
}}} // namespace at::cuda::cub

View File

@ -7,12 +7,9 @@
#include <THC/THCDeviceUtils.cuh>
#include <THC/THCTensorMathReduce.cuh>
#include <THC/THCThrustAllocator.cuh>
#include <THC/THCReduceApplyUtils.cuh>
#include <thrust/execution_policy.h>
#include <thrust/iterator/constant_iterator.h>
#include <thrust/unique.h>
#include <ATen/cuda/cub.cuh>
#include <ATen/native/cuda/EmbeddingBackwardKernel.cuh>
#include <ATen/native/cuda/SortingCommon.cuh>
@ -224,14 +221,19 @@ __global__ void renorm_kernel(
} // anonymous namespace
Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indices,
template<typename index_t>
void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count);
Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indices_,
int64_t num_weights, int64_t padding_idx,
bool scale_grad_by_freq) {
auto grad_arg = TensorArg(grad_, "grad", 1);
auto indices_arg = TensorArg(indices, "indices", 1);
auto indices_arg = TensorArg(indices_, "indices", 1);
checkScalarTypes("embedding_backward", indices_arg, {kLong, kInt});
checkSameGPU("embedding_backward", grad_arg, indices_arg);
auto indices = indices_.contiguous();
auto num_indices = indices.numel();
auto grad = grad_.contiguous().view({num_indices, grad_.size(-1)});
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
@ -272,59 +274,16 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice
auto orig_indices = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
Tensor count;
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () {
using device_ptr = thrust::device_ptr<index_t>;
// Sort the inputs into sorted with the corresponding indices; we
// don't need a stable or multidimensional sort, so just use Thrust
// directly
{
sorted_indices.copy_(indices);
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
auto policy = thrust::cuda::par(allocator).on(stream);
// Fill sortedOrigIndices with sequential indices
auto count_iter = thrust::counting_iterator<index_t>(0);
auto orig_data = device_ptr(orig_indices.data_ptr<index_t>());
thrust::copy(policy, count_iter, count_iter + num_indices, orig_data);
// Sort; a stable sort is not required
auto sorted_data = device_ptr(sorted_indices.data_ptr<index_t>());
thrust::sort_by_key(policy, sorted_data, sorted_data + num_indices, orig_data,
LTOp<index_t>());
}
auto range = at::arange(num_indices, indices.options());
int64_t nbits = cuda::cub::get_num_bits(num_weights);
cuda::cub::sort_pairs(
indices.data_ptr<index_t>(), sorted_indices.data_ptr<index_t>(),
range.data_ptr<index_t>(), orig_indices.data_ptr<index_t>(),
num_indices, false/*, 0, nbits*/);
if (scale_grad_by_freq) {
count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
auto policy = thrust::cuda::par(allocator).on(stream);
// Compute an increasing sequence per unique item in sortedIndices:
// sorted: 2 5 5 5 7 7 8 9 9
// count: 1 1 2 3 1 2 1 1 2
auto sorted_data = device_ptr(sorted_indices.data_ptr<index_t>());
auto count_data = device_ptr(count.data_ptr<index_t>());
thrust::inclusive_scan_by_key(
policy,
sorted_data,
sorted_data + num_indices,
thrust::make_constant_iterator(1),
count_data
);
// Take the maximum of each count per unique key in reverse:
// sorted: 2 5 5 5 7 7 8 9 9
// count: 1 3 3 3 2 2 1 2 2
thrust::inclusive_scan_by_key(
policy,
thrust::make_reverse_iterator(sorted_data + num_indices),
thrust::make_reverse_iterator(sorted_data),
thrust::make_reverse_iterator(count_data + num_indices),
thrust::make_reverse_iterator(count_data + num_indices),
thrust::equal_to<index_t>(),
thrust::maximum<index_t>()
);
embedding_dense_backward_cuda_scan<index_t>(sorted_indices, count);
}
});
@ -340,23 +299,23 @@ Tensor & embedding_renorm_cuda_(Tensor & self, const Tensor & indices,
checkSameGPU("embedding_renorm", self_arg, indices_arg);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
auto policy = thrust::cuda::par(allocator).on(stream);
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_renorm_cuda_", [&] () {
using device_ptr = thrust::device_ptr<index_t>;
auto num_indices = indices.numel();
auto indices_contig = std::get<0>(indices.sort()).contiguous();
auto indices_data = device_ptr(indices_contig.data_ptr<index_t>());
auto unique_indices = at::empty(indices.numel(), indices.options());
auto unique_data = device_ptr(unique_indices.data_ptr<index_t>());
auto end = thrust::unique_copy(policy, indices_data, indices_data + num_indices, unique_data);
auto num_unique_indices = static_cast<int>(end - unique_data);
auto num_unique_indices = at::empty({}, indices.options().dtype(kLong));
dim3 grid(num_unique_indices);
dim3 block(128);
cuda::cub::unique(
indices_contig.data_ptr<index_t>(),
unique_indices.data_ptr<index_t>(),
num_unique_indices.data_ptr<int64_t>(),
num_indices
);
dim3 grid = num_unique_indices.item<int64_t>();
dim3 block = 128;
int dim = self.stride(0);
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "embedding_backward", [&] {

View File

@ -10,10 +10,6 @@
#include <THC/THCThrustAllocator.cuh>
#include <THC/THCAtomics.cuh>
#include <thrust/execution_policy.h>
#include <thrust/unique.h>
#include <thrust/device_vector.h>
#pragma once
namespace at {

View File

@ -218,9 +218,6 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<c10::optional<Ten
std::tie(linearIndex, src, nElemBefore, strideBefore, sliceSize, inversePerm) = makeLinearIndex(self, indices, !unsafe);
int64_t num_indices = linearIndex.numel();
TORCH_CHECK(num_indices <= std::numeric_limits<int>::max(),
"index_put of tensors larger than INT_MAX is not supported yet in pytorch");
if (num_indices > 0 && sliceSize > 0) {
const bool permuted = !src.is_contiguous();
auto src_ = permuted ? src.contiguous() : src;

View File

@ -5,6 +5,8 @@
#include <thrust/device_ptr.h>
#include <thrust/execution_policy.h>
#include <thrust/sort.h>
#include <thrust/unique.h>
#include <thrust/device_ptr.h>
namespace at { namespace native {
@ -30,4 +32,45 @@ void index_put_with_sort_kernel_thrust_helper(Tensor &linearIndex, Tensor &orig_
thrust::sort_by_key(policy, sorted_data, sorted_data + num_indices, orig_data, LTOp<int64_t>());
}
template<typename index_t>
void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count) {
using device_ptr = thrust::device_ptr<index_t>;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
auto policy = thrust::cuda::par(allocator).on(stream);
auto num_indices = count.numel();
// Compute an increasing sequence per unique item in sortedIndices:
// sorted: 2 5 5 5 7 7 8 9 9
// count: 1 1 2 3 1 2 1 1 2
auto sorted_data = device_ptr(sorted_indices.data_ptr<index_t>());
auto count_data = device_ptr(count.data_ptr<index_t>());
thrust::inclusive_scan_by_key(
policy,
sorted_data,
sorted_data + num_indices,
thrust::make_constant_iterator(1),
count_data
);
// Take the maximum of each count per unique key in reverse:
// sorted: 2 5 5 5 7 7 8 9 9
// count: 1 3 3 3 2 2 1 2 2
thrust::inclusive_scan_by_key(
policy,
thrust::make_reverse_iterator(sorted_data + num_indices),
thrust::make_reverse_iterator(sorted_data),
thrust::make_reverse_iterator(count_data + num_indices),
thrust::make_reverse_iterator(count_data + num_indices),
thrust::equal_to<index_t>(),
thrust::maximum<index_t>()
);
}
template
void embedding_dense_backward_cuda_scan<int>(Tensor &sorted_indices, Tensor &count);
template
void embedding_dense_backward_cuda_scan<int64_t>(Tensor &sorted_indices, Tensor &count);
}}

View File

@ -47,8 +47,6 @@ template <int N> struct alignas(N) OpaqueType { char data[N]; };
Tensor& randperm_out_cuda(int64_t n, c10::optional<Generator> generator, Tensor& result) {
TORCH_CHECK(n >= 0, "n must be non-negative, got", n);
TORCH_CHECK(n <= std::numeric_limits<int>::max(),
"randperm of tensors larger than INT_MAX is not supported yet in pytorch");
check_supported_max_int_with_precision(n, result);

View File

@ -94,13 +94,7 @@ std::tuple<Tensor, Tensor, Tensor, int64_t> compute_unique(
Tensor length = at::empty({1}, options);
int64_t num_out;
if (!return_counts) {
CUB_WRAPPER(
cub::DeviceSelect::Unique,
data,
data_out.data_ptr<scalar_t>(),
length.data_ptr<int64_t>(),
num_inp,
stream);
cuda::cub::unique(data, data_out.data_ptr<scalar_t>(), length.data_ptr<int64_t>(), num_inp);
num_out = length.item<int64_t>();
} else {
counts.resize_(num_inp);
@ -135,11 +129,6 @@ std::tuple<Tensor, Tensor, Tensor> unique_cuda_template(
auto options = self.options().dtype(kLong);
int64_t num_inp = self.numel();
TORCH_CHECK(
num_inp <= INT_MAX,
"num_inp ",
num_inp,
" is too big to be handled by cub");
Tensor sorted;
Tensor self_c = self.contiguous();
if (consecutive) {

View File

@ -2774,6 +2774,14 @@ new_module_tests = [
input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
check_gradgrad=False,
),
dict(
module_name='Embedding',
constructor_args=(4, 3),
cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3)',
input_fn=lambda: torch.empty(1, 512, dtype=torch.long).random_(4).expand(7, 512),
check_gradgrad=False,
desc='discontiguous'
),
dict(
module_name='EmbeddingBag',
constructor_args=(4, 3),