mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[TensorExpr] Allow for 'keepdim' argument in aten::mean in NNC's external call. (#68756)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/68756 That fixes some warnings in our tests. Test Plan: Imported from OSS Reviewed By: navahgar Differential Revision: D32600952 Pulled By: ZolotukhinM fbshipit-source-id: 548eaf3659e20795cce44d8f57e77f4a47d44d98
This commit is contained in:
parent
a93f505ee5
commit
75ce040620
|
|
@ -716,7 +716,7 @@ TEST(ExternalCall, UnaryFloat) {
|
|||
{100},
|
||||
[](at::Tensor x) { return at::mean(x, {1}); },
|
||||
"nnc_aten_mean",
|
||||
toExprHandleVec({1})});
|
||||
toExprHandleVec({1, /*keepdim=*/0})});
|
||||
for (auto curTest : tests) {
|
||||
std::vector<int64_t> aShape, resShape;
|
||||
TensorFunc torchFunc;
|
||||
|
|
|
|||
|
|
@ -687,12 +687,13 @@ void nnc_aten_mean(
|
|||
|
||||
at::Tensor& r = tensors[0];
|
||||
const at::Tensor& x = tensors[1];
|
||||
std::vector<int64_t> mean_dims(args_num);
|
||||
if (args_num > 0) {
|
||||
memcpy(mean_dims.data(), extra_args, sizeof(int64_t) * args_num);
|
||||
std::vector<int64_t> mean_dims(args_num - 1);
|
||||
bool keepdim = (bool)extra_args[args_num - 1];
|
||||
if (args_num > 1) {
|
||||
memcpy(mean_dims.data(), extra_args, sizeof(int64_t) * (args_num - 1));
|
||||
}
|
||||
try {
|
||||
at::mean_out(r, x, mean_dims);
|
||||
at::mean_out(r, x, mean_dims, keepdim);
|
||||
} catch (...) {
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -111,22 +111,27 @@ Tensor computeMean(
|
|||
if (outputType) {
|
||||
dtype = Dtype(*outputType);
|
||||
}
|
||||
bool keepdim = false;
|
||||
BufHandle ResultBuf("mean", outputShape, dtype);
|
||||
BufHandle InputBuf = c10::get<BufHandle>(inputs[0]);
|
||||
std::vector<ExprHandle> mean_dims_expr;
|
||||
std::vector<ExprHandle> extra_args;
|
||||
if (inputs.size() > 2) {
|
||||
keepdim = c10::get<bool>(inputs[2]);
|
||||
}
|
||||
|
||||
if (auto mean_dims = c10::get_if<IntList>(&inputs[1])) {
|
||||
mean_dims_expr = c10::fmap<ExprHandle>(*mean_dims);
|
||||
extra_args = c10::fmap<ExprHandle>(*mean_dims);
|
||||
} else {
|
||||
// When dims argument is not specified, reduce over all dimensions
|
||||
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
|
||||
for (int64_t idx = 0; idx < InputBuf.ndim(); idx++) {
|
||||
mean_dims_expr.emplace_back(idx);
|
||||
extra_args.emplace_back(idx);
|
||||
}
|
||||
}
|
||||
extra_args.push_back(LongImm::make(static_cast<int64_t>(keepdim)));
|
||||
return Tensor(
|
||||
ResultBuf.node(),
|
||||
ExternalCall::make(
|
||||
ResultBuf, "nnc_aten_mean", {InputBuf}, mean_dims_expr));
|
||||
ExternalCall::make(ResultBuf, "nnc_aten_mean", {InputBuf}, extra_args));
|
||||
}
|
||||
|
||||
Tensor computeMax(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user