[ROCM] Navi21 Enablement 3: Embedding kernels (#72809)

Summary:
This PR is a follow up to the following prs.
https://github.com/pytorch/pytorch/pull/69942
https://github.com/pytorch/pytorch/pull/72682

We are adding support to Navi21 GPUs which have a warpsize of 32. We cannot rely on a constant so we have to dynamically look up the warpsize when launching the kernel on the host side. Inside device functions this is not needed and the compiler can correctly detect the correct warpsize to replace the C10_WARP_SIZE constant.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/72809

Reviewed By: mruberry

Differential Revision: D34400737

Pulled By: ngimel

fbshipit-source-id: 1a1374465d4006e485d4d11531a4c78ddb178cdf
(cherry picked from commit 94211fe1f0)
This commit is contained in:
Michael Melesse 2022-02-22 19:41:01 -08:00 committed by PyTorch MergeBot
parent 299b40de50
commit 785ebb9d6d
3 changed files with 13 additions and 10 deletions

View File

@ -11,6 +11,7 @@
#include <ATen/native/cuda/EmbeddingBackwardKernel.cuh>
#include <ATen/native/cuda/SortingCommon.cuh>
#include <ATen/native/cuda/block_reduce.cuh>
#include <ATen/native/cuda/thread_constants.h>
#if CUB_SUPPORTS_SCAN_BY_KEY()
#include <thrust/iterator/reverse_iterator.h>
@ -249,8 +250,9 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice
auto indices_contig = indices.contiguous();
auto grad_weight = at::zeros({num_weights, grad_.size(-1)}, grad_.options());
int64_t stride = grad_weight.stride(0);
dim3 grid(ceil_div(stride, (int64_t)C10_WARP_SIZE));
dim3 block(C10_WARP_SIZE, BLOCKDIMY);
int warp_size = at::cuda::warp_size();
dim3 grid(ceil_div(stride, (int64_t)warp_size));
dim3 block(warp_size, BLOCKDIMY);
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16,
@ -263,7 +265,7 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice
embedding_backward_feature_kernel<scalar_t, accscalar_t, index_t>
<<<grid,
block,
sizeof(accscalar_t)*C10_WARP_SIZE*BLOCKDIMY + sizeof(int)*C10_WARP_SIZE*BLOCKDIMY,
sizeof(accscalar_t)*warp_size*BLOCKDIMY + sizeof(int)*warp_size*BLOCKDIMY,
stream>>>
(indices_contig.data_ptr<index_t>(),
grad.data_ptr<scalar_t>(),
@ -352,18 +354,18 @@ Tensor & embedding_renorm_cuda_(Tensor & self, const Tensor & indices,
num_indices
);
constexpr int num_threads = 128;
static_assert(num_threads % C10_WARP_SIZE == 0 &&
num_threads <= cuda_utils::kCUDABlockReduceMaxThreads,
int warp_size = at::cuda::warp_size();
TORCH_INTERNAL_ASSERT(num_threads() % warp_size == 0 &&
num_threads() <= cuda_utils::kCUDABlockReduceMaxThreads,
"BlockReduceSum requires all warps be active");
int64_t *num_unique_indices_ptr = num_unique_indices.data_ptr<int64_t>();
dim3 grid = unique_indices.numel();
dim3 block = num_threads;
dim3 block = num_threads();
int dim = self.stride(0);
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "embedding_renorm_cuda_", [&] {
using accscalar_t = acc_type<scalar_t, true>;
renorm_kernel<<<grid, block, (block.x / C10_WARP_SIZE) * sizeof(accscalar_t), stream>>>(
renorm_kernel<<<grid, block, (block.x / warp_size) * sizeof(accscalar_t), stream>>>(
self.data_ptr<scalar_t>(),
unique_indices.data_ptr<index_t>(),
static_cast<accscalar_t>(max_norm),

View File

@ -244,7 +244,8 @@ Tensor embedding_backward_cuda_kernel(
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
const int stride_warped = ceil_div(stride, C10_WARP_SIZE)*C10_WARP_SIZE;
const int warp_size = at::cuda::warp_size();
const int stride_warped = ceil_div(stride, warp_size)*warp_size;
const int block = std::min(stride_warped, MAX_BLOCK_SIZE);
const int grid = ceil_div(num_of_partial_segments*stride_warped, block);

View File

@ -515,7 +515,7 @@ Tensor _embedding_bag_per_sample_weights_backward_cuda(
AT_ASSERT(weight.size(1) == embedding_features);
const int threads_per_block = 512;
const int warps_per_block = threads_per_block / C10_WARP_SIZE;
const int warps_per_block = threads_per_block / at::cuda::warp_size();
dim3 block(threads_per_block);
dim3 grid((num_samples + warps_per_block - 1) / warps_per_block);