mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add Weighted Loss Functions to PyTorch : WMSE, WMAE, and Weighted Huber Loss (#132049)
#### Summary This pull request introduces new weighted loss functions to the PyTorch library: `weighted_huber_loss`, `wmse_loss`, and `wmae_loss`. These functions allow for precise control over the influence of each sample during training, important for imbalanced data or when certain samples are more significant than others. #### Changes - **`weighted_huber_loss`**: Huber loss modified to incorporate weights, providing a balance between L1 and L2 loss based on the `delta` parameter. - **`wmse_loss`** (Weighted Mean Squared Error): Applies weights to the standard MSE loss, useful for emphasizing certain samples in regression tasks. - **`wmae_loss`** (Weighted Mean Absolute Error): Adjusts MAE loss calculation by including weights, ideal for datasets with outliers. #### Code Details - **Input Validation**: Ensures `input`, `target`, and `weights` tensors match in size to prevent broadcasting errors. - **Reduction Options**: Supports `none`, `mean`, and `sum` reductions to suit various computational needs. - **Backward Compatibility**: Maintains support for deprecated arguments `size_average` and `reduce`, while encouraging use of the `reduction` argument. #### Usage Example ```python import torch input = torch.tensor([0.5, 2.5, 2.0], dtype=torch.float32) target = torch.tensor([0.0, 2.0, 1.5], dtype=torch.float32) weights = torch.tensor([1.0, 0.5, 1.5], dtype=torch.float32) loss = weighted_huber_loss(input, target, weights, delta=1.0) print(loss) ``` --- Feedback on these implementations is welcome; please let me know if further modifications are required. Resolves #132465 Pull Request resolved: https://github.com/pytorch/pytorch/pull/132049 Approved by: https://github.com/mikaylagawarecki Co-authored-by: mikaylagawarecki <mikaylagawarecki@gmail.com>
This commit is contained in:
parent
82e74ad40e
commit
c1e7d85ce6
|
|
@ -2577,6 +2577,30 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
|
|||
self.assertEqual(len(w), 1)
|
||||
self.assertIn('Please ensure they have the same size.', str(w[0]))
|
||||
|
||||
def test_weighted_mse_loss(self):
|
||||
inputs = torch.tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True)
|
||||
targets = torch.tensor([1.5, 2.5, 3.5, 4.5])
|
||||
weight = torch.tensor([1.0, 2.0, 3.0, 4.0])
|
||||
loss = F.mse_loss(inputs, targets, weight=weight, reduction='mean')
|
||||
expected_loss = torch.tensor(0.25)
|
||||
self.assertTrue(torch.isclose(loss, expected_loss), f"Expected {expected_loss}, but got {loss}")
|
||||
|
||||
def test_weighted_l1_loss_with_weights(self):
|
||||
inputs = torch.tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True)
|
||||
targets = torch.tensor([1.5, 2.5, 3.5, 4.5])
|
||||
weight = torch.tensor([1.0, 2.0, 3.0, 4.0])
|
||||
loss = F.l1_loss(inputs, targets, weight=weight, reduction='mean')
|
||||
expected_loss = torch.tensor(0.5)
|
||||
self.assertTrue(torch.isclose(loss, expected_loss), f"Expected {expected_loss}, but got {loss}")
|
||||
|
||||
def test_weighted_huber_loss(self):
|
||||
inputs = torch.tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True)
|
||||
targets = torch.tensor([1.5, 2.5, 3.5, 4.5])
|
||||
weight = torch.tensor([1.0, 2.0, 3.0, 4.0])
|
||||
loss = F.huber_loss(input=inputs, target=targets, weight=weight, reduction='mean', delta=1.0)
|
||||
expected_loss = torch.tensor(0.25)
|
||||
print(torch.isclose(loss, expected_loss, atol=1e-6), f"Expected {expected_loss}, but got {loss}")
|
||||
|
||||
def test_gaussian_nll_loss_broadcasting(self):
|
||||
input = torch.tensor([[0.5, 1.5, 2.5], [2., 4., 6.]])
|
||||
target_full = torch.tensor([[1., 2., 3.], [1., 2., 3.]])
|
||||
|
|
|
|||
|
|
@ -3693,8 +3693,11 @@ def huber_loss(
|
|||
target: Tensor,
|
||||
reduction: str = "mean",
|
||||
delta: float = 1.0,
|
||||
weight: Optional[Tensor] = None,
|
||||
) -> Tensor:
|
||||
r"""Compute the Huber loss.
|
||||
r"""huber_loss(input, target, reduction='mean', delta=1.0, weight=None) -> Tensor
|
||||
|
||||
Computes the Huber loss, with optional weighting.
|
||||
|
||||
Function uses a squared term if the absolute
|
||||
element-wise error falls below delta and a delta-scaled L1 term otherwise.
|
||||
|
|
@ -3702,17 +3705,30 @@ def huber_loss(
|
|||
When delta equals 1, this loss is equivalent to SmoothL1Loss.
|
||||
In general, Huber loss differs from SmoothL1Loss by a factor of delta (AKA beta in Smooth L1).
|
||||
|
||||
See :class:`~torch.nn.HuberLoss` for details.
|
||||
Args:
|
||||
input (Tensor): Predicted values.
|
||||
target (Tensor): Ground truth values.
|
||||
reduction (str, optional): Specifies the reduction to apply to the output:
|
||||
'none' | 'mean' | 'sum'. 'mean': the mean of the output is taken.
|
||||
'sum': the output will be summed. 'none': no reduction will be applied.
|
||||
Default: 'mean'.
|
||||
delta (float, optional): The threshold at which to change between delta-scaled L1 and L2 loss. Default: 1.0.
|
||||
weight (Tensor, optional): Weights for each sample. Default: None.
|
||||
|
||||
Returns:
|
||||
Tensor: Huber loss (optionally weighted).
|
||||
"""
|
||||
if has_torch_function_variadic(input, target):
|
||||
if has_torch_function_variadic(input, target, weight):
|
||||
return handle_torch_function(
|
||||
huber_loss,
|
||||
(input, target),
|
||||
(input, target, weight),
|
||||
input,
|
||||
target,
|
||||
reduction=reduction,
|
||||
delta=delta,
|
||||
weight=weight,
|
||||
)
|
||||
|
||||
if not (target.size() == input.size()):
|
||||
warnings.warn(
|
||||
f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). "
|
||||
|
|
@ -3722,9 +3738,34 @@ def huber_loss(
|
|||
)
|
||||
|
||||
expanded_input, expanded_target = torch.broadcast_tensors(input, target)
|
||||
return torch._C._nn.huber_loss(
|
||||
expanded_input, expanded_target, _Reduction.get_enum(reduction), delta
|
||||
)
|
||||
|
||||
if weight is None:
|
||||
# Use the optimized C++ backend for standard Huber loss
|
||||
return torch._C._nn.huber_loss(
|
||||
expanded_input, expanded_target, _Reduction.get_enum(reduction), delta
|
||||
)
|
||||
else:
|
||||
if weight.size() != input.size():
|
||||
raise ValueError("Weights and input must have the same size.")
|
||||
|
||||
# Calculate the unweighted loss first
|
||||
unweighted_loss = torch._C._nn.huber_loss(
|
||||
expanded_input, expanded_target, _Reduction.get_enum("none"), delta
|
||||
)
|
||||
|
||||
# Apply weight to the unweighted loss
|
||||
weighted_loss = unweighted_loss * weight
|
||||
|
||||
if reduction == "none":
|
||||
return weighted_loss
|
||||
elif reduction == "sum":
|
||||
return torch.sum(weighted_loss)
|
||||
elif reduction == "mean":
|
||||
return weighted_loss.mean()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid reduction mode: {reduction}. Expected one of 'none', 'mean', 'sum'."
|
||||
)
|
||||
|
||||
|
||||
def l1_loss(
|
||||
|
|
@ -3733,6 +3774,7 @@ def l1_loss(
|
|||
size_average: Optional[bool] = None,
|
||||
reduce: Optional[bool] = None,
|
||||
reduction: str = "mean",
|
||||
weight: Optional[Tensor] = None,
|
||||
) -> Tensor: # noqa: D400,D402
|
||||
r"""l1_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor
|
||||
|
||||
|
|
@ -3743,7 +3785,7 @@ def l1_loss(
|
|||
if has_torch_function_variadic(input, target):
|
||||
return handle_torch_function(
|
||||
l1_loss,
|
||||
(input, target),
|
||||
(input, target, weight),
|
||||
input,
|
||||
target,
|
||||
size_average=size_average,
|
||||
|
|
@ -3761,9 +3803,28 @@ def l1_loss(
|
|||
reduction = _Reduction.legacy_get_string(size_average, reduce)
|
||||
|
||||
expanded_input, expanded_target = torch.broadcast_tensors(input, target)
|
||||
return torch._C._nn.l1_loss(
|
||||
expanded_input, expanded_target, _Reduction.get_enum(reduction)
|
||||
)
|
||||
|
||||
if weight is not None:
|
||||
if weight.size() != input.size():
|
||||
raise ValueError("Weights and input must have the same size.")
|
||||
|
||||
absolute_errors = torch.abs(expanded_input - expanded_target)
|
||||
weighted_absolute_errors = absolute_errors * weight
|
||||
|
||||
if reduction == "none":
|
||||
return weighted_absolute_errors
|
||||
elif reduction == "sum":
|
||||
return torch.sum(weighted_absolute_errors)
|
||||
elif reduction == "mean":
|
||||
return torch.sum(weighted_absolute_errors) / torch.sum(weight)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid reduction mode: {reduction}. Expected one of 'none', 'mean', 'sum'."
|
||||
)
|
||||
else:
|
||||
return torch._C._nn.l1_loss(
|
||||
expanded_input, expanded_target, _Reduction.get_enum(reduction)
|
||||
)
|
||||
|
||||
|
||||
def mse_loss(
|
||||
|
|
@ -3772,22 +3833,38 @@ def mse_loss(
|
|||
size_average: Optional[bool] = None,
|
||||
reduce: Optional[bool] = None,
|
||||
reduction: str = "mean",
|
||||
) -> Tensor: # noqa: D400,D402
|
||||
r"""mse_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor
|
||||
weight: Optional[Tensor] = None,
|
||||
) -> Tensor:
|
||||
r"""mse_loss(input, target, size_average=None, reduce=None, reduction='mean', weight=None) -> Tensor
|
||||
|
||||
Measures the element-wise mean squared error.
|
||||
See :class:`~torch.nn.MSELoss` for details.
|
||||
Measures the element-wise mean squared error, with optional weighting.
|
||||
|
||||
Args:
|
||||
input (Tensor): Predicted values.
|
||||
target (Tensor): Ground truth values.
|
||||
size_average (bool, optional): Deprecated (use reduction).
|
||||
reduce (bool, optional): Deprecated (use reduction).
|
||||
reduction (str, optional): Specifies the reduction to apply to the output:
|
||||
'none' | 'mean' | 'sum'. 'mean': the mean of the output is taken.
|
||||
'sum': the output will be summed. 'none': no reduction will be applied.
|
||||
Default: 'mean'.
|
||||
weight (Tensor, optional): Weights for each sample. Default: None.
|
||||
|
||||
Returns:
|
||||
Tensor: Mean Squared Error loss (optionally weighted).
|
||||
"""
|
||||
if has_torch_function_variadic(input, target):
|
||||
if has_torch_function_variadic(input, target, weight):
|
||||
return handle_torch_function(
|
||||
mse_loss,
|
||||
(input, target),
|
||||
(input, target, weight),
|
||||
input,
|
||||
target,
|
||||
size_average=size_average,
|
||||
reduce=reduce,
|
||||
reduction=reduction,
|
||||
weight=weight,
|
||||
)
|
||||
|
||||
if not (target.size() == input.size()):
|
||||
warnings.warn(
|
||||
f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). "
|
||||
|
|
@ -3795,13 +3872,34 @@ def mse_loss(
|
|||
"Please ensure they have the same size.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if size_average is not None or reduce is not None:
|
||||
reduction = _Reduction.legacy_get_string(size_average, reduce)
|
||||
|
||||
expanded_input, expanded_target = torch.broadcast_tensors(input, target)
|
||||
return torch._C._nn.mse_loss(
|
||||
expanded_input, expanded_target, _Reduction.get_enum(reduction)
|
||||
)
|
||||
|
||||
if weight is not None:
|
||||
if weight.size() != input.size():
|
||||
raise ValueError("Weights and input must have the same size.")
|
||||
|
||||
# Perform weighted MSE loss manually
|
||||
squared_errors = torch.pow(expanded_input - expanded_target, 2)
|
||||
weighted_squared_errors = squared_errors * weight
|
||||
|
||||
if reduction == "none":
|
||||
return weighted_squared_errors
|
||||
elif reduction == "sum":
|
||||
return torch.sum(weighted_squared_errors)
|
||||
elif reduction == "mean":
|
||||
return torch.sum(weighted_squared_errors) / torch.sum(weight)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid reduction mode: {reduction}. Expected one of 'none', 'mean', 'sum'."
|
||||
)
|
||||
else:
|
||||
return torch._C._nn.mse_loss(
|
||||
expanded_input, expanded_target, _Reduction.get_enum(reduction)
|
||||
)
|
||||
|
||||
|
||||
def margin_ranking_loss(
|
||||
|
|
|
|||
|
|
@ -901,7 +901,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
lambda input, size=None, scale_factor=None, mode="nearest", align_corners=None, recompute_scale_factor=None, antialias=False: -1 # noqa: B950
|
||||
),
|
||||
torch.nn.functional.kl_div: lambda input, target, size_average=None, reduce=None, reduction="mean", log_target=False: -1, # noqa: B950
|
||||
torch.nn.functional.l1_loss: lambda input, target, size_average=None, reduce=None, reduction="mean": -1,
|
||||
torch.nn.functional.l1_loss: lambda input, target, size_average=None, reduce=None, reduction="mean", weight=None: -1,
|
||||
torch.nn.functional.layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1,
|
||||
torch.nn.functional.leaky_relu: lambda input, negative_slope=0.01, inplace=False: -1,
|
||||
torch.nn.functional.linear: lambda input, weight, bias=None: -1,
|
||||
|
|
@ -935,7 +935,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch.nn.functional.max_unpool1d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1, # noqa: B950
|
||||
torch.nn.functional.max_unpool2d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1, # noqa: B950
|
||||
torch.nn.functional.max_unpool3d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1, # noqa: B950
|
||||
torch.nn.functional.mse_loss: lambda input, target, size_average=None, reduce=None, reduction="mean": -1,
|
||||
torch.nn.functional.mse_loss: lambda input, target, size_average=None, reduce=None, reduction="mean", weight=None: -1,
|
||||
torch.nn.functional.multi_head_attention_forward: (
|
||||
lambda query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training=True, key_padding_mask=None, need_weights=True, attn_mask=None, use_separate_proj_weight=False, q_proj_weight=None, k_proj_weight=None, v_proj_weight=None, static_k=None, static_v=None, average_attn_weights=None, is_causal=False: -1 # noqa: B950
|
||||
),
|
||||
|
|
@ -968,7 +968,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch.nn.functional.mish: lambda input, inplace=False: -1,
|
||||
torch.nn.functional.scaled_dot_product_attention: lambda query, key, value, attn_mask=None, dropout_p=0.0: -1,
|
||||
torch.nn.functional.smooth_l1_loss: lambda input, target, size_average=None, reduce=None, reduction="mean", beta=1.0: -1, # noqa: B950
|
||||
torch.nn.functional.huber_loss: lambda input, target, reduction="mean", delta=1.0: -1,
|
||||
torch.nn.functional.huber_loss: lambda input, target, reduction="mean", delta=1.0, weight=None: -1,
|
||||
torch.nn.functional.soft_margin_loss: lambda input, target, size_average=None, reduce=None, reduction="mean": -1, # noqa: B950
|
||||
torch.nn.functional.softmax: lambda input, dim=None, _stacklevel=3, dtype=None: -1,
|
||||
torch.nn.functional.softmin: lambda input, dim=None, _stacklevel=3, dtype=None: -1,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user