logsumexp for multiple dimensions (#16475)

Summary:
Move `logsumexp` and `max_values` to `TensorIterator` and use it to make `logsumexp` work for multiple dimensions.

Timings on a tensor of shape `(10,1000000,10)`, for each combination of (cpu, single-threaded cpu, gpu) and dimension:

**before**
208 ms ± 2.72 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
279 ms ± 5.07 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
199 ms ± 2.64 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.11 s ± 33.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.25 s ± 25.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.11 s ± 6.83 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
15.4 ms ± 1.02 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
132 ms ± 30.1 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
39.6 ms ± 19.1 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

**after**
199 ms ± 8.23 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
307 ms ± 8.73 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
207 ms ± 7.62 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
1.16 s ± 8.92 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.26 s ± 47.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.13 s ± 13.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
15.4 ms ± 868 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)
132 ms ± 27.6 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
39.6 ms ± 21.8 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16475

Differential Revision: D13855746

Pulled By: umanwizard

fbshipit-source-id: aaacc0b967c3f89073487e1952ae6f76b7bd7ad3
This commit is contained in:
Brennan Vincent 2019-02-05 08:27:04 -08:00 committed by Facebook Github Bot
parent 4047c97266
commit 1ce188c510
15 changed files with 225 additions and 90 deletions

View File

@ -411,11 +411,11 @@ class CAFFE2_API Tensor {
Tensor logdet() const;
Tensor log_softmax(int64_t dim, ScalarType dtype) const;
Tensor log_softmax(int64_t dim) const;
Tensor logsumexp(int64_t dim, bool keepdim=false) const;
Tensor logsumexp(IntList dim, bool keepdim=false) const;
Tensor matmul(const Tensor & other) const;
Tensor matrix_power(int64_t n) const;
std::tuple<Tensor,Tensor> max(int64_t dim, bool keepdim=false) const;
Tensor max_values(int64_t dim, bool keepdim=false) const;
Tensor max_values(IntList dim, bool keepdim=false) const;
Tensor mean(ScalarType dtype) const;
Tensor mean() const;
Tensor mean(IntList dim, bool keepdim, ScalarType dtype) const;
@ -423,7 +423,7 @@ class CAFFE2_API Tensor {
Tensor mean(IntList dim, ScalarType dtype) const;
std::tuple<Tensor,Tensor> median(int64_t dim, bool keepdim=false) const;
std::tuple<Tensor,Tensor> min(int64_t dim, bool keepdim=false) const;
Tensor min_values(int64_t dim, bool keepdim=false) const;
Tensor min_values(IntList dim, bool keepdim=false) const;
Tensor mm(const Tensor & mat2) const;
std::tuple<Tensor,Tensor> mode(int64_t dim=-1, bool keepdim=false) const;
Tensor mul(const Tensor & other) const;

View File

@ -376,7 +376,7 @@ inline Tensor Tensor::log_softmax(int64_t dim, ScalarType dtype) const {
inline Tensor Tensor::log_softmax(int64_t dim) const {
return type().log_softmax(*this, dim);
}
inline Tensor Tensor::logsumexp(int64_t dim, bool keepdim) const {
inline Tensor Tensor::logsumexp(IntList dim, bool keepdim) const {
return type().logsumexp(*this, dim, keepdim);
}
inline Tensor Tensor::matmul(const Tensor & other) const {
@ -388,7 +388,7 @@ inline Tensor Tensor::matrix_power(int64_t n) const {
inline std::tuple<Tensor,Tensor> Tensor::max(int64_t dim, bool keepdim) const {
return type().max(*this, dim, keepdim);
}
inline Tensor Tensor::max_values(int64_t dim, bool keepdim) const {
inline Tensor Tensor::max_values(IntList dim, bool keepdim) const {
return type().max_values(*this, dim, keepdim);
}
inline Tensor Tensor::mean(ScalarType dtype) const {
@ -412,7 +412,7 @@ inline std::tuple<Tensor,Tensor> Tensor::median(int64_t dim, bool keepdim) const
inline std::tuple<Tensor,Tensor> Tensor::min(int64_t dim, bool keepdim) const {
return type().min(*this, dim, keepdim);
}
inline Tensor Tensor::min_values(int64_t dim, bool keepdim) const {
inline Tensor Tensor::min_values(IntList dim, bool keepdim) const {
return type().min_values(*this, dim, keepdim);
}
inline Tensor Tensor::mm(const Tensor & mat2) const {

View File

@ -293,11 +293,11 @@ struct CAFFE2_API Type {
virtual Tensor logdet(const Tensor & self) const = 0;
virtual Tensor log_softmax(const Tensor & self, int64_t dim, ScalarType dtype) const = 0;
virtual Tensor log_softmax(const Tensor & self, int64_t dim) const = 0;
virtual Tensor logsumexp(const Tensor & self, int64_t dim, bool keepdim) const = 0;
virtual Tensor logsumexp(const Tensor & self, IntList dim, bool keepdim) const = 0;
virtual Tensor matmul(const Tensor & self, const Tensor & other) const = 0;
virtual Tensor matrix_power(const Tensor & self, int64_t n) const = 0;
virtual std::tuple<Tensor,Tensor> max(const Tensor & self, int64_t dim, bool keepdim) const = 0;
virtual Tensor max_values(const Tensor & self, int64_t dim, bool keepdim) const = 0;
virtual Tensor max_values(const Tensor & self, IntList dim, bool keepdim) const = 0;
virtual Tensor mean(const Tensor & self, ScalarType dtype) const = 0;
virtual Tensor mean(const Tensor & self) const = 0;
virtual Tensor mean(const Tensor & self, IntList dim, bool keepdim, ScalarType dtype) const = 0;
@ -305,7 +305,7 @@ struct CAFFE2_API Type {
virtual Tensor mean(const Tensor & self, IntList dim, ScalarType dtype) const = 0;
virtual std::tuple<Tensor,Tensor> median(const Tensor & self, int64_t dim, bool keepdim) const = 0;
virtual std::tuple<Tensor,Tensor> min(const Tensor & self, int64_t dim, bool keepdim) const = 0;
virtual Tensor min_values(const Tensor & self, int64_t dim, bool keepdim) const = 0;
virtual Tensor min_values(const Tensor & self, IntList dim, bool keepdim) const = 0;
virtual Tensor mm(const Tensor & self, const Tensor & mat2) const = 0;
virtual std::tuple<Tensor,Tensor> mode(const Tensor & self, int64_t dim, bool keepdim) const = 0;
virtual Tensor mul(const Tensor & self, const Tensor & other) const = 0;

View File

@ -419,12 +419,11 @@ Vec256<int16_t> inline operator-(const Vec256<int16_t>& a, const Vec256<int16_t>
return _mm256_sub_epi16(a, b);
}
// AVX2 has no intrinsic for int64_t multiply so it needs to be emulated
// This could be implemented more efficiently using epi32 instructions
// This is also technically avx compatible, but then we'll need AVX
// code for add as well.
template <>
Vec256<int64_t> inline operator*(const Vec256<int64_t>& a, const Vec256<int64_t>& b) {
// Emulate operations with no native 64-bit support in avx,
// by extracting each element, performing the operation pointwise,
// then combining the results into a vector.
template <typename op_t>
Vec256<int64_t> inline emulate(const Vec256<int64_t>& a, const Vec256<int64_t>& b, const op_t& op) {
int64_t a0 = _mm256_extract_epi64(a, 0);
int64_t a1 = _mm256_extract_epi64(a, 1);
int64_t a2 = _mm256_extract_epi64(a, 2);
@ -435,14 +434,23 @@ Vec256<int64_t> inline operator*(const Vec256<int64_t>& a, const Vec256<int64_t>
int64_t b2 = _mm256_extract_epi64(b, 2);
int64_t b3 = _mm256_extract_epi64(b, 3);
int64_t c0 = a0 * b0;
int64_t c1 = a1 * b1;
int64_t c2 = a2 * b2;
int64_t c3 = a3 * b3;
int64_t c0 = op(a0, b0);
int64_t c1 = op(a1, b1);
int64_t c2 = op(a2, b2);
int64_t c3 = op(a3, b3);
return _mm256_set_epi64x(c3, c2, c1, c0);
}
// AVX2 has no intrinsic for int64_t multiply so it needs to be emulated
// This could be implemented more efficiently using epi32 instructions
// This is also technically avx compatible, but then we'll need AVX
// code for add as well.
template <>
Vec256<int64_t> inline operator*(const Vec256<int64_t>& a, const Vec256<int64_t>& b) {
return emulate(a, b, [](int64_t a_point, int64_t b_point){return a_point * b_point;});
}
template <>
Vec256<int32_t> inline operator*(const Vec256<int32_t>& a, const Vec256<int32_t>& b) {
return _mm256_mullo_epi32(a, b);
@ -453,6 +461,36 @@ Vec256<int16_t> inline operator*(const Vec256<int16_t>& a, const Vec256<int16_t>
return _mm256_mullo_epi16(a, b);
}
template <>
Vec256<int64_t> inline minimum(const Vec256<int64_t>& a, const Vec256<int64_t>& b) {
return emulate(a, b, [](int64_t a_point, int64_t b_point) {return std::min(a_point, b_point);});
}
template <>
Vec256<int32_t> inline minimum(const Vec256<int32_t>& a, const Vec256<int32_t>& b) {
return _mm256_min_epi32(a, b);
}
template <>
Vec256<int16_t> inline minimum(const Vec256<int16_t>& a, const Vec256<int16_t>& b) {
return _mm256_min_epi16(a, b);
}
template <>
Vec256<int64_t> inline maximum(const Vec256<int64_t>& a, const Vec256<int64_t>& b) {
return emulate(a, b, [](int64_t a_point, int64_t b_point) {return std::max(a_point, b_point);});
}
template <>
Vec256<int32_t> inline maximum(const Vec256<int32_t>& a, const Vec256<int32_t>& b) {
return _mm256_max_epi32(a, b);
}
template <>
Vec256<int16_t> inline maximum(const Vec256<int16_t>& a, const Vec256<int16_t>& b) {
return _mm256_max_epi16(a, b);
}
template <typename T>
Vec256<T> inline intdiv_256(const Vec256<T>& a, const Vec256<T>& b) {
T values_a[Vec256<T>::size()];

View File

@ -27,6 +27,8 @@ DEFINE_DISPATCH(norm_stub);
DEFINE_DISPATCH(mean_stub);
DEFINE_DISPATCH(and_stub);
DEFINE_DISPATCH(or_stub);
DEFINE_DISPATCH(min_values_stub);
DEFINE_DISPATCH(max_values_stub);
static inline Tensor integer_upcast(const Tensor& self, optional<ScalarType> dtype) {
ScalarType scalarType = self.type().scalarType();
@ -382,26 +384,36 @@ Tensor prod(const Tensor& self, int64_t dim, ScalarType dtype) {
return at::native::prod(self, dim, false, dtype);
}
Tensor& logsumexp_out(Tensor& result, const Tensor &self, int64_t dim_, bool keepdim) {
int64_t dim = maybe_wrap_dim(dim_, self.dim());
static Tensor squeeze_multiple(const Tensor& self, IntList dims) {
int ndims = self.sizes().size();
auto dims_to_squeeze = at::dim_list_to_bitset(dims, ndims);
Tensor result = self;
for (int i = ndims - 1; i >= 0; --i) {
if (dims_to_squeeze[i]) {
result = result.squeeze(i);
}
}
return result;
}
Tensor& logsumexp_out(Tensor& result, const Tensor &self, IntList dims, bool keepdim) {
// can't take max of empty tensor
if (self.numel() != 0) {
auto maxes = at::max_values(self, dim, true);
auto maxes_squeezed = (keepdim ? maxes : maxes.squeeze(dim));
auto maxes = at::max_values(self, dims, true);
auto maxes_squeezed = (keepdim ? maxes : squeeze_multiple(maxes, dims));
maxes_squeezed.masked_fill_(maxes_squeezed.abs() == INFINITY, 0);
at::sum_out(result, at::exp(self - maxes), dim, keepdim);
at::sum_out(result, at::exp(self - maxes), dims, keepdim);
result.log_().add_(maxes_squeezed);
} else {
at::sum_out(result, at::exp(self), dim, keepdim);
at::sum_out(result, at::exp(self), dims, keepdim);
result.log_();
}
return result;
}
Tensor logsumexp(const Tensor &self, int64_t dim_, bool keepdim) {
int64_t dim = maybe_wrap_dim(dim_, self.dim());
Tensor logsumexp(const Tensor &self, IntList dims, bool keepdim) {
Tensor result = at::empty({0}, self.options());
return at::native::logsumexp_out(result, self, dim, keepdim);
return at::native::logsumexp_out(result, self, dims, keepdim);
}
static Tensor& norm_out(Tensor &result, const Tensor &self, optional<Scalar> opt_p,
@ -559,6 +571,32 @@ Tensor &any_out(Tensor &result, const Tensor &self, int64_t dim, bool keepdim) {
}
}
Tensor min_values(const Tensor& self, IntList dims, bool keepdim) {
if (dims.size() == 1) {
return std::get<0>(self.min(dims[0], keepdim));
} else {
Tensor result = at::empty({0}, self.options());
ScalarType dtype = get_dtype(result, self, {}, true);
auto iter = make_reduction("min_values", result, self, dims, keepdim, dtype);
AT_CHECK(iter->numel() > 0, "min_values on a tensor with no elements is not defined.");
min_values_stub(iter->device_type(), *iter);
return result;
}
}
Tensor max_values(const Tensor& self, IntList dims, bool keepdim) {
if (dims.size() == 1) {
return std::get<0>(self.max(dims[0], keepdim));
} else {
Tensor result = at::empty({0}, self.options());
ScalarType dtype = get_dtype(result, self, {}, true);
auto iter = make_reduction("max_values", result, self, dims, keepdim, dtype);
AT_CHECK(iter->numel() > 0, "max_values on a tensor with no elements is not defined.");
max_values_stub(iter->device_type(), *iter);
return result;
}
}
static Tensor &std_var_out(Tensor &result, const Tensor &self, IntList dim, bool unbiased, bool keepdim, bool take_sqrt) {
AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA,
"std and var only support CPU AND CUDA backend, got: ", toString(self.type().backend()));

View File

@ -17,6 +17,8 @@ DECLARE_DISPATCH(reduce_fn, prod_stub);
DECLARE_DISPATCH(reduce_fn, mean_stub);
DECLARE_DISPATCH(reduce_fn, and_stub);
DECLARE_DISPATCH(reduce_fn, or_stub);
DECLARE_DISPATCH(reduce_fn, min_values_stub);
DECLARE_DISPATCH(reduce_fn, max_values_stub);
using reduce_std_var_function =
void (*)(TensorIterator&, bool unbiased, bool take_sqrt);

View File

@ -196,10 +196,6 @@ std::tuple<Tensor &,Tensor &> max_out(Tensor& max, Tensor& max_indices,
}
}
Tensor max_values(const Tensor& self, int64_t dim, bool keepdim) {
return std::get<0>(self.max(dim, keepdim));
}
std::tuple<Tensor &,Tensor &> _min_out_cpu(Tensor& min, Tensor& min_indices,
const Tensor& self, int64_t dim, bool keepdim) {
if (self.is_contiguous() && min.is_contiguous() && min_indices.is_contiguous()) {
@ -239,10 +235,6 @@ std::tuple<Tensor &,Tensor &> min_out(Tensor& min, Tensor& min_indices,
}
}
Tensor min_values(const Tensor& self, int64_t dim, bool keepdim) {
return std::get<0>(self.min(dim, keepdim));
}
// argmax and argmin
Tensor argmax(const Tensor& self, int64_t dim, bool keepdim) {

View File

@ -148,6 +148,24 @@ static void or_kernel_impl(TensorIterator& iter) {
/*ident=*/false);
}
static void min_values_kernel_impl(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES(iter.type(), "min_values", [&iter] {
binary_kernel_reduce_vec(
iter,
[](scalar_t a, scalar_t b) -> scalar_t { return std::min(a, b); },
[](Vec256<scalar_t> a, Vec256<scalar_t> b) { return minimum(a, b); });
});
}
static void max_values_kernel_impl(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES(iter.type(), "min_values", [&iter] {
binary_kernel_reduce_vec(
iter,
[](scalar_t a, scalar_t b) -> scalar_t { return std::max(a, b); },
[](Vec256<scalar_t> a, Vec256<scalar_t> b) { return maximum(a, b); });
});
}
} // anonymous namespace
REGISTER_DISPATCH(sum_stub, &sum_kernel_impl);
@ -157,5 +175,7 @@ REGISTER_DISPATCH(mean_stub, &mean_kernel_impl);
REGISTER_DISPATCH(norm_stub, &norm_kernel_tensor_iterator_impl);
REGISTER_DISPATCH(and_stub, &and_kernel_impl);
REGISTER_DISPATCH(or_stub, &or_kernel_impl);
REGISTER_DISPATCH(min_values_stub, &min_values_kernel_impl);
REGISTER_DISPATCH(max_values_stub, &max_values_kernel_impl);
}} // namespace at::native

View File

@ -2,6 +2,7 @@
#include <ATen/AccumulateType.h>
#include <ATen/Context.h>
#include <ATen/Dispatch.h>
#include <ATen/cuda/NumericLimits.cuh>
#include <ATen/native/cuda/DeviceSqrt.cuh>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/cuda/Reduce.cuh>
@ -10,6 +11,7 @@
#include <ATen/native/ReduceOps.h>
#include <limits>
#include <tuple>
#include <THC/THCNumerics.cuh>
namespace at { namespace native {
@ -158,6 +160,34 @@ void or_kernel_cuda(TensorIterator& iter) {
}), false);
}
template <typename scalar_t>
void max_values_kernel_cuda_impl(TensorIterator& iter) {
gpu_reduce_kernel<scalar_t, scalar_t>(
iter, func_wrapper<scalar_t> ([]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return (THCNumerics<scalar_t>::isnan(a) || a > b) ? a : b;
}), at::numeric_limits<scalar_t>::lower_bound());
}
template <typename scalar_t>
void min_values_kernel_cuda_impl(TensorIterator& iter) {
gpu_reduce_kernel<scalar_t, scalar_t>(
iter, func_wrapper<scalar_t> ([]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return (THCNumerics<scalar_t>::isnan(a) || a < b) ? a : b;
}), at::numeric_limits<scalar_t>::upper_bound());
}
void max_values_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES(iter.type(), "max_values", [&]() {
max_values_kernel_cuda_impl<scalar_t>(iter);
});
}
void min_values_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES(iter.type(), "min_values", [&]() {
min_values_kernel_cuda_impl<scalar_t>(iter);
});
}
REGISTER_DISPATCH(std_var_stub, &std_var_kernel_cuda);
REGISTER_DISPATCH(sum_stub, &sum_kernel_cuda);
REGISTER_DISPATCH(prod_stub, &prod_kernel_cuda);
@ -165,5 +195,7 @@ REGISTER_DISPATCH(mean_stub, &mean_kernel_cuda);
REGISTER_DISPATCH(norm_stub, &norm_kernel_cuda);
REGISTER_DISPATCH(and_stub, &and_kernel_cuda);
REGISTER_DISPATCH(or_stub, &or_kernel_cuda);
REGISTER_DISPATCH(max_values_stub, &max_values_kernel_cuda);
REGISTER_DISPATCH(min_values_stub, &min_values_kernel_cuda);
}} // namespace at::native

View File

@ -1313,11 +1313,11 @@
CPU: log_softmax_backward_cpu
CUDA: log_softmax_backward_cuda
- func: logsumexp(Tensor self, int dim, bool keepdim=False) -> Tensor
- func: logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor
matches_jit_signature: True
variants: function, method
- func: logsumexp(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
- func: logsumexp(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
matches_jit_signature: True
- func: margin_ranking_loss(Tensor input1, Tensor input2, Tensor target, float margin=0.0, int reduction=Mean) -> Tensor
@ -1345,7 +1345,7 @@
- func: max(Tensor self, int64_t dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) ->(Tensor(a!) values, Tensor(b!) indices)
- func: max_values(Tensor self, int dim, bool keepdim=False) -> Tensor
- func: max_values(Tensor self, int[1] dim, bool keepdim=False) -> Tensor
matches_jit_signature: True
variants: function, method
@ -1404,7 +1404,7 @@
- func: min(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) ->(Tensor(a!), Tensor(b!))
- func: min_values(Tensor self, int dim, bool keepdim=False) -> Tensor
- func: min_values(Tensor self, int[1] dim, bool keepdim=False) -> Tensor
matches_jit_signature: True
variants: function, method

View File

@ -2225,6 +2225,15 @@ class _TestTorchMixin(object):
lambda n, d: n.var(d, ddof=1 if unbiased else 0),
use_integral=False)
@unittest.skipIf(not TEST_NUMPY, 'Numpy not found')
@unittest.skipIf(not TEST_SCIPY, 'Scipy not found')
def test_logsumexp_dim(self):
from scipy.special import logsumexp
self._test_dim_ops(
lambda t, d: t.logsumexp(d),
lambda n, d: logsumexp(n, d),
use_integral=False)
def test_sum_out(self):
x = torch.rand(100, 100)
res1 = torch.sum(x, 1)

View File

@ -451,7 +451,7 @@
- name: log_normal_(Tensor self, double mean, double std, Generator generator)
self: zeros_like(grad)
- name: logsumexp(Tensor self, int64_t dim, bool keepdim)
- name: logsumexp(Tensor self, IntList dim, bool keepdim)
self: logsumexp_backward(grad, self, result, dim, keepdim)
- name: lt_(Tensor self, Scalar other)

View File

@ -429,10 +429,10 @@ Tensor cumsum_backward(const Tensor &x, int64_t dim, ScalarType input_dtype) {
return cumsum_backward(x.to(input_dtype), dim);
}
Tensor logsumexp_backward(Tensor grad, const Tensor & self, Tensor result, int64_t dim, bool keepdim) {
Tensor logsumexp_backward(Tensor grad, const Tensor & self, Tensor result, IntList dim, bool keepdim) {
if (!keepdim && self.dim() != 0) {
grad = grad.unsqueeze(dim);
result = result.unsqueeze(dim);
grad = unsqueeze_multiple(grad, dim, self.sizes().size());
result = unsqueeze_multiple(result, dim, self.sizes().size());
}
return grad * (self - result).exp();
}

View File

@ -23,12 +23,34 @@ def parse_kwargs(desc):
return {desc.split(' ')[0]: desc for desc in kwargs}
def merge_dicts(*dicts):
return {x: d[x] for d in dicts for x in d}
reduceops_common_args = parse_kwargs("""
dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
If specified, the input tensor is casted to :attr:`dtype` before the operation
is performed. This is useful for preventing data type overflows. Default: None.
keepdim (bool): whether the output tensor has :attr:`dim` retained or not
""")
multi_dim_common = merge_dicts(reduceops_common_args, parse_kwargs("""
dim (int or tuple of ints): the dimension or dimensions to reduce
"""), {'keepdim_details': """
If :attr:`keepdim` is ``True``, the output tensor is of the same size
as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1.
Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the
output tensor having 1 (or ``len(dim)``) fewer dimension(s).
"""})
single_dim_common = merge_dicts(reduceops_common_args, parse_kwargs("""
dim (int): the dimension to reduce
"""), {'keepdim_details': """If :attr:`keepdim` is ``True``, the output tensor is of the same size
as :attr:`input` except in the dimension :attr:`dim` where it is of size 1.
Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in
the output tensor having 1 fewer dimension than :attr:`input`."""})
factory_common_args = parse_kwargs("""
out (Tensor, optional): the output tensor
dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
@ -2487,17 +2509,14 @@ stabilized.
For summation index :math:`j` given by `dim` and other indices :math:`i`, the result is
.. math::
\text{logsumexp}(x)_{i} = \log \sum_j \exp(x_{ij})
\text{{logsumexp}}(x)_{{i}} = \log \sum_j \exp(x_{{ij}})
If :attr:`keepdim` is ``True``, the output tensor is of the same size
as :attr:`input` except in the dimension :attr:`dim` where it is of size 1.
Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in
the output tensor having 1 fewer dimension than :attr:`input`.
{keepdim_details}
Args:
input (Tensor): the input tensor
dim (int): the dimension to reduce
keepdim (bool): whether the output tensor has :attr:`dim` retained or not
{dim}
{keepdim}
out (Tensor, optional): the output tensor
@ -2505,7 +2524,7 @@ Example::
>>> a = torch.randn(3, 3)
>>> torch.logsumexp(a, 1)
tensor([ 0.8442, 1.4322, 0.8711])
""")
""".format(**multi_dim_common))
add_docstr(torch.lt,
r"""
@ -2730,15 +2749,12 @@ Returns the mean value of each row of the :attr:`input` tensor in the given
dimension :attr:`dim`. If :attr:`dim` is a list of dimensions,
reduce over all of them.
If :attr:`keepdim` is ``True``, the output tensor is of the same size
as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1.
Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the
output tensor having 1 (or ``len(dim)``) fewer dimension(s).
{keepdim_details}
Args:
input (Tensor): the input tensor
dim (int or tuple of ints): the dimension or dimensions to reduce
keepdim (bool, optional): whether the output tensor has :attr:`dim` retained or not
{dim}
{keepdim}
out (Tensor): the output tensor
Example::
@ -2756,7 +2772,7 @@ Example::
[-0.5085],
[-0.4599],
[ 0.1807]])
""")
""".format(**multi_dim_common))
add_docstr(torch.median,
r"""
@ -3582,15 +3598,12 @@ Example::
Returns the product of each row of the :attr:`input` tensor in the given
dimension :attr:`dim`.
If :attr:`keepdim` is ``True``, the output tensor is of the same size as
:attr:`input` except in the dimension :attr:`dim` where it is of size 1.
Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting
in the output tensor having 1 fewer dimension than :attr:`input`.
{keepdim_details}
Args:
input (Tensor): the input tensor
dim (int): the dimension to reduce
keepdim (bool): whether the output tensor has :attr:`dim` retained or not
{dim}
{keepdim}
{dtype}
Example::
@ -3603,7 +3616,7 @@ Example::
[ 1.1131, -1.0629]])
>>> torch.prod(a, 1)
tensor([-0.2018, -0.2962, -0.0821, -1.1831])
""".format(**reduceops_common_args))
""".format(**single_dim_common))
add_docstr(torch.pstrf, r"""
pstrf(a, upper=True, out=None) -> (Tensor, Tensor)
@ -4461,18 +4474,15 @@ Returns the standard-deviation of each row of the :attr:`input` tensor in the
dimension :attr:`dim`. If :attr:`dim` is a list of dimensions,
reduce over all of them.
If :attr:`keepdim` is ``True``, the output tensor is of the same size
as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1.
Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the
output tensor having 1 (or ``len(dim)``) fewer dimension(s).
{keepdim_details}
If :attr:`unbiased` is ``False``, then the standard-deviation will be calculated
via the biased estimator. Otherwise, Bessel's correction will be used.
Args:
input (Tensor): the input tensor
dim (int or tuple of ints): the dimension or dimensions to reduce
keepdim (bool): whether the output tensor has :attr:`dim` retained or not
{dim}
{keepdim}
unbiased (bool): whether to use the unbiased estimation or not
out (Tensor, optional): the output tensor
@ -4486,7 +4496,7 @@ Example::
[ 0.1264, -0.5080, 1.6420, 0.1992]])
>>> torch.std(a, dim=1)
tensor([ 1.0311, 0.7477, 1.2204, 0.9087])
""")
""".format(**multi_dim_common))
add_docstr(torch.sum,
r"""
@ -4512,15 +4522,12 @@ Returns the sum of each row of the :attr:`input` tensor in the given
dimension :attr:`dim`. If :attr:`dim` is a list of dimensions,
reduce over all of them.
If :attr:`keepdim` is ``True``, the output tensor is of the same size
as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1.
Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the
output tensor having 1 (or ``len(dim)``) fewer dimension(s).
{keepdim_details}
Args:
input (Tensor): the input tensor
dim (int or tuple of ints): the dimension or dimensions to reduce
keepdim (bool): whether the output tensor has :attr:`dim` retained or not
{dim}
{keepdim}
{dtype}
Example::
@ -4536,7 +4543,7 @@ Example::
>>> b = torch.arange(4 * 5 * 6).view(4, 5, 6)
>>> torch.sum(b, (2, 1))
tensor([ 435., 1335., 2235., 3135.])
""".format(**reduceops_common_args))
""".format(**multi_dim_common))
add_docstr(torch.svd,
r"""
@ -5300,18 +5307,15 @@ Example::
Returns the variance of each row of the :attr:`input` tensor in the given
dimension :attr:`dim`.
If :attr:`keepdim` is ``True``, the output tensors are of the same size
as :attr:`input` except in the dimension :attr:`dim` where they are of size 1.
Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in
the outputs tensor having 1 fewer dimension than :attr:`input`.
{keepdim_details}
If :attr:`unbiased` is ``False``, then the variance will be calculated via the
biased estimator. Otherwise, Bessel's correction will be used.
Args:
input (Tensor): the input tensor
dim (int): the dimension to reduce
keepdim (bool): whether the output tensor has :attr:`dim` retained or not
{dim}
{keepdim}
unbiased (bool): whether to use the unbiased estimation or not
out (Tensor, optional): the output tensor
@ -5325,7 +5329,7 @@ Example::
[-0.7700, 0.6074, -0.1469, 0.7777]])
>>> torch.var(a, 1)
tensor([ 1.7444, 1.1363, 0.7356, 0.5112])
""")
""".format(**multi_dim_common))
add_docstr(torch.zeros,
r"""

View File

@ -904,9 +904,6 @@ class ShapePropagator {
{
"aten::argmax(Tensor self, int dim, bool keepdim) -> Tensor",
"aten::argmin(Tensor self, int dim, bool keepdim) -> Tensor",
"aten::max_values(Tensor self, int dim, bool keepdim) -> Tensor",
"aten::min_values(Tensor self, int dim, bool keepdim) -> Tensor",
"aten::logsumexp(Tensor self, int dim, bool keepdim) -> Tensor",
"aten::all(Tensor self, int dim, bool keepdim) -> Tensor",
"aten::any(Tensor self, int dim, bool keepdim) -> Tensor",
@ -957,10 +954,13 @@ class ShapePropagator {
// - has a bool keepdim argument
static const register_formula_for multidim_reduce_ops{
{
"aten::logsumexp(Tensor self, int[] dim, bool keepdim) -> Tensor",
"aten::mean(Tensor self, int[] dim, bool keepdim) -> Tensor",
"aten::norm(Tensor self, Scalar? p, int[] dim, bool keepdim) -> Tensor",
"aten::std(Tensor self, int[] dim, bool unbiased, bool keepdim) -> Tensor",
"aten::var(Tensor self, int[] dim, bool unbiased, bool keepdim) -> Tensor",
"aten::max_values(Tensor self, int[] dim, bool keepdim) -> Tensor",
"aten::min_values(Tensor self, int[] dim, bool keepdim) -> Tensor",
},
[](Node* node) -> type_vec_t {
if (auto dim = node->get<std::vector<int64_t>>(attr::dim)) {