mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[torch] Add cuda support for segment reduction 'max' (#56704)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/56704 This is re submit of PR: https://github.com/pytorch/pytorch/pull/54175 Main changes compared to original PR: - Switch to importing "<ATen/cuda/cub.cuh>" - Use CUB_WRAPPER to reduce boiler plate code. Test Plan: Will check CI status to make sure a Added unit test Reviewed By: ngimel Differential Revision: D27941257 fbshipit-source-id: 24a0e0c7f6c46126d2606fe42ed03dca15684415
This commit is contained in:
parent
d578e8cfa2
commit
6c37788cb1
|
|
@ -499,6 +499,7 @@ filegroup(
|
|||
"aten/src/ATen/native/cuda/Repeat.cu.cc",
|
||||
"aten/src/ATen/native/cuda/ReplicationPadding.cu.cc",
|
||||
"aten/src/ATen/native/cuda/Resize.cu.cc",
|
||||
"aten/src/ATen/native/cuda/SegmentReduce.cu.cc",
|
||||
"aten/src/ATen/native/cuda/SoftMax.cu.cc",
|
||||
"aten/src/ATen/native/cuda/SortingKthValue.cu.cc",
|
||||
"aten/src/ATen/native/cuda/SparseMM.cu.cc",
|
||||
|
|
|
|||
|
|
@ -17,12 +17,13 @@
|
|||
#include <c10/cuda/CUDAStream.h>
|
||||
|
||||
// handle the temporary storage and 'twice' calls for cub API
|
||||
#define CUB_WRAPPER(func, ...) do { \
|
||||
size_t temp_storage_bytes = 0; \
|
||||
func(nullptr, temp_storage_bytes, __VA_ARGS__); \
|
||||
auto temp_storage = allocator->allocate(temp_storage_bytes); \
|
||||
func(temp_storage.get(), temp_storage_bytes, __VA_ARGS__); \
|
||||
AT_CUDA_CHECK(cudaGetLastError()); \
|
||||
#define CUB_WRAPPER(func, ...) do { \
|
||||
size_t temp_storage_bytes = 0; \
|
||||
func(nullptr, temp_storage_bytes, __VA_ARGS__); \
|
||||
auto& caching_allocator = *::c10::cuda::CUDACachingAllocator::get(); \
|
||||
auto temp_storage = caching_allocator.allocate(temp_storage_bytes); \
|
||||
func(temp_storage.get(), temp_storage_bytes, __VA_ARGS__); \
|
||||
AT_CUDA_CHECK(cudaGetLastError()); \
|
||||
} while (false)
|
||||
|
||||
#ifdef __HIP_PLATFORM_HCC__
|
||||
|
|
@ -57,8 +58,6 @@ static inline void sort_keys(
|
|||
) {
|
||||
using key_t_ = typename cuda_type<key_t>::type;
|
||||
|
||||
auto allocator = c10::cuda::CUDACachingAllocator::get();
|
||||
|
||||
const key_t_ *keys_in_ = reinterpret_cast<const key_t_*>(keys_in);
|
||||
key_t_ *keys_out_ = reinterpret_cast<key_t_*>(keys_out);
|
||||
|
||||
|
|
|
|||
|
|
@ -1,43 +1,22 @@
|
|||
#include <ATen/native/SegmentReduce.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/NumericUtils.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
DEFINE_DISPATCH(segment_reduce_stub);
|
||||
DEFINE_DISPATCH(_segment_reduce_stub);
|
||||
|
||||
enum ReductionType { MAX };
|
||||
const std::map<std::string, ReductionType> reduce2REDUCE = {
|
||||
{"max", MAX},
|
||||
};
|
||||
namespace {
|
||||
|
||||
Tensor _segment_reduce_cpu(
|
||||
Tensor _segment_reduce_cpu_kernel(
|
||||
const Tensor& data,
|
||||
std::string reduce,
|
||||
const c10::optional<Tensor>& lengths,
|
||||
const c10::optional<Tensor>& indices,
|
||||
const Tensor& lengths,
|
||||
int64_t axis,
|
||||
bool unsafe) {
|
||||
axis = maybe_wrap_dim(axis, data.ndimension());
|
||||
TORCH_CHECK(axis == 0, "Currently only dim=0 is supported!");
|
||||
TORCH_CHECK(data.dim() == 1);
|
||||
TORCH_CHECK(data.numel() > 0);
|
||||
TORCH_CHECK(
|
||||
reduce2REDUCE.at(reduce) == MAX,
|
||||
"Currently only 'max' reduction is supported!");
|
||||
|
||||
// length related checks
|
||||
TORCH_CHECK(
|
||||
lengths.has_value() && !indices.has_value(),
|
||||
"Currently only lengths based reduction is supported!")
|
||||
const auto& lengths_value = lengths.value();
|
||||
TORCH_CHECK(lengths_value.dim() == 1);
|
||||
TORCH_CHECK(data.get_device() == lengths_value.get_device());
|
||||
TORCH_CHECK(data.dim() >= lengths_value.dim());
|
||||
|
||||
const auto lengths_contig = lengths_value.contiguous();
|
||||
const auto lengths_contig = lengths.contiguous();
|
||||
const auto data_contig = data.contiguous();
|
||||
|
||||
int64_t batch_size = lengths_contig.numel();
|
||||
|
|
@ -47,7 +26,8 @@ Tensor _segment_reduce_cpu(
|
|||
if (!unsafe) {
|
||||
int64_t sum = 0;
|
||||
for (int64_t i = 0; i < batch_size; ++i) {
|
||||
TORCH_CHECK(lengths_data[i] > 0);
|
||||
TORCH_CHECK(
|
||||
(lengths_data[i] > 0), "lengths contains non positive value!");
|
||||
sum += lengths_data[i];
|
||||
}
|
||||
TORCH_CHECK(sum == data.numel());
|
||||
|
|
@ -80,5 +60,49 @@ Tensor _segment_reduce_cpu(
|
|||
return output;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
enum SegmentReductionType { MAX };
|
||||
static const std::map<std::string, SegmentReductionType> segmentReduce2REDUCE =
|
||||
{
|
||||
{"max", MAX},
|
||||
};
|
||||
|
||||
Tensor segment_reduce_kernel(
|
||||
const Tensor& data,
|
||||
std::string reduce,
|
||||
const c10::optional<Tensor>& lengths,
|
||||
const c10::optional<Tensor>& indices,
|
||||
int64_t axis,
|
||||
bool unsafe) {
|
||||
axis = maybe_wrap_dim(axis, data.ndimension());
|
||||
TORCH_CHECK(axis == 0, "Currently only dim=0 is supported!");
|
||||
TORCH_CHECK(data.dim() == 1);
|
||||
TORCH_CHECK(data.numel() > 0);
|
||||
TORCH_CHECK(
|
||||
at::native::segmentReduce2REDUCE.at(reduce) == MAX,
|
||||
"Currently only 'max' reduction is supported!");
|
||||
|
||||
// length related checks
|
||||
TORCH_CHECK(
|
||||
lengths.has_value() && !indices.has_value(),
|
||||
"Currently only lengths based reduction is supported!")
|
||||
const auto& lengths_value = lengths.value();
|
||||
TORCH_CHECK(lengths_value.dim() == 1);
|
||||
TORCH_CHECK(data.get_device() == lengths_value.get_device());
|
||||
TORCH_CHECK(data.dim() >= lengths_value.dim());
|
||||
|
||||
return _segment_reduce_stub(
|
||||
data.device().type(), data, lengths_value, axis, unsafe);
|
||||
}
|
||||
|
||||
REGISTER_ARCH_DISPATCH(
|
||||
_segment_reduce_stub,
|
||||
DEFAULT,
|
||||
&_segment_reduce_cpu_kernel);
|
||||
REGISTER_AVX_DISPATCH(_segment_reduce_stub, &_segment_reduce_cpu_kernel);
|
||||
REGISTER_AVX2_DISPATCH(_segment_reduce_stub, &_segment_reduce_cpu_kernel);
|
||||
REGISTER_VSX_DISPATCH(_segment_reduce_stub, &_segment_reduce_cpu_kernel);
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
|
|
|||
|
|
@ -7,14 +7,9 @@
|
|||
namespace at {
|
||||
namespace native {
|
||||
|
||||
using segment_reduce_fn = void (*)(
|
||||
const Tensor&,
|
||||
std::string,
|
||||
const c10::optional<Tensor>&,
|
||||
const c10::optional<Tensor>&,
|
||||
int64_t,
|
||||
bool);
|
||||
DECLARE_DISPATCH(segment_reduce_fn, segment_reduce_stub);
|
||||
using segment_reduce_fn =
|
||||
Tensor (*)(const Tensor&, const Tensor&, int64_t, bool);
|
||||
DECLARE_DISPATCH(segment_reduce_fn, _segment_reduce_stub);
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
|
|
|||
93
aten/src/ATen/native/cuda/SegmentReduce.cu
Normal file
93
aten/src/ATen/native/cuda/SegmentReduce.cu
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
|
||||
#include <ATen/native/SegmentReduce.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/NumericUtils.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/detail/KernelUtils.h>
|
||||
#include <ATen/cuda/cub.cuh>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
struct CustomMax {
|
||||
template <typename OutputT>
|
||||
__host__ __device__ __forceinline__ OutputT
|
||||
operator()(const OutputT& a, const OutputT& b) const {
|
||||
if (at::_isnan(a)) {
|
||||
return a;
|
||||
} else if (at::_isnan(b)) {
|
||||
return b;
|
||||
}
|
||||
return std::max<OutputT>(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
Tensor _get_complete_sum(const Tensor& lengths) {
|
||||
int64_t segment_count = lengths.numel();
|
||||
TORCH_CHECK(segment_count < INT_MAX);
|
||||
auto offsets = at::empty({segment_count + 1}, lengths.options());
|
||||
offsets[0].zero_();
|
||||
auto* lengths_data_ptr = lengths.data_ptr<int64_t>();
|
||||
auto* offsets_data_ptr = offsets.data_ptr<int64_t>();
|
||||
|
||||
CUB_WRAPPER(
|
||||
cub::DeviceScan::InclusiveSum,
|
||||
lengths_data_ptr,
|
||||
offsets_data_ptr + 1,
|
||||
segment_count,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
|
||||
return offsets;
|
||||
}
|
||||
|
||||
Tensor _segment_reduce_cuda_kernel(
|
||||
const Tensor& data,
|
||||
const Tensor& lengths,
|
||||
int64_t axis,
|
||||
bool unsafe) {
|
||||
if (!unsafe) {
|
||||
TORCH_CHECK(
|
||||
(lengths.min().item<int64_t>() > 0),
|
||||
"lengths contains non positive value!");
|
||||
TORCH_CHECK(lengths.sum().item<int64_t>() == data.numel());
|
||||
}
|
||||
|
||||
int64_t segment_count = lengths.numel();
|
||||
const auto data_contig = data.contiguous();
|
||||
auto output = at::empty({segment_count}, data.options());
|
||||
|
||||
const auto lengths_contig = lengths.contiguous();
|
||||
auto offsets = _get_complete_sum(lengths_contig);
|
||||
auto* offsets_data_ptr = offsets.data_ptr<int64_t>();
|
||||
|
||||
AT_DISPATCH_ALL_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
data.scalar_type(),
|
||||
"segment_reduce_cuda",
|
||||
[&]() {
|
||||
auto* data_contig_data_ptr = data_contig.data_ptr<scalar_t>();
|
||||
auto* output_data_ptr = output.data_ptr<scalar_t>();
|
||||
|
||||
CustomMax max_op{};
|
||||
scalar_t initial_value = std::numeric_limits<scalar_t>::lowest();
|
||||
CUB_WRAPPER(
|
||||
cub::DeviceSegmentedReduce::Reduce,
|
||||
data_contig_data_ptr,
|
||||
output_data_ptr,
|
||||
segment_count,
|
||||
offsets_data_ptr,
|
||||
offsets_data_ptr + 1,
|
||||
max_op,
|
||||
initial_value,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
});
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
REGISTER_DISPATCH(_segment_reduce_stub, &_segment_reduce_cuda_kernel);
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
|
@ -8967,4 +8967,4 @@
|
|||
- func: segment_reduce(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, int axis=0, bool unsafe=False) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: _segment_reduce_cpu
|
||||
CPU, CUDA: segment_reduce_kernel
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
import torch
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests,
|
||||
onlyCPU,
|
||||
dtypes,
|
||||
dtypesIfCUDA,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase,
|
||||
|
|
@ -11,25 +11,29 @@ from torch.testing._internal.common_utils import (
|
|||
|
||||
|
||||
class TestSegmentReductions(TestCase):
|
||||
@onlyCPU
|
||||
@dtypes(torch.half, torch.bfloat16, torch.float, torch.double)
|
||||
def test_max_simple_1d(self, device, dtype):
|
||||
def _test_max_simple_1d(self, device, dtype, unsafe):
|
||||
lengths = torch.tensor([1, 2, 3], device=device)
|
||||
data = torch.tensor([1, float("nan"), 3, 4, 5, 6], device=device, dtype=dtype)
|
||||
expected_result = torch.tensor([1, float("nan"), 6], device=device, dtype=dtype)
|
||||
actual_result = torch.segment_reduce(
|
||||
data=data, reduce="max", lengths=lengths, axis=0, unsafe=False
|
||||
data=data, reduce="max", lengths=lengths, axis=0, unsafe=unsafe
|
||||
)
|
||||
self.assertEqual(
|
||||
expected_result, actual_result, rtol=1e-03, atol=1e-05, equal_nan=True
|
||||
)
|
||||
actual_result = torch.segment_reduce(
|
||||
data=data, reduce="max", lengths=lengths, axis=-1, unsafe=False
|
||||
data=data, reduce="max", lengths=lengths, axis=-1, unsafe=unsafe
|
||||
)
|
||||
self.assertEqual(
|
||||
expected_result, actual_result, rtol=1e-03, atol=1e-05, equal_nan=True
|
||||
)
|
||||
|
||||
@dtypesIfCUDA(torch.half, torch.bfloat16, torch.float, torch.double)
|
||||
@dtypes(torch.half, torch.bfloat16, torch.float, torch.double)
|
||||
def test_max_simple_1d(self, device, dtype):
|
||||
self._test_max_simple_1d(device, dtype, False)
|
||||
self._test_max_simple_1d(device, dtype, True)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestSegmentReductions, globals())
|
||||
|
||||
|
|
|
|||
|
|
@ -7738,6 +7738,7 @@ CUDA_IDENTIFIER_MAP = collections.OrderedDict(
|
|||
("cub::DeviceReduce", ("hipcub::DeviceReduce", CONV_SPECIAL_FUNC, API_RUNTIME)),
|
||||
("cub::DeviceScan", ("hipcub::DeviceScan", CONV_SPECIAL_FUNC, API_RUNTIME)),
|
||||
("cub::DeviceSegmentedRadixSort", ("hipcub::DeviceSegmentedRadixSort", CONV_SPECIAL_FUNC, API_RUNTIME)),
|
||||
("cub::DeviceSegmentedReduce", ("hipcub::DeviceSegmentedReduce", CONV_SPECIAL_FUNC, API_RUNTIME)),
|
||||
("cub::DeviceSelect", ("hipcub::DeviceSelect", CONV_SPECIAL_FUNC, API_RUNTIME)),
|
||||
("cub::KeyValuePair", ("hipcub::KeyValuePair", CONV_SPECIAL_FUNC, API_RUNTIME)),
|
||||
("cub::Max", ("hipcub::Max", CONV_SPECIAL_FUNC, API_RUNTIME)),
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user