CUDA BFloat Pooling (#44836)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44836

Reviewed By: mruberry

Differential Revision: D23800992

Pulled By: ngimel

fbshipit-source-id: 2945a27874345197cbd1d8a4fbd20816afc02c86
This commit is contained in:
Xiang Gao 2020-09-19 15:37:39 -07:00 committed by Facebook GitHub Bot
parent 7ecfaef7ec
commit faef89c89f
7 changed files with 332 additions and 371 deletions

View File

@ -512,17 +512,15 @@ namespace {
AT_ASSERT(input_.numel() < std::numeric_limits<int32_t>::max());
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
input_.scalar_type(), "adaptive_avg_pool2d_nhwc_cuda", [&] {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "adaptive_avg_pool2d_nhwc_cuda", [&] {
size_t shmem_size = (kernel_size_C * block_x * block_y * block_z) * sizeof(scalar_t);
AT_ASSERT(shmem_size <= sharedMemPerBlock);
adaptive_average_pool_nhwc<int32_t><<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>> (
input_.data_ptr<scalar_t>(),
output.data_ptr<scalar_t>(),
sizeB, sizeC, isizeH, isizeW, osizeH, osizeW,
kernel_stride_C, kernel_size_C,
istrideB, istrideC, istrideH, istrideW);
});
}
size_t shmem_size = (kernel_size_C * block_x * block_y * block_z) * sizeof(scalar_t);
AT_ASSERT(shmem_size <= sharedMemPerBlock);
adaptive_average_pool_nhwc<int32_t><<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>> (
input_.data_ptr<scalar_t>(),
output.data_ptr<scalar_t>(),
sizeB, sizeC, isizeH, isizeW, osizeH, osizeW,
kernel_stride_C, kernel_size_C,
istrideB, istrideC, istrideH, istrideW);
}
);
break;
}
@ -551,22 +549,20 @@ namespace {
}
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
input_.scalar_type(), "adaptive_avg_pool2d_cuda", [&] {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "adaptive_avg_pool2d_cuda", [&] {
scalar_t *input_data = input_.data_ptr<scalar_t>();
scalar_t *output_data = output.data_ptr<scalar_t>();
scalar_t *input_data = input_.data_ptr<scalar_t>();
scalar_t *output_data = output.data_ptr<scalar_t>();
// cuda blocks & threads:
int blocksH = std::max<int64_t>((int)(16L / sizeD), 1);
dim3 blocks(grid_x, blocksH);
dim3 threads(32, 8);
// cuda blocks & threads:
int blocksH = std::max<int64_t>((int)(16L / sizeD), 1);
dim3 blocks(grid_x, blocksH);
dim3 threads(32, 8);
// run averagepool kernel
adaptive_average_pool <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
input_data, output_data,
isizeH, isizeW, osizeH, osizeW,
istrideD, istrideH, istrideW);
});
}
// run averagepool kernel
adaptive_average_pool <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
input_data, output_data,
isizeH, isizeW, osizeH, osizeW,
istrideD, istrideH, istrideW);
}
);
break;
}
@ -661,17 +657,15 @@ namespace {
AT_ASSERT(input.numel() < std::numeric_limits<int32_t>::max());
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
input.scalar_type(), "adaptive_avg_pool2d_backward_nhwc_cuda", [&] {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "adaptive_avg_pool2d_backward_nhwc_cuda", [&] {
size_t shmem_size = (kernel_size_C * block_x * block_y * block_z + osizeH + osizeW) * sizeof(scalar_t) + 2 * isizeW * sizeof(int32_t);
AT_ASSERT(shmem_size <= sharedMemPerBlock);
adaptive_average_gradinput_nhwc<int32_t><<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>> (
gradInput.data_ptr<scalar_t>(),
gradOutput.data_ptr<scalar_t>(),
sizeB, sizeC, isizeH, isizeW, osizeH, osizeW,
kernel_stride_C, kernel_size_C,
ostrideB, ostrideC, ostrideH, ostrideW);
});
}
size_t shmem_size = (kernel_size_C * block_x * block_y * block_z + osizeH + osizeW) * sizeof(scalar_t) + 2 * isizeW * sizeof(int32_t);
AT_ASSERT(shmem_size <= sharedMemPerBlock);
adaptive_average_gradinput_nhwc<int32_t><<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>> (
gradInput.data_ptr<scalar_t>(),
gradOutput.data_ptr<scalar_t>(),
sizeB, sizeC, isizeH, isizeW, osizeH, osizeW,
kernel_stride_C, kernel_size_C,
ostrideB, ostrideC, ostrideH, ostrideW);
}
);
break;
}
@ -693,30 +687,28 @@ namespace {
//bool atomic = (isizeW%osizeW != 0) || (isizeH%osizeH != 0);
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
input.scalar_type(), "adaptive_avg_pool2d_backward_cuda", [&] {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "adaptive_avg_pool2d_backward_cuda", [&] {
scalar_t *gradOutput_data = gradOutput.data_ptr<scalar_t>();
scalar_t *gradInput_data = gradInput.data_ptr<scalar_t>();
scalar_t *gradOutput_data = gradOutput.data_ptr<scalar_t>();
scalar_t *gradInput_data = gradInput.data_ptr<scalar_t>();
// cuda blocks & threads:
int blocksH = std::max((int)(16L / sizeD), 1);
dim3 blocks(grid_x, blocksH);
dim3 threads(32, 8);
// cuda blocks & threads:
int blocksH = std::max((int)(16L / sizeD), 1);
dim3 blocks(grid_x, blocksH);
dim3 threads(32, 8);
if(atomic)
{
// run updateGradInput kernel, accumulate gradients atomically
atomic_adaptive_average_gradinput <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
gradInput_data, gradOutput_data,
isizeH, isizeW, osizeH, osizeW);
}
else
{
// run updateGradInput kernel
adaptive_average_gradinput <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
gradInput_data, gradOutput_data,
isizeH, isizeW, osizeH, osizeW);
}
});
if(atomic)
{
// run updateGradInput kernel, accumulate gradients atomically
atomic_adaptive_average_gradinput <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
gradInput_data, gradOutput_data,
isizeH, isizeW, osizeH, osizeW);
}
else
{
// run updateGradInput kernel
adaptive_average_gradinput <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
gradInput_data, gradOutput_data,
isizeH, isizeW, osizeH, osizeW);
}
}
);
break;

View File

@ -390,17 +390,15 @@ void adaptive_avg_pool3d_out_cuda_template(
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
input.scalar_type(), "adaptive_avg_pool3d_cuda", [&] {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "adaptive_avg_pool3d_cuda", [&] {
scalar_t* input_data = input.data_ptr<scalar_t>();
scalar_t* output_data = output.data_ptr<scalar_t>();
scalar_t* input_data = input.data_ptr<scalar_t>();
scalar_t* output_data = output.data_ptr<scalar_t>();
adaptiveaveragepool_loop(
input_data, output_data,
totalZ,
isizeT, isizeH, isizeW,
osizeT, osizeH, osizeW,
istrideD, istrideT, istrideH, istrideW);
});
adaptiveaveragepool_loop(
input_data, output_data,
totalZ,
isizeT, isizeH, isizeW,
osizeT, osizeH, osizeW,
istrideD, istrideT, istrideH, istrideW);
});
}
@ -457,30 +455,26 @@ void adaptive_avg_pool3d_backward_out_cuda_template(
if (atomic) {
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
input.scalar_type(), "adaptive_avg_pool3d_backward_cuda", [&] {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "adaptive_avg_pool3d_backward_cuda", [&] {
scalar_t* gradInput_data = gradInput.data_ptr<scalar_t>();
scalar_t* gradOutput_data = gradOutput.data_ptr<scalar_t>();
scalar_t* gradInput_data = gradInput.data_ptr<scalar_t>();
scalar_t* gradOutput_data = gradOutput.data_ptr<scalar_t>();
atomicadaptiveaveragegradinput_loop(
gradInput_data, gradOutput_data,
totalZ,
isizeT, isizeH, isizeW,
osizeT, osizeH, osizeW);
});
atomicadaptiveaveragegradinput_loop(
gradInput_data, gradOutput_data,
totalZ,
isizeT, isizeH, isizeW,
osizeT, osizeH, osizeW);
});
} else {
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
input.scalar_type(), "adaptive_avg_pool3d_backward_cuda", [&] {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "adaptive_avg_pool3d_backward_cuda", [&] {
scalar_t* gradInput_data = gradInput.data_ptr<scalar_t>();
scalar_t* gradOutput_data = gradOutput.data_ptr<scalar_t>();
scalar_t* gradInput_data = gradInput.data_ptr<scalar_t>();
scalar_t* gradOutput_data = gradOutput.data_ptr<scalar_t>();
adaptiveaveragegradinput_loop(
gradInput_data, gradOutput_data,
totalZ,
isizeT, isizeH, isizeW,
osizeT, osizeH, osizeW);
});
adaptiveaveragegradinput_loop(
gradInput_data, gradOutput_data,
totalZ,
isizeT, isizeH, isizeW,
osizeT, osizeH, osizeW);
});
}
}

View File

@ -232,27 +232,25 @@ void adaptive_max_pool2d_out_cuda_template(
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
"adaptive_max_pool2d_cuda",
[&] {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "adaptive_max_pool2d_cuda", [&] {
output.resize_({sizeD, osizeH, osizeW});
indices.resize_({sizeD, osizeH, osizeW});
output.resize_({sizeD, osizeH, osizeW});
indices.resize_({sizeD, osizeH, osizeW});
scalar_t *input_data = input.data_ptr<scalar_t>();
scalar_t *output_data = output.data_ptr<scalar_t>();
int64_t *indices_data = indices.data_ptr<int64_t>();
scalar_t *input_data = input.data_ptr<scalar_t>();
scalar_t *output_data = output.data_ptr<scalar_t>();
int64_t *indices_data = indices.data_ptr<int64_t>();
// cuda blocks & threads:
int blocksH = (int)(16L / sizeD);
blocksH = blocksH < 1 ? 1 : blocksH;
dim3 blocks(sizeD, blocksH);
dim3 threads(32, 8);
// cuda blocks & threads:
int blocksH = (int)(16L / sizeD);
blocksH = blocksH < 1 ? 1 : blocksH;
dim3 blocks(sizeD, blocksH);
dim3 threads(32, 8);
// run maxpool kernel
adaptivemaxpool <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
input_data, output_data,
indices_data,
isizeH, isizeW, osizeH, osizeW,
istrideD, istrideH, istrideW);
});
// run maxpool kernel
adaptivemaxpool <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
input_data, output_data,
indices_data,
isizeH, isizeW, osizeH, osizeW,
istrideD, istrideH, istrideW);
}
);
AT_CUDA_CHECK(cudaGetLastError());
@ -271,27 +269,25 @@ void adaptive_max_pool2d_out_cuda_template(
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input_.scalar_type(),
"adaptive_max_pool2d_cuda",
[&] {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "adaptive_max_pool2d_cuda", [&] {
output.resize_({sizeB, sizeD, osizeH, osizeW});
indices.resize_({sizeB, sizeD, osizeH, osizeW});
output.resize_({sizeB, sizeD, osizeH, osizeW});
indices.resize_({sizeB, sizeD, osizeH, osizeW});
scalar_t *input_data = input_.data_ptr<scalar_t>();
scalar_t *output_data = output.data_ptr<scalar_t>();
int64_t *indices_data = indices.data_ptr<int64_t>();
scalar_t *input_data = input_.data_ptr<scalar_t>();
scalar_t *output_data = output.data_ptr<scalar_t>();
int64_t *indices_data = indices.data_ptr<int64_t>();
// cuda blocks & threads:
int blocksH = (int)(16L / sizeD);
blocksH = blocksH < 1 ? 1 : blocksH;
dim3 blocks(sizeB*sizeD, blocksH);
dim3 threads(32, 8);
// cuda blocks & threads:
int blocksH = (int)(16L / sizeD);
blocksH = blocksH < 1 ? 1 : blocksH;
dim3 blocks(sizeB*sizeD, blocksH);
dim3 threads(32, 8);
// run maxpool kernel
adaptivemaxpool <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
input_data, output_data,
indices_data,
isizeH, isizeW, osizeH, osizeW,
istrideD, istrideH, istrideW);
});
// run maxpool kernel
adaptivemaxpool <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
input_data, output_data,
indices_data,
isizeH, isizeW, osizeH, osizeW,
istrideD, istrideH, istrideW);
}
);
AT_CUDA_CHECK(cudaGetLastError());
@ -333,34 +329,32 @@ void adaptive_max_pool2d_backward_out_cuda_template(
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
"adaptive_max_pool2d_backward_cuda",
[&] {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "adaptive_max_pool2d_backward_cuda", [&] {
scalar_t *gradInput_data = gradInput.data_ptr<scalar_t>();
scalar_t *gradOutput_data = gradOutput.data_ptr<scalar_t>();
int64_t *indices_data = indices.data_ptr<int64_t>();
scalar_t *gradInput_data = gradInput.data_ptr<scalar_t>();
scalar_t *gradOutput_data = gradOutput.data_ptr<scalar_t>();
int64_t *indices_data = indices.data_ptr<int64_t>();
// cuda blocks & threads:
int blocksH = (int)(16L / sizeD);
blocksH = blocksH < 1 ? 1 : blocksH;
dim3 blocks(sizeD, blocksH);
dim3 threads(32, 8);
// cuda blocks & threads:
int blocksH = (int)(16L / sizeD);
blocksH = blocksH < 1 ? 1 : blocksH;
dim3 blocks(sizeD, blocksH);
dim3 threads(32, 8);
if(atomic)
{
// run updateGradInput kernel, accumulate gradients atomically
atomicadaptivemaxgradinput <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
gradInput_data, gradOutput_data,
indices_data,
isizeH, isizeW, osizeH, osizeW);
}
else
{
// run updateGradInput kernel
atomicadaptivemaxgradinput <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
gradInput_data, gradOutput_data,
indices_data,
isizeH, isizeW, osizeH, osizeW);
}
});
if(atomic)
{
// run updateGradInput kernel, accumulate gradients atomically
atomicadaptivemaxgradinput <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
gradInput_data, gradOutput_data,
indices_data,
isizeH, isizeW, osizeH, osizeW);
}
else
{
// run updateGradInput kernel
atomicadaptivemaxgradinput <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
gradInput_data, gradOutput_data,
indices_data,
isizeH, isizeW, osizeH, osizeW);
}
}
);
AT_CUDA_CHECK(cudaGetLastError());
@ -381,34 +375,32 @@ void adaptive_max_pool2d_backward_out_cuda_template(
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
"adaptive_max_pool2d_backward_cuda",
[&] {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "adaptive_max_pool2d_backward_cuda", [&] {
scalar_t *gradInput_data = gradInput.data_ptr<scalar_t>();
scalar_t *gradOutput_data = gradOutput.data_ptr<scalar_t>();
int64_t *indices_data = indices.data_ptr<int64_t>();
scalar_t *gradInput_data = gradInput.data_ptr<scalar_t>();
scalar_t *gradOutput_data = gradOutput.data_ptr<scalar_t>();
int64_t *indices_data = indices.data_ptr<int64_t>();
// cuda blocks & threads:
int blocksH = (int)(16L / sizeD);
blocksH = blocksH < 1 ? 1 : blocksH;
dim3 blocks(sizeB*sizeD, blocksH);
dim3 threads(32, 8);
// cuda blocks & threads:
int blocksH = (int)(16L / sizeD);
blocksH = blocksH < 1 ? 1 : blocksH;
dim3 blocks(sizeB*sizeD, blocksH);
dim3 threads(32, 8);
if(atomic)
{
// run updateGradInput kernel, accumulate gradients atomically
atomicadaptivemaxgradinput <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
gradInput_data, gradOutput_data,
indices_data,
isizeH, isizeW, osizeH, osizeW);
}
else
{
// run updateGradInput kernel, accumulate gradients atomically
adaptivemaxgradinput <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
gradInput_data, gradOutput_data,
indices_data,
isizeH, isizeW, osizeH, osizeW);
}
});
if(atomic)
{
// run updateGradInput kernel, accumulate gradients atomically
atomicadaptivemaxgradinput <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
gradInput_data, gradOutput_data,
indices_data,
isizeH, isizeW, osizeH, osizeW);
}
else
{
// run updateGradInput kernel, accumulate gradients atomically
adaptivemaxgradinput <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
gradInput_data, gradOutput_data,
indices_data,
isizeH, isizeW, osizeH, osizeW);
}
}
);
AT_CUDA_CHECK(cudaGetLastError());

View File

@ -366,15 +366,13 @@ void adaptive_max_pool3d_out_cuda_template(
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
"adaptive_max_pool3d_cuda",
[&] {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "adaptive_max_pool3d_cuda", [&] {
scalar_t *input_data = input.data_ptr<scalar_t>();
scalar_t *output_data = output.data_ptr<scalar_t>();
int64_t *indices_data = indices.data_ptr<int64_t>();
scalar_t *input_data = input.data_ptr<scalar_t>();
scalar_t *output_data = output.data_ptr<scalar_t>();
int64_t *indices_data = indices.data_ptr<int64_t>();
adaptivemaxpool_loop(
input_data, output_data, indices_data, totalZ, isizeT, isizeH, isizeW,
osizeT, osizeH, osizeW, istrideD, istrideT, istrideH, istrideW);
});
adaptivemaxpool_loop(
input_data, output_data, indices_data, totalZ, isizeT, isizeH, isizeW,
osizeT, osizeH, osizeW, istrideD, istrideT, istrideH, istrideW);
}
);
}
@ -435,32 +433,28 @@ void adaptive_max_pool3d_backward_out_cuda_template(
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
"adaptive_max_pool3d_backward_cuda",
[&] {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "adaptive_max_pool3d_backward_cuda", [&] {
scalar_t *gradInput_data = gradInput.data_ptr<scalar_t>();
scalar_t *gradOutput_data = gradOutput.data_ptr<scalar_t>();
int64_t *indices_data = indices.data_ptr<int64_t>();
scalar_t *gradInput_data = gradInput.data_ptr<scalar_t>();
scalar_t *gradOutput_data = gradOutput.data_ptr<scalar_t>();
int64_t *indices_data = indices.data_ptr<int64_t>();
atomicadaptivemaxgradinput_loop(
gradInput_data, gradOutput_data, indices_data,
totalZ,
isizeT, isizeH, isizeW, osizeT, osizeH, osizeW);
});
atomicadaptivemaxgradinput_loop(
gradInput_data, gradOutput_data, indices_data,
totalZ,
isizeT, isizeH, isizeW, osizeT, osizeH, osizeW);
}
);
} else {
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
"adaptive_max_pool3d_backward_cuda",
[&] {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "adaptive_max_pool3d_backward_cuda", [&] {
scalar_t *gradInput_data = gradInput.data_ptr<scalar_t>();
scalar_t *gradOutput_data = gradOutput.data_ptr<scalar_t>();
int64_t *indices_data = indices.data_ptr<int64_t>();
scalar_t *gradInput_data = gradInput.data_ptr<scalar_t>();
scalar_t *gradOutput_data = gradOutput.data_ptr<scalar_t>();
int64_t *indices_data = indices.data_ptr<int64_t>();
adaptivemaxgradinput_loop(
gradInput_data, gradOutput_data, indices_data,
totalZ,
isizeT, isizeH, isizeW, osizeT, osizeH, osizeW);
});
adaptivemaxgradinput_loop(
gradInput_data, gradOutput_data, indices_data,
totalZ,
isizeT, isizeH, isizeW, osizeT, osizeH, osizeW);
}
);
}

View File

@ -303,33 +303,15 @@ void avg_pool2d_out_cuda_template(
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
"avg_pool2d_out_cuda_frame",
[&] {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "avg_pool2d_out_cuda_frame", [&] {
using accscalar_t = acc_type<scalar_t, true>;
using accscalar_t = acc_type<scalar_t, true>;
scalar_t *output_data = output.data_ptr<scalar_t>();
scalar_t *input_data = input.data_ptr<scalar_t>();
scalar_t *output_data = output.data_ptr<scalar_t>();
scalar_t *input_data = input.data_ptr<scalar_t>();
switch (memory_format){
case MemoryFormat::ChannelsLast: {
output.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::ChannelsLast);
avg_pool2d_out_cuda_frame_nhwc<scalar_t, accscalar_t>
<<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
count,
input_data,
nbatch,
nInputPlane,
inputHeight, inputWidth,
outputHeight, outputWidth,
kH, kW,
dH, dW,
padH, padW,
output_data,
divisor_override_value,
count_include_pad, use_divisor);
break;
}
case MemoryFormat::Contiguous: {
avg_pool2d_out_cuda_frame<scalar_t, accscalar_t>
switch (memory_format){
case MemoryFormat::ChannelsLast: {
output.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::ChannelsLast);
avg_pool2d_out_cuda_frame_nhwc<scalar_t, accscalar_t>
<<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
count,
input_data,
@ -343,11 +325,27 @@ void avg_pool2d_out_cuda_template(
output_data,
divisor_override_value,
count_include_pad, use_divisor);
break;
}
default: TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
break;
}
});
case MemoryFormat::Contiguous: {
avg_pool2d_out_cuda_frame<scalar_t, accscalar_t>
<<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
count,
input_data,
nbatch,
nInputPlane,
inputHeight, inputWidth,
outputHeight, outputWidth,
kH, kW,
dH, dW,
padH, padW,
output_data,
divisor_override_value,
count_include_pad, use_divisor);
break;
}
default: TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
}
}
);
@ -437,51 +435,49 @@ Tensor& avg_pool2d_backward_out_cuda_template(
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
"avg_pool2d_backward_out_cuda_frame",
[&] {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "avg_pool2d_backward_out_cuda_frame", [&] {
using accscalar_t = acc_type<scalar_t, true>;
using accscalar_t = acc_type<scalar_t, true>;
scalar_t *gradOutput_data = gradOutput.data_ptr<scalar_t>();
scalar_t *gradInput_data = gradInput.data_ptr<scalar_t>();
scalar_t *gradOutput_data = gradOutput.data_ptr<scalar_t>();
scalar_t *gradInput_data = gradInput.data_ptr<scalar_t>();
switch (memory_format) {
case MemoryFormat::ChannelsLast: {
gradInput.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::ChannelsLast);
avg_pool2d_backward_out_cuda_frame_nhwc<scalar_t, accscalar_t>
<<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
count,
gradOutput_data,
nbatch,
nInputPlane,
inputHeight, inputWidth,
outputHeight, outputWidth,
kH, kW,
dH, dW,
padH, padW,
gradInput_data,
divisor_override_value,
count_include_pad, use_divisor);
break;
}
case MemoryFormat::Contiguous: {
avg_pool2d_backward_out_cuda_frame<scalar_t, accscalar_t>
<<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
count,
gradOutput_data,
nbatch,
nInputPlane,
inputHeight, inputWidth,
outputHeight, outputWidth,
kH, kW,
dH, dW,
padH, padW,
gradInput_data,
divisor_override_value,
count_include_pad, use_divisor);
break;
}
default: TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
switch (memory_format) {
case MemoryFormat::ChannelsLast: {
gradInput.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::ChannelsLast);
avg_pool2d_backward_out_cuda_frame_nhwc<scalar_t, accscalar_t>
<<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
count,
gradOutput_data,
nbatch,
nInputPlane,
inputHeight, inputWidth,
outputHeight, outputWidth,
kH, kW,
dH, dW,
padH, padW,
gradInput_data,
divisor_override_value,
count_include_pad, use_divisor);
break;
}
});
case MemoryFormat::Contiguous: {
avg_pool2d_backward_out_cuda_frame<scalar_t, accscalar_t>
<<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
count,
gradOutput_data,
nbatch,
nInputPlane,
inputHeight, inputWidth,
outputHeight, outputWidth,
kH, kW,
dH, dW,
padH, padW,
gradInput_data,
divisor_override_value,
count_include_pad, use_divisor);
break;
}
default: TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
}
}
);

View File

@ -415,44 +415,42 @@ void avg_pool3d_out_cuda_template(
input.scalar_type(),
"avg_pool3d_out_cuda",
[&] {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "avg_pool3d_out_cuda", [&] {
using accscalar_t = acc_type<scalar_t, true>;
int64_t totalZ = otime * nslices * nbatch;
int64_t offsetZ = 0;
dim3 block(32, 8);
using accscalar_t = acc_type<scalar_t, true>;
int64_t totalZ = otime * nslices * nbatch;
int64_t offsetZ = 0;
dim3 block(32, 8);
while (totalZ > 0) {
dim3 grid(cuda::ATenCeilDiv(owidth, static_cast<int64_t>(block.x)),
cuda::ATenCeilDiv(oheight, static_cast<int64_t>(block.y)),
totalZ > 65535 ? 65535 : totalZ);
while (totalZ > 0) {
dim3 grid(cuda::ATenCeilDiv(owidth, static_cast<int64_t>(block.x)),
cuda::ATenCeilDiv(oheight, static_cast<int64_t>(block.y)),
totalZ > 65535 ? 65535 : totalZ);
switch (kW) {
LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(1);
LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(2);
LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(3);
LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(4);
LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(5);
LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(6);
LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(7);
default:
avg_pool3d_cuda_update_output<scalar_t, accscalar_t>
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
work_input.packed_accessor64<scalar_t, 4>(),
work_output.packed_accessor64<scalar_t, 4>(),
kT, kH, kW,
dT, dH, dW,
padT, padH, padW,
count_include_pad,
offsetZ, divisor);
break;
}
AT_CUDA_CHECK(cudaGetLastError());
totalZ -= 65535;
offsetZ += 65535;
switch (kW) {
LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(1);
LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(2);
LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(3);
LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(4);
LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(5);
LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(6);
LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(7);
default:
avg_pool3d_cuda_update_output<scalar_t, accscalar_t>
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
work_input.packed_accessor64<scalar_t, 4>(),
work_output.packed_accessor64<scalar_t, 4>(),
kT, kH, kW,
dT, dH, dW,
padT, padH, padW,
count_include_pad,
offsetZ, divisor);
break;
}
});
AT_CUDA_CHECK(cudaGetLastError());
totalZ -= 65535;
offsetZ += 65535;
}
}
);
}
@ -559,38 +557,36 @@ void avg_pool3d_backward_out_cuda_template(
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
"avg_pool3d_backward_out_frame_stride1",
[&] {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "avg_pool3d_backward_out_frame_stride1", [&] {
using accscalar_t = acc_type<scalar_t, true>;
int64_t totalZ = itime * nslices * nbatch;
int64_t offsetZ = 0;
dim3 block(32, 8);
using accscalar_t = acc_type<scalar_t, true>;
int64_t totalZ = itime * nslices * nbatch;
int64_t offsetZ = 0;
dim3 block(32, 8);
accscalar_t divide_factor;
if (divisor) {
divide_factor = static_cast<accscalar_t>(divisor);
} else {
divide_factor = static_cast<accscalar_t>(kT * kH * kW);
}
accscalar_t divide_factor;
if (divisor) {
divide_factor = static_cast<accscalar_t>(divisor);
} else {
divide_factor = static_cast<accscalar_t>(kT * kH * kW);
}
while (totalZ > 0) {
dim3 grid(cuda::ATenCeilDiv(iwidth, static_cast<int64_t>(block.x)),
cuda::ATenCeilDiv(iheight, static_cast<int64_t>(block.y)),
totalZ > 65535 ? 65535 : totalZ);
while (totalZ > 0) {
dim3 grid(cuda::ATenCeilDiv(iwidth, static_cast<int64_t>(block.x)),
cuda::ATenCeilDiv(iheight, static_cast<int64_t>(block.y)),
totalZ > 65535 ? 65535 : totalZ);
avg_pool3d_single_backward_out_frame_stride1<scalar_t, accscalar_t>
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
work_grad_output.packed_accessor64<scalar_t, 4>(),
work_grad_input.packed_accessor64<scalar_t, 4>(),
kT, kH, kW,
1.0f/divide_factor,
offsetZ);
avg_pool3d_single_backward_out_frame_stride1<scalar_t, accscalar_t>
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
work_grad_output.packed_accessor64<scalar_t, 4>(),
work_grad_input.packed_accessor64<scalar_t, 4>(),
kT, kH, kW,
1.0f/divide_factor,
offsetZ);
AT_CUDA_CHECK(cudaGetLastError());
AT_CUDA_CHECK(cudaGetLastError());
totalZ -= 65535;
offsetZ += 65535;
}
});
totalZ -= 65535;
offsetZ += 65535;
}
}
);
}
@ -598,46 +594,44 @@ void avg_pool3d_backward_out_cuda_template(
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
"avg_pool3d_backward_out_frame",
[&] {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "avg_pool3d_backward_out_frame", [&] {
using accscalar_t = acc_type<scalar_t, true>;
int64_t totalZ = otime * nslices * nbatch;
int64_t offsetZ = 0;
dim3 block(32, 8);
using accscalar_t = acc_type<scalar_t, true>;
int64_t totalZ = otime * nslices * nbatch;
int64_t offsetZ = 0;
dim3 block(32, 8);
while (totalZ > 0) {
dim3 grid(cuda::ATenCeilDiv(owidth, static_cast<int64_t>(block.x)),
cuda::ATenCeilDiv(oheight, static_cast<int64_t>(block.y)),
totalZ > 65535 ? 65535 : totalZ);
while (totalZ > 0) {
dim3 grid(cuda::ATenCeilDiv(owidth, static_cast<int64_t>(block.x)),
cuda::ATenCeilDiv(oheight, static_cast<int64_t>(block.y)),
totalZ > 65535 ? 65535 : totalZ);
if (kernelsOverlap) {
avg_pool3d_cuda_update_grad_input_atomic<scalar_t, accscalar_t>
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
work_grad_output.packed_accessor64<scalar_t, 4>(),
work_grad_input.packed_accessor64<scalar_t, 4>(),
kT, kH, kW,
dT, dH, dW,
padT, padH, padW,
count_include_pad,
offsetZ, divisor);
}
else {
avg_pool3d_cuda_update_grad_input<scalar_t, accscalar_t>
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
work_grad_output.packed_accessor64<scalar_t, 4>(),
work_grad_input.packed_accessor64<scalar_t, 4>(),
kT, kH, kW,
dT, dH, dW,
padT, padH, padW,
count_include_pad,
offsetZ, divisor);
}
AT_CUDA_CHECK(cudaGetLastError());
totalZ -= 65535;
offsetZ += 65535;
if (kernelsOverlap) {
avg_pool3d_cuda_update_grad_input_atomic<scalar_t, accscalar_t>
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
work_grad_output.packed_accessor64<scalar_t, 4>(),
work_grad_input.packed_accessor64<scalar_t, 4>(),
kT, kH, kW,
dT, dH, dW,
padT, padH, padW,
count_include_pad,
offsetZ, divisor);
}
});
else {
avg_pool3d_cuda_update_grad_input<scalar_t, accscalar_t>
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
work_grad_output.packed_accessor64<scalar_t, 4>(),
work_grad_input.packed_accessor64<scalar_t, 4>(),
kT, kH, kW,
dT, dH, dW,
padT, padH, padW,
count_include_pad,
offsetZ, divisor);
}
AT_CUDA_CHECK(cudaGetLastError());
totalZ -= 65535;
offsetZ += 65535;
}
}
);
}

View File

@ -11842,7 +11842,6 @@ class TestNNDeviceType(NNTestCase):
self._test_bfloat16_ops(torch.nn.LeakyReLU(), device, inp_dims=(5), prec=1e-2)
@onlyCUDA
@skipCUDAIfNotRocm
def test_pooling_bfloat16(self, device):
self._test_bfloat16_ops(torch.nn.AvgPool1d(3, stride=2), device, inp_dims=(8, 4, 16), prec=0.05)
self._test_bfloat16_ops(torch.nn.AvgPool2d(3, stride=2), device, inp_dims=(8, 4, 16, 16), prec=0.05)