mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ROCm] Adjust grid size for non-unit stride backwards indexing (#165026)
Adjust grid size for non-unit stride backwards indexing. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165026 Approved by: https://github.com/jeffdaily
This commit is contained in:
parent
3f27100d3e
commit
01a2812f48
|
|
@ -710,6 +710,9 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
|
|||
dim3 block(warp_size, indices_per_block);
|
||||
|
||||
#ifdef USE_ROCM
|
||||
dim3 new_grid_many_indices(ceil_div(num_indices, (int64_t) (indices_per_block * warp_size)),
|
||||
grid.y == 1 ? std::min<int>(at::cuda::getCurrentDeviceProperties()->maxGridSize[1], ceil_div(sliceSize, (int64_t) (warp_size))) : grid.y,
|
||||
grid.z);
|
||||
dim3 new_grid(ceil_div(num_indices, (int64_t) (indices_per_block * warp_size)), grid.y, grid.z);
|
||||
size_t smem_dups_size = indices_per_block * warp_size * sizeof(int64_t);
|
||||
#define KERNEL_GRID new_grid
|
||||
|
|
@ -788,7 +791,7 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
|
|||
expandedValue.scalar_type(),
|
||||
"indexing_backward_many_indices",
|
||||
AT_WRAP([&] {
|
||||
indexing_backward_kernel_many_indices<scalar_t, UNROLL><<<new_grid, block, smem_dups_size, stream>>>(
|
||||
indexing_backward_kernel_many_indices<scalar_t, UNROLL><<<new_grid_many_indices, block, smem_dups_size, stream>>>(
|
||||
sorted_indices.const_data_ptr<int64_t>(),
|
||||
orig_indices.const_data_ptr<int64_t>(),
|
||||
expandedValue.const_data_ptr<scalar_t>(),
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user