mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[torch][segment_reduce] Add support for initial value (#56923)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/56923 Next Steps in order: - Add backward support for CUDA - Add support for more aggregation types - Benchmarking (for cuda mainly)/more testing/documentation - Support for multi dimension Test Plan: Updated unit test to include 0 length segment as well. Reviewed By: ngimel Differential Revision: D27992228 fbshipit-source-id: 28851811f8a784a63162721c511d69e617a93727
This commit is contained in:
parent
bd347012ec
commit
20eac093a7
|
|
@ -17,45 +17,32 @@ Tensor _segment_reduce_cpu_kernel(
|
||||||
const Tensor& data,
|
const Tensor& data,
|
||||||
const Tensor& lengths,
|
const Tensor& lengths,
|
||||||
int64_t axis,
|
int64_t axis,
|
||||||
bool unsafe) {
|
const c10::optional<Scalar>& initial) {
|
||||||
const auto lengths_contig = lengths.contiguous();
|
int64_t batch_size = lengths.numel();
|
||||||
const auto data_contig = data.contiguous();
|
|
||||||
|
|
||||||
int64_t batch_size = lengths_contig.numel();
|
|
||||||
auto output = at::empty({batch_size}, data.options());
|
auto output = at::empty({batch_size}, data.options());
|
||||||
|
|
||||||
const auto* lengths_data = lengths_contig.data_ptr<int64_t>();
|
const auto* lengths_data = lengths.data_ptr<int64_t>();
|
||||||
if (!unsafe) {
|
|
||||||
int64_t sum = 0;
|
|
||||||
for (int64_t i = 0; i < batch_size; ++i) {
|
|
||||||
TORCH_CHECK(
|
|
||||||
(lengths_data[i] > 0), "lengths contains non positive value!");
|
|
||||||
sum += lengths_data[i];
|
|
||||||
}
|
|
||||||
TORCH_CHECK(sum == data.numel());
|
|
||||||
}
|
|
||||||
|
|
||||||
AT_DISPATCH_ALL_TYPES_AND2(
|
AT_DISPATCH_ALL_TYPES_AND2(
|
||||||
kBFloat16,
|
kBFloat16, kHalf, data.scalar_type(), "_segment_reduce_cpu", ([&]() {
|
||||||
kHalf,
|
|
||||||
data_contig.scalar_type(),
|
|
||||||
"_segment_reduce_cpu",
|
|
||||||
([&]() {
|
|
||||||
auto* output_data = output.data_ptr<scalar_t>();
|
auto* output_data = output.data_ptr<scalar_t>();
|
||||||
const auto* values_data = data_contig.data_ptr<scalar_t>();
|
const auto* values_data = data.data_ptr<scalar_t>();
|
||||||
int64_t k = 0;
|
int64_t k = 0;
|
||||||
for (int64_t i = 0; i < batch_size; ++i) {
|
for (int64_t i = 0; i < batch_size; ++i) {
|
||||||
scalar_t reduction = std::numeric_limits<scalar_t>::lowest();
|
scalar_t initial_value = initial.has_value()
|
||||||
|
? initial.value().to<scalar_t>()
|
||||||
|
: std::numeric_limits<scalar_t>::lowest();
|
||||||
for (int64_t j = 0; j < lengths_data[i]; ++j) {
|
for (int64_t j = 0; j < lengths_data[i]; ++j) {
|
||||||
const auto data = values_data[k];
|
const auto data = values_data[k];
|
||||||
reduction =
|
initial_value = at::_isnan(data)
|
||||||
at::_isnan(data) ? data : std::max<scalar_t>(reduction, data);
|
? data
|
||||||
|
: std::max<scalar_t>(initial_value, data);
|
||||||
k++;
|
k++;
|
||||||
}
|
}
|
||||||
// If unsafe is false, check on lengths or indices should cover cases
|
// If unsafe is false, check on lengths or indices should cover cases
|
||||||
// where lengths for a particular segment is non-positive. If unsafe
|
// where lengths for a particular segment is negative. If unsafe
|
||||||
// is true, simply set to numerical limits for particular reduction
|
// is true, simply set to initial_value for particular reduction
|
||||||
output_data[i] = reduction;
|
output_data[i] = initial_value;
|
||||||
}
|
}
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
|
@ -93,8 +80,7 @@ Tensor _segment_reduce_cpu_backward_kernel(
|
||||||
}
|
}
|
||||||
k++;
|
k++;
|
||||||
}
|
}
|
||||||
// Average gradient output based on number of maximum in the segment
|
// Average gradient based on number of maximum elements in the segment
|
||||||
TORCH_INTERNAL_ASSERT(counter > 0);
|
|
||||||
if (counter < 2) {
|
if (counter < 2) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
@ -124,7 +110,8 @@ Tensor segment_reduce_kernel(
|
||||||
const c10::optional<Tensor>& lengths,
|
const c10::optional<Tensor>& lengths,
|
||||||
const c10::optional<Tensor>& indices,
|
const c10::optional<Tensor>& indices,
|
||||||
int64_t axis,
|
int64_t axis,
|
||||||
bool unsafe) {
|
bool unsafe,
|
||||||
|
const c10::optional<Scalar>& initial) {
|
||||||
axis = maybe_wrap_dim(axis, data.ndimension());
|
axis = maybe_wrap_dim(axis, data.ndimension());
|
||||||
TORCH_CHECK(axis == 0, "Currently only dim=0 is supported!");
|
TORCH_CHECK(axis == 0, "Currently only dim=0 is supported!");
|
||||||
TORCH_CHECK(data.dim() == 1);
|
TORCH_CHECK(data.dim() == 1);
|
||||||
|
|
@ -142,8 +129,22 @@ Tensor segment_reduce_kernel(
|
||||||
TORCH_CHECK(data.get_device() == lengths_value.get_device());
|
TORCH_CHECK(data.get_device() == lengths_value.get_device());
|
||||||
TORCH_CHECK(data.dim() >= lengths_value.dim());
|
TORCH_CHECK(data.dim() >= lengths_value.dim());
|
||||||
|
|
||||||
|
if (!unsafe) {
|
||||||
|
auto min_length = lengths_value.min().item<int64_t>();
|
||||||
|
TORCH_CHECK((min_length >= 0), "lengths contains negative value!");
|
||||||
|
TORCH_CHECK(min_length != 0 || initial.has_value());
|
||||||
|
TORCH_CHECK(lengths_value.sum().item<int64_t>() == data.numel());
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto data_contig = data.contiguous();
|
||||||
|
const auto lengths_contig = lengths_value.contiguous();
|
||||||
|
|
||||||
return _segment_reduce_stub(
|
return _segment_reduce_stub(
|
||||||
data.device().type(), data, lengths_value, axis, unsafe);
|
data_contig.device().type(),
|
||||||
|
data_contig,
|
||||||
|
lengths_contig,
|
||||||
|
axis,
|
||||||
|
initial);
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_ARCH_DISPATCH(
|
REGISTER_ARCH_DISPATCH(
|
||||||
|
|
|
||||||
|
|
@ -7,8 +7,11 @@
|
||||||
namespace at {
|
namespace at {
|
||||||
namespace native {
|
namespace native {
|
||||||
|
|
||||||
using segment_reduce_fn =
|
using segment_reduce_fn = Tensor (*)(
|
||||||
Tensor (*)(const Tensor&, const Tensor&, int64_t, bool);
|
const Tensor&,
|
||||||
|
const Tensor&,
|
||||||
|
int64_t,
|
||||||
|
const c10::optional<Scalar>&);
|
||||||
DECLARE_DISPATCH(segment_reduce_fn, _segment_reduce_stub);
|
DECLARE_DISPATCH(segment_reduce_fn, _segment_reduce_stub);
|
||||||
|
|
||||||
using segment_reduce_backward_fn =
|
using segment_reduce_backward_fn =
|
||||||
|
|
|
||||||
|
|
@ -45,20 +45,11 @@ Tensor _segment_reduce_cuda_kernel(
|
||||||
const Tensor& data,
|
const Tensor& data,
|
||||||
const Tensor& lengths,
|
const Tensor& lengths,
|
||||||
int64_t axis,
|
int64_t axis,
|
||||||
bool unsafe) {
|
const c10::optional<Scalar>& initial) {
|
||||||
if (!unsafe) {
|
|
||||||
TORCH_CHECK(
|
|
||||||
(lengths.min().item<int64_t>() > 0),
|
|
||||||
"lengths contains non positive value!");
|
|
||||||
TORCH_CHECK(lengths.sum().item<int64_t>() == data.numel());
|
|
||||||
}
|
|
||||||
|
|
||||||
int64_t segment_count = lengths.numel();
|
int64_t segment_count = lengths.numel();
|
||||||
const auto data_contig = data.contiguous();
|
|
||||||
auto output = at::empty({segment_count}, data.options());
|
auto output = at::empty({segment_count}, data.options());
|
||||||
|
|
||||||
const auto lengths_contig = lengths.contiguous();
|
auto offsets = _get_complete_sum(lengths);
|
||||||
auto offsets = _get_complete_sum(lengths_contig);
|
|
||||||
auto* offsets_data_ptr = offsets.data_ptr<int64_t>();
|
auto* offsets_data_ptr = offsets.data_ptr<int64_t>();
|
||||||
|
|
||||||
AT_DISPATCH_ALL_TYPES_AND2(
|
AT_DISPATCH_ALL_TYPES_AND2(
|
||||||
|
|
@ -67,14 +58,16 @@ Tensor _segment_reduce_cuda_kernel(
|
||||||
data.scalar_type(),
|
data.scalar_type(),
|
||||||
"segment_reduce_cuda",
|
"segment_reduce_cuda",
|
||||||
[&]() {
|
[&]() {
|
||||||
auto* data_contig_data_ptr = data_contig.data_ptr<scalar_t>();
|
auto* data_data_ptr = data.data_ptr<scalar_t>();
|
||||||
auto* output_data_ptr = output.data_ptr<scalar_t>();
|
auto* output_data_ptr = output.data_ptr<scalar_t>();
|
||||||
|
|
||||||
CustomMax max_op{};
|
CustomMax max_op{};
|
||||||
scalar_t initial_value = std::numeric_limits<scalar_t>::lowest();
|
scalar_t initial_value = initial.has_value()
|
||||||
|
? initial.value().to<scalar_t>()
|
||||||
|
: std::numeric_limits<scalar_t>::lowest();
|
||||||
CUB_WRAPPER(
|
CUB_WRAPPER(
|
||||||
cub::DeviceSegmentedReduce::Reduce,
|
cub::DeviceSegmentedReduce::Reduce,
|
||||||
data_contig_data_ptr,
|
data_data_ptr,
|
||||||
output_data_ptr,
|
output_data_ptr,
|
||||||
segment_count,
|
segment_count,
|
||||||
offsets_data_ptr,
|
offsets_data_ptr,
|
||||||
|
|
|
||||||
|
|
@ -9003,7 +9003,7 @@
|
||||||
cpp_no_default_args: ['a', 'b']
|
cpp_no_default_args: ['a', 'b']
|
||||||
python_module: nn
|
python_module: nn
|
||||||
|
|
||||||
- func: segment_reduce(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, int axis=0, bool unsafe=False) -> Tensor
|
- func: segment_reduce(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, int axis=0, bool unsafe=False, Scalar? initial=None) -> Tensor
|
||||||
variants: function
|
variants: function
|
||||||
dispatch:
|
dispatch:
|
||||||
CPU, CUDA: segment_reduce_kernel
|
CPU, CUDA: segment_reduce_kernel
|
||||||
|
|
|
||||||
|
|
@ -13,16 +13,24 @@ from torch.testing._internal.common_utils import (
|
||||||
|
|
||||||
class TestSegmentReductions(TestCase):
|
class TestSegmentReductions(TestCase):
|
||||||
def _test_max_simple_1d(self, device, dtype, unsafe, axis):
|
def _test_max_simple_1d(self, device, dtype, unsafe, axis):
|
||||||
lengths = torch.tensor([1, 2, 3], device=device)
|
lengths = torch.tensor([1, 2, 3, 0], device=device)
|
||||||
data = torch.tensor(
|
data = torch.tensor(
|
||||||
[1, float("nan"), 3, 4, 5, 5],
|
[1, float("nan"), 3, 4, 5, 5],
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
requires_grad=True,
|
requires_grad=True,
|
||||||
)
|
)
|
||||||
expected_result = torch.tensor([1, float("nan"), 5], device=device, dtype=dtype)
|
initial_value = 0
|
||||||
|
expected_result = torch.tensor(
|
||||||
|
[1, float("nan"), 5, initial_value], device=device, dtype=dtype
|
||||||
|
)
|
||||||
actual_result = torch.segment_reduce(
|
actual_result = torch.segment_reduce(
|
||||||
data=data, reduce="max", lengths=lengths, axis=axis, unsafe=unsafe
|
data=data,
|
||||||
|
reduce="max",
|
||||||
|
lengths=lengths,
|
||||||
|
axis=axis,
|
||||||
|
unsafe=unsafe,
|
||||||
|
initial=initial_value,
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
expected_result, actual_result, rtol=1e-03, atol=1e-05, equal_nan=True
|
expected_result, actual_result, rtol=1e-03, atol=1e-05, equal_nan=True
|
||||||
|
|
@ -52,7 +60,12 @@ class TestSegmentReductions(TestCase):
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
gradcheck(
|
gradcheck(
|
||||||
lambda x: torch.segment_reduce(
|
lambda x: torch.segment_reduce(
|
||||||
data=x, reduce="max", lengths=lengths, axis=axis, unsafe=unsafe
|
data=x,
|
||||||
|
reduce="max",
|
||||||
|
lengths=lengths,
|
||||||
|
axis=axis,
|
||||||
|
unsafe=unsafe,
|
||||||
|
initial=initial_value,
|
||||||
),
|
),
|
||||||
(data,),
|
(data,),
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -2062,5 +2062,5 @@
|
||||||
- name: nonzero(Tensor self) -> Tensor
|
- name: nonzero(Tensor self) -> Tensor
|
||||||
output_differentiability: [False]
|
output_differentiability: [False]
|
||||||
|
|
||||||
- name: segment_reduce(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, int axis=0, bool unsafe=False) -> Tensor
|
- name: segment_reduce(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, int axis=0, bool unsafe=False, Scalar? initial=None) -> Tensor
|
||||||
data: segment_reduce_backward(grad, result, data, lengths)
|
data: segment_reduce_backward(grad, result, data, lengths)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user