mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[ROCm] Improve vectorized elementwise kernel performance in MI300X (#153634)
* Use non-temporal loads to improve the vectorized elementwise kernel performance on MI300 * Use thread_work_size of 8 or 16 for vectorized elementwise kernel Co-author: @amd-hhashemi Pull Request resolved: https://github.com/pytorch/pytorch/pull/153634 Approved by: https://github.com/jeffdaily
This commit is contained in:
parent
555fc05868
commit
6be829535f
|
|
@ -226,9 +226,15 @@ C10_LAUNCH_BOUNDS_1(num_threads())
|
||||||
__global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
|
__global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
|
||||||
using traits = function_traits<func_t>;
|
using traits = function_traits<func_t>;
|
||||||
constexpr auto io_size = calc_io_size<func_t>();
|
constexpr auto io_size = calc_io_size<func_t>();
|
||||||
int remaining = N - io_block_work_size<io_size>() * blockIdx.x;
|
#ifdef __gfx942__
|
||||||
|
constexpr int tws = (io_size >= 2) ? 8 : 16;
|
||||||
|
#else
|
||||||
|
constexpr int tws = elems_per_thread<io_size>();
|
||||||
|
#endif
|
||||||
|
constexpr int bws = tws * num_threads();
|
||||||
|
int remaining = N - bws * blockIdx.x;
|
||||||
|
|
||||||
if (remaining < io_block_work_size<io_size>()) { // if this block handles the reminder,
|
if (remaining < bws) { // if this block handles the reminder,
|
||||||
// just do a naive unrolled loop
|
// just do a naive unrolled loop
|
||||||
auto input_calc = TrivialOffsetCalculator<traits::arity>();
|
auto input_calc = TrivialOffsetCalculator<traits::arity>();
|
||||||
auto output_calc = TrivialOffsetCalculator<1>();
|
auto output_calc = TrivialOffsetCalculator<1>();
|
||||||
|
|
@ -240,14 +246,14 @@ __global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
|
||||||
decltype(output_calc),
|
decltype(output_calc),
|
||||||
memory::LoadWithoutCast,
|
memory::LoadWithoutCast,
|
||||||
memory::StoreWithoutCast,
|
memory::StoreWithoutCast,
|
||||||
elems_per_thread<io_size>()>(
|
tws>(
|
||||||
data, remaining, input_calc, output_calc, loader, storer);
|
data, remaining, input_calc, output_calc, loader, storer);
|
||||||
elementwise_kernel_helper(f, policy);
|
elementwise_kernel_helper(f, policy);
|
||||||
} else { // if this block has a full `block_work_size` data to handle, use
|
} else { // if this block has a full `block_work_size` data to handle, use
|
||||||
// vectorized memory access
|
// vectorized memory access
|
||||||
constexpr auto optimal_vec_size = calc_optimal_vec_size<vec_size, io_size>();
|
constexpr auto optimal_vec_size = calc_optimal_vec_size<vec_size, io_size>();
|
||||||
elementwise_kernel_helper(
|
elementwise_kernel_helper(
|
||||||
f, memory::policies::vectorized<optimal_vec_size, array_t, elems_per_thread<io_size>()>(data));
|
f, memory::policies::vectorized<optimal_vec_size, array_t, tws>(data));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif // USE_ROCM
|
#endif // USE_ROCM
|
||||||
|
|
@ -285,10 +291,12 @@ static inline void launch_vectorized_kernel(
|
||||||
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
|
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
|
||||||
using traits = function_traits<func_t>;
|
using traits = function_traits<func_t>;
|
||||||
constexpr auto io_size = calc_io_size<func_t>();
|
constexpr auto io_size = calc_io_size<func_t>();
|
||||||
int64_t grid = (N + io_block_work_size<io_size>() - 1) / io_block_work_size<io_size>();
|
|
||||||
auto stream = at::cuda::getCurrentCUDAStream();
|
auto stream = at::cuda::getCurrentCUDAStream();
|
||||||
#ifdef USE_ROCM
|
#ifdef USE_ROCM
|
||||||
int vec_size = memory::can_vectorize_up_to<func_t>(data);
|
int vec_size = memory::can_vectorize_up_to<func_t>(data);
|
||||||
|
c10::DeviceIndex curDevice = -1;
|
||||||
|
AT_CUDA_CHECK(c10::cuda::GetDevice(&curDevice));
|
||||||
|
int tws = at::detail::getCUDAHooks().isGPUArch({"gfx942"}, curDevice) ? ((io_size >= 2) ? 8 : 16) : elems_per_thread<io_size>();
|
||||||
#else
|
#else
|
||||||
using cpp_type = typename function_traits<func_t>::result_type;
|
using cpp_type = typename function_traits<func_t>::result_type;
|
||||||
const uint16_t max_vec_size = memory::can_vectorize_up_to<func_t>(data);
|
const uint16_t max_vec_size = memory::can_vectorize_up_to<func_t>(data);
|
||||||
|
|
@ -305,7 +313,10 @@ static inline void launch_vectorized_kernel(
|
||||||
if constexpr (sizeof(cpp_type) < 2) {
|
if constexpr (sizeof(cpp_type) < 2) {
|
||||||
vec_size = std::min<uint16_t>(vec_size, 4);
|
vec_size = std::min<uint16_t>(vec_size, 4);
|
||||||
}
|
}
|
||||||
|
int tws = elems_per_thread<io_size>();
|
||||||
#endif
|
#endif
|
||||||
|
int bws = tws * num_threads();
|
||||||
|
int64_t grid = (N + bws - 1) / bws;
|
||||||
switch (vec_size) {
|
switch (vec_size) {
|
||||||
#ifdef USE_ROCM
|
#ifdef USE_ROCM
|
||||||
case 16:
|
case 16:
|
||||||
|
|
@ -334,8 +345,9 @@ static inline void launch_vectorized_kernel(
|
||||||
auto output_calc = TrivialOffsetCalculator<1>();
|
auto output_calc = TrivialOffsetCalculator<1>();
|
||||||
auto loader = memory::LoadWithoutCast();
|
auto loader = memory::LoadWithoutCast();
|
||||||
auto storer = memory::StoreWithoutCast();
|
auto storer = memory::StoreWithoutCast();
|
||||||
|
int64_t grid_unrolled = (N + io_block_work_size<io_size>() - 1) / io_block_work_size<io_size>();
|
||||||
unrolled_elementwise_kernel<func_t, array_t, elems_per_thread<io_size>()>
|
unrolled_elementwise_kernel<func_t, array_t, elems_per_thread<io_size>()>
|
||||||
<<<grid, num_threads(), 0, stream>>>(
|
<<<grid_unrolled, num_threads(), 0, stream>>>(
|
||||||
N, f, data, input_calc, output_calc, loader, storer);
|
N, f, data, input_calc, output_calc, loader, storer);
|
||||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||||
break;
|
break;
|
||||||
|
|
|
||||||
|
|
@ -187,6 +187,30 @@ template <int vec_size, typename scalar_t>
|
||||||
__device__ aligned_vector<scalar_t, vec_size> load_vector(const scalar_t *base_ptr, uint32_t offset) {
|
__device__ aligned_vector<scalar_t, vec_size> load_vector(const scalar_t *base_ptr, uint32_t offset) {
|
||||||
using vec_t = aligned_vector<scalar_t, vec_size>;
|
using vec_t = aligned_vector<scalar_t, vec_size>;
|
||||||
auto *from = reinterpret_cast<const vec_t *>(base_ptr);
|
auto *from = reinterpret_cast<const vec_t *>(base_ptr);
|
||||||
|
#if defined(USE_ROCM) && defined(__gfx942__)
|
||||||
|
using longx2 = __attribute__((__vector_size__(4*sizeof(int)))) int;
|
||||||
|
if constexpr (sizeof(vec_t) == sizeof(int)) {
|
||||||
|
union {
|
||||||
|
vec_t v;
|
||||||
|
int i;
|
||||||
|
} tmpt = { .i = __builtin_nontemporal_load(reinterpret_cast<const int *>(&(from[offset]))) };
|
||||||
|
return tmpt.v;
|
||||||
|
}
|
||||||
|
else if constexpr (sizeof(vec_t) == sizeof(long)) {
|
||||||
|
union {
|
||||||
|
vec_t v;
|
||||||
|
long i;
|
||||||
|
} tmpt = { .i = __builtin_nontemporal_load(reinterpret_cast<const long *>(&(from[offset]))) };
|
||||||
|
return tmpt.v;
|
||||||
|
}
|
||||||
|
else if constexpr (sizeof(vec_t) == sizeof(longx2)) {
|
||||||
|
union {
|
||||||
|
vec_t v;
|
||||||
|
longx2 i;
|
||||||
|
} tmpt = { .i = __builtin_nontemporal_load(reinterpret_cast<const longx2 *>(&(from[offset]))) };
|
||||||
|
return tmpt.v;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
return from[offset];
|
return from[offset];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user