From dbdda654af2ef8044d4818689a9022f9cb678a95 Mon Sep 17 00:00:00 2001 From: eqy Date: Sat, 4 Jan 2025 02:30:38 +0000 Subject: [PATCH] [64-bit][CUDA] Upsample2D 64-bit indexing fix attempt 2 (#141923) #141831 Block/thread math requires a cast... Pull Request resolved: https://github.com/pytorch/pytorch/pull/141923 Approved by: https://github.com/ngimel --- .../src/ATen/native/cuda/UpSampleNearest2d.cu | 26 +++++++++---------- test/test_nn.py | 7 ++++- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/aten/src/ATen/native/cuda/UpSampleNearest2d.cu b/aten/src/ATen/native/cuda/UpSampleNearest2d.cu index bdf5ad1ae33..e9470b3a6b4 100644 --- a/aten/src/ATen/native/cuda/UpSampleNearest2d.cu +++ b/aten/src/ATen/native/cuda/UpSampleNearest2d.cu @@ -49,14 +49,14 @@ __global__ void upsample_nearest2d_out_frame( float height_scale, float width_scale) { size_t nc_iter = threadIdx.z + blockIdx.z * blockDim.z; - int w2 = threadIdx.x + blockIdx.x * blockDim.x; - int h2 = threadIdx.y + blockIdx.y * blockDim.y; + int64_t w2 = ((int64_t) threadIdx.x) + blockIdx.x * blockDim.x; + int64_t h2 = threadIdx.y + blockIdx.y * blockDim.y; if (w2 >= width2 || h2 >= height2) { return; } - int nc_stride = blockDim.z * gridDim.z; + int64_t nc_stride = ((int64_t) blockDim.z) * gridDim.z; const size_t h1 = height1 == height2 ? h2 @@ -93,9 +93,9 @@ __global__ void upsample_nearest2d_nhwc_out_frame( float width_scale, const size_t out_numel) { - const int64_t index = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x; + const int64_t index = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x; - if (index < out_numel) { + if (index < out_numel) { const auto c = index % channels; const auto w2 = (index / channels) % width2; const auto h2 = (index / channels / width2) % height2; @@ -126,8 +126,8 @@ __global__ void upsample_nearest2d_backward_out_frame( if (dst_idx >= dim_c * dst_dim_h * dst_dim_w) return; - int dst_c_stride = dst_dim_h * dst_dim_w; - int src_c_stride = src_dim_h * src_dim_w; + int64_t dst_c_stride = dst_dim_h * dst_dim_w; + int64_t src_c_stride = src_dim_h * src_dim_w; int c = (dst_idx / (dst_c_stride)) % dim_c; @@ -178,7 +178,7 @@ __global__ void upsample_nearest2d_backward_nhwc_out_frame( // 1 is for grad_output (src) // 2 is for grad_input (dst) - const int index = blockIdx.x * blockDim.x + threadIdx.x; + const int64_t index = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x; if (index < gi_numel) { const int c = index % channels; @@ -250,7 +250,6 @@ static void upsample_nearest2d_out_cuda_template( AT_DISPATCH_FLOATING_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Byte, input.scalar_type(), "upsample_nearest2d_nhwc_out_frame", [&] { const scalar_t* idata = input.const_data_ptr(); scalar_t* odata = output.mutable_data_ptr(); - upsample_nearest2d_nhwc_out_frame <<>>( idata, @@ -272,7 +271,7 @@ static void upsample_nearest2d_out_cuda_template( Tensor output_c = output.is_contiguous() ? output : at::empty(output.sizes(), output.options()); Tensor input = input_.contiguous(); - int nc = nbatch * channels; + int64_t nc = nbatch * channels; const int max_threads = std::min( at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, MAX_THREADS); @@ -293,7 +292,7 @@ static void upsample_nearest2d_out_cuda_template( int grid_x = ceil_div(output_width, block_x); int grid_y = ceil_div(output_height, block_y); int grid_z = std::min( - maxGridSize[2], ceil_div(nc, block_z * 4)); + maxGridSize[2], ceil_div(nc, (int64_t) block_z * 4)); const dim3 grid(grid_x, grid_y, grid_z); // Error out on cases where grid_x & grid_y exceeds limit of launch config, as // the current kernel implementation doesn't loop over the two dimensions. @@ -303,7 +302,6 @@ static void upsample_nearest2d_out_cuda_template( TORCH_CHECK( grid_x <= maxGridSize[0] && grid_y <= maxGridSize[1], "input tensor has spatial dimension larger than the kernel capacity"); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); AT_DISPATCH_FLOATING_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Byte, input.scalar_type(), "upsample_nearest2d_out_frame", [&] { using accscalar_t = at::acc_type; @@ -404,10 +402,10 @@ static void upsample_nearest2d_backward_out_cuda_template( Tensor grad_output = grad_output_.contiguous(); // upsample_nearest2d meta call makes sure `nbatch != 0` - unsigned int n = grad_input.numel() / nbatch; + size_t n = grad_input.numel() / nbatch; dim3 bdim{std::min( at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, MAX_THREADS)}; - dim3 gdim{ceil_div(n, bdim.x)}; + dim3 gdim{(unsigned int) ceil_div(n, (size_t) bdim.x)}; // safe check for int64 indexing; implicitly restrict launch config for kernel TORCH_CHECK(grad_input.numel() <= std::numeric_limits::max(), "upsample2d grad_input.numel() <= std::numeric_limits::max(), but got ", grad_input.sizes()); TORCH_CHECK(grad_output.numel() <= std::numeric_limits::max(), "upsample2d grad_output.numel() <= std::numeric_limits::max(), but got ", grad_output.sizes()); diff --git a/test/test_nn.py b/test/test_nn.py index b97744ccb70..d79b1ba00de 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -9961,7 +9961,8 @@ class TestNNDeviceType(NNTestCase): gradgradcheck(lambda x: F.interpolate(x, out_size, **kwargs), [input]) @onlyCUDA - @dtypes(torch.half) + @skipCUDAIfRocm(msg="launch bounds error out on ROCM") + @dtypes(torch.half, torch.bfloat16) @largeTensorTest('40GB') def test_upsampling_64bit_indexing_channels_last(self, device, dtype): x = torch.rand((32, 64, 512, 512), dtype=dtype, device=device) @@ -9970,6 +9971,10 @@ class TestNNDeviceType(NNTestCase): del x self.assertTrue(torch.allclose(out, out_ref)) + x = torch.ones((17, 256, 512, 512), dtype=dtype).cuda().to(memory_format=torch.channels_last) + out = torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest') + self.assertEqual(out[0], out[-1]) + @onlyCUDA @dtypes(torch.half) @largeTensorTest('40GB')