mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
add bf16 in fp32 out fast path for embedingbag in caffe2 perfkernel (#89198)
Add BF16 in FP32 out kernel into Caffe2 emb perfkernels. And also update the python code-gen files to generate the kernel. The ut will be covered in the next PR(#89199) in this stack ( Tested by nn.EmbeddingBag with BF16 data type) Pull Request resolved: https://github.com/pytorch/pytorch/pull/89198 Approved by: https://github.com/jgong5, https://github.com/kit1980
This commit is contained in:
parent
68805b08d1
commit
7cd6e6acad
|
|
@ -1,5 +1,6 @@
|
|||
#include "caffe2/perfkernels/embedding_lookup_idx.h"
|
||||
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include "caffe2/core/common.h"
|
||||
|
|
@ -214,6 +215,8 @@ EMBEDDING_IDX_SPECIALIZATION(int32_t, float, float, float, false);
|
|||
EMBEDDING_IDX_SPECIALIZATION(int64_t, float, float, float, false);
|
||||
EMBEDDING_IDX_SPECIALIZATION(int32_t, half, at::Half, float, false);
|
||||
EMBEDDING_IDX_SPECIALIZATION(int64_t, half, at::Half, float, false);
|
||||
EMBEDDING_IDX_SPECIALIZATION(int32_t, bfloat16, at::BFloat16, float, false);
|
||||
EMBEDDING_IDX_SPECIALIZATION(int64_t, bfloat16, at::BFloat16, float, false);
|
||||
EMBEDDING_IDX_SPECIALIZATION(int32_t, uint8_t, uint8_t, float, false);
|
||||
EMBEDDING_IDX_SPECIALIZATION(int64_t, uint8_t, uint8_t, float, false);
|
||||
|
||||
|
|
@ -221,6 +224,8 @@ EMBEDDING_IDX_SPECIALIZATION(int32_t, float, float, float, true);
|
|||
EMBEDDING_IDX_SPECIALIZATION(int64_t, float, float, float, true);
|
||||
EMBEDDING_IDX_SPECIALIZATION(int32_t, half, at::Half, float, true);
|
||||
EMBEDDING_IDX_SPECIALIZATION(int64_t, half, at::Half, float, true);
|
||||
EMBEDDING_IDX_SPECIALIZATION(int32_t, bfloat16, at::BFloat16, float, true);
|
||||
EMBEDDING_IDX_SPECIALIZATION(int64_t, bfloat16, at::BFloat16, float, true);
|
||||
EMBEDDING_IDX_SPECIALIZATION(int32_t, uint8_t, uint8_t, float, true);
|
||||
EMBEDDING_IDX_SPECIALIZATION(int64_t, uint8_t, uint8_t, float, true);
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -4,7 +4,7 @@ import argparse
|
|||
import sys
|
||||
|
||||
|
||||
sizeof = {"float": 4, "at::Half": 2, "uint8_t": 1}
|
||||
sizeof = {"float": 4, "at::Half": 2, "at::BFloat16": 2, "uint8_t": 1}
|
||||
|
||||
|
||||
def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets):
|
||||
|
|
@ -24,6 +24,16 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets)
|
|||
" _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (%d)))),\n" # noqa
|
||||
" vop%d);" % (regid, regid, regid)
|
||||
)
|
||||
elif InType == "at::BFloat16":
|
||||
code.append(
|
||||
" vop%d = _mm256_fmadd_ps(\n"
|
||||
" vwgt,\n"
|
||||
" _mm256_castsi256_ps(_mm256_slli_epi32(\n"
|
||||
" _mm256_cvtepu16_epi32(_mm_loadu_si128(\n"
|
||||
" reinterpret_cast<const __m128i*>(ip + (%d)))),\n"
|
||||
" 16)),\n" # noqa
|
||||
" vop%d);" % (regid, regid, regid)
|
||||
)
|
||||
elif InType == "uint8_t":
|
||||
code.append(
|
||||
" vop%d = _mm256_fmadd_ps(\n"
|
||||
|
|
@ -104,6 +114,7 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets)
|
|||
|
||||
if InType == "uint8_t":
|
||||
code.append(" " + OutType + " wgt = 1.f;")
|
||||
code.append(" // NOLINTNEXTLINE(cppcoreguidelines-init-variables)")
|
||||
code.append(" " + OutType + " bio;")
|
||||
code.append(" if (weights) {")
|
||||
code.append(
|
||||
|
|
@ -133,7 +144,10 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets)
|
|||
code.append(" const {}* ip = &input[idx * fused_block_size];".format(InType))
|
||||
code.append(
|
||||
" const {} next_T0 = (dataInd < index_size - prefdist_T0)\n"
|
||||
" ? (dataInd + prefdist_T0)\n : dataInd;".format(
|
||||
" // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)\n"
|
||||
" ? (dataInd + prefdist_T0)\n"
|
||||
" // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)\n"
|
||||
" : dataInd;".format(
|
||||
IndexType
|
||||
)
|
||||
)
|
||||
|
|
@ -206,6 +220,18 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets):
|
|||
" reinterpret_cast<const __m128i*>(&ip[j]))),\n"
|
||||
" _mm256_loadu_ps(&op[j])));"
|
||||
)
|
||||
elif InType == "at::BFloat16":
|
||||
code.append(
|
||||
" _mm256_storeu_ps(\n"
|
||||
" &op[j],\n"
|
||||
" _mm256_fmadd_ps(\n"
|
||||
" vwgt,\n"
|
||||
" _mm256_castsi256_ps(_mm256_slli_epi32(\n"
|
||||
" _mm256_cvtepu16_epi32(_mm_loadu_si128(\n"
|
||||
" reinterpret_cast<const __m128i*>(&ip[j]))),\n"
|
||||
" 16)),\n"
|
||||
" _mm256_loadu_ps(&op[j])));"
|
||||
)
|
||||
elif InType == "uint8_t":
|
||||
code.append(
|
||||
" _mm256_storeu_ps(\n"
|
||||
|
|
@ -229,7 +255,8 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets):
|
|||
code = []
|
||||
if InType == "at::Half":
|
||||
code.append(" alignas(64) at::Half vtmp1[8] = {0};")
|
||||
|
||||
if InType == "at::BFloat16":
|
||||
code.append(" alignas(64) at::BFloat16 vtmp1[8] = {0};")
|
||||
|
||||
|
||||
if use_offsets:
|
||||
|
|
@ -291,6 +318,7 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets):
|
|||
|
||||
if InType == "uint8_t":
|
||||
code.append(" " + OutType + " wgt = 1.f;")
|
||||
code.append(" // NOLINTNEXTLINE(cppcoreguidelines-init-variables)")
|
||||
code.append(" " + OutType + " bio;")
|
||||
code.append(" if (weights) {")
|
||||
code.append(
|
||||
|
|
@ -320,7 +348,10 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets):
|
|||
code.append(" const {}* ip = &input[idx * fused_block_size];".format(InType))
|
||||
code.append(
|
||||
" const {} next_T0 = (dataInd < index_size - prefdist_T0)\n"
|
||||
" ? (dataInd + prefdist_T0)\n : dataInd;".format(
|
||||
" // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)\n"
|
||||
" ? (dataInd + prefdist_T0)\n"
|
||||
" // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)\n"
|
||||
" : dataInd;".format(
|
||||
IndexType
|
||||
)
|
||||
)
|
||||
|
|
@ -351,6 +382,14 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets):
|
|||
" _mm256_cvtph_ps(*(reinterpret_cast<const __m128i*>(vtmp1)));"
|
||||
)
|
||||
code.append(" op[j] = std::fma(wgt, ((float*)(&vtmp2))[0], op[j]);")
|
||||
elif InType == "at::BFloat16":
|
||||
code.append(" vtmp1[0] = ip[j];")
|
||||
code.append(
|
||||
" __m256 vtmp2 = _mm256_castsi256_ps(_mm256_slli_epi32(\n"
|
||||
" _mm256_cvtepu16_epi32(*(reinterpret_cast<const __m128i*>(vtmp1))),\n"
|
||||
" 16));"
|
||||
)
|
||||
code.append(" op[j] = std::fma(wgt, ((float*)(&vtmp2))[0], op[j]);")
|
||||
elif InType == "uint8_t":
|
||||
code.append(" op[j] = std::fma(wgt, (float)ip[j], bio + op[j]);")
|
||||
else:
|
||||
|
|
@ -408,6 +447,8 @@ options = [
|
|||
["int64_t", "int64_t", "float", "float", "float", "float"],
|
||||
["int32_t", "int", "half", "at::Half", "float", "float"],
|
||||
["int64_t", "int64_t", "half", "at::Half", "float", "float"],
|
||||
["int32_t", "int", "bfloat16", "at::BFloat16", "float", "float"],
|
||||
["int64_t", "int64_t", "bfloat16", "at::BFloat16", "float", "float"],
|
||||
["int32_t", "int", "uint8_t", "uint8_t", "float", "float"],
|
||||
["int64_t", "int64_t", "uint8_t", "uint8_t", "float", "float"],
|
||||
]
|
||||
|
|
@ -422,6 +463,7 @@ code.append("//// DO NOT MODIFY!!!")
|
|||
code.append("//// --------------------------\n")
|
||||
|
||||
code.append("#include <c10/util/Half.h>")
|
||||
code.append("#include <c10/util/BFloat16.h>")
|
||||
code.append("#include <immintrin.h>")
|
||||
|
||||
code.append("namespace caffe2 {\n")
|
||||
|
|
@ -461,6 +503,7 @@ for o in options:
|
|||
code += args
|
||||
|
||||
code.append(" const " + IndexType + " prefdist_T0 = 16;")
|
||||
code.append(" // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)")
|
||||
# block_size is the number of elements and fused_block_size is the size of
|
||||
# an entire row, including scale and bias.
|
||||
offset = (8 // sizeof[InType]) if opts.fused else 0
|
||||
|
|
@ -484,6 +527,7 @@ for o in options:
|
|||
code += unroll(2, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets)
|
||||
code.append(" } else {")
|
||||
code.append(" // generic code")
|
||||
code.append(" // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)")
|
||||
code += generic(IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets)
|
||||
code.append(" }")
|
||||
code.append(" return dataInd == index_size;")
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user