[ROCm] Improve reduction sum performance (#160466)

* Use input vectorization for reduction_on_fastest_striding_dimension when dim0 >= 128

**Reproducer:**
```
import time
import torch

shapes = [
    (5079670, 128)
]

dims = [
    (1)
]

for i, shape in enumerate(shapes):
    x = torch.randn(shape, device='cuda', dtype=torch.float)
    for _ in range(10):
        w = torch.sum(x, dims[i])
    torch.cuda.synchronize()
    print(w.size())

    start_time = time.time()
    for _ in range(50):
        _ = torch.sum(x, dims[i])
    torch.cuda.synchronize()
    end_time = time.time()
    mean_time = (end_time - start_time)/50
    print(f"Avg time for shape {shape}: {mean_time * 1e6:.2f} us")
```

**Before (MI300X):**
Avg time for shape (5079670, 128): 1629.99 us

**After (MI300X)**
Avg time for shape (5079670, 128): 1008.59 us

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160466
Approved by: https://github.com/petrex, https://github.com/jeffdaily
This commit is contained in:
Jerry Mannil 2025-08-13 18:46:54 +00:00 committed by PyTorch MergeBot
parent db0b7f1cc9
commit 70ccdec44b

View File

@ -1062,7 +1062,7 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){
// In such case, values in each loaded vector always correspond to different outputs.
if (fastest_moving_stride == sizeof(scalar_t)) {
#ifdef USE_ROCM
if (reduction_on_fastest_striding_dimension && dim0 > 128 && iter.num_reduce_dims() == 1) {
if (reduction_on_fastest_striding_dimension && dim0 >= 128 && iter.num_reduce_dims() == 1) {
#else
if (reduction_on_fastest_striding_dimension && dim0 > 128 && iter.num_reduce_dims() == 1 && vt0 >= input_vec_size) {
#endif