mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
parallelize sort (#142391)
- use __gnu_parallel::sort for gcc compilations
- add a parallelized version of std::sort and std::stable_sort for non gcc compilations
Using __gnu_parallel::sort:
provides ~3.7x speed up for length 50000 sorts with NUM_THREADS=16 and NUM_THREADS=4 on aarch64
The performance is measured using the following script:
```python
import torch
import torch.autograd.profiler as profiler
torch.manual_seed(0)
N = 50000
x = torch.randn(N, dtype=torch.float)
with profiler.profile(with_stack=True, profile_memory=False, record_shapes=True) as prof:
for i in range(1000):
_, _ = torch.sort(x)
print(prof.key_averages(group_by_input_shape=True).table(sort_by='self_cpu_time_total', row_limit=10))
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142391
Approved by: https://github.com/malfet
This commit is contained in:
parent
7725d0ba12
commit
49082f9dba
|
|
@ -389,6 +389,12 @@ if(INTERN_BUILD_ATEN_OPS)
|
|||
else(MSVC)
|
||||
set(EXTRA_FLAGS "-DCPU_CAPABILITY=${CPU_CAPABILITY} -DCPU_CAPABILITY_${CPU_CAPABILITY}")
|
||||
endif(MSVC)
|
||||
|
||||
# Only parallelize the SortingKernel for now to avoid side effects
|
||||
if(${NAME} STREQUAL "native/cpu/SortingKernel.cpp" AND NOT MSVC AND USE_OMP)
|
||||
string(APPEND EXTRA_FLAGS " -D_GLIBCXX_PARALLEL")
|
||||
endif()
|
||||
|
||||
# Disable certain warnings for GCC-9.X
|
||||
if(CMAKE_COMPILER_IS_GNUCXX)
|
||||
if(("${NAME}" STREQUAL "native/cpu/GridSamplerKernel.cpp") AND ("${CPU_CAPABILITY}" STREQUAL "DEFAULT"))
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user