Commit Graph

326 Commits

Author SHA1 Message Date
eellison
b731ced91f Prologue Fusion (#134532)
This PR extends our ability to fuse pointwise nodes onto triton templates with the ability to fuse pointwise nodes into triton templates - prologue fusion.

Similar to the store_output api:
`{{store_output(("idx_m", "idx_n"), "acc", "mask")}}`

And the modification api:

```
{{ modification(
    subgraph_number=0,
    output_name="post_mod_scores",
    score="qk",
    out="qk"
) | indent_except_first(1) }}
```

We have:

```{{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", indent_width=8)}}```

Because we are now loading the input with explicit indices and mask, I needed to rewrite the mm kernel to no longer update the [pointers by BLOCK_K](bb03ef7aca/torch/_inductor/kernel/mm.py (L110-L111)) on every iteration and instead on each iteration compute indices from the the k_idx of each loop. This did not have any perf difference.

There are a couple main use cases for prologue fusion:

- Fusing dequants into a matmul. particularly for more bandwidth bound scenarios.
- Fusing gather into a matmul. This is useful particularly in MOE. See https://github.com/pytorch/pytorch/issues/134535 for more details.

Prologue fusion is generally much less profitable than epilogue fusion, because it must be applied to an element of an input on each loop of the matmul, compared to only once in the epilogue (gather into matmul is a potential exception). Accordingly, we are much less aggressive in attempting to fuse prologue fusion. We only attempt fusion if it does not increase the number of memory bytes read instead the triton template, multipled by a small factor to allow gathers. This restricts reliably unprofitable fusions like fp32->fp16 inside kernel. In future pr we could potentially have api of being more aggressive if we know we are in a bandwidth bound regime. See: https://github.com/pytorch/pytorch/pull/134532/files#diff-d2539c9c8dc6a3d7e457767a880612e96d3c85752a77ead49a9e4e00a3e4c3c7R3060-R3066

Other notes:

By default we will upcast to fp32 inside every kernel. This matches eager numerics. This is fine enough for epilogue because it is only done once (although it is probably unnecessary for say a relu) but tanks perf for prologue. I am currently using the `codegen_upcast_to_fp32` option to avoid it, but that will not work for libdevice calls that require fp32. We will need https://github.com/pytorch/pytorch/pull/136778/ and dtype-aware codegen to upcast fp16 ops into libdevice calls.

With prologue fusion, we now have essentially separate kernels for each input, and for the output. I had to increase the number of fields that are swapped out in `set_subgraph_body` by a large number :/ I also update the fusion logic because the inputs will have a different group than the outputs. Maybe as part of enabling multiple outputs, this could get cleaned up a bit so..

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134532
Approved by: https://github.com/jansel
2024-12-13 04:18:25 +00:00
eellison
0b75b7ff2b [Easy] factor out inductor ophandler decompositions (#142400)
Factor out inductor operator decompositions

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142400
Approved by: https://github.com/Chillee, https://github.com/jansel
2024-12-12 19:03:26 +00:00
PyTorch MergeBot
233853a66f Revert "Prologue Fusion (#134532)"
This reverts commit 59ab3825e7.

Reverted https://github.com/pytorch/pytorch/pull/134532 on behalf of https://github.com/clee2000 due to A couple of PRs in this stack are breaking internally on different tests ([comment](https://github.com/pytorch/pytorch/pull/134532#issuecomment-2536643675))
2024-12-11 17:32:26 +00:00
PyTorch MergeBot
829a93562a Revert "[Easy] factor out inductor ophandler decompositions (#142400)"
This reverts commit fa746e3eeb.

Reverted https://github.com/pytorch/pytorch/pull/142400 on behalf of https://github.com/clee2000 due to A couple of PRs in this stack are breaking internally on different tests ([comment](https://github.com/pytorch/pytorch/pull/134532#issuecomment-2536643675))
2024-12-11 17:32:26 +00:00
eellison
fa746e3eeb [Easy] factor out inductor ophandler decompositions (#142400)
Factor out inductor operator decompositions

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142400
Approved by: https://github.com/Chillee, https://github.com/jansel
ghstack dependencies: #134532, #142350
2024-12-10 16:58:36 +00:00
eellison
59ab3825e7 Prologue Fusion (#134532)
This PR extends our ability to fuse pointwise nodes onto triton templates with the ability to fuse pointwise nodes into triton templates - prologue fusion.

Similar to the store_output api:
`{{store_output(("idx_m", "idx_n"), "acc", "mask")}}`

And the modification api:

```
{{ modification(
    subgraph_number=0,
    output_name="post_mod_scores",
    score="qk",
    out="qk"
) | indent_except_first(1) }}
```

We have:

```{{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", indent_width=8)}}```

Because we are now loading the input with explicit indices and mask, I needed to rewrite the mm kernel to no longer update the [pointers by BLOCK_K](bb03ef7aca/torch/_inductor/kernel/mm.py (L110-L111)) on every iteration and instead on each iteration compute indices from the the k_idx of each loop. This did not have any perf difference.

There are a couple main use cases for prologue fusion:

- Fusing dequants into a matmul. particularly for more bandwidth bound scenarios.
- Fusing gather into a matmul. This is useful particularly in MOE. See https://github.com/pytorch/pytorch/issues/134535 for more details.

Prologue fusion is generally much less profitable than epilogue fusion, because it must be applied to an element of an input on each loop of the matmul, compared to only once in the epilogue (gather into matmul is a potential exception). Accordingly, we are much less aggressive in attempting to fuse prologue fusion. We only attempt fusion if it does not increase the number of memory bytes read instead the triton template, multipled by a small factor to allow gathers. This restricts reliably unprofitable fusions like fp32->fp16 inside kernel. In future pr we could potentially have api of being more aggressive if we know we are in a bandwidth bound regime. See: https://github.com/pytorch/pytorch/pull/134532/files#diff-d2539c9c8dc6a3d7e457767a880612e96d3c85752a77ead49a9e4e00a3e4c3c7R3060-R3066

Other notes:

By default we will upcast to fp32 inside every kernel. This matches eager numerics. This is fine enough for epilogue because it is only done once (although it is probably unnecessary for say a relu) but tanks perf for prologue. I am currently using the `codegen_upcast_to_fp32` option to avoid it, but that will not work for libdevice calls that require fp32. We will need https://github.com/pytorch/pytorch/pull/136778/ and dtype-aware codegen to upcast fp16 ops into libdevice calls.

With prologue fusion, we now have essentially separate kernels for each input, and for the output. I had to increase the number of fields that are swapped out in `set_subgraph_body` by a large number :/ I also update the fusion logic because the inputs will have a different group than the outputs. Maybe as part of enabling multiple outputs, this could get cleaned up a bit so..

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134532
Approved by: https://github.com/jansel
2024-12-10 16:25:57 +00:00
eellison
fd35be2fd3 TritonTemplate dtype fixes (#141991)
- Set the dtype of "acc" appropriately so that epilogue fusion will have args with dtype
- Update dtype propagation to use `type_to_dtype` instead of instantiating tensor
- Throw if we have a string arg where we should have a proper CSEVariable, unless we're doing the Modification Subgraph thing which is nyi. everything else is appropriately typed (cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov @drisspg ).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141991
Approved by: https://github.com/drisspg
ghstack dependencies: #139945, #140057, #141495, #141882
2024-12-04 17:24:23 +00:00
Bin Bao
a51a048027 [AOTI][refactor] Move stack allocation related configs (#139093)
Summary: Move allow_stack_allocation and use_minimal_arrayref_interface configs into the aot_inductor subclass.

Test Plan: CI

Differential Revision: D65064301

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139093
Approved by: https://github.com/chenyang78
2024-12-04 00:15:19 +00:00
eellison
f83361b274 inductor dtype propagation fixes (#141495)
- Add in upcast_compute_type on creation of new tensors (loads, constants)
- Fixes index_expr - right now we are sort of inconsistent in dtype and dont always respect the dtype specified. would be nice to fix but not doing in this pr.
- bug fix in view dtype where we were always upcasting back to fp32 when input was in bf16/fp16. we should only be doing that if the output is also in bf16/fp16.
- for masked, avoid calling dtype propagation and just use output dtype.

Turns on the runtime dtype verification for opinfo tests. The separate test file is still useful because we can use it for testing turning off codegen_upcast_to_fp32.

Follow ups:

- We could consider requiring less explicit upcast_compute_types calls and do it automatically. That would potentially make things easier but be less flexible in the future. Maybe I should have done it this pr.
- Be more consistent on our index expr dtype printing.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141495
Approved by: https://github.com/blaine-rister, https://github.com/arui-meta, https://github.com/ezyang
ghstack dependencies: #139945, #140057
2024-11-28 11:39:38 +00:00
eellison
fd553b9817 Add remaining method and tests for dtype propagation (#140057)
Adds the remaining unimplemented ops as well as an assertion failure if someone adds a new op without a dtype rule.

We test all unique pointwise operators registered as lowerings which have an opinfo. There will be some follow ups for this to work well with both `codegen_upcast_to_fp32` as True and False.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140057
Approved by: https://github.com/arui-meta, https://github.com/blaine-rister, https://github.com/ezyang
ghstack dependencies: #139945
2024-11-27 17:06:44 +00:00
eellison
566ceb3e7e Refactor dtype propagation (#139945)
A couple changes.

- Tries to reuse dtype propagation rules that were already registered in inductor. These were present both with `pointwise_overrides_data` and the `boolean_ops` list. Additionally, the registration of pointwise ops already specified dtype propagation rules. Saves those registrations and reuses them later.

- Factors out `get_promoted_dtype` which uses functools.lru_cache to take in non - CSEVariable args because those will not work with the functools cache.

Tests get added later in the stack when everything is implemented.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139945
Approved by: https://github.com/blaine-rister, https://github.com/arui-meta, https://github.com/ezyang
2024-11-27 16:57:02 +00:00
Isuru Fernando
44186a0a4e Move Sympy printers to torch/utils/_sympy/printers.py (#140597)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140597
Approved by: https://github.com/ezyang, https://github.com/anijain2305
2024-11-26 18:11:00 +00:00
xinan.lin
4742080ed9 [AOTI XPU] Enable Cpp wraper for Intel GPU. (#135318)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135318
Approved by: https://github.com/jgong5, https://github.com/EikanWang, https://github.com/guangyey, https://github.com/desertfire
2024-11-26 11:51:32 +00:00
PyTorch MergeBot
f23621ec56 Revert "Move Sympy printers to torch/utils/_sympy/printers.py (#140597)"
This reverts commit c25b201583.

Reverted https://github.com/pytorch/pytorch/pull/140597 on behalf of https://github.com/huydhn due to Trunk is sad again after this lands, this looks like a landrace this time, so please do a rebase ([comment](https://github.com/pytorch/pytorch/pull/140597#issuecomment-2494052978))
2024-11-22 15:43:39 +00:00
Isuru Fernando
c25b201583 Move Sympy printers to torch/utils/_sympy/printers.py (#140597)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140597
Approved by: https://github.com/ezyang, https://github.com/anijain2305
2024-11-22 02:04:36 +00:00
Jason Ansel
6eca0aee76 [inductor] Refactor ir.Layout into ir.OutputSpec (#140910)
This separate the concepts of a Layout (size/stride/etc) and an OutputSpec (which includes multiple outputs).  Which should make typing easier.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140910
Approved by: https://github.com/ezyang
ghstack dependencies: #140895
2024-11-21 20:01:57 +00:00
PyTorch MergeBot
701e06b643 Revert "Move Sympy printers to torch/utils/_sympy/printers.py (#140597)"
This reverts commit aefcdb3c9f.

Reverted https://github.com/pytorch/pytorch/pull/140597 on behalf of https://github.com/huydhn due to Sorry for reverting your change but I think it fails inductor/test_padding in trunk. This is a target determination miss and that failed test was not run in your PR ([comment](https://github.com/pytorch/pytorch/pull/140597#issuecomment-2489641453))
2024-11-20 22:13:57 +00:00
Isuru Fernando
aefcdb3c9f Move Sympy printers to torch/utils/_sympy/printers.py (#140597)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140597
Approved by: https://github.com/ezyang, https://github.com/anijain2305
2024-11-20 20:26:49 +00:00
Jason Ansel
808f0f656d [inductor] Refactor MutableBox to make IRNode typing easier (#140895)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140895
Approved by: https://github.com/ezyang, https://github.com/Skylion007
2024-11-20 19:50:46 +00:00
eellison
eff22171d2 Add Current Mask Var To CSE Cache Key (#140838)
This torch.cat kernel has multiple subblocks which load from the same input. We were incorrectly reusing the mask vars from the first load for the second load.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140838
Approved by: https://github.com/jansel
ghstack dependencies: #140841
2024-11-20 00:55:56 +00:00
PyTorch MergeBot
d472a5f680 Revert "[inductor] Refactor MutableBox to make IRNode typing easier (#140895)"
This reverts commit c79e78b503.

Reverted https://github.com/pytorch/pytorch/pull/140895 on behalf of https://github.com/huydhn due to Sorry for reverting your change but I think test_torchbind_inductor is failing in trunk after this lands ([comment](https://github.com/pytorch/pytorch/pull/140895#issuecomment-2484679319))
2024-11-19 04:25:41 +00:00
Jason Ansel
c79e78b503 [inductor] Refactor MutableBox to make IRNode typing easier (#140895)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140895
Approved by: https://github.com/ezyang, https://github.com/Skylion007
2024-11-19 00:24:35 +00:00
chilli
e9fb2c6abe Add some error messages for flexattention (#138891)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138891
Approved by: https://github.com/Skylion007, https://github.com/drisspg
2024-11-13 04:05:29 +00:00
Jason Ansel
ed30fa74ab [inductor] sympy.Integer([01]) -> sympy.S.(Zero|One) (#139523)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139523
Approved by: https://github.com/ezyang
ghstack dependencies: #139364, #139365, #139370, #139452
2024-11-04 04:28:40 +00:00
Jason Ansel
b6fb135c2c [inductor] Simplify remove_kernel_local_buffers (#139452)
I plan to reuse `can_buffer_be_removed_through_fusion` in some heuristics.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139452
Approved by: https://github.com/shunting314
ghstack dependencies: #139364, #139365, #139370
2024-11-04 04:28:40 +00:00
Jason Ansel
3d633f12ba [inductor] Move remove_kernel_local_buffers to Kernel (#139370)
This method mutates the kernel, so it fits better in that class.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139370
Approved by: https://github.com/shunting314
ghstack dependencies: #139364, #139365
2024-11-04 04:28:33 +00:00
Jason Ansel
d189f92eb1 [inductor] Remove SIMDKernel.last_usage (#139364)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139364
Approved by: https://github.com/eellison, https://github.com/shunting314
2024-11-04 04:28:18 +00:00
PyTorch MergeBot
0863d6a08e Revert "[inductor] Remove SIMDKernel.last_usage (#139364)"
This reverts commit 286d3ce266.

Reverted https://github.com/pytorch/pytorch/pull/139364 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is failing lots of internal tests in D65345157 ([comment](https://github.com/pytorch/pytorch/pull/139364#issuecomment-2452897337))
2024-11-02 06:49:11 +00:00
PyTorch MergeBot
dc4b459737 Revert "[inductor] Move remove_kernel_local_buffers to Kernel (#139370)"
This reverts commit b57b4b7f9b.

Reverted https://github.com/pytorch/pytorch/pull/139370 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is failing lots of internal tests in D65345157 ([comment](https://github.com/pytorch/pytorch/pull/139364#issuecomment-2452897337))
2024-11-02 06:49:10 +00:00
PyTorch MergeBot
66a401c9e1 Revert "[inductor] Simplify remove_kernel_local_buffers (#139452)"
This reverts commit 73c0762a34.

Reverted https://github.com/pytorch/pytorch/pull/139452 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is failing lots of internal tests in D65345157 ([comment](https://github.com/pytorch/pytorch/pull/139364#issuecomment-2452897337))
2024-11-02 06:49:10 +00:00
PyTorch MergeBot
98e11b0021 Revert "[inductor] sympy.Integer([01]) -> sympy.S.(Zero|One) (#139523)"
This reverts commit c53beab377.

Reverted https://github.com/pytorch/pytorch/pull/139523 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is failing lots of internal tests in D65345157 ([comment](https://github.com/pytorch/pytorch/pull/139364#issuecomment-2452897337))
2024-11-02 06:49:10 +00:00
Jason Ansel
c53beab377 [inductor] sympy.Integer([01]) -> sympy.S.(Zero|One) (#139523)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139523
Approved by: https://github.com/ezyang
ghstack dependencies: #139364, #139365, #139370, #139452
2024-11-02 03:04:22 +00:00
Jason Ansel
73c0762a34 [inductor] Simplify remove_kernel_local_buffers (#139452)
I plan to reuse `can_buffer_be_removed_through_fusion` in some heuristics.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139452
Approved by: https://github.com/shunting314
ghstack dependencies: #139364, #139365, #139370
2024-11-01 20:36:39 +00:00
Jason Ansel
b57b4b7f9b [inductor] Move remove_kernel_local_buffers to Kernel (#139370)
This method mutates the kernel, so it fits better in that class.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139370
Approved by: https://github.com/shunting314
ghstack dependencies: #139364, #139365
2024-11-01 16:28:15 +00:00
Jason Ansel
286d3ce266 [inductor] Remove SIMDKernel.last_usage (#139364)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139364
Approved by: https://github.com/eellison, https://github.com/shunting314
2024-11-01 16:28:15 +00:00
drisspg
a884462bca Add workspace to TritonTemplates (#138050)
Here's a markdown summary for the PR:

# Add workspace buffer support for Triton templates

## Summary
Adds support for templates to allocate and use temporary workspace buffers

## Key Changes
- Add `WorkspaceArg` support in Triton template system
- Automatic workspace allocation/deallocation around kernel execution
- Zero-initialization support for workspace buffers
- Seamless integration with existing tensor management

## Example Usage
```python
def generate(self, ...):
    workspace_arg = WorkspaceArg(
        count=1024*1024,  # 1MB workspace
        zero_fill=True    # Zero-initialized
    )

    return TritonTemplateCaller(..., workspace_arg=workspace_arg)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138050
Approved by: https://github.com/Chillee, https://github.com/eellison
2024-10-29 18:17:54 +00:00
Jason Ansel
a762dc0357 [inductor] Multi-kernel + cooperative reductions (#138893)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138893
Approved by: https://github.com/shunting314
ghstack dependencies: #138533
2024-10-29 15:45:17 +00:00
Jason Ansel
2b937e4e6d [inductor] Cooperative reductions (#137756)
Example generated code for `(x+y).sum()`:
```py
@triton.jit
def triton_unk_fused_add_sum_0(in_ptr0, in_ptr1, out_ptr0, ws_ptr, semaphores_ptr, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr, RSPLIT : tl.constexpr):
    xnumel = 1
    rnumel = 1048576
    rsplit_id = tl.program_id(0)
    num_rblocks = (rnumel + RBLOCK - 1) // RBLOCK
    rsplit_chunk = (num_rblocks + RSPLIT - 1) // RSPLIT * RBLOCK
    rsplit_start = rsplit_chunk * rsplit_id
    rsplit_end = rsplit_chunk * (rsplit_id + 1)
    xoffset = tl.program_id(1) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
    rbase = tl.arange(0, RBLOCK)[None, :]
    _tmp4 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
    for roffset in range(rsplit_start, rsplit_end, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r0 = rindex
        tmp0 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0)
        tmp1 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_first', other=0.0)
        tmp2 = tmp0 + tmp1
        tmp3 = tl.broadcast_to(tmp2, [XBLOCK, RBLOCK])
        tmp5 = _tmp4 + tmp3
        _tmp4 = tl.where(rmask, tmp5, _tmp4)
    tmp4 = tl.sum(_tmp4, 1)[:, None]
    if RSPLIT > 1:
        tmp4_ws = (ws_ptr + 0).to(tl.pointer_type(tl.float32))
        tl.store(tmp4_ws + (xindex * RSPLIT + rsplit_id), tmp4, None)
    if RSPLIT > 1:
        triton_helpers.gpu_barrier(semaphores_ptr + (2 * tl.program_id(1) + 0), RSPLIT, True)
    if RSPLIT > 1:
        tmp4_peers = tl.load(tmp4_ws + (xindex * RSPLIT + tl.arange(0, RSPLIT)[None,:]), None, eviction_policy='evict_first')
        tmp4 = tl.sum(tmp4_peers, 1)[:, None]
    if rsplit_id == (0 % RSPLIT):
        tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp4, None)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137756
Approved by: https://github.com/eellison
2024-10-29 00:45:53 +00:00
Adnan Akhundov
ab09c4d913 Add host-side TMA support to AOTInductor (#138878)
This adds host-side Triton TMA support to AOTInductor. Notes:

- Two helper functions, `init1DTMADescriptor` and `init2DTMADescriptor` are added to the C++ wrapper codegen on GPU, conditioned on the model having user-defined Triton kernels with host-side TMA (CUDA-specific).
- C++ wrapper codegen on GPU emits TMA descriptor initialization via the aforementioned helper functions.
- Special handling added for the TMA descriptors (in the Python wrapper codegen) during the compile-time autotuning, as the underlying tensor can't be passed directly to the user-defined Triton kernel. TMA descriptors are generated in-between the source tensor's buffer and the kernel call, like in the full Python wrapper codegen.
- This PR concludes the host-side Triton TMA support in PT2.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138878
Approved by: https://github.com/desertfire, https://github.com/chenyang78
ghstack dependencies: #138759, #138877
2024-10-28 23:39:53 +00:00
PyTorch MergeBot
60d1c7138d Revert "[inductor] Cooperative reductions (#137756)"
This reverts commit fed37dbfbc.

Reverted https://github.com/pytorch/pytorch/pull/137756 on behalf of https://github.com/jeanschmidt due to ROCM tests are timing out :( ([comment](https://github.com/pytorch/pytorch/pull/137756#issuecomment-2441579322))
2024-10-28 13:24:33 +00:00
Jason Ansel
fed37dbfbc [inductor] Cooperative reductions (#137756)
Example generated code for `(x+y).sum()`:
```py
@triton.jit
def triton_unk_fused_add_sum_0(in_ptr0, in_ptr1, out_ptr0, ws_ptr, semaphores_ptr, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr, RSPLIT : tl.constexpr):
    xnumel = 1
    rnumel = 1048576
    rsplit_id = tl.program_id(0)
    num_rblocks = (rnumel + RBLOCK - 1) // RBLOCK
    rsplit_chunk = (num_rblocks + RSPLIT - 1) // RSPLIT * RBLOCK
    rsplit_start = rsplit_chunk * rsplit_id
    rsplit_end = rsplit_chunk * (rsplit_id + 1)
    xoffset = tl.program_id(1) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
    rbase = tl.arange(0, RBLOCK)[None, :]
    _tmp4 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
    for roffset in range(rsplit_start, rsplit_end, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r0 = rindex
        tmp0 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0)
        tmp1 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_first', other=0.0)
        tmp2 = tmp0 + tmp1
        tmp3 = tl.broadcast_to(tmp2, [XBLOCK, RBLOCK])
        tmp5 = _tmp4 + tmp3
        _tmp4 = tl.where(rmask, tmp5, _tmp4)
    tmp4 = tl.sum(_tmp4, 1)[:, None]
    if RSPLIT > 1:
        tmp4_ws = (ws_ptr + 0).to(tl.pointer_type(tl.float32))
        tl.store(tmp4_ws + (xindex * RSPLIT + rsplit_id), tmp4, None)
    if RSPLIT > 1:
        triton_helpers.gpu_barrier(semaphores_ptr + (2 * tl.program_id(1) + 0), RSPLIT, True)
    if RSPLIT > 1:
        tmp4_peers = tl.load(tmp4_ws + (xindex * RSPLIT + tl.arange(0, RSPLIT)[None,:]), None, eviction_policy='evict_first')
        tmp4 = tl.sum(tmp4_peers, 1)[:, None]
    if rsplit_id == (0 % RSPLIT):
        tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp4, None)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137756
Approved by: https://github.com/eellison
ghstack dependencies: #138970
2024-10-27 16:31:38 +00:00
Xinran / Allan Rui
ba6526814a Add dtype attribute to CSEVariable (#136778)
Summary:
- This diff introduces `dtype` attribute to `TritonCSEVariable` and a dtype propagation helper function to infer dtype from input to output for each op.

- There will be a follow-up diff that uses this `dtype` information in `TritonCSEVariable` to perform dtype-aware codegen.

Test Plan: CI

Differential Revision: D61815079

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136778
Approved by: https://github.com/eellison, https://github.com/blaine-rister
2024-10-25 18:00:30 +00:00
Adam Mainz
d0640b945b [inductor][nit] removing unnecessary else statements (#138789)
Summary: while reading through inductor template code I found a few places where else statements were driving me crazy. Fixing them as I read

Test Plan: CI

Differential Revision: D64882385

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138789
Approved by: https://github.com/aakhundov
2024-10-25 17:59:25 +00:00
PyTorch MergeBot
3b186c5659 Revert "[AOTI] Fix test_index_put_with_none_index_cpu_with_stack_allocation (#138303)"
This reverts commit 1417b2cd05.

Reverted https://github.com/pytorch/pytorch/pull/138303 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/138303#issuecomment-2427991065))
2024-10-22 00:46:48 +00:00
Bin Bao
1417b2cd05 [AOTI] Fix test_index_put_with_none_index_cpu_with_stack_allocation (#138303)
Summary: The problem happened after splitting CppWrapperCpu and CppWrapperCpuArrayRef, because CppWrapperCpuArrayRef.generate_index_put_fallback missed a statement. Running test_aot_inductor.py as a whole didn't reveal the problem, but running test_index_put_with_none_index_cpu_with_stack_allocation individually did. Digging deeper, the root cause is init_backend_registration has incorrectly cached CPU CppWrapperCodegen class, which means CppWrapperCpuArrayRef was never picked when running test_aot_inductor.py as a whole.

Differential Revision: [D64598714](https://our.internmc.facebook.com/intern/diff/D64598714)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138303
Approved by: https://github.com/hl475
2024-10-21 13:47:50 +00:00
Jason Ansel
4632594546 [inductor] Move V.graph.scheduler.current_device to V.graph.current_device (#138252)
There are some places where it would be nice to use this, but the scheduler hasn't yet been created.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138252
Approved by: https://github.com/eellison
ghstack dependencies: #138170
2024-10-18 23:05:54 +00:00
Jason Ansel
85a6a782e5 [inductor] Generalize WorkspaceArg for graph-level semaphores (#138170)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138170
Approved by: https://github.com/Chillee
2024-10-18 23:05:54 +00:00
Adnan Akhundov
d116d007ee Add host-side Triton TMA support to Inductor (#137950)
This adds Dynamo tracing support for the host-side Triton TMA API (see `create_2d_tma_descriptor` calls on the host in the [Triton tutorial](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#sphx-glr-getting-started-tutorials-09-persistent-matmul-py)). A few notes:

- Here we assume the availability of the host-side TMA API added to upstream Triton in https://github.com/triton-lang/triton/pull/4498. As of time of writing, this is not a part of the PT2 OSS Triton pin (although back-ported internally). OSS Triton pin update should be done in December 2024.
- Due to Dynamo support implemented in the previous PR, the `tma_descriptor_metadata` dict is delivered to the `triton_kerenl_wrap_` lowering and passed to the `ir.UserDefinedTritonKernel` as additional argument.
- Looking into the `tma_descriptor_metadata`, `ir.UserDefinedTritonKernel` substitutes the corresponding `TensorBox` arguments of the kernel (swapped upstream in Dynamo) by the new `ir.TMADescriptor` nodes implementing TMA descriptors in Inductor IR.
- `ir.TMADescriptor.__init__` provides the wiring between the upstream underlying `ir.TensorBox` and the downstream `ir.UserDefinedTritonKernel` kernel. In particular, we use `ir.NonOwnedLayout` wrapping `ir.ReinterpretView` to avoid the upstream tensor's buffer being deleted prematurely (before the TMA descriptor is used in the Triton kernel).
- Via `ir.TMADescriptor.codegen`, the Triton's `create_{1d,2d}_tma_descriptor` function call is codegened in the wrapper (in the host code).
- New `TMADescriptorArg` dataclass is added to handle the Triton kernel metadata pertinent to host-side TMA.
- AOT Inductor support will be implemented in a follow-up PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137950
Approved by: https://github.com/eellison
ghstack dependencies: #137677
2024-10-18 06:27:24 +00:00
Bin Bao
2e67d7cc35 [AOTI] Remove the non-ABI-compatible mode (part 1) (#138009)
Summary: The ABI-compatible mode has been turned on as default in https://github.com/pytorch/pytorch/pull/136534. Removing the non-ABI-compatible logic to greatly simplify the wrapper codegen logic.

Differential Revision: [D64439676](https://our.internmc.facebook.com/intern/diff/D64439676)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138009
Approved by: https://github.com/chenyang78
ghstack dependencies: #137982, #138016
2024-10-17 02:48:26 +00:00
Benjamin Glass
a968576777 Add lowering for aten.searchsorted (#135701)
Adds lowering for `aten.searchsorted`. This entails:

1. Adding support for multi-dimensional bucket tensors to `ops.bucketize`.
2. Adding support for striding to `ops.bucketize`.
3. Adding support for sorting tensors to `ops.bucketize`.
4. Adding a lowering for `aten.searchsorted.Tensor`.
5. Adding a basic decomposition for `aten.searchsorted.Scalar` that calls into the lowering for tensors.
6. Updating the meta-function for `aten.searchsorted` to properly check some of the sizing conditions.

Closes #135873

Differential Revision: [D63766514](https://our.internmc.facebook.com/intern/diff/D63766514)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135701
Approved by: https://github.com/amjames, https://github.com/eellison, https://github.com/davidberard98
2024-10-04 19:26:05 +00:00