mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
CUDA BFloat Pooling (#44836)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44836 Reviewed By: mruberry Differential Revision: D23800992 Pulled By: ngimel fbshipit-source-id: 2945a27874345197cbd1d8a4fbd20816afc02c86
This commit is contained in:
parent
7ecfaef7ec
commit
faef89c89f
|
|
@ -512,17 +512,15 @@ namespace {
|
|||
AT_ASSERT(input_.numel() < std::numeric_limits<int32_t>::max());
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
|
||||
input_.scalar_type(), "adaptive_avg_pool2d_nhwc_cuda", [&] {
|
||||
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "adaptive_avg_pool2d_nhwc_cuda", [&] {
|
||||
size_t shmem_size = (kernel_size_C * block_x * block_y * block_z) * sizeof(scalar_t);
|
||||
AT_ASSERT(shmem_size <= sharedMemPerBlock);
|
||||
adaptive_average_pool_nhwc<int32_t><<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>> (
|
||||
input_.data_ptr<scalar_t>(),
|
||||
output.data_ptr<scalar_t>(),
|
||||
sizeB, sizeC, isizeH, isizeW, osizeH, osizeW,
|
||||
kernel_stride_C, kernel_size_C,
|
||||
istrideB, istrideC, istrideH, istrideW);
|
||||
});
|
||||
}
|
||||
size_t shmem_size = (kernel_size_C * block_x * block_y * block_z) * sizeof(scalar_t);
|
||||
AT_ASSERT(shmem_size <= sharedMemPerBlock);
|
||||
adaptive_average_pool_nhwc<int32_t><<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>> (
|
||||
input_.data_ptr<scalar_t>(),
|
||||
output.data_ptr<scalar_t>(),
|
||||
sizeB, sizeC, isizeH, isizeW, osizeH, osizeW,
|
||||
kernel_stride_C, kernel_size_C,
|
||||
istrideB, istrideC, istrideH, istrideW);
|
||||
}
|
||||
);
|
||||
break;
|
||||
}
|
||||
|
|
@ -551,22 +549,20 @@ namespace {
|
|||
}
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
|
||||
input_.scalar_type(), "adaptive_avg_pool2d_cuda", [&] {
|
||||
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "adaptive_avg_pool2d_cuda", [&] {
|
||||
scalar_t *input_data = input_.data_ptr<scalar_t>();
|
||||
scalar_t *output_data = output.data_ptr<scalar_t>();
|
||||
scalar_t *input_data = input_.data_ptr<scalar_t>();
|
||||
scalar_t *output_data = output.data_ptr<scalar_t>();
|
||||
|
||||
// cuda blocks & threads:
|
||||
int blocksH = std::max<int64_t>((int)(16L / sizeD), 1);
|
||||
dim3 blocks(grid_x, blocksH);
|
||||
dim3 threads(32, 8);
|
||||
// cuda blocks & threads:
|
||||
int blocksH = std::max<int64_t>((int)(16L / sizeD), 1);
|
||||
dim3 blocks(grid_x, blocksH);
|
||||
dim3 threads(32, 8);
|
||||
|
||||
// run averagepool kernel
|
||||
adaptive_average_pool <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
|
||||
input_data, output_data,
|
||||
isizeH, isizeW, osizeH, osizeW,
|
||||
istrideD, istrideH, istrideW);
|
||||
});
|
||||
}
|
||||
// run averagepool kernel
|
||||
adaptive_average_pool <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
|
||||
input_data, output_data,
|
||||
isizeH, isizeW, osizeH, osizeW,
|
||||
istrideD, istrideH, istrideW);
|
||||
}
|
||||
);
|
||||
break;
|
||||
}
|
||||
|
|
@ -661,17 +657,15 @@ namespace {
|
|||
AT_ASSERT(input.numel() < std::numeric_limits<int32_t>::max());
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
|
||||
input.scalar_type(), "adaptive_avg_pool2d_backward_nhwc_cuda", [&] {
|
||||
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "adaptive_avg_pool2d_backward_nhwc_cuda", [&] {
|
||||
size_t shmem_size = (kernel_size_C * block_x * block_y * block_z + osizeH + osizeW) * sizeof(scalar_t) + 2 * isizeW * sizeof(int32_t);
|
||||
AT_ASSERT(shmem_size <= sharedMemPerBlock);
|
||||
adaptive_average_gradinput_nhwc<int32_t><<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>> (
|
||||
gradInput.data_ptr<scalar_t>(),
|
||||
gradOutput.data_ptr<scalar_t>(),
|
||||
sizeB, sizeC, isizeH, isizeW, osizeH, osizeW,
|
||||
kernel_stride_C, kernel_size_C,
|
||||
ostrideB, ostrideC, ostrideH, ostrideW);
|
||||
});
|
||||
}
|
||||
size_t shmem_size = (kernel_size_C * block_x * block_y * block_z + osizeH + osizeW) * sizeof(scalar_t) + 2 * isizeW * sizeof(int32_t);
|
||||
AT_ASSERT(shmem_size <= sharedMemPerBlock);
|
||||
adaptive_average_gradinput_nhwc<int32_t><<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>> (
|
||||
gradInput.data_ptr<scalar_t>(),
|
||||
gradOutput.data_ptr<scalar_t>(),
|
||||
sizeB, sizeC, isizeH, isizeW, osizeH, osizeW,
|
||||
kernel_stride_C, kernel_size_C,
|
||||
ostrideB, ostrideC, ostrideH, ostrideW);
|
||||
}
|
||||
);
|
||||
break;
|
||||
}
|
||||
|
|
@ -693,30 +687,28 @@ namespace {
|
|||
//bool atomic = (isizeW%osizeW != 0) || (isizeH%osizeH != 0);
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
|
||||
input.scalar_type(), "adaptive_avg_pool2d_backward_cuda", [&] {
|
||||
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "adaptive_avg_pool2d_backward_cuda", [&] {
|
||||
scalar_t *gradOutput_data = gradOutput.data_ptr<scalar_t>();
|
||||
scalar_t *gradInput_data = gradInput.data_ptr<scalar_t>();
|
||||
scalar_t *gradOutput_data = gradOutput.data_ptr<scalar_t>();
|
||||
scalar_t *gradInput_data = gradInput.data_ptr<scalar_t>();
|
||||
|
||||
// cuda blocks & threads:
|
||||
int blocksH = std::max((int)(16L / sizeD), 1);
|
||||
dim3 blocks(grid_x, blocksH);
|
||||
dim3 threads(32, 8);
|
||||
// cuda blocks & threads:
|
||||
int blocksH = std::max((int)(16L / sizeD), 1);
|
||||
dim3 blocks(grid_x, blocksH);
|
||||
dim3 threads(32, 8);
|
||||
|
||||
if(atomic)
|
||||
{
|
||||
// run updateGradInput kernel, accumulate gradients atomically
|
||||
atomic_adaptive_average_gradinput <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
|
||||
gradInput_data, gradOutput_data,
|
||||
isizeH, isizeW, osizeH, osizeW);
|
||||
}
|
||||
else
|
||||
{
|
||||
// run updateGradInput kernel
|
||||
adaptive_average_gradinput <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
|
||||
gradInput_data, gradOutput_data,
|
||||
isizeH, isizeW, osizeH, osizeW);
|
||||
}
|
||||
});
|
||||
if(atomic)
|
||||
{
|
||||
// run updateGradInput kernel, accumulate gradients atomically
|
||||
atomic_adaptive_average_gradinput <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
|
||||
gradInput_data, gradOutput_data,
|
||||
isizeH, isizeW, osizeH, osizeW);
|
||||
}
|
||||
else
|
||||
{
|
||||
// run updateGradInput kernel
|
||||
adaptive_average_gradinput <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
|
||||
gradInput_data, gradOutput_data,
|
||||
isizeH, isizeW, osizeH, osizeW);
|
||||
}
|
||||
}
|
||||
);
|
||||
break;
|
||||
|
|
|
|||
|
|
@ -390,17 +390,15 @@ void adaptive_avg_pool3d_out_cuda_template(
|
|||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
|
||||
input.scalar_type(), "adaptive_avg_pool3d_cuda", [&] {
|
||||
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "adaptive_avg_pool3d_cuda", [&] {
|
||||
scalar_t* input_data = input.data_ptr<scalar_t>();
|
||||
scalar_t* output_data = output.data_ptr<scalar_t>();
|
||||
scalar_t* input_data = input.data_ptr<scalar_t>();
|
||||
scalar_t* output_data = output.data_ptr<scalar_t>();
|
||||
|
||||
adaptiveaveragepool_loop(
|
||||
input_data, output_data,
|
||||
totalZ,
|
||||
isizeT, isizeH, isizeW,
|
||||
osizeT, osizeH, osizeW,
|
||||
istrideD, istrideT, istrideH, istrideW);
|
||||
});
|
||||
adaptiveaveragepool_loop(
|
||||
input_data, output_data,
|
||||
totalZ,
|
||||
isizeT, isizeH, isizeW,
|
||||
osizeT, osizeH, osizeW,
|
||||
istrideD, istrideT, istrideH, istrideW);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -457,30 +455,26 @@ void adaptive_avg_pool3d_backward_out_cuda_template(
|
|||
if (atomic) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
|
||||
input.scalar_type(), "adaptive_avg_pool3d_backward_cuda", [&] {
|
||||
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "adaptive_avg_pool3d_backward_cuda", [&] {
|
||||
scalar_t* gradInput_data = gradInput.data_ptr<scalar_t>();
|
||||
scalar_t* gradOutput_data = gradOutput.data_ptr<scalar_t>();
|
||||
scalar_t* gradInput_data = gradInput.data_ptr<scalar_t>();
|
||||
scalar_t* gradOutput_data = gradOutput.data_ptr<scalar_t>();
|
||||
|
||||
atomicadaptiveaveragegradinput_loop(
|
||||
gradInput_data, gradOutput_data,
|
||||
totalZ,
|
||||
isizeT, isizeH, isizeW,
|
||||
osizeT, osizeH, osizeW);
|
||||
});
|
||||
atomicadaptiveaveragegradinput_loop(
|
||||
gradInput_data, gradOutput_data,
|
||||
totalZ,
|
||||
isizeT, isizeH, isizeW,
|
||||
osizeT, osizeH, osizeW);
|
||||
});
|
||||
} else {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
|
||||
input.scalar_type(), "adaptive_avg_pool3d_backward_cuda", [&] {
|
||||
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "adaptive_avg_pool3d_backward_cuda", [&] {
|
||||
scalar_t* gradInput_data = gradInput.data_ptr<scalar_t>();
|
||||
scalar_t* gradOutput_data = gradOutput.data_ptr<scalar_t>();
|
||||
scalar_t* gradInput_data = gradInput.data_ptr<scalar_t>();
|
||||
scalar_t* gradOutput_data = gradOutput.data_ptr<scalar_t>();
|
||||
|
||||
adaptiveaveragegradinput_loop(
|
||||
gradInput_data, gradOutput_data,
|
||||
totalZ,
|
||||
isizeT, isizeH, isizeW,
|
||||
osizeT, osizeH, osizeW);
|
||||
});
|
||||
adaptiveaveragegradinput_loop(
|
||||
gradInput_data, gradOutput_data,
|
||||
totalZ,
|
||||
isizeT, isizeH, isizeW,
|
||||
osizeT, osizeH, osizeW);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -232,27 +232,25 @@ void adaptive_max_pool2d_out_cuda_template(
|
|||
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
|
||||
"adaptive_max_pool2d_cuda",
|
||||
[&] {
|
||||
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "adaptive_max_pool2d_cuda", [&] {
|
||||
output.resize_({sizeD, osizeH, osizeW});
|
||||
indices.resize_({sizeD, osizeH, osizeW});
|
||||
output.resize_({sizeD, osizeH, osizeW});
|
||||
indices.resize_({sizeD, osizeH, osizeW});
|
||||
|
||||
scalar_t *input_data = input.data_ptr<scalar_t>();
|
||||
scalar_t *output_data = output.data_ptr<scalar_t>();
|
||||
int64_t *indices_data = indices.data_ptr<int64_t>();
|
||||
scalar_t *input_data = input.data_ptr<scalar_t>();
|
||||
scalar_t *output_data = output.data_ptr<scalar_t>();
|
||||
int64_t *indices_data = indices.data_ptr<int64_t>();
|
||||
|
||||
// cuda blocks & threads:
|
||||
int blocksH = (int)(16L / sizeD);
|
||||
blocksH = blocksH < 1 ? 1 : blocksH;
|
||||
dim3 blocks(sizeD, blocksH);
|
||||
dim3 threads(32, 8);
|
||||
// cuda blocks & threads:
|
||||
int blocksH = (int)(16L / sizeD);
|
||||
blocksH = blocksH < 1 ? 1 : blocksH;
|
||||
dim3 blocks(sizeD, blocksH);
|
||||
dim3 threads(32, 8);
|
||||
|
||||
// run maxpool kernel
|
||||
adaptivemaxpool <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
|
||||
input_data, output_data,
|
||||
indices_data,
|
||||
isizeH, isizeW, osizeH, osizeW,
|
||||
istrideD, istrideH, istrideW);
|
||||
});
|
||||
// run maxpool kernel
|
||||
adaptivemaxpool <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
|
||||
input_data, output_data,
|
||||
indices_data,
|
||||
isizeH, isizeW, osizeH, osizeW,
|
||||
istrideD, istrideH, istrideW);
|
||||
}
|
||||
);
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
|
|
@ -271,27 +269,25 @@ void adaptive_max_pool2d_out_cuda_template(
|
|||
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input_.scalar_type(),
|
||||
"adaptive_max_pool2d_cuda",
|
||||
[&] {
|
||||
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "adaptive_max_pool2d_cuda", [&] {
|
||||
output.resize_({sizeB, sizeD, osizeH, osizeW});
|
||||
indices.resize_({sizeB, sizeD, osizeH, osizeW});
|
||||
output.resize_({sizeB, sizeD, osizeH, osizeW});
|
||||
indices.resize_({sizeB, sizeD, osizeH, osizeW});
|
||||
|
||||
scalar_t *input_data = input_.data_ptr<scalar_t>();
|
||||
scalar_t *output_data = output.data_ptr<scalar_t>();
|
||||
int64_t *indices_data = indices.data_ptr<int64_t>();
|
||||
scalar_t *input_data = input_.data_ptr<scalar_t>();
|
||||
scalar_t *output_data = output.data_ptr<scalar_t>();
|
||||
int64_t *indices_data = indices.data_ptr<int64_t>();
|
||||
|
||||
// cuda blocks & threads:
|
||||
int blocksH = (int)(16L / sizeD);
|
||||
blocksH = blocksH < 1 ? 1 : blocksH;
|
||||
dim3 blocks(sizeB*sizeD, blocksH);
|
||||
dim3 threads(32, 8);
|
||||
// cuda blocks & threads:
|
||||
int blocksH = (int)(16L / sizeD);
|
||||
blocksH = blocksH < 1 ? 1 : blocksH;
|
||||
dim3 blocks(sizeB*sizeD, blocksH);
|
||||
dim3 threads(32, 8);
|
||||
|
||||
// run maxpool kernel
|
||||
adaptivemaxpool <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
|
||||
input_data, output_data,
|
||||
indices_data,
|
||||
isizeH, isizeW, osizeH, osizeW,
|
||||
istrideD, istrideH, istrideW);
|
||||
});
|
||||
// run maxpool kernel
|
||||
adaptivemaxpool <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
|
||||
input_data, output_data,
|
||||
indices_data,
|
||||
isizeH, isizeW, osizeH, osizeW,
|
||||
istrideD, istrideH, istrideW);
|
||||
}
|
||||
);
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
|
|
@ -333,34 +329,32 @@ void adaptive_max_pool2d_backward_out_cuda_template(
|
|||
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
|
||||
"adaptive_max_pool2d_backward_cuda",
|
||||
[&] {
|
||||
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "adaptive_max_pool2d_backward_cuda", [&] {
|
||||
scalar_t *gradInput_data = gradInput.data_ptr<scalar_t>();
|
||||
scalar_t *gradOutput_data = gradOutput.data_ptr<scalar_t>();
|
||||
int64_t *indices_data = indices.data_ptr<int64_t>();
|
||||
scalar_t *gradInput_data = gradInput.data_ptr<scalar_t>();
|
||||
scalar_t *gradOutput_data = gradOutput.data_ptr<scalar_t>();
|
||||
int64_t *indices_data = indices.data_ptr<int64_t>();
|
||||
|
||||
// cuda blocks & threads:
|
||||
int blocksH = (int)(16L / sizeD);
|
||||
blocksH = blocksH < 1 ? 1 : blocksH;
|
||||
dim3 blocks(sizeD, blocksH);
|
||||
dim3 threads(32, 8);
|
||||
// cuda blocks & threads:
|
||||
int blocksH = (int)(16L / sizeD);
|
||||
blocksH = blocksH < 1 ? 1 : blocksH;
|
||||
dim3 blocks(sizeD, blocksH);
|
||||
dim3 threads(32, 8);
|
||||
|
||||
if(atomic)
|
||||
{
|
||||
// run updateGradInput kernel, accumulate gradients atomically
|
||||
atomicadaptivemaxgradinput <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
|
||||
gradInput_data, gradOutput_data,
|
||||
indices_data,
|
||||
isizeH, isizeW, osizeH, osizeW);
|
||||
}
|
||||
else
|
||||
{
|
||||
// run updateGradInput kernel
|
||||
atomicadaptivemaxgradinput <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
|
||||
gradInput_data, gradOutput_data,
|
||||
indices_data,
|
||||
isizeH, isizeW, osizeH, osizeW);
|
||||
}
|
||||
});
|
||||
if(atomic)
|
||||
{
|
||||
// run updateGradInput kernel, accumulate gradients atomically
|
||||
atomicadaptivemaxgradinput <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
|
||||
gradInput_data, gradOutput_data,
|
||||
indices_data,
|
||||
isizeH, isizeW, osizeH, osizeW);
|
||||
}
|
||||
else
|
||||
{
|
||||
// run updateGradInput kernel
|
||||
atomicadaptivemaxgradinput <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
|
||||
gradInput_data, gradOutput_data,
|
||||
indices_data,
|
||||
isizeH, isizeW, osizeH, osizeW);
|
||||
}
|
||||
}
|
||||
);
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
|
|
@ -381,34 +375,32 @@ void adaptive_max_pool2d_backward_out_cuda_template(
|
|||
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
|
||||
"adaptive_max_pool2d_backward_cuda",
|
||||
[&] {
|
||||
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "adaptive_max_pool2d_backward_cuda", [&] {
|
||||
scalar_t *gradInput_data = gradInput.data_ptr<scalar_t>();
|
||||
scalar_t *gradOutput_data = gradOutput.data_ptr<scalar_t>();
|
||||
int64_t *indices_data = indices.data_ptr<int64_t>();
|
||||
scalar_t *gradInput_data = gradInput.data_ptr<scalar_t>();
|
||||
scalar_t *gradOutput_data = gradOutput.data_ptr<scalar_t>();
|
||||
int64_t *indices_data = indices.data_ptr<int64_t>();
|
||||
|
||||
// cuda blocks & threads:
|
||||
int blocksH = (int)(16L / sizeD);
|
||||
blocksH = blocksH < 1 ? 1 : blocksH;
|
||||
dim3 blocks(sizeB*sizeD, blocksH);
|
||||
dim3 threads(32, 8);
|
||||
// cuda blocks & threads:
|
||||
int blocksH = (int)(16L / sizeD);
|
||||
blocksH = blocksH < 1 ? 1 : blocksH;
|
||||
dim3 blocks(sizeB*sizeD, blocksH);
|
||||
dim3 threads(32, 8);
|
||||
|
||||
if(atomic)
|
||||
{
|
||||
// run updateGradInput kernel, accumulate gradients atomically
|
||||
atomicadaptivemaxgradinput <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
|
||||
gradInput_data, gradOutput_data,
|
||||
indices_data,
|
||||
isizeH, isizeW, osizeH, osizeW);
|
||||
}
|
||||
else
|
||||
{
|
||||
// run updateGradInput kernel, accumulate gradients atomically
|
||||
adaptivemaxgradinput <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
|
||||
gradInput_data, gradOutput_data,
|
||||
indices_data,
|
||||
isizeH, isizeW, osizeH, osizeW);
|
||||
}
|
||||
});
|
||||
if(atomic)
|
||||
{
|
||||
// run updateGradInput kernel, accumulate gradients atomically
|
||||
atomicadaptivemaxgradinput <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
|
||||
gradInput_data, gradOutput_data,
|
||||
indices_data,
|
||||
isizeH, isizeW, osizeH, osizeW);
|
||||
}
|
||||
else
|
||||
{
|
||||
// run updateGradInput kernel, accumulate gradients atomically
|
||||
adaptivemaxgradinput <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> (
|
||||
gradInput_data, gradOutput_data,
|
||||
indices_data,
|
||||
isizeH, isizeW, osizeH, osizeW);
|
||||
}
|
||||
}
|
||||
);
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
|
|
|
|||
|
|
@ -366,15 +366,13 @@ void adaptive_max_pool3d_out_cuda_template(
|
|||
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
|
||||
"adaptive_max_pool3d_cuda",
|
||||
[&] {
|
||||
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "adaptive_max_pool3d_cuda", [&] {
|
||||
scalar_t *input_data = input.data_ptr<scalar_t>();
|
||||
scalar_t *output_data = output.data_ptr<scalar_t>();
|
||||
int64_t *indices_data = indices.data_ptr<int64_t>();
|
||||
scalar_t *input_data = input.data_ptr<scalar_t>();
|
||||
scalar_t *output_data = output.data_ptr<scalar_t>();
|
||||
int64_t *indices_data = indices.data_ptr<int64_t>();
|
||||
|
||||
adaptivemaxpool_loop(
|
||||
input_data, output_data, indices_data, totalZ, isizeT, isizeH, isizeW,
|
||||
osizeT, osizeH, osizeW, istrideD, istrideT, istrideH, istrideW);
|
||||
});
|
||||
adaptivemaxpool_loop(
|
||||
input_data, output_data, indices_data, totalZ, isizeT, isizeH, isizeW,
|
||||
osizeT, osizeH, osizeW, istrideD, istrideT, istrideH, istrideW);
|
||||
}
|
||||
);
|
||||
}
|
||||
|
|
@ -435,32 +433,28 @@ void adaptive_max_pool3d_backward_out_cuda_template(
|
|||
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
|
||||
"adaptive_max_pool3d_backward_cuda",
|
||||
[&] {
|
||||
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "adaptive_max_pool3d_backward_cuda", [&] {
|
||||
scalar_t *gradInput_data = gradInput.data_ptr<scalar_t>();
|
||||
scalar_t *gradOutput_data = gradOutput.data_ptr<scalar_t>();
|
||||
int64_t *indices_data = indices.data_ptr<int64_t>();
|
||||
scalar_t *gradInput_data = gradInput.data_ptr<scalar_t>();
|
||||
scalar_t *gradOutput_data = gradOutput.data_ptr<scalar_t>();
|
||||
int64_t *indices_data = indices.data_ptr<int64_t>();
|
||||
|
||||
atomicadaptivemaxgradinput_loop(
|
||||
gradInput_data, gradOutput_data, indices_data,
|
||||
totalZ,
|
||||
isizeT, isizeH, isizeW, osizeT, osizeH, osizeW);
|
||||
});
|
||||
atomicadaptivemaxgradinput_loop(
|
||||
gradInput_data, gradOutput_data, indices_data,
|
||||
totalZ,
|
||||
isizeT, isizeH, isizeW, osizeT, osizeH, osizeW);
|
||||
}
|
||||
);
|
||||
} else {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
|
||||
"adaptive_max_pool3d_backward_cuda",
|
||||
[&] {
|
||||
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "adaptive_max_pool3d_backward_cuda", [&] {
|
||||
scalar_t *gradInput_data = gradInput.data_ptr<scalar_t>();
|
||||
scalar_t *gradOutput_data = gradOutput.data_ptr<scalar_t>();
|
||||
int64_t *indices_data = indices.data_ptr<int64_t>();
|
||||
scalar_t *gradInput_data = gradInput.data_ptr<scalar_t>();
|
||||
scalar_t *gradOutput_data = gradOutput.data_ptr<scalar_t>();
|
||||
int64_t *indices_data = indices.data_ptr<int64_t>();
|
||||
|
||||
adaptivemaxgradinput_loop(
|
||||
gradInput_data, gradOutput_data, indices_data,
|
||||
totalZ,
|
||||
isizeT, isizeH, isizeW, osizeT, osizeH, osizeW);
|
||||
});
|
||||
adaptivemaxgradinput_loop(
|
||||
gradInput_data, gradOutput_data, indices_data,
|
||||
totalZ,
|
||||
isizeT, isizeH, isizeW, osizeT, osizeH, osizeW);
|
||||
}
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -303,33 +303,15 @@ void avg_pool2d_out_cuda_template(
|
|||
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
|
||||
"avg_pool2d_out_cuda_frame",
|
||||
[&] {
|
||||
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "avg_pool2d_out_cuda_frame", [&] {
|
||||
using accscalar_t = acc_type<scalar_t, true>;
|
||||
using accscalar_t = acc_type<scalar_t, true>;
|
||||
|
||||
scalar_t *output_data = output.data_ptr<scalar_t>();
|
||||
scalar_t *input_data = input.data_ptr<scalar_t>();
|
||||
scalar_t *output_data = output.data_ptr<scalar_t>();
|
||||
scalar_t *input_data = input.data_ptr<scalar_t>();
|
||||
|
||||
switch (memory_format){
|
||||
case MemoryFormat::ChannelsLast: {
|
||||
output.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::ChannelsLast);
|
||||
avg_pool2d_out_cuda_frame_nhwc<scalar_t, accscalar_t>
|
||||
<<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
count,
|
||||
input_data,
|
||||
nbatch,
|
||||
nInputPlane,
|
||||
inputHeight, inputWidth,
|
||||
outputHeight, outputWidth,
|
||||
kH, kW,
|
||||
dH, dW,
|
||||
padH, padW,
|
||||
output_data,
|
||||
divisor_override_value,
|
||||
count_include_pad, use_divisor);
|
||||
break;
|
||||
}
|
||||
case MemoryFormat::Contiguous: {
|
||||
avg_pool2d_out_cuda_frame<scalar_t, accscalar_t>
|
||||
switch (memory_format){
|
||||
case MemoryFormat::ChannelsLast: {
|
||||
output.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::ChannelsLast);
|
||||
avg_pool2d_out_cuda_frame_nhwc<scalar_t, accscalar_t>
|
||||
<<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
count,
|
||||
input_data,
|
||||
|
|
@ -343,11 +325,27 @@ void avg_pool2d_out_cuda_template(
|
|||
output_data,
|
||||
divisor_override_value,
|
||||
count_include_pad, use_divisor);
|
||||
break;
|
||||
}
|
||||
default: TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
|
||||
break;
|
||||
}
|
||||
});
|
||||
case MemoryFormat::Contiguous: {
|
||||
avg_pool2d_out_cuda_frame<scalar_t, accscalar_t>
|
||||
<<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
count,
|
||||
input_data,
|
||||
nbatch,
|
||||
nInputPlane,
|
||||
inputHeight, inputWidth,
|
||||
outputHeight, outputWidth,
|
||||
kH, kW,
|
||||
dH, dW,
|
||||
padH, padW,
|
||||
output_data,
|
||||
divisor_override_value,
|
||||
count_include_pad, use_divisor);
|
||||
break;
|
||||
}
|
||||
default: TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
|
|
@ -437,51 +435,49 @@ Tensor& avg_pool2d_backward_out_cuda_template(
|
|||
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
|
||||
"avg_pool2d_backward_out_cuda_frame",
|
||||
[&] {
|
||||
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "avg_pool2d_backward_out_cuda_frame", [&] {
|
||||
using accscalar_t = acc_type<scalar_t, true>;
|
||||
using accscalar_t = acc_type<scalar_t, true>;
|
||||
|
||||
scalar_t *gradOutput_data = gradOutput.data_ptr<scalar_t>();
|
||||
scalar_t *gradInput_data = gradInput.data_ptr<scalar_t>();
|
||||
scalar_t *gradOutput_data = gradOutput.data_ptr<scalar_t>();
|
||||
scalar_t *gradInput_data = gradInput.data_ptr<scalar_t>();
|
||||
|
||||
switch (memory_format) {
|
||||
case MemoryFormat::ChannelsLast: {
|
||||
gradInput.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::ChannelsLast);
|
||||
avg_pool2d_backward_out_cuda_frame_nhwc<scalar_t, accscalar_t>
|
||||
<<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
count,
|
||||
gradOutput_data,
|
||||
nbatch,
|
||||
nInputPlane,
|
||||
inputHeight, inputWidth,
|
||||
outputHeight, outputWidth,
|
||||
kH, kW,
|
||||
dH, dW,
|
||||
padH, padW,
|
||||
gradInput_data,
|
||||
divisor_override_value,
|
||||
count_include_pad, use_divisor);
|
||||
break;
|
||||
}
|
||||
case MemoryFormat::Contiguous: {
|
||||
avg_pool2d_backward_out_cuda_frame<scalar_t, accscalar_t>
|
||||
<<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
count,
|
||||
gradOutput_data,
|
||||
nbatch,
|
||||
nInputPlane,
|
||||
inputHeight, inputWidth,
|
||||
outputHeight, outputWidth,
|
||||
kH, kW,
|
||||
dH, dW,
|
||||
padH, padW,
|
||||
gradInput_data,
|
||||
divisor_override_value,
|
||||
count_include_pad, use_divisor);
|
||||
break;
|
||||
}
|
||||
default: TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
|
||||
switch (memory_format) {
|
||||
case MemoryFormat::ChannelsLast: {
|
||||
gradInput.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::ChannelsLast);
|
||||
avg_pool2d_backward_out_cuda_frame_nhwc<scalar_t, accscalar_t>
|
||||
<<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
count,
|
||||
gradOutput_data,
|
||||
nbatch,
|
||||
nInputPlane,
|
||||
inputHeight, inputWidth,
|
||||
outputHeight, outputWidth,
|
||||
kH, kW,
|
||||
dH, dW,
|
||||
padH, padW,
|
||||
gradInput_data,
|
||||
divisor_override_value,
|
||||
count_include_pad, use_divisor);
|
||||
break;
|
||||
}
|
||||
});
|
||||
case MemoryFormat::Contiguous: {
|
||||
avg_pool2d_backward_out_cuda_frame<scalar_t, accscalar_t>
|
||||
<<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
count,
|
||||
gradOutput_data,
|
||||
nbatch,
|
||||
nInputPlane,
|
||||
inputHeight, inputWidth,
|
||||
outputHeight, outputWidth,
|
||||
kH, kW,
|
||||
dH, dW,
|
||||
padH, padW,
|
||||
gradInput_data,
|
||||
divisor_override_value,
|
||||
count_include_pad, use_divisor);
|
||||
break;
|
||||
}
|
||||
default: TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
|
|
|
|||
|
|
@ -415,44 +415,42 @@ void avg_pool3d_out_cuda_template(
|
|||
input.scalar_type(),
|
||||
"avg_pool3d_out_cuda",
|
||||
[&] {
|
||||
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "avg_pool3d_out_cuda", [&] {
|
||||
using accscalar_t = acc_type<scalar_t, true>;
|
||||
int64_t totalZ = otime * nslices * nbatch;
|
||||
int64_t offsetZ = 0;
|
||||
dim3 block(32, 8);
|
||||
using accscalar_t = acc_type<scalar_t, true>;
|
||||
int64_t totalZ = otime * nslices * nbatch;
|
||||
int64_t offsetZ = 0;
|
||||
dim3 block(32, 8);
|
||||
|
||||
while (totalZ > 0) {
|
||||
dim3 grid(cuda::ATenCeilDiv(owidth, static_cast<int64_t>(block.x)),
|
||||
cuda::ATenCeilDiv(oheight, static_cast<int64_t>(block.y)),
|
||||
totalZ > 65535 ? 65535 : totalZ);
|
||||
while (totalZ > 0) {
|
||||
dim3 grid(cuda::ATenCeilDiv(owidth, static_cast<int64_t>(block.x)),
|
||||
cuda::ATenCeilDiv(oheight, static_cast<int64_t>(block.y)),
|
||||
totalZ > 65535 ? 65535 : totalZ);
|
||||
|
||||
switch (kW) {
|
||||
LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(1);
|
||||
LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(2);
|
||||
LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(3);
|
||||
LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(4);
|
||||
LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(5);
|
||||
LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(6);
|
||||
LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(7);
|
||||
default:
|
||||
avg_pool3d_cuda_update_output<scalar_t, accscalar_t>
|
||||
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
work_input.packed_accessor64<scalar_t, 4>(),
|
||||
work_output.packed_accessor64<scalar_t, 4>(),
|
||||
kT, kH, kW,
|
||||
dT, dH, dW,
|
||||
padT, padH, padW,
|
||||
count_include_pad,
|
||||
offsetZ, divisor);
|
||||
break;
|
||||
}
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
totalZ -= 65535;
|
||||
offsetZ += 65535;
|
||||
switch (kW) {
|
||||
LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(1);
|
||||
LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(2);
|
||||
LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(3);
|
||||
LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(4);
|
||||
LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(5);
|
||||
LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(6);
|
||||
LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(7);
|
||||
default:
|
||||
avg_pool3d_cuda_update_output<scalar_t, accscalar_t>
|
||||
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
work_input.packed_accessor64<scalar_t, 4>(),
|
||||
work_output.packed_accessor64<scalar_t, 4>(),
|
||||
kT, kH, kW,
|
||||
dT, dH, dW,
|
||||
padT, padH, padW,
|
||||
count_include_pad,
|
||||
offsetZ, divisor);
|
||||
break;
|
||||
}
|
||||
});
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
totalZ -= 65535;
|
||||
offsetZ += 65535;
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
|
|
@ -559,38 +557,36 @@ void avg_pool3d_backward_out_cuda_template(
|
|||
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
|
||||
"avg_pool3d_backward_out_frame_stride1",
|
||||
[&] {
|
||||
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "avg_pool3d_backward_out_frame_stride1", [&] {
|
||||
using accscalar_t = acc_type<scalar_t, true>;
|
||||
int64_t totalZ = itime * nslices * nbatch;
|
||||
int64_t offsetZ = 0;
|
||||
dim3 block(32, 8);
|
||||
using accscalar_t = acc_type<scalar_t, true>;
|
||||
int64_t totalZ = itime * nslices * nbatch;
|
||||
int64_t offsetZ = 0;
|
||||
dim3 block(32, 8);
|
||||
|
||||
accscalar_t divide_factor;
|
||||
if (divisor) {
|
||||
divide_factor = static_cast<accscalar_t>(divisor);
|
||||
} else {
|
||||
divide_factor = static_cast<accscalar_t>(kT * kH * kW);
|
||||
}
|
||||
accscalar_t divide_factor;
|
||||
if (divisor) {
|
||||
divide_factor = static_cast<accscalar_t>(divisor);
|
||||
} else {
|
||||
divide_factor = static_cast<accscalar_t>(kT * kH * kW);
|
||||
}
|
||||
|
||||
while (totalZ > 0) {
|
||||
dim3 grid(cuda::ATenCeilDiv(iwidth, static_cast<int64_t>(block.x)),
|
||||
cuda::ATenCeilDiv(iheight, static_cast<int64_t>(block.y)),
|
||||
totalZ > 65535 ? 65535 : totalZ);
|
||||
while (totalZ > 0) {
|
||||
dim3 grid(cuda::ATenCeilDiv(iwidth, static_cast<int64_t>(block.x)),
|
||||
cuda::ATenCeilDiv(iheight, static_cast<int64_t>(block.y)),
|
||||
totalZ > 65535 ? 65535 : totalZ);
|
||||
|
||||
avg_pool3d_single_backward_out_frame_stride1<scalar_t, accscalar_t>
|
||||
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
work_grad_output.packed_accessor64<scalar_t, 4>(),
|
||||
work_grad_input.packed_accessor64<scalar_t, 4>(),
|
||||
kT, kH, kW,
|
||||
1.0f/divide_factor,
|
||||
offsetZ);
|
||||
avg_pool3d_single_backward_out_frame_stride1<scalar_t, accscalar_t>
|
||||
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
work_grad_output.packed_accessor64<scalar_t, 4>(),
|
||||
work_grad_input.packed_accessor64<scalar_t, 4>(),
|
||||
kT, kH, kW,
|
||||
1.0f/divide_factor,
|
||||
offsetZ);
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
totalZ -= 65535;
|
||||
offsetZ += 65535;
|
||||
}
|
||||
});
|
||||
totalZ -= 65535;
|
||||
offsetZ += 65535;
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
|
|
@ -598,46 +594,44 @@ void avg_pool3d_backward_out_cuda_template(
|
|||
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(),
|
||||
"avg_pool3d_backward_out_frame",
|
||||
[&] {
|
||||
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "avg_pool3d_backward_out_frame", [&] {
|
||||
using accscalar_t = acc_type<scalar_t, true>;
|
||||
int64_t totalZ = otime * nslices * nbatch;
|
||||
int64_t offsetZ = 0;
|
||||
dim3 block(32, 8);
|
||||
using accscalar_t = acc_type<scalar_t, true>;
|
||||
int64_t totalZ = otime * nslices * nbatch;
|
||||
int64_t offsetZ = 0;
|
||||
dim3 block(32, 8);
|
||||
|
||||
while (totalZ > 0) {
|
||||
dim3 grid(cuda::ATenCeilDiv(owidth, static_cast<int64_t>(block.x)),
|
||||
cuda::ATenCeilDiv(oheight, static_cast<int64_t>(block.y)),
|
||||
totalZ > 65535 ? 65535 : totalZ);
|
||||
while (totalZ > 0) {
|
||||
dim3 grid(cuda::ATenCeilDiv(owidth, static_cast<int64_t>(block.x)),
|
||||
cuda::ATenCeilDiv(oheight, static_cast<int64_t>(block.y)),
|
||||
totalZ > 65535 ? 65535 : totalZ);
|
||||
|
||||
if (kernelsOverlap) {
|
||||
avg_pool3d_cuda_update_grad_input_atomic<scalar_t, accscalar_t>
|
||||
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
work_grad_output.packed_accessor64<scalar_t, 4>(),
|
||||
work_grad_input.packed_accessor64<scalar_t, 4>(),
|
||||
kT, kH, kW,
|
||||
dT, dH, dW,
|
||||
padT, padH, padW,
|
||||
count_include_pad,
|
||||
offsetZ, divisor);
|
||||
}
|
||||
else {
|
||||
avg_pool3d_cuda_update_grad_input<scalar_t, accscalar_t>
|
||||
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
work_grad_output.packed_accessor64<scalar_t, 4>(),
|
||||
work_grad_input.packed_accessor64<scalar_t, 4>(),
|
||||
kT, kH, kW,
|
||||
dT, dH, dW,
|
||||
padT, padH, padW,
|
||||
count_include_pad,
|
||||
offsetZ, divisor);
|
||||
}
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
totalZ -= 65535;
|
||||
offsetZ += 65535;
|
||||
if (kernelsOverlap) {
|
||||
avg_pool3d_cuda_update_grad_input_atomic<scalar_t, accscalar_t>
|
||||
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
work_grad_output.packed_accessor64<scalar_t, 4>(),
|
||||
work_grad_input.packed_accessor64<scalar_t, 4>(),
|
||||
kT, kH, kW,
|
||||
dT, dH, dW,
|
||||
padT, padH, padW,
|
||||
count_include_pad,
|
||||
offsetZ, divisor);
|
||||
}
|
||||
});
|
||||
else {
|
||||
avg_pool3d_cuda_update_grad_input<scalar_t, accscalar_t>
|
||||
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
work_grad_output.packed_accessor64<scalar_t, 4>(),
|
||||
work_grad_input.packed_accessor64<scalar_t, 4>(),
|
||||
kT, kH, kW,
|
||||
dT, dH, dW,
|
||||
padT, padH, padW,
|
||||
count_include_pad,
|
||||
offsetZ, divisor);
|
||||
}
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
totalZ -= 65535;
|
||||
offsetZ += 65535;
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11842,7 +11842,6 @@ class TestNNDeviceType(NNTestCase):
|
|||
self._test_bfloat16_ops(torch.nn.LeakyReLU(), device, inp_dims=(5), prec=1e-2)
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfNotRocm
|
||||
def test_pooling_bfloat16(self, device):
|
||||
self._test_bfloat16_ops(torch.nn.AvgPool1d(3, stride=2), device, inp_dims=(8, 4, 16), prec=0.05)
|
||||
self._test_bfloat16_ops(torch.nn.AvgPool2d(3, stride=2), device, inp_dims=(8, 4, 16, 16), prec=0.05)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user