[torch] Add cuda support for segment reduction 'max' (#54175)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54175

Building on top of previous PR. This PR adds cuda support for 1D max reduction.

Next steps:
- Add support for other major reduction types (e.g. min, sum) for 1D tensor
- Documentation for the op
- Perf optimizations and benchmark util
- Backward support  (not high priority)
- Support for multi dimensional tensors (on data and lengths) (not high priority)
- Support for 'indices' (not high priority)

Test Plan: Added unit test

Reviewed By: ngimel

Differential Revision: D27121170

fbshipit-source-id: 1c2565f42e2903e6fc089d56983ce8857efbfa3c
This commit is contained in:
Serhat Yilmaz 2021-04-08 13:22:58 -07:00 committed by Facebook GitHub Bot
parent 778f9eab6c
commit eb5e1fc713
6 changed files with 189 additions and 43 deletions

View File

@ -498,6 +498,7 @@ filegroup(
"aten/src/ATen/native/cuda/Repeat.cu.cc",
"aten/src/ATen/native/cuda/ReplicationPadding.cu.cc",
"aten/src/ATen/native/cuda/Resize.cu.cc",
"aten/src/ATen/native/cuda/SegmentReduce.cu.cc",
"aten/src/ATen/native/cuda/SoftMax.cu.cc",
"aten/src/ATen/native/cuda/SortingKthValue.cu.cc",
"aten/src/ATen/native/cuda/SparseMM.cu.cc",

View File

@ -1,43 +1,22 @@
#include <ATen/native/SegmentReduce.h>
#include <ATen/ATen.h>
#include <ATen/Dispatch.h>
#include <ATen/NumericUtils.h>
namespace at {
namespace native {
DEFINE_DISPATCH(segment_reduce_stub);
DEFINE_DISPATCH(_segment_reduce_stub);
enum ReductionType { MAX };
const std::map<std::string, ReductionType> reduce2REDUCE = {
{"max", MAX},
};
namespace {
Tensor _segment_reduce_cpu(
Tensor _segment_reduce_cpu_kernel(
const Tensor& data,
std::string reduce,
const c10::optional<Tensor>& lengths,
const c10::optional<Tensor>& indices,
const Tensor& lengths,
int64_t axis,
bool unsafe) {
axis = maybe_wrap_dim(axis, data.ndimension());
TORCH_CHECK(axis == 0, "Currently only dim=0 is supported!");
TORCH_CHECK(data.dim() == 1);
TORCH_CHECK(data.numel() > 0);
TORCH_CHECK(
reduce2REDUCE.at(reduce) == MAX,
"Currently only 'max' reduction is supported!");
// length related checks
TORCH_CHECK(
lengths.has_value() && !indices.has_value(),
"Currently only lengths based reduction is supported!")
const auto& lengths_value = lengths.value();
TORCH_CHECK(lengths_value.dim() == 1);
TORCH_CHECK(data.get_device() == lengths_value.get_device());
TORCH_CHECK(data.dim() >= lengths_value.dim());
const auto lengths_contig = lengths_value.contiguous();
const auto lengths_contig = lengths.contiguous();
const auto data_contig = data.contiguous();
int64_t batch_size = lengths_contig.numel();
@ -47,7 +26,8 @@ Tensor _segment_reduce_cpu(
if (!unsafe) {
int64_t sum = 0;
for (int64_t i = 0; i < batch_size; ++i) {
TORCH_CHECK(lengths_data[i] > 0);
TORCH_CHECK(
(lengths_data[i] > 0), "lengths contains non positive value!");
sum += lengths_data[i];
}
TORCH_CHECK(sum == data.numel());
@ -80,5 +60,49 @@ Tensor _segment_reduce_cpu(
return output;
}
} // namespace
enum SegmentReductionType { MAX };
static const std::map<std::string, SegmentReductionType> segmentReduce2REDUCE =
{
{"max", MAX},
};
Tensor segment_reduce_kernel(
const Tensor& data,
std::string reduce,
const c10::optional<Tensor>& lengths,
const c10::optional<Tensor>& indices,
int64_t axis,
bool unsafe) {
axis = maybe_wrap_dim(axis, data.ndimension());
TORCH_CHECK(axis == 0, "Currently only dim=0 is supported!");
TORCH_CHECK(data.dim() == 1);
TORCH_CHECK(data.numel() > 0);
TORCH_CHECK(
at::native::segmentReduce2REDUCE.at(reduce) == MAX,
"Currently only 'max' reduction is supported!");
// length related checks
TORCH_CHECK(
lengths.has_value() && !indices.has_value(),
"Currently only lengths based reduction is supported!")
const auto& lengths_value = lengths.value();
TORCH_CHECK(lengths_value.dim() == 1);
TORCH_CHECK(data.get_device() == lengths_value.get_device());
TORCH_CHECK(data.dim() >= lengths_value.dim());
return _segment_reduce_stub(
data.device().type(), data, lengths_value, axis, unsafe);
}
REGISTER_ARCH_DISPATCH(
_segment_reduce_stub,
DEFAULT,
&_segment_reduce_cpu_kernel);
REGISTER_AVX_DISPATCH(_segment_reduce_stub, &_segment_reduce_cpu_kernel);
REGISTER_AVX2_DISPATCH(_segment_reduce_stub, &_segment_reduce_cpu_kernel);
REGISTER_VSX_DISPATCH(_segment_reduce_stub, &_segment_reduce_cpu_kernel);
} // namespace native
} // namespace at

View File

@ -7,14 +7,9 @@
namespace at {
namespace native {
using segment_reduce_fn = void (*)(
const Tensor&,
std::string,
const c10::optional<Tensor>&,
const c10::optional<Tensor>&,
int64_t,
bool);
DECLARE_DISPATCH(segment_reduce_fn, segment_reduce_stub);
using segment_reduce_fn =
Tensor (*)(const Tensor&, const Tensor&, int64_t, bool);
DECLARE_DISPATCH(segment_reduce_fn, _segment_reduce_stub);
} // namespace native
} // namespace at

View File

@ -0,0 +1,122 @@
#include <ATen/native/SegmentReduce.h>
#include <ATen/ATen.h>
#include <ATen/NumericUtils.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/KernelUtils.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <ATen/cuda/CubUtils.cuh>
#include <iostream>
namespace at {
namespace native {
struct CustomMax {
template <typename OutputT>
__host__ __device__ __forceinline__ OutputT
operator()(const OutputT& a, const OutputT& b) {
if (at::_isnan(a)) {
return a;
} else if (at::_isnan(b)) {
return b;
}
return std::max<OutputT>(a, b);
}
};
Tensor _get_complete_sum(const Tensor& lengths) {
int64_t segment_count = lengths.numel();
auto offsets = at::empty({segment_count + 1}, lengths.options());
offsets[0].zero_();
auto* lengths_data_ptr = lengths.data_ptr<int64_t>();
auto* offsets_data_ptr = offsets.data_ptr<int64_t>();
size_t temp_storage_bytes = 0;
AT_CUDA_CHECK(cub::DeviceScan::InclusiveSum(
nullptr,
temp_storage_bytes,
lengths_data_ptr,
offsets_data_ptr + 1,
segment_count,
at::cuda::getCurrentCUDAStream()););
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
auto dataPtr = allocator.allocate(temp_storage_bytes);
AT_CUDA_CHECK(cub::DeviceScan::InclusiveSum(
dataPtr.get(),
temp_storage_bytes,
lengths_data_ptr,
offsets_data_ptr + 1,
segment_count,
at::cuda::getCurrentCUDAStream()););
return offsets;
}
Tensor _segment_reduce_cuda_kernel(
const Tensor& data,
const Tensor& lengths,
int64_t axis,
bool unsafe) {
if (!unsafe) {
TORCH_CHECK(
(lengths.min().item<int64_t>() > 0),
"lengths contains non positive value!");
TORCH_CHECK(lengths.sum().item<int64_t>() == data.numel());
}
int64_t segment_count = lengths.numel();
const auto data_contig = data.contiguous();
auto output = at::empty({segment_count}, data.options());
const auto lengths_contig = lengths.contiguous();
auto offsets = _get_complete_sum(lengths_contig);
auto* offsets_data_ptr = offsets.data_ptr<int64_t>();
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
data.scalar_type(),
"segment_reduce_cuda",
[&]() {
auto* data_contig_data_ptr = data_contig.data_ptr<scalar_t>();
auto* output_data_ptr = output.data_ptr<scalar_t>();
CustomMax max_op{};
size_t temp_storage_bytes = 0;
scalar_t initial_value = std::numeric_limits<scalar_t>::lowest();
AT_CUDA_CHECK(cub::DeviceSegmentedReduce::Reduce(
nullptr,
temp_storage_bytes,
data_contig_data_ptr,
output_data_ptr,
segment_count,
offsets_data_ptr,
offsets_data_ptr + 1,
max_op,
initial_value,
at::cuda::getCurrentCUDAStream()););
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
auto dataPtr = allocator.allocate(temp_storage_bytes);
AT_CUDA_CHECK(cub::DeviceSegmentedReduce::Reduce(
dataPtr.get(),
temp_storage_bytes,
data_contig_data_ptr,
output_data_ptr,
segment_count,
offsets_data_ptr,
offsets_data_ptr + 1,
max_op,
initial_value,
at::cuda::getCurrentCUDAStream()););
});
return output;
}
REGISTER_DISPATCH(_segment_reduce_stub, &_segment_reduce_cuda_kernel);
} // namespace native
} // namespace at

View File

@ -8872,4 +8872,4 @@
- func: segment_reduce(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, int axis=0, bool unsafe=False) -> Tensor
variants: function
dispatch:
CPU: _segment_reduce_cpu
CPU, CUDA: segment_reduce_kernel

View File

@ -1,8 +1,8 @@
import torch
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
onlyCPU,
dtypes,
dtypesIfCUDA,
)
from torch.testing._internal.common_utils import (
TestCase,
@ -11,25 +11,29 @@ from torch.testing._internal.common_utils import (
class TestSegmentReductions(TestCase):
@onlyCPU
@dtypes(torch.half, torch.bfloat16, torch.float, torch.double)
def test_max_simple_1d(self, device, dtype):
def _test_max_simple_1d(self, device, dtype, unsafe):
lengths = torch.tensor([1, 2, 3], device=device)
data = torch.tensor([1, float("nan"), 3, 4, 5, 6], device=device, dtype=dtype)
expected_result = torch.tensor([1, float("nan"), 6], device=device, dtype=dtype)
actual_result = torch.segment_reduce(
data=data, reduce="max", lengths=lengths, axis=0, unsafe=False
data=data, reduce="max", lengths=lengths, axis=0, unsafe=unsafe
)
self.assertEqual(
expected_result, actual_result, rtol=1e-03, atol=1e-05, equal_nan=True
)
actual_result = torch.segment_reduce(
data=data, reduce="max", lengths=lengths, axis=-1, unsafe=False
data=data, reduce="max", lengths=lengths, axis=-1, unsafe=unsafe
)
self.assertEqual(
expected_result, actual_result, rtol=1e-03, atol=1e-05, equal_nan=True
)
@dtypesIfCUDA(torch.half, torch.bfloat16, torch.float, torch.double)
@dtypes(torch.half, torch.bfloat16, torch.float, torch.double)
def test_max_simple_1d(self, device, dtype):
self._test_max_simple_1d(device, dtype, False)
self._test_max_simple_1d(device, dtype, True)
instantiate_device_type_tests(TestSegmentReductions, globals())