Commit Graph

576 Commits

Author SHA1 Message Date
Oguz Ulgen
d1947a8707 Migrate from lru_cache to cache (#155613)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155613
Approved by: https://github.com/ezyang
ghstack dependencies: #155612
2025-06-11 19:44:18 +00:00
PyTorch MergeBot
95448b2ce6 Revert "[Inductor] Improve typing, and prepare for ABI-compatible AOTI C-shim dispatching (#154371)"
This reverts commit 65b1aedd09.

Reverted https://github.com/pytorch/pytorch/pull/154371 on behalf of https://github.com/clee2000 due to see henry's comment above.  This was reverted internally because it causes a memory leak and OOMs on AMD? ([comment](https://github.com/pytorch/pytorch/pull/154371#issuecomment-2954192879))
2025-06-08 17:37:29 +00:00
PyTorch MergeBot
7e4c097b07 Revert "[inductor] Add typing to _inductor/ir.py (#149958)"
This reverts commit 529e0357c6.

Reverted https://github.com/pytorch/pytorch/pull/149958 on behalf of https://github.com/malfet due to Looks like it broke inductor_torchbind tests, due to more graphbreaks, see b0fbbef136/1 ([comment](https://github.com/pytorch/pytorch/pull/149958#issuecomment-2949583209))
2025-06-06 15:19:16 +00:00
Tom Ritchford
529e0357c6 [inductor] Add typing to _inductor/ir.py (#149958)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149958
Approved by: https://github.com/Skylion007
2025-06-06 14:15:01 +00:00
eellison
0827464002 Replace runtime type parameterization (#155221)
See:

```
>>> import timeit; print(f"OrderedSet[str](): {timeit.timeit('OrderedSet[str]()', setup='from torch.utils._ordered_set import OrderedSet', number=1000000):.6f}s, OrderedSet(): {timeit.timeit('OrderedSet()', setup='from torch.utils._ordered_set import OrderedSet', number=1000000):.6f}s")
```
> `OrderedSet[str]()`: 0.354622s, OrderedSet(): 0.095376s

Type parameterization should be on type hint, not in runtime.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155221
Approved by: https://github.com/Skylion007, https://github.com/jansel
2025-06-05 21:43:54 +00:00
Boyuan Feng
be16f21ca6 [Graph Partition] add symints to get_graph_inputs (#154679)
During `codegen_inputs`, we check whether there are undefined symbols:
65b1aedd09/torch/_inductor/codegen/wrapper.py (L1668-L1674)

Previously, for graph partition inputs, we do not explicitly add symints.
65b1aedd09/torch/_inductor/codegen/wrapper.py (L3265-L3272)
We relied on sizes/strides of TensorBox for codegen symint inputs.  For example, a tensor with shape `[s0, 2]` will implicitly codegen `s0` as an input here. This works fine in most cases since backed symint has to come from some tensor shapes.
65b1aedd09/torch/_inductor/codegen/wrapper.py (L1624-L1632)

In rare cases, this does not work. One example is saved tensors for backward where a tensor may have shape `[2*s0, 2]`. Since `2*s0` is an expression but not a symbol, `codegen_input_symbol_assignment` would not handle `s0` and later there would be an error when `_verify_input_symbol_assignment`.

The fix is add symints to `get_graph_inputs`. An alternative way is to update `codegen_input_symbol_assignment` but I want to minimize the change to graph partition only.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154679
Approved by: https://github.com/eellison
2025-06-05 06:46:28 +00:00
Boyuan Feng
a4da1d4a47 [Graph Partition] support standalone_compile (#154698)
For graph partition, `write_get_raw_stream_header_once` is done once so the autotune code may not have the header. This PR additionally calls `write_get_raw_stream_header` in `codegen_device_guard_enter` before `get_raw_stream` is used.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154698
Approved by: https://github.com/oulgen
2025-06-03 07:40:42 +00:00
Paul Zhang
22a4cabd19 [Inductor] Add NaN assert to returned values from generated code (#154455)
Summary: It is possible to have `reinterpret_tensor` in the output of inductor codegen, e.g. `reinterpret_tensor(buf366, (1024, ), (1, ), 0)` in the return tuple. This adds assertions to all return values from inductor codegen to prevent nans from slipping through and being hard to trace.

Test Plan:
NaN asserts properly generated in example gemm script:

    vars = (buf1, primals_2, buf2, primals_1, )
    for var in vars:
        if isinstance(var, torch.Tensor):
            assert not var.isnan().any().item()
            assert not var.isinf().any().item()

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154455
Approved by: https://github.com/eellison
2025-05-30 20:32:56 +00:00
PyTorch MergeBot
fb67fa9968 Revert "[Inductor] Add NaN assert to returned values from generated code (#154455)"
This reverts commit aec3ef1008.

Reverted https://github.com/pytorch/pytorch/pull/154455 on behalf of https://github.com/malfet due to Looks like it broke inductor/test_compile_subprocess.py::CpuTests::test_AllenaiLongformerBase, see 35fc5c49b4/1(default%2C%20&mergeEphemeralLF=true ([comment](https://github.com/pytorch/pytorch/pull/154455#issuecomment-2923154249))
2025-05-30 18:45:01 +00:00
Paul Zhang
aec3ef1008 [Inductor] Add NaN assert to returned values from generated code (#154455)
Summary: It is possible to have `reinterpret_tensor` in the output of inductor codegen, e.g. `reinterpret_tensor(buf366, (1024, ), (1, ), 0)` in the return tuple. This adds assertions to all return values from inductor codegen to prevent nans from slipping through and being hard to trace.

Test Plan:
NaN asserts properly generated in example gemm script:

    vars = (buf1, primals_2, buf2, primals_1, )
    for var in vars:
        if isinstance(var, torch.Tensor):
            assert not var.isnan().any().item()
            assert not var.isinf().any().item()

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154455
Approved by: https://github.com/eellison
2025-05-30 08:53:24 +00:00
PyTorch MergeBot
639f459cb6 Revert "[Inductor] Add NaN assert to returned values from generated code (#154455)"
This reverts commit c3de2c7c6b.

Reverted https://github.com/pytorch/pytorch/pull/154455 on behalf of https://github.com/huydhn due to Sorry for reverting your change, I am trying to see if it help fix the broken trunk below.  It it does not help, I will reland the PR ([comment](https://github.com/pytorch/pytorch/pull/154455#issuecomment-2921562089))
2025-05-30 08:11:22 +00:00
Paul Zhang
c3de2c7c6b [Inductor] Add NaN assert to returned values from generated code (#154455)
Summary: It is possible to have `reinterpret_tensor` in the output of inductor codegen, e.g. `reinterpret_tensor(buf366, (1024, ), (1, ), 0)` in the return tuple. This adds assertions to all return values from inductor codegen to prevent nans from slipping through and being hard to trace.

Test Plan:
NaN asserts properly generated in example gemm script:

    vars = (buf1, primals_2, buf2, primals_1, )
    for var in vars:
        if isinstance(var, torch.Tensor):
            assert not var.isnan().any().item()
            assert not var.isinf().any().item()

Differential Revision: D74691131

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154455
Approved by: https://github.com/eellison
2025-05-30 03:09:37 +00:00
Benjamin Glass
65b1aedd09 [Inductor] Improve typing, and prepare for ABI-compatible AOTI C-shim dispatching (#154371)
Prepares for the next PR in the stack by tightening up typing on a `cpp_wrapper` interface that's only used in one (well-typed) place, as well as downstream effects of that change. In particular, this enabled:

1. removing a number of now clearly unnecessary asserts
2. adding a few more targeted asserts to validate the code's current assumptions
3. removing some unneeded control flow in several functions

As far as I can tell, this PR should be functionally neutral. One argument was removed from a `cpp_wrapper` public API, but that argument was unused, and only had a single callsite.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154371
Approved by: https://github.com/desertfire
2025-05-28 23:25:17 +00:00
PyTorch MergeBot
555fc05868 Revert "[Inductor] Improve typing, and prepare for ABI-compatible AOTI C-shim dispatching (#154371)"
This reverts commit 6169ca0b65.

Reverted https://github.com/pytorch/pytorch/pull/154371 on behalf of https://github.com/benjaminglass1 due to Appears to have broken main ([comment](https://github.com/pytorch/pytorch/pull/154371#issuecomment-2913975736))
2025-05-27 20:39:09 +00:00
Benjamin Glass
6169ca0b65 [Inductor] Improve typing, and prepare for ABI-compatible AOTI C-shim dispatching (#154371)
Prepares for the next PR in the stack by tightening up typing on a `cpp_wrapper` interface that's only used in one (well-typed) place, as well as downstream effects of that change. In particular, this enabled:

1. removing a number of now clearly unnecessary asserts
2. adding a few more targeted asserts to validate the code's current assumptions
3. removing some unneeded control flow in several functions

As far as I can tell, this PR should be functionally neutral. One argument was removed from a `cpp_wrapper` public API, but that argument was unused, and only had a single callsite.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154371
Approved by: https://github.com/desertfire
2025-05-27 19:17:41 +00:00
Boyuan Feng
669b176d4c [Graph Partition] support removed arguments, NoneLayout, and mutation (#153899)
Graph partition relies on `read_writes` to collect partition inputs and outputs. There are three edge cases:

1. `NoneLayout` is not allocated so it cannot become a partition input or output.
2. Codegen may decide a buffer to be internal to a kernel (e.g., triton kernel). One example is some buffers internal to a FusedSchedulerNode. These buffers are never actually allocated as `buf_id`.
3. We should use mutation_real_name for graph partition inputs and outputs to match the behavior of other codegen.

This PR supports these 3 cases.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153899
Approved by: https://github.com/eellison
2025-05-22 04:24:31 +00:00
Colin Peppler
81b6920c68 [aoti] skip input symbol codegen for sympy expr w/ many symbols (#152579)
Issue was that
- symbol-ids appeared out-of-order w.r.t to the order of the forward inputs
```
def forward(arg0 # [(s3 - 1) + s4, 32], arg1 #[(s3 - 1)] ..)
```
- this causes codegen to fail because it expects all the base symbols `s4,s3` to have been codegen-ed already.
- well, we can skip codegen-ing sympy expr with many symbols e.g. `(s3 - 1) + s4` because `s3` and `s4` will be codegen-ed by other inputs.

```
# for example
s3 = arg1.size(0) + 1
s4 = argN.size(0)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152579
Approved by: https://github.com/jingsh, https://github.com/desertfire
2025-05-07 01:18:09 +00:00
Blaine Burton Rister
bc11afd41f [Inductor] FX backend via Wrapper IR (#146942)
# Sub-PRs

These PRs contain refactors from the main one. They should be reviewed and merged first.

- https://github.com/pytorch/pytorch/pull/150458
- https://github.com/pytorch/pytorch/pull/152391
- https://github.com/pytorch/pytorch/pull/152587

# Feature

The goals of this PR are twofold.

## Goal 1: Introduce Wrapper IR as an intermediate step in wrapper codegen.

In addition to Triton/C++/Halide kernels, Inductor also generates "wrapper" code which allocates memory and calls the kernels. Originally, this wrapper code was fairly standard Python which resembled a user-written PyTorch program. Over time, various wrapper code generators have been added to accommodate things like AOTInductor, which prefers C++ code for static compilation. This complexity has bled into other parts of the codebase, as we now need if/else statements to choose between Python and C++ macros. (See an example [here](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/ir.py#L5515-L5522).) Since most of these code generation steps are conceptually identical across target languages, it seems reasonable to refactor them into some kind of intermediate representation which can be shared between the various backends. This might also make it easier to develop out-of-tree backends which cannot put their own macros in core Inductor components.

This PR takes some initial steps to formalize Inductor's wrapper codegen by generalizing the existing Memory Planning IR into a fully fledged Wrapper IR. This is pretty much identical to the existing Memory Planning IR, but it supports a richer set of ops for things like kernel definitions and calls. This refactor could help encapsulate wrapper codegen. Ideally, we don't need to worry about direct Python/C++ codegen in the main compiler files such as `ir.py`, and can instead defer these to classes like `PythonWrapperCodegen` and `CppWrapperCpu`, which operate on the Wrapper IR.

## Goal 2: Convert Wrapper IR into FX IR.

One of the main benefits of Wrapper IR is to enable more diverse Inductor backends. This PR introduces a converter from Wrapper IR into [FX IR](https://pytorch.org/docs/stable/fx.html), which is the intermediate representation most commonly used in PyTorch graph compilers. The purpose of this is to enable out-of-tree backends to consume Inductor's output in FX IR, which would hopefully make Inductor easier to leverage in novel compilers, hardware accelerators, etc.

It's not trivial to generate Python or C++ code which Inductor can compile and run, and doing so may require changes to other core Inductor files, for the reasons outlined in the previous section. The goal of supporting FX output is to enable something like `torch.compile`'s [custom backend](https://pytorch.org/docs/stable/torch.compiler_custom_backends.html) system, in which an out-of-tree backend can receive an optimized FX graph from Inductor, and compile and run it however it likes.

The typical users of this feature would likely not be part of PyTorch, and may or may not support running a kernel in eager mode. However, they can understand what `torch.empty_strided` means, compile and run Triton kernels, etc. So we just need to present them with an FX graph saying what code Inductor wants to run, which should be easier to analyze and transform in a third party system than Python or C++ source.

Since FX IR is fairly stable, this mechanism should hopefully isolate third-party backends, hardware accelerators, etc. from the implementation details of Inductor, and vice versa.

# Current status

Things that seem to work:

- Converted a lot of the most common Python codegen lines to Wrapper IR lines.
     - Handled the following cases, in addition to what was already in the Memory Planning IR:
         - Comments
         - Triton kernels
         - Extern/fallback kernels
         - Freeing tensors (`del buf0`)
         - MultiOutput
         - Graph outputs
         - ReinterpretView / StorageBox, for both call args and outputs.
     - FX conversion asserts that the program only contains Wrapper IR lines, and not strings of Python/C++ code.
- Prototype FX converter which can handle some of the most common use cases.
   - Defining Triton kernels, and putting them in a side table using TorchDynamo's existing [utilities](https://dev-discuss.pytorch.org/t/higher-order-operators-2023-10/1565).
   - Calling wrapped Triton kernels.
   - Calling extern kernels and certain types of fallback kernels.
       - Support both `extern_kernels.*` and `aten.*`.
       - Support multi-output kernels like `torch.topk`.
   - Graphs with multiple inputs/outputs.
   - Training i.e. calling `Tensor.backward()` in a compiled function.
   - Graph breaks (training).
- Run the `torch.fx.GraphModule` on GPU using the standard `__call__` method. This makes it easy to test the correctness of FX codegen.

Things that don't work:
- Both Wrapper IR and Wrapper -> FX coverage are currently best effort. There are still features which aren't captured as Wrapper IR lines, and fall back to plain strings. This representation is functionally correct but probably not rich enough to achieve the goals outlined in the previous sections.
         - Fallback kernels seem like the most difficult thing to fully cover, since they each define their own Python/C++ macros that would need to be converted to FX.
         - Size/alignment asserts are currently disabled via the config file. It's possible to generate FX IR for these, but it seems reasonable to defer these sanity checks to a later PR.
         - CommBuffer's and distributed communication are not yet supported. An earlier version of this PR attempted to implement this by calling `empty_strided_p2p`. However, building and testing distributed support seems non-trivial, so it's probably better to defer this.

# Out-of-tree compilers

With this PR, out of tree backends will be able to do further compilation on the FX graphs by subclassing `WrapperFxCodegen` and overriding the `compile_graph` function. This follows the same API as torch.compile's [custom backends](https://pytorch.org/docs/stable/torch.compiler_custom_backends.html), where the user simply returns a callable running the graph. The callable need not be a method of `GraphModule` or any other PyTorch class. See an example below.

```
from torch._inductor.codegen.wrapper_fxir import WrapperFxCodegen

class MyCustomBackend(WrapperFxCodegen):
     def compile_graph(self, gm):
         # Add 1 to the graph's outputs
         def compiled_fn(*args):
             return [x + 1 for x in gm.graph.forward(*args)]
         return compiled_fn
```

# Example FX graphs

This section contains some example FX graphs generated by Inductor. The correctness of these graphs was verified against eager mode by calling the corresponding `GraphModule`.

Here's an FX graph calling a basic Triton kernel. Notice how outputs are allocated with `torch.empty_strided`, and the Triton kernel is called by reference to Dynamo's triton side table.
```
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ((8,), (1,)), kwargs = {dtype: torch.float32, device: cuda:0})
    %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(8,)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg1_1, in_ptr1: %arg0_1, out_ptr0: %buf0, xnumel: 8, XBLOCK: 8}})
    return (buf0,)
```

Here's a more complicated graph that calls a `torch.addmm` extern kernel.

```
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=2] = placeholder[target=arg1_1]
    %buf0 : [num_users=3] = call_function[target=torch.empty_strided](args = ((), ()), kwargs = {dtype: torch.float32, device: cuda:0})
    %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(1,)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg1_1, out_ptr0: %buf0, xnumel: 1, r0_numel: 129, XBLOCK: 1}})
    %buf2 : [num_users=2] = call_function[target=torch.empty_strided](args = ((129, 1), (1, 1)), kwargs = {dtype: torch.float32, device: cuda:0})
    %addmm : [num_users=0] = call_function[target=torch.addmm](args = (%buf0, %arg0_1, %arg1_1), kwargs = {alpha: 1, beta: 1, out: %buf2})
    %delete : [num_users=0] = call_function[target=torch._inductor.codegen.wrapper_fxir.delete](args = (%buf0,), kwargs = {})
    return (buf2,)
```

Here's a graph which indexes into a tuple using `operator.getitem`. This is necessary to use the output of the `torch.topk` operation.

```
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %buf0 : [num_users=3] = call_function[target=torch.ops.aten.topk.default](args = (%arg0_1, 2), kwargs = {})
    %buf1 : [num_users=2] = call_function[target=operator.getitem](args = (%buf0, 0), kwargs = {})
    %buf2 : [num_users=2] = call_function[target=operator.getitem](args = (%buf0, 1), kwargs = {})
    %delete : [num_users=0] = call_function[target=torch._inductor.codegen.wrapper_fxir.delete](args = (%buf0,), kwargs = {})
    %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(2,)], tma_descriptor_metadata: {}, kwargs: {in_out_ptr0: %buf1, xnumel: 2, XBLOCK: 2}})
    %triton_kernel_wrapper_mutation_1 : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 1, constant_args_idx: 1, grid: [(2,)], tma_descriptor_metadata: {}, kwargs: {in_out_ptr0: %buf2, xnumel: 2, XBLOCK: 2}})
    return (buf1, buf2)
```

Here's a graph that reinterprets an output tensor using `torch.as_strided`. This is one way to handle Inductor's `ReinterpretView` op.

```
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ((2, 4), (4, 1)), kwargs = {dtype: torch.float32, device: cuda:0})
    %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(8,)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg0_1, in_ptr1: %arg1_1, out_ptr0: %buf0, xnumel: 8, XBLOCK: 8}})
    %buf0_view_buf0_0 : [num_users=1] = call_function[target=torch.as_strided](args = (%buf0, (8,), (1,), 0), kwargs = {})
    return (buf0_view_buf0_0,)
```

Here's a graph with dynamic shapes. This one is a little bit funky. Inductor provides a graph input for each shape symbol, which we map to a placeholder, in this example `s6`. Then, shape expressions in the generated code can refer to the symbol `s6`. The size hint for `s6` is stored in `node.meta["val"]` where `node` is the placeholder defining it. This works out in the generated python code because the placeholder defines a Python variable with the name `s6`.
```
graph():
    %s6 : [num_users=0] = placeholder[target=s6]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %arg2_1 : [num_users=1] = placeholder[target=arg2_1]
    %buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ((s6,), (1,)), kwargs = {dtype: torch.float32, device: cuda:0})
    %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [[-(((-s6)//8)), 1, 1]], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg2_1, in_ptr1: %arg1_1, out_ptr0: %buf0, xnumel: s6, XBLOCK: 8}})
    return buf0
```

Here's another graph, this time with dynamic shapes and strides. The grid expression is more complex since the numel is a product of dimensions.
```
graph():
    %s10 : [num_users=0] = placeholder[target=s10]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %arg2_1 : [num_users=1] = placeholder[target=arg2_1]
    %buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ([s10, s10], [s10, 1]), kwargs = {dtype: torch.float32, device: cuda:0})
    %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [[-(((s10**2)//(-64))), 1, 1]], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg2_1, in_ptr1: %arg1_1, out_ptr0: %buf0, xnumel: s10**2, XBLOCK: 64}})
    return buf0
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146942
Approved by: https://github.com/jansel
2025-05-06 10:06:39 +00:00
PyTorch MergeBot
99dac7005f Revert "[Inductor] FX backend via Wrapper IR (#146942)"
This reverts commit a7691140a0.

Reverted https://github.com/pytorch/pytorch/pull/146942 on behalf of https://github.com/malfet due to Looks like it indeed breaks lint, see a7691140a0/1 ([comment](https://github.com/pytorch/pytorch/pull/146942#issuecomment-2852192778))
2025-05-05 20:01:29 +00:00
Blaine Burton Rister
a7691140a0 [Inductor] FX backend via Wrapper IR (#146942)
# Sub-PRs

These PRs contain refactors from the main one. They should be reviewed and merged first.

- https://github.com/pytorch/pytorch/pull/150458
- https://github.com/pytorch/pytorch/pull/152391
- https://github.com/pytorch/pytorch/pull/152587

# Feature

The goals of this PR are twofold.

## Goal 1: Introduce Wrapper IR as an intermediate step in wrapper codegen.

In addition to Triton/C++/Halide kernels, Inductor also generates "wrapper" code which allocates memory and calls the kernels. Originally, this wrapper code was fairly standard Python which resembled a user-written PyTorch program. Over time, various wrapper code generators have been added to accommodate things like AOTInductor, which prefers C++ code for static compilation. This complexity has bled into other parts of the codebase, as we now need if/else statements to choose between Python and C++ macros. (See an example [here](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/ir.py#L5515-L5522).) Since most of these code generation steps are conceptually identical across target languages, it seems reasonable to refactor them into some kind of intermediate representation which can be shared between the various backends. This might also make it easier to develop out-of-tree backends which cannot put their own macros in core Inductor components.

This PR takes some initial steps to formalize Inductor's wrapper codegen by generalizing the existing Memory Planning IR into a fully fledged Wrapper IR. This is pretty much identical to the existing Memory Planning IR, but it supports a richer set of ops for things like kernel definitions and calls. This refactor could help encapsulate wrapper codegen. Ideally, we don't need to worry about direct Python/C++ codegen in the main compiler files such as `ir.py`, and can instead defer these to classes like `PythonWrapperCodegen` and `CppWrapperCpu`, which operate on the Wrapper IR.

## Goal 2: Convert Wrapper IR into FX IR.

One of the main benefits of Wrapper IR is to enable more diverse Inductor backends. This PR introduces a converter from Wrapper IR into [FX IR](https://pytorch.org/docs/stable/fx.html), which is the intermediate representation most commonly used in PyTorch graph compilers. The purpose of this is to enable out-of-tree backends to consume Inductor's output in FX IR, which would hopefully make Inductor easier to leverage in novel compilers, hardware accelerators, etc.

It's not trivial to generate Python or C++ code which Inductor can compile and run, and doing so may require changes to other core Inductor files, for the reasons outlined in the previous section. The goal of supporting FX output is to enable something like `torch.compile`'s [custom backend](https://pytorch.org/docs/stable/torch.compiler_custom_backends.html) system, in which an out-of-tree backend can receive an optimized FX graph from Inductor, and compile and run it however it likes.

The typical users of this feature would likely not be part of PyTorch, and may or may not support running a kernel in eager mode. However, they can understand what `torch.empty_strided` means, compile and run Triton kernels, etc. So we just need to present them with an FX graph saying what code Inductor wants to run, which should be easier to analyze and transform in a third party system than Python or C++ source.

Since FX IR is fairly stable, this mechanism should hopefully isolate third-party backends, hardware accelerators, etc. from the implementation details of Inductor, and vice versa.

# Current status

Things that seem to work:

- Converted a lot of the most common Python codegen lines to Wrapper IR lines.
     - Handled the following cases, in addition to what was already in the Memory Planning IR:
         - Comments
         - Triton kernels
         - Extern/fallback kernels
         - Freeing tensors (`del buf0`)
         - MultiOutput
         - Graph outputs
         - ReinterpretView / StorageBox, for both call args and outputs.
     - FX conversion asserts that the program only contains Wrapper IR lines, and not strings of Python/C++ code.
- Prototype FX converter which can handle some of the most common use cases.
   - Defining Triton kernels, and putting them in a side table using TorchDynamo's existing [utilities](https://dev-discuss.pytorch.org/t/higher-order-operators-2023-10/1565).
   - Calling wrapped Triton kernels.
   - Calling extern kernels and certain types of fallback kernels.
       - Support both `extern_kernels.*` and `aten.*`.
       - Support multi-output kernels like `torch.topk`.
   - Graphs with multiple inputs/outputs.
   - Training i.e. calling `Tensor.backward()` in a compiled function.
   - Graph breaks (training).
- Run the `torch.fx.GraphModule` on GPU using the standard `__call__` method. This makes it easy to test the correctness of FX codegen.

Things that don't work:
- Both Wrapper IR and Wrapper -> FX coverage are currently best effort. There are still features which aren't captured as Wrapper IR lines, and fall back to plain strings. This representation is functionally correct but probably not rich enough to achieve the goals outlined in the previous sections.
         - Fallback kernels seem like the most difficult thing to fully cover, since they each define their own Python/C++ macros that would need to be converted to FX.
         - Size/alignment asserts are currently disabled via the config file. It's possible to generate FX IR for these, but it seems reasonable to defer these sanity checks to a later PR.
         - CommBuffer's and distributed communication are not yet supported. An earlier version of this PR attempted to implement this by calling `empty_strided_p2p`. However, building and testing distributed support seems non-trivial, so it's probably better to defer this.

# Out-of-tree compilers

With this PR, out of tree backends will be able to do further compilation on the FX graphs by subclassing `WrapperFxCodegen` and overriding the `compile_graph` function. This follows the same API as torch.compile's [custom backends](https://pytorch.org/docs/stable/torch.compiler_custom_backends.html), where the user simply returns a callable running the graph. The callable need not be a method of `GraphModule` or any other PyTorch class. See an example below.

```
from torch._inductor.codegen.wrapper_fxir import WrapperFxCodegen

class MyCustomBackend(WrapperFxCodegen):
     def compile_graph(self, gm):
         # Add 1 to the graph's outputs
         def compiled_fn(*args):
             return [x + 1 for x in gm.graph.forward(*args)]
         return compiled_fn
```

# Example FX graphs

This section contains some example FX graphs generated by Inductor. The correctness of these graphs was verified against eager mode by calling the corresponding `GraphModule`.

Here's an FX graph calling a basic Triton kernel. Notice how outputs are allocated with `torch.empty_strided`, and the Triton kernel is called by reference to Dynamo's triton side table.
```
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ((8,), (1,)), kwargs = {dtype: torch.float32, device: cuda:0})
    %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(8,)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg1_1, in_ptr1: %arg0_1, out_ptr0: %buf0, xnumel: 8, XBLOCK: 8}})
    return (buf0,)
```

Here's a more complicated graph that calls a `torch.addmm` extern kernel.

```
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=2] = placeholder[target=arg1_1]
    %buf0 : [num_users=3] = call_function[target=torch.empty_strided](args = ((), ()), kwargs = {dtype: torch.float32, device: cuda:0})
    %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(1,)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg1_1, out_ptr0: %buf0, xnumel: 1, r0_numel: 129, XBLOCK: 1}})
    %buf2 : [num_users=2] = call_function[target=torch.empty_strided](args = ((129, 1), (1, 1)), kwargs = {dtype: torch.float32, device: cuda:0})
    %addmm : [num_users=0] = call_function[target=torch.addmm](args = (%buf0, %arg0_1, %arg1_1), kwargs = {alpha: 1, beta: 1, out: %buf2})
    %delete : [num_users=0] = call_function[target=torch._inductor.codegen.wrapper_fxir.delete](args = (%buf0,), kwargs = {})
    return (buf2,)
```

Here's a graph which indexes into a tuple using `operator.getitem`. This is necessary to use the output of the `torch.topk` operation.

```
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %buf0 : [num_users=3] = call_function[target=torch.ops.aten.topk.default](args = (%arg0_1, 2), kwargs = {})
    %buf1 : [num_users=2] = call_function[target=operator.getitem](args = (%buf0, 0), kwargs = {})
    %buf2 : [num_users=2] = call_function[target=operator.getitem](args = (%buf0, 1), kwargs = {})
    %delete : [num_users=0] = call_function[target=torch._inductor.codegen.wrapper_fxir.delete](args = (%buf0,), kwargs = {})
    %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(2,)], tma_descriptor_metadata: {}, kwargs: {in_out_ptr0: %buf1, xnumel: 2, XBLOCK: 2}})
    %triton_kernel_wrapper_mutation_1 : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 1, constant_args_idx: 1, grid: [(2,)], tma_descriptor_metadata: {}, kwargs: {in_out_ptr0: %buf2, xnumel: 2, XBLOCK: 2}})
    return (buf1, buf2)
```

Here's a graph that reinterprets an output tensor using `torch.as_strided`. This is one way to handle Inductor's `ReinterpretView` op.

```
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ((2, 4), (4, 1)), kwargs = {dtype: torch.float32, device: cuda:0})
    %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(8,)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg0_1, in_ptr1: %arg1_1, out_ptr0: %buf0, xnumel: 8, XBLOCK: 8}})
    %buf0_view_buf0_0 : [num_users=1] = call_function[target=torch.as_strided](args = (%buf0, (8,), (1,), 0), kwargs = {})
    return (buf0_view_buf0_0,)
```

Here's a graph with dynamic shapes. This one is a little bit funky. Inductor provides a graph input for each shape symbol, which we map to a placeholder, in this example `s6`. Then, shape expressions in the generated code can refer to the symbol `s6`. The size hint for `s6` is stored in `node.meta["val"]` where `node` is the placeholder defining it. This works out in the generated python code because the placeholder defines a Python variable with the name `s6`.
```
graph():
    %s6 : [num_users=0] = placeholder[target=s6]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %arg2_1 : [num_users=1] = placeholder[target=arg2_1]
    %buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ((s6,), (1,)), kwargs = {dtype: torch.float32, device: cuda:0})
    %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [[-(((-s6)//8)), 1, 1]], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg2_1, in_ptr1: %arg1_1, out_ptr0: %buf0, xnumel: s6, XBLOCK: 8}})
    return buf0
```

Here's another graph, this time with dynamic shapes and strides. The grid expression is more complex since the numel is a product of dimensions.
```
graph():
    %s10 : [num_users=0] = placeholder[target=s10]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %arg2_1 : [num_users=1] = placeholder[target=arg2_1]
    %buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ([s10, s10], [s10, 1]), kwargs = {dtype: torch.float32, device: cuda:0})
    %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [[-(((s10**2)//(-64))), 1, 1]], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg2_1, in_ptr1: %arg1_1, out_ptr0: %buf0, xnumel: s10**2, XBLOCK: 64}})
    return buf0
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146942
Approved by: https://github.com/jansel
2025-05-05 19:34:49 +00:00
PaulZhang12
84aa0985fb [Inductor] Add decomposeK as an autotuning choice for mm (#150654)
As a result of adding subgraph as a choice to inductor https://github.com/pytorch/pytorch/pull/149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: https://github.com/pytorch/pytorch/pull/150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. DecomposeK is currently only enabled for `torch.compile`.

Followups:
* decompose_k does not currently support epilogue fusion, which will take some work to enable
* Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM
* Add for addmm
* Enable for Inference and AOTI

Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously:

<img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" />

TorchInductor Benchmark Dashboard:
<img width="1727" alt="Screenshot 2025-04-30 at 2 02 53 PM" src="https://github.com/user-attachments/assets/4acd7ffc-407f-4cfd-98bb-2e3d8b1f00b3" />

We see speedups across all runs for training. Compile time increased as expected, with more `mm` options to tune over.

Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150654
Approved by: https://github.com/eellison
2025-05-03 02:23:54 +00:00
Animesh Jain
73b6b1ded4 [inductor][invoke_subgraph] Free the buffers before the subgraph call (#152494)
Before
![image](https://github.com/user-attachments/assets/62b24c14-69e6-40fb-94e3-223930132ef6)

After
![image](https://github.com/user-attachments/assets/9f340d4e-80a9-45aa-9400-626fff5b5ecd)

tlparse - https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmph5dwWt/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152494
Approved by: https://github.com/Skylion007, https://github.com/eellison
2025-05-03 00:38:08 +00:00
Blaine Burton Rister
add4702ebc [Inductor] Introduce Wrapper IR line for symbolic call args (#152587)
Preparatory refactor for https://github.com/pytorch/pytorch/pull/146942.

This PR introduces a new wrapper IR line to represent symbolic call args. This deletes a little bit of duplicated code between the Python and C++ backends. In the main PR, having a Wrapper IR line for this also tells the FX backend what this part of the wrapper code is doing. Before this PR, symbolic call args generated raw Python lines, which confuse the FX converter.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152587
Approved by: https://github.com/jansel
2025-05-02 20:37:00 +00:00
Animesh Jain
ea4b7e0e1d [invoke_subgraph] Simplify output code for subgraph output node (#152490)
Before - [manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmppQg3F8/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000](https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmppQg3F8/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000)
![image](https://github.com/user-attachments/assets/8fecdc23-eb78-4e15-9d03-c4bae4b49434)

After fix - https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmp9a5EM0/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000
![image](https://github.com/user-attachments/assets/8e98120c-d82e-42dc-bc50-a6bfd4f9923c)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152490
Approved by: https://github.com/eellison
ghstack dependencies: #152383
2025-05-02 16:31:25 +00:00
Animesh Jain
f6761f2968 [inductor][subgraph] Simplify the resulting output code for subgraph (#152383)
Check out output code

Before this PR -  - https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmp3iXDVs/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000
![image](https://github.com/user-attachments/assets/ef86eb8f-e8b9-47dd-8609-f90481f018b8)

After this PR - https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpRgUJvq/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000

![image](https://github.com/user-attachments/assets/10e22c60-7fb9-4519-9d54-019beff5333b)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152383
Approved by: https://github.com/eellison
2025-05-02 15:52:34 +00:00
PyTorch MergeBot
7c3e679ddd Revert "[Inductor] Add decomposeK as an autotuning choice for mm (#150654)"
This reverts commit fdcfc6a61a.

Reverted https://github.com/pytorch/pytorch/pull/150654 on behalf of https://github.com/wdvr due to Failing ROCM tests: inductor/test_subgraph_choice.py::TestSubgraphChoice::test_subgraph_decompose_k [GH job link](https://github.com/pytorch/pytorch/actions/runs/14786111108/job/41515742446) [HUD commit link](3c54e0c216) ([comment](https://github.com/pytorch/pytorch/pull/150654#issuecomment-2846470409))
2025-05-02 06:31:38 +00:00
PaulZhang12
fdcfc6a61a [Inductor] Add decomposeK as an autotuning choice for mm (#150654)
As a result of adding subgraph as a choice to inductor https://github.com/pytorch/pytorch/pull/149761 and enabling FP32 output from PyTorch GEMMs from FP16/BF16 inputs: https://github.com/pytorch/pytorch/pull/150812, this PR enables decompose_k as an autotuning choice for Inductor in generating the fastest matmuls with Triton. DecomposeK is currently only enabled for `torch.compile`.

Followups:
* decompose_k does not currently support epilogue fusion, which will take some work to enable
* Enable autotuning the bmm with Triton Templates as well without requiring tons of more compile time, async compilation. Anecdotal evidence shows that Triton BMM performs better usually than aten BMM
* Add for addmm
* Enable for Inference and AOTI

Below are the results of running TritonBench for Split-K shapes, comparing the aten performance versus pt2_triton, which now autotunes on decompose_k, seeing >10% speedup compared to aten on average, and for some shapes over 3x the performance of the best Triton mm previously:

<img width="929" alt="Screenshot 2025-04-28 at 9 15 39 PM" src="https://github.com/user-attachments/assets/27d85bbc-4f3a-43a6-a8fa-d4a5bbb8c999" />

TorchInductor Benchmark Dashboard:
<img width="1727" alt="Screenshot 2025-04-30 at 2 02 53 PM" src="https://github.com/user-attachments/assets/4acd7ffc-407f-4cfd-98bb-2e3d8b1f00b3" />

We see speedups across all runs for training. Compile time increased as expected, with more `mm` options to tune over.

Differential Revision: [D73820115](https://our.internmc.facebook.com/intern/diff/D73820115)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150654
Approved by: https://github.com/eellison
2025-05-01 23:01:30 +00:00
PyTorch MergeBot
f7b60456cc Revert "[inductor][subgraph] Simplify the resulting output code for subgraph (#152383)"
This reverts commit 98eb7c8cb1.

Reverted https://github.com/pytorch/pytorch/pull/152383 on behalf of https://github.com/malfet due to Broke CI, see 52cbcac640/1 ([comment](https://github.com/pytorch/pytorch/pull/152384#issuecomment-2845099985))
2025-05-01 15:46:08 +00:00
PyTorch MergeBot
2f1800bc3d Revert "[invoke_subgraph] Simplify output code for subgraph output node (#152490)"
This reverts commit 5fe335810a.

Reverted https://github.com/pytorch/pytorch/pull/152490 on behalf of https://github.com/malfet due to Broke CI, see 52cbcac640/1 ([comment](https://github.com/pytorch/pytorch/pull/152384#issuecomment-2845099985))
2025-05-01 15:46:07 +00:00
PyTorch MergeBot
2fa39e60ed Revert "[inductor][invoke_subgraph] Free the buffers before the subgraph call (#152494)"
This reverts commit 5236a8506c.

Reverted https://github.com/pytorch/pytorch/pull/152494 on behalf of https://github.com/malfet due to Broke CI, see 52cbcac640/1 ([comment](https://github.com/pytorch/pytorch/pull/152384#issuecomment-2845099985))
2025-05-01 15:46:07 +00:00
Blaine Burton Rister
7c63ddd817 [Inductor] Wrapper code refactors to prepare for FX codegen (#152391)
This PR contains some refactors from https://github.com/pytorch/pytorch/pull/146942, which help to enable Wrapper FX codegen:
1. Remove `OutputLine`, which is unused.
2. Add an attribute to the backend classes specifying whether they support caching.
3. Before compiling a graph, query the registered backends and check whether caching is supported.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152391
Approved by: https://github.com/jansel
2025-05-01 09:14:55 +00:00
Animesh Jain
5236a8506c [inductor][invoke_subgraph] Free the buffers before the subgraph call (#152494)
Before
![image](https://github.com/user-attachments/assets/62b24c14-69e6-40fb-94e3-223930132ef6)

After
![image](https://github.com/user-attachments/assets/9f340d4e-80a9-45aa-9400-626fff5b5ecd)

tlparse - https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmph5dwWt/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152494
Approved by: https://github.com/Skylion007, https://github.com/eellison
ghstack dependencies: #152357, #152384, #152383, #152490
2025-05-01 02:04:10 +00:00
Animesh Jain
5fe335810a [invoke_subgraph] Simplify output code for subgraph output node (#152490)
Before - [manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmppQg3F8/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000](https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmppQg3F8/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000)
![image](https://github.com/user-attachments/assets/8fecdc23-eb78-4e15-9d03-c4bae4b49434)

After fix - https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmp9a5EM0/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000
![image](https://github.com/user-attachments/assets/8e98120c-d82e-42dc-bc50-a6bfd4f9923c)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152490
Approved by: https://github.com/eellison
ghstack dependencies: #152357, #152384, #152383
2025-05-01 02:04:10 +00:00
Animesh Jain
98eb7c8cb1 [inductor][subgraph] Simplify the resulting output code for subgraph (#152383)
Check out output code

Before this PR -  - https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmp3iXDVs/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000
![image](https://github.com/user-attachments/assets/ef86eb8f-e8b9-47dd-8609-f90481f018b8)

After this PR - https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpRgUJvq/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000

![image](https://github.com/user-attachments/assets/10e22c60-7fb9-4519-9d54-019beff5333b)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152383
Approved by: https://github.com/eellison
ghstack dependencies: #152357, #152384
2025-05-01 02:04:10 +00:00
Anthony Shoumikhin
e2f9759bd0 Fix broken URLs (#152237)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152237
Approved by: https://github.com/huydhn, https://github.com/malfet
2025-04-27 09:56:42 +00:00
Rachel Guo
c729f7dbee [provenance_tracking][reland] Fix UT error and re-land ExternKernel support (#151709)
Summary:
ATT.

reverted previous diff :  D72572050

Test Plan:
```
 TORCH_LOGS="+inductor, output_code" buck2 run -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=h100 @//mode/opt fbcode//caffe2/test/inductor:provenance_tracing -- -r test_triton_kernel_to_post_grad_tracing_extern_kernel
```

Differential Revision: D73281217

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151709
Approved by: https://github.com/jingsh
2025-04-22 15:44:56 +00:00
Blaine Burton Rister
c0a0761871 [Inductor] Refactor wrapper codegen to use Wrapper IR. (#150458)
Preparatory refactor for https://github.com/pytorch/pytorch/pull/146942.

# Feature

This PR refactors the existing wrapper codegen into `WrapperLine` subclasses, extending the existing Memory Planning IR into a fully-fledged Wrapper IR. See the diagram below.

![wrapper_ir](https://github.com/user-attachments/assets/a61db21b-caf3-45d2-bfdb-91066ae4ba6b)

The IR currently supports the following ops:
- All existing memory planning IR ops (`AllocateLine`, `FreeIfNotReusedLine`, etc.)
- Reinterpret views (`ReinterpretLine`)
- Kernel definitions (`KernelDefinitionLine`)
- Calls to defined kernels (`KernelCallLine`)
- Calls to extern kernels (`ExternKernelLine`, `ExternKernelAllocLine`)
- Ops with multiple outputs (`MultiOutputLine`)
- Tensor cleanup at the end of a graph (`FreeLine`)
- Leaving comments in code (`CommentLine`)

There are two main motivations for this refactor:
1. Unlike free-form C++ and and Python code, Wrapper IR lines provide structured information about what the wrapper code does. This serves as a natural extension point for other types of wrapper codegen. For example, the parent PR generates FX IR from Wrapper IR. Wrapper IR aims to give new backends enough information to generate wrapper code without needing to modify core Inductor files such as `ir.py`.
2. This design will hopefully promote stronger modularity and encapsulation.
   a. Inductor's core compilation passes don't need to worry about whether they're targeting Python, C++, FX or anything else. They can simply focus on generating Wrapper IR, and target-specific code can be refactored into the various backends.
   b. Backends do not need to know about all the details and internal state of `V.graph` IR. For example, they don't need to consider whether a buffer has been removed from the graph when generating code. Wrapper IR will hopefully provide a simpler interface for generating wrapper code, which abstracts away the details of device code.

# Implementation details

The implementation mainly consists of separating direct C++/Python codegen into two phases:
 1. Emit Wrapper IR lines describing what the wrapper code is supposed to do.
 2. Inside the `codegen()` method of each `WrapperLine`, call backend methods which generate pure Python/C++ code using the information stored in the Wrapper IR line. For example, `KernelCallLine` calls `wrapper._generate_kernel_call_helper`, which is overriden by the various Python and C++ backends to generate the final wrapper code.

The main difficulty in implementing this is that we need to be careful that code is generated in the correct order. Wrapper codegen happens in two passes: first we write code into `self.lines` which mainly contains wrapper IR, but can also contain raw Python or C++ lines in some situations. Then, we convert the wrapper IR into the final Python/C++ code in `self.wrapper_call`. Since the same macros may be used in both passes, it's difficult to ensure that code is written to the correct buffer. The easiest solution for this was to implement a context manager overriding the `writeline` method to write to  `self.wrapper_call` after memory planning is finished. This way, `writeline` writes to `self.lines` in the first pass, and `self.wrapper_call` in the second. This obviated the need to pass `code` or `writeline` variables all the way through the call stack, which would have touched most of the existing macros.

# Test plan

Since this refactor touches all the existing wrapper codegen classes, the existing CI provides good coverage.

The parent PR introduces new tests for the FX IR backend. Among other things, these tests assert that `self.lines` only contains Wrapper IR lines, and no free-form code. While this would not be true of all programs today, the tests suggests that the IR implemented in this PR is sufficient to cover basic PyTorch usage.

# Future directions

These two goals are only partially realized by this PR. These are several important steps which still undergo direct Python/C++ codegen in core files:
 - User-defined Triton kernels.
 - Reinterpret views on outputs, from `gen_output_refs()`. (In the parent PR, the FX converter has a custom way of handling this. This can eventually be ported into Wrapper IR.)
 -  Fallback ops with custom `codegen()` methods, e.g. `ScatterFallback`.
 -  Misc. C++ lines emitted by the various cpp backends, e.g. declaring constants.

These cases will gradually be handled in subsequent PRs, as the Inductor->FX converter expands its coverage. Given that these refactors are pretty tricky to do, it seems wiser to execute them in stages, as opposed to porting everything to Wrapper IR at once.Some Python and codegen still lives in core files such as `ir.py`, as described in previous sections. Hopefully, this PR will serve as a starting point which moves the codebase towards a more modular design. Over time, we can gradually refactor the remaining codegen (mainly in `ir.py`) into backend classes.

One limitation of this PR is that codegen still happens in two phases during `PythonWrapperCodegen`. First, we generate Wrapper IR into `self.lines`, and from there we generate Python or C++ code into `self.wrapper_call`, `self.header`, etc. In the long term, it would be cleaner to split wrapper IR into its own class which doesn't deal with Python/C++ codegen at all. (See the diagram at the top.) That would strictly enforce the boundary between Wrapper IR and Python/C++ wrapper code. However, this would probably be a much larger refactor.

Another limitation of the current code is that the helper functions have a lot of call args. It's also possible to clean this up by passing Wrapper IR ops e.g. `KernelCallLine` into helper functions like `_generate_kernel_call_helper`, since they store all the arguments. However, that change would likely be prone to merge conflicts, so I would like to save it for follow-up PRs if possible.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150458
Approved by: https://github.com/eellison
2025-04-15 17:28:36 +00:00
PyTorch MergeBot
8157e76b79 Revert "[Inductor] Refactor wrapper codegen to use Wrapper IR. (#150458)"
This reverts commit fe7f425de7.

Reverted https://github.com/pytorch/pytorch/pull/150458 on behalf of https://github.com/clee2000 due to broke a lot of tests internally? D72906459 ([comment](https://github.com/pytorch/pytorch/pull/150458#issuecomment-2799578597))
2025-04-13 03:52:42 +00:00
Blaine Burton Rister
fe7f425de7 [Inductor] Refactor wrapper codegen to use Wrapper IR. (#150458)
Preparatory refactor for https://github.com/pytorch/pytorch/pull/146942.

# Feature

This PR refactors the existing wrapper codegen into `WrapperLine` subclasses, extending the existing Memory Planning IR into a fully-fledged Wrapper IR. See the diagram below.

![wrapper_ir](https://github.com/user-attachments/assets/a61db21b-caf3-45d2-bfdb-91066ae4ba6b)

The IR currently supports the following ops:
- All existing memory planning IR ops (`AllocateLine`, `FreeIfNotReusedLine`, etc.)
- Reinterpret views (`ReinterpretLine`)
- Kernel definitions (`KernelDefinitionLine`)
- Calls to defined kernels (`KernelCallLine`)
- Calls to extern kernels (`ExternKernelLine`, `ExternKernelAllocLine`)
- Ops with multiple outputs (`MultiOutputLine`)
- Tensor cleanup at the end of a graph (`FreeLine`)
- Leaving comments in code (`CommentLine`)

There are two main motivations for this refactor:
1. Unlike free-form C++ and and Python code, Wrapper IR lines provide structured information about what the wrapper code does. This serves as a natural extension point for other types of wrapper codegen. For example, the parent PR generates FX IR from Wrapper IR. Wrapper IR aims to give new backends enough information to generate wrapper code without needing to modify core Inductor files such as `ir.py`.
2. This design will hopefully promote stronger modularity and encapsulation.
   a. Inductor's core compilation passes don't need to worry about whether they're targeting Python, C++, FX or anything else. They can simply focus on generating Wrapper IR, and target-specific code can be refactored into the various backends.
   b. Backends do not need to know about all the details and internal state of `V.graph` IR. For example, they don't need to consider whether a buffer has been removed from the graph when generating code. Wrapper IR will hopefully provide a simpler interface for generating wrapper code, which abstracts away the details of device code.

# Implementation details

The implementation mainly consists of separating direct C++/Python codegen into two phases:
 1. Emit Wrapper IR lines describing what the wrapper code is supposed to do.
 2. Inside the `codegen()` method of each `WrapperLine`, call backend methods which generate pure Python/C++ code using the information stored in the Wrapper IR line. For example, `KernelCallLine` calls `wrapper._generate_kernel_call_helper`, which is overriden by the various Python and C++ backends to generate the final wrapper code.

The main difficulty in implementing this is that we need to be careful that code is generated in the correct order. Wrapper codegen happens in two passes: first we write code into `self.lines` which mainly contains wrapper IR, but can also contain raw Python or C++ lines in some situations. Then, we convert the wrapper IR into the final Python/C++ code in `self.wrapper_call`. Since the same macros may be used in both passes, it's difficult to ensure that code is written to the correct buffer. The easiest solution for this was to implement a context manager overriding the `writeline` method to write to  `self.wrapper_call` after memory planning is finished. This way, `writeline` writes to `self.lines` in the first pass, and `self.wrapper_call` in the second. This obviated the need to pass `code` or `writeline` variables all the way through the call stack, which would have touched most of the existing macros.

# Test plan

Since this refactor touches all the existing wrapper codegen classes, the existing CI provides good coverage.

The parent PR introduces new tests for the FX IR backend. Among other things, these tests assert that `self.lines` only contains Wrapper IR lines, and no free-form code. While this would not be true of all programs today, the tests suggests that the IR implemented in this PR is sufficient to cover basic PyTorch usage.

# Future directions

These two goals are only partially realized by this PR. These are several important steps which still undergo direct Python/C++ codegen in core files:
 - User-defined Triton kernels.
 - Reinterpret views on outputs, from `gen_output_refs()`. (In the parent PR, the FX converter has a custom way of handling this. This can eventually be ported into Wrapper IR.)
 -  Fallback ops with custom `codegen()` methods, e.g. `ScatterFallback`.
 -  Misc. C++ lines emitted by the various cpp backends, e.g. declaring constants.

These cases will gradually be handled in subsequent PRs, as the Inductor->FX converter expands its coverage. Given that these refactors are pretty tricky to do, it seems wiser to execute them in stages, as opposed to porting everything to Wrapper IR at once.Some Python and codegen still lives in core files such as `ir.py`, as described in previous sections. Hopefully, this PR will serve as a starting point which moves the codebase towards a more modular design. Over time, we can gradually refactor the remaining codegen (mainly in `ir.py`) into backend classes.

One limitation of this PR is that codegen still happens in two phases during `PythonWrapperCodegen`. First, we generate Wrapper IR into `self.lines`, and from there we generate Python or C++ code into `self.wrapper_call`, `self.header`, etc. In the long term, it would be cleaner to split wrapper IR into its own class which doesn't deal with Python/C++ codegen at all. (See the diagram at the top.) That would strictly enforce the boundary between Wrapper IR and Python/C++ wrapper code. However, this would probably be a much larger refactor.

Another limitation of the current code is that the helper functions have a lot of call args. It's also possible to clean this up by passing Wrapper IR ops e.g. `KernelCallLine` into helper functions like `_generate_kernel_call_helper`, since they store all the arguments. However, that change would likely be prone to merge conflicts, so I would like to save it for follow-up PRs if possible.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150458
Approved by: https://github.com/eellison
2025-04-12 01:15:19 +00:00
Shunting Zhang
252029b294 [Inductor] assert fallback output alignment (#150804)
Previous PR (https://github.com/pytorch/pytorch/pull/150777) fixes the alignment problem for fallback kernel assuming meta kernel is correct. This PR handles the case that meta kernel is incorrect. Assertion is added if the compiler assumes a  fallback kernel output is aligned.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150804
Approved by: https://github.com/jansel, https://github.com/eellison
ghstack dependencies: #150777
2025-04-10 20:01:06 +00:00
Yuanhao Ji
a6933a1c42 [Inductor] Remove triton dtype patch which has landed (#149611)
As this [pr][0] has already landed, we should remove its patch.

Having [mentioned][1] this before, I am making this change now to avoid omissions.

[0]: https://github.com/triton-lang/triton/pull/3342
[1]: https://github.com/pytorch/pytorch/pull/147583/files#r1970440062

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149611
Approved by: https://github.com/eellison
2025-04-10 03:42:55 +00:00
Shunting Zhang
901b02cf16 [Inductor] fix alignement assumption for fallback (#150777)
Inductor right now only works properly for fallback kernels producing aligned output.
When Inductor create layout for fallback kernel output, Inductor does not add the tensor offset to the layout [link](2a1e2b88ed/torch/_inductor/ir.py (L6935-L6941)). Thus unaligned output will be treated as aligned. Adding the offset to the layout directly does not work since that change the index expression in the generated kernel and we may 'double' applying the offset. Triton already considers the offset when passing in the data_ptr.

To solve this issue, we track the unaligned buffer names instead.

This potentially can fix the internal issues we are debugging here: https://fb.workplace.com/groups/1075192433118967/permalink/1618308128807392/

Differential Revision: [D72600784](https://our.internmc.facebook.com/intern/diff/D72600784)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150777
Approved by: https://github.com/eellison, https://github.com/jansel
2025-04-08 18:49:44 +00:00
Mu-Chu Lee
f363fe616d [AOTInductor] Fix autotuning code's codegen (#150522)
Summary:
Codegen used to generate tmp_arg_{index} as temporary args, and index is the position of the caller.
We changed the logic of codegen such that we can reuse previous generated samples, and only delete after arg is no longer used. In this case, we need to make {index} unique, since different functions could reuse the same "tmp_arg_{index}" name string, but corresponds to different args.

Test Plan: `python test/inductor/test_aot_inductor.py -k test_autotuning_args_reuse`

Differential Revision: D72297084

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150522
Approved by: https://github.com/desertfire, https://github.com/22quinn
2025-04-03 00:08:19 +00:00
Boyuan Feng
039ebdc192 [Graph Partition] Support symbol inputs (#149458)
This PR supports symbol inputs to graph partition functions. Before this PR, we rely on `node.read_writes` to get partition inputs. However, this does not cover symbol inputs.

In this PR, for each graph partition, we collect all symbol inputs which are required to be in scope to successfully         perform codegen, including:
- free symbols used in partition nodes.
- free symbols in partition input/node shapes, strides, and offsets. This is needed for recording cudagraphs for tensors with dynamic shapes.

### Note1: MutationLayout
In this example, node.layout is MutationLayoutSHOULDREMOVE. The symint from index `n` does not appear in the size, offset, stridese of node.layout. This symint appear in node.layout.target. So we need extra handle for it.

```python
x = torch.zeros(7, device="cuda")

def fn(n, a):
    a[n] = -1
    return a

opt_fn = torch.compile(fn, fullgraph=True)

for n in range(2, x.shape[0]):
    opt_fn(n, x)
```

### Note2: Composability with Padded Tensor Subclass

W/o graph partition, Padded Tensor subclass lifts outer shapes to input arguments (i.e., arg0_1 for s0, arg1_1 for s1) but does not lift inner shapes (i.e., s2 and s3). Since cudagraph cache relies on integer inputs, it will cache on outer shapes and ignore inner shapes, which is bad.

```
def call(args):
    arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1 = args
    args.clear()
    s0 = arg0_1
    s1 = arg1_1
    arg2_1_size = arg2_1.size()
    s2 = arg2_1_size[0]
    s3 = arg2_1_size[1]
    assert_size_stride(arg2_1, (s2, s3), (s3, 1))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        buf0 = empty_strided_cuda((s2, s3), (s3, 1), torch.float32)
        # Topologically Sorted Source Nodes: [x1, mul], Original ATen: [aten.add, aten.mul]
        triton_poi_fused_add_mul_0_xnumel = s2*s3
        stream0 = get_raw_stream(0)
        triton_poi_fused_add_mul_0.run(arg2_1, buf0, triton_poi_fused_add_mul_0_xnumel, stream=stream0)
        del arg2_1
    return (buf0, s0, s1, s1, )
```

w/ graph partition, the partition function only includes tensor and inner shapes as inputs, to make sure the cudagraph caching is correct. Full Comparison: [code](https://www.internalfb.com/intern/diffing/?paste_number=1761674743)
```python
   def call(self, args):
        arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1 = args
        args.clear()
        s0 = arg0_1
        s1 = arg1_1
        arg2_1_size = arg2_1.size()
        s2 = arg2_1_size[0]
        s3 = arg2_1_size[1]
        assert_size_stride(arg2_1, (s2, s3), (s3, 1))
        partition0_args = [arg2_1, s2, s3]
        del arg2_1
        (buf0,) = self.partitions[0](partition0_args)
        del partition0_args
        return (buf0, s0, s1, s1, )
```

The number of cudagraphs is validated below: (also added to test)
```python
import torch

from padded_tensor import PaddedTensor

# Turning off graph_partition leads to
# torch._inductor.cudagraph_trees.get_container(0).tree_manager.new_graph_id().id=6
# at the end, which is wrong.
# torch._inductor.config.graph_partition = False

# Turning on graph_partition leads to
# torch._inductor.cudagraph_trees.get_container(0).tree_manager.new_graph_id().id=4
# at the end, which is correct.
torch._inductor.config.graph_partition = True

def f(x):
    x1 = x + 1
    return x1 * 2

compiled_f = torch.compile(f, mode="reduce-overhead")

def run(shape):
    x = torch.randn(*shape, device="cuda")
    pad_x = PaddedTensor.from_tensor(x, multipliers={0:4, 1:4})
    assert hasattr(pad_x, "multipliers"), breakpoint()
    eager_out = f(pad_x)

    for _ in range(3):
        compiled_out = compiled_f(pad_x)
    compiled_out = compiled_f(pad_x)

    assert eager_out.shape == compiled_out.shape
    assert eager_out.tensor.shape == compiled_out.tensor.shape
    assert torch.allclose(eager_out.tensor, compiled_out.tensor)

# static shape. record a NEW cudagraph. 1 cudagraph in total now.
run((2,3))
# outer shape is dynamic, leading to a new dynamo graph
# this new dynamo graph forces a NEW cudagraph. 2 cudagraphs in total now
run((3,4))
# outer shape changed but inner shape does not change
# so NO new cudagraph is recorded
run((2,2))
# inner shape is dynamic now, leading to a new dynamo graph
# this new dynamo graph forces a NEW cudagraph. 3 cudagraphs in total now
run((5,6))
# does NOT record a new cudagraph
run((7,8))
# record a NEW cudagraph. 4 cudagraphs in total now
run((10,11))

assert torch._inductor.cudagraph_trees.get_container(0).tree_manager.new_graph_id().id == 4
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149458
Approved by: https://github.com/eellison
2025-03-26 17:21:30 +00:00
Mu-Chu Lee
a0253d2840 [Inductor] Use real input to autotune user defined triton kernels (#149553)
Summary:
User defined Triton kernel sometimes rely on real inputs to determine
the path of execution. We need real inputs to invoke the correct
behavior of the user defined triton kernels (see example in test case,
where we have an early return for random inputs)

Test Plan:
Included in the commit.
python test/inductor/test_aot_inductor.py -k triton_autotuning
python test/inductor/test_aot_inductor.py -k triton_mutated_autotuning

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149553
Approved by: https://github.com/davidberard98, https://github.com/eellison
2025-03-26 16:42:48 +00:00
Benjamin Glass
0f1aaeb62e cpp_wrapper: persist autotune example tensors until last use (#146706)
Patches over an issue where randomly generated example tensors can cause kernel autotuning to fail, when those tensors would not be possible outputs from previous kernels in the sequence. This fixes a failure in `test_torchinductor_opinfo.py` when run with compile-time autotuning, `test_comprehensive_nanquantile_cuda_float64`.

For clarity, the situation triggering this PR looks like kernels `A -> BCDE -> F` (`BCDE` is fused), where one of the outputs from `A` is a boolean tensor describing some of the input data. Previously, we randomly regenerated that boolean tensor and the input data before passing them to `BCDE`, so that they no longer matched. This caused a `tl.device_assert` call in `BCDE` to fail. With this PR, we reuse the random data input to `A` and the output Boolean tensor, such that they match and pass the device assertion in `BCDE`.

Fixes #147799.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146706
Approved by: https://github.com/desertfire
2025-03-25 17:58:40 +00:00
Bin Bao
3647711a89 [AOTI][refactor] Remove dead code (#149287)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149287
Approved by: https://github.com/cyyever, https://github.com/yushangdi
2025-03-20 07:29:27 +00:00
Jason Ansel
b040dc3a53 Reland: [inductor] Simplify grid handling (#148305)
Summary:
Relands D69965761 / https://github.com/pytorch/pytorch/pull/147583

Before this PR, calling a triton kernel would look like:
```py
kernel.run(a, b, xnumel, grid=grid(xnumel), stream=stream0)
```
where the `grid=` was passed as a callable (function closure) arg.  This PR removes the grid arg:
```py
kernel.run(a, b, xnumel, stream=stream0)
```
instead now the grid computation is included in the kernel launcher, with something like:
```py
def launcher(in_ptr0, out_ptr0, xnumel, stream):
    grid_0 = ((xnumel + 1023) >> 10)
    grid_1 = 1
    grid_2 = 1
    runner(grid_0, grid_1, grid_2, stream, function, metadata, None, launch_enter_hook, launch_exit_hook, in_ptr0, out_ptr0, xnumel)
```

This should be faster, since we remove multiple function/dict calls and are able to specialize the grid computation for each `triton.Config`.

It also allows us to unify the handling of grids between the Python and C++ wrapper code.  Before this, C++ wrapper code didn't actually support dynamic grid sizes and instead burned in a static grid.

This unification allows this PR to be a net deletion of code.

Differential [disconnected] Revision: D70471332

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148305
Approved by: https://github.com/shunting314, https://github.com/eellison
2025-03-12 15:52:16 +00:00
PyTorch MergeBot
5ada4e6a53 Revert "Reland: [inductor] Simplify grid handling (#148305)"
This reverts commit 8d08b49015.

Reverted https://github.com/pytorch/pytorch/pull/148305 on behalf of https://github.com/jithunnair-amd due to Broke ROCm CI ([comment](https://github.com/pytorch/pytorch/pull/148305#issuecomment-2718177044))
2025-03-12 14:58:43 +00:00
Jason Ansel
8d08b49015 Reland: [inductor] Simplify grid handling (#148305)
Summary:
Relands D69965761 / https://github.com/pytorch/pytorch/pull/147583

Before this PR, calling a triton kernel would look like:
```py
kernel.run(a, b, xnumel, grid=grid(xnumel), stream=stream0)
```
where the `grid=` was passed as a callable (function closure) arg.  This PR removes the grid arg:
```py
kernel.run(a, b, xnumel, stream=stream0)
```
instead now the grid computation is included in the kernel launcher, with something like:
```py
def launcher(in_ptr0, out_ptr0, xnumel, stream):
    grid_0 = ((xnumel + 1023) >> 10)
    grid_1 = 1
    grid_2 = 1
    runner(grid_0, grid_1, grid_2, stream, function, metadata, None, launch_enter_hook, launch_exit_hook, in_ptr0, out_ptr0, xnumel)
```

This should be faster, since we remove multiple function/dict calls and are able to specialize the grid computation for each `triton.Config`.

It also allows us to unify the handling of grids between the Python and C++ wrapper code.  Before this, C++ wrapper code didn't actually support dynamic grid sizes and instead burned in a static grid.

This unification allows this PR to be a net deletion of code.

Differential Revision: D70471332

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148305
Approved by: https://github.com/shunting314, https://github.com/eellison
2025-03-11 18:51:06 +00:00