mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/21169 We should minimize dependency from perfkernels (we were including eigen header files only in cc files not compiled with avx or avx2 options but better to be very strict because it's easy to introduce illegal instruction errors in perfkernels) Reviewed By: salexspb Differential Revision: D15563839 fbshipit-source-id: d4b1bca22d7f2e6f20f23664d4b99498e5984586
222 lines
16 KiB
C++
222 lines
16 KiB
C++
#include "caffe2/perfkernels/embedding_lookup.h"
|
|
|
|
#include "caffe2/core/types.h"
|
|
#include "caffe2/perfkernels/common.h"
|
|
|
|
namespace caffe2 {
|
|
|
|
/**
|
|
* Base implementation does runtime dispatch for each segment of reduction
|
|
* @return false if there is an out-of-bound error
|
|
*/
|
|
template <
|
|
typename IndexType,
|
|
typename InType,
|
|
typename OutType,
|
|
bool IS_WEIGHT_POSITIONAL = false>
|
|
static bool EmbeddingLookupGenericSlow(
|
|
const int64_t block_size,
|
|
const int64_t output_size,
|
|
const int64_t index_size,
|
|
const int64_t data_size,
|
|
const InType* input,
|
|
const IndexType* indices,
|
|
const int* lengths,
|
|
const float* weights, // optional, can be null for sum reducer
|
|
const float* scale_bias, // optional scale & bias params for uint8 input
|
|
bool normalize_by_lengths,
|
|
OutType* out) {
|
|
int64_t current = 0;
|
|
for (int m = 0; m < output_size; ++m) {
|
|
memset(out, 0, sizeof(OutType) * block_size);
|
|
if (current + lengths[m] > index_size) {
|
|
return false;
|
|
}
|
|
for (int i = 0; i < lengths[m]; ++i) {
|
|
int64_t idx = indices[current];
|
|
if (idx < 0 || idx >= data_size) {
|
|
return false;
|
|
}
|
|
#ifdef __GNUC__
|
|
if (current + 1 < index_size) {
|
|
__builtin_prefetch(input + block_size * indices[current + 1], 0, 1);
|
|
}
|
|
#endif // __GNUC__
|
|
|
|
float w = 1.f, b = 0.f;
|
|
if (weights) {
|
|
w = weights[IS_WEIGHT_POSITIONAL ? i : current];
|
|
}
|
|
if (scale_bias) {
|
|
b = w * scale_bias[2 * indices[current] + 1];
|
|
w = w * scale_bias[2 * indices[current]];
|
|
}
|
|
|
|
for (int j = 0; j < block_size; ++j) {
|
|
out[j] += w * input[block_size * indices[current] + j] + b;
|
|
}
|
|
|
|
++current;
|
|
}
|
|
if (normalize_by_lengths && lengths[m]) {
|
|
float scale = 1.f / lengths[m];
|
|
for (int j = 0; j < block_size; ++j) {
|
|
out[j] *= scale;
|
|
}
|
|
}
|
|
out += block_size;
|
|
}
|
|
return current == index_size;
|
|
}
|
|
|
|
// Proxy back to generic implementation
|
|
#define EMBEDDING_SPECIALIZATION( \
|
|
IndexType, InTypeName, InType, OutType, IS_WEIGHT_POSITIONAL) \
|
|
bool \
|
|
EmbeddingLookup_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL##__base( \
|
|
const int64_t block_size, \
|
|
const int64_t output_size, \
|
|
const int64_t index_size, \
|
|
const int64_t data_size, \
|
|
const InType* input, \
|
|
const IndexType* indices, \
|
|
const int* lengths, \
|
|
const float* weights, \
|
|
const float* scale_bias, \
|
|
bool normalize_by_lengths, \
|
|
OutType* out) { \
|
|
return EmbeddingLookupGenericSlow< \
|
|
IndexType, \
|
|
InType, \
|
|
OutType, \
|
|
IS_WEIGHT_POSITIONAL>( \
|
|
block_size, \
|
|
output_size, \
|
|
index_size, \
|
|
data_size, \
|
|
input, \
|
|
indices, \
|
|
lengths, \
|
|
weights, \
|
|
scale_bias, \
|
|
normalize_by_lengths, \
|
|
out); \
|
|
} \
|
|
decltype( \
|
|
EmbeddingLookup_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL##__base) \
|
|
EmbeddingLookup_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL##__avx2_fma; \
|
|
bool \
|
|
EmbeddingLookup_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL( \
|
|
const int64_t block_size, \
|
|
const int64_t output_size, \
|
|
const int64_t index_size, \
|
|
const int64_t data_size, \
|
|
const InType* input, \
|
|
const IndexType* indices, \
|
|
const int* lengths, \
|
|
const float* weights, \
|
|
const float* scale_bias, \
|
|
bool normalize_by_lengths, \
|
|
OutType* out) { \
|
|
if (std::is_same<InType, uint8_t>::value) { \
|
|
CAFFE_ENFORCE(scale_bias != nullptr, "scale_bias must not be nullptr"); \
|
|
} else { \
|
|
CAFFE_ENFORCE(scale_bias == nullptr, "scale_bias must be nullptr"); \
|
|
} \
|
|
AVX2_FMA_DO( \
|
|
EmbeddingLookup_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL, \
|
|
block_size, \
|
|
output_size, \
|
|
index_size, \
|
|
data_size, \
|
|
input, \
|
|
indices, \
|
|
lengths, \
|
|
weights, \
|
|
scale_bias, \
|
|
normalize_by_lengths, \
|
|
out); \
|
|
BASE_DO( \
|
|
EmbeddingLookup_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL, \
|
|
block_size, \
|
|
output_size, \
|
|
index_size, \
|
|
data_size, \
|
|
input, \
|
|
indices, \
|
|
lengths, \
|
|
weights, \
|
|
scale_bias, \
|
|
normalize_by_lengths, \
|
|
out); \
|
|
} \
|
|
template <> \
|
|
void EmbeddingLookup<IndexType, InType, OutType, IS_WEIGHT_POSITIONAL>( \
|
|
const int64_t block_size, \
|
|
const int64_t output_size, \
|
|
const int64_t index_size, \
|
|
const int64_t data_size, \
|
|
const InType* input, \
|
|
const IndexType* indices, \
|
|
const int* lengths, \
|
|
const float* weights, \
|
|
const float* scale_bias, \
|
|
bool normalize_by_lengths, \
|
|
OutType* out) { \
|
|
bool success = \
|
|
EmbeddingLookup_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL( \
|
|
block_size, \
|
|
output_size, \
|
|
index_size, \
|
|
data_size, \
|
|
input, \
|
|
indices, \
|
|
lengths, \
|
|
weights, \
|
|
scale_bias, \
|
|
normalize_by_lengths, \
|
|
out); \
|
|
if (success) { \
|
|
return; \
|
|
} \
|
|
int64_t current = 0; \
|
|
for (int m = 0; m < output_size; ++m) { \
|
|
for (int i = 0; i < lengths[m]; ++i) { \
|
|
CAFFE_ENFORCE_LT(current, index_size); \
|
|
IndexType idx = indices[current]; \
|
|
CAFFE_ENFORCE( \
|
|
0 <= idx && idx < data_size, \
|
|
"Index ", \
|
|
current, \
|
|
" is out of bounds: ", \
|
|
idx, \
|
|
", range 0 to ", \
|
|
data_size); \
|
|
++current; \
|
|
} \
|
|
} \
|
|
CAFFE_ENFORCE_EQ( \
|
|
current, \
|
|
index_size, \
|
|
"Your input seems to be incorrect: the sum of lengths values should be " \
|
|
"the size of the indices tensor, but it appears not."); \
|
|
}
|
|
|
|
EMBEDDING_SPECIALIZATION(int32_t, float, float, float, false);
|
|
EMBEDDING_SPECIALIZATION(int64_t, float, float, float, false);
|
|
EMBEDDING_SPECIALIZATION(int32_t, half, at::Half, float, false);
|
|
EMBEDDING_SPECIALIZATION(int64_t, half, at::Half, float, false);
|
|
EMBEDDING_SPECIALIZATION(int32_t, uint8_t, uint8_t, float, false);
|
|
EMBEDDING_SPECIALIZATION(int64_t, uint8_t, uint8_t, float, false);
|
|
|
|
EMBEDDING_SPECIALIZATION(int32_t, float, float, float, true);
|
|
EMBEDDING_SPECIALIZATION(int64_t, float, float, float, true);
|
|
EMBEDDING_SPECIALIZATION(int32_t, half, at::Half, float, true);
|
|
EMBEDDING_SPECIALIZATION(int64_t, half, at::Half, float, true);
|
|
EMBEDDING_SPECIALIZATION(int32_t, uint8_t, uint8_t, float, true);
|
|
EMBEDDING_SPECIALIZATION(int64_t, uint8_t, uint8_t, float, true);
|
|
|
|
#undef EMBEDDING_SPECIALIZATION
|
|
|
|
} // namespace caffe2
|