Commit Graph

389 Commits

Author SHA1 Message Date
Bin Bao
a597a00c87 [AOTI][refactor][3/n] Declare python_kernel_name and cpp_kernel_name in ExternKernel (#115972)
Summary: Both ExternKernelAlloc and ExternKernelOut need the two fields, so declaring them in the base class. Also add cpp codegen for IndexPutFallback and InplaceBernoulliFallback in this PR.

This is a reland of https://github.com/pytorch/pytorch/pull/115831

Differential Revision: [D52290900](https://our.internmc.facebook.com/intern/diff/D52290900)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115972
Approved by: https://github.com/chenyang78
2023-12-20 03:22:03 +00:00
PyTorch MergeBot
d5115bfb06 Revert "[AOTI][refactor][3/n] Declare python_kernel_name and cpp_kernel_name in ExternKernel (#115831)"
This reverts commit 287a865677.

Reverted https://github.com/pytorch/pytorch/pull/115831 on behalf of https://github.com/desertfire due to rocm CI failure ([comment](https://github.com/pytorch/pytorch/pull/115831#issuecomment-1858322270))
2023-12-15 18:34:55 +00:00
Bin Bao
287a865677 [AOTI][refactor][3/n] Declare python_kernel_name and cpp_kernel_name in ExternKernel (#115831)
Summary: Both ExternKernelAlloc and ExternKernelOut need the two fields, so declaring them in the base class. Also add cpp codegen for IndexPutFallback and InplaceBernoulliFallback in this PR.

Differential Revision: [D52189999](https://our.internmc.facebook.com/intern/diff/D52189999)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115831
Approved by: https://github.com/chenyang78
2023-12-15 14:40:44 +00:00
Bin Bao
7d4ccd7b9e [AOTI][refactor][2/n] Rename kernel to python_kernel_name (#115766)
Differential Revision: [D52164940](https://our.internmc.facebook.com/intern/diff/D52164940)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115766
Approved by: https://github.com/chenyang78
ghstack dependencies: #115783
2023-12-15 03:08:13 +00:00
Bin Bao
f90a5f891b [AOTI][refactor][1/n] Rename cpp_kernel to cpp_kernel_name (#115783)
Differential Revision: [D52142184](https://our.internmc.facebook.com/intern/diff/D52142184)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115783
Approved by: https://github.com/chenyang78, https://github.com/jansel
2023-12-15 00:50:17 +00:00
Peter Bell
02196c21ac [inductor] Parameterize ir.Scan on combine_fn (#109132)
This replaces `tl.cumsum` and `tl.cumprod` with calls to `tl.associative_scan`
where the combine function is generated from inductor IR.

So before we had:
```python
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr):
    xnumel = 20
    rnumel = 30
    RBLOCK: tl.constexpr = 32
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rindex = tl.arange(0, RBLOCK)[None, :]
    rmask = rindex < rnumel
    r1 = rindex
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (r1 + (30*x0)), rmask & xmask, other=0).to(tl.float32)
    tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
    tmp2 = tl.where(rmask & xmask, tmp1, 0)
    tmp3 = tl.cumsum(tmp2, 1)
    tl.store(out_ptr0 + (r1 + (30*x0)), tmp3, rmask & xmask)
```

Now we have:
```python
@triton.jit
def _triton_helper_fn0(arg0, arg1):
    tmp0 = tmp0 + tmp1
    return tmp0

@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr):
    xnumel = 20
    rnumel = 30
    RBLOCK: tl.constexpr = 32
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rindex = tl.arange(0, RBLOCK)[None, :]
    rmask = rindex < rnumel
    r1 = rindex
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (r1 + (30*x0)), rmask & xmask, other=0).to(tl.float32)
    tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
    tmp2 = tl.where(rmask & xmask, tmp1, 0)
    tmp3 = tl.associative_scan(tmp2, 1, _triton_helper_fn0)
    tl.store(out_ptr0 + (r1 + (30*x0)), tmp3, rmask & xmask)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109132
Approved by: https://github.com/lezcano
2023-12-12 16:30:50 +00:00
Scott Wolchok
ff6f987adc [PyTorch] Replace cached thread_locals with stack allocation in AOTI (#112116)
This changes cached thread_local tensors to stack-allocated buffers. Since we were incidentally caching output in a thread_local, I had to add manual thread_local caching of outputs, which I implemented by caching a buffer and a Tensor whose storage is that buffer and then just memcpying the result into the cached buffer every time. Ideally, memory planning would be able to identify allocations that are the backing storage for outputs, but this should be good enough in the absence of planning.

Differential Revision: [D50416438](https://our.internmc.facebook.com/intern/diff/D50416438/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112116
Approved by: https://github.com/jansel, https://github.com/desertfire
2023-12-12 06:19:45 +00:00
Bin Bao
2e6b809d6b [AOTI] Fix a missing declaration for the result of item() (#115175)
Differential Revision: [D51968539](https://our.internmc.facebook.com/intern/diff/D51968539)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115175
Approved by: https://github.com/chenyang78
2023-12-10 22:49:45 +00:00
leslie-fang-intel
f6291a5e93 [Quant] [Inductor] Enable QLinear weight prepack when input dimension size exceeds 2 (#113928)
**Summary**
Enable the qlinear weight prepack when input dimension size exceeds 2. There are extra reshape node before and after the `addmm` or `mm` node if input dimension size exceeds 2.

**Test Plan**
```
python -m pytest test_mkldnn_pattern_matcher.py -k input_dim_exceeds_2
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113928
Approved by: https://github.com/jgong5, https://github.com/eellison
ghstack dependencies: #113733, #113912
2023-12-06 01:24:15 +00:00
leslie-fang-intel
4a624d1f8a [Quant] [PT2] Enable QLinear input with multi dims (#113733)
**Summary**
In the previous QLinear implementation, it was assumed that inputs have a dimension of 2. In this update, we have modified QLinear to accept inputs with a dimension greater than 2, incorporating input and output reshaping accordingly.

**Test Plan**
```
python -u -m pytest -s -v test_quantized_op.py -k test_qlinear_pt2e
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113733
Approved by: https://github.com/jgong5, https://github.com/eellison
2023-12-06 01:16:51 +00:00
Peter Bell
7aac689b19 [inductor] Add ir.Scan and lower aten.cumsum on CUDA (#106581)
This adds the `ir.Scan` node (currently only supported on CUDA) which re-uses the existing reduction kernel machinery to support different kinds of non-pointwise ops. Just like reductions it supports prologue and epilogue fusions and has both persistent and non-persistent kernel generation.

Currently this doesn't support the equivalent of `Reduction.create_multilayer` and will instead fall back to eager in those cases. This is because splitting into multiple kernel invocations ends up being far slower than cub's single kernel strategy which matches the performance of a copy kernel.

Fixes https://github.com/pytorch/pytorch/issues/93631

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106581
Approved by: https://github.com/lezcano, https://github.com/atalman
2023-12-05 23:31:49 +00:00
Bin Bao
e06bff8bbe [AOTI] Handle empty input args (#114682)
Summary: When the model takes no inputs, AOTInductor relies on checking weights to figure out which device to compile the model into. Currently recording buffer device type happens too late, and this PR fixes that.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114682
Approved by: https://github.com/chenyang78
2023-12-05 15:02:17 +00:00
Will Feng
1d0e70ad65 Add get_mutation_names to ir.Wait (#115104)
`ir.Wait` generates the last 2 lines of this code:
```python
buf1_work = dist.all_gather_into_tensor(buf1[0], buf1_inputs[0], async_op=True, group=buf1_pg)
fun_col_impl._register_tensor_work(buf1, buf1_work)
buf2 = buf1[0]
del buf1

buf2 = _wait_tensor(buf2)  #  <- generated by ir.Wait
buf3 = buf2;  # reuse  <- generated by ir.Wait
```
`_wait_tensor` technically is a "mutation" op that changes `buf2` in place. So we should mark `ir.Wait` as a mutation op (by overriding its `get_mutation_names()`).

This fixes a very peculiar issue when inductor comm reordering is used for llama model: downstream nodes that uses the all-gather comm output sometimes takes dependency on `buf2` (the node before `ir.Wait`) instead of on `buf3` (`ir.Wait`) (it's still unclear why it behaves like this). To work around the issue, we add the missing annotation that `buf3` is a mutation of `buf2`, so that the scheduler knows to schedule `buf3` before any of the `buf2` users.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115104
Approved by: https://github.com/wanchaol
2023-12-05 03:54:33 +00:00
Yang Chen
4d8b9964e1 [aotinductor] support at::convolution for AOTInductor (#114961)
This PR adds support to at::convolution for AOTInductor

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114961
Approved by: https://github.com/desertfire
2023-12-03 07:52:28 +00:00
Jez Ng
f1fd02503b Reland #113487 and #112527 (sdpa shim & fp8 AOTInductor support) (#114974)
This is a backout of #113747 which reverted the above two commits. Now that
#113997 has landed, this diff can be landed safely without breaking ABI compatibility.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114974
Approved by: https://github.com/chenyang78
2023-12-02 03:25:51 +00:00
colinpeppler
5262484ece [easy][aotinductor] fix typos & add static typing (#114728)
```
// check all references
$ grep -rl 'cpp_kernel_overlad_name' *
ir.py
```

```
$ lintrunner --take MYPYINDUCTOR torch/_inductor/codegen/wrapper.py torch/_inductor/ir.py
ok No lint issues.
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114728
Approved by: https://github.com/Skylion007, https://github.com/chenyang78
2023-11-30 02:10:56 +00:00
chundian
74e10f0f60 [inductor] Fix torch.split bug on unbacked symint (#113406)
torch.split(x, l) fails when l's shape is the unbacked symint.

E.g. l =
y.tolist() makes l the unbacked shape, because l depends on the
data access of y. The downdtream call `SliceView.create()`
evaluates the shape even if the input shape is unbacked symint,
which brings up the bug.

Test Plan:
python test/inductor/test_unbacked_symints.py -k test_split_with_sizes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113406
Approved by: https://github.com/aakhundov, https://github.com/ezyang
2023-11-28 20:45:13 +00:00
Michael Lazos
4c794f2ef1 Reinplace foreach when safe and allow aliasing during lowering (#112440)
This reduces compile time of Adam on 1k parameters from 180s to 140s (28%), the main reason being that thousands of buffers no longer get sent to the scheduler.

The idea behind this is that if a destination buffer (from a copy_) has no users, it shouldn't matter if dst aliases src.

This is implemented by reinplacing copy_ nodes when safe.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112440
Approved by: https://github.com/jansel
2023-11-27 21:32:42 +00:00
PyTorch MergeBot
ccb1de3595 Revert "[inductor] Fix torch.split bug on unbacked symint (#113406)"
This reverts commit cd7d6938c1.

Reverted https://github.com/pytorch/pytorch/pull/113406 on behalf of https://github.com/DanilBaibak due to Break internal build ([comment](https://github.com/pytorch/pytorch/pull/113406#issuecomment-1827727411))
2023-11-27 12:20:52 +00:00
Oguz Ulgen
c6d88604d5 [Inductor] Fix mutation tracking of ConvolutionBinaryInplace (#114501)
Init function reorders the arguments so the mutation actually happens on
argument input[0]

I am not sure if there's a good way to test this unfortunately.. Added
tests on https://github.com/pytorch/pytorch/pull/114436

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114501
Approved by: https://github.com/leslie-fang-intel, https://github.com/aakhundov
2023-11-24 19:32:41 +00:00
Adnan Akhundov
0a063ad2c0 [inductor] Pass None and skip constexpr in custom Triton kernel calls from C++ (#114475)
Summary: `None` arguments are codegened as `*i8` in the `triton_meta` of the generated or user-defined Triton kernels:

85aa372374/torch/_inductor/codegen/triton_utils.py (L33-L36)

Due to this, in contrary to the conventional Triton, we actually should pass `nullptr` to the Triton kernels in C++ wrapper codegen instead of passing nothing (as normally `None` doesn't make it to the generated PTX parameters, just like `tl.constexpr` args).

This PR adds two things:

1. Proper C++ wrapper codegening (ABI and non-ABI) of `nullptr` and `c10::nullopt`, as the prior way codegening `c10::nullopt` as tensor breaks (also `c10` breaks in the ABI mode).

2. Skipping `tl.constexpr` args when calling the loaded-from-cubin compiled Triton kernel in the C++ wrapper codegen. As a side effect, this also resolves an issue with string arguments: now they are simply omitted in the C++ wrapper codegen.

Test Plan:

```
$ python test/inductor/test_aot_inductor.py -k test_triton_kernel_with_none_input
...
----------------------------------------------------------------------
Ran 4 tests in 40.364s

OK (skipped=2)
```

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114475
Approved by: https://github.com/oulgen
2023-11-24 12:51:56 +00:00
chundian
cd7d6938c1 [inductor] Fix torch.split bug on unbacked symint (#113406)
torch.split(x, l) fails when l's shape is the unbacked symint.

E.g. l =
y.tolist() makes l the unbacked shape, because l depends on the
data access of y. The downdtream call `SliceView.create()`
evaluates the shape even if the input shape is unbacked symint,
which brings up the bug.

Test Plan:
python test/inductor/test_unbacked_symints.py -k test_split_with_sizes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113406
Approved by: https://github.com/aakhundov, https://github.com/ezyang
2023-11-24 07:21:00 +00:00
Oguz Ulgen
51390722e9 Fix ConvolutionBinaryInplace using target node (#114436)
This IR node mutates in place, it needs to use the argument not the
target.

Fixes #113440

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114436
Approved by: https://github.com/jansel
ghstack dependencies: #114169
2023-11-24 06:25:11 +00:00
Edward Z. Yang
7ea184d7e3 Handle item() on boolean tensor (#114157)
This needs some special handling because we don't actually allocate
boolean symbols in sympy; we allocate an integer indicator variable.
See comment for more details.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114157
Approved by: https://github.com/ydwu4
2023-11-21 04:34:58 +00:00
Jez Ng
87925789ae Make V.graph properly typed (#114025)
Previously it lacked a type hint and so was treated as an Any type. This
resulted in a lot of untyped code downstream as V.graph is referenced in
many places in inductor code. I've typed it properly now as
GraphLowering, and fixed the numerous type errors this surfaced.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114025
Approved by: https://github.com/eellison
ghstack dependencies: #114013
2023-11-21 02:14:29 +00:00
Jez Ng
4667e20b3f Delete a bunch of type-ignores (#113990)
* Replaced `ignore[import]` by mypy config file entries
* Removed a bunch of ignores around previously-fixed attr-defined /
  call-arg issues
* Fixed some invalid / undefined types; added a few more type-ignores to
  squelch the downstream errors this exposed

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113990
Approved by: https://github.com/eellison, https://github.com/Skylion007
ghstack dependencies: #113979
2023-11-18 02:48:38 +00:00
Sherlock Huang
8372983fe3 [AOTInductor] Use ProxyExecutor for aten op if c-shim is missing (#113918)
Summary:
As discussed in the meeting, we are inverting the policy on the use of proxy executor for aten fallbacks.
By default, aten fallback ops will use proxy executor, unless a c-shim is available.

Test Plan: CIs

Differential Revision: D51417683

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113918
Approved by: https://github.com/chenyang78
2023-11-18 00:04:21 +00:00
Jez Ng
4b1583fe57 type-ignore issues exposed by import following (#113979)
Some new errors were introduced in a land-race with
https://github.com/pytorch/pytorch/pull/113830. Silence them for now to
get the lintrunner job green again.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113979
Approved by: https://github.com/huydhn
2023-11-17 21:20:09 +00:00
eellison
a9134fa99a Skip cudagraphs when there is sparsity (#113791)
Fix for dlrm training

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113791
Approved by: https://github.com/Chillee
2023-11-17 01:36:03 +00:00
Bin Bao
1480c670a0 [AOTI] Delay the fallback kernel naming decision to the codegen time (#113660)
Summary: This is to prepare for a later change that changes AOTI's second-pass to perform codegen only.

Differential Revision: [D51382677](https://our.internmc.facebook.com/intern/diff/D51382677)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113660
Approved by: https://github.com/chenyang78
2023-11-16 23:07:30 +00:00
Jez Ng
df9acc61fb [inductor] Make {freezing,ir}.py pass follow-imports typechecking (#113534)
I used a couple of type-ignore comments in ir.py because it constructs
short-lived instances of FixedLayout and GraphModuleSerializer, just to
call a single method on them that doesn't use all their members. Making
those unused members optional would make the rest of the code a lot
messier with sprinkled `assert` statements.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113534
Approved by: https://github.com/albanD
2023-11-16 01:53:52 +00:00
Bin Bao
c99d88afa4 [AOTI] Remove try_find_schema (#113617)
Differential Revision: [D51350727](https://our.internmc.facebook.com/intern/diff/D51350727)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113617
Approved by: https://github.com/aakhundov, https://github.com/chenyang78, https://github.com/khabinov
2023-11-15 22:42:47 +00:00
Wei Wei
b19cf868e8 Back out "Support fp8 in AOTInductor + support optional<> in C ABI (#112527)" (#113747)
Test Plan: sandcastle

Differential Revision: D51330618

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113747
Approved by: https://github.com/chenyang78, https://github.com/khabinov
2023-11-15 22:42:22 +00:00
Yang Chen
a144eb502a [aotinductor] add versions for the sdpa shim api (#113487)
In our first implemenation of the sdpa shim api, we didn't consider
the case where the optional scale argument could be None. It was
unnoticed because we always got a default argument for the cuda backend.
The issue was detected with the cpu backend.

This PR implements versioning for shim kernels. Currently, we only
have different versions for the sdpa api. We expect we would only
maintain a very small number of abi-compatible shim APIs that
had different versions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113487
Approved by: https://github.com/int3, https://github.com/desertfire
2023-11-13 20:18:58 +00:00
Oguz Ulgen
06dc2f162d [AOTI] Implement support for user defined kernels that use triton.autotune (#113229)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113229
Approved by: https://github.com/chenyang78
2023-11-10 22:40:51 +00:00
PyTorch MergeBot
2cd8c0565c Revert "[AOTI] Implement support for user defined kernels that use triton.autotune (#113229)"
This reverts commit 1488bafb27.

Reverted https://github.com/pytorch/pytorch/pull/113229 on behalf of https://github.com/PaliC due to breaking test_aot_inductor.py tests though a forward fix is coming ([comment](https://github.com/pytorch/pytorch/pull/113229#issuecomment-1806159396))
2023-11-10 17:46:14 +00:00
Oguz Ulgen
1488bafb27 [AOTI] Implement support for user defined kernels that use triton.autotune (#113229)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113229
Approved by: https://github.com/chenyang78
2023-11-10 01:39:00 +00:00
Lucas Pasqualin
1d56e7b5af Adds broadcast to functional collectives (#112668)
Adds `broadcast` to functional collectives, including inductor support.

Test with `python test_inductor_collectives.py -- TestCollectivesMultiProc.test_broadcast_inductor`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112668
Approved by: https://github.com/wanchaol, https://github.com/wconstab
2023-11-09 15:47:52 +00:00
Yifu Wang
625958d8bc Inductor support for native c10d_functional (#112439)
This PR adds Inductor support for [native c10d_functional ops](https://github.com/pytorch/pytorch/pull/110570).

The Inductor IRs introduced in this PR will replace the existing `CollectiveKernel` IR hierarchy. Compared to the existing collective IRs, the new IRs:
- Are target language agnostic and support AOTInductor.
- Express the constraints solely with read/write deps. This maximizes the potential for buffer reuse.
- Address an issue where out-of-place collective's input buffers could be mutated while being volatilely read.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112439
Approved by: https://github.com/Chillee
2023-11-08 23:40:21 +00:00
Jez Ng
297c26bb8e Support fp8 in AOTInductor + support optional<> in C ABI (#112527)
This was originally ipiszy's PR: https://github.com/pytorch/pytorch/pull/112358

It turns out that we need to add support for optional types in order to
support fp8 gemm (i.e. scaled_mm). Since our ABI-stable C interface
can't support optional<> directly, I am passing in optional types via
pointer instead.

`AtenTensorHandle`s are already pointers, so nothing needs to change
there. Only value types need to change.

We decided on this approach instead of adding an extra `bool` param to
the callee because this simplifies things. Having the same number of
arguments regardless of whether we are emitting Python / C++ /
ABI-compatible C++ makes codegen easier.

There are a number of existing ABI-compatible functions that have
optional-typed value parameters. Previously, they just assumed they
would never be passed a `nullopt` / `None` at runtime. Changing them to
use pointer types now would break ABI stability, so I have created an
exclude list for those functions.

Finally, I think the current implementation is kind of messy, and only
works for FallbackKernels, even though technically ExternKernels could
also have the same issue. It also doesn't support optional types nested
in lists. I've left FIXME comments for both issues.

Differential Revision: [D51084289](https://our.internmc.facebook.com/intern/diff/D51084289)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112527
Approved by: https://github.com/chenyang78, https://github.com/desertfire
2023-11-08 22:56:48 +00:00
Sherlock Huang
728ed37663 [AOTInductor] Allow using ProxyExecutor for ATen fallbacks (#112976)
Summary: Use ProxyExecutor for aten._scaled_dot_product_efficient_attention in ABI-mode

Test Plan: OSS CI

Differential Revision: D51005807

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112976
Approved by: https://github.com/chenyang78, https://github.com/jansel
2023-11-08 08:34:11 +00:00
Oguz Ulgen
8ba11bf79d [AOTI] Support non auto-tuned triton kernels in aoti (#113090)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113090
Approved by: https://github.com/aakhundov, https://github.com/chenyang78, https://github.com/desertfire
2023-11-08 07:48:15 +00:00
Oguz Ulgen
611a7457ca [Inductor] Kill MutationLayout from ir.py (#112925)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112925
Approved by: https://github.com/jansel
2023-11-07 17:03:52 +00:00
Oguz Ulgen
bfa717c6a6 [Inductor] Improve reinplace_scatters pass (#112801)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112801
Approved by: https://github.com/Chillee, https://github.com/jansel
ghstack dependencies: #112752, #113008
2023-11-07 05:29:42 +00:00
Oguz Ulgen
dbf44dffc9 [Inductor] Cache generated user defined triton kernels on tensor dtype and non tensor parameters (#112752)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112752
Approved by: https://github.com/jansel
2023-11-07 05:29:16 +00:00
Kai Londenberg
bdfde62e54 [Inductor CUTLASS backend] Epilogue fusion codegen (Step 1) (#110890)
Summary:

This PR adds epilogue fusion code generation support for the new experimental
[Inductor Cutlass backend]([https://github.com/pytorch/pytorch/pull/108015]).

Details:

A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template
and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and
performs the computation of the fused Pointwise / Elementwise computation nodes.

This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu),
which is currently the only documentation and example of Cutlass Epilogue Visitor Trees.

This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform.
A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments
to each of the functor subexpressions.

Step 1 functionality:

 * End to end code generation is possible using the above approach.
 * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants )
   after a matmul.
 * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc.
 * Examples / Unit tests include ReLU and ReLU6 fusion.
 * Support for fp16 and fp16 with fp32 accumulation data types.
 * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 )

The following is not yet supported, and is left for future work:

 * Full operation support ( e.g. full set of all ops usually handled via V.ops handlers )
 * Cutlass EVT with SM80 support ( possible in Cutlass 3.2.1 according to release notes, but not yet documented )
 * Add support for additional (auxiliary) inputs, which changes the Template Kernels' call signature
 * Add support for additional (auxiliary) outputs ( requires support for full computation graphs )
 * Add support for reduction operations and operations which use different output layouts than the input
 * Add support for additional dtypes ( as far as Cutlass allows )

This PR updates third_party/cutlass to v3.2.2, which has some important improvements and features
for the inductor backend.

See also Cutlass release notes:
https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1 and https://github.com/NVIDIA/cutlass/releases/tag/v3.2.2

Notable changes in Cutlass 3.2.1 include:
 * Cutlass codegen python code has moved into a package with the "cutlass_library" namespace, which allows to
   prevent namespace clashes without resolving to monkey-patching ( which was done earlier ).
 * Support for SM80 epilogue visitor trees ( according to the Release Notes, not tried yet )
 * Small API changes to the cutlass_library API ( requires adapting the inductor backend code )

Notable changes in Cutlass 3.2.2 include:
 * Bugfix that led to CUDA Illegal memory access in some Pytorch unit tests involving flash attention

 Test Plan:
  * CI
  * pytest test/inductor/test_max_autotune.py

Note: So far, the CUTLASS backend is still disabled by default. Benchmarks are planned once more advanced fusions are enabled.

Differential Revision: [D50988161](https://our.internmc.facebook.com/intern/diff/D50988161)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110890
Approved by: https://github.com/jansel
ghstack dependencies: #112762
2023-11-06 19:42:10 +00:00
Oguz Ulgen
67e8762e83 [Inductor] Kill has_aliasing (#112875)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112875
Approved by: https://github.com/Chillee
2023-11-03 23:22:22 +00:00
Oguz Ulgen
001573b687 [Inductor] Support one node creating multiple mutations in scheduler (#112547)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112547
Approved by: https://github.com/Chillee
2023-11-03 16:01:31 +00:00
leslie-fang-intel
a53d29cc18 Enable oneDNN QLinear FP32/BF16 output (#112126)
**Summary**
- PR 2 for enabling Int8-Mixed-BF16 PT2E PTQ Quantization with Inductor https://github.com/pytorch/pytorch/issues/111640.
- Enable QLinear (relu) with BFloat16 or Float32 output.

**TestPlan**
```
python -u -m pytest -s -v test_quantized_op.py -k test_qlinear_pt2e
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112126
Approved by: https://github.com/jerryzh168, https://github.com/jgong5
ghstack dependencies: #112010
2023-11-03 08:20:54 +00:00
leslie-fang-intel
b6fc7af8a0 Enable oneDNN QConv FP32/BF16 output (#112010)
**Summary**

- PR 1 for enabling Int8-Mixed-BF16 PT2E PTQ Quantization with Inductor https://github.com/pytorch/pytorch/issues/111640.
- Enable QConv (relu, add, add_relu) with BFloat16 or Float32 output.

**Test Plan**
```
python -u -m pytest -s -v test_quantized_op.py -k test_qconv1d_pt2e
python -u -m pytest -s -v test_quantized_op.py -k test_qconv2d_pt2e
python -u -m pytest -s -v test_quantized_op.py -k test_qconv3d_pt2e
python -u -m pytest test_quantized_op.py -k test_qconv2d_relu_pt2e
python -u -m pytest test_quantized_op.py -k test_qconv2d_add_pt2e
python -u -m pytest test_quantized_op.py -k test_qconv2d_add_relu_pt2e
python -u -m pytest test_quantized_op.py -k test_qconv2d_add_relu_float_output_pt2e
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112010
Approved by: https://github.com/jerryzh168, https://github.com/jgong5
2023-11-03 08:16:45 +00:00