mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Support multi-dimensional lengths in segment_reduce to support pytorch_scatter.segment_* functionalities (CUDA)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77061 Approved by: https://github.com/cpuhrsch
This commit is contained in:
parent
c0a7c1d02e
commit
40f7ef1f3d
|
|
@ -13,6 +13,8 @@
|
|||
#else
|
||||
#include <ATen/ops/empty.h>
|
||||
#include <ATen/ops/zeros.h>
|
||||
#include <ATen/ops/cat.h>
|
||||
#include <ATen/ops/cumsum.h>
|
||||
#endif
|
||||
|
||||
namespace at {
|
||||
|
|
@ -68,7 +70,7 @@ Tensor _get_complete_sum(const Tensor& lengths) {
|
|||
offsets[0].zero_();
|
||||
|
||||
AT_DISPATCH_INDEX_TYPES(
|
||||
lengths.type(), "_segment_reduce_cuda_backward_kernel1", ([&] {
|
||||
lengths.scalar_type(), "_segment_reduce_cuda_backward_kernel1", ([&] {
|
||||
auto* lengths_data_ptr = lengths.data_ptr<index_t>();
|
||||
auto* offsets_data_ptr = offsets.data_ptr<index_t>();
|
||||
at::cuda::cub::inclusive_sum(
|
||||
|
|
@ -108,22 +110,34 @@ __global__ void segment_reduce_forward_kernel(
|
|||
const index_t* lengths_data,
|
||||
const index_t* lengths_cumsum_data,
|
||||
const int64_t segment_count,
|
||||
const int64_t stride_count,
|
||||
const int64_t lengths_stride_axis,
|
||||
bool is_initial_set,
|
||||
scalar_t initial_value) {
|
||||
scalar_t initial_value,
|
||||
const int64_t outer_offset,
|
||||
const int64_t inner_offset,
|
||||
const int64_t data_stride_axis,
|
||||
const int64_t data_size_axis,
|
||||
const int64_t output_stride_axis,
|
||||
const int64_t output_size_axis,
|
||||
const int64_t lengths_cumsum_stride_axis) {
|
||||
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int64_t row_id = idx / stride_count;
|
||||
int64_t lane_id = idx % stride_count;
|
||||
if (idx >= (segment_count * stride_count)) {
|
||||
if (idx >= (outer_offset * segment_count * inner_offset)) {
|
||||
return;
|
||||
}
|
||||
int64_t offset_start = lengths_cumsum_data[row_id];
|
||||
int64_t offset_end = lengths_cumsum_data[row_id + 1];
|
||||
int64_t row_id = idx / inner_offset;
|
||||
int64_t lane_id = idx % inner_offset; // lane_id is the inner_idx
|
||||
int64_t outer_idx = row_id / segment_count;
|
||||
int64_t dim_idx = row_id % segment_count;
|
||||
|
||||
int64_t offset_idx = outer_idx * lengths_cumsum_stride_axis * (segment_count + 1) + dim_idx;
|
||||
index_t offset_start = lengths_cumsum_data[offset_idx];
|
||||
index_t offset_end = lengths_cumsum_data[offset_idx + 1];
|
||||
|
||||
// ===== step2: apply reduction
|
||||
for (int64_t j = offset_start; j < offset_end; ++j) {
|
||||
int64_t starting_index = (j * stride_count) + lane_id;
|
||||
const auto data = values_data[starting_index];
|
||||
for (index_t j = offset_start; j < offset_end; ++j) {
|
||||
int64_t data_index = outer_idx * data_stride_axis * data_size_axis
|
||||
+ j * data_stride_axis + lane_id;
|
||||
const auto data = values_data[data_index];
|
||||
// TODO: There is no need to branch with every element
|
||||
if (reduction == SegmentReductionType::MAX) {
|
||||
initial_value =
|
||||
|
|
@ -142,19 +156,22 @@ __global__ void segment_reduce_forward_kernel(
|
|||
}
|
||||
|
||||
// ===== step3: finalize reduction
|
||||
CUDA_KERNEL_ASSERT(lengths_data[row_id] >= 0);
|
||||
if (lengths_data[row_id] == 0 && !is_initial_set &&
|
||||
int64_t lengths_idx = outer_idx * lengths_stride_axis * segment_count + dim_idx;
|
||||
CUDA_KERNEL_ASSERT(lengths_data[lengths_idx] >= 0);
|
||||
if (lengths_data[lengths_idx] == 0 && !is_initial_set &&
|
||||
reduction == SegmentReductionType::MEAN) {
|
||||
initial_value = static_cast<scalar_t>(NAN);
|
||||
} else if (
|
||||
reduction == SegmentReductionType::MEAN && lengths_data[row_id] > 0 &&
|
||||
reduction == SegmentReductionType::MEAN && lengths_data[lengths_idx] > 0 &&
|
||||
!at::_isnan(initial_value)) {
|
||||
initial_value = initial_value / lengths_data[row_id];
|
||||
initial_value = initial_value / lengths_data[lengths_idx];
|
||||
}
|
||||
int64_t output_index = (row_id * stride_count) + lane_id;
|
||||
int64_t output_index = outer_idx * output_stride_axis * output_size_axis
|
||||
+ dim_idx * output_stride_axis + lane_id;
|
||||
output_data[output_index] = initial_value;
|
||||
}
|
||||
|
||||
|
||||
template <typename scalar_t, typename index_t>
|
||||
__global__ void segment_reduce_backward_kernel(
|
||||
SegmentReductionType reduction,
|
||||
|
|
@ -165,32 +182,46 @@ __global__ void segment_reduce_backward_kernel(
|
|||
const index_t* lengths_data,
|
||||
const index_t* lengths_cumsum_data,
|
||||
const int64_t segment_count,
|
||||
const int64_t stride_count,
|
||||
scalar_t initial_prod_value) {
|
||||
const int64_t lengths_stride_axis,
|
||||
scalar_t initial_prod_value,
|
||||
const int64_t outer_offset,
|
||||
const int64_t inner_offset,
|
||||
const int64_t data_stride_axis,
|
||||
const int64_t data_size_axis,
|
||||
const int64_t output_stride_axis,
|
||||
const int64_t output_size_axis,
|
||||
const int64_t lengths_cumsum_stride_axis) {
|
||||
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int64_t row_id = idx / stride_count;
|
||||
int64_t lane_id = idx % stride_count;
|
||||
|
||||
if (idx >= (segment_count * stride_count)) {
|
||||
if (idx >= (outer_offset * segment_count * inner_offset)) {
|
||||
return;
|
||||
}
|
||||
if (lengths_data[row_id] == 0) {
|
||||
int64_t row_id = idx / inner_offset;
|
||||
int64_t lane_id = idx % inner_offset; // lane_id is the inner_idx
|
||||
int64_t outer_idx = row_id / segment_count;
|
||||
int64_t dim_idx = row_id % segment_count;
|
||||
|
||||
int64_t lengths_idx = outer_idx * lengths_stride_axis * segment_count + dim_idx;
|
||||
auto segment_length = lengths_data[lengths_idx];
|
||||
if (segment_length == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
int64_t offset_start = lengths_cumsum_data[row_id];
|
||||
int64_t offset_end = lengths_cumsum_data[row_id + 1];
|
||||
int64_t offset_idx = outer_idx * lengths_cumsum_stride_axis * (segment_count + 1) + dim_idx;
|
||||
index_t offset_start = lengths_cumsum_data[offset_idx];
|
||||
index_t offset_end = lengths_cumsum_data[offset_idx + 1];
|
||||
|
||||
int64_t output_index = (row_id * stride_count) + lane_id;
|
||||
int64_t output_index = outer_idx * output_stride_axis * output_size_axis
|
||||
+ dim_idx * output_stride_axis + lane_id;
|
||||
|
||||
if (reduction == SegmentReductionType::MAX ||
|
||||
reduction == SegmentReductionType::MIN) {
|
||||
int64_t counter = 0;
|
||||
for (int64_t j = offset_start; j < offset_end; ++j) {
|
||||
int64_t starting_index = (j * stride_count) + lane_id;
|
||||
if (at::_isnan(values_data[starting_index]) ||
|
||||
values_data[starting_index] == output_data[output_index]) {
|
||||
grad_input_data[starting_index] = grad_data[output_index];
|
||||
int64_t data_index = outer_idx * data_stride_axis * data_size_axis
|
||||
+ j * data_stride_axis + lane_id;
|
||||
if (at::_isnan(values_data[data_index]) ||
|
||||
values_data[data_index] == output_data[output_index]) {
|
||||
grad_input_data[data_index] = grad_data[output_index];
|
||||
counter++;
|
||||
}
|
||||
}
|
||||
|
|
@ -200,42 +231,47 @@ __global__ void segment_reduce_backward_kernel(
|
|||
return;
|
||||
}
|
||||
for (int64_t j = offset_start; j < offset_end; ++j) {
|
||||
int64_t starting_index = (j * stride_count) + lane_id;
|
||||
if (grad_input_data[starting_index] > 0) {
|
||||
grad_input_data[starting_index] =
|
||||
grad_input_data[starting_index] / counter;
|
||||
int64_t data_index = outer_idx * data_stride_axis * data_size_axis
|
||||
+ j * data_stride_axis + lane_id;
|
||||
if (grad_input_data[data_index] > 0) {
|
||||
grad_input_data[data_index] =
|
||||
grad_input_data[data_index] / counter;
|
||||
}
|
||||
}
|
||||
} else if (reduction == SegmentReductionType::MEAN) {
|
||||
auto grad_val = grad_data[output_index] / lengths_data[row_id];
|
||||
auto grad_val = grad_data[output_index] / segment_length;
|
||||
for (int64_t j = offset_start; j < offset_end; ++j) {
|
||||
int64_t starting_index = (j * stride_count) + lane_id;
|
||||
grad_input_data[starting_index] = grad_val;
|
||||
int64_t data_index = outer_idx * data_stride_axis * data_size_axis
|
||||
+ j * data_stride_axis + lane_id;
|
||||
grad_input_data[data_index] = grad_val;
|
||||
}
|
||||
} else if (reduction == SegmentReductionType::SUM) {
|
||||
const auto& grad_val = grad_data[output_index];
|
||||
for (int64_t j = offset_start; j < offset_end; ++j) {
|
||||
int64_t starting_index = (j * stride_count) + lane_id;
|
||||
grad_input_data[starting_index] = grad_val;
|
||||
int64_t data_index = outer_idx * data_stride_axis * data_size_axis
|
||||
+ j * data_stride_axis + lane_id;
|
||||
grad_input_data[data_index] = grad_val;
|
||||
}
|
||||
} else if (reduction == SegmentReductionType::PROD) {
|
||||
const auto& grad_val = grad_data[output_index] * output_data[output_index];
|
||||
for (int64_t j = offset_start; j < offset_end; ++j) {
|
||||
int64_t starting_index = (j * stride_count) + lane_id;
|
||||
if (at::_isnan(values_data[starting_index]) ||
|
||||
values_data[starting_index] == 0) {
|
||||
int64_t data_index = outer_idx * data_stride_axis * data_size_axis
|
||||
+ j * data_stride_axis + lane_id;
|
||||
if (at::_isnan(values_data[data_index]) ||
|
||||
values_data[data_index] == 0) {
|
||||
// explicitly compute exclusive prod
|
||||
scalar_t exclusive_prod = initial_prod_value;
|
||||
int64_t idx;
|
||||
int64_t prod_idx;
|
||||
for (int64_t k = offset_start; k < offset_end; ++k) {
|
||||
if (k != j) {
|
||||
idx = (k * stride_count) + lane_id;
|
||||
exclusive_prod *= values_data[idx];
|
||||
prod_idx = outer_idx * data_stride_axis * data_size_axis
|
||||
+ k * data_stride_axis + lane_id;
|
||||
exclusive_prod *= values_data[prod_idx];
|
||||
}
|
||||
}
|
||||
grad_input_data[starting_index] = grad_data[output_index] * exclusive_prod;
|
||||
grad_input_data[data_index] = grad_data[output_index] * exclusive_prod;
|
||||
} else {
|
||||
grad_input_data[starting_index] = grad_val / values_data[starting_index];
|
||||
grad_input_data[data_index] = grad_val / values_data[data_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -251,28 +287,43 @@ Tensor _segment_reduce_cuda_backward_kernel(
|
|||
const Tensor& lengths_contig,
|
||||
int64_t axis,
|
||||
const c10::optional<Scalar>& initial) {
|
||||
int64_t segment_count = lengths_contig.numel();
|
||||
auto output_shape = data_contig.sizes().vec();
|
||||
output_shape[axis] = segment_count;
|
||||
axis = lengths_contig.dim() - 1;
|
||||
int64_t segment_count = lengths_contig.size(axis);
|
||||
int64_t lengths_stride_axis = lengths_contig.stride(axis);
|
||||
auto grad_input = at::zeros({data_contig.sizes()}, grad_contig.options());
|
||||
|
||||
int64_t stride_count = data_contig.numel() / data_contig.size(axis);
|
||||
auto zeros_shape = lengths_contig.sizes().vec();
|
||||
zeros_shape[axis] = 1;
|
||||
auto offsets = at::cat({at::zeros(zeros_shape, lengths_contig.options()), lengths_contig}, axis);
|
||||
offsets.cumsum_(axis);
|
||||
|
||||
auto offsets = _get_complete_sum(lengths_contig);
|
||||
// outer_offset is the size of the outer dimensions of output (before axis)
|
||||
// inner_offset is the size of the inner dimensions of output (after axis)
|
||||
int64_t outer_offset = 1, inner_offset = 1;
|
||||
for (int64_t d = 0; d < axis; d++) {
|
||||
outer_offset *= output_contig.size(d);
|
||||
}
|
||||
for (int64_t d = axis + 1; d < output_contig.dim(); d++) {
|
||||
inner_offset *= output_contig.size(d);
|
||||
}
|
||||
|
||||
constexpr int threads_per_block = 256;
|
||||
int64_t num_blocks =
|
||||
((segment_count * stride_count) + threads_per_block - 1) /
|
||||
threads_per_block;
|
||||
int64_t num_blocks = (outer_offset * inner_offset * segment_count + threads_per_block - 1) / threads_per_block;
|
||||
|
||||
num_blocks = std::max(num_blocks, (int64_t)1);
|
||||
|
||||
auto data_stride_axis = data_contig.stride(axis);
|
||||
auto data_size_axis = data_contig.size(axis);
|
||||
auto output_stride_axis = output_contig.stride(axis);
|
||||
auto output_size_axis = output_contig.size(axis);
|
||||
auto offsets_stride_axis = offsets.stride(axis);
|
||||
|
||||
AT_DISPATCH_INDEX_TYPES(
|
||||
lengths_contig.type(), "_segment_reduce_cuda_backward_kernel1", ([&] {
|
||||
lengths_contig.scalar_type(), "_segment_reduce_cuda_backward_kernel1", ([&] {
|
||||
const auto* lengths_data = lengths_contig.data_ptr<index_t>();
|
||||
auto* offsets_data = offsets.data_ptr<index_t>();
|
||||
|
||||
// TODO: Swtich to TensorIterator for better maintainablility and
|
||||
// TODO: Switch to TensorIterator for better maintainablility and
|
||||
// readability
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
kBFloat16,
|
||||
|
|
@ -305,8 +356,16 @@ Tensor _segment_reduce_cuda_backward_kernel(
|
|||
lengths_data,
|
||||
offsets_data,
|
||||
segment_count,
|
||||
stride_count,
|
||||
initial_prod_value);
|
||||
lengths_stride_axis,
|
||||
initial_prod_value,
|
||||
outer_offset,
|
||||
inner_offset,
|
||||
data_stride_axis,
|
||||
data_size_axis,
|
||||
output_stride_axis,
|
||||
output_size_axis,
|
||||
offsets_stride_axis
|
||||
);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}));
|
||||
}));
|
||||
|
|
@ -319,24 +378,46 @@ Tensor _segment_reduce_cuda_kernel(
|
|||
const Tensor& lengths,
|
||||
int64_t axis,
|
||||
const c10::optional<Scalar>& initial) {
|
||||
int64_t segment_count = lengths.numel();
|
||||
// data and lengths should be contiguous from the call to .contiguous in segment_reduce_kernel
|
||||
TORCH_CHECK(data.is_contiguous(), "Expected data to be contiguous.");
|
||||
TORCH_CHECK(lengths.is_contiguous(), "Expected lengths to be contiguous.");
|
||||
axis = lengths.dim() - 1;
|
||||
int64_t segment_count = lengths.size(axis);
|
||||
int64_t lengths_stride_axis = lengths.stride(axis);
|
||||
auto output_shape = data.sizes().vec();
|
||||
output_shape[axis] = segment_count;
|
||||
auto output = at::empty(output_shape, data.options());
|
||||
|
||||
int64_t stride_count = data.numel() / data.size(axis);
|
||||
// _get_complete_sum only supports 1D?
|
||||
auto zeros_shape = lengths.sizes().vec();
|
||||
zeros_shape[axis] = 1;
|
||||
auto offsets = at::cat({at::zeros(zeros_shape, lengths.options()), lengths}, axis);
|
||||
offsets.cumsum_(axis);
|
||||
|
||||
auto offsets = _get_complete_sum(lengths);
|
||||
// outer_offset is the size of the outer dimensions of output (before axis)
|
||||
// inner_offset is the size of the inner dimensions of output (after axis)
|
||||
int64_t outer_offset = 1, inner_offset = 1;
|
||||
for (int64_t d = 0; d < axis; d++) {
|
||||
outer_offset *= output.size(d);
|
||||
}
|
||||
for (int64_t d = axis + 1; d < output.dim(); d++) {
|
||||
inner_offset *= output.size(d);
|
||||
}
|
||||
|
||||
constexpr int threads_per_block = 256;
|
||||
int64_t num_blocks =
|
||||
((segment_count * stride_count) + threads_per_block - 1) /
|
||||
threads_per_block;
|
||||
// segment_count * stride_count is just output.numel() ?
|
||||
int64_t num_blocks = (output.numel() + threads_per_block - 1) / threads_per_block;
|
||||
|
||||
num_blocks = std::max(num_blocks, (int64_t)1);
|
||||
|
||||
auto data_stride_axis = data.stride(axis);
|
||||
auto data_size_axis = data.size(axis);
|
||||
auto output_stride_axis = output.stride(axis);
|
||||
auto output_size_axis = output.size(axis);
|
||||
auto offsets_stride_axis = offsets.stride(axis);
|
||||
|
||||
AT_DISPATCH_INDEX_TYPES(
|
||||
lengths.type(), "_segment_reduce_cuda_kernel1", ([&] {
|
||||
lengths.scalar_type(), "_segment_reduce_cuda_kernel1", ([&] {
|
||||
auto* offsets_data_ptr = offsets.data_ptr<index_t>();
|
||||
auto* lengths_data_ptr = lengths.data_ptr<index_t>();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
|
|
@ -376,9 +457,17 @@ Tensor _segment_reduce_cuda_kernel(
|
|||
lengths_data_ptr,
|
||||
offsets_data_ptr,
|
||||
segment_count,
|
||||
stride_count,
|
||||
lengths_stride_axis,
|
||||
initial.has_value(),
|
||||
initial_value);
|
||||
initial_value,
|
||||
outer_offset,
|
||||
inner_offset,
|
||||
data_stride_axis,
|
||||
data_size_axis,
|
||||
output_stride_axis,
|
||||
output_size_axis,
|
||||
offsets_stride_axis
|
||||
);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
} else {
|
||||
if (reduction == SegmentReductionType::MAX) {
|
||||
|
|
|
|||
|
|
@ -7,13 +7,12 @@ import torch
|
|||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests,
|
||||
dtypes,
|
||||
onlyCPU
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase,
|
||||
run_tests,
|
||||
gradcheck,
|
||||
parametrize
|
||||
parametrize,
|
||||
|
||||
)
|
||||
|
||||
|
|
@ -300,7 +299,6 @@ class TestSegmentReductions(TestCase):
|
|||
)
|
||||
)
|
||||
@parametrize("reduce", ['sum', 'prod', 'min', 'max', 'mean'])
|
||||
@onlyCPU # will be removed in next PR where CUDA implementation of segment_reduce is adjusted
|
||||
def test_pytorch_scatter_test_cases(self, device, dtypes, reduce):
|
||||
val_dtype, length_dtype = dtypes
|
||||
# zero-length segments are filled with reduction inits contrary to pytorch_scatter.
|
||||
|
|
@ -384,7 +382,6 @@ class TestSegmentReductions(TestCase):
|
|||
axis=dim,
|
||||
unsafe=True,
|
||||
)
|
||||
|
||||
self.assertEqual(actual_result, expected)
|
||||
|
||||
if val_dtype == torch.float64:
|
||||
|
|
@ -469,20 +466,19 @@ class TestSegmentReductions(TestCase):
|
|||
check_backward,
|
||||
)
|
||||
|
||||
@onlyCPU
|
||||
@dtypes(torch.int, torch.int64)
|
||||
def test_unsafe_flag(self, device, dtype):
|
||||
length_type = dtype
|
||||
lengths = torch.tensor([0, 2, 3, 0], dtype=length_type)
|
||||
data = torch.arange(6).float()
|
||||
lengths = torch.tensor([0, 2, 3, 0], device=device, dtype=length_type)
|
||||
data = torch.arange(6, dtype=torch.float, device=device)
|
||||
|
||||
# test for error on 1-D lenghts
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected all rows of lengths along axis"):
|
||||
torch.segment_reduce(data, 'sum', lengths=lengths, axis=0, unsafe=False)
|
||||
|
||||
# test for error on multi-D lengths
|
||||
nd_lengths = torch.tensor([[0, 3, 3, 0], [2, 3, 0, 0]], dtype=length_type)
|
||||
nd_data = torch.arange(12).reshape(2, 6).float()
|
||||
nd_lengths = torch.tensor([[0, 3, 3, 0], [2, 3, 0, 0]], dtype=length_type, device=device)
|
||||
nd_data = torch.arange(12, dtype=torch.float, device=device).reshape(2, 6)
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected all rows of lengths along axis"):
|
||||
torch.segment_reduce(nd_data, 'sum', lengths=nd_lengths, axis=1, unsafe=False)
|
||||
|
||||
|
|
|
|||
|
|
@ -8327,7 +8327,7 @@ def sample_inputs_scatter_reduce(op_info, device, dtype, requires_grad, **kwargs
|
|||
args=(1, idx, src, reduce),
|
||||
kwargs={'include_self': True})
|
||||
|
||||
def sample_inputs_segment_reduce(op_info, device, dtype, requires_grad, **kwargs):
|
||||
def sample_inputs_segment_reduce(op_info, device, dtype, requires_grad, *, mode='lengths', **kwargs):
|
||||
def _tensor(shape, dtype=dtype, low=None, high=None):
|
||||
return make_tensor(shape, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad)
|
||||
|
||||
|
|
@ -8340,6 +8340,11 @@ def sample_inputs_segment_reduce(op_info, device, dtype, requires_grad, **kwargs
|
|||
((S, S), 0, [0, 1, 2, 2], False),
|
||||
# test when lengths do not sum to dim size
|
||||
((M, S, S), 0, [1, 2, 0, 6, 0], True),
|
||||
# test for higher dimensions
|
||||
((S, S), 1, [[0, 1, 2, 2] for _ in range(S)], False),
|
||||
((S, S), 1, [[2, 0, 3, 0], [0, 1, 2, 2], [3, 0, 2, 0], [1, 1, 1, 2], [0, 1, 2, 2]], False),
|
||||
((S, S, S), 1, [[0, 1, 2, 2] for _ in range(S)], False),
|
||||
((S, S, S), 1, [[2, 0, 3, 0], [0, 1, 2, 2], [3, 0, 2, 0], [1, 1, 1, 2], [0, 1, 2, 2]], False),
|
||||
)
|
||||
|
||||
reductions = ["max", "mean", "min", "sum", "prod"]
|
||||
|
|
@ -19373,6 +19378,7 @@ op_db: List[OpInfo] = [
|
|||
),
|
||||
OpInfo(
|
||||
'segment_reduce',
|
||||
variant_test_name='lengths',
|
||||
dtypes=floating_types_and(torch.float16, torch.bfloat16),
|
||||
supports_out=False,
|
||||
# RuntimeError: derivative for aten::_segment_reduce_backward is not implemented
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user