mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
master
34 Commits
| Author | SHA1 | Message | Date | |
|---|---|---|---|---|
|
|
7d67a41db4 |
make FXConverter.generate use V.fake_mode instead of _detect_fake_mode_from_gm (#166591)
Summary:
FXConverter configurs _node_metadata_hook passing in `fake_mode` explicitly, which is relevant for cases down the line like `_generate_triton_call` that inserts a `triton_kernel_wrapper_mutation` node.
This `fake_mode` is obtained from `_detect_fake_mode_from_gm`, which can be different from inductor set `V.fake_mode`.
For example, while `V.fake_mode` is not None, `_detect_fake_mode_from_gm` can be **None** for a parent graph containing only a submodule which has no input args and only constants
```
parent graph():
%sub : [num_users=1] = call_module[target=sub](args = (), kwargs = {})
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%sub, slice(None, None, None)), kwargs = {})
return (getitem,)
submodule graph():
%randn : [num_users=1] = call_function[target=torch.ops.aten.randn.default](args = ([5, 10],), kwargs = {device: cuda, pin_memory: False})
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%randn, 1), kwargs = {})
return (add,)
```
Getting this discrepnancy is flawed, it makes `_node_metadata_hook` try running inputs in a different "fake_mode" or no fake_mode when the rest of lowering uses `V.fake_mode`. In some cases where input is placed on custom non-gpu device, it can even complain with "requires device to be started" or tensor device mismatch.
So this diff updates FXConverter.generate to use `V.fake_mode` which is populated by inductor properly.
Test Plan:
added a test `test_const_folded_subgraph` in `test_fxir_backend.py`, this test:
- creates a graph module that calls a subgraph with no inputs and containing only const-foldable ops
- const fold the subgraph
- run FXConverter.generate, expect `fake_mode` used to code-generate is not None
On the prior implementation when `_detect_fake_mode_from_gm` was used, this test would fail as fake_mode would be `None`.
With this change, the test passes, `fake_mode` is properly collected from `V.fake_mode` which is not None.
Differential Revision: D85767475
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166591
Approved by: https://github.com/blaine-rister, https://github.com/mlazos, https://github.com/eellison
|
||
|
|
695cb0d342 |
[2/N][Fix] Fix typo in test folder (#166374)
Fix typo in test folder. _typos.toml ```bash [default.extend-words] nd = "nd" arange = "arange" Nd = "Nd" GLOBALs = "GLOBALs" hte = "hte" iy = "iy" PN = "PN" Dout = "Dout" optin = "optin" gam = "gam" PTD = "PTD" ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/166374 Approved by: https://github.com/cyyever, https://github.com/ezyang |
||
|
|
9a91486e45 |
[Inductor-FX] Don't flatten constant args (#166144)
Summary: Fallback kernels are created with flattened constant args and an `unflatten` utility to unflatten them when needed. Apply it in FXConverter to preserve the original structure Test Plan: added new CI tests Differential Revision: D85347589 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166144 Approved by: https://github.com/blaine-rister |
||
|
|
f2bb22ff84 |
[Inductor-FX] Support Tensor.item (#165599)
# Feature This PR supports compiling `Tensor.item` with Inductor's FX backend. This maps to a custom WrapperCodeGen method called `codegen_dynamic_scalar`. # Implementation The implementation is fairly mechanical, following the usual flow for these types of PRs. 1. Introduce a new Wrapper IR line for this, called `DynamicScalarLine`. 2. Split `PythonWrapperCodegen.codegen_dynamic_scalar` into 2 parts: a public method which generates the Wrapper IR line, and a private one generating Python from Wrapper IR. 3. Implement an FX codegen method for the wrapper IR line. This one calls `aten.where.Scalar` to handle code like `1 if x.item() else 0`, which is a bit tricky. It also calls `aten.item.default` to convert tensors to scalars. # Test plan Added CI tests mirroring the AOTI ones. They test float, int and bool types, the latter taking a distinct codegen path. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165599 Approved by: https://github.com/angelayi, https://github.com/jansel |
||
|
|
7edd18f0fd |
[Inductor-FX] Generalize FloorDiv conversion to handle more complex launch grids. Remove python_slow grid mode. (#163828)
# Problem Inductor's FX backend receives sympy expressions for Triton launch grids, and passes these to a tracer to generate equivalent FX IR. However, the tracer does not support all possible sympy expressions. In particular, it can't handle ops like `floor` and `Pow` which would be found in an expression like `floor(x / y)`. Instead, it expects `FloorDiv(x, y)`, which has the advantage that all intermediate values are integers, unlike `x / y`. Inductor's Python backend uses a trick where `ceil(x / y)` is computed in Python as `-(x // -y)`, which is faster when evaluating Python launch grids at runtime. However, this trick generates more complex sympy expressions, so the FX backend introduced a `"python_slow"` mode using a more familiar form of ceil division. However, this mode is slower to evaluate, which increased production CPU usage. (Internal reviewers see T237853632.) # Solution To get the best of both worlds, this PR removes `"python_slow"` mode, and generalizes the `replace_floor_div` function to handle the more complex expressions resulting from the `"python"` grid mode. The new algorithm is conceptually similar to the existing one, except instead of analyzing only the first argument to a `sympy.Mul` op, it checks all factors, so it can handle expressions containing both `Rational` and `Pow` ops, among other cases. It also uses `Mul.make_args` to handle the case when the argument to `floor` is not a `Mul`. Finally, it uses `expr.is_positive` to check the sign of symbolic exponents. This new algorithm is guaranteed to convert all `floor` ops to an equivalent expression using `FloorDiv`. (To see this, consider that `floor(x) == FloorDiv(x, 1)`.) Note it may not remove all `Pow` ops, with a counterexample being `floor(x / (2 + z ** y))`, but it covers everything we've seen in practice for symbolic launch grids. In particular, it covers the typical case where `Pow` is a factor of the argument to `floor`, and the exponent is `-1`. Is this situation, we move the `Pow` to the denominator of `FloorDiv` and the exponent becomes `1`, eliminating the `Pow` op. # Test plan This PR adds an end-to-end test for static padding with dynamic outer dimensions, which creates a difficult sympy expression that the existing algorithm would not be able to handle. This PR also adds some unit tests for the `replace_floor_div` function. It can be difficult to construct end-to-end tests that expose all the trickiest expressions, as those tests have to pass through a number of other systems handling dynamic shapes. Therefore, it's easier to expose the edge cases with these new unit tests. The tests check that we can replace all `floor` ops in the input expression with `FloorDiv`, then they expand `FloorDiv` back to `floor` and check equality with the original expression. Note this PR also requires some MTIA changes to pass internal tests. Those will be stacked onto the imported diff. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163828 Approved by: https://github.com/nandesuka, https://github.com/angelayi, https://github.com/jansel |
||
|
|
604da4bb9a |
[Inductor-FX] Support unbacked symbol definitions (#163729)
# Problem Inductor sometimes generates unbacked symints to handle things like mismatched branches of `torch.cond`. This code is represented by `pytree.KeyPath`, with special codegen logic to convert it to Python and C++. This was not previously supported by the FX backend. # Feature This PR adds support for unbacked symbol declarations to the FX backend. The implementation is fairly straightforward. 1. Instead of raw Python/C++, update the wrapper codegen method to emit a new Wrapper IR line called `UnbackedSymbolDefsLine`. This contains all the information needed to generate the Python and C++ code. 2. Move the existing Python/C++ codegen to a private method, which is invoked by `UnbackedSymbolDefsLine.codegen()`. 3. Implement a method to generate FX IR from unbacked symbol definitions. The implementation is based on recursive descent, consuming some keypath entries, emitting an FX IR node, and recursing to the rest of the keypath. It is conceptually identical to the existing algorithm for Python and C++, except it generates FX nodes. 4. The FX backend currently relies on size hints to generate autotuning arguments, and consequently autotuning does not support unbacked SymInts. At some point, we would like to generalize the autotuning logic to support these. But for now, simply emit a warning and skip autotuning when we see them. 5. The new test case exposed some tricky issues reconciling Triton call args with constants stored in `triton_meta`. This PR rewrites the relevant helper function to do this in a more principled way. # Test plan This PR imports an existing control flow test to the FX backend's test suite. The test uses unbacked symbol definitions to handle mismatched dynamic shapes coming from `torch.cond` branches. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163729 Approved by: https://github.com/jansel |
||
|
|
a8c528c105 |
[1/N] Apply UP035 rule in tests (#163947)
Apply UP035 `ruff` rule in tests, but some tests for `fx` and `dynamo` are excluded in case the old typing is the test target. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163947 Approved by: https://github.com/ezyang |
||
|
|
f68de58c9d |
[Inductor-FX] Support symbol and dynamic scalar graph inputs and outputs (#163596)
# Problems This PR fixes a few edge cases that the FX converter missed related to dynamic shapes. 1. Inductor graphs can sometimes take `sympy.Symbol` inputs. We have logic to convert these to FX placeholder nodes. However, this logic did not update the `self.expr_to_proxy` table mapping symbols to proxy nodes. (There was existing logic to do this for `ir.TensorBox` inputs, but not `sympy.Symbol`.) This caused sympy tracing to fail when these symbol inputs were used in other expressions. 2. We lacked codegen for `ShapeAsConstantBuffer`. This IR node is seen when the graph input or output is a scalar computed from dynamic shapes. # Fixes a. Update `self.expr_to_proxy` when generating placeholders for `sympy.Symbol` inputs. Change `SymbolBuffer.get_example` to convert the symbol to a `torch.SymInt`, so we can populate `meta["val"]` correctly and use the value in other computations. b. Support `ShapeAsConstantBuffer` by tracing the sympy expression. c. Move output generation inside the metadata hook, allowing us to populate `meta["val"]` for the nodes computing `ShapeAsConstantBuffer`. # Test plan Added several new CI tests: 1. `torch.cond` with dynamic shapes. This exposes both issues, as the predicate is a `ShapeAsConstantBuffer` and one of the subgraphs uses a symbol input, due to the closure. Also tests when the parent and subgraphs have different input shapes. 2. Output dynamic shape scalar. This tests `ShapeAsConstantBuffer` as an output. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163596 Approved by: https://github.com/angelayi, https://github.com/jansel |
||
|
|
e56dd5d770 |
[Inductor-FX] Support torch.cond (#163234)
# Feature
Support `torch.cond` in the FX converter. The generated FX IR is conceptually indentical to what would come from `torch.export`:
- Submodules as stored as attributes, and accessed via `getattr`.
- The conditional is represented as `torch.ops.higher_order.cond`, which takes in the subgraphs, a predicate and submodule inputs.
# Implementation overview
The FX backend generates code for subgraphs using the following steps:
1. When `codegen_conditional` is called in `WrapperFxCodegen`, we emit a `ConditionalLine`.
a. We also codegen the true/false subgraphs at this time, storing their subgms for later.
2. At the beginning of FX conversion, generate `get_attr` nodes accessing each subgraph. It's important to do this at the start, before registering the node metadata hook. This also matches the convention followed by torch.export.
3. When we see the `ConditionalLine` in the FX converter, we generate a corresponding `torch.ops.higher_order.cond`.
# Implementation details
This ended up being a substantial change, as wrapper codegen has some special logic for subgraphs.
Certain methods of `PythonWrapperCodegen` are overridden by `SubgraphPythonWrapperCodegen`. To apply these overrides, we use multiple inheritance with the registered subclass of `WrapperFxCodegen`.
Unlike most other wrapper codegen methods, which map 1:1 to Wrapper IR lines, subgraph codegen generates a number of wrapper lines including `EnterSubgraphLine` and `ExitSubgraphLine`, along with Python or C++ code calling the subgraph as a function. These lines are used for some backends' memory planning.
In contrast, FX IR typically represents a subgraph call as a single HOP node, or a `call_module` op. To account for this difference, this PR introduces a new wrapper IR line called `ConditionalLine`, which is only used by the FX backend. We override the `codegen_conditional` method to emit this line. This sidesteps having to port the existing subgraph codegen and associated memory planning to Wrapper IR. (In principle, it seems possible to adapt the existing backends to `ConditionalLine`, but it could be a larger refactor, since we'd also have to update the memory planning.)
Some of the lower-level subgraph codegen methods are still shared between the FX and Python backends, such as `generate_subgraph_common`. Those were easier to port to Wrapper IR.
This also required generalizing the way the FX converter handles graph inputs and outputs. Previously, it assumed the IO signature was the same as `V.graph.module`, but this is only true for the parent graph, and not subgraphs. Instead, we need to call `get_graph_inputs` and `get_graph_outputs` to populate the inputs and outputs for subgraphs.
# Test plan
This PR adds a couple of tests using torch.cond. Here's an example graph generated by one of them:
```
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
%false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
%cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%arg0_1, %true_graph_0, %false_graph_0, (%arg1_1,)), kwargs = {})
%buf1 : [num_users=2] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
%triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 6, constant_args_idx: 6, grid: [(1, 1, 1)], tma_descriptor_metadata: {}, kwargs: {in_out_ptr0: %buf1, xnumel: 6, XBLOCK: 8}})
return buf1
```
It also removes an existing negative test which checked that a certain error was raised when subgraphs were encountered.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163234
Approved by: https://github.com/angelayi, https://github.com/jansel
|
||
|
|
3bfa35d62e |
[AOTI-FX] Solve for undefined symbols in dynamic input shapes (#163044)
# Problem
When dynamic shapes are passed to AOTInductor, they usually have a very basic form like `(s0, 5, 27)`. In these cases it's straightforward to generate code defining the symbol `s0` as a specific dimension of the input tensor. However, AOTI can handle slightly more generic expressions than this, such as `(2 * s0, 5, 27)`. In these cases, we don't immediately know the value of `s0`, but we need to solve for it, since it may be referenced in other parts of the program such as kernel call arguments, launch grids, etc.
# Feature
This PR adds support for more generic dynamic input expressions in the FX backend, following the implementation already present in AOTI's C++ backend:
1. Check if the expression contains *one* undefined symbol, as multiple variables would make the equation underdetermined. Let's call this `s0`. (We could potentially generalize this, but this PR focuses on cases AOTI can already handle.)
2. Generate a new symbol for the relevant size or stride of the input tensor. Let's call this `size`. This is computed with FX nodes just as a normal symbol would be.
3. Use sympy to solve for `s0` in terms of `size`. Let's call the resulting expression `solution`.
4. Since we know `s0` is an integer, `solution == floor(solution)`. Take the floor and then convert division to `FloorDiv`. This is required to trace through the expression, since the return value of regular division is not guaranteed to be an integer.
5. Generate FX for the modified `solution`, which defines the value `s0`.
6. Override the relevant method of `PythonWrapperCodegen` to a no-op, since the FX converter handles the above on its own.
# Test plan
In addition to the existing dynamic shapes tests, this PR adds new test cases where the input shape contains a non-trivial expression. This dynamic input dimension is then multiplied by other dimensions to form the argument to a `reshape`.
Here's an example graph from one of the CI tests. In this case, the input expression was `2*x + 1`, and the solution is `x = (sym_size_int - 1) / 2`:
```
graph():
%arg0_1 : [num_users=2] = placeholder[target=arg0_1]
%sym_size_int : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%arg0_1, 0), kwargs = {})
%sym_sum : [num_users=1] = call_function[target=torch.sym_sum](args = ([-1, %sym_size_int],), kwargs = {})
%floordiv : [num_users=1] = call_function[target=operator.floordiv](args = (%sym_sum, 2), kwargs = {})
%mul : [num_users=2] = call_function[target=operator.mul](args = (8, %floordiv), kwargs = {})
%sym_sum_1 : [num_users=2] = call_function[target=torch.sym_sum](args = ([4, %mul],), kwargs = {})
%buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ([%sym_sum_1], [1]), kwargs = {dtype: torch.float32, device: cuda:0})
%sym_sum_2 : [num_users=1] = call_function[target=torch.sym_sum](args = ([35, %mul],), kwargs = {})
%floordiv_1 : [num_users=1] = call_function[target=operator.floordiv](args = (%sym_sum_2, 32), kwargs = {})
%triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(%floordiv_1, 1, 1)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg0_1, out_ptr0: %buf0, xnumel: %sym_sum_1, XBLOCK: 32}})
return buf0
```
The `sym_size_int` node returns the first dimension of the input tensor. Next, `floordiv` computes the input symbol in terms of the input size. Then, the launch grid is computed by `floordiv_1`, the kernel argument by `sym_sum_1`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163044
Approved by: https://github.com/jansel
|
||
|
|
9aca0ba027 |
[Inductor-FX] Support IndexPutFallback (#162863)
# Feature This PR supports lowering `IndexPutFallback` through Inductor's FX converter. The approach is very similar to the one taken in https://github.com/pytorch/pytorch/pull/162686. Compared to `ScatterFallback`, this required one additional change: the value of `self.op_overload` for `IndexPutFallback` was inaccurate. Previously, it used `aten.index_put`, which would result in unsound FX IR. The existing Python/C++ codegen use `aten.index_put_`, since the fallback mutates its input. This PR changes `self.op_overload` to match that. # Test plan Added a CI test lowering deterministic index put via the FX converter. Pull Request resolved: https://github.com/pytorch/pytorch/pull/162863 Approved by: https://github.com/angelayi |
||
|
|
a7bbc5fea7 |
[Inductor-FX] Support ScatterFallback (#162686)
# Problem Inductor has a `ScatterFallback` op with custom Python and C++ wrapper codegen macros. This is used in certain situations where the default Triton codegen doesn't apply, and especially for reductions which need to be deterministic. Since this op used direct Python/C++ codegen, it wasn't compatible with the FX backend. # Feature This PR refactors the associated wrapper codegen to support `ScatterFallback`. This follows the same basic steps that were used for other fallback ops including `MultiOutput` and `ExternKernel`: 1. Create a new wrapper IR op called `ScatterFallbackLine`. Move the logic in `ScatterFallback.cogeden` to `ScatterFallbackLine.codegen`, to prevent it from affecting the FX backend. This logic is unsafe for FX because it may generate Python or C++ strings with methods like `codegen_reference()`. 2. To eleminate the dependence on `V.graph`, move language-specific logic to the respective wrapper codegen subclasses. In this case, C++ codegen has some special logic, which is moved to `CppWrapperCpu`. 3. Create a new method in `FXWrapperCodegen` to handle `ScatterFallbackLine`. # Test plan Added a couple of CI tests for the FX backend with scatter fallbacks. Pull Request resolved: https://github.com/pytorch/pytorch/pull/162686 Approved by: https://github.com/jansel |
||
|
|
8508651477 |
Fix flaky AOTFxirTestCase (#162472)
Fixes https://github.com/pytorch/pytorch/issues/162357 Fixes https://github.com/pytorch/pytorch/issues/160970 Fixes https://github.com/pytorch/pytorch/issues/161038 Fixes https://github.com/pytorch/pytorch/issues/160951 Fixes https://github.com/pytorch/pytorch/issues/161698 These tests were introduced in https://github.com/pytorch/pytorch/pull/160765 and they are all flaky when `torch._inductor.aot_compile` uses multiple threads (the default option). The issue could be reproduced by running them locally multiple times. For example, ``` pytest --flake-runs 10 --flake-finder -v inductor/test_fxir_backend.py -k test_aoti_fx_add (output logs at P1938386961) ... --------------------------------------------------------------------------------------------------------------------------------------------------- Captured stdout call --------------------------------------------------------------------------------------------------------------------------------------------------- inductor [('async_compile_cache_miss', 1)] graph_break [] --------------------------------------------------------------------------------------------------------------------------------------------------- Captured stdout call --------------------------------------------------------------------------------------------------------------------------------------------------- inductor [('async_compile_cache_miss', 1)] graph_break [] --------------------------------------------------------------------------------------------------------------------------------------------------- Captured stdout call --------------------------------------------------------------------------------------------------------------------------------------------------- inductor [('async_compile_cache_miss', 1)] graph_break [] --------------------------------------------------------------------------------------------------------------------------------------------------- Captured stdout call --------------------------------------------------------------------------------------------------------------------------------------------------- inductor [('async_compile_cache_miss', 1)] graph_break [] --------------------------------------------------------------------------------------------------------------------------------------------------- Captured stdout call --------------------------------------------------------------------------------------------------------------------------------------------------- inductor [('async_compile_cache_miss', 1)] graph_break [] --------------------------------------------------------------------------------------------------------------------------------------------------- Captured stdout call --------------------------------------------------------------------------------------------------------------------------------------------------- inductor [('async_compile_cache_miss', 1)] graph_break [] --------------------------------------------------------------------------------------------------------------------------------------------------- Captured stdout call --------------------------------------------------------------------------------------------------------------------------------------------------- inductor [('async_compile_cache_miss', 1)] graph_break [] --------------------------------------------------------------------------------------------------------------------------------------------------- Captured stdout call --------------------------------------------------------------------------------------------------------------------------------------------------- inductor [('async_compile_cache_miss', 2), ('async_compile_cache_hit', 1)] graph_break [] --------------------------------------------------------------------------------------------------------------------------------------------------- Captured stdout call --------------------------------------------------------------------------------------------------------------------------------------------------- inductor [('async_compile_cache_miss', 2), ('async_compile_cache_hit', 1)] graph_break [] --------------------------------------------------------------------------------------------------------------------------------------------------- Captured stdout call --------------------------------------------------------------------------------------------------------------------------------------------------- inductor [('async_compile_cache_miss', 2), ('async_compile_cache_hit', 1)] graph_break [] ================================================================================================================================================= short test summary info ================================================================================================================================================== FAILED [0.4834s] inductor/test_fxir_backend.py::AOTFxirTestCase::test_aoti_fx_add - AttributeError: 'NoneType' object has no attribute '__code__' FAILED [0.4576s] inductor/test_fxir_backend.py::AOTFxirTestCase::test_aoti_fx_add - AttributeError: 'NoneType' object has no attribute '__code__' FAILED [0.4613s] inductor/test_fxir_backend.py::AOTFxirTestCase::test_aoti_fx_add - AttributeError: 'NoneType' object has no attribute '__code__' =============================================================================================================================================== 3 failed, 7 passed in 12.89s =============================================================================================================================================== ``` Setting `compile_threads` to 1 will get rid of the test flakiness, but there might be underlying issues from https://github.com/pytorch/pytorch/pull/160765. Pull Request resolved: https://github.com/pytorch/pytorch/pull/162472 Approved by: https://github.com/angelayi, https://github.com/Skylion007 |
||
|
|
9aedb3cd87 |
[AOTI-FX] Support registering custom FX backends (#162317)
# Feature Currently, `torch._inductor.compile_aot` always uses the `WrapperFxCodegen` class. In contrast, Python and C++ codegen allow users to register custom backends. This PR brings that feature to FX codegen. # Test plan Added a CI test registering a custom FX backend. Pull Request resolved: https://github.com/pytorch/pytorch/pull/162317 Approved by: https://github.com/jansel |
||
|
|
9491d289b3 |
Support generic dynamic shape with padding (#160997)
Summary: Inductor has the following configurations: config.comprehensive_padding config.padding_alignment_bytes config.padding_stride_threshold In the case of static shape by enabling these three options Inductor will generate code for Flexible layout tensors that tries to pad up all stride dimension to be a multiple of config.padding_alignment_bytes for strides above: config.padding_stride_threshold. In the case where dynamic shapes is enabled no padding is done today. This PR introduces the following configuration which allows the user to specify they wish to generated a padded stride even in the case of dynamic shape operations. This is mainly done so we don't break the previous behaviour of not padding up dynamic shape use cases. The config.padding_stride_threshold does not apply since the values of the strides are dynamic. config.pad_dynamic_shapes In addition to this a new mode "python_slow" has been added to launch grid calculation which achieves the same ceildiv behaviour that is generally applicable to integer division. This is done to prevent test regressions and make wrapper_fxir codegen more generic. Test Plan: CI Rollback Plan: Differential Revision: D80468808 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160997 Approved by: https://github.com/blaine-rister, https://github.com/jansel |
||
|
|
81b7b16618 |
Reland "[Fix XPU CI][Inductor UT] Fix test cases broken by community. (#161142)" (#161949)
This PR reland #161142 which is reverted to be able to revert other PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/161949 Approved by: https://github.com/jansel |
||
|
|
54e275e0d8 |
Revert "[Fix XPU CI][Inductor UT] Fix test cases broken by community. (#161142)"
This reverts commit
|
||
|
|
c83cbd2f2a |
[Fix XPU CI][Inductor UT] Fix test cases broken by community. (#161142)
Fixes #161384, Fixes #161162, Fixes #160946, Fixes #160947, Fixes #160948 Pull Request resolved: https://github.com/pytorch/pytorch/pull/161142 Approved by: https://github.com/jansel |
||
|
|
26d0ff1cba |
[AOTI-FX] Enhance launch grid FloorDiv replacement using sympy.together. (#161582)
# Feature 2d launch grids with dynamic shapes can contain sympy expressions like `floor(x / 128 + y / 128)`. This breaks the dynamic shapes tracer which only supports `FloorDiv`, and not `floor`. To handle this case, call `sympy.together` prior to pattern matching to convert this to `floor((x + y) / 128)`. Then, we can recognize the pattern and map it to `FloorDiv(x + y, 128)`. # Test plan Added a custom Triton test exposing this. The test calls a 2d autotuned kernel with dynamic shapes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/161582 Approved by: https://github.com/nandesuka |
||
|
|
9de9d25f8d |
[Inductor-FX] Support custom triton kernels (#161474)
# Feature Add support for custom Triton kernels to the FX backend. This turned out not to require any new features, except for a minor change to handle `tl.constexpr` arguments which are not part of the autotuning config. # Caveat This may not cover every possible case. For example, we might need more features for autotuning custom Triton code. This PR entirely skips the [custom codegen ](https://github.com/pytorch/pytorch/blob/main/torch/_higher_order_ops/triton_kernel_wrap.py#L1034-L1039) for user-defined grid functions, but there may be edge cases requiring this logic. However, this PR seems to do a reasonable job as many of the grids end up being written into Inductor/Triton metadata and don't require special codegen. As a follow up, I'm planning to test this against all of AOTI's custom Triton kernel tests. # Test plan Added a CI test using a custom Triton kernel. Pull Request resolved: https://github.com/pytorch/pytorch/pull/161474 Approved by: https://github.com/angelayi |
||
|
|
c081481bbe |
[aoti-fx] Output OpOverload fallbacks (#161195)
Updates the inductor-wrapper-fxir code to use the kernel.op_overload when generating extern kernel calls. This way we can keep the IR consistent with using ATen ops. TODO: we're also inserting torch.empty_strided calls -- need to turn this into aten too Pull Request resolved: https://github.com/pytorch/pytorch/pull/161195 Approved by: https://github.com/blaine-rister |
||
|
|
d228a776e9 |
[Inductor-FX] Support Tensorbox outputs (#161245)
# Problem The FX converter previously supported graph outputs which were `StorageBox`, but not `TensorBox`. The latter seems to show up in certain cases when the output is a slice/view of the input. # Fix This PR generalizes the code to handle `MutableBox` instead of `StorageBox` specifically. # Test Added a CI test exposing the issue. The test case was found by intentionally breaking `TensorBox(ReinterpretView` support in https://github.com/pytorch/pytorch/pull/161258. Pull Request resolved: https://github.com/pytorch/pytorch/pull/161245 Approved by: https://github.com/angelayi |
||
|
|
3dacaf0e1e |
[aoti-fx] Add meta["val"] metadata (#161019)
Summary: Added a `_set_node_metadata_hook` which automatically adds node.meta["val"] to every new node that gets created under this context. Test Plan: ` buck2 test //mtia/host_runtime/afg/tests:test_dynamic_shapes_advanced_ops` https://www.internalfb.com/buck2/866439a2-2ba6-42d1-8e43-508d60456e2e `buck2 test //mtia/host_runtime/afg/tests:test_dynamic_shapes_basic_ops` https://www.internalfb.com/intern/testinfra/testrun/11540474149662857 Rollback Plan: Differential Revision: D80579336 Pull Request resolved: https://github.com/pytorch/pytorch/pull/161019 Approved by: https://github.com/blaine-rister |
||
|
|
6ac9035a84 |
[aoti-fx] Dynamic shapes support (#160766)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160766 Approved by: https://github.com/jansel ghstack dependencies: #160765 |
||
|
|
bab79824cb |
[aoti-fx] Initial AOTInductor FX (#160765)
Using the existing WrapperFxCodegen backend, this PR prototypes an AOT version of it which will directly return a graph module.
How to use:
```python
exported_gm = torch.export.export(model, inp, dynamic_shapes=dynamic_shapes).module()
compiled_gm = torch._inductor.aot_compile(
exported_gm, inp, options={"fx_wrapper": True, "compile_threads": 1}
)
assert torch.allclose(model(*inp), compiled_gm(*inp))
```
The motivation behind this is that backends like ExecuTorch/MTIA would like to use inductor's optimization technologies, but might have their own graph lowering pipelines so they might not want to use AOTI (which generates an so).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160765
Approved by: https://github.com/jansel
|
||
|
|
426f249f20 |
Fix launch grid calculation (#159497)
Summary: The launch grid calculation code is using a python trick to achieve CeilDiv() through negative integer division with FloorDiv(). This is language dependent behaviour that doesn't apply to all languages. In the FXIR backend we negate this behaviour and replace the experssion with CeilDiv() operation so the computation is correct regardless of language used. Not directly directly changing the orginal computation as it leads to a performance degredation. Test Plan: CI Rollback Plan: Differential Revision: D79275534 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159497 Approved by: https://github.com/blaine-rister |
||
|
|
2004f8aa10 |
FXConverter handling of generic output in inductor fallback kernel (#159002) (#159297)
Summary:
A fallback kernel's output may be a non-list/tuple but a `MultiOutput` with empty indices. Allow the `FXConverter` to handle such case.
Test Plan:
Modified the fxir test for fallbacks, then ran `buck2 test mode/dev-nosan caffe2/test/inductor:fxir_backend -- test_fallback`.
Before this diff the modified test would fail with
```
File "/re_cwd/buck-out/v2/gen/fbcode/e2105f7329ead90a/caffe2/test/inductor/__fxir_backend__/fxir_backend#link-tree/torch/_inductor/codegen/wrapper_fxir.py", line 341, in generate
line.codegen_fx(self)(line)
File "/re_cwd/buck-out/v2/gen/fbcode/e2105f7329ead90a/caffe2/test/inductor/__fxir_backend__/fxir_backend#link-tree/torch/_inductor/codegen/wrapper_fxir.py", line 489, in _generate_multi_output
inds = line.indices[0][1:]
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
IndexError: list index out of range
```
(Full error paste in P1878839403)
With this diff the error is no longer present.
Rollback Plan:
Differential Revision: [D79126619](https://our.internmc.facebook.com/intern/diff/D79126619)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159297
Approved by: https://github.com/blaine-rister
|
||
|
|
92f41ccc26 |
[Inductor] Support precomputed size args in the FX backend. (#157758)
# Feature
If a Triton kernel has a complicated indexing expression, Inductor may decide to precompute it on the host and pass it to the kernel as an argument. This happens in situations like broadcasts with dynamic shapes.
This PR adds support for this feature to Inductor's FX IR backend.
We generate FX IR for precomputed size args in 3 steps:
1. In `PythonWrapperCodegen`, this PR refactors the relevant code to use a `SymbolicCallArgLine` instead of raw Python strings. This stores a (symbol, expr) pair. (Prior to this PR, it was (str, expr), but changing this to a symbol makes it easier to do substitutions later on.)
2. In `WrapperFxCodegen`, keep a dict of {symbol: expr} arg defs which gets updated whenever we see a `SymbolicCallArgLine`.
3. When the FX backend sees a `KernelCallLine`, it uses this dict to replace symbolic call args with their definitions.
In the longer run, it might be desirable to emit FX nodes defining these symbolic call args. That way, we could reuse the size computation when the same kernel is called multiple times. However, I wasn't sure if there was an existing way to generate FX nodes from a sympy expression, and implementing that seemed like overkill for the present purposes.
# Test plan
Added a new CI test exercising this feature.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157758
Approved by: https://github.com/jansel
|
||
|
|
f5e6e52f25 |
[BE][PYFMT] migrate PYFMT for test/inductor/ to ruff format (#148186)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148186 Approved by: https://github.com/jansel |
||
|
|
2e2ea7290a |
[Inductor] Support autotuning in the FX backend. (#155049)
# Feature If `config.triton.autotune_at_compile_time` is set to `True`, autotune Triton kernels during FX conversion. Else, stick with the existing behavior of using the first precompiled config. # Test plan Added CI tests verifying that the tuner is called iff this flag is set, with and without dynamic shapes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/155049 Approved by: https://github.com/jansel |
||
|
|
437df54cc8 |
[Inductor] Fix a few FX conversion bugs. (#154958)
# Feature This PR fixes two bugs with Inductor's FX backend. 1. When extracting offsets from `ReinterpretView`'s, we accidentally took the offset of the parent layout instead of the view's layout. This case is triggered when multiple kernels write into the same buffer due to `torch.cat`. 2. In certain rare cases, `V.graph.graph_inputs` can contain a constant input value. In case this happens, create a new `sympy.Symbol` for the input, for compatibility with the existing `SymbolBuffer` abstraction mapping to an FX placeholder. This case is triggered when calling `torch._inductor.compile` on certain modules coming from `torch.export`. # Test plan Added a couple of tests exposing these bugs. 1. Concat with multiple kernels writing to the same buffer. 3. `Export` -> `torch._inductor.compile` with a constant input. Pull Request resolved: https://github.com/pytorch/pytorch/pull/154958 Approved by: https://github.com/jansel |
||
|
|
bc11afd41f |
[Inductor] FX backend via Wrapper IR (#146942)
# Sub-PRs These PRs contain refactors from the main one. They should be reviewed and merged first. - https://github.com/pytorch/pytorch/pull/150458 - https://github.com/pytorch/pytorch/pull/152391 - https://github.com/pytorch/pytorch/pull/152587 # Feature The goals of this PR are twofold. ## Goal 1: Introduce Wrapper IR as an intermediate step in wrapper codegen. In addition to Triton/C++/Halide kernels, Inductor also generates "wrapper" code which allocates memory and calls the kernels. Originally, this wrapper code was fairly standard Python which resembled a user-written PyTorch program. Over time, various wrapper code generators have been added to accommodate things like AOTInductor, which prefers C++ code for static compilation. This complexity has bled into other parts of the codebase, as we now need if/else statements to choose between Python and C++ macros. (See an example [here](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/ir.py#L5515-L5522).) Since most of these code generation steps are conceptually identical across target languages, it seems reasonable to refactor them into some kind of intermediate representation which can be shared between the various backends. This might also make it easier to develop out-of-tree backends which cannot put their own macros in core Inductor components. This PR takes some initial steps to formalize Inductor's wrapper codegen by generalizing the existing Memory Planning IR into a fully fledged Wrapper IR. This is pretty much identical to the existing Memory Planning IR, but it supports a richer set of ops for things like kernel definitions and calls. This refactor could help encapsulate wrapper codegen. Ideally, we don't need to worry about direct Python/C++ codegen in the main compiler files such as `ir.py`, and can instead defer these to classes like `PythonWrapperCodegen` and `CppWrapperCpu`, which operate on the Wrapper IR. ## Goal 2: Convert Wrapper IR into FX IR. One of the main benefits of Wrapper IR is to enable more diverse Inductor backends. This PR introduces a converter from Wrapper IR into [FX IR](https://pytorch.org/docs/stable/fx.html), which is the intermediate representation most commonly used in PyTorch graph compilers. The purpose of this is to enable out-of-tree backends to consume Inductor's output in FX IR, which would hopefully make Inductor easier to leverage in novel compilers, hardware accelerators, etc. It's not trivial to generate Python or C++ code which Inductor can compile and run, and doing so may require changes to other core Inductor files, for the reasons outlined in the previous section. The goal of supporting FX output is to enable something like `torch.compile`'s [custom backend](https://pytorch.org/docs/stable/torch.compiler_custom_backends.html) system, in which an out-of-tree backend can receive an optimized FX graph from Inductor, and compile and run it however it likes. The typical users of this feature would likely not be part of PyTorch, and may or may not support running a kernel in eager mode. However, they can understand what `torch.empty_strided` means, compile and run Triton kernels, etc. So we just need to present them with an FX graph saying what code Inductor wants to run, which should be easier to analyze and transform in a third party system than Python or C++ source. Since FX IR is fairly stable, this mechanism should hopefully isolate third-party backends, hardware accelerators, etc. from the implementation details of Inductor, and vice versa. # Current status Things that seem to work: - Converted a lot of the most common Python codegen lines to Wrapper IR lines. - Handled the following cases, in addition to what was already in the Memory Planning IR: - Comments - Triton kernels - Extern/fallback kernels - Freeing tensors (`del buf0`) - MultiOutput - Graph outputs - ReinterpretView / StorageBox, for both call args and outputs. - FX conversion asserts that the program only contains Wrapper IR lines, and not strings of Python/C++ code. - Prototype FX converter which can handle some of the most common use cases. - Defining Triton kernels, and putting them in a side table using TorchDynamo's existing [utilities](https://dev-discuss.pytorch.org/t/higher-order-operators-2023-10/1565). - Calling wrapped Triton kernels. - Calling extern kernels and certain types of fallback kernels. - Support both `extern_kernels.*` and `aten.*`. - Support multi-output kernels like `torch.topk`. - Graphs with multiple inputs/outputs. - Training i.e. calling `Tensor.backward()` in a compiled function. - Graph breaks (training). - Run the `torch.fx.GraphModule` on GPU using the standard `__call__` method. This makes it easy to test the correctness of FX codegen. Things that don't work: - Both Wrapper IR and Wrapper -> FX coverage are currently best effort. There are still features which aren't captured as Wrapper IR lines, and fall back to plain strings. This representation is functionally correct but probably not rich enough to achieve the goals outlined in the previous sections. - Fallback kernels seem like the most difficult thing to fully cover, since they each define their own Python/C++ macros that would need to be converted to FX. - Size/alignment asserts are currently disabled via the config file. It's possible to generate FX IR for these, but it seems reasonable to defer these sanity checks to a later PR. - CommBuffer's and distributed communication are not yet supported. An earlier version of this PR attempted to implement this by calling `empty_strided_p2p`. However, building and testing distributed support seems non-trivial, so it's probably better to defer this. # Out-of-tree compilers With this PR, out of tree backends will be able to do further compilation on the FX graphs by subclassing `WrapperFxCodegen` and overriding the `compile_graph` function. This follows the same API as torch.compile's [custom backends](https://pytorch.org/docs/stable/torch.compiler_custom_backends.html), where the user simply returns a callable running the graph. The callable need not be a method of `GraphModule` or any other PyTorch class. See an example below. ``` from torch._inductor.codegen.wrapper_fxir import WrapperFxCodegen class MyCustomBackend(WrapperFxCodegen): def compile_graph(self, gm): # Add 1 to the graph's outputs def compiled_fn(*args): return [x + 1 for x in gm.graph.forward(*args)] return compiled_fn ``` # Example FX graphs This section contains some example FX graphs generated by Inductor. The correctness of these graphs was verified against eager mode by calling the corresponding `GraphModule`. Here's an FX graph calling a basic Triton kernel. Notice how outputs are allocated with `torch.empty_strided`, and the Triton kernel is called by reference to Dynamo's triton side table. ``` graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %arg1_1 : [num_users=1] = placeholder[target=arg1_1] %buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ((8,), (1,)), kwargs = {dtype: torch.float32, device: cuda:0}) %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(8,)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg1_1, in_ptr1: %arg0_1, out_ptr0: %buf0, xnumel: 8, XBLOCK: 8}}) return (buf0,) ``` Here's a more complicated graph that calls a `torch.addmm` extern kernel. ``` graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %arg1_1 : [num_users=2] = placeholder[target=arg1_1] %buf0 : [num_users=3] = call_function[target=torch.empty_strided](args = ((), ()), kwargs = {dtype: torch.float32, device: cuda:0}) %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(1,)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg1_1, out_ptr0: %buf0, xnumel: 1, r0_numel: 129, XBLOCK: 1}}) %buf2 : [num_users=2] = call_function[target=torch.empty_strided](args = ((129, 1), (1, 1)), kwargs = {dtype: torch.float32, device: cuda:0}) %addmm : [num_users=0] = call_function[target=torch.addmm](args = (%buf0, %arg0_1, %arg1_1), kwargs = {alpha: 1, beta: 1, out: %buf2}) %delete : [num_users=0] = call_function[target=torch._inductor.codegen.wrapper_fxir.delete](args = (%buf0,), kwargs = {}) return (buf2,) ``` Here's a graph which indexes into a tuple using `operator.getitem`. This is necessary to use the output of the `torch.topk` operation. ``` graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %buf0 : [num_users=3] = call_function[target=torch.ops.aten.topk.default](args = (%arg0_1, 2), kwargs = {}) %buf1 : [num_users=2] = call_function[target=operator.getitem](args = (%buf0, 0), kwargs = {}) %buf2 : [num_users=2] = call_function[target=operator.getitem](args = (%buf0, 1), kwargs = {}) %delete : [num_users=0] = call_function[target=torch._inductor.codegen.wrapper_fxir.delete](args = (%buf0,), kwargs = {}) %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(2,)], tma_descriptor_metadata: {}, kwargs: {in_out_ptr0: %buf1, xnumel: 2, XBLOCK: 2}}) %triton_kernel_wrapper_mutation_1 : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 1, constant_args_idx: 1, grid: [(2,)], tma_descriptor_metadata: {}, kwargs: {in_out_ptr0: %buf2, xnumel: 2, XBLOCK: 2}}) return (buf1, buf2) ``` Here's a graph that reinterprets an output tensor using `torch.as_strided`. This is one way to handle Inductor's `ReinterpretView` op. ``` graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %arg1_1 : [num_users=1] = placeholder[target=arg1_1] %buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ((2, 4), (4, 1)), kwargs = {dtype: torch.float32, device: cuda:0}) %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(8,)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg0_1, in_ptr1: %arg1_1, out_ptr0: %buf0, xnumel: 8, XBLOCK: 8}}) %buf0_view_buf0_0 : [num_users=1] = call_function[target=torch.as_strided](args = (%buf0, (8,), (1,), 0), kwargs = {}) return (buf0_view_buf0_0,) ``` Here's a graph with dynamic shapes. This one is a little bit funky. Inductor provides a graph input for each shape symbol, which we map to a placeholder, in this example `s6`. Then, shape expressions in the generated code can refer to the symbol `s6`. The size hint for `s6` is stored in `node.meta["val"]` where `node` is the placeholder defining it. This works out in the generated python code because the placeholder defines a Python variable with the name `s6`. ``` graph(): %s6 : [num_users=0] = placeholder[target=s6] %arg1_1 : [num_users=1] = placeholder[target=arg1_1] %arg2_1 : [num_users=1] = placeholder[target=arg2_1] %buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ((s6,), (1,)), kwargs = {dtype: torch.float32, device: cuda:0}) %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [[-(((-s6)//8)), 1, 1]], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg2_1, in_ptr1: %arg1_1, out_ptr0: %buf0, xnumel: s6, XBLOCK: 8}}) return buf0 ``` Here's another graph, this time with dynamic shapes and strides. The grid expression is more complex since the numel is a product of dimensions. ``` graph(): %s10 : [num_users=0] = placeholder[target=s10] %arg1_1 : [num_users=1] = placeholder[target=arg1_1] %arg2_1 : [num_users=1] = placeholder[target=arg2_1] %buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ([s10, s10], [s10, 1]), kwargs = {dtype: torch.float32, device: cuda:0}) %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [[-(((s10**2)//(-64))), 1, 1]], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg2_1, in_ptr1: %arg1_1, out_ptr0: %buf0, xnumel: s10**2, XBLOCK: 64}}) return buf0 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/146942 Approved by: https://github.com/jansel |
||
|
|
99dac7005f |
Revert "[Inductor] FX backend via Wrapper IR (#146942)"
This reverts commit |
||
|
|
a7691140a0 |
[Inductor] FX backend via Wrapper IR (#146942)
# Sub-PRs These PRs contain refactors from the main one. They should be reviewed and merged first. - https://github.com/pytorch/pytorch/pull/150458 - https://github.com/pytorch/pytorch/pull/152391 - https://github.com/pytorch/pytorch/pull/152587 # Feature The goals of this PR are twofold. ## Goal 1: Introduce Wrapper IR as an intermediate step in wrapper codegen. In addition to Triton/C++/Halide kernels, Inductor also generates "wrapper" code which allocates memory and calls the kernels. Originally, this wrapper code was fairly standard Python which resembled a user-written PyTorch program. Over time, various wrapper code generators have been added to accommodate things like AOTInductor, which prefers C++ code for static compilation. This complexity has bled into other parts of the codebase, as we now need if/else statements to choose between Python and C++ macros. (See an example [here](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/ir.py#L5515-L5522).) Since most of these code generation steps are conceptually identical across target languages, it seems reasonable to refactor them into some kind of intermediate representation which can be shared between the various backends. This might also make it easier to develop out-of-tree backends which cannot put their own macros in core Inductor components. This PR takes some initial steps to formalize Inductor's wrapper codegen by generalizing the existing Memory Planning IR into a fully fledged Wrapper IR. This is pretty much identical to the existing Memory Planning IR, but it supports a richer set of ops for things like kernel definitions and calls. This refactor could help encapsulate wrapper codegen. Ideally, we don't need to worry about direct Python/C++ codegen in the main compiler files such as `ir.py`, and can instead defer these to classes like `PythonWrapperCodegen` and `CppWrapperCpu`, which operate on the Wrapper IR. ## Goal 2: Convert Wrapper IR into FX IR. One of the main benefits of Wrapper IR is to enable more diverse Inductor backends. This PR introduces a converter from Wrapper IR into [FX IR](https://pytorch.org/docs/stable/fx.html), which is the intermediate representation most commonly used in PyTorch graph compilers. The purpose of this is to enable out-of-tree backends to consume Inductor's output in FX IR, which would hopefully make Inductor easier to leverage in novel compilers, hardware accelerators, etc. It's not trivial to generate Python or C++ code which Inductor can compile and run, and doing so may require changes to other core Inductor files, for the reasons outlined in the previous section. The goal of supporting FX output is to enable something like `torch.compile`'s [custom backend](https://pytorch.org/docs/stable/torch.compiler_custom_backends.html) system, in which an out-of-tree backend can receive an optimized FX graph from Inductor, and compile and run it however it likes. The typical users of this feature would likely not be part of PyTorch, and may or may not support running a kernel in eager mode. However, they can understand what `torch.empty_strided` means, compile and run Triton kernels, etc. So we just need to present them with an FX graph saying what code Inductor wants to run, which should be easier to analyze and transform in a third party system than Python or C++ source. Since FX IR is fairly stable, this mechanism should hopefully isolate third-party backends, hardware accelerators, etc. from the implementation details of Inductor, and vice versa. # Current status Things that seem to work: - Converted a lot of the most common Python codegen lines to Wrapper IR lines. - Handled the following cases, in addition to what was already in the Memory Planning IR: - Comments - Triton kernels - Extern/fallback kernels - Freeing tensors (`del buf0`) - MultiOutput - Graph outputs - ReinterpretView / StorageBox, for both call args and outputs. - FX conversion asserts that the program only contains Wrapper IR lines, and not strings of Python/C++ code. - Prototype FX converter which can handle some of the most common use cases. - Defining Triton kernels, and putting them in a side table using TorchDynamo's existing [utilities](https://dev-discuss.pytorch.org/t/higher-order-operators-2023-10/1565). - Calling wrapped Triton kernels. - Calling extern kernels and certain types of fallback kernels. - Support both `extern_kernels.*` and `aten.*`. - Support multi-output kernels like `torch.topk`. - Graphs with multiple inputs/outputs. - Training i.e. calling `Tensor.backward()` in a compiled function. - Graph breaks (training). - Run the `torch.fx.GraphModule` on GPU using the standard `__call__` method. This makes it easy to test the correctness of FX codegen. Things that don't work: - Both Wrapper IR and Wrapper -> FX coverage are currently best effort. There are still features which aren't captured as Wrapper IR lines, and fall back to plain strings. This representation is functionally correct but probably not rich enough to achieve the goals outlined in the previous sections. - Fallback kernels seem like the most difficult thing to fully cover, since they each define their own Python/C++ macros that would need to be converted to FX. - Size/alignment asserts are currently disabled via the config file. It's possible to generate FX IR for these, but it seems reasonable to defer these sanity checks to a later PR. - CommBuffer's and distributed communication are not yet supported. An earlier version of this PR attempted to implement this by calling `empty_strided_p2p`. However, building and testing distributed support seems non-trivial, so it's probably better to defer this. # Out-of-tree compilers With this PR, out of tree backends will be able to do further compilation on the FX graphs by subclassing `WrapperFxCodegen` and overriding the `compile_graph` function. This follows the same API as torch.compile's [custom backends](https://pytorch.org/docs/stable/torch.compiler_custom_backends.html), where the user simply returns a callable running the graph. The callable need not be a method of `GraphModule` or any other PyTorch class. See an example below. ``` from torch._inductor.codegen.wrapper_fxir import WrapperFxCodegen class MyCustomBackend(WrapperFxCodegen): def compile_graph(self, gm): # Add 1 to the graph's outputs def compiled_fn(*args): return [x + 1 for x in gm.graph.forward(*args)] return compiled_fn ``` # Example FX graphs This section contains some example FX graphs generated by Inductor. The correctness of these graphs was verified against eager mode by calling the corresponding `GraphModule`. Here's an FX graph calling a basic Triton kernel. Notice how outputs are allocated with `torch.empty_strided`, and the Triton kernel is called by reference to Dynamo's triton side table. ``` graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %arg1_1 : [num_users=1] = placeholder[target=arg1_1] %buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ((8,), (1,)), kwargs = {dtype: torch.float32, device: cuda:0}) %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(8,)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg1_1, in_ptr1: %arg0_1, out_ptr0: %buf0, xnumel: 8, XBLOCK: 8}}) return (buf0,) ``` Here's a more complicated graph that calls a `torch.addmm` extern kernel. ``` graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %arg1_1 : [num_users=2] = placeholder[target=arg1_1] %buf0 : [num_users=3] = call_function[target=torch.empty_strided](args = ((), ()), kwargs = {dtype: torch.float32, device: cuda:0}) %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(1,)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg1_1, out_ptr0: %buf0, xnumel: 1, r0_numel: 129, XBLOCK: 1}}) %buf2 : [num_users=2] = call_function[target=torch.empty_strided](args = ((129, 1), (1, 1)), kwargs = {dtype: torch.float32, device: cuda:0}) %addmm : [num_users=0] = call_function[target=torch.addmm](args = (%buf0, %arg0_1, %arg1_1), kwargs = {alpha: 1, beta: 1, out: %buf2}) %delete : [num_users=0] = call_function[target=torch._inductor.codegen.wrapper_fxir.delete](args = (%buf0,), kwargs = {}) return (buf2,) ``` Here's a graph which indexes into a tuple using `operator.getitem`. This is necessary to use the output of the `torch.topk` operation. ``` graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %buf0 : [num_users=3] = call_function[target=torch.ops.aten.topk.default](args = (%arg0_1, 2), kwargs = {}) %buf1 : [num_users=2] = call_function[target=operator.getitem](args = (%buf0, 0), kwargs = {}) %buf2 : [num_users=2] = call_function[target=operator.getitem](args = (%buf0, 1), kwargs = {}) %delete : [num_users=0] = call_function[target=torch._inductor.codegen.wrapper_fxir.delete](args = (%buf0,), kwargs = {}) %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(2,)], tma_descriptor_metadata: {}, kwargs: {in_out_ptr0: %buf1, xnumel: 2, XBLOCK: 2}}) %triton_kernel_wrapper_mutation_1 : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 1, constant_args_idx: 1, grid: [(2,)], tma_descriptor_metadata: {}, kwargs: {in_out_ptr0: %buf2, xnumel: 2, XBLOCK: 2}}) return (buf1, buf2) ``` Here's a graph that reinterprets an output tensor using `torch.as_strided`. This is one way to handle Inductor's `ReinterpretView` op. ``` graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %arg1_1 : [num_users=1] = placeholder[target=arg1_1] %buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ((2, 4), (4, 1)), kwargs = {dtype: torch.float32, device: cuda:0}) %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [(8,)], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg0_1, in_ptr1: %arg1_1, out_ptr0: %buf0, xnumel: 8, XBLOCK: 8}}) %buf0_view_buf0_0 : [num_users=1] = call_function[target=torch.as_strided](args = (%buf0, (8,), (1,), 0), kwargs = {}) return (buf0_view_buf0_0,) ``` Here's a graph with dynamic shapes. This one is a little bit funky. Inductor provides a graph input for each shape symbol, which we map to a placeholder, in this example `s6`. Then, shape expressions in the generated code can refer to the symbol `s6`. The size hint for `s6` is stored in `node.meta["val"]` where `node` is the placeholder defining it. This works out in the generated python code because the placeholder defines a Python variable with the name `s6`. ``` graph(): %s6 : [num_users=0] = placeholder[target=s6] %arg1_1 : [num_users=1] = placeholder[target=arg1_1] %arg2_1 : [num_users=1] = placeholder[target=arg2_1] %buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ((s6,), (1,)), kwargs = {dtype: torch.float32, device: cuda:0}) %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [[-(((-s6)//8)), 1, 1]], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg2_1, in_ptr1: %arg1_1, out_ptr0: %buf0, xnumel: s6, XBLOCK: 8}}) return buf0 ``` Here's another graph, this time with dynamic shapes and strides. The grid expression is more complex since the numel is a product of dimensions. ``` graph(): %s10 : [num_users=0] = placeholder[target=s10] %arg1_1 : [num_users=1] = placeholder[target=arg1_1] %arg2_1 : [num_users=1] = placeholder[target=arg2_1] %buf0 : [num_users=2] = call_function[target=torch.empty_strided](args = ([s10, s10], [s10, 1]), kwargs = {dtype: torch.float32, device: cuda:0}) %triton_kernel_wrapper_mutation : [num_users=0] = call_function[target=torch.ops.higher_order.triton_kernel_wrapper_mutation](args = (), kwargs = {kernel_idx: 0, constant_args_idx: 0, grid: [[-(((s10**2)//(-64))), 1, 1]], tma_descriptor_metadata: {}, kwargs: {in_ptr0: %arg2_1, in_ptr1: %arg1_1, out_ptr0: %buf0, xnumel: s10**2, XBLOCK: 64}}) return buf0 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/146942 Approved by: https://github.com/jansel |