mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Fix backward of binary_cross_entropy_with_logits
The previous PR in this stack uncovered an error in the forward over backward for this function. In this PR, we fix this error and we also fix the gradgrad implementation (and make it more stable and faster using `logsigmoid`). We also move the double backward for this function to `FunctoinsManual` as there's no reason for it to be in `native_functions` Pull Request resolved: https://github.com/pytorch/pytorch/pull/79381 Approved by: https://github.com/soulitzer
This commit is contained in:
parent
6b20ef6b91
commit
28a7ee8cec
|
|
@ -351,49 +351,6 @@ Tensor binary_cross_entropy_with_logits(const Tensor& input, const Tensor& targe
|
|||
return apply_loss_reduction(loss, reduction);
|
||||
}
|
||||
|
||||
Tensor binary_cross_entropy_with_logits_backward(
|
||||
const Tensor& grad,
|
||||
const Tensor& input,
|
||||
const Tensor& target,
|
||||
const c10::optional<Tensor>& weight_opt,
|
||||
const c10::optional<Tensor>& pos_weight_opt,
|
||||
int64_t reduction) {
|
||||
// See [Note: hacky wrapper removal for optional tensor]
|
||||
c10::MaybeOwned<Tensor> weight_maybe_owned =
|
||||
at::borrow_from_optional_tensor(weight_opt);
|
||||
const Tensor& weight = *weight_maybe_owned;
|
||||
const Tensor& pos_weight =
|
||||
c10::value_or_else(pos_weight_opt, [] { return Tensor(); });
|
||||
|
||||
Tensor grad_input;
|
||||
auto hasSubclassTensors = at::areAnyTensorSubclassLike({grad, input, target});
|
||||
|
||||
// If there are subclassed tensors use the out of place version
|
||||
if (pos_weight.defined()) {
|
||||
// pos_weight might need to be broadcasted, thus mul(target) is not inplace.
|
||||
auto t = pos_weight.mul(target);
|
||||
grad_input = hasSubclassTensors
|
||||
? t.add(1).sub(target).mul(input.sigmoid()).sub(t).mul(grad)
|
||||
: t.add(1).sub_(target).mul_(input.sigmoid()).sub_(t).mul_(grad);
|
||||
} else {
|
||||
grad_input = hasSubclassTensors ? (input.sigmoid() - target).mul(grad)
|
||||
: (input.sigmoid() - target).mul_(grad);
|
||||
}
|
||||
if (weight.defined()) {
|
||||
if (at::areAnyTensorSubclassLike({grad_input, weight})) {
|
||||
grad_input = grad_input.mul(weight);
|
||||
} else {
|
||||
grad_input.mul_(weight);
|
||||
}
|
||||
}
|
||||
|
||||
if (reduction == at::Reduction::Mean) {
|
||||
return grad_input / input.numel();
|
||||
}
|
||||
|
||||
return grad_input;
|
||||
}
|
||||
|
||||
Tensor poisson_nll_loss(const Tensor& input, const Tensor& target, const bool log_input, const bool full, const double eps, const int64_t reduction)
|
||||
{
|
||||
Tensor loss;
|
||||
|
|
|
|||
|
|
@ -130,7 +130,7 @@ void launch_glu_backward_kernel(const TensorIteratorBase& iter,
|
|||
// -----------------------------------
|
||||
|
||||
void launch_log_sigmoid_forward_kernel(TensorIteratorBase& iter) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(kHalf, iter.common_dtype(),
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.common_dtype(),
|
||||
"log_sigmoid_forward_cuda", [&] {
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
|
||||
|
|
@ -149,7 +149,7 @@ void launch_log_sigmoid_forward_kernel(TensorIteratorBase& iter) {
|
|||
// -----------------------------------
|
||||
|
||||
void log_sigmoid_backward_kernel(TensorIterator& iter) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(kHalf, iter.common_dtype(),
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.common_dtype(),
|
||||
"log_sigmoid_backward_cuda", [&] {
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
gpu_kernel(iter,
|
||||
|
|
|
|||
|
|
@ -966,9 +966,6 @@
|
|||
dispatch:
|
||||
CompositeExplicitAutograd: binary_cross_entropy_with_logits
|
||||
|
||||
- func: binary_cross_entropy_with_logits_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, Tensor? pos_weight=None, int reduction=Mean) -> Tensor
|
||||
variants: function
|
||||
|
||||
- func: bincount(Tensor self, Tensor? weights=None, int minlength=0) -> Tensor
|
||||
variants: function, method
|
||||
dispatch:
|
||||
|
|
|
|||
|
|
@ -51,9 +51,7 @@ ALLOW_LIST = [
|
|||
("aten::randperm", datetime.date(9999, 1, 1)),
|
||||
("aten::linalg_solve", datetime.date(2022, 8, 31)),
|
||||
("aten::linalg_solve.out", datetime.date(2022, 8, 31)),
|
||||
("aten::l1_loss_backward.grad_input", datetime.date(2022, 7, 1)),
|
||||
("aten::l1_loss_backward", datetime.date(2022, 7, 1)),
|
||||
("aten::l1_loss.out", datetime.date(2022, 7, 1)),
|
||||
("aten::binary_cross_entropy_with_logits_backward", datetime.date(2022, 9, 21)),
|
||||
("aten::_linalg_qr_helper", datetime.date(2022, 8, 1)),
|
||||
("aten::linalg_lu_solve", datetime.date(2022, 8, 1)),
|
||||
("aten::linalg_lu_solve.out", datetime.date(2022, 8, 1)),
|
||||
|
|
|
|||
|
|
@ -463,7 +463,6 @@ traced_operators:
|
|||
aten::binary_cross_entropy: 13
|
||||
aten::binary_cross_entropy_backward: 12
|
||||
aten::binary_cross_entropy_with_logits: 3
|
||||
aten::binary_cross_entropy_with_logits_backward: 2
|
||||
aten::bitwise_and.Tensor: 13
|
||||
aten::bitwise_and_.Tensor: 1
|
||||
aten::bitwise_not: 13
|
||||
|
|
|
|||
|
|
@ -1999,6 +1999,57 @@ Tensor binary_cross_entropy_double_backward_target(
|
|||
return res;
|
||||
}
|
||||
|
||||
Tensor binary_cross_entropy_with_logits_backward(
|
||||
const Tensor& grad,
|
||||
const Tensor& input,
|
||||
const Tensor& target,
|
||||
const c10::optional<Tensor>& weight,
|
||||
const c10::optional<Tensor>& pos_weight,
|
||||
int64_t reduction) {
|
||||
// Trivial case
|
||||
if (grad._is_zerotensor()) {
|
||||
return at::_efficientzerotensor(input.sizes(), input.options());
|
||||
}
|
||||
|
||||
// -w * [ pos * y * (1 -sigmoid(x)) - (1 - y) sigmoid(x)] * grad
|
||||
|
||||
// If there are subclassed tensors use the out of place version
|
||||
Tensor grad_input;
|
||||
if (isDefined(pos_weight)) {
|
||||
// pos_weight might need to be broadcasted, thus mul(target) is not inplace.
|
||||
auto t = pos_weight->mul(target);
|
||||
grad_input = at::areAnyTensorSubclassLike({input, target}) ||
|
||||
at::GradMode::is_enabled()
|
||||
? t.add(1).sub(target).mul(input.sigmoid()).sub(t)
|
||||
: t.add(1).sub_(target).mul_(input.sigmoid()).sub_(t);
|
||||
} else {
|
||||
grad_input = at::areAnyTensorSubclassLike({input, target}) ||
|
||||
at::GradMode::is_enabled()
|
||||
? input.sigmoid().sub(target)
|
||||
: input.sigmoid().sub_(target);
|
||||
}
|
||||
|
||||
if (at::isTensorSubclassLike(grad) || at::GradMode::is_enabled()) {
|
||||
grad_input = grad_input.mul(grad);
|
||||
} else {
|
||||
grad_input.mul_(grad);
|
||||
}
|
||||
|
||||
if (isDefined(weight)) {
|
||||
if (at::isTensorSubclassLike(*weight) || at::GradMode::is_enabled()) {
|
||||
grad_input = grad_input.mul(*weight);
|
||||
} else {
|
||||
grad_input.mul_(*weight);
|
||||
}
|
||||
}
|
||||
|
||||
if (reduction == at::Reduction::Mean) {
|
||||
grad_input.div_(input.numel());
|
||||
}
|
||||
|
||||
return grad_input;
|
||||
}
|
||||
|
||||
Tensor binary_cross_entropy_with_logits_target_backward(
|
||||
const Tensor& grad_output,
|
||||
const Tensor& self,
|
||||
|
|
@ -2006,28 +2057,30 @@ Tensor binary_cross_entropy_with_logits_target_backward(
|
|||
const c10::optional<Tensor>& weight,
|
||||
const c10::optional<Tensor>& pos_weight,
|
||||
int64_t reduction) {
|
||||
if (grad_output._is_zerotensor()) {
|
||||
return at::_efficientzerotensor(target.sizes(), target.options());
|
||||
}
|
||||
|
||||
Tensor grad_target;
|
||||
if (isDefined(pos_weight)) {
|
||||
if (!areAnyTensorSubclassLike({*pos_weight, grad_output})) {
|
||||
grad_target = (1. - self.sigmoid())
|
||||
.log_()
|
||||
.sub_(pos_weight->mul(self.sigmoid().log_()))
|
||||
.mul_(grad_output);
|
||||
} else {
|
||||
grad_target = (1. - self.sigmoid())
|
||||
.log_()
|
||||
.sub(pos_weight->mul(self.sigmoid().log_()))
|
||||
if (areAnyTensorSubclassLike({*pos_weight, grad_output})) {
|
||||
grad_target = at::log_sigmoid(-self)
|
||||
.sub(at::log_sigmoid(self).mul(*pos_weight))
|
||||
.mul(grad_output);
|
||||
} else {
|
||||
grad_target = at::log_sigmoid(-self)
|
||||
.sub_(at::log_sigmoid(self).mul_(*pos_weight))
|
||||
.mul_(grad_output);
|
||||
}
|
||||
} else {
|
||||
grad_target = self.mul(-grad_output);
|
||||
grad_target = -self * grad_output;
|
||||
}
|
||||
|
||||
if (isDefined(weight)) {
|
||||
if (!isTensorSubclassLike(*weight)) {
|
||||
grad_target.mul_(*weight);
|
||||
} else {
|
||||
if (at::isTensorSubclassLike(*weight)) {
|
||||
grad_target = grad_target.mul(*weight);
|
||||
} else {
|
||||
grad_target.mul_(*weight);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -421,6 +421,13 @@ Tensor binary_cross_entropy_double_backward_target(
|
|||
const Tensor& target,
|
||||
const c10::optional<Tensor>& weight,
|
||||
int64_t reduction);
|
||||
Tensor binary_cross_entropy_with_logits_backward(
|
||||
const Tensor& grad,
|
||||
const Tensor& input,
|
||||
const Tensor& target,
|
||||
const c10::optional<Tensor>& weight_opt,
|
||||
const c10::optional<Tensor>& pos_weight_opt,
|
||||
int64_t reduction);
|
||||
at::Tensor binary_cross_entropy_with_logits_target_backward(
|
||||
const at::Tensor& grad_output,
|
||||
const at::Tensor& self,
|
||||
|
|
|
|||
|
|
@ -13663,7 +13663,6 @@ op_db: List[OpInfo] = [
|
|||
'test_variant_consistency_jit',
|
||||
dtypes=(torch.float32,)
|
||||
),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestGradients', "test_fn_gradgrad", dtypes=(torch.float64,)),
|
||||
),
|
||||
),
|
||||
UnaryUfuncInfo(
|
||||
|
|
@ -14153,7 +14152,7 @@ op_db: List[OpInfo] = [
|
|||
ref=_NOTHING,
|
||||
supports_out=False,
|
||||
dtypes=floating_types_and(torch.bfloat16),
|
||||
dtypesIfCUDA=floating_types_and(torch.float16),
|
||||
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
|
||||
sample_inputs_func=sample_inputs_multilabel_soft_margin_loss,
|
||||
supports_forward_ad=True,
|
||||
decorators=(
|
||||
|
|
@ -14666,7 +14665,7 @@ op_db: List[OpInfo] = [
|
|||
aten_backward_name='log_sigmoid_backward',
|
||||
ref=reference_logsigmoid,
|
||||
dtypes=floating_types_and(torch.bfloat16),
|
||||
dtypesIfCUDA=floating_types_and(torch.float16),
|
||||
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
|
||||
supports_autograd=True,
|
||||
assert_autodiffed=False,
|
||||
supports_forward_ad=True,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user