add OpInfos for nn.functional.binary_cross_entropy(_with_logits)?

ghstack-source-id: 740eeff117
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67023
This commit is contained in:
Philip Meier 2022-02-08 14:36:26 +01:00
parent 45cdfbeeab
commit 5ada829c4b

View File

@ -7744,6 +7744,32 @@ def sample_inputs_pdist(op_info, device, dtype, requires_grad, **kwargs):
*[SampleInput(make_input((S, S)), kwargs=dict(p=p)) for p in (0.0, 1.0, 2.0, 10.0, float("inf"))],
]
def sample_inputs_binary_cross_entropy(op_info, device, dtype, requires_grad, logits=False, **kwargs):
make = partial(make_tensor, device=device, dtype=dtype)
make_prob = partial(make, low=0, high=1)
reductions = ("mean", "sum", "none")
shapes_and_kwargs = [
*[(shape, None) for shape in ((), (1,), (S,), (S, S), (S, S, S))],
*[((S, S), dict(reduction=reduction)) for reduction in reductions],
*[((S, S), dict(reduction=reduction, weight=make((S, S)))) for reduction in reductions],
]
if logits:
shapes_and_kwargs.extend(
[((S, S), dict(reduction=reduction, pos_weight=make((S,), low=0))) for reduction in reductions]
)
return [
SampleInput(
(make if logits else make_prob)(shape, requires_grad=requires_grad),
args=(make_prob(shape, requires_grad=requires_grad),),
kwargs=kwargs,
)
for shape, kwargs in shapes_and_kwargs
]
foreach_unary_op_db: List[OpInfo] = [
ForeachFuncInfo('exp'),
@ -15586,6 +15612,96 @@ op_db: List[OpInfo] = [
),
),
),
OpInfo(
"nn.functional.binary_cross_entropy",
sample_inputs_func=sample_inputs_binary_cross_entropy,
dtypes=floating_types(),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
supports_out=False,
gradcheck_fast_mode=False,
decorators=(
DecorateInfo(
toleranceOverride({torch.float32: tol(atol=1e-3, rtol=1e-3)}),
"TestJit",
"test_variant_consistency_jit",
),
),
skips=(
# RuntimeError: expected int at position 0, but got: Tensor
DecorateInfo(
unittest.expectedFailure,
"TestJit",
"test_variant_consistency_jit",
),
# NotImplementedError: the derivative for 'binary_cross_entropy_backward wrt `target`' is not implemented.
DecorateInfo(
unittest.expectedFailure,
"TestGradients",
"test_fn_gradgrad",
),
# AssertionError: Found a sampled tensor of floating-point dtype torch.float32 sampled with
# requires_grad=False.
# `weight` input does not support gradient.
DecorateInfo(
unittest.expectedFailure,
"TestCommon",
"test_floating_inputs_are_differentiable",
),
),
),
OpInfo(
"nn.functional.binary_cross_entropy_with_logits",
sample_inputs_func=partial(sample_inputs_binary_cross_entropy, logits=True),
dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
gradcheck_fast_mode=False,
decorators=(
DecorateInfo(
toleranceOverride({torch.float32: tol(atol=1e-3, rtol=1e-3)}),
"TestJit",
"test_variant_consistency_jit",
),
),
skips=(
# torch.autograd.gradcheck.GradcheckError: Jacobian computed with forward mode mismatch for output 0 with
# respect to input 0
DecorateInfo(
unittest.expectedFailure,
"TestGradients",
"test_fn_fwgrad_bwgrad",
),
# RuntimeError: one of the variables needed for gradient computation has been modified by an inplace
# operation: [torch.DoubleTensor [5, 5]], which is output 0 of SigmoidBackward0, is at version 1;
# expected version 0 instead.
DecorateInfo(
unittest.expectedFailure,
"TestGradients",
"test_fn_gradgrad",
),
DecorateInfo(
unittest.expectedFailure,
"TestJit",
"test_variant_consistency_jit",
),
# AssertionError: Found a sampled tensor of floating-point dtype torch.float32 sampled with
# requires_grad=False.
# `weight` input does not support gradient.
DecorateInfo(
unittest.expectedFailure,
"TestCommon",
"test_floating_inputs_are_differentiable",
),
# RuntimeError: ZeroTensors are immutable. Please use the materialized zero tensor obtained using .clone()
# if you want a mutable tensor.
DecorateInfo(
unittest.expectedFailure,
"TestGradients",
"test_forward_mode_AD",
),
),
),
]
# Common operator groupings