mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
unify reduction types from different operators: scatter, scatter_reduce, segment_reduce (#91499)
The target of this PR is to unify `ReductionType` for reduce operators so that we have the same set of reduce utils for `init`, or `update` for vectorization. Pull Request resolved: https://github.com/pytorch/pytorch/pull/91499 Approved by: https://github.com/ngimel
This commit is contained in:
parent
a70387f0fa
commit
eb7b89771e
40
aten/src/ATen/native/ReductionType.h
Normal file
40
aten/src/ATen/native/ReductionType.h
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
#pragma once
|
||||
|
||||
#include <c10/core/Scalar.h>
|
||||
|
||||
namespace at { namespace native {
|
||||
|
||||
enum ReductionType {MAX, MEAN, MIN, SUM, PROD};
|
||||
|
||||
static inline ReductionType get_reduction_enum(const c10::string_view& reduce) {
|
||||
if (reduce == "amax") {
|
||||
return ReductionType::MAX;
|
||||
} else if (reduce == "mean") {
|
||||
return ReductionType::MEAN;
|
||||
} else if (reduce == "amin") {
|
||||
return ReductionType::MIN;
|
||||
} else if (reduce == "sum") {
|
||||
return ReductionType::SUM;
|
||||
} else if (reduce == "prod") {
|
||||
return ReductionType::PROD;
|
||||
} else {
|
||||
TORCH_CHECK(false, "reduce argument must be either sum, prod, mean, amax or amin, got ", reduce);
|
||||
}
|
||||
}
|
||||
|
||||
// used for `scatter_reduce`, old options for BC.
|
||||
static inline ReductionType get_operator_enum(const c10::string_view reduce, bool use_new_options) {
|
||||
if (use_new_options) {
|
||||
return get_reduction_enum(reduce);
|
||||
} else {
|
||||
if (reduce == "add") {
|
||||
return ReductionType::SUM;
|
||||
} else if (reduce == "multiply") {
|
||||
return ReductionType::PROD;
|
||||
} else {
|
||||
TORCH_CHECK(false, "reduce argument must be either add or multiply.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}} // at::native
|
||||
|
|
@ -28,25 +28,9 @@ DEFINE_DISPATCH(_segment_reduce_offsets_backward_stub);
|
|||
|
||||
namespace {
|
||||
|
||||
SegmentReductionType get_reduction_enum(const c10::string_view& reduce) {
|
||||
if (reduce == "max") {
|
||||
return SegmentReductionType::MAX;
|
||||
} else if (reduce == "mean") {
|
||||
return SegmentReductionType::MEAN;
|
||||
} else if (reduce == "min") {
|
||||
return SegmentReductionType::MIN;
|
||||
} else if (reduce == "sum") {
|
||||
return SegmentReductionType::SUM;
|
||||
} else if (reduce == "prod") {
|
||||
return SegmentReductionType::PROD;
|
||||
} else {
|
||||
TORCH_CHECK(false, "unsupported reduction given! ", reduce);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool is_offsets_like=false>
|
||||
void _segment_reduce_lengths_cpu_kernel1(
|
||||
SegmentReductionType reduction,
|
||||
ReductionType reduction,
|
||||
const Tensor& data,
|
||||
const T* lengths_data,
|
||||
int64_t axis,
|
||||
|
|
@ -90,15 +74,15 @@ void _segment_reduce_lengths_cpu_kernel1(
|
|||
scalar_t initial_value;
|
||||
if (initial.has_value()) {
|
||||
initial_value = initial.value().to<scalar_t>();
|
||||
} else if (reduction == SegmentReductionType::MAX) {
|
||||
} else if (reduction == ReductionType::MAX) {
|
||||
initial_value = -std::numeric_limits<scalar_t>::infinity();
|
||||
} else if (
|
||||
reduction == SegmentReductionType::MEAN ||
|
||||
reduction == SegmentReductionType::SUM) {
|
||||
reduction == ReductionType::MEAN ||
|
||||
reduction == ReductionType::SUM) {
|
||||
initial_value = 0;
|
||||
} else if (reduction == SegmentReductionType::MIN) {
|
||||
} else if (reduction == ReductionType::MIN) {
|
||||
initial_value = std::numeric_limits<scalar_t>::infinity();
|
||||
} else if (reduction == SegmentReductionType::PROD) {
|
||||
} else if (reduction == ReductionType::PROD) {
|
||||
initial_value = 1;
|
||||
}
|
||||
|
||||
|
|
@ -107,19 +91,19 @@ void _segment_reduce_lengths_cpu_kernel1(
|
|||
int64_t data_index = outer_idx * data_stride_axis * data_size_axis
|
||||
+ j * data_stride_axis + inner_idx;
|
||||
const auto val = values_data[data_index];
|
||||
if (reduction == SegmentReductionType::MAX) {
|
||||
if (reduction == ReductionType::MAX) {
|
||||
initial_value = at::_isnan(val)
|
||||
? val
|
||||
: std::max<scalar_t>(initial_value, val);
|
||||
} else if (
|
||||
reduction == SegmentReductionType::MEAN ||
|
||||
reduction == SegmentReductionType::SUM) {
|
||||
reduction == ReductionType::MEAN ||
|
||||
reduction == ReductionType::SUM) {
|
||||
initial_value = initial_value + val;
|
||||
} else if (reduction == SegmentReductionType::MIN) {
|
||||
} else if (reduction == ReductionType::MIN) {
|
||||
initial_value = at::_isnan(val)
|
||||
? val
|
||||
: std::min<scalar_t>(initial_value, val);
|
||||
} else if (reduction == SegmentReductionType::PROD) {
|
||||
} else if (reduction == ReductionType::PROD) {
|
||||
initial_value = initial_value * val;
|
||||
}
|
||||
}
|
||||
|
|
@ -128,10 +112,10 @@ void _segment_reduce_lengths_cpu_kernel1(
|
|||
TORCH_CHECK(segment_length >= 0);
|
||||
|
||||
if (segment_length == 0 && !initial.has_value() &&
|
||||
reduction == SegmentReductionType::MEAN) {
|
||||
reduction == ReductionType::MEAN) {
|
||||
initial_value = static_cast<scalar_t>(NAN);
|
||||
} else if (
|
||||
reduction == SegmentReductionType::MEAN &&
|
||||
reduction == ReductionType::MEAN &&
|
||||
segment_length > 0 && !at::_isnan(initial_value)) {
|
||||
initial_value = initial_value / segment_length;
|
||||
}
|
||||
|
|
@ -145,7 +129,7 @@ void _segment_reduce_lengths_cpu_kernel1(
|
|||
}
|
||||
|
||||
Tensor _segment_reduce_lengths_cpu_kernel(
|
||||
SegmentReductionType reduction,
|
||||
ReductionType reduction,
|
||||
const Tensor& data,
|
||||
const Tensor& lengths,
|
||||
int64_t axis,
|
||||
|
|
@ -171,7 +155,7 @@ Tensor _segment_reduce_lengths_cpu_kernel(
|
|||
}
|
||||
|
||||
Tensor _segment_reduce_offsets_cpu_kernel(
|
||||
SegmentReductionType reduction,
|
||||
ReductionType reduction,
|
||||
const Tensor& data,
|
||||
const Tensor& offsets,
|
||||
int64_t axis,
|
||||
|
|
@ -201,7 +185,7 @@ void _segment_reduce_cpu_lengths_backward_kernel1(
|
|||
const Tensor& grad_contig,
|
||||
const Tensor& output_contig,
|
||||
const Tensor& data_contig,
|
||||
SegmentReductionType reduction,
|
||||
ReductionType reduction,
|
||||
const T* lengths_data,
|
||||
int64_t axis,
|
||||
const c10::optional<Scalar>& initial,
|
||||
|
|
@ -234,7 +218,7 @@ void _segment_reduce_cpu_lengths_backward_kernel1(
|
|||
const auto* values_data = data_contig.data_ptr<scalar_t>();
|
||||
// Used to calculate exclusive prod
|
||||
scalar_t initial_prod_value;
|
||||
if (reduction == SegmentReductionType::PROD) {
|
||||
if (reduction == ReductionType::PROD) {
|
||||
if (initial.has_value()) {
|
||||
initial_prod_value = initial.value().to<scalar_t>();
|
||||
} else {
|
||||
|
|
@ -265,8 +249,8 @@ void _segment_reduce_cpu_lengths_backward_kernel1(
|
|||
for (const auto inner_idx : c10::irange(inner_offset)) {
|
||||
int64_t output_index = outer_idx * output_stride_axis * output_size_axis
|
||||
+ dim_idx * output_stride_axis + inner_idx;
|
||||
if (reduction == SegmentReductionType::MAX ||
|
||||
reduction == SegmentReductionType::MIN) {
|
||||
if (reduction == ReductionType::MAX ||
|
||||
reduction == ReductionType::MIN) {
|
||||
int64_t counter = 0;
|
||||
for (const auto j : c10::irange(segment_start, segment_end)) {
|
||||
int64_t data_index = outer_idx * data_stride_axis * data_size_axis
|
||||
|
|
@ -290,21 +274,21 @@ void _segment_reduce_cpu_lengths_backward_kernel1(
|
|||
grad_input_data[data_index] / counter;
|
||||
}
|
||||
}
|
||||
} else if (reduction == SegmentReductionType::MEAN) {
|
||||
} else if (reduction == ReductionType::MEAN) {
|
||||
auto grad_val = grad_data[output_index] / segment_length;
|
||||
for (const auto j : c10::irange(segment_start, segment_end)) {
|
||||
int64_t data_index = outer_idx * data_stride_axis * data_size_axis
|
||||
+ j * data_stride_axis + inner_idx;
|
||||
grad_input_data[data_index] = grad_val;
|
||||
}
|
||||
} else if (reduction == SegmentReductionType::SUM) {
|
||||
} else if (reduction == ReductionType::SUM) {
|
||||
const auto& grad_val = grad_data[output_index];
|
||||
for (const auto j : c10::irange(segment_start, segment_end)) {
|
||||
int64_t data_index = outer_idx * data_stride_axis * data_size_axis
|
||||
+ j * data_stride_axis + inner_idx;
|
||||
grad_input_data[data_index] = grad_val;
|
||||
}
|
||||
} else if (reduction == SegmentReductionType::PROD) {
|
||||
} else if (reduction == ReductionType::PROD) {
|
||||
const auto& grad_val = grad_data[output_index] * output_data[output_index];
|
||||
for (const auto j : c10::irange(segment_start, segment_end)) {
|
||||
int64_t data_index = outer_idx * data_stride_axis * data_size_axis
|
||||
|
|
@ -337,7 +321,7 @@ Tensor _segment_reduce_cpu_lengths_backward_kernel(
|
|||
const Tensor& grad_contig,
|
||||
const Tensor& output_contig,
|
||||
const Tensor& data_contig,
|
||||
SegmentReductionType reduction,
|
||||
ReductionType reduction,
|
||||
const Tensor& lengths_contig,
|
||||
int64_t axis,
|
||||
const c10::optional<Scalar>& initial) {
|
||||
|
|
@ -370,7 +354,7 @@ Tensor _segment_reduce_cpu_offsets_backward_kernel(
|
|||
const Tensor& grad_contig,
|
||||
const Tensor& output_contig,
|
||||
const Tensor& data_contig,
|
||||
SegmentReductionType reduction,
|
||||
ReductionType reduction,
|
||||
const Tensor& offsets_contig,
|
||||
int64_t axis,
|
||||
const c10::optional<Scalar>& initial) {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
#pragma once
|
||||
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
#include <ATen/native/ReductionType.h>
|
||||
#include <c10/core/Scalar.h>
|
||||
#include <c10/util/Optional.h>
|
||||
|
||||
|
|
@ -9,10 +10,8 @@ class Tensor;
|
|||
|
||||
namespace native {
|
||||
|
||||
enum SegmentReductionType { MAX, MEAN, MIN, SUM, PROD};
|
||||
|
||||
using segment_reduce_lengths_fn = Tensor (*)(
|
||||
SegmentReductionType,
|
||||
ReductionType,
|
||||
const Tensor&,
|
||||
const Tensor&,
|
||||
int64_t,
|
||||
|
|
@ -20,7 +19,7 @@ using segment_reduce_lengths_fn = Tensor (*)(
|
|||
DECLARE_DISPATCH(segment_reduce_lengths_fn, _segment_reduce_lengths_stub);
|
||||
|
||||
using segment_reduce_offsets_fn = Tensor (*)(
|
||||
SegmentReductionType,
|
||||
ReductionType,
|
||||
const Tensor&,
|
||||
const Tensor&,
|
||||
int64_t,
|
||||
|
|
@ -31,7 +30,7 @@ using segment_reduce_lengths_backward_fn = Tensor (*)(
|
|||
const Tensor&,
|
||||
const Tensor&,
|
||||
const Tensor&,
|
||||
SegmentReductionType,
|
||||
ReductionType,
|
||||
const Tensor&,
|
||||
int64_t,
|
||||
const c10::optional<Scalar>&);
|
||||
|
|
@ -41,7 +40,7 @@ using segment_reduce_offsets_backward_fn = Tensor (*)(
|
|||
const Tensor&,
|
||||
const Tensor&,
|
||||
const Tensor&,
|
||||
SegmentReductionType,
|
||||
ReductionType,
|
||||
const Tensor&,
|
||||
int64_t,
|
||||
const c10::optional<Scalar>&);
|
||||
|
|
|
|||
|
|
@ -148,32 +148,6 @@ AdvancedIndex make_info(Tensor self, IOptTensorListRef orig);
|
|||
|
||||
namespace meta {
|
||||
|
||||
native::SCATTER_GATHER_OP get_operator_enum(const c10::string_view reduce, bool use_new_options = false) {
|
||||
if (use_new_options) {
|
||||
if (reduce == "sum") {
|
||||
return native::SCATTER_GATHER_OP::REDUCE_ADD;
|
||||
} else if (reduce == "prod") {
|
||||
return native::SCATTER_GATHER_OP::REDUCE_MULTIPLY;
|
||||
} else if (reduce == "mean") {
|
||||
return native::SCATTER_GATHER_OP::REDUCE_MEAN;
|
||||
} else if (reduce == "amax") {
|
||||
return native::SCATTER_GATHER_OP::REDUCE_MAXIMUM;
|
||||
} else if (reduce == "amin") {
|
||||
return native::SCATTER_GATHER_OP::REDUCE_MINIMUM;
|
||||
} else {
|
||||
TORCH_CHECK(false, "reduce argument must be either sum, prod, mean, amax or amin.");
|
||||
}
|
||||
} else {
|
||||
if (reduce == "add") {
|
||||
return native::SCATTER_GATHER_OP::REDUCE_ADD;
|
||||
} else if (reduce == "multiply") {
|
||||
return native::SCATTER_GATHER_OP::REDUCE_MULTIPLY;
|
||||
} else {
|
||||
TORCH_CHECK(false, "reduce argument must be either add or multiply.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_META_FUNC(gather)
|
||||
(const Tensor & self, int64_t dim, const Tensor & index, bool sparse_grad) {
|
||||
const Tensor& result = maybe_get_output(0);
|
||||
|
|
@ -227,7 +201,7 @@ void scatter_meta_impl(
|
|||
meta.set_output_raw_strided(0, self.sizes(), {}, self.options());
|
||||
if (reduce.has_value()) {
|
||||
// Check if we have a valid reduce operator.
|
||||
get_operator_enum(reduce.value(), use_new_options);
|
||||
at::native::get_operator_enum(reduce.value(), use_new_options);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -975,7 +949,7 @@ void index_reduce_func_impl(
|
|||
const Tensor& source,
|
||||
bool include_self,
|
||||
const Tensor& result,
|
||||
const SCATTER_GATHER_OP& op) {
|
||||
const ReductionType& op) {
|
||||
if (!result.is_same(self)) result.copy_(self);
|
||||
if (!include_self) {
|
||||
AT_DISPATCH_ALL_TYPES_AND2(
|
||||
|
|
@ -983,14 +957,14 @@ void index_reduce_func_impl(
|
|||
self.scalar_type(), "index_reduce_func_exclude_input_init", [&] {
|
||||
scalar_t init_val;
|
||||
switch (op) {
|
||||
case SCATTER_GATHER_OP::REDUCE_MULTIPLY:
|
||||
case ReductionType::PROD:
|
||||
init_val = (scalar_t)1;
|
||||
break;
|
||||
case SCATTER_GATHER_OP::REDUCE_MAXIMUM:
|
||||
case ReductionType::MAX:
|
||||
init_val = std::numeric_limits<scalar_t>::has_infinity ? -std::numeric_limits<scalar_t>::infinity()
|
||||
: std::numeric_limits<scalar_t>::lowest();
|
||||
break;
|
||||
case SCATTER_GATHER_OP::REDUCE_MINIMUM:
|
||||
case ReductionType::MIN:
|
||||
init_val = std::numeric_limits<scalar_t>::has_infinity ? std::numeric_limits<scalar_t>::infinity()
|
||||
: std::numeric_limits<scalar_t>::max();
|
||||
break;
|
||||
|
|
@ -1037,13 +1011,13 @@ void index_reduce_func_impl(
|
|||
iter.unsafe_replace_operand(2, source_data);
|
||||
|
||||
switch (op) {
|
||||
case SCATTER_GATHER_OP::REDUCE_MULTIPLY :
|
||||
case ReductionType::PROD :
|
||||
mul_stub(iter.device_type(), iter);
|
||||
break;
|
||||
case SCATTER_GATHER_OP::REDUCE_MINIMUM :
|
||||
case ReductionType::MIN :
|
||||
minimum_stub(iter.device_type(), iter);
|
||||
break;
|
||||
case SCATTER_GATHER_OP::REDUCE_MAXIMUM :
|
||||
case ReductionType::MAX :
|
||||
maximum_stub(iter.device_type(), iter);
|
||||
break;
|
||||
default :
|
||||
|
|
@ -1053,7 +1027,7 @@ void index_reduce_func_impl(
|
|||
}
|
||||
});
|
||||
|
||||
if (op == SCATTER_GATHER_OP::REDUCE_MEAN) {
|
||||
if (op == ReductionType::MEAN) {
|
||||
auto counts = include_self ? at::ones_like(result) : at::zeros_like(result);
|
||||
counts.index_add_(dim, index, at::ones_like(source));
|
||||
counts.masked_fill_(counts == 0, 1);
|
||||
|
|
@ -1088,19 +1062,19 @@ void index_reduce_func_impl(
|
|||
scalar_t *count_ip;
|
||||
scalar_t val;
|
||||
switch (op) {
|
||||
case SCATTER_GATHER_OP::REDUCE_MEAN :
|
||||
case ReductionType::MEAN :
|
||||
*self_ip += *(source_ptr + i * source_stride);
|
||||
count_ip = counts_ptr + self_i * counts_stride;
|
||||
*count_ip += 1;
|
||||
break;
|
||||
case SCATTER_GATHER_OP::REDUCE_MULTIPLY :
|
||||
case ReductionType::PROD :
|
||||
*self_ip *= *(source_ptr + i * source_stride);
|
||||
break;
|
||||
case SCATTER_GATHER_OP::REDUCE_MINIMUM :
|
||||
case ReductionType::MIN :
|
||||
val = *(source_ptr + i * source_stride);
|
||||
*self_ip = at::_isnan<scalar_t>(val) ? val : std::min(*self_ip, val);
|
||||
break;
|
||||
case SCATTER_GATHER_OP::REDUCE_MAXIMUM :
|
||||
case ReductionType::MAX :
|
||||
val = *(source_ptr + i * source_stride);
|
||||
*self_ip = at::_isnan<scalar_t>(val) ? val : std::max(*self_ip, val);
|
||||
break;
|
||||
|
|
@ -1110,7 +1084,7 @@ void index_reduce_func_impl(
|
|||
}
|
||||
});
|
||||
});
|
||||
if (op == SCATTER_GATHER_OP::REDUCE_MEAN) {
|
||||
if (op == ReductionType::MEAN) {
|
||||
counts.masked_fill_(counts == 0, 1);
|
||||
if (result.is_floating_point() || result.is_complex()) {
|
||||
result.div_(counts);
|
||||
|
|
@ -1130,7 +1104,7 @@ TORCH_IMPL_FUNC(index_reduce_cpu_out)
|
|||
bool include_input,
|
||||
const Tensor& result) {
|
||||
TORCH_WARN_ONCE("index_reduce() is in beta and the API may change at any time.");
|
||||
auto op = meta::get_operator_enum(reduce, true);
|
||||
auto op = get_operator_enum(reduce, true);
|
||||
index_reduce_func_impl(self, dim, index, source, include_input, result, op);
|
||||
}
|
||||
|
||||
|
|
@ -1511,27 +1485,27 @@ static void scatter_reduce_exclude_self_helper(
|
|||
const Tensor& self,
|
||||
int64_t dim,
|
||||
const Tensor& index,
|
||||
const SCATTER_GATHER_OP& op) {
|
||||
const ReductionType& op) {
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
|
||||
at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool,
|
||||
self.scalar_type(), "scatter_reduce_exclude_input_init", [&] {
|
||||
scalar_t init_val;
|
||||
switch (op) {
|
||||
case SCATTER_GATHER_OP::REDUCE_ADD:
|
||||
case ReductionType::SUM:
|
||||
init_val = (scalar_t)0;
|
||||
break;
|
||||
case SCATTER_GATHER_OP::REDUCE_MULTIPLY:
|
||||
case ReductionType::PROD:
|
||||
init_val = (scalar_t)1;
|
||||
break;
|
||||
case SCATTER_GATHER_OP::REDUCE_MAXIMUM:
|
||||
case ReductionType::MAX:
|
||||
init_val = std::numeric_limits<scalar_t>::has_infinity ? -std::numeric_limits<scalar_t>::infinity()
|
||||
: std::numeric_limits<scalar_t>::lowest();
|
||||
break;
|
||||
case SCATTER_GATHER_OP::REDUCE_MINIMUM:
|
||||
case ReductionType::MIN:
|
||||
init_val = std::numeric_limits<scalar_t>::has_infinity ? std::numeric_limits<scalar_t>::infinity()
|
||||
: std::numeric_limits<scalar_t>::max();
|
||||
break;
|
||||
case SCATTER_GATHER_OP::REDUCE_MEAN:
|
||||
case ReductionType::MEAN:
|
||||
init_val = (scalar_t)0;
|
||||
break;
|
||||
}
|
||||
|
|
@ -1561,7 +1535,7 @@ void scatter_impl(
|
|||
if (index.numel() == 0) return;
|
||||
|
||||
if (reduce.has_value()) {
|
||||
auto op = meta::get_operator_enum(reduce.value(), use_new_options);
|
||||
auto op = get_operator_enum(reduce.value(), use_new_options);
|
||||
if (!reduce_includes_self) {
|
||||
// scatter inits for reduction to appropriate indices (used by scatter_reduce.two)
|
||||
scatter_reduce_exclude_self_helper(mut_out, dim, index, op);
|
||||
|
|
@ -1756,7 +1730,7 @@ TORCH_IMPL_FUNC(scatter_reduce_two)
|
|||
out.copy_(self);
|
||||
}
|
||||
|
||||
const auto op = meta::get_operator_enum(reduce, true);
|
||||
const auto op = get_operator_enum(reduce, true);
|
||||
|
||||
if (can_use_expanded_index_path(out, dim, index, src, /*is_scatter_like*/true)) {
|
||||
scatter_reduce_expanded_index_stub(self.device().type(), out, index, src, op, include_self);
|
||||
|
|
@ -1769,7 +1743,7 @@ TORCH_IMPL_FUNC(scatter_reduce_two)
|
|||
reduce,
|
||||
include_self);
|
||||
|
||||
if (op == SCATTER_GATHER_OP::REDUCE_MEAN) {
|
||||
if (op == ReductionType::MEAN) {
|
||||
auto ones = at::ones_like(src);
|
||||
auto count = include_self ? at::ones_like(out) : at::zeros_like(out);
|
||||
count.scatter_add_(dim, index, ones);
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
#include <ATen/core/List.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
#include <ATen/native/ReductionType.h>
|
||||
#include <ATen/native/cpu/radix_sort.h>
|
||||
|
||||
namespace at {
|
||||
|
|
@ -13,8 +14,6 @@ struct TensorIterator;
|
|||
|
||||
namespace at { namespace native {
|
||||
|
||||
enum class SCATTER_GATHER_OP: uint8_t {REDUCE_ADD, REDUCE_MULTIPLY, REDUCE_MAXIMUM, REDUCE_MINIMUM, REDUCE_MEAN};
|
||||
|
||||
using index_put_with_sort_fn = void(*)(Tensor &, const c10::List<c10::optional<Tensor>> &, const Tensor &, bool accumulate, bool unsafe);
|
||||
using index_put_with_sort_quantized_fn = void(*)(Tensor& self, const c10::List<c10::optional<Tensor>>& indices, const Tensor& value, double scale, int zero_point, bool unsafe);
|
||||
using gather_fn = void (*)(const Tensor & result, const Tensor & self, int64_t dim, const Tensor & index);
|
||||
|
|
@ -22,11 +21,11 @@ using scatter_fn = void(*)(const Tensor& self, int64_t dim, const Tensor& index,
|
|||
using scatter_fill_fn = void(*)(const Tensor& self, int64_t dim, const Tensor& index, const Scalar& src);
|
||||
using scatter_add_fn = void(*)(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src);
|
||||
using scatter_reduce_fn = void(*)(const Tensor& self, const int64_t dim, const Tensor& index,
|
||||
const Tensor& src, const SCATTER_GATHER_OP& reduce);
|
||||
const Tensor& src, const ReductionType& reduce);
|
||||
using scatter_scalar_reduce_fn = void(*)(const Tensor& self, const int64_t dim, const Tensor& index,
|
||||
const Scalar& value, const SCATTER_GATHER_OP& reduce);
|
||||
const Scalar& value, const ReductionType& reduce);
|
||||
using scatter_reduce_two_fn = void(*)(const Tensor& self, const int64_t dim, const Tensor& index,
|
||||
const Tensor& src, const SCATTER_GATHER_OP& reduce);
|
||||
const Tensor& src, const ReductionType& reduce);
|
||||
|
||||
DECLARE_DISPATCH(index_put_with_sort_fn, index_put_with_sort_stub);
|
||||
DECLARE_DISPATCH(index_put_with_sort_quantized_fn, index_put_with_sort_quantized_stub);
|
||||
|
|
@ -95,7 +94,7 @@ static inline bool can_use_expanded_index_path(
|
|||
}
|
||||
|
||||
using scatter_add_expanded_index_fn = void(*)(const Tensor&, const Tensor&, const Tensor&);
|
||||
using scatter_reduce_expanded_index_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const SCATTER_GATHER_OP& reduce, bool);
|
||||
using scatter_reduce_expanded_index_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const ReductionType& reduce, bool);
|
||||
using gather_expanded_index_fn = void (*)(const Tensor&, const Tensor&, const Tensor&);
|
||||
|
||||
DECLARE_DISPATCH(scatter_add_expanded_index_fn, scatter_add_expanded_index_stub);
|
||||
|
|
|
|||
|
|
@ -573,19 +573,19 @@ struct cpu_scatter_gather_base_kernel {
|
|||
}
|
||||
};
|
||||
|
||||
template <typename scalar_t, SCATTER_GATHER_OP reduce>
|
||||
template <typename scalar_t, ReductionType reduce>
|
||||
inline void init(scalar_t* ptr, int64_t size, bool include_self) {
|
||||
if (!include_self) {
|
||||
using acc_t = vec::vec_scalar_t<scalar_t>;
|
||||
using Vec = vec::Vectorized<acc_t>;
|
||||
|
||||
acc_t val;
|
||||
if (reduce == SCATTER_GATHER_OP::REDUCE_ADD ||
|
||||
reduce == SCATTER_GATHER_OP::REDUCE_MEAN) {
|
||||
if (reduce == ReductionType::SUM ||
|
||||
reduce == ReductionType::MEAN) {
|
||||
val = static_cast<acc_t>(0);
|
||||
} else if (reduce == SCATTER_GATHER_OP::REDUCE_MULTIPLY) {
|
||||
} else if (reduce == ReductionType::PROD) {
|
||||
val = static_cast<acc_t>(1);
|
||||
} else if (reduce == SCATTER_GATHER_OP::REDUCE_MAXIMUM) {
|
||||
} else if (reduce == ReductionType::MAX) {
|
||||
val = std::numeric_limits<acc_t>::lowest();
|
||||
} else {
|
||||
val = std::numeric_limits<acc_t>::max();
|
||||
|
|
@ -598,14 +598,14 @@ inline void init(scalar_t* ptr, int64_t size, bool include_self) {
|
|||
}
|
||||
}
|
||||
|
||||
template <typename vec_t, SCATTER_GATHER_OP reduce>
|
||||
template <typename vec_t, ReductionType reduce>
|
||||
inline vec_t update(const vec_t& x, const vec_t& y) {
|
||||
if (reduce == SCATTER_GATHER_OP::REDUCE_ADD ||
|
||||
reduce == SCATTER_GATHER_OP::REDUCE_MEAN) {
|
||||
if (reduce == ReductionType::SUM ||
|
||||
reduce == ReductionType::MEAN) {
|
||||
return x + y;
|
||||
} else if (reduce == SCATTER_GATHER_OP::REDUCE_MULTIPLY) {
|
||||
} else if (reduce == ReductionType::PROD) {
|
||||
return x * y;
|
||||
} else if (reduce == SCATTER_GATHER_OP::REDUCE_MAXIMUM) {
|
||||
} else if (reduce == ReductionType::MAX) {
|
||||
return vec::maximum(x, y);
|
||||
} else {
|
||||
return vec::minimum(x, y);
|
||||
|
|
@ -635,7 +635,7 @@ inline vec_t update(const vec_t& x, const vec_t& y) {
|
|||
//
|
||||
// step 2: spmm reduce, parallel on M and vectorize on K
|
||||
//
|
||||
template <typename scalar_t, SCATTER_GATHER_OP reduce>
|
||||
template <typename scalar_t, ReductionType reduce>
|
||||
void cpu_scatter_reduce_expanded_index(const Tensor& self, const Tensor& index, const Tensor& src, bool include_self) {
|
||||
int64_t* index_data = index.data_ptr<int64_t>();
|
||||
scalar_t* self_data = self.data_ptr<scalar_t>();
|
||||
|
|
@ -735,7 +735,7 @@ void cpu_scatter_reduce_expanded_index(const Tensor& self, const Tensor& index,
|
|||
K);
|
||||
}
|
||||
|
||||
if (reduce == SCATTER_GATHER_OP::REDUCE_MEAN) {
|
||||
if (reduce == ReductionType::MEAN) {
|
||||
int64_t count = include_self ? 1 : 0;
|
||||
count += off_end - off_start;
|
||||
if (count != 0) {
|
||||
|
|
@ -791,30 +791,30 @@ void cpu_gather_expanded_index_kernel(const Tensor& result, const Tensor& index,
|
|||
void scatter_add_expanded_index_kernel(const Tensor& self, const Tensor& index, const Tensor& src) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(
|
||||
ScalarType::BFloat16, self.scalar_type(), "scatter_add_expanded_index", [&] {
|
||||
cpu_scatter_reduce_expanded_index<scalar_t, SCATTER_GATHER_OP::REDUCE_ADD>(self, index, src, /*include_self*/true);
|
||||
cpu_scatter_reduce_expanded_index<scalar_t, ReductionType::SUM>(self, index, src, /*include_self*/true);
|
||||
});
|
||||
}
|
||||
|
||||
void scatter_reduce_expanded_index_kernel(
|
||||
const Tensor& self, const Tensor& index, const Tensor& src,
|
||||
const SCATTER_GATHER_OP& reduce, bool include_self) {
|
||||
const ReductionType& reduce, bool include_self) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(
|
||||
ScalarType::BFloat16, self.scalar_type(), "scatter_reduce_expanded_index", [&] {
|
||||
switch (reduce) {
|
||||
case SCATTER_GATHER_OP::REDUCE_ADD :
|
||||
cpu_scatter_reduce_expanded_index<scalar_t, SCATTER_GATHER_OP::REDUCE_ADD>(self, index, src, include_self);
|
||||
case ReductionType::SUM :
|
||||
cpu_scatter_reduce_expanded_index<scalar_t, ReductionType::SUM>(self, index, src, include_self);
|
||||
break;
|
||||
case SCATTER_GATHER_OP::REDUCE_MULTIPLY :
|
||||
cpu_scatter_reduce_expanded_index<scalar_t, SCATTER_GATHER_OP::REDUCE_MULTIPLY>(self, index, src, include_self);
|
||||
case ReductionType::PROD :
|
||||
cpu_scatter_reduce_expanded_index<scalar_t, ReductionType::PROD>(self, index, src, include_self);
|
||||
break;
|
||||
case SCATTER_GATHER_OP::REDUCE_MAXIMUM :
|
||||
cpu_scatter_reduce_expanded_index<scalar_t, SCATTER_GATHER_OP::REDUCE_MAXIMUM>(self, index, src, include_self);
|
||||
case ReductionType::MAX :
|
||||
cpu_scatter_reduce_expanded_index<scalar_t, ReductionType::MAX>(self, index, src, include_self);
|
||||
break;
|
||||
case SCATTER_GATHER_OP::REDUCE_MINIMUM :
|
||||
cpu_scatter_reduce_expanded_index<scalar_t, SCATTER_GATHER_OP::REDUCE_MINIMUM>(self, index, src, include_self);
|
||||
case ReductionType::MIN :
|
||||
cpu_scatter_reduce_expanded_index<scalar_t, ReductionType::MIN>(self, index, src, include_self);
|
||||
break;
|
||||
case SCATTER_GATHER_OP::REDUCE_MEAN :
|
||||
cpu_scatter_reduce_expanded_index<scalar_t, SCATTER_GATHER_OP::REDUCE_MEAN>(self, index, src, include_self);
|
||||
case ReductionType::MEAN :
|
||||
cpu_scatter_reduce_expanded_index<scalar_t, ReductionType::MEAN>(self, index, src, include_self);
|
||||
break;
|
||||
}
|
||||
});
|
||||
|
|
@ -850,13 +850,13 @@ void scatter_add_cpu_kernel(const Tensor& self, int64_t dim, const Tensor& index
|
|||
}
|
||||
|
||||
void scatter_reduce_cpu_kernel(const Tensor& self, const int64_t dim, const Tensor& index,
|
||||
const Tensor& src, const SCATTER_GATHER_OP& reduce) {
|
||||
const Tensor& src, const ReductionType& reduce) {
|
||||
switch (reduce) {
|
||||
case SCATTER_GATHER_OP::REDUCE_ADD :
|
||||
case ReductionType::SUM :
|
||||
cpu_scatter_gather_base_kernel<>()(self, dim, index, src,
|
||||
"scatter_reduce_add_", reduce_add);
|
||||
break;
|
||||
case SCATTER_GATHER_OP::REDUCE_MULTIPLY :
|
||||
case ReductionType::PROD :
|
||||
cpu_scatter_gather_base_kernel<>()(self, dim, index, src,
|
||||
"scatter_reduce_multiply_", reduce_multiply);
|
||||
break;
|
||||
|
|
@ -866,25 +866,25 @@ void scatter_reduce_cpu_kernel(const Tensor& self, const int64_t dim, const Tens
|
|||
}
|
||||
|
||||
void scatter_reduce_two_cpu_kernel(const Tensor& self, const int64_t dim, const Tensor& index,
|
||||
const Tensor& src, const SCATTER_GATHER_OP& reduce) {
|
||||
const Tensor& src, const ReductionType& reduce) {
|
||||
switch (reduce) {
|
||||
case SCATTER_GATHER_OP::REDUCE_ADD :
|
||||
case ReductionType::SUM :
|
||||
cpu_scatter_gather_base_kernel<>()(self, dim, index, src,
|
||||
"scatter_reduce_sum_", reduce_add);
|
||||
break;
|
||||
case SCATTER_GATHER_OP::REDUCE_MULTIPLY :
|
||||
case ReductionType::PROD :
|
||||
cpu_scatter_gather_base_kernel<>()(self, dim, index, src,
|
||||
"scatter_reduce_prod_", reduce_multiply);
|
||||
break;
|
||||
case SCATTER_GATHER_OP::REDUCE_MAXIMUM :
|
||||
case ReductionType::MAX :
|
||||
cpu_scatter_gather_base_kernel<>()(self, dim, index, src,
|
||||
"scatter_reduce_amax_", reduce_maximum);
|
||||
break;
|
||||
case SCATTER_GATHER_OP::REDUCE_MINIMUM :
|
||||
case ReductionType::MIN :
|
||||
cpu_scatter_gather_base_kernel<>()(self, dim, index, src,
|
||||
"scatter_reduce_amin_", reduce_minimum);
|
||||
break;
|
||||
case SCATTER_GATHER_OP::REDUCE_MEAN :
|
||||
case ReductionType::MEAN :
|
||||
cpu_scatter_gather_base_kernel<>()(self, dim, index, src,
|
||||
"scatter_reduce_mean_", reduce_mean);
|
||||
break;
|
||||
|
|
@ -892,13 +892,13 @@ void scatter_reduce_two_cpu_kernel(const Tensor& self, const int64_t dim, const
|
|||
}
|
||||
|
||||
void scatter_scalar_reduce_cpu_kernel(const Tensor& self, const int64_t dim, const Tensor& index,
|
||||
const Scalar& value, const SCATTER_GATHER_OP& reduce) {
|
||||
const Scalar& value, const ReductionType& reduce) {
|
||||
switch (reduce) {
|
||||
case SCATTER_GATHER_OP::REDUCE_ADD :
|
||||
case ReductionType::SUM :
|
||||
cpu_scatter_gather_base_kernel<>()(self, dim, index, value,
|
||||
"scatter_scalar_reduce_add_", reduce_add);
|
||||
break;
|
||||
case SCATTER_GATHER_OP::REDUCE_MULTIPLY :
|
||||
case ReductionType::PROD :
|
||||
cpu_scatter_gather_base_kernel<>()(self, dim, index, value,
|
||||
"scatter_scalar_reduce_multiply_", reduce_multiply);
|
||||
break;
|
||||
|
|
|
|||
|
|
@ -865,7 +865,7 @@ void index_reduce_func_cuda_impl(
|
|||
const Tensor& index,
|
||||
const Tensor& source,
|
||||
bool include_self,
|
||||
const SCATTER_GATHER_OP& reduce,
|
||||
const ReductionType& reduce,
|
||||
const func_t& reduce_func,
|
||||
const Tensor& result) {
|
||||
globalContext().alertNotDeterministic("index_reduce_cuda");
|
||||
|
|
@ -886,14 +886,14 @@ void index_reduce_func_cuda_impl(
|
|||
self.scalar_type(), "index_reduce_func_cuda_exclude_input_init", [&] {
|
||||
scalar_t init_val;
|
||||
switch (reduce) {
|
||||
case SCATTER_GATHER_OP::REDUCE_MULTIPLY:
|
||||
case ReductionType::PROD:
|
||||
init_val = (scalar_t)1;
|
||||
break;
|
||||
case SCATTER_GATHER_OP::REDUCE_MAXIMUM:
|
||||
case ReductionType::MAX:
|
||||
init_val = std::numeric_limits<scalar_t>::has_infinity ? -std::numeric_limits<scalar_t>::infinity()
|
||||
: std::numeric_limits<scalar_t>::lowest();
|
||||
break;
|
||||
case SCATTER_GATHER_OP::REDUCE_MINIMUM:
|
||||
case ReductionType::MIN:
|
||||
init_val = std::numeric_limits<scalar_t>::has_infinity ? std::numeric_limits<scalar_t>::infinity()
|
||||
: std::numeric_limits<scalar_t>::max();
|
||||
break;
|
||||
|
|
@ -1047,9 +1047,9 @@ TORCH_IMPL_FUNC(index_reduce_cuda_out)
|
|||
TORCH_WARN_ONCE("index_reduce() is in beta and the API may change at any time.");
|
||||
|
||||
if (reduce == "prod") {
|
||||
index_reduce_func_cuda_impl(self, dim, index, source, include_self, SCATTER_GATHER_OP::REDUCE_MULTIPLY, reduce_multiply, result);
|
||||
index_reduce_func_cuda_impl(self, dim, index, source, include_self, ReductionType::PROD, reduce_multiply, result);
|
||||
} else if (reduce == "mean") {
|
||||
index_reduce_func_cuda_impl(self, dim, index, source, include_self, SCATTER_GATHER_OP::REDUCE_MEAN, reduce_add, result);
|
||||
index_reduce_func_cuda_impl(self, dim, index, source, include_self, ReductionType::MEAN, reduce_add, result);
|
||||
auto counts = include_self ? at::ones_like(result) : at::zeros_like(result);
|
||||
counts.index_add_(dim, index, at::ones_like(source));
|
||||
counts.masked_fill_(counts == 0, 1);
|
||||
|
|
@ -1059,9 +1059,9 @@ TORCH_IMPL_FUNC(index_reduce_cuda_out)
|
|||
result.div_(counts, "floor");
|
||||
}
|
||||
} else if (reduce == "amax") {
|
||||
index_reduce_func_cuda_impl(self, dim, index, source, include_self, SCATTER_GATHER_OP::REDUCE_MAXIMUM, reduce_maximum, result);
|
||||
index_reduce_func_cuda_impl(self, dim, index, source, include_self, ReductionType::MAX, reduce_maximum, result);
|
||||
} else if (reduce == "amin") {
|
||||
index_reduce_func_cuda_impl(self, dim, index, source, include_self, SCATTER_GATHER_OP::REDUCE_MINIMUM, reduce_minimum, result);
|
||||
index_reduce_func_cuda_impl(self, dim, index, source, include_self, ReductionType::MIN, reduce_minimum, result);
|
||||
} else {
|
||||
TORCH_CHECK(false, "reduce argument must be either prod, mean, amax or amin, got ", reduce, ".");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -496,13 +496,13 @@ void scatter_add_cuda_kernel(const Tensor& self, int64_t dim, const Tensor& inde
|
|||
}
|
||||
|
||||
void scatter_reduce_cuda_kernel(const Tensor& self, const int64_t dim, const Tensor& index,
|
||||
const Tensor& src, const SCATTER_GATHER_OP& reduce) {
|
||||
const Tensor& src, const ReductionType& reduce) {
|
||||
switch (reduce) {
|
||||
case SCATTER_GATHER_OP::REDUCE_ADD :
|
||||
case ReductionType::SUM :
|
||||
cuda_scatter_gather_base_kernel<true, false>()(self, dim, index, src,
|
||||
"scatter_reduce_cuda_add_", reduce_add);
|
||||
break;
|
||||
case SCATTER_GATHER_OP::REDUCE_MULTIPLY :
|
||||
case ReductionType::PROD :
|
||||
cuda_scatter_gather_base_kernel<true, false>()(self, dim, index, src,
|
||||
"scatter_reduce_cuda_multiply_", reduce_multiply);
|
||||
break;
|
||||
|
|
@ -512,26 +512,26 @@ void scatter_reduce_cuda_kernel(const Tensor& self, const int64_t dim, const Ten
|
|||
}
|
||||
|
||||
void scatter_reduce_two_cuda_kernel(const Tensor& self, const int64_t dim, const Tensor& index,
|
||||
const Tensor& src, const SCATTER_GATHER_OP& reduce) {
|
||||
const Tensor& src, const ReductionType& reduce) {
|
||||
globalContext().alertNotDeterministic("scatter_reduce_cuda");
|
||||
switch (reduce) {
|
||||
case SCATTER_GATHER_OP::REDUCE_ADD :
|
||||
case ReductionType::SUM :
|
||||
cuda_scatter_gather_base_kernel<true, false>()(self, dim, index, src,
|
||||
"scatter_reduce_cuda_sum_", reduce_add);
|
||||
break;
|
||||
case SCATTER_GATHER_OP::REDUCE_MULTIPLY :
|
||||
case ReductionType::PROD :
|
||||
cuda_scatter_gather_base_kernel<true, false>()(self, dim, index, src,
|
||||
"scatter_reduce_cuda_prod_", reduce_multiply);
|
||||
break;
|
||||
case SCATTER_GATHER_OP::REDUCE_MAXIMUM :
|
||||
case ReductionType::MAX :
|
||||
cuda_scatter_gather_base_kernel<true, false>()(self, dim, index, src,
|
||||
"scatter_reduce_cuda_amax_", reduce_maximum);
|
||||
break;
|
||||
case SCATTER_GATHER_OP::REDUCE_MINIMUM :
|
||||
case ReductionType::MIN :
|
||||
cuda_scatter_gather_base_kernel<true, false>()(self, dim, index, src,
|
||||
"scatter_reduce_cuda_amin_", reduce_minimum);
|
||||
break;
|
||||
case SCATTER_GATHER_OP::REDUCE_MEAN :
|
||||
case ReductionType::MEAN :
|
||||
cuda_scatter_gather_base_kernel<true, false>()(self, dim, index, src,
|
||||
"scatter_reduce_cuda_mean_", reduce_mean);
|
||||
break;
|
||||
|
|
@ -539,13 +539,13 @@ void scatter_reduce_two_cuda_kernel(const Tensor& self, const int64_t dim, const
|
|||
}
|
||||
|
||||
void scatter_scalar_reduce_cuda_kernel(const Tensor& self, const int64_t dim, const Tensor& index,
|
||||
const Scalar& value, const SCATTER_GATHER_OP& reduce) {
|
||||
const Scalar& value, const ReductionType& reduce) {
|
||||
switch (reduce) {
|
||||
case SCATTER_GATHER_OP::REDUCE_ADD :
|
||||
case ReductionType::SUM :
|
||||
cuda_scatter_fill_base_kernel<false>()(self, dim, index, value,
|
||||
"scatter_fill_cuda_add_", reduce_add);
|
||||
break;
|
||||
case SCATTER_GATHER_OP::REDUCE_MULTIPLY :
|
||||
case ReductionType::PROD :
|
||||
cuda_scatter_fill_base_kernel<false>()(self, dim, index, value,
|
||||
"scatter_fill_cuda_multiply_", reduce_multiply);
|
||||
break;
|
||||
|
|
|
|||
|
|
@ -104,7 +104,7 @@ __global__ static void post_sum_div_kernel(
|
|||
|
||||
template <typename scalar_t, typename index_t>
|
||||
__global__ void segment_reduce_forward_kernel(
|
||||
SegmentReductionType reduction,
|
||||
ReductionType reduction,
|
||||
scalar_t* output_data,
|
||||
scalar_t* values_data,
|
||||
const index_t* lengths_data,
|
||||
|
|
@ -139,18 +139,18 @@ __global__ void segment_reduce_forward_kernel(
|
|||
+ j * data_stride_axis + lane_id;
|
||||
const auto data = values_data[data_index];
|
||||
// TODO: There is no need to branch with every element
|
||||
if (reduction == SegmentReductionType::MAX) {
|
||||
if (reduction == ReductionType::MAX) {
|
||||
initial_value =
|
||||
at::_isnan(data) ? data : std::max<scalar_t>(initial_value, data);
|
||||
} else if (
|
||||
reduction == SegmentReductionType::MEAN ||
|
||||
reduction == SegmentReductionType::SUM) {
|
||||
reduction == ReductionType::MEAN ||
|
||||
reduction == ReductionType::SUM) {
|
||||
initial_value = initial_value + data;
|
||||
} else if (reduction == SegmentReductionType::MIN) {
|
||||
} else if (reduction == ReductionType::MIN) {
|
||||
initial_value =
|
||||
at::_isnan(data) ? data : std::min<scalar_t>(initial_value, data);
|
||||
} else if (
|
||||
reduction == SegmentReductionType::PROD) {
|
||||
reduction == ReductionType::PROD) {
|
||||
initial_value = initial_value * data;
|
||||
}
|
||||
}
|
||||
|
|
@ -159,10 +159,10 @@ __global__ void segment_reduce_forward_kernel(
|
|||
int64_t lengths_idx = outer_idx * lengths_stride_axis * segment_count + dim_idx;
|
||||
CUDA_KERNEL_ASSERT(lengths_data[lengths_idx] >= 0);
|
||||
if (lengths_data[lengths_idx] == 0 && !is_initial_set &&
|
||||
reduction == SegmentReductionType::MEAN) {
|
||||
reduction == ReductionType::MEAN) {
|
||||
initial_value = static_cast<scalar_t>(NAN);
|
||||
} else if (
|
||||
reduction == SegmentReductionType::MEAN && lengths_data[lengths_idx] > 0 &&
|
||||
reduction == ReductionType::MEAN && lengths_data[lengths_idx] > 0 &&
|
||||
!at::_isnan(initial_value)) {
|
||||
initial_value = initial_value / lengths_data[lengths_idx];
|
||||
}
|
||||
|
|
@ -174,7 +174,7 @@ __global__ void segment_reduce_forward_kernel(
|
|||
|
||||
template <typename scalar_t, typename index_t>
|
||||
__global__ void segment_reduce_backward_kernel(
|
||||
SegmentReductionType reduction,
|
||||
ReductionType reduction,
|
||||
scalar_t* grad_input_data,
|
||||
scalar_t* grad_data,
|
||||
scalar_t* output_data,
|
||||
|
|
@ -213,8 +213,8 @@ __global__ void segment_reduce_backward_kernel(
|
|||
int64_t output_index = outer_idx * output_stride_axis * output_size_axis
|
||||
+ dim_idx * output_stride_axis + lane_id;
|
||||
|
||||
if (reduction == SegmentReductionType::MAX ||
|
||||
reduction == SegmentReductionType::MIN) {
|
||||
if (reduction == ReductionType::MAX ||
|
||||
reduction == ReductionType::MIN) {
|
||||
int64_t counter = 0;
|
||||
for (int64_t j = offset_start; j < offset_end; ++j) {
|
||||
int64_t data_index = outer_idx * data_stride_axis * data_size_axis
|
||||
|
|
@ -238,21 +238,21 @@ __global__ void segment_reduce_backward_kernel(
|
|||
grad_input_data[data_index] / counter;
|
||||
}
|
||||
}
|
||||
} else if (reduction == SegmentReductionType::MEAN) {
|
||||
} else if (reduction == ReductionType::MEAN) {
|
||||
auto grad_val = grad_data[output_index] / segment_length;
|
||||
for (int64_t j = offset_start; j < offset_end; ++j) {
|
||||
int64_t data_index = outer_idx * data_stride_axis * data_size_axis
|
||||
+ j * data_stride_axis + lane_id;
|
||||
grad_input_data[data_index] = grad_val;
|
||||
}
|
||||
} else if (reduction == SegmentReductionType::SUM) {
|
||||
} else if (reduction == ReductionType::SUM) {
|
||||
const auto& grad_val = grad_data[output_index];
|
||||
for (int64_t j = offset_start; j < offset_end; ++j) {
|
||||
int64_t data_index = outer_idx * data_stride_axis * data_size_axis
|
||||
+ j * data_stride_axis + lane_id;
|
||||
grad_input_data[data_index] = grad_val;
|
||||
}
|
||||
} else if (reduction == SegmentReductionType::PROD) {
|
||||
} else if (reduction == ReductionType::PROD) {
|
||||
const auto& grad_val = grad_data[output_index] * output_data[output_index];
|
||||
for (int64_t j = offset_start; j < offset_end; ++j) {
|
||||
int64_t data_index = outer_idx * data_stride_axis * data_size_axis
|
||||
|
|
@ -282,7 +282,7 @@ Tensor _segment_reduce_lengths_offsets_backward_cuda_kernel(
|
|||
const Tensor& grad_contig,
|
||||
const Tensor& output_contig,
|
||||
const Tensor& data_contig,
|
||||
SegmentReductionType reduction,
|
||||
ReductionType reduction,
|
||||
const Tensor& lengths_or_offsets_contig,
|
||||
int64_t axis,
|
||||
const c10::optional<Scalar>& initial,
|
||||
|
|
@ -385,7 +385,7 @@ Tensor _segment_reduce_lengths_backward_cuda_kernel(
|
|||
const Tensor& grad_contig,
|
||||
const Tensor& output_contig,
|
||||
const Tensor& data_contig,
|
||||
SegmentReductionType reduction,
|
||||
ReductionType reduction,
|
||||
const Tensor& lengths_contig,
|
||||
int64_t axis,
|
||||
const c10::optional<Scalar>& initial) {
|
||||
|
|
@ -397,7 +397,7 @@ Tensor _segment_reduce_offsets_backward_cuda_kernel(
|
|||
const Tensor& grad_contig,
|
||||
const Tensor& output_contig,
|
||||
const Tensor& data_contig,
|
||||
SegmentReductionType reduction,
|
||||
ReductionType reduction,
|
||||
const Tensor& offsets_contig,
|
||||
int64_t axis,
|
||||
const c10::optional<Scalar>& initial) {
|
||||
|
|
@ -406,7 +406,7 @@ Tensor _segment_reduce_offsets_backward_cuda_kernel(
|
|||
}
|
||||
|
||||
Tensor _segment_reduce_lengths_offsets_cuda_kernel(
|
||||
SegmentReductionType reduction,
|
||||
ReductionType reduction,
|
||||
const Tensor& data,
|
||||
const Tensor& lengths_or_offsets,
|
||||
int64_t axis,
|
||||
|
|
@ -474,15 +474,15 @@ Tensor _segment_reduce_lengths_offsets_cuda_kernel(
|
|||
scalar_t initial_value;
|
||||
if (initial.has_value()) {
|
||||
initial_value = initial.value().to<scalar_t>();
|
||||
} else if (reduction == SegmentReductionType::MAX) {
|
||||
} else if (reduction == ReductionType::MAX) {
|
||||
initial_value = -std::numeric_limits<scalar_t>::infinity();
|
||||
} else if (
|
||||
reduction == SegmentReductionType::MEAN ||
|
||||
reduction == SegmentReductionType::SUM) {
|
||||
reduction == ReductionType::MEAN ||
|
||||
reduction == ReductionType::SUM) {
|
||||
initial_value = 0;
|
||||
} else if (reduction == SegmentReductionType::MIN) {
|
||||
} else if (reduction == ReductionType::MIN) {
|
||||
initial_value = std::numeric_limits<scalar_t>::infinity();
|
||||
} else if (reduction == SegmentReductionType::PROD) {
|
||||
} else if (reduction == ReductionType::PROD) {
|
||||
initial_value = 1;
|
||||
}
|
||||
|
||||
|
|
@ -511,7 +511,7 @@ Tensor _segment_reduce_lengths_offsets_cuda_kernel(
|
|||
);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
} else {
|
||||
if (reduction == SegmentReductionType::MAX) {
|
||||
if (reduction == ReductionType::MAX) {
|
||||
CustomMax max_op{};
|
||||
CUB_WRAPPER(
|
||||
cub::DeviceSegmentedReduce::Reduce,
|
||||
|
|
@ -523,7 +523,7 @@ Tensor _segment_reduce_lengths_offsets_cuda_kernel(
|
|||
max_op,
|
||||
initial_value,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
} else if (reduction == SegmentReductionType::MEAN) {
|
||||
} else if (reduction == ReductionType::MEAN) {
|
||||
CustomSum sum_op{};
|
||||
CUB_WRAPPER(
|
||||
cub::DeviceSegmentedReduce::Reduce,
|
||||
|
|
@ -547,7 +547,7 @@ Tensor _segment_reduce_lengths_offsets_cuda_kernel(
|
|||
initial.has_value(),
|
||||
initial_value);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
} else if (reduction == SegmentReductionType::MIN) {
|
||||
} else if (reduction == ReductionType::MIN) {
|
||||
CustomMin min_op{};
|
||||
CUB_WRAPPER(
|
||||
cub::DeviceSegmentedReduce::Reduce,
|
||||
|
|
@ -559,7 +559,7 @@ Tensor _segment_reduce_lengths_offsets_cuda_kernel(
|
|||
min_op,
|
||||
initial_value,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
} else if (reduction == SegmentReductionType::SUM) {
|
||||
} else if (reduction == ReductionType::SUM) {
|
||||
CustomSum sum_op{};
|
||||
CUB_WRAPPER(
|
||||
cub::DeviceSegmentedReduce::Reduce,
|
||||
|
|
@ -571,7 +571,7 @@ Tensor _segment_reduce_lengths_offsets_cuda_kernel(
|
|||
sum_op,
|
||||
initial_value,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
} else if (reduction == SegmentReductionType::PROD) {
|
||||
} else if (reduction == ReductionType::PROD) {
|
||||
CustomProd prod_op{};
|
||||
CUB_WRAPPER(
|
||||
cub::DeviceSegmentedReduce::Reduce,
|
||||
|
|
@ -592,7 +592,7 @@ Tensor _segment_reduce_lengths_offsets_cuda_kernel(
|
|||
}
|
||||
|
||||
Tensor _segment_reduce_lengths_cuda_kernel(
|
||||
SegmentReductionType reduction,
|
||||
ReductionType reduction,
|
||||
const Tensor& data,
|
||||
const Tensor& lengths,
|
||||
int64_t axis,
|
||||
|
|
@ -602,7 +602,7 @@ Tensor _segment_reduce_lengths_cuda_kernel(
|
|||
}
|
||||
|
||||
Tensor _segment_reduce_offsets_cuda_kernel(
|
||||
SegmentReductionType reduction,
|
||||
ReductionType reduction,
|
||||
const Tensor& data,
|
||||
const Tensor& offsets,
|
||||
int64_t axis,
|
||||
|
|
|
|||
|
|
@ -18,17 +18,17 @@ from torch.testing._internal.common_utils import (
|
|||
)
|
||||
|
||||
|
||||
reductions = ["max", "mean", "min", "sum", "prod"]
|
||||
reductions = ["amax", "mean", "amin", "sum", "prod"]
|
||||
|
||||
|
||||
def get_default_value(initial_value, reduction):
|
||||
if initial_value is not None:
|
||||
return initial_value
|
||||
if reduction == "max":
|
||||
if reduction == "amax":
|
||||
return -float("Inf")
|
||||
elif reduction == "mean":
|
||||
return float("nan")
|
||||
elif reduction == "min":
|
||||
elif reduction == "amin":
|
||||
return float("Inf")
|
||||
elif reduction == "sum":
|
||||
return 0.0
|
||||
|
|
@ -133,13 +133,13 @@ class TestSegmentReductions(TestCase):
|
|||
check_backward = True if initial is not None else False
|
||||
initial_value = initial
|
||||
default_value = get_default_value(initial_value, reduction)
|
||||
if reduction == "max":
|
||||
if reduction == "amax":
|
||||
expected_result = [1, float("nan"), 5, default_value]
|
||||
expected_grad = [1, 1, 0, 0, 0.5, 0.5]
|
||||
elif reduction == "mean":
|
||||
expected_result = [1, float("nan"), 4.666, default_value]
|
||||
expected_grad = [1.0, 0.5, 0.5, 0.333, 0.333, 0.333]
|
||||
elif reduction == "min":
|
||||
elif reduction == "amin":
|
||||
if initial is not None:
|
||||
initial_value = 1000 # some high number
|
||||
default_value = get_default_value(initial_value, reduction)
|
||||
|
|
@ -191,7 +191,7 @@ class TestSegmentReductions(TestCase):
|
|||
check_backward = True if initial is not None else False
|
||||
initial_value = initial
|
||||
default_value = get_default_value(initial_value, reduction)
|
||||
if reduction == "max":
|
||||
if reduction == "amax":
|
||||
expected_result = [
|
||||
[1, 1],
|
||||
[float("nan"), float("nan")],
|
||||
|
|
@ -221,7 +221,7 @@ class TestSegmentReductions(TestCase):
|
|||
[0.333, 0.333],
|
||||
[0.333, 0.333],
|
||||
]
|
||||
elif reduction == "min":
|
||||
elif reduction == "amin":
|
||||
if initial is not None:
|
||||
initial_value = 1000 # some high number
|
||||
default_value = get_default_value(initial_value, reduction)
|
||||
|
|
@ -308,7 +308,7 @@ class TestSegmentReductions(TestCase):
|
|||
(torch.int, torch.int64),
|
||||
)
|
||||
)
|
||||
@parametrize("reduce", ['sum', 'prod', 'min', 'max', 'mean'])
|
||||
@parametrize("reduce", ['sum', 'prod', 'amin', 'amax', 'mean'])
|
||||
def test_pytorch_scatter_test_cases(self, device, dtypes, reduce):
|
||||
val_dtype, length_dtype = dtypes
|
||||
# zero-length segments are filled with reduction inits contrary to pytorch_scatter.
|
||||
|
|
@ -320,8 +320,8 @@ class TestSegmentReductions(TestCase):
|
|||
'sum': [3, 12, 0, 6],
|
||||
'prod': [2, 60, 1, 6],
|
||||
'mean': [1.5, 4, float('nan'), 6],
|
||||
'min': [1, 3, float('inf'), 6],
|
||||
'max': [2, 5, -float('inf'), 6],
|
||||
'amin': [1, 3, float('inf'), 6],
|
||||
'amax': [2, 5, -float('inf'), 6],
|
||||
},
|
||||
{
|
||||
'src': [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]],
|
||||
|
|
@ -330,8 +330,8 @@ class TestSegmentReductions(TestCase):
|
|||
'sum': [[4, 6], [21, 24], [0, 0], [11, 12]],
|
||||
'prod': [[3, 8], [315, 480], [1, 1], [11, 12]],
|
||||
'mean': [[2, 3], [7, 8], [float('nan'), float('nan')], [11, 12]],
|
||||
'min': [[1, 2], [5, 6], [float('inf'), float('inf')], [11, 12]],
|
||||
'max': [[3, 4], [9, 10], [-float('inf'), -float('inf')], [11, 12]],
|
||||
'amin': [[1, 2], [5, 6], [float('inf'), float('inf')], [11, 12]],
|
||||
'amax': [[3, 4], [9, 10], [-float('inf'), -float('inf')], [11, 12]],
|
||||
},
|
||||
{
|
||||
'src': [[1, 3, 5, 7, 9, 11], [2, 4, 6, 8, 10, 12]],
|
||||
|
|
@ -340,8 +340,8 @@ class TestSegmentReductions(TestCase):
|
|||
'sum': [[4, 21, 0, 11], [12, 18, 12, 0]],
|
||||
'prod': [[3, 315, 1, 11], [48, 80, 12, 1]],
|
||||
'mean': [[2, 7, float('nan'), 11], [4, 9, 12, float('nan')]],
|
||||
'min': [[1, 5, float('inf'), 11], [2, 8, 12, float('inf')]],
|
||||
'max': [[3, 9, -float('inf'), 11], [6, 10, 12, -float('inf')]],
|
||||
'amin': [[1, 5, float('inf'), 11], [2, 8, 12, float('inf')]],
|
||||
'amax': [[3, 9, -float('inf'), 11], [6, 10, 12, -float('inf')]],
|
||||
},
|
||||
{
|
||||
'src': [[[1, 2], [3, 4], [5, 6]], [[7, 9], [10, 11], [12, 13]]],
|
||||
|
|
@ -351,10 +351,10 @@ class TestSegmentReductions(TestCase):
|
|||
'prod': [[[3, 8], [5, 6], [1, 1]], [[7, 9], [1, 1], [120, 143]]],
|
||||
'mean': [[[2, 3], [5, 6], [float('nan'), float('nan')]],
|
||||
[[7, 9], [float('nan'), float('nan')], [11, 12]]],
|
||||
'min': [[[1, 2], [5, 6], [float('inf'), float('inf')]],
|
||||
[[7, 9], [float('inf'), float('inf')], [10, 11]]],
|
||||
'max': [[[3, 4], [5, 6], [-float('inf'), -float('inf')]],
|
||||
[[7, 9], [-float('inf'), -float('inf')], [12, 13]]],
|
||||
'amin': [[[1, 2], [5, 6], [float('inf'), float('inf')]],
|
||||
[[7, 9], [float('inf'), float('inf')], [10, 11]]],
|
||||
'amax': [[[3, 4], [5, 6], [-float('inf'), -float('inf')]],
|
||||
[[7, 9], [-float('inf'), -float('inf')], [12, 13]]],
|
||||
},
|
||||
{
|
||||
'src': [[1, 3], [2, 4]],
|
||||
|
|
@ -363,8 +363,8 @@ class TestSegmentReductions(TestCase):
|
|||
'sum': [[4], [6]],
|
||||
'prod': [[3], [8]],
|
||||
'mean': [[2], [3]],
|
||||
'min': [[1], [2]],
|
||||
'max': [[3], [4]],
|
||||
'amin': [[1], [2]],
|
||||
'amax': [[3], [4]],
|
||||
},
|
||||
{
|
||||
'src': [[[1, 1], [3, 3]], [[2, 2], [4, 4]]],
|
||||
|
|
@ -373,8 +373,8 @@ class TestSegmentReductions(TestCase):
|
|||
'sum': [[[4, 4]], [[6, 6]]],
|
||||
'prod': [[[3, 3]], [[8, 8]]],
|
||||
'mean': [[[2, 2]], [[3, 3]]],
|
||||
'min': [[[1, 1]], [[2, 2]]],
|
||||
'max': [[[3, 3]], [[4, 4]]],
|
||||
'amin': [[[1, 1]], [[2, 2]]],
|
||||
'amax': [[[3, 3]], [[4, 4]]],
|
||||
},
|
||||
]
|
||||
for test in tests:
|
||||
|
|
@ -409,9 +409,9 @@ class TestSegmentReductions(TestCase):
|
|||
initial = 1
|
||||
# supply initial values to prevent gradcheck from failing for 0 length segments
|
||||
# where nan/inf are reduction identities that produce nans when calculating the numerical jacobian
|
||||
if reduce == 'min':
|
||||
if reduce == 'amin':
|
||||
initial = 1000
|
||||
elif reduce == 'max':
|
||||
elif reduce == 'amax':
|
||||
initial = -1000
|
||||
segment_reduce_args = {x, reduce}
|
||||
segment_reduce_kwargs = dict(axis=dim, unsafe=True, initial=initial)
|
||||
|
|
@ -442,7 +442,7 @@ class TestSegmentReductions(TestCase):
|
|||
|
||||
for reduction in reductions:
|
||||
initial_value = 0
|
||||
if reduction == "max":
|
||||
if reduction == "amax":
|
||||
expected_result = [
|
||||
np.full((2, 5), initial_value).tolist(),
|
||||
np.max(data[:2], axis=0).tolist(),
|
||||
|
|
@ -456,7 +456,7 @@ class TestSegmentReductions(TestCase):
|
|||
np.mean(data[2:], axis=0).tolist(),
|
||||
np.full((2, 5), initial_value).tolist(),
|
||||
]
|
||||
elif reduction == "min":
|
||||
elif reduction == "amin":
|
||||
initial_value = 1000 # some high number
|
||||
expected_result = [
|
||||
np.full((2, 5), initial_value).tolist(),
|
||||
|
|
|
|||
|
|
@ -783,9 +783,6 @@ def _sparse_csr_segment_reduction_helper(
|
|||
)
|
||||
new_nnz = new_crow_indices[-1]
|
||||
new_col_indices = col_indices.new_zeros(new_nnz)
|
||||
# segment_reduce takes 'max'/'min' rather than 'amax'/'amin', changing this would be BC-breaking
|
||||
if reduce in ["amax", "amin"]:
|
||||
reduce = reduce[1:]
|
||||
new_values = torch.segment_reduce(values, reduce, offsets=crow_indices)
|
||||
new_shape = [mask_input.size(0), 1]
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -6188,7 +6188,7 @@ def sample_inputs_segment_reduce(op_info, device, dtype, requires_grad, *, mode=
|
|||
((S, S, S), 1, [[2, 0, 3, 0], [0, 1, 2, 2], [3, 0, 2, 0], [1, 1, 1, 2], [0, 1, 2, 2]], False),
|
||||
)
|
||||
|
||||
reductions = ["max", "mean", "min", "sum", "prod"]
|
||||
reductions = ["amax", "mean", "amin", "sum", "prod"]
|
||||
for args, reduce, initial in product(test_cases, reductions, [1, 2]):
|
||||
inp_shape, dim, lengths, unsafe = args
|
||||
lengths_t = torch.tensor(lengths, dtype=torch.long, device=device)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user