mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
add OpInfos for nn.functional.binary_cross_entropy(_with_logits)?
ghstack-source-id: 740eeff117
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67023
This commit is contained in:
parent
45cdfbeeab
commit
5ada829c4b
|
|
@ -7744,6 +7744,32 @@ def sample_inputs_pdist(op_info, device, dtype, requires_grad, **kwargs):
|
|||
*[SampleInput(make_input((S, S)), kwargs=dict(p=p)) for p in (0.0, 1.0, 2.0, 10.0, float("inf"))],
|
||||
]
|
||||
|
||||
def sample_inputs_binary_cross_entropy(op_info, device, dtype, requires_grad, logits=False, **kwargs):
|
||||
make = partial(make_tensor, device=device, dtype=dtype)
|
||||
make_prob = partial(make, low=0, high=1)
|
||||
|
||||
reductions = ("mean", "sum", "none")
|
||||
|
||||
shapes_and_kwargs = [
|
||||
*[(shape, None) for shape in ((), (1,), (S,), (S, S), (S, S, S))],
|
||||
*[((S, S), dict(reduction=reduction)) for reduction in reductions],
|
||||
*[((S, S), dict(reduction=reduction, weight=make((S, S)))) for reduction in reductions],
|
||||
]
|
||||
|
||||
if logits:
|
||||
shapes_and_kwargs.extend(
|
||||
[((S, S), dict(reduction=reduction, pos_weight=make((S,), low=0))) for reduction in reductions]
|
||||
)
|
||||
|
||||
return [
|
||||
SampleInput(
|
||||
(make if logits else make_prob)(shape, requires_grad=requires_grad),
|
||||
args=(make_prob(shape, requires_grad=requires_grad),),
|
||||
kwargs=kwargs,
|
||||
)
|
||||
for shape, kwargs in shapes_and_kwargs
|
||||
]
|
||||
|
||||
|
||||
foreach_unary_op_db: List[OpInfo] = [
|
||||
ForeachFuncInfo('exp'),
|
||||
|
|
@ -15586,6 +15612,96 @@ op_db: List[OpInfo] = [
|
|||
),
|
||||
),
|
||||
),
|
||||
OpInfo(
|
||||
"nn.functional.binary_cross_entropy",
|
||||
sample_inputs_func=sample_inputs_binary_cross_entropy,
|
||||
dtypes=floating_types(),
|
||||
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
|
||||
supports_out=False,
|
||||
gradcheck_fast_mode=False,
|
||||
decorators=(
|
||||
DecorateInfo(
|
||||
toleranceOverride({torch.float32: tol(atol=1e-3, rtol=1e-3)}),
|
||||
"TestJit",
|
||||
"test_variant_consistency_jit",
|
||||
),
|
||||
),
|
||||
skips=(
|
||||
# RuntimeError: expected int at position 0, but got: Tensor
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure,
|
||||
"TestJit",
|
||||
"test_variant_consistency_jit",
|
||||
),
|
||||
# NotImplementedError: the derivative for 'binary_cross_entropy_backward wrt `target`' is not implemented.
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure,
|
||||
"TestGradients",
|
||||
"test_fn_gradgrad",
|
||||
),
|
||||
# AssertionError: Found a sampled tensor of floating-point dtype torch.float32 sampled with
|
||||
# requires_grad=False.
|
||||
# `weight` input does not support gradient.
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure,
|
||||
"TestCommon",
|
||||
"test_floating_inputs_are_differentiable",
|
||||
),
|
||||
),
|
||||
),
|
||||
OpInfo(
|
||||
"nn.functional.binary_cross_entropy_with_logits",
|
||||
sample_inputs_func=partial(sample_inputs_binary_cross_entropy, logits=True),
|
||||
dtypes=floating_types_and(torch.bfloat16),
|
||||
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
|
||||
supports_out=False,
|
||||
supports_forward_ad=True,
|
||||
gradcheck_fast_mode=False,
|
||||
decorators=(
|
||||
DecorateInfo(
|
||||
toleranceOverride({torch.float32: tol(atol=1e-3, rtol=1e-3)}),
|
||||
"TestJit",
|
||||
"test_variant_consistency_jit",
|
||||
),
|
||||
),
|
||||
skips=(
|
||||
# torch.autograd.gradcheck.GradcheckError: Jacobian computed with forward mode mismatch for output 0 with
|
||||
# respect to input 0
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure,
|
||||
"TestGradients",
|
||||
"test_fn_fwgrad_bwgrad",
|
||||
),
|
||||
# RuntimeError: one of the variables needed for gradient computation has been modified by an inplace
|
||||
# operation: [torch.DoubleTensor [5, 5]], which is output 0 of SigmoidBackward0, is at version 1;
|
||||
# expected version 0 instead.
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure,
|
||||
"TestGradients",
|
||||
"test_fn_gradgrad",
|
||||
),
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure,
|
||||
"TestJit",
|
||||
"test_variant_consistency_jit",
|
||||
),
|
||||
# AssertionError: Found a sampled tensor of floating-point dtype torch.float32 sampled with
|
||||
# requires_grad=False.
|
||||
# `weight` input does not support gradient.
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure,
|
||||
"TestCommon",
|
||||
"test_floating_inputs_are_differentiable",
|
||||
),
|
||||
# RuntimeError: ZeroTensors are immutable. Please use the materialized zero tensor obtained using .clone()
|
||||
# if you want a mutable tensor.
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure,
|
||||
"TestGradients",
|
||||
"test_forward_mode_AD",
|
||||
),
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
# Common operator groupings
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user