mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[64-bit][CUDA] Upsample2D 64-bit indexing fix attempt 2 (#141923)
#141831 Block/thread math requires a cast... Pull Request resolved: https://github.com/pytorch/pytorch/pull/141923 Approved by: https://github.com/ngimel
This commit is contained in:
parent
1d091e47d6
commit
dbdda654af
|
|
@ -49,14 +49,14 @@ __global__ void upsample_nearest2d_out_frame(
|
|||
float height_scale,
|
||||
float width_scale) {
|
||||
size_t nc_iter = threadIdx.z + blockIdx.z * blockDim.z;
|
||||
int w2 = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int h2 = threadIdx.y + blockIdx.y * blockDim.y;
|
||||
int64_t w2 = ((int64_t) threadIdx.x) + blockIdx.x * blockDim.x;
|
||||
int64_t h2 = threadIdx.y + blockIdx.y * blockDim.y;
|
||||
|
||||
if (w2 >= width2 || h2 >= height2) {
|
||||
return;
|
||||
}
|
||||
|
||||
int nc_stride = blockDim.z * gridDim.z;
|
||||
int64_t nc_stride = ((int64_t) blockDim.z) * gridDim.z;
|
||||
|
||||
const size_t h1 = height1 == height2
|
||||
? h2
|
||||
|
|
@ -93,9 +93,9 @@ __global__ void upsample_nearest2d_nhwc_out_frame(
|
|||
float width_scale,
|
||||
const size_t out_numel) {
|
||||
|
||||
const int64_t index = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x;
|
||||
const int64_t index = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x;
|
||||
|
||||
if (index < out_numel) {
|
||||
if (index < out_numel) {
|
||||
const auto c = index % channels;
|
||||
const auto w2 = (index / channels) % width2;
|
||||
const auto h2 = (index / channels / width2) % height2;
|
||||
|
|
@ -126,8 +126,8 @@ __global__ void upsample_nearest2d_backward_out_frame(
|
|||
if (dst_idx >= dim_c * dst_dim_h * dst_dim_w)
|
||||
return;
|
||||
|
||||
int dst_c_stride = dst_dim_h * dst_dim_w;
|
||||
int src_c_stride = src_dim_h * src_dim_w;
|
||||
int64_t dst_c_stride = dst_dim_h * dst_dim_w;
|
||||
int64_t src_c_stride = src_dim_h * src_dim_w;
|
||||
|
||||
int c = (dst_idx / (dst_c_stride)) % dim_c;
|
||||
|
||||
|
|
@ -178,7 +178,7 @@ __global__ void upsample_nearest2d_backward_nhwc_out_frame(
|
|||
// 1 is for grad_output (src)
|
||||
// 2 is for grad_input (dst)
|
||||
|
||||
const int index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int64_t index = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x;
|
||||
|
||||
if (index < gi_numel) {
|
||||
const int c = index % channels;
|
||||
|
|
@ -250,7 +250,6 @@ static void upsample_nearest2d_out_cuda_template(
|
|||
AT_DISPATCH_FLOATING_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Byte, input.scalar_type(), "upsample_nearest2d_nhwc_out_frame", [&] {
|
||||
const scalar_t* idata = input.const_data_ptr<scalar_t>();
|
||||
scalar_t* odata = output.mutable_data_ptr<scalar_t>();
|
||||
|
||||
upsample_nearest2d_nhwc_out_frame<scalar_t, nn_compute_source_index_fn>
|
||||
<<<ceil_div(num_kernels, num_threads), num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
idata,
|
||||
|
|
@ -272,7 +271,7 @@ static void upsample_nearest2d_out_cuda_template(
|
|||
Tensor output_c = output.is_contiguous() ? output : at::empty(output.sizes(), output.options());
|
||||
Tensor input = input_.contiguous();
|
||||
|
||||
int nc = nbatch * channels;
|
||||
int64_t nc = nbatch * channels;
|
||||
|
||||
const int max_threads = std::min<int>(
|
||||
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, MAX_THREADS);
|
||||
|
|
@ -293,7 +292,7 @@ static void upsample_nearest2d_out_cuda_template(
|
|||
int grid_x = ceil_div(output_width, block_x);
|
||||
int grid_y = ceil_div(output_height, block_y);
|
||||
int grid_z = std::min<int>(
|
||||
maxGridSize[2], ceil_div(nc, block_z * 4));
|
||||
maxGridSize[2], ceil_div(nc, (int64_t) block_z * 4));
|
||||
const dim3 grid(grid_x, grid_y, grid_z);
|
||||
// Error out on cases where grid_x & grid_y exceeds limit of launch config, as
|
||||
// the current kernel implementation doesn't loop over the two dimensions.
|
||||
|
|
@ -303,7 +302,6 @@ static void upsample_nearest2d_out_cuda_template(
|
|||
TORCH_CHECK(
|
||||
grid_x <= maxGridSize[0] && grid_y <= maxGridSize[1],
|
||||
"input tensor has spatial dimension larger than the kernel capacity");
|
||||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Byte, input.scalar_type(), "upsample_nearest2d_out_frame", [&] {
|
||||
using accscalar_t = at::acc_type<scalar_t, true>;
|
||||
|
|
@ -404,10 +402,10 @@ static void upsample_nearest2d_backward_out_cuda_template(
|
|||
Tensor grad_output = grad_output_.contiguous();
|
||||
|
||||
// upsample_nearest2d meta call makes sure `nbatch != 0`
|
||||
unsigned int n = grad_input.numel() / nbatch;
|
||||
size_t n = grad_input.numel() / nbatch;
|
||||
dim3 bdim{std::min<unsigned int>(
|
||||
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, MAX_THREADS)};
|
||||
dim3 gdim{ceil_div(n, bdim.x)};
|
||||
dim3 gdim{(unsigned int) ceil_div(n, (size_t) bdim.x)};
|
||||
// safe check for int64 indexing; implicitly restrict launch config for kernel
|
||||
TORCH_CHECK(grad_input.numel() <= std::numeric_limits<int64_t>::max(), "upsample2d grad_input.numel() <= std::numeric_limits<int64_t>::max(), but got ", grad_input.sizes());
|
||||
TORCH_CHECK(grad_output.numel() <= std::numeric_limits<int64_t>::max(), "upsample2d grad_output.numel() <= std::numeric_limits<int64_t>::max(), but got ", grad_output.sizes());
|
||||
|
|
|
|||
|
|
@ -9961,7 +9961,8 @@ class TestNNDeviceType(NNTestCase):
|
|||
gradgradcheck(lambda x: F.interpolate(x, out_size, **kwargs), [input])
|
||||
|
||||
@onlyCUDA
|
||||
@dtypes(torch.half)
|
||||
@skipCUDAIfRocm(msg="launch bounds error out on ROCM")
|
||||
@dtypes(torch.half, torch.bfloat16)
|
||||
@largeTensorTest('40GB')
|
||||
def test_upsampling_64bit_indexing_channels_last(self, device, dtype):
|
||||
x = torch.rand((32, 64, 512, 512), dtype=dtype, device=device)
|
||||
|
|
@ -9970,6 +9971,10 @@ class TestNNDeviceType(NNTestCase):
|
|||
del x
|
||||
self.assertTrue(torch.allclose(out, out_ref))
|
||||
|
||||
x = torch.ones((17, 256, 512, 512), dtype=dtype).cuda().to(memory_format=torch.channels_last)
|
||||
out = torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')
|
||||
self.assertEqual(out[0], out[-1])
|
||||
|
||||
@onlyCUDA
|
||||
@dtypes(torch.half)
|
||||
@largeTensorTest('40GB')
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user