add decomposition for nll_loss2d_backward (#77198)

Adds a decomposition for `nll_loss2d_backward`

This will let us actually run all the tests for jvpvjp ([see this functorch PR](https://github.com/pytorch/functorch/pull/792)). I confirmed locally that this made those tests pass too
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77198
Approved by: https://github.com/Chillee
This commit is contained in:
samdow 2022-05-11 20:41:20 +00:00 committed by PyTorch MergeBot
parent 2fcd5808a3
commit d694cf60fe

View File

@ -419,6 +419,39 @@ def huber_loss_backward(
) )
def _nll_loss_backward(
grad_output: Tensor,
self: Tensor,
target: Tensor,
weight: Optional[Tensor],
reduction: int,
ignore_index: int,
total_weight: Tensor,
) -> Tensor:
channel_dim = 0 if self.dim() < 2 else 1
if reduction == Reduction.MEAN.value:
grad_output = grad_output / total_weight
target = target.unsqueeze(channel_dim)
grad_input = torch.zeros_like(self)
grad_input = torch.scatter(grad_input, channel_dim, target, -1.0)
if grad_input.dim() > grad_output.dim() > 0:
grad_output = grad_output.unsqueeze(channel_dim)
if weight is not None:
new_shape = [1 for _ in range(self.dim())]
new_shape[channel_dim] = weight.shape[0]
weight = weight.reshape(new_shape)
grad_output = grad_output * weight
has_ignore_index = ignore_index >= 0
if has_ignore_index:
ignore_index_mask = target != ignore_index
grad_output = grad_output * ignore_index_mask
return grad_input * grad_output
@register_decomposition(aten.nll_loss_backward) @register_decomposition(aten.nll_loss_backward)
def nll_loss_backward( def nll_loss_backward(
grad_output: Tensor, grad_output: Tensor,
@ -457,29 +490,39 @@ def nll_loss_backward(
grad_output.dim() <= 1 and grad_output.numel() == 1 grad_output.dim() <= 1 and grad_output.numel() == 1
), f"Expected a single element grad_output tensor, but got: {grad_output.shape}" ), f"Expected a single element grad_output tensor, but got: {grad_output.shape}"
channel_dim = 0 if self.dim() < 2 else 1 return _nll_loss_backward(grad_output, self, target, weight, reduction, ignore_index, total_weight)
if reduction == Reduction.MEAN.value:
grad_output = grad_output / total_weight
target = target.unsqueeze(channel_dim)
grad_input = torch.zeros_like(self)
grad_input = torch.scatter(grad_input, channel_dim, target, -1.0)
if grad_input.dim() > grad_output.dim() > 0: @register_decomposition(aten.nll_loss2d_backward)
grad_output = grad_output.unsqueeze(channel_dim) def nll_loss2d_backward(
grad_output: Tensor,
self: Tensor,
target: Tensor,
weight: Optional[Tensor],
reduction: int,
ignore_index: int,
total_weight: Tensor,
) -> Tensor:
assert (
self.dim() == 4
), f"only batches of spatial inputs supported (4D tensors), but got input of dimension: {self.dim()}"
if weight is not None: assert (
new_shape = [1 for _ in range(self.dim())] target.dim() == 3
new_shape[channel_dim] = weight.shape[0] ), f"only batches of spatial targets supported (3D tensors) but got targets of dimension: {target.dim()}"
weight.reshape(new_shape)
grad_output = grad_output * weight
has_ignore_index = ignore_index >= 0 assert(
if has_ignore_index: self.shape[0] == target.shape[0] and self.shape[2] == target.shape[1] and self.shape[3] == target.shape[2]
ignore_index_mask = target != ignore_index ), f"size mismatch (got input: {self.shape}, target: {target.shape}"
grad_output = grad_output * ignore_index_mask
return grad_input * grad_output assert (
total_weight.numel() == 1
), (
"expected total_weight to be a single element tensor, "
f"got: {total_weight.shape} ( {total_weight.numel()}, elements)"
)
return _nll_loss_backward(grad_output, self, target, weight, reduction, ignore_index, total_weight)
@register_decomposition(aten.binary_cross_entropy) @register_decomposition(aten.binary_cross_entropy)