mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[PyTorch][Static Runtime] Fix extra refcount bumps in layer_norm (#71237)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/71237 Noticed these on inspection. ghstack-source-id: 147171799 Test Plan: CI Reviewed By: mikeiovine Differential Revision: D33519799 fbshipit-source-id: 167c63323b345a5822303cecdbbbbb959f66f6e4
This commit is contained in:
parent
dbbef542c0
commit
57e8da2d35
|
|
@ -2019,6 +2019,13 @@ REGISTER_OPERATOR_FUNCTOR(aten::softmax, aten_softmax, [](Node* n) -> SROperator
|
|||
};
|
||||
});
|
||||
|
||||
static c10::MaybeOwned<at::Tensor> borrow_from_optional_tensor_ivalue(
|
||||
const IValue& iv) {
|
||||
if (iv.isNone()) {
|
||||
return c10::MaybeOwned<at::Tensor>::owned(c10::in_place);
|
||||
}
|
||||
return c10::MaybeOwned<at::Tensor>::borrowed(iv.toTensor());
|
||||
}
|
||||
REGISTER_OPERATOR_FUNCTOR(
|
||||
static_runtime::layer_norm,
|
||||
aten_layer_norm,
|
||||
|
|
@ -2032,15 +2039,13 @@ REGISTER_OPERATOR_FUNCTOR(
|
|||
// ignore Input(5): `bool cudnn_enable=True`
|
||||
const auto& input = p_node->Input(0).toTensor();
|
||||
const auto normalized_shape = p_node->Input(1).toDimVector();
|
||||
auto weight_opt = p_node->Input(2).toOptional<at::Tensor>();
|
||||
auto bias_opt = p_node->Input(3).toOptional<at::Tensor>();
|
||||
float eps = p_node->Input(4).toDouble();
|
||||
|
||||
c10::MaybeOwned<at::Tensor> weight_maybe_owned =
|
||||
at::borrow_from_optional_tensor(weight_opt);
|
||||
borrow_from_optional_tensor_ivalue(p_node->Input(2));
|
||||
const at::Tensor& weight = *weight_maybe_owned;
|
||||
c10::MaybeOwned<at::Tensor> bias_maybe_owned =
|
||||
at::borrow_from_optional_tensor(bias_opt);
|
||||
borrow_from_optional_tensor_ivalue(p_node->Input(3));
|
||||
const at::Tensor& bias = *bias_maybe_owned;
|
||||
|
||||
auto M_N = at::native::_check_layer_norm_inputs(
|
||||
|
|
@ -2074,8 +2079,8 @@ REGISTER_OPERATOR_FUNCTOR(
|
|||
at::native::resize_(p_node->Output(2).toTensor(), {M}, c10::nullopt);
|
||||
}
|
||||
at::Tensor& output = p_node->Output(0).toTensor();
|
||||
at::Tensor mean = p_node->Output(1).toTensor();
|
||||
at::Tensor rstd = p_node->Output(2).toTensor();
|
||||
at::Tensor& mean = p_node->Output(1).toTensor();
|
||||
at::Tensor& rstd = p_node->Output(2).toTensor();
|
||||
at::native::layer_norm_cpu_out(
|
||||
output,
|
||||
mean,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user