mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Fix https://github.com/pytorch/pytorch/issues/16012 It replaces https://github.com/pytorch/pytorch/pull/20684 that has gone stale and simply adds tests on top of it. These calls used to crash, they now work and return the same value as the backward using the autograd engine. Pull Request resolved: https://github.com/pytorch/pytorch/pull/33872 Differential Revision: D20148360 Pulled By: albanD fbshipit-source-id: 1113f1a25be238570fa8900fc1be658b61a47802
314 lines
14 KiB
Python
314 lines
14 KiB
Python
"""Gradient interface"""
|
|
|
|
import torch
|
|
from .modules.utils import _single, _pair, _triple
|
|
import warnings
|
|
|
|
|
|
def _grad_input_padding(grad_output, input_size, stride, padding, kernel_size, dilation=None):
|
|
if dilation is None:
|
|
# For backward compatibility
|
|
warnings.warn("_grad_input_padding 'dilation' argument not provided. Default of 1 is used.")
|
|
dilation = [1] * len(stride)
|
|
|
|
input_size = list(input_size)
|
|
k = grad_output.dim() - 2
|
|
|
|
if len(input_size) == k + 2:
|
|
input_size = input_size[-k:]
|
|
if len(input_size) != k:
|
|
raise ValueError("input_size must have {} elements (got {})"
|
|
.format(k + 2, len(input_size)))
|
|
|
|
def dim_size(d):
|
|
return ((grad_output.size(d + 2) - 1) * stride[d] - 2 * padding[d] + 1
|
|
+ dilation[d] * (kernel_size[d] - 1))
|
|
|
|
min_sizes = [dim_size(d) for d in range(k)]
|
|
max_sizes = [min_sizes[d] + stride[d] - 1 for d in range(k)]
|
|
for size, min_size, max_size in zip(input_size, min_sizes, max_sizes):
|
|
if size < min_size or size > max_size:
|
|
raise ValueError(
|
|
("requested an input grad size of {}, but valid sizes range "
|
|
"from {} to {} (for a grad_output of {})").format(
|
|
input_size, min_sizes, max_sizes,
|
|
grad_output.size()[2:]))
|
|
|
|
return tuple(input_size[d] - min_sizes[d] for d in range(k))
|
|
|
|
|
|
def conv1d_input(input_size, weight, grad_output, stride=1, padding=0, dilation=1, groups=1):
|
|
r"""
|
|
Computes the gradient of conv1d with respect to the input of the convolution.
|
|
This is same as the 1D transposed convolution operator under the hood but requires
|
|
the shape of the gradient w.r.t. input to be specified explicitly.
|
|
|
|
Args:
|
|
input_size : Shape of the input gradient tensor
|
|
weight: weight tensor (out_channels x in_channels/groups x kW)
|
|
grad_output : output gradient tensor (minibatch x out_channels x oW)
|
|
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
|
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
|
|
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
|
|
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
|
|
|
|
Examples::
|
|
|
|
>>> input = torch.randn(1,1,3, requires_grad=True)
|
|
>>> weight = torch.randn(1,1,1, requires_grad=True)
|
|
>>> output = F.conv1d(input, weight)
|
|
>>> grad_output = torch.randn(output.shape)
|
|
>>> grad_input = torch.autograd.grad(output, input, grad_output)
|
|
>>> F.grad.conv1d_input(input.shape, weight, grad_output)
|
|
|
|
"""
|
|
stride = _single(stride)
|
|
padding = _single(padding)
|
|
dilation = _single(dilation)
|
|
kernel_size = [weight.shape[2]]
|
|
|
|
if input_size is None:
|
|
raise ValueError("grad.conv1d_input requires specifying an input_size")
|
|
|
|
grad_input_padding = _grad_input_padding(grad_output, input_size, stride,
|
|
padding, kernel_size, dilation)
|
|
|
|
return torch.conv_transpose1d(
|
|
grad_output, weight, None, stride, padding, grad_input_padding, groups,
|
|
dilation)
|
|
|
|
|
|
def conv1d_weight(input, weight_size, grad_output, stride=1, padding=0, dilation=1, groups=1):
|
|
r"""
|
|
Computes the gradient of conv1d with respect to the weight of the convolution.
|
|
|
|
Args:
|
|
input: input tensor of shape (minibatch x in_channels x iW)
|
|
weight_size : Shape of the weight gradient tensor
|
|
grad_output : output gradient tensor (minibatch x out_channels x oW)
|
|
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
|
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
|
|
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
|
|
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
|
|
|
|
Examples::
|
|
|
|
>>> input = torch.randn(1,1,3, requires_grad=True)
|
|
>>> weight = torch.randn(1,1,1, requires_grad=True)
|
|
>>> output = F.conv1d(input, weight)
|
|
>>> grad_output = torch.randn(output.shape)
|
|
>>> grad_weight = torch.autograd.grad(output, filter, grad_output)
|
|
>>> F.grad.conv1d_weight(input, weight.shape, grad_output)
|
|
|
|
"""
|
|
stride = _single(stride)
|
|
padding = _single(padding)
|
|
dilation = _single(dilation)
|
|
in_channels = input.shape[1]
|
|
out_channels = grad_output.shape[1]
|
|
min_batch = input.shape[0]
|
|
|
|
grad_output = grad_output.contiguous().repeat(1, in_channels // groups, 1)
|
|
grad_output = grad_output.contiguous().view(
|
|
grad_output.shape[0] * grad_output.shape[1], 1, grad_output.shape[2])
|
|
|
|
input = input.contiguous().view(1, input.shape[0] * input.shape[1],
|
|
input.shape[2])
|
|
|
|
grad_weight = torch.conv1d(input, grad_output, None, dilation, padding,
|
|
stride, in_channels * min_batch)
|
|
|
|
grad_weight = grad_weight.contiguous().view(
|
|
min_batch, grad_weight.shape[1] // min_batch, grad_weight.shape[2])
|
|
|
|
return grad_weight.sum(dim=0).view(
|
|
in_channels // groups, out_channels, grad_weight.shape[2]).transpose(
|
|
0, 1).narrow(2, 0, weight_size[2])
|
|
|
|
|
|
def conv2d_input(input_size, weight, grad_output, stride=1, padding=0, dilation=1, groups=1):
|
|
r"""
|
|
Computes the gradient of conv2d with respect to the input of the convolution.
|
|
This is same as the 2D transposed convolution operator under the hood but requires
|
|
the shape of the gradient w.r.t. input to be specified explicitly.
|
|
|
|
Args:
|
|
input_size : Shape of the input gradient tensor
|
|
weight: weight tensor (out_channels x in_channels/groups x kH x kW)
|
|
grad_output : output gradient tensor (minibatch x out_channels x oH x oW)
|
|
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
|
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
|
|
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
|
|
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
|
|
|
|
Examples::
|
|
|
|
>>> input = torch.randn(1,1,3,3, requires_grad=True)
|
|
>>> weight = torch.randn(1,1,1,2, requires_grad=True)
|
|
>>> output = F.conv2d(input, weight)
|
|
>>> grad_output = torch.randn(output.shape)
|
|
>>> grad_input = torch.autograd.grad(output, input, grad_output)
|
|
>>> F.grad.conv2d_input(input.shape, weight, grad_output)
|
|
|
|
"""
|
|
stride = _pair(stride)
|
|
padding = _pair(padding)
|
|
dilation = _pair(dilation)
|
|
kernel_size = (weight.shape[2], weight.shape[3])
|
|
|
|
if input_size is None:
|
|
raise ValueError("grad.conv2d_input requires specifying an input_size")
|
|
|
|
grad_input_padding = _grad_input_padding(grad_output, input_size, stride,
|
|
padding, kernel_size, dilation)
|
|
|
|
return torch.conv_transpose2d(
|
|
grad_output, weight, None, stride, padding, grad_input_padding, groups,
|
|
dilation)
|
|
|
|
|
|
def conv2d_weight(input, weight_size, grad_output, stride=1, padding=0, dilation=1, groups=1):
|
|
r"""
|
|
Computes the gradient of conv2d with respect to the weight of the convolution.
|
|
|
|
Args:
|
|
input: input tensor of shape (minibatch x in_channels x iH x iW)
|
|
weight_size : Shape of the weight gradient tensor
|
|
grad_output : output gradient tensor (minibatch x out_channels x oH x oW)
|
|
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
|
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
|
|
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
|
|
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
|
|
|
|
Examples::
|
|
|
|
>>> input = torch.randn(1,1,3,3, requires_grad=True)
|
|
>>> weight = torch.randn(1,1,1,2, requires_grad=True)
|
|
>>> output = F.conv2d(input, weight)
|
|
>>> grad_output = torch.randn(output.shape)
|
|
>>> grad_weight = torch.autograd.grad(output, filter, grad_output)
|
|
>>> F.grad.conv2d_weight(input, weight.shape, grad_output)
|
|
|
|
"""
|
|
stride = _pair(stride)
|
|
padding = _pair(padding)
|
|
dilation = _pair(dilation)
|
|
in_channels = input.shape[1]
|
|
out_channels = grad_output.shape[1]
|
|
min_batch = input.shape[0]
|
|
|
|
grad_output = grad_output.contiguous().repeat(1, in_channels // groups, 1,
|
|
1)
|
|
grad_output = grad_output.contiguous().view(
|
|
grad_output.shape[0] * grad_output.shape[1], 1, grad_output.shape[2],
|
|
grad_output.shape[3])
|
|
|
|
input = input.contiguous().view(1, input.shape[0] * input.shape[1],
|
|
input.shape[2], input.shape[3])
|
|
|
|
grad_weight = torch.conv2d(input, grad_output, None, dilation, padding,
|
|
stride, in_channels * min_batch)
|
|
|
|
grad_weight = grad_weight.contiguous().view(
|
|
min_batch, grad_weight.shape[1] // min_batch, grad_weight.shape[2],
|
|
grad_weight.shape[3])
|
|
|
|
return grad_weight.sum(dim=0).view(
|
|
in_channels // groups, out_channels,
|
|
grad_weight.shape[2], grad_weight.shape[3]).transpose(0, 1).narrow(
|
|
2, 0, weight_size[2]).narrow(3, 0, weight_size[3])
|
|
|
|
|
|
def conv3d_input(input_size, weight, grad_output, stride=1, padding=0, dilation=1, groups=1):
|
|
r"""
|
|
Computes the gradient of conv3d with respect to the input of the convolution.
|
|
This is same as the 3D transposed convolution operator under the hood but requires
|
|
the shape of the gradient w.r.t. input to be specified explicitly.
|
|
|
|
Args:
|
|
input_size : Shape of the input gradient tensor
|
|
weight: weights tensor (out_channels x in_channels/groups x kT x kH x kW)
|
|
grad_output : output gradient tensor (minibatch x out_channels x oT x oH x oW)
|
|
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
|
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
|
|
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
|
|
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
|
|
|
|
Examples::
|
|
|
|
>>> input = torch.randn(2, 8, 10, 10, 20, requires_grad=True)
|
|
>>> weight = torch.randn(4, 8, 2, 3, 3, requires_grad=True)
|
|
>>> output = F.conv3d(input, weight)
|
|
>>> grad_output = torch.randn(output.shape)
|
|
>>> grad_input = torch.autograd.grad(output, input, grad_output)
|
|
>>> F.grad.conv3d_input(input.shape, weight, grad_output)
|
|
|
|
"""
|
|
stride = _triple(stride)
|
|
padding = _triple(padding)
|
|
dilation = _triple(dilation)
|
|
kernel_size = (weight.shape[2], weight.shape[3], weight.shape[4])
|
|
|
|
if input_size is None:
|
|
raise ValueError("grad.conv3d_input requires specifying an input_size")
|
|
|
|
grad_input_padding = _grad_input_padding(grad_output, input_size, stride,
|
|
padding, kernel_size, dilation)
|
|
|
|
return torch.conv_transpose3d(
|
|
grad_output, weight, None, stride, padding, grad_input_padding, groups,
|
|
dilation)
|
|
|
|
|
|
def conv3d_weight(input, weight_size, grad_output, stride=1, padding=0, dilation=1, groups=1):
|
|
r"""
|
|
Computes the gradient of conv3d with respect to the weight of the convolution.
|
|
|
|
Args:
|
|
input: input tensor of shape (minibatch x in_channels x iT x iH x iW)
|
|
weight_size : Shape of the weight gradient tensor
|
|
grad_output : output gradient tensor (minibatch x out_channels x oT x oH x oW)
|
|
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
|
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
|
|
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
|
|
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
|
|
|
|
Examples::
|
|
|
|
>>> input = torch.randn(2, 8, 10, 10, 20, requires_grad=True)
|
|
>>> weight = torch.randn(4, 8, 2, 3, 3, requires_grad=True)
|
|
>>> output = F.conv3d(input, weight)
|
|
>>> grad_output = torch.randn(output.shape)
|
|
>>> grad_weight = torch.autograd.grad(output, weight, grad_output)
|
|
>>> F.grad.conv3d_weight(input, weight.shape, grad_output)
|
|
|
|
"""
|
|
stride = _triple(stride)
|
|
padding = _triple(padding)
|
|
dilation = _triple(dilation)
|
|
in_channels = input.shape[1]
|
|
out_channels = grad_output.shape[1]
|
|
min_batch = input.shape[0]
|
|
|
|
grad_output = grad_output.repeat(1, in_channels // groups, 1, 1, 1)
|
|
grad_output = grad_output.contiguous().view(
|
|
grad_output.shape[0] * grad_output.shape[1], 1, grad_output.shape[2],
|
|
grad_output.shape[3], grad_output.shape[4])
|
|
|
|
input = input.contiguous().view(1, input.shape[0] * input.shape[1],
|
|
input.shape[2], input.shape[3],
|
|
input.shape[4])
|
|
|
|
grad_weight = torch.conv3d(input, grad_output, None, dilation, padding,
|
|
stride, in_channels * min_batch)
|
|
|
|
grad_weight = grad_weight.contiguous().view(
|
|
min_batch, grad_weight.shape[1] // min_batch, grad_weight.shape[2],
|
|
grad_weight.shape[3], grad_weight.shape[4])
|
|
|
|
return grad_weight.sum(dim=0).view(
|
|
in_channels // groups, out_channels, grad_weight.shape[2],
|
|
grad_weight.shape[3], grad_weight.shape[4]).transpose(0, 1).narrow(
|
|
2, 0, weight_size[2]).narrow(3, 0, weight_size[3]).narrow(
|
|
4, 0, weight_size[4])
|