[PyTorch][Static Runtime] Fix extra refcount bumps in layer_norm (#71237)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71237

Noticed these on inspection.
ghstack-source-id: 147171799

Test Plan: CI

Reviewed By: mikeiovine

Differential Revision: D33519799

fbshipit-source-id: 167c63323b345a5822303cecdbbbbb959f66f6e4
This commit is contained in:
Scott Wolchok 2022-01-19 16:01:54 -08:00 committed by Facebook GitHub Bot
parent dbbef542c0
commit 57e8da2d35

View File

@ -2019,6 +2019,13 @@ REGISTER_OPERATOR_FUNCTOR(aten::softmax, aten_softmax, [](Node* n) -> SROperator
};
});
static c10::MaybeOwned<at::Tensor> borrow_from_optional_tensor_ivalue(
const IValue& iv) {
if (iv.isNone()) {
return c10::MaybeOwned<at::Tensor>::owned(c10::in_place);
}
return c10::MaybeOwned<at::Tensor>::borrowed(iv.toTensor());
}
REGISTER_OPERATOR_FUNCTOR(
static_runtime::layer_norm,
aten_layer_norm,
@ -2032,15 +2039,13 @@ REGISTER_OPERATOR_FUNCTOR(
// ignore Input(5): `bool cudnn_enable=True`
const auto& input = p_node->Input(0).toTensor();
const auto normalized_shape = p_node->Input(1).toDimVector();
auto weight_opt = p_node->Input(2).toOptional<at::Tensor>();
auto bias_opt = p_node->Input(3).toOptional<at::Tensor>();
float eps = p_node->Input(4).toDouble();
c10::MaybeOwned<at::Tensor> weight_maybe_owned =
at::borrow_from_optional_tensor(weight_opt);
borrow_from_optional_tensor_ivalue(p_node->Input(2));
const at::Tensor& weight = *weight_maybe_owned;
c10::MaybeOwned<at::Tensor> bias_maybe_owned =
at::borrow_from_optional_tensor(bias_opt);
borrow_from_optional_tensor_ivalue(p_node->Input(3));
const at::Tensor& bias = *bias_maybe_owned;
auto M_N = at::native::_check_layer_norm_inputs(
@ -2074,8 +2079,8 @@ REGISTER_OPERATOR_FUNCTOR(
at::native::resize_(p_node->Output(2).toTensor(), {M}, c10::nullopt);
}
at::Tensor& output = p_node->Output(0).toTensor();
at::Tensor mean = p_node->Output(1).toTensor();
at::Tensor rstd = p_node->Output(2).toTensor();
at::Tensor& mean = p_node->Output(1).toTensor();
at::Tensor& rstd = p_node->Output(2).toTensor();
at::native::layer_norm_cpu_out(
output,
mean,