[torch][segment_reduce] Add support for initial value (#56923)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56923

Next Steps in order:
- Add backward support for CUDA
- Add support for more aggregation types
- Benchmarking (for cuda mainly)/more testing/documentation
- Support for multi dimension

Test Plan: Updated unit test to include 0 length segment as well.

Reviewed By: ngimel

Differential Revision: D27992228

fbshipit-source-id: 28851811f8a784a63162721c511d69e617a93727
This commit is contained in:
Serhat Yilmaz 2021-04-30 18:00:20 -07:00 committed by Facebook GitHub Bot
parent bd347012ec
commit 20eac093a7
6 changed files with 63 additions and 53 deletions

View File

@ -17,45 +17,32 @@ Tensor _segment_reduce_cpu_kernel(
const Tensor& data, const Tensor& data,
const Tensor& lengths, const Tensor& lengths,
int64_t axis, int64_t axis,
bool unsafe) { const c10::optional<Scalar>& initial) {
const auto lengths_contig = lengths.contiguous(); int64_t batch_size = lengths.numel();
const auto data_contig = data.contiguous();
int64_t batch_size = lengths_contig.numel();
auto output = at::empty({batch_size}, data.options()); auto output = at::empty({batch_size}, data.options());
const auto* lengths_data = lengths_contig.data_ptr<int64_t>(); const auto* lengths_data = lengths.data_ptr<int64_t>();
if (!unsafe) {
int64_t sum = 0;
for (int64_t i = 0; i < batch_size; ++i) {
TORCH_CHECK(
(lengths_data[i] > 0), "lengths contains non positive value!");
sum += lengths_data[i];
}
TORCH_CHECK(sum == data.numel());
}
AT_DISPATCH_ALL_TYPES_AND2( AT_DISPATCH_ALL_TYPES_AND2(
kBFloat16, kBFloat16, kHalf, data.scalar_type(), "_segment_reduce_cpu", ([&]() {
kHalf,
data_contig.scalar_type(),
"_segment_reduce_cpu",
([&]() {
auto* output_data = output.data_ptr<scalar_t>(); auto* output_data = output.data_ptr<scalar_t>();
const auto* values_data = data_contig.data_ptr<scalar_t>(); const auto* values_data = data.data_ptr<scalar_t>();
int64_t k = 0; int64_t k = 0;
for (int64_t i = 0; i < batch_size; ++i) { for (int64_t i = 0; i < batch_size; ++i) {
scalar_t reduction = std::numeric_limits<scalar_t>::lowest(); scalar_t initial_value = initial.has_value()
? initial.value().to<scalar_t>()
: std::numeric_limits<scalar_t>::lowest();
for (int64_t j = 0; j < lengths_data[i]; ++j) { for (int64_t j = 0; j < lengths_data[i]; ++j) {
const auto data = values_data[k]; const auto data = values_data[k];
reduction = initial_value = at::_isnan(data)
at::_isnan(data) ? data : std::max<scalar_t>(reduction, data); ? data
: std::max<scalar_t>(initial_value, data);
k++; k++;
} }
// If unsafe is false, check on lengths or indices should cover cases // If unsafe is false, check on lengths or indices should cover cases
// where lengths for a particular segment is non-positive. If unsafe // where lengths for a particular segment is negative. If unsafe
// is true, simply set to numerical limits for particular reduction // is true, simply set to initial_value for particular reduction
output_data[i] = reduction; output_data[i] = initial_value;
} }
})); }));
@ -93,8 +80,7 @@ Tensor _segment_reduce_cpu_backward_kernel(
} }
k++; k++;
} }
// Average gradient output based on number of maximum in the segment // Average gradient based on number of maximum elements in the segment
TORCH_INTERNAL_ASSERT(counter > 0);
if (counter < 2) { if (counter < 2) {
continue; continue;
} }
@ -124,7 +110,8 @@ Tensor segment_reduce_kernel(
const c10::optional<Tensor>& lengths, const c10::optional<Tensor>& lengths,
const c10::optional<Tensor>& indices, const c10::optional<Tensor>& indices,
int64_t axis, int64_t axis,
bool unsafe) { bool unsafe,
const c10::optional<Scalar>& initial) {
axis = maybe_wrap_dim(axis, data.ndimension()); axis = maybe_wrap_dim(axis, data.ndimension());
TORCH_CHECK(axis == 0, "Currently only dim=0 is supported!"); TORCH_CHECK(axis == 0, "Currently only dim=0 is supported!");
TORCH_CHECK(data.dim() == 1); TORCH_CHECK(data.dim() == 1);
@ -142,8 +129,22 @@ Tensor segment_reduce_kernel(
TORCH_CHECK(data.get_device() == lengths_value.get_device()); TORCH_CHECK(data.get_device() == lengths_value.get_device());
TORCH_CHECK(data.dim() >= lengths_value.dim()); TORCH_CHECK(data.dim() >= lengths_value.dim());
if (!unsafe) {
auto min_length = lengths_value.min().item<int64_t>();
TORCH_CHECK((min_length >= 0), "lengths contains negative value!");
TORCH_CHECK(min_length != 0 || initial.has_value());
TORCH_CHECK(lengths_value.sum().item<int64_t>() == data.numel());
}
const auto data_contig = data.contiguous();
const auto lengths_contig = lengths_value.contiguous();
return _segment_reduce_stub( return _segment_reduce_stub(
data.device().type(), data, lengths_value, axis, unsafe); data_contig.device().type(),
data_contig,
lengths_contig,
axis,
initial);
} }
REGISTER_ARCH_DISPATCH( REGISTER_ARCH_DISPATCH(

View File

@ -7,8 +7,11 @@
namespace at { namespace at {
namespace native { namespace native {
using segment_reduce_fn = using segment_reduce_fn = Tensor (*)(
Tensor (*)(const Tensor&, const Tensor&, int64_t, bool); const Tensor&,
const Tensor&,
int64_t,
const c10::optional<Scalar>&);
DECLARE_DISPATCH(segment_reduce_fn, _segment_reduce_stub); DECLARE_DISPATCH(segment_reduce_fn, _segment_reduce_stub);
using segment_reduce_backward_fn = using segment_reduce_backward_fn =

View File

@ -45,20 +45,11 @@ Tensor _segment_reduce_cuda_kernel(
const Tensor& data, const Tensor& data,
const Tensor& lengths, const Tensor& lengths,
int64_t axis, int64_t axis,
bool unsafe) { const c10::optional<Scalar>& initial) {
if (!unsafe) {
TORCH_CHECK(
(lengths.min().item<int64_t>() > 0),
"lengths contains non positive value!");
TORCH_CHECK(lengths.sum().item<int64_t>() == data.numel());
}
int64_t segment_count = lengths.numel(); int64_t segment_count = lengths.numel();
const auto data_contig = data.contiguous();
auto output = at::empty({segment_count}, data.options()); auto output = at::empty({segment_count}, data.options());
const auto lengths_contig = lengths.contiguous(); auto offsets = _get_complete_sum(lengths);
auto offsets = _get_complete_sum(lengths_contig);
auto* offsets_data_ptr = offsets.data_ptr<int64_t>(); auto* offsets_data_ptr = offsets.data_ptr<int64_t>();
AT_DISPATCH_ALL_TYPES_AND2( AT_DISPATCH_ALL_TYPES_AND2(
@ -67,14 +58,16 @@ Tensor _segment_reduce_cuda_kernel(
data.scalar_type(), data.scalar_type(),
"segment_reduce_cuda", "segment_reduce_cuda",
[&]() { [&]() {
auto* data_contig_data_ptr = data_contig.data_ptr<scalar_t>(); auto* data_data_ptr = data.data_ptr<scalar_t>();
auto* output_data_ptr = output.data_ptr<scalar_t>(); auto* output_data_ptr = output.data_ptr<scalar_t>();
CustomMax max_op{}; CustomMax max_op{};
scalar_t initial_value = std::numeric_limits<scalar_t>::lowest(); scalar_t initial_value = initial.has_value()
? initial.value().to<scalar_t>()
: std::numeric_limits<scalar_t>::lowest();
CUB_WRAPPER( CUB_WRAPPER(
cub::DeviceSegmentedReduce::Reduce, cub::DeviceSegmentedReduce::Reduce,
data_contig_data_ptr, data_data_ptr,
output_data_ptr, output_data_ptr,
segment_count, segment_count,
offsets_data_ptr, offsets_data_ptr,

View File

@ -9003,7 +9003,7 @@
cpp_no_default_args: ['a', 'b'] cpp_no_default_args: ['a', 'b']
python_module: nn python_module: nn
- func: segment_reduce(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, int axis=0, bool unsafe=False) -> Tensor - func: segment_reduce(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, int axis=0, bool unsafe=False, Scalar? initial=None) -> Tensor
variants: function variants: function
dispatch: dispatch:
CPU, CUDA: segment_reduce_kernel CPU, CUDA: segment_reduce_kernel

View File

@ -13,16 +13,24 @@ from torch.testing._internal.common_utils import (
class TestSegmentReductions(TestCase): class TestSegmentReductions(TestCase):
def _test_max_simple_1d(self, device, dtype, unsafe, axis): def _test_max_simple_1d(self, device, dtype, unsafe, axis):
lengths = torch.tensor([1, 2, 3], device=device) lengths = torch.tensor([1, 2, 3, 0], device=device)
data = torch.tensor( data = torch.tensor(
[1, float("nan"), 3, 4, 5, 5], [1, float("nan"), 3, 4, 5, 5],
device=device, device=device,
dtype=dtype, dtype=dtype,
requires_grad=True, requires_grad=True,
) )
expected_result = torch.tensor([1, float("nan"), 5], device=device, dtype=dtype) initial_value = 0
expected_result = torch.tensor(
[1, float("nan"), 5, initial_value], device=device, dtype=dtype
)
actual_result = torch.segment_reduce( actual_result = torch.segment_reduce(
data=data, reduce="max", lengths=lengths, axis=axis, unsafe=unsafe data=data,
reduce="max",
lengths=lengths,
axis=axis,
unsafe=unsafe,
initial=initial_value,
) )
self.assertEqual( self.assertEqual(
expected_result, actual_result, rtol=1e-03, atol=1e-05, equal_nan=True expected_result, actual_result, rtol=1e-03, atol=1e-05, equal_nan=True
@ -52,7 +60,12 @@ class TestSegmentReductions(TestCase):
self.assertTrue( self.assertTrue(
gradcheck( gradcheck(
lambda x: torch.segment_reduce( lambda x: torch.segment_reduce(
data=x, reduce="max", lengths=lengths, axis=axis, unsafe=unsafe data=x,
reduce="max",
lengths=lengths,
axis=axis,
unsafe=unsafe,
initial=initial_value,
), ),
(data,), (data,),
) )

View File

@ -2062,5 +2062,5 @@
- name: nonzero(Tensor self) -> Tensor - name: nonzero(Tensor self) -> Tensor
output_differentiability: [False] output_differentiability: [False]
- name: segment_reduce(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, int axis=0, bool unsafe=False) -> Tensor - name: segment_reduce(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, int axis=0, bool unsafe=False, Scalar? initial=None) -> Tensor
data: segment_reduce_backward(grad, result, data, lengths) data: segment_reduce_backward(grad, result, data, lengths)