mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[torch][segment_reduce] Add support for int lengths as well (#61141)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/61141 Currently only long is supported. This diff adds support for other index type. Next Steps: - Update default, refactor unit test and test non_initial value as well - Cleanup (more tests, benchmark, update documentation) Test Plan: updated unit test. rely on CI. Reviewed By: ngimel Differential Revision: D29526308 fbshipit-source-id: b4043603483851ef7e0e93b0bb02ac7849c6449d
This commit is contained in:
parent
423523d8bb
commit
a78ad5dc4c
|
|
@ -39,65 +39,72 @@ Tensor _segment_reduce_cpu_kernel(
|
|||
auto output = at::empty(output_shape, data.options());
|
||||
|
||||
int64_t stride_count = data.numel() / data.size(axis);
|
||||
const auto* lengths_data = lengths.data_ptr<int64_t>();
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
kBFloat16, kHalf, data.scalar_type(), "_segment_reduce_cpu", ([&]() {
|
||||
auto* output_data = output.data_ptr<scalar_t>();
|
||||
const auto* values_data = data.data_ptr<scalar_t>();
|
||||
int64_t lengths_cum_sum = 0;
|
||||
for (int64_t i = 0; i < segment_count; ++i) {
|
||||
for (int64_t l = 0; l < stride_count; ++l) {
|
||||
// ===== step1: initialize starting value
|
||||
scalar_t initial_value;
|
||||
if (initial.has_value()) {
|
||||
initial_value = initial.value().to<scalar_t>();
|
||||
} else if (reduction == SegmentReductionType::MAX) {
|
||||
initial_value = -std::numeric_limits<scalar_t>::infinity();
|
||||
} else if (
|
||||
reduction == SegmentReductionType::MEAN ||
|
||||
reduction == SegmentReductionType::SUM) {
|
||||
initial_value = 0;
|
||||
} else if (reduction == SegmentReductionType::MIN) {
|
||||
initial_value = std::numeric_limits<scalar_t>::infinity();
|
||||
}
|
||||
AT_DISPATCH_INDEX_TYPES(
|
||||
lengths.type(), "_segment_reduce_cpu_kernel1", ([&] {
|
||||
const auto* lengths_data = lengths.data_ptr<index_t>();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
kBFloat16,
|
||||
kHalf,
|
||||
data.scalar_type(),
|
||||
"_segment_reduce_cpu",
|
||||
([&]() {
|
||||
auto* output_data = output.data_ptr<scalar_t>();
|
||||
const auto* values_data = data.data_ptr<scalar_t>();
|
||||
int64_t lengths_cum_sum = 0;
|
||||
for (int64_t i = 0; i < segment_count; ++i) {
|
||||
for (int64_t l = 0; l < stride_count; ++l) {
|
||||
// ===== step1: initialize starting value
|
||||
scalar_t initial_value;
|
||||
if (initial.has_value()) {
|
||||
initial_value = initial.value().to<scalar_t>();
|
||||
} else if (reduction == SegmentReductionType::MAX) {
|
||||
initial_value = -std::numeric_limits<scalar_t>::infinity();
|
||||
} else if (
|
||||
reduction == SegmentReductionType::MEAN ||
|
||||
reduction == SegmentReductionType::SUM) {
|
||||
initial_value = 0;
|
||||
} else if (reduction == SegmentReductionType::MIN) {
|
||||
initial_value = std::numeric_limits<scalar_t>::infinity();
|
||||
}
|
||||
|
||||
// ===== step2: apply reduction
|
||||
for (int64_t j = 0; j < lengths_data[i]; ++j) {
|
||||
int64_t starting_index =
|
||||
((lengths_cum_sum + j) * stride_count) + l;
|
||||
const auto data = values_data[starting_index];
|
||||
// TODO: There is no need to branch with every element
|
||||
if (reduction == SegmentReductionType::MAX) {
|
||||
initial_value = at::_isnan(data)
|
||||
? data
|
||||
: std::max<scalar_t>(initial_value, data);
|
||||
} else if (
|
||||
reduction == SegmentReductionType::MEAN ||
|
||||
reduction == SegmentReductionType::SUM) {
|
||||
initial_value = initial_value + data;
|
||||
} else if (reduction == SegmentReductionType::MIN) {
|
||||
initial_value = at::_isnan(data)
|
||||
? data
|
||||
: std::min<scalar_t>(initial_value, data);
|
||||
// ===== step2: apply reduction
|
||||
for (int64_t j = 0; j < lengths_data[i]; ++j) {
|
||||
int64_t starting_index =
|
||||
((lengths_cum_sum + j) * stride_count) + l;
|
||||
const auto data = values_data[starting_index];
|
||||
// TODO: There is no need to branch with every element
|
||||
if (reduction == SegmentReductionType::MAX) {
|
||||
initial_value = at::_isnan(data)
|
||||
? data
|
||||
: std::max<scalar_t>(initial_value, data);
|
||||
} else if (
|
||||
reduction == SegmentReductionType::MEAN ||
|
||||
reduction == SegmentReductionType::SUM) {
|
||||
initial_value = initial_value + data;
|
||||
} else if (reduction == SegmentReductionType::MIN) {
|
||||
initial_value = at::_isnan(data)
|
||||
? data
|
||||
: std::min<scalar_t>(initial_value, data);
|
||||
}
|
||||
}
|
||||
|
||||
// ===== step3: finalize reduction
|
||||
TORCH_CHECK(lengths_data[i] >= 0);
|
||||
|
||||
if (lengths_data[i] == 0 && !initial.has_value()) {
|
||||
initial_value = static_cast<scalar_t>(NAN);
|
||||
} else if (
|
||||
reduction == SegmentReductionType::MEAN &&
|
||||
lengths_data[i] > 0 && !at::_isnan(initial_value)) {
|
||||
initial_value = initial_value / lengths_data[i];
|
||||
}
|
||||
int64_t output_index = (i * stride_count) + l;
|
||||
output_data[output_index] = initial_value;
|
||||
}
|
||||
lengths_cum_sum += lengths_data[i];
|
||||
}
|
||||
}
|
||||
|
||||
// ===== step3: finalize reduction
|
||||
TORCH_CHECK(lengths_data[i] >= 0);
|
||||
|
||||
if (lengths_data[i] == 0 && !initial.has_value()) {
|
||||
initial_value = static_cast<scalar_t>(NAN);
|
||||
} else if (
|
||||
reduction == SegmentReductionType::MEAN &&
|
||||
lengths_data[i] > 0 && !at::_isnan(initial_value)) {
|
||||
initial_value = initial_value / lengths_data[i];
|
||||
}
|
||||
int64_t output_index = (i * stride_count) + l;
|
||||
output_data[output_index] = initial_value;
|
||||
}
|
||||
lengths_cum_sum += lengths_data[i];
|
||||
}
|
||||
}));
|
||||
}));
|
||||
|
||||
return output;
|
||||
|
|
@ -116,73 +123,81 @@ Tensor _segment_reduce_cpu_backward_kernel(
|
|||
auto grad_input = at::zeros({data_contig.sizes()}, grad_contig.options());
|
||||
|
||||
int64_t stride_count = data_contig.numel() / data_contig.size(axis);
|
||||
const auto* lengths_data = lengths_contig.data_ptr<int64_t>();
|
||||
|
||||
// TODO: Swtich to TensorIterator for better maintainablility and readability
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
kBFloat16,
|
||||
kHalf,
|
||||
data_contig.scalar_type(),
|
||||
"_segment_reduce_cpu",
|
||||
([&]() {
|
||||
auto* output_data = output_contig.data_ptr<scalar_t>();
|
||||
auto* grad_data = grad_contig.data_ptr<scalar_t>();
|
||||
auto* grad_input_data = grad_input.data_ptr<scalar_t>();
|
||||
const auto* values_data = data_contig.data_ptr<scalar_t>();
|
||||
AT_DISPATCH_INDEX_TYPES(
|
||||
lengths_contig.type(),
|
||||
"_segment_reduce_cpu_backward_kernel1",
|
||||
([&] {
|
||||
const auto* lengths_data = lengths_contig.data_ptr<index_t>();
|
||||
// TODO: Swtich to TensorIterator for better maintainablility and
|
||||
// readability
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
kBFloat16,
|
||||
kHalf,
|
||||
data_contig.scalar_type(),
|
||||
"_segment_reduce_cpu",
|
||||
([&]() {
|
||||
auto* output_data = output_contig.data_ptr<scalar_t>();
|
||||
auto* grad_data = grad_contig.data_ptr<scalar_t>();
|
||||
auto* grad_input_data = grad_input.data_ptr<scalar_t>();
|
||||
const auto* values_data = data_contig.data_ptr<scalar_t>();
|
||||
|
||||
int64_t lengths_cum_sum = 0;
|
||||
for (int64_t i = 0; i < segment_count; ++i) {
|
||||
if (lengths_data[i] == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (int64_t l = 0; l < stride_count; ++l) {
|
||||
int64_t output_index = (i * stride_count) + l;
|
||||
|
||||
if (reduction == SegmentReductionType::MAX ||
|
||||
reduction == SegmentReductionType::MIN) {
|
||||
int64_t counter = 0;
|
||||
for (int64_t j = 0; j < lengths_data[i]; ++j) {
|
||||
int64_t starting_index =
|
||||
((lengths_cum_sum + j) * stride_count) + l;
|
||||
if (at::_isnan(values_data[starting_index]) ||
|
||||
values_data[starting_index] == output_data[output_index]) {
|
||||
grad_input_data[starting_index] = grad_data[output_index];
|
||||
counter++;
|
||||
int64_t lengths_cum_sum = 0;
|
||||
for (int64_t i = 0; i < segment_count; ++i) {
|
||||
if (lengths_data[i] == 0) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
// Average gradient based on number of maximum elements in the
|
||||
// segment
|
||||
if (counter < 2) {
|
||||
continue;
|
||||
}
|
||||
for (int64_t j = 0; j < lengths_data[i]; ++j) {
|
||||
int64_t starting_index =
|
||||
((lengths_cum_sum + j) * stride_count) + l;
|
||||
if (grad_input_data[starting_index] > 0) {
|
||||
grad_input_data[starting_index] =
|
||||
grad_input_data[starting_index] / counter;
|
||||
}
|
||||
}
|
||||
} else if (reduction == SegmentReductionType::MEAN) {
|
||||
auto grad_val = grad_data[output_index] / lengths_data[i];
|
||||
for (int64_t j = 0; j < lengths_data[i]; ++j) {
|
||||
int64_t starting_index =
|
||||
((lengths_cum_sum + j) * stride_count) + l;
|
||||
grad_input_data[starting_index] = grad_val;
|
||||
}
|
||||
} else if (reduction == SegmentReductionType::SUM) {
|
||||
const auto& grad_val = grad_data[output_index];
|
||||
for (int64_t j = 0; j < lengths_data[i]; ++j) {
|
||||
int64_t starting_index =
|
||||
((lengths_cum_sum + j) * stride_count) + l;
|
||||
grad_input_data[starting_index] = grad_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
lengths_cum_sum += lengths_data[i];
|
||||
}
|
||||
for (int64_t l = 0; l < stride_count; ++l) {
|
||||
int64_t output_index = (i * stride_count) + l;
|
||||
|
||||
if (reduction == SegmentReductionType::MAX ||
|
||||
reduction == SegmentReductionType::MIN) {
|
||||
int64_t counter = 0;
|
||||
for (int64_t j = 0; j < lengths_data[i]; ++j) {
|
||||
int64_t starting_index =
|
||||
((lengths_cum_sum + j) * stride_count) + l;
|
||||
if (at::_isnan(values_data[starting_index]) ||
|
||||
values_data[starting_index] ==
|
||||
output_data[output_index]) {
|
||||
grad_input_data[starting_index] =
|
||||
grad_data[output_index];
|
||||
counter++;
|
||||
}
|
||||
}
|
||||
// Average gradient based on number of maximum elements in
|
||||
// the segment
|
||||
if (counter < 2) {
|
||||
continue;
|
||||
}
|
||||
for (int64_t j = 0; j < lengths_data[i]; ++j) {
|
||||
int64_t starting_index =
|
||||
((lengths_cum_sum + j) * stride_count) + l;
|
||||
if (grad_input_data[starting_index] > 0) {
|
||||
grad_input_data[starting_index] =
|
||||
grad_input_data[starting_index] / counter;
|
||||
}
|
||||
}
|
||||
} else if (reduction == SegmentReductionType::MEAN) {
|
||||
auto grad_val = grad_data[output_index] / lengths_data[i];
|
||||
for (int64_t j = 0; j < lengths_data[i]; ++j) {
|
||||
int64_t starting_index =
|
||||
((lengths_cum_sum + j) * stride_count) + l;
|
||||
grad_input_data[starting_index] = grad_val;
|
||||
}
|
||||
} else if (reduction == SegmentReductionType::SUM) {
|
||||
const auto& grad_val = grad_data[output_index];
|
||||
for (int64_t j = 0; j < lengths_data[i]; ++j) {
|
||||
int64_t starting_index =
|
||||
((lengths_cum_sum + j) * stride_count) + l;
|
||||
grad_input_data[starting_index] = grad_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
lengths_cum_sum += lengths_data[i];
|
||||
}
|
||||
}));
|
||||
}));
|
||||
|
||||
return grad_input;
|
||||
|
|
|
|||
|
|
@ -50,23 +50,25 @@ Tensor _get_complete_sum(const Tensor& lengths) {
|
|||
TORCH_CHECK(segment_count < INT_MAX);
|
||||
auto offsets = at::empty({segment_count + 1}, lengths.options());
|
||||
offsets[0].zero_();
|
||||
auto* lengths_data_ptr = lengths.data_ptr<int64_t>();
|
||||
auto* offsets_data_ptr = offsets.data_ptr<int64_t>();
|
||||
|
||||
CUB_WRAPPER(
|
||||
cub::DeviceScan::InclusiveSum,
|
||||
lengths_data_ptr,
|
||||
offsets_data_ptr + 1,
|
||||
segment_count,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
|
||||
AT_DISPATCH_INDEX_TYPES(
|
||||
lengths.type(), "_segment_reduce_cuda_backward_kernel1", ([&] {
|
||||
auto* lengths_data_ptr = lengths.data_ptr<index_t>();
|
||||
auto* offsets_data_ptr = offsets.data_ptr<index_t>();
|
||||
CUB_WRAPPER(
|
||||
cub::DeviceScan::InclusiveSum,
|
||||
lengths_data_ptr,
|
||||
offsets_data_ptr + 1,
|
||||
segment_count,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
}));
|
||||
return offsets;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
template <typename scalar_t, typename index_t>
|
||||
__global__ static void post_sum_div_kernel(
|
||||
scalar_t* output_data,
|
||||
const int64_t* lengths_data,
|
||||
const index_t* lengths_data,
|
||||
const int64_t segment_count,
|
||||
bool is_initial_set,
|
||||
scalar_t initial) {
|
||||
|
|
@ -84,13 +86,13 @@ __global__ static void post_sum_div_kernel(
|
|||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ static void segment_reduce_forward_kernel(
|
||||
template <typename scalar_t, typename index_t>
|
||||
__global__ void segment_reduce_forward_kernel(
|
||||
SegmentReductionType reduction,
|
||||
scalar_t* output_data,
|
||||
scalar_t* values_data,
|
||||
const int64_t* lengths_data,
|
||||
const int64_t* lengths_cumsum_data,
|
||||
const index_t* lengths_data,
|
||||
const index_t* lengths_cumsum_data,
|
||||
const int64_t segment_count,
|
||||
const int64_t stride_count,
|
||||
bool is_initial_set,
|
||||
|
|
@ -135,15 +137,15 @@ __global__ static void segment_reduce_forward_kernel(
|
|||
output_data[output_index] = initial_value;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ static void segment_reduce_backward_kernel(
|
||||
template <typename scalar_t, typename index_t>
|
||||
__global__ void segment_reduce_backward_kernel(
|
||||
SegmentReductionType reduction,
|
||||
scalar_t* grad_input_data,
|
||||
scalar_t* grad_data,
|
||||
scalar_t* output_data,
|
||||
const scalar_t* values_data,
|
||||
const int64_t* lengths_data,
|
||||
const int64_t* lengths_cumsum_data,
|
||||
const index_t* lengths_data,
|
||||
const index_t* lengths_cumsum_data,
|
||||
const int64_t segment_count,
|
||||
const int64_t stride_count) {
|
||||
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
|
@ -215,10 +217,8 @@ Tensor _segment_reduce_cuda_backward_kernel(
|
|||
auto grad_input = at::zeros({data_contig.sizes()}, grad_contig.options());
|
||||
|
||||
int64_t stride_count = data_contig.numel() / data_contig.size(axis);
|
||||
const auto* lengths_data = lengths_contig.data_ptr<int64_t>();
|
||||
|
||||
auto offsets = _get_complete_sum(lengths_contig);
|
||||
auto* offsets_data = offsets.data_ptr<int64_t>();
|
||||
|
||||
constexpr int threads_per_block = 256;
|
||||
int64_t num_blocks =
|
||||
|
|
@ -227,35 +227,41 @@ Tensor _segment_reduce_cuda_backward_kernel(
|
|||
|
||||
num_blocks = std::max(num_blocks, (int64_t)1);
|
||||
|
||||
// TODO: Swtich to TensorIterator for better maintainablility and readability
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
kBFloat16,
|
||||
kHalf,
|
||||
data_contig.scalar_type(),
|
||||
"_segment_reduce_cpu",
|
||||
([&]() {
|
||||
auto* output_data = output_contig.data_ptr<scalar_t>();
|
||||
auto* grad_data = grad_contig.data_ptr<scalar_t>();
|
||||
auto* grad_input_data = grad_input.data_ptr<scalar_t>();
|
||||
const auto* values_data = data_contig.data_ptr<scalar_t>();
|
||||
AT_DISPATCH_INDEX_TYPES(
|
||||
lengths_contig.type(), "_segment_reduce_cuda_backward_kernel1", ([&] {
|
||||
const auto* lengths_data = lengths_contig.data_ptr<index_t>();
|
||||
auto* offsets_data = offsets.data_ptr<index_t>();
|
||||
|
||||
segment_reduce_backward_kernel<scalar_t>
|
||||
<<<num_blocks,
|
||||
threads_per_block,
|
||||
0,
|
||||
at::cuda::getCurrentCUDAStream()>>>(
|
||||
reduction,
|
||||
grad_input_data,
|
||||
grad_data,
|
||||
output_data,
|
||||
values_data,
|
||||
lengths_data,
|
||||
offsets_data,
|
||||
segment_count,
|
||||
stride_count);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
// TODO: Swtich to TensorIterator for better maintainablility and
|
||||
// readability
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
kBFloat16,
|
||||
kHalf,
|
||||
data_contig.scalar_type(),
|
||||
"_segment_reduce_cpu",
|
||||
([&]() {
|
||||
auto* output_data = output_contig.data_ptr<scalar_t>();
|
||||
auto* grad_data = grad_contig.data_ptr<scalar_t>();
|
||||
auto* grad_input_data = grad_input.data_ptr<scalar_t>();
|
||||
const auto* values_data = data_contig.data_ptr<scalar_t>();
|
||||
|
||||
segment_reduce_backward_kernel<scalar_t>
|
||||
<<<num_blocks,
|
||||
threads_per_block,
|
||||
0,
|
||||
at::cuda::getCurrentCUDAStream()>>>(
|
||||
reduction,
|
||||
grad_input_data,
|
||||
grad_data,
|
||||
output_data,
|
||||
values_data,
|
||||
lengths_data,
|
||||
offsets_data,
|
||||
segment_count,
|
||||
stride_count);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}));
|
||||
}));
|
||||
|
||||
return grad_input;
|
||||
}
|
||||
|
||||
|
|
@ -271,10 +277,8 @@ Tensor _segment_reduce_cuda_kernel(
|
|||
auto output = at::empty(output_shape, data.options());
|
||||
|
||||
int64_t stride_count = data.numel() / data.size(axis);
|
||||
const auto* lengths_data = lengths.data_ptr<int64_t>();
|
||||
|
||||
auto offsets = _get_complete_sum(lengths);
|
||||
auto* offsets_data_ptr = offsets.data_ptr<int64_t>();
|
||||
|
||||
constexpr int threads_per_block = 256;
|
||||
int64_t num_blocks =
|
||||
|
|
@ -282,111 +286,115 @@ Tensor _segment_reduce_cuda_kernel(
|
|||
threads_per_block;
|
||||
|
||||
num_blocks = std::max(num_blocks, (int64_t)1);
|
||||
auto* lengths_data_ptr = lengths.data_ptr<int64_t>();
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
data.scalar_type(),
|
||||
"segment_reduce_cuda",
|
||||
[&]() {
|
||||
auto* data_data_ptr = data.data_ptr<scalar_t>();
|
||||
auto* output_data_ptr = output.data_ptr<scalar_t>();
|
||||
AT_DISPATCH_INDEX_TYPES(
|
||||
lengths.type(), "_segment_reduce_cuda_kernel1", ([&] {
|
||||
auto* offsets_data_ptr = offsets.data_ptr<index_t>();
|
||||
auto* lengths_data_ptr = lengths.data_ptr<index_t>();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
data.scalar_type(),
|
||||
"segment_reduce_cuda",
|
||||
[&]() {
|
||||
auto* data_data_ptr = data.data_ptr<scalar_t>();
|
||||
auto* output_data_ptr = output.data_ptr<scalar_t>();
|
||||
|
||||
// initialize starting value
|
||||
scalar_t initial_value;
|
||||
if (initial.has_value()) {
|
||||
initial_value = initial.value().to<scalar_t>();
|
||||
} else if (reduction == SegmentReductionType::MAX) {
|
||||
initial_value = -std::numeric_limits<scalar_t>::infinity();
|
||||
} else if (
|
||||
reduction == SegmentReductionType::MEAN ||
|
||||
reduction == SegmentReductionType::SUM) {
|
||||
initial_value = 0;
|
||||
} else if (reduction == SegmentReductionType::MIN) {
|
||||
initial_value = std::numeric_limits<scalar_t>::infinity();
|
||||
}
|
||||
// initialize starting value
|
||||
scalar_t initial_value;
|
||||
if (initial.has_value()) {
|
||||
initial_value = initial.value().to<scalar_t>();
|
||||
} else if (reduction == SegmentReductionType::MAX) {
|
||||
initial_value = -std::numeric_limits<scalar_t>::infinity();
|
||||
} else if (
|
||||
reduction == SegmentReductionType::MEAN ||
|
||||
reduction == SegmentReductionType::SUM) {
|
||||
initial_value = 0;
|
||||
} else if (reduction == SegmentReductionType::MIN) {
|
||||
initial_value = std::numeric_limits<scalar_t>::infinity();
|
||||
}
|
||||
|
||||
if (output_shape.size() > 1) {
|
||||
segment_reduce_forward_kernel<scalar_t>
|
||||
<<<num_blocks,
|
||||
threads_per_block,
|
||||
0,
|
||||
at::cuda::getCurrentCUDAStream()>>>(
|
||||
reduction,
|
||||
output_data_ptr,
|
||||
data_data_ptr,
|
||||
lengths_data_ptr,
|
||||
offsets_data_ptr,
|
||||
segment_count,
|
||||
stride_count,
|
||||
initial.has_value(),
|
||||
initial_value);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
} else {
|
||||
if (reduction == SegmentReductionType::MAX) {
|
||||
CustomMax max_op{};
|
||||
CUB_WRAPPER(
|
||||
cub::DeviceSegmentedReduce::Reduce,
|
||||
data_data_ptr,
|
||||
output_data_ptr,
|
||||
segment_count,
|
||||
offsets_data_ptr,
|
||||
offsets_data_ptr + 1,
|
||||
max_op,
|
||||
initial_value,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
} else if (reduction == SegmentReductionType::MEAN) {
|
||||
CustomSum sum_op{};
|
||||
CUB_WRAPPER(
|
||||
cub::DeviceSegmentedReduce::Reduce,
|
||||
data_data_ptr,
|
||||
output_data_ptr,
|
||||
segment_count,
|
||||
offsets_data_ptr,
|
||||
offsets_data_ptr + 1,
|
||||
sum_op,
|
||||
initial_value,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
if (output_shape.size() > 1) {
|
||||
segment_reduce_forward_kernel<scalar_t>
|
||||
<<<num_blocks,
|
||||
threads_per_block,
|
||||
0,
|
||||
at::cuda::getCurrentCUDAStream()>>>(
|
||||
reduction,
|
||||
output_data_ptr,
|
||||
data_data_ptr,
|
||||
lengths_data_ptr,
|
||||
offsets_data_ptr,
|
||||
segment_count,
|
||||
stride_count,
|
||||
initial.has_value(),
|
||||
initial_value);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
} else {
|
||||
if (reduction == SegmentReductionType::MAX) {
|
||||
CustomMax max_op{};
|
||||
CUB_WRAPPER(
|
||||
cub::DeviceSegmentedReduce::Reduce,
|
||||
data_data_ptr,
|
||||
output_data_ptr,
|
||||
segment_count,
|
||||
offsets_data_ptr,
|
||||
offsets_data_ptr + 1,
|
||||
max_op,
|
||||
initial_value,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
} else if (reduction == SegmentReductionType::MEAN) {
|
||||
CustomSum sum_op{};
|
||||
CUB_WRAPPER(
|
||||
cub::DeviceSegmentedReduce::Reduce,
|
||||
data_data_ptr,
|
||||
output_data_ptr,
|
||||
segment_count,
|
||||
offsets_data_ptr,
|
||||
offsets_data_ptr + 1,
|
||||
sum_op,
|
||||
initial_value,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
|
||||
post_sum_div_kernel<scalar_t>
|
||||
<<<num_blocks,
|
||||
threads_per_block,
|
||||
0,
|
||||
at::cuda::getCurrentCUDAStream()>>>(
|
||||
output_data_ptr,
|
||||
lengths_data_ptr,
|
||||
segment_count,
|
||||
initial.has_value(),
|
||||
initial_value);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
} else if (reduction == SegmentReductionType::MIN) {
|
||||
CustomMin min_op{};
|
||||
CUB_WRAPPER(
|
||||
cub::DeviceSegmentedReduce::Reduce,
|
||||
data_data_ptr,
|
||||
output_data_ptr,
|
||||
segment_count,
|
||||
offsets_data_ptr,
|
||||
offsets_data_ptr + 1,
|
||||
min_op,
|
||||
initial_value,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
} else if (reduction == SegmentReductionType::SUM) {
|
||||
CustomSum sum_op{};
|
||||
CUB_WRAPPER(
|
||||
cub::DeviceSegmentedReduce::Reduce,
|
||||
data_data_ptr,
|
||||
output_data_ptr,
|
||||
segment_count,
|
||||
offsets_data_ptr,
|
||||
offsets_data_ptr + 1,
|
||||
sum_op,
|
||||
initial_value,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
}
|
||||
}
|
||||
});
|
||||
post_sum_div_kernel<scalar_t>
|
||||
<<<num_blocks,
|
||||
threads_per_block,
|
||||
0,
|
||||
at::cuda::getCurrentCUDAStream()>>>(
|
||||
output_data_ptr,
|
||||
lengths_data_ptr,
|
||||
segment_count,
|
||||
initial.has_value(),
|
||||
initial_value);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
} else if (reduction == SegmentReductionType::MIN) {
|
||||
CustomMin min_op{};
|
||||
CUB_WRAPPER(
|
||||
cub::DeviceSegmentedReduce::Reduce,
|
||||
data_data_ptr,
|
||||
output_data_ptr,
|
||||
segment_count,
|
||||
offsets_data_ptr,
|
||||
offsets_data_ptr + 1,
|
||||
min_op,
|
||||
initial_value,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
} else if (reduction == SegmentReductionType::SUM) {
|
||||
CustomSum sum_op{};
|
||||
CUB_WRAPPER(
|
||||
cub::DeviceSegmentedReduce::Reduce,
|
||||
data_data_ptr,
|
||||
output_data_ptr,
|
||||
segment_count,
|
||||
offsets_data_ptr,
|
||||
offsets_data_ptr + 1,
|
||||
sum_op,
|
||||
initial_value,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
}
|
||||
}
|
||||
});
|
||||
}));
|
||||
|
||||
return output;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from itertools import product
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.testing._internal.common_device_type import (
|
||||
|
|
@ -28,8 +30,9 @@ class TestSegmentReductions(TestCase):
|
|||
expected_arr,
|
||||
expected_grad_arr,
|
||||
check_backward,
|
||||
lengths_dtype=torch.int,
|
||||
):
|
||||
lengths = torch.tensor(lengths_arr, device=device)
|
||||
lengths = torch.tensor(lengths_arr, device=device, dtype=lengths_dtype)
|
||||
data = torch.tensor(
|
||||
data_arr,
|
||||
device=device,
|
||||
|
|
@ -85,8 +88,14 @@ class TestSegmentReductions(TestCase):
|
|||
)
|
||||
)
|
||||
|
||||
@dtypes(torch.half, torch.bfloat16, torch.float, torch.double)
|
||||
def test_simple_1d(self, device, dtype):
|
||||
@dtypes(
|
||||
*product(
|
||||
(torch.half, torch.bfloat16, torch.float, torch.double),
|
||||
(torch.int, torch.int64),
|
||||
)
|
||||
)
|
||||
def test_simple_1d(self, device, dtypes):
|
||||
val_dtype, length_type = dtypes
|
||||
lengths = [1, 2, 3, 0]
|
||||
data = [1, float("nan"), 3, 4, 5, 5]
|
||||
check_backward = True
|
||||
|
|
@ -114,7 +123,7 @@ class TestSegmentReductions(TestCase):
|
|||
self._test_common(
|
||||
reduction,
|
||||
device,
|
||||
dtype,
|
||||
val_dtype,
|
||||
unsafe,
|
||||
axis,
|
||||
initial_value,
|
||||
|
|
@ -123,10 +132,17 @@ class TestSegmentReductions(TestCase):
|
|||
expected_result,
|
||||
expected_grad,
|
||||
check_backward,
|
||||
length_type,
|
||||
)
|
||||
|
||||
@dtypes(torch.half, torch.bfloat16, torch.float, torch.double)
|
||||
def test_multi_d_simple(self, device, dtype):
|
||||
@dtypes(
|
||||
*product(
|
||||
(torch.half, torch.bfloat16, torch.float, torch.double),
|
||||
(torch.int, torch.int64),
|
||||
)
|
||||
)
|
||||
def test_multi_d_simple(self, device, dtypes):
|
||||
val_dtype, length_type = dtypes
|
||||
check_backward = True
|
||||
axis = 0
|
||||
lengths = [1, 2, 3, 0]
|
||||
|
|
@ -202,7 +218,7 @@ class TestSegmentReductions(TestCase):
|
|||
self._test_common(
|
||||
reduction,
|
||||
device,
|
||||
dtype,
|
||||
val_dtype,
|
||||
unsafe,
|
||||
axis,
|
||||
initial_value,
|
||||
|
|
@ -213,8 +229,14 @@ class TestSegmentReductions(TestCase):
|
|||
check_backward,
|
||||
)
|
||||
|
||||
@dtypes(torch.half, torch.bfloat16, torch.float, torch.double)
|
||||
def test_multi_d(self, device, dtype):
|
||||
@dtypes(
|
||||
*product(
|
||||
(torch.half, torch.bfloat16, torch.float, torch.double),
|
||||
(torch.int, torch.int64),
|
||||
)
|
||||
)
|
||||
def test_multi_d(self, device, dtypes):
|
||||
val_dtype, length_type = dtypes
|
||||
axis = 0
|
||||
lengths = [0, 2]
|
||||
data = np.arange(20).reshape(2, 2, 5).tolist()
|
||||
|
|
@ -253,7 +275,7 @@ class TestSegmentReductions(TestCase):
|
|||
self._test_common(
|
||||
reduction,
|
||||
device,
|
||||
dtype,
|
||||
val_dtype,
|
||||
unsafe,
|
||||
axis,
|
||||
initial_value,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user