diff --git a/aten/src/ATen/native/cuda/CUDALoops.cuh b/aten/src/ATen/native/cuda/CUDALoops.cuh index a4f05852ae5..d9d9ce658e2 100644 --- a/aten/src/ATen/native/cuda/CUDALoops.cuh +++ b/aten/src/ATen/native/cuda/CUDALoops.cuh @@ -226,9 +226,15 @@ C10_LAUNCH_BOUNDS_1(num_threads()) __global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) { using traits = function_traits; constexpr auto io_size = calc_io_size(); - int remaining = N - io_block_work_size() * blockIdx.x; +#ifdef __gfx942__ + constexpr int tws = (io_size >= 2) ? 8 : 16; +#else + constexpr int tws = elems_per_thread(); +#endif + constexpr int bws = tws * num_threads(); + int remaining = N - bws * blockIdx.x; - if (remaining < io_block_work_size()) { // if this block handles the reminder, + if (remaining < bws) { // if this block handles the reminder, // just do a naive unrolled loop auto input_calc = TrivialOffsetCalculator(); auto output_calc = TrivialOffsetCalculator<1>(); @@ -240,14 +246,14 @@ __global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) { decltype(output_calc), memory::LoadWithoutCast, memory::StoreWithoutCast, - elems_per_thread()>( + tws>( data, remaining, input_calc, output_calc, loader, storer); elementwise_kernel_helper(f, policy); } else { // if this block has a full `block_work_size` data to handle, use // vectorized memory access constexpr auto optimal_vec_size = calc_optimal_vec_size(); elementwise_kernel_helper( - f, memory::policies::vectorized()>(data)); + f, memory::policies::vectorized(data)); } } #endif // USE_ROCM @@ -285,10 +291,12 @@ static inline void launch_vectorized_kernel( TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max()); using traits = function_traits; constexpr auto io_size = calc_io_size(); - int64_t grid = (N + io_block_work_size() - 1) / io_block_work_size(); auto stream = at::cuda::getCurrentCUDAStream(); #ifdef USE_ROCM int vec_size = memory::can_vectorize_up_to(data); + c10::DeviceIndex curDevice = -1; + AT_CUDA_CHECK(c10::cuda::GetDevice(&curDevice)); + int tws = at::detail::getCUDAHooks().isGPUArch({"gfx942"}, curDevice) ? ((io_size >= 2) ? 8 : 16) : elems_per_thread(); #else using cpp_type = typename function_traits::result_type; const uint16_t max_vec_size = memory::can_vectorize_up_to(data); @@ -305,7 +313,10 @@ static inline void launch_vectorized_kernel( if constexpr (sizeof(cpp_type) < 2) { vec_size = std::min(vec_size, 4); } + int tws = elems_per_thread(); #endif + int bws = tws * num_threads(); + int64_t grid = (N + bws - 1) / bws; switch (vec_size) { #ifdef USE_ROCM case 16: @@ -334,8 +345,9 @@ static inline void launch_vectorized_kernel( auto output_calc = TrivialOffsetCalculator<1>(); auto loader = memory::LoadWithoutCast(); auto storer = memory::StoreWithoutCast(); + int64_t grid_unrolled = (N + io_block_work_size() - 1) / io_block_work_size(); unrolled_elementwise_kernel()> - <<>>( + <<>>( N, f, data, input_calc, output_calc, loader, storer); C10_CUDA_KERNEL_LAUNCH_CHECK(); break; diff --git a/aten/src/ATen/native/cuda/MemoryAccess.cuh b/aten/src/ATen/native/cuda/MemoryAccess.cuh index a84e49f15ff..d29ba35393a 100644 --- a/aten/src/ATen/native/cuda/MemoryAccess.cuh +++ b/aten/src/ATen/native/cuda/MemoryAccess.cuh @@ -187,6 +187,30 @@ template __device__ aligned_vector load_vector(const scalar_t *base_ptr, uint32_t offset) { using vec_t = aligned_vector; auto *from = reinterpret_cast(base_ptr); +#if defined(USE_ROCM) && defined(__gfx942__) + using longx2 = __attribute__((__vector_size__(4*sizeof(int)))) int; + if constexpr (sizeof(vec_t) == sizeof(int)) { + union { + vec_t v; + int i; + } tmpt = { .i = __builtin_nontemporal_load(reinterpret_cast(&(from[offset]))) }; + return tmpt.v; + } + else if constexpr (sizeof(vec_t) == sizeof(long)) { + union { + vec_t v; + long i; + } tmpt = { .i = __builtin_nontemporal_load(reinterpret_cast(&(from[offset]))) }; + return tmpt.v; + } + else if constexpr (sizeof(vec_t) == sizeof(longx2)) { + union { + vec_t v; + longx2 i; + } tmpt = { .i = __builtin_nontemporal_load(reinterpret_cast(&(from[offset]))) }; + return tmpt.v; + } +#endif return from[offset]; }