mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[torch][segment_reduce] Add support for sum and min reductions (#60379)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/60379 This concludes the support for all reductions types initially planned (min, max, mean, sum). Next Steps: - Cleanups - update default values when length is 0 and initial not given - templatize the code to avoid branching with every item.( and other known improvements) - more unit tests, verification - benchmarking Test Plan: updated unit tests. Reviewed By: ngimel Differential Revision: D29268218 fbshipit-source-id: c77d91671e01dcf96c18c758fa3ea522b2e13db9
This commit is contained in:
parent
63219f1f9f
commit
af66824c1f
|
|
@ -18,6 +18,10 @@ SegmentReductionType get_reduction_enum(const c10::string_view& reduce) {
|
|||
return SegmentReductionType::MAX;
|
||||
} else if (reduce == "mean") {
|
||||
return SegmentReductionType::MEAN;
|
||||
} else if (reduce == "min") {
|
||||
return SegmentReductionType::MIN;
|
||||
} else if (reduce == "sum") {
|
||||
return SegmentReductionType::SUM;
|
||||
} else {
|
||||
TORCH_CHECK(false, "unsopported reduction given! ", reduce);
|
||||
}
|
||||
|
|
@ -37,7 +41,7 @@ Tensor _segment_reduce_cpu_kernel(
|
|||
int64_t stride_count = data.numel() / data.size(axis);
|
||||
const auto* lengths_data = lengths.data_ptr<int64_t>();
|
||||
|
||||
AT_DISPATCH_ALL_TYPES_AND2(
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
kBFloat16, kHalf, data.scalar_type(), "_segment_reduce_cpu", ([&]() {
|
||||
auto* output_data = output.data_ptr<scalar_t>();
|
||||
const auto* values_data = data.data_ptr<scalar_t>();
|
||||
|
|
@ -49,9 +53,13 @@ Tensor _segment_reduce_cpu_kernel(
|
|||
if (initial.has_value()) {
|
||||
initial_value = initial.value().to<scalar_t>();
|
||||
} else if (reduction == SegmentReductionType::MAX) {
|
||||
initial_value = std::numeric_limits<scalar_t>::lowest();
|
||||
} else if (reduction == SegmentReductionType::MEAN) {
|
||||
initial_value = -std::numeric_limits<scalar_t>::infinity();
|
||||
} else if (
|
||||
reduction == SegmentReductionType::MEAN ||
|
||||
reduction == SegmentReductionType::SUM) {
|
||||
initial_value = 0;
|
||||
} else if (reduction == SegmentReductionType::MIN) {
|
||||
initial_value = std::numeric_limits<scalar_t>::infinity();
|
||||
}
|
||||
|
||||
// ===== step2: apply reduction
|
||||
|
|
@ -64,8 +72,14 @@ Tensor _segment_reduce_cpu_kernel(
|
|||
initial_value = at::_isnan(data)
|
||||
? data
|
||||
: std::max<scalar_t>(initial_value, data);
|
||||
} else if (reduction == SegmentReductionType::MEAN) {
|
||||
} else if (
|
||||
reduction == SegmentReductionType::MEAN ||
|
||||
reduction == SegmentReductionType::SUM) {
|
||||
initial_value = initial_value + data;
|
||||
} else if (reduction == SegmentReductionType::MIN) {
|
||||
initial_value = at::_isnan(data)
|
||||
? data
|
||||
: std::min<scalar_t>(initial_value, data);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -105,7 +119,7 @@ Tensor _segment_reduce_cpu_backward_kernel(
|
|||
const auto* lengths_data = lengths_contig.data_ptr<int64_t>();
|
||||
|
||||
// TODO: Swtich to TensorIterator for better maintainablility and readability
|
||||
AT_DISPATCH_ALL_TYPES_AND2(
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
kBFloat16,
|
||||
kHalf,
|
||||
data_contig.scalar_type(),
|
||||
|
|
@ -125,7 +139,8 @@ Tensor _segment_reduce_cpu_backward_kernel(
|
|||
for (int64_t l = 0; l < stride_count; ++l) {
|
||||
int64_t output_index = (i * stride_count) + l;
|
||||
|
||||
if (reduction == SegmentReductionType::MAX) {
|
||||
if (reduction == SegmentReductionType::MAX ||
|
||||
reduction == SegmentReductionType::MIN) {
|
||||
int64_t counter = 0;
|
||||
for (int64_t j = 0; j < lengths_data[i]; ++j) {
|
||||
int64_t starting_index =
|
||||
|
|
@ -156,6 +171,13 @@ Tensor _segment_reduce_cpu_backward_kernel(
|
|||
((lengths_cum_sum + j) * stride_count) + l;
|
||||
grad_input_data[starting_index] = grad_val;
|
||||
}
|
||||
} else if (reduction == SegmentReductionType::SUM) {
|
||||
const auto& grad_val = grad_data[output_index];
|
||||
for (int64_t j = 0; j < lengths_data[i]; ++j) {
|
||||
int64_t starting_index =
|
||||
((lengths_cum_sum + j) * stride_count) + l;
|
||||
grad_input_data[starting_index] = grad_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
namespace at {
|
||||
namespace native {
|
||||
|
||||
enum SegmentReductionType { MAX, MEAN };
|
||||
enum SegmentReductionType { MAX, MEAN, MIN, SUM };
|
||||
|
||||
using segment_reduce_fn = Tensor (*)(
|
||||
SegmentReductionType,
|
||||
|
|
|
|||
|
|
@ -32,6 +32,19 @@ struct CustomSum {
|
|||
}
|
||||
};
|
||||
|
||||
struct CustomMin {
|
||||
template <typename OutputT>
|
||||
__host__ __device__ __forceinline__ OutputT
|
||||
operator()(const OutputT& a, const OutputT& b) const {
|
||||
if (at::_isnan(a)) {
|
||||
return a;
|
||||
} else if (at::_isnan(b)) {
|
||||
return b;
|
||||
}
|
||||
return std::min<OutputT>(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
Tensor _get_complete_sum(const Tensor& lengths) {
|
||||
int64_t segment_count = lengths.numel();
|
||||
TORCH_CHECK(segment_count < INT_MAX);
|
||||
|
|
@ -99,8 +112,13 @@ __global__ static void segment_reduce_forward_kernel(
|
|||
if (reduction == SegmentReductionType::MAX) {
|
||||
initial_value =
|
||||
at::_isnan(data) ? data : std::max<scalar_t>(initial_value, data);
|
||||
} else if (reduction == SegmentReductionType::MEAN) {
|
||||
} else if (
|
||||
reduction == SegmentReductionType::MEAN ||
|
||||
reduction == SegmentReductionType::SUM) {
|
||||
initial_value = initial_value + data;
|
||||
} else if (reduction == SegmentReductionType::MIN) {
|
||||
initial_value =
|
||||
at::_isnan(data) ? data : std::min<scalar_t>(initial_value, data);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -144,7 +162,8 @@ __global__ static void segment_reduce_backward_kernel(
|
|||
|
||||
int64_t output_index = (row_id * stride_count) + lane_id;
|
||||
|
||||
if (reduction == SegmentReductionType::MAX) {
|
||||
if (reduction == SegmentReductionType::MAX ||
|
||||
reduction == SegmentReductionType::MIN) {
|
||||
int64_t counter = 0;
|
||||
for (int64_t j = offset_start; j < offset_end; ++j) {
|
||||
int64_t starting_index = (j * stride_count) + lane_id;
|
||||
|
|
@ -172,6 +191,12 @@ __global__ static void segment_reduce_backward_kernel(
|
|||
int64_t starting_index = (j * stride_count) + lane_id;
|
||||
grad_input_data[starting_index] = grad_val;
|
||||
}
|
||||
} else if (reduction == SegmentReductionType::SUM) {
|
||||
const auto& grad_val = grad_data[output_index];
|
||||
for (int64_t j = offset_start; j < offset_end; ++j) {
|
||||
int64_t starting_index = (j * stride_count) + lane_id;
|
||||
grad_input_data[starting_index] = grad_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -203,7 +228,7 @@ Tensor _segment_reduce_cuda_backward_kernel(
|
|||
num_blocks = std::max(num_blocks, (int64_t)1);
|
||||
|
||||
// TODO: Swtich to TensorIterator for better maintainablility and readability
|
||||
AT_DISPATCH_ALL_TYPES_AND2(
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
kBFloat16,
|
||||
kHalf,
|
||||
data_contig.scalar_type(),
|
||||
|
|
@ -259,7 +284,7 @@ Tensor _segment_reduce_cuda_kernel(
|
|||
num_blocks = std::max(num_blocks, (int64_t)1);
|
||||
auto* lengths_data_ptr = lengths.data_ptr<int64_t>();
|
||||
|
||||
AT_DISPATCH_ALL_TYPES_AND2(
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
data.scalar_type(),
|
||||
|
|
@ -273,9 +298,13 @@ Tensor _segment_reduce_cuda_kernel(
|
|||
if (initial.has_value()) {
|
||||
initial_value = initial.value().to<scalar_t>();
|
||||
} else if (reduction == SegmentReductionType::MAX) {
|
||||
initial_value = std::numeric_limits<scalar_t>::lowest();
|
||||
} else if (reduction == SegmentReductionType::MEAN) {
|
||||
initial_value = -std::numeric_limits<scalar_t>::infinity();
|
||||
} else if (
|
||||
reduction == SegmentReductionType::MEAN ||
|
||||
reduction == SegmentReductionType::SUM) {
|
||||
initial_value = 0;
|
||||
} else if (reduction == SegmentReductionType::MIN) {
|
||||
initial_value = std::numeric_limits<scalar_t>::infinity();
|
||||
}
|
||||
|
||||
if (output_shape.size() > 1) {
|
||||
|
|
@ -331,6 +360,30 @@ Tensor _segment_reduce_cuda_kernel(
|
|||
initial.has_value(),
|
||||
initial_value);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
} else if (reduction == SegmentReductionType::MIN) {
|
||||
CustomMin min_op{};
|
||||
CUB_WRAPPER(
|
||||
cub::DeviceSegmentedReduce::Reduce,
|
||||
data_data_ptr,
|
||||
output_data_ptr,
|
||||
segment_count,
|
||||
offsets_data_ptr,
|
||||
offsets_data_ptr + 1,
|
||||
min_op,
|
||||
initial_value,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
} else if (reduction == SegmentReductionType::SUM) {
|
||||
CustomSum sum_op{};
|
||||
CUB_WRAPPER(
|
||||
cub::DeviceSegmentedReduce::Reduce,
|
||||
data_data_ptr,
|
||||
output_data_ptr,
|
||||
segment_count,
|
||||
offsets_data_ptr,
|
||||
offsets_data_ptr + 1,
|
||||
sum_op,
|
||||
initial_value,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
}
|
||||
}
|
||||
});
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ import torch
|
|||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests,
|
||||
dtypes,
|
||||
dtypesIfCUDA,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase,
|
||||
|
|
@ -12,6 +11,9 @@ from torch.testing._internal.common_utils import (
|
|||
)
|
||||
|
||||
|
||||
reductions = ["max", "mean", "min", "sum"]
|
||||
|
||||
|
||||
class TestSegmentReductions(TestCase):
|
||||
def _test_common(
|
||||
self,
|
||||
|
|
@ -83,48 +85,56 @@ class TestSegmentReductions(TestCase):
|
|||
)
|
||||
)
|
||||
|
||||
@dtypesIfCUDA(torch.half, torch.bfloat16, torch.float, torch.double)
|
||||
@dtypes(torch.half, torch.bfloat16, torch.float, torch.double)
|
||||
def test_simple_1d(self, device, dtype):
|
||||
lengths = [1, 2, 3, 0]
|
||||
data = [1, float("nan"), 3, 4, 5, 5]
|
||||
initial_value = 0
|
||||
check_backward = True
|
||||
|
||||
for reduction in ("max", "mean"):
|
||||
for reduction in reductions:
|
||||
if reduction == "max":
|
||||
initial_value = 0
|
||||
expected_result = [1, float("nan"), 5, initial_value]
|
||||
expected_grad = [1, 1, 0, 0, 0.5, 0.5]
|
||||
elif reduction == "mean":
|
||||
initial_value = 0
|
||||
expected_result = [1, float("nan"), 4.666, initial_value]
|
||||
expected_grad = [1.0, 0.5, 0.5, 0.333, 0.333, 0.333]
|
||||
elif reduction == "min":
|
||||
initial_value = 1000 # some high number
|
||||
expected_result = [1, float("nan"), 4, initial_value]
|
||||
expected_grad = [1.0, 1.0, 0, 1, 0, 0]
|
||||
elif reduction == "sum":
|
||||
initial_value = 0
|
||||
expected_result = [1, float("nan"), 14, initial_value]
|
||||
expected_grad = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
|
||||
for axis in [0, -1]:
|
||||
for unsafe in [True, False]:
|
||||
self._test_common(
|
||||
reduction,
|
||||
device,
|
||||
dtype,
|
||||
unsafe,
|
||||
axis,
|
||||
initial_value,
|
||||
data,
|
||||
lengths,
|
||||
expected_result,
|
||||
expected_grad,
|
||||
check_backward,
|
||||
)
|
||||
for initial in [initial_value, None]:
|
||||
self._test_common(
|
||||
reduction,
|
||||
device,
|
||||
dtype,
|
||||
unsafe,
|
||||
axis,
|
||||
initial_value,
|
||||
data,
|
||||
lengths,
|
||||
expected_result,
|
||||
expected_grad,
|
||||
check_backward,
|
||||
)
|
||||
|
||||
@dtypesIfCUDA(torch.half, torch.bfloat16, torch.float, torch.double)
|
||||
@dtypes(torch.half, torch.bfloat16, torch.float, torch.double)
|
||||
def test_multi_d_simple(self, device, dtype):
|
||||
initial_value = 0
|
||||
check_backward = True
|
||||
axis = 0
|
||||
lengths = [1, 2, 3, 0]
|
||||
data = [[1, 1], [float("nan"), 1], [3, float("nan")], [4, 1], [3, 2], [2, 3]]
|
||||
|
||||
for reduction in ("max", "mean"):
|
||||
for reduction in reductions:
|
||||
if reduction == "max":
|
||||
initial_value = 0
|
||||
expected_result = [
|
||||
[1, 1],
|
||||
[float("nan"), float("nan")],
|
||||
|
|
@ -140,6 +150,7 @@ class TestSegmentReductions(TestCase):
|
|||
[0, 1],
|
||||
]
|
||||
elif reduction == "mean":
|
||||
initial_value = 0
|
||||
expected_result = [
|
||||
[1, 1],
|
||||
[float("nan"), float("nan")],
|
||||
|
|
@ -154,25 +165,56 @@ class TestSegmentReductions(TestCase):
|
|||
[0.333, 0.333],
|
||||
[0.333, 0.333],
|
||||
]
|
||||
elif reduction == "min":
|
||||
initial_value = 1000 # some high number
|
||||
expected_result = [
|
||||
[1, 1],
|
||||
[float("nan"), float("nan")],
|
||||
[2, 1],
|
||||
[initial_value, initial_value],
|
||||
]
|
||||
expected_grad = [
|
||||
[1.0, 1.0],
|
||||
[1, 0],
|
||||
[0, 1],
|
||||
[0, 1],
|
||||
[0, 0],
|
||||
[1, 0],
|
||||
]
|
||||
elif reduction == "sum":
|
||||
initial_value = 0
|
||||
expected_result = [
|
||||
[1, 1],
|
||||
[float("nan"), float("nan")],
|
||||
[9, 6],
|
||||
[initial_value, initial_value],
|
||||
]
|
||||
expected_grad = [
|
||||
[1.0, 1.0],
|
||||
[1.0, 1.0],
|
||||
[1.0, 1.0],
|
||||
[1.0, 1.0],
|
||||
[1.0, 1.0],
|
||||
[1.0, 1.0],
|
||||
]
|
||||
for unsafe in [True, False]:
|
||||
self._test_common(
|
||||
reduction,
|
||||
device,
|
||||
dtype,
|
||||
unsafe,
|
||||
axis,
|
||||
initial_value,
|
||||
data,
|
||||
lengths,
|
||||
expected_result,
|
||||
expected_grad,
|
||||
check_backward,
|
||||
)
|
||||
for initial in [initial_value, None]:
|
||||
self._test_common(
|
||||
reduction,
|
||||
device,
|
||||
dtype,
|
||||
unsafe,
|
||||
axis,
|
||||
initial_value,
|
||||
data,
|
||||
lengths,
|
||||
expected_result,
|
||||
expected_grad,
|
||||
check_backward,
|
||||
)
|
||||
|
||||
@dtypesIfCUDA(torch.half, torch.bfloat16, torch.float, torch.double)
|
||||
@dtypes(torch.half, torch.bfloat16, torch.float, torch.double)
|
||||
def test_multi_d(self, device, dtype):
|
||||
initial_value = 0
|
||||
axis = 0
|
||||
lengths = [0, 2]
|
||||
data = np.arange(20).reshape(2, 2, 5).tolist()
|
||||
|
|
@ -181,31 +223,46 @@ class TestSegmentReductions(TestCase):
|
|||
# TODO: calculate grad and check correctness
|
||||
check_backward = False
|
||||
|
||||
for reduction in ["max", "mean"]:
|
||||
for reduction in reductions:
|
||||
if reduction == "max":
|
||||
initial_value = 0
|
||||
expected_result = [
|
||||
np.full((2, 5), initial_value).tolist(),
|
||||
np.max(data, axis=0).tolist(),
|
||||
]
|
||||
elif reduction == "mean":
|
||||
initial_value = 0
|
||||
expected_result = [
|
||||
np.full((2, 5), initial_value).tolist(),
|
||||
np.mean(data, axis=0).tolist(),
|
||||
]
|
||||
elif reduction == "min":
|
||||
initial_value = 1000 # some high number
|
||||
expected_result = [
|
||||
np.full((2, 5), initial_value).tolist(),
|
||||
np.min(data, axis=0).tolist(),
|
||||
]
|
||||
elif reduction == "sum":
|
||||
initial_value = 0
|
||||
expected_result = [
|
||||
np.full((2, 5), initial_value).tolist(),
|
||||
np.sum(data, axis=0).tolist(),
|
||||
]
|
||||
for unsafe in [True, False]:
|
||||
self._test_common(
|
||||
reduction,
|
||||
device,
|
||||
dtype,
|
||||
unsafe,
|
||||
axis,
|
||||
initial_value,
|
||||
data,
|
||||
lengths,
|
||||
expected_result,
|
||||
expected_grad,
|
||||
check_backward,
|
||||
)
|
||||
for initial in [initial_value, None]:
|
||||
self._test_common(
|
||||
reduction,
|
||||
device,
|
||||
dtype,
|
||||
unsafe,
|
||||
axis,
|
||||
initial_value,
|
||||
data,
|
||||
lengths,
|
||||
expected_result,
|
||||
expected_grad,
|
||||
check_backward,
|
||||
)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestSegmentReductions, globals())
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user