Remove unneeded optional dereference (#141578)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141578
Approved by: https://github.com/swolchok
This commit is contained in:
cyy 2024-12-12 04:34:41 +00:00 committed by PyTorch MergeBot
parent f7b9533c3f
commit 20df80a669
6 changed files with 20 additions and 81 deletions

View File

@ -1732,11 +1732,10 @@ std::tuple<Tensor,Tensor,Tensor> _convolution_double_backward( const std::option
// See [Note: hacky wrapper removal for optional tensor] // See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> ggI_maybe_owned = at::borrow_from_optional_tensor(ggI_opt); c10::MaybeOwned<Tensor> ggI_maybe_owned = at::borrow_from_optional_tensor(ggI_opt);
const Tensor& ggI = *ggI_maybe_owned; const Tensor& ggI = *ggI_maybe_owned;
const Tensor& ggW_r = ggW_r_opt.value_or(Tensor()); Tensor ggW = ggW_r_opt.value_or(Tensor());
const Tensor& ggb = ggb_opt.value_or(Tensor()); const Tensor& ggb = ggb_opt.value_or(Tensor());
auto ggW = ggW_r;
auto gO = gO_r; auto gO = gO_r;
auto weight = weight_r; auto weight = weight_r;

View File

@ -251,20 +251,12 @@ Tensor kl_div(const Tensor& input, const Tensor& target, int64_t reduction, bool
} }
Tensor binary_cross_entropy_cpu(const Tensor& input, const Tensor& target, const std::optional<Tensor>& weight_opt, int64_t reduction) { Tensor binary_cross_entropy_cpu(const Tensor& input, const Tensor& target, const std::optional<Tensor>& weight_opt, int64_t reduction) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
Tensor loss = at::empty_like(input); Tensor loss = at::empty_like(input);
return at::native::binary_cross_entropy_out_cpu( return at::native::binary_cross_entropy_out_cpu(
input, target, weight, reduction, loss); input, target, weight_opt, reduction, loss);
} }
Tensor& binary_cross_entropy_out_cpu(const Tensor& input, const Tensor& target, const std::optional<Tensor>& weight_opt, int64_t reduction, Tensor& loss) { Tensor& binary_cross_entropy_out_cpu(const Tensor& input, const Tensor& target, const std::optional<Tensor>& weight_opt, int64_t reduction, Tensor& loss) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
Tensor loss_squeezed = at::squeeze(loss); Tensor loss_squeezed = at::squeeze(loss);
auto iter = TensorIteratorConfig() auto iter = TensorIteratorConfig()
@ -297,8 +289,8 @@ Tensor& binary_cross_entropy_out_cpu(const Tensor& input, const Tensor& target,
}); });
}); });
if (weight.defined()) { if (weight_opt.has_value() && weight_opt->defined()) {
loss.mul_(weight); loss.mul_(*weight_opt);
} }
if (reduction != at::Reduction::None) { if (reduction != at::Reduction::None) {
Tensor loss_reduced = apply_loss_reduction(loss, reduction); Tensor loss_reduced = apply_loss_reduction(loss, reduction);
@ -308,20 +300,12 @@ Tensor& binary_cross_entropy_out_cpu(const Tensor& input, const Tensor& target,
} }
Tensor binary_cross_entropy_backward_cpu(const Tensor& grad, const Tensor& input, const Tensor& target, const std::optional<Tensor>& weight_opt, int64_t reduction) { Tensor binary_cross_entropy_backward_cpu(const Tensor& grad, const Tensor& input, const Tensor& target, const std::optional<Tensor>& weight_opt, int64_t reduction) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
Tensor grad_input = at::empty_like(input); Tensor grad_input = at::empty_like(input);
return at::native::binary_cross_entropy_backward_out_cpu( return at::native::binary_cross_entropy_backward_out_cpu(
grad, input, target, weight, reduction, grad_input); grad, input, target, weight_opt, reduction, grad_input);
} }
Tensor& binary_cross_entropy_backward_out_cpu(const Tensor& grad, const Tensor& input, const Tensor& target, const std::optional<Tensor>& weight_opt, int64_t reduction, Tensor& grad_input) { Tensor& binary_cross_entropy_backward_out_cpu(const Tensor& grad, const Tensor& input, const Tensor& target, const std::optional<Tensor>& weight_opt, int64_t reduction, Tensor& grad_input) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
Tensor grad_input_squeezed = at::squeeze(grad_input); Tensor grad_input_squeezed = at::squeeze(grad_input);
auto iter = TensorIteratorConfig() auto iter = TensorIteratorConfig()
@ -350,8 +334,8 @@ Tensor& binary_cross_entropy_backward_out_cpu(const Tensor& grad, const Tensor&
}); });
}); });
if (weight.defined()) { if (weight_opt.has_value() && weight_opt->defined()) {
grad_input.mul_(weight); grad_input.mul_(*weight_opt);
} }
if (reduction == at::Reduction::Mean) { if (reduction == at::Reduction::Mean) {
grad_input.div_(input.numel()); grad_input.div_(input.numel());
@ -360,23 +344,17 @@ Tensor& binary_cross_entropy_backward_out_cpu(const Tensor& grad, const Tensor&
} }
Tensor binary_cross_entropy_with_logits(const Tensor& input, const Tensor& target, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& pos_weight_opt, int64_t reduction) { Tensor binary_cross_entropy_with_logits(const Tensor& input, const Tensor& target, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& pos_weight_opt, int64_t reduction) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
c10::MaybeOwned<Tensor> pos_weight_maybe_owned = at::borrow_from_optional_tensor(pos_weight_opt);
const Tensor& pos_weight = *pos_weight_maybe_owned;
auto log_sigmoid_input = at::log_sigmoid(input); auto log_sigmoid_input = at::log_sigmoid(input);
if (pos_weight.defined()) { if (pos_weight_opt.has_value() && pos_weight_opt->defined()) {
// pos_weight need to be broadcasted, thus mul(target) is not inplace. // pos_weight need to be broadcasted, thus mul(target) is not inplace.
auto log_weight = (pos_weight - 1).mul(target).add_(1); auto log_weight = (*pos_weight_opt- 1).mul(target).add_(1);
log_sigmoid_input.mul_(log_weight); log_sigmoid_input.mul_(log_weight);
} }
Tensor loss = (1 - target).mul_(input).sub_(log_sigmoid_input); Tensor loss = (1 - target).mul_(input).sub_(log_sigmoid_input);
if (weight.defined()) { if (weight_opt.has_value() && weight_opt->defined()) {
loss.mul_(weight); loss.mul_(*weight_opt);
} }
return apply_loss_reduction(loss, reduction); return apply_loss_reduction(loss, reduction);

View File

@ -659,20 +659,12 @@ Tensor cross_entropy_loss_symint(
} }
Tensor & nll_loss_out(const Tensor & self, const Tensor & target, const std::optional<Tensor>& weight_opt, int64_t reduction, int64_t ignore_index, Tensor & output) { Tensor & nll_loss_out(const Tensor & self, const Tensor & target, const std::optional<Tensor>& weight_opt, int64_t reduction, int64_t ignore_index, Tensor & output) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
Tensor total_weight = at::empty({0}, self.options()); Tensor total_weight = at::empty({0}, self.options());
return std::get<0>(at::nll_loss_forward_out(output, total_weight, self, target, weight, reduction, ignore_index)); return std::get<0>(at::nll_loss_forward_out(output, total_weight, self, target, weight_opt, reduction, ignore_index));
} }
Tensor nll_loss_symint(const Tensor & self, const Tensor & target, const std::optional<Tensor>& weight_opt, int64_t reduction, c10::SymInt ignore_index) { Tensor nll_loss_symint(const Tensor & self, const Tensor & target, const std::optional<Tensor>& weight_opt, int64_t reduction, c10::SymInt ignore_index) {
// See [Note: hacky wrapper removal for optional tensor] return std::get<0>(at::nll_loss_forward_symint(self, target, weight_opt, reduction, std::move(ignore_index)));
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
return std::get<0>(at::nll_loss_forward_symint(self, target, weight, reduction, std::move(ignore_index)));
} }
Tensor nll_loss_nd_symint( Tensor nll_loss_nd_symint(

View File

@ -424,14 +424,10 @@ std::tuple<Tensor, Tensor> nll_loss2d_forward_cpu(
const Tensor& target, const std::optional<Tensor>& weight_opt, const Tensor& target, const std::optional<Tensor>& weight_opt,
int64_t reduction, int64_t reduction,
int64_t ignore_index) { int64_t ignore_index) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
auto output = at::empty({0}, self.options()); auto output = at::empty({0}, self.options());
auto total_weight = at::empty({0}, self.options()); auto total_weight = at::empty({0}, self.options());
at::native::nll_loss2d_forward_out_cpu( at::native::nll_loss2d_forward_out_cpu(
self, target, weight, reduction, ignore_index, output, total_weight); self, target, weight_opt, reduction, ignore_index, output, total_weight);
return std::make_tuple(output, total_weight); return std::make_tuple(output, total_weight);
} }
@ -465,16 +461,12 @@ Tensor nll_loss2d_backward_cpu(
int64_t reduction, int64_t reduction,
int64_t ignore_index, int64_t ignore_index,
const Tensor& total_weight) { const Tensor& total_weight) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
auto grad_input = at::zeros_like(self); auto grad_input = at::zeros_like(self);
at::native::nll_loss2d_backward_out_cpu( at::native::nll_loss2d_backward_out_cpu(
grad_output, grad_output,
self, self,
target, target,
weight, weight_opt,
reduction, reduction,
ignore_index, ignore_index,
total_weight, total_weight,
@ -483,20 +475,12 @@ Tensor nll_loss2d_backward_cpu(
} }
Tensor & nll_loss2d_out(const Tensor & self, const Tensor & target, const std::optional<Tensor>& weight_opt, int64_t reduction, int64_t ignore_index, Tensor & output) { Tensor & nll_loss2d_out(const Tensor & self, const Tensor & target, const std::optional<Tensor>& weight_opt, int64_t reduction, int64_t ignore_index, Tensor & output) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
Tensor total_weight = at::empty({0}, self.options()); Tensor total_weight = at::empty({0}, self.options());
return std::get<0>(at::nll_loss2d_forward_out(output, total_weight, self, target, weight, reduction, ignore_index)); return std::get<0>(at::nll_loss2d_forward_out(output, total_weight, self, target, weight_opt, reduction, ignore_index));
} }
Tensor nll_loss2d_symint(const Tensor & self, const Tensor & target, const std::optional<Tensor>& weight_opt, int64_t reduction, c10::SymInt ignore_index) { Tensor nll_loss2d_symint(const Tensor & self, const Tensor & target, const std::optional<Tensor>& weight_opt, int64_t reduction, c10::SymInt ignore_index) {
// See [Note: hacky wrapper removal for optional tensor] return std::get<0>(at::nll_loss2d_forward_symint(self, target, weight_opt, reduction, std::move(ignore_index)));
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
return std::get<0>(at::nll_loss2d_forward_symint(self, target, weight, reduction, std::move(ignore_index)));
} }
} // namespace at::native } // namespace at::native

View File

@ -63,13 +63,9 @@ void binary_cross_entropy_backward_out_kernel(Tensor& grad_input, const Tensor&
namespace at::native { namespace at::native {
Tensor binary_cross_entropy_cuda(const Tensor& input, const Tensor& target, const std::optional<Tensor>& weight_opt, int64_t reduction) { Tensor binary_cross_entropy_cuda(const Tensor& input, const Tensor& target, const std::optional<Tensor>& weight_opt, int64_t reduction) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
Tensor loss = at::empty_like(input); Tensor loss = at::empty_like(input);
return at::native::binary_cross_entropy_out_cuda( return at::native::binary_cross_entropy_out_cuda(
input, target, weight, reduction, loss); input, target, weight_opt, reduction, loss);
} }
Tensor& binary_cross_entropy_out_cuda(const Tensor& input, const Tensor& target, const std::optional<Tensor>& weight_opt, int64_t reduction, Tensor& loss) { Tensor& binary_cross_entropy_out_cuda(const Tensor& input, const Tensor& target, const std::optional<Tensor>& weight_opt, int64_t reduction, Tensor& loss) {
@ -122,13 +118,9 @@ Tensor& binary_cross_entropy_out_cuda(const Tensor& input, const Tensor& target,
} }
Tensor binary_cross_entropy_backward_cuda(const Tensor& grad, const Tensor& input, const Tensor& target, const std::optional<Tensor>& weight_opt, int64_t reduction) { Tensor binary_cross_entropy_backward_cuda(const Tensor& grad, const Tensor& input, const Tensor& target, const std::optional<Tensor>& weight_opt, int64_t reduction) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
Tensor grad_input = at::empty_like(input); Tensor grad_input = at::empty_like(input);
return at::native::binary_cross_entropy_backward_out_cuda( return at::native::binary_cross_entropy_backward_out_cuda(
grad, input, target, weight, reduction, grad_input); grad, input, target, weight_opt, reduction, grad_input);
} }
Tensor& binary_cross_entropy_backward_out_cuda(const Tensor& grad, const Tensor& input, const Tensor& target, const std::optional<Tensor>& weight_opt, int64_t reduction, Tensor& grad_input) { Tensor& binary_cross_entropy_backward_out_cuda(const Tensor& grad, const Tensor& input, const Tensor& target, const std::optional<Tensor>& weight_opt, int64_t reduction, Tensor& grad_input) {

View File

@ -190,13 +190,7 @@ Tensor layer_norm_symint(
c10::SymIntArrayRef normalized_shape, const std::optional<Tensor>& weight_opt /* optional */, const std::optional<Tensor>& bias_opt /* optional */, c10::SymIntArrayRef normalized_shape, const std::optional<Tensor>& weight_opt /* optional */, const std::optional<Tensor>& bias_opt /* optional */,
double eps, double eps,
bool /* cudnn_enable, deprecated */) { bool /* cudnn_enable, deprecated */) {
// See [Note: hacky wrapper removal for optional tensor] return std::get<0>(at::native_layer_norm_symint(input, normalized_shape, weight_opt, bias_opt, eps));
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
const Tensor& bias = *bias_maybe_owned;
return std::get<0>(at::native_layer_norm_symint(input, normalized_shape, weight, bias, eps));
} }
DEFINE_DISPATCH(LayerNormKernel); DEFINE_DISPATCH(LayerNormKernel);