Fix backward of binary_cross_entropy_with_logits

The previous PR in this stack uncovered an error in the forward over
backward for this function.

In this PR, we fix this error and we also fix the gradgrad
implementation (and make it more stable and faster using `logsigmoid`).
We also move the double backward for this function to `FunctoinsManual`
as there's no reason for it to be in `native_functions`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79381

Approved by: https://github.com/soulitzer
This commit is contained in:
lezcano 2022-06-22 01:35:13 +00:00 committed by PyTorch MergeBot
parent 6b20ef6b91
commit 28a7ee8cec
8 changed files with 78 additions and 68 deletions

View File

@ -351,49 +351,6 @@ Tensor binary_cross_entropy_with_logits(const Tensor& input, const Tensor& targe
return apply_loss_reduction(loss, reduction);
}
Tensor binary_cross_entropy_with_logits_backward(
const Tensor& grad,
const Tensor& input,
const Tensor& target,
const c10::optional<Tensor>& weight_opt,
const c10::optional<Tensor>& pos_weight_opt,
int64_t reduction) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned =
at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
const Tensor& pos_weight =
c10::value_or_else(pos_weight_opt, [] { return Tensor(); });
Tensor grad_input;
auto hasSubclassTensors = at::areAnyTensorSubclassLike({grad, input, target});
// If there are subclassed tensors use the out of place version
if (pos_weight.defined()) {
// pos_weight might need to be broadcasted, thus mul(target) is not inplace.
auto t = pos_weight.mul(target);
grad_input = hasSubclassTensors
? t.add(1).sub(target).mul(input.sigmoid()).sub(t).mul(grad)
: t.add(1).sub_(target).mul_(input.sigmoid()).sub_(t).mul_(grad);
} else {
grad_input = hasSubclassTensors ? (input.sigmoid() - target).mul(grad)
: (input.sigmoid() - target).mul_(grad);
}
if (weight.defined()) {
if (at::areAnyTensorSubclassLike({grad_input, weight})) {
grad_input = grad_input.mul(weight);
} else {
grad_input.mul_(weight);
}
}
if (reduction == at::Reduction::Mean) {
return grad_input / input.numel();
}
return grad_input;
}
Tensor poisson_nll_loss(const Tensor& input, const Tensor& target, const bool log_input, const bool full, const double eps, const int64_t reduction)
{
Tensor loss;

View File

@ -130,7 +130,7 @@ void launch_glu_backward_kernel(const TensorIteratorBase& iter,
// -----------------------------------
void launch_log_sigmoid_forward_kernel(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES_AND(kHalf, iter.common_dtype(),
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.common_dtype(),
"log_sigmoid_forward_cuda", [&] {
using opmath_t = at::opmath_type<scalar_t>;
@ -149,7 +149,7 @@ void launch_log_sigmoid_forward_kernel(TensorIteratorBase& iter) {
// -----------------------------------
void log_sigmoid_backward_kernel(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND(kHalf, iter.common_dtype(),
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.common_dtype(),
"log_sigmoid_backward_cuda", [&] {
using opmath_t = at::opmath_type<scalar_t>;
gpu_kernel(iter,

View File

@ -966,9 +966,6 @@
dispatch:
CompositeExplicitAutograd: binary_cross_entropy_with_logits
- func: binary_cross_entropy_with_logits_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, Tensor? pos_weight=None, int reduction=Mean) -> Tensor
variants: function
- func: bincount(Tensor self, Tensor? weights=None, int minlength=0) -> Tensor
variants: function, method
dispatch:

View File

@ -51,9 +51,7 @@ ALLOW_LIST = [
("aten::randperm", datetime.date(9999, 1, 1)),
("aten::linalg_solve", datetime.date(2022, 8, 31)),
("aten::linalg_solve.out", datetime.date(2022, 8, 31)),
("aten::l1_loss_backward.grad_input", datetime.date(2022, 7, 1)),
("aten::l1_loss_backward", datetime.date(2022, 7, 1)),
("aten::l1_loss.out", datetime.date(2022, 7, 1)),
("aten::binary_cross_entropy_with_logits_backward", datetime.date(2022, 9, 21)),
("aten::_linalg_qr_helper", datetime.date(2022, 8, 1)),
("aten::linalg_lu_solve", datetime.date(2022, 8, 1)),
("aten::linalg_lu_solve.out", datetime.date(2022, 8, 1)),

View File

@ -463,7 +463,6 @@ traced_operators:
aten::binary_cross_entropy: 13
aten::binary_cross_entropy_backward: 12
aten::binary_cross_entropy_with_logits: 3
aten::binary_cross_entropy_with_logits_backward: 2
aten::bitwise_and.Tensor: 13
aten::bitwise_and_.Tensor: 1
aten::bitwise_not: 13

View File

@ -1999,6 +1999,57 @@ Tensor binary_cross_entropy_double_backward_target(
return res;
}
Tensor binary_cross_entropy_with_logits_backward(
const Tensor& grad,
const Tensor& input,
const Tensor& target,
const c10::optional<Tensor>& weight,
const c10::optional<Tensor>& pos_weight,
int64_t reduction) {
// Trivial case
if (grad._is_zerotensor()) {
return at::_efficientzerotensor(input.sizes(), input.options());
}
// -w * [ pos * y * (1 -sigmoid(x)) - (1 - y) sigmoid(x)] * grad
// If there are subclassed tensors use the out of place version
Tensor grad_input;
if (isDefined(pos_weight)) {
// pos_weight might need to be broadcasted, thus mul(target) is not inplace.
auto t = pos_weight->mul(target);
grad_input = at::areAnyTensorSubclassLike({input, target}) ||
at::GradMode::is_enabled()
? t.add(1).sub(target).mul(input.sigmoid()).sub(t)
: t.add(1).sub_(target).mul_(input.sigmoid()).sub_(t);
} else {
grad_input = at::areAnyTensorSubclassLike({input, target}) ||
at::GradMode::is_enabled()
? input.sigmoid().sub(target)
: input.sigmoid().sub_(target);
}
if (at::isTensorSubclassLike(grad) || at::GradMode::is_enabled()) {
grad_input = grad_input.mul(grad);
} else {
grad_input.mul_(grad);
}
if (isDefined(weight)) {
if (at::isTensorSubclassLike(*weight) || at::GradMode::is_enabled()) {
grad_input = grad_input.mul(*weight);
} else {
grad_input.mul_(*weight);
}
}
if (reduction == at::Reduction::Mean) {
grad_input.div_(input.numel());
}
return grad_input;
}
Tensor binary_cross_entropy_with_logits_target_backward(
const Tensor& grad_output,
const Tensor& self,
@ -2006,28 +2057,30 @@ Tensor binary_cross_entropy_with_logits_target_backward(
const c10::optional<Tensor>& weight,
const c10::optional<Tensor>& pos_weight,
int64_t reduction) {
if (grad_output._is_zerotensor()) {
return at::_efficientzerotensor(target.sizes(), target.options());
}
Tensor grad_target;
if (isDefined(pos_weight)) {
if (!areAnyTensorSubclassLike({*pos_weight, grad_output})) {
grad_target = (1. - self.sigmoid())
.log_()
.sub_(pos_weight->mul(self.sigmoid().log_()))
.mul_(grad_output);
} else {
grad_target = (1. - self.sigmoid())
.log_()
.sub(pos_weight->mul(self.sigmoid().log_()))
if (areAnyTensorSubclassLike({*pos_weight, grad_output})) {
grad_target = at::log_sigmoid(-self)
.sub(at::log_sigmoid(self).mul(*pos_weight))
.mul(grad_output);
} else {
grad_target = at::log_sigmoid(-self)
.sub_(at::log_sigmoid(self).mul_(*pos_weight))
.mul_(grad_output);
}
} else {
grad_target = self.mul(-grad_output);
grad_target = -self * grad_output;
}
if (isDefined(weight)) {
if (!isTensorSubclassLike(*weight)) {
grad_target.mul_(*weight);
} else {
if (at::isTensorSubclassLike(*weight)) {
grad_target = grad_target.mul(*weight);
} else {
grad_target.mul_(*weight);
}
}

View File

@ -421,6 +421,13 @@ Tensor binary_cross_entropy_double_backward_target(
const Tensor& target,
const c10::optional<Tensor>& weight,
int64_t reduction);
Tensor binary_cross_entropy_with_logits_backward(
const Tensor& grad,
const Tensor& input,
const Tensor& target,
const c10::optional<Tensor>& weight_opt,
const c10::optional<Tensor>& pos_weight_opt,
int64_t reduction);
at::Tensor binary_cross_entropy_with_logits_target_backward(
const at::Tensor& grad_output,
const at::Tensor& self,

View File

@ -13663,7 +13663,6 @@ op_db: List[OpInfo] = [
'test_variant_consistency_jit',
dtypes=(torch.float32,)
),
DecorateInfo(unittest.expectedFailure, 'TestGradients', "test_fn_gradgrad", dtypes=(torch.float64,)),
),
),
UnaryUfuncInfo(
@ -14153,7 +14152,7 @@ op_db: List[OpInfo] = [
ref=_NOTHING,
supports_out=False,
dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.float16),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
sample_inputs_func=sample_inputs_multilabel_soft_margin_loss,
supports_forward_ad=True,
decorators=(
@ -14666,7 +14665,7 @@ op_db: List[OpInfo] = [
aten_backward_name='log_sigmoid_backward',
ref=reference_logsigmoid,
dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.float16),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
supports_autograd=True,
assert_autodiffed=False,
supports_forward_ad=True,