pytorch/torch/_higher_order_ops
Colin Peppler cbf420b67a [inductor] for UserDefinedTritonKernels don't mark all inputs as mutating (#124425)
Take this example:
```
def _mul2(x):
    y = torch.empty_like(x)
    mul2_kernel[(10,)](
        in_ptr0=x, out_ptr=y,
        n_elements=x.numel(), BLOCK_SIZE=1,
    )
    return y

def f(x):
    for _ in range(4):
        x = _mul2(x)
    return x + 1
```

Currently, the codegen will show up like this. Notice, how we allocate 5 buffers of the same size.
```
# Source Nodes: [triton_kernel_wrapper_mutation], Original ATen: []
buf0 = empty_strided_cuda((10, ), (1, ), torch.float32)
mul2_kernel_0.run(in_ptr0=arg0_1, out_ptr=reinterpret_tensor(buf0, (10, ), (1, ), 0) ...)

# Source Nodes: [triton_kernel_wrapper_mutation_1], Original ATen: []
buf4 = empty_strided_cuda((10, ), (1, ), torch.float32)
mul2_kernel_0.run(in_ptr0=reinterpret_tensor(buf0, (10, ), (1, ), 0), out_ptr=reinterpret_tensor(buf4, (10, ), (1, ), 0) ...)

# Source Nodes: [triton_kernel_wrapper_mutation_2], Original ATen: []
buf8 = empty_strided_cuda((10, ), (1, ), torch.float32)
mul2_kernel_0.run(in_ptr0=reinterpret_tensor(buf4, (10, ), (1, ), 0), out_ptr=reinterpret_tensor(buf8, (10, ), (1, ), 0) ...)

# Source Nodes: [triton_kernel_wrapper_mutation_3], Original ATen: []
buf12 = empty_strided_cuda((10, ), (1, ), torch.float32)
mul2_kernel_0.run(in_ptr0=reinterpret_tensor(buf8, (10, ), (1, ), 0), out_ptr=reinterpret_tensor(buf12, (10, ), (1, ), 0) ...)

# Source Nodes: [add], Original ATen: [aten.add]
buf16 = empty_strided_cuda((10, ), (1, ), torch.float32)
triton_poi_fused_add_0.run(buf12, buf16, 10, grid=grid(10), stream=stream0)...)
return (buf16, )
```

With this PR, we want to see this. Notice, how we only allocate 2 buffers this time. The other 3 buffers are re-used.
```
# Source Nodes: [triton_kernel_wrapper_mutation], Original ATen: []
buf0 = empty_strided_cuda((10, ), (1, ), torch.float32)
mul2_kernel_0.run(in_ptr0=arg0_1, out_ptr=reinterpret_tensor(buf0, (10, ), (1, ), 0), ...)
del arg0_1

# Source Nodes: [triton_kernel_wrapper_mutation_1], Original ATen: []
buf2 = empty_strided_cuda((10, ), (1, ), torch.float32)
mul2_kernel_0.run(in_ptr0=reinterpret_tensor(buf0, (10, ), (1, ), 0), out_ptr=reinterpret_tensor(buf2, (10, ), (1, ), 0) ...)

# Source Nodes: [triton_kernel_wrapper_mutation_2], Original ATen: []
buf4 = buf0; del buf0  # reuse
mul2_kernel_0.run(in_ptr0=reinterpret_tensor(buf2, (10, ), (1, ), 0), out_ptr=reinterpret_tensor(buf4, (10, ), (1, ), 0) ...)

# Source Nodes: [triton_kernel_wrapper_mutation_3], Original ATen: []
buf6 = buf2; del buf2  # reuse
mul2_kernel_0.run(in_ptr0=reinterpret_tensor(buf4, (10, ), (1, ), 0), out_ptr=reinterpret_tensor(buf6, (10, ), (1, ), 0) ...)
del buf4

# Source Nodes: [add], Original ATen: [aten.add]
buf8 = buf6; del buf6  # reuse
triton_poi_fused_add_0.run(buf8, 10, grid=grid(10), stream=stream0)
return (buf8, )
```

Differential Revision: [D56379307](https://our.internmc.facebook.com/intern/diff/D56379307)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124425
Approved by: https://github.com/oulgen
2024-04-21 06:00:14 +00:00
..
__init__.py ScoreMod API (#121845) 2024-04-06 01:10:44 +00:00
auto_functionalize.py Rename impl_abstract to register_fake, part 1/2 (#123937) 2024-04-17 12:46:01 +00:00
cond.py Don't record autograd state ops while torch.compile in pre-dispatch export (#121736) 2024-03-14 23:06:10 +00:00
effects.py [effects] Add way to register effectul op (#122348) 2024-04-09 03:22:32 +00:00
map.py Clean up mode handling in python dispatcher (#121083) 2024-03-08 00:30:34 +00:00
out_dtype.py Support higher order op functionalization in predispatch IR (#115314) 2024-03-01 09:13:47 +00:00
strict_mode.py [torch.export] Support is_compiling() flag for non-strict mode (#119602) 2024-02-29 05:52:51 +00:00
templated_attention.py Adds LSE output for templated-attention-hop if inputs require grad (#124308) 2024-04-20 05:45:56 +00:00
torchbind.py Add torch._library.register_fake_class to fakify torchBind class (#122622) 2024-04-02 23:52:17 +00:00
triton_kernel_wrap.py [inductor] for UserDefinedTritonKernels don't mark all inputs as mutating (#124425) 2024-04-21 06:00:14 +00:00
utils.py Support higher order op functionalization in predispatch IR (#115314) 2024-03-01 09:13:47 +00:00
while_loop.py [while_loop] support closures (#123018) 2024-04-03 19:35:15 +00:00
wrap.py [export] add replace_set_grad_with_hop_pass (#119810) 2024-02-17 02:18:19 +00:00