[64-bit][CUDA] Upsample2D 64-bit indexing fix attempt 2 (#141923)

#141831
Block/thread math requires a cast...

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141923
Approved by: https://github.com/ngimel
This commit is contained in:
eqy 2025-01-04 02:30:38 +00:00 committed by PyTorch MergeBot
parent 1d091e47d6
commit dbdda654af
2 changed files with 18 additions and 15 deletions

View File

@ -49,14 +49,14 @@ __global__ void upsample_nearest2d_out_frame(
float height_scale,
float width_scale) {
size_t nc_iter = threadIdx.z + blockIdx.z * blockDim.z;
int w2 = threadIdx.x + blockIdx.x * blockDim.x;
int h2 = threadIdx.y + blockIdx.y * blockDim.y;
int64_t w2 = ((int64_t) threadIdx.x) + blockIdx.x * blockDim.x;
int64_t h2 = threadIdx.y + blockIdx.y * blockDim.y;
if (w2 >= width2 || h2 >= height2) {
return;
}
int nc_stride = blockDim.z * gridDim.z;
int64_t nc_stride = ((int64_t) blockDim.z) * gridDim.z;
const size_t h1 = height1 == height2
? h2
@ -93,9 +93,9 @@ __global__ void upsample_nearest2d_nhwc_out_frame(
float width_scale,
const size_t out_numel) {
const int64_t index = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x;
const int64_t index = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x;
if (index < out_numel) {
if (index < out_numel) {
const auto c = index % channels;
const auto w2 = (index / channels) % width2;
const auto h2 = (index / channels / width2) % height2;
@ -126,8 +126,8 @@ __global__ void upsample_nearest2d_backward_out_frame(
if (dst_idx >= dim_c * dst_dim_h * dst_dim_w)
return;
int dst_c_stride = dst_dim_h * dst_dim_w;
int src_c_stride = src_dim_h * src_dim_w;
int64_t dst_c_stride = dst_dim_h * dst_dim_w;
int64_t src_c_stride = src_dim_h * src_dim_w;
int c = (dst_idx / (dst_c_stride)) % dim_c;
@ -178,7 +178,7 @@ __global__ void upsample_nearest2d_backward_nhwc_out_frame(
// 1 is for grad_output (src)
// 2 is for grad_input (dst)
const int index = blockIdx.x * blockDim.x + threadIdx.x;
const int64_t index = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x;
if (index < gi_numel) {
const int c = index % channels;
@ -250,7 +250,6 @@ static void upsample_nearest2d_out_cuda_template(
AT_DISPATCH_FLOATING_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Byte, input.scalar_type(), "upsample_nearest2d_nhwc_out_frame", [&] {
const scalar_t* idata = input.const_data_ptr<scalar_t>();
scalar_t* odata = output.mutable_data_ptr<scalar_t>();
upsample_nearest2d_nhwc_out_frame<scalar_t, nn_compute_source_index_fn>
<<<ceil_div(num_kernels, num_threads), num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
idata,
@ -272,7 +271,7 @@ static void upsample_nearest2d_out_cuda_template(
Tensor output_c = output.is_contiguous() ? output : at::empty(output.sizes(), output.options());
Tensor input = input_.contiguous();
int nc = nbatch * channels;
int64_t nc = nbatch * channels;
const int max_threads = std::min<int>(
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, MAX_THREADS);
@ -293,7 +292,7 @@ static void upsample_nearest2d_out_cuda_template(
int grid_x = ceil_div(output_width, block_x);
int grid_y = ceil_div(output_height, block_y);
int grid_z = std::min<int>(
maxGridSize[2], ceil_div(nc, block_z * 4));
maxGridSize[2], ceil_div(nc, (int64_t) block_z * 4));
const dim3 grid(grid_x, grid_y, grid_z);
// Error out on cases where grid_x & grid_y exceeds limit of launch config, as
// the current kernel implementation doesn't loop over the two dimensions.
@ -303,7 +302,6 @@ static void upsample_nearest2d_out_cuda_template(
TORCH_CHECK(
grid_x <= maxGridSize[0] && grid_y <= maxGridSize[1],
"input tensor has spatial dimension larger than the kernel capacity");
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Byte, input.scalar_type(), "upsample_nearest2d_out_frame", [&] {
using accscalar_t = at::acc_type<scalar_t, true>;
@ -404,10 +402,10 @@ static void upsample_nearest2d_backward_out_cuda_template(
Tensor grad_output = grad_output_.contiguous();
// upsample_nearest2d meta call makes sure `nbatch != 0`
unsigned int n = grad_input.numel() / nbatch;
size_t n = grad_input.numel() / nbatch;
dim3 bdim{std::min<unsigned int>(
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, MAX_THREADS)};
dim3 gdim{ceil_div(n, bdim.x)};
dim3 gdim{(unsigned int) ceil_div(n, (size_t) bdim.x)};
// safe check for int64 indexing; implicitly restrict launch config for kernel
TORCH_CHECK(grad_input.numel() <= std::numeric_limits<int64_t>::max(), "upsample2d grad_input.numel() <= std::numeric_limits<int64_t>::max(), but got ", grad_input.sizes());
TORCH_CHECK(grad_output.numel() <= std::numeric_limits<int64_t>::max(), "upsample2d grad_output.numel() <= std::numeric_limits<int64_t>::max(), but got ", grad_output.sizes());

View File

@ -9961,7 +9961,8 @@ class TestNNDeviceType(NNTestCase):
gradgradcheck(lambda x: F.interpolate(x, out_size, **kwargs), [input])
@onlyCUDA
@dtypes(torch.half)
@skipCUDAIfRocm(msg="launch bounds error out on ROCM")
@dtypes(torch.half, torch.bfloat16)
@largeTensorTest('40GB')
def test_upsampling_64bit_indexing_channels_last(self, device, dtype):
x = torch.rand((32, 64, 512, 512), dtype=dtype, device=device)
@ -9970,6 +9971,10 @@ class TestNNDeviceType(NNTestCase):
del x
self.assertTrue(torch.allclose(out, out_ref))
x = torch.ones((17, 256, 512, 512), dtype=dtype).cuda().to(memory_format=torch.channels_last)
out = torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')
self.assertEqual(out[0], out[-1])
@onlyCUDA
@dtypes(torch.half)
@largeTensorTest('40GB')