[torch][segment_reduce] Add support for sum and min reductions (#60379)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60379

This concludes the support for all reductions types initially planned (min, max, mean, sum).

Next Steps:
- Cleanups
       -  update default values when length is 0 and initial not given
       - templatize the code to avoid branching with every item.( and other known improvements)
- more unit tests, verification
- benchmarking

Test Plan: updated unit tests.

Reviewed By: ngimel

Differential Revision: D29268218

fbshipit-source-id: c77d91671e01dcf96c18c758fa3ea522b2e13db9
This commit is contained in:
Serhat Yilmaz 2021-06-23 18:50:36 -07:00 committed by Facebook GitHub Bot
parent 63219f1f9f
commit af66824c1f
4 changed files with 194 additions and 62 deletions

View File

@ -18,6 +18,10 @@ SegmentReductionType get_reduction_enum(const c10::string_view& reduce) {
return SegmentReductionType::MAX;
} else if (reduce == "mean") {
return SegmentReductionType::MEAN;
} else if (reduce == "min") {
return SegmentReductionType::MIN;
} else if (reduce == "sum") {
return SegmentReductionType::SUM;
} else {
TORCH_CHECK(false, "unsopported reduction given! ", reduce);
}
@ -37,7 +41,7 @@ Tensor _segment_reduce_cpu_kernel(
int64_t stride_count = data.numel() / data.size(axis);
const auto* lengths_data = lengths.data_ptr<int64_t>();
AT_DISPATCH_ALL_TYPES_AND2(
AT_DISPATCH_FLOATING_TYPES_AND2(
kBFloat16, kHalf, data.scalar_type(), "_segment_reduce_cpu", ([&]() {
auto* output_data = output.data_ptr<scalar_t>();
const auto* values_data = data.data_ptr<scalar_t>();
@ -49,9 +53,13 @@ Tensor _segment_reduce_cpu_kernel(
if (initial.has_value()) {
initial_value = initial.value().to<scalar_t>();
} else if (reduction == SegmentReductionType::MAX) {
initial_value = std::numeric_limits<scalar_t>::lowest();
} else if (reduction == SegmentReductionType::MEAN) {
initial_value = -std::numeric_limits<scalar_t>::infinity();
} else if (
reduction == SegmentReductionType::MEAN ||
reduction == SegmentReductionType::SUM) {
initial_value = 0;
} else if (reduction == SegmentReductionType::MIN) {
initial_value = std::numeric_limits<scalar_t>::infinity();
}
// ===== step2: apply reduction
@ -64,8 +72,14 @@ Tensor _segment_reduce_cpu_kernel(
initial_value = at::_isnan(data)
? data
: std::max<scalar_t>(initial_value, data);
} else if (reduction == SegmentReductionType::MEAN) {
} else if (
reduction == SegmentReductionType::MEAN ||
reduction == SegmentReductionType::SUM) {
initial_value = initial_value + data;
} else if (reduction == SegmentReductionType::MIN) {
initial_value = at::_isnan(data)
? data
: std::min<scalar_t>(initial_value, data);
}
}
@ -105,7 +119,7 @@ Tensor _segment_reduce_cpu_backward_kernel(
const auto* lengths_data = lengths_contig.data_ptr<int64_t>();
// TODO: Swtich to TensorIterator for better maintainablility and readability
AT_DISPATCH_ALL_TYPES_AND2(
AT_DISPATCH_FLOATING_TYPES_AND2(
kBFloat16,
kHalf,
data_contig.scalar_type(),
@ -125,7 +139,8 @@ Tensor _segment_reduce_cpu_backward_kernel(
for (int64_t l = 0; l < stride_count; ++l) {
int64_t output_index = (i * stride_count) + l;
if (reduction == SegmentReductionType::MAX) {
if (reduction == SegmentReductionType::MAX ||
reduction == SegmentReductionType::MIN) {
int64_t counter = 0;
for (int64_t j = 0; j < lengths_data[i]; ++j) {
int64_t starting_index =
@ -156,6 +171,13 @@ Tensor _segment_reduce_cpu_backward_kernel(
((lengths_cum_sum + j) * stride_count) + l;
grad_input_data[starting_index] = grad_val;
}
} else if (reduction == SegmentReductionType::SUM) {
const auto& grad_val = grad_data[output_index];
for (int64_t j = 0; j < lengths_data[i]; ++j) {
int64_t starting_index =
((lengths_cum_sum + j) * stride_count) + l;
grad_input_data[starting_index] = grad_val;
}
}
}

View File

@ -7,7 +7,7 @@
namespace at {
namespace native {
enum SegmentReductionType { MAX, MEAN };
enum SegmentReductionType { MAX, MEAN, MIN, SUM };
using segment_reduce_fn = Tensor (*)(
SegmentReductionType,

View File

@ -32,6 +32,19 @@ struct CustomSum {
}
};
struct CustomMin {
template <typename OutputT>
__host__ __device__ __forceinline__ OutputT
operator()(const OutputT& a, const OutputT& b) const {
if (at::_isnan(a)) {
return a;
} else if (at::_isnan(b)) {
return b;
}
return std::min<OutputT>(a, b);
}
};
Tensor _get_complete_sum(const Tensor& lengths) {
int64_t segment_count = lengths.numel();
TORCH_CHECK(segment_count < INT_MAX);
@ -99,8 +112,13 @@ __global__ static void segment_reduce_forward_kernel(
if (reduction == SegmentReductionType::MAX) {
initial_value =
at::_isnan(data) ? data : std::max<scalar_t>(initial_value, data);
} else if (reduction == SegmentReductionType::MEAN) {
} else if (
reduction == SegmentReductionType::MEAN ||
reduction == SegmentReductionType::SUM) {
initial_value = initial_value + data;
} else if (reduction == SegmentReductionType::MIN) {
initial_value =
at::_isnan(data) ? data : std::min<scalar_t>(initial_value, data);
}
}
@ -144,7 +162,8 @@ __global__ static void segment_reduce_backward_kernel(
int64_t output_index = (row_id * stride_count) + lane_id;
if (reduction == SegmentReductionType::MAX) {
if (reduction == SegmentReductionType::MAX ||
reduction == SegmentReductionType::MIN) {
int64_t counter = 0;
for (int64_t j = offset_start; j < offset_end; ++j) {
int64_t starting_index = (j * stride_count) + lane_id;
@ -172,6 +191,12 @@ __global__ static void segment_reduce_backward_kernel(
int64_t starting_index = (j * stride_count) + lane_id;
grad_input_data[starting_index] = grad_val;
}
} else if (reduction == SegmentReductionType::SUM) {
const auto& grad_val = grad_data[output_index];
for (int64_t j = offset_start; j < offset_end; ++j) {
int64_t starting_index = (j * stride_count) + lane_id;
grad_input_data[starting_index] = grad_val;
}
}
}
@ -203,7 +228,7 @@ Tensor _segment_reduce_cuda_backward_kernel(
num_blocks = std::max(num_blocks, (int64_t)1);
// TODO: Swtich to TensorIterator for better maintainablility and readability
AT_DISPATCH_ALL_TYPES_AND2(
AT_DISPATCH_FLOATING_TYPES_AND2(
kBFloat16,
kHalf,
data_contig.scalar_type(),
@ -259,7 +284,7 @@ Tensor _segment_reduce_cuda_kernel(
num_blocks = std::max(num_blocks, (int64_t)1);
auto* lengths_data_ptr = lengths.data_ptr<int64_t>();
AT_DISPATCH_ALL_TYPES_AND2(
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
data.scalar_type(),
@ -273,9 +298,13 @@ Tensor _segment_reduce_cuda_kernel(
if (initial.has_value()) {
initial_value = initial.value().to<scalar_t>();
} else if (reduction == SegmentReductionType::MAX) {
initial_value = std::numeric_limits<scalar_t>::lowest();
} else if (reduction == SegmentReductionType::MEAN) {
initial_value = -std::numeric_limits<scalar_t>::infinity();
} else if (
reduction == SegmentReductionType::MEAN ||
reduction == SegmentReductionType::SUM) {
initial_value = 0;
} else if (reduction == SegmentReductionType::MIN) {
initial_value = std::numeric_limits<scalar_t>::infinity();
}
if (output_shape.size() > 1) {
@ -331,6 +360,30 @@ Tensor _segment_reduce_cuda_kernel(
initial.has_value(),
initial_value);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else if (reduction == SegmentReductionType::MIN) {
CustomMin min_op{};
CUB_WRAPPER(
cub::DeviceSegmentedReduce::Reduce,
data_data_ptr,
output_data_ptr,
segment_count,
offsets_data_ptr,
offsets_data_ptr + 1,
min_op,
initial_value,
at::cuda::getCurrentCUDAStream());
} else if (reduction == SegmentReductionType::SUM) {
CustomSum sum_op{};
CUB_WRAPPER(
cub::DeviceSegmentedReduce::Reduce,
data_data_ptr,
output_data_ptr,
segment_count,
offsets_data_ptr,
offsets_data_ptr + 1,
sum_op,
initial_value,
at::cuda::getCurrentCUDAStream());
}
}
});

View File

@ -3,7 +3,6 @@ import torch
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
dtypes,
dtypesIfCUDA,
)
from torch.testing._internal.common_utils import (
TestCase,
@ -12,6 +11,9 @@ from torch.testing._internal.common_utils import (
)
reductions = ["max", "mean", "min", "sum"]
class TestSegmentReductions(TestCase):
def _test_common(
self,
@ -83,48 +85,56 @@ class TestSegmentReductions(TestCase):
)
)
@dtypesIfCUDA(torch.half, torch.bfloat16, torch.float, torch.double)
@dtypes(torch.half, torch.bfloat16, torch.float, torch.double)
def test_simple_1d(self, device, dtype):
lengths = [1, 2, 3, 0]
data = [1, float("nan"), 3, 4, 5, 5]
initial_value = 0
check_backward = True
for reduction in ("max", "mean"):
for reduction in reductions:
if reduction == "max":
initial_value = 0
expected_result = [1, float("nan"), 5, initial_value]
expected_grad = [1, 1, 0, 0, 0.5, 0.5]
elif reduction == "mean":
initial_value = 0
expected_result = [1, float("nan"), 4.666, initial_value]
expected_grad = [1.0, 0.5, 0.5, 0.333, 0.333, 0.333]
elif reduction == "min":
initial_value = 1000 # some high number
expected_result = [1, float("nan"), 4, initial_value]
expected_grad = [1.0, 1.0, 0, 1, 0, 0]
elif reduction == "sum":
initial_value = 0
expected_result = [1, float("nan"), 14, initial_value]
expected_grad = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
for axis in [0, -1]:
for unsafe in [True, False]:
self._test_common(
reduction,
device,
dtype,
unsafe,
axis,
initial_value,
data,
lengths,
expected_result,
expected_grad,
check_backward,
)
for initial in [initial_value, None]:
self._test_common(
reduction,
device,
dtype,
unsafe,
axis,
initial_value,
data,
lengths,
expected_result,
expected_grad,
check_backward,
)
@dtypesIfCUDA(torch.half, torch.bfloat16, torch.float, torch.double)
@dtypes(torch.half, torch.bfloat16, torch.float, torch.double)
def test_multi_d_simple(self, device, dtype):
initial_value = 0
check_backward = True
axis = 0
lengths = [1, 2, 3, 0]
data = [[1, 1], [float("nan"), 1], [3, float("nan")], [4, 1], [3, 2], [2, 3]]
for reduction in ("max", "mean"):
for reduction in reductions:
if reduction == "max":
initial_value = 0
expected_result = [
[1, 1],
[float("nan"), float("nan")],
@ -140,6 +150,7 @@ class TestSegmentReductions(TestCase):
[0, 1],
]
elif reduction == "mean":
initial_value = 0
expected_result = [
[1, 1],
[float("nan"), float("nan")],
@ -154,25 +165,56 @@ class TestSegmentReductions(TestCase):
[0.333, 0.333],
[0.333, 0.333],
]
elif reduction == "min":
initial_value = 1000 # some high number
expected_result = [
[1, 1],
[float("nan"), float("nan")],
[2, 1],
[initial_value, initial_value],
]
expected_grad = [
[1.0, 1.0],
[1, 0],
[0, 1],
[0, 1],
[0, 0],
[1, 0],
]
elif reduction == "sum":
initial_value = 0
expected_result = [
[1, 1],
[float("nan"), float("nan")],
[9, 6],
[initial_value, initial_value],
]
expected_grad = [
[1.0, 1.0],
[1.0, 1.0],
[1.0, 1.0],
[1.0, 1.0],
[1.0, 1.0],
[1.0, 1.0],
]
for unsafe in [True, False]:
self._test_common(
reduction,
device,
dtype,
unsafe,
axis,
initial_value,
data,
lengths,
expected_result,
expected_grad,
check_backward,
)
for initial in [initial_value, None]:
self._test_common(
reduction,
device,
dtype,
unsafe,
axis,
initial_value,
data,
lengths,
expected_result,
expected_grad,
check_backward,
)
@dtypesIfCUDA(torch.half, torch.bfloat16, torch.float, torch.double)
@dtypes(torch.half, torch.bfloat16, torch.float, torch.double)
def test_multi_d(self, device, dtype):
initial_value = 0
axis = 0
lengths = [0, 2]
data = np.arange(20).reshape(2, 2, 5).tolist()
@ -181,31 +223,46 @@ class TestSegmentReductions(TestCase):
# TODO: calculate grad and check correctness
check_backward = False
for reduction in ["max", "mean"]:
for reduction in reductions:
if reduction == "max":
initial_value = 0
expected_result = [
np.full((2, 5), initial_value).tolist(),
np.max(data, axis=0).tolist(),
]
elif reduction == "mean":
initial_value = 0
expected_result = [
np.full((2, 5), initial_value).tolist(),
np.mean(data, axis=0).tolist(),
]
elif reduction == "min":
initial_value = 1000 # some high number
expected_result = [
np.full((2, 5), initial_value).tolist(),
np.min(data, axis=0).tolist(),
]
elif reduction == "sum":
initial_value = 0
expected_result = [
np.full((2, 5), initial_value).tolist(),
np.sum(data, axis=0).tolist(),
]
for unsafe in [True, False]:
self._test_common(
reduction,
device,
dtype,
unsafe,
axis,
initial_value,
data,
lengths,
expected_result,
expected_grad,
check_backward,
)
for initial in [initial_value, None]:
self._test_common(
reduction,
device,
dtype,
unsafe,
axis,
initial_value,
data,
lengths,
expected_result,
expected_grad,
check_backward,
)
instantiate_device_type_tests(TestSegmentReductions, globals())