[TensorExpr] Allow for 'keepdim' argument in aten::mean in NNC's external call. (#68756)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68756

That fixes some warnings in our tests.

Test Plan: Imported from OSS

Reviewed By: navahgar

Differential Revision: D32600952

Pulled By: ZolotukhinM

fbshipit-source-id: 548eaf3659e20795cce44d8f57e77f4a47d44d98
This commit is contained in:
Mikhail Zolotukhin 2021-11-30 00:03:21 -08:00 committed by Facebook GitHub Bot
parent a93f505ee5
commit 75ce040620
3 changed files with 16 additions and 10 deletions

View File

@ -716,7 +716,7 @@ TEST(ExternalCall, UnaryFloat) {
{100},
[](at::Tensor x) { return at::mean(x, {1}); },
"nnc_aten_mean",
toExprHandleVec({1})});
toExprHandleVec({1, /*keepdim=*/0})});
for (auto curTest : tests) {
std::vector<int64_t> aShape, resShape;
TensorFunc torchFunc;

View File

@ -687,12 +687,13 @@ void nnc_aten_mean(
at::Tensor& r = tensors[0];
const at::Tensor& x = tensors[1];
std::vector<int64_t> mean_dims(args_num);
if (args_num > 0) {
memcpy(mean_dims.data(), extra_args, sizeof(int64_t) * args_num);
std::vector<int64_t> mean_dims(args_num - 1);
bool keepdim = (bool)extra_args[args_num - 1];
if (args_num > 1) {
memcpy(mean_dims.data(), extra_args, sizeof(int64_t) * (args_num - 1));
}
try {
at::mean_out(r, x, mean_dims);
at::mean_out(r, x, mean_dims, keepdim);
} catch (...) {
}
}

View File

@ -111,22 +111,27 @@ Tensor computeMean(
if (outputType) {
dtype = Dtype(*outputType);
}
bool keepdim = false;
BufHandle ResultBuf("mean", outputShape, dtype);
BufHandle InputBuf = c10::get<BufHandle>(inputs[0]);
std::vector<ExprHandle> mean_dims_expr;
std::vector<ExprHandle> extra_args;
if (inputs.size() > 2) {
keepdim = c10::get<bool>(inputs[2]);
}
if (auto mean_dims = c10::get_if<IntList>(&inputs[1])) {
mean_dims_expr = c10::fmap<ExprHandle>(*mean_dims);
extra_args = c10::fmap<ExprHandle>(*mean_dims);
} else {
// When dims argument is not specified, reduce over all dimensions
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
for (int64_t idx = 0; idx < InputBuf.ndim(); idx++) {
mean_dims_expr.emplace_back(idx);
extra_args.emplace_back(idx);
}
}
extra_args.push_back(LongImm::make(static_cast<int64_t>(keepdim)));
return Tensor(
ResultBuf.node(),
ExternalCall::make(
ResultBuf, "nnc_aten_mean", {InputBuf}, mean_dims_expr));
ExternalCall::make(ResultBuf, "nnc_aten_mean", {InputBuf}, extra_args));
}
Tensor computeMax(