Support multi-dimensional lengths in segment_reduce to support pytorch_scatter.segment_* functionalities (CUDA)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77061

Approved by: https://github.com/cpuhrsch
This commit is contained in:
Mikayla Gawarecki 2022-06-09 04:03:07 +00:00 committed by PyTorch MergeBot
parent c0a7c1d02e
commit 40f7ef1f3d
3 changed files with 169 additions and 78 deletions

View File

@ -13,6 +13,8 @@
#else #else
#include <ATen/ops/empty.h> #include <ATen/ops/empty.h>
#include <ATen/ops/zeros.h> #include <ATen/ops/zeros.h>
#include <ATen/ops/cat.h>
#include <ATen/ops/cumsum.h>
#endif #endif
namespace at { namespace at {
@ -68,7 +70,7 @@ Tensor _get_complete_sum(const Tensor& lengths) {
offsets[0].zero_(); offsets[0].zero_();
AT_DISPATCH_INDEX_TYPES( AT_DISPATCH_INDEX_TYPES(
lengths.type(), "_segment_reduce_cuda_backward_kernel1", ([&] { lengths.scalar_type(), "_segment_reduce_cuda_backward_kernel1", ([&] {
auto* lengths_data_ptr = lengths.data_ptr<index_t>(); auto* lengths_data_ptr = lengths.data_ptr<index_t>();
auto* offsets_data_ptr = offsets.data_ptr<index_t>(); auto* offsets_data_ptr = offsets.data_ptr<index_t>();
at::cuda::cub::inclusive_sum( at::cuda::cub::inclusive_sum(
@ -108,22 +110,34 @@ __global__ void segment_reduce_forward_kernel(
const index_t* lengths_data, const index_t* lengths_data,
const index_t* lengths_cumsum_data, const index_t* lengths_cumsum_data,
const int64_t segment_count, const int64_t segment_count,
const int64_t stride_count, const int64_t lengths_stride_axis,
bool is_initial_set, bool is_initial_set,
scalar_t initial_value) { scalar_t initial_value,
const int64_t outer_offset,
const int64_t inner_offset,
const int64_t data_stride_axis,
const int64_t data_size_axis,
const int64_t output_stride_axis,
const int64_t output_size_axis,
const int64_t lengths_cumsum_stride_axis) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
int64_t row_id = idx / stride_count; if (idx >= (outer_offset * segment_count * inner_offset)) {
int64_t lane_id = idx % stride_count;
if (idx >= (segment_count * stride_count)) {
return; return;
} }
int64_t offset_start = lengths_cumsum_data[row_id]; int64_t row_id = idx / inner_offset;
int64_t offset_end = lengths_cumsum_data[row_id + 1]; int64_t lane_id = idx % inner_offset; // lane_id is the inner_idx
int64_t outer_idx = row_id / segment_count;
int64_t dim_idx = row_id % segment_count;
int64_t offset_idx = outer_idx * lengths_cumsum_stride_axis * (segment_count + 1) + dim_idx;
index_t offset_start = lengths_cumsum_data[offset_idx];
index_t offset_end = lengths_cumsum_data[offset_idx + 1];
// ===== step2: apply reduction // ===== step2: apply reduction
for (int64_t j = offset_start; j < offset_end; ++j) { for (index_t j = offset_start; j < offset_end; ++j) {
int64_t starting_index = (j * stride_count) + lane_id; int64_t data_index = outer_idx * data_stride_axis * data_size_axis
const auto data = values_data[starting_index]; + j * data_stride_axis + lane_id;
const auto data = values_data[data_index];
// TODO: There is no need to branch with every element // TODO: There is no need to branch with every element
if (reduction == SegmentReductionType::MAX) { if (reduction == SegmentReductionType::MAX) {
initial_value = initial_value =
@ -142,19 +156,22 @@ __global__ void segment_reduce_forward_kernel(
} }
// ===== step3: finalize reduction // ===== step3: finalize reduction
CUDA_KERNEL_ASSERT(lengths_data[row_id] >= 0); int64_t lengths_idx = outer_idx * lengths_stride_axis * segment_count + dim_idx;
if (lengths_data[row_id] == 0 && !is_initial_set && CUDA_KERNEL_ASSERT(lengths_data[lengths_idx] >= 0);
if (lengths_data[lengths_idx] == 0 && !is_initial_set &&
reduction == SegmentReductionType::MEAN) { reduction == SegmentReductionType::MEAN) {
initial_value = static_cast<scalar_t>(NAN); initial_value = static_cast<scalar_t>(NAN);
} else if ( } else if (
reduction == SegmentReductionType::MEAN && lengths_data[row_id] > 0 && reduction == SegmentReductionType::MEAN && lengths_data[lengths_idx] > 0 &&
!at::_isnan(initial_value)) { !at::_isnan(initial_value)) {
initial_value = initial_value / lengths_data[row_id]; initial_value = initial_value / lengths_data[lengths_idx];
} }
int64_t output_index = (row_id * stride_count) + lane_id; int64_t output_index = outer_idx * output_stride_axis * output_size_axis
+ dim_idx * output_stride_axis + lane_id;
output_data[output_index] = initial_value; output_data[output_index] = initial_value;
} }
template <typename scalar_t, typename index_t> template <typename scalar_t, typename index_t>
__global__ void segment_reduce_backward_kernel( __global__ void segment_reduce_backward_kernel(
SegmentReductionType reduction, SegmentReductionType reduction,
@ -165,32 +182,46 @@ __global__ void segment_reduce_backward_kernel(
const index_t* lengths_data, const index_t* lengths_data,
const index_t* lengths_cumsum_data, const index_t* lengths_cumsum_data,
const int64_t segment_count, const int64_t segment_count,
const int64_t stride_count, const int64_t lengths_stride_axis,
scalar_t initial_prod_value) { scalar_t initial_prod_value,
const int64_t outer_offset,
const int64_t inner_offset,
const int64_t data_stride_axis,
const int64_t data_size_axis,
const int64_t output_stride_axis,
const int64_t output_size_axis,
const int64_t lengths_cumsum_stride_axis) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
int64_t row_id = idx / stride_count; if (idx >= (outer_offset * segment_count * inner_offset)) {
int64_t lane_id = idx % stride_count;
if (idx >= (segment_count * stride_count)) {
return; return;
} }
if (lengths_data[row_id] == 0) { int64_t row_id = idx / inner_offset;
int64_t lane_id = idx % inner_offset; // lane_id is the inner_idx
int64_t outer_idx = row_id / segment_count;
int64_t dim_idx = row_id % segment_count;
int64_t lengths_idx = outer_idx * lengths_stride_axis * segment_count + dim_idx;
auto segment_length = lengths_data[lengths_idx];
if (segment_length == 0) {
return; return;
} }
int64_t offset_start = lengths_cumsum_data[row_id]; int64_t offset_idx = outer_idx * lengths_cumsum_stride_axis * (segment_count + 1) + dim_idx;
int64_t offset_end = lengths_cumsum_data[row_id + 1]; index_t offset_start = lengths_cumsum_data[offset_idx];
index_t offset_end = lengths_cumsum_data[offset_idx + 1];
int64_t output_index = (row_id * stride_count) + lane_id; int64_t output_index = outer_idx * output_stride_axis * output_size_axis
+ dim_idx * output_stride_axis + lane_id;
if (reduction == SegmentReductionType::MAX || if (reduction == SegmentReductionType::MAX ||
reduction == SegmentReductionType::MIN) { reduction == SegmentReductionType::MIN) {
int64_t counter = 0; int64_t counter = 0;
for (int64_t j = offset_start; j < offset_end; ++j) { for (int64_t j = offset_start; j < offset_end; ++j) {
int64_t starting_index = (j * stride_count) + lane_id; int64_t data_index = outer_idx * data_stride_axis * data_size_axis
if (at::_isnan(values_data[starting_index]) || + j * data_stride_axis + lane_id;
values_data[starting_index] == output_data[output_index]) { if (at::_isnan(values_data[data_index]) ||
grad_input_data[starting_index] = grad_data[output_index]; values_data[data_index] == output_data[output_index]) {
grad_input_data[data_index] = grad_data[output_index];
counter++; counter++;
} }
} }
@ -200,42 +231,47 @@ __global__ void segment_reduce_backward_kernel(
return; return;
} }
for (int64_t j = offset_start; j < offset_end; ++j) { for (int64_t j = offset_start; j < offset_end; ++j) {
int64_t starting_index = (j * stride_count) + lane_id; int64_t data_index = outer_idx * data_stride_axis * data_size_axis
if (grad_input_data[starting_index] > 0) { + j * data_stride_axis + lane_id;
grad_input_data[starting_index] = if (grad_input_data[data_index] > 0) {
grad_input_data[starting_index] / counter; grad_input_data[data_index] =
grad_input_data[data_index] / counter;
} }
} }
} else if (reduction == SegmentReductionType::MEAN) { } else if (reduction == SegmentReductionType::MEAN) {
auto grad_val = grad_data[output_index] / lengths_data[row_id]; auto grad_val = grad_data[output_index] / segment_length;
for (int64_t j = offset_start; j < offset_end; ++j) { for (int64_t j = offset_start; j < offset_end; ++j) {
int64_t starting_index = (j * stride_count) + lane_id; int64_t data_index = outer_idx * data_stride_axis * data_size_axis
grad_input_data[starting_index] = grad_val; + j * data_stride_axis + lane_id;
grad_input_data[data_index] = grad_val;
} }
} else if (reduction == SegmentReductionType::SUM) { } else if (reduction == SegmentReductionType::SUM) {
const auto& grad_val = grad_data[output_index]; const auto& grad_val = grad_data[output_index];
for (int64_t j = offset_start; j < offset_end; ++j) { for (int64_t j = offset_start; j < offset_end; ++j) {
int64_t starting_index = (j * stride_count) + lane_id; int64_t data_index = outer_idx * data_stride_axis * data_size_axis
grad_input_data[starting_index] = grad_val; + j * data_stride_axis + lane_id;
grad_input_data[data_index] = grad_val;
} }
} else if (reduction == SegmentReductionType::PROD) { } else if (reduction == SegmentReductionType::PROD) {
const auto& grad_val = grad_data[output_index] * output_data[output_index]; const auto& grad_val = grad_data[output_index] * output_data[output_index];
for (int64_t j = offset_start; j < offset_end; ++j) { for (int64_t j = offset_start; j < offset_end; ++j) {
int64_t starting_index = (j * stride_count) + lane_id; int64_t data_index = outer_idx * data_stride_axis * data_size_axis
if (at::_isnan(values_data[starting_index]) || + j * data_stride_axis + lane_id;
values_data[starting_index] == 0) { if (at::_isnan(values_data[data_index]) ||
values_data[data_index] == 0) {
// explicitly compute exclusive prod // explicitly compute exclusive prod
scalar_t exclusive_prod = initial_prod_value; scalar_t exclusive_prod = initial_prod_value;
int64_t idx; int64_t prod_idx;
for (int64_t k = offset_start; k < offset_end; ++k) { for (int64_t k = offset_start; k < offset_end; ++k) {
if (k != j) { if (k != j) {
idx = (k * stride_count) + lane_id; prod_idx = outer_idx * data_stride_axis * data_size_axis
exclusive_prod *= values_data[idx]; + k * data_stride_axis + lane_id;
exclusive_prod *= values_data[prod_idx];
} }
} }
grad_input_data[starting_index] = grad_data[output_index] * exclusive_prod; grad_input_data[data_index] = grad_data[output_index] * exclusive_prod;
} else { } else {
grad_input_data[starting_index] = grad_val / values_data[starting_index]; grad_input_data[data_index] = grad_val / values_data[data_index];
} }
} }
} }
@ -251,28 +287,43 @@ Tensor _segment_reduce_cuda_backward_kernel(
const Tensor& lengths_contig, const Tensor& lengths_contig,
int64_t axis, int64_t axis,
const c10::optional<Scalar>& initial) { const c10::optional<Scalar>& initial) {
int64_t segment_count = lengths_contig.numel(); axis = lengths_contig.dim() - 1;
auto output_shape = data_contig.sizes().vec(); int64_t segment_count = lengths_contig.size(axis);
output_shape[axis] = segment_count; int64_t lengths_stride_axis = lengths_contig.stride(axis);
auto grad_input = at::zeros({data_contig.sizes()}, grad_contig.options()); auto grad_input = at::zeros({data_contig.sizes()}, grad_contig.options());
int64_t stride_count = data_contig.numel() / data_contig.size(axis); auto zeros_shape = lengths_contig.sizes().vec();
zeros_shape[axis] = 1;
auto offsets = at::cat({at::zeros(zeros_shape, lengths_contig.options()), lengths_contig}, axis);
offsets.cumsum_(axis);
auto offsets = _get_complete_sum(lengths_contig); // outer_offset is the size of the outer dimensions of output (before axis)
// inner_offset is the size of the inner dimensions of output (after axis)
int64_t outer_offset = 1, inner_offset = 1;
for (int64_t d = 0; d < axis; d++) {
outer_offset *= output_contig.size(d);
}
for (int64_t d = axis + 1; d < output_contig.dim(); d++) {
inner_offset *= output_contig.size(d);
}
constexpr int threads_per_block = 256; constexpr int threads_per_block = 256;
int64_t num_blocks = int64_t num_blocks = (outer_offset * inner_offset * segment_count + threads_per_block - 1) / threads_per_block;
((segment_count * stride_count) + threads_per_block - 1) /
threads_per_block;
num_blocks = std::max(num_blocks, (int64_t)1); num_blocks = std::max(num_blocks, (int64_t)1);
auto data_stride_axis = data_contig.stride(axis);
auto data_size_axis = data_contig.size(axis);
auto output_stride_axis = output_contig.stride(axis);
auto output_size_axis = output_contig.size(axis);
auto offsets_stride_axis = offsets.stride(axis);
AT_DISPATCH_INDEX_TYPES( AT_DISPATCH_INDEX_TYPES(
lengths_contig.type(), "_segment_reduce_cuda_backward_kernel1", ([&] { lengths_contig.scalar_type(), "_segment_reduce_cuda_backward_kernel1", ([&] {
const auto* lengths_data = lengths_contig.data_ptr<index_t>(); const auto* lengths_data = lengths_contig.data_ptr<index_t>();
auto* offsets_data = offsets.data_ptr<index_t>(); auto* offsets_data = offsets.data_ptr<index_t>();
// TODO: Swtich to TensorIterator for better maintainablility and // TODO: Switch to TensorIterator for better maintainablility and
// readability // readability
AT_DISPATCH_FLOATING_TYPES_AND2( AT_DISPATCH_FLOATING_TYPES_AND2(
kBFloat16, kBFloat16,
@ -305,8 +356,16 @@ Tensor _segment_reduce_cuda_backward_kernel(
lengths_data, lengths_data,
offsets_data, offsets_data,
segment_count, segment_count,
stride_count, lengths_stride_axis,
initial_prod_value); initial_prod_value,
outer_offset,
inner_offset,
data_stride_axis,
data_size_axis,
output_stride_axis,
output_size_axis,
offsets_stride_axis
);
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
})); }));
})); }));
@ -319,24 +378,46 @@ Tensor _segment_reduce_cuda_kernel(
const Tensor& lengths, const Tensor& lengths,
int64_t axis, int64_t axis,
const c10::optional<Scalar>& initial) { const c10::optional<Scalar>& initial) {
int64_t segment_count = lengths.numel(); // data and lengths should be contiguous from the call to .contiguous in segment_reduce_kernel
TORCH_CHECK(data.is_contiguous(), "Expected data to be contiguous.");
TORCH_CHECK(lengths.is_contiguous(), "Expected lengths to be contiguous.");
axis = lengths.dim() - 1;
int64_t segment_count = lengths.size(axis);
int64_t lengths_stride_axis = lengths.stride(axis);
auto output_shape = data.sizes().vec(); auto output_shape = data.sizes().vec();
output_shape[axis] = segment_count; output_shape[axis] = segment_count;
auto output = at::empty(output_shape, data.options()); auto output = at::empty(output_shape, data.options());
int64_t stride_count = data.numel() / data.size(axis); // _get_complete_sum only supports 1D?
auto zeros_shape = lengths.sizes().vec();
zeros_shape[axis] = 1;
auto offsets = at::cat({at::zeros(zeros_shape, lengths.options()), lengths}, axis);
offsets.cumsum_(axis);
auto offsets = _get_complete_sum(lengths); // outer_offset is the size of the outer dimensions of output (before axis)
// inner_offset is the size of the inner dimensions of output (after axis)
int64_t outer_offset = 1, inner_offset = 1;
for (int64_t d = 0; d < axis; d++) {
outer_offset *= output.size(d);
}
for (int64_t d = axis + 1; d < output.dim(); d++) {
inner_offset *= output.size(d);
}
constexpr int threads_per_block = 256; constexpr int threads_per_block = 256;
int64_t num_blocks = // segment_count * stride_count is just output.numel() ?
((segment_count * stride_count) + threads_per_block - 1) / int64_t num_blocks = (output.numel() + threads_per_block - 1) / threads_per_block;
threads_per_block;
num_blocks = std::max(num_blocks, (int64_t)1); num_blocks = std::max(num_blocks, (int64_t)1);
auto data_stride_axis = data.stride(axis);
auto data_size_axis = data.size(axis);
auto output_stride_axis = output.stride(axis);
auto output_size_axis = output.size(axis);
auto offsets_stride_axis = offsets.stride(axis);
AT_DISPATCH_INDEX_TYPES( AT_DISPATCH_INDEX_TYPES(
lengths.type(), "_segment_reduce_cuda_kernel1", ([&] { lengths.scalar_type(), "_segment_reduce_cuda_kernel1", ([&] {
auto* offsets_data_ptr = offsets.data_ptr<index_t>(); auto* offsets_data_ptr = offsets.data_ptr<index_t>();
auto* lengths_data_ptr = lengths.data_ptr<index_t>(); auto* lengths_data_ptr = lengths.data_ptr<index_t>();
AT_DISPATCH_FLOATING_TYPES_AND2( AT_DISPATCH_FLOATING_TYPES_AND2(
@ -376,9 +457,17 @@ Tensor _segment_reduce_cuda_kernel(
lengths_data_ptr, lengths_data_ptr,
offsets_data_ptr, offsets_data_ptr,
segment_count, segment_count,
stride_count, lengths_stride_axis,
initial.has_value(), initial.has_value(),
initial_value); initial_value,
outer_offset,
inner_offset,
data_stride_axis,
data_size_axis,
output_stride_axis,
output_size_axis,
offsets_stride_axis
);
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
} else { } else {
if (reduction == SegmentReductionType::MAX) { if (reduction == SegmentReductionType::MAX) {

View File

@ -7,13 +7,12 @@ import torch
from torch.testing._internal.common_device_type import ( from torch.testing._internal.common_device_type import (
instantiate_device_type_tests, instantiate_device_type_tests,
dtypes, dtypes,
onlyCPU
) )
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
TestCase, TestCase,
run_tests, run_tests,
gradcheck, gradcheck,
parametrize parametrize,
) )
@ -300,7 +299,6 @@ class TestSegmentReductions(TestCase):
) )
) )
@parametrize("reduce", ['sum', 'prod', 'min', 'max', 'mean']) @parametrize("reduce", ['sum', 'prod', 'min', 'max', 'mean'])
@onlyCPU # will be removed in next PR where CUDA implementation of segment_reduce is adjusted
def test_pytorch_scatter_test_cases(self, device, dtypes, reduce): def test_pytorch_scatter_test_cases(self, device, dtypes, reduce):
val_dtype, length_dtype = dtypes val_dtype, length_dtype = dtypes
# zero-length segments are filled with reduction inits contrary to pytorch_scatter. # zero-length segments are filled with reduction inits contrary to pytorch_scatter.
@ -384,7 +382,6 @@ class TestSegmentReductions(TestCase):
axis=dim, axis=dim,
unsafe=True, unsafe=True,
) )
self.assertEqual(actual_result, expected) self.assertEqual(actual_result, expected)
if val_dtype == torch.float64: if val_dtype == torch.float64:
@ -469,20 +466,19 @@ class TestSegmentReductions(TestCase):
check_backward, check_backward,
) )
@onlyCPU
@dtypes(torch.int, torch.int64) @dtypes(torch.int, torch.int64)
def test_unsafe_flag(self, device, dtype): def test_unsafe_flag(self, device, dtype):
length_type = dtype length_type = dtype
lengths = torch.tensor([0, 2, 3, 0], dtype=length_type) lengths = torch.tensor([0, 2, 3, 0], device=device, dtype=length_type)
data = torch.arange(6).float() data = torch.arange(6, dtype=torch.float, device=device)
# test for error on 1-D lenghts # test for error on 1-D lenghts
with self.assertRaisesRegex(RuntimeError, "Expected all rows of lengths along axis"): with self.assertRaisesRegex(RuntimeError, "Expected all rows of lengths along axis"):
torch.segment_reduce(data, 'sum', lengths=lengths, axis=0, unsafe=False) torch.segment_reduce(data, 'sum', lengths=lengths, axis=0, unsafe=False)
# test for error on multi-D lengths # test for error on multi-D lengths
nd_lengths = torch.tensor([[0, 3, 3, 0], [2, 3, 0, 0]], dtype=length_type) nd_lengths = torch.tensor([[0, 3, 3, 0], [2, 3, 0, 0]], dtype=length_type, device=device)
nd_data = torch.arange(12).reshape(2, 6).float() nd_data = torch.arange(12, dtype=torch.float, device=device).reshape(2, 6)
with self.assertRaisesRegex(RuntimeError, "Expected all rows of lengths along axis"): with self.assertRaisesRegex(RuntimeError, "Expected all rows of lengths along axis"):
torch.segment_reduce(nd_data, 'sum', lengths=nd_lengths, axis=1, unsafe=False) torch.segment_reduce(nd_data, 'sum', lengths=nd_lengths, axis=1, unsafe=False)

View File

@ -8327,7 +8327,7 @@ def sample_inputs_scatter_reduce(op_info, device, dtype, requires_grad, **kwargs
args=(1, idx, src, reduce), args=(1, idx, src, reduce),
kwargs={'include_self': True}) kwargs={'include_self': True})
def sample_inputs_segment_reduce(op_info, device, dtype, requires_grad, **kwargs): def sample_inputs_segment_reduce(op_info, device, dtype, requires_grad, *, mode='lengths', **kwargs):
def _tensor(shape, dtype=dtype, low=None, high=None): def _tensor(shape, dtype=dtype, low=None, high=None):
return make_tensor(shape, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad) return make_tensor(shape, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad)
@ -8340,6 +8340,11 @@ def sample_inputs_segment_reduce(op_info, device, dtype, requires_grad, **kwargs
((S, S), 0, [0, 1, 2, 2], False), ((S, S), 0, [0, 1, 2, 2], False),
# test when lengths do not sum to dim size # test when lengths do not sum to dim size
((M, S, S), 0, [1, 2, 0, 6, 0], True), ((M, S, S), 0, [1, 2, 0, 6, 0], True),
# test for higher dimensions
((S, S), 1, [[0, 1, 2, 2] for _ in range(S)], False),
((S, S), 1, [[2, 0, 3, 0], [0, 1, 2, 2], [3, 0, 2, 0], [1, 1, 1, 2], [0, 1, 2, 2]], False),
((S, S, S), 1, [[0, 1, 2, 2] for _ in range(S)], False),
((S, S, S), 1, [[2, 0, 3, 0], [0, 1, 2, 2], [3, 0, 2, 0], [1, 1, 1, 2], [0, 1, 2, 2]], False),
) )
reductions = ["max", "mean", "min", "sum", "prod"] reductions = ["max", "mean", "min", "sum", "prod"]
@ -19373,6 +19378,7 @@ op_db: List[OpInfo] = [
), ),
OpInfo( OpInfo(
'segment_reduce', 'segment_reduce',
variant_test_name='lengths',
dtypes=floating_types_and(torch.float16, torch.bfloat16), dtypes=floating_types_and(torch.float16, torch.bfloat16),
supports_out=False, supports_out=False,
# RuntimeError: derivative for aten::_segment_reduce_backward is not implemented # RuntimeError: derivative for aten::_segment_reduce_backward is not implemented