mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
logsumexp for multiple dimensions (#16475)
Summary: Move `logsumexp` and `max_values` to `TensorIterator` and use it to make `logsumexp` work for multiple dimensions. Timings on a tensor of shape `(10,1000000,10)`, for each combination of (cpu, single-threaded cpu, gpu) and dimension: **before** 208 ms ± 2.72 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 279 ms ± 5.07 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 199 ms ± 2.64 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 1.11 s ± 33.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 1.25 s ± 25.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 1.11 s ± 6.83 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 15.4 ms ± 1.02 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 132 ms ± 30.1 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 39.6 ms ± 19.1 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) **after** 199 ms ± 8.23 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 307 ms ± 8.73 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 207 ms ± 7.62 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) 1.16 s ± 8.92 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 1.26 s ± 47.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 1.13 s ± 13.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 15.4 ms ± 868 ns per loop (mean ± std. dev. of 7 runs, 100 loops each) 132 ms ± 27.6 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 39.6 ms ± 21.8 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) Pull Request resolved: https://github.com/pytorch/pytorch/pull/16475 Differential Revision: D13855746 Pulled By: umanwizard fbshipit-source-id: aaacc0b967c3f89073487e1952ae6f76b7bd7ad3
This commit is contained in:
parent
4047c97266
commit
1ce188c510
|
|
@ -411,11 +411,11 @@ class CAFFE2_API Tensor {
|
|||
Tensor logdet() const;
|
||||
Tensor log_softmax(int64_t dim, ScalarType dtype) const;
|
||||
Tensor log_softmax(int64_t dim) const;
|
||||
Tensor logsumexp(int64_t dim, bool keepdim=false) const;
|
||||
Tensor logsumexp(IntList dim, bool keepdim=false) const;
|
||||
Tensor matmul(const Tensor & other) const;
|
||||
Tensor matrix_power(int64_t n) const;
|
||||
std::tuple<Tensor,Tensor> max(int64_t dim, bool keepdim=false) const;
|
||||
Tensor max_values(int64_t dim, bool keepdim=false) const;
|
||||
Tensor max_values(IntList dim, bool keepdim=false) const;
|
||||
Tensor mean(ScalarType dtype) const;
|
||||
Tensor mean() const;
|
||||
Tensor mean(IntList dim, bool keepdim, ScalarType dtype) const;
|
||||
|
|
@ -423,7 +423,7 @@ class CAFFE2_API Tensor {
|
|||
Tensor mean(IntList dim, ScalarType dtype) const;
|
||||
std::tuple<Tensor,Tensor> median(int64_t dim, bool keepdim=false) const;
|
||||
std::tuple<Tensor,Tensor> min(int64_t dim, bool keepdim=false) const;
|
||||
Tensor min_values(int64_t dim, bool keepdim=false) const;
|
||||
Tensor min_values(IntList dim, bool keepdim=false) const;
|
||||
Tensor mm(const Tensor & mat2) const;
|
||||
std::tuple<Tensor,Tensor> mode(int64_t dim=-1, bool keepdim=false) const;
|
||||
Tensor mul(const Tensor & other) const;
|
||||
|
|
|
|||
|
|
@ -376,7 +376,7 @@ inline Tensor Tensor::log_softmax(int64_t dim, ScalarType dtype) const {
|
|||
inline Tensor Tensor::log_softmax(int64_t dim) const {
|
||||
return type().log_softmax(*this, dim);
|
||||
}
|
||||
inline Tensor Tensor::logsumexp(int64_t dim, bool keepdim) const {
|
||||
inline Tensor Tensor::logsumexp(IntList dim, bool keepdim) const {
|
||||
return type().logsumexp(*this, dim, keepdim);
|
||||
}
|
||||
inline Tensor Tensor::matmul(const Tensor & other) const {
|
||||
|
|
@ -388,7 +388,7 @@ inline Tensor Tensor::matrix_power(int64_t n) const {
|
|||
inline std::tuple<Tensor,Tensor> Tensor::max(int64_t dim, bool keepdim) const {
|
||||
return type().max(*this, dim, keepdim);
|
||||
}
|
||||
inline Tensor Tensor::max_values(int64_t dim, bool keepdim) const {
|
||||
inline Tensor Tensor::max_values(IntList dim, bool keepdim) const {
|
||||
return type().max_values(*this, dim, keepdim);
|
||||
}
|
||||
inline Tensor Tensor::mean(ScalarType dtype) const {
|
||||
|
|
@ -412,7 +412,7 @@ inline std::tuple<Tensor,Tensor> Tensor::median(int64_t dim, bool keepdim) const
|
|||
inline std::tuple<Tensor,Tensor> Tensor::min(int64_t dim, bool keepdim) const {
|
||||
return type().min(*this, dim, keepdim);
|
||||
}
|
||||
inline Tensor Tensor::min_values(int64_t dim, bool keepdim) const {
|
||||
inline Tensor Tensor::min_values(IntList dim, bool keepdim) const {
|
||||
return type().min_values(*this, dim, keepdim);
|
||||
}
|
||||
inline Tensor Tensor::mm(const Tensor & mat2) const {
|
||||
|
|
|
|||
|
|
@ -293,11 +293,11 @@ struct CAFFE2_API Type {
|
|||
virtual Tensor logdet(const Tensor & self) const = 0;
|
||||
virtual Tensor log_softmax(const Tensor & self, int64_t dim, ScalarType dtype) const = 0;
|
||||
virtual Tensor log_softmax(const Tensor & self, int64_t dim) const = 0;
|
||||
virtual Tensor logsumexp(const Tensor & self, int64_t dim, bool keepdim) const = 0;
|
||||
virtual Tensor logsumexp(const Tensor & self, IntList dim, bool keepdim) const = 0;
|
||||
virtual Tensor matmul(const Tensor & self, const Tensor & other) const = 0;
|
||||
virtual Tensor matrix_power(const Tensor & self, int64_t n) const = 0;
|
||||
virtual std::tuple<Tensor,Tensor> max(const Tensor & self, int64_t dim, bool keepdim) const = 0;
|
||||
virtual Tensor max_values(const Tensor & self, int64_t dim, bool keepdim) const = 0;
|
||||
virtual Tensor max_values(const Tensor & self, IntList dim, bool keepdim) const = 0;
|
||||
virtual Tensor mean(const Tensor & self, ScalarType dtype) const = 0;
|
||||
virtual Tensor mean(const Tensor & self) const = 0;
|
||||
virtual Tensor mean(const Tensor & self, IntList dim, bool keepdim, ScalarType dtype) const = 0;
|
||||
|
|
@ -305,7 +305,7 @@ struct CAFFE2_API Type {
|
|||
virtual Tensor mean(const Tensor & self, IntList dim, ScalarType dtype) const = 0;
|
||||
virtual std::tuple<Tensor,Tensor> median(const Tensor & self, int64_t dim, bool keepdim) const = 0;
|
||||
virtual std::tuple<Tensor,Tensor> min(const Tensor & self, int64_t dim, bool keepdim) const = 0;
|
||||
virtual Tensor min_values(const Tensor & self, int64_t dim, bool keepdim) const = 0;
|
||||
virtual Tensor min_values(const Tensor & self, IntList dim, bool keepdim) const = 0;
|
||||
virtual Tensor mm(const Tensor & self, const Tensor & mat2) const = 0;
|
||||
virtual std::tuple<Tensor,Tensor> mode(const Tensor & self, int64_t dim, bool keepdim) const = 0;
|
||||
virtual Tensor mul(const Tensor & self, const Tensor & other) const = 0;
|
||||
|
|
|
|||
|
|
@ -419,12 +419,11 @@ Vec256<int16_t> inline operator-(const Vec256<int16_t>& a, const Vec256<int16_t>
|
|||
return _mm256_sub_epi16(a, b);
|
||||
}
|
||||
|
||||
// AVX2 has no intrinsic for int64_t multiply so it needs to be emulated
|
||||
// This could be implemented more efficiently using epi32 instructions
|
||||
// This is also technically avx compatible, but then we'll need AVX
|
||||
// code for add as well.
|
||||
template <>
|
||||
Vec256<int64_t> inline operator*(const Vec256<int64_t>& a, const Vec256<int64_t>& b) {
|
||||
// Emulate operations with no native 64-bit support in avx,
|
||||
// by extracting each element, performing the operation pointwise,
|
||||
// then combining the results into a vector.
|
||||
template <typename op_t>
|
||||
Vec256<int64_t> inline emulate(const Vec256<int64_t>& a, const Vec256<int64_t>& b, const op_t& op) {
|
||||
int64_t a0 = _mm256_extract_epi64(a, 0);
|
||||
int64_t a1 = _mm256_extract_epi64(a, 1);
|
||||
int64_t a2 = _mm256_extract_epi64(a, 2);
|
||||
|
|
@ -435,14 +434,23 @@ Vec256<int64_t> inline operator*(const Vec256<int64_t>& a, const Vec256<int64_t>
|
|||
int64_t b2 = _mm256_extract_epi64(b, 2);
|
||||
int64_t b3 = _mm256_extract_epi64(b, 3);
|
||||
|
||||
int64_t c0 = a0 * b0;
|
||||
int64_t c1 = a1 * b1;
|
||||
int64_t c2 = a2 * b2;
|
||||
int64_t c3 = a3 * b3;
|
||||
int64_t c0 = op(a0, b0);
|
||||
int64_t c1 = op(a1, b1);
|
||||
int64_t c2 = op(a2, b2);
|
||||
int64_t c3 = op(a3, b3);
|
||||
|
||||
return _mm256_set_epi64x(c3, c2, c1, c0);
|
||||
}
|
||||
|
||||
// AVX2 has no intrinsic for int64_t multiply so it needs to be emulated
|
||||
// This could be implemented more efficiently using epi32 instructions
|
||||
// This is also technically avx compatible, but then we'll need AVX
|
||||
// code for add as well.
|
||||
template <>
|
||||
Vec256<int64_t> inline operator*(const Vec256<int64_t>& a, const Vec256<int64_t>& b) {
|
||||
return emulate(a, b, [](int64_t a_point, int64_t b_point){return a_point * b_point;});
|
||||
}
|
||||
|
||||
template <>
|
||||
Vec256<int32_t> inline operator*(const Vec256<int32_t>& a, const Vec256<int32_t>& b) {
|
||||
return _mm256_mullo_epi32(a, b);
|
||||
|
|
@ -453,6 +461,36 @@ Vec256<int16_t> inline operator*(const Vec256<int16_t>& a, const Vec256<int16_t>
|
|||
return _mm256_mullo_epi16(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vec256<int64_t> inline minimum(const Vec256<int64_t>& a, const Vec256<int64_t>& b) {
|
||||
return emulate(a, b, [](int64_t a_point, int64_t b_point) {return std::min(a_point, b_point);});
|
||||
}
|
||||
|
||||
template <>
|
||||
Vec256<int32_t> inline minimum(const Vec256<int32_t>& a, const Vec256<int32_t>& b) {
|
||||
return _mm256_min_epi32(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vec256<int16_t> inline minimum(const Vec256<int16_t>& a, const Vec256<int16_t>& b) {
|
||||
return _mm256_min_epi16(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vec256<int64_t> inline maximum(const Vec256<int64_t>& a, const Vec256<int64_t>& b) {
|
||||
return emulate(a, b, [](int64_t a_point, int64_t b_point) {return std::max(a_point, b_point);});
|
||||
}
|
||||
|
||||
template <>
|
||||
Vec256<int32_t> inline maximum(const Vec256<int32_t>& a, const Vec256<int32_t>& b) {
|
||||
return _mm256_max_epi32(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vec256<int16_t> inline maximum(const Vec256<int16_t>& a, const Vec256<int16_t>& b) {
|
||||
return _mm256_max_epi16(a, b);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Vec256<T> inline intdiv_256(const Vec256<T>& a, const Vec256<T>& b) {
|
||||
T values_a[Vec256<T>::size()];
|
||||
|
|
|
|||
|
|
@ -27,6 +27,8 @@ DEFINE_DISPATCH(norm_stub);
|
|||
DEFINE_DISPATCH(mean_stub);
|
||||
DEFINE_DISPATCH(and_stub);
|
||||
DEFINE_DISPATCH(or_stub);
|
||||
DEFINE_DISPATCH(min_values_stub);
|
||||
DEFINE_DISPATCH(max_values_stub);
|
||||
|
||||
static inline Tensor integer_upcast(const Tensor& self, optional<ScalarType> dtype) {
|
||||
ScalarType scalarType = self.type().scalarType();
|
||||
|
|
@ -382,26 +384,36 @@ Tensor prod(const Tensor& self, int64_t dim, ScalarType dtype) {
|
|||
return at::native::prod(self, dim, false, dtype);
|
||||
}
|
||||
|
||||
Tensor& logsumexp_out(Tensor& result, const Tensor &self, int64_t dim_, bool keepdim) {
|
||||
int64_t dim = maybe_wrap_dim(dim_, self.dim());
|
||||
static Tensor squeeze_multiple(const Tensor& self, IntList dims) {
|
||||
int ndims = self.sizes().size();
|
||||
auto dims_to_squeeze = at::dim_list_to_bitset(dims, ndims);
|
||||
Tensor result = self;
|
||||
for (int i = ndims - 1; i >= 0; --i) {
|
||||
if (dims_to_squeeze[i]) {
|
||||
result = result.squeeze(i);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor& logsumexp_out(Tensor& result, const Tensor &self, IntList dims, bool keepdim) {
|
||||
// can't take max of empty tensor
|
||||
if (self.numel() != 0) {
|
||||
auto maxes = at::max_values(self, dim, true);
|
||||
auto maxes_squeezed = (keepdim ? maxes : maxes.squeeze(dim));
|
||||
auto maxes = at::max_values(self, dims, true);
|
||||
auto maxes_squeezed = (keepdim ? maxes : squeeze_multiple(maxes, dims));
|
||||
maxes_squeezed.masked_fill_(maxes_squeezed.abs() == INFINITY, 0);
|
||||
at::sum_out(result, at::exp(self - maxes), dim, keepdim);
|
||||
at::sum_out(result, at::exp(self - maxes), dims, keepdim);
|
||||
result.log_().add_(maxes_squeezed);
|
||||
} else {
|
||||
at::sum_out(result, at::exp(self), dim, keepdim);
|
||||
at::sum_out(result, at::exp(self), dims, keepdim);
|
||||
result.log_();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor logsumexp(const Tensor &self, int64_t dim_, bool keepdim) {
|
||||
int64_t dim = maybe_wrap_dim(dim_, self.dim());
|
||||
Tensor logsumexp(const Tensor &self, IntList dims, bool keepdim) {
|
||||
Tensor result = at::empty({0}, self.options());
|
||||
return at::native::logsumexp_out(result, self, dim, keepdim);
|
||||
return at::native::logsumexp_out(result, self, dims, keepdim);
|
||||
}
|
||||
|
||||
static Tensor& norm_out(Tensor &result, const Tensor &self, optional<Scalar> opt_p,
|
||||
|
|
@ -559,6 +571,32 @@ Tensor &any_out(Tensor &result, const Tensor &self, int64_t dim, bool keepdim) {
|
|||
}
|
||||
}
|
||||
|
||||
Tensor min_values(const Tensor& self, IntList dims, bool keepdim) {
|
||||
if (dims.size() == 1) {
|
||||
return std::get<0>(self.min(dims[0], keepdim));
|
||||
} else {
|
||||
Tensor result = at::empty({0}, self.options());
|
||||
ScalarType dtype = get_dtype(result, self, {}, true);
|
||||
auto iter = make_reduction("min_values", result, self, dims, keepdim, dtype);
|
||||
AT_CHECK(iter->numel() > 0, "min_values on a tensor with no elements is not defined.");
|
||||
min_values_stub(iter->device_type(), *iter);
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
Tensor max_values(const Tensor& self, IntList dims, bool keepdim) {
|
||||
if (dims.size() == 1) {
|
||||
return std::get<0>(self.max(dims[0], keepdim));
|
||||
} else {
|
||||
Tensor result = at::empty({0}, self.options());
|
||||
ScalarType dtype = get_dtype(result, self, {}, true);
|
||||
auto iter = make_reduction("max_values", result, self, dims, keepdim, dtype);
|
||||
AT_CHECK(iter->numel() > 0, "max_values on a tensor with no elements is not defined.");
|
||||
max_values_stub(iter->device_type(), *iter);
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
static Tensor &std_var_out(Tensor &result, const Tensor &self, IntList dim, bool unbiased, bool keepdim, bool take_sqrt) {
|
||||
AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA,
|
||||
"std and var only support CPU AND CUDA backend, got: ", toString(self.type().backend()));
|
||||
|
|
|
|||
|
|
@ -17,6 +17,8 @@ DECLARE_DISPATCH(reduce_fn, prod_stub);
|
|||
DECLARE_DISPATCH(reduce_fn, mean_stub);
|
||||
DECLARE_DISPATCH(reduce_fn, and_stub);
|
||||
DECLARE_DISPATCH(reduce_fn, or_stub);
|
||||
DECLARE_DISPATCH(reduce_fn, min_values_stub);
|
||||
DECLARE_DISPATCH(reduce_fn, max_values_stub);
|
||||
|
||||
using reduce_std_var_function =
|
||||
void (*)(TensorIterator&, bool unbiased, bool take_sqrt);
|
||||
|
|
|
|||
|
|
@ -196,10 +196,6 @@ std::tuple<Tensor &,Tensor &> max_out(Tensor& max, Tensor& max_indices,
|
|||
}
|
||||
}
|
||||
|
||||
Tensor max_values(const Tensor& self, int64_t dim, bool keepdim) {
|
||||
return std::get<0>(self.max(dim, keepdim));
|
||||
}
|
||||
|
||||
std::tuple<Tensor &,Tensor &> _min_out_cpu(Tensor& min, Tensor& min_indices,
|
||||
const Tensor& self, int64_t dim, bool keepdim) {
|
||||
if (self.is_contiguous() && min.is_contiguous() && min_indices.is_contiguous()) {
|
||||
|
|
@ -239,10 +235,6 @@ std::tuple<Tensor &,Tensor &> min_out(Tensor& min, Tensor& min_indices,
|
|||
}
|
||||
}
|
||||
|
||||
Tensor min_values(const Tensor& self, int64_t dim, bool keepdim) {
|
||||
return std::get<0>(self.min(dim, keepdim));
|
||||
}
|
||||
|
||||
// argmax and argmin
|
||||
|
||||
Tensor argmax(const Tensor& self, int64_t dim, bool keepdim) {
|
||||
|
|
|
|||
|
|
@ -148,6 +148,24 @@ static void or_kernel_impl(TensorIterator& iter) {
|
|||
/*ident=*/false);
|
||||
}
|
||||
|
||||
static void min_values_kernel_impl(TensorIterator& iter) {
|
||||
AT_DISPATCH_ALL_TYPES(iter.type(), "min_values", [&iter] {
|
||||
binary_kernel_reduce_vec(
|
||||
iter,
|
||||
[](scalar_t a, scalar_t b) -> scalar_t { return std::min(a, b); },
|
||||
[](Vec256<scalar_t> a, Vec256<scalar_t> b) { return minimum(a, b); });
|
||||
});
|
||||
}
|
||||
|
||||
static void max_values_kernel_impl(TensorIterator& iter) {
|
||||
AT_DISPATCH_ALL_TYPES(iter.type(), "min_values", [&iter] {
|
||||
binary_kernel_reduce_vec(
|
||||
iter,
|
||||
[](scalar_t a, scalar_t b) -> scalar_t { return std::max(a, b); },
|
||||
[](Vec256<scalar_t> a, Vec256<scalar_t> b) { return maximum(a, b); });
|
||||
});
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
REGISTER_DISPATCH(sum_stub, &sum_kernel_impl);
|
||||
|
|
@ -157,5 +175,7 @@ REGISTER_DISPATCH(mean_stub, &mean_kernel_impl);
|
|||
REGISTER_DISPATCH(norm_stub, &norm_kernel_tensor_iterator_impl);
|
||||
REGISTER_DISPATCH(and_stub, &and_kernel_impl);
|
||||
REGISTER_DISPATCH(or_stub, &or_kernel_impl);
|
||||
REGISTER_DISPATCH(min_values_stub, &min_values_kernel_impl);
|
||||
REGISTER_DISPATCH(max_values_stub, &max_values_kernel_impl);
|
||||
|
||||
}} // namespace at::native
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/Context.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/cuda/NumericLimits.cuh>
|
||||
#include <ATen/native/cuda/DeviceSqrt.cuh>
|
||||
#include <ATen/native/cuda/Loops.cuh>
|
||||
#include <ATen/native/cuda/Reduce.cuh>
|
||||
|
|
@ -10,6 +11,7 @@
|
|||
#include <ATen/native/ReduceOps.h>
|
||||
#include <limits>
|
||||
#include <tuple>
|
||||
#include <THC/THCNumerics.cuh>
|
||||
|
||||
|
||||
namespace at { namespace native {
|
||||
|
|
@ -158,6 +160,34 @@ void or_kernel_cuda(TensorIterator& iter) {
|
|||
}), false);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void max_values_kernel_cuda_impl(TensorIterator& iter) {
|
||||
gpu_reduce_kernel<scalar_t, scalar_t>(
|
||||
iter, func_wrapper<scalar_t> ([]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
|
||||
return (THCNumerics<scalar_t>::isnan(a) || a > b) ? a : b;
|
||||
}), at::numeric_limits<scalar_t>::lower_bound());
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void min_values_kernel_cuda_impl(TensorIterator& iter) {
|
||||
gpu_reduce_kernel<scalar_t, scalar_t>(
|
||||
iter, func_wrapper<scalar_t> ([]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
|
||||
return (THCNumerics<scalar_t>::isnan(a) || a < b) ? a : b;
|
||||
}), at::numeric_limits<scalar_t>::upper_bound());
|
||||
}
|
||||
|
||||
void max_values_kernel_cuda(TensorIterator& iter) {
|
||||
AT_DISPATCH_ALL_TYPES(iter.type(), "max_values", [&]() {
|
||||
max_values_kernel_cuda_impl<scalar_t>(iter);
|
||||
});
|
||||
}
|
||||
|
||||
void min_values_kernel_cuda(TensorIterator& iter) {
|
||||
AT_DISPATCH_ALL_TYPES(iter.type(), "min_values", [&]() {
|
||||
min_values_kernel_cuda_impl<scalar_t>(iter);
|
||||
});
|
||||
}
|
||||
|
||||
REGISTER_DISPATCH(std_var_stub, &std_var_kernel_cuda);
|
||||
REGISTER_DISPATCH(sum_stub, &sum_kernel_cuda);
|
||||
REGISTER_DISPATCH(prod_stub, &prod_kernel_cuda);
|
||||
|
|
@ -165,5 +195,7 @@ REGISTER_DISPATCH(mean_stub, &mean_kernel_cuda);
|
|||
REGISTER_DISPATCH(norm_stub, &norm_kernel_cuda);
|
||||
REGISTER_DISPATCH(and_stub, &and_kernel_cuda);
|
||||
REGISTER_DISPATCH(or_stub, &or_kernel_cuda);
|
||||
REGISTER_DISPATCH(max_values_stub, &max_values_kernel_cuda);
|
||||
REGISTER_DISPATCH(min_values_stub, &min_values_kernel_cuda);
|
||||
|
||||
}} // namespace at::native
|
||||
|
|
|
|||
|
|
@ -1313,11 +1313,11 @@
|
|||
CPU: log_softmax_backward_cpu
|
||||
CUDA: log_softmax_backward_cuda
|
||||
|
||||
- func: logsumexp(Tensor self, int dim, bool keepdim=False) -> Tensor
|
||||
- func: logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor
|
||||
matches_jit_signature: True
|
||||
variants: function, method
|
||||
|
||||
- func: logsumexp(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
|
||||
- func: logsumexp(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
|
||||
matches_jit_signature: True
|
||||
|
||||
- func: margin_ranking_loss(Tensor input1, Tensor input2, Tensor target, float margin=0.0, int reduction=Mean) -> Tensor
|
||||
|
|
@ -1345,7 +1345,7 @@
|
|||
|
||||
- func: max(Tensor self, int64_t dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) ->(Tensor(a!) values, Tensor(b!) indices)
|
||||
|
||||
- func: max_values(Tensor self, int dim, bool keepdim=False) -> Tensor
|
||||
- func: max_values(Tensor self, int[1] dim, bool keepdim=False) -> Tensor
|
||||
matches_jit_signature: True
|
||||
variants: function, method
|
||||
|
||||
|
|
@ -1404,7 +1404,7 @@
|
|||
|
||||
- func: min(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) ->(Tensor(a!), Tensor(b!))
|
||||
|
||||
- func: min_values(Tensor self, int dim, bool keepdim=False) -> Tensor
|
||||
- func: min_values(Tensor self, int[1] dim, bool keepdim=False) -> Tensor
|
||||
matches_jit_signature: True
|
||||
variants: function, method
|
||||
|
||||
|
|
|
|||
|
|
@ -2225,6 +2225,15 @@ class _TestTorchMixin(object):
|
|||
lambda n, d: n.var(d, ddof=1 if unbiased else 0),
|
||||
use_integral=False)
|
||||
|
||||
@unittest.skipIf(not TEST_NUMPY, 'Numpy not found')
|
||||
@unittest.skipIf(not TEST_SCIPY, 'Scipy not found')
|
||||
def test_logsumexp_dim(self):
|
||||
from scipy.special import logsumexp
|
||||
self._test_dim_ops(
|
||||
lambda t, d: t.logsumexp(d),
|
||||
lambda n, d: logsumexp(n, d),
|
||||
use_integral=False)
|
||||
|
||||
def test_sum_out(self):
|
||||
x = torch.rand(100, 100)
|
||||
res1 = torch.sum(x, 1)
|
||||
|
|
|
|||
|
|
@ -451,7 +451,7 @@
|
|||
- name: log_normal_(Tensor self, double mean, double std, Generator generator)
|
||||
self: zeros_like(grad)
|
||||
|
||||
- name: logsumexp(Tensor self, int64_t dim, bool keepdim)
|
||||
- name: logsumexp(Tensor self, IntList dim, bool keepdim)
|
||||
self: logsumexp_backward(grad, self, result, dim, keepdim)
|
||||
|
||||
- name: lt_(Tensor self, Scalar other)
|
||||
|
|
|
|||
|
|
@ -429,10 +429,10 @@ Tensor cumsum_backward(const Tensor &x, int64_t dim, ScalarType input_dtype) {
|
|||
return cumsum_backward(x.to(input_dtype), dim);
|
||||
}
|
||||
|
||||
Tensor logsumexp_backward(Tensor grad, const Tensor & self, Tensor result, int64_t dim, bool keepdim) {
|
||||
Tensor logsumexp_backward(Tensor grad, const Tensor & self, Tensor result, IntList dim, bool keepdim) {
|
||||
if (!keepdim && self.dim() != 0) {
|
||||
grad = grad.unsqueeze(dim);
|
||||
result = result.unsqueeze(dim);
|
||||
grad = unsqueeze_multiple(grad, dim, self.sizes().size());
|
||||
result = unsqueeze_multiple(result, dim, self.sizes().size());
|
||||
}
|
||||
return grad * (self - result).exp();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,12 +23,34 @@ def parse_kwargs(desc):
|
|||
return {desc.split(' ')[0]: desc for desc in kwargs}
|
||||
|
||||
|
||||
def merge_dicts(*dicts):
|
||||
return {x: d[x] for d in dicts for x in d}
|
||||
|
||||
|
||||
reduceops_common_args = parse_kwargs("""
|
||||
dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
|
||||
If specified, the input tensor is casted to :attr:`dtype` before the operation
|
||||
is performed. This is useful for preventing data type overflows. Default: None.
|
||||
keepdim (bool): whether the output tensor has :attr:`dim` retained or not
|
||||
""")
|
||||
|
||||
multi_dim_common = merge_dicts(reduceops_common_args, parse_kwargs("""
|
||||
dim (int or tuple of ints): the dimension or dimensions to reduce
|
||||
"""), {'keepdim_details': """
|
||||
If :attr:`keepdim` is ``True``, the output tensor is of the same size
|
||||
as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1.
|
||||
Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the
|
||||
output tensor having 1 (or ``len(dim)``) fewer dimension(s).
|
||||
"""})
|
||||
|
||||
single_dim_common = merge_dicts(reduceops_common_args, parse_kwargs("""
|
||||
dim (int): the dimension to reduce
|
||||
"""), {'keepdim_details': """If :attr:`keepdim` is ``True``, the output tensor is of the same size
|
||||
as :attr:`input` except in the dimension :attr:`dim` where it is of size 1.
|
||||
Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in
|
||||
the output tensor having 1 fewer dimension than :attr:`input`."""})
|
||||
|
||||
|
||||
factory_common_args = parse_kwargs("""
|
||||
out (Tensor, optional): the output tensor
|
||||
dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
|
||||
|
|
@ -2487,17 +2509,14 @@ stabilized.
|
|||
For summation index :math:`j` given by `dim` and other indices :math:`i`, the result is
|
||||
|
||||
.. math::
|
||||
\text{logsumexp}(x)_{i} = \log \sum_j \exp(x_{ij})
|
||||
\text{{logsumexp}}(x)_{{i}} = \log \sum_j \exp(x_{{ij}})
|
||||
|
||||
If :attr:`keepdim` is ``True``, the output tensor is of the same size
|
||||
as :attr:`input` except in the dimension :attr:`dim` where it is of size 1.
|
||||
Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in
|
||||
the output tensor having 1 fewer dimension than :attr:`input`.
|
||||
{keepdim_details}
|
||||
|
||||
Args:
|
||||
input (Tensor): the input tensor
|
||||
dim (int): the dimension to reduce
|
||||
keepdim (bool): whether the output tensor has :attr:`dim` retained or not
|
||||
{dim}
|
||||
{keepdim}
|
||||
out (Tensor, optional): the output tensor
|
||||
|
||||
|
||||
|
|
@ -2505,7 +2524,7 @@ Example::
|
|||
>>> a = torch.randn(3, 3)
|
||||
>>> torch.logsumexp(a, 1)
|
||||
tensor([ 0.8442, 1.4322, 0.8711])
|
||||
""")
|
||||
""".format(**multi_dim_common))
|
||||
|
||||
add_docstr(torch.lt,
|
||||
r"""
|
||||
|
|
@ -2730,15 +2749,12 @@ Returns the mean value of each row of the :attr:`input` tensor in the given
|
|||
dimension :attr:`dim`. If :attr:`dim` is a list of dimensions,
|
||||
reduce over all of them.
|
||||
|
||||
If :attr:`keepdim` is ``True``, the output tensor is of the same size
|
||||
as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1.
|
||||
Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the
|
||||
output tensor having 1 (or ``len(dim)``) fewer dimension(s).
|
||||
{keepdim_details}
|
||||
|
||||
Args:
|
||||
input (Tensor): the input tensor
|
||||
dim (int or tuple of ints): the dimension or dimensions to reduce
|
||||
keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not
|
||||
{dim}
|
||||
{keepdim}
|
||||
out (Tensor): the output tensor
|
||||
|
||||
Example::
|
||||
|
|
@ -2756,7 +2772,7 @@ Example::
|
|||
[-0.5085],
|
||||
[-0.4599],
|
||||
[ 0.1807]])
|
||||
""")
|
||||
""".format(**multi_dim_common))
|
||||
|
||||
add_docstr(torch.median,
|
||||
r"""
|
||||
|
|
@ -3582,15 +3598,12 @@ Example::
|
|||
Returns the product of each row of the :attr:`input` tensor in the given
|
||||
dimension :attr:`dim`.
|
||||
|
||||
If :attr:`keepdim` is ``True``, the output tensor is of the same size as
|
||||
:attr:`input` except in the dimension :attr:`dim` where it is of size 1.
|
||||
Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting
|
||||
in the output tensor having 1 fewer dimension than :attr:`input`.
|
||||
{keepdim_details}
|
||||
|
||||
Args:
|
||||
input (Tensor): the input tensor
|
||||
dim (int): the dimension to reduce
|
||||
keepdim (bool): whether the output tensor has :attr:`dim` retained or not
|
||||
{dim}
|
||||
{keepdim}
|
||||
{dtype}
|
||||
|
||||
Example::
|
||||
|
|
@ -3603,7 +3616,7 @@ Example::
|
|||
[ 1.1131, -1.0629]])
|
||||
>>> torch.prod(a, 1)
|
||||
tensor([-0.2018, -0.2962, -0.0821, -1.1831])
|
||||
""".format(**reduceops_common_args))
|
||||
""".format(**single_dim_common))
|
||||
|
||||
add_docstr(torch.pstrf, r"""
|
||||
pstrf(a, upper=True, out=None) -> (Tensor, Tensor)
|
||||
|
|
@ -4461,18 +4474,15 @@ Returns the standard-deviation of each row of the :attr:`input` tensor in the
|
|||
dimension :attr:`dim`. If :attr:`dim` is a list of dimensions,
|
||||
reduce over all of them.
|
||||
|
||||
If :attr:`keepdim` is ``True``, the output tensor is of the same size
|
||||
as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1.
|
||||
Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the
|
||||
output tensor having 1 (or ``len(dim)``) fewer dimension(s).
|
||||
{keepdim_details}
|
||||
|
||||
If :attr:`unbiased` is ``False``, then the standard-deviation will be calculated
|
||||
via the biased estimator. Otherwise, Bessel's correction will be used.
|
||||
|
||||
Args:
|
||||
input (Tensor): the input tensor
|
||||
dim (int or tuple of ints): the dimension or dimensions to reduce
|
||||
keepdim (bool): whether the output tensor has :attr:`dim` retained or not
|
||||
{dim}
|
||||
{keepdim}
|
||||
unbiased (bool): whether to use the unbiased estimation or not
|
||||
out (Tensor, optional): the output tensor
|
||||
|
||||
|
|
@ -4486,7 +4496,7 @@ Example::
|
|||
[ 0.1264, -0.5080, 1.6420, 0.1992]])
|
||||
>>> torch.std(a, dim=1)
|
||||
tensor([ 1.0311, 0.7477, 1.2204, 0.9087])
|
||||
""")
|
||||
""".format(**multi_dim_common))
|
||||
|
||||
add_docstr(torch.sum,
|
||||
r"""
|
||||
|
|
@ -4512,15 +4522,12 @@ Returns the sum of each row of the :attr:`input` tensor in the given
|
|||
dimension :attr:`dim`. If :attr:`dim` is a list of dimensions,
|
||||
reduce over all of them.
|
||||
|
||||
If :attr:`keepdim` is ``True``, the output tensor is of the same size
|
||||
as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1.
|
||||
Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the
|
||||
output tensor having 1 (or ``len(dim)``) fewer dimension(s).
|
||||
{keepdim_details}
|
||||
|
||||
Args:
|
||||
input (Tensor): the input tensor
|
||||
dim (int or tuple of ints): the dimension or dimensions to reduce
|
||||
keepdim (bool): whether the output tensor has :attr:`dim` retained or not
|
||||
{dim}
|
||||
{keepdim}
|
||||
{dtype}
|
||||
|
||||
Example::
|
||||
|
|
@ -4536,7 +4543,7 @@ Example::
|
|||
>>> b = torch.arange(4 * 5 * 6).view(4, 5, 6)
|
||||
>>> torch.sum(b, (2, 1))
|
||||
tensor([ 435., 1335., 2235., 3135.])
|
||||
""".format(**reduceops_common_args))
|
||||
""".format(**multi_dim_common))
|
||||
|
||||
add_docstr(torch.svd,
|
||||
r"""
|
||||
|
|
@ -5300,18 +5307,15 @@ Example::
|
|||
Returns the variance of each row of the :attr:`input` tensor in the given
|
||||
dimension :attr:`dim`.
|
||||
|
||||
If :attr:`keepdim` is ``True``, the output tensors are of the same size
|
||||
as :attr:`input` except in the dimension :attr:`dim` where they are of size 1.
|
||||
Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in
|
||||
the outputs tensor having 1 fewer dimension than :attr:`input`.
|
||||
{keepdim_details}
|
||||
|
||||
If :attr:`unbiased` is ``False``, then the variance will be calculated via the
|
||||
biased estimator. Otherwise, Bessel's correction will be used.
|
||||
|
||||
Args:
|
||||
input (Tensor): the input tensor
|
||||
dim (int): the dimension to reduce
|
||||
keepdim (bool): whether the output tensor has :attr:`dim` retained or not
|
||||
{dim}
|
||||
{keepdim}
|
||||
unbiased (bool): whether to use the unbiased estimation or not
|
||||
out (Tensor, optional): the output tensor
|
||||
|
||||
|
|
@ -5325,7 +5329,7 @@ Example::
|
|||
[-0.7700, 0.6074, -0.1469, 0.7777]])
|
||||
>>> torch.var(a, 1)
|
||||
tensor([ 1.7444, 1.1363, 0.7356, 0.5112])
|
||||
""")
|
||||
""".format(**multi_dim_common))
|
||||
|
||||
add_docstr(torch.zeros,
|
||||
r"""
|
||||
|
|
|
|||
|
|
@ -904,9 +904,6 @@ class ShapePropagator {
|
|||
{
|
||||
"aten::argmax(Tensor self, int dim, bool keepdim) -> Tensor",
|
||||
"aten::argmin(Tensor self, int dim, bool keepdim) -> Tensor",
|
||||
"aten::max_values(Tensor self, int dim, bool keepdim) -> Tensor",
|
||||
"aten::min_values(Tensor self, int dim, bool keepdim) -> Tensor",
|
||||
"aten::logsumexp(Tensor self, int dim, bool keepdim) -> Tensor",
|
||||
"aten::all(Tensor self, int dim, bool keepdim) -> Tensor",
|
||||
"aten::any(Tensor self, int dim, bool keepdim) -> Tensor",
|
||||
|
||||
|
|
@ -957,10 +954,13 @@ class ShapePropagator {
|
|||
// - has a bool keepdim argument
|
||||
static const register_formula_for multidim_reduce_ops{
|
||||
{
|
||||
"aten::logsumexp(Tensor self, int[] dim, bool keepdim) -> Tensor",
|
||||
"aten::mean(Tensor self, int[] dim, bool keepdim) -> Tensor",
|
||||
"aten::norm(Tensor self, Scalar? p, int[] dim, bool keepdim) -> Tensor",
|
||||
"aten::std(Tensor self, int[] dim, bool unbiased, bool keepdim) -> Tensor",
|
||||
"aten::var(Tensor self, int[] dim, bool unbiased, bool keepdim) -> Tensor",
|
||||
"aten::max_values(Tensor self, int[] dim, bool keepdim) -> Tensor",
|
||||
"aten::min_values(Tensor self, int[] dim, bool keepdim) -> Tensor",
|
||||
},
|
||||
[](Node* node) -> type_vec_t {
|
||||
if (auto dim = node->get<std::vector<int64_t>>(attr::dim)) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user