From a1cb67b69eb83a5552feb2c48ba704c31982a717 Mon Sep 17 00:00:00 2001 From: Doru Bercea Date: Tue, 11 Mar 2025 19:02:44 +0000 Subject: [PATCH] [ROCm] Improve backwards indexing when stride is not one (#147630) Improve backwards indexing when stride is not one. Pull Request resolved: https://github.com/pytorch/pytorch/pull/147630 Approved by: https://github.com/jeffdaily --- aten/src/ATen/native/cuda/Indexing.cu | 443 ++++++++++++-------------- test/test_indexing.py | 22 ++ 2 files changed, 230 insertions(+), 235 deletions(-) diff --git a/aten/src/ATen/native/cuda/Indexing.cu b/aten/src/ATen/native/cuda/Indexing.cu index 6531ef01ee1..6f52e5e5b02 100644 --- a/aten/src/ATen/native/cuda/Indexing.cu +++ b/aten/src/ATen/native/cuda/Indexing.cu @@ -55,6 +55,205 @@ constexpr uint64_t getDefaultMaxThreadsPerBlock() { #endif } +#ifdef USE_ROCM +#define SKIP_SORTED_INDICES 32 +template +__global__ void indexing_backward_kernel( + const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight, + int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim, bool accumulate) { + using opmath_t = at::opmath_type; + + extern __shared__ unsigned char smem[]; + auto smem_dups_cache = reinterpret_cast(smem); + + int smem_offset = threadIdx.y * C10_WARP_SIZE; + + int laneIdx = threadIdx.x % C10_WARP_SIZE; + int64_t grad_row = 0; + + for (int64_t z = blockIdx.z; z < outer_dim; z += gridDim.z) { + // Init duplicates every time we compute a new set of entries: + smem_dups_cache[smem_offset + laneIdx] = 0; + WARP_SYNC(); + + int64_t base_idx = blockIdx.x * blockDim.y * C10_WARP_SIZE + threadIdx.y * C10_WARP_SIZE; + int64_t idx = base_idx + laneIdx; + + if (idx < numel) { + int64_t crnt_sorted_idx = sorted_indices[idx]; + + if (idx == 0 || crnt_sorted_idx != sorted_indices[idx - 1]) { + // Determine the number of duplicates in advance: + int64_t num_duplicates = 1; + + // Lookahead in case there is a large number of duplicates. Once that is done, handle the tail. + while ((idx + num_duplicates + SKIP_SORTED_INDICES - 1) < numel) { + if (sorted_indices[idx + num_duplicates + SKIP_SORTED_INDICES - 1] != crnt_sorted_idx) break; + num_duplicates += SKIP_SORTED_INDICES; + } + while (((idx + num_duplicates) < numel) && (sorted_indices[idx + num_duplicates] == crnt_sorted_idx)) { + num_duplicates++; + } + + smem_dups_cache[smem_offset + laneIdx] = num_duplicates; + } + } + + WARP_SYNC(); + + // All lanes in the warp are still active here. Use them all to reduce duplicates when + // large number of duplicates are present: + for (int subwarp = 0; subwarp < C10_WARP_SIZE; subwarp++) { + // All lanes read the shared memory entry for number of duplicates + int64_t new_num_duplicates = smem_dups_cache[smem_offset + subwarp]; + + // Check if the original sub-warp had duplicates to eliminate, if not skip. + if (new_num_duplicates == 0) + continue; + + // There are duplicates that need eliminating: + int64_t new_idx = base_idx + subwarp; + int64_t new_crnt_sorted_idx = sorted_indices[new_idx]; + const int64_t new_weight_row = new_crnt_sorted_idx * stride + z * stride_before; + + if (!accumulate) { + const int64_t grad_row = ((int64_t)indices[new_idx + new_num_duplicates - 1]) * stride + z * numel * stride; + int64_t feature_dim = blockIdx.y * blockDim.x + threadIdx.x; + while (feature_dim < stride) { + grad_weight[new_weight_row + feature_dim] = grad_output[grad_row + feature_dim]; + feature_dim += gridDim.y * blockDim.x; + } + continue; + } + + for (int dup = 0; dup < new_num_duplicates; dup++) { + const int64_t grad_row = ((int64_t) indices[new_idx + dup]) * stride + z * numel * stride; + + // All lanes do the same thing up to here. + int64_t feature_dim = blockIdx.y * blockDim.x + threadIdx.x; + + // Each lane has a different feature_dim. + while (feature_dim < stride) { + grad_weight[new_weight_row + feature_dim] += grad_output[grad_row + feature_dim]; + feature_dim += gridDim.y * blockDim.x; + } + } + } + } +} + +template +__global__ void indexing_backward_kernel_stride_1( + const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight, + int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim, bool accumulate) { + using opmath_t = at::opmath_type; + + int laneIdx = threadIdx.x % C10_WARP_SIZE; + + const opmath_t scale = (opmath_t)1.0; + int64_t grad_row = 0; + + extern __shared__ unsigned char smem[]; + auto smem_dups_cache = reinterpret_cast(smem); + + // Each warp gets a different section of the share memory allocation: + int smem_offset = threadIdx.y * C10_WARP_SIZE; + + // Number of values processed by each thread (grain size) + for (int64_t z = blockIdx.z; z < outer_dim; z += gridDim.z) { + // Init duplicates every time we compute a new set of entries: + smem_dups_cache[smem_offset + laneIdx] = 0; + + int64_t base_idx = blockIdx.x * blockDim.y * C10_WARP_SIZE + threadIdx.y * C10_WARP_SIZE; + int64_t idx = base_idx + laneIdx; + + // Each lane calculates the number of duplicates: + if (idx < numel) { + int64_t crnt_sorted_idx = sorted_indices[idx]; + + if (idx == 0 || crnt_sorted_idx != sorted_indices[idx - 1]) { + // Determine the number of duplicates in advance: + int64_t num_duplicates = 1; + + // Lookahead in case there is a large number of duplicates. Once that is done, handle the tail. + while ((idx + num_duplicates + SKIP_SORTED_INDICES - 1) < numel) { + if (sorted_indices[idx + num_duplicates + SKIP_SORTED_INDICES - 1] != crnt_sorted_idx) break; + num_duplicates += SKIP_SORTED_INDICES; + } + while (((idx + num_duplicates) < numel) && (sorted_indices[idx + num_duplicates] == crnt_sorted_idx)) { + num_duplicates++; + } + + if (!accumulate) { + const int64_t weight_row = crnt_sorted_idx * stride + z * stride_before; + grad_row = ((int64_t)indices[idx + num_duplicates - 1]) * stride + z * numel * stride; + grad_weight[weight_row] = + static_cast(static_cast(grad_output[grad_row]) * scale); + continue; + } + + // Each lane sequentially handles the duplicate elimination: + if (num_duplicates < C10_WARP_SIZE) { + opmath_t gradient = (opmath_t)0.0; + const int64_t weight_row = crnt_sorted_idx * stride + z * stride_before; + for (int64_t i = 0; i < num_duplicates; ++i) { + grad_row = ((int64_t) indices[idx + i]) * stride + z * numel * stride; + gradient += static_cast(grad_output[grad_row]) * scale; + } + + grad_weight[weight_row] = static_cast(static_cast(grad_weight[weight_row]) + gradient); + } else { + // Add duplicate to the cache: + smem_dups_cache[smem_offset + laneIdx] = num_duplicates; + } + } + } + + WARP_SYNC(); + + // All lanes in the warp are still active here. Use them all to reduce duplicates when + // large number of duplicates are present: + for (int subwarp = 0; subwarp < C10_WARP_SIZE; subwarp++) { + // All lanes read the shared memory entry for number of duplicates + int64_t new_num_duplicates = smem_dups_cache[smem_offset + subwarp]; + + // Check if the original sub-warp had duplicates to eliminate, if not skip. + if (new_num_duplicates == 0) + continue; + + // There are duplicates that need eliminating: + int64_t new_idx = base_idx + subwarp; + int64_t new_crnt_sorted_idx = sorted_indices[new_idx]; + const int64_t new_weight_row = new_crnt_sorted_idx * stride + z * stride_before; + + // Result of the reduction will be in this variable: + opmath_t gradient = (opmath_t)0.0; + + int64_t num_warp_passes = new_num_duplicates / C10_WARP_SIZE; + // Parallel reduction across the array of duplicates using all the lanes in the warp: + for (int64_t i = 0; i < num_warp_passes; ++i) { + grad_row = ((int64_t) indices[new_idx + i * C10_WARP_SIZE + laneIdx]) * stride + z * numel * stride; + gradient += static_cast(grad_output[grad_row]) * scale; + } + + // Reduce across the lanes of the warp: + WARP_SYNC(); + for (int offset = C10_WARP_SIZE / 2; offset > 0; offset /= 2) { + gradient += WARP_SHFL_DOWN(gradient, offset); + } + + if (laneIdx == 0) { + for (int64_t i = num_warp_passes * C10_WARP_SIZE; i < new_num_duplicates; ++i) { + grad_row = ((int64_t) indices[new_idx + i]) * stride + z * numel * stride; + gradient += static_cast(grad_output[grad_row]) * scale; + } + + grad_weight[new_weight_row] = static_cast(static_cast(grad_weight[new_weight_row]) + gradient); + } + } + } +} +#else template __global__ void indexing_backward_kernel( const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight, @@ -133,169 +332,6 @@ __global__ void indexing_backward_kernel( } } -#ifdef USE_ROCM -template -__global__ void indexing_backward_kernel_rocm( - const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight, - int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim) { - - // This implementation is adopted from indexing_backward_kernel above. - using opmath_t = at::opmath_type; - for (int64_t z = blockIdx.z; z < outer_dim; z += gridDim.z){ - int64_t idx = blockIdx.x * blockDim.y + threadIdx.y; - if (idx < numel && (idx == 0 || sorted_indices[idx] != sorted_indices[idx - 1])){ - do { - // if not accumulate, we only keep the last duplicate index so skip those before it - if constexpr (!accumulate) { - if ((idx < numel - 1) && sorted_indices[idx] == sorted_indices[idx + 1]) { - idx++; - continue; - } - } - const int64_t weight_row = ((int64_t) sorted_indices[idx]) * stride + z * stride_before; - const int64_t grad_row = ((int64_t) indices[idx]) * stride + z * numel * stride; - - opmath_t gradient; - opmath_t weight; - - int64_t feature_dim = threadIdx.x + blockIdx.y * blockDim.x; - while (feature_dim < stride) { - gradient = static_cast(grad_output[grad_row + feature_dim]); - if constexpr (accumulate) { - weight = static_cast(grad_weight[weight_row + feature_dim]); - } - - if constexpr (accumulate) { - weight += gradient; - } else { - weight = gradient; - } - - grad_weight[weight_row + feature_dim] = static_cast(weight); - feature_dim += gridDim.y * blockDim.x; - } - - idx++; - } while (idx < numel && sorted_indices[idx] == sorted_indices[idx - 1]); - } - } -} -#endif - -#ifdef USE_ROCM -#define SKIP 32 -template -__global__ void indexing_backward_kernel_stride_1( - const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight, - int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim, bool accumulate) { - using opmath_t = at::opmath_type; - - int laneIdx = threadIdx.x % C10_WARP_SIZE; - - const opmath_t scale = (opmath_t)1.0; - int64_t grad_row = 0; - - extern __shared__ unsigned char smem[]; - auto smem_dups_cache = reinterpret_cast(smem); - - // Each warp gets a different section of the share memory allocation: - int smem_offset = threadIdx.y * C10_WARP_SIZE; - - // Number of values processed by each thread (grain size) - for (int64_t z = blockIdx.z; z < outer_dim; z += gridDim.z) { - // Init duplicates every time we compute a new set of entries: - smem_dups_cache[smem_offset + laneIdx] = 0; - - int64_t base_idx = blockIdx.x * blockDim.y * C10_WARP_SIZE + threadIdx.y * C10_WARP_SIZE; - int64_t idx = base_idx + laneIdx; - - // Each lane calculates the number of duplicates: - if (idx < numel) { - int64_t crnt_sorted_idx = sorted_indices[idx]; - - if (idx == 0 || crnt_sorted_idx != sorted_indices[idx - 1]) { - // Determine the number of duplicates in advance: - int64_t num_duplicates = 1; - - // Lookahead in case there is a large number of duplicates. Once that is done, handle the tail. - while ((idx + num_duplicates + SKIP - 1) < numel) { - if (sorted_indices[idx + num_duplicates + SKIP - 1] != crnt_sorted_idx) break; - num_duplicates += SKIP; - } - while (((idx + num_duplicates) < numel) && (sorted_indices[idx + num_duplicates] == crnt_sorted_idx)) { - num_duplicates++; - } - - if (!accumulate) { - const int64_t weight_row = crnt_sorted_idx * stride + z * stride_before; - grad_row = ((int64_t)indices[idx + num_duplicates - 1]) * stride + z * numel * stride; - grad_weight[weight_row] = - static_cast(static_cast(grad_output[grad_row]) * scale); - continue; - } - - // Each lane sequentially handles the duplicate elimination: - if (num_duplicates < C10_WARP_SIZE) { - opmath_t gradient = (opmath_t)0.0; - const int64_t weight_row = crnt_sorted_idx * stride + z * stride_before; - for (int64_t i = 0; i < num_duplicates; ++i) { - grad_row = ((int64_t) indices[idx + i]) * stride + z * numel * stride; - gradient += static_cast(grad_output[grad_row]) * scale; - } - - grad_weight[weight_row] = static_cast(static_cast(grad_weight[weight_row]) + gradient); - } else { - // Add duplicate to the cache: - smem_dups_cache[smem_offset + laneIdx] = num_duplicates; - } - } - } - - WARP_SYNC(); - - // All lanes in the warp are still active here. Use them all to reduce duplicates when - // large number of duplicates are present: - for (int subwarp = 0; subwarp < C10_WARP_SIZE; subwarp++) { - // All lanes read the shared memory entry for number of duplicates - int64_t new_num_duplicates = smem_dups_cache[smem_offset + subwarp]; - - // Check if the original sub-warp had duplicates to eliminate, if not skip. - if (new_num_duplicates == 0) - continue; - - // There are duplicates that need eliminating: - int64_t new_idx = base_idx + subwarp; - int64_t new_crnt_sorted_idx = sorted_indices[new_idx]; - const int64_t new_weight_row = new_crnt_sorted_idx * stride + z * stride_before; - - // Result of the reduction will be in this variable: - opmath_t gradient = (opmath_t)0.0; - - int64_t num_warp_passes = new_num_duplicates / C10_WARP_SIZE; - // Parallel reduction across the array of duplicates using all the lanes in the warp: - for (int64_t i = 0; i < num_warp_passes; ++i) { - grad_row = ((int64_t) indices[new_idx + i * C10_WARP_SIZE + laneIdx]) * stride + z * numel * stride; - gradient += static_cast(grad_output[grad_row]) * scale; - } - - // Reduce across the lanes of the warp: - WARP_SYNC(); - for (int offset = C10_WARP_SIZE / 2; offset > 0; offset /= 2) { - gradient += WARP_SHFL_DOWN(gradient, offset); - } - - if (laneIdx == 0) { - for (int64_t i = num_warp_passes * C10_WARP_SIZE; i < new_num_duplicates; ++i) { - grad_row = ((int64_t) indices[new_idx + i]) * stride + z * numel * stride; - gradient += static_cast(grad_output[grad_row]) * scale; - } - - grad_weight[new_weight_row] = static_cast(static_cast(grad_weight[new_weight_row]) + gradient); - } - } - } -} -#else template __global__ void indexing_backward_kernel_stride_1( const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight, @@ -664,11 +700,8 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List(1,nElemBefore), at::cuda::getCurrentDeviceProperties()->maxGridSize[2])); dim3 block(warp_size, indices_per_block); - - if (sliceSize == 1) { #ifdef USE_ROCM - // Adapt grid size to smaller virtual warp size: - dim3 new_grid(ceil_div(num_indices, (int64_t) (indices_per_block * warp_size)), grid.y, grid.z); - size_t smem_dups_size = indices_per_block * warp_size * sizeof(int64_t); + dim3 new_grid(ceil_div(num_indices, (int64_t) (indices_per_block * warp_size)), grid.y, grid.z); + size_t smem_dups_size = indices_per_block * warp_size * sizeof(int64_t); #define KERNEL_GRID new_grid #define KERNEL_SMEM smem_dups_size #else #define KERNEL_GRID grid #define KERNEL_SMEM 0 #endif + + if (sliceSize == 1) { // This implementation is faster with high amounts of duplicates but could overflow // if FP16 / BF16 is used AT_DISPATCH_V2( @@ -719,8 +751,6 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<<>>( - sorted_indices.const_data_ptr(), - orig_indices.const_data_ptr(), - expandedValue.const_data_ptr(), - src_.mutable_data_ptr(), - num_indices, - sliceSize, - strideBefore, - nElemBefore); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }), - AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), - // AT_EXPAND(AT_FLOAT8_TYPES), - // TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True - // should not be supported here, then reenable AT_FLOAT8_DTYPES - kFloat8_e4m3fn, - kFloat8_e5m2, - kFloat8_e4m3fnuz, - kFloat8_e5m2fnuz, - kComplexHalf, - kHalf, - kBool, - kBFloat16); - } else { - AT_DISPATCH_V2( - expandedValue.scalar_type(), - "indexing_backward", - AT_WRAP([&] { - indexing_backward_kernel_rocm<<>>( - sorted_indices.const_data_ptr(), - orig_indices.const_data_ptr(), - expandedValue.const_data_ptr(), - src_.mutable_data_ptr(), - num_indices, - sliceSize, - strideBefore, - nElemBefore); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }), - AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), - // AT_EXPAND(AT_FLOAT8_TYPES), - // TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True - // should not be supported here, then reenable AT_FLOAT8_DTYPES - kFloat8_e4m3fn, - kFloat8_e5m2, - kFloat8_e4m3fnuz, - kFloat8_e5m2fnuz, - kComplexHalf, - kHalf, - kBool, - kBFloat16); - } -#endif } else { AT_DISPATCH_V2( expandedValue.scalar_type(), "indexing_backward", AT_WRAP([&] { - indexing_backward_kernel<<>>( + indexing_backward_kernel<<>>( sorted_indices.const_data_ptr(), orig_indices.const_data_ptr(), expandedValue.const_data_ptr(), @@ -843,6 +813,9 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List