mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[primTorch] Add ref for triplet_margin_loss, improve triplet_margin_with_distance_loss (#85614)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85614 Approved by: https://github.com/lezcano, https://github.com/mruberry
This commit is contained in:
parent
ce56ee11fd
commit
d56017a14f
|
|
@ -157,15 +157,17 @@ Tensor triplet_margin_loss(const Tensor& anchor, const Tensor& positive, const T
|
||||||
auto n_dim = negative.dim();
|
auto n_dim = negative.dim();
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
a_dim == p_dim && p_dim == n_dim,
|
a_dim == p_dim && p_dim == n_dim,
|
||||||
"All inputs should have same dimension but got ",
|
"The anchor, positive, and negative tensors are expected to have "
|
||||||
a_dim,
|
"the same number of dimensions, but got: anchor ", a_dim, "D, "
|
||||||
"D, ",
|
"positive ", p_dim, "D, and negative ", n_dim, "D inputs")
|
||||||
p_dim,
|
|
||||||
"D and ",
|
|
||||||
n_dim,
|
|
||||||
"D inputs.")
|
|
||||||
auto dist_pos = at::pairwise_distance(anchor, positive, p, eps);
|
auto dist_pos = at::pairwise_distance(anchor, positive, p, eps);
|
||||||
auto dist_neg = at::pairwise_distance(anchor, negative, p, eps);
|
auto dist_neg = at::pairwise_distance(anchor, negative, p, eps);
|
||||||
|
// The distance swap is described in the paper "Learning shallow
|
||||||
|
// convolutional feature descriptors with triplet losses" by V. Balntas, E.
|
||||||
|
// Riba et al. If True, and if the positive example is closer to the
|
||||||
|
// negative example than the anchor is, swaps the positive example and the
|
||||||
|
// anchor in the loss computation.
|
||||||
if (swap) {
|
if (swap) {
|
||||||
auto dist_swap = at::pairwise_distance(positive, negative, p, eps);
|
auto dist_swap = at::pairwise_distance(positive, negative, p, eps);
|
||||||
dist_neg = at::min(dist_neg, dist_swap);
|
dist_neg = at::min(dist_neg, dist_swap);
|
||||||
|
|
|
||||||
|
|
@ -9202,21 +9202,6 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
|
||||||
self.assertEqual(F.triplet_margin_loss(input1, input2, input3, swap=True, reduction='none'),
|
self.assertEqual(F.triplet_margin_loss(input1, input2, input3, swap=True, reduction='none'),
|
||||||
loss_reference_fns['TripletMarginLoss'](input1, input2, input3, swap=True, reduction='none'))
|
loss_reference_fns['TripletMarginLoss'](input1, input2, input3, swap=True, reduction='none'))
|
||||||
|
|
||||||
def test_triplet_margin_loss_invalid(self):
|
|
||||||
input1 = torch.randn(5, 10, requires_grad=True)
|
|
||||||
input2 = torch.randn(5, 10, requires_grad=True)
|
|
||||||
input3 = torch.randn(5, 10, requires_grad=True)
|
|
||||||
input_1d = torch.randn(10, requires_grad=True)
|
|
||||||
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "All inputs should have same dimension"):
|
|
||||||
F.triplet_margin_loss(input1, input2, input_1d)
|
|
||||||
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "All inputs should have same dimension"):
|
|
||||||
F.triplet_margin_loss(input1, input_1d, input3)
|
|
||||||
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "All inputs should have same dimension"):
|
|
||||||
F.triplet_margin_loss(input_1d, input2, input3)
|
|
||||||
|
|
||||||
def test_pointwise_loss_target_grad_none_reduction(self):
|
def test_pointwise_loss_target_grad_none_reduction(self):
|
||||||
i = torch.randn(5, 10)
|
i = torch.randn(5, 10)
|
||||||
t = torch.randn(5, 10, requires_grad=True)
|
t = torch.randn(5, 10, requires_grad=True)
|
||||||
|
|
|
||||||
|
|
@ -1623,6 +1623,7 @@ class TestRefsOpsInfo(TestCase):
|
||||||
'_refs.broadcast_shapes',
|
'_refs.broadcast_shapes',
|
||||||
'_refs.broadcast_tensors',
|
'_refs.broadcast_tensors',
|
||||||
'_refs.nn.functional.tanhshrink',
|
'_refs.nn.functional.tanhshrink',
|
||||||
|
'_refs.nn.functional.triplet_margin_loss',
|
||||||
'_refs.rfloordiv',
|
'_refs.rfloordiv',
|
||||||
'_refs.rtruediv',
|
'_refs.rtruediv',
|
||||||
'_refs.rpow',
|
'_refs.rpow',
|
||||||
|
|
|
||||||
|
|
@ -114,6 +114,7 @@ __all__ = [
|
||||||
"bitwise_or",
|
"bitwise_or",
|
||||||
"bitwise_right_shift",
|
"bitwise_right_shift",
|
||||||
"bitwise_xor",
|
"bitwise_xor",
|
||||||
|
"clamp_min",
|
||||||
# "complex",
|
# "complex",
|
||||||
"copysign",
|
"copysign",
|
||||||
"div",
|
"div",
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Optional, Union
|
from typing import Callable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
@ -46,6 +46,7 @@ __all__ = [
|
||||||
"softshrink",
|
"softshrink",
|
||||||
"tanhshrink",
|
"tanhshrink",
|
||||||
"threshold",
|
"threshold",
|
||||||
|
"triplet_margin_loss",
|
||||||
"glu",
|
"glu",
|
||||||
"pairwise_distance",
|
"pairwise_distance",
|
||||||
"pdist",
|
"pdist",
|
||||||
|
|
@ -362,7 +363,8 @@ def l1_loss(
|
||||||
Reference implementation of torch.nn.functional.l1_loss
|
Reference implementation of torch.nn.functional.l1_loss
|
||||||
"""
|
"""
|
||||||
if size_average is not None or reduce is not None:
|
if size_average is not None or reduce is not None:
|
||||||
# TODO: raise exception instead of converting value
|
# TODO: Raise exception instead of converting value. This is only for
|
||||||
|
# primTorch since it can drop support for deprecated arguments.
|
||||||
# msg = "size_average and reduce args are deprecated, please use reduction argument."
|
# msg = "size_average and reduce args are deprecated, please use reduction argument."
|
||||||
reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce)
|
reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce)
|
||||||
_check_reduction_value(reduction)
|
_check_reduction_value(reduction)
|
||||||
|
|
@ -406,7 +408,8 @@ def mse_loss(
|
||||||
reduction: str = "mean",
|
reduction: str = "mean",
|
||||||
) -> TensorLikeType:
|
) -> TensorLikeType:
|
||||||
if size_average is not None or reduce is not None:
|
if size_average is not None or reduce is not None:
|
||||||
# TODO: raise exception instead of converting value
|
# TODO: Raise exception instead of converting value. This is only for
|
||||||
|
# primTorch since it can drop support for deprecated arguments.
|
||||||
# msg = "size_average and reduce args are deprecated, please use reduction argument."
|
# msg = "size_average and reduce args are deprecated, please use reduction argument."
|
||||||
reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce)
|
reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce)
|
||||||
_check_reduction_value(reduction)
|
_check_reduction_value(reduction)
|
||||||
|
|
@ -501,6 +504,84 @@ def threshold(
|
||||||
return torch.where(a <= threshold, value, a)
|
return torch.where(a <= threshold, value, a)
|
||||||
|
|
||||||
|
|
||||||
|
# CompositeImplicitAutograd - don't register decomp
|
||||||
|
# No elementwise type promotion - core op doesn't explicitly type promote
|
||||||
|
def triplet_margin_loss(
|
||||||
|
anchor: TensorLikeType,
|
||||||
|
positive: TensorLikeType,
|
||||||
|
negative: TensorLikeType,
|
||||||
|
margin: float = 1.0,
|
||||||
|
p: float = 2,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
swap: bool = False,
|
||||||
|
size_average: Optional[bool] = None,
|
||||||
|
reduce: Optional[bool] = None,
|
||||||
|
reduction: str = "mean",
|
||||||
|
) -> TensorLikeType:
|
||||||
|
if size_average is not None or reduce is not None:
|
||||||
|
# TODO: Raise exception instead of converting value. This is only for
|
||||||
|
# primTorch since it can drop support for deprecated arguments.
|
||||||
|
# msg = "size_average and reduce args are deprecated, please use reduction argument."
|
||||||
|
reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce)
|
||||||
|
|
||||||
|
# torch.nn.functional.triplet_margin_with_distance_loss has no ref defined
|
||||||
|
# since it's a pure Python implementation. Use this helper instead.
|
||||||
|
return _triplet_margin_with_distance_loss(
|
||||||
|
anchor=anchor,
|
||||||
|
positive=positive,
|
||||||
|
negative=negative,
|
||||||
|
distance_function=lambda x, y: torch.pairwise_distance(x, y, p, eps),
|
||||||
|
margin=margin,
|
||||||
|
swap=swap,
|
||||||
|
reduction=reduction,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Pure Python impl - don't register decomp and don't add a ref. Defined as a
|
||||||
|
# helper here since triplet_margin_loss can be nicely implemented with it.
|
||||||
|
def _triplet_margin_with_distance_loss(
|
||||||
|
anchor: TensorLikeType,
|
||||||
|
positive: TensorLikeType,
|
||||||
|
negative: TensorLikeType,
|
||||||
|
*,
|
||||||
|
distance_function: Optional[
|
||||||
|
Callable[[TensorLikeType, TensorLikeType], TensorLikeType]
|
||||||
|
] = None,
|
||||||
|
margin: float = 1.0,
|
||||||
|
swap: bool = False,
|
||||||
|
reduction: str = "mean",
|
||||||
|
) -> TensorLikeType:
|
||||||
|
_check_reduction_value(reduction)
|
||||||
|
|
||||||
|
a_dim = anchor.ndim
|
||||||
|
p_dim = positive.ndim
|
||||||
|
n_dim = negative.ndim
|
||||||
|
check(
|
||||||
|
a_dim == p_dim and p_dim == n_dim,
|
||||||
|
lambda: (
|
||||||
|
f"The anchor, positive, and negative tensors are expected to have "
|
||||||
|
f"the same number of dimensions, but got: anchor {a_dim}D, "
|
||||||
|
f"positive {p_dim}D, and negative {n_dim}D inputs"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if distance_function is None:
|
||||||
|
distance_function = torch.pairwise_distance
|
||||||
|
|
||||||
|
dist_pos = distance_function(anchor, positive)
|
||||||
|
dist_neg = distance_function(anchor, negative)
|
||||||
|
# The distance swap is described in the paper "Learning shallow
|
||||||
|
# convolutional feature descriptors with triplet losses" by V. Balntas, E.
|
||||||
|
# Riba et al. If True, and if the positive example is closer to the
|
||||||
|
# negative example than the anchor is, swaps the positive example and the
|
||||||
|
# anchor in the loss computation.
|
||||||
|
if swap:
|
||||||
|
dist_swap = distance_function(positive, negative)
|
||||||
|
dist_neg = torch.minimum(dist_neg, dist_swap)
|
||||||
|
loss = torch.clamp_min(margin + dist_pos - dist_neg, 0)
|
||||||
|
return _apply_loss_reduction(loss, reduction)
|
||||||
|
|
||||||
|
|
||||||
@register_decomposition(torch.ops.aten.hardtanh)
|
@register_decomposition(torch.ops.aten.hardtanh)
|
||||||
@elementwise_unary_scalar_wrapper
|
@elementwise_unary_scalar_wrapper
|
||||||
@elementwise_type_promotion_wrapper(
|
@elementwise_type_promotion_wrapper(
|
||||||
|
|
@ -582,7 +663,8 @@ def poisson_nll_loss(
|
||||||
Reference implementation of torch.nn.functional.poisson_nll_loss
|
Reference implementation of torch.nn.functional.poisson_nll_loss
|
||||||
"""
|
"""
|
||||||
if size_average is not None or reduce is not None:
|
if size_average is not None or reduce is not None:
|
||||||
# TODO: raise exception instead of converting value
|
# TODO: Raise exception instead of converting value. This is only for
|
||||||
|
# primTorch since it can drop support for deprecated arguments.
|
||||||
# msg = "size_average and reduce args are deprecated, please use reduction argument."
|
# msg = "size_average and reduce args are deprecated, please use reduction argument."
|
||||||
reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce)
|
reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce)
|
||||||
_check_reduction_value(reduction)
|
_check_reduction_value(reduction)
|
||||||
|
|
|
||||||
|
|
@ -4583,24 +4583,43 @@ def triplet_margin_with_distance_loss(
|
||||||
reduction=reduction,
|
reduction=reduction,
|
||||||
)
|
)
|
||||||
|
|
||||||
distance_function = distance_function if distance_function is not None else pairwise_distance
|
# Check validity of reduction mode
|
||||||
|
if reduction not in ("mean", "sum", "none"):
|
||||||
|
raise ValueError(f"{reduction} is not a valid value for reduction")
|
||||||
|
|
||||||
positive_dist = distance_function(anchor, positive)
|
# Check dimensions
|
||||||
negative_dist = distance_function(anchor, negative)
|
a_dim = anchor.ndim
|
||||||
|
p_dim = positive.ndim
|
||||||
|
n_dim = negative.ndim
|
||||||
|
if not (a_dim == p_dim and p_dim == n_dim):
|
||||||
|
raise RuntimeError(
|
||||||
|
(f"The anchor, positive, and negative tensors are expected to have "
|
||||||
|
f"the same number of dimensions, but got: anchor {a_dim}D, "
|
||||||
|
f"positive {p_dim}D, and negative {n_dim}D inputs"))
|
||||||
|
|
||||||
|
# Calculate loss
|
||||||
|
if distance_function is None:
|
||||||
|
distance_function = torch.pairwise_distance
|
||||||
|
|
||||||
|
dist_pos = distance_function(anchor, positive)
|
||||||
|
dist_neg = distance_function(anchor, negative)
|
||||||
|
# The distance swap is described in the paper "Learning shallow
|
||||||
|
# convolutional feature descriptors with triplet losses" by V. Balntas, E.
|
||||||
|
# Riba et al. If True, and if the positive example is closer to the
|
||||||
|
# negative example than the anchor is, swaps the positive example and the
|
||||||
|
# anchor in the loss computation.
|
||||||
if swap:
|
if swap:
|
||||||
swap_dist = distance_function(positive, negative)
|
dist_swap = distance_function(positive, negative)
|
||||||
negative_dist = torch.min(negative_dist, swap_dist)
|
dist_neg = torch.minimum(dist_neg, dist_swap)
|
||||||
|
loss = torch.clamp_min(margin + dist_pos - dist_neg, 0)
|
||||||
|
|
||||||
output = torch.clamp(positive_dist - negative_dist + margin, min=0.0)
|
# Apply reduction
|
||||||
|
if reduction == "sum":
|
||||||
reduction_enum = _Reduction.get_enum(reduction)
|
return torch.sum(loss)
|
||||||
if reduction_enum == 1:
|
elif reduction == "mean":
|
||||||
return output.mean()
|
return torch.mean(loss)
|
||||||
elif reduction_enum == 2:
|
else: # reduction == "none"
|
||||||
return output.sum()
|
return loss
|
||||||
else:
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def normalize(input: Tensor, p: float = 2.0, dim: int = 1, eps: float = 1e-12, out: Optional[Tensor] = None) -> Tensor:
|
def normalize(input: Tensor, p: float = 2.0, dim: int = 1, eps: float = 1e-12, out: Optional[Tensor] = None) -> Tensor:
|
||||||
|
|
|
||||||
|
|
@ -7386,6 +7386,58 @@ def sample_inputs_triplet_margin_loss(op_info, device, dtype, requires_grad, wit
|
||||||
kwargs["distance_function"] = torch.nn.PairwiseDistance()
|
kwargs["distance_function"] = torch.nn.PairwiseDistance()
|
||||||
yield SampleInput(input, args=args, kwargs=kwargs)
|
yield SampleInput(input, args=args, kwargs=kwargs)
|
||||||
|
|
||||||
|
def error_inputs_triplet_margin_loss(op_info, device, **kwargs):
|
||||||
|
make_input = partial(make_tensor, device=device, dtype=torch.float32)
|
||||||
|
|
||||||
|
samples = (
|
||||||
|
# input, args, kwargs, error_type, error_regex
|
||||||
|
# invalid reduction
|
||||||
|
(make_input(3, 4), (make_input(3, 4), make_input(3, 4)),
|
||||||
|
dict(reduction="abc"),
|
||||||
|
ValueError, "abc is not a valid value for reduction"),
|
||||||
|
|
||||||
|
# shape mismatch
|
||||||
|
(make_input(3, 5), (make_input(3, 4), make_input(3, 4)),
|
||||||
|
dict(),
|
||||||
|
RuntimeError,
|
||||||
|
(r"The size of tensor a \(5\) must match the size of tensor b \(4\) "
|
||||||
|
r"at non-singleton dimension 1")),
|
||||||
|
(make_input(3, 4), (make_input(3, 5), make_input(3, 4)),
|
||||||
|
dict(),
|
||||||
|
RuntimeError,
|
||||||
|
(r"The size of tensor a \(4\) must match the size of tensor b \(5\) "
|
||||||
|
r"at non-singleton dimension 1")),
|
||||||
|
(make_input(3, 4), (make_input(3, 4), make_input(3, 5)),
|
||||||
|
dict(),
|
||||||
|
RuntimeError,
|
||||||
|
(r"The size of tensor a \(4\) must match the size of tensor b \(5\) "
|
||||||
|
r"at non-singleton dimension 1")),
|
||||||
|
|
||||||
|
# different dimensions
|
||||||
|
(make_input(3,), (make_input(3, 4), make_input(3, 4)),
|
||||||
|
dict(),
|
||||||
|
RuntimeError,
|
||||||
|
(r"The anchor, positive, and negative tensors are expected to have "
|
||||||
|
r"the same number of dimensions, but got: anchor 1D, positive 2D, "
|
||||||
|
r"and negative 2D inputs")),
|
||||||
|
(make_input(3, 4), (make_input(3,), make_input(3, 4)),
|
||||||
|
dict(),
|
||||||
|
RuntimeError,
|
||||||
|
(r"The anchor, positive, and negative tensors are expected to have "
|
||||||
|
r"the same number of dimensions, but got: anchor 2D, positive 1D, "
|
||||||
|
r"and negative 2D inputs")),
|
||||||
|
(make_input(3, 4), (make_input(3, 4), make_input(3,)),
|
||||||
|
dict(),
|
||||||
|
RuntimeError,
|
||||||
|
(r"The anchor, positive, and negative tensors are expected to have "
|
||||||
|
r"the same number of dimensions, but got: anchor 2D, positive 2D, "
|
||||||
|
r"and negative 1D inputs")),
|
||||||
|
)
|
||||||
|
|
||||||
|
for input, args, kwargs, error_type, error_regex in samples:
|
||||||
|
yield ErrorInput(SampleInput(input, args=args, kwargs=kwargs),
|
||||||
|
error_type=error_type, error_regex=error_regex)
|
||||||
|
|
||||||
|
|
||||||
def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_grad, **kwargs):
|
def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_grad, **kwargs):
|
||||||
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||||
|
|
@ -12101,6 +12153,7 @@ op_db: List[OpInfo] = [
|
||||||
OpInfo(
|
OpInfo(
|
||||||
"nn.functional.triplet_margin_loss",
|
"nn.functional.triplet_margin_loss",
|
||||||
sample_inputs_func=sample_inputs_triplet_margin_loss,
|
sample_inputs_func=sample_inputs_triplet_margin_loss,
|
||||||
|
error_inputs_func=error_inputs_triplet_margin_loss,
|
||||||
dtypes=all_types_and_complex_and(torch.bfloat16),
|
dtypes=all_types_and_complex_and(torch.bfloat16),
|
||||||
dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16),
|
dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16),
|
||||||
supports_out=False,
|
supports_out=False,
|
||||||
|
|
@ -12110,6 +12163,7 @@ op_db: List[OpInfo] = [
|
||||||
OpInfo(
|
OpInfo(
|
||||||
"nn.functional.triplet_margin_with_distance_loss",
|
"nn.functional.triplet_margin_with_distance_loss",
|
||||||
sample_inputs_func=partial(sample_inputs_triplet_margin_loss, with_distance=True),
|
sample_inputs_func=partial(sample_inputs_triplet_margin_loss, with_distance=True),
|
||||||
|
error_inputs_func=error_inputs_triplet_margin_loss,
|
||||||
dtypes=all_types_and_complex_and(torch.bfloat16),
|
dtypes=all_types_and_complex_and(torch.bfloat16),
|
||||||
dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16),
|
dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16),
|
||||||
supports_out=False,
|
supports_out=False,
|
||||||
|
|
@ -17572,6 +17626,20 @@ python_ref_db = [
|
||||||
torch_opinfo_name="clamp",
|
torch_opinfo_name="clamp",
|
||||||
supports_nvfuser=False,
|
supports_nvfuser=False,
|
||||||
),
|
),
|
||||||
|
PythonRefInfo(
|
||||||
|
"_refs.nn.functional.triplet_margin_loss",
|
||||||
|
torch_opinfo_name="nn.functional.triplet_margin_loss",
|
||||||
|
supports_out=False,
|
||||||
|
# TODO: Uses minimum and clamp, which don't support nvfuser.
|
||||||
|
supports_nvfuser=False,
|
||||||
|
skips=(
|
||||||
|
# AssertionError: Tensor-likes are not close!
|
||||||
|
# Greatest absolute difference: 6.103515625e-05 at index (4,) (up to 1e-05 allowed)
|
||||||
|
# Greatest relative difference: 8.519846983548175e-06 at index (4,) (up to 1.3e-06 allowed)
|
||||||
|
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref',
|
||||||
|
dtypes=(torch.uint8,), device_type="cpu"),
|
||||||
|
)
|
||||||
|
),
|
||||||
#
|
#
|
||||||
# Data Conversion & Data Movement Opinfos
|
# Data Conversion & Data Movement Opinfos
|
||||||
#
|
#
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user