[CUDA] Workaround shmem limit for certain input sizes in AdaptiveAvgPool1D (#115231)

Reference issue #68248

CC @ptrblck @malfet @xwang233

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115231
Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
eqy 2023-12-19 22:40:10 +00:00 committed by PyTorch MergeBot
parent 7d92449171
commit d55365dc05
2 changed files with 65 additions and 39 deletions

View File

@ -639,7 +639,7 @@ namespace {
{sizeC*isizeH*isizeW, 1, isizeW*sizeC, sizeC});
}
const int max_threads = std::min<int>(
int max_threads = std::min<int>(
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, CUDA_MAX_THREADS);
int* maxThreadsDim = at::cuda::getCurrentDeviceProperties()->maxThreadsDim;
int* maxGridSize = at::cuda::getCurrentDeviceProperties()->maxGridSize;
@ -654,46 +654,57 @@ namespace {
// C -> block.x
// encourage larger block_y & block_z for better cache hit while maintain
// reasonable block_x for coalesced memory access;
int block_x = std::min<int>(
maxThreadsDim[0], std::min<int>(lastPow2(sizeC), at::cuda::warp_size()));
int block_y = std::min<int>(
maxThreadsDim[1], std::min<int>(lastPow2(isizeW), max_threads / block_x));
int block_z = std::min<int>(
maxThreadsDim[2], std::min<int>(lastPow2(isizeH), max_threads / block_x / block_y));
block_x = std::min<int>(
maxThreadsDim[0], std::min<int>(lastPow2(sizeC), max_threads / block_y / block_z));
const dim3 block(block_x, block_y, block_z);
int kernel_stride_C = ceil_div(sizeC, block_x * 4);
int kernel_size_C = ceil_div(sizeC, block_x * kernel_stride_C);
bool done = false;
do {
int block_x = std::max<int>(std::min<int>(
maxThreadsDim[0], std::min<int>(lastPow2(sizeC), at::cuda::warp_size())), 1);
int block_y = std::max<int>(std::min<int>(
maxThreadsDim[1], std::min<int>(lastPow2(isizeW), max_threads / block_x)), 1);
int block_z = std::max<int>(std::min<int>(
maxThreadsDim[2], std::min<int>(lastPow2(isizeH), max_threads / block_x / block_y)), 1);
block_x = std::max<int>(std::min<int>(
maxThreadsDim[0], std::min<int>(lastPow2(sizeC), max_threads / block_y / block_z)), 1);
const dim3 block(block_x, block_y, block_z);
int kernel_stride_C = ceil_div(sizeC, block_x * 4);
int kernel_size_C = ceil_div(sizeC, block_x * kernel_stride_C);
// Do NOT clip grid_x, striding on Batch dimension is not in the kernel,
// although it could be easily implemented given current kernel.
int grid_x = sizeB*kernel_stride_C;
// it's OK to clip grid_y & grid_z, as we block the two dimensions in the kernel;
int grid_y = std::min<int>(
maxGridSize[1], ceil_div(isizeW, block_y*BLOCK_STRIDE));
int grid_z = std::min<int>(
maxGridSize[2], ceil_div(isizeH, block_z*BLOCK_STRIDE));
const dim3 grid(grid_x, grid_y, grid_z);
// Do NOT clip grid_x, striding on Batch dimension is not in the kernel,
// although it could be easily implemented given current kernel.
int grid_x = sizeB*kernel_stride_C;
// it's OK to clip grid_y & grid_z, as we block the two dimensions in the kernel;
int grid_y = std::min<int>(
maxGridSize[1], ceil_div(isizeW, block_y*BLOCK_STRIDE));
int grid_z = std::min<int>(
maxGridSize[2], ceil_div(isizeH, block_z*BLOCK_STRIDE));
const dim3 grid(grid_x, grid_y, grid_z);
// we are dealing with packed tensor here. max index is the same as numel.
// TODO: to really support input tensor large enought to go beyond int32,
// we will need to restrict out shared memory usage and adjust the launch
// config;
AT_ASSERT(input.numel() < std::numeric_limits<int32_t>::max());
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
input.scalar_type(), "adaptive_avg_pool2d_backward_nhwc_cuda", [&] {
size_t shmem_size = (kernel_size_C * block_x * block_y * block_z + osizeH + osizeW) * sizeof(scalar_t) + 2 * isizeW * sizeof(int32_t);
AT_ASSERT(shmem_size <= sharedMemPerBlock);
adaptive_average_gradinput_nhwc<int32_t><<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>> (
gradInput.mutable_data_ptr<scalar_t>(),
gradOutput.const_data_ptr<scalar_t>(),
sizeB, sizeC, isizeH, isizeW, osizeH, osizeW,
kernel_stride_C, kernel_size_C,
ostrideB, ostrideC, ostrideH, ostrideW);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
);
// we are dealing with packed tensor here. max index is the same as numel.
// TODO: to really support input tensor large enought to go beyond int32,
// we will need to restrict out shared memory usage and adjust the launch
// config;
AT_ASSERT(input.numel() < std::numeric_limits<int32_t>::max());
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
input.scalar_type(), "adaptive_avg_pool2d_backward_nhwc_cuda", [&] {
size_t shmem_size = (kernel_size_C * block_x * block_y * block_z + osizeH + osizeW) * sizeof(scalar_t) + 2 * isizeW * sizeof(int32_t);
if (shmem_size <= sharedMemPerBlock) {
adaptive_average_gradinput_nhwc<int32_t><<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>> (
gradInput.mutable_data_ptr<scalar_t>(),
gradOutput.const_data_ptr<scalar_t>(),
sizeB, sizeC, isizeH, isizeW, osizeH, osizeW,
kernel_stride_C, kernel_size_C,
ostrideB, ostrideC, ostrideH, ostrideW);
C10_CUDA_KERNEL_LAUNCH_CHECK();
done = true;
} else {
TORCH_WARN_ONCE("Requested shmem_size exceeds sharedMemPerBlock limit! Reducing max_threads...");
max_threads /= 2;
}
}
);
} while (!done && max_threads);
if (!done) {
TORCH_INTERNAL_ASSERT(false, "Couldn't reduce launch bounds to accomodate shaedMemPerBlock limit");
}
break;
}
case at::MemoryFormat::Contiguous: {

View File

@ -12799,6 +12799,21 @@ class TestNNDeviceType(NNTestCase):
out = torch._softmax_backward_data(x, x, 2, x.dtype)
self.assertEqual(out[0, 0, 0], 1 / numel)
# reference issue: https://github.com/pytorch/pytorch/issues/68248
@onlyCUDA
def test_adaptiveavg_pool1d_shmem(self, device):
x = torch.randn(1, 256, 1, 5000, device=device).to(memory_format=torch.channels_last)
x_cpu = x.cpu()
x_cpu.requires_grad_()
x.requires_grad_()
y = torch.nn.functional.adaptive_avg_pool2d(x, (1, 256))
y_cpu = torch.nn.functional.adaptive_avg_pool2d(x_cpu, (1, 256))
grad = torch.randn_like(y)
grad_cpu = grad.cpu()
y.backward(grad)
y_cpu.backward(grad_cpu)
self.assertEqual(x.grad, x_cpu.grad)
@skipMeta
def test_channel_shuffle(self, device):
# 3D tensor