mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-08 07:39:33 +01:00
add decomposition for nll_loss2d_backward (#77198)
Adds a decomposition for `nll_loss2d_backward` This will let us actually run all the tests for jvpvjp ([see this functorch PR](https://github.com/pytorch/functorch/pull/792)). I confirmed locally that this made those tests pass too Pull Request resolved: https://github.com/pytorch/pytorch/pull/77198 Approved by: https://github.com/Chillee
This commit is contained in:
parent
2fcd5808a3
commit
d694cf60fe
|
|
@ -419,6 +419,39 @@ def huber_loss_backward(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _nll_loss_backward(
|
||||||
|
grad_output: Tensor,
|
||||||
|
self: Tensor,
|
||||||
|
target: Tensor,
|
||||||
|
weight: Optional[Tensor],
|
||||||
|
reduction: int,
|
||||||
|
ignore_index: int,
|
||||||
|
total_weight: Tensor,
|
||||||
|
) -> Tensor:
|
||||||
|
channel_dim = 0 if self.dim() < 2 else 1
|
||||||
|
if reduction == Reduction.MEAN.value:
|
||||||
|
grad_output = grad_output / total_weight
|
||||||
|
|
||||||
|
target = target.unsqueeze(channel_dim)
|
||||||
|
grad_input = torch.zeros_like(self)
|
||||||
|
grad_input = torch.scatter(grad_input, channel_dim, target, -1.0)
|
||||||
|
|
||||||
|
if grad_input.dim() > grad_output.dim() > 0:
|
||||||
|
grad_output = grad_output.unsqueeze(channel_dim)
|
||||||
|
|
||||||
|
if weight is not None:
|
||||||
|
new_shape = [1 for _ in range(self.dim())]
|
||||||
|
new_shape[channel_dim] = weight.shape[0]
|
||||||
|
weight = weight.reshape(new_shape)
|
||||||
|
grad_output = grad_output * weight
|
||||||
|
|
||||||
|
has_ignore_index = ignore_index >= 0
|
||||||
|
if has_ignore_index:
|
||||||
|
ignore_index_mask = target != ignore_index
|
||||||
|
grad_output = grad_output * ignore_index_mask
|
||||||
|
|
||||||
|
return grad_input * grad_output
|
||||||
|
|
||||||
@register_decomposition(aten.nll_loss_backward)
|
@register_decomposition(aten.nll_loss_backward)
|
||||||
def nll_loss_backward(
|
def nll_loss_backward(
|
||||||
grad_output: Tensor,
|
grad_output: Tensor,
|
||||||
|
|
@ -457,29 +490,39 @@ def nll_loss_backward(
|
||||||
grad_output.dim() <= 1 and grad_output.numel() == 1
|
grad_output.dim() <= 1 and grad_output.numel() == 1
|
||||||
), f"Expected a single element grad_output tensor, but got: {grad_output.shape}"
|
), f"Expected a single element grad_output tensor, but got: {grad_output.shape}"
|
||||||
|
|
||||||
channel_dim = 0 if self.dim() < 2 else 1
|
return _nll_loss_backward(grad_output, self, target, weight, reduction, ignore_index, total_weight)
|
||||||
if reduction == Reduction.MEAN.value:
|
|
||||||
grad_output = grad_output / total_weight
|
|
||||||
|
|
||||||
target = target.unsqueeze(channel_dim)
|
|
||||||
grad_input = torch.zeros_like(self)
|
|
||||||
grad_input = torch.scatter(grad_input, channel_dim, target, -1.0)
|
|
||||||
|
|
||||||
if grad_input.dim() > grad_output.dim() > 0:
|
@register_decomposition(aten.nll_loss2d_backward)
|
||||||
grad_output = grad_output.unsqueeze(channel_dim)
|
def nll_loss2d_backward(
|
||||||
|
grad_output: Tensor,
|
||||||
|
self: Tensor,
|
||||||
|
target: Tensor,
|
||||||
|
weight: Optional[Tensor],
|
||||||
|
reduction: int,
|
||||||
|
ignore_index: int,
|
||||||
|
total_weight: Tensor,
|
||||||
|
) -> Tensor:
|
||||||
|
assert (
|
||||||
|
self.dim() == 4
|
||||||
|
), f"only batches of spatial inputs supported (4D tensors), but got input of dimension: {self.dim()}"
|
||||||
|
|
||||||
if weight is not None:
|
assert (
|
||||||
new_shape = [1 for _ in range(self.dim())]
|
target.dim() == 3
|
||||||
new_shape[channel_dim] = weight.shape[0]
|
), f"only batches of spatial targets supported (3D tensors) but got targets of dimension: {target.dim()}"
|
||||||
weight.reshape(new_shape)
|
|
||||||
grad_output = grad_output * weight
|
|
||||||
|
|
||||||
has_ignore_index = ignore_index >= 0
|
assert(
|
||||||
if has_ignore_index:
|
self.shape[0] == target.shape[0] and self.shape[2] == target.shape[1] and self.shape[3] == target.shape[2]
|
||||||
ignore_index_mask = target != ignore_index
|
), f"size mismatch (got input: {self.shape}, target: {target.shape}"
|
||||||
grad_output = grad_output * ignore_index_mask
|
|
||||||
|
|
||||||
return grad_input * grad_output
|
assert (
|
||||||
|
total_weight.numel() == 1
|
||||||
|
), (
|
||||||
|
"expected total_weight to be a single element tensor, "
|
||||||
|
f"got: {total_weight.shape} ( {total_weight.numel()}, elements)"
|
||||||
|
)
|
||||||
|
|
||||||
|
return _nll_loss_backward(grad_output, self, target, weight, reduction, ignore_index, total_weight)
|
||||||
|
|
||||||
|
|
||||||
@register_decomposition(aten.binary_cross_entropy)
|
@register_decomposition(aten.binary_cross_entropy)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user