[torch][segment_reduce] Add support for int lengths as well (#61141)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61141

Currently only long is supported. This diff adds support for other index type.

Next Steps:
- Update default, refactor unit test and test non_initial value as well
- Cleanup (more tests, benchmark, update documentation)

Test Plan: updated unit test. rely on CI.

Reviewed By: ngimel

Differential Revision: D29526308

fbshipit-source-id: b4043603483851ef7e0e93b0bb02ac7849c6449d
This commit is contained in:
Serhat Yilmaz 2021-07-07 13:07:12 -07:00 committed by Facebook GitHub Bot
parent 423523d8bb
commit a78ad5dc4c
3 changed files with 325 additions and 280 deletions

View File

@ -39,65 +39,72 @@ Tensor _segment_reduce_cpu_kernel(
auto output = at::empty(output_shape, data.options());
int64_t stride_count = data.numel() / data.size(axis);
const auto* lengths_data = lengths.data_ptr<int64_t>();
AT_DISPATCH_FLOATING_TYPES_AND2(
kBFloat16, kHalf, data.scalar_type(), "_segment_reduce_cpu", ([&]() {
auto* output_data = output.data_ptr<scalar_t>();
const auto* values_data = data.data_ptr<scalar_t>();
int64_t lengths_cum_sum = 0;
for (int64_t i = 0; i < segment_count; ++i) {
for (int64_t l = 0; l < stride_count; ++l) {
// ===== step1: initialize starting value
scalar_t initial_value;
if (initial.has_value()) {
initial_value = initial.value().to<scalar_t>();
} else if (reduction == SegmentReductionType::MAX) {
initial_value = -std::numeric_limits<scalar_t>::infinity();
} else if (
reduction == SegmentReductionType::MEAN ||
reduction == SegmentReductionType::SUM) {
initial_value = 0;
} else if (reduction == SegmentReductionType::MIN) {
initial_value = std::numeric_limits<scalar_t>::infinity();
}
AT_DISPATCH_INDEX_TYPES(
lengths.type(), "_segment_reduce_cpu_kernel1", ([&] {
const auto* lengths_data = lengths.data_ptr<index_t>();
AT_DISPATCH_FLOATING_TYPES_AND2(
kBFloat16,
kHalf,
data.scalar_type(),
"_segment_reduce_cpu",
([&]() {
auto* output_data = output.data_ptr<scalar_t>();
const auto* values_data = data.data_ptr<scalar_t>();
int64_t lengths_cum_sum = 0;
for (int64_t i = 0; i < segment_count; ++i) {
for (int64_t l = 0; l < stride_count; ++l) {
// ===== step1: initialize starting value
scalar_t initial_value;
if (initial.has_value()) {
initial_value = initial.value().to<scalar_t>();
} else if (reduction == SegmentReductionType::MAX) {
initial_value = -std::numeric_limits<scalar_t>::infinity();
} else if (
reduction == SegmentReductionType::MEAN ||
reduction == SegmentReductionType::SUM) {
initial_value = 0;
} else if (reduction == SegmentReductionType::MIN) {
initial_value = std::numeric_limits<scalar_t>::infinity();
}
// ===== step2: apply reduction
for (int64_t j = 0; j < lengths_data[i]; ++j) {
int64_t starting_index =
((lengths_cum_sum + j) * stride_count) + l;
const auto data = values_data[starting_index];
// TODO: There is no need to branch with every element
if (reduction == SegmentReductionType::MAX) {
initial_value = at::_isnan(data)
? data
: std::max<scalar_t>(initial_value, data);
} else if (
reduction == SegmentReductionType::MEAN ||
reduction == SegmentReductionType::SUM) {
initial_value = initial_value + data;
} else if (reduction == SegmentReductionType::MIN) {
initial_value = at::_isnan(data)
? data
: std::min<scalar_t>(initial_value, data);
// ===== step2: apply reduction
for (int64_t j = 0; j < lengths_data[i]; ++j) {
int64_t starting_index =
((lengths_cum_sum + j) * stride_count) + l;
const auto data = values_data[starting_index];
// TODO: There is no need to branch with every element
if (reduction == SegmentReductionType::MAX) {
initial_value = at::_isnan(data)
? data
: std::max<scalar_t>(initial_value, data);
} else if (
reduction == SegmentReductionType::MEAN ||
reduction == SegmentReductionType::SUM) {
initial_value = initial_value + data;
} else if (reduction == SegmentReductionType::MIN) {
initial_value = at::_isnan(data)
? data
: std::min<scalar_t>(initial_value, data);
}
}
// ===== step3: finalize reduction
TORCH_CHECK(lengths_data[i] >= 0);
if (lengths_data[i] == 0 && !initial.has_value()) {
initial_value = static_cast<scalar_t>(NAN);
} else if (
reduction == SegmentReductionType::MEAN &&
lengths_data[i] > 0 && !at::_isnan(initial_value)) {
initial_value = initial_value / lengths_data[i];
}
int64_t output_index = (i * stride_count) + l;
output_data[output_index] = initial_value;
}
lengths_cum_sum += lengths_data[i];
}
}
// ===== step3: finalize reduction
TORCH_CHECK(lengths_data[i] >= 0);
if (lengths_data[i] == 0 && !initial.has_value()) {
initial_value = static_cast<scalar_t>(NAN);
} else if (
reduction == SegmentReductionType::MEAN &&
lengths_data[i] > 0 && !at::_isnan(initial_value)) {
initial_value = initial_value / lengths_data[i];
}
int64_t output_index = (i * stride_count) + l;
output_data[output_index] = initial_value;
}
lengths_cum_sum += lengths_data[i];
}
}));
}));
return output;
@ -116,73 +123,81 @@ Tensor _segment_reduce_cpu_backward_kernel(
auto grad_input = at::zeros({data_contig.sizes()}, grad_contig.options());
int64_t stride_count = data_contig.numel() / data_contig.size(axis);
const auto* lengths_data = lengths_contig.data_ptr<int64_t>();
// TODO: Swtich to TensorIterator for better maintainablility and readability
AT_DISPATCH_FLOATING_TYPES_AND2(
kBFloat16,
kHalf,
data_contig.scalar_type(),
"_segment_reduce_cpu",
([&]() {
auto* output_data = output_contig.data_ptr<scalar_t>();
auto* grad_data = grad_contig.data_ptr<scalar_t>();
auto* grad_input_data = grad_input.data_ptr<scalar_t>();
const auto* values_data = data_contig.data_ptr<scalar_t>();
AT_DISPATCH_INDEX_TYPES(
lengths_contig.type(),
"_segment_reduce_cpu_backward_kernel1",
([&] {
const auto* lengths_data = lengths_contig.data_ptr<index_t>();
// TODO: Swtich to TensorIterator for better maintainablility and
// readability
AT_DISPATCH_FLOATING_TYPES_AND2(
kBFloat16,
kHalf,
data_contig.scalar_type(),
"_segment_reduce_cpu",
([&]() {
auto* output_data = output_contig.data_ptr<scalar_t>();
auto* grad_data = grad_contig.data_ptr<scalar_t>();
auto* grad_input_data = grad_input.data_ptr<scalar_t>();
const auto* values_data = data_contig.data_ptr<scalar_t>();
int64_t lengths_cum_sum = 0;
for (int64_t i = 0; i < segment_count; ++i) {
if (lengths_data[i] == 0) {
continue;
}
for (int64_t l = 0; l < stride_count; ++l) {
int64_t output_index = (i * stride_count) + l;
if (reduction == SegmentReductionType::MAX ||
reduction == SegmentReductionType::MIN) {
int64_t counter = 0;
for (int64_t j = 0; j < lengths_data[i]; ++j) {
int64_t starting_index =
((lengths_cum_sum + j) * stride_count) + l;
if (at::_isnan(values_data[starting_index]) ||
values_data[starting_index] == output_data[output_index]) {
grad_input_data[starting_index] = grad_data[output_index];
counter++;
int64_t lengths_cum_sum = 0;
for (int64_t i = 0; i < segment_count; ++i) {
if (lengths_data[i] == 0) {
continue;
}
}
// Average gradient based on number of maximum elements in the
// segment
if (counter < 2) {
continue;
}
for (int64_t j = 0; j < lengths_data[i]; ++j) {
int64_t starting_index =
((lengths_cum_sum + j) * stride_count) + l;
if (grad_input_data[starting_index] > 0) {
grad_input_data[starting_index] =
grad_input_data[starting_index] / counter;
}
}
} else if (reduction == SegmentReductionType::MEAN) {
auto grad_val = grad_data[output_index] / lengths_data[i];
for (int64_t j = 0; j < lengths_data[i]; ++j) {
int64_t starting_index =
((lengths_cum_sum + j) * stride_count) + l;
grad_input_data[starting_index] = grad_val;
}
} else if (reduction == SegmentReductionType::SUM) {
const auto& grad_val = grad_data[output_index];
for (int64_t j = 0; j < lengths_data[i]; ++j) {
int64_t starting_index =
((lengths_cum_sum + j) * stride_count) + l;
grad_input_data[starting_index] = grad_val;
}
}
}
lengths_cum_sum += lengths_data[i];
}
for (int64_t l = 0; l < stride_count; ++l) {
int64_t output_index = (i * stride_count) + l;
if (reduction == SegmentReductionType::MAX ||
reduction == SegmentReductionType::MIN) {
int64_t counter = 0;
for (int64_t j = 0; j < lengths_data[i]; ++j) {
int64_t starting_index =
((lengths_cum_sum + j) * stride_count) + l;
if (at::_isnan(values_data[starting_index]) ||
values_data[starting_index] ==
output_data[output_index]) {
grad_input_data[starting_index] =
grad_data[output_index];
counter++;
}
}
// Average gradient based on number of maximum elements in
// the segment
if (counter < 2) {
continue;
}
for (int64_t j = 0; j < lengths_data[i]; ++j) {
int64_t starting_index =
((lengths_cum_sum + j) * stride_count) + l;
if (grad_input_data[starting_index] > 0) {
grad_input_data[starting_index] =
grad_input_data[starting_index] / counter;
}
}
} else if (reduction == SegmentReductionType::MEAN) {
auto grad_val = grad_data[output_index] / lengths_data[i];
for (int64_t j = 0; j < lengths_data[i]; ++j) {
int64_t starting_index =
((lengths_cum_sum + j) * stride_count) + l;
grad_input_data[starting_index] = grad_val;
}
} else if (reduction == SegmentReductionType::SUM) {
const auto& grad_val = grad_data[output_index];
for (int64_t j = 0; j < lengths_data[i]; ++j) {
int64_t starting_index =
((lengths_cum_sum + j) * stride_count) + l;
grad_input_data[starting_index] = grad_val;
}
}
}
lengths_cum_sum += lengths_data[i];
}
}));
}));
return grad_input;

View File

@ -50,23 +50,25 @@ Tensor _get_complete_sum(const Tensor& lengths) {
TORCH_CHECK(segment_count < INT_MAX);
auto offsets = at::empty({segment_count + 1}, lengths.options());
offsets[0].zero_();
auto* lengths_data_ptr = lengths.data_ptr<int64_t>();
auto* offsets_data_ptr = offsets.data_ptr<int64_t>();
CUB_WRAPPER(
cub::DeviceScan::InclusiveSum,
lengths_data_ptr,
offsets_data_ptr + 1,
segment_count,
at::cuda::getCurrentCUDAStream());
AT_DISPATCH_INDEX_TYPES(
lengths.type(), "_segment_reduce_cuda_backward_kernel1", ([&] {
auto* lengths_data_ptr = lengths.data_ptr<index_t>();
auto* offsets_data_ptr = offsets.data_ptr<index_t>();
CUB_WRAPPER(
cub::DeviceScan::InclusiveSum,
lengths_data_ptr,
offsets_data_ptr + 1,
segment_count,
at::cuda::getCurrentCUDAStream());
}));
return offsets;
}
template <typename scalar_t>
template <typename scalar_t, typename index_t>
__global__ static void post_sum_div_kernel(
scalar_t* output_data,
const int64_t* lengths_data,
const index_t* lengths_data,
const int64_t segment_count,
bool is_initial_set,
scalar_t initial) {
@ -84,13 +86,13 @@ __global__ static void post_sum_div_kernel(
}
}
template <typename scalar_t>
__global__ static void segment_reduce_forward_kernel(
template <typename scalar_t, typename index_t>
__global__ void segment_reduce_forward_kernel(
SegmentReductionType reduction,
scalar_t* output_data,
scalar_t* values_data,
const int64_t* lengths_data,
const int64_t* lengths_cumsum_data,
const index_t* lengths_data,
const index_t* lengths_cumsum_data,
const int64_t segment_count,
const int64_t stride_count,
bool is_initial_set,
@ -135,15 +137,15 @@ __global__ static void segment_reduce_forward_kernel(
output_data[output_index] = initial_value;
}
template <typename scalar_t>
__global__ static void segment_reduce_backward_kernel(
template <typename scalar_t, typename index_t>
__global__ void segment_reduce_backward_kernel(
SegmentReductionType reduction,
scalar_t* grad_input_data,
scalar_t* grad_data,
scalar_t* output_data,
const scalar_t* values_data,
const int64_t* lengths_data,
const int64_t* lengths_cumsum_data,
const index_t* lengths_data,
const index_t* lengths_cumsum_data,
const int64_t segment_count,
const int64_t stride_count) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
@ -215,10 +217,8 @@ Tensor _segment_reduce_cuda_backward_kernel(
auto grad_input = at::zeros({data_contig.sizes()}, grad_contig.options());
int64_t stride_count = data_contig.numel() / data_contig.size(axis);
const auto* lengths_data = lengths_contig.data_ptr<int64_t>();
auto offsets = _get_complete_sum(lengths_contig);
auto* offsets_data = offsets.data_ptr<int64_t>();
constexpr int threads_per_block = 256;
int64_t num_blocks =
@ -227,35 +227,41 @@ Tensor _segment_reduce_cuda_backward_kernel(
num_blocks = std::max(num_blocks, (int64_t)1);
// TODO: Swtich to TensorIterator for better maintainablility and readability
AT_DISPATCH_FLOATING_TYPES_AND2(
kBFloat16,
kHalf,
data_contig.scalar_type(),
"_segment_reduce_cpu",
([&]() {
auto* output_data = output_contig.data_ptr<scalar_t>();
auto* grad_data = grad_contig.data_ptr<scalar_t>();
auto* grad_input_data = grad_input.data_ptr<scalar_t>();
const auto* values_data = data_contig.data_ptr<scalar_t>();
AT_DISPATCH_INDEX_TYPES(
lengths_contig.type(), "_segment_reduce_cuda_backward_kernel1", ([&] {
const auto* lengths_data = lengths_contig.data_ptr<index_t>();
auto* offsets_data = offsets.data_ptr<index_t>();
segment_reduce_backward_kernel<scalar_t>
<<<num_blocks,
threads_per_block,
0,
at::cuda::getCurrentCUDAStream()>>>(
reduction,
grad_input_data,
grad_data,
output_data,
values_data,
lengths_data,
offsets_data,
segment_count,
stride_count);
C10_CUDA_KERNEL_LAUNCH_CHECK();
// TODO: Swtich to TensorIterator for better maintainablility and
// readability
AT_DISPATCH_FLOATING_TYPES_AND2(
kBFloat16,
kHalf,
data_contig.scalar_type(),
"_segment_reduce_cpu",
([&]() {
auto* output_data = output_contig.data_ptr<scalar_t>();
auto* grad_data = grad_contig.data_ptr<scalar_t>();
auto* grad_input_data = grad_input.data_ptr<scalar_t>();
const auto* values_data = data_contig.data_ptr<scalar_t>();
segment_reduce_backward_kernel<scalar_t>
<<<num_blocks,
threads_per_block,
0,
at::cuda::getCurrentCUDAStream()>>>(
reduction,
grad_input_data,
grad_data,
output_data,
values_data,
lengths_data,
offsets_data,
segment_count,
stride_count);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}));
}));
return grad_input;
}
@ -271,10 +277,8 @@ Tensor _segment_reduce_cuda_kernel(
auto output = at::empty(output_shape, data.options());
int64_t stride_count = data.numel() / data.size(axis);
const auto* lengths_data = lengths.data_ptr<int64_t>();
auto offsets = _get_complete_sum(lengths);
auto* offsets_data_ptr = offsets.data_ptr<int64_t>();
constexpr int threads_per_block = 256;
int64_t num_blocks =
@ -282,111 +286,115 @@ Tensor _segment_reduce_cuda_kernel(
threads_per_block;
num_blocks = std::max(num_blocks, (int64_t)1);
auto* lengths_data_ptr = lengths.data_ptr<int64_t>();
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
data.scalar_type(),
"segment_reduce_cuda",
[&]() {
auto* data_data_ptr = data.data_ptr<scalar_t>();
auto* output_data_ptr = output.data_ptr<scalar_t>();
AT_DISPATCH_INDEX_TYPES(
lengths.type(), "_segment_reduce_cuda_kernel1", ([&] {
auto* offsets_data_ptr = offsets.data_ptr<index_t>();
auto* lengths_data_ptr = lengths.data_ptr<index_t>();
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
data.scalar_type(),
"segment_reduce_cuda",
[&]() {
auto* data_data_ptr = data.data_ptr<scalar_t>();
auto* output_data_ptr = output.data_ptr<scalar_t>();
// initialize starting value
scalar_t initial_value;
if (initial.has_value()) {
initial_value = initial.value().to<scalar_t>();
} else if (reduction == SegmentReductionType::MAX) {
initial_value = -std::numeric_limits<scalar_t>::infinity();
} else if (
reduction == SegmentReductionType::MEAN ||
reduction == SegmentReductionType::SUM) {
initial_value = 0;
} else if (reduction == SegmentReductionType::MIN) {
initial_value = std::numeric_limits<scalar_t>::infinity();
}
// initialize starting value
scalar_t initial_value;
if (initial.has_value()) {
initial_value = initial.value().to<scalar_t>();
} else if (reduction == SegmentReductionType::MAX) {
initial_value = -std::numeric_limits<scalar_t>::infinity();
} else if (
reduction == SegmentReductionType::MEAN ||
reduction == SegmentReductionType::SUM) {
initial_value = 0;
} else if (reduction == SegmentReductionType::MIN) {
initial_value = std::numeric_limits<scalar_t>::infinity();
}
if (output_shape.size() > 1) {
segment_reduce_forward_kernel<scalar_t>
<<<num_blocks,
threads_per_block,
0,
at::cuda::getCurrentCUDAStream()>>>(
reduction,
output_data_ptr,
data_data_ptr,
lengths_data_ptr,
offsets_data_ptr,
segment_count,
stride_count,
initial.has_value(),
initial_value);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
if (reduction == SegmentReductionType::MAX) {
CustomMax max_op{};
CUB_WRAPPER(
cub::DeviceSegmentedReduce::Reduce,
data_data_ptr,
output_data_ptr,
segment_count,
offsets_data_ptr,
offsets_data_ptr + 1,
max_op,
initial_value,
at::cuda::getCurrentCUDAStream());
} else if (reduction == SegmentReductionType::MEAN) {
CustomSum sum_op{};
CUB_WRAPPER(
cub::DeviceSegmentedReduce::Reduce,
data_data_ptr,
output_data_ptr,
segment_count,
offsets_data_ptr,
offsets_data_ptr + 1,
sum_op,
initial_value,
at::cuda::getCurrentCUDAStream());
if (output_shape.size() > 1) {
segment_reduce_forward_kernel<scalar_t>
<<<num_blocks,
threads_per_block,
0,
at::cuda::getCurrentCUDAStream()>>>(
reduction,
output_data_ptr,
data_data_ptr,
lengths_data_ptr,
offsets_data_ptr,
segment_count,
stride_count,
initial.has_value(),
initial_value);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
if (reduction == SegmentReductionType::MAX) {
CustomMax max_op{};
CUB_WRAPPER(
cub::DeviceSegmentedReduce::Reduce,
data_data_ptr,
output_data_ptr,
segment_count,
offsets_data_ptr,
offsets_data_ptr + 1,
max_op,
initial_value,
at::cuda::getCurrentCUDAStream());
} else if (reduction == SegmentReductionType::MEAN) {
CustomSum sum_op{};
CUB_WRAPPER(
cub::DeviceSegmentedReduce::Reduce,
data_data_ptr,
output_data_ptr,
segment_count,
offsets_data_ptr,
offsets_data_ptr + 1,
sum_op,
initial_value,
at::cuda::getCurrentCUDAStream());
post_sum_div_kernel<scalar_t>
<<<num_blocks,
threads_per_block,
0,
at::cuda::getCurrentCUDAStream()>>>(
output_data_ptr,
lengths_data_ptr,
segment_count,
initial.has_value(),
initial_value);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else if (reduction == SegmentReductionType::MIN) {
CustomMin min_op{};
CUB_WRAPPER(
cub::DeviceSegmentedReduce::Reduce,
data_data_ptr,
output_data_ptr,
segment_count,
offsets_data_ptr,
offsets_data_ptr + 1,
min_op,
initial_value,
at::cuda::getCurrentCUDAStream());
} else if (reduction == SegmentReductionType::SUM) {
CustomSum sum_op{};
CUB_WRAPPER(
cub::DeviceSegmentedReduce::Reduce,
data_data_ptr,
output_data_ptr,
segment_count,
offsets_data_ptr,
offsets_data_ptr + 1,
sum_op,
initial_value,
at::cuda::getCurrentCUDAStream());
}
}
});
post_sum_div_kernel<scalar_t>
<<<num_blocks,
threads_per_block,
0,
at::cuda::getCurrentCUDAStream()>>>(
output_data_ptr,
lengths_data_ptr,
segment_count,
initial.has_value(),
initial_value);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else if (reduction == SegmentReductionType::MIN) {
CustomMin min_op{};
CUB_WRAPPER(
cub::DeviceSegmentedReduce::Reduce,
data_data_ptr,
output_data_ptr,
segment_count,
offsets_data_ptr,
offsets_data_ptr + 1,
min_op,
initial_value,
at::cuda::getCurrentCUDAStream());
} else if (reduction == SegmentReductionType::SUM) {
CustomSum sum_op{};
CUB_WRAPPER(
cub::DeviceSegmentedReduce::Reduce,
data_data_ptr,
output_data_ptr,
segment_count,
offsets_data_ptr,
offsets_data_ptr + 1,
sum_op,
initial_value,
at::cuda::getCurrentCUDAStream());
}
}
});
}));
return output;
}

View File

@ -1,3 +1,5 @@
from itertools import product
import numpy as np
import torch
from torch.testing._internal.common_device_type import (
@ -28,8 +30,9 @@ class TestSegmentReductions(TestCase):
expected_arr,
expected_grad_arr,
check_backward,
lengths_dtype=torch.int,
):
lengths = torch.tensor(lengths_arr, device=device)
lengths = torch.tensor(lengths_arr, device=device, dtype=lengths_dtype)
data = torch.tensor(
data_arr,
device=device,
@ -85,8 +88,14 @@ class TestSegmentReductions(TestCase):
)
)
@dtypes(torch.half, torch.bfloat16, torch.float, torch.double)
def test_simple_1d(self, device, dtype):
@dtypes(
*product(
(torch.half, torch.bfloat16, torch.float, torch.double),
(torch.int, torch.int64),
)
)
def test_simple_1d(self, device, dtypes):
val_dtype, length_type = dtypes
lengths = [1, 2, 3, 0]
data = [1, float("nan"), 3, 4, 5, 5]
check_backward = True
@ -114,7 +123,7 @@ class TestSegmentReductions(TestCase):
self._test_common(
reduction,
device,
dtype,
val_dtype,
unsafe,
axis,
initial_value,
@ -123,10 +132,17 @@ class TestSegmentReductions(TestCase):
expected_result,
expected_grad,
check_backward,
length_type,
)
@dtypes(torch.half, torch.bfloat16, torch.float, torch.double)
def test_multi_d_simple(self, device, dtype):
@dtypes(
*product(
(torch.half, torch.bfloat16, torch.float, torch.double),
(torch.int, torch.int64),
)
)
def test_multi_d_simple(self, device, dtypes):
val_dtype, length_type = dtypes
check_backward = True
axis = 0
lengths = [1, 2, 3, 0]
@ -202,7 +218,7 @@ class TestSegmentReductions(TestCase):
self._test_common(
reduction,
device,
dtype,
val_dtype,
unsafe,
axis,
initial_value,
@ -213,8 +229,14 @@ class TestSegmentReductions(TestCase):
check_backward,
)
@dtypes(torch.half, torch.bfloat16, torch.float, torch.double)
def test_multi_d(self, device, dtype):
@dtypes(
*product(
(torch.half, torch.bfloat16, torch.float, torch.double),
(torch.int, torch.int64),
)
)
def test_multi_d(self, device, dtypes):
val_dtype, length_type = dtypes
axis = 0
lengths = [0, 2]
data = np.arange(20).reshape(2, 2, 5).tolist()
@ -253,7 +275,7 @@ class TestSegmentReductions(TestCase):
self._test_common(
reduction,
device,
dtype,
val_dtype,
unsafe,
axis,
initial_value,