[ATen][CUDA][CUB] Implement changes to CCCL (CUB/Thrust/LibCUDACXX) usage in ATen (#153373)

A major release of CCCL 3.0.0 will introduce some bc-breaking changes. Namely iterators like TransformInputIterator and ConstantInputIterator were moved from CUB to Thrust, some operators like Max and Sum were moved to LibCUDACXX.

For the more info on changes please visit: https://nvidia.github.io/cccl/cccl/3.0_migration_guide.html

This is a follow up to PR #147493. A description from the original PR:
> Several cub iterators have been deprecated and removed in the latest CCCL (cub) development https://github.com/NVIDIA/cccl/pull/3831. This PR replaced the usage of those cub iterators with thrust iterators.
>
> Some cub thread operators were also deprecated and removed in https://github.com/NVIDIA/cccl/pull/3918. This PR replaced those operators with libcudacxx ops.
>
> This might also affect ROCM usability a bit.
>
> This patch is tested to work with CCCL commit at 82befb0894
>
> Tracking of CCCL/CUB deprecations in the most recent development https://github.com/NVIDIA/cccl/issues/101

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153373
Approved by: https://github.com/cyyever, https://github.com/atalman
This commit is contained in:
Aidyn-A 2025-06-28 05:44:47 +00:00 committed by PyTorch MergeBot
parent a92b24cd83
commit 51eb8e8f84
9 changed files with 50 additions and 25 deletions

View File

@ -15,8 +15,7 @@ struct SumOp {
template <typename input_t, typename output_t>
void inclusive_sum_truncating(const input_t *input, output_t *output, int64_t num_items) {
using NO_ROCM(at_cuda_detail)::cub::Sum;
inclusive_scan(input, output, Sum{}, num_items);
inclusive_scan(input, output, NO_ROCM(::cuda)::std::plus<>{}, num_items);
}
template void inclusive_sum_truncating(const int32_t *input, int32_t *output, int64_t num_items);
@ -42,8 +41,7 @@ struct CountMaskOp {
void mask_exclusive_sum(const uint8_t *mask, int64_t *output_idx, int64_t n) {
CountMaskOp op{};
auto iter = NO_ROCM(at_cuda_detail)::cub::TransformInputIterator<
bool, decltype(op), decltype(mask)>(mask, op);
auto iter = ATEN_CUB_TRANSFORM_ITERATOR(bool, decltype(op), decltype(mask))(mask, op);
exclusive_scan(iter, output_idx, SumOp<int64_t>{}, int64_t{0}, n);
}

View File

@ -6,6 +6,10 @@
#include <iterator>
#include <limits>
#ifndef USE_ROCM
#include <cuda/std/functional>
#endif
#include <ATen/cuda/cub_definitions.cuh>
#include <ATen/cuda/CUDAContextLight.h>
@ -51,6 +55,21 @@
#define ROCM_HIPCUB(x) x
#endif
#if CUB_V3_PLUS()
#include <thrust/iterator/transform_iterator.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/constant_iterator.h>
#define ATEN_CUB_TRANSFORM_ITERATOR(ValueType, ...) ::thrust::transform_iterator<__VA_ARGS__>
#define ATEN_CUB_COUNTING_ITERATOR(...) ::thrust::counting_iterator<__VA_ARGS__>
#define ATEN_CUB_CONSTANT_ITERATOR(...) ::thrust::constant_iterator<__VA_ARGS__>
#define ATEN_CUB_MAXIMUM() ::cuda::maximum<>()
#else
#define ATEN_CUB_TRANSFORM_ITERATOR(...) NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::TransformInputIterator<__VA_ARGS__>
#define ATEN_CUB_COUNTING_ITERATOR(...) NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::CountingInputIterator<__VA_ARGS__>
#define ATEN_CUB_CONSTANT_ITERATOR(...) NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::ConstantInputIterator<__VA_ARGS__>
#define ATEN_CUB_MAXIMUM() NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::Max()
#endif
#if (!defined(USE_ROCM) && !CUB_SUPPORTS_NV_BFLOAT16()) || defined(USE_ROCM)
#if !defined(USE_ROCM)
@ -270,7 +289,7 @@ inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
return x.value;
}
};
auto input_ = NO_ROCM(at_cuda_detail)::cub::TransformInputIterator<input_t, decltype(input_iter_transform), ArgIndexInputIterator>(
auto input_ = ATEN_CUB_TRANSFORM_ITERATOR(input_t, decltype(input_iter_transform), ArgIndexInputIterator)(
ArgIndexInputIterator(input + i), input_iter_transform);
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan,
input_,
@ -425,7 +444,7 @@ __global__ void calc_block_sums(const T * d_in, aggT * agg, int64_t nelem, int i
aggT data[ITEMS_PER_THREAD];
aggT agg_val = 0;
TransformFunctor<T, aggT, nonzero> transform_functor;
auto iter_in = ROCM_HIPCUB(at_cuda_detail::cub)::TransformInputIterator<aggT, TransformFunctor<T, aggT, nonzero>, const T*>(d_in, transform_functor);
auto iter_in = ATEN_CUB_TRANSFORM_ITERATOR(aggT, TransformFunctor<T, aggT, nonzero>, const T*)(d_in, transform_functor);
for (int i=0; i<iters_per_cta; i++){
if (remaining >= BLOCK_THREADS * ITEMS_PER_THREAD) {
BlockLoadT(temp_storage.load).Load(iter_in, data);
@ -568,7 +587,7 @@ inline void inclusive_sum_by_key(KeysInputIteratorT keys, ValuesInputIteratorT i
"cub InclusiveSumByKey does not support more than INT_MAX elements");
#if !defined(USE_ROCM)
CUB_WRAPPER(at_cuda_detail::cub::DeviceScan::InclusiveSumByKey,
keys, input, output, num_items, at_cuda_detail::cub::Equality(), at::cuda::getCurrentCUDAStream());
keys, input, output, num_items, NO_ROCM(::cuda)::std::equal_to<>(), at::cuda::getCurrentCUDAStream());
#else
CUB_WRAPPER(cub::DeviceScan::InclusiveSumByKey,
keys, input, output, num_items, hipcub::Equality(), at::cuda::getCurrentCUDAStream());
@ -581,7 +600,7 @@ inline void inclusive_scan_by_key(KeysInputIteratorT keys, ValuesInputIteratorT
"cub InclusiveSumByKey does not support more than INT_MAX elements");
#if !defined(USE_ROCM)
CUB_WRAPPER(at_cuda_detail::cub::DeviceScan::InclusiveScanByKey,
keys, input, output, scan_op, num_items, at_cuda_detail::cub::Equality(), at::cuda::getCurrentCUDAStream());
keys, input, output, scan_op, num_items, NO_ROCM(::cuda)::std::equal_to<>(), at::cuda::getCurrentCUDAStream());
#else
CUB_WRAPPER(cub::DeviceScan::InclusiveScanByKey,
keys, input, output, scan_op, num_items, hipcub::Equality(), at::cuda::getCurrentCUDAStream());

View File

@ -51,3 +51,11 @@
#else
#define CUB_SUPPORTS_FUTURE_VALUE() false
#endif
// There were many bc-breaking changes in major version release of CCCL v3.0.0
// Please see https://nvidia.github.io/cccl/cccl/3.0_migration_guide.html
#if CUB_VERSION >= 300000
#define CUB_V3_PLUS() true
#else
#define CUB_V3_PLUS() false
#endif

View File

@ -317,7 +317,7 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice
auto count_data = count.mutable_data_ptr<index_t>();
cuda::cub::inclusive_sum_by_key(
sorted_data,
NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::ConstantInputIterator<index_t>(1),
ATEN_CUB_CONSTANT_ITERATOR(index_t)(1),
count_data,
num_indices
);
@ -329,7 +329,7 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice
thrust::make_reverse_iterator(sorted_data + num_indices),
thrust::make_reverse_iterator(static_cast<const index_t*>(count_data) + num_indices),
thrust::make_reverse_iterator(count_data + num_indices),
NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::Max(),
ATEN_CUB_MAXIMUM(),
num_indices
);
});

View File

@ -210,7 +210,7 @@ Tensor embedding_bag_backward_cuda_sum_avg(
auto count_data = count.mutable_data_ptr<index_t>();
cuda::cub::inclusive_sum_by_key(
sorted_data,
NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::ConstantInputIterator<index_t>(1),
ATEN_CUB_CONSTANT_ITERATOR(index_t)(1),
count_data,
num_indices
);
@ -222,7 +222,7 @@ Tensor embedding_bag_backward_cuda_sum_avg(
thrust::make_reverse_iterator(sorted_data + num_indices),
thrust::make_reverse_iterator(count_data + num_indices),
thrust::make_reverse_iterator(count_data + num_indices),
NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::Max(),
ATEN_CUB_MAXIMUM(),
num_indices
);
});

View File

@ -94,7 +94,7 @@ __global__ void flag_kernel(const T* d_in, int64_t * d_out, const int64_t * agg,
// Specialize BlockScan type for our thread block
using BlockScanT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockScan<int, BLOCK_THREADS, ROCM_HIPCUB(at_cuda_detail::cub)::BLOCK_SCAN_WARP_SCANS>;
using TransformInputIteratorT = ROCM_HIPCUB(at_cuda_detail::cub)::TransformInputIterator<int, NonZeroOp<T>, const T*>;
using TransformInputIteratorT = ATEN_CUB_TRANSFORM_ITERATOR(int, NonZeroOp<T>, const T*);
using BlockExchangeT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockExchange<int, BLOCK_THREADS, ITEMS_PER_THREAD>;
// Shared memory
@ -184,7 +184,7 @@ void nonzero_cuda_out_impl(const Tensor& self, Tensor& out) {
auto num_nonzeros = allocator.allocate(sizeof(int) * num_chunks);
for (int64_t idx = 0; idx < num_chunks; idx++) {
int64_t remaining = std::min(chunk_size, self.numel() - idx * chunk_size);
cub::TransformInputIterator<bool, NonZeroOp<scalar_t>, const scalar_t*> itr(
ATEN_CUB_TRANSFORM_ITERATOR(bool, NonZeroOp<scalar_t>, const scalar_t*) itr(
self_.const_data_ptr<scalar_t>() + idx * chunk_size,
NonZeroOp<scalar_t>());
AT_CUDA_CHECK(cub::DeviceReduce::Sum(
@ -243,8 +243,8 @@ void nonzero_cuda_out_impl(const Tensor& self, Tensor& out) {
for (int64_t idx = 0; idx < num_chunks; idx++) {
int remaining = std::min(chunk_size, self.numel() - idx * chunk_size);
cub::CountingInputIterator<int64_t> counting_itr(idx * chunk_size);
cub::TransformInputIterator<bool, NonZeroOp<scalar_t>, const scalar_t*>
ATEN_CUB_COUNTING_ITERATOR(int64_t) counting_itr(idx * chunk_size);
ATEN_CUB_TRANSFORM_ITERATOR(bool, NonZeroOp<scalar_t>, const scalar_t*)
itr(self_.const_data_ptr<scalar_t>() + idx * chunk_size,
NonZeroOp<scalar_t>());
temp_storage_bytes = 0;

View File

@ -721,8 +721,8 @@ void launch(
desired, counts, num_blocks, blocks_per_slice, kthCounts);
C10_CUDA_KERNEL_LAUNCH_CHECK();
// Do a prefix scan of withinKCounts and kthCounts using slice_idx as keys to get the starting index of each block
using counting_iter_t = cub::CountingInputIterator<uint32_t, uint32_t>;
using slice_idx_iter_t = cub::TransformInputIterator<uint32_t, BlockIdxToKey, counting_iter_t>;
using counting_iter_t = ATEN_CUB_COUNTING_ITERATOR(uint32_t, uint32_t);
using slice_idx_iter_t = ATEN_CUB_TRANSFORM_ITERATOR(uint32_t, BlockIdxToKey, counting_iter_t);
slice_idx_iter_t slice_idx_iter(counting_iter_t(0), BlockIdxToKey(blocks_per_slice));
at::cuda::cub::inclusive_sum_by_key(slice_idx_iter, withinKCounts, withinKCounts, num_blocks);
at::cuda::cub::inclusive_sum_by_key(slice_idx_iter, kthCounts, kthCounts, num_blocks);

View File

@ -54,7 +54,7 @@ struct LoadBoolOp {
auto wrap_input_iterator(const bool *data) {
// See NOTE [Loading boolean values]
LoadBoolOp op;
return NO_ROCM(at_cuda_detail)::cub::TransformInputIterator<bool, LoadBoolOp, const uint8_t*, int>(
return ATEN_CUB_TRANSFORM_ITERATOR(bool, LoadBoolOp, const uint8_t*, int)(
reinterpret_cast<const uint8_t*>(data), op);
}
@ -259,10 +259,10 @@ struct UniqueCub<bool> {
const bool* self_data = self.const_data_ptr<bool>();
MapNumberOfTrueValues op;
NO_ROCM(at_cuda_detail)::cub::TransformInputIterator<int, MapNumberOfTrueValues, const uint8_t*, int>
ATEN_CUB_TRANSFORM_ITERATOR(int, MapNumberOfTrueValues, const uint8_t*, int)
data_iter(reinterpret_cast<const uint8_t*>(self_data), op);
at::cuda::cub::reduce(data_iter, tmp_num_true.get(), num_inp,
NO_ROCM(at_cuda_detail)::cub::Sum{}, 0);
NO_ROCM(::cuda)::std::plus<>{}, 0);
auto options = self.options();
output = at::empty({2}, self.options());

View File

@ -146,8 +146,8 @@ TEST(InclusiveScanSplit, CubTest) {
cudaMallocManaged(&output1, sizeof(int) * 10);
cudaDeviceSynchronize();
at::cuda::cub::inclusive_scan<int *, int *, ::at_cuda_detail::cub::Sum, /*max_cub_size=*/2>(
input, output1, ::at_cuda_detail::cub::Sum(), 10);
at::cuda::cub::inclusive_scan<int *, int *, NO_ROCM(::cuda)::std::plus<>, /*max_cub_size=*/2>(
input, output1, NO_ROCM(::cuda)::std::plus<>(), 10);
cudaDeviceSynchronize();
ASSERT_EQ(output1[0], 1);
@ -172,8 +172,8 @@ TEST(ExclusiveScanSplit, CubTest) {
cudaMallocManaged(&output2, sizeof(int) * 10);
cudaDeviceSynchronize();
at::cuda::cub::exclusive_scan<int *, int *, ::at_cuda_detail::cub::Sum, int, /*max_cub_size=*/2>(
input, output2, ::at_cuda_detail::cub::Sum(), 0, 10);
at::cuda::cub::exclusive_scan<int *, int *, NO_ROCM(::cuda)::std::plus<>, int, /*max_cub_size=*/2>(
input, output2, NO_ROCM(::cuda)::std::plus<>(), 0, 10);
cudaDeviceSynchronize();
ASSERT_EQ(output2[0], 0);