reland Add offsets-based reduction to segment_reduce (CPU, CUDA)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79725

Approved by: https://github.com/george-qi
This commit is contained in:
Mikayla Gawarecki 2022-06-16 17:32:10 +00:00 committed by PyTorch MergeBot
parent 64f3742b2b
commit 7360b53ff3
9 changed files with 491 additions and 175 deletions

View File

@ -8,8 +8,10 @@
namespace at { namespace at {
namespace native { namespace native {
DEFINE_DISPATCH(_segment_reduce_stub); DEFINE_DISPATCH(_segment_reduce_lengths_stub);
DEFINE_DISPATCH(_segment_reduce_backward_stub); DEFINE_DISPATCH(_segment_reduce_offsets_stub);
DEFINE_DISPATCH(_segment_reduce_lengths_backward_stub);
DEFINE_DISPATCH(_segment_reduce_offsets_backward_stub);
namespace { namespace {
@ -29,8 +31,8 @@ SegmentReductionType get_reduction_enum(const c10::string_view& reduce) {
} }
} }
template <typename T> template <typename T, bool is_offsets_like=false>
void _segment_reduce_cpu_kernel1( void _segment_reduce_lengths_cpu_kernel1(
SegmentReductionType reduction, SegmentReductionType reduction,
const Tensor& data, const Tensor& data,
const T* lengths_data, const T* lengths_data,
@ -46,14 +48,30 @@ void _segment_reduce_cpu_kernel1(
outer_offset *= output.size(d); outer_offset *= output.size(d);
for (int64_t d = axis + 1; d < output.dim(); d++) for (int64_t d = axis + 1; d < output.dim(); d++)
inner_offset *= output.size(d); inner_offset *= output.size(d);
int64_t lengths_size_axis = is_offsets_like ? segment_count + 1 : segment_count;
auto data_stride_axis = data.stride(axis);
auto data_size_axis = data.size(axis);
auto output_stride_axis = output.stride(axis);
auto output_size_axis = output.size(axis);
AT_DISPATCH_FLOATING_TYPES_AND2( AT_DISPATCH_FLOATING_TYPES_AND2(
kBFloat16, kHalf, data.scalar_type(), "_segment_reduce_cpu", [&]() { kBFloat16, kHalf, data.scalar_type(), "_segment_reduce_cpu", [&]() {
auto* output_data = output.data_ptr<scalar_t>(); auto* output_data = output.data_ptr<scalar_t>();
const auto* values_data = data.data_ptr<scalar_t>(); const auto* values_data = data.data_ptr<scalar_t>();
for (const auto outer_idx : c10::irange(outer_offset)) { for (const auto outer_idx : c10::irange(outer_offset)) {
int64_t lengths_cum_sum = 0; int64_t segment_start, segment_length;
int64_t segment_end = is_offsets_like ?
lengths_data[outer_idx * lengths_stride_axis * lengths_size_axis] :
0;
for (const auto dim_idx : c10::irange(segment_count)) { for (const auto dim_idx : c10::irange(segment_count)) {
int64_t segment_length = lengths_data[outer_idx * lengths_stride_axis * segment_count + dim_idx]; segment_start = segment_end;
auto lengths_idx = outer_idx * lengths_stride_axis * lengths_size_axis + dim_idx;
if (is_offsets_like) {
segment_end = lengths_data[lengths_idx + 1];
segment_length = segment_end - segment_start;
} else {
segment_length = lengths_data[lengths_idx];
segment_end += segment_length;
}
for (const auto inner_idx : c10::irange(inner_offset)) { for (const auto inner_idx : c10::irange(inner_offset)) {
// ===== step1: initialize starting value // ===== step1: initialize starting value
scalar_t initial_value; scalar_t initial_value;
@ -72,9 +90,9 @@ void _segment_reduce_cpu_kernel1(
} }
// ===== step2: apply reduction // ===== step2: apply reduction
for (const auto j : c10::irange(segment_length)) { for (const auto j : c10::irange(segment_start, segment_end)) {
int64_t data_index = outer_idx * data.stride(axis) * data.size(axis) int64_t data_index = outer_idx * data_stride_axis * data_size_axis
+ (lengths_cum_sum + j) * data.stride(axis) + inner_idx; + j * data_stride_axis + inner_idx;
const auto val = values_data[data_index]; const auto val = values_data[data_index];
if (reduction == SegmentReductionType::MAX) { if (reduction == SegmentReductionType::MAX) {
initial_value = at::_isnan(val) initial_value = at::_isnan(val)
@ -104,17 +122,16 @@ void _segment_reduce_cpu_kernel1(
segment_length > 0 && !at::_isnan(initial_value)) { segment_length > 0 && !at::_isnan(initial_value)) {
initial_value = initial_value / segment_length; initial_value = initial_value / segment_length;
} }
int64_t output_index = outer_idx * output.stride(axis) * output.size(axis) int64_t output_index = outer_idx * output_stride_axis * output_size_axis
+ dim_idx * output.stride(axis) + inner_idx; + dim_idx * output_stride_axis + inner_idx;
output_data[output_index] = initial_value; output_data[output_index] = initial_value;
} }
lengths_cum_sum += segment_length;
} }
} }
}); });
} }
Tensor _segment_reduce_cpu_kernel( Tensor _segment_reduce_lengths_cpu_kernel(
SegmentReductionType reduction, SegmentReductionType reduction,
const Tensor& data, const Tensor& data,
const Tensor& lengths, const Tensor& lengths,
@ -131,17 +148,43 @@ Tensor _segment_reduce_cpu_kernel(
output_shape[axis] = segment_count; output_shape[axis] = segment_count;
auto output = at::empty(output_shape, data.options()); auto output = at::empty(output_shape, data.options());
AT_DISPATCH_INDEX_TYPES(lengths.scalar_type(), "_segment_reduce_cpu_kernel1", [&]() { AT_DISPATCH_INDEX_TYPES(lengths.scalar_type(), "_segment_reduce_lengths_cpu_kernel1", [&]() {
const auto* lengths_data = lengths.data_ptr<index_t>(); const auto* lengths_data = lengths.data_ptr<index_t>();
_segment_reduce_cpu_kernel1( _segment_reduce_lengths_cpu_kernel1(
reduction, data, lengths_data, axis, initial, output, segment_count, lengths_stride_axis); reduction, data, lengths_data, axis, initial, output, segment_count, lengths_stride_axis);
}); });
return output; return output;
} }
template <typename T> Tensor _segment_reduce_offsets_cpu_kernel(
void _segment_reduce_cpu_backward_kernel1( SegmentReductionType reduction,
const Tensor& data,
const Tensor& offsets,
int64_t axis,
const c10::optional<Scalar>& initial) {
// data and lengths should be contiguous from the call to .contiguous in segment_reduce_kernel
TORCH_CHECK(data.is_contiguous(), "Expected data to be contiguous.");
TORCH_CHECK(offsets.is_contiguous(), "Expected offsets to be contiguous.");
// reduction axis should always be the last dimension of lengths
axis = offsets.dim() - 1;
int64_t segment_count = offsets.size(axis) - 1;
int64_t offsets_stride_axis = offsets.stride(axis);
auto output_shape = data.sizes().vec();
output_shape[axis] = segment_count;
auto output = at::empty(output_shape, data.options());
AT_DISPATCH_INDEX_TYPES(offsets.scalar_type(), "_segment_reduce_offsets_cpu_kernel1", [&]() {
const auto* offsets_data = offsets.data_ptr<index_t>();
_segment_reduce_lengths_cpu_kernel1<index_t, /*is_offsets_like=*/true>(
reduction, data, offsets_data, axis, initial, output, segment_count, offsets_stride_axis);
});
return output;
}
template <typename T, bool is_offsets_like = false>
void _segment_reduce_cpu_lengths_backward_kernel1(
const Tensor& grad_contig, const Tensor& grad_contig,
const Tensor& output_contig, const Tensor& output_contig,
const Tensor& data_contig, const Tensor& data_contig,
@ -159,7 +202,12 @@ void _segment_reduce_cpu_backward_kernel1(
outer_offset *= output_contig.size(d); outer_offset *= output_contig.size(d);
for (int64_t d = axis + 1; d < output_contig.dim(); d++) for (int64_t d = axis + 1; d < output_contig.dim(); d++)
inner_offset *= output_contig.size(d); inner_offset *= output_contig.size(d);
// TODO: Swtich to TensorIterator for better maintainablility and int64_t lengths_size_axis = is_offsets_like ? segment_count + 1 : segment_count;
auto data_stride_axis = data_contig.stride(axis);
auto data_size_axis = data_contig.size(axis);
auto output_stride_axis = output_contig.stride(axis);
auto output_size_axis = output_contig.size(axis);
// TODO: Switch to TensorIterator for better maintainablility and
// readability // readability
AT_DISPATCH_FLOATING_TYPES_AND2( AT_DISPATCH_FLOATING_TYPES_AND2(
kBFloat16, kBFloat16,
@ -182,21 +230,34 @@ void _segment_reduce_cpu_backward_kernel1(
} }
for (const auto outer_idx : c10::irange(outer_offset)) { for (const auto outer_idx : c10::irange(outer_offset)) {
int64_t lengths_cum_sum = 0; // int64_t lengths_cum_sum = 0;
int64_t segment_start, segment_length;
int64_t segment_end = is_offsets_like ?
lengths_data[outer_idx * lengths_stride_axis * lengths_size_axis] :
0;
for (const auto dim_idx : c10::irange(segment_count)) { for (const auto dim_idx : c10::irange(segment_count)) {
int64_t segment_length = lengths_data[outer_idx * lengths_stride_axis * segment_count + dim_idx]; // int64_t segment_length = lengths_data[outer_idx * lengths_stride_axis * segment_count + dim_idx];
segment_start = segment_end;
auto lengths_idx = outer_idx * lengths_stride_axis * lengths_size_axis + dim_idx;
if (is_offsets_like) {
segment_end = lengths_data[lengths_idx + 1];
segment_length = segment_end - segment_start;
} else {
segment_length = lengths_data[lengths_idx];
segment_end += segment_length;
}
if (segment_length == 0) { if (segment_length == 0) {
continue; continue;
} }
for (const auto inner_idx : c10::irange(inner_offset)) { for (const auto inner_idx : c10::irange(inner_offset)) {
int64_t output_index = outer_idx * output_contig.stride(axis) * output_contig.size(axis) int64_t output_index = outer_idx * output_stride_axis * output_size_axis
+ dim_idx * output_contig.stride(axis) + inner_idx; + dim_idx * output_stride_axis + inner_idx;
if (reduction == SegmentReductionType::MAX || if (reduction == SegmentReductionType::MAX ||
reduction == SegmentReductionType::MIN) { reduction == SegmentReductionType::MIN) {
int64_t counter = 0; int64_t counter = 0;
for (const auto j : c10::irange(segment_length)) { for (const auto j : c10::irange(segment_start, segment_end)) {
int64_t data_index = outer_idx * data_contig.stride(axis) * data_contig.size(axis) int64_t data_index = outer_idx * data_stride_axis * data_size_axis
+ (lengths_cum_sum + j) * data_contig.stride(axis) + inner_idx; + j * data_stride_axis + inner_idx;
if (at::_isnan(values_data[data_index]) || if (at::_isnan(values_data[data_index]) ||
values_data[data_index] == output_data[output_index]) { values_data[data_index] == output_data[output_index]) {
grad_input_data[data_index] = grad_data[output_index]; grad_input_data[data_index] = grad_data[output_index];
@ -208,9 +269,9 @@ void _segment_reduce_cpu_backward_kernel1(
if (counter < 2) { if (counter < 2) {
continue; continue;
} }
for (const auto j : c10::irange(segment_length)) { for (const auto j : c10::irange(segment_start, segment_end)) {
int64_t data_index = outer_idx * data_contig.stride(axis) * data_contig.size(axis) int64_t data_index = outer_idx * data_stride_axis * data_size_axis
+ (lengths_cum_sum + j) * data_contig.stride(axis) + inner_idx; + j * data_stride_axis + inner_idx;
if (grad_input_data[data_index] > 0) { if (grad_input_data[data_index] > 0) {
grad_input_data[data_index] = grad_input_data[data_index] =
grad_input_data[data_index] / counter; grad_input_data[data_index] / counter;
@ -218,32 +279,32 @@ void _segment_reduce_cpu_backward_kernel1(
} }
} else if (reduction == SegmentReductionType::MEAN) { } else if (reduction == SegmentReductionType::MEAN) {
auto grad_val = grad_data[output_index] / segment_length; auto grad_val = grad_data[output_index] / segment_length;
for (const auto j : c10::irange(segment_length)) { for (const auto j : c10::irange(segment_start, segment_end)) {
int64_t data_index = outer_idx * data_contig.stride(axis) * data_contig.size(axis) int64_t data_index = outer_idx * data_stride_axis * data_size_axis
+ (lengths_cum_sum + j) * data_contig.stride(axis) + inner_idx; + j * data_stride_axis + inner_idx;
grad_input_data[data_index] = grad_val; grad_input_data[data_index] = grad_val;
} }
} else if (reduction == SegmentReductionType::SUM) { } else if (reduction == SegmentReductionType::SUM) {
const auto& grad_val = grad_data[output_index]; const auto& grad_val = grad_data[output_index];
for (const auto j : c10::irange(segment_length)) { for (const auto j : c10::irange(segment_start, segment_end)) {
int64_t data_index = outer_idx * data_contig.stride(axis) * data_contig.size(axis) int64_t data_index = outer_idx * data_stride_axis * data_size_axis
+ (lengths_cum_sum + j) * data_contig.stride(axis) + inner_idx; + j * data_stride_axis + inner_idx;
grad_input_data[data_index] = grad_val; grad_input_data[data_index] = grad_val;
} }
} else if (reduction == SegmentReductionType::PROD) { } else if (reduction == SegmentReductionType::PROD) {
const auto& grad_val = grad_data[output_index] * output_data[output_index]; const auto& grad_val = grad_data[output_index] * output_data[output_index];
for (const auto j : c10::irange(segment_length)) { for (const auto j : c10::irange(segment_start, segment_end)) {
int64_t data_index = outer_idx * data_contig.stride(axis) * data_contig.size(axis) int64_t data_index = outer_idx * data_stride_axis * data_size_axis
+ (lengths_cum_sum + j) * data_contig.stride(axis) + inner_idx; + j * data_stride_axis + inner_idx;
if (at::_isnan(values_data[data_index]) || if (at::_isnan(values_data[data_index]) ||
values_data[data_index] == 0) { values_data[data_index] == 0) {
// explicitly compute exclusive prod // explicitly compute exclusive prod
scalar_t exclusive_prod = initial_prod_value; scalar_t exclusive_prod = initial_prod_value;
int64_t idx; int64_t idx;
for (const auto k : c10::irange(segment_length)) { for (const auto k : c10::irange(segment_start, segment_end)) {
if (k != j) { if (k != j) {
idx = outer_idx * data_contig.stride(axis) * data_contig.size(axis) idx = outer_idx * data_stride_axis * data_size_axis
+ (lengths_cum_sum + k) * data_contig.stride(axis) + inner_idx; + k * data_stride_axis + inner_idx;
exclusive_prod *= values_data[idx]; exclusive_prod *= values_data[idx];
} }
} }
@ -254,13 +315,12 @@ void _segment_reduce_cpu_backward_kernel1(
} }
} }
} }
lengths_cum_sum += segment_length;
} }
} }
}); });
} }
Tensor _segment_reduce_cpu_backward_kernel( Tensor _segment_reduce_cpu_lengths_backward_kernel(
const Tensor& grad_contig, const Tensor& grad_contig,
const Tensor& output_contig, const Tensor& output_contig,
const Tensor& data_contig, const Tensor& data_contig,
@ -274,9 +334,9 @@ Tensor _segment_reduce_cpu_backward_kernel(
auto grad_input = at::zeros({data_contig.sizes()}, grad_contig.options()); auto grad_input = at::zeros({data_contig.sizes()}, grad_contig.options());
AT_DISPATCH_INDEX_TYPES( AT_DISPATCH_INDEX_TYPES(
lengths_contig.scalar_type(), "_segment_reduce_cpu_backward_kernel1", [&] { lengths_contig.scalar_type(), "_segment_reduce_cpu_lengths_backward_kernel1", [&] {
const auto* lengths_data = lengths_contig.data_ptr<index_t>(); const auto* lengths_data = lengths_contig.data_ptr<index_t>();
_segment_reduce_cpu_backward_kernel1( _segment_reduce_cpu_lengths_backward_kernel1(
grad_contig, grad_contig,
output_contig, output_contig,
data_contig, data_contig,
@ -292,6 +352,39 @@ Tensor _segment_reduce_cpu_backward_kernel(
return grad_input; return grad_input;
} }
Tensor _segment_reduce_cpu_offsets_backward_kernel(
const Tensor& grad_contig,
const Tensor& output_contig,
const Tensor& data_contig,
SegmentReductionType reduction,
const Tensor& offsets_contig,
int64_t axis,
const c10::optional<Scalar>& initial) {
axis = offsets_contig.dim() - 1;
int64_t segment_count = offsets_contig.size(axis) - 1;
int64_t offsets_stride_axis = offsets_contig.stride(axis);
auto grad_input = at::zeros({data_contig.sizes()}, grad_contig.options());
AT_DISPATCH_INDEX_TYPES(
offsets_contig.scalar_type(), "_segment_reduce_cpu_offsets_backward_kernel1", [&] {
const auto* offsets_data = offsets_contig.data_ptr<index_t>();
_segment_reduce_cpu_lengths_backward_kernel1<index_t, /*is_offsets_like=*/true>(
grad_contig,
output_contig,
data_contig,
reduction,
offsets_data,
axis,
initial,
grad_input,
segment_count,
offsets_stride_axis);
});
return grad_input;
}
} // namespace } // namespace
Tensor segment_reduce_kernel( Tensor segment_reduce_kernel(
@ -299,49 +392,94 @@ Tensor segment_reduce_kernel(
c10::string_view reduce, c10::string_view reduce,
const c10::optional<Tensor>& lengths, const c10::optional<Tensor>& lengths,
const c10::optional<Tensor>& indices, const c10::optional<Tensor>& indices,
const c10::optional<Tensor>& offsets,
int64_t axis, int64_t axis,
bool unsafe, bool unsafe,
const c10::optional<Scalar>& initial) { const c10::optional<Scalar>& initial) {
axis = maybe_wrap_dim(axis, data.ndimension()); axis = maybe_wrap_dim(axis, data.ndimension());
TORCH_CHECK(data.numel() > 0); TORCH_CHECK(data.numel() > 0);
// length related checks // check that one of lengths or offsets is defined
auto lengths_has_value = lengths.has_value();
auto offsets_has_value = offsets.has_value();
TORCH_CHECK( TORCH_CHECK(
lengths.has_value() && !indices.has_value(), !indices.has_value(),
"Currently only lengths based reduction is supported!") "segment_reduce(): indices based reduction is not supported yet.");
const auto& lengths_value = lengths.value(); TORCH_CHECK(
TORCH_CHECK(data.get_device() == lengths_value.get_device()); lengths_has_value || offsets_has_value,
TORCH_CHECK(data.dim() >= lengths_value.dim()); "segment_reduce(): Either lengths or offsets must be defined.")
TORCH_CHECK(axis == lengths_value.dim() - 1, "Expected axis to be equal to lengths.ndim() - 1 but got ", axis, ".");
if (!unsafe) {
auto min_length = lengths_value.min().item<int64_t>();
TORCH_CHECK((min_length >= 0), "lengths contains negative value!");
TORCH_CHECK(all(lengths_value.sum({-1}) == data.size(axis)).item<bool>(),
"Expected all rows of lengths to sum to data.size(lengths.dim()-1) when unsafe=False");
}
auto reduction = get_reduction_enum(reduce); auto reduction = get_reduction_enum(reduce);
const auto data_contig = data.contiguous(); const auto data_contig = data.contiguous();
const auto lengths_contig = lengths_value.contiguous();
return _segment_reduce_stub( if (offsets_has_value) {
const auto& offsets_value = offsets.value();
// offsets related checks
TORCH_CHECK(data.get_device() == offsets_value.get_device());
TORCH_CHECK(data.dim() >= offsets_value.dim());
TORCH_CHECK(axis == offsets_value.dim() - 1,
"segment_reduce(): Expected axis to be the last dimension of offsets but got ", axis, ".");
// TODO: add checks when !unsafe
const auto offsets_contig = offsets_value.contiguous();
return _segment_reduce_offsets_stub(
data_contig.device().type(),
reduction,
data_contig,
offsets_contig,
axis,
initial);
} else {
const auto& lengths_value = lengths.value();
// length related checks
TORCH_CHECK(data.get_device() == lengths_value.get_device());
TORCH_CHECK(data.dim() >= lengths_value.dim());
TORCH_CHECK(axis == lengths_value.dim() - 1,
"segment_reduce(): Expected axis to be the last dimension of lengths but got ", axis, ".");
if (!unsafe) {
auto min_length = lengths_value.min().item<int64_t>();
TORCH_CHECK((min_length >= 0), "lengths contains negative value!");
TORCH_CHECK(all(lengths_value.sum({-1}) == data.size(axis)).item<bool>(),
"segment_reduce(): Expected all rows of lengths along axis ",
"to sum to data.size(lengths.dim()-1) when !unsafe.");
}
const auto lengths_contig = lengths_value.contiguous();
return _segment_reduce_lengths_stub(
data_contig.device().type(), data_contig.device().type(),
reduction, reduction,
data_contig, data_contig,
lengths_contig, lengths_contig,
axis, axis,
initial); initial);
}
} }
REGISTER_ARCH_DISPATCH( REGISTER_ARCH_DISPATCH(
_segment_reduce_stub, _segment_reduce_lengths_stub,
DEFAULT, DEFAULT,
&_segment_reduce_cpu_kernel); &_segment_reduce_lengths_cpu_kernel);
REGISTER_AVX2_DISPATCH(_segment_reduce_stub, &_segment_reduce_cpu_kernel); REGISTER_AVX2_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel);
REGISTER_AVX512_DISPATCH(_segment_reduce_stub, &_segment_reduce_cpu_kernel); REGISTER_AVX512_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel);
REGISTER_VSX_DISPATCH(_segment_reduce_stub, &_segment_reduce_cpu_kernel); REGISTER_VSX_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel);
REGISTER_ZVECTOR_DISPATCH(_segment_reduce_stub, &_segment_reduce_cpu_kernel); REGISTER_ZVECTOR_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel);
// offsets dispatches
REGISTER_ARCH_DISPATCH(
_segment_reduce_offsets_stub,
DEFAULT,
&_segment_reduce_offsets_cpu_kernel);
REGISTER_AVX2_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel);
REGISTER_AVX512_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel);
REGISTER_VSX_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel);
REGISTER_ZVECTOR_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel);
// Currently some computation is being duplicated across forward and backward. // Currently some computation is being duplicated across forward and backward.
// TODO: Cache indices in forward pass to re-use in backward // TODO: Cache indices in forward pass to re-use in backward
@ -351,21 +489,40 @@ Tensor _segment_reduce_backward_kernel(
const Tensor& data, const Tensor& data,
c10::string_view reduce, c10::string_view reduce,
const c10::optional<Tensor>& lengths, const c10::optional<Tensor>& lengths,
const c10::optional<Tensor>& offsets,
int64_t axis, int64_t axis,
const c10::optional<Scalar>& initial) { const c10::optional<Scalar>& initial) {
axis = maybe_wrap_dim(axis, data.ndimension()); axis = maybe_wrap_dim(axis, data.ndimension());
// check that one of lengths or offsets is defined
// codegen for derivatives.yaml passes an undefined Tensor for None rather than a c10::optional
// so checking .has_value() doesn't work unlike in the forward pass
auto lengths_has_value = lengths.has_value() && lengths.value().defined();
auto offsets_has_value = offsets.has_value() && offsets.value().defined();
TORCH_CHECK( TORCH_CHECK(
lengths.has_value(), lengths_has_value || offsets_has_value,
"Currently only lengths based reduction is supported!") "segment_reduce(): Either lengths or offsets must be defined.");
const auto& lengths_value = lengths.value();
const auto grad_contig = grad.contiguous(); const auto grad_contig = grad.contiguous();
const auto output_contig = output.contiguous(); const auto output_contig = output.contiguous();
const auto data_contig = data.contiguous(); const auto data_contig = data.contiguous();
const auto lengths_contig = lengths_value.contiguous();
auto reduction = get_reduction_enum(reduce); auto reduction = get_reduction_enum(reduce);
return _segment_reduce_backward_stub(
if (offsets_has_value) {
const auto& offsets_value = offsets.value();
const auto offsets_contig = offsets_value.contiguous();
return _segment_reduce_offsets_backward_stub(
grad_contig.device().type(),
grad_contig,
output_contig,
data_contig,
reduction,
offsets_contig,
axis,
initial);
} else {
const auto& lengths_value = lengths.value();
const auto lengths_contig = lengths_value.contiguous();
return _segment_reduce_lengths_backward_stub(
grad_contig.device().type(), grad_contig.device().type(),
grad_contig, grad_contig,
output_contig, output_contig,
@ -374,24 +531,42 @@ Tensor _segment_reduce_backward_kernel(
lengths_contig, lengths_contig,
axis, axis,
initial); initial);
}
} }
REGISTER_ARCH_DISPATCH( REGISTER_ARCH_DISPATCH(
_segment_reduce_backward_stub, _segment_reduce_lengths_backward_stub,
DEFAULT, DEFAULT,
&_segment_reduce_cpu_backward_kernel); &_segment_reduce_cpu_lengths_backward_kernel);
REGISTER_AVX512_DISPATCH( REGISTER_AVX512_DISPATCH(
_segment_reduce_backward_stub, _segment_reduce_lengths_backward_stub,
&_segment_reduce_cpu_backward_kernel); &_segment_reduce_cpu_lengths_backward_kernel);
REGISTER_AVX2_DISPATCH( REGISTER_AVX2_DISPATCH(
_segment_reduce_backward_stub, _segment_reduce_lengths_backward_stub,
&_segment_reduce_cpu_backward_kernel); &_segment_reduce_cpu_lengths_backward_kernel);
REGISTER_VSX_DISPATCH( REGISTER_VSX_DISPATCH(
_segment_reduce_backward_stub, _segment_reduce_lengths_backward_stub,
&_segment_reduce_cpu_backward_kernel); &_segment_reduce_cpu_lengths_backward_kernel);
REGISTER_ZVECTOR_DISPATCH( REGISTER_ZVECTOR_DISPATCH(
_segment_reduce_backward_stub, _segment_reduce_lengths_backward_stub,
&_segment_reduce_cpu_backward_kernel); &_segment_reduce_cpu_lengths_backward_kernel);
REGISTER_ARCH_DISPATCH(
_segment_reduce_offsets_backward_stub,
DEFAULT,
&_segment_reduce_cpu_offsets_backward_kernel);
REGISTER_AVX512_DISPATCH(
_segment_reduce_offsets_backward_stub,
&_segment_reduce_cpu_offsets_backward_kernel);
REGISTER_AVX2_DISPATCH(
_segment_reduce_offsets_backward_stub,
&_segment_reduce_cpu_offsets_backward_kernel);
REGISTER_VSX_DISPATCH(
_segment_reduce_offsets_backward_stub,
&_segment_reduce_cpu_offsets_backward_kernel);
REGISTER_ZVECTOR_DISPATCH(
_segment_reduce_offsets_backward_stub,
&_segment_reduce_cpu_offsets_backward_kernel);
} // namespace native } // namespace native
} // namespace at } // namespace at

View File

@ -11,15 +11,23 @@ namespace native {
enum SegmentReductionType { MAX, MEAN, MIN, SUM, PROD}; enum SegmentReductionType { MAX, MEAN, MIN, SUM, PROD};
using segment_reduce_fn = Tensor (*)( using segment_reduce_lengths_fn = Tensor (*)(
SegmentReductionType, SegmentReductionType,
const Tensor&, const Tensor&,
const Tensor&, const Tensor&,
int64_t, int64_t,
const c10::optional<Scalar>&); const c10::optional<Scalar>&);
DECLARE_DISPATCH(segment_reduce_fn, _segment_reduce_stub); DECLARE_DISPATCH(segment_reduce_lengths_fn, _segment_reduce_lengths_stub);
using segment_reduce_backward_fn = Tensor (*)( using segment_reduce_offsets_fn = Tensor (*)(
SegmentReductionType,
const Tensor&,
const Tensor&,
int64_t,
const c10::optional<Scalar>&);
DECLARE_DISPATCH(segment_reduce_offsets_fn, _segment_reduce_offsets_stub);
using segment_reduce_lengths_backward_fn = Tensor (*)(
const Tensor&, const Tensor&,
const Tensor&, const Tensor&,
const Tensor&, const Tensor&,
@ -27,7 +35,17 @@ using segment_reduce_backward_fn = Tensor (*)(
const Tensor&, const Tensor&,
int64_t, int64_t,
const c10::optional<Scalar>&); const c10::optional<Scalar>&);
DECLARE_DISPATCH(segment_reduce_backward_fn, _segment_reduce_backward_stub); DECLARE_DISPATCH(segment_reduce_lengths_backward_fn, _segment_reduce_lengths_backward_stub);
using segment_reduce_offsets_backward_fn = Tensor (*)(
const Tensor&,
const Tensor&,
const Tensor&,
SegmentReductionType,
const Tensor&,
int64_t,
const c10::optional<Scalar>&);
DECLARE_DISPATCH(segment_reduce_offsets_backward_fn, _segment_reduce_offsets_backward_stub);
} // namespace native } // namespace native
} // namespace at } // namespace at

View File

@ -70,7 +70,7 @@ Tensor _get_complete_sum(const Tensor& lengths) {
offsets[0].zero_(); offsets[0].zero_();
AT_DISPATCH_INDEX_TYPES( AT_DISPATCH_INDEX_TYPES(
lengths.scalar_type(), "_segment_reduce_cuda_backward_kernel1", ([&] { lengths.scalar_type(), "_segment_reduce_cuda_lengths_offsets_backward_kernel1", ([&] {
auto* lengths_data_ptr = lengths.data_ptr<index_t>(); auto* lengths_data_ptr = lengths.data_ptr<index_t>();
auto* offsets_data_ptr = offsets.data_ptr<index_t>(); auto* offsets_data_ptr = offsets.data_ptr<index_t>();
at::cuda::cub::inclusive_sum( at::cuda::cub::inclusive_sum(
@ -278,23 +278,33 @@ __global__ void segment_reduce_backward_kernel(
} }
} // namespace } // namespace
Tensor _segment_reduce_cuda_backward_kernel( Tensor _segment_reduce_lengths_offsets_backward_cuda_kernel(
const Tensor& grad_contig, const Tensor& grad_contig,
const Tensor& output_contig, const Tensor& output_contig,
const Tensor& data_contig, const Tensor& data_contig,
SegmentReductionType reduction, SegmentReductionType reduction,
const Tensor& lengths_contig, const Tensor& lengths_or_offsets_contig,
int64_t axis, int64_t axis,
const c10::optional<Scalar>& initial) { const c10::optional<Scalar>& initial,
axis = lengths_contig.dim() - 1; bool is_offsets_like) {
int64_t segment_count = lengths_contig.size(axis); axis = lengths_or_offsets_contig.dim() - 1;
int64_t lengths_stride_axis = lengths_contig.stride(axis); int64_t segment_count = is_offsets_like ?
lengths_or_offsets_contig.size(axis) - 1 :
lengths_or_offsets_contig.size(axis);
int64_t lengths_stride_axis = lengths_or_offsets_contig.stride(axis);
auto grad_input = at::zeros({data_contig.sizes()}, grad_contig.options()); auto grad_input = at::zeros({data_contig.sizes()}, grad_contig.options());
auto zeros_shape = lengths_contig.sizes().vec(); auto offsets = lengths_or_offsets_contig;
zeros_shape[axis] = 1; auto lengths = lengths_or_offsets_contig;
auto offsets = at::cat({at::zeros(zeros_shape, lengths_contig.options()), lengths_contig}, axis); if (is_offsets_like) {
offsets.cumsum_(axis); lengths = lengths.diff();
} else {
// _get_complete_sum only supports 1D
auto zeros_shape = offsets.sizes().vec();
zeros_shape[axis] = 1;
offsets = at::cat({at::zeros(zeros_shape, offsets.options()), offsets}, axis);
offsets.cumsum_(axis);
}
// outer_offset is the size of the outer dimensions of output (before axis) // outer_offset is the size of the outer dimensions of output (before axis)
// inner_offset is the size of the inner dimensions of output (after axis) // inner_offset is the size of the inner dimensions of output (after axis)
@ -318,8 +328,8 @@ Tensor _segment_reduce_cuda_backward_kernel(
auto offsets_stride_axis = offsets.stride(axis); auto offsets_stride_axis = offsets.stride(axis);
AT_DISPATCH_INDEX_TYPES( AT_DISPATCH_INDEX_TYPES(
lengths_contig.scalar_type(), "_segment_reduce_cuda_backward_kernel1", ([&] { lengths_or_offsets_contig.scalar_type(), "_segment_reduce_cuda_lengths_offsets_backward_kernel1", ([&] {
const auto* lengths_data = lengths_contig.data_ptr<index_t>(); const auto* lengths_data = lengths.data_ptr<index_t>();
auto* offsets_data = offsets.data_ptr<index_t>(); auto* offsets_data = offsets.data_ptr<index_t>();
// TODO: Switch to TensorIterator for better maintainablility and // TODO: Switch to TensorIterator for better maintainablility and
@ -371,27 +381,59 @@ Tensor _segment_reduce_cuda_backward_kernel(
return grad_input; return grad_input;
} }
Tensor _segment_reduce_cuda_kernel( Tensor _segment_reduce_lengths_backward_cuda_kernel(
SegmentReductionType reduction, const Tensor& grad_contig,
const Tensor& data, const Tensor& output_contig,
const Tensor& lengths, const Tensor& data_contig,
int64_t axis, SegmentReductionType reduction,
const c10::optional<Scalar>& initial) { const Tensor& lengths_contig,
// data and lengths should be contiguous from the call to .contiguous in segment_reduce_kernel int64_t axis,
TORCH_CHECK(data.is_contiguous(), "Expected data to be contiguous."); const c10::optional<Scalar>& initial) {
TORCH_CHECK(lengths.is_contiguous(), "Expected lengths to be contiguous."); return _segment_reduce_lengths_offsets_backward_cuda_kernel(
axis = lengths.dim() - 1; grad_contig, output_contig, data_contig, reduction, lengths_contig, axis, initial, /*is_offsets_like=*/false);
int64_t segment_count = lengths.size(axis); }
int64_t lengths_stride_axis = lengths.stride(axis);
Tensor _segment_reduce_offsets_backward_cuda_kernel(
const Tensor& grad_contig,
const Tensor& output_contig,
const Tensor& data_contig,
SegmentReductionType reduction,
const Tensor& offsets_contig,
int64_t axis,
const c10::optional<Scalar>& initial) {
return _segment_reduce_lengths_offsets_backward_cuda_kernel(
grad_contig, output_contig, data_contig, reduction, offsets_contig, axis, initial, /*is_offsets_like=*/true);
}
Tensor _segment_reduce_lengths_offsets_cuda_kernel(
SegmentReductionType reduction,
const Tensor& data,
const Tensor& lengths_or_offsets,
int64_t axis,
const c10::optional<Scalar>& initial,
bool is_offsets_like) {
// data and lengths_or_offsets should be contiguous from the call to .contiguous in segment_reduce_kernel
TORCH_CHECK(data.is_contiguous());
TORCH_CHECK(lengths_or_offsets.is_contiguous());
axis = lengths_or_offsets.dim() - 1;
int64_t segment_count = is_offsets_like ? lengths_or_offsets.size(axis) - 1 : lengths_or_offsets.size(axis);
int64_t lengths_stride_axis = lengths_or_offsets.stride(axis);
auto output_shape = data.sizes().vec(); auto output_shape = data.sizes().vec();
output_shape[axis] = segment_count; output_shape[axis] = segment_count;
auto output = at::empty(output_shape, data.options()); auto output = at::empty(output_shape, data.options());
// _get_complete_sum only supports 1D?
auto zeros_shape = lengths.sizes().vec(); auto offsets = lengths_or_offsets;
zeros_shape[axis] = 1; auto lengths = lengths_or_offsets;
auto offsets = at::cat({at::zeros(zeros_shape, lengths.options()), lengths}, axis); if (is_offsets_like) {
offsets.cumsum_(axis); lengths = lengths.diff();
} else {
// _get_complete_sum only supports 1D
auto zeros_shape = offsets.sizes().vec();
zeros_shape[axis] = 1;
offsets = at::cat({at::zeros(zeros_shape, offsets.options()), offsets}, axis);
offsets.cumsum_(axis);
}
// outer_offset is the size of the outer dimensions of output (before axis) // outer_offset is the size of the outer dimensions of output (before axis)
// inner_offset is the size of the inner dimensions of output (after axis) // inner_offset is the size of the inner dimensions of output (after axis)
@ -416,7 +458,7 @@ Tensor _segment_reduce_cuda_kernel(
auto offsets_stride_axis = offsets.stride(axis); auto offsets_stride_axis = offsets.stride(axis);
AT_DISPATCH_INDEX_TYPES( AT_DISPATCH_INDEX_TYPES(
lengths.scalar_type(), "_segment_reduce_cuda_kernel1", ([&] { lengths_or_offsets.scalar_type(), "_segment_reduce_cuda_kernel1", ([&] {
auto* offsets_data_ptr = offsets.data_ptr<index_t>(); auto* offsets_data_ptr = offsets.data_ptr<index_t>();
auto* lengths_data_ptr = lengths.data_ptr<index_t>(); auto* lengths_data_ptr = lengths.data_ptr<index_t>();
AT_DISPATCH_FLOATING_TYPES_AND2( AT_DISPATCH_FLOATING_TYPES_AND2(
@ -549,10 +591,34 @@ Tensor _segment_reduce_cuda_kernel(
return output; return output;
} }
REGISTER_DISPATCH(_segment_reduce_stub, &_segment_reduce_cuda_kernel); Tensor _segment_reduce_lengths_cuda_kernel(
SegmentReductionType reduction,
const Tensor& data,
const Tensor& lengths,
int64_t axis,
const c10::optional<Scalar>& initial) {
return _segment_reduce_lengths_offsets_cuda_kernel(
reduction, data, lengths, axis, initial, /*is_offsets_like=*/false);
}
Tensor _segment_reduce_offsets_cuda_kernel(
SegmentReductionType reduction,
const Tensor& data,
const Tensor& offsets,
int64_t axis,
const c10::optional<Scalar>& initial) {
return _segment_reduce_lengths_offsets_cuda_kernel(
reduction, data, offsets, axis, initial, /*is_offsets_like=*/true);
}
REGISTER_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cuda_kernel);
REGISTER_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cuda_kernel);
REGISTER_DISPATCH( REGISTER_DISPATCH(
_segment_reduce_backward_stub, _segment_reduce_lengths_backward_stub,
&_segment_reduce_cuda_backward_kernel); &_segment_reduce_lengths_backward_cuda_kernel);
REGISTER_DISPATCH(
_segment_reduce_offsets_backward_stub,
&_segment_reduce_offsets_backward_cuda_kernel);
} // namespace native } // namespace native
} // namespace at } // namespace at

View File

@ -11924,12 +11924,12 @@
dispatch: dispatch:
CompositeExplicitAutograd: _test_warn_in_autograd CompositeExplicitAutograd: _test_warn_in_autograd
- func: segment_reduce(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, int axis=0, bool unsafe=False, Scalar? initial=None) -> Tensor - func: segment_reduce(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, Tensor? offsets=None, int axis=0, bool unsafe=False, Scalar? initial=None) -> Tensor
variants: function variants: function
dispatch: dispatch:
CPU, CUDA: segment_reduce_kernel CPU, CUDA: segment_reduce_kernel
- func: _segment_reduce_backward(Tensor grad, Tensor output, Tensor data, str reduce, *, Tensor? lengths=None, int axis=0, Scalar? initial=None) -> Tensor - func: _segment_reduce_backward(Tensor grad, Tensor output, Tensor data, str reduce, *, Tensor? lengths=None, Tensor? offsets=None, int axis=0, Scalar? initial=None) -> Tensor
variants: function variants: function
dispatch: dispatch:
CPU, CUDA: _segment_reduce_backward_kernel CPU, CUDA: _segment_reduce_backward_kernel

View File

@ -87,8 +87,8 @@ ALLOW_LIST = [
("prim::infer_squeeze_size", datetime.date(9999, 1, 1)), ("prim::infer_squeeze_size", datetime.date(9999, 1, 1)),
("aten::_weight_norm_cuda_interface", datetime.date(9999, 1, 1)), ("aten::_weight_norm_cuda_interface", datetime.date(9999, 1, 1)),
("aten::_weight_norm_cuda_interface_backward", datetime.date(9999, 1, 1)), ("aten::_weight_norm_cuda_interface_backward", datetime.date(9999, 1, 1)),
("aten::segment_reduce", datetime.date(9999, 1, 1)), ("aten::segment_reduce", datetime.date(2022, 6, 30)),
("aten::_segment_reduce_backward", datetime.date(9999, 1, 1)), ("aten::_segment_reduce_backward", datetime.date(2022, 6, 30)),
("aten::empty.SymInt", datetime.date(9999, 1, 1)), ("aten::empty.SymInt", datetime.date(9999, 1, 1)),
# TODO: FIXME: prims shouldn't be checked # TODO: FIXME: prims shouldn't be checked
("prims::.*", datetime.date(9999, 1, 1)), ("prims::.*", datetime.date(9999, 1, 1)),

View File

@ -1,6 +1,7 @@
# Owner(s): ["module: scatter & gather ops"] # Owner(s): ["module: scatter & gather ops"]
from itertools import product from itertools import product
from functools import partial
import numpy as np import numpy as np
import torch import torch
@ -52,6 +53,11 @@ class TestSegmentReductions(TestCase):
lengths_dtype=torch.int, lengths_dtype=torch.int,
): ):
lengths = torch.tensor(lengths_arr, device=device, dtype=lengths_dtype) lengths = torch.tensor(lengths_arr, device=device, dtype=lengths_dtype)
# generate offsets from lengths
zeros_shape = list(lengths.shape)
zeros_shape[-1] = 1
offsets = torch.cat((lengths.new_zeros(zeros_shape), lengths), -1).cumsum_(-1)
data = torch.tensor( data = torch.tensor(
data_arr, data_arr,
device=device, device=device,
@ -60,52 +66,56 @@ class TestSegmentReductions(TestCase):
) )
expected_result = torch.tensor(expected_arr, device=device, dtype=dtype) expected_result = torch.tensor(expected_arr, device=device, dtype=dtype)
expected_grad = torch.tensor(expected_grad_arr, device=device, dtype=dtype) expected_grad = torch.tensor(expected_grad_arr, device=device, dtype=dtype)
actual_result = torch.segment_reduce( for mode in ['lengths', 'offsets']:
data=data, segment_reduce_kwargs = dict(
reduce=reduction, axis=axis,
lengths=lengths, unsafe=unsafe,
axis=axis, initial=initial_value)
unsafe=unsafe, if (mode == 'lengths'):
initial=initial_value, segment_reduce_kwargs['lengths'] = lengths
) else:
self.assertEqual( segment_reduce_kwargs['offsets'] = offsets
expected_result, actual_result, rtol=1e-02, atol=1e-05, equal_nan=True actual_result = torch.segment_reduce(
) data=data,
reduce=reduction,
if not check_backward: **segment_reduce_kwargs
return
# Test backward
actual_result.sum().backward()
self.assertEqual(
expected_grad, data.grad, rtol=1e-02, atol=1e-05, equal_nan=True
)
# gradcheck does not work well with bfloat16 or fp16 cpu types
# also there is small numerical difference with fp32
if dtype not in [torch.half, torch.bfloat16, torch.float]:
# gradcheck does not like "nan" input, setting to random 10
d_non_nan = np.nan_to_num(data_arr, nan=10)
data = torch.tensor(
# [10 if v == float("nan") else v for v in data],
d_non_nan,
device=device,
dtype=dtype,
requires_grad=True,
) )
self.assertTrue( self.assertEqual(
gradcheck( expected_result, actual_result, rtol=1e-02, atol=1e-05, equal_nan=True
lambda x: torch.segment_reduce( )
data=x,
reduce=reduction, if not check_backward:
lengths=lengths, return
axis=axis,
unsafe=unsafe, # Test backward
initial=initial_value, actual_result.sum().backward()
), self.assertEqual(
(data,), expected_grad, data.grad, rtol=1e-02, atol=1e-05, equal_nan=True
)
data = data.clone().detach().requires_grad_(True)
# gradcheck does not work well with bfloat16 or fp16 cpu types
# also there is small numerical difference with fp32
if dtype not in [torch.half, torch.bfloat16, torch.float]:
# gradcheck does not like "nan" input, setting to random 10
d_non_nan = np.nan_to_num(data_arr, nan=10)
new_data = torch.tensor(
# [10 if v == float("nan") else v for v in data],
d_non_nan,
device=device,
dtype=dtype,
requires_grad=True,
)
self.assertTrue(
gradcheck(
lambda x: torch.segment_reduce(
data=x,
reduce=reduction,
**segment_reduce_kwargs
),
(new_data,),
)
) )
)
@dtypes( @dtypes(
*product( *product(
@ -384,8 +394,18 @@ class TestSegmentReductions(TestCase):
) )
self.assertEqual(actual_result, expected) self.assertEqual(actual_result, expected)
# test offsets
actual_result = torch.segment_reduce(
data=data,
reduce=reduce,
offsets=indptr,
axis=dim,
unsafe=True,
)
self.assertEqual(actual_result, expected)
if val_dtype == torch.float64: if val_dtype == torch.float64:
def fn(x): def fn(x, mode='lengths'):
initial = 1 initial = 1
# supply initial values to prevent gradcheck from failing for 0 length segments # supply initial values to prevent gradcheck from failing for 0 length segments
# where nan/inf are reduction identities that produce nans when calculating the numerical jacobian # where nan/inf are reduction identities that produce nans when calculating the numerical jacobian
@ -393,8 +413,16 @@ class TestSegmentReductions(TestCase):
initial = 1000 initial = 1000
elif reduce == 'max': elif reduce == 'max':
initial = -1000 initial = -1000
return torch.segment_reduce(x, reduce, lengths=lengths, axis=dim, unsafe=True, initial=initial) segment_reduce_args = {x, reduce}
self.assertTrue(gradcheck(fn, (data.clone().detach().requires_grad_(True)))) segment_reduce_kwargs = dict(axis=dim, unsafe=True, initial=initial)
if mode == 'lengths':
segment_reduce_kwargs[mode] = lengths
elif mode == 'offsets':
segment_reduce_kwargs[mode] = indptr
return torch.segment_reduce(*segment_reduce_args, **segment_reduce_kwargs)
self.assertTrue(gradcheck(partial(fn, mode='lengths'), (data.clone().detach().requires_grad_(True))))
self.assertTrue(gradcheck(partial(fn, mode='offsets'), (data.clone().detach().requires_grad_(True))))
@dtypes( @dtypes(
*product( *product(

View File

@ -2731,8 +2731,8 @@
- name: nonzero(Tensor self) -> Tensor - name: nonzero(Tensor self) -> Tensor
output_differentiability: [False] output_differentiability: [False]
- name: segment_reduce(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, int axis=0, bool unsafe=False, Scalar? initial=None) -> Tensor - name: segment_reduce(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, Tensor? offsets=None, int axis=0, bool unsafe=False, Scalar? initial=None) -> Tensor
data: _segment_reduce_backward(grad, result, data, reduce, lengths, axis, initial) data: _segment_reduce_backward(grad, result, data, reduce, lengths, offsets, axis, initial)
- name: _pin_memory(Tensor self, Device? device=None) -> Tensor - name: _pin_memory(Tensor self, Device? device=None) -> Tensor
self: grad self: grad

View File

@ -948,7 +948,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.scatter_add: lambda input, dim, index, src: -1, torch.scatter_add: lambda input, dim, index, src: -1,
torch.scatter_reduce: lambda input, dim, index, src, reduce, include_self=True: -1, torch.scatter_reduce: lambda input, dim, index, src, reduce, include_self=True: -1,
torch.searchsorted: lambda sorted_sequence, input, out_int32=False, right=False, out=None: -1, torch.searchsorted: lambda sorted_sequence, input, out_int32=False, right=False, out=None: -1,
torch.segment_reduce: lambda data, reduce="max", lengths=None, indices=None, axis=0, unsafe=False: -1, torch.segment_reduce: lambda data, reduce="max", lengths=None, indices=None, offsets=None, axis=0, unsafe=False: -1,
torch.select: lambda input, dim, index: -1, torch.select: lambda input, dim, index: -1,
torch.select_scatter: lambda input, src, dim, index: -1, torch.select_scatter: lambda input, src, dim, index: -1,
torch.slice_scatter: lambda input, src, dim=0, start=None, end=None, step=1: -1, torch.slice_scatter: lambda input, src, dim=0, start=None, end=None, step=1: -1,

View File

@ -8449,9 +8449,19 @@ def sample_inputs_segment_reduce(op_info, device, dtype, requires_grad, *, mode=
for args, reduce, initial in product(test_cases, reductions, [1, 2]): for args, reduce, initial in product(test_cases, reductions, [1, 2]):
inp_shape, dim, lengths, unsafe = args inp_shape, dim, lengths, unsafe = args
lengths_t = torch.tensor(lengths, dtype=torch.long, device=device) lengths_t = torch.tensor(lengths, dtype=torch.long, device=device)
sample_input_kwargs = {'axis': dim, 'unsafe': unsafe, 'initial': initial}
if mode == 'lengths':
sample_input_kwargs['lengths'] = lengths_t
elif mode == 'offsets':
zeros_shape = list(lengths_t.shape)
zeros_shape[dim] = 1
offsets_t = torch.cat((lengths_t.new_zeros(zeros_shape), lengths_t), dim).cumsum_(dim)
sample_input_kwargs['offsets'] = offsets_t
else:
raise RuntimeError(f"mode most be one of 'offsets' or 'lengths' got '{mode}'.")
yield SampleInput(_tensor(inp_shape), yield SampleInput(_tensor(inp_shape),
args=(reduce,), args=(reduce,),
kwargs={'lengths': lengths_t, 'axis': dim, 'unsafe': unsafe, 'initial': initial}) kwargs=sample_input_kwargs)
def sample_inputs_ravel(op_info, device, dtype, requires_grad, **kwargs): def sample_inputs_ravel(op_info, device, dtype, requires_grad, **kwargs):
@ -19586,6 +19596,25 @@ op_db: List[OpInfo] = [
), ),
), ),
), ),
OpInfo(
'segment_reduce',
variant_test_name='offsets',
dtypes=floating_types_and(torch.float16, torch.bfloat16),
supports_out=False,
# RuntimeError: derivative for aten::_segment_reduce_backward is not implemented
supports_gradgrad=False,
sample_inputs_func=partial(sample_inputs_segment_reduce, mode='offsets'),
skips=(
# FIXME: CUDA driver API confirmed a leak in
# __main__.TestJitCUDA.test_variant_consistency_jit_segment_reduce_cuda_float32
DecorateInfo(
unittest.skip("Skipped!"),
"TestJit",
"test_variant_consistency_jit",
device_type="cuda",
),
),
),
UnaryUfuncInfo( UnaryUfuncInfo(
'special.bessel_j0', 'special.bessel_j0',
decorators=( decorators=(