mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
103 lines
4.1 KiB
Python
103 lines
4.1 KiB
Python
import torch
|
|
from torch.autograd import Function
|
|
from torch.autograd.function import once_differentiable
|
|
from torch._thnn import type2backend
|
|
from .thnn.auto import function_by_name
|
|
import torch.backends.cudnn as cudnn
|
|
|
|
|
|
class GridSampler(Function):
|
|
|
|
@staticmethod
|
|
def forward(ctx, input, grid):
|
|
ctx.save_for_backward(input, grid)
|
|
grid_sz = grid.size()
|
|
if cudnn.is_acceptable(input):
|
|
output = input.new(grid_sz[0], input.size(1), grid_sz[1], grid_sz[2])
|
|
grid = grid.contiguous()
|
|
if 0 in input.stride():
|
|
input = input.contiguous()
|
|
torch._C._cudnn_grid_sampler_forward(input, grid, output)
|
|
else:
|
|
backend = type2backend[type(input)]
|
|
output = input.new(grid_sz[0], input.size(1), grid_sz[1], grid_sz[2])
|
|
backend.SpatialGridSamplerBilinear_updateOutput(backend.library_state, input, grid, output)
|
|
return output
|
|
|
|
@staticmethod
|
|
@once_differentiable
|
|
def backward(ctx, grad_output):
|
|
input, grid = ctx.saved_tensors
|
|
if cudnn.is_acceptable(input):
|
|
grad_input = input.new(input.size())
|
|
grad_grid = grid.new(grid.size())
|
|
grid = grid.contiguous()
|
|
if 0 in input.stride():
|
|
input = input.contiguous()
|
|
torch._C._cudnn_grid_sampler_backward(input, grad_input,
|
|
grid, grad_grid,
|
|
grad_output)
|
|
else:
|
|
backend = type2backend[type(input)]
|
|
grad_input = input.new(input.size())
|
|
grad_grid = grid.new(grid.size())
|
|
backend.SpatialGridSamplerBilinear_updateGradInput(
|
|
backend.library_state, input, grad_input,
|
|
grid, grad_grid, grad_output)
|
|
return grad_input, grad_grid
|
|
|
|
|
|
class AffineGridGenerator(Function):
|
|
|
|
@staticmethod
|
|
def _enforce_cudnn(input):
|
|
if not cudnn.enabled:
|
|
raise RuntimeError("AffineGridGenerator needs CuDNN for "
|
|
"processing CUDA inputs, but CuDNN is not enabled")
|
|
assert cudnn.is_acceptable(input)
|
|
|
|
@staticmethod
|
|
def forward(ctx, theta, size):
|
|
assert type(size) == torch.Size
|
|
N, C, H, W = size
|
|
ctx.size = size
|
|
if theta.is_cuda:
|
|
ctx.is_cuda = True
|
|
AffineGridGenerator._enforce_cudnn(theta)
|
|
grid = theta.new(N, H, W, 2)
|
|
theta = theta.contiguous()
|
|
torch._C._cudnn_affine_grid_generator_forward(theta, grid, N, C, H, W)
|
|
else:
|
|
ctx.is_cuda = False
|
|
base_grid = theta.new(N, H, W, 3)
|
|
linear_points = torch.linspace(-1, 1, W) if W > 1 else torch.Tensor([-1])
|
|
base_grid[:, :, :, 0] = torch.ger(torch.ones(H), linear_points).expand_as(base_grid[:, :, :, 0])
|
|
linear_points = torch.linspace(-1, 1, H) if H > 1 else torch.Tensor([-1])
|
|
base_grid[:, :, :, 1] = torch.ger(linear_points, torch.ones(W)).expand_as(base_grid[:, :, :, 1])
|
|
base_grid[:, :, :, 2] = 1
|
|
ctx.base_grid = base_grid
|
|
grid = torch.bmm(base_grid.view(N, H * W, 3), theta.transpose(1, 2))
|
|
grid = grid.view(N, H, W, 2)
|
|
return grid
|
|
|
|
@staticmethod
|
|
@once_differentiable
|
|
def backward(ctx, grad_grid):
|
|
N, C, H, W = ctx.size
|
|
assert grad_grid.size() == torch.Size([N, H, W, 2])
|
|
assert ctx.is_cuda == grad_grid.is_cuda
|
|
if grad_grid.is_cuda:
|
|
AffineGridGenerator._enforce_cudnn(grad_grid)
|
|
grad_theta = grad_grid.new(N, 2, 3)
|
|
grad_grid = grad_grid.contiguous()
|
|
torch._C._cudnn_affine_grid_generator_backward(grad_theta, grad_grid,
|
|
N, C, H, W)
|
|
else:
|
|
base_grid = ctx.base_grid
|
|
grad_theta = torch.bmm(
|
|
base_grid.view(N, H * W, 3).transpose(1, 2),
|
|
grad_grid.view(N, H * W, 2))
|
|
grad_theta = grad_theta.transpose(1, 2)
|
|
|
|
return grad_theta, None
|