mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
add decomposition for nll_loss2d_backward (#77198)
Adds a decomposition for `nll_loss2d_backward` This will let us actually run all the tests for jvpvjp ([see this functorch PR](https://github.com/pytorch/functorch/pull/792)). I confirmed locally that this made those tests pass too Pull Request resolved: https://github.com/pytorch/pytorch/pull/77198 Approved by: https://github.com/Chillee
This commit is contained in:
parent
2fcd5808a3
commit
d694cf60fe
|
|
@ -419,6 +419,39 @@ def huber_loss_backward(
|
|||
)
|
||||
|
||||
|
||||
def _nll_loss_backward(
|
||||
grad_output: Tensor,
|
||||
self: Tensor,
|
||||
target: Tensor,
|
||||
weight: Optional[Tensor],
|
||||
reduction: int,
|
||||
ignore_index: int,
|
||||
total_weight: Tensor,
|
||||
) -> Tensor:
|
||||
channel_dim = 0 if self.dim() < 2 else 1
|
||||
if reduction == Reduction.MEAN.value:
|
||||
grad_output = grad_output / total_weight
|
||||
|
||||
target = target.unsqueeze(channel_dim)
|
||||
grad_input = torch.zeros_like(self)
|
||||
grad_input = torch.scatter(grad_input, channel_dim, target, -1.0)
|
||||
|
||||
if grad_input.dim() > grad_output.dim() > 0:
|
||||
grad_output = grad_output.unsqueeze(channel_dim)
|
||||
|
||||
if weight is not None:
|
||||
new_shape = [1 for _ in range(self.dim())]
|
||||
new_shape[channel_dim] = weight.shape[0]
|
||||
weight = weight.reshape(new_shape)
|
||||
grad_output = grad_output * weight
|
||||
|
||||
has_ignore_index = ignore_index >= 0
|
||||
if has_ignore_index:
|
||||
ignore_index_mask = target != ignore_index
|
||||
grad_output = grad_output * ignore_index_mask
|
||||
|
||||
return grad_input * grad_output
|
||||
|
||||
@register_decomposition(aten.nll_loss_backward)
|
||||
def nll_loss_backward(
|
||||
grad_output: Tensor,
|
||||
|
|
@ -457,29 +490,39 @@ def nll_loss_backward(
|
|||
grad_output.dim() <= 1 and grad_output.numel() == 1
|
||||
), f"Expected a single element grad_output tensor, but got: {grad_output.shape}"
|
||||
|
||||
channel_dim = 0 if self.dim() < 2 else 1
|
||||
if reduction == Reduction.MEAN.value:
|
||||
grad_output = grad_output / total_weight
|
||||
return _nll_loss_backward(grad_output, self, target, weight, reduction, ignore_index, total_weight)
|
||||
|
||||
target = target.unsqueeze(channel_dim)
|
||||
grad_input = torch.zeros_like(self)
|
||||
grad_input = torch.scatter(grad_input, channel_dim, target, -1.0)
|
||||
|
||||
if grad_input.dim() > grad_output.dim() > 0:
|
||||
grad_output = grad_output.unsqueeze(channel_dim)
|
||||
@register_decomposition(aten.nll_loss2d_backward)
|
||||
def nll_loss2d_backward(
|
||||
grad_output: Tensor,
|
||||
self: Tensor,
|
||||
target: Tensor,
|
||||
weight: Optional[Tensor],
|
||||
reduction: int,
|
||||
ignore_index: int,
|
||||
total_weight: Tensor,
|
||||
) -> Tensor:
|
||||
assert (
|
||||
self.dim() == 4
|
||||
), f"only batches of spatial inputs supported (4D tensors), but got input of dimension: {self.dim()}"
|
||||
|
||||
if weight is not None:
|
||||
new_shape = [1 for _ in range(self.dim())]
|
||||
new_shape[channel_dim] = weight.shape[0]
|
||||
weight.reshape(new_shape)
|
||||
grad_output = grad_output * weight
|
||||
assert (
|
||||
target.dim() == 3
|
||||
), f"only batches of spatial targets supported (3D tensors) but got targets of dimension: {target.dim()}"
|
||||
|
||||
has_ignore_index = ignore_index >= 0
|
||||
if has_ignore_index:
|
||||
ignore_index_mask = target != ignore_index
|
||||
grad_output = grad_output * ignore_index_mask
|
||||
assert(
|
||||
self.shape[0] == target.shape[0] and self.shape[2] == target.shape[1] and self.shape[3] == target.shape[2]
|
||||
), f"size mismatch (got input: {self.shape}, target: {target.shape}"
|
||||
|
||||
return grad_input * grad_output
|
||||
assert (
|
||||
total_weight.numel() == 1
|
||||
), (
|
||||
"expected total_weight to be a single element tensor, "
|
||||
f"got: {total_weight.shape} ( {total_weight.numel()}, elements)"
|
||||
)
|
||||
|
||||
return _nll_loss_backward(grad_output, self, target, weight, reduction, ignore_index, total_weight)
|
||||
|
||||
|
||||
@register_decomposition(aten.binary_cross_entropy)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user