Commit Graph

57 Commits

Author SHA1 Message Date
Max Podkorytov
7ef2c62fd3 [ROCm][Inductor][CK] Add ck-tile based universal gemm kernels to torch.mm autotune choices (#152341)
This PR adds code generation for CK-tile based universal gemm kernels to the CK backend for Inductor, and adds these kernels to autotune choices.

Unlike legacy-CK based kernels (which are generated by parsing the CK instances from CK library), we generate the set of instances by manually specifying the tuning parameters.

This PR introduces a new template for code generation, and compilation/autotuning is handled by the existing infrastructure.

Points of discussion:

* For simplicity and reduced coupling with CK, the instance filter checks only data type and layout, and doesn't check the alignment requirement - meaning that more instances will be compiled than necessary - while keeping the code generation independent from internal CK logic which checks the alignment validity at runtime
* CK-tile instances are enabled whenever legacy-CK instances are enabled. A config knob could be introduced to differentiate between the instance types if that's needed
* Whether gemm problem size K is ever dynamic, since whenever it's not a compile-time constant, we need to perform a runtime dispatch between several kernels

** Testing **

Use the existing tests in `test/inductor/test_ck_backend.py`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152341
Approved by: https://github.com/chenyang78
2025-05-21 23:59:16 +00:00
Bin Bao
33a5179269 [AOTI][reland2] Remove typedef for half and bfloat16 (#153467)
Summary:
Reland https://github.com/pytorch/pytorch/pull/151109 after fixing cutlass AOTI build issues.

typedef is prone to name collision. Explicitly spell out the actual aten types, needed for the standalone AOTI codegen.

Differential Revision: D74398762

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153467
Approved by: https://github.com/jingsh, https://github.com/henrylhtsang, https://github.com/cyyever
2025-05-14 02:37:18 +00:00
Sam Larsen
f1de3f9f07 Rename "output_tensor" -> "out" in autotune_process.py (#153169)
Summary: This change is to support remote autotuning. I want to use all the same benchmarking utilities in select_algorithm.py. For remote autotuning, I'll reuse the TritonBenchmarkRequest class used for subprocess autotuning because it's already serializable. That class is also used in standard, in-process autotuning, but via TritonTemplateCaller.benchmark() which sets the output_tensor param when calling the underlying TritonBenchmarkRequest. For remote, I'll be using the TritonBenchmarkRequest request directly so I want the parameter to be named 'out' to avoid "got an unexpected keyword argument 'out'".

Test Plan: Existing unit tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153169
Approved by: https://github.com/aorenste, https://github.com/eellison
2025-05-13 14:18:29 +00:00
Jason Ansel
b040dc3a53 Reland: [inductor] Simplify grid handling (#148305)
Summary:
Relands D69965761 / https://github.com/pytorch/pytorch/pull/147583

Before this PR, calling a triton kernel would look like:
```py
kernel.run(a, b, xnumel, grid=grid(xnumel), stream=stream0)
```
where the `grid=` was passed as a callable (function closure) arg.  This PR removes the grid arg:
```py
kernel.run(a, b, xnumel, stream=stream0)
```
instead now the grid computation is included in the kernel launcher, with something like:
```py
def launcher(in_ptr0, out_ptr0, xnumel, stream):
    grid_0 = ((xnumel + 1023) >> 10)
    grid_1 = 1
    grid_2 = 1
    runner(grid_0, grid_1, grid_2, stream, function, metadata, None, launch_enter_hook, launch_exit_hook, in_ptr0, out_ptr0, xnumel)
```

This should be faster, since we remove multiple function/dict calls and are able to specialize the grid computation for each `triton.Config`.

It also allows us to unify the handling of grids between the Python and C++ wrapper code.  Before this, C++ wrapper code didn't actually support dynamic grid sizes and instead burned in a static grid.

This unification allows this PR to be a net deletion of code.

Differential [disconnected] Revision: D70471332

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148305
Approved by: https://github.com/shunting314, https://github.com/eellison
2025-03-12 15:52:16 +00:00
PyTorch MergeBot
5ada4e6a53 Revert "Reland: [inductor] Simplify grid handling (#148305)"
This reverts commit 8d08b49015.

Reverted https://github.com/pytorch/pytorch/pull/148305 on behalf of https://github.com/jithunnair-amd due to Broke ROCm CI ([comment](https://github.com/pytorch/pytorch/pull/148305#issuecomment-2718177044))
2025-03-12 14:58:43 +00:00
Jason Ansel
8d08b49015 Reland: [inductor] Simplify grid handling (#148305)
Summary:
Relands D69965761 / https://github.com/pytorch/pytorch/pull/147583

Before this PR, calling a triton kernel would look like:
```py
kernel.run(a, b, xnumel, grid=grid(xnumel), stream=stream0)
```
where the `grid=` was passed as a callable (function closure) arg.  This PR removes the grid arg:
```py
kernel.run(a, b, xnumel, stream=stream0)
```
instead now the grid computation is included in the kernel launcher, with something like:
```py
def launcher(in_ptr0, out_ptr0, xnumel, stream):
    grid_0 = ((xnumel + 1023) >> 10)
    grid_1 = 1
    grid_2 = 1
    runner(grid_0, grid_1, grid_2, stream, function, metadata, None, launch_enter_hook, launch_exit_hook, in_ptr0, out_ptr0, xnumel)
```

This should be faster, since we remove multiple function/dict calls and are able to specialize the grid computation for each `triton.Config`.

It also allows us to unify the handling of grids between the Python and C++ wrapper code.  Before this, C++ wrapper code didn't actually support dynamic grid sizes and instead burned in a static grid.

This unification allows this PR to be a net deletion of code.

Differential Revision: D70471332

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148305
Approved by: https://github.com/shunting314, https://github.com/eellison
2025-03-11 18:51:06 +00:00
Ruben Rodriguez Buchillon
32715a2311 [inductor][ck] add kBatch_sweep to config.rocm (#148223)
Summary:
# Why

enable testing and users to specify a set of kBatches to try rather than relying on our hand written heuristic

# What

add rocm.kBatch_sweep as a list of kBatches to try out. These will generate a product of CK instances, one per kBatch for each existing op, though they are often filtered out if they are likely to fail at runtime

Test Plan: n/a

Reviewed By: chenyang78

Differential Revision: D70226055

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148223
Approved by: https://github.com/ColinPeppler
2025-03-06 01:14:33 +00:00
PyTorch MergeBot
608377d341 Revert "[import][inductor] Simplify grid handling (#147583)"
This reverts commit b59776d857.

Reverted https://github.com/pytorch/pytorch/pull/147583 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/147583#issuecomment-2693016036))
2025-03-03 00:49:32 +00:00
Jason Ansel
b59776d857 [import][inductor] Simplify grid handling (#147583)
Before this PR, calling a triton kernel would look like:
```py
kernel.run(a, b, xnumel, grid=grid(xnumel), stream=stream0)
```
where the `grid=` was passed as a callable (function closure) arg.  This PR removes the grid arg:
```py
kernel.run(a, b, xnumel, stream=stream0)
```
instead now the grid computation is included in the kernel launcher, with something like:
```py
def launcher(in_ptr0, out_ptr0, xnumel, stream):
    grid_0 = ((xnumel + 1023) >> 10)
    grid_1 = 1
    grid_2 = 1
    runner(grid_0, grid_1, grid_2, stream, function, metadata, None, launch_enter_hook, launch_exit_hook, in_ptr0, out_ptr0, xnumel)
```

This should be faster, since we remove multiple function/dict calls and are able to specialize the grid computation for each `triton.Config`.

It also allows us to unify the handling of grids between the Python and C++ wrapper code.  Before this, C++ wrapper code didn't actually support dynamic grid sizes and instead burned in a static grid.

This unification allows this PR to be a net deletion of code.

Note the attached diff contains some minor fbcode-only changes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147583
Approved by: https://github.com/eellison, https://github.com/shunting314
2025-03-02 07:31:07 +00:00
Ruben Rodriguez Buchillon
6f91720e1c [inductor][ck] manual kBatch heuristic (#148118)
Summary:
# Why

Leverage kBatch parameter for large splitK examples for CK for better than ATEN performance

# What

replace default kBatch = 1 with a manual heuristic

- if K > 16 * max (M,N)
- leverage k_per_block, and K and number of SMs on the chip
- upper bound to 128, lower bound to 1

This is better than defaulting to 1, cheap to calculate, and shows performance beyond ATEN

This is of course subject to change and improvement

Test Plan:
with minor modifications to to run torch.mm on the shape `M, N, K = 2048, 2048, 524288`

```
buck2 run -c fbcode.re_gpu_tests=False mode/opt-amd-gpu  fbcode//deeplearning/aot_inductor/benchmark/sampling:test_gemm_autotune_benchmark_AMD_block_0
```

```
AUTOTUNE mm(2048x524288, 524288x2048)
  rocm_ck_gemm_template_49 10.4972 ms 100.0%
  rocm_ck_gemm_template_8 10.6132 ms 98.9%
  rocm_ck_gemm_template_9 10.6907 ms 98.2%
[...]
  mm 18.9880 ms 55.3%
```

Reviewed By: ColinPeppler

Differential Revision: D70224591

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148118
Approved by: https://github.com/ColinPeppler
2025-02-28 20:36:16 +00:00
Xuehai Pan
1cb4e2df65 [BE][PYFMT] migrate PYFMT for torch._inductor to ruff format (#144550)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144550
Approved by: https://github.com/jansel
2025-02-28 13:33:19 +00:00
Ruben Rodriguez Buchillon
f0d00421cf [inductor][ck] kBatch filtering with gen_ops (#148004)
Summary:
# Why

not all choices of kBatch are valid and will lead to a runtime error (when CK checks the validity of the args)

c9bcfd755e/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp (L1020)

# What

- move kBatch inside the gen_ops to have more control over it, and be able to filter it
- expand filtering based on the cpp logic
- refactor the padding checks to be more readable

Test Plan:
```
buck2 run -c fbcode.re_gpu_tests=False mode/opt-amd-gpu  fbcode//deeplearning/aot_inductor/benchmark/sampling:test_gemm_autotune_benchmark_AMD_block_0
```

with

kBatch = 128: some filering
kBatch = 1: no filering
kBatch = 1738: all options filtered out

Reviewed By: henrylhtsang

Differential Revision: D70211442

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148004
Approved by: https://github.com/ColinPeppler, https://github.com/tenpercent
2025-02-27 20:13:58 +00:00
Ruben Rodriguez Buchillon
7a06bfdd1c [inductor][ck] kBatch parametrized (#147885)
Summary:
# Why

Enable us to set the kBatch parameter, rather than bake it in

Especially for larger splitK scenarios, this can yield very good performance (up to 1.5x vs hipblaslt from initial tests)

## Why like this

The obvious question should be: why not add this to the op itself, and maybe even into the template/kernel. That would simplify the code.

The choice to have it as a "runtime" param that we fix is be able to reuse the compiled CK `.so` libraries, as now multiple choices of kBatch can be used with the exact same `.so` (as the shared library does not depend on kBatch, but takes it as a parameter)

# What

- copy cutlass approach for swizzle to have a "runtime" arg that we pass in but is really choice dependent
- pipe through everything from template and kernel
- hard-code it to be kBatch=1 for now (same as before, just now settable)

This is part of a series of Diffs, where next we need to figure out
1. how to filter out ops + kBatch that don't work
2. set this better for splitK scenarios (hand written heuristic)

Test Plan:
(with minor modifications)

```
# show it working with AOTI
buck2 run mode/opt-amd-gpu //scripts/henrylhtsang/repros:aot
```

```
# show it working with inductor only
buck2 run -c fbcode.re_gpu_tests=False mode/opt-amd-gpu  fbcode//deeplearning/aot_inductor/benchmark/sampling:test_gemm_autotune_benchmark_AMD_block_0
```

Differential Revision: D70200008

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147885
Approved by: https://github.com/ColinPeppler
2025-02-26 07:28:19 +00:00
Aviral Goel
866dc45d3c [Inductor][ROCm][CK] Unhardedcoded kernel shapes for ck_conv_template codegen (#147504)
## [Inductor][ROCm][CK] Parameterize `ck_conv_template` Codegen

### Description
Previously, ROCm CK kernel codegen templates were hardcoded with fixed values for convolution parameters:

- `index_t GroupCount`
- `index_t NBatch`
- `index_t NOutChannels`
- `index_t NInChannels`
- `vector<index_t> FilterSize`
- `vector<index_t> InputSize`
- `vector<index_t> ConvolutionStrides`
- `vector<index_t> Dilations`
- `vector<index_t> LeftPads`
- `vector<index_t> RightPads`

This PR updates `ck_conv_template` to accept these parameters dynamically from Inductor. By doing so, we reduce the number of generated templates, improving flexibility and maintainability.

### Testing
- Verified correctness by running relevant test cases, i.e `test/inductor/test_ck_backend.py`
- Ensured generated kernels reflect the updated parameterization, i.e generated templates in `/tmp/torchinductor_root/`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147504
Approved by: https://github.com/jansel, https://github.com/eellison, https://github.com/tenpercent

Co-authored-by: Max Podkorytov <4273004+tenpercent@users.noreply.github.com>
2025-02-25 07:48:07 +00:00
Jason Ansel
e9f6e273e7 [inductor] Add typing to common.CSE (#145993)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145993
Approved by: https://github.com/yanboliang
ghstack dependencies: #145916
2025-02-04 16:05:39 +00:00
PyTorch MergeBot
d3c7e4bb9c Revert "[inductor] Add typing to common.CSE (#145993)"
This reverts commit 8c657ae4be.

Reverted https://github.com/pytorch/pytorch/pull/145993 on behalf of https://github.com/atalman due to Sorry need to revert https://github.com/pytorch/pytorch/pull/145916 ([comment](https://github.com/pytorch/pytorch/pull/145993#issuecomment-2632712384))
2025-02-04 03:04:01 +00:00
Jason Ansel
8c657ae4be [inductor] Add typing to common.CSE (#145993)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145993
Approved by: https://github.com/yanboliang
ghstack dependencies: #145913, #145914, #145915, #145916
2025-02-01 16:34:18 +00:00
Jason Ansel
e90cf4abcf [inductor] Add some typing to common.py (#145691)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145691
Approved by: https://github.com/malfet
ghstack dependencies: #145690
2025-01-27 06:27:13 +00:00
Aaron Orenstein
2bf772d1ba PEP585 update - torch/_inductor/codegen (#145106)
See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145106
Approved by: https://github.com/bobrenjc93
2025-01-18 06:56:03 +00:00
Max Podkorytov
99600789c3 [ROCm][Inductor][CK] hackfix for segfault in addmm op (#144519)
This snippet used to cause segfault on GPU due to incorrect input order when invoking the kernel

```
import os
import torch
import torch.nn as nn

from torch._inductor import config as inductor_config
from torch._inductor.utils import fresh_inductor_cache

M, N, K = 128, 128, 4096
dtype = torch.float16

X = torch.randn(M, N, dtype=dtype).cuda()
A = torch.randn(M, K, dtype=dtype).cuda()
B = torch.randn(K, N, dtype=dtype).cuda()

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, b, x, y):
        return torch.addmm(b, x, y)

import ck4inductor
ck_dir = os.path.dirname(ck4inductor.__file__)

with fresh_inductor_cache():
    with inductor_config.patch(
        {
            "max_autotune_gemm_backends": "CK",
            "autotune_fallback_to_aten": False,
            "compile_threads": 144,
            "rocm.ck_dir": ck_dir,
        }
    ):
        compiled_model = torch.compile(SimpleModel(), mode="max-autotune")
        res = compiled_model(X, A, B)
        res_eager = torch.addmm(X, A, B)
        torch.testing.assert_close(res, res_eager)

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144519
Approved by: https://github.com/chenyang78
2025-01-10 19:29:14 +00:00
bobrenjc93
a3ab27b8e0 Migrate from Tuple -> tuple in torch/_inductor (#144264)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144264
Approved by: https://github.com/eellison
2025-01-07 03:27:27 +00:00
bobrenjc93
c17d767686 remove allow-untyped-defs from _inductor/codegen/rocm/rocm_template_buffer.py (#143870)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143870
Approved by: https://github.com/aorenste, https://github.com/Skylion007
2024-12-27 23:28:51 +00:00
bobrenjc93
a42ca5a45b remove allow-untyped-defs for _inductor/codegen/rocm/rocm_template_buffer.py (#143272)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143272
Approved by: https://github.com/aorenste
2024-12-17 05:34:22 +00:00
Tom Ritchford
da67a6a7bb [inductor] Replace set by OrderedSet (#138466)
Uses the set_linter from https://github.com/pytorch/pytorch/pull/138454
and considerable manual editing

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138466
Approved by: https://github.com/eellison
2024-12-13 16:08:45 +00:00
eellison
b731ced91f Prologue Fusion (#134532)
This PR extends our ability to fuse pointwise nodes onto triton templates with the ability to fuse pointwise nodes into triton templates - prologue fusion.

Similar to the store_output api:
`{{store_output(("idx_m", "idx_n"), "acc", "mask")}}`

And the modification api:

```
{{ modification(
    subgraph_number=0,
    output_name="post_mod_scores",
    score="qk",
    out="qk"
) | indent_except_first(1) }}
```

We have:

```{{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", indent_width=8)}}```

Because we are now loading the input with explicit indices and mask, I needed to rewrite the mm kernel to no longer update the [pointers by BLOCK_K](bb03ef7aca/torch/_inductor/kernel/mm.py (L110-L111)) on every iteration and instead on each iteration compute indices from the the k_idx of each loop. This did not have any perf difference.

There are a couple main use cases for prologue fusion:

- Fusing dequants into a matmul. particularly for more bandwidth bound scenarios.
- Fusing gather into a matmul. This is useful particularly in MOE. See https://github.com/pytorch/pytorch/issues/134535 for more details.

Prologue fusion is generally much less profitable than epilogue fusion, because it must be applied to an element of an input on each loop of the matmul, compared to only once in the epilogue (gather into matmul is a potential exception). Accordingly, we are much less aggressive in attempting to fuse prologue fusion. We only attempt fusion if it does not increase the number of memory bytes read instead the triton template, multipled by a small factor to allow gathers. This restricts reliably unprofitable fusions like fp32->fp16 inside kernel. In future pr we could potentially have api of being more aggressive if we know we are in a bandwidth bound regime. See: https://github.com/pytorch/pytorch/pull/134532/files#diff-d2539c9c8dc6a3d7e457767a880612e96d3c85752a77ead49a9e4e00a3e4c3c7R3060-R3066

Other notes:

By default we will upcast to fp32 inside every kernel. This matches eager numerics. This is fine enough for epilogue because it is only done once (although it is probably unnecessary for say a relu) but tanks perf for prologue. I am currently using the `codegen_upcast_to_fp32` option to avoid it, but that will not work for libdevice calls that require fp32. We will need https://github.com/pytorch/pytorch/pull/136778/ and dtype-aware codegen to upcast fp16 ops into libdevice calls.

With prologue fusion, we now have essentially separate kernels for each input, and for the output. I had to increase the number of fields that are swapped out in `set_subgraph_body` by a large number :/ I also update the fusion logic because the inputs will have a different group than the outputs. Maybe as part of enabling multiple outputs, this could get cleaned up a bit so..

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134532
Approved by: https://github.com/jansel
2024-12-13 04:18:25 +00:00
Tom Ritchford
dc23f1944a Remove unused Python variables in torch/[_-a]* (#133492)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133492
Approved by: https://github.com/albanD
2024-12-12 17:39:14 +00:00
PyTorch MergeBot
233853a66f Revert "Prologue Fusion (#134532)"
This reverts commit 59ab3825e7.

Reverted https://github.com/pytorch/pytorch/pull/134532 on behalf of https://github.com/clee2000 due to A couple of PRs in this stack are breaking internally on different tests ([comment](https://github.com/pytorch/pytorch/pull/134532#issuecomment-2536643675))
2024-12-11 17:32:26 +00:00
PyTorch MergeBot
5c97ac9721 Revert "Remove unused Python variables in torch/[_-a]* (#133492)"
This reverts commit fda975a7b3.

Reverted https://github.com/pytorch/pytorch/pull/133492 on behalf of https://github.com/clee2000 due to Sorry, I need to revert this in order to revert something else.  The only thing you need to do is rebase and remerge ([comment](https://github.com/pytorch/pytorch/pull/133492#issuecomment-2536635516))
2024-12-11 17:29:12 +00:00
Tom Ritchford
fda975a7b3 Remove unused Python variables in torch/[_-a]* (#133492)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133492
Approved by: https://github.com/albanD
2024-12-10 21:48:44 +00:00
eellison
59ab3825e7 Prologue Fusion (#134532)
This PR extends our ability to fuse pointwise nodes onto triton templates with the ability to fuse pointwise nodes into triton templates - prologue fusion.

Similar to the store_output api:
`{{store_output(("idx_m", "idx_n"), "acc", "mask")}}`

And the modification api:

```
{{ modification(
    subgraph_number=0,
    output_name="post_mod_scores",
    score="qk",
    out="qk"
) | indent_except_first(1) }}
```

We have:

```{{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", indent_width=8)}}```

Because we are now loading the input with explicit indices and mask, I needed to rewrite the mm kernel to no longer update the [pointers by BLOCK_K](bb03ef7aca/torch/_inductor/kernel/mm.py (L110-L111)) on every iteration and instead on each iteration compute indices from the the k_idx of each loop. This did not have any perf difference.

There are a couple main use cases for prologue fusion:

- Fusing dequants into a matmul. particularly for more bandwidth bound scenarios.
- Fusing gather into a matmul. This is useful particularly in MOE. See https://github.com/pytorch/pytorch/issues/134535 for more details.

Prologue fusion is generally much less profitable than epilogue fusion, because it must be applied to an element of an input on each loop of the matmul, compared to only once in the epilogue (gather into matmul is a potential exception). Accordingly, we are much less aggressive in attempting to fuse prologue fusion. We only attempt fusion if it does not increase the number of memory bytes read instead the triton template, multipled by a small factor to allow gathers. This restricts reliably unprofitable fusions like fp32->fp16 inside kernel. In future pr we could potentially have api of being more aggressive if we know we are in a bandwidth bound regime. See: https://github.com/pytorch/pytorch/pull/134532/files#diff-d2539c9c8dc6a3d7e457767a880612e96d3c85752a77ead49a9e4e00a3e4c3c7R3060-R3066

Other notes:

By default we will upcast to fp32 inside every kernel. This matches eager numerics. This is fine enough for epilogue because it is only done once (although it is probably unnecessary for say a relu) but tanks perf for prologue. I am currently using the `codegen_upcast_to_fp32` option to avoid it, but that will not work for libdevice calls that require fp32. We will need https://github.com/pytorch/pytorch/pull/136778/ and dtype-aware codegen to upcast fp16 ops into libdevice calls.

With prologue fusion, we now have essentially separate kernels for each input, and for the output. I had to increase the number of fields that are swapped out in `set_subgraph_body` by a large number :/ I also update the fusion logic because the inputs will have a different group than the outputs. Maybe as part of enabling multiple outputs, this could get cleaned up a bit so..

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134532
Approved by: https://github.com/jansel
2024-12-10 16:25:57 +00:00
Max Podkorytov
822e8a01c6 [ROCm][Inductor][CK] Add batched gemms into gemm max autotune with CK backend (#141520)
## Testing
```
TORCH_LOGS=+torch._inductor pytest --capture=no test/inductor/test_ck_backend.py -k bmm
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141520
Approved by: https://github.com/chenyang78
2024-12-05 16:03:12 +00:00
Max Podkorytov
d64827dc35 [ROCm][Inductor][CK] Enable scaled mm with bias in gemm max autotune with CK backend (#140674)
## Testing
```
pytest test/inductor/test_ck_backend.py -k scaled_mm
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140674
Approved by: https://github.com/chenyang78
2024-11-15 22:08:38 +00:00
Max Podkorytov
ca30704f0b [Inductor][ROCm][CK] Add standalone runner (#139441)
Generate standalone executable to debug and profile CK gemm instances

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139441
Approved by: https://github.com/ColinPeppler
2024-11-07 06:21:27 +00:00
Aaron Orenstein
06f619d999 typing ir.py - part 2 (#131846)
See #131852

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131846
Approved by: https://github.com/eellison
ghstack dependencies: #139238
2024-11-06 00:01:15 +00:00
Max Podkorytov
2dab4ccb65 [Inductor][ROCm][CK] add CK grouped conv2d fwd kernels to ROCm codegen (#137947)
Plug into lowering and end to end test in a later PR

Instance parsing companion PR https://github.com/ROCm/composable_kernel/pull/1585

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137947
Approved by: https://github.com/ColinPeppler, https://github.com/chenyang78
2024-10-22 18:25:23 +00:00
Colin Peppler
89067402d4 [easy] in ROCmTemplate set kwargs when creating Buffer (#138521)
Summary: https://github.com/pytorch/pytorch/pull/137768 makes Inductor IR kw only

Test Plan: CI

Differential Revision: D64723804

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138521
Approved by: https://github.com/tenpercent, https://github.com/chenyang78
2024-10-22 03:13:16 +00:00
Jason Ansel
4632594546 [inductor] Move V.graph.scheduler.current_device to V.graph.current_device (#138252)
There are some places where it would be nice to use this, but the scheduler hasn't yet been created.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138252
Approved by: https://github.com/eellison
ghstack dependencies: #138170
2024-10-18 23:05:54 +00:00
Jason Ansel
85a6a782e5 [inductor] Generalize WorkspaceArg for graph-level semaphores (#138170)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138170
Approved by: https://github.com/Chillee
2024-10-18 23:05:54 +00:00
chilli
1cf78bbf62 Refactored debug_extra to be on ChoiceCaller (and called description) (#137857)
Before:
<img width="644" alt="image" src="https://github.com/user-attachments/assets/17b0fa8a-37c8-494b-8914-9d42c3db4bef">

After:
<img width="1292" alt="image" src="https://github.com/user-attachments/assets/5ee59747-a34f-4dd6-b943-cb5a53d52080">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137857
Approved by: https://github.com/ezyang, https://github.com/jansel, https://github.com/masnesral
ghstack dependencies: #137768
2024-10-15 00:48:14 +00:00
Max Podkorytov
52ba40c6f6 [ROCm][AOTI] add CK backend (#135641)
Companion to #134379

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135641
Approved by: https://github.com/ColinPeppler, https://github.com/chenyang78

Co-authored-by: Colin Peppler <colinpeppler@meta.com>
2024-10-07 23:53:58 +00:00
Colin Peppler
d117ec1d6e [3/3][Inductor] Make CK work in FBCode (#136234)
Summary:
# Context
Goal: Enable CK for Inductor in FBCode

We split this stack into three diffs to help with review & in case we need to revert anything.

# This Diff
* Gets us to have CK kernels as an option for GEMM autotuning in Inductor.

Reviewed By: zjing14

Differential Revision: D62662705

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136234
Approved by: https://github.com/tenpercent, https://github.com/chenyang78
2024-10-02 12:17:38 +00:00
Jez Ng
71aac59e93 Add Triton CPU as an Inductor backend (#133408)
The goal is to use Inductor-generated kernels to stress test the new Triton CPU backend.

Differential Revision: [D63298968](https://our.internmc.facebook.com/intern/diff/D63298968)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133408
Approved by: https://github.com/jansel, https://github.com/blaine-rister, https://github.com/malfet
2024-09-30 20:24:52 +00:00
PyTorch MergeBot
36428f91e9 Revert "Add Triton CPU as an Inductor backend (#133408)"
This reverts commit 31c0467594.

Reverted https://github.com/pytorch/pytorch/pull/133408 on behalf of https://github.com/int3 due to internal tests failing ([comment](https://github.com/pytorch/pytorch/pull/133408#issuecomment-2379692517))
2024-09-27 16:54:27 +00:00
Jez Ng
31c0467594 Add Triton CPU as an Inductor backend (#133408)
The goal is to use Inductor-generated kernels to stress test the new Triton CPU backend.

Differential Revision: [D63298968](https://our.internmc.facebook.com/intern/diff/D63298968)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133408
Approved by: https://github.com/jansel, https://github.com/blaine-rister, https://github.com/malfet
2024-09-26 15:35:26 +00:00
Bin Bao
95c0f7493f [Inductor] Rename WrapperCodeGen to PythonWrapperCodegen (#136062)
Summary: Rename WrapperCodeGen to PythonWrapperCodegen to make its meaning more explicit.

Differential Revision: [D63300358](https://our.internmc.facebook.com/intern/diff/D63300358)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136062
Approved by: https://github.com/angelayi, https://github.com/chenyang78
2024-09-24 21:02:51 +00:00
Max Podkorytov
7283530db2 [ROCm][Inductor][CK] FP8 gemm (#136337)
At the moment, lowering torch._scaled_mm with tensorwise scaling and rowwise scaling for both A and B

We probably also want to support either combination of tensorwise and rowwise for A and B, as well as bias support

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136337
Approved by: https://github.com/chenyang78
2024-09-24 05:19:45 +00:00
PyTorch MergeBot
d0cebedb31 Revert "Add Triton CPU as an Inductor backend (#133408)"
This reverts commit e498b02b47.

Reverted https://github.com/pytorch/pytorch/pull/133408 on behalf of https://github.com/jeanschmidt due to Broke internal signals, see D62737208 for more details ([comment](https://github.com/pytorch/pytorch/pull/133408#issuecomment-2353623816))
2024-09-16 18:33:33 +00:00
Jez Ng
e498b02b47 Add Triton CPU as an Inductor backend (#133408)
The goal is to use Inductor-generated kernels to stress test the new Triton CPU backend.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133408
Approved by: https://github.com/jansel
2024-09-14 21:45:19 +00:00
xinan.lin
13ee85ca5e [Inductor] Generalize cuda cpp wrapper as common triton based GPU cpp wrapper, will be reused by xpu in next PR. (#135312)
[Inductor] Generalize cuda cpp wrapper as common triton based GPU cpp wrapper, will be reused by xpu in next PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135312
Approved by: https://github.com/jansel, https://github.com/desertfire, https://github.com/eellison
2024-09-11 23:59:54 +00:00
Max Podkorytov
ef0f5919c7 [ROCm][Inductor][CK] Fix codegen after ck signature change (#134483)
MakeArgument signature was changed in https://github.com/ROCm/composable_kernel/pull/1453 adding splitK argument to universal gemm templates which are used to codegen addmm and matmul

(part of the series started at #125453 )

# Testing
`pytest test/inductor/test_ck_backend.py`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134483
Approved by: https://github.com/ColinPeppler
2024-08-27 23:25:42 +00:00