optimize dim reduce performance on norm, argmax and argmin (#72083)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72083

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64479

Test Plan: Imported from OSS

Reviewed By: VitalyFedyunin

Differential Revision: D33862408

Pulled By: frank-wei

fbshipit-source-id: eb291d59144e2ddc566d8c1491fe09b5b3f53fb0
(cherry picked from commit 11c384049d)
This commit is contained in:
mingfeima 2022-02-10 11:26:14 -08:00 committed by PyTorch MergeBot
parent 4d01789f69
commit 75e769449d
3 changed files with 117 additions and 0 deletions

View File

@ -283,4 +283,32 @@ void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vo
});
}
// when reduction is on most inner dimension (dim 0 in TensorIterator)
// and input has contiguous most inner dimension, `binary_kernel_reduce_lastdim`
// can be used.
static inline bool is_reduce_lastdim(TensorIteratorBase& iter) {
return iter.num_reduce_dims() == 1 && iter.is_dim_reduced(0)
&& iter.ninputs() == 1 && iter.strides(1)[0] == iter.element_size(1);
}
template <typename reduce_func_t>
void binary_kernel_reduce_lastdim(TensorIteratorBase& iter, reduce_func_t reduce_op) {
auto shape = iter.shape();
int64_t dim_size = shape[0];
int64_t grain_size = std::max((int64_t) 1, at::internal::GRAIN_SIZE / dim_size);
TensorIterator sub_iter(iter);
// create sub iterator to parallel on all non-reduce-dims
sub_iter.narrow(0, 0, 1);
auto loop = [&](char** data, const int64_t* strides, int64_t size) {
char* out = data[0];
char* in = data[1];
for (int64_t i = 0; i < size; ++i) {
reduce_op(out, in, dim_size);
out += strides[0];
in += strides[1];
}
};
sub_iter.for_each(loop, grain_size);
}
}}} // namespace at::native::<anonymous>

View File

@ -4,6 +4,7 @@
#include <ATen/Dispatch.h>
#include <ATen/cpu/vec/vec.h>
#include <ATen/cpu/vec/functional.h>
#include <ATen/native/ReduceOps.h>
#include <ATen/native/ReduceOpsUtils.h>
#include <ATen/native/Resize.h>
@ -191,6 +192,19 @@ static void prod_kernel_impl(TensorIterator& iter) {
}
}
template <typename scalar_t, typename acc_t>
inline void norm_two_reduce_step(Vectorized<acc_t>& acc_vec, Vectorized<scalar_t>& data_vec) {
acc_vec += data_vec * data_vec;
}
template <>
inline void norm_two_reduce_step(Vectorized<float>& acc_fvec, Vectorized<BFloat16>& data_bvec) {
Vectorized<float> data_fvec0, data_fvec1;
std::tie(data_fvec0, data_fvec1) = convert_bfloat16_float(data_bvec);
acc_fvec += data_fvec0 * data_fvec0;
acc_fvec += data_fvec1 * data_fvec1;
}
static void norm_kernel_tensor_iterator_impl(
TensorIterator& iter,
const Scalar& p) {
@ -205,6 +219,9 @@ static void norm_kernel_tensor_iterator_impl(
AT_ERROR("norm_kernel_tensor_iterator_impl expects norm to be integer or float");
}
bool use_fast_path = is_reduce_lastdim(iter) && iter.dtype(0) == iter.input_dtype()
&& (iter.input_dtype() == kFloat || iter.input_dtype() == kBFloat16);
// In the dispatch code blocks below, reduction kernels accumulate results as
// the type `acc_t`. When `scalar_t` is complex, `acc_t` is the downgraded
// real number type. Otherwise, `acc_t` and `scalar_t` are the same type.
@ -227,6 +244,36 @@ static void norm_kernel_tensor_iterator_impl(
);
});
} else if (val == 2) {
if (use_fast_path) {
AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, iter.input_dtype(), "norm_cpu", [&] {
// use float as accumulate type for BFloat16
using acc_t = vec_scalar_t<scalar_t>;
binary_kernel_reduce_lastdim(iter, [](char* result_data_bytes, char* self_data_bytes, int64_t size) {
scalar_t* result_data = (scalar_t*)result_data_bytes;
scalar_t* self_data = (scalar_t*)self_data_bytes;
using Vec = Vectorized<scalar_t>;
using fVec = Vectorized<acc_t>;
fVec acc_vec{acc_t(0)};
acc_t buffer[fVec::size()];
int64_t d = 0;
for (; d < size - (size % Vec::size()); d += Vec::size()) {
Vec data_vec = Vec::loadu(self_data + d);
norm_two_reduce_step(acc_vec, data_vec);
}
acc_vec.store(buffer);
for (int j = 1; j < fVec::size(); j++) {
buffer[0] = buffer[0] + buffer[j];
}
for (; d < size; d++) {
acc_t data_val = acc_t(self_data[d]);
buffer[0] += data_val * data_val;
}
result_data[0] = scalar_t(std::sqrt(buffer[0]));
});
});
return;
}
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.input_dtype(), "norm_cpu", [&] {
using acc_t = typename scalar_value_type<scalar_t>::type;
binary_kernel_reduce(
@ -396,6 +443,21 @@ static void max_values_kernel_impl(TensorIterator& iter) {
static void argmax_kernel_impl(TensorIterator &iter) {
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.dtype(1), "argmax_cpu", [&] {
if (is_reduce_lastdim(iter)) {
using arg_t = std::pair<scalar_t, int64_t>;
auto op = ArgMaxOps<scalar_t>{};
binary_kernel_reduce_lastdim(iter, [&](char* result_data_bytes, char* self_data_bytes, int64_t size) {
int64_t* result_data = (int64_t*)result_data_bytes;
scalar_t* self_data = (scalar_t*)self_data_bytes;
arg_t acc = arg_t(lower_bound<scalar_t>(), 0);
for (int64_t i = 0; i < size; i++) {
acc = op.reduce(acc, self_data[i], i);
}
result_data[0] = acc.second;
});
return;
}
binary_kernel_reduce(
iter,
ArgMaxOps<scalar_t>{},
@ -405,6 +467,21 @@ static void argmax_kernel_impl(TensorIterator &iter) {
static void argmin_kernel_impl(TensorIterator &iter) {
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.dtype(1), "argmin_cpu", [&] {
if (is_reduce_lastdim(iter)) {
using arg_t = std::pair<scalar_t, int64_t>;
auto op = ArgMinOps<scalar_t>{};
binary_kernel_reduce_lastdim(iter, [&](char* result_data_bytes, char* self_data_bytes, int64_t size) {
int64_t* result_data = (int64_t*)result_data_bytes;
scalar_t* self_data = (scalar_t*)self_data_bytes;
arg_t acc = arg_t(upper_bound<scalar_t>(), 0);
for (int64_t i = 0; i < size; i++) {
acc = op.reduce(acc, self_data[i], i);
}
result_data[0] = acc.second;
});
return;
}
binary_kernel_reduce(
iter,
ArgMinOps<scalar_t>{},

View File

@ -459,6 +459,18 @@ class TestReductions(TestCase):
with self.assertRaisesRegex(RuntimeError, "only tensors with up to 64 dims are supported"):
op(x, -1)
@onlyCPU
@dtypes(torch.float, torch.bfloat16)
def test_dim_reduction_lastdim(self, device, dtype):
x = torch.randn(3, 5, 40, device=device, dtype=dtype)
x = x[:, :, 0:40:2]
x2 = x.contiguous()
ops = [torch.norm, torch.argmax, torch.argmin]
for op in ops:
y = op(x, dim=-1)
y2 = op(x2, dim=-1)
self.assertEqual(y, y2)
@skipIfNoSciPy
def test_logsumexp(self, device):
from scipy.special import logsumexp