mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
optimize dim reduce performance on norm, argmax and argmin (#72083)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72083
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64479
Test Plan: Imported from OSS
Reviewed By: VitalyFedyunin
Differential Revision: D33862408
Pulled By: frank-wei
fbshipit-source-id: eb291d59144e2ddc566d8c1491fe09b5b3f53fb0
(cherry picked from commit 11c384049d)
This commit is contained in:
parent
4d01789f69
commit
75e769449d
|
|
@ -283,4 +283,32 @@ void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vo
|
|||
});
|
||||
}
|
||||
|
||||
// when reduction is on most inner dimension (dim 0 in TensorIterator)
|
||||
// and input has contiguous most inner dimension, `binary_kernel_reduce_lastdim`
|
||||
// can be used.
|
||||
static inline bool is_reduce_lastdim(TensorIteratorBase& iter) {
|
||||
return iter.num_reduce_dims() == 1 && iter.is_dim_reduced(0)
|
||||
&& iter.ninputs() == 1 && iter.strides(1)[0] == iter.element_size(1);
|
||||
}
|
||||
|
||||
template <typename reduce_func_t>
|
||||
void binary_kernel_reduce_lastdim(TensorIteratorBase& iter, reduce_func_t reduce_op) {
|
||||
auto shape = iter.shape();
|
||||
int64_t dim_size = shape[0];
|
||||
int64_t grain_size = std::max((int64_t) 1, at::internal::GRAIN_SIZE / dim_size);
|
||||
TensorIterator sub_iter(iter);
|
||||
// create sub iterator to parallel on all non-reduce-dims
|
||||
sub_iter.narrow(0, 0, 1);
|
||||
auto loop = [&](char** data, const int64_t* strides, int64_t size) {
|
||||
char* out = data[0];
|
||||
char* in = data[1];
|
||||
for (int64_t i = 0; i < size; ++i) {
|
||||
reduce_op(out, in, dim_size);
|
||||
out += strides[0];
|
||||
in += strides[1];
|
||||
}
|
||||
};
|
||||
sub_iter.for_each(loop, grain_size);
|
||||
}
|
||||
|
||||
}}} // namespace at::native::<anonymous>
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/cpu/vec/vec.h>
|
||||
#include <ATen/cpu/vec/functional.h>
|
||||
#include <ATen/native/ReduceOps.h>
|
||||
#include <ATen/native/ReduceOpsUtils.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
|
|
@ -191,6 +192,19 @@ static void prod_kernel_impl(TensorIterator& iter) {
|
|||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename acc_t>
|
||||
inline void norm_two_reduce_step(Vectorized<acc_t>& acc_vec, Vectorized<scalar_t>& data_vec) {
|
||||
acc_vec += data_vec * data_vec;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void norm_two_reduce_step(Vectorized<float>& acc_fvec, Vectorized<BFloat16>& data_bvec) {
|
||||
Vectorized<float> data_fvec0, data_fvec1;
|
||||
std::tie(data_fvec0, data_fvec1) = convert_bfloat16_float(data_bvec);
|
||||
acc_fvec += data_fvec0 * data_fvec0;
|
||||
acc_fvec += data_fvec1 * data_fvec1;
|
||||
}
|
||||
|
||||
static void norm_kernel_tensor_iterator_impl(
|
||||
TensorIterator& iter,
|
||||
const Scalar& p) {
|
||||
|
|
@ -205,6 +219,9 @@ static void norm_kernel_tensor_iterator_impl(
|
|||
AT_ERROR("norm_kernel_tensor_iterator_impl expects norm to be integer or float");
|
||||
}
|
||||
|
||||
bool use_fast_path = is_reduce_lastdim(iter) && iter.dtype(0) == iter.input_dtype()
|
||||
&& (iter.input_dtype() == kFloat || iter.input_dtype() == kBFloat16);
|
||||
|
||||
// In the dispatch code blocks below, reduction kernels accumulate results as
|
||||
// the type `acc_t`. When `scalar_t` is complex, `acc_t` is the downgraded
|
||||
// real number type. Otherwise, `acc_t` and `scalar_t` are the same type.
|
||||
|
|
@ -227,6 +244,36 @@ static void norm_kernel_tensor_iterator_impl(
|
|||
);
|
||||
});
|
||||
} else if (val == 2) {
|
||||
if (use_fast_path) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, iter.input_dtype(), "norm_cpu", [&] {
|
||||
// use float as accumulate type for BFloat16
|
||||
using acc_t = vec_scalar_t<scalar_t>;
|
||||
binary_kernel_reduce_lastdim(iter, [](char* result_data_bytes, char* self_data_bytes, int64_t size) {
|
||||
scalar_t* result_data = (scalar_t*)result_data_bytes;
|
||||
scalar_t* self_data = (scalar_t*)self_data_bytes;
|
||||
|
||||
using Vec = Vectorized<scalar_t>;
|
||||
using fVec = Vectorized<acc_t>;
|
||||
fVec acc_vec{acc_t(0)};
|
||||
acc_t buffer[fVec::size()];
|
||||
int64_t d = 0;
|
||||
for (; d < size - (size % Vec::size()); d += Vec::size()) {
|
||||
Vec data_vec = Vec::loadu(self_data + d);
|
||||
norm_two_reduce_step(acc_vec, data_vec);
|
||||
}
|
||||
acc_vec.store(buffer);
|
||||
for (int j = 1; j < fVec::size(); j++) {
|
||||
buffer[0] = buffer[0] + buffer[j];
|
||||
}
|
||||
for (; d < size; d++) {
|
||||
acc_t data_val = acc_t(self_data[d]);
|
||||
buffer[0] += data_val * data_val;
|
||||
}
|
||||
result_data[0] = scalar_t(std::sqrt(buffer[0]));
|
||||
});
|
||||
});
|
||||
return;
|
||||
}
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.input_dtype(), "norm_cpu", [&] {
|
||||
using acc_t = typename scalar_value_type<scalar_t>::type;
|
||||
binary_kernel_reduce(
|
||||
|
|
@ -396,6 +443,21 @@ static void max_values_kernel_impl(TensorIterator& iter) {
|
|||
|
||||
static void argmax_kernel_impl(TensorIterator &iter) {
|
||||
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.dtype(1), "argmax_cpu", [&] {
|
||||
if (is_reduce_lastdim(iter)) {
|
||||
using arg_t = std::pair<scalar_t, int64_t>;
|
||||
auto op = ArgMaxOps<scalar_t>{};
|
||||
binary_kernel_reduce_lastdim(iter, [&](char* result_data_bytes, char* self_data_bytes, int64_t size) {
|
||||
int64_t* result_data = (int64_t*)result_data_bytes;
|
||||
scalar_t* self_data = (scalar_t*)self_data_bytes;
|
||||
|
||||
arg_t acc = arg_t(lower_bound<scalar_t>(), 0);
|
||||
for (int64_t i = 0; i < size; i++) {
|
||||
acc = op.reduce(acc, self_data[i], i);
|
||||
}
|
||||
result_data[0] = acc.second;
|
||||
});
|
||||
return;
|
||||
}
|
||||
binary_kernel_reduce(
|
||||
iter,
|
||||
ArgMaxOps<scalar_t>{},
|
||||
|
|
@ -405,6 +467,21 @@ static void argmax_kernel_impl(TensorIterator &iter) {
|
|||
|
||||
static void argmin_kernel_impl(TensorIterator &iter) {
|
||||
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.dtype(1), "argmin_cpu", [&] {
|
||||
if (is_reduce_lastdim(iter)) {
|
||||
using arg_t = std::pair<scalar_t, int64_t>;
|
||||
auto op = ArgMinOps<scalar_t>{};
|
||||
binary_kernel_reduce_lastdim(iter, [&](char* result_data_bytes, char* self_data_bytes, int64_t size) {
|
||||
int64_t* result_data = (int64_t*)result_data_bytes;
|
||||
scalar_t* self_data = (scalar_t*)self_data_bytes;
|
||||
|
||||
arg_t acc = arg_t(upper_bound<scalar_t>(), 0);
|
||||
for (int64_t i = 0; i < size; i++) {
|
||||
acc = op.reduce(acc, self_data[i], i);
|
||||
}
|
||||
result_data[0] = acc.second;
|
||||
});
|
||||
return;
|
||||
}
|
||||
binary_kernel_reduce(
|
||||
iter,
|
||||
ArgMinOps<scalar_t>{},
|
||||
|
|
|
|||
|
|
@ -459,6 +459,18 @@ class TestReductions(TestCase):
|
|||
with self.assertRaisesRegex(RuntimeError, "only tensors with up to 64 dims are supported"):
|
||||
op(x, -1)
|
||||
|
||||
@onlyCPU
|
||||
@dtypes(torch.float, torch.bfloat16)
|
||||
def test_dim_reduction_lastdim(self, device, dtype):
|
||||
x = torch.randn(3, 5, 40, device=device, dtype=dtype)
|
||||
x = x[:, :, 0:40:2]
|
||||
x2 = x.contiguous()
|
||||
ops = [torch.norm, torch.argmax, torch.argmin]
|
||||
for op in ops:
|
||||
y = op(x, dim=-1)
|
||||
y2 = op(x2, dim=-1)
|
||||
self.assertEqual(y, y2)
|
||||
|
||||
@skipIfNoSciPy
|
||||
def test_logsumexp(self, device):
|
||||
from scipy.special import logsumexp
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user