mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Extract CPU log_softmax kernels to header (#156243)
This allows sharing them with ExecuTorch. Differential Revision: [D76830114](https://our.internmc.facebook.com/intern/diff/D76830114/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/156243 Approved by: https://github.com/janeyx99
This commit is contained in:
parent
96e4c95cd8
commit
c82a174cea
337
aten/src/ATen/native/cpu/LogSoftmaxKernelImpl.h
Normal file
337
aten/src/ATen/native/cpu/LogSoftmaxKernelImpl.h
Normal file
|
|
@ -0,0 +1,337 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <ATen/OpMathType.h>
|
||||||
|
#include <ATen/Parallel.h>
|
||||||
|
#include <ATen/cpu/vec/functional.h>
|
||||||
|
#include <ATen/cpu/vec/vec.h>
|
||||||
|
#include <c10/util/irange.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <limits>
|
||||||
|
#include <memory>
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
|
namespace at::native {
|
||||||
|
inline namespace CPU_CAPABILITY {
|
||||||
|
template <typename scalar_t>
|
||||||
|
int64_t vec_log_softmax_lastdim_chunk_size(int64_t grain_size, int64_t outer_size, int64_t dim_size) {
|
||||||
|
// Coincidentally, at::internal::GRAIN_SIZE is 32768, which is equal to the
|
||||||
|
// size of L1D cache on many processors. Some processors have 48 KB L1D cache
|
||||||
|
// nowadays, so maybe in the future, we can leverage the knowledge of a
|
||||||
|
// machine's L1D cache size.
|
||||||
|
int64_t MAX_CHUNK_SIZE = std::max<int64_t>(
|
||||||
|
1,
|
||||||
|
grain_size / (sizeof(scalar_t) * dim_size));
|
||||||
|
return std::min<int64_t>(MAX_CHUNK_SIZE, outer_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
void serial_vec_log_softmax_lastdim_range(
|
||||||
|
const scalar_t* input_data_base,
|
||||||
|
scalar_t* output_data_base,
|
||||||
|
int64_t dim_size,
|
||||||
|
int64_t chunk_size,
|
||||||
|
int64_t begin,
|
||||||
|
int64_t end) {
|
||||||
|
if (end <= begin) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
using Vec = vec::Vectorized<vec::vec_scalar_t<scalar_t>>;
|
||||||
|
// MSVC requires such a declaration of dynamic arrays
|
||||||
|
// Source: https://stackoverflow.com/a/33423538
|
||||||
|
auto tmp_sum_scalar = std::make_unique<scalar_t[]>(chunk_size);
|
||||||
|
auto max_input_arr = std::make_unique<scalar_t[]>(chunk_size);
|
||||||
|
for (int64_t ii = begin; ii < end; ii += chunk_size) {
|
||||||
|
int64_t loop_end = chunk_size;
|
||||||
|
if (ii + chunk_size > end) {
|
||||||
|
loop_end = end - ii;
|
||||||
|
}
|
||||||
|
for (const auto j : c10::irange(loop_end)) {
|
||||||
|
int64_t i = ii + j;
|
||||||
|
const scalar_t* input_data = input_data_base + i * dim_size;
|
||||||
|
max_input_arr[j] = vec::reduce_all<scalar_t>(
|
||||||
|
[](Vec& x, Vec& y) { return vec::maximum(x, y); },
|
||||||
|
input_data,
|
||||||
|
dim_size);
|
||||||
|
}
|
||||||
|
for (const auto j : c10::irange(loop_end)) {
|
||||||
|
int64_t i = ii + j;
|
||||||
|
const scalar_t* input_data = input_data_base + i * dim_size;
|
||||||
|
scalar_t max_input = max_input_arr[j];
|
||||||
|
tmp_sum_scalar[j] = vec::map_reduce_all<scalar_t>(
|
||||||
|
[max_input](Vec x) { return (x - Vec(max_input)).exp(); },
|
||||||
|
[](Vec x, Vec y) { return x + y; },
|
||||||
|
input_data,
|
||||||
|
dim_size);
|
||||||
|
}
|
||||||
|
// See [Note AVX-SSE transitions] for why this should call the
|
||||||
|
// vectorized version (aside from perf improvements).
|
||||||
|
vec::map(
|
||||||
|
[](Vec x) { return x.log(); },
|
||||||
|
tmp_sum_scalar.get(),
|
||||||
|
tmp_sum_scalar.get(),
|
||||||
|
loop_end);
|
||||||
|
for (const auto j : c10::irange(loop_end)) {
|
||||||
|
int64_t i = ii + j;
|
||||||
|
const scalar_t* input_data = input_data_base + i * dim_size;
|
||||||
|
scalar_t* output_data = output_data_base + i * dim_size;
|
||||||
|
scalar_t tmp_sum = tmp_sum_scalar[j];
|
||||||
|
scalar_t max_input = max_input_arr[j];
|
||||||
|
|
||||||
|
// It's necessary to keep the order of the operations below.
|
||||||
|
// In some cases that input is large digits and the difference
|
||||||
|
// is small, if we compute `max_input` plus `tmp_sum` before,
|
||||||
|
// there would be a numerical problem. See an example in
|
||||||
|
// https://github.com/pytorch/pytorch/issues/11752#issuecomment-422883379
|
||||||
|
vec::map(
|
||||||
|
[tmp_sum, max_input](Vec x) {
|
||||||
|
return x - Vec(max_input) - Vec(tmp_sum);
|
||||||
|
},
|
||||||
|
output_data,
|
||||||
|
input_data,
|
||||||
|
dim_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Can't include ATen/Parallel.h.
|
||||||
|
// TODO: find a way to have only one copy of divup.
|
||||||
|
inline int64_t divup(int64_t x, int64_t y) {
|
||||||
|
return (x + y - 1) / y;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t, int64_t BLOCK_SIZE = 128 * 1024>
|
||||||
|
std::pair<int64_t,int64_t> vec_logsoftmax_chunk_size_and_num_chunks(int64_t inner_size, int64_t dim_size) {
|
||||||
|
using Vec = vec::Vectorized<scalar_t>;
|
||||||
|
int64_t MAX_CHUNK_SIZE = std::max<int64_t>(BLOCK_SIZE / dim_size / sizeof(scalar_t), Vec::size());
|
||||||
|
MAX_CHUNK_SIZE = MAX_CHUNK_SIZE / Vec::size() * Vec::size();
|
||||||
|
int64_t CHUNK_SIZE = std::min<int64_t>(MAX_CHUNK_SIZE, inner_size);
|
||||||
|
int64_t num_chunks = divup(inner_size, CHUNK_SIZE);
|
||||||
|
return {CHUNK_SIZE, num_chunks};
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
std::enable_if_t<std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
|
||||||
|
serial_vec_logsoftmax_range(
|
||||||
|
const scalar_t* input_data_base,
|
||||||
|
scalar_t* output_data_base,
|
||||||
|
int64_t inner_size,
|
||||||
|
int64_t chunk_size,
|
||||||
|
int64_t num_chunks,
|
||||||
|
int64_t dim_size,
|
||||||
|
int64_t begin,
|
||||||
|
int64_t end) {
|
||||||
|
using Vec = vec::Vectorized<scalar_t>;
|
||||||
|
// thread local temp buffer which holds vertical reduction result: max and sum.
|
||||||
|
auto buffer = std::make_unique<scalar_t []>(chunk_size * 2);
|
||||||
|
scalar_t* input_max_data = buffer.get();
|
||||||
|
scalar_t* tmp_sum_data = buffer.get() + chunk_size;
|
||||||
|
|
||||||
|
for (int64_t i = begin; i < end; i++) {
|
||||||
|
int64_t outer_idx = i / num_chunks;
|
||||||
|
int64_t k = i % num_chunks;
|
||||||
|
int64_t inner_idx_begin = k * chunk_size;
|
||||||
|
int64_t size = std::min(chunk_size, inner_size - inner_idx_begin);
|
||||||
|
|
||||||
|
// init
|
||||||
|
Vec zero_vec = Vec(scalar_t(0));
|
||||||
|
Vec min_vec = Vec(-std::numeric_limits<scalar_t>::infinity());
|
||||||
|
int64_t d0 = 0;
|
||||||
|
for (; d0 < size - (size % Vec::size()); d0 += Vec::size()) {
|
||||||
|
min_vec.store(input_max_data + d0);
|
||||||
|
zero_vec.store(tmp_sum_data + d0);
|
||||||
|
}
|
||||||
|
for (; d0 < size; d0++) {
|
||||||
|
input_max_data[d0] = -std::numeric_limits<scalar_t>::infinity();
|
||||||
|
tmp_sum_data[d0] = scalar_t(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// compute max
|
||||||
|
for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
|
||||||
|
const scalar_t* input_ptr = input_data_base + outer_idx * dim_size * inner_size
|
||||||
|
+ dim_idx * inner_size + inner_idx_begin;
|
||||||
|
|
||||||
|
int64_t d1 = 0;
|
||||||
|
for (; d1 < size - (size % Vec::size()); d1 += Vec::size()) {
|
||||||
|
Vec data_vec = Vec::loadu(input_ptr + d1);
|
||||||
|
Vec max_vec = Vec::loadu(input_max_data + d1);
|
||||||
|
max_vec = Vec::blendv(max_vec, data_vec, data_vec > max_vec);
|
||||||
|
max_vec.store(input_max_data + d1);
|
||||||
|
}
|
||||||
|
for (; d1 < size; d1++) {
|
||||||
|
scalar_t data_val = input_ptr[d1];
|
||||||
|
scalar_t max_val = input_max_data[d1];
|
||||||
|
input_max_data[d1] = data_val > max_val ? data_val : max_val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// compute sum of (x - max).exp()
|
||||||
|
for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
|
||||||
|
const scalar_t* input_ptr = input_data_base + outer_idx * dim_size * inner_size
|
||||||
|
+ dim_idx * inner_size + inner_idx_begin;
|
||||||
|
|
||||||
|
int64_t d2 = 0;
|
||||||
|
for (; d2 < size - (size % Vec::size()); d2 += Vec::size()) {
|
||||||
|
Vec data_vec = Vec::loadu(input_ptr + d2);
|
||||||
|
Vec sum_vec = Vec::loadu(tmp_sum_data + d2);
|
||||||
|
Vec max_vec = Vec::loadu(input_max_data + d2);
|
||||||
|
sum_vec += (data_vec - max_vec).exp();
|
||||||
|
sum_vec.store(tmp_sum_data + d2);
|
||||||
|
}
|
||||||
|
for (; d2 < size; d2++) {
|
||||||
|
scalar_t data_val = input_ptr[d2];
|
||||||
|
scalar_t max_val = input_max_data[d2];
|
||||||
|
tmp_sum_data[d2] += std::exp(data_val - max_val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// apply log
|
||||||
|
vec::map([](Vec x) { return x.log(); }, tmp_sum_data, tmp_sum_data, size);
|
||||||
|
|
||||||
|
// compute x - max - sum
|
||||||
|
for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
|
||||||
|
int64_t offset = outer_idx * dim_size * inner_size + dim_idx * inner_size + inner_idx_begin;
|
||||||
|
const scalar_t* input_ptr = input_data_base + offset;
|
||||||
|
scalar_t* output_ptr = output_data_base + offset;
|
||||||
|
|
||||||
|
int64_t d3 = 0;
|
||||||
|
for (; d3 < size - (size % Vec::size()); d3 += Vec::size()) {
|
||||||
|
Vec data_vec = Vec::loadu(input_ptr + d3);
|
||||||
|
Vec max_vec = Vec::loadu(input_max_data + d3);
|
||||||
|
Vec sum_vec = Vec::loadu(tmp_sum_data + d3);
|
||||||
|
Vec out_vec = data_vec - max_vec - sum_vec;
|
||||||
|
out_vec.store(output_ptr + d3);
|
||||||
|
}
|
||||||
|
for (; d3 < size; d3++) {
|
||||||
|
output_ptr[d3] = input_ptr[d3] - input_max_data[d3] - tmp_sum_data[d3];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
std::enable_if_t<!std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
|
||||||
|
serial_vec_logsoftmax_range(
|
||||||
|
const scalar_t* input_data_base,
|
||||||
|
scalar_t* output_data_base,
|
||||||
|
int64_t inner_size,
|
||||||
|
int64_t chunk_size,
|
||||||
|
int64_t num_chunks,
|
||||||
|
int64_t dim_size,
|
||||||
|
int64_t begin,
|
||||||
|
int64_t end) {
|
||||||
|
using Vec = vec::Vectorized<scalar_t>;
|
||||||
|
using fVec = vec::Vectorized<float>;
|
||||||
|
auto buffer = std::make_unique<float []>(chunk_size * 2);
|
||||||
|
float* input_max_data = buffer.get();
|
||||||
|
float* tmp_sum_data = buffer.get() + chunk_size;
|
||||||
|
|
||||||
|
// thread local buffer that holds input data in float32 to save next 2 dtype conversion
|
||||||
|
auto input_buffer = std::make_unique<float []>(dim_size * chunk_size);
|
||||||
|
float* input_buffer_data = input_buffer.get();
|
||||||
|
|
||||||
|
// init
|
||||||
|
for (int64_t i = begin; i < end; i++) {
|
||||||
|
int64_t outer_idx = i / num_chunks;
|
||||||
|
int64_t k = i % num_chunks;
|
||||||
|
int64_t inner_idx_begin = k * chunk_size;
|
||||||
|
int64_t size = std::min(chunk_size, inner_size - inner_idx_begin);
|
||||||
|
|
||||||
|
fVec zero_fvec = fVec(float(0));
|
||||||
|
fVec min_fvec = fVec(-std::numeric_limits<float>::infinity());
|
||||||
|
int64_t d0 = 0;
|
||||||
|
for (; d0 < size - (size % Vec::size()); d0 += Vec::size()) {
|
||||||
|
min_fvec.store(input_max_data + d0);
|
||||||
|
min_fvec.store(input_max_data + d0 + fVec::size());
|
||||||
|
zero_fvec.store(tmp_sum_data + d0);
|
||||||
|
zero_fvec.store(tmp_sum_data + d0 + fVec::size());
|
||||||
|
}
|
||||||
|
for (; d0 < size; d0++) {
|
||||||
|
input_max_data[d0] = -std::numeric_limits<float>::infinity();
|
||||||
|
tmp_sum_data[d0] = float(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// compute max
|
||||||
|
for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
|
||||||
|
const scalar_t* input_ptr = input_data_base + outer_idx * dim_size * inner_size
|
||||||
|
+ dim_idx * inner_size + inner_idx_begin;
|
||||||
|
float* input_buffer_ptr = input_buffer_data + dim_idx * chunk_size;
|
||||||
|
|
||||||
|
int64_t d1 = 0;
|
||||||
|
for (; d1 < size - (size % Vec::size()); d1 += Vec::size()) {
|
||||||
|
Vec data_vec = Vec::loadu(input_ptr + d1);
|
||||||
|
auto [data_fvec0, data_fvec1] = vec::convert_to_float<scalar_t>(data_vec);
|
||||||
|
fVec max_fvec0 = fVec::loadu(input_max_data + d1);
|
||||||
|
fVec max_fvec1 = fVec::loadu(input_max_data + d1 + fVec::size());
|
||||||
|
max_fvec0 = fVec::blendv(max_fvec0, data_fvec0, data_fvec0 > max_fvec0);
|
||||||
|
max_fvec1 = fVec::blendv(max_fvec1, data_fvec1, data_fvec1 > max_fvec1);
|
||||||
|
max_fvec0.store(input_max_data + d1);
|
||||||
|
max_fvec1.store(input_max_data + d1 + fVec::size());
|
||||||
|
|
||||||
|
// cache the 'converted' float input
|
||||||
|
data_fvec0.store(input_buffer_ptr + d1);
|
||||||
|
data_fvec1.store(input_buffer_ptr + d1 + fVec::size());
|
||||||
|
}
|
||||||
|
for (; d1 < size; d1++) {
|
||||||
|
float data_val = float(input_ptr[d1]);
|
||||||
|
float max_val = input_max_data[d1];
|
||||||
|
input_max_data[d1] = data_val > max_val ? data_val : max_val;
|
||||||
|
input_buffer_ptr[d1] = data_val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// compute sum of (x - max).exp()
|
||||||
|
for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
|
||||||
|
float* input_buffer_ptr = input_buffer_data + dim_idx * chunk_size;
|
||||||
|
|
||||||
|
int64_t d2 = 0;
|
||||||
|
for (; d2 < size - (size % Vec::size()); d2 += Vec::size()) {
|
||||||
|
fVec data_fvec0 = fVec::loadu(input_buffer_ptr + d2);
|
||||||
|
fVec data_fvec1 = fVec::loadu(input_buffer_ptr + d2 + fVec::size());
|
||||||
|
fVec sum_fvec0 = fVec::loadu(tmp_sum_data + d2);
|
||||||
|
fVec sum_fvec1 = fVec::loadu(tmp_sum_data + d2 + fVec::size());
|
||||||
|
fVec max_fvec0 = fVec::loadu(input_max_data + d2);
|
||||||
|
fVec max_fvec1 = fVec::loadu(input_max_data + d2 + fVec::size());
|
||||||
|
sum_fvec0 += (data_fvec0 - max_fvec0).exp();
|
||||||
|
sum_fvec1 += (data_fvec1 - max_fvec1).exp();
|
||||||
|
sum_fvec0.store(tmp_sum_data + d2);
|
||||||
|
sum_fvec1.store(tmp_sum_data + d2 + fVec::size());
|
||||||
|
}
|
||||||
|
for (; d2 < size; d2++) {
|
||||||
|
float data_val = input_buffer_ptr[d2];
|
||||||
|
float max_val = input_max_data[d2];
|
||||||
|
tmp_sum_data[d2] += std::exp(data_val - max_val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// apply log
|
||||||
|
vec::map([](fVec x) { return x.log(); }, tmp_sum_data, tmp_sum_data, size);
|
||||||
|
|
||||||
|
// compute x - max - sum
|
||||||
|
for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
|
||||||
|
float* input_buffer_ptr = input_buffer_data + dim_idx * chunk_size;
|
||||||
|
scalar_t* output_ptr = output_data_base + outer_idx * dim_size * inner_size
|
||||||
|
+ dim_idx * inner_size + inner_idx_begin;
|
||||||
|
|
||||||
|
int64_t d3 = 0;
|
||||||
|
for (; d3 < size - (size % Vec::size()); d3 += Vec::size()) {
|
||||||
|
fVec data_fvec0 = fVec::loadu(input_buffer_ptr + d3);
|
||||||
|
fVec data_fvec1 = fVec::loadu(input_buffer_ptr + d3 + fVec::size());
|
||||||
|
fVec max_fvec0 = fVec::loadu(input_max_data + d3);
|
||||||
|
fVec max_fvec1 = fVec::loadu(input_max_data + d3 + fVec::size());
|
||||||
|
fVec sum_fvec0 = fVec::loadu(tmp_sum_data + d3);
|
||||||
|
fVec sum_fvec1 = fVec::loadu(tmp_sum_data + d3 + fVec::size());
|
||||||
|
fVec out_fvec0 = data_fvec0 - max_fvec0 - sum_fvec0;
|
||||||
|
fVec out_fvec1 = data_fvec1 - max_fvec1 - sum_fvec1;
|
||||||
|
Vec out_vec = vec::convert_from_float<scalar_t>(out_fvec0, out_fvec1);
|
||||||
|
out_vec.store(output_ptr + d3);
|
||||||
|
}
|
||||||
|
for (; d3 < size; d3++) {
|
||||||
|
output_ptr[d3] = scalar_t(input_buffer_ptr[d3] - input_max_data[d3] - tmp_sum_data[d3]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace CPU_CAPABILITY
|
||||||
|
}} // namespace at::native
|
||||||
|
|
@ -2,6 +2,8 @@
|
||||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||||
#include <ATen/native/cpu/SoftmaxKernel.h>
|
#include <ATen/native/cpu/SoftmaxKernel.h>
|
||||||
|
|
||||||
|
#include <ATen/native/cpu/LogSoftmaxKernelImpl.h>
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <iterator>
|
#include <iterator>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
@ -28,7 +30,6 @@
|
||||||
// We use a chunk size such that it'd fit in L1D.
|
// We use a chunk size such that it'd fit in L1D.
|
||||||
|
|
||||||
namespace at::native {
|
namespace at::native {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
inline void _vec_log_softmax_lastdim(
|
inline void _vec_log_softmax_lastdim(
|
||||||
|
|
@ -36,15 +37,10 @@ inline void _vec_log_softmax_lastdim(
|
||||||
scalar_t* output_data_base,
|
scalar_t* output_data_base,
|
||||||
int64_t outer_size,
|
int64_t outer_size,
|
||||||
int64_t dim_size) {
|
int64_t dim_size) {
|
||||||
using Vec = vec::Vectorized<vec::vec_scalar_t<scalar_t>>;
|
const auto chunk_size = vec_log_softmax_lastdim_chunk_size<scalar_t>(
|
||||||
// Coincidentally, at::internal::GRAIN_SIZE is 32768, which is equal to the
|
at::internal::GRAIN_SIZE,
|
||||||
// size of L1D cache on many processors. Some processors have 48 KB L1D cache
|
outer_size,
|
||||||
// nowadays, so maybe in the future, we can leverage the knowledge of a
|
dim_size);
|
||||||
// machine's L1D cache size.
|
|
||||||
int64_t MAX_CHUNK_SIZE = std::max<int64_t>(
|
|
||||||
1,
|
|
||||||
at::internal::GRAIN_SIZE / (sizeof(scalar_t) * dim_size));
|
|
||||||
int64_t CHUNK_SIZE = std::min<int64_t>(MAX_CHUNK_SIZE, outer_size);
|
|
||||||
// Note: grain_size value of 0
|
// Note: grain_size value of 0
|
||||||
// We don't change the number of OpenMP threads in the OpenMP thread-pool,
|
// We don't change the number of OpenMP threads in the OpenMP thread-pool,
|
||||||
// so some threads do useful work, while others don't.
|
// so some threads do useful work, while others don't.
|
||||||
|
|
@ -52,60 +48,13 @@ inline void _vec_log_softmax_lastdim(
|
||||||
// work among threads in an equitable manner. We compute CHUNK_SIZE to ensure
|
// work among threads in an equitable manner. We compute CHUNK_SIZE to ensure
|
||||||
// each thread's computations would be efficient.
|
// each thread's computations would be efficient.
|
||||||
parallel_for(0, outer_size, 0, [&](int64_t begin, int64_t end) {
|
parallel_for(0, outer_size, 0, [&](int64_t begin, int64_t end) {
|
||||||
// MSVC requires such a declaration of dynamic arrays
|
serial_vec_log_softmax_lastdim_range(
|
||||||
// Source: https://stackoverflow.com/a/33423538
|
input_data_base,
|
||||||
auto tmp_sum_scalar = std::make_unique<scalar_t[]>(CHUNK_SIZE);
|
output_data_base,
|
||||||
auto max_input_arr = std::make_unique<scalar_t[]>(CHUNK_SIZE);
|
dim_size,
|
||||||
for (int64_t ii = begin; ii < end; ii += CHUNK_SIZE) {
|
chunk_size,
|
||||||
int64_t loop_end = CHUNK_SIZE;
|
begin,
|
||||||
if (ii + CHUNK_SIZE > end)
|
end);
|
||||||
loop_end = end - ii;
|
|
||||||
for (const auto j : c10::irange(loop_end)) {
|
|
||||||
int64_t i = ii + j;
|
|
||||||
const scalar_t* input_data = input_data_base + i * dim_size;
|
|
||||||
max_input_arr[j] = vec::reduce_all<scalar_t>(
|
|
||||||
[](Vec& x, Vec& y) { return vec::maximum(x, y); },
|
|
||||||
input_data,
|
|
||||||
dim_size);
|
|
||||||
}
|
|
||||||
for (const auto j : c10::irange(loop_end)) {
|
|
||||||
int64_t i = ii + j;
|
|
||||||
const scalar_t* input_data = input_data_base + i * dim_size;
|
|
||||||
scalar_t max_input = max_input_arr[j];
|
|
||||||
tmp_sum_scalar[j] = vec::map_reduce_all<scalar_t>(
|
|
||||||
[max_input](Vec x) { return (x - Vec(max_input)).exp(); },
|
|
||||||
[](Vec x, Vec y) { return x + y; },
|
|
||||||
input_data,
|
|
||||||
dim_size);
|
|
||||||
}
|
|
||||||
// See [Note AVX-SSE transitions] for why this should call the
|
|
||||||
// vectorized version (aside from perf improvements).
|
|
||||||
vec::map(
|
|
||||||
[](Vec x) { return x.log(); },
|
|
||||||
tmp_sum_scalar.get(),
|
|
||||||
tmp_sum_scalar.get(),
|
|
||||||
loop_end);
|
|
||||||
for (const auto j : c10::irange(loop_end)) {
|
|
||||||
int64_t i = ii + j;
|
|
||||||
const scalar_t* input_data = input_data_base + i * dim_size;
|
|
||||||
scalar_t* output_data = output_data_base + i * dim_size;
|
|
||||||
scalar_t tmp_sum = tmp_sum_scalar[j];
|
|
||||||
scalar_t max_input = max_input_arr[j];
|
|
||||||
|
|
||||||
// It's necessary to keep the order of the operations below.
|
|
||||||
// In some cases that input is large digits and the difference
|
|
||||||
// is small, if we compute `max_input` plus `tmp_sum` before,
|
|
||||||
// there would be a numerical problem. See an example in
|
|
||||||
// https://github.com/pytorch/pytorch/issues/11752#issuecomment-422883379
|
|
||||||
vec::map(
|
|
||||||
[tmp_sum, max_input](Vec x) {
|
|
||||||
return x - Vec(max_input) - Vec(tmp_sum);
|
|
||||||
},
|
|
||||||
output_data,
|
|
||||||
input_data,
|
|
||||||
dim_size);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -891,100 +840,23 @@ _vec_logsoftmax(
|
||||||
int64_t outer_size,
|
int64_t outer_size,
|
||||||
int64_t inner_size,
|
int64_t inner_size,
|
||||||
int64_t dim_size) {
|
int64_t dim_size) {
|
||||||
using Vec = vec::Vectorized<scalar_t>;
|
const auto [CHUNK_SIZE_binding, num_chunks_binding] = vec_logsoftmax_chunk_size_and_num_chunks<scalar_t>(
|
||||||
int64_t BLOCK_SIZE = 128 * 1024;
|
inner_size, dim_size);
|
||||||
int64_t MAX_CHUNK_SIZE = std::max<int64_t>(BLOCK_SIZE / dim_size / sizeof(scalar_t), Vec::size());
|
// Work around "capturing a structured binding is not yet supported in OpenMP".
|
||||||
MAX_CHUNK_SIZE = MAX_CHUNK_SIZE / Vec::size() * Vec::size();
|
const auto CHUNK_SIZE = CHUNK_SIZE_binding;
|
||||||
int64_t CHUNK_SIZE = std::min<int64_t>(MAX_CHUNK_SIZE, inner_size);
|
const auto num_chunks = num_chunks_binding;
|
||||||
int64_t num_chunks = divup(inner_size, CHUNK_SIZE);
|
|
||||||
|
|
||||||
// See Note: grain_size value of 0
|
// See Note: grain_size value of 0
|
||||||
at::parallel_for(0, outer_size * num_chunks, 0, [&](int64_t begin, int64_t end) {
|
at::parallel_for(0, outer_size * num_chunks, 0, [&](int64_t begin, int64_t end) {
|
||||||
// thread local temp buffer which holds vertical reduction result: max and sum.
|
serial_vec_logsoftmax_range(
|
||||||
auto buffer = std::make_unique<scalar_t []>(CHUNK_SIZE * 2);
|
input_data_base,
|
||||||
scalar_t* input_max_data = buffer.get();
|
output_data_base,
|
||||||
scalar_t* tmp_sum_data = buffer.get() + CHUNK_SIZE;
|
inner_size,
|
||||||
|
CHUNK_SIZE,
|
||||||
for (int64_t i = begin; i < end; i++) {
|
num_chunks,
|
||||||
int64_t outer_idx = i / num_chunks;
|
dim_size,
|
||||||
int64_t k = i % num_chunks;
|
begin,
|
||||||
int64_t inner_idx_begin = k * CHUNK_SIZE;
|
end);
|
||||||
int64_t size = std::min(CHUNK_SIZE, inner_size - inner_idx_begin);
|
|
||||||
|
|
||||||
// init
|
|
||||||
Vec zero_vec = Vec(scalar_t(0));
|
|
||||||
Vec min_vec = Vec(-std::numeric_limits<scalar_t>::infinity());
|
|
||||||
int64_t d0 = 0;
|
|
||||||
for (; d0 < size - (size % Vec::size()); d0 += Vec::size()) {
|
|
||||||
min_vec.store(input_max_data + d0);
|
|
||||||
zero_vec.store(tmp_sum_data + d0);
|
|
||||||
}
|
|
||||||
for (; d0 < size; d0++) {
|
|
||||||
input_max_data[d0] = -std::numeric_limits<scalar_t>::infinity();
|
|
||||||
tmp_sum_data[d0] = scalar_t(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
// compute max
|
|
||||||
for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
|
|
||||||
const scalar_t* input_ptr = input_data_base + outer_idx * dim_size * inner_size
|
|
||||||
+ dim_idx * inner_size + inner_idx_begin;
|
|
||||||
|
|
||||||
int64_t d1 = 0;
|
|
||||||
for (; d1 < size - (size % Vec::size()); d1 += Vec::size()) {
|
|
||||||
Vec data_vec = Vec::loadu(input_ptr + d1);
|
|
||||||
Vec max_vec = Vec::loadu(input_max_data + d1);
|
|
||||||
max_vec = Vec::blendv(max_vec, data_vec, data_vec > max_vec);
|
|
||||||
max_vec.store(input_max_data + d1);
|
|
||||||
}
|
|
||||||
for (; d1 < size; d1++) {
|
|
||||||
scalar_t data_val = input_ptr[d1];
|
|
||||||
scalar_t max_val = input_max_data[d1];
|
|
||||||
input_max_data[d1] = data_val > max_val ? data_val : max_val;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// compute sum of (x - max).exp()
|
|
||||||
for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
|
|
||||||
const scalar_t* input_ptr = input_data_base + outer_idx * dim_size * inner_size
|
|
||||||
+ dim_idx * inner_size + inner_idx_begin;
|
|
||||||
|
|
||||||
int64_t d2 = 0;
|
|
||||||
for (; d2 < size - (size % Vec::size()); d2 += Vec::size()) {
|
|
||||||
Vec data_vec = Vec::loadu(input_ptr + d2);
|
|
||||||
Vec sum_vec = Vec::loadu(tmp_sum_data + d2);
|
|
||||||
Vec max_vec = Vec::loadu(input_max_data + d2);
|
|
||||||
sum_vec += (data_vec - max_vec).exp();
|
|
||||||
sum_vec.store(tmp_sum_data + d2);
|
|
||||||
}
|
|
||||||
for (; d2 < size; d2++) {
|
|
||||||
scalar_t data_val = input_ptr[d2];
|
|
||||||
scalar_t max_val = input_max_data[d2];
|
|
||||||
tmp_sum_data[d2] += std::exp(data_val - max_val);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// apply log
|
|
||||||
vec::map([](Vec x) { return x.log(); }, tmp_sum_data, tmp_sum_data, size);
|
|
||||||
|
|
||||||
// compute x - max - sum
|
|
||||||
for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
|
|
||||||
int64_t offset = outer_idx * dim_size * inner_size + dim_idx * inner_size + inner_idx_begin;
|
|
||||||
const scalar_t* input_ptr = input_data_base + offset;
|
|
||||||
scalar_t* output_ptr = output_data_base + offset;
|
|
||||||
|
|
||||||
int64_t d3 = 0;
|
|
||||||
for (; d3 < size - (size % Vec::size()); d3 += Vec::size()) {
|
|
||||||
Vec data_vec = Vec::loadu(input_ptr + d3);
|
|
||||||
Vec max_vec = Vec::loadu(input_max_data + d3);
|
|
||||||
Vec sum_vec = Vec::loadu(tmp_sum_data + d3);
|
|
||||||
Vec out_vec = data_vec - max_vec - sum_vec;
|
|
||||||
out_vec.store(output_ptr + d3);
|
|
||||||
}
|
|
||||||
for (; d3 < size; d3++) {
|
|
||||||
output_ptr[d3] = input_ptr[d3] - input_max_data[d3] - tmp_sum_data[d3];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -996,125 +868,23 @@ _vec_logsoftmax(
|
||||||
int64_t outer_size,
|
int64_t outer_size,
|
||||||
int64_t inner_size,
|
int64_t inner_size,
|
||||||
int64_t dim_size) {
|
int64_t dim_size) {
|
||||||
using Vec = vec::Vectorized<scalar_t>;
|
const auto [CHUNK_SIZE_binding, num_chunks_binding] = vec_logsoftmax_chunk_size_and_num_chunks<scalar_t>(
|
||||||
using fVec = vec::Vectorized<float>;
|
inner_size, dim_size);
|
||||||
int64_t BLOCK_SIZE = 128 * 1024;
|
// Work around "capturing a structured binding is not yet supported in OpenMP".
|
||||||
int64_t MAX_CHUNK_SIZE = std::max<int64_t>(BLOCK_SIZE / dim_size / sizeof(scalar_t), Vec::size());
|
const auto CHUNK_SIZE = CHUNK_SIZE_binding;
|
||||||
MAX_CHUNK_SIZE = MAX_CHUNK_SIZE / Vec::size() * Vec::size();
|
const auto num_chunks = num_chunks_binding;
|
||||||
int64_t CHUNK_SIZE = std::min<int64_t>(MAX_CHUNK_SIZE, inner_size);
|
|
||||||
int64_t num_chunks = divup(inner_size, CHUNK_SIZE);
|
|
||||||
|
|
||||||
// See Note: grain_size value of 0
|
// See Note: grain_size value of 0
|
||||||
at::parallel_for(0, outer_size * num_chunks, 0, [&](int64_t begin, int64_t end) {
|
at::parallel_for(0, outer_size * num_chunks, 0, [&](int64_t begin, int64_t end) {
|
||||||
auto buffer = std::make_unique<float []>(CHUNK_SIZE * 2);
|
serial_vec_logsoftmax_range(
|
||||||
float* input_max_data = buffer.get();
|
input_data_base,
|
||||||
float* tmp_sum_data = buffer.get() + CHUNK_SIZE;
|
output_data_base,
|
||||||
|
inner_size,
|
||||||
// thread local buffer that holds input data in float32 to save next 2 dtype conversion
|
CHUNK_SIZE,
|
||||||
auto input_buffer = std::make_unique<float []>(dim_size * CHUNK_SIZE);
|
num_chunks,
|
||||||
float* input_buffer_data = input_buffer.get();
|
dim_size,
|
||||||
|
begin,
|
||||||
// init
|
end);
|
||||||
for (int64_t i = begin; i < end; i++) {
|
|
||||||
int64_t outer_idx = i / num_chunks;
|
|
||||||
int64_t k = i % num_chunks;
|
|
||||||
int64_t inner_idx_begin = k * CHUNK_SIZE;
|
|
||||||
int64_t size = std::min(CHUNK_SIZE, inner_size - inner_idx_begin);
|
|
||||||
|
|
||||||
fVec zero_fvec = fVec(float(0));
|
|
||||||
fVec min_fvec = fVec(-std::numeric_limits<float>::infinity());
|
|
||||||
int64_t d0 = 0;
|
|
||||||
for (; d0 < size - (size % Vec::size()); d0 += Vec::size()) {
|
|
||||||
min_fvec.store(input_max_data + d0);
|
|
||||||
min_fvec.store(input_max_data + d0 + fVec::size());
|
|
||||||
zero_fvec.store(tmp_sum_data + d0);
|
|
||||||
zero_fvec.store(tmp_sum_data + d0 + fVec::size());
|
|
||||||
}
|
|
||||||
for (; d0 < size; d0++) {
|
|
||||||
input_max_data[d0] = -std::numeric_limits<float>::infinity();
|
|
||||||
tmp_sum_data[d0] = float(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
// compute max
|
|
||||||
for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
|
|
||||||
const scalar_t* input_ptr = input_data_base + outer_idx * dim_size * inner_size
|
|
||||||
+ dim_idx * inner_size + inner_idx_begin;
|
|
||||||
float* input_buffer_ptr = input_buffer_data + dim_idx * CHUNK_SIZE;
|
|
||||||
|
|
||||||
int64_t d1 = 0;
|
|
||||||
for (; d1 < size - (size % Vec::size()); d1 += Vec::size()) {
|
|
||||||
Vec data_vec = Vec::loadu(input_ptr + d1);
|
|
||||||
auto [data_fvec0, data_fvec1] = vec::convert_to_float<scalar_t>(data_vec);
|
|
||||||
fVec max_fvec0 = fVec::loadu(input_max_data + d1);
|
|
||||||
fVec max_fvec1 = fVec::loadu(input_max_data + d1 + fVec::size());
|
|
||||||
max_fvec0 = fVec::blendv(max_fvec0, data_fvec0, data_fvec0 > max_fvec0);
|
|
||||||
max_fvec1 = fVec::blendv(max_fvec1, data_fvec1, data_fvec1 > max_fvec1);
|
|
||||||
max_fvec0.store(input_max_data + d1);
|
|
||||||
max_fvec1.store(input_max_data + d1 + fVec::size());
|
|
||||||
|
|
||||||
// cache the 'converted' float input
|
|
||||||
data_fvec0.store(input_buffer_ptr + d1);
|
|
||||||
data_fvec1.store(input_buffer_ptr + d1 + fVec::size());
|
|
||||||
}
|
|
||||||
for (; d1 < size; d1++) {
|
|
||||||
float data_val = float(input_ptr[d1]);
|
|
||||||
float max_val = input_max_data[d1];
|
|
||||||
input_max_data[d1] = data_val > max_val ? data_val : max_val;
|
|
||||||
input_buffer_ptr[d1] = data_val;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// compute sum of (x - max).exp()
|
|
||||||
for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
|
|
||||||
float* input_buffer_ptr = input_buffer_data + dim_idx * CHUNK_SIZE;
|
|
||||||
|
|
||||||
int64_t d2 = 0;
|
|
||||||
for (; d2 < size - (size % Vec::size()); d2 += Vec::size()) {
|
|
||||||
fVec data_fvec0 = fVec::loadu(input_buffer_ptr + d2);
|
|
||||||
fVec data_fvec1 = fVec::loadu(input_buffer_ptr + d2 + fVec::size());
|
|
||||||
fVec sum_fvec0 = fVec::loadu(tmp_sum_data + d2);
|
|
||||||
fVec sum_fvec1 = fVec::loadu(tmp_sum_data + d2 + fVec::size());
|
|
||||||
fVec max_fvec0 = fVec::loadu(input_max_data + d2);
|
|
||||||
fVec max_fvec1 = fVec::loadu(input_max_data + d2 + fVec::size());
|
|
||||||
sum_fvec0 += (data_fvec0 - max_fvec0).exp();
|
|
||||||
sum_fvec1 += (data_fvec1 - max_fvec1).exp();
|
|
||||||
sum_fvec0.store(tmp_sum_data + d2);
|
|
||||||
sum_fvec1.store(tmp_sum_data + d2 + fVec::size());
|
|
||||||
}
|
|
||||||
for (; d2 < size; d2++) {
|
|
||||||
float data_val = input_buffer_ptr[d2];
|
|
||||||
float max_val = input_max_data[d2];
|
|
||||||
tmp_sum_data[d2] += std::exp(data_val - max_val);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// apply log
|
|
||||||
vec::map([](fVec x) { return x.log(); }, tmp_sum_data, tmp_sum_data, size);
|
|
||||||
|
|
||||||
// compute x - max - sum
|
|
||||||
for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
|
|
||||||
float* input_buffer_ptr = input_buffer_data + dim_idx * CHUNK_SIZE;
|
|
||||||
scalar_t* output_ptr = output_data_base + outer_idx * dim_size * inner_size
|
|
||||||
+ dim_idx * inner_size + inner_idx_begin;
|
|
||||||
|
|
||||||
int64_t d3 = 0;
|
|
||||||
for (; d3 < size - (size % Vec::size()); d3 += Vec::size()) {
|
|
||||||
fVec data_fvec0 = fVec::loadu(input_buffer_ptr + d3);
|
|
||||||
fVec data_fvec1 = fVec::loadu(input_buffer_ptr + d3 + fVec::size());
|
|
||||||
fVec max_fvec0 = fVec::loadu(input_max_data + d3);
|
|
||||||
fVec max_fvec1 = fVec::loadu(input_max_data + d3 + fVec::size());
|
|
||||||
fVec sum_fvec0 = fVec::loadu(tmp_sum_data + d3);
|
|
||||||
fVec sum_fvec1 = fVec::loadu(tmp_sum_data + d3 + fVec::size());
|
|
||||||
fVec out_fvec0 = data_fvec0 - max_fvec0 - sum_fvec0;
|
|
||||||
fVec out_fvec1 = data_fvec1 - max_fvec1 - sum_fvec1;
|
|
||||||
Vec out_vec = vec::convert_from_float<scalar_t>(out_fvec0, out_fvec1);
|
|
||||||
out_vec.store(output_ptr + d3);
|
|
||||||
}
|
|
||||||
for (; d3 < size; d3++) {
|
|
||||||
output_ptr[d3] = scalar_t(input_buffer_ptr[d3] - input_max_data[d3] - tmp_sum_data[d3]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user