mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Take this example:
```
def _mul2(x):
y = torch.empty_like(x)
mul2_kernel[(10,)](
in_ptr0=x, out_ptr=y,
n_elements=x.numel(), BLOCK_SIZE=1,
)
return y
def f(x):
for _ in range(4):
x = _mul2(x)
return x + 1
```
Currently, the codegen will show up like this. Notice, how we allocate 5 buffers of the same size.
```
# Source Nodes: [triton_kernel_wrapper_mutation], Original ATen: []
buf0 = empty_strided_cuda((10, ), (1, ), torch.float32)
mul2_kernel_0.run(in_ptr0=arg0_1, out_ptr=reinterpret_tensor(buf0, (10, ), (1, ), 0) ...)
# Source Nodes: [triton_kernel_wrapper_mutation_1], Original ATen: []
buf4 = empty_strided_cuda((10, ), (1, ), torch.float32)
mul2_kernel_0.run(in_ptr0=reinterpret_tensor(buf0, (10, ), (1, ), 0), out_ptr=reinterpret_tensor(buf4, (10, ), (1, ), 0) ...)
# Source Nodes: [triton_kernel_wrapper_mutation_2], Original ATen: []
buf8 = empty_strided_cuda((10, ), (1, ), torch.float32)
mul2_kernel_0.run(in_ptr0=reinterpret_tensor(buf4, (10, ), (1, ), 0), out_ptr=reinterpret_tensor(buf8, (10, ), (1, ), 0) ...)
# Source Nodes: [triton_kernel_wrapper_mutation_3], Original ATen: []
buf12 = empty_strided_cuda((10, ), (1, ), torch.float32)
mul2_kernel_0.run(in_ptr0=reinterpret_tensor(buf8, (10, ), (1, ), 0), out_ptr=reinterpret_tensor(buf12, (10, ), (1, ), 0) ...)
# Source Nodes: [add], Original ATen: [aten.add]
buf16 = empty_strided_cuda((10, ), (1, ), torch.float32)
triton_poi_fused_add_0.run(buf12, buf16, 10, grid=grid(10), stream=stream0)...)
return (buf16, )
```
With this PR, we want to see this. Notice, how we only allocate 2 buffers this time. The other 3 buffers are re-used.
```
# Source Nodes: [triton_kernel_wrapper_mutation], Original ATen: []
buf0 = empty_strided_cuda((10, ), (1, ), torch.float32)
mul2_kernel_0.run(in_ptr0=arg0_1, out_ptr=reinterpret_tensor(buf0, (10, ), (1, ), 0), ...)
del arg0_1
# Source Nodes: [triton_kernel_wrapper_mutation_1], Original ATen: []
buf2 = empty_strided_cuda((10, ), (1, ), torch.float32)
mul2_kernel_0.run(in_ptr0=reinterpret_tensor(buf0, (10, ), (1, ), 0), out_ptr=reinterpret_tensor(buf2, (10, ), (1, ), 0) ...)
# Source Nodes: [triton_kernel_wrapper_mutation_2], Original ATen: []
buf4 = buf0; del buf0 # reuse
mul2_kernel_0.run(in_ptr0=reinterpret_tensor(buf2, (10, ), (1, ), 0), out_ptr=reinterpret_tensor(buf4, (10, ), (1, ), 0) ...)
# Source Nodes: [triton_kernel_wrapper_mutation_3], Original ATen: []
buf6 = buf2; del buf2 # reuse
mul2_kernel_0.run(in_ptr0=reinterpret_tensor(buf4, (10, ), (1, ), 0), out_ptr=reinterpret_tensor(buf6, (10, ), (1, ), 0) ...)
del buf4
# Source Nodes: [add], Original ATen: [aten.add]
buf8 = buf6; del buf6 # reuse
triton_poi_fused_add_0.run(buf8, 10, grid=grid(10), stream=stream0)
return (buf8, )
```
Differential Revision: [D56379307](https://our.internmc.facebook.com/intern/diff/D56379307)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124425
Approved by: https://github.com/oulgen
|
||
|---|---|---|
| .. | ||
| __init__.py | ||
| auto_functionalize.py | ||
| cond.py | ||
| effects.py | ||
| map.py | ||
| out_dtype.py | ||
| strict_mode.py | ||
| templated_attention.py | ||
| torchbind.py | ||
| triton_kernel_wrap.py | ||
| utils.py | ||
| while_loop.py | ||
| wrap.py | ||