diff --git a/aten/src/ATen/native/SegmentReduce.cpp b/aten/src/ATen/native/SegmentReduce.cpp index fb6ebea07d0..00bb25d23a9 100644 --- a/aten/src/ATen/native/SegmentReduce.cpp +++ b/aten/src/ATen/native/SegmentReduce.cpp @@ -18,6 +18,10 @@ SegmentReductionType get_reduction_enum(const c10::string_view& reduce) { return SegmentReductionType::MAX; } else if (reduce == "mean") { return SegmentReductionType::MEAN; + } else if (reduce == "min") { + return SegmentReductionType::MIN; + } else if (reduce == "sum") { + return SegmentReductionType::SUM; } else { TORCH_CHECK(false, "unsopported reduction given! ", reduce); } @@ -37,7 +41,7 @@ Tensor _segment_reduce_cpu_kernel( int64_t stride_count = data.numel() / data.size(axis); const auto* lengths_data = lengths.data_ptr(); - AT_DISPATCH_ALL_TYPES_AND2( + AT_DISPATCH_FLOATING_TYPES_AND2( kBFloat16, kHalf, data.scalar_type(), "_segment_reduce_cpu", ([&]() { auto* output_data = output.data_ptr(); const auto* values_data = data.data_ptr(); @@ -49,9 +53,13 @@ Tensor _segment_reduce_cpu_kernel( if (initial.has_value()) { initial_value = initial.value().to(); } else if (reduction == SegmentReductionType::MAX) { - initial_value = std::numeric_limits::lowest(); - } else if (reduction == SegmentReductionType::MEAN) { + initial_value = -std::numeric_limits::infinity(); + } else if ( + reduction == SegmentReductionType::MEAN || + reduction == SegmentReductionType::SUM) { initial_value = 0; + } else if (reduction == SegmentReductionType::MIN) { + initial_value = std::numeric_limits::infinity(); } // ===== step2: apply reduction @@ -64,8 +72,14 @@ Tensor _segment_reduce_cpu_kernel( initial_value = at::_isnan(data) ? data : std::max(initial_value, data); - } else if (reduction == SegmentReductionType::MEAN) { + } else if ( + reduction == SegmentReductionType::MEAN || + reduction == SegmentReductionType::SUM) { initial_value = initial_value + data; + } else if (reduction == SegmentReductionType::MIN) { + initial_value = at::_isnan(data) + ? data + : std::min(initial_value, data); } } @@ -105,7 +119,7 @@ Tensor _segment_reduce_cpu_backward_kernel( const auto* lengths_data = lengths_contig.data_ptr(); // TODO: Swtich to TensorIterator for better maintainablility and readability - AT_DISPATCH_ALL_TYPES_AND2( + AT_DISPATCH_FLOATING_TYPES_AND2( kBFloat16, kHalf, data_contig.scalar_type(), @@ -125,7 +139,8 @@ Tensor _segment_reduce_cpu_backward_kernel( for (int64_t l = 0; l < stride_count; ++l) { int64_t output_index = (i * stride_count) + l; - if (reduction == SegmentReductionType::MAX) { + if (reduction == SegmentReductionType::MAX || + reduction == SegmentReductionType::MIN) { int64_t counter = 0; for (int64_t j = 0; j < lengths_data[i]; ++j) { int64_t starting_index = @@ -156,6 +171,13 @@ Tensor _segment_reduce_cpu_backward_kernel( ((lengths_cum_sum + j) * stride_count) + l; grad_input_data[starting_index] = grad_val; } + } else if (reduction == SegmentReductionType::SUM) { + const auto& grad_val = grad_data[output_index]; + for (int64_t j = 0; j < lengths_data[i]; ++j) { + int64_t starting_index = + ((lengths_cum_sum + j) * stride_count) + l; + grad_input_data[starting_index] = grad_val; + } } } diff --git a/aten/src/ATen/native/SegmentReduce.h b/aten/src/ATen/native/SegmentReduce.h index 5eb87d798ee..11a399ae77a 100644 --- a/aten/src/ATen/native/SegmentReduce.h +++ b/aten/src/ATen/native/SegmentReduce.h @@ -7,7 +7,7 @@ namespace at { namespace native { -enum SegmentReductionType { MAX, MEAN }; +enum SegmentReductionType { MAX, MEAN, MIN, SUM }; using segment_reduce_fn = Tensor (*)( SegmentReductionType, diff --git a/aten/src/ATen/native/cuda/SegmentReduce.cu b/aten/src/ATen/native/cuda/SegmentReduce.cu index 36ae51a4eae..c133e31b570 100644 --- a/aten/src/ATen/native/cuda/SegmentReduce.cu +++ b/aten/src/ATen/native/cuda/SegmentReduce.cu @@ -32,6 +32,19 @@ struct CustomSum { } }; +struct CustomMin { + template + __host__ __device__ __forceinline__ OutputT + operator()(const OutputT& a, const OutputT& b) const { + if (at::_isnan(a)) { + return a; + } else if (at::_isnan(b)) { + return b; + } + return std::min(a, b); + } +}; + Tensor _get_complete_sum(const Tensor& lengths) { int64_t segment_count = lengths.numel(); TORCH_CHECK(segment_count < INT_MAX); @@ -99,8 +112,13 @@ __global__ static void segment_reduce_forward_kernel( if (reduction == SegmentReductionType::MAX) { initial_value = at::_isnan(data) ? data : std::max(initial_value, data); - } else if (reduction == SegmentReductionType::MEAN) { + } else if ( + reduction == SegmentReductionType::MEAN || + reduction == SegmentReductionType::SUM) { initial_value = initial_value + data; + } else if (reduction == SegmentReductionType::MIN) { + initial_value = + at::_isnan(data) ? data : std::min(initial_value, data); } } @@ -144,7 +162,8 @@ __global__ static void segment_reduce_backward_kernel( int64_t output_index = (row_id * stride_count) + lane_id; - if (reduction == SegmentReductionType::MAX) { + if (reduction == SegmentReductionType::MAX || + reduction == SegmentReductionType::MIN) { int64_t counter = 0; for (int64_t j = offset_start; j < offset_end; ++j) { int64_t starting_index = (j * stride_count) + lane_id; @@ -172,6 +191,12 @@ __global__ static void segment_reduce_backward_kernel( int64_t starting_index = (j * stride_count) + lane_id; grad_input_data[starting_index] = grad_val; } + } else if (reduction == SegmentReductionType::SUM) { + const auto& grad_val = grad_data[output_index]; + for (int64_t j = offset_start; j < offset_end; ++j) { + int64_t starting_index = (j * stride_count) + lane_id; + grad_input_data[starting_index] = grad_val; + } } } @@ -203,7 +228,7 @@ Tensor _segment_reduce_cuda_backward_kernel( num_blocks = std::max(num_blocks, (int64_t)1); // TODO: Swtich to TensorIterator for better maintainablility and readability - AT_DISPATCH_ALL_TYPES_AND2( + AT_DISPATCH_FLOATING_TYPES_AND2( kBFloat16, kHalf, data_contig.scalar_type(), @@ -259,7 +284,7 @@ Tensor _segment_reduce_cuda_kernel( num_blocks = std::max(num_blocks, (int64_t)1); auto* lengths_data_ptr = lengths.data_ptr(); - AT_DISPATCH_ALL_TYPES_AND2( + AT_DISPATCH_FLOATING_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, data.scalar_type(), @@ -273,9 +298,13 @@ Tensor _segment_reduce_cuda_kernel( if (initial.has_value()) { initial_value = initial.value().to(); } else if (reduction == SegmentReductionType::MAX) { - initial_value = std::numeric_limits::lowest(); - } else if (reduction == SegmentReductionType::MEAN) { + initial_value = -std::numeric_limits::infinity(); + } else if ( + reduction == SegmentReductionType::MEAN || + reduction == SegmentReductionType::SUM) { initial_value = 0; + } else if (reduction == SegmentReductionType::MIN) { + initial_value = std::numeric_limits::infinity(); } if (output_shape.size() > 1) { @@ -331,6 +360,30 @@ Tensor _segment_reduce_cuda_kernel( initial.has_value(), initial_value); C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else if (reduction == SegmentReductionType::MIN) { + CustomMin min_op{}; + CUB_WRAPPER( + cub::DeviceSegmentedReduce::Reduce, + data_data_ptr, + output_data_ptr, + segment_count, + offsets_data_ptr, + offsets_data_ptr + 1, + min_op, + initial_value, + at::cuda::getCurrentCUDAStream()); + } else if (reduction == SegmentReductionType::SUM) { + CustomSum sum_op{}; + CUB_WRAPPER( + cub::DeviceSegmentedReduce::Reduce, + data_data_ptr, + output_data_ptr, + segment_count, + offsets_data_ptr, + offsets_data_ptr + 1, + sum_op, + initial_value, + at::cuda::getCurrentCUDAStream()); } } }); diff --git a/test/test_segment_reductions.py b/test/test_segment_reductions.py index 8b349b4c0b3..ad8e53ea776 100644 --- a/test/test_segment_reductions.py +++ b/test/test_segment_reductions.py @@ -3,7 +3,6 @@ import torch from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, dtypes, - dtypesIfCUDA, ) from torch.testing._internal.common_utils import ( TestCase, @@ -12,6 +11,9 @@ from torch.testing._internal.common_utils import ( ) +reductions = ["max", "mean", "min", "sum"] + + class TestSegmentReductions(TestCase): def _test_common( self, @@ -83,48 +85,56 @@ class TestSegmentReductions(TestCase): ) ) - @dtypesIfCUDA(torch.half, torch.bfloat16, torch.float, torch.double) @dtypes(torch.half, torch.bfloat16, torch.float, torch.double) def test_simple_1d(self, device, dtype): lengths = [1, 2, 3, 0] data = [1, float("nan"), 3, 4, 5, 5] - initial_value = 0 check_backward = True - for reduction in ("max", "mean"): + for reduction in reductions: if reduction == "max": + initial_value = 0 expected_result = [1, float("nan"), 5, initial_value] expected_grad = [1, 1, 0, 0, 0.5, 0.5] elif reduction == "mean": + initial_value = 0 expected_result = [1, float("nan"), 4.666, initial_value] expected_grad = [1.0, 0.5, 0.5, 0.333, 0.333, 0.333] + elif reduction == "min": + initial_value = 1000 # some high number + expected_result = [1, float("nan"), 4, initial_value] + expected_grad = [1.0, 1.0, 0, 1, 0, 0] + elif reduction == "sum": + initial_value = 0 + expected_result = [1, float("nan"), 14, initial_value] + expected_grad = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0] for axis in [0, -1]: for unsafe in [True, False]: - self._test_common( - reduction, - device, - dtype, - unsafe, - axis, - initial_value, - data, - lengths, - expected_result, - expected_grad, - check_backward, - ) + for initial in [initial_value, None]: + self._test_common( + reduction, + device, + dtype, + unsafe, + axis, + initial_value, + data, + lengths, + expected_result, + expected_grad, + check_backward, + ) - @dtypesIfCUDA(torch.half, torch.bfloat16, torch.float, torch.double) @dtypes(torch.half, torch.bfloat16, torch.float, torch.double) def test_multi_d_simple(self, device, dtype): - initial_value = 0 check_backward = True axis = 0 lengths = [1, 2, 3, 0] data = [[1, 1], [float("nan"), 1], [3, float("nan")], [4, 1], [3, 2], [2, 3]] - for reduction in ("max", "mean"): + for reduction in reductions: if reduction == "max": + initial_value = 0 expected_result = [ [1, 1], [float("nan"), float("nan")], @@ -140,6 +150,7 @@ class TestSegmentReductions(TestCase): [0, 1], ] elif reduction == "mean": + initial_value = 0 expected_result = [ [1, 1], [float("nan"), float("nan")], @@ -154,25 +165,56 @@ class TestSegmentReductions(TestCase): [0.333, 0.333], [0.333, 0.333], ] + elif reduction == "min": + initial_value = 1000 # some high number + expected_result = [ + [1, 1], + [float("nan"), float("nan")], + [2, 1], + [initial_value, initial_value], + ] + expected_grad = [ + [1.0, 1.0], + [1, 0], + [0, 1], + [0, 1], + [0, 0], + [1, 0], + ] + elif reduction == "sum": + initial_value = 0 + expected_result = [ + [1, 1], + [float("nan"), float("nan")], + [9, 6], + [initial_value, initial_value], + ] + expected_grad = [ + [1.0, 1.0], + [1.0, 1.0], + [1.0, 1.0], + [1.0, 1.0], + [1.0, 1.0], + [1.0, 1.0], + ] for unsafe in [True, False]: - self._test_common( - reduction, - device, - dtype, - unsafe, - axis, - initial_value, - data, - lengths, - expected_result, - expected_grad, - check_backward, - ) + for initial in [initial_value, None]: + self._test_common( + reduction, + device, + dtype, + unsafe, + axis, + initial_value, + data, + lengths, + expected_result, + expected_grad, + check_backward, + ) - @dtypesIfCUDA(torch.half, torch.bfloat16, torch.float, torch.double) @dtypes(torch.half, torch.bfloat16, torch.float, torch.double) def test_multi_d(self, device, dtype): - initial_value = 0 axis = 0 lengths = [0, 2] data = np.arange(20).reshape(2, 2, 5).tolist() @@ -181,31 +223,46 @@ class TestSegmentReductions(TestCase): # TODO: calculate grad and check correctness check_backward = False - for reduction in ["max", "mean"]: + for reduction in reductions: if reduction == "max": + initial_value = 0 expected_result = [ np.full((2, 5), initial_value).tolist(), np.max(data, axis=0).tolist(), ] elif reduction == "mean": + initial_value = 0 expected_result = [ np.full((2, 5), initial_value).tolist(), np.mean(data, axis=0).tolist(), ] + elif reduction == "min": + initial_value = 1000 # some high number + expected_result = [ + np.full((2, 5), initial_value).tolist(), + np.min(data, axis=0).tolist(), + ] + elif reduction == "sum": + initial_value = 0 + expected_result = [ + np.full((2, 5), initial_value).tolist(), + np.sum(data, axis=0).tolist(), + ] for unsafe in [True, False]: - self._test_common( - reduction, - device, - dtype, - unsafe, - axis, - initial_value, - data, - lengths, - expected_result, - expected_grad, - check_backward, - ) + for initial in [initial_value, None]: + self._test_common( + reduction, + device, + dtype, + unsafe, + axis, + initial_value, + data, + lengths, + expected_result, + expected_grad, + check_backward, + ) instantiate_device_type_tests(TestSegmentReductions, globals())