[primTorch] Add ref for triplet_margin_loss, improve triplet_margin_with_distance_loss (#85614)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85614
Approved by: https://github.com/lezcano, https://github.com/mruberry
This commit is contained in:
Nikita Karetnikov 2022-10-12 11:20:04 +02:00 committed by PyTorch MergeBot
parent ce56ee11fd
commit d56017a14f
7 changed files with 198 additions and 40 deletions

View File

@ -157,15 +157,17 @@ Tensor triplet_margin_loss(const Tensor& anchor, const Tensor& positive, const T
auto n_dim = negative.dim();
TORCH_CHECK(
a_dim == p_dim && p_dim == n_dim,
"All inputs should have same dimension but got ",
a_dim,
"D, ",
p_dim,
"D and ",
n_dim,
"D inputs.")
"The anchor, positive, and negative tensors are expected to have "
"the same number of dimensions, but got: anchor ", a_dim, "D, "
"positive ", p_dim, "D, and negative ", n_dim, "D inputs")
auto dist_pos = at::pairwise_distance(anchor, positive, p, eps);
auto dist_neg = at::pairwise_distance(anchor, negative, p, eps);
// The distance swap is described in the paper "Learning shallow
// convolutional feature descriptors with triplet losses" by V. Balntas, E.
// Riba et al. If True, and if the positive example is closer to the
// negative example than the anchor is, swaps the positive example and the
// anchor in the loss computation.
if (swap) {
auto dist_swap = at::pairwise_distance(positive, negative, p, eps);
dist_neg = at::min(dist_neg, dist_swap);

View File

@ -9202,21 +9202,6 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
self.assertEqual(F.triplet_margin_loss(input1, input2, input3, swap=True, reduction='none'),
loss_reference_fns['TripletMarginLoss'](input1, input2, input3, swap=True, reduction='none'))
def test_triplet_margin_loss_invalid(self):
input1 = torch.randn(5, 10, requires_grad=True)
input2 = torch.randn(5, 10, requires_grad=True)
input3 = torch.randn(5, 10, requires_grad=True)
input_1d = torch.randn(10, requires_grad=True)
with self.assertRaisesRegex(RuntimeError, "All inputs should have same dimension"):
F.triplet_margin_loss(input1, input2, input_1d)
with self.assertRaisesRegex(RuntimeError, "All inputs should have same dimension"):
F.triplet_margin_loss(input1, input_1d, input3)
with self.assertRaisesRegex(RuntimeError, "All inputs should have same dimension"):
F.triplet_margin_loss(input_1d, input2, input3)
def test_pointwise_loss_target_grad_none_reduction(self):
i = torch.randn(5, 10)
t = torch.randn(5, 10, requires_grad=True)

View File

@ -1623,6 +1623,7 @@ class TestRefsOpsInfo(TestCase):
'_refs.broadcast_shapes',
'_refs.broadcast_tensors',
'_refs.nn.functional.tanhshrink',
'_refs.nn.functional.triplet_margin_loss',
'_refs.rfloordiv',
'_refs.rtruediv',
'_refs.rpow',

View File

@ -114,6 +114,7 @@ __all__ = [
"bitwise_or",
"bitwise_right_shift",
"bitwise_xor",
"clamp_min",
# "complex",
"copysign",
"div",

View File

@ -1,4 +1,4 @@
from typing import Optional, Union
from typing import Callable, Optional, Union
import torch
@ -46,6 +46,7 @@ __all__ = [
"softshrink",
"tanhshrink",
"threshold",
"triplet_margin_loss",
"glu",
"pairwise_distance",
"pdist",
@ -362,7 +363,8 @@ def l1_loss(
Reference implementation of torch.nn.functional.l1_loss
"""
if size_average is not None or reduce is not None:
# TODO: raise exception instead of converting value
# TODO: Raise exception instead of converting value. This is only for
# primTorch since it can drop support for deprecated arguments.
# msg = "size_average and reduce args are deprecated, please use reduction argument."
reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce)
_check_reduction_value(reduction)
@ -406,7 +408,8 @@ def mse_loss(
reduction: str = "mean",
) -> TensorLikeType:
if size_average is not None or reduce is not None:
# TODO: raise exception instead of converting value
# TODO: Raise exception instead of converting value. This is only for
# primTorch since it can drop support for deprecated arguments.
# msg = "size_average and reduce args are deprecated, please use reduction argument."
reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce)
_check_reduction_value(reduction)
@ -501,6 +504,84 @@ def threshold(
return torch.where(a <= threshold, value, a)
# CompositeImplicitAutograd - don't register decomp
# No elementwise type promotion - core op doesn't explicitly type promote
def triplet_margin_loss(
anchor: TensorLikeType,
positive: TensorLikeType,
negative: TensorLikeType,
margin: float = 1.0,
p: float = 2,
eps: float = 1e-6,
swap: bool = False,
size_average: Optional[bool] = None,
reduce: Optional[bool] = None,
reduction: str = "mean",
) -> TensorLikeType:
if size_average is not None or reduce is not None:
# TODO: Raise exception instead of converting value. This is only for
# primTorch since it can drop support for deprecated arguments.
# msg = "size_average and reduce args are deprecated, please use reduction argument."
reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce)
# torch.nn.functional.triplet_margin_with_distance_loss has no ref defined
# since it's a pure Python implementation. Use this helper instead.
return _triplet_margin_with_distance_loss(
anchor=anchor,
positive=positive,
negative=negative,
distance_function=lambda x, y: torch.pairwise_distance(x, y, p, eps),
margin=margin,
swap=swap,
reduction=reduction,
)
# Pure Python impl - don't register decomp and don't add a ref. Defined as a
# helper here since triplet_margin_loss can be nicely implemented with it.
def _triplet_margin_with_distance_loss(
anchor: TensorLikeType,
positive: TensorLikeType,
negative: TensorLikeType,
*,
distance_function: Optional[
Callable[[TensorLikeType, TensorLikeType], TensorLikeType]
] = None,
margin: float = 1.0,
swap: bool = False,
reduction: str = "mean",
) -> TensorLikeType:
_check_reduction_value(reduction)
a_dim = anchor.ndim
p_dim = positive.ndim
n_dim = negative.ndim
check(
a_dim == p_dim and p_dim == n_dim,
lambda: (
f"The anchor, positive, and negative tensors are expected to have "
f"the same number of dimensions, but got: anchor {a_dim}D, "
f"positive {p_dim}D, and negative {n_dim}D inputs"
),
)
if distance_function is None:
distance_function = torch.pairwise_distance
dist_pos = distance_function(anchor, positive)
dist_neg = distance_function(anchor, negative)
# The distance swap is described in the paper "Learning shallow
# convolutional feature descriptors with triplet losses" by V. Balntas, E.
# Riba et al. If True, and if the positive example is closer to the
# negative example than the anchor is, swaps the positive example and the
# anchor in the loss computation.
if swap:
dist_swap = distance_function(positive, negative)
dist_neg = torch.minimum(dist_neg, dist_swap)
loss = torch.clamp_min(margin + dist_pos - dist_neg, 0)
return _apply_loss_reduction(loss, reduction)
@register_decomposition(torch.ops.aten.hardtanh)
@elementwise_unary_scalar_wrapper
@elementwise_type_promotion_wrapper(
@ -582,7 +663,8 @@ def poisson_nll_loss(
Reference implementation of torch.nn.functional.poisson_nll_loss
"""
if size_average is not None or reduce is not None:
# TODO: raise exception instead of converting value
# TODO: Raise exception instead of converting value. This is only for
# primTorch since it can drop support for deprecated arguments.
# msg = "size_average and reduce args are deprecated, please use reduction argument."
reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce)
_check_reduction_value(reduction)

View File

@ -4583,24 +4583,43 @@ def triplet_margin_with_distance_loss(
reduction=reduction,
)
distance_function = distance_function if distance_function is not None else pairwise_distance
# Check validity of reduction mode
if reduction not in ("mean", "sum", "none"):
raise ValueError(f"{reduction} is not a valid value for reduction")
positive_dist = distance_function(anchor, positive)
negative_dist = distance_function(anchor, negative)
# Check dimensions
a_dim = anchor.ndim
p_dim = positive.ndim
n_dim = negative.ndim
if not (a_dim == p_dim and p_dim == n_dim):
raise RuntimeError(
(f"The anchor, positive, and negative tensors are expected to have "
f"the same number of dimensions, but got: anchor {a_dim}D, "
f"positive {p_dim}D, and negative {n_dim}D inputs"))
# Calculate loss
if distance_function is None:
distance_function = torch.pairwise_distance
dist_pos = distance_function(anchor, positive)
dist_neg = distance_function(anchor, negative)
# The distance swap is described in the paper "Learning shallow
# convolutional feature descriptors with triplet losses" by V. Balntas, E.
# Riba et al. If True, and if the positive example is closer to the
# negative example than the anchor is, swaps the positive example and the
# anchor in the loss computation.
if swap:
swap_dist = distance_function(positive, negative)
negative_dist = torch.min(negative_dist, swap_dist)
dist_swap = distance_function(positive, negative)
dist_neg = torch.minimum(dist_neg, dist_swap)
loss = torch.clamp_min(margin + dist_pos - dist_neg, 0)
output = torch.clamp(positive_dist - negative_dist + margin, min=0.0)
reduction_enum = _Reduction.get_enum(reduction)
if reduction_enum == 1:
return output.mean()
elif reduction_enum == 2:
return output.sum()
else:
return output
# Apply reduction
if reduction == "sum":
return torch.sum(loss)
elif reduction == "mean":
return torch.mean(loss)
else: # reduction == "none"
return loss
def normalize(input: Tensor, p: float = 2.0, dim: int = 1, eps: float = 1e-12, out: Optional[Tensor] = None) -> Tensor:

View File

@ -7386,6 +7386,58 @@ def sample_inputs_triplet_margin_loss(op_info, device, dtype, requires_grad, wit
kwargs["distance_function"] = torch.nn.PairwiseDistance()
yield SampleInput(input, args=args, kwargs=kwargs)
def error_inputs_triplet_margin_loss(op_info, device, **kwargs):
make_input = partial(make_tensor, device=device, dtype=torch.float32)
samples = (
# input, args, kwargs, error_type, error_regex
# invalid reduction
(make_input(3, 4), (make_input(3, 4), make_input(3, 4)),
dict(reduction="abc"),
ValueError, "abc is not a valid value for reduction"),
# shape mismatch
(make_input(3, 5), (make_input(3, 4), make_input(3, 4)),
dict(),
RuntimeError,
(r"The size of tensor a \(5\) must match the size of tensor b \(4\) "
r"at non-singleton dimension 1")),
(make_input(3, 4), (make_input(3, 5), make_input(3, 4)),
dict(),
RuntimeError,
(r"The size of tensor a \(4\) must match the size of tensor b \(5\) "
r"at non-singleton dimension 1")),
(make_input(3, 4), (make_input(3, 4), make_input(3, 5)),
dict(),
RuntimeError,
(r"The size of tensor a \(4\) must match the size of tensor b \(5\) "
r"at non-singleton dimension 1")),
# different dimensions
(make_input(3,), (make_input(3, 4), make_input(3, 4)),
dict(),
RuntimeError,
(r"The anchor, positive, and negative tensors are expected to have "
r"the same number of dimensions, but got: anchor 1D, positive 2D, "
r"and negative 2D inputs")),
(make_input(3, 4), (make_input(3,), make_input(3, 4)),
dict(),
RuntimeError,
(r"The anchor, positive, and negative tensors are expected to have "
r"the same number of dimensions, but got: anchor 2D, positive 1D, "
r"and negative 2D inputs")),
(make_input(3, 4), (make_input(3, 4), make_input(3,)),
dict(),
RuntimeError,
(r"The anchor, positive, and negative tensors are expected to have "
r"the same number of dimensions, but got: anchor 2D, positive 2D, "
r"and negative 1D inputs")),
)
for input, args, kwargs, error_type, error_regex in samples:
yield ErrorInput(SampleInput(input, args=args, kwargs=kwargs),
error_type=error_type, error_regex=error_regex)
def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_grad, **kwargs):
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
@ -12101,6 +12153,7 @@ op_db: List[OpInfo] = [
OpInfo(
"nn.functional.triplet_margin_loss",
sample_inputs_func=sample_inputs_triplet_margin_loss,
error_inputs_func=error_inputs_triplet_margin_loss,
dtypes=all_types_and_complex_and(torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16),
supports_out=False,
@ -12110,6 +12163,7 @@ op_db: List[OpInfo] = [
OpInfo(
"nn.functional.triplet_margin_with_distance_loss",
sample_inputs_func=partial(sample_inputs_triplet_margin_loss, with_distance=True),
error_inputs_func=error_inputs_triplet_margin_loss,
dtypes=all_types_and_complex_and(torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16),
supports_out=False,
@ -17572,6 +17626,20 @@ python_ref_db = [
torch_opinfo_name="clamp",
supports_nvfuser=False,
),
PythonRefInfo(
"_refs.nn.functional.triplet_margin_loss",
torch_opinfo_name="nn.functional.triplet_margin_loss",
supports_out=False,
# TODO: Uses minimum and clamp, which don't support nvfuser.
supports_nvfuser=False,
skips=(
# AssertionError: Tensor-likes are not close!
# Greatest absolute difference: 6.103515625e-05 at index (4,) (up to 1e-05 allowed)
# Greatest relative difference: 8.519846983548175e-06 at index (4,) (up to 1.3e-06 allowed)
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref',
dtypes=(torch.uint8,), device_type="cpu"),
)
),
#
# Data Conversion & Data Movement Opinfos
#