mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[CUDA] Workaround shmem limit for certain input sizes in AdaptiveAvgPool1D (#115231)
Reference issue #68248 CC @ptrblck @malfet @xwang233 Pull Request resolved: https://github.com/pytorch/pytorch/pull/115231 Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
parent
7d92449171
commit
d55365dc05
|
|
@ -639,7 +639,7 @@ namespace {
|
|||
{sizeC*isizeH*isizeW, 1, isizeW*sizeC, sizeC});
|
||||
}
|
||||
|
||||
const int max_threads = std::min<int>(
|
||||
int max_threads = std::min<int>(
|
||||
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, CUDA_MAX_THREADS);
|
||||
int* maxThreadsDim = at::cuda::getCurrentDeviceProperties()->maxThreadsDim;
|
||||
int* maxGridSize = at::cuda::getCurrentDeviceProperties()->maxGridSize;
|
||||
|
|
@ -654,46 +654,57 @@ namespace {
|
|||
// C -> block.x
|
||||
// encourage larger block_y & block_z for better cache hit while maintain
|
||||
// reasonable block_x for coalesced memory access;
|
||||
int block_x = std::min<int>(
|
||||
maxThreadsDim[0], std::min<int>(lastPow2(sizeC), at::cuda::warp_size()));
|
||||
int block_y = std::min<int>(
|
||||
maxThreadsDim[1], std::min<int>(lastPow2(isizeW), max_threads / block_x));
|
||||
int block_z = std::min<int>(
|
||||
maxThreadsDim[2], std::min<int>(lastPow2(isizeH), max_threads / block_x / block_y));
|
||||
block_x = std::min<int>(
|
||||
maxThreadsDim[0], std::min<int>(lastPow2(sizeC), max_threads / block_y / block_z));
|
||||
const dim3 block(block_x, block_y, block_z);
|
||||
int kernel_stride_C = ceil_div(sizeC, block_x * 4);
|
||||
int kernel_size_C = ceil_div(sizeC, block_x * kernel_stride_C);
|
||||
bool done = false;
|
||||
do {
|
||||
int block_x = std::max<int>(std::min<int>(
|
||||
maxThreadsDim[0], std::min<int>(lastPow2(sizeC), at::cuda::warp_size())), 1);
|
||||
int block_y = std::max<int>(std::min<int>(
|
||||
maxThreadsDim[1], std::min<int>(lastPow2(isizeW), max_threads / block_x)), 1);
|
||||
int block_z = std::max<int>(std::min<int>(
|
||||
maxThreadsDim[2], std::min<int>(lastPow2(isizeH), max_threads / block_x / block_y)), 1);
|
||||
block_x = std::max<int>(std::min<int>(
|
||||
maxThreadsDim[0], std::min<int>(lastPow2(sizeC), max_threads / block_y / block_z)), 1);
|
||||
const dim3 block(block_x, block_y, block_z);
|
||||
int kernel_stride_C = ceil_div(sizeC, block_x * 4);
|
||||
int kernel_size_C = ceil_div(sizeC, block_x * kernel_stride_C);
|
||||
|
||||
// Do NOT clip grid_x, striding on Batch dimension is not in the kernel,
|
||||
// although it could be easily implemented given current kernel.
|
||||
int grid_x = sizeB*kernel_stride_C;
|
||||
// it's OK to clip grid_y & grid_z, as we block the two dimensions in the kernel;
|
||||
int grid_y = std::min<int>(
|
||||
maxGridSize[1], ceil_div(isizeW, block_y*BLOCK_STRIDE));
|
||||
int grid_z = std::min<int>(
|
||||
maxGridSize[2], ceil_div(isizeH, block_z*BLOCK_STRIDE));
|
||||
const dim3 grid(grid_x, grid_y, grid_z);
|
||||
// Do NOT clip grid_x, striding on Batch dimension is not in the kernel,
|
||||
// although it could be easily implemented given current kernel.
|
||||
int grid_x = sizeB*kernel_stride_C;
|
||||
// it's OK to clip grid_y & grid_z, as we block the two dimensions in the kernel;
|
||||
int grid_y = std::min<int>(
|
||||
maxGridSize[1], ceil_div(isizeW, block_y*BLOCK_STRIDE));
|
||||
int grid_z = std::min<int>(
|
||||
maxGridSize[2], ceil_div(isizeH, block_z*BLOCK_STRIDE));
|
||||
const dim3 grid(grid_x, grid_y, grid_z);
|
||||
|
||||
// we are dealing with packed tensor here. max index is the same as numel.
|
||||
// TODO: to really support input tensor large enought to go beyond int32,
|
||||
// we will need to restrict out shared memory usage and adjust the launch
|
||||
// config;
|
||||
AT_ASSERT(input.numel() < std::numeric_limits<int32_t>::max());
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
|
||||
input.scalar_type(), "adaptive_avg_pool2d_backward_nhwc_cuda", [&] {
|
||||
size_t shmem_size = (kernel_size_C * block_x * block_y * block_z + osizeH + osizeW) * sizeof(scalar_t) + 2 * isizeW * sizeof(int32_t);
|
||||
AT_ASSERT(shmem_size <= sharedMemPerBlock);
|
||||
adaptive_average_gradinput_nhwc<int32_t><<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>> (
|
||||
gradInput.mutable_data_ptr<scalar_t>(),
|
||||
gradOutput.const_data_ptr<scalar_t>(),
|
||||
sizeB, sizeC, isizeH, isizeW, osizeH, osizeW,
|
||||
kernel_stride_C, kernel_size_C,
|
||||
ostrideB, ostrideC, ostrideH, ostrideW);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
);
|
||||
// we are dealing with packed tensor here. max index is the same as numel.
|
||||
// TODO: to really support input tensor large enought to go beyond int32,
|
||||
// we will need to restrict out shared memory usage and adjust the launch
|
||||
// config;
|
||||
AT_ASSERT(input.numel() < std::numeric_limits<int32_t>::max());
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
|
||||
input.scalar_type(), "adaptive_avg_pool2d_backward_nhwc_cuda", [&] {
|
||||
size_t shmem_size = (kernel_size_C * block_x * block_y * block_z + osizeH + osizeW) * sizeof(scalar_t) + 2 * isizeW * sizeof(int32_t);
|
||||
if (shmem_size <= sharedMemPerBlock) {
|
||||
adaptive_average_gradinput_nhwc<int32_t><<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>> (
|
||||
gradInput.mutable_data_ptr<scalar_t>(),
|
||||
gradOutput.const_data_ptr<scalar_t>(),
|
||||
sizeB, sizeC, isizeH, isizeW, osizeH, osizeW,
|
||||
kernel_stride_C, kernel_size_C,
|
||||
ostrideB, ostrideC, ostrideH, ostrideW);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
done = true;
|
||||
} else {
|
||||
TORCH_WARN_ONCE("Requested shmem_size exceeds sharedMemPerBlock limit! Reducing max_threads...");
|
||||
max_threads /= 2;
|
||||
}
|
||||
}
|
||||
);
|
||||
} while (!done && max_threads);
|
||||
if (!done) {
|
||||
TORCH_INTERNAL_ASSERT(false, "Couldn't reduce launch bounds to accomodate shaedMemPerBlock limit");
|
||||
}
|
||||
break;
|
||||
}
|
||||
case at::MemoryFormat::Contiguous: {
|
||||
|
|
|
|||
|
|
@ -12799,6 +12799,21 @@ class TestNNDeviceType(NNTestCase):
|
|||
out = torch._softmax_backward_data(x, x, 2, x.dtype)
|
||||
self.assertEqual(out[0, 0, 0], 1 / numel)
|
||||
|
||||
# reference issue: https://github.com/pytorch/pytorch/issues/68248
|
||||
@onlyCUDA
|
||||
def test_adaptiveavg_pool1d_shmem(self, device):
|
||||
x = torch.randn(1, 256, 1, 5000, device=device).to(memory_format=torch.channels_last)
|
||||
x_cpu = x.cpu()
|
||||
x_cpu.requires_grad_()
|
||||
x.requires_grad_()
|
||||
y = torch.nn.functional.adaptive_avg_pool2d(x, (1, 256))
|
||||
y_cpu = torch.nn.functional.adaptive_avg_pool2d(x_cpu, (1, 256))
|
||||
grad = torch.randn_like(y)
|
||||
grad_cpu = grad.cpu()
|
||||
y.backward(grad)
|
||||
y_cpu.backward(grad_cpu)
|
||||
self.assertEqual(x.grad, x_cpu.grad)
|
||||
|
||||
@skipMeta
|
||||
def test_channel_shuffle(self, device):
|
||||
# 3D tensor
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user