Extract CPU log_softmax kernels to header (#156243)

This allows sharing them with ExecuTorch.

Differential Revision: [D76830114](https://our.internmc.facebook.com/intern/diff/D76830114/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156243
Approved by: https://github.com/janeyx99
This commit is contained in:
Scott Wolchok 2025-06-23 09:16:06 -07:00 committed by PyTorch MergeBot
parent 96e4c95cd8
commit c82a174cea
2 changed files with 378 additions and 271 deletions

View File

@ -0,0 +1,337 @@
#pragma once
#include <ATen/OpMathType.h>
#include <ATen/Parallel.h>
#include <ATen/cpu/vec/functional.h>
#include <ATen/cpu/vec/vec.h>
#include <c10/util/irange.h>
#include <algorithm>
#include <cmath>
#include <cstdint>
#include <limits>
#include <memory>
#include <type_traits>
namespace at::native {
inline namespace CPU_CAPABILITY {
template <typename scalar_t>
int64_t vec_log_softmax_lastdim_chunk_size(int64_t grain_size, int64_t outer_size, int64_t dim_size) {
// Coincidentally, at::internal::GRAIN_SIZE is 32768, which is equal to the
// size of L1D cache on many processors. Some processors have 48 KB L1D cache
// nowadays, so maybe in the future, we can leverage the knowledge of a
// machine's L1D cache size.
int64_t MAX_CHUNK_SIZE = std::max<int64_t>(
1,
grain_size / (sizeof(scalar_t) * dim_size));
return std::min<int64_t>(MAX_CHUNK_SIZE, outer_size);
}
template <typename scalar_t>
void serial_vec_log_softmax_lastdim_range(
const scalar_t* input_data_base,
scalar_t* output_data_base,
int64_t dim_size,
int64_t chunk_size,
int64_t begin,
int64_t end) {
if (end <= begin) {
return;
}
using Vec = vec::Vectorized<vec::vec_scalar_t<scalar_t>>;
// MSVC requires such a declaration of dynamic arrays
// Source: https://stackoverflow.com/a/33423538
auto tmp_sum_scalar = std::make_unique<scalar_t[]>(chunk_size);
auto max_input_arr = std::make_unique<scalar_t[]>(chunk_size);
for (int64_t ii = begin; ii < end; ii += chunk_size) {
int64_t loop_end = chunk_size;
if (ii + chunk_size > end) {
loop_end = end - ii;
}
for (const auto j : c10::irange(loop_end)) {
int64_t i = ii + j;
const scalar_t* input_data = input_data_base + i * dim_size;
max_input_arr[j] = vec::reduce_all<scalar_t>(
[](Vec& x, Vec& y) { return vec::maximum(x, y); },
input_data,
dim_size);
}
for (const auto j : c10::irange(loop_end)) {
int64_t i = ii + j;
const scalar_t* input_data = input_data_base + i * dim_size;
scalar_t max_input = max_input_arr[j];
tmp_sum_scalar[j] = vec::map_reduce_all<scalar_t>(
[max_input](Vec x) { return (x - Vec(max_input)).exp(); },
[](Vec x, Vec y) { return x + y; },
input_data,
dim_size);
}
// See [Note AVX-SSE transitions] for why this should call the
// vectorized version (aside from perf improvements).
vec::map(
[](Vec x) { return x.log(); },
tmp_sum_scalar.get(),
tmp_sum_scalar.get(),
loop_end);
for (const auto j : c10::irange(loop_end)) {
int64_t i = ii + j;
const scalar_t* input_data = input_data_base + i * dim_size;
scalar_t* output_data = output_data_base + i * dim_size;
scalar_t tmp_sum = tmp_sum_scalar[j];
scalar_t max_input = max_input_arr[j];
// It's necessary to keep the order of the operations below.
// In some cases that input is large digits and the difference
// is small, if we compute `max_input` plus `tmp_sum` before,
// there would be a numerical problem. See an example in
// https://github.com/pytorch/pytorch/issues/11752#issuecomment-422883379
vec::map(
[tmp_sum, max_input](Vec x) {
return x - Vec(max_input) - Vec(tmp_sum);
},
output_data,
input_data,
dim_size);
}
}
}
// Can't include ATen/Parallel.h.
// TODO: find a way to have only one copy of divup.
inline int64_t divup(int64_t x, int64_t y) {
return (x + y - 1) / y;
}
template <typename scalar_t, int64_t BLOCK_SIZE = 128 * 1024>
std::pair<int64_t,int64_t> vec_logsoftmax_chunk_size_and_num_chunks(int64_t inner_size, int64_t dim_size) {
using Vec = vec::Vectorized<scalar_t>;
int64_t MAX_CHUNK_SIZE = std::max<int64_t>(BLOCK_SIZE / dim_size / sizeof(scalar_t), Vec::size());
MAX_CHUNK_SIZE = MAX_CHUNK_SIZE / Vec::size() * Vec::size();
int64_t CHUNK_SIZE = std::min<int64_t>(MAX_CHUNK_SIZE, inner_size);
int64_t num_chunks = divup(inner_size, CHUNK_SIZE);
return {CHUNK_SIZE, num_chunks};
}
template <typename scalar_t>
std::enable_if_t<std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
serial_vec_logsoftmax_range(
const scalar_t* input_data_base,
scalar_t* output_data_base,
int64_t inner_size,
int64_t chunk_size,
int64_t num_chunks,
int64_t dim_size,
int64_t begin,
int64_t end) {
using Vec = vec::Vectorized<scalar_t>;
// thread local temp buffer which holds vertical reduction result: max and sum.
auto buffer = std::make_unique<scalar_t []>(chunk_size * 2);
scalar_t* input_max_data = buffer.get();
scalar_t* tmp_sum_data = buffer.get() + chunk_size;
for (int64_t i = begin; i < end; i++) {
int64_t outer_idx = i / num_chunks;
int64_t k = i % num_chunks;
int64_t inner_idx_begin = k * chunk_size;
int64_t size = std::min(chunk_size, inner_size - inner_idx_begin);
// init
Vec zero_vec = Vec(scalar_t(0));
Vec min_vec = Vec(-std::numeric_limits<scalar_t>::infinity());
int64_t d0 = 0;
for (; d0 < size - (size % Vec::size()); d0 += Vec::size()) {
min_vec.store(input_max_data + d0);
zero_vec.store(tmp_sum_data + d0);
}
for (; d0 < size; d0++) {
input_max_data[d0] = -std::numeric_limits<scalar_t>::infinity();
tmp_sum_data[d0] = scalar_t(0);
}
// compute max
for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
const scalar_t* input_ptr = input_data_base + outer_idx * dim_size * inner_size
+ dim_idx * inner_size + inner_idx_begin;
int64_t d1 = 0;
for (; d1 < size - (size % Vec::size()); d1 += Vec::size()) {
Vec data_vec = Vec::loadu(input_ptr + d1);
Vec max_vec = Vec::loadu(input_max_data + d1);
max_vec = Vec::blendv(max_vec, data_vec, data_vec > max_vec);
max_vec.store(input_max_data + d1);
}
for (; d1 < size; d1++) {
scalar_t data_val = input_ptr[d1];
scalar_t max_val = input_max_data[d1];
input_max_data[d1] = data_val > max_val ? data_val : max_val;
}
}
// compute sum of (x - max).exp()
for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
const scalar_t* input_ptr = input_data_base + outer_idx * dim_size * inner_size
+ dim_idx * inner_size + inner_idx_begin;
int64_t d2 = 0;
for (; d2 < size - (size % Vec::size()); d2 += Vec::size()) {
Vec data_vec = Vec::loadu(input_ptr + d2);
Vec sum_vec = Vec::loadu(tmp_sum_data + d2);
Vec max_vec = Vec::loadu(input_max_data + d2);
sum_vec += (data_vec - max_vec).exp();
sum_vec.store(tmp_sum_data + d2);
}
for (; d2 < size; d2++) {
scalar_t data_val = input_ptr[d2];
scalar_t max_val = input_max_data[d2];
tmp_sum_data[d2] += std::exp(data_val - max_val);
}
}
// apply log
vec::map([](Vec x) { return x.log(); }, tmp_sum_data, tmp_sum_data, size);
// compute x - max - sum
for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
int64_t offset = outer_idx * dim_size * inner_size + dim_idx * inner_size + inner_idx_begin;
const scalar_t* input_ptr = input_data_base + offset;
scalar_t* output_ptr = output_data_base + offset;
int64_t d3 = 0;
for (; d3 < size - (size % Vec::size()); d3 += Vec::size()) {
Vec data_vec = Vec::loadu(input_ptr + d3);
Vec max_vec = Vec::loadu(input_max_data + d3);
Vec sum_vec = Vec::loadu(tmp_sum_data + d3);
Vec out_vec = data_vec - max_vec - sum_vec;
out_vec.store(output_ptr + d3);
}
for (; d3 < size; d3++) {
output_ptr[d3] = input_ptr[d3] - input_max_data[d3] - tmp_sum_data[d3];
}
}
}
}
template <typename scalar_t>
std::enable_if_t<!std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
serial_vec_logsoftmax_range(
const scalar_t* input_data_base,
scalar_t* output_data_base,
int64_t inner_size,
int64_t chunk_size,
int64_t num_chunks,
int64_t dim_size,
int64_t begin,
int64_t end) {
using Vec = vec::Vectorized<scalar_t>;
using fVec = vec::Vectorized<float>;
auto buffer = std::make_unique<float []>(chunk_size * 2);
float* input_max_data = buffer.get();
float* tmp_sum_data = buffer.get() + chunk_size;
// thread local buffer that holds input data in float32 to save next 2 dtype conversion
auto input_buffer = std::make_unique<float []>(dim_size * chunk_size);
float* input_buffer_data = input_buffer.get();
// init
for (int64_t i = begin; i < end; i++) {
int64_t outer_idx = i / num_chunks;
int64_t k = i % num_chunks;
int64_t inner_idx_begin = k * chunk_size;
int64_t size = std::min(chunk_size, inner_size - inner_idx_begin);
fVec zero_fvec = fVec(float(0));
fVec min_fvec = fVec(-std::numeric_limits<float>::infinity());
int64_t d0 = 0;
for (; d0 < size - (size % Vec::size()); d0 += Vec::size()) {
min_fvec.store(input_max_data + d0);
min_fvec.store(input_max_data + d0 + fVec::size());
zero_fvec.store(tmp_sum_data + d0);
zero_fvec.store(tmp_sum_data + d0 + fVec::size());
}
for (; d0 < size; d0++) {
input_max_data[d0] = -std::numeric_limits<float>::infinity();
tmp_sum_data[d0] = float(0);
}
// compute max
for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
const scalar_t* input_ptr = input_data_base + outer_idx * dim_size * inner_size
+ dim_idx * inner_size + inner_idx_begin;
float* input_buffer_ptr = input_buffer_data + dim_idx * chunk_size;
int64_t d1 = 0;
for (; d1 < size - (size % Vec::size()); d1 += Vec::size()) {
Vec data_vec = Vec::loadu(input_ptr + d1);
auto [data_fvec0, data_fvec1] = vec::convert_to_float<scalar_t>(data_vec);
fVec max_fvec0 = fVec::loadu(input_max_data + d1);
fVec max_fvec1 = fVec::loadu(input_max_data + d1 + fVec::size());
max_fvec0 = fVec::blendv(max_fvec0, data_fvec0, data_fvec0 > max_fvec0);
max_fvec1 = fVec::blendv(max_fvec1, data_fvec1, data_fvec1 > max_fvec1);
max_fvec0.store(input_max_data + d1);
max_fvec1.store(input_max_data + d1 + fVec::size());
// cache the 'converted' float input
data_fvec0.store(input_buffer_ptr + d1);
data_fvec1.store(input_buffer_ptr + d1 + fVec::size());
}
for (; d1 < size; d1++) {
float data_val = float(input_ptr[d1]);
float max_val = input_max_data[d1];
input_max_data[d1] = data_val > max_val ? data_val : max_val;
input_buffer_ptr[d1] = data_val;
}
}
// compute sum of (x - max).exp()
for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
float* input_buffer_ptr = input_buffer_data + dim_idx * chunk_size;
int64_t d2 = 0;
for (; d2 < size - (size % Vec::size()); d2 += Vec::size()) {
fVec data_fvec0 = fVec::loadu(input_buffer_ptr + d2);
fVec data_fvec1 = fVec::loadu(input_buffer_ptr + d2 + fVec::size());
fVec sum_fvec0 = fVec::loadu(tmp_sum_data + d2);
fVec sum_fvec1 = fVec::loadu(tmp_sum_data + d2 + fVec::size());
fVec max_fvec0 = fVec::loadu(input_max_data + d2);
fVec max_fvec1 = fVec::loadu(input_max_data + d2 + fVec::size());
sum_fvec0 += (data_fvec0 - max_fvec0).exp();
sum_fvec1 += (data_fvec1 - max_fvec1).exp();
sum_fvec0.store(tmp_sum_data + d2);
sum_fvec1.store(tmp_sum_data + d2 + fVec::size());
}
for (; d2 < size; d2++) {
float data_val = input_buffer_ptr[d2];
float max_val = input_max_data[d2];
tmp_sum_data[d2] += std::exp(data_val - max_val);
}
}
// apply log
vec::map([](fVec x) { return x.log(); }, tmp_sum_data, tmp_sum_data, size);
// compute x - max - sum
for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
float* input_buffer_ptr = input_buffer_data + dim_idx * chunk_size;
scalar_t* output_ptr = output_data_base + outer_idx * dim_size * inner_size
+ dim_idx * inner_size + inner_idx_begin;
int64_t d3 = 0;
for (; d3 < size - (size % Vec::size()); d3 += Vec::size()) {
fVec data_fvec0 = fVec::loadu(input_buffer_ptr + d3);
fVec data_fvec1 = fVec::loadu(input_buffer_ptr + d3 + fVec::size());
fVec max_fvec0 = fVec::loadu(input_max_data + d3);
fVec max_fvec1 = fVec::loadu(input_max_data + d3 + fVec::size());
fVec sum_fvec0 = fVec::loadu(tmp_sum_data + d3);
fVec sum_fvec1 = fVec::loadu(tmp_sum_data + d3 + fVec::size());
fVec out_fvec0 = data_fvec0 - max_fvec0 - sum_fvec0;
fVec out_fvec1 = data_fvec1 - max_fvec1 - sum_fvec1;
Vec out_vec = vec::convert_from_float<scalar_t>(out_fvec0, out_fvec1);
out_vec.store(output_ptr + d3);
}
for (; d3 < size; d3++) {
output_ptr[d3] = scalar_t(input_buffer_ptr[d3] - input_max_data[d3] - tmp_sum_data[d3]);
}
}
}
} // namespace CPU_CAPABILITY
}} // namespace at::native

View File

@ -2,6 +2,8 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/native/cpu/SoftmaxKernel.h>
#include <ATen/native/cpu/LogSoftmaxKernelImpl.h>
#include <algorithm>
#include <iterator>
#include <numeric>
@ -28,7 +30,6 @@
// We use a chunk size such that it'd fit in L1D.
namespace at::native {
namespace {
template <typename scalar_t>
inline void _vec_log_softmax_lastdim(
@ -36,15 +37,10 @@ inline void _vec_log_softmax_lastdim(
scalar_t* output_data_base,
int64_t outer_size,
int64_t dim_size) {
using Vec = vec::Vectorized<vec::vec_scalar_t<scalar_t>>;
// Coincidentally, at::internal::GRAIN_SIZE is 32768, which is equal to the
// size of L1D cache on many processors. Some processors have 48 KB L1D cache
// nowadays, so maybe in the future, we can leverage the knowledge of a
// machine's L1D cache size.
int64_t MAX_CHUNK_SIZE = std::max<int64_t>(
1,
at::internal::GRAIN_SIZE / (sizeof(scalar_t) * dim_size));
int64_t CHUNK_SIZE = std::min<int64_t>(MAX_CHUNK_SIZE, outer_size);
const auto chunk_size = vec_log_softmax_lastdim_chunk_size<scalar_t>(
at::internal::GRAIN_SIZE,
outer_size,
dim_size);
// Note: grain_size value of 0
// We don't change the number of OpenMP threads in the OpenMP thread-pool,
// so some threads do useful work, while others don't.
@ -52,60 +48,13 @@ inline void _vec_log_softmax_lastdim(
// work among threads in an equitable manner. We compute CHUNK_SIZE to ensure
// each thread's computations would be efficient.
parallel_for(0, outer_size, 0, [&](int64_t begin, int64_t end) {
// MSVC requires such a declaration of dynamic arrays
// Source: https://stackoverflow.com/a/33423538
auto tmp_sum_scalar = std::make_unique<scalar_t[]>(CHUNK_SIZE);
auto max_input_arr = std::make_unique<scalar_t[]>(CHUNK_SIZE);
for (int64_t ii = begin; ii < end; ii += CHUNK_SIZE) {
int64_t loop_end = CHUNK_SIZE;
if (ii + CHUNK_SIZE > end)
loop_end = end - ii;
for (const auto j : c10::irange(loop_end)) {
int64_t i = ii + j;
const scalar_t* input_data = input_data_base + i * dim_size;
max_input_arr[j] = vec::reduce_all<scalar_t>(
[](Vec& x, Vec& y) { return vec::maximum(x, y); },
input_data,
dim_size);
}
for (const auto j : c10::irange(loop_end)) {
int64_t i = ii + j;
const scalar_t* input_data = input_data_base + i * dim_size;
scalar_t max_input = max_input_arr[j];
tmp_sum_scalar[j] = vec::map_reduce_all<scalar_t>(
[max_input](Vec x) { return (x - Vec(max_input)).exp(); },
[](Vec x, Vec y) { return x + y; },
input_data,
dim_size);
}
// See [Note AVX-SSE transitions] for why this should call the
// vectorized version (aside from perf improvements).
vec::map(
[](Vec x) { return x.log(); },
tmp_sum_scalar.get(),
tmp_sum_scalar.get(),
loop_end);
for (const auto j : c10::irange(loop_end)) {
int64_t i = ii + j;
const scalar_t* input_data = input_data_base + i * dim_size;
scalar_t* output_data = output_data_base + i * dim_size;
scalar_t tmp_sum = tmp_sum_scalar[j];
scalar_t max_input = max_input_arr[j];
// It's necessary to keep the order of the operations below.
// In some cases that input is large digits and the difference
// is small, if we compute `max_input` plus `tmp_sum` before,
// there would be a numerical problem. See an example in
// https://github.com/pytorch/pytorch/issues/11752#issuecomment-422883379
vec::map(
[tmp_sum, max_input](Vec x) {
return x - Vec(max_input) - Vec(tmp_sum);
},
output_data,
input_data,
dim_size);
}
}
serial_vec_log_softmax_lastdim_range(
input_data_base,
output_data_base,
dim_size,
chunk_size,
begin,
end);
});
}
@ -891,100 +840,23 @@ _vec_logsoftmax(
int64_t outer_size,
int64_t inner_size,
int64_t dim_size) {
using Vec = vec::Vectorized<scalar_t>;
int64_t BLOCK_SIZE = 128 * 1024;
int64_t MAX_CHUNK_SIZE = std::max<int64_t>(BLOCK_SIZE / dim_size / sizeof(scalar_t), Vec::size());
MAX_CHUNK_SIZE = MAX_CHUNK_SIZE / Vec::size() * Vec::size();
int64_t CHUNK_SIZE = std::min<int64_t>(MAX_CHUNK_SIZE, inner_size);
int64_t num_chunks = divup(inner_size, CHUNK_SIZE);
const auto [CHUNK_SIZE_binding, num_chunks_binding] = vec_logsoftmax_chunk_size_and_num_chunks<scalar_t>(
inner_size, dim_size);
// Work around "capturing a structured binding is not yet supported in OpenMP".
const auto CHUNK_SIZE = CHUNK_SIZE_binding;
const auto num_chunks = num_chunks_binding;
// See Note: grain_size value of 0
at::parallel_for(0, outer_size * num_chunks, 0, [&](int64_t begin, int64_t end) {
// thread local temp buffer which holds vertical reduction result: max and sum.
auto buffer = std::make_unique<scalar_t []>(CHUNK_SIZE * 2);
scalar_t* input_max_data = buffer.get();
scalar_t* tmp_sum_data = buffer.get() + CHUNK_SIZE;
for (int64_t i = begin; i < end; i++) {
int64_t outer_idx = i / num_chunks;
int64_t k = i % num_chunks;
int64_t inner_idx_begin = k * CHUNK_SIZE;
int64_t size = std::min(CHUNK_SIZE, inner_size - inner_idx_begin);
// init
Vec zero_vec = Vec(scalar_t(0));
Vec min_vec = Vec(-std::numeric_limits<scalar_t>::infinity());
int64_t d0 = 0;
for (; d0 < size - (size % Vec::size()); d0 += Vec::size()) {
min_vec.store(input_max_data + d0);
zero_vec.store(tmp_sum_data + d0);
}
for (; d0 < size; d0++) {
input_max_data[d0] = -std::numeric_limits<scalar_t>::infinity();
tmp_sum_data[d0] = scalar_t(0);
}
// compute max
for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
const scalar_t* input_ptr = input_data_base + outer_idx * dim_size * inner_size
+ dim_idx * inner_size + inner_idx_begin;
int64_t d1 = 0;
for (; d1 < size - (size % Vec::size()); d1 += Vec::size()) {
Vec data_vec = Vec::loadu(input_ptr + d1);
Vec max_vec = Vec::loadu(input_max_data + d1);
max_vec = Vec::blendv(max_vec, data_vec, data_vec > max_vec);
max_vec.store(input_max_data + d1);
}
for (; d1 < size; d1++) {
scalar_t data_val = input_ptr[d1];
scalar_t max_val = input_max_data[d1];
input_max_data[d1] = data_val > max_val ? data_val : max_val;
}
}
// compute sum of (x - max).exp()
for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
const scalar_t* input_ptr = input_data_base + outer_idx * dim_size * inner_size
+ dim_idx * inner_size + inner_idx_begin;
int64_t d2 = 0;
for (; d2 < size - (size % Vec::size()); d2 += Vec::size()) {
Vec data_vec = Vec::loadu(input_ptr + d2);
Vec sum_vec = Vec::loadu(tmp_sum_data + d2);
Vec max_vec = Vec::loadu(input_max_data + d2);
sum_vec += (data_vec - max_vec).exp();
sum_vec.store(tmp_sum_data + d2);
}
for (; d2 < size; d2++) {
scalar_t data_val = input_ptr[d2];
scalar_t max_val = input_max_data[d2];
tmp_sum_data[d2] += std::exp(data_val - max_val);
}
}
// apply log
vec::map([](Vec x) { return x.log(); }, tmp_sum_data, tmp_sum_data, size);
// compute x - max - sum
for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
int64_t offset = outer_idx * dim_size * inner_size + dim_idx * inner_size + inner_idx_begin;
const scalar_t* input_ptr = input_data_base + offset;
scalar_t* output_ptr = output_data_base + offset;
int64_t d3 = 0;
for (; d3 < size - (size % Vec::size()); d3 += Vec::size()) {
Vec data_vec = Vec::loadu(input_ptr + d3);
Vec max_vec = Vec::loadu(input_max_data + d3);
Vec sum_vec = Vec::loadu(tmp_sum_data + d3);
Vec out_vec = data_vec - max_vec - sum_vec;
out_vec.store(output_ptr + d3);
}
for (; d3 < size; d3++) {
output_ptr[d3] = input_ptr[d3] - input_max_data[d3] - tmp_sum_data[d3];
}
}
}
serial_vec_logsoftmax_range(
input_data_base,
output_data_base,
inner_size,
CHUNK_SIZE,
num_chunks,
dim_size,
begin,
end);
});
}
@ -996,125 +868,23 @@ _vec_logsoftmax(
int64_t outer_size,
int64_t inner_size,
int64_t dim_size) {
using Vec = vec::Vectorized<scalar_t>;
using fVec = vec::Vectorized<float>;
int64_t BLOCK_SIZE = 128 * 1024;
int64_t MAX_CHUNK_SIZE = std::max<int64_t>(BLOCK_SIZE / dim_size / sizeof(scalar_t), Vec::size());
MAX_CHUNK_SIZE = MAX_CHUNK_SIZE / Vec::size() * Vec::size();
int64_t CHUNK_SIZE = std::min<int64_t>(MAX_CHUNK_SIZE, inner_size);
int64_t num_chunks = divup(inner_size, CHUNK_SIZE);
const auto [CHUNK_SIZE_binding, num_chunks_binding] = vec_logsoftmax_chunk_size_and_num_chunks<scalar_t>(
inner_size, dim_size);
// Work around "capturing a structured binding is not yet supported in OpenMP".
const auto CHUNK_SIZE = CHUNK_SIZE_binding;
const auto num_chunks = num_chunks_binding;
// See Note: grain_size value of 0
at::parallel_for(0, outer_size * num_chunks, 0, [&](int64_t begin, int64_t end) {
auto buffer = std::make_unique<float []>(CHUNK_SIZE * 2);
float* input_max_data = buffer.get();
float* tmp_sum_data = buffer.get() + CHUNK_SIZE;
// thread local buffer that holds input data in float32 to save next 2 dtype conversion
auto input_buffer = std::make_unique<float []>(dim_size * CHUNK_SIZE);
float* input_buffer_data = input_buffer.get();
// init
for (int64_t i = begin; i < end; i++) {
int64_t outer_idx = i / num_chunks;
int64_t k = i % num_chunks;
int64_t inner_idx_begin = k * CHUNK_SIZE;
int64_t size = std::min(CHUNK_SIZE, inner_size - inner_idx_begin);
fVec zero_fvec = fVec(float(0));
fVec min_fvec = fVec(-std::numeric_limits<float>::infinity());
int64_t d0 = 0;
for (; d0 < size - (size % Vec::size()); d0 += Vec::size()) {
min_fvec.store(input_max_data + d0);
min_fvec.store(input_max_data + d0 + fVec::size());
zero_fvec.store(tmp_sum_data + d0);
zero_fvec.store(tmp_sum_data + d0 + fVec::size());
}
for (; d0 < size; d0++) {
input_max_data[d0] = -std::numeric_limits<float>::infinity();
tmp_sum_data[d0] = float(0);
}
// compute max
for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
const scalar_t* input_ptr = input_data_base + outer_idx * dim_size * inner_size
+ dim_idx * inner_size + inner_idx_begin;
float* input_buffer_ptr = input_buffer_data + dim_idx * CHUNK_SIZE;
int64_t d1 = 0;
for (; d1 < size - (size % Vec::size()); d1 += Vec::size()) {
Vec data_vec = Vec::loadu(input_ptr + d1);
auto [data_fvec0, data_fvec1] = vec::convert_to_float<scalar_t>(data_vec);
fVec max_fvec0 = fVec::loadu(input_max_data + d1);
fVec max_fvec1 = fVec::loadu(input_max_data + d1 + fVec::size());
max_fvec0 = fVec::blendv(max_fvec0, data_fvec0, data_fvec0 > max_fvec0);
max_fvec1 = fVec::blendv(max_fvec1, data_fvec1, data_fvec1 > max_fvec1);
max_fvec0.store(input_max_data + d1);
max_fvec1.store(input_max_data + d1 + fVec::size());
// cache the 'converted' float input
data_fvec0.store(input_buffer_ptr + d1);
data_fvec1.store(input_buffer_ptr + d1 + fVec::size());
}
for (; d1 < size; d1++) {
float data_val = float(input_ptr[d1]);
float max_val = input_max_data[d1];
input_max_data[d1] = data_val > max_val ? data_val : max_val;
input_buffer_ptr[d1] = data_val;
}
}
// compute sum of (x - max).exp()
for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
float* input_buffer_ptr = input_buffer_data + dim_idx * CHUNK_SIZE;
int64_t d2 = 0;
for (; d2 < size - (size % Vec::size()); d2 += Vec::size()) {
fVec data_fvec0 = fVec::loadu(input_buffer_ptr + d2);
fVec data_fvec1 = fVec::loadu(input_buffer_ptr + d2 + fVec::size());
fVec sum_fvec0 = fVec::loadu(tmp_sum_data + d2);
fVec sum_fvec1 = fVec::loadu(tmp_sum_data + d2 + fVec::size());
fVec max_fvec0 = fVec::loadu(input_max_data + d2);
fVec max_fvec1 = fVec::loadu(input_max_data + d2 + fVec::size());
sum_fvec0 += (data_fvec0 - max_fvec0).exp();
sum_fvec1 += (data_fvec1 - max_fvec1).exp();
sum_fvec0.store(tmp_sum_data + d2);
sum_fvec1.store(tmp_sum_data + d2 + fVec::size());
}
for (; d2 < size; d2++) {
float data_val = input_buffer_ptr[d2];
float max_val = input_max_data[d2];
tmp_sum_data[d2] += std::exp(data_val - max_val);
}
}
// apply log
vec::map([](fVec x) { return x.log(); }, tmp_sum_data, tmp_sum_data, size);
// compute x - max - sum
for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
float* input_buffer_ptr = input_buffer_data + dim_idx * CHUNK_SIZE;
scalar_t* output_ptr = output_data_base + outer_idx * dim_size * inner_size
+ dim_idx * inner_size + inner_idx_begin;
int64_t d3 = 0;
for (; d3 < size - (size % Vec::size()); d3 += Vec::size()) {
fVec data_fvec0 = fVec::loadu(input_buffer_ptr + d3);
fVec data_fvec1 = fVec::loadu(input_buffer_ptr + d3 + fVec::size());
fVec max_fvec0 = fVec::loadu(input_max_data + d3);
fVec max_fvec1 = fVec::loadu(input_max_data + d3 + fVec::size());
fVec sum_fvec0 = fVec::loadu(tmp_sum_data + d3);
fVec sum_fvec1 = fVec::loadu(tmp_sum_data + d3 + fVec::size());
fVec out_fvec0 = data_fvec0 - max_fvec0 - sum_fvec0;
fVec out_fvec1 = data_fvec1 - max_fvec1 - sum_fvec1;
Vec out_vec = vec::convert_from_float<scalar_t>(out_fvec0, out_fvec1);
out_vec.store(output_ptr + d3);
}
for (; d3 < size; d3++) {
output_ptr[d3] = scalar_t(input_buffer_ptr[d3] - input_max_data[d3] - tmp_sum_data[d3]);
}
}
}
serial_vec_logsoftmax_range(
input_data_base,
output_data_base,
inner_size,
CHUNK_SIZE,
num_chunks,
dim_size,
begin,
end);
});
}