[TensorExpr] Add lowering for aten::embedding. (#66518)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/66518

Differential Revision:
D31590855
D31590855

Test Plan: Imported from OSS

Reviewed By: pbelevich

Pulled By: ZolotukhinM

fbshipit-source-id: aace0a87b1649330dae44182f7873aca27160d64
This commit is contained in:
Mikhail Zolotukhin 2021-11-03 09:41:07 -07:00 committed by Facebook GitHub Bot
parent 008a58d226
commit 00afe9ba7b
5 changed files with 105 additions and 0 deletions

View File

@ -375,6 +375,57 @@ TEST(ExternalCall, Addmm_float) {
ASSERT_TRUE(at::allclose(nnc_result, ref));
}
TEST(ExternalCall, Embedding) {
BufHandle Weight("Weight", {256, 100}, kFloat);
BufHandle Indices("Indices", {1, 115}, kLong);
BufHandle ResultBuf("Result", {1, 115, 100}, kFloat);
int64_t padding_idx = -1;
bool scale_grad_by_freq = false;
bool sparse = false;
Tensor Result = Tensor(
ResultBuf.node(),
ExternalCall::make(
ResultBuf,
"nnc_aten_embedding",
{Weight, Indices},
{padding_idx, (int64_t)scale_grad_by_freq, (int64_t)sparse}));
LoopNest l({Result});
l.prepareForCodegen();
l.simplify();
auto options = at::TensorOptions()
.layout(at::kStrided)
.device(at::kCPU)
.requires_grad(false);
at::Tensor weight = at::ones({256, 100}, options.dtype(at::kFloat)) * 5.f;
at::Tensor indices = at::ones({1, 115}, options.dtype(at::kLong)) * 6;
at::Tensor ref =
at::embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse);
at::Tensor nnc_result;
std::vector<float> weight_buf(256 * 100, 5.f);
std::vector<int64_t> indices_buf(1 * 115, 6);
std::vector<float> result_buf(1 * 115 * 100, -1.f);
#ifdef TORCH_ENABLE_LLVM
LLVMCodeGen llvm_codegen(l.root_stmt(), {Weight, Indices, Result});
llvm_codegen.call({weight_buf, indices_buf, result_buf});
nnc_result = at::from_blob(
result_buf.data(), {1, 115, 100}, options.dtype(at::kFloat));
ASSERT_TRUE(at::allclose(nnc_result, ref));
#endif
SimpleIREvaluator ir_eval(l.root_stmt(), {Weight, Indices, Result});
ir_eval.call({weight_buf, indices_buf, result_buf});
nnc_result = at::from_blob(
result_buf.data(), {1, 115, 100}, options.dtype(at::kFloat));
ASSERT_TRUE(at::allclose(nnc_result, ref));
}
#ifdef USE_XNNPACK
TEST(ExternalCall, Prepacked_Linear_float) {

View File

@ -475,6 +475,29 @@ void nnc_prepacked_conv2d_clamp_run(
#endif // USE_XNNPACK
void nnc_aten_embedding(
int64_t bufs_num,
void** buf_data,
int64_t* buf_ranks,
int64_t* buf_dims,
int8_t* buf_dtypes,
int64_t args_num,
int64_t* extra_args) {
std::vector<at::Tensor> tensors =
constructTensors(bufs_num, buf_data, buf_ranks, buf_dims, buf_dtypes);
at::Tensor& r = tensors[0];
const at::Tensor& weight = tensors[1];
const at::Tensor& indices = tensors[2];
try {
r = at::embedding(weight, indices);
} catch (...) {
}
// TODO: have to copy output because at::embedding doesnt have an out variant
// and NNC's external calls don't support allocations
memcpy(buf_data[0], r.data_ptr(), r.element_size() * r.numel());
}
#ifndef C10_MOBILE
const static RegisterNNCExternalFunction nnc_conv2d(
@ -515,6 +538,10 @@ const static RegisterNNCExternalFunction nnc_triangular_solve(
"nnc_aten_triangular_solve",
nnc_aten_triangular_solve);
const static RegisterNNCExternalFunction nnc_embedding(
"nnc_aten_embedding",
nnc_aten_embedding);
#ifdef USE_XNNPACK
const static RegisterNNCExternalFunction reg_nnc_prepacked_linear_clamp_run(
"nnc_prepacked_linear_clamp_run",

View File

@ -1485,6 +1485,9 @@ RegisterNNCLoweringsFunction aten_add(
: computeTwoOperand(
"aten_add", inputs, outputShape, outputType, add_lambda);
});
RegisterNNCLoweringsFunction aten_embedding(
{"aten::embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor"},
computeEmbedding);
#define NNC_QUANTIZATION_EXPR_QUANT 0
#define NNC_QUANTIZATION_EXPR_DEQUANT 0

View File

@ -612,6 +612,25 @@ Tensor computeCat(
});
}
Tensor computeEmbedding(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const c10::optional<ScalarType>& outputType,
at::Device device) {
Dtype dtype = kFloat;
if (outputType) {
dtype = Dtype(*outputType);
}
BufHandle ResultBuf("emb", outputShape, dtype);
const BufHandle& w = c10::get<BufHandle>(inputs[0]);
const BufHandle& indices = c10::get<BufHandle>(inputs[1]);
StmtPtr s =
ExternalCall::make(ResultBuf, "nnc_aten_embedding", {w, indices}, {});
return Tensor(ResultBuf.node(), s);
}
} // namespace tensorexpr
} // namespace jit
} // namespace torch

View File

@ -74,6 +74,11 @@ Tensor computeCat(
const std::vector<ExprHandle>& outputShape,
const c10::optional<ScalarType>& outputType,
at::Device device);
Tensor computeEmbedding(
const std::vector<ArgValue>& inputs,
const std::vector<ExprHandle>& outputShape,
const c10::optional<ScalarType>& outputType,
at::Device device);
} // namespace tensorexpr
} // namespace jit