Commit Graph

337 Commits

Author SHA1 Message Date
Jason Ansel
0093735ccd [inductor] Use compile time config values in runtime (#124561)
This removes usage of torch._inductor.config from `torch._inductor.runtime`.  Fixing two issues:
1) If configs change we should really use the compile time ones
2) In compile workers, we want to use the parent process config

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124561
Approved by: https://github.com/yanboliang
ghstack dependencies: #124552, #124553, #124557, #124559, #124560, #124569
2024-04-22 18:46:40 +00:00
Jason Ansel
480585fd2b [inductor] Refactor runtime files into torch._inductor.runtime (part 1) (#124552)
I am planning to make the compile_worker process not import torch so it can start up much faster.  This stack is prep for that.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124552
Approved by: https://github.com/yanboliang
2024-04-22 18:41:12 +00:00
PyTorch MergeBot
16eea7c6a5 Revert "[inductor] Refactor runtime files into torch._inductor.runtime (part 1) (#124552)"
This reverts commit a7035cc11a.

Reverted https://github.com/pytorch/pytorch/pull/124552 on behalf of https://github.com/jeanschmidt due to There are internal breakages, already discussed with author and he'll FF ([comment](https://github.com/pytorch/pytorch/pull/124552#issuecomment-2070548223))
2024-04-22 18:28:05 +00:00
PyTorch MergeBot
30dec1da84 Revert "[inductor] Use compile time config values in runtime (#124561)"
This reverts commit 3af12447f8.

Reverted https://github.com/pytorch/pytorch/pull/124561 on behalf of https://github.com/jeanschmidt due to There are internal breakages, already discussed with author and he'll FF ([comment](https://github.com/pytorch/pytorch/pull/124561#issuecomment-2070537634))
2024-04-22 18:24:38 +00:00
Jason Ansel
3af12447f8 [inductor] Use compile time config values in runtime (#124561)
This removes usage of torch._inductor.config from `torch._inductor.runtime`.  Fixing two issues:
1) If configs change we should really use the compile time ones
2) In compile workers, we want to use the parent process config

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124561
Approved by: https://github.com/yanboliang
ghstack dependencies: #124552, #124553, #124557, #124559, #124560, #124569
2024-04-22 04:51:30 +00:00
Jason Ansel
a7035cc11a [inductor] Refactor runtime files into torch._inductor.runtime (part 1) (#124552)
I am planning to make the compile_worker process not import torch so it can start up much faster.  This stack is prep for that.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124552
Approved by: https://github.com/yanboliang
2024-04-22 04:51:05 +00:00
Edward Z. Yang
afa78ad08c Call writeline from writelines (#124515)
This makes it more convenient to add a breakpoint here.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124515
Approved by: https://github.com/albanD
2024-04-20 15:45:30 +00:00
Zhuoran Zhao
b0d83726bd [5/x][AMD][Lowering Enablement] Hipifying aoti code_wrapper (#124241)
Summary: as title

Test Plan:
CI & unit test

patch on top of https://www.internalfb.com/phabricator/paste/view/P1214895953 to test

Differential Revision: D56223917

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124241
Approved by: https://github.com/jansel, https://github.com/desertfire
2024-04-19 18:57:38 +00:00
Xuehai Pan
93e249969b [BE] enable ruff rule RSE and remove useless parentheses in raise statements (#124261)
Remove useless parentheses in `raise` statements if the exception type is raised with no argument.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124261
Approved by: https://github.com/albanD
2024-04-17 19:29:34 +00:00
Adnan Akhundov
03a05e791a Don't add non-integer Triton kernel arg 1 to equal_to_1 (#123886)
Summary: Triton compiler adds constnat argument 1 to `equal_to_1` [only when it's an int](8c5e33c77e/python/triton/runtime/jit.py (L275)). Here we restrict Inductor's `equal_to_1` in the same way.

Test Plan:

```
$ python test/inductor/test_triton_kernels.py -k test_triton_kernel_equal_to_1_float_arg
...
----------------------------------------------------------------------
Ran 1 test in 6.528s

OK

$ python test/inductor/test_triton_kernels.py -k test_triton_kernel_equal_to_1_arg
...
----------------------------------------------------------------------
Ran 2 tests in 10.142s

OK
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123886
Approved by: https://github.com/oulgen
ghstack dependencies: #123703
2024-04-14 20:34:05 +00:00
Adnan Akhundov
3d9dc976ae Handle unqualified imports in custom Triton kernels (#123703)
Summary: If in a custom (user-written) Triton kernel an externally imported symbol is used directly, we need to codegen the corresponding import outside the kernel body in the Python wrapper. E.g., if the user code has this:

```
    from triton.language.extra.cuda.libdevice import fast_dividef

    @triton.jit
    def my_kernel(...):
        ...
        x = fast_dividef(...)
        ...
```

The `from triton.language.extra.cuda.libdevice import fast_dividef` line needs to be carried over together with the `my_kernel` function. The PR adds this.

Test Plan:

```
$ python test/inductor/test_triton_kernels.py
...
----------------------------------------------------------------------
Ran 464 tests in 113.512s

OK
```

Differential Revision: [D55953241](https://our.internmc.facebook.com/intern/diff/D55953241)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123703
Approved by: https://github.com/jansel, https://github.com/oulgen
2024-04-11 23:49:25 +00:00
Brian Hirsh
134e56fa33 inductor: log unique id to match output_code to aot graphs (#118647)
I found it helpful to be able to see, given some inductor output code, which AOT graph it came from. When you have large models with multiple graphs floating around this can be difficult, so I added the aot_config.aot_id to the printed inductor output.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118647
Approved by: https://github.com/ezyang
2024-04-11 14:37:07 +00:00
Adnan Akhundov
c773913407 Add torch.while_loop support to AOT Inductor (#123586)
Summary: Previously, `torch.while_loop` was supported only in JIT inductor (added in https://github.com/pytorch/pytorch/pull/122069). Here we extend the support to AOT Inductor.

Test Plan:

```
$ python test/inductor/test_aot_inductor.py -k test_while_loop
...
----------------------------------------------------------------------
Ran 24 tests in 129.236s

OK (skipped=8)

$ python test/inductor/test_control_flow.py
...
----------------------------------------------------------------------
Ran 50 tests in 136.199s

OK
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123586
Approved by: https://github.com/jansel, https://github.com/chenyang78
2024-04-09 22:53:10 +00:00
Bin Bao
064a650b63 [AOTI][refactor] Improve generate_extern_kernel_out's signature (#123351)
Summary: Annotate types and make the names more readable.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123351
Approved by: https://github.com/chenyang78
ghstack dependencies: #123346
2024-04-04 23:23:50 +00:00
ydwu4
a4035bea5c [while_loop] support closures (#123018)
We add an additional_inputs arguments to the HOP while_loop and rename the operands to carried_inputs based on offline discussion with @zou3519 . This allows us to support closures, parameters and buffers.

The alternative is to pass the lifted inputs directly to outputs of body_fn. But since we want the body_fn's output to not aliasing input. We'll need to copy the inputs and remove the copies later. This is a bit more work to do.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123018
Approved by: https://github.com/aakhundov
ghstack dependencies: #123217
2024-04-03 19:35:15 +00:00
Mu-Chu Lee
4b725e1619 [AOTInductor] Support quantized linear on CPU with fbgemm (#123069)
Summary:
Added support for quantized linear on CPU with fbgemm.
Specifically, for torch.ops.quantized.linear_unpacked_dynamic_fp16, we
decompose it into two steps, pack weight, and fbgemm's qlinear with
packed weight.

Test Plan:
Included in commit.
test_aot_inductor::test_quantized_linear

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D55577959](https://our.internmc.facebook.com/intern/diff/D55577959)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123069
Approved by: https://github.com/hl475
2024-04-01 09:15:05 +00:00
Edward Z. Yang
10bdf64427 Properly pexpr the actual sympy.Expression, don't repr it. (#122893)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122893
Approved by: https://github.com/albanD, https://github.com/desertfire, https://github.com/jansel
2024-03-29 06:40:19 +00:00
Mu-Chu Lee
091a24495b [AOTInductor] Support use_runtime_constant_folding for CPU. (#122563)
Summary:
We allow CPU to use the config use_runtime_constant_folding.
Changes include
1. Rearrange USE_CUDA flags. Add CPU sections that consumes memory directly.
2. Codegen changes to accomodate cpp fusions for CPU only. Specifically, we shouldn't generate 2 headers that would cause re-declaration.

Test Plan: Activate tests that were deactivated for CPU before.

Reviewed By: khabinov

Differential Revision: D55234300

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122563
Approved by: https://github.com/chenyang78
2024-03-28 17:49:05 +00:00
Bin Bao
537cd66e73 [Inductor] Support custom op in JIT with cpp wrapper (#122554)
Summary:  To call custom ops in an ABI-compatible way requires doing boxed call with varargs across C shim. In the JIT mode, we can get around it by calling into Python.  https://gist.github.com/desertfire/be2a65b0a9b47780bb716b53ac2cd2b3 is an example of generated code.

Differential Revision: [D55326556](https://our.internmc.facebook.com/intern/diff/D55326556)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122554
Approved by: https://github.com/jansel, https://github.com/chenyang78
2024-03-26 18:48:45 +00:00
Adnan Akhundov
9223b2cb31 Pop codegened parent graph from wrapper in GraphLowering (#122469)
Summary: Previously, we kept a reference to `V.graph` in the `codegened_graph_stack` of the wrapper. Memory regression analysis of https://github.com/pytorch/pytorch/issues/121887 shows that this has led to a slightly higher memory utilization during lowering of the `llama_v2_7b_16h` model. Here we refactor the code to pop the parent subgraph from the `codegened_graph_stack` when codegen-ing is done.

Fixes https://github.com/pytorch/pytorch/issues/121887.

Test Plan: CI, also see https://github.com/pytorch/pytorch/issues/121887#issuecomment-2014209104.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122469
Approved by: https://github.com/eellison
2024-03-25 20:27:59 +00:00
chilli
4d8a3f8bb3 changed aliasing checks to properly recurse for computing last usage (#122444)
Fixes https://github.com/pytorch/pytorch/issues/122457

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122444
Approved by: https://github.com/yifuwang, https://github.com/jansel
ghstack dependencies: #121624, #122474
2024-03-23 01:43:21 +00:00
chilli
d34514f8db Renamed mutationlayout/aliasedlayout (#122474)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122474
Approved by: https://github.com/jansel
ghstack dependencies: #121624
2024-03-22 08:32:14 +00:00
Adnan Akhundov
e419011471 [inductor] Add torch.while_loop support to JIT Inductor (#122069)
Summary: `torch.while_loop` HOP support is added to JIT Inductor. The test coverage is limited due to the functionality constraints of the upstream `torch.while_loop` op in Dynamo / Export. When those are lifted, we'll add more tests (see TODO-s in the test file).

AOT Inductor support will be added in a follow-up PR.

Test Plan:

```
$ python test/inductor/test_control_flow.py
...
----------------------------------------------------------------------
Ran 38 tests in 159.387s

OK
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122069
Approved by: https://github.com/jansel, https://github.com/eellison
2024-03-22 02:45:27 +00:00
Kai Londenberg
2cd0a5d516 [Inductor] Fix for WrapperCodeGen.statically_known_int_or_none (#121808)
There's obviously a small typo in WrapperCodeGen.statically_known_int_or_none,
where the return value of a call to V.graph._shape_env._maybe_evaluate_static
is being discarded.

This fix changes that to work how it was likely intended to.

Test Plan:
CI

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121808
Approved by: https://github.com/lezcano, https://github.com/jansel, https://github.com/aakhundov
2024-03-21 21:15:32 +00:00
Adnan Akhundov
456b112dca [inductor] Support non-Tensor predicate in torch.cond (#122378)
Summary: Previously, we only supported torch.Tensor boolean scalar predicate in `torch.cond` in Inductor. This PR adds support for SymBool and Python bool predicate, to match the `torch.cond` [sematics](https://pytorch.org/docs/stable/generated/torch.cond.html) in Dynamo / Export.

Test Plan:

```
$ python test/inductor/test_control_flow.py
...
----------------------------------------------------------------------
Ran 34 tests in 56.980s

OK

$ python test/inductor/test_aot_inductor.py -k test_cond
...
----------------------------------------------------------------------
Ran 54 tests in 460.093s

OK (skipped=4)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122378
Approved by: https://github.com/jansel, https://github.com/chenyang78
2024-03-21 14:35:01 +00:00
Oguz Ulgen
c0b2e56c8f Support triton.language.dtype with torch.compile -- Second Attempt (#122141)
This PR is the second attempt at supporting `triton.language.dtype`, now instead of putting it on the graph, we put it on the side table since it is a constant.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122141
Approved by: https://github.com/jansel
ghstack dependencies: #122140
2024-03-19 19:40:52 +00:00
Oguz Ulgen
7c5e29ae71 Back out "Support triton.language.dtype with torch.compile (#121690)" (#122108)
Summary: Some hard to deal with package import/export related problems. Lets revert and start with clean slate.

Test Plan: CI

Differential Revision: D55024877

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122108
Approved by: https://github.com/ezyang
2024-03-18 20:50:28 +00:00
Bin Bao
818b14025a [AOTI][refactor] Remove is_legacy_abi_kernel and abi_compatible_kernel (#121523)
Summary: is_legacy_abi_kernel was used for _scaled_dot_product_flash_attention fallback. It is only needed for C shim kernel name matching now, and the name matching is done with a direct string comparison. Also consolidate the fallback cpp kernel naming logic in CppWrapperCpu.

Differential Revision: [D54727789](https://our.internmc.facebook.com/intern/diff/D54727789)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121523
Approved by: https://github.com/chenyang78
2024-03-14 22:05:38 +00:00
Oguz Ulgen
79ee6bbde3 Support triton.language.dtype with torch.compile (#121690)
Putting this PR as an RFC since I have resorted to some horrible hacks in order to make this work.
```
(Pdb) p triton.language.float32
triton.language.fp32
(Pdb) p str(triton.language.float32)
'fp32'
(Pdb) p repr(triton.language.float32)
'triton.language.fp32'
```
This means that we need to "rewrite" them for fx graph and inductor execution.

This PR allows Mamba2 to work with `torch.compile`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121690
Approved by: https://github.com/Skylion007
2024-03-12 23:21:46 +00:00
Peter Bell
459c5bca58 [inductor] Refactor common triton imports into one function (#121438)
This means when codegen depends on a particular import we only need to
add it in one place and it's applied to all triton kernels.

This also changes codegen slightly so instead of generating
`@pointwise` we now generate `@triton_heuristics.pointwise` just so
the imports are the same for all kernel types.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121438
Approved by: https://github.com/lezcano
2024-03-09 18:17:36 +00:00
Adnan Akhundov
3d089de851 Add torch.cond support to AOT Inductor (#121120)
Summary: In this PR, `torch.cond` support and the necessary codegening infrastructure is added to C++ wrapper (AOTInductor and friends).

Notable additions:

- A new mechanism in the Python wrapper codegen to precompile and save the Triton kernels (generated and user-defined) which haven't been covered by the active path through the control flow given the sample inputs. As we can't do the runtime autotuning of the kernels outside the active path, we precompile and save them with the `launchers[0]` (corresponding to the first config).

- Codegen infra for `torch.cond` in the C++ wrapper (ABI- and non-ABI-compatible). The `torch.cond` codegen has been slightly refactored to avoid duplication across the Python and C++ wrappers.

- More extensions of the caching sites in the wrapper code to cache per codegened graph (e.g., `codegen_int_array_var`) + some infra for tracking the current codegened graph in the wrapper (both during codegen-ing in the `Scheduler.codegen` and in the `WrapperCodeGen.generate` functions).

- New unit tests to cover the added AOT Inductor + `torch.cond` functionality.

Codegen examples from the new unit tests:

- [`test_cond_simple_abi_compatible_cpu`](https://gist.github.com/aakhundov/862d5de9aa460f5df399e1387f7b342e)
- [`test_cond_simple_abi_compatible_cuda`](https://gist.github.com/aakhundov/d70b81f95fa8cc768cedef9acacb25bb)
- [`test_cond_simple_non_abi_compatible_cpu`](https://gist.github.com/aakhundov/c0ae7a8cbb6fa311c838e1b580f9a3f6)
- [`test_cond_simple_non_abi_compatible_cuda`](https://gist.github.com/aakhundov/08b945d4e8a32c97b7f9ff6272f4a223)
- [`test_cond_nested_abi_compatible_cuda`](https://gist.github.com/aakhundov/ce664f433c53e010ce4c0d96a6c13711)
- [`test_cond_with_parameters_abi_compatible_cuda`](https://gist.github.com/aakhundov/77afbeb8eaab5c5b930a3f922a7baf12)
- [`test_cond_with_multiple_outputs_abi_compatible_cuda`](https://gist.github.com/aakhundov/8cc06105ec8a3fe88be09b3f6e32c690)

Test Plan:

```
$ python test/inductor/test_aot_inductor.py -k test_cond
...
----------------------------------------------------------------------
Ran 42 tests in 170.619s

OK
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121120
Approved by: https://github.com/jansel, https://github.com/chenyang78
2024-03-07 22:39:57 +00:00
Oguz Ulgen
18d574a07a [Inductor] Use indices for constants in triton_meta (#121427)
@bertmaher pointed out that constants are passed with their indices, not their names. Looking at triton source, this appears to be true 392370b303/python/triton/runtime/jit.py (L381-L385)
I'm guessing both indices and names work here but lets be consistent.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121427
Approved by: https://github.com/aakhundov
2024-03-07 21:59:43 +00:00
Bin Bao
7e598c0053 [Inductor] Enable ABI-compatible mode for cpp-wrapper JIT (#121309)
Differential Revision: [D54617284](https://our.internmc.facebook.com/intern/diff/D54617284)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121309
Approved by: https://github.com/chenyang78
2024-03-07 14:22:06 +00:00
Oguz Ulgen
6566b3db67 Add an autotune cache for inductor generated kernels (#120963)
Summary: Inductor currently has a best config cache for kernels that it generates. This is a local cache done via writing to the file system. This diff takes this local cache to remote by reusing the existing triton caching mechanism built via Memcache internally and Redis externally.

Test Plan:
tested locally using `TORCH_INDUCTOR_AUTOTUNE_REMOTE_CACHE =1`

Look at scuba to verify the local testing: https://fburl.com/scuba/triton_remote_cache/z6pypznk

The plan is to land this diff with this turned off and gradually introduce this.

Differential Revision: D54398076

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120963
Approved by: https://github.com/jansel
2024-03-04 16:58:37 +00:00
Oguz Ulgen
558316b5f4 Emit grid wrapper inlined with the user defined triton kernel (#120824)
Fixes #120801

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120824
Approved by: https://github.com/chenyang78, https://github.com/jansel
ghstack dependencies: #120809
2024-02-29 16:17:45 +00:00
Oguz Ulgen
84e2accd6c Make triton_meta be part of user defined triton kernel cache (#120809)
Tensors with different shapes will generate different triton meta (divisibility rules), we need this to be part of the cache key.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120809
Approved by: https://github.com/chenyang78, https://github.com/jansel
2024-02-29 16:17:45 +00:00
Adnan Akhundov
0a46102b37 Add equal_to_1 to triton_meta for user-written Triton kernels (#120579)
Summary: Previously, we omitted `equal_to_1` from the `triton_meta` part of the `@user_autotune` decorator. For user-written Triton kernels, this could lead to perf regressions, as the kernel in the Inductor codegen is compiled without `equal_to_1` specialization.

Fixes #120478. The repro from the issue, on A100:

Before this PR:

```
Triton matmul:           0.0167 seconds
Triton matmul compiled:  0.0751 seconds
```

After this PR:

```
Triton matmul:           0.0168 seconds
Triton matmul compiled:  0.0072 seconds
```

Test Plan:

```
$ python test/dynamo/test_triton_kernels.py -k  test_triton_kernel_equal_to_1_arg
...
----------------------------------------------------------------------
Ran 3 tests in 3.545s

OK
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120579
Approved by: https://github.com/oulgen, https://github.com/jansel, https://github.com/chenyang78
2024-02-29 05:19:39 +00:00
Adnan Akhundov
2d17230212 [inductor] Do not reuse buffers across scopes in mem planning (#120777)
Summary: Previously, in the `memory_plan_reuse` we assumed that the generated code is flat: in the sense of it can't have nested scopes. However, with nested control flow codegen-ing, this is no longer the case. This causes bugs in buffers being reused across the visibility boundaries in different nested scopes.

In this PR, we add nested planning states in `memory_plan_reuse` on entering and exiting scope in the codegen. This restricts the buffer reusability only to the currently active (peak) scope / planning state.

Test Plan:

```
python test/inductor/test_control_flow.py -k test_subgraphs_with_parameters
...
----------------------------------------------------------------------
Ran 27 tests in 149.413s

OK
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120777
Approved by: https://github.com/chenyang78, https://github.com/desertfire, https://github.com/jansel
ghstack dependencies: #120665
2024-02-29 03:52:02 +00:00
Jason Ansel
01ec8df6d8 [Compiled Autograd] Introduce BackwardState capture (#120382)
This adds support for backwards hooks that are *both*:
1) Interior to the graph; and
2) Dynamically generated (e.g. lambdas)

We do this by creating a BackwardState object that is used to register the hooks in the forward, then populated by dynamo *after* the forwards runs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120382
Approved by: https://github.com/xmfan
2024-02-28 20:36:47 +00:00
Yang Chen
b96ea097ee [aotinductor] rename CppWrapperCodeGen and CudaWrapperCodeGen (#120391)
make WrapperCodeGen subclass names consistent with the
file names:

CppWrapperCodeGen -> CppWrapperCpu
CudaWrapperCodeGen -> CppWrapperCuda

Differential Revision: [D54074938](https://our.internmc.facebook.com/intern/diff/D54074938)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120391
Approved by: https://github.com/aakhundov
2024-02-23 10:41:50 +00:00
Oguz Ulgen
29b2131c62 [Inductor] Fix bug around out of order constexprs in inductor (#120287)
Inductor signature/config generation code assumes that all constexprs come as last arguments of the function. This is not always true for user defined kernels.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120287
Approved by: https://github.com/jansel
2024-02-21 17:39:41 +00:00
wangjiangben-hw
26610175d2 pass device_str for async_compile.triton function (#120202)
Fixes #120203

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120202
Approved by: https://github.com/jansel
2024-02-21 03:48:57 +00:00
Adnan Akhundov
badf84bd6b [inductor] Add torch.cond support to JIT Inductor (#119759)
Summary: `torch.cond` is already supported in Dynamo and Export: the `true_fn` and `false_fn` subgraphs are traced as child fx graphs of the main graph and passed to the `torch.cond` higher-order operator in the fx graph. However, this breaks in Inductor, as the latter doesn't have the ways of dealing with child fx subgraphs and properly lowering and codegen-ing them.

In this PR, we add `torch.cond` support in Inductor. This is achieved by adding subgraph lowering and codegen-ing infrastructure as well as new `Conditional` IR node type weaving the parent graph with the true and false child subgraphs.

Here we only implement `torch.cond` support in JIT Inductor (Python wrapper codegen). The implementation in AOT Inductor (C++ wrapper codegen), including ABI-compatibility mode, will follow.

Test Plan:

```
$ python test/inductor/test_control_flow.py
...
----------------------------------------------------------------------
Ran 24 tests in 86.790s
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119759
Approved by: https://github.com/jansel, https://github.com/eellison
2024-02-17 07:25:27 +00:00
Yang Chen
bc7f3efb09 [aot_inductor] move CppWrapperCodeGen into a separate file (#119871)
This reverts commit d8e319a961.

Differential Revision: [D53817853](https://our.internmc.facebook.com/intern/diff/D53817853)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119871
Approved by: https://github.com/albanD, https://github.com/khabinov
ghstack dependencies: #119870
2024-02-16 08:14:20 +00:00
Yang Chen
78c9b2948a [aot_inductor] move CudaWrapperCodeGen into a separate file (#119870)
This reverts commit 3ab08946d5.

Differential Revision: [D53817852](https://our.internmc.facebook.com/intern/diff/D53817852)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119870
Approved by: https://github.com/khabinov
2024-02-16 08:10:51 +00:00
Adnan Akhundov
c2a835d710 [inductor] Refactor device guard Python codegen to allow nested indentation (#119673)
Summary: The codegen of `with torch.cuda._DeviceGuard` context manager in the Python wrapper code is implemented via `device_cm_stack: contextlib.ExitStack()`. As the context managers in the stack are `code.indent()`, this means that the whole stack is unindented at once on `device_cm_stack.close()`. This becomes problematic when attempting to codegen indented code (e.g., for control flow in Python and / or nested subgraph codegen-ing).

In this PR, we refactor the device guard codegen-ing in Python by replacing the `device_cm_stack` by explicit indent and unindent calls for entering and exiting the `with torch.cuda._DeviceGuard` context manager. This allows for nested device guard context managers and better aligns with other indented codegen-ing intertwined with it (e.g., for nested subgraph codegen-ing).

This is necessary for the upcoming support for `torch.cond` (and other control flow operators) in Inductor. Before that, the only change in the Python wrapper codegen is that the `return outputs` is now happening outside the `with torch.cuda._DeviceGuard` context manager.

Test Plan: CI

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119673
Approved by: https://github.com/peterbell10
2024-02-13 15:05:30 +00:00
Bin Bao
70c93c6097 [inductor] Update JIT Inductor cpp wrapper entry function signature (#119280)
Summary: Change JIT Inductor cpp wrapper entry function to use similar signature as AOTInductor, i.e. using an array of AtenTensorHandle instead of a vector of at::Tensor as the inputs and return output through a pointer. This makes it easier to consolidate the ABI compatible and non-compatible modes.

Differential Revision: [D53478825](https://our.internmc.facebook.com/intern/diff/D53478825)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119280
Approved by: https://github.com/chenyang78
2024-02-12 22:24:35 +00:00
Adnan Akhundov
e5f46a1d35 Check alignment of ReinterpretView args of custom Triton kernels (#119649)
Summary: Currently, when a custom (user-written) Triton kernel has a ReinterpretView argument in IR, we're always skipping the alignment checking for this argument when preparing the `signature_of` for the AOT compilation of the Triton kernel (via setting `TensorArg.check_alignment` to `False`). This is problematic for user-written kernels where, albeit reinterpreted, the argument of the Triton kernel (the data pointer) can still be aligned to 16. When we skip alignment checking, the performance of the AOT-compiled internal Triton kernels can degrade 2x--3x.

In this PR, we replace `TensorArg.check_alignment` by `TensorArg.offset`, in which we specify the offset of the `ReinterpretView.layout` relative to the underlying `ir.Buffer` (corresponding to the data pointer before reinterpretation). As the size and stride of the layout don't change the alignment properties, those can be skipped. Importantly, for `ReinterpretView` arguments of custom Triton kernels, we use `arg.data.get_name()` as the buffer name. That, together with the offset, is used to check the alignment.

Bonus: the namedtuples in `codegen/common.py` are refactored as `dataclass`es, with nicer type hints and default values (for the newly added `TensorArg.offset`).

Test Plan:

```
$ python test/inductor/test_aot_inductor.py -k test_triton_kernel_reinterpret_view
...
----------------------------------------------------------------------
Ran 6 tests in 27.952s

OK (skipped=4)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119649
Approved by: https://github.com/oulgen
2024-02-11 20:21:17 +00:00
Adnan Akhundov
0bed0501fa Don't skip register-spilling configs in custom Triton kernel auto-tuning (#119634)
Summary: There has been some empirical evidence that, for (non-trivial) custom (user-written) Triton kernels, a register-spilling config yields the best result in auto-tuning. For this reason, we don't skip register-spilling config from auto-tuning of the custom Triton kernels.

<details>
<summary>An example of auto-tuning result with the register-spilling config outperforming others</summary>

```
BLOCK_M: 16, BLOCK_N: 16, num_warps: 2, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.748896, nreg 255, nspill 0, #shared-mem 8704
BLOCK_M: 16, BLOCK_N: 16, num_warps: 4, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 1.723424, nreg 249, nspill 0, #shared-mem 8704
BLOCK_M: 16, BLOCK_N: 16, num_warps: 8, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 2.202656, nreg 190, nspill 0, #shared-mem 8704
BLOCK_M: 16, BLOCK_N: 16, num_warps: 2, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.748256, nreg 255, nspill 0, #shared-mem 8704
BLOCK_M: 16, BLOCK_N: 16, num_warps: 4, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 1.724896, nreg 249, nspill 0, #shared-mem 8704
BLOCK_M: 16, BLOCK_N: 16, num_warps: 8, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 2.201632, nreg 190, nspill 0, #shared-mem 8704
BLOCK_M: 16, BLOCK_N: 32, num_warps: 2, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.651664, nreg 255, nspill 56, #shared-mem 13312
BLOCK_M: 16, BLOCK_N: 32, num_warps: 4, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.846368, nreg 255, nspill 14, #shared-mem 13312
BLOCK_M: 16, BLOCK_N: 32, num_warps: 8, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 1.841792, nreg 243, nspill 0, #shared-mem 13312
BLOCK_M: 16, BLOCK_N: 32, num_warps: 2, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.651584, nreg 255, nspill 56, #shared-mem 13312
BLOCK_M: 16, BLOCK_N: 32, num_warps: 4, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.846432, nreg 255, nspill 14, #shared-mem 13312
BLOCK_M: 16, BLOCK_N: 32, num_warps: 8, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 1.841904, nreg 243, nspill 0, #shared-mem 13312
BLOCK_M: 16, BLOCK_N: 64, num_warps: 2, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 1.236448, nreg 255, nspill 254, #shared-mem 22528
BLOCK_M: 16, BLOCK_N: 64, num_warps: 4, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 1.484384, nreg 255, nspill 174, #shared-mem 22528
BLOCK_M: 16, BLOCK_N: 64, num_warps: 8, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 1.131168, nreg 255, nspill 6, #shared-mem 22528
BLOCK_M: 16, BLOCK_N: 64, num_warps: 2, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 1.236544, nreg 255, nspill 254, #shared-mem 22528
BLOCK_M: 16, BLOCK_N: 64, num_warps: 4, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 1.483648, nreg 255, nspill 174, #shared-mem 22528
BLOCK_M: 16, BLOCK_N: 64, num_warps: 8, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 1.131408, nreg 255, nspill 6, #shared-mem 22528
BLOCK_M: 32, BLOCK_N: 16, num_warps: 2, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.516112, nreg 255, nspill 28, #shared-mem 13312
BLOCK_M: 32, BLOCK_N: 16, num_warps: 4, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.737792, nreg 255, nspill 0, #shared-mem 13312
BLOCK_M: 32, BLOCK_N: 16, num_warps: 8, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 1.411632, nreg 193, nspill 0, #shared-mem 13312
BLOCK_M: 32, BLOCK_N: 16, num_warps: 2, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.515904, nreg 255, nspill 28, #shared-mem 13312
BLOCK_M: 32, BLOCK_N: 16, num_warps: 4, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.736608, nreg 255, nspill 0, #shared-mem 13312
BLOCK_M: 32, BLOCK_N: 16, num_warps: 8, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 1.409808, nreg 193, nspill 0, #shared-mem 13312
BLOCK_M: 32, BLOCK_N: 32, num_warps: 2, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.553536, nreg 255, nspill 130, #shared-mem 18432
BLOCK_M: 32, BLOCK_N: 32, num_warps: 4, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.569792, nreg 255, nspill 56, #shared-mem 18432
BLOCK_M: 32, BLOCK_N: 32, num_warps: 8, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.892448, nreg 255, nspill 4, #shared-mem 18432
BLOCK_M: 32, BLOCK_N: 32, num_warps: 2, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.553584, nreg 255, nspill 130, #shared-mem 18432
BLOCK_M: 32, BLOCK_N: 32, num_warps: 4, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.569568, nreg 255, nspill 56, #shared-mem 18432
BLOCK_M: 32, BLOCK_N: 32, num_warps: 8, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.892240, nreg 255, nspill 4, #shared-mem 18432
BLOCK_M: 32, BLOCK_N: 64, num_warps: 2, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 1.332928, nreg 255, nspill 366, #shared-mem 28672
BLOCK_M: 32, BLOCK_N: 64, num_warps: 4, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.922256, nreg 255, nspill 228, #shared-mem 28672
BLOCK_M: 32, BLOCK_N: 64, num_warps: 8, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.758400, nreg 255, nspill 26, #shared-mem 28672
BLOCK_M: 32, BLOCK_N: 64, num_warps: 2, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 1.333440, nreg 255, nspill 366, #shared-mem 28672
BLOCK_M: 32, BLOCK_N: 64, num_warps: 4, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.922336, nreg 255, nspill 228, #shared-mem 28672
BLOCK_M: 32, BLOCK_N: 64, num_warps: 8, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.758496, nreg 255, nspill 26, #shared-mem 28672
BLOCK_M: 64, BLOCK_N: 16, num_warps: 2, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 1.231648, nreg 255, nspill 292, #shared-mem 22528
BLOCK_M: 64, BLOCK_N: 16, num_warps: 4, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.639424, nreg 255, nspill 90, #shared-mem 22528
BLOCK_M: 64, BLOCK_N: 16, num_warps: 8, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.917952, nreg 240, nspill 0, #shared-mem 22528
BLOCK_M: 64, BLOCK_N: 16, num_warps: 2, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 1.230624, nreg 255, nspill 292, #shared-mem 22528
BLOCK_M: 64, BLOCK_N: 16, num_warps: 4, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.639168, nreg 255, nspill 90, #shared-mem 22528
BLOCK_M: 64, BLOCK_N: 16, num_warps: 8, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.917440, nreg 240, nspill 0, #shared-mem 22528
BLOCK_M: 64, BLOCK_N: 32, num_warps: 2, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.838080, nreg 255, nspill 354, #shared-mem 28672
BLOCK_M: 64, BLOCK_N: 32, num_warps: 4, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.569184, nreg 255, nspill 178, #shared-mem 28672
BLOCK_M: 64, BLOCK_N: 32, num_warps: 8, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.614720, nreg 255, nspill 28, #shared-mem 28672
BLOCK_M: 64, BLOCK_N: 32, num_warps: 2, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.838048, nreg 255, nspill 354, #shared-mem 28672
BLOCK_M: 64, BLOCK_N: 32, num_warps: 4, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.569472, nreg 255, nspill 178, #shared-mem 28672
BLOCK_M: 64, BLOCK_N: 32, num_warps: 8, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.615104, nreg 255, nspill 28, #shared-mem 28672
BLOCK_M: 64, BLOCK_N: 64, num_warps: 2, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 1.012128, nreg 255, nspill 522, #shared-mem 40960
BLOCK_M: 64, BLOCK_N: 64, num_warps: 4, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.861536, nreg 255, nspill 378, #shared-mem 40960
BLOCK_M: 64, BLOCK_N: 64, num_warps: 8, num_ctas: 1, num_stages: 1, enable_warp_specialization: False, enable_persistent: False: 0.771584, nreg 255, nspill 134, #shared-mem 40960
BLOCK_M: 64, BLOCK_N: 64, num_warps: 2, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 1.012512, nreg 255, nspill 522, #shared-mem 40960
BLOCK_M: 64, BLOCK_N: 64, num_warps: 4, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.861024, nreg 255, nspill 378, #shared-mem 40960
BLOCK_M: 64, BLOCK_N: 64, num_warps: 8, num_ctas: 1, num_stages: 2, enable_warp_specialization: False, enable_persistent: False: 0.771712, nreg 255, nspill 134, #shared-mem 40960
```

</details>

In the above, the winning config is `BLOCK_M: 32, BLOCK_N: 16, num_warps: 2, num_ctas: 1, num_stages: 2`, although it has non-zero `nspill 28`. This is an example where we need to consider all configs, including the register-spilling ones, to obtain the best result from auto-tuning.

In the worst case, this will just make auto-tuning longer, but can't regress the results. And, as the number of custom Triton kernels in the model is normally much smaller than the number of Inductor-generated ones, this should be acceptable.

Test Plan: CI

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119634
Approved by: https://github.com/oulgen
2024-02-11 02:13:25 +00:00
PyTorch MergeBot
3ab08946d5 Revert "[aot_inductor] move CudaWrapperCodeGen into a separate file (#119448)"
This reverts commit 0597dab523.

Reverted https://github.com/pytorch/pytorch/pull/119448 on behalf of https://github.com/DanilBaibak due to Broken trunk ([comment](https://github.com/pytorch/pytorch/pull/119448#issuecomment-1937345167))
2024-02-10 23:04:36 +00:00