add OpInfos for torch.nn.functional.triplet_margin(_with_distance)?_loss

ghstack-source-id: bbc38b4b85
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67079
This commit is contained in:
Philip Meier 2022-02-08 14:36:26 +01:00
parent 5ada829c4b
commit 334339a3d2

View File

@ -10,6 +10,7 @@ import unittest
import math
import torch
from torch import nn
import numpy as np
from torch._six import inf
import collections.abc
@ -7770,6 +7771,25 @@ def sample_inputs_binary_cross_entropy(op_info, device, dtype, requires_grad, lo
for shape, kwargs in shapes_and_kwargs
]
def sample_inputs_triplet_margin_loss(op_info, device, dtype, requires_grad, with_distance=False, **kwargs):
make = partial(make_tensor, (S, M), device=device, dtype=dtype, requires_grad=requires_grad)
kwargss = (
*[dict(margin=margin) for margin in (1e-6, 1.0, 10.0)],
dict(swap=True),
*[dict(reduction=reduction) for reduction in ("mean", "sum", "none")],
)
sample_inputs = []
for kwargs in kwargss:
input = make()
args = (make(), make())
if with_distance:
kwargs["distance_function"] = nn.PairwiseDistance()
sample_inputs.append(SampleInput(input, args=args, kwargs=kwargs))
return sample_inputs
foreach_unary_op_db: List[OpInfo] = [
ForeachFuncInfo('exp'),
@ -15702,6 +15722,36 @@ op_db: List[OpInfo] = [
),
),
),
OpInfo(
"nn.functional.triplet_margin_loss",
sample_inputs_func=sample_inputs_triplet_margin_loss,
dtypes=all_types_and_complex_and(torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16),
supports_out=False,
),
OpInfo(
"nn.functional.triplet_margin_with_distance_loss",
sample_inputs_func=partial(sample_inputs_triplet_margin_loss, with_distance=True),
dtypes=all_types_and_complex_and(torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16),
supports_out=False,
skips=(
# This test cannot handle a callable passed to `distance_function`. If we would use
# `distance_function=None`, the test would pass fine.
DecorateInfo(
unittest.expectedFailure,
"TestJit",
"test_variant_consistency_jit",
),
# This tests raises a plain AssertionError at
# https://github.com/pytorch/pytorch/blob/840fe8e4e6efee5d8197bd6987757f29d72dd162/test/test_fx_experimental.py#L1559
DecorateInfo(
unittest.expectedFailure,
"TestNormalizeOperators",
"test_normalize_operator_exhaustive",
),
),
),
]
# Common operator groupings