mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[ROCM] Navi21 Enablement 3: Embedding kernels (#72809)
Summary:
This PR is a follow up to the following prs.
https://github.com/pytorch/pytorch/pull/69942
https://github.com/pytorch/pytorch/pull/72682
We are adding support to Navi21 GPUs which have a warpsize of 32. We cannot rely on a constant so we have to dynamically look up the warpsize when launching the kernel on the host side. Inside device functions this is not needed and the compiler can correctly detect the correct warpsize to replace the C10_WARP_SIZE constant.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72809
Reviewed By: mruberry
Differential Revision: D34400737
Pulled By: ngimel
fbshipit-source-id: 1a1374465d4006e485d4d11531a4c78ddb178cdf
(cherry picked from commit 94211fe1f0)
This commit is contained in:
parent
299b40de50
commit
785ebb9d6d
|
|
@ -11,6 +11,7 @@
|
|||
#include <ATen/native/cuda/EmbeddingBackwardKernel.cuh>
|
||||
#include <ATen/native/cuda/SortingCommon.cuh>
|
||||
#include <ATen/native/cuda/block_reduce.cuh>
|
||||
#include <ATen/native/cuda/thread_constants.h>
|
||||
|
||||
#if CUB_SUPPORTS_SCAN_BY_KEY()
|
||||
#include <thrust/iterator/reverse_iterator.h>
|
||||
|
|
@ -249,8 +250,9 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice
|
|||
auto indices_contig = indices.contiguous();
|
||||
auto grad_weight = at::zeros({num_weights, grad_.size(-1)}, grad_.options());
|
||||
int64_t stride = grad_weight.stride(0);
|
||||
dim3 grid(ceil_div(stride, (int64_t)C10_WARP_SIZE));
|
||||
dim3 block(C10_WARP_SIZE, BLOCKDIMY);
|
||||
int warp_size = at::cuda::warp_size();
|
||||
dim3 grid(ceil_div(stride, (int64_t)warp_size));
|
||||
dim3 block(warp_size, BLOCKDIMY);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half, at::ScalarType::BFloat16,
|
||||
|
|
@ -263,7 +265,7 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice
|
|||
embedding_backward_feature_kernel<scalar_t, accscalar_t, index_t>
|
||||
<<<grid,
|
||||
block,
|
||||
sizeof(accscalar_t)*C10_WARP_SIZE*BLOCKDIMY + sizeof(int)*C10_WARP_SIZE*BLOCKDIMY,
|
||||
sizeof(accscalar_t)*warp_size*BLOCKDIMY + sizeof(int)*warp_size*BLOCKDIMY,
|
||||
stream>>>
|
||||
(indices_contig.data_ptr<index_t>(),
|
||||
grad.data_ptr<scalar_t>(),
|
||||
|
|
@ -352,18 +354,18 @@ Tensor & embedding_renorm_cuda_(Tensor & self, const Tensor & indices,
|
|||
num_indices
|
||||
);
|
||||
|
||||
constexpr int num_threads = 128;
|
||||
static_assert(num_threads % C10_WARP_SIZE == 0 &&
|
||||
num_threads <= cuda_utils::kCUDABlockReduceMaxThreads,
|
||||
int warp_size = at::cuda::warp_size();
|
||||
TORCH_INTERNAL_ASSERT(num_threads() % warp_size == 0 &&
|
||||
num_threads() <= cuda_utils::kCUDABlockReduceMaxThreads,
|
||||
"BlockReduceSum requires all warps be active");
|
||||
int64_t *num_unique_indices_ptr = num_unique_indices.data_ptr<int64_t>();
|
||||
dim3 grid = unique_indices.numel();
|
||||
dim3 block = num_threads;
|
||||
dim3 block = num_threads();
|
||||
int dim = self.stride(0);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "embedding_renorm_cuda_", [&] {
|
||||
using accscalar_t = acc_type<scalar_t, true>;
|
||||
renorm_kernel<<<grid, block, (block.x / C10_WARP_SIZE) * sizeof(accscalar_t), stream>>>(
|
||||
renorm_kernel<<<grid, block, (block.x / warp_size) * sizeof(accscalar_t), stream>>>(
|
||||
self.data_ptr<scalar_t>(),
|
||||
unique_indices.data_ptr<index_t>(),
|
||||
static_cast<accscalar_t>(max_norm),
|
||||
|
|
|
|||
|
|
@ -244,7 +244,8 @@ Tensor embedding_backward_cuda_kernel(
|
|||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
|
||||
const int stride_warped = ceil_div(stride, C10_WARP_SIZE)*C10_WARP_SIZE;
|
||||
const int warp_size = at::cuda::warp_size();
|
||||
const int stride_warped = ceil_div(stride, warp_size)*warp_size;
|
||||
const int block = std::min(stride_warped, MAX_BLOCK_SIZE);
|
||||
const int grid = ceil_div(num_of_partial_segments*stride_warped, block);
|
||||
|
||||
|
|
|
|||
|
|
@ -515,7 +515,7 @@ Tensor _embedding_bag_per_sample_weights_backward_cuda(
|
|||
AT_ASSERT(weight.size(1) == embedding_features);
|
||||
|
||||
const int threads_per_block = 512;
|
||||
const int warps_per_block = threads_per_block / C10_WARP_SIZE;
|
||||
const int warps_per_block = threads_per_block / at::cuda::warp_size();
|
||||
|
||||
dim3 block(threads_per_block);
|
||||
dim3 grid((num_samples + warps_per_block - 1) / warps_per_block);
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user