[ROCm] Adjust grid size for non-unit stride backwards indexing (#165026)

Adjust grid size for non-unit stride backwards indexing.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165026
Approved by: https://github.com/jeffdaily
This commit is contained in:
Gheorghe-Teodor Bercea 2025-10-10 16:36:34 +00:00 committed by PyTorch MergeBot
parent 3f27100d3e
commit 01a2812f48

View File

@ -710,6 +710,9 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
dim3 block(warp_size, indices_per_block);
#ifdef USE_ROCM
dim3 new_grid_many_indices(ceil_div(num_indices, (int64_t) (indices_per_block * warp_size)),
grid.y == 1 ? std::min<int>(at::cuda::getCurrentDeviceProperties()->maxGridSize[1], ceil_div(sliceSize, (int64_t) (warp_size))) : grid.y,
grid.z);
dim3 new_grid(ceil_div(num_indices, (int64_t) (indices_per_block * warp_size)), grid.y, grid.z);
size_t smem_dups_size = indices_per_block * warp_size * sizeof(int64_t);
#define KERNEL_GRID new_grid
@ -788,7 +791,7 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
expandedValue.scalar_type(),
"indexing_backward_many_indices",
AT_WRAP([&] {
indexing_backward_kernel_many_indices<scalar_t, UNROLL><<<new_grid, block, smem_dups_size, stream>>>(
indexing_backward_kernel_many_indices<scalar_t, UNROLL><<<new_grid_many_indices, block, smem_dups_size, stream>>>(
sorted_indices.const_data_ptr<int64_t>(),
orig_indices.const_data_ptr<int64_t>(),
expandedValue.const_data_ptr<scalar_t>(),