[ROCm] Improve vectorized elementwise kernel performance in MI300X (#153634)

* Use non-temporal loads to improve the vectorized elementwise kernel performance on MI300
* Use thread_work_size of 8 or 16 for vectorized elementwise kernel

Co-author: @amd-hhashemi

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153634
Approved by: https://github.com/jeffdaily
This commit is contained in:
Jerry Mannil 2025-05-27 20:49:32 +00:00 committed by PyTorch MergeBot
parent 555fc05868
commit 6be829535f
2 changed files with 42 additions and 6 deletions

View File

@ -226,9 +226,15 @@ C10_LAUNCH_BOUNDS_1(num_threads())
__global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
using traits = function_traits<func_t>;
constexpr auto io_size = calc_io_size<func_t>();
int remaining = N - io_block_work_size<io_size>() * blockIdx.x;
#ifdef __gfx942__
constexpr int tws = (io_size >= 2) ? 8 : 16;
#else
constexpr int tws = elems_per_thread<io_size>();
#endif
constexpr int bws = tws * num_threads();
int remaining = N - bws * blockIdx.x;
if (remaining < io_block_work_size<io_size>()) { // if this block handles the reminder,
if (remaining < bws) { // if this block handles the reminder,
// just do a naive unrolled loop
auto input_calc = TrivialOffsetCalculator<traits::arity>();
auto output_calc = TrivialOffsetCalculator<1>();
@ -240,14 +246,14 @@ __global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
decltype(output_calc),
memory::LoadWithoutCast,
memory::StoreWithoutCast,
elems_per_thread<io_size>()>(
tws>(
data, remaining, input_calc, output_calc, loader, storer);
elementwise_kernel_helper(f, policy);
} else { // if this block has a full `block_work_size` data to handle, use
// vectorized memory access
constexpr auto optimal_vec_size = calc_optimal_vec_size<vec_size, io_size>();
elementwise_kernel_helper(
f, memory::policies::vectorized<optimal_vec_size, array_t, elems_per_thread<io_size>()>(data));
f, memory::policies::vectorized<optimal_vec_size, array_t, tws>(data));
}
}
#endif // USE_ROCM
@ -285,10 +291,12 @@ static inline void launch_vectorized_kernel(
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
using traits = function_traits<func_t>;
constexpr auto io_size = calc_io_size<func_t>();
int64_t grid = (N + io_block_work_size<io_size>() - 1) / io_block_work_size<io_size>();
auto stream = at::cuda::getCurrentCUDAStream();
#ifdef USE_ROCM
int vec_size = memory::can_vectorize_up_to<func_t>(data);
c10::DeviceIndex curDevice = -1;
AT_CUDA_CHECK(c10::cuda::GetDevice(&curDevice));
int tws = at::detail::getCUDAHooks().isGPUArch({"gfx942"}, curDevice) ? ((io_size >= 2) ? 8 : 16) : elems_per_thread<io_size>();
#else
using cpp_type = typename function_traits<func_t>::result_type;
const uint16_t max_vec_size = memory::can_vectorize_up_to<func_t>(data);
@ -305,7 +313,10 @@ static inline void launch_vectorized_kernel(
if constexpr (sizeof(cpp_type) < 2) {
vec_size = std::min<uint16_t>(vec_size, 4);
}
int tws = elems_per_thread<io_size>();
#endif
int bws = tws * num_threads();
int64_t grid = (N + bws - 1) / bws;
switch (vec_size) {
#ifdef USE_ROCM
case 16:
@ -334,8 +345,9 @@ static inline void launch_vectorized_kernel(
auto output_calc = TrivialOffsetCalculator<1>();
auto loader = memory::LoadWithoutCast();
auto storer = memory::StoreWithoutCast();
int64_t grid_unrolled = (N + io_block_work_size<io_size>() - 1) / io_block_work_size<io_size>();
unrolled_elementwise_kernel<func_t, array_t, elems_per_thread<io_size>()>
<<<grid, num_threads(), 0, stream>>>(
<<<grid_unrolled, num_threads(), 0, stream>>>(
N, f, data, input_calc, output_calc, loader, storer);
C10_CUDA_KERNEL_LAUNCH_CHECK();
break;

View File

@ -187,6 +187,30 @@ template <int vec_size, typename scalar_t>
__device__ aligned_vector<scalar_t, vec_size> load_vector(const scalar_t *base_ptr, uint32_t offset) {
using vec_t = aligned_vector<scalar_t, vec_size>;
auto *from = reinterpret_cast<const vec_t *>(base_ptr);
#if defined(USE_ROCM) && defined(__gfx942__)
using longx2 = __attribute__((__vector_size__(4*sizeof(int)))) int;
if constexpr (sizeof(vec_t) == sizeof(int)) {
union {
vec_t v;
int i;
} tmpt = { .i = __builtin_nontemporal_load(reinterpret_cast<const int *>(&(from[offset]))) };
return tmpt.v;
}
else if constexpr (sizeof(vec_t) == sizeof(long)) {
union {
vec_t v;
long i;
} tmpt = { .i = __builtin_nontemporal_load(reinterpret_cast<const long *>(&(from[offset]))) };
return tmpt.v;
}
else if constexpr (sizeof(vec_t) == sizeof(longx2)) {
union {
vec_t v;
longx2 i;
} tmpt = { .i = __builtin_nontemporal_load(reinterpret_cast<const longx2 *>(&(from[offset]))) };
return tmpt.v;
}
#endif
return from[offset];
}