mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: fix nearest upsample dgrad bug, where window computation was wrong previously; fix python test where previously GPU implementation was not tested; Pull Request resolved: https://github.com/pytorch/pytorch/pull/39055 Differential Revision: D21763242 Pulled By: albanD fbshipit-source-id: 9b1d5365f40176450f529136110542fd36bd7f58
This commit is contained in:
parent
5702a28b26
commit
bfcb687b9c
|
|
@ -136,24 +136,52 @@ static inline void upsample_3d_shape_check(
|
|||
}
|
||||
}
|
||||
|
||||
// NOTE [ Nearest neighbor upsampling kernel implementation ]
|
||||
//
|
||||
// The nearest neighbor upsampling kernel implementation is symmetrical as
|
||||
// expected. We launch kernels with threads mapping to destination tensors where
|
||||
// kernels write data to, each thread reads data from the source tensor, this
|
||||
// means:
|
||||
// 1. In the forward kernel,
|
||||
// src_xxx refers to properties of input tensors;
|
||||
// dst_xxx refers to properties of output tensors;
|
||||
// scale_factor is the ratio of src_size to dst_size;
|
||||
// 2. In the backward kernel,
|
||||
// src_xxx refers to properties of grad_output tensors;
|
||||
// dst_xxx refers to properties of grad_input tensors;
|
||||
// scale_factor is the ratio of src_size to dst_size;
|
||||
//
|
||||
// Because of this, we need to take the reciprocal of the scale defined by
|
||||
// upsample layer during forward path. The motivation is to avoid slow
|
||||
// division in the kernel code, so we can use faster multiplication instead.
|
||||
// This is not necessary during backward path, since the scale_factor is already
|
||||
// the reciprocal of corresponding scale_factor used in the forward path due to
|
||||
// the swap of source and destination tensor.
|
||||
//
|
||||
// Similarly, since the mapping from grad_input to grad_output during backward
|
||||
// is the reverse of the mapping of output to input, we need to have opposite
|
||||
// mapping functions to compute the source index.
|
||||
|
||||
// see NOTE [ Nearest neighbor upsampling kernel implementation ]
|
||||
template <typename accscalar_t>
|
||||
__host__ __forceinline__ static accscalar_t compute_scales_value(
|
||||
const c10::optional<double> scale,
|
||||
int64_t input_size,
|
||||
int64_t output_size) {
|
||||
int64_t src_size,
|
||||
int64_t dst_size) {
|
||||
// FIXME: remove magic > 0 after we ensure no models were serialized with -1 defaults.
|
||||
return (scale.has_value() && scale.value() > 0.) ? (accscalar_t)(1.0 / scale.value())
|
||||
: (accscalar_t)input_size / output_size;
|
||||
: (accscalar_t)src_size / dst_size;
|
||||
}
|
||||
|
||||
// see NOTE [ Nearest neighbor upsampling kernel implementation ]
|
||||
template <typename accscalar_t>
|
||||
__host__ __forceinline__ static accscalar_t compute_scales_value_backwards(
|
||||
const c10::optional<double> scale,
|
||||
int64_t input_size,
|
||||
int64_t output_size) {
|
||||
int64_t src_size,
|
||||
int64_t dst_size) {
|
||||
// FIXME: remove magic > 0 after we ensure no models were serialized with -1 defaults.
|
||||
return (scale.has_value() && scale.value() > 0.) ? (accscalar_t)scale.value()
|
||||
: (accscalar_t)input_size / output_size;
|
||||
: (accscalar_t)src_size / dst_size;
|
||||
}
|
||||
|
||||
template <typename accscalar_t>
|
||||
|
|
@ -188,6 +216,7 @@ __device__ __forceinline__ static accscalar_t area_pixel_compute_source_index(
|
|||
}
|
||||
}
|
||||
|
||||
// see NOTE [ Nearest neighbor upsampling kernel implementation ]
|
||||
__device__ __forceinline__ static int nearest_neighbor_compute_source_index(
|
||||
const float scale,
|
||||
int dst_index,
|
||||
|
|
@ -197,6 +226,16 @@ __device__ __forceinline__ static int nearest_neighbor_compute_source_index(
|
|||
return src_index;
|
||||
}
|
||||
|
||||
// see NOTE [ Nearest neighbor upsampling kernel implementation ]
|
||||
__device__ __forceinline__ static int nearest_neighbor_bw_compute_source_index(
|
||||
const float scale,
|
||||
int dst_index,
|
||||
int output_size) {
|
||||
const int src_index =
|
||||
static_cast<int>(ceilf(dst_index * scale));
|
||||
return src_index;
|
||||
}
|
||||
|
||||
/* Used by UpSampleBicubic2d.cu */
|
||||
template <typename scalar_t>
|
||||
__device__ __forceinline__ static scalar_t upsample_get_value_bounded(
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ namespace {
|
|||
|
||||
#define MAX_THREADS 512
|
||||
|
||||
// see NOTE [ Nearest neighbor upsampling kernel implementation ]
|
||||
template <typename scalar_t>
|
||||
C10_LAUNCH_BOUNDS_1(1024)
|
||||
__global__ void upsample_nearest1d_out_frame(
|
||||
|
|
@ -43,6 +44,7 @@ __global__ void upsample_nearest1d_out_frame(
|
|||
}
|
||||
}
|
||||
|
||||
// see NOTE [ Nearest neighbor upsampling kernel implementation ]
|
||||
// Backward operation
|
||||
template <typename scalar_t, typename accscalar_t>
|
||||
C10_LAUNCH_BOUNDS_1(1024)
|
||||
|
|
@ -62,8 +64,8 @@ __global__ void upsample_nearest1d_backward_out_frame(
|
|||
int c = (dst_idx / (dst_dim_w)) % dim_c;
|
||||
|
||||
int dst_x = dst_idx % dst_dim_w;
|
||||
int src_x = nearest_neighbor_compute_source_index(scale_factor, dst_x, src_dim_w);
|
||||
int src_x_up = nearest_neighbor_compute_source_index(scale_factor, dst_x+1, src_dim_w+1);
|
||||
int src_x = nearest_neighbor_bw_compute_source_index(scale_factor, dst_x, src_dim_w);
|
||||
int src_x_up = nearest_neighbor_bw_compute_source_index(scale_factor, dst_x+1, src_dim_w+1);
|
||||
|
||||
for (int b = 0; b < dim_b; b++) {
|
||||
accscalar_t grad = 0;
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ namespace {
|
|||
|
||||
#define MAX_THREADS 512
|
||||
|
||||
// see NOTE [ Nearest neighbor upsampling kernel implementation ]
|
||||
template <typename scalar_t, typename accscalar_t>
|
||||
C10_LAUNCH_BOUNDS_1(1024)
|
||||
__global__ void upsample_nearest2d_out_frame(
|
||||
|
|
@ -57,6 +58,7 @@ __global__ void upsample_nearest2d_out_frame(
|
|||
}
|
||||
}
|
||||
|
||||
// see NOTE [ Nearest neighbor upsampling kernel implementation ]
|
||||
template <typename scalar_t, typename accscalar_t>
|
||||
C10_LAUNCH_BOUNDS_1(1024)
|
||||
__global__ void upsample_nearest2d_backward_out_frame(
|
||||
|
|
@ -81,14 +83,14 @@ __global__ void upsample_nearest2d_backward_out_frame(
|
|||
|
||||
int dst_y = (dst_idx / dst_dim_w) % dst_dim_h;
|
||||
int src_y =
|
||||
nearest_neighbor_compute_source_index(height_scale, dst_y, src_dim_h);
|
||||
int src_y_up = nearest_neighbor_compute_source_index(
|
||||
nearest_neighbor_bw_compute_source_index(height_scale, dst_y, src_dim_h);
|
||||
int src_y_up = nearest_neighbor_bw_compute_source_index(
|
||||
height_scale, dst_y + 1, src_dim_h + 1);
|
||||
|
||||
int dst_x = dst_idx % dst_dim_w;
|
||||
int src_x =
|
||||
nearest_neighbor_compute_source_index(width_scale, dst_x, src_dim_w);
|
||||
int src_x_up = nearest_neighbor_compute_source_index(
|
||||
nearest_neighbor_bw_compute_source_index(width_scale, dst_x, src_dim_w);
|
||||
int src_x_up = nearest_neighbor_bw_compute_source_index(
|
||||
width_scale, dst_x + 1, src_dim_w + 1);
|
||||
|
||||
for (int b = 0; b < dim_b; b++) {
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ namespace {
|
|||
|
||||
#define MAX_THREADS 512
|
||||
|
||||
// see NOTE [ Nearest neighbor upsampling kernel implementation ]
|
||||
template <typename scalar_t>
|
||||
C10_LAUNCH_BOUNDS_1(1024)
|
||||
__global__ void upsample_nearest3d_out_frame(
|
||||
|
|
@ -57,6 +58,7 @@ __global__ void upsample_nearest3d_out_frame(
|
|||
}
|
||||
}
|
||||
|
||||
// see NOTE [ Nearest neighbor upsampling kernel implementation ]
|
||||
// Backward operation
|
||||
template <typename scalar_t, typename accscalar_t>
|
||||
C10_LAUNCH_BOUNDS_1(1024)
|
||||
|
|
@ -85,16 +87,16 @@ __global__ void upsample_nearest3d_backward_out_frame(
|
|||
int c = (dst_idx / (dst_c_stride)) % dim_c;
|
||||
|
||||
int dst_z = (dst_idx / dst_dim_h / dst_dim_w) % dst_dim_d;
|
||||
int src_z = nearest_neighbor_compute_source_index(depth_scale, dst_z, src_dim_d);
|
||||
int src_z_up = nearest_neighbor_compute_source_index(depth_scale, dst_z+1, src_dim_d+1);
|
||||
int src_z = nearest_neighbor_bw_compute_source_index(depth_scale, dst_z, src_dim_d);
|
||||
int src_z_up = nearest_neighbor_bw_compute_source_index(depth_scale, dst_z+1, src_dim_d+1);
|
||||
|
||||
int dst_y = (dst_idx / dst_dim_w) % dst_dim_h;
|
||||
int src_y = nearest_neighbor_compute_source_index(height_scale, dst_y, src_dim_h);
|
||||
int src_y_up = nearest_neighbor_compute_source_index(height_scale, dst_y+1, src_dim_h+1);
|
||||
int src_y = nearest_neighbor_bw_compute_source_index(height_scale, dst_y, src_dim_h);
|
||||
int src_y_up = nearest_neighbor_bw_compute_source_index(height_scale, dst_y+1, src_dim_h+1);
|
||||
|
||||
int dst_x = dst_idx % dst_dim_w;
|
||||
int src_x = nearest_neighbor_compute_source_index(width_scale, dst_x, src_dim_w);
|
||||
int src_x_up = nearest_neighbor_compute_source_index(width_scale, dst_x+1, src_dim_w+1);
|
||||
int src_x = nearest_neighbor_bw_compute_source_index(width_scale, dst_x, src_dim_w);
|
||||
int src_x_up = nearest_neighbor_bw_compute_source_index(width_scale, dst_x+1, src_dim_w+1);
|
||||
|
||||
for (int b = 0; b < dim_b; b++) {
|
||||
accscalar_t grad = 0;
|
||||
|
|
|
|||
|
|
@ -7625,7 +7625,7 @@ class TestNN(NNTestCase):
|
|||
dim = len(in_t.shape) - 2
|
||||
out_shape = [1, 1] + [out_size] * dim
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
out_t = m(in_t)
|
||||
out_t = layer(in_t)
|
||||
self.assertEqual(torch.ones(out_shape), out_t)
|
||||
|
||||
self.assertEqual(
|
||||
|
|
@ -7634,10 +7634,10 @@ class TestNN(NNTestCase):
|
|||
gradcheck(lambda x: F.interpolate(x, out_size, **kwargs), [in_t])
|
||||
gradgradcheck(lambda x: F.interpolate(x, out_size, **kwargs), [in_t])
|
||||
|
||||
def _make_input(dim):
|
||||
def _make_input(dim, device):
|
||||
size = [1, 1]
|
||||
size += [2] * dim
|
||||
return torch.ones(size, requires_grad=True)
|
||||
return torch.ones(size, requires_grad=True, device=device)
|
||||
|
||||
device_list = ['cpu']
|
||||
if TEST_CUDA:
|
||||
|
|
@ -7648,27 +7648,27 @@ class TestNN(NNTestCase):
|
|||
for mode in ['nearest', 'area']:
|
||||
kwargs = dict(mode=mode)
|
||||
m = nn.Upsample(scale_factor=scale_factor, **kwargs).to(device)
|
||||
for input in [_make_input(1), _make_input(2), _make_input(3)]:
|
||||
for input in [_make_input(1, device), _make_input(2, device), _make_input(3, device)]:
|
||||
_test_interpolate_helper(input, scale_factor, m)
|
||||
|
||||
for align_corners in [True, False]:
|
||||
kwargs = dict(mode='linear', align_corners=align_corners)
|
||||
m = nn.Upsample(scale_factor=scale_factor, **kwargs).to(device)
|
||||
_test_interpolate_helper(_make_input(1), scale_factor, m)
|
||||
_test_interpolate_helper(_make_input(1, device), scale_factor, m)
|
||||
|
||||
kwargs = dict(mode='bilinear', align_corners=align_corners)
|
||||
m = nn.Upsample(scale_factor=scale_factor, **kwargs).to(device)
|
||||
_test_interpolate_helper(_make_input(2), scale_factor, m)
|
||||
_test_interpolate_helper(_make_input(2, device), scale_factor, m)
|
||||
|
||||
kwargs = dict(mode='bicubic', align_corners=align_corners)
|
||||
|
||||
def m(t):
|
||||
return F.interpolate(t, scale_factor=scale_factor, **kwargs).to(device)
|
||||
_test_interpolate_helper(_make_input(2), scale_factor, m)
|
||||
_test_interpolate_helper(_make_input(2, device), scale_factor, m)
|
||||
|
||||
kwargs = dict(mode='trilinear', align_corners=align_corners)
|
||||
m = nn.Upsample(scale_factor=scale_factor, **kwargs).to(device)
|
||||
_test_interpolate_helper(_make_input(3), scale_factor, m)
|
||||
_test_interpolate_helper(_make_input(3, device), scale_factor, m)
|
||||
|
||||
def test_linear_broadcasting(self):
|
||||
m = nn.Linear(5, 8)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user