unify reduction types from different operators: scatter, scatter_reduce, segment_reduce (#91499)

The target of this PR is to unify `ReductionType` for reduce operators so that we have the same set of reduce utils for `init`, or `update` for vectorization.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91499
Approved by: https://github.com/ngimel
This commit is contained in:
mingfeima 2023-01-12 10:19:30 +08:00 committed by PyTorch MergeBot
parent a70387f0fa
commit eb7b89771e
12 changed files with 211 additions and 218 deletions

View File

@ -0,0 +1,40 @@
#pragma once
#include <c10/core/Scalar.h>
namespace at { namespace native {
enum ReductionType {MAX, MEAN, MIN, SUM, PROD};
static inline ReductionType get_reduction_enum(const c10::string_view& reduce) {
if (reduce == "amax") {
return ReductionType::MAX;
} else if (reduce == "mean") {
return ReductionType::MEAN;
} else if (reduce == "amin") {
return ReductionType::MIN;
} else if (reduce == "sum") {
return ReductionType::SUM;
} else if (reduce == "prod") {
return ReductionType::PROD;
} else {
TORCH_CHECK(false, "reduce argument must be either sum, prod, mean, amax or amin, got ", reduce);
}
}
// used for `scatter_reduce`, old options for BC.
static inline ReductionType get_operator_enum(const c10::string_view reduce, bool use_new_options) {
if (use_new_options) {
return get_reduction_enum(reduce);
} else {
if (reduce == "add") {
return ReductionType::SUM;
} else if (reduce == "multiply") {
return ReductionType::PROD;
} else {
TORCH_CHECK(false, "reduce argument must be either add or multiply.")
}
}
}
}} // at::native

View File

@ -28,25 +28,9 @@ DEFINE_DISPATCH(_segment_reduce_offsets_backward_stub);
namespace {
SegmentReductionType get_reduction_enum(const c10::string_view& reduce) {
if (reduce == "max") {
return SegmentReductionType::MAX;
} else if (reduce == "mean") {
return SegmentReductionType::MEAN;
} else if (reduce == "min") {
return SegmentReductionType::MIN;
} else if (reduce == "sum") {
return SegmentReductionType::SUM;
} else if (reduce == "prod") {
return SegmentReductionType::PROD;
} else {
TORCH_CHECK(false, "unsupported reduction given! ", reduce);
}
}
template <typename T, bool is_offsets_like=false>
void _segment_reduce_lengths_cpu_kernel1(
SegmentReductionType reduction,
ReductionType reduction,
const Tensor& data,
const T* lengths_data,
int64_t axis,
@ -90,15 +74,15 @@ void _segment_reduce_lengths_cpu_kernel1(
scalar_t initial_value;
if (initial.has_value()) {
initial_value = initial.value().to<scalar_t>();
} else if (reduction == SegmentReductionType::MAX) {
} else if (reduction == ReductionType::MAX) {
initial_value = -std::numeric_limits<scalar_t>::infinity();
} else if (
reduction == SegmentReductionType::MEAN ||
reduction == SegmentReductionType::SUM) {
reduction == ReductionType::MEAN ||
reduction == ReductionType::SUM) {
initial_value = 0;
} else if (reduction == SegmentReductionType::MIN) {
} else if (reduction == ReductionType::MIN) {
initial_value = std::numeric_limits<scalar_t>::infinity();
} else if (reduction == SegmentReductionType::PROD) {
} else if (reduction == ReductionType::PROD) {
initial_value = 1;
}
@ -107,19 +91,19 @@ void _segment_reduce_lengths_cpu_kernel1(
int64_t data_index = outer_idx * data_stride_axis * data_size_axis
+ j * data_stride_axis + inner_idx;
const auto val = values_data[data_index];
if (reduction == SegmentReductionType::MAX) {
if (reduction == ReductionType::MAX) {
initial_value = at::_isnan(val)
? val
: std::max<scalar_t>(initial_value, val);
} else if (
reduction == SegmentReductionType::MEAN ||
reduction == SegmentReductionType::SUM) {
reduction == ReductionType::MEAN ||
reduction == ReductionType::SUM) {
initial_value = initial_value + val;
} else if (reduction == SegmentReductionType::MIN) {
} else if (reduction == ReductionType::MIN) {
initial_value = at::_isnan(val)
? val
: std::min<scalar_t>(initial_value, val);
} else if (reduction == SegmentReductionType::PROD) {
} else if (reduction == ReductionType::PROD) {
initial_value = initial_value * val;
}
}
@ -128,10 +112,10 @@ void _segment_reduce_lengths_cpu_kernel1(
TORCH_CHECK(segment_length >= 0);
if (segment_length == 0 && !initial.has_value() &&
reduction == SegmentReductionType::MEAN) {
reduction == ReductionType::MEAN) {
initial_value = static_cast<scalar_t>(NAN);
} else if (
reduction == SegmentReductionType::MEAN &&
reduction == ReductionType::MEAN &&
segment_length > 0 && !at::_isnan(initial_value)) {
initial_value = initial_value / segment_length;
}
@ -145,7 +129,7 @@ void _segment_reduce_lengths_cpu_kernel1(
}
Tensor _segment_reduce_lengths_cpu_kernel(
SegmentReductionType reduction,
ReductionType reduction,
const Tensor& data,
const Tensor& lengths,
int64_t axis,
@ -171,7 +155,7 @@ Tensor _segment_reduce_lengths_cpu_kernel(
}
Tensor _segment_reduce_offsets_cpu_kernel(
SegmentReductionType reduction,
ReductionType reduction,
const Tensor& data,
const Tensor& offsets,
int64_t axis,
@ -201,7 +185,7 @@ void _segment_reduce_cpu_lengths_backward_kernel1(
const Tensor& grad_contig,
const Tensor& output_contig,
const Tensor& data_contig,
SegmentReductionType reduction,
ReductionType reduction,
const T* lengths_data,
int64_t axis,
const c10::optional<Scalar>& initial,
@ -234,7 +218,7 @@ void _segment_reduce_cpu_lengths_backward_kernel1(
const auto* values_data = data_contig.data_ptr<scalar_t>();
// Used to calculate exclusive prod
scalar_t initial_prod_value;
if (reduction == SegmentReductionType::PROD) {
if (reduction == ReductionType::PROD) {
if (initial.has_value()) {
initial_prod_value = initial.value().to<scalar_t>();
} else {
@ -265,8 +249,8 @@ void _segment_reduce_cpu_lengths_backward_kernel1(
for (const auto inner_idx : c10::irange(inner_offset)) {
int64_t output_index = outer_idx * output_stride_axis * output_size_axis
+ dim_idx * output_stride_axis + inner_idx;
if (reduction == SegmentReductionType::MAX ||
reduction == SegmentReductionType::MIN) {
if (reduction == ReductionType::MAX ||
reduction == ReductionType::MIN) {
int64_t counter = 0;
for (const auto j : c10::irange(segment_start, segment_end)) {
int64_t data_index = outer_idx * data_stride_axis * data_size_axis
@ -290,21 +274,21 @@ void _segment_reduce_cpu_lengths_backward_kernel1(
grad_input_data[data_index] / counter;
}
}
} else if (reduction == SegmentReductionType::MEAN) {
} else if (reduction == ReductionType::MEAN) {
auto grad_val = grad_data[output_index] / segment_length;
for (const auto j : c10::irange(segment_start, segment_end)) {
int64_t data_index = outer_idx * data_stride_axis * data_size_axis
+ j * data_stride_axis + inner_idx;
grad_input_data[data_index] = grad_val;
}
} else if (reduction == SegmentReductionType::SUM) {
} else if (reduction == ReductionType::SUM) {
const auto& grad_val = grad_data[output_index];
for (const auto j : c10::irange(segment_start, segment_end)) {
int64_t data_index = outer_idx * data_stride_axis * data_size_axis
+ j * data_stride_axis + inner_idx;
grad_input_data[data_index] = grad_val;
}
} else if (reduction == SegmentReductionType::PROD) {
} else if (reduction == ReductionType::PROD) {
const auto& grad_val = grad_data[output_index] * output_data[output_index];
for (const auto j : c10::irange(segment_start, segment_end)) {
int64_t data_index = outer_idx * data_stride_axis * data_size_axis
@ -337,7 +321,7 @@ Tensor _segment_reduce_cpu_lengths_backward_kernel(
const Tensor& grad_contig,
const Tensor& output_contig,
const Tensor& data_contig,
SegmentReductionType reduction,
ReductionType reduction,
const Tensor& lengths_contig,
int64_t axis,
const c10::optional<Scalar>& initial) {
@ -370,7 +354,7 @@ Tensor _segment_reduce_cpu_offsets_backward_kernel(
const Tensor& grad_contig,
const Tensor& output_contig,
const Tensor& data_contig,
SegmentReductionType reduction,
ReductionType reduction,
const Tensor& offsets_contig,
int64_t axis,
const c10::optional<Scalar>& initial) {

View File

@ -1,6 +1,7 @@
#pragma once
#include <ATen/native/DispatchStub.h>
#include <ATen/native/ReductionType.h>
#include <c10/core/Scalar.h>
#include <c10/util/Optional.h>
@ -9,10 +10,8 @@ class Tensor;
namespace native {
enum SegmentReductionType { MAX, MEAN, MIN, SUM, PROD};
using segment_reduce_lengths_fn = Tensor (*)(
SegmentReductionType,
ReductionType,
const Tensor&,
const Tensor&,
int64_t,
@ -20,7 +19,7 @@ using segment_reduce_lengths_fn = Tensor (*)(
DECLARE_DISPATCH(segment_reduce_lengths_fn, _segment_reduce_lengths_stub);
using segment_reduce_offsets_fn = Tensor (*)(
SegmentReductionType,
ReductionType,
const Tensor&,
const Tensor&,
int64_t,
@ -31,7 +30,7 @@ using segment_reduce_lengths_backward_fn = Tensor (*)(
const Tensor&,
const Tensor&,
const Tensor&,
SegmentReductionType,
ReductionType,
const Tensor&,
int64_t,
const c10::optional<Scalar>&);
@ -41,7 +40,7 @@ using segment_reduce_offsets_backward_fn = Tensor (*)(
const Tensor&,
const Tensor&,
const Tensor&,
SegmentReductionType,
ReductionType,
const Tensor&,
int64_t,
const c10::optional<Scalar>&);

View File

@ -148,32 +148,6 @@ AdvancedIndex make_info(Tensor self, IOptTensorListRef orig);
namespace meta {
native::SCATTER_GATHER_OP get_operator_enum(const c10::string_view reduce, bool use_new_options = false) {
if (use_new_options) {
if (reduce == "sum") {
return native::SCATTER_GATHER_OP::REDUCE_ADD;
} else if (reduce == "prod") {
return native::SCATTER_GATHER_OP::REDUCE_MULTIPLY;
} else if (reduce == "mean") {
return native::SCATTER_GATHER_OP::REDUCE_MEAN;
} else if (reduce == "amax") {
return native::SCATTER_GATHER_OP::REDUCE_MAXIMUM;
} else if (reduce == "amin") {
return native::SCATTER_GATHER_OP::REDUCE_MINIMUM;
} else {
TORCH_CHECK(false, "reduce argument must be either sum, prod, mean, amax or amin.");
}
} else {
if (reduce == "add") {
return native::SCATTER_GATHER_OP::REDUCE_ADD;
} else if (reduce == "multiply") {
return native::SCATTER_GATHER_OP::REDUCE_MULTIPLY;
} else {
TORCH_CHECK(false, "reduce argument must be either add or multiply.")
}
}
}
TORCH_META_FUNC(gather)
(const Tensor & self, int64_t dim, const Tensor & index, bool sparse_grad) {
const Tensor& result = maybe_get_output(0);
@ -227,7 +201,7 @@ void scatter_meta_impl(
meta.set_output_raw_strided(0, self.sizes(), {}, self.options());
if (reduce.has_value()) {
// Check if we have a valid reduce operator.
get_operator_enum(reduce.value(), use_new_options);
at::native::get_operator_enum(reduce.value(), use_new_options);
}
}
@ -975,7 +949,7 @@ void index_reduce_func_impl(
const Tensor& source,
bool include_self,
const Tensor& result,
const SCATTER_GATHER_OP& op) {
const ReductionType& op) {
if (!result.is_same(self)) result.copy_(self);
if (!include_self) {
AT_DISPATCH_ALL_TYPES_AND2(
@ -983,14 +957,14 @@ void index_reduce_func_impl(
self.scalar_type(), "index_reduce_func_exclude_input_init", [&] {
scalar_t init_val;
switch (op) {
case SCATTER_GATHER_OP::REDUCE_MULTIPLY:
case ReductionType::PROD:
init_val = (scalar_t)1;
break;
case SCATTER_GATHER_OP::REDUCE_MAXIMUM:
case ReductionType::MAX:
init_val = std::numeric_limits<scalar_t>::has_infinity ? -std::numeric_limits<scalar_t>::infinity()
: std::numeric_limits<scalar_t>::lowest();
break;
case SCATTER_GATHER_OP::REDUCE_MINIMUM:
case ReductionType::MIN:
init_val = std::numeric_limits<scalar_t>::has_infinity ? std::numeric_limits<scalar_t>::infinity()
: std::numeric_limits<scalar_t>::max();
break;
@ -1037,13 +1011,13 @@ void index_reduce_func_impl(
iter.unsafe_replace_operand(2, source_data);
switch (op) {
case SCATTER_GATHER_OP::REDUCE_MULTIPLY :
case ReductionType::PROD :
mul_stub(iter.device_type(), iter);
break;
case SCATTER_GATHER_OP::REDUCE_MINIMUM :
case ReductionType::MIN :
minimum_stub(iter.device_type(), iter);
break;
case SCATTER_GATHER_OP::REDUCE_MAXIMUM :
case ReductionType::MAX :
maximum_stub(iter.device_type(), iter);
break;
default :
@ -1053,7 +1027,7 @@ void index_reduce_func_impl(
}
});
if (op == SCATTER_GATHER_OP::REDUCE_MEAN) {
if (op == ReductionType::MEAN) {
auto counts = include_self ? at::ones_like(result) : at::zeros_like(result);
counts.index_add_(dim, index, at::ones_like(source));
counts.masked_fill_(counts == 0, 1);
@ -1088,19 +1062,19 @@ void index_reduce_func_impl(
scalar_t *count_ip;
scalar_t val;
switch (op) {
case SCATTER_GATHER_OP::REDUCE_MEAN :
case ReductionType::MEAN :
*self_ip += *(source_ptr + i * source_stride);
count_ip = counts_ptr + self_i * counts_stride;
*count_ip += 1;
break;
case SCATTER_GATHER_OP::REDUCE_MULTIPLY :
case ReductionType::PROD :
*self_ip *= *(source_ptr + i * source_stride);
break;
case SCATTER_GATHER_OP::REDUCE_MINIMUM :
case ReductionType::MIN :
val = *(source_ptr + i * source_stride);
*self_ip = at::_isnan<scalar_t>(val) ? val : std::min(*self_ip, val);
break;
case SCATTER_GATHER_OP::REDUCE_MAXIMUM :
case ReductionType::MAX :
val = *(source_ptr + i * source_stride);
*self_ip = at::_isnan<scalar_t>(val) ? val : std::max(*self_ip, val);
break;
@ -1110,7 +1084,7 @@ void index_reduce_func_impl(
}
});
});
if (op == SCATTER_GATHER_OP::REDUCE_MEAN) {
if (op == ReductionType::MEAN) {
counts.masked_fill_(counts == 0, 1);
if (result.is_floating_point() || result.is_complex()) {
result.div_(counts);
@ -1130,7 +1104,7 @@ TORCH_IMPL_FUNC(index_reduce_cpu_out)
bool include_input,
const Tensor& result) {
TORCH_WARN_ONCE("index_reduce() is in beta and the API may change at any time.");
auto op = meta::get_operator_enum(reduce, true);
auto op = get_operator_enum(reduce, true);
index_reduce_func_impl(self, dim, index, source, include_input, result, op);
}
@ -1511,27 +1485,27 @@ static void scatter_reduce_exclude_self_helper(
const Tensor& self,
int64_t dim,
const Tensor& index,
const SCATTER_GATHER_OP& op) {
const ReductionType& op) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool,
self.scalar_type(), "scatter_reduce_exclude_input_init", [&] {
scalar_t init_val;
switch (op) {
case SCATTER_GATHER_OP::REDUCE_ADD:
case ReductionType::SUM:
init_val = (scalar_t)0;
break;
case SCATTER_GATHER_OP::REDUCE_MULTIPLY:
case ReductionType::PROD:
init_val = (scalar_t)1;
break;
case SCATTER_GATHER_OP::REDUCE_MAXIMUM:
case ReductionType::MAX:
init_val = std::numeric_limits<scalar_t>::has_infinity ? -std::numeric_limits<scalar_t>::infinity()
: std::numeric_limits<scalar_t>::lowest();
break;
case SCATTER_GATHER_OP::REDUCE_MINIMUM:
case ReductionType::MIN:
init_val = std::numeric_limits<scalar_t>::has_infinity ? std::numeric_limits<scalar_t>::infinity()
: std::numeric_limits<scalar_t>::max();
break;
case SCATTER_GATHER_OP::REDUCE_MEAN:
case ReductionType::MEAN:
init_val = (scalar_t)0;
break;
}
@ -1561,7 +1535,7 @@ void scatter_impl(
if (index.numel() == 0) return;
if (reduce.has_value()) {
auto op = meta::get_operator_enum(reduce.value(), use_new_options);
auto op = get_operator_enum(reduce.value(), use_new_options);
if (!reduce_includes_self) {
// scatter inits for reduction to appropriate indices (used by scatter_reduce.two)
scatter_reduce_exclude_self_helper(mut_out, dim, index, op);
@ -1756,7 +1730,7 @@ TORCH_IMPL_FUNC(scatter_reduce_two)
out.copy_(self);
}
const auto op = meta::get_operator_enum(reduce, true);
const auto op = get_operator_enum(reduce, true);
if (can_use_expanded_index_path(out, dim, index, src, /*is_scatter_like*/true)) {
scatter_reduce_expanded_index_stub(self.device().type(), out, index, src, op, include_self);
@ -1769,7 +1743,7 @@ TORCH_IMPL_FUNC(scatter_reduce_two)
reduce,
include_self);
if (op == SCATTER_GATHER_OP::REDUCE_MEAN) {
if (op == ReductionType::MEAN) {
auto ones = at::ones_like(src);
auto count = include_self ? at::ones_like(out) : at::zeros_like(out);
count.scatter_add_(dim, index, ones);

View File

@ -5,6 +5,7 @@
#include <ATen/core/List.h>
#include <ATen/core/Tensor.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/ReductionType.h>
#include <ATen/native/cpu/radix_sort.h>
namespace at {
@ -13,8 +14,6 @@ struct TensorIterator;
namespace at { namespace native {
enum class SCATTER_GATHER_OP: uint8_t {REDUCE_ADD, REDUCE_MULTIPLY, REDUCE_MAXIMUM, REDUCE_MINIMUM, REDUCE_MEAN};
using index_put_with_sort_fn = void(*)(Tensor &, const c10::List<c10::optional<Tensor>> &, const Tensor &, bool accumulate, bool unsafe);
using index_put_with_sort_quantized_fn = void(*)(Tensor& self, const c10::List<c10::optional<Tensor>>& indices, const Tensor& value, double scale, int zero_point, bool unsafe);
using gather_fn = void (*)(const Tensor & result, const Tensor & self, int64_t dim, const Tensor & index);
@ -22,11 +21,11 @@ using scatter_fn = void(*)(const Tensor& self, int64_t dim, const Tensor& index,
using scatter_fill_fn = void(*)(const Tensor& self, int64_t dim, const Tensor& index, const Scalar& src);
using scatter_add_fn = void(*)(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src);
using scatter_reduce_fn = void(*)(const Tensor& self, const int64_t dim, const Tensor& index,
const Tensor& src, const SCATTER_GATHER_OP& reduce);
const Tensor& src, const ReductionType& reduce);
using scatter_scalar_reduce_fn = void(*)(const Tensor& self, const int64_t dim, const Tensor& index,
const Scalar& value, const SCATTER_GATHER_OP& reduce);
const Scalar& value, const ReductionType& reduce);
using scatter_reduce_two_fn = void(*)(const Tensor& self, const int64_t dim, const Tensor& index,
const Tensor& src, const SCATTER_GATHER_OP& reduce);
const Tensor& src, const ReductionType& reduce);
DECLARE_DISPATCH(index_put_with_sort_fn, index_put_with_sort_stub);
DECLARE_DISPATCH(index_put_with_sort_quantized_fn, index_put_with_sort_quantized_stub);
@ -95,7 +94,7 @@ static inline bool can_use_expanded_index_path(
}
using scatter_add_expanded_index_fn = void(*)(const Tensor&, const Tensor&, const Tensor&);
using scatter_reduce_expanded_index_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const SCATTER_GATHER_OP& reduce, bool);
using scatter_reduce_expanded_index_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const ReductionType& reduce, bool);
using gather_expanded_index_fn = void (*)(const Tensor&, const Tensor&, const Tensor&);
DECLARE_DISPATCH(scatter_add_expanded_index_fn, scatter_add_expanded_index_stub);

View File

@ -573,19 +573,19 @@ struct cpu_scatter_gather_base_kernel {
}
};
template <typename scalar_t, SCATTER_GATHER_OP reduce>
template <typename scalar_t, ReductionType reduce>
inline void init(scalar_t* ptr, int64_t size, bool include_self) {
if (!include_self) {
using acc_t = vec::vec_scalar_t<scalar_t>;
using Vec = vec::Vectorized<acc_t>;
acc_t val;
if (reduce == SCATTER_GATHER_OP::REDUCE_ADD ||
reduce == SCATTER_GATHER_OP::REDUCE_MEAN) {
if (reduce == ReductionType::SUM ||
reduce == ReductionType::MEAN) {
val = static_cast<acc_t>(0);
} else if (reduce == SCATTER_GATHER_OP::REDUCE_MULTIPLY) {
} else if (reduce == ReductionType::PROD) {
val = static_cast<acc_t>(1);
} else if (reduce == SCATTER_GATHER_OP::REDUCE_MAXIMUM) {
} else if (reduce == ReductionType::MAX) {
val = std::numeric_limits<acc_t>::lowest();
} else {
val = std::numeric_limits<acc_t>::max();
@ -598,14 +598,14 @@ inline void init(scalar_t* ptr, int64_t size, bool include_self) {
}
}
template <typename vec_t, SCATTER_GATHER_OP reduce>
template <typename vec_t, ReductionType reduce>
inline vec_t update(const vec_t& x, const vec_t& y) {
if (reduce == SCATTER_GATHER_OP::REDUCE_ADD ||
reduce == SCATTER_GATHER_OP::REDUCE_MEAN) {
if (reduce == ReductionType::SUM ||
reduce == ReductionType::MEAN) {
return x + y;
} else if (reduce == SCATTER_GATHER_OP::REDUCE_MULTIPLY) {
} else if (reduce == ReductionType::PROD) {
return x * y;
} else if (reduce == SCATTER_GATHER_OP::REDUCE_MAXIMUM) {
} else if (reduce == ReductionType::MAX) {
return vec::maximum(x, y);
} else {
return vec::minimum(x, y);
@ -635,7 +635,7 @@ inline vec_t update(const vec_t& x, const vec_t& y) {
//
// step 2: spmm reduce, parallel on M and vectorize on K
//
template <typename scalar_t, SCATTER_GATHER_OP reduce>
template <typename scalar_t, ReductionType reduce>
void cpu_scatter_reduce_expanded_index(const Tensor& self, const Tensor& index, const Tensor& src, bool include_self) {
int64_t* index_data = index.data_ptr<int64_t>();
scalar_t* self_data = self.data_ptr<scalar_t>();
@ -735,7 +735,7 @@ void cpu_scatter_reduce_expanded_index(const Tensor& self, const Tensor& index,
K);
}
if (reduce == SCATTER_GATHER_OP::REDUCE_MEAN) {
if (reduce == ReductionType::MEAN) {
int64_t count = include_self ? 1 : 0;
count += off_end - off_start;
if (count != 0) {
@ -791,30 +791,30 @@ void cpu_gather_expanded_index_kernel(const Tensor& result, const Tensor& index,
void scatter_add_expanded_index_kernel(const Tensor& self, const Tensor& index, const Tensor& src) {
AT_DISPATCH_FLOATING_TYPES_AND(
ScalarType::BFloat16, self.scalar_type(), "scatter_add_expanded_index", [&] {
cpu_scatter_reduce_expanded_index<scalar_t, SCATTER_GATHER_OP::REDUCE_ADD>(self, index, src, /*include_self*/true);
cpu_scatter_reduce_expanded_index<scalar_t, ReductionType::SUM>(self, index, src, /*include_self*/true);
});
}
void scatter_reduce_expanded_index_kernel(
const Tensor& self, const Tensor& index, const Tensor& src,
const SCATTER_GATHER_OP& reduce, bool include_self) {
const ReductionType& reduce, bool include_self) {
AT_DISPATCH_FLOATING_TYPES_AND(
ScalarType::BFloat16, self.scalar_type(), "scatter_reduce_expanded_index", [&] {
switch (reduce) {
case SCATTER_GATHER_OP::REDUCE_ADD :
cpu_scatter_reduce_expanded_index<scalar_t, SCATTER_GATHER_OP::REDUCE_ADD>(self, index, src, include_self);
case ReductionType::SUM :
cpu_scatter_reduce_expanded_index<scalar_t, ReductionType::SUM>(self, index, src, include_self);
break;
case SCATTER_GATHER_OP::REDUCE_MULTIPLY :
cpu_scatter_reduce_expanded_index<scalar_t, SCATTER_GATHER_OP::REDUCE_MULTIPLY>(self, index, src, include_self);
case ReductionType::PROD :
cpu_scatter_reduce_expanded_index<scalar_t, ReductionType::PROD>(self, index, src, include_self);
break;
case SCATTER_GATHER_OP::REDUCE_MAXIMUM :
cpu_scatter_reduce_expanded_index<scalar_t, SCATTER_GATHER_OP::REDUCE_MAXIMUM>(self, index, src, include_self);
case ReductionType::MAX :
cpu_scatter_reduce_expanded_index<scalar_t, ReductionType::MAX>(self, index, src, include_self);
break;
case SCATTER_GATHER_OP::REDUCE_MINIMUM :
cpu_scatter_reduce_expanded_index<scalar_t, SCATTER_GATHER_OP::REDUCE_MINIMUM>(self, index, src, include_self);
case ReductionType::MIN :
cpu_scatter_reduce_expanded_index<scalar_t, ReductionType::MIN>(self, index, src, include_self);
break;
case SCATTER_GATHER_OP::REDUCE_MEAN :
cpu_scatter_reduce_expanded_index<scalar_t, SCATTER_GATHER_OP::REDUCE_MEAN>(self, index, src, include_self);
case ReductionType::MEAN :
cpu_scatter_reduce_expanded_index<scalar_t, ReductionType::MEAN>(self, index, src, include_self);
break;
}
});
@ -850,13 +850,13 @@ void scatter_add_cpu_kernel(const Tensor& self, int64_t dim, const Tensor& index
}
void scatter_reduce_cpu_kernel(const Tensor& self, const int64_t dim, const Tensor& index,
const Tensor& src, const SCATTER_GATHER_OP& reduce) {
const Tensor& src, const ReductionType& reduce) {
switch (reduce) {
case SCATTER_GATHER_OP::REDUCE_ADD :
case ReductionType::SUM :
cpu_scatter_gather_base_kernel<>()(self, dim, index, src,
"scatter_reduce_add_", reduce_add);
break;
case SCATTER_GATHER_OP::REDUCE_MULTIPLY :
case ReductionType::PROD :
cpu_scatter_gather_base_kernel<>()(self, dim, index, src,
"scatter_reduce_multiply_", reduce_multiply);
break;
@ -866,25 +866,25 @@ void scatter_reduce_cpu_kernel(const Tensor& self, const int64_t dim, const Tens
}
void scatter_reduce_two_cpu_kernel(const Tensor& self, const int64_t dim, const Tensor& index,
const Tensor& src, const SCATTER_GATHER_OP& reduce) {
const Tensor& src, const ReductionType& reduce) {
switch (reduce) {
case SCATTER_GATHER_OP::REDUCE_ADD :
case ReductionType::SUM :
cpu_scatter_gather_base_kernel<>()(self, dim, index, src,
"scatter_reduce_sum_", reduce_add);
break;
case SCATTER_GATHER_OP::REDUCE_MULTIPLY :
case ReductionType::PROD :
cpu_scatter_gather_base_kernel<>()(self, dim, index, src,
"scatter_reduce_prod_", reduce_multiply);
break;
case SCATTER_GATHER_OP::REDUCE_MAXIMUM :
case ReductionType::MAX :
cpu_scatter_gather_base_kernel<>()(self, dim, index, src,
"scatter_reduce_amax_", reduce_maximum);
break;
case SCATTER_GATHER_OP::REDUCE_MINIMUM :
case ReductionType::MIN :
cpu_scatter_gather_base_kernel<>()(self, dim, index, src,
"scatter_reduce_amin_", reduce_minimum);
break;
case SCATTER_GATHER_OP::REDUCE_MEAN :
case ReductionType::MEAN :
cpu_scatter_gather_base_kernel<>()(self, dim, index, src,
"scatter_reduce_mean_", reduce_mean);
break;
@ -892,13 +892,13 @@ void scatter_reduce_two_cpu_kernel(const Tensor& self, const int64_t dim, const
}
void scatter_scalar_reduce_cpu_kernel(const Tensor& self, const int64_t dim, const Tensor& index,
const Scalar& value, const SCATTER_GATHER_OP& reduce) {
const Scalar& value, const ReductionType& reduce) {
switch (reduce) {
case SCATTER_GATHER_OP::REDUCE_ADD :
case ReductionType::SUM :
cpu_scatter_gather_base_kernel<>()(self, dim, index, value,
"scatter_scalar_reduce_add_", reduce_add);
break;
case SCATTER_GATHER_OP::REDUCE_MULTIPLY :
case ReductionType::PROD :
cpu_scatter_gather_base_kernel<>()(self, dim, index, value,
"scatter_scalar_reduce_multiply_", reduce_multiply);
break;

View File

@ -865,7 +865,7 @@ void index_reduce_func_cuda_impl(
const Tensor& index,
const Tensor& source,
bool include_self,
const SCATTER_GATHER_OP& reduce,
const ReductionType& reduce,
const func_t& reduce_func,
const Tensor& result) {
globalContext().alertNotDeterministic("index_reduce_cuda");
@ -886,14 +886,14 @@ void index_reduce_func_cuda_impl(
self.scalar_type(), "index_reduce_func_cuda_exclude_input_init", [&] {
scalar_t init_val;
switch (reduce) {
case SCATTER_GATHER_OP::REDUCE_MULTIPLY:
case ReductionType::PROD:
init_val = (scalar_t)1;
break;
case SCATTER_GATHER_OP::REDUCE_MAXIMUM:
case ReductionType::MAX:
init_val = std::numeric_limits<scalar_t>::has_infinity ? -std::numeric_limits<scalar_t>::infinity()
: std::numeric_limits<scalar_t>::lowest();
break;
case SCATTER_GATHER_OP::REDUCE_MINIMUM:
case ReductionType::MIN:
init_val = std::numeric_limits<scalar_t>::has_infinity ? std::numeric_limits<scalar_t>::infinity()
: std::numeric_limits<scalar_t>::max();
break;
@ -1047,9 +1047,9 @@ TORCH_IMPL_FUNC(index_reduce_cuda_out)
TORCH_WARN_ONCE("index_reduce() is in beta and the API may change at any time.");
if (reduce == "prod") {
index_reduce_func_cuda_impl(self, dim, index, source, include_self, SCATTER_GATHER_OP::REDUCE_MULTIPLY, reduce_multiply, result);
index_reduce_func_cuda_impl(self, dim, index, source, include_self, ReductionType::PROD, reduce_multiply, result);
} else if (reduce == "mean") {
index_reduce_func_cuda_impl(self, dim, index, source, include_self, SCATTER_GATHER_OP::REDUCE_MEAN, reduce_add, result);
index_reduce_func_cuda_impl(self, dim, index, source, include_self, ReductionType::MEAN, reduce_add, result);
auto counts = include_self ? at::ones_like(result) : at::zeros_like(result);
counts.index_add_(dim, index, at::ones_like(source));
counts.masked_fill_(counts == 0, 1);
@ -1059,9 +1059,9 @@ TORCH_IMPL_FUNC(index_reduce_cuda_out)
result.div_(counts, "floor");
}
} else if (reduce == "amax") {
index_reduce_func_cuda_impl(self, dim, index, source, include_self, SCATTER_GATHER_OP::REDUCE_MAXIMUM, reduce_maximum, result);
index_reduce_func_cuda_impl(self, dim, index, source, include_self, ReductionType::MAX, reduce_maximum, result);
} else if (reduce == "amin") {
index_reduce_func_cuda_impl(self, dim, index, source, include_self, SCATTER_GATHER_OP::REDUCE_MINIMUM, reduce_minimum, result);
index_reduce_func_cuda_impl(self, dim, index, source, include_self, ReductionType::MIN, reduce_minimum, result);
} else {
TORCH_CHECK(false, "reduce argument must be either prod, mean, amax or amin, got ", reduce, ".");
}

View File

@ -496,13 +496,13 @@ void scatter_add_cuda_kernel(const Tensor& self, int64_t dim, const Tensor& inde
}
void scatter_reduce_cuda_kernel(const Tensor& self, const int64_t dim, const Tensor& index,
const Tensor& src, const SCATTER_GATHER_OP& reduce) {
const Tensor& src, const ReductionType& reduce) {
switch (reduce) {
case SCATTER_GATHER_OP::REDUCE_ADD :
case ReductionType::SUM :
cuda_scatter_gather_base_kernel<true, false>()(self, dim, index, src,
"scatter_reduce_cuda_add_", reduce_add);
break;
case SCATTER_GATHER_OP::REDUCE_MULTIPLY :
case ReductionType::PROD :
cuda_scatter_gather_base_kernel<true, false>()(self, dim, index, src,
"scatter_reduce_cuda_multiply_", reduce_multiply);
break;
@ -512,26 +512,26 @@ void scatter_reduce_cuda_kernel(const Tensor& self, const int64_t dim, const Ten
}
void scatter_reduce_two_cuda_kernel(const Tensor& self, const int64_t dim, const Tensor& index,
const Tensor& src, const SCATTER_GATHER_OP& reduce) {
const Tensor& src, const ReductionType& reduce) {
globalContext().alertNotDeterministic("scatter_reduce_cuda");
switch (reduce) {
case SCATTER_GATHER_OP::REDUCE_ADD :
case ReductionType::SUM :
cuda_scatter_gather_base_kernel<true, false>()(self, dim, index, src,
"scatter_reduce_cuda_sum_", reduce_add);
break;
case SCATTER_GATHER_OP::REDUCE_MULTIPLY :
case ReductionType::PROD :
cuda_scatter_gather_base_kernel<true, false>()(self, dim, index, src,
"scatter_reduce_cuda_prod_", reduce_multiply);
break;
case SCATTER_GATHER_OP::REDUCE_MAXIMUM :
case ReductionType::MAX :
cuda_scatter_gather_base_kernel<true, false>()(self, dim, index, src,
"scatter_reduce_cuda_amax_", reduce_maximum);
break;
case SCATTER_GATHER_OP::REDUCE_MINIMUM :
case ReductionType::MIN :
cuda_scatter_gather_base_kernel<true, false>()(self, dim, index, src,
"scatter_reduce_cuda_amin_", reduce_minimum);
break;
case SCATTER_GATHER_OP::REDUCE_MEAN :
case ReductionType::MEAN :
cuda_scatter_gather_base_kernel<true, false>()(self, dim, index, src,
"scatter_reduce_cuda_mean_", reduce_mean);
break;
@ -539,13 +539,13 @@ void scatter_reduce_two_cuda_kernel(const Tensor& self, const int64_t dim, const
}
void scatter_scalar_reduce_cuda_kernel(const Tensor& self, const int64_t dim, const Tensor& index,
const Scalar& value, const SCATTER_GATHER_OP& reduce) {
const Scalar& value, const ReductionType& reduce) {
switch (reduce) {
case SCATTER_GATHER_OP::REDUCE_ADD :
case ReductionType::SUM :
cuda_scatter_fill_base_kernel<false>()(self, dim, index, value,
"scatter_fill_cuda_add_", reduce_add);
break;
case SCATTER_GATHER_OP::REDUCE_MULTIPLY :
case ReductionType::PROD :
cuda_scatter_fill_base_kernel<false>()(self, dim, index, value,
"scatter_fill_cuda_multiply_", reduce_multiply);
break;

View File

@ -104,7 +104,7 @@ __global__ static void post_sum_div_kernel(
template <typename scalar_t, typename index_t>
__global__ void segment_reduce_forward_kernel(
SegmentReductionType reduction,
ReductionType reduction,
scalar_t* output_data,
scalar_t* values_data,
const index_t* lengths_data,
@ -139,18 +139,18 @@ __global__ void segment_reduce_forward_kernel(
+ j * data_stride_axis + lane_id;
const auto data = values_data[data_index];
// TODO: There is no need to branch with every element
if (reduction == SegmentReductionType::MAX) {
if (reduction == ReductionType::MAX) {
initial_value =
at::_isnan(data) ? data : std::max<scalar_t>(initial_value, data);
} else if (
reduction == SegmentReductionType::MEAN ||
reduction == SegmentReductionType::SUM) {
reduction == ReductionType::MEAN ||
reduction == ReductionType::SUM) {
initial_value = initial_value + data;
} else if (reduction == SegmentReductionType::MIN) {
} else if (reduction == ReductionType::MIN) {
initial_value =
at::_isnan(data) ? data : std::min<scalar_t>(initial_value, data);
} else if (
reduction == SegmentReductionType::PROD) {
reduction == ReductionType::PROD) {
initial_value = initial_value * data;
}
}
@ -159,10 +159,10 @@ __global__ void segment_reduce_forward_kernel(
int64_t lengths_idx = outer_idx * lengths_stride_axis * segment_count + dim_idx;
CUDA_KERNEL_ASSERT(lengths_data[lengths_idx] >= 0);
if (lengths_data[lengths_idx] == 0 && !is_initial_set &&
reduction == SegmentReductionType::MEAN) {
reduction == ReductionType::MEAN) {
initial_value = static_cast<scalar_t>(NAN);
} else if (
reduction == SegmentReductionType::MEAN && lengths_data[lengths_idx] > 0 &&
reduction == ReductionType::MEAN && lengths_data[lengths_idx] > 0 &&
!at::_isnan(initial_value)) {
initial_value = initial_value / lengths_data[lengths_idx];
}
@ -174,7 +174,7 @@ __global__ void segment_reduce_forward_kernel(
template <typename scalar_t, typename index_t>
__global__ void segment_reduce_backward_kernel(
SegmentReductionType reduction,
ReductionType reduction,
scalar_t* grad_input_data,
scalar_t* grad_data,
scalar_t* output_data,
@ -213,8 +213,8 @@ __global__ void segment_reduce_backward_kernel(
int64_t output_index = outer_idx * output_stride_axis * output_size_axis
+ dim_idx * output_stride_axis + lane_id;
if (reduction == SegmentReductionType::MAX ||
reduction == SegmentReductionType::MIN) {
if (reduction == ReductionType::MAX ||
reduction == ReductionType::MIN) {
int64_t counter = 0;
for (int64_t j = offset_start; j < offset_end; ++j) {
int64_t data_index = outer_idx * data_stride_axis * data_size_axis
@ -238,21 +238,21 @@ __global__ void segment_reduce_backward_kernel(
grad_input_data[data_index] / counter;
}
}
} else if (reduction == SegmentReductionType::MEAN) {
} else if (reduction == ReductionType::MEAN) {
auto grad_val = grad_data[output_index] / segment_length;
for (int64_t j = offset_start; j < offset_end; ++j) {
int64_t data_index = outer_idx * data_stride_axis * data_size_axis
+ j * data_stride_axis + lane_id;
grad_input_data[data_index] = grad_val;
}
} else if (reduction == SegmentReductionType::SUM) {
} else if (reduction == ReductionType::SUM) {
const auto& grad_val = grad_data[output_index];
for (int64_t j = offset_start; j < offset_end; ++j) {
int64_t data_index = outer_idx * data_stride_axis * data_size_axis
+ j * data_stride_axis + lane_id;
grad_input_data[data_index] = grad_val;
}
} else if (reduction == SegmentReductionType::PROD) {
} else if (reduction == ReductionType::PROD) {
const auto& grad_val = grad_data[output_index] * output_data[output_index];
for (int64_t j = offset_start; j < offset_end; ++j) {
int64_t data_index = outer_idx * data_stride_axis * data_size_axis
@ -282,7 +282,7 @@ Tensor _segment_reduce_lengths_offsets_backward_cuda_kernel(
const Tensor& grad_contig,
const Tensor& output_contig,
const Tensor& data_contig,
SegmentReductionType reduction,
ReductionType reduction,
const Tensor& lengths_or_offsets_contig,
int64_t axis,
const c10::optional<Scalar>& initial,
@ -385,7 +385,7 @@ Tensor _segment_reduce_lengths_backward_cuda_kernel(
const Tensor& grad_contig,
const Tensor& output_contig,
const Tensor& data_contig,
SegmentReductionType reduction,
ReductionType reduction,
const Tensor& lengths_contig,
int64_t axis,
const c10::optional<Scalar>& initial) {
@ -397,7 +397,7 @@ Tensor _segment_reduce_offsets_backward_cuda_kernel(
const Tensor& grad_contig,
const Tensor& output_contig,
const Tensor& data_contig,
SegmentReductionType reduction,
ReductionType reduction,
const Tensor& offsets_contig,
int64_t axis,
const c10::optional<Scalar>& initial) {
@ -406,7 +406,7 @@ Tensor _segment_reduce_offsets_backward_cuda_kernel(
}
Tensor _segment_reduce_lengths_offsets_cuda_kernel(
SegmentReductionType reduction,
ReductionType reduction,
const Tensor& data,
const Tensor& lengths_or_offsets,
int64_t axis,
@ -474,15 +474,15 @@ Tensor _segment_reduce_lengths_offsets_cuda_kernel(
scalar_t initial_value;
if (initial.has_value()) {
initial_value = initial.value().to<scalar_t>();
} else if (reduction == SegmentReductionType::MAX) {
} else if (reduction == ReductionType::MAX) {
initial_value = -std::numeric_limits<scalar_t>::infinity();
} else if (
reduction == SegmentReductionType::MEAN ||
reduction == SegmentReductionType::SUM) {
reduction == ReductionType::MEAN ||
reduction == ReductionType::SUM) {
initial_value = 0;
} else if (reduction == SegmentReductionType::MIN) {
} else if (reduction == ReductionType::MIN) {
initial_value = std::numeric_limits<scalar_t>::infinity();
} else if (reduction == SegmentReductionType::PROD) {
} else if (reduction == ReductionType::PROD) {
initial_value = 1;
}
@ -511,7 +511,7 @@ Tensor _segment_reduce_lengths_offsets_cuda_kernel(
);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
if (reduction == SegmentReductionType::MAX) {
if (reduction == ReductionType::MAX) {
CustomMax max_op{};
CUB_WRAPPER(
cub::DeviceSegmentedReduce::Reduce,
@ -523,7 +523,7 @@ Tensor _segment_reduce_lengths_offsets_cuda_kernel(
max_op,
initial_value,
at::cuda::getCurrentCUDAStream());
} else if (reduction == SegmentReductionType::MEAN) {
} else if (reduction == ReductionType::MEAN) {
CustomSum sum_op{};
CUB_WRAPPER(
cub::DeviceSegmentedReduce::Reduce,
@ -547,7 +547,7 @@ Tensor _segment_reduce_lengths_offsets_cuda_kernel(
initial.has_value(),
initial_value);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else if (reduction == SegmentReductionType::MIN) {
} else if (reduction == ReductionType::MIN) {
CustomMin min_op{};
CUB_WRAPPER(
cub::DeviceSegmentedReduce::Reduce,
@ -559,7 +559,7 @@ Tensor _segment_reduce_lengths_offsets_cuda_kernel(
min_op,
initial_value,
at::cuda::getCurrentCUDAStream());
} else if (reduction == SegmentReductionType::SUM) {
} else if (reduction == ReductionType::SUM) {
CustomSum sum_op{};
CUB_WRAPPER(
cub::DeviceSegmentedReduce::Reduce,
@ -571,7 +571,7 @@ Tensor _segment_reduce_lengths_offsets_cuda_kernel(
sum_op,
initial_value,
at::cuda::getCurrentCUDAStream());
} else if (reduction == SegmentReductionType::PROD) {
} else if (reduction == ReductionType::PROD) {
CustomProd prod_op{};
CUB_WRAPPER(
cub::DeviceSegmentedReduce::Reduce,
@ -592,7 +592,7 @@ Tensor _segment_reduce_lengths_offsets_cuda_kernel(
}
Tensor _segment_reduce_lengths_cuda_kernel(
SegmentReductionType reduction,
ReductionType reduction,
const Tensor& data,
const Tensor& lengths,
int64_t axis,
@ -602,7 +602,7 @@ Tensor _segment_reduce_lengths_cuda_kernel(
}
Tensor _segment_reduce_offsets_cuda_kernel(
SegmentReductionType reduction,
ReductionType reduction,
const Tensor& data,
const Tensor& offsets,
int64_t axis,

View File

@ -18,17 +18,17 @@ from torch.testing._internal.common_utils import (
)
reductions = ["max", "mean", "min", "sum", "prod"]
reductions = ["amax", "mean", "amin", "sum", "prod"]
def get_default_value(initial_value, reduction):
if initial_value is not None:
return initial_value
if reduction == "max":
if reduction == "amax":
return -float("Inf")
elif reduction == "mean":
return float("nan")
elif reduction == "min":
elif reduction == "amin":
return float("Inf")
elif reduction == "sum":
return 0.0
@ -133,13 +133,13 @@ class TestSegmentReductions(TestCase):
check_backward = True if initial is not None else False
initial_value = initial
default_value = get_default_value(initial_value, reduction)
if reduction == "max":
if reduction == "amax":
expected_result = [1, float("nan"), 5, default_value]
expected_grad = [1, 1, 0, 0, 0.5, 0.5]
elif reduction == "mean":
expected_result = [1, float("nan"), 4.666, default_value]
expected_grad = [1.0, 0.5, 0.5, 0.333, 0.333, 0.333]
elif reduction == "min":
elif reduction == "amin":
if initial is not None:
initial_value = 1000 # some high number
default_value = get_default_value(initial_value, reduction)
@ -191,7 +191,7 @@ class TestSegmentReductions(TestCase):
check_backward = True if initial is not None else False
initial_value = initial
default_value = get_default_value(initial_value, reduction)
if reduction == "max":
if reduction == "amax":
expected_result = [
[1, 1],
[float("nan"), float("nan")],
@ -221,7 +221,7 @@ class TestSegmentReductions(TestCase):
[0.333, 0.333],
[0.333, 0.333],
]
elif reduction == "min":
elif reduction == "amin":
if initial is not None:
initial_value = 1000 # some high number
default_value = get_default_value(initial_value, reduction)
@ -308,7 +308,7 @@ class TestSegmentReductions(TestCase):
(torch.int, torch.int64),
)
)
@parametrize("reduce", ['sum', 'prod', 'min', 'max', 'mean'])
@parametrize("reduce", ['sum', 'prod', 'amin', 'amax', 'mean'])
def test_pytorch_scatter_test_cases(self, device, dtypes, reduce):
val_dtype, length_dtype = dtypes
# zero-length segments are filled with reduction inits contrary to pytorch_scatter.
@ -320,8 +320,8 @@ class TestSegmentReductions(TestCase):
'sum': [3, 12, 0, 6],
'prod': [2, 60, 1, 6],
'mean': [1.5, 4, float('nan'), 6],
'min': [1, 3, float('inf'), 6],
'max': [2, 5, -float('inf'), 6],
'amin': [1, 3, float('inf'), 6],
'amax': [2, 5, -float('inf'), 6],
},
{
'src': [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]],
@ -330,8 +330,8 @@ class TestSegmentReductions(TestCase):
'sum': [[4, 6], [21, 24], [0, 0], [11, 12]],
'prod': [[3, 8], [315, 480], [1, 1], [11, 12]],
'mean': [[2, 3], [7, 8], [float('nan'), float('nan')], [11, 12]],
'min': [[1, 2], [5, 6], [float('inf'), float('inf')], [11, 12]],
'max': [[3, 4], [9, 10], [-float('inf'), -float('inf')], [11, 12]],
'amin': [[1, 2], [5, 6], [float('inf'), float('inf')], [11, 12]],
'amax': [[3, 4], [9, 10], [-float('inf'), -float('inf')], [11, 12]],
},
{
'src': [[1, 3, 5, 7, 9, 11], [2, 4, 6, 8, 10, 12]],
@ -340,8 +340,8 @@ class TestSegmentReductions(TestCase):
'sum': [[4, 21, 0, 11], [12, 18, 12, 0]],
'prod': [[3, 315, 1, 11], [48, 80, 12, 1]],
'mean': [[2, 7, float('nan'), 11], [4, 9, 12, float('nan')]],
'min': [[1, 5, float('inf'), 11], [2, 8, 12, float('inf')]],
'max': [[3, 9, -float('inf'), 11], [6, 10, 12, -float('inf')]],
'amin': [[1, 5, float('inf'), 11], [2, 8, 12, float('inf')]],
'amax': [[3, 9, -float('inf'), 11], [6, 10, 12, -float('inf')]],
},
{
'src': [[[1, 2], [3, 4], [5, 6]], [[7, 9], [10, 11], [12, 13]]],
@ -351,10 +351,10 @@ class TestSegmentReductions(TestCase):
'prod': [[[3, 8], [5, 6], [1, 1]], [[7, 9], [1, 1], [120, 143]]],
'mean': [[[2, 3], [5, 6], [float('nan'), float('nan')]],
[[7, 9], [float('nan'), float('nan')], [11, 12]]],
'min': [[[1, 2], [5, 6], [float('inf'), float('inf')]],
[[7, 9], [float('inf'), float('inf')], [10, 11]]],
'max': [[[3, 4], [5, 6], [-float('inf'), -float('inf')]],
[[7, 9], [-float('inf'), -float('inf')], [12, 13]]],
'amin': [[[1, 2], [5, 6], [float('inf'), float('inf')]],
[[7, 9], [float('inf'), float('inf')], [10, 11]]],
'amax': [[[3, 4], [5, 6], [-float('inf'), -float('inf')]],
[[7, 9], [-float('inf'), -float('inf')], [12, 13]]],
},
{
'src': [[1, 3], [2, 4]],
@ -363,8 +363,8 @@ class TestSegmentReductions(TestCase):
'sum': [[4], [6]],
'prod': [[3], [8]],
'mean': [[2], [3]],
'min': [[1], [2]],
'max': [[3], [4]],
'amin': [[1], [2]],
'amax': [[3], [4]],
},
{
'src': [[[1, 1], [3, 3]], [[2, 2], [4, 4]]],
@ -373,8 +373,8 @@ class TestSegmentReductions(TestCase):
'sum': [[[4, 4]], [[6, 6]]],
'prod': [[[3, 3]], [[8, 8]]],
'mean': [[[2, 2]], [[3, 3]]],
'min': [[[1, 1]], [[2, 2]]],
'max': [[[3, 3]], [[4, 4]]],
'amin': [[[1, 1]], [[2, 2]]],
'amax': [[[3, 3]], [[4, 4]]],
},
]
for test in tests:
@ -409,9 +409,9 @@ class TestSegmentReductions(TestCase):
initial = 1
# supply initial values to prevent gradcheck from failing for 0 length segments
# where nan/inf are reduction identities that produce nans when calculating the numerical jacobian
if reduce == 'min':
if reduce == 'amin':
initial = 1000
elif reduce == 'max':
elif reduce == 'amax':
initial = -1000
segment_reduce_args = {x, reduce}
segment_reduce_kwargs = dict(axis=dim, unsafe=True, initial=initial)
@ -442,7 +442,7 @@ class TestSegmentReductions(TestCase):
for reduction in reductions:
initial_value = 0
if reduction == "max":
if reduction == "amax":
expected_result = [
np.full((2, 5), initial_value).tolist(),
np.max(data[:2], axis=0).tolist(),
@ -456,7 +456,7 @@ class TestSegmentReductions(TestCase):
np.mean(data[2:], axis=0).tolist(),
np.full((2, 5), initial_value).tolist(),
]
elif reduction == "min":
elif reduction == "amin":
initial_value = 1000 # some high number
expected_result = [
np.full((2, 5), initial_value).tolist(),

View File

@ -783,9 +783,6 @@ def _sparse_csr_segment_reduction_helper(
)
new_nnz = new_crow_indices[-1]
new_col_indices = col_indices.new_zeros(new_nnz)
# segment_reduce takes 'max'/'min' rather than 'amax'/'amin', changing this would be BC-breaking
if reduce in ["amax", "amin"]:
reduce = reduce[1:]
new_values = torch.segment_reduce(values, reduce, offsets=crow_indices)
new_shape = [mask_input.size(0), 1]
else:

View File

@ -6188,7 +6188,7 @@ def sample_inputs_segment_reduce(op_info, device, dtype, requires_grad, *, mode=
((S, S, S), 1, [[2, 0, 3, 0], [0, 1, 2, 2], [3, 0, 2, 0], [1, 1, 1, 2], [0, 1, 2, 2]], False),
)
reductions = ["max", "mean", "min", "sum", "prod"]
reductions = ["amax", "mean", "amin", "sum", "prod"]
for args, reduce, initial in product(test_cases, reductions, [1, 2]):
inp_shape, dim, lengths, unsafe = args
lengths_t = torch.tensor(lengths, dtype=torch.long, device=device)