mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[Reland] Embedding thrust->cub migration (#63806)
Summary: Fixes https://github.com/pytorch/pytorch/issues/63427 Pull Request resolved: https://github.com/pytorch/pytorch/pull/63806 Reviewed By: bdhirsh Differential Revision: D30498255 Pulled By: ngimel fbshipit-source-id: 78b7085a92a168cf0163f53dcb712bac922f5235
This commit is contained in:
parent
94d621584a
commit
227cb268bc
|
|
@ -3,6 +3,7 @@
|
|||
#include <cstddef>
|
||||
#include <type_traits>
|
||||
#include <iterator>
|
||||
#include <limits>
|
||||
|
||||
// include cub in a safe manner, see:
|
||||
// https://github.com/pytorch/pytorch/pull/55292
|
||||
|
|
@ -102,6 +103,8 @@ static inline void sort_keys(
|
|||
const key_t *keys_in, key_t *keys_out,
|
||||
int64_t n, bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8
|
||||
) {
|
||||
TORCH_CHECK(n <= std::numeric_limits<int>::max(),
|
||||
"cub sort does not support sorting more than INT_MAX elements");
|
||||
using key_t_ = typename detail::cuda_type<key_t>::type;
|
||||
|
||||
const key_t_ *keys_in_ = reinterpret_cast<const key_t_*>(keys_in);
|
||||
|
|
@ -124,6 +127,8 @@ static inline void sort_pairs(
|
|||
const value_t *values_in, value_t *values_out,
|
||||
int64_t n, bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8
|
||||
) {
|
||||
TORCH_CHECK(n <= std::numeric_limits<int>::max(),
|
||||
"cub sort does not support sorting more than INT_MAX elements");
|
||||
using key_t_ = typename detail::cuda_type<key_t>::type;
|
||||
|
||||
auto allocator = c10::cuda::CUDACachingAllocator::get();
|
||||
|
|
@ -156,6 +161,10 @@ static inline void segmented_sort_pairs(
|
|||
OffsetIteratorT begin_offsets, OffsetIteratorT end_offsets,
|
||||
bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8
|
||||
) {
|
||||
TORCH_CHECK(num_elements <= std::numeric_limits<int>::max(),
|
||||
"cub sort does not support sorting more than INT_MAX elements");
|
||||
TORCH_CHECK(num_segments <= std::numeric_limits<int>::max(),
|
||||
"cub sort does not support sorting more than INT_MAX elements");
|
||||
using key_t_ = typename detail::cuda_type<key_t>::type;
|
||||
|
||||
auto allocator = c10::cuda::CUDACachingAllocator::get();
|
||||
|
|
@ -305,4 +314,12 @@ inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
|
|||
}
|
||||
}
|
||||
|
||||
}}}
|
||||
template<typename InputIteratorT , typename OutputIteratorT , typename NumSelectedIteratorT >
|
||||
inline void unique(InputIteratorT input, OutputIteratorT output, NumSelectedIteratorT num_selected_out, int64_t num_items) {
|
||||
TORCH_CHECK(num_items <= std::numeric_limits<int>::max(),
|
||||
"cub unique does not support more than INT_MAX elements");
|
||||
CUB_WRAPPER(NO_ROCM(detail)::cub::DeviceSelect::Unique,
|
||||
input, output, num_selected_out, num_items, at::cuda::getCurrentCUDAStream());
|
||||
}
|
||||
|
||||
}}} // namespace at::cuda::cub
|
||||
|
|
|
|||
|
|
@ -7,12 +7,9 @@
|
|||
|
||||
#include <THC/THCDeviceUtils.cuh>
|
||||
#include <THC/THCTensorMathReduce.cuh>
|
||||
#include <THC/THCThrustAllocator.cuh>
|
||||
#include <THC/THCReduceApplyUtils.cuh>
|
||||
|
||||
#include <thrust/execution_policy.h>
|
||||
#include <thrust/iterator/constant_iterator.h>
|
||||
#include <thrust/unique.h>
|
||||
#include <ATen/cuda/cub.cuh>
|
||||
|
||||
#include <ATen/native/cuda/EmbeddingBackwardKernel.cuh>
|
||||
#include <ATen/native/cuda/SortingCommon.cuh>
|
||||
|
|
@ -224,14 +221,19 @@ __global__ void renorm_kernel(
|
|||
|
||||
} // anonymous namespace
|
||||
|
||||
Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indices,
|
||||
template<typename index_t>
|
||||
void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count);
|
||||
|
||||
Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indices_,
|
||||
int64_t num_weights, int64_t padding_idx,
|
||||
bool scale_grad_by_freq) {
|
||||
auto grad_arg = TensorArg(grad_, "grad", 1);
|
||||
auto indices_arg = TensorArg(indices, "indices", 1);
|
||||
auto indices_arg = TensorArg(indices_, "indices", 1);
|
||||
checkScalarTypes("embedding_backward", indices_arg, {kLong, kInt});
|
||||
checkSameGPU("embedding_backward", grad_arg, indices_arg);
|
||||
|
||||
auto indices = indices_.contiguous();
|
||||
|
||||
auto num_indices = indices.numel();
|
||||
auto grad = grad_.contiguous().view({num_indices, grad_.size(-1)});
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
|
@ -272,59 +274,16 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice
|
|||
auto orig_indices = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
Tensor count;
|
||||
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () {
|
||||
using device_ptr = thrust::device_ptr<index_t>;
|
||||
|
||||
// Sort the inputs into sorted with the corresponding indices; we
|
||||
// don't need a stable or multidimensional sort, so just use Thrust
|
||||
// directly
|
||||
{
|
||||
sorted_indices.copy_(indices);
|
||||
|
||||
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
|
||||
auto policy = thrust::cuda::par(allocator).on(stream);
|
||||
|
||||
// Fill sortedOrigIndices with sequential indices
|
||||
auto count_iter = thrust::counting_iterator<index_t>(0);
|
||||
auto orig_data = device_ptr(orig_indices.data_ptr<index_t>());
|
||||
thrust::copy(policy, count_iter, count_iter + num_indices, orig_data);
|
||||
|
||||
// Sort; a stable sort is not required
|
||||
auto sorted_data = device_ptr(sorted_indices.data_ptr<index_t>());
|
||||
thrust::sort_by_key(policy, sorted_data, sorted_data + num_indices, orig_data,
|
||||
LTOp<index_t>());
|
||||
}
|
||||
auto range = at::arange(num_indices, indices.options());
|
||||
int64_t nbits = cuda::cub::get_num_bits(num_weights);
|
||||
cuda::cub::sort_pairs(
|
||||
indices.data_ptr<index_t>(), sorted_indices.data_ptr<index_t>(),
|
||||
range.data_ptr<index_t>(), orig_indices.data_ptr<index_t>(),
|
||||
num_indices, false/*, 0, nbits*/);
|
||||
|
||||
if (scale_grad_by_freq) {
|
||||
count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
|
||||
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
|
||||
auto policy = thrust::cuda::par(allocator).on(stream);
|
||||
|
||||
// Compute an increasing sequence per unique item in sortedIndices:
|
||||
// sorted: 2 5 5 5 7 7 8 9 9
|
||||
// count: 1 1 2 3 1 2 1 1 2
|
||||
auto sorted_data = device_ptr(sorted_indices.data_ptr<index_t>());
|
||||
auto count_data = device_ptr(count.data_ptr<index_t>());
|
||||
thrust::inclusive_scan_by_key(
|
||||
policy,
|
||||
sorted_data,
|
||||
sorted_data + num_indices,
|
||||
thrust::make_constant_iterator(1),
|
||||
count_data
|
||||
);
|
||||
|
||||
// Take the maximum of each count per unique key in reverse:
|
||||
// sorted: 2 5 5 5 7 7 8 9 9
|
||||
// count: 1 3 3 3 2 2 1 2 2
|
||||
thrust::inclusive_scan_by_key(
|
||||
policy,
|
||||
thrust::make_reverse_iterator(sorted_data + num_indices),
|
||||
thrust::make_reverse_iterator(sorted_data),
|
||||
thrust::make_reverse_iterator(count_data + num_indices),
|
||||
thrust::make_reverse_iterator(count_data + num_indices),
|
||||
thrust::equal_to<index_t>(),
|
||||
thrust::maximum<index_t>()
|
||||
);
|
||||
embedding_dense_backward_cuda_scan<index_t>(sorted_indices, count);
|
||||
}
|
||||
});
|
||||
|
||||
|
|
@ -340,23 +299,23 @@ Tensor & embedding_renorm_cuda_(Tensor & self, const Tensor & indices,
|
|||
checkSameGPU("embedding_renorm", self_arg, indices_arg);
|
||||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
|
||||
auto policy = thrust::cuda::par(allocator).on(stream);
|
||||
|
||||
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_renorm_cuda_", [&] () {
|
||||
using device_ptr = thrust::device_ptr<index_t>;
|
||||
|
||||
auto num_indices = indices.numel();
|
||||
auto indices_contig = std::get<0>(indices.sort()).contiguous();
|
||||
auto indices_data = device_ptr(indices_contig.data_ptr<index_t>());
|
||||
|
||||
auto unique_indices = at::empty(indices.numel(), indices.options());
|
||||
auto unique_data = device_ptr(unique_indices.data_ptr<index_t>());
|
||||
auto end = thrust::unique_copy(policy, indices_data, indices_data + num_indices, unique_data);
|
||||
auto num_unique_indices = static_cast<int>(end - unique_data);
|
||||
auto num_unique_indices = at::empty({}, indices.options().dtype(kLong));
|
||||
|
||||
dim3 grid(num_unique_indices);
|
||||
dim3 block(128);
|
||||
cuda::cub::unique(
|
||||
indices_contig.data_ptr<index_t>(),
|
||||
unique_indices.data_ptr<index_t>(),
|
||||
num_unique_indices.data_ptr<int64_t>(),
|
||||
num_indices
|
||||
);
|
||||
|
||||
dim3 grid = num_unique_indices.item<int64_t>();
|
||||
dim3 block = 128;
|
||||
int dim = self.stride(0);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "embedding_backward", [&] {
|
||||
|
|
|
|||
|
|
@ -10,10 +10,6 @@
|
|||
#include <THC/THCThrustAllocator.cuh>
|
||||
#include <THC/THCAtomics.cuh>
|
||||
|
||||
#include <thrust/execution_policy.h>
|
||||
#include <thrust/unique.h>
|
||||
#include <thrust/device_vector.h>
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace at {
|
||||
|
|
|
|||
|
|
@ -218,9 +218,6 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<c10::optional<Ten
|
|||
std::tie(linearIndex, src, nElemBefore, strideBefore, sliceSize, inversePerm) = makeLinearIndex(self, indices, !unsafe);
|
||||
int64_t num_indices = linearIndex.numel();
|
||||
|
||||
TORCH_CHECK(num_indices <= std::numeric_limits<int>::max(),
|
||||
"index_put of tensors larger than INT_MAX is not supported yet in pytorch");
|
||||
|
||||
if (num_indices > 0 && sliceSize > 0) {
|
||||
const bool permuted = !src.is_contiguous();
|
||||
auto src_ = permuted ? src.contiguous() : src;
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@
|
|||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/execution_policy.h>
|
||||
#include <thrust/sort.h>
|
||||
#include <thrust/unique.h>
|
||||
#include <thrust/device_ptr.h>
|
||||
|
||||
namespace at { namespace native {
|
||||
|
||||
|
|
@ -30,4 +32,45 @@ void index_put_with_sort_kernel_thrust_helper(Tensor &linearIndex, Tensor &orig_
|
|||
thrust::sort_by_key(policy, sorted_data, sorted_data + num_indices, orig_data, LTOp<int64_t>());
|
||||
}
|
||||
|
||||
template<typename index_t>
|
||||
void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count) {
|
||||
using device_ptr = thrust::device_ptr<index_t>;
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
|
||||
auto policy = thrust::cuda::par(allocator).on(stream);
|
||||
|
||||
auto num_indices = count.numel();
|
||||
|
||||
// Compute an increasing sequence per unique item in sortedIndices:
|
||||
// sorted: 2 5 5 5 7 7 8 9 9
|
||||
// count: 1 1 2 3 1 2 1 1 2
|
||||
auto sorted_data = device_ptr(sorted_indices.data_ptr<index_t>());
|
||||
auto count_data = device_ptr(count.data_ptr<index_t>());
|
||||
thrust::inclusive_scan_by_key(
|
||||
policy,
|
||||
sorted_data,
|
||||
sorted_data + num_indices,
|
||||
thrust::make_constant_iterator(1),
|
||||
count_data
|
||||
);
|
||||
|
||||
// Take the maximum of each count per unique key in reverse:
|
||||
// sorted: 2 5 5 5 7 7 8 9 9
|
||||
// count: 1 3 3 3 2 2 1 2 2
|
||||
thrust::inclusive_scan_by_key(
|
||||
policy,
|
||||
thrust::make_reverse_iterator(sorted_data + num_indices),
|
||||
thrust::make_reverse_iterator(sorted_data),
|
||||
thrust::make_reverse_iterator(count_data + num_indices),
|
||||
thrust::make_reverse_iterator(count_data + num_indices),
|
||||
thrust::equal_to<index_t>(),
|
||||
thrust::maximum<index_t>()
|
||||
);
|
||||
}
|
||||
|
||||
template
|
||||
void embedding_dense_backward_cuda_scan<int>(Tensor &sorted_indices, Tensor &count);
|
||||
template
|
||||
void embedding_dense_backward_cuda_scan<int64_t>(Tensor &sorted_indices, Tensor &count);
|
||||
|
||||
}}
|
||||
|
|
|
|||
|
|
@ -47,8 +47,6 @@ template <int N> struct alignas(N) OpaqueType { char data[N]; };
|
|||
|
||||
Tensor& randperm_out_cuda(int64_t n, c10::optional<Generator> generator, Tensor& result) {
|
||||
TORCH_CHECK(n >= 0, "n must be non-negative, got", n);
|
||||
TORCH_CHECK(n <= std::numeric_limits<int>::max(),
|
||||
"randperm of tensors larger than INT_MAX is not supported yet in pytorch");
|
||||
|
||||
check_supported_max_int_with_precision(n, result);
|
||||
|
||||
|
|
|
|||
|
|
@ -94,13 +94,7 @@ std::tuple<Tensor, Tensor, Tensor, int64_t> compute_unique(
|
|||
Tensor length = at::empty({1}, options);
|
||||
int64_t num_out;
|
||||
if (!return_counts) {
|
||||
CUB_WRAPPER(
|
||||
cub::DeviceSelect::Unique,
|
||||
data,
|
||||
data_out.data_ptr<scalar_t>(),
|
||||
length.data_ptr<int64_t>(),
|
||||
num_inp,
|
||||
stream);
|
||||
cuda::cub::unique(data, data_out.data_ptr<scalar_t>(), length.data_ptr<int64_t>(), num_inp);
|
||||
num_out = length.item<int64_t>();
|
||||
} else {
|
||||
counts.resize_(num_inp);
|
||||
|
|
@ -135,11 +129,6 @@ std::tuple<Tensor, Tensor, Tensor> unique_cuda_template(
|
|||
|
||||
auto options = self.options().dtype(kLong);
|
||||
int64_t num_inp = self.numel();
|
||||
TORCH_CHECK(
|
||||
num_inp <= INT_MAX,
|
||||
"num_inp ",
|
||||
num_inp,
|
||||
" is too big to be handled by cub");
|
||||
Tensor sorted;
|
||||
Tensor self_c = self.contiguous();
|
||||
if (consecutive) {
|
||||
|
|
|
|||
|
|
@ -2774,6 +2774,14 @@ new_module_tests = [
|
|||
input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
|
||||
check_gradgrad=False,
|
||||
),
|
||||
dict(
|
||||
module_name='Embedding',
|
||||
constructor_args=(4, 3),
|
||||
cpp_constructor_args='torch::nn::EmbeddingOptions(4, 3)',
|
||||
input_fn=lambda: torch.empty(1, 512, dtype=torch.long).random_(4).expand(7, 512),
|
||||
check_gradgrad=False,
|
||||
desc='discontiguous'
|
||||
),
|
||||
dict(
|
||||
module_name='EmbeddingBag',
|
||||
constructor_args=(4, 3),
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user