diff --git a/aten/src/ATen/native/Loss.cpp b/aten/src/ATen/native/Loss.cpp index 027af18aadc..52569ba6b49 100644 --- a/aten/src/ATen/native/Loss.cpp +++ b/aten/src/ATen/native/Loss.cpp @@ -157,15 +157,17 @@ Tensor triplet_margin_loss(const Tensor& anchor, const Tensor& positive, const T auto n_dim = negative.dim(); TORCH_CHECK( a_dim == p_dim && p_dim == n_dim, - "All inputs should have same dimension but got ", - a_dim, - "D, ", - p_dim, - "D and ", - n_dim, - "D inputs.") + "The anchor, positive, and negative tensors are expected to have " + "the same number of dimensions, but got: anchor ", a_dim, "D, " + "positive ", p_dim, "D, and negative ", n_dim, "D inputs") + auto dist_pos = at::pairwise_distance(anchor, positive, p, eps); auto dist_neg = at::pairwise_distance(anchor, negative, p, eps); + // The distance swap is described in the paper "Learning shallow + // convolutional feature descriptors with triplet losses" by V. Balntas, E. + // Riba et al. If True, and if the positive example is closer to the + // negative example than the anchor is, swaps the positive example and the + // anchor in the loss computation. if (swap) { auto dist_swap = at::pairwise_distance(positive, negative, p, eps); dist_neg = at::min(dist_neg, dist_swap); diff --git a/test/test_nn.py b/test/test_nn.py index 084d5935145..681efab9e82 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -9202,21 +9202,6 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""") self.assertEqual(F.triplet_margin_loss(input1, input2, input3, swap=True, reduction='none'), loss_reference_fns['TripletMarginLoss'](input1, input2, input3, swap=True, reduction='none')) - def test_triplet_margin_loss_invalid(self): - input1 = torch.randn(5, 10, requires_grad=True) - input2 = torch.randn(5, 10, requires_grad=True) - input3 = torch.randn(5, 10, requires_grad=True) - input_1d = torch.randn(10, requires_grad=True) - - with self.assertRaisesRegex(RuntimeError, "All inputs should have same dimension"): - F.triplet_margin_loss(input1, input2, input_1d) - - with self.assertRaisesRegex(RuntimeError, "All inputs should have same dimension"): - F.triplet_margin_loss(input1, input_1d, input3) - - with self.assertRaisesRegex(RuntimeError, "All inputs should have same dimension"): - F.triplet_margin_loss(input_1d, input2, input3) - def test_pointwise_loss_target_grad_none_reduction(self): i = torch.randn(5, 10) t = torch.randn(5, 10, requires_grad=True) diff --git a/test/test_ops.py b/test/test_ops.py index 4ecbc59f5d7..1f6dbd15c18 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1623,6 +1623,7 @@ class TestRefsOpsInfo(TestCase): '_refs.broadcast_shapes', '_refs.broadcast_tensors', '_refs.nn.functional.tanhshrink', + '_refs.nn.functional.triplet_margin_loss', '_refs.rfloordiv', '_refs.rtruediv', '_refs.rpow', diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 9a0d391b8a1..b6ddad311ba 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -114,6 +114,7 @@ __all__ = [ "bitwise_or", "bitwise_right_shift", "bitwise_xor", + "clamp_min", # "complex", "copysign", "div", diff --git a/torch/_refs/nn/functional/__init__.py b/torch/_refs/nn/functional/__init__.py index b3f698f7178..bd146e96c49 100644 --- a/torch/_refs/nn/functional/__init__.py +++ b/torch/_refs/nn/functional/__init__.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Callable, Optional, Union import torch @@ -46,6 +46,7 @@ __all__ = [ "softshrink", "tanhshrink", "threshold", + "triplet_margin_loss", "glu", "pairwise_distance", "pdist", @@ -362,7 +363,8 @@ def l1_loss( Reference implementation of torch.nn.functional.l1_loss """ if size_average is not None or reduce is not None: - # TODO: raise exception instead of converting value + # TODO: Raise exception instead of converting value. This is only for + # primTorch since it can drop support for deprecated arguments. # msg = "size_average and reduce args are deprecated, please use reduction argument." reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) _check_reduction_value(reduction) @@ -406,7 +408,8 @@ def mse_loss( reduction: str = "mean", ) -> TensorLikeType: if size_average is not None or reduce is not None: - # TODO: raise exception instead of converting value + # TODO: Raise exception instead of converting value. This is only for + # primTorch since it can drop support for deprecated arguments. # msg = "size_average and reduce args are deprecated, please use reduction argument." reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) _check_reduction_value(reduction) @@ -501,6 +504,84 @@ def threshold( return torch.where(a <= threshold, value, a) +# CompositeImplicitAutograd - don't register decomp +# No elementwise type promotion - core op doesn't explicitly type promote +def triplet_margin_loss( + anchor: TensorLikeType, + positive: TensorLikeType, + negative: TensorLikeType, + margin: float = 1.0, + p: float = 2, + eps: float = 1e-6, + swap: bool = False, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> TensorLikeType: + if size_average is not None or reduce is not None: + # TODO: Raise exception instead of converting value. This is only for + # primTorch since it can drop support for deprecated arguments. + # msg = "size_average and reduce args are deprecated, please use reduction argument." + reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) + + # torch.nn.functional.triplet_margin_with_distance_loss has no ref defined + # since it's a pure Python implementation. Use this helper instead. + return _triplet_margin_with_distance_loss( + anchor=anchor, + positive=positive, + negative=negative, + distance_function=lambda x, y: torch.pairwise_distance(x, y, p, eps), + margin=margin, + swap=swap, + reduction=reduction, + ) + + +# Pure Python impl - don't register decomp and don't add a ref. Defined as a +# helper here since triplet_margin_loss can be nicely implemented with it. +def _triplet_margin_with_distance_loss( + anchor: TensorLikeType, + positive: TensorLikeType, + negative: TensorLikeType, + *, + distance_function: Optional[ + Callable[[TensorLikeType, TensorLikeType], TensorLikeType] + ] = None, + margin: float = 1.0, + swap: bool = False, + reduction: str = "mean", +) -> TensorLikeType: + _check_reduction_value(reduction) + + a_dim = anchor.ndim + p_dim = positive.ndim + n_dim = negative.ndim + check( + a_dim == p_dim and p_dim == n_dim, + lambda: ( + f"The anchor, positive, and negative tensors are expected to have " + f"the same number of dimensions, but got: anchor {a_dim}D, " + f"positive {p_dim}D, and negative {n_dim}D inputs" + ), + ) + + if distance_function is None: + distance_function = torch.pairwise_distance + + dist_pos = distance_function(anchor, positive) + dist_neg = distance_function(anchor, negative) + # The distance swap is described in the paper "Learning shallow + # convolutional feature descriptors with triplet losses" by V. Balntas, E. + # Riba et al. If True, and if the positive example is closer to the + # negative example than the anchor is, swaps the positive example and the + # anchor in the loss computation. + if swap: + dist_swap = distance_function(positive, negative) + dist_neg = torch.minimum(dist_neg, dist_swap) + loss = torch.clamp_min(margin + dist_pos - dist_neg, 0) + return _apply_loss_reduction(loss, reduction) + + @register_decomposition(torch.ops.aten.hardtanh) @elementwise_unary_scalar_wrapper @elementwise_type_promotion_wrapper( @@ -582,7 +663,8 @@ def poisson_nll_loss( Reference implementation of torch.nn.functional.poisson_nll_loss """ if size_average is not None or reduce is not None: - # TODO: raise exception instead of converting value + # TODO: Raise exception instead of converting value. This is only for + # primTorch since it can drop support for deprecated arguments. # msg = "size_average and reduce args are deprecated, please use reduction argument." reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) _check_reduction_value(reduction) diff --git a/torch/nn/functional.py b/torch/nn/functional.py index a7de8b2b29c..f0697659c6c 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -4583,24 +4583,43 @@ def triplet_margin_with_distance_loss( reduction=reduction, ) - distance_function = distance_function if distance_function is not None else pairwise_distance + # Check validity of reduction mode + if reduction not in ("mean", "sum", "none"): + raise ValueError(f"{reduction} is not a valid value for reduction") - positive_dist = distance_function(anchor, positive) - negative_dist = distance_function(anchor, negative) + # Check dimensions + a_dim = anchor.ndim + p_dim = positive.ndim + n_dim = negative.ndim + if not (a_dim == p_dim and p_dim == n_dim): + raise RuntimeError( + (f"The anchor, positive, and negative tensors are expected to have " + f"the same number of dimensions, but got: anchor {a_dim}D, " + f"positive {p_dim}D, and negative {n_dim}D inputs")) + # Calculate loss + if distance_function is None: + distance_function = torch.pairwise_distance + + dist_pos = distance_function(anchor, positive) + dist_neg = distance_function(anchor, negative) + # The distance swap is described in the paper "Learning shallow + # convolutional feature descriptors with triplet losses" by V. Balntas, E. + # Riba et al. If True, and if the positive example is closer to the + # negative example than the anchor is, swaps the positive example and the + # anchor in the loss computation. if swap: - swap_dist = distance_function(positive, negative) - negative_dist = torch.min(negative_dist, swap_dist) + dist_swap = distance_function(positive, negative) + dist_neg = torch.minimum(dist_neg, dist_swap) + loss = torch.clamp_min(margin + dist_pos - dist_neg, 0) - output = torch.clamp(positive_dist - negative_dist + margin, min=0.0) - - reduction_enum = _Reduction.get_enum(reduction) - if reduction_enum == 1: - return output.mean() - elif reduction_enum == 2: - return output.sum() - else: - return output + # Apply reduction + if reduction == "sum": + return torch.sum(loss) + elif reduction == "mean": + return torch.mean(loss) + else: # reduction == "none" + return loss def normalize(input: Tensor, p: float = 2.0, dim: int = 1, eps: float = 1e-12, out: Optional[Tensor] = None) -> Tensor: diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 4cb8e8025a8..71a9dabe273 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -7386,6 +7386,58 @@ def sample_inputs_triplet_margin_loss(op_info, device, dtype, requires_grad, wit kwargs["distance_function"] = torch.nn.PairwiseDistance() yield SampleInput(input, args=args, kwargs=kwargs) +def error_inputs_triplet_margin_loss(op_info, device, **kwargs): + make_input = partial(make_tensor, device=device, dtype=torch.float32) + + samples = ( + # input, args, kwargs, error_type, error_regex + # invalid reduction + (make_input(3, 4), (make_input(3, 4), make_input(3, 4)), + dict(reduction="abc"), + ValueError, "abc is not a valid value for reduction"), + + # shape mismatch + (make_input(3, 5), (make_input(3, 4), make_input(3, 4)), + dict(), + RuntimeError, + (r"The size of tensor a \(5\) must match the size of tensor b \(4\) " + r"at non-singleton dimension 1")), + (make_input(3, 4), (make_input(3, 5), make_input(3, 4)), + dict(), + RuntimeError, + (r"The size of tensor a \(4\) must match the size of tensor b \(5\) " + r"at non-singleton dimension 1")), + (make_input(3, 4), (make_input(3, 4), make_input(3, 5)), + dict(), + RuntimeError, + (r"The size of tensor a \(4\) must match the size of tensor b \(5\) " + r"at non-singleton dimension 1")), + + # different dimensions + (make_input(3,), (make_input(3, 4), make_input(3, 4)), + dict(), + RuntimeError, + (r"The anchor, positive, and negative tensors are expected to have " + r"the same number of dimensions, but got: anchor 1D, positive 2D, " + r"and negative 2D inputs")), + (make_input(3, 4), (make_input(3,), make_input(3, 4)), + dict(), + RuntimeError, + (r"The anchor, positive, and negative tensors are expected to have " + r"the same number of dimensions, but got: anchor 2D, positive 1D, " + r"and negative 2D inputs")), + (make_input(3, 4), (make_input(3, 4), make_input(3,)), + dict(), + RuntimeError, + (r"The anchor, positive, and negative tensors are expected to have " + r"the same number of dimensions, but got: anchor 2D, positive 2D, " + r"and negative 1D inputs")), + ) + + for input, args, kwargs, error_type, error_regex in samples: + yield ErrorInput(SampleInput(input, args=args, kwargs=kwargs), + error_type=error_type, error_regex=error_regex) + def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_grad, **kwargs): make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) @@ -12101,6 +12153,7 @@ op_db: List[OpInfo] = [ OpInfo( "nn.functional.triplet_margin_loss", sample_inputs_func=sample_inputs_triplet_margin_loss, + error_inputs_func=error_inputs_triplet_margin_loss, dtypes=all_types_and_complex_and(torch.bfloat16), dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16), supports_out=False, @@ -12110,6 +12163,7 @@ op_db: List[OpInfo] = [ OpInfo( "nn.functional.triplet_margin_with_distance_loss", sample_inputs_func=partial(sample_inputs_triplet_margin_loss, with_distance=True), + error_inputs_func=error_inputs_triplet_margin_loss, dtypes=all_types_and_complex_and(torch.bfloat16), dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16), supports_out=False, @@ -17572,6 +17626,20 @@ python_ref_db = [ torch_opinfo_name="clamp", supports_nvfuser=False, ), + PythonRefInfo( + "_refs.nn.functional.triplet_margin_loss", + torch_opinfo_name="nn.functional.triplet_margin_loss", + supports_out=False, + # TODO: Uses minimum and clamp, which don't support nvfuser. + supports_nvfuser=False, + skips=( + # AssertionError: Tensor-likes are not close! + # Greatest absolute difference: 6.103515625e-05 at index (4,) (up to 1e-05 allowed) + # Greatest relative difference: 8.519846983548175e-06 at index (4,) (up to 1.3e-06 allowed) + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref', + dtypes=(torch.uint8,), device_type="cpu"), + ) + ), # # Data Conversion & Data Movement Opinfos #