add bf16 in fp32 out fast path for embedingbag in caffe2 perfkernel (#89198)

Add BF16 in FP32 out kernel into Caffe2 emb perfkernels. And also update the python code-gen files to generate the kernel.
The ut will be covered in the next PR(#89199) in this stack ( Tested by nn.EmbeddingBag with BF16 data type)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89198
Approved by: https://github.com/jgong5, https://github.com/kit1980
This commit is contained in:
haozhe.zhu 2022-11-29 23:54:54 +00:00 committed by PyTorch MergeBot
parent 68805b08d1
commit 7cd6e6acad
3 changed files with 1355 additions and 4 deletions

View File

@ -1,5 +1,6 @@
#include "caffe2/perfkernels/embedding_lookup_idx.h"
#include <c10/util/BFloat16.h>
#include <c10/util/Half.h>
#include <c10/util/irange.h>
#include "caffe2/core/common.h"
@ -214,6 +215,8 @@ EMBEDDING_IDX_SPECIALIZATION(int32_t, float, float, float, false);
EMBEDDING_IDX_SPECIALIZATION(int64_t, float, float, float, false);
EMBEDDING_IDX_SPECIALIZATION(int32_t, half, at::Half, float, false);
EMBEDDING_IDX_SPECIALIZATION(int64_t, half, at::Half, float, false);
EMBEDDING_IDX_SPECIALIZATION(int32_t, bfloat16, at::BFloat16, float, false);
EMBEDDING_IDX_SPECIALIZATION(int64_t, bfloat16, at::BFloat16, float, false);
EMBEDDING_IDX_SPECIALIZATION(int32_t, uint8_t, uint8_t, float, false);
EMBEDDING_IDX_SPECIALIZATION(int64_t, uint8_t, uint8_t, float, false);
@ -221,6 +224,8 @@ EMBEDDING_IDX_SPECIALIZATION(int32_t, float, float, float, true);
EMBEDDING_IDX_SPECIALIZATION(int64_t, float, float, float, true);
EMBEDDING_IDX_SPECIALIZATION(int32_t, half, at::Half, float, true);
EMBEDDING_IDX_SPECIALIZATION(int64_t, half, at::Half, float, true);
EMBEDDING_IDX_SPECIALIZATION(int32_t, bfloat16, at::BFloat16, float, true);
EMBEDDING_IDX_SPECIALIZATION(int64_t, bfloat16, at::BFloat16, float, true);
EMBEDDING_IDX_SPECIALIZATION(int32_t, uint8_t, uint8_t, float, true);
EMBEDDING_IDX_SPECIALIZATION(int64_t, uint8_t, uint8_t, float, true);

File diff suppressed because it is too large Load Diff

View File

@ -4,7 +4,7 @@ import argparse
import sys
sizeof = {"float": 4, "at::Half": 2, "uint8_t": 1}
sizeof = {"float": 4, "at::Half": 2, "at::BFloat16": 2, "uint8_t": 1}
def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets):
@ -24,6 +24,16 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets)
" _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (%d)))),\n" # noqa
" vop%d);" % (regid, regid, regid)
)
elif InType == "at::BFloat16":
code.append(
" vop%d = _mm256_fmadd_ps(\n"
" vwgt,\n"
" _mm256_castsi256_ps(_mm256_slli_epi32(\n"
" _mm256_cvtepu16_epi32(_mm_loadu_si128(\n"
" reinterpret_cast<const __m128i*>(ip + (%d)))),\n"
" 16)),\n" # noqa
" vop%d);" % (regid, regid, regid)
)
elif InType == "uint8_t":
code.append(
" vop%d = _mm256_fmadd_ps(\n"
@ -104,6 +114,7 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets)
if InType == "uint8_t":
code.append(" " + OutType + " wgt = 1.f;")
code.append(" // NOLINTNEXTLINE(cppcoreguidelines-init-variables)")
code.append(" " + OutType + " bio;")
code.append(" if (weights) {")
code.append(
@ -133,7 +144,10 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets)
code.append(" const {}* ip = &input[idx * fused_block_size];".format(InType))
code.append(
" const {} next_T0 = (dataInd < index_size - prefdist_T0)\n"
" ? (dataInd + prefdist_T0)\n : dataInd;".format(
" // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)\n"
" ? (dataInd + prefdist_T0)\n"
" // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)\n"
" : dataInd;".format(
IndexType
)
)
@ -206,6 +220,18 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets):
" reinterpret_cast<const __m128i*>(&ip[j]))),\n"
" _mm256_loadu_ps(&op[j])));"
)
elif InType == "at::BFloat16":
code.append(
" _mm256_storeu_ps(\n"
" &op[j],\n"
" _mm256_fmadd_ps(\n"
" vwgt,\n"
" _mm256_castsi256_ps(_mm256_slli_epi32(\n"
" _mm256_cvtepu16_epi32(_mm_loadu_si128(\n"
" reinterpret_cast<const __m128i*>(&ip[j]))),\n"
" 16)),\n"
" _mm256_loadu_ps(&op[j])));"
)
elif InType == "uint8_t":
code.append(
" _mm256_storeu_ps(\n"
@ -229,7 +255,8 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets):
code = []
if InType == "at::Half":
code.append(" alignas(64) at::Half vtmp1[8] = {0};")
if InType == "at::BFloat16":
code.append(" alignas(64) at::BFloat16 vtmp1[8] = {0};")
if use_offsets:
@ -291,6 +318,7 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets):
if InType == "uint8_t":
code.append(" " + OutType + " wgt = 1.f;")
code.append(" // NOLINTNEXTLINE(cppcoreguidelines-init-variables)")
code.append(" " + OutType + " bio;")
code.append(" if (weights) {")
code.append(
@ -320,7 +348,10 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets):
code.append(" const {}* ip = &input[idx * fused_block_size];".format(InType))
code.append(
" const {} next_T0 = (dataInd < index_size - prefdist_T0)\n"
" ? (dataInd + prefdist_T0)\n : dataInd;".format(
" // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)\n"
" ? (dataInd + prefdist_T0)\n"
" // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)\n"
" : dataInd;".format(
IndexType
)
)
@ -351,6 +382,14 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets):
" _mm256_cvtph_ps(*(reinterpret_cast<const __m128i*>(vtmp1)));"
)
code.append(" op[j] = std::fma(wgt, ((float*)(&vtmp2))[0], op[j]);")
elif InType == "at::BFloat16":
code.append(" vtmp1[0] = ip[j];")
code.append(
" __m256 vtmp2 = _mm256_castsi256_ps(_mm256_slli_epi32(\n"
" _mm256_cvtepu16_epi32(*(reinterpret_cast<const __m128i*>(vtmp1))),\n"
" 16));"
)
code.append(" op[j] = std::fma(wgt, ((float*)(&vtmp2))[0], op[j]);")
elif InType == "uint8_t":
code.append(" op[j] = std::fma(wgt, (float)ip[j], bio + op[j]);")
else:
@ -408,6 +447,8 @@ options = [
["int64_t", "int64_t", "float", "float", "float", "float"],
["int32_t", "int", "half", "at::Half", "float", "float"],
["int64_t", "int64_t", "half", "at::Half", "float", "float"],
["int32_t", "int", "bfloat16", "at::BFloat16", "float", "float"],
["int64_t", "int64_t", "bfloat16", "at::BFloat16", "float", "float"],
["int32_t", "int", "uint8_t", "uint8_t", "float", "float"],
["int64_t", "int64_t", "uint8_t", "uint8_t", "float", "float"],
]
@ -422,6 +463,7 @@ code.append("//// DO NOT MODIFY!!!")
code.append("//// --------------------------\n")
code.append("#include <c10/util/Half.h>")
code.append("#include <c10/util/BFloat16.h>")
code.append("#include <immintrin.h>")
code.append("namespace caffe2 {\n")
@ -461,6 +503,7 @@ for o in options:
code += args
code.append(" const " + IndexType + " prefdist_T0 = 16;")
code.append(" // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)")
# block_size is the number of elements and fused_block_size is the size of
# an entire row, including scale and bias.
offset = (8 // sizeof[InType]) if opts.fused else 0
@ -484,6 +527,7 @@ for o in options:
code += unroll(2, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets)
code.append(" } else {")
code.append(" // generic code")
code.append(" // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)")
code += generic(IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets)
code.append(" }")
code.append(" return dataInd == index_size;")