Revert "Add label_smoothing param in nn.BCELoss and nn.BCEWithLogitsLoss (#150282)"

This reverts commit f990490a23.

Reverted https://github.com/pytorch/pytorch/pull/150282 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/150282#issuecomment-3182844949))
This commit is contained in:
PyTorch MergeBot 2025-08-13 09:01:52 +00:00
parent 6e8865fbc1
commit 641ee74781
5 changed files with 11 additions and 62 deletions

View File

@ -3472,7 +3472,6 @@ def binary_cross_entropy(
size_average: Optional[bool] = None, size_average: Optional[bool] = None,
reduce: Optional[bool] = None, reduce: Optional[bool] = None,
reduction: str = "mean", reduction: str = "mean",
label_smoothing: float = 0.0,
) -> Tensor: ) -> Tensor:
r"""Compute Binary Cross Entropy between the target and input probabilities. r"""Compute Binary Cross Entropy between the target and input probabilities.
@ -3491,11 +3490,9 @@ def binary_cross_entropy(
elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
and :attr:`reduce` are in the process of being deprecated, and in the meantime, and :attr:`reduce` are in the process of being deprecated, and in the meantime,
specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
label_smoothing (float, optional): A float in [0.0, 1.0]. Specifies the amount
of smoothing when computing the loss, where 0.0 means no smoothing. The targets
become a mixture of the original ground truth and a uniform distribution as described in
`Rethinking the Inception Architecture for Computer Vision <https://arxiv.org/abs/1512.00567>`__. Default: :math:`0.0`.
Examples:: Examples::
>>> input = torch.randn(3, 2, requires_grad=True) >>> input = torch.randn(3, 2, requires_grad=True)
>>> target = torch.rand(3, 2, requires_grad=False) >>> target = torch.rand(3, 2, requires_grad=False)
>>> loss = F.binary_cross_entropy(torch.sigmoid(input), target) >>> loss = F.binary_cross_entropy(torch.sigmoid(input), target)
@ -3511,7 +3508,6 @@ def binary_cross_entropy(
size_average=size_average, size_average=size_average,
reduce=reduce, reduce=reduce,
reduction=reduction, reduction=reduction,
label_smoothing=label_smoothing,
) )
if size_average is not None or reduce is not None: if size_average is not None or reduce is not None:
reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
@ -3527,13 +3523,6 @@ def binary_cross_entropy(
new_size = _infer_size(target.size(), weight.size()) new_size = _infer_size(target.size(), weight.size())
weight = weight.expand(new_size) weight = weight.expand(new_size)
assert 0 <= label_smoothing <= 1, (
f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}"
)
if label_smoothing > 0:
target = target * (1 - label_smoothing) + (1 - target) * label_smoothing
return torch._C._nn.binary_cross_entropy(input, target, weight, reduction_enum) return torch._C._nn.binary_cross_entropy(input, target, weight, reduction_enum)
@ -3545,7 +3534,6 @@ def binary_cross_entropy_with_logits(
reduce: Optional[bool] = None, reduce: Optional[bool] = None,
reduction: str = "mean", reduction: str = "mean",
pos_weight: Optional[Tensor] = None, pos_weight: Optional[Tensor] = None,
label_smoothing: float = 0.0,
) -> Tensor: ) -> Tensor:
r"""Compute Binary Cross Entropy between target and input logits. r"""Compute Binary Cross Entropy between target and input logits.
@ -3572,11 +3560,9 @@ def binary_cross_entropy_with_logits(
[C, H, W] the same pos_weights across the batch. To apply the same positive weight [C, H, W] the same pos_weights across the batch. To apply the same positive weight
along all spatial dimensions for a 2D multi-class target [C, H, W] use: [C, 1, 1]. along all spatial dimensions for a 2D multi-class target [C, H, W] use: [C, 1, 1].
Default: ``None`` Default: ``None``
label_smoothing (float, optional): A float in [0.0, 1.0]. Specifies the amount
of smoothing when computing the loss, where 0.0 means no smoothing. The targets
become a mixture of the original ground truth and a uniform distribution as described in
`Rethinking the Inception Architecture for Computer Vision <https://arxiv.org/abs/1512.00567>`__. Default: :math:`0.0`.
Examples:: Examples::
>>> input = torch.randn(3, requires_grad=True) >>> input = torch.randn(3, requires_grad=True)
>>> target = torch.empty(3).random_(2) >>> target = torch.empty(3).random_(2)
>>> loss = F.binary_cross_entropy_with_logits(input, target) >>> loss = F.binary_cross_entropy_with_logits(input, target)
@ -3593,7 +3579,6 @@ def binary_cross_entropy_with_logits(
reduce=reduce, reduce=reduce,
reduction=reduction, reduction=reduction,
pos_weight=pos_weight, pos_weight=pos_weight,
label_smoothing=label_smoothing,
) )
if size_average is not None or reduce is not None: if size_average is not None or reduce is not None:
reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
@ -3605,13 +3590,6 @@ def binary_cross_entropy_with_logits(
f"Target size ({target.size()}) must be the same as input size ({input.size()})" f"Target size ({target.size()}) must be the same as input size ({input.size()})"
) )
assert 0 <= label_smoothing <= 1, (
f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}"
)
if label_smoothing > 0:
target = target * (1 - label_smoothing) + (1 - target) * label_smoothing
return torch.binary_cross_entropy_with_logits( return torch.binary_cross_entropy_with_logits(
input, target, weight, pos_weight, reduction_enum input, target, weight, pos_weight, reduction_enum
) )

View File

@ -134,7 +134,6 @@ def binary_cross_entropy_with_logits(
reduce: bool | None = ..., reduce: bool | None = ...,
reduction: str = ..., reduction: str = ...,
pos_weight: Tensor | None = ..., pos_weight: Tensor | None = ...,
label_smoothing: float = ...,
) -> Tensor: ... ) -> Tensor: ...
__all__ += ["binary_cross_entropy_with_logits"] __all__ += ["binary_cross_entropy_with_logits"]
@ -146,7 +145,6 @@ def binary_cross_entropy(
size_average: bool | None = ..., size_average: bool | None = ...,
reduce: bool | None = ..., reduce: bool | None = ...,
reduction: str = ..., reduction: str = ...,
label_smoothing: float = ...,
) -> Tensor: ... ) -> Tensor: ...
__all__ += ["binary_cross_entropy"] __all__ += ["binary_cross_entropy"]

View File

@ -692,10 +692,6 @@ class BCELoss(_WeightedLoss):
elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
and :attr:`reduce` are in the process of being deprecated, and in the meantime, and :attr:`reduce` are in the process of being deprecated, and in the meantime,
specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
label_smoothing (float, optional): A float in [0.0, 1.0]. Specifies the amount
of smoothing when computing the loss, where 0.0 means no smoothing. The targets
become a mixture of the original ground truth and a uniform distribution as described in
`Rethinking the Inception Architecture for Computer Vision <https://arxiv.org/abs/1512.00567>`__. Default: :math:`0.0`.
Shape: Shape:
- Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
@ -721,21 +717,15 @@ class BCELoss(_WeightedLoss):
size_average=None, size_average=None,
reduce=None, reduce=None,
reduction: str = "mean", reduction: str = "mean",
label_smoothing: float = 0.0,
) -> None: ) -> None:
super().__init__(weight, size_average, reduce, reduction) super().__init__(weight, size_average, reduce, reduction)
self.label_smoothing = label_smoothing
def forward(self, input: Tensor, target: Tensor) -> Tensor: def forward(self, input: Tensor, target: Tensor) -> Tensor:
""" """
Runs the forward pass. Runs the forward pass.
""" """
return F.binary_cross_entropy( return F.binary_cross_entropy(
input, input, target, weight=self.weight, reduction=self.reduction
target,
weight=self.weight,
reduction=self.reduction,
label_smoothing=self.label_smoothing,
) )
@ -825,10 +815,6 @@ class BCEWithLogitsLoss(_Loss):
[C, H, W] the same pos_weights across the batch. To apply the same positive weight [C, H, W] the same pos_weights across the batch. To apply the same positive weight
along all spatial dimensions for a 2D multi-class target [C, H, W] use: [C, 1, 1]. along all spatial dimensions for a 2D multi-class target [C, H, W] use: [C, 1, 1].
Default: ``None`` Default: ``None``
label_smoothing (float, optional): A float in [0.0, 1.0]. Specifies the amount
of smoothing when computing the loss, where 0.0 means no smoothing. The targets
become a mixture of the original ground truth and a uniform distribution as described in
`Rethinking the Inception Architecture for Computer Vision <https://arxiv.org/abs/1512.00567>`__. Default: :math:`0.0`.
Shape: Shape:
- Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
@ -852,14 +838,12 @@ class BCEWithLogitsLoss(_Loss):
reduce=None, reduce=None,
reduction: str = "mean", reduction: str = "mean",
pos_weight: Optional[Tensor] = None, pos_weight: Optional[Tensor] = None,
label_smoothing: float = 0.0,
) -> None: ) -> None:
super().__init__(size_average, reduce, reduction) super().__init__(size_average, reduce, reduction)
self.register_buffer("weight", weight) self.register_buffer("weight", weight)
self.register_buffer("pos_weight", pos_weight) self.register_buffer("pos_weight", pos_weight)
self.weight: Optional[Tensor] self.weight: Optional[Tensor]
self.pos_weight: Optional[Tensor] self.pos_weight: Optional[Tensor]
self.label_smoothing = label_smoothing
def forward(self, input: Tensor, target: Tensor) -> Tensor: def forward(self, input: Tensor, target: Tensor) -> Tensor:
"""Runs the forward pass.""" """Runs the forward pass."""
@ -869,7 +853,6 @@ class BCEWithLogitsLoss(_Loss):
self.weight, self.weight,
pos_weight=self.pos_weight, pos_weight=self.pos_weight,
reduction=self.reduction, reduction=self.reduction,
label_smoothing=self.label_smoothing,
) )

View File

@ -488,7 +488,7 @@ def get_testing_overrides() -> dict[Callable, Callable]:
torch.bernoulli: lambda input, generator=None, out=None: -1, torch.bernoulli: lambda input, generator=None, out=None: -1,
torch.bilinear: lambda input1, input2, weight, bias: -1, torch.bilinear: lambda input1, input2, weight, bias: -1,
torch.binary_cross_entropy_with_logits: ( torch.binary_cross_entropy_with_logits: (
lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean", pos_weight=None, label_smoothing=0.0: -1 # noqa: B950 lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean", pos_weight=None: -1
), ),
torch.bincount: lambda input, weights=None, minlength=0: -1, torch.bincount: lambda input, weights=None, minlength=0: -1,
torch.binomial: lambda count, prob, generator=None: -1, torch.binomial: lambda count, prob, generator=None: -1,
@ -851,10 +851,10 @@ def get_testing_overrides() -> dict[Callable, Callable]:
), ),
torch.nn.functional.bilinear: lambda input1, input2, weight, bias=None: -1, torch.nn.functional.bilinear: lambda input1, input2, weight, bias=None: -1,
torch.nn.functional.binary_cross_entropy: ( torch.nn.functional.binary_cross_entropy: (
lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean", label_smoothing=0.0: -1 lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean": -1
), ),
torch.nn.functional.binary_cross_entropy_with_logits: ( torch.nn.functional.binary_cross_entropy_with_logits: (
lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean", pos_weight=None, label_smoothing=0.0: -1 # noqa: B950 lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean", pos_weight=None: -1
), ),
torch.nn.functional.celu: lambda input, alpha=1.0, inplace=False: -1, torch.nn.functional.celu: lambda input, alpha=1.0, inplace=False: -1,
torch.nn.functional.cosine_embedding_loss: ( torch.nn.functional.cosine_embedding_loss: (

View File

@ -1463,14 +1463,9 @@ def module_inputs_torch_nn_BCELoss(module_info, device, dtype, requires_grad, tr
('reduction_mean', {'reduction': 'mean'}), ('reduction_mean', {'reduction': 'mean'}),
('reduction_none', {'reduction': 'none'}), ('reduction_none', {'reduction': 'none'}),
('weights', {'weight': make_weight((10,))}), ('weights', {'weight': make_weight((10,))}),
('label_smoothing', {'label_smoothing': 0.15}),
] ]
def bce_loss_reference_fn(m, p, i, t, reduction='mean', weight=None, label_smoothing=0.0): def bce_loss_reference_fn(m, p, i, t, reduction='mean', weight=None):
assert 0 <= label_smoothing <= 1
if label_smoothing > 0:
t = t * (1 - label_smoothing) + (1 - t) * label_smoothing
result = -(t * i.log() + (1 - t) * (1 - i).log()) result = -(t * i.log() + (1 - t) * (1 - i).log())
if weight is not None: if weight is not None:
@ -1516,15 +1511,10 @@ def module_inputs_torch_nn_BCEWithLogitsLoss(module_info, device, dtype, require
('reduction_mean', {'reduction': 'mean'}), ('reduction_mean', {'reduction': 'mean'}),
('reduction_none', {'reduction': 'none'}), ('reduction_none', {'reduction': 'none'}),
('weights', {'weight': make_weight((10,))}), ('weights', {'weight': make_weight((10,))}),
('scalar_weights', {'weight': make_weight(())}), ('scalar_weights', {'weight': make_weight(())})
('label_smoothing', {'label_smoothing': 0.15}),
] ]
def bce_withlogitsloss_reference_fn(m, p, i, t, reduction='mean', weight=None, label_smoothing=0.0): def bce_withlogitsloss_reference_fn(m, p, i, t, reduction='mean', weight=None):
assert 0 <= label_smoothing <= 1
if label_smoothing > 0:
t = t * (1 - label_smoothing) + (1 - t) * label_smoothing
# TODO: add pos_weight to the definition here and corresponding SampleInputs # TODO: add pos_weight to the definition here and corresponding SampleInputs
max_val = (-i).clamp(min=0) max_val = (-i).clamp(min=0)
result = (1 - t).mul_(i).add_(max_val).add_((-max_val).exp_().add_((-i - max_val).exp_()).log_()) result = (1 - t).mul_(i).add_(max_val).add_((-max_val).exp_().add_((-i - max_val).exp_()).log_())