mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[primTorch] Add ref for triplet_margin_loss, improve triplet_margin_with_distance_loss (#85614)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85614 Approved by: https://github.com/lezcano, https://github.com/mruberry
This commit is contained in:
parent
ce56ee11fd
commit
d56017a14f
|
|
@ -157,15 +157,17 @@ Tensor triplet_margin_loss(const Tensor& anchor, const Tensor& positive, const T
|
|||
auto n_dim = negative.dim();
|
||||
TORCH_CHECK(
|
||||
a_dim == p_dim && p_dim == n_dim,
|
||||
"All inputs should have same dimension but got ",
|
||||
a_dim,
|
||||
"D, ",
|
||||
p_dim,
|
||||
"D and ",
|
||||
n_dim,
|
||||
"D inputs.")
|
||||
"The anchor, positive, and negative tensors are expected to have "
|
||||
"the same number of dimensions, but got: anchor ", a_dim, "D, "
|
||||
"positive ", p_dim, "D, and negative ", n_dim, "D inputs")
|
||||
|
||||
auto dist_pos = at::pairwise_distance(anchor, positive, p, eps);
|
||||
auto dist_neg = at::pairwise_distance(anchor, negative, p, eps);
|
||||
// The distance swap is described in the paper "Learning shallow
|
||||
// convolutional feature descriptors with triplet losses" by V. Balntas, E.
|
||||
// Riba et al. If True, and if the positive example is closer to the
|
||||
// negative example than the anchor is, swaps the positive example and the
|
||||
// anchor in the loss computation.
|
||||
if (swap) {
|
||||
auto dist_swap = at::pairwise_distance(positive, negative, p, eps);
|
||||
dist_neg = at::min(dist_neg, dist_swap);
|
||||
|
|
|
|||
|
|
@ -9202,21 +9202,6 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
|
|||
self.assertEqual(F.triplet_margin_loss(input1, input2, input3, swap=True, reduction='none'),
|
||||
loss_reference_fns['TripletMarginLoss'](input1, input2, input3, swap=True, reduction='none'))
|
||||
|
||||
def test_triplet_margin_loss_invalid(self):
|
||||
input1 = torch.randn(5, 10, requires_grad=True)
|
||||
input2 = torch.randn(5, 10, requires_grad=True)
|
||||
input3 = torch.randn(5, 10, requires_grad=True)
|
||||
input_1d = torch.randn(10, requires_grad=True)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "All inputs should have same dimension"):
|
||||
F.triplet_margin_loss(input1, input2, input_1d)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "All inputs should have same dimension"):
|
||||
F.triplet_margin_loss(input1, input_1d, input3)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "All inputs should have same dimension"):
|
||||
F.triplet_margin_loss(input_1d, input2, input3)
|
||||
|
||||
def test_pointwise_loss_target_grad_none_reduction(self):
|
||||
i = torch.randn(5, 10)
|
||||
t = torch.randn(5, 10, requires_grad=True)
|
||||
|
|
|
|||
|
|
@ -1623,6 +1623,7 @@ class TestRefsOpsInfo(TestCase):
|
|||
'_refs.broadcast_shapes',
|
||||
'_refs.broadcast_tensors',
|
||||
'_refs.nn.functional.tanhshrink',
|
||||
'_refs.nn.functional.triplet_margin_loss',
|
||||
'_refs.rfloordiv',
|
||||
'_refs.rtruediv',
|
||||
'_refs.rpow',
|
||||
|
|
|
|||
|
|
@ -114,6 +114,7 @@ __all__ = [
|
|||
"bitwise_or",
|
||||
"bitwise_right_shift",
|
||||
"bitwise_xor",
|
||||
"clamp_min",
|
||||
# "complex",
|
||||
"copysign",
|
||||
"div",
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Optional, Union
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
|
@ -46,6 +46,7 @@ __all__ = [
|
|||
"softshrink",
|
||||
"tanhshrink",
|
||||
"threshold",
|
||||
"triplet_margin_loss",
|
||||
"glu",
|
||||
"pairwise_distance",
|
||||
"pdist",
|
||||
|
|
@ -362,7 +363,8 @@ def l1_loss(
|
|||
Reference implementation of torch.nn.functional.l1_loss
|
||||
"""
|
||||
if size_average is not None or reduce is not None:
|
||||
# TODO: raise exception instead of converting value
|
||||
# TODO: Raise exception instead of converting value. This is only for
|
||||
# primTorch since it can drop support for deprecated arguments.
|
||||
# msg = "size_average and reduce args are deprecated, please use reduction argument."
|
||||
reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce)
|
||||
_check_reduction_value(reduction)
|
||||
|
|
@ -406,7 +408,8 @@ def mse_loss(
|
|||
reduction: str = "mean",
|
||||
) -> TensorLikeType:
|
||||
if size_average is not None or reduce is not None:
|
||||
# TODO: raise exception instead of converting value
|
||||
# TODO: Raise exception instead of converting value. This is only for
|
||||
# primTorch since it can drop support for deprecated arguments.
|
||||
# msg = "size_average and reduce args are deprecated, please use reduction argument."
|
||||
reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce)
|
||||
_check_reduction_value(reduction)
|
||||
|
|
@ -501,6 +504,84 @@ def threshold(
|
|||
return torch.where(a <= threshold, value, a)
|
||||
|
||||
|
||||
# CompositeImplicitAutograd - don't register decomp
|
||||
# No elementwise type promotion - core op doesn't explicitly type promote
|
||||
def triplet_margin_loss(
|
||||
anchor: TensorLikeType,
|
||||
positive: TensorLikeType,
|
||||
negative: TensorLikeType,
|
||||
margin: float = 1.0,
|
||||
p: float = 2,
|
||||
eps: float = 1e-6,
|
||||
swap: bool = False,
|
||||
size_average: Optional[bool] = None,
|
||||
reduce: Optional[bool] = None,
|
||||
reduction: str = "mean",
|
||||
) -> TensorLikeType:
|
||||
if size_average is not None or reduce is not None:
|
||||
# TODO: Raise exception instead of converting value. This is only for
|
||||
# primTorch since it can drop support for deprecated arguments.
|
||||
# msg = "size_average and reduce args are deprecated, please use reduction argument."
|
||||
reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce)
|
||||
|
||||
# torch.nn.functional.triplet_margin_with_distance_loss has no ref defined
|
||||
# since it's a pure Python implementation. Use this helper instead.
|
||||
return _triplet_margin_with_distance_loss(
|
||||
anchor=anchor,
|
||||
positive=positive,
|
||||
negative=negative,
|
||||
distance_function=lambda x, y: torch.pairwise_distance(x, y, p, eps),
|
||||
margin=margin,
|
||||
swap=swap,
|
||||
reduction=reduction,
|
||||
)
|
||||
|
||||
|
||||
# Pure Python impl - don't register decomp and don't add a ref. Defined as a
|
||||
# helper here since triplet_margin_loss can be nicely implemented with it.
|
||||
def _triplet_margin_with_distance_loss(
|
||||
anchor: TensorLikeType,
|
||||
positive: TensorLikeType,
|
||||
negative: TensorLikeType,
|
||||
*,
|
||||
distance_function: Optional[
|
||||
Callable[[TensorLikeType, TensorLikeType], TensorLikeType]
|
||||
] = None,
|
||||
margin: float = 1.0,
|
||||
swap: bool = False,
|
||||
reduction: str = "mean",
|
||||
) -> TensorLikeType:
|
||||
_check_reduction_value(reduction)
|
||||
|
||||
a_dim = anchor.ndim
|
||||
p_dim = positive.ndim
|
||||
n_dim = negative.ndim
|
||||
check(
|
||||
a_dim == p_dim and p_dim == n_dim,
|
||||
lambda: (
|
||||
f"The anchor, positive, and negative tensors are expected to have "
|
||||
f"the same number of dimensions, but got: anchor {a_dim}D, "
|
||||
f"positive {p_dim}D, and negative {n_dim}D inputs"
|
||||
),
|
||||
)
|
||||
|
||||
if distance_function is None:
|
||||
distance_function = torch.pairwise_distance
|
||||
|
||||
dist_pos = distance_function(anchor, positive)
|
||||
dist_neg = distance_function(anchor, negative)
|
||||
# The distance swap is described in the paper "Learning shallow
|
||||
# convolutional feature descriptors with triplet losses" by V. Balntas, E.
|
||||
# Riba et al. If True, and if the positive example is closer to the
|
||||
# negative example than the anchor is, swaps the positive example and the
|
||||
# anchor in the loss computation.
|
||||
if swap:
|
||||
dist_swap = distance_function(positive, negative)
|
||||
dist_neg = torch.minimum(dist_neg, dist_swap)
|
||||
loss = torch.clamp_min(margin + dist_pos - dist_neg, 0)
|
||||
return _apply_loss_reduction(loss, reduction)
|
||||
|
||||
|
||||
@register_decomposition(torch.ops.aten.hardtanh)
|
||||
@elementwise_unary_scalar_wrapper
|
||||
@elementwise_type_promotion_wrapper(
|
||||
|
|
@ -582,7 +663,8 @@ def poisson_nll_loss(
|
|||
Reference implementation of torch.nn.functional.poisson_nll_loss
|
||||
"""
|
||||
if size_average is not None or reduce is not None:
|
||||
# TODO: raise exception instead of converting value
|
||||
# TODO: Raise exception instead of converting value. This is only for
|
||||
# primTorch since it can drop support for deprecated arguments.
|
||||
# msg = "size_average and reduce args are deprecated, please use reduction argument."
|
||||
reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce)
|
||||
_check_reduction_value(reduction)
|
||||
|
|
|
|||
|
|
@ -4583,24 +4583,43 @@ def triplet_margin_with_distance_loss(
|
|||
reduction=reduction,
|
||||
)
|
||||
|
||||
distance_function = distance_function if distance_function is not None else pairwise_distance
|
||||
# Check validity of reduction mode
|
||||
if reduction not in ("mean", "sum", "none"):
|
||||
raise ValueError(f"{reduction} is not a valid value for reduction")
|
||||
|
||||
positive_dist = distance_function(anchor, positive)
|
||||
negative_dist = distance_function(anchor, negative)
|
||||
# Check dimensions
|
||||
a_dim = anchor.ndim
|
||||
p_dim = positive.ndim
|
||||
n_dim = negative.ndim
|
||||
if not (a_dim == p_dim and p_dim == n_dim):
|
||||
raise RuntimeError(
|
||||
(f"The anchor, positive, and negative tensors are expected to have "
|
||||
f"the same number of dimensions, but got: anchor {a_dim}D, "
|
||||
f"positive {p_dim}D, and negative {n_dim}D inputs"))
|
||||
|
||||
# Calculate loss
|
||||
if distance_function is None:
|
||||
distance_function = torch.pairwise_distance
|
||||
|
||||
dist_pos = distance_function(anchor, positive)
|
||||
dist_neg = distance_function(anchor, negative)
|
||||
# The distance swap is described in the paper "Learning shallow
|
||||
# convolutional feature descriptors with triplet losses" by V. Balntas, E.
|
||||
# Riba et al. If True, and if the positive example is closer to the
|
||||
# negative example than the anchor is, swaps the positive example and the
|
||||
# anchor in the loss computation.
|
||||
if swap:
|
||||
swap_dist = distance_function(positive, negative)
|
||||
negative_dist = torch.min(negative_dist, swap_dist)
|
||||
dist_swap = distance_function(positive, negative)
|
||||
dist_neg = torch.minimum(dist_neg, dist_swap)
|
||||
loss = torch.clamp_min(margin + dist_pos - dist_neg, 0)
|
||||
|
||||
output = torch.clamp(positive_dist - negative_dist + margin, min=0.0)
|
||||
|
||||
reduction_enum = _Reduction.get_enum(reduction)
|
||||
if reduction_enum == 1:
|
||||
return output.mean()
|
||||
elif reduction_enum == 2:
|
||||
return output.sum()
|
||||
else:
|
||||
return output
|
||||
# Apply reduction
|
||||
if reduction == "sum":
|
||||
return torch.sum(loss)
|
||||
elif reduction == "mean":
|
||||
return torch.mean(loss)
|
||||
else: # reduction == "none"
|
||||
return loss
|
||||
|
||||
|
||||
def normalize(input: Tensor, p: float = 2.0, dim: int = 1, eps: float = 1e-12, out: Optional[Tensor] = None) -> Tensor:
|
||||
|
|
|
|||
|
|
@ -7386,6 +7386,58 @@ def sample_inputs_triplet_margin_loss(op_info, device, dtype, requires_grad, wit
|
|||
kwargs["distance_function"] = torch.nn.PairwiseDistance()
|
||||
yield SampleInput(input, args=args, kwargs=kwargs)
|
||||
|
||||
def error_inputs_triplet_margin_loss(op_info, device, **kwargs):
|
||||
make_input = partial(make_tensor, device=device, dtype=torch.float32)
|
||||
|
||||
samples = (
|
||||
# input, args, kwargs, error_type, error_regex
|
||||
# invalid reduction
|
||||
(make_input(3, 4), (make_input(3, 4), make_input(3, 4)),
|
||||
dict(reduction="abc"),
|
||||
ValueError, "abc is not a valid value for reduction"),
|
||||
|
||||
# shape mismatch
|
||||
(make_input(3, 5), (make_input(3, 4), make_input(3, 4)),
|
||||
dict(),
|
||||
RuntimeError,
|
||||
(r"The size of tensor a \(5\) must match the size of tensor b \(4\) "
|
||||
r"at non-singleton dimension 1")),
|
||||
(make_input(3, 4), (make_input(3, 5), make_input(3, 4)),
|
||||
dict(),
|
||||
RuntimeError,
|
||||
(r"The size of tensor a \(4\) must match the size of tensor b \(5\) "
|
||||
r"at non-singleton dimension 1")),
|
||||
(make_input(3, 4), (make_input(3, 4), make_input(3, 5)),
|
||||
dict(),
|
||||
RuntimeError,
|
||||
(r"The size of tensor a \(4\) must match the size of tensor b \(5\) "
|
||||
r"at non-singleton dimension 1")),
|
||||
|
||||
# different dimensions
|
||||
(make_input(3,), (make_input(3, 4), make_input(3, 4)),
|
||||
dict(),
|
||||
RuntimeError,
|
||||
(r"The anchor, positive, and negative tensors are expected to have "
|
||||
r"the same number of dimensions, but got: anchor 1D, positive 2D, "
|
||||
r"and negative 2D inputs")),
|
||||
(make_input(3, 4), (make_input(3,), make_input(3, 4)),
|
||||
dict(),
|
||||
RuntimeError,
|
||||
(r"The anchor, positive, and negative tensors are expected to have "
|
||||
r"the same number of dimensions, but got: anchor 2D, positive 1D, "
|
||||
r"and negative 2D inputs")),
|
||||
(make_input(3, 4), (make_input(3, 4), make_input(3,)),
|
||||
dict(),
|
||||
RuntimeError,
|
||||
(r"The anchor, positive, and negative tensors are expected to have "
|
||||
r"the same number of dimensions, but got: anchor 2D, positive 2D, "
|
||||
r"and negative 1D inputs")),
|
||||
)
|
||||
|
||||
for input, args, kwargs, error_type, error_regex in samples:
|
||||
yield ErrorInput(SampleInput(input, args=args, kwargs=kwargs),
|
||||
error_type=error_type, error_regex=error_regex)
|
||||
|
||||
|
||||
def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_grad, **kwargs):
|
||||
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
|
@ -12101,6 +12153,7 @@ op_db: List[OpInfo] = [
|
|||
OpInfo(
|
||||
"nn.functional.triplet_margin_loss",
|
||||
sample_inputs_func=sample_inputs_triplet_margin_loss,
|
||||
error_inputs_func=error_inputs_triplet_margin_loss,
|
||||
dtypes=all_types_and_complex_and(torch.bfloat16),
|
||||
dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16),
|
||||
supports_out=False,
|
||||
|
|
@ -12110,6 +12163,7 @@ op_db: List[OpInfo] = [
|
|||
OpInfo(
|
||||
"nn.functional.triplet_margin_with_distance_loss",
|
||||
sample_inputs_func=partial(sample_inputs_triplet_margin_loss, with_distance=True),
|
||||
error_inputs_func=error_inputs_triplet_margin_loss,
|
||||
dtypes=all_types_and_complex_and(torch.bfloat16),
|
||||
dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16),
|
||||
supports_out=False,
|
||||
|
|
@ -17572,6 +17626,20 @@ python_ref_db = [
|
|||
torch_opinfo_name="clamp",
|
||||
supports_nvfuser=False,
|
||||
),
|
||||
PythonRefInfo(
|
||||
"_refs.nn.functional.triplet_margin_loss",
|
||||
torch_opinfo_name="nn.functional.triplet_margin_loss",
|
||||
supports_out=False,
|
||||
# TODO: Uses minimum and clamp, which don't support nvfuser.
|
||||
supports_nvfuser=False,
|
||||
skips=(
|
||||
# AssertionError: Tensor-likes are not close!
|
||||
# Greatest absolute difference: 6.103515625e-05 at index (4,) (up to 1e-05 allowed)
|
||||
# Greatest relative difference: 8.519846983548175e-06 at index (4,) (up to 1.3e-06 allowed)
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref',
|
||||
dtypes=(torch.uint8,), device_type="cpu"),
|
||||
)
|
||||
),
|
||||
#
|
||||
# Data Conversion & Data Movement Opinfos
|
||||
#
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user