Add Weighted Loss Functions to PyTorch : WMSE, WMAE, and Weighted Huber Loss (#132049)

#### Summary
This pull request introduces new weighted loss functions to the PyTorch library: `weighted_huber_loss`, `wmse_loss`, and `wmae_loss`. These functions allow for precise control over the influence of each sample during training, important for imbalanced data or when certain samples are more significant than others.

#### Changes
- **`weighted_huber_loss`**: Huber loss modified to incorporate weights, providing a balance between L1 and L2 loss based on the `delta` parameter.
- **`wmse_loss`** (Weighted Mean Squared Error): Applies weights to the standard MSE loss, useful for emphasizing certain samples in regression tasks.
- **`wmae_loss`** (Weighted Mean Absolute Error): Adjusts MAE loss calculation by including weights, ideal for datasets with outliers.

#### Code Details
- **Input Validation**: Ensures `input`, `target`, and `weights` tensors match in size to prevent broadcasting errors.
- **Reduction Options**: Supports `none`, `mean`, and `sum` reductions to suit various computational needs.
- **Backward Compatibility**: Maintains support for deprecated arguments `size_average` and `reduce`, while encouraging use of the `reduction` argument.

#### Usage Example
```python
import torch
input = torch.tensor([0.5, 2.5, 2.0], dtype=torch.float32)
target = torch.tensor([0.0, 2.0, 1.5], dtype=torch.float32)
weights = torch.tensor([1.0, 0.5, 1.5], dtype=torch.float32)

loss = weighted_huber_loss(input, target, weights, delta=1.0)
print(loss)
```
---

Feedback on these implementations is welcome; please let me know if further modifications are required.

Resolves #132465

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132049
Approved by: https://github.com/mikaylagawarecki

Co-authored-by: mikaylagawarecki <mikaylagawarecki@gmail.com>
This commit is contained in:
Donald Tolley 2024-10-31 21:59:41 +00:00 committed by PyTorch MergeBot
parent 82e74ad40e
commit c1e7d85ce6
3 changed files with 145 additions and 23 deletions

View File

@ -2577,6 +2577,30 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
self.assertEqual(len(w), 1)
self.assertIn('Please ensure they have the same size.', str(w[0]))
def test_weighted_mse_loss(self):
inputs = torch.tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True)
targets = torch.tensor([1.5, 2.5, 3.5, 4.5])
weight = torch.tensor([1.0, 2.0, 3.0, 4.0])
loss = F.mse_loss(inputs, targets, weight=weight, reduction='mean')
expected_loss = torch.tensor(0.25)
self.assertTrue(torch.isclose(loss, expected_loss), f"Expected {expected_loss}, but got {loss}")
def test_weighted_l1_loss_with_weights(self):
inputs = torch.tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True)
targets = torch.tensor([1.5, 2.5, 3.5, 4.5])
weight = torch.tensor([1.0, 2.0, 3.0, 4.0])
loss = F.l1_loss(inputs, targets, weight=weight, reduction='mean')
expected_loss = torch.tensor(0.5)
self.assertTrue(torch.isclose(loss, expected_loss), f"Expected {expected_loss}, but got {loss}")
def test_weighted_huber_loss(self):
inputs = torch.tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True)
targets = torch.tensor([1.5, 2.5, 3.5, 4.5])
weight = torch.tensor([1.0, 2.0, 3.0, 4.0])
loss = F.huber_loss(input=inputs, target=targets, weight=weight, reduction='mean', delta=1.0)
expected_loss = torch.tensor(0.25)
print(torch.isclose(loss, expected_loss, atol=1e-6), f"Expected {expected_loss}, but got {loss}")
def test_gaussian_nll_loss_broadcasting(self):
input = torch.tensor([[0.5, 1.5, 2.5], [2., 4., 6.]])
target_full = torch.tensor([[1., 2., 3.], [1., 2., 3.]])

View File

@ -3693,8 +3693,11 @@ def huber_loss(
target: Tensor,
reduction: str = "mean",
delta: float = 1.0,
weight: Optional[Tensor] = None,
) -> Tensor:
r"""Compute the Huber loss.
r"""huber_loss(input, target, reduction='mean', delta=1.0, weight=None) -> Tensor
Computes the Huber loss, with optional weighting.
Function uses a squared term if the absolute
element-wise error falls below delta and a delta-scaled L1 term otherwise.
@ -3702,17 +3705,30 @@ def huber_loss(
When delta equals 1, this loss is equivalent to SmoothL1Loss.
In general, Huber loss differs from SmoothL1Loss by a factor of delta (AKA beta in Smooth L1).
See :class:`~torch.nn.HuberLoss` for details.
Args:
input (Tensor): Predicted values.
target (Tensor): Ground truth values.
reduction (str, optional): Specifies the reduction to apply to the output:
'none' | 'mean' | 'sum'. 'mean': the mean of the output is taken.
'sum': the output will be summed. 'none': no reduction will be applied.
Default: 'mean'.
delta (float, optional): The threshold at which to change between delta-scaled L1 and L2 loss. Default: 1.0.
weight (Tensor, optional): Weights for each sample. Default: None.
Returns:
Tensor: Huber loss (optionally weighted).
"""
if has_torch_function_variadic(input, target):
if has_torch_function_variadic(input, target, weight):
return handle_torch_function(
huber_loss,
(input, target),
(input, target, weight),
input,
target,
reduction=reduction,
delta=delta,
weight=weight,
)
if not (target.size() == input.size()):
warnings.warn(
f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). "
@ -3722,9 +3738,34 @@ def huber_loss(
)
expanded_input, expanded_target = torch.broadcast_tensors(input, target)
return torch._C._nn.huber_loss(
expanded_input, expanded_target, _Reduction.get_enum(reduction), delta
)
if weight is None:
# Use the optimized C++ backend for standard Huber loss
return torch._C._nn.huber_loss(
expanded_input, expanded_target, _Reduction.get_enum(reduction), delta
)
else:
if weight.size() != input.size():
raise ValueError("Weights and input must have the same size.")
# Calculate the unweighted loss first
unweighted_loss = torch._C._nn.huber_loss(
expanded_input, expanded_target, _Reduction.get_enum("none"), delta
)
# Apply weight to the unweighted loss
weighted_loss = unweighted_loss * weight
if reduction == "none":
return weighted_loss
elif reduction == "sum":
return torch.sum(weighted_loss)
elif reduction == "mean":
return weighted_loss.mean()
else:
raise ValueError(
f"Invalid reduction mode: {reduction}. Expected one of 'none', 'mean', 'sum'."
)
def l1_loss(
@ -3733,6 +3774,7 @@ def l1_loss(
size_average: Optional[bool] = None,
reduce: Optional[bool] = None,
reduction: str = "mean",
weight: Optional[Tensor] = None,
) -> Tensor: # noqa: D400,D402
r"""l1_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor
@ -3743,7 +3785,7 @@ def l1_loss(
if has_torch_function_variadic(input, target):
return handle_torch_function(
l1_loss,
(input, target),
(input, target, weight),
input,
target,
size_average=size_average,
@ -3761,9 +3803,28 @@ def l1_loss(
reduction = _Reduction.legacy_get_string(size_average, reduce)
expanded_input, expanded_target = torch.broadcast_tensors(input, target)
return torch._C._nn.l1_loss(
expanded_input, expanded_target, _Reduction.get_enum(reduction)
)
if weight is not None:
if weight.size() != input.size():
raise ValueError("Weights and input must have the same size.")
absolute_errors = torch.abs(expanded_input - expanded_target)
weighted_absolute_errors = absolute_errors * weight
if reduction == "none":
return weighted_absolute_errors
elif reduction == "sum":
return torch.sum(weighted_absolute_errors)
elif reduction == "mean":
return torch.sum(weighted_absolute_errors) / torch.sum(weight)
else:
raise ValueError(
f"Invalid reduction mode: {reduction}. Expected one of 'none', 'mean', 'sum'."
)
else:
return torch._C._nn.l1_loss(
expanded_input, expanded_target, _Reduction.get_enum(reduction)
)
def mse_loss(
@ -3772,22 +3833,38 @@ def mse_loss(
size_average: Optional[bool] = None,
reduce: Optional[bool] = None,
reduction: str = "mean",
) -> Tensor: # noqa: D400,D402
r"""mse_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor
weight: Optional[Tensor] = None,
) -> Tensor:
r"""mse_loss(input, target, size_average=None, reduce=None, reduction='mean', weight=None) -> Tensor
Measures the element-wise mean squared error.
See :class:`~torch.nn.MSELoss` for details.
Measures the element-wise mean squared error, with optional weighting.
Args:
input (Tensor): Predicted values.
target (Tensor): Ground truth values.
size_average (bool, optional): Deprecated (use reduction).
reduce (bool, optional): Deprecated (use reduction).
reduction (str, optional): Specifies the reduction to apply to the output:
'none' | 'mean' | 'sum'. 'mean': the mean of the output is taken.
'sum': the output will be summed. 'none': no reduction will be applied.
Default: 'mean'.
weight (Tensor, optional): Weights for each sample. Default: None.
Returns:
Tensor: Mean Squared Error loss (optionally weighted).
"""
if has_torch_function_variadic(input, target):
if has_torch_function_variadic(input, target, weight):
return handle_torch_function(
mse_loss,
(input, target),
(input, target, weight),
input,
target,
size_average=size_average,
reduce=reduce,
reduction=reduction,
weight=weight,
)
if not (target.size() == input.size()):
warnings.warn(
f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). "
@ -3795,13 +3872,34 @@ def mse_loss(
"Please ensure they have the same size.",
stacklevel=2,
)
if size_average is not None or reduce is not None:
reduction = _Reduction.legacy_get_string(size_average, reduce)
expanded_input, expanded_target = torch.broadcast_tensors(input, target)
return torch._C._nn.mse_loss(
expanded_input, expanded_target, _Reduction.get_enum(reduction)
)
if weight is not None:
if weight.size() != input.size():
raise ValueError("Weights and input must have the same size.")
# Perform weighted MSE loss manually
squared_errors = torch.pow(expanded_input - expanded_target, 2)
weighted_squared_errors = squared_errors * weight
if reduction == "none":
return weighted_squared_errors
elif reduction == "sum":
return torch.sum(weighted_squared_errors)
elif reduction == "mean":
return torch.sum(weighted_squared_errors) / torch.sum(weight)
else:
raise ValueError(
f"Invalid reduction mode: {reduction}. Expected one of 'none', 'mean', 'sum'."
)
else:
return torch._C._nn.mse_loss(
expanded_input, expanded_target, _Reduction.get_enum(reduction)
)
def margin_ranking_loss(

View File

@ -901,7 +901,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
lambda input, size=None, scale_factor=None, mode="nearest", align_corners=None, recompute_scale_factor=None, antialias=False: -1 # noqa: B950
),
torch.nn.functional.kl_div: lambda input, target, size_average=None, reduce=None, reduction="mean", log_target=False: -1, # noqa: B950
torch.nn.functional.l1_loss: lambda input, target, size_average=None, reduce=None, reduction="mean": -1,
torch.nn.functional.l1_loss: lambda input, target, size_average=None, reduce=None, reduction="mean", weight=None: -1,
torch.nn.functional.layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1,
torch.nn.functional.leaky_relu: lambda input, negative_slope=0.01, inplace=False: -1,
torch.nn.functional.linear: lambda input, weight, bias=None: -1,
@ -935,7 +935,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.nn.functional.max_unpool1d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1, # noqa: B950
torch.nn.functional.max_unpool2d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1, # noqa: B950
torch.nn.functional.max_unpool3d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1, # noqa: B950
torch.nn.functional.mse_loss: lambda input, target, size_average=None, reduce=None, reduction="mean": -1,
torch.nn.functional.mse_loss: lambda input, target, size_average=None, reduce=None, reduction="mean", weight=None: -1,
torch.nn.functional.multi_head_attention_forward: (
lambda query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training=True, key_padding_mask=None, need_weights=True, attn_mask=None, use_separate_proj_weight=False, q_proj_weight=None, k_proj_weight=None, v_proj_weight=None, static_k=None, static_v=None, average_attn_weights=None, is_causal=False: -1 # noqa: B950
),
@ -968,7 +968,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.nn.functional.mish: lambda input, inplace=False: -1,
torch.nn.functional.scaled_dot_product_attention: lambda query, key, value, attn_mask=None, dropout_p=0.0: -1,
torch.nn.functional.smooth_l1_loss: lambda input, target, size_average=None, reduce=None, reduction="mean", beta=1.0: -1, # noqa: B950
torch.nn.functional.huber_loss: lambda input, target, reduction="mean", delta=1.0: -1,
torch.nn.functional.huber_loss: lambda input, target, reduction="mean", delta=1.0, weight=None: -1,
torch.nn.functional.soft_margin_loss: lambda input, target, size_average=None, reduce=None, reduction="mean": -1, # noqa: B950
torch.nn.functional.softmax: lambda input, dim=None, _stacklevel=3, dtype=None: -1,
torch.nn.functional.softmin: lambda input, dim=None, _stacklevel=3, dtype=None: -1,