mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Calling C++ from Python via ctypes is notoriously slow. This switches to generating our own C++ bindings directly, which is a >5x speedup on this kernel-launch-bound microbenchmark:
```python
from ctypes import c_void_p
import torch
from torch import empty
from torch._inductor.codecache import AsyncCompile
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
from torch._inductor.wrapper_benchmark import compiled_module_main
async_compile = AsyncCompile()
src = '''
#include "/tmp/torchinductor_jansel/gb/cgbau5vlj6cetmcjbjbtw6x4rrivaln6f45s5d72gy2bfx5foz3k.h"
extern "C" void kernel(const float* in_ptr0,
float* out_ptr0)
{
{
auto tmp0 = in_ptr0[static_cast<long>(0L)];
auto tmp1 = static_cast<float>(1.0);
auto tmp2 = decltype(tmp0)(tmp0 + tmp1);
out_ptr0[static_cast<long>(0L)] = tmp2;
}
}
'''
cpp_fused_add_ctypes = async_compile.cpp(src)
cpp_fused_add_cpython = async_compile.cpp_pybinding(["const float*", "float*"], src)
async_compile.wait(globals())
del async_compile
def call(arg0_1):
buf0 = empty((1,), device='cpu', dtype=torch.float32)
if use_ctypes:
for _ in range(100):
cpp_fused_add_ctypes(c_void_p(arg0_1.data_ptr()), c_void_p(buf0.data_ptr()))
else:
for _ in range(100):
cpp_fused_add_cpython(arg0_1, buf0)
del arg0_1
return (buf0,)
def benchmark_compiled_module(times=1000, repeat=100):
arg0_1 = rand_strided((1,), (1,), device='cpu', dtype=torch.float32)
return print_performance(lambda: call(arg0_1), times=times, repeat=repeat)
print("old ctypes bindings: ", end='')
use_ctypes = True
compiled_module_main('None', benchmark_compiled_module)
print("new bindings: ", end='')
use_ctypes = False
compiled_module_main('None', benchmark_compiled_module)
```
Output:
```
old ctypes bindings: 0.000073
new bindings: 0.000013
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117500
Approved by: https://github.com/desertfire
ghstack dependencies: #117409, #116667, #117591
31 lines
510 B
Python
31 lines
510 B
Python
import time
|
|
import timeit
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
|
|
def add1(x):
|
|
return x + 1
|
|
|
|
|
|
def bench(name, fn):
|
|
x = torch.randn(1)
|
|
start = time.perf_counter()
|
|
for _ in range(3):
|
|
fn(x)
|
|
end = time.perf_counter()
|
|
|
|
results = timeit.repeat(lambda: fn(x), number=1000, repeat=100)
|
|
print(f"{name} {np.median(results)*1000:.1f}us (warmup={end-start:.1f}s)")
|
|
|
|
|
|
def main():
|
|
bench("eager ", add1)
|
|
bench("compiled", torch.compile(add1))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|