mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ROCm] Improve reduction sum performance (#160466)
* Use input vectorization for reduction_on_fastest_striding_dimension when dim0 >= 128
**Reproducer:**
```
import time
import torch
shapes = [
(5079670, 128)
]
dims = [
(1)
]
for i, shape in enumerate(shapes):
x = torch.randn(shape, device='cuda', dtype=torch.float)
for _ in range(10):
w = torch.sum(x, dims[i])
torch.cuda.synchronize()
print(w.size())
start_time = time.time()
for _ in range(50):
_ = torch.sum(x, dims[i])
torch.cuda.synchronize()
end_time = time.time()
mean_time = (end_time - start_time)/50
print(f"Avg time for shape {shape}: {mean_time * 1e6:.2f} us")
```
**Before (MI300X):**
Avg time for shape (5079670, 128): 1629.99 us
**After (MI300X)**
Avg time for shape (5079670, 128): 1008.59 us
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160466
Approved by: https://github.com/petrex, https://github.com/jeffdaily
This commit is contained in:
parent
db0b7f1cc9
commit
70ccdec44b
|
|
@ -1062,7 +1062,7 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){
|
||||||
// In such case, values in each loaded vector always correspond to different outputs.
|
// In such case, values in each loaded vector always correspond to different outputs.
|
||||||
if (fastest_moving_stride == sizeof(scalar_t)) {
|
if (fastest_moving_stride == sizeof(scalar_t)) {
|
||||||
#ifdef USE_ROCM
|
#ifdef USE_ROCM
|
||||||
if (reduction_on_fastest_striding_dimension && dim0 > 128 && iter.num_reduce_dims() == 1) {
|
if (reduction_on_fastest_striding_dimension && dim0 >= 128 && iter.num_reduce_dims() == 1) {
|
||||||
#else
|
#else
|
||||||
if (reduction_on_fastest_striding_dimension && dim0 > 128 && iter.num_reduce_dims() == 1 && vt0 >= input_vec_size) {
|
if (reduction_on_fastest_striding_dimension && dim0 > 128 && iter.num_reduce_dims() == 1 && vt0 >= input_vec_size) {
|
||||||
#endif
|
#endif
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user