mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[TensorExpr] Add lowering for aten::embedding. (#66518)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/66518 Differential Revision: D31590855 D31590855 Test Plan: Imported from OSS Reviewed By: pbelevich Pulled By: ZolotukhinM fbshipit-source-id: aace0a87b1649330dae44182f7873aca27160d64
This commit is contained in:
parent
008a58d226
commit
00afe9ba7b
|
|
@ -375,6 +375,57 @@ TEST(ExternalCall, Addmm_float) {
|
|||
ASSERT_TRUE(at::allclose(nnc_result, ref));
|
||||
}
|
||||
|
||||
TEST(ExternalCall, Embedding) {
|
||||
BufHandle Weight("Weight", {256, 100}, kFloat);
|
||||
BufHandle Indices("Indices", {1, 115}, kLong);
|
||||
BufHandle ResultBuf("Result", {1, 115, 100}, kFloat);
|
||||
int64_t padding_idx = -1;
|
||||
bool scale_grad_by_freq = false;
|
||||
bool sparse = false;
|
||||
|
||||
Tensor Result = Tensor(
|
||||
ResultBuf.node(),
|
||||
ExternalCall::make(
|
||||
ResultBuf,
|
||||
"nnc_aten_embedding",
|
||||
{Weight, Indices},
|
||||
{padding_idx, (int64_t)scale_grad_by_freq, (int64_t)sparse}));
|
||||
LoopNest l({Result});
|
||||
l.prepareForCodegen();
|
||||
l.simplify();
|
||||
|
||||
auto options = at::TensorOptions()
|
||||
.layout(at::kStrided)
|
||||
.device(at::kCPU)
|
||||
.requires_grad(false);
|
||||
|
||||
at::Tensor weight = at::ones({256, 100}, options.dtype(at::kFloat)) * 5.f;
|
||||
at::Tensor indices = at::ones({1, 115}, options.dtype(at::kLong)) * 6;
|
||||
at::Tensor ref =
|
||||
at::embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse);
|
||||
|
||||
at::Tensor nnc_result;
|
||||
std::vector<float> weight_buf(256 * 100, 5.f);
|
||||
std::vector<int64_t> indices_buf(1 * 115, 6);
|
||||
std::vector<float> result_buf(1 * 115 * 100, -1.f);
|
||||
|
||||
#ifdef TORCH_ENABLE_LLVM
|
||||
LLVMCodeGen llvm_codegen(l.root_stmt(), {Weight, Indices, Result});
|
||||
|
||||
llvm_codegen.call({weight_buf, indices_buf, result_buf});
|
||||
nnc_result = at::from_blob(
|
||||
result_buf.data(), {1, 115, 100}, options.dtype(at::kFloat));
|
||||
ASSERT_TRUE(at::allclose(nnc_result, ref));
|
||||
#endif
|
||||
|
||||
SimpleIREvaluator ir_eval(l.root_stmt(), {Weight, Indices, Result});
|
||||
|
||||
ir_eval.call({weight_buf, indices_buf, result_buf});
|
||||
nnc_result = at::from_blob(
|
||||
result_buf.data(), {1, 115, 100}, options.dtype(at::kFloat));
|
||||
ASSERT_TRUE(at::allclose(nnc_result, ref));
|
||||
}
|
||||
|
||||
#ifdef USE_XNNPACK
|
||||
|
||||
TEST(ExternalCall, Prepacked_Linear_float) {
|
||||
|
|
|
|||
|
|
@ -475,6 +475,29 @@ void nnc_prepacked_conv2d_clamp_run(
|
|||
|
||||
#endif // USE_XNNPACK
|
||||
|
||||
void nnc_aten_embedding(
|
||||
int64_t bufs_num,
|
||||
void** buf_data,
|
||||
int64_t* buf_ranks,
|
||||
int64_t* buf_dims,
|
||||
int8_t* buf_dtypes,
|
||||
int64_t args_num,
|
||||
int64_t* extra_args) {
|
||||
std::vector<at::Tensor> tensors =
|
||||
constructTensors(bufs_num, buf_data, buf_ranks, buf_dims, buf_dtypes);
|
||||
|
||||
at::Tensor& r = tensors[0];
|
||||
const at::Tensor& weight = tensors[1];
|
||||
const at::Tensor& indices = tensors[2];
|
||||
try {
|
||||
r = at::embedding(weight, indices);
|
||||
} catch (...) {
|
||||
}
|
||||
// TODO: have to copy output because at::embedding doesnt have an out variant
|
||||
// and NNC's external calls don't support allocations
|
||||
memcpy(buf_data[0], r.data_ptr(), r.element_size() * r.numel());
|
||||
}
|
||||
|
||||
#ifndef C10_MOBILE
|
||||
|
||||
const static RegisterNNCExternalFunction nnc_conv2d(
|
||||
|
|
@ -515,6 +538,10 @@ const static RegisterNNCExternalFunction nnc_triangular_solve(
|
|||
"nnc_aten_triangular_solve",
|
||||
nnc_aten_triangular_solve);
|
||||
|
||||
const static RegisterNNCExternalFunction nnc_embedding(
|
||||
"nnc_aten_embedding",
|
||||
nnc_aten_embedding);
|
||||
|
||||
#ifdef USE_XNNPACK
|
||||
const static RegisterNNCExternalFunction reg_nnc_prepacked_linear_clamp_run(
|
||||
"nnc_prepacked_linear_clamp_run",
|
||||
|
|
|
|||
|
|
@ -1485,6 +1485,9 @@ RegisterNNCLoweringsFunction aten_add(
|
|||
: computeTwoOperand(
|
||||
"aten_add", inputs, outputShape, outputType, add_lambda);
|
||||
});
|
||||
RegisterNNCLoweringsFunction aten_embedding(
|
||||
{"aten::embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor"},
|
||||
computeEmbedding);
|
||||
|
||||
#define NNC_QUANTIZATION_EXPR_QUANT 0
|
||||
#define NNC_QUANTIZATION_EXPR_DEQUANT 0
|
||||
|
|
|
|||
|
|
@ -612,6 +612,25 @@ Tensor computeCat(
|
|||
});
|
||||
}
|
||||
|
||||
Tensor computeEmbedding(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device) {
|
||||
Dtype dtype = kFloat;
|
||||
if (outputType) {
|
||||
dtype = Dtype(*outputType);
|
||||
}
|
||||
|
||||
BufHandle ResultBuf("emb", outputShape, dtype);
|
||||
const BufHandle& w = c10::get<BufHandle>(inputs[0]);
|
||||
const BufHandle& indices = c10::get<BufHandle>(inputs[1]);
|
||||
|
||||
StmtPtr s =
|
||||
ExternalCall::make(ResultBuf, "nnc_aten_embedding", {w, indices}, {});
|
||||
return Tensor(ResultBuf.node(), s);
|
||||
}
|
||||
|
||||
} // namespace tensorexpr
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -74,6 +74,11 @@ Tensor computeCat(
|
|||
const std::vector<ExprHandle>& outputShape,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
Tensor computeEmbedding(
|
||||
const std::vector<ArgValue>& inputs,
|
||||
const std::vector<ExprHandle>& outputShape,
|
||||
const c10::optional<ScalarType>& outputType,
|
||||
at::Device device);
|
||||
|
||||
} // namespace tensorexpr
|
||||
} // namespace jit
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user