mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
add OpInfos for torch.nn.functional.triplet_margin(_with_distance)?_loss
ghstack-source-id: bbc38b4b85
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67079
This commit is contained in:
parent
5ada829c4b
commit
334339a3d2
|
|
@ -10,6 +10,7 @@ import unittest
|
||||||
import math
|
import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch import nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch._six import inf
|
from torch._six import inf
|
||||||
import collections.abc
|
import collections.abc
|
||||||
|
|
@ -7770,6 +7771,25 @@ def sample_inputs_binary_cross_entropy(op_info, device, dtype, requires_grad, lo
|
||||||
for shape, kwargs in shapes_and_kwargs
|
for shape, kwargs in shapes_and_kwargs
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def sample_inputs_triplet_margin_loss(op_info, device, dtype, requires_grad, with_distance=False, **kwargs):
|
||||||
|
make = partial(make_tensor, (S, M), device=device, dtype=dtype, requires_grad=requires_grad)
|
||||||
|
|
||||||
|
kwargss = (
|
||||||
|
*[dict(margin=margin) for margin in (1e-6, 1.0, 10.0)],
|
||||||
|
dict(swap=True),
|
||||||
|
*[dict(reduction=reduction) for reduction in ("mean", "sum", "none")],
|
||||||
|
)
|
||||||
|
|
||||||
|
sample_inputs = []
|
||||||
|
for kwargs in kwargss:
|
||||||
|
input = make()
|
||||||
|
args = (make(), make())
|
||||||
|
if with_distance:
|
||||||
|
kwargs["distance_function"] = nn.PairwiseDistance()
|
||||||
|
sample_inputs.append(SampleInput(input, args=args, kwargs=kwargs))
|
||||||
|
|
||||||
|
return sample_inputs
|
||||||
|
|
||||||
|
|
||||||
foreach_unary_op_db: List[OpInfo] = [
|
foreach_unary_op_db: List[OpInfo] = [
|
||||||
ForeachFuncInfo('exp'),
|
ForeachFuncInfo('exp'),
|
||||||
|
|
@ -15702,6 +15722,36 @@ op_db: List[OpInfo] = [
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
OpInfo(
|
||||||
|
"nn.functional.triplet_margin_loss",
|
||||||
|
sample_inputs_func=sample_inputs_triplet_margin_loss,
|
||||||
|
dtypes=all_types_and_complex_and(torch.bfloat16),
|
||||||
|
dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16),
|
||||||
|
supports_out=False,
|
||||||
|
),
|
||||||
|
OpInfo(
|
||||||
|
"nn.functional.triplet_margin_with_distance_loss",
|
||||||
|
sample_inputs_func=partial(sample_inputs_triplet_margin_loss, with_distance=True),
|
||||||
|
dtypes=all_types_and_complex_and(torch.bfloat16),
|
||||||
|
dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16),
|
||||||
|
supports_out=False,
|
||||||
|
skips=(
|
||||||
|
# This test cannot handle a callable passed to `distance_function`. If we would use
|
||||||
|
# `distance_function=None`, the test would pass fine.
|
||||||
|
DecorateInfo(
|
||||||
|
unittest.expectedFailure,
|
||||||
|
"TestJit",
|
||||||
|
"test_variant_consistency_jit",
|
||||||
|
),
|
||||||
|
# This tests raises a plain AssertionError at
|
||||||
|
# https://github.com/pytorch/pytorch/blob/840fe8e4e6efee5d8197bd6987757f29d72dd162/test/test_fx_experimental.py#L1559
|
||||||
|
DecorateInfo(
|
||||||
|
unittest.expectedFailure,
|
||||||
|
"TestNormalizeOperators",
|
||||||
|
"test_normalize_operator_exhaustive",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Common operator groupings
|
# Common operator groupings
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user