From 334339a3d27d17d955c63b97f32dd728d7212f8c Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 8 Feb 2022 14:36:26 +0100 Subject: [PATCH] add `OpInfo`s for `torch.nn.functional.triplet_margin(_with_distance)?_loss` ghstack-source-id: bbc38b4b859dd116bf604ae5a44ed40d642005b1 Pull Request resolved: https://github.com/pytorch/pytorch/pull/67079 --- .../_internal/common_methods_invocations.py | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index cc089a3a5d0..0c8b6f85c1b 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -10,6 +10,7 @@ import unittest import math import torch +from torch import nn import numpy as np from torch._six import inf import collections.abc @@ -7770,6 +7771,25 @@ def sample_inputs_binary_cross_entropy(op_info, device, dtype, requires_grad, lo for shape, kwargs in shapes_and_kwargs ] +def sample_inputs_triplet_margin_loss(op_info, device, dtype, requires_grad, with_distance=False, **kwargs): + make = partial(make_tensor, (S, M), device=device, dtype=dtype, requires_grad=requires_grad) + + kwargss = ( + *[dict(margin=margin) for margin in (1e-6, 1.0, 10.0)], + dict(swap=True), + *[dict(reduction=reduction) for reduction in ("mean", "sum", "none")], + ) + + sample_inputs = [] + for kwargs in kwargss: + input = make() + args = (make(), make()) + if with_distance: + kwargs["distance_function"] = nn.PairwiseDistance() + sample_inputs.append(SampleInput(input, args=args, kwargs=kwargs)) + + return sample_inputs + foreach_unary_op_db: List[OpInfo] = [ ForeachFuncInfo('exp'), @@ -15702,6 +15722,36 @@ op_db: List[OpInfo] = [ ), ), ), + OpInfo( + "nn.functional.triplet_margin_loss", + sample_inputs_func=sample_inputs_triplet_margin_loss, + dtypes=all_types_and_complex_and(torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16), + supports_out=False, + ), + OpInfo( + "nn.functional.triplet_margin_with_distance_loss", + sample_inputs_func=partial(sample_inputs_triplet_margin_loss, with_distance=True), + dtypes=all_types_and_complex_and(torch.bfloat16), + dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16), + supports_out=False, + skips=( + # This test cannot handle a callable passed to `distance_function`. If we would use + # `distance_function=None`, the test would pass fine. + DecorateInfo( + unittest.expectedFailure, + "TestJit", + "test_variant_consistency_jit", + ), + # This tests raises a plain AssertionError at + # https://github.com/pytorch/pytorch/blob/840fe8e4e6efee5d8197bd6987757f29d72dd162/test/test_fx_experimental.py#L1559 + DecorateInfo( + unittest.expectedFailure, + "TestNormalizeOperators", + "test_normalize_operator_exhaustive", + ), + ), + ), ] # Common operator groupings