Commit Graph

518 Commits

Author SHA1 Message Date
xinan.lin
04312293a2 [Inductor] Fix wrong CSEVariable dtype for reduction. Fix #141861 (#142189)
Fix #141861

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142189
Approved by: https://github.com/jansel
2024-12-09 21:07:35 +00:00
Blaine Burton Rister
0c66cee9a2 [Inductor] Expand dtype aware codegen for libdevice and tl.math ops (#140864)
# Feature
Previously, only the codegen for `torch.sqrt` was dtype aware. This PR updates most of the `libdevice`/`tl.math` ops to support dtype-aware codegen as well. This is often necessary to get correct code when `config.triton.codegen_upcast_to_fp32=False`, as most Triton math ops do not support float16/bfloat16.

This PR enables dtype aware codegen via the `maybe_upcast_float32` decorator. This wraps `TritonOverrides` macros to upcast arguments to float32, and downcast the result back to the original dtype. The exception is for ops that return booleans, in which case we set `convert_output=False` and skip the output cast.

# Test Plan
Added CI tests for all the new ops. The list of ops to test is automatically generated based on uses of the `maybe_upcast_float32` decorator, and stored in the new `OpDtypeSupport` class. In each new test, we search the generated code for upcasts/downcasts using a regex.

Also added a unit test for `OpDtypeSupport` which checks that we have correct dtype info for ops that require upcasts.

This PR also moves some existing tests around, to collect all the dtype aware codegen tests in one file.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140864
Approved by: https://github.com/eellison, https://github.com/arui-meta

Co-authored-by: eellison <elias.ellison@gmail.com>
2024-12-08 19:42:48 +00:00
Jason Ansel
0367a31401 [inductor] Minor typing changes (#142219)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142219
Approved by: https://github.com/Skylion007, https://github.com/yanboliang
2024-12-07 17:48:37 +00:00
PyTorch MergeBot
fc831f76f8 Revert "[Inductor] Represent size_hints as a dict (#142249)"
This reverts commit f870ee2cc4.

Reverted https://github.com/pytorch/pytorch/pull/142249 on behalf of https://github.com/blaine-rister due to would break internal tests ([comment](https://github.com/pytorch/pytorch/pull/142249#issuecomment-2524991008))
2024-12-07 07:43:51 +00:00
Blaine Burton Rister
f870ee2cc4 [Inductor] Represent size_hints as a dict (#142249)
Preparatory refactor for https://github.com/pytorch/pytorch/pull/137243.

# Feature

Follow up to https://github.com/pytorch/pytorch/pull/141751. Since we now represent `numels` as a dict, it's natural to extend this to `size_hints`. The latter are basically just the former rounded up to the nearest power of 2. This simplifies various heuristics such as the coordinate descent tuner. Where we previously needed to determine which index in `size_hints` corresponds to each dimension, now we can just query by prefix. This will be especially important when we enable 2D reductions, as it becomes harder to keep track of these things when we have multiple reduction dimensions. (See the previous PR for some examples.)

# Test plan

The existing CI provides good coverage. This PR modifies a few tests which explicitly constructed size hints.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142249
Approved by: https://github.com/jansel
2024-12-07 06:43:05 +00:00
PyTorch MergeBot
f36cccba2e Revert "[Inductor] Expand dtype aware codegen for libdevice and tl.math ops (#140864)"
This reverts commit 80ca6dd892.

Reverted https://github.com/pytorch/pytorch/pull/140864 on behalf of https://github.com/atalman due to failing internally ([comment](https://github.com/pytorch/pytorch/pull/140864#issuecomment-2524168602))
2024-12-06 21:03:06 +00:00
Blaine Burton Rister
80ca6dd892 [Inductor] Expand dtype aware codegen for libdevice and tl.math ops (#140864)
# Feature
Previously, only the codegen for `torch.sqrt` was dtype aware. This PR updates most of the `libdevice`/`tl.math` ops to support dtype-aware codegen as well. This is often necessary to get correct code when `config.triton.codegen_upcast_to_fp32=False`, as most Triton math ops do not support float16/bfloat16.

This PR enables dtype aware codegen via the `maybe_upcast_float32` decorator. This wraps `TritonOverrides` macros to upcast arguments to float32, and downcast the result back to the original dtype. The exception is for ops that return booleans, in which case we set `convert_output=False` and skip the output cast.

# Test Plan
Added CI tests for all the new ops. The list of ops to test is automatically generated based on uses of the `maybe_upcast_float32` decorator, and stored in the new `OpDtypeSupport` class. In each new test, we search the generated code for upcasts/downcasts using a regex.

Also added a unit test for `OpDtypeSupport` which checks that we have correct dtype info for ops that require upcasts.

This PR also moves some existing tests around, to collect all the dtype aware codegen tests in one file.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140864
Approved by: https://github.com/eellison, https://github.com/arui-meta

Co-authored-by: eellison <elias.ellison@gmail.com>
2024-12-06 03:15:20 +00:00
Blaine Burton Rister
f9af86de01 [Inductor] Represent tiling as a dict (#141751)
# Summary

Preparatory refactor for https://github.com/pytorch/pytorch/pull/137243. This makes it easier to generalize to multi-dimensional reductions.

This diff refactors `self.numels` from a tuple like `(8,16)` to a dict like `{"x": 8, "r": 16}`.

Note: this is based off of https://github.com/pytorch/pytorch/pull/141738, which enables `tree.is_reduction`. That PR should land first.

# Test plan
The existing CI provides good coverage.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141751
Approved by: https://github.com/jansel
2024-12-05 02:28:16 +00:00
eellison
fd35be2fd3 TritonTemplate dtype fixes (#141991)
- Set the dtype of "acc" appropriately so that epilogue fusion will have args with dtype
- Update dtype propagation to use `type_to_dtype` instead of instantiating tensor
- Throw if we have a string arg where we should have a proper CSEVariable, unless we're doing the Modification Subgraph thing which is nyi. everything else is appropriately typed (cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov @drisspg ).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141991
Approved by: https://github.com/drisspg
ghstack dependencies: #139945, #140057, #141495, #141882
2024-12-04 17:24:23 +00:00
PyTorch MergeBot
38d10a1b17 Revert "[Inductor] Represent tiling as a dict (#141751)"
This reverts commit 5deca07c0d.

Reverted https://github.com/pytorch/pytorch/pull/141751 on behalf of https://github.com/atalman due to Failing internal builds ([comment](https://github.com/pytorch/pytorch/pull/141751#issuecomment-2517815899))
2024-12-04 15:43:16 +00:00
Mwiza Kunda
f0b33658f8 Dont use constant mask if ynumel potentially overflows ygrids (#139751)
If (ynumel / YBLOCK)  > get_max_ygrids(), the z dimension will be used if znumel is None. However, if (ynumel / YBLOCK) % get_max_ygrids() != 0, there will be program launches with inputs that require masking, and so this needs to be considered when determining if the y dimension has a constant mask.

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139751
Approved by: https://github.com/eellison

Co-authored-by: George White <georgew@graphcore.ai>
2024-12-03 22:56:18 +00:00
Jason Ansel
b2fe1b9409 [inductor] Fix 3d tiling (#141709)
Fixes #141121

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141709
Approved by: https://github.com/eellison
2024-12-01 19:47:41 +00:00
Blaine Burton Rister
5deca07c0d [Inductor] Represent tiling as a dict (#141751)
# Summary

Preparatory refactor for https://github.com/pytorch/pytorch/pull/137243. This makes it easier to generalize to multi-dimensional reductions.

This diff refactors `self.numels` from a tuple like `(8,16)` to a dict like `{"x": 8, "r": 16}`.

Note: this is based off of https://github.com/pytorch/pytorch/pull/141738, which enables `tree.is_reduction`. That PR should land first.

# Test plan
The existing CI provides good coverage.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141751
Approved by: https://github.com/jansel
2024-12-01 09:54:34 +00:00
Blaine Burton Rister
c2fa544472 [Inductor] move block pointer analysis to a new module (#141733)
# Summary

Preparatory refactor for https://github.com/pytorch/pytorch/pull/137243. This refactors the ModularIndexing block pointer analysis into its own module. That way, we can call it from other places besides Triton codegen. In the parent PR, we will use this to find tiling splits that simplify the indexing.

# Test plan

Tested by the existing CI.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141733
Approved by: https://github.com/jansel
2024-11-30 23:21:24 +00:00
Blaine Burton Rister
49fde426ba [Inductor] Use a helper function to tell if a tree or prefix is a reduction (#141738)
Preparatory refactor for https://github.com/pytorch/pytorch/pull/137243. Previously, we would typically check for reductions by `tree.prefix == "r"`. This PR moves the check into a helper function. This makes it easier to generalize the code to multi-dimensional reductions, which could have multiple prefixes like `("r0_", "r1_")`.

Tested by the existing CI.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141738
Approved by: https://github.com/jansel
2024-11-30 22:38:13 +00:00
Nan Zhang
5aacfa037b [Inductor] fix broadcast logic for Triton (#141027) (#141693)
Summary:

Fix logic for inserting broadcast on kernel with load going directly to store. In the case where load is going directly to store, we insert a tl.broadcast on the store, regardless of the block size on the load. In the case where a broadcast is not required, the downstream Triton compiler is expected to remove this no-op broadcast instruction.

Test Plan: Added tests under test_torchinductor_strided_blocks.py:test_expand_broadcast in OSS and internal test cases.

Reviewed By: blaine-rister

Differential Revision: D65518033

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141693
Approved by: https://github.com/blaine-rister
2024-11-28 16:38:25 +00:00
eellison
f83361b274 inductor dtype propagation fixes (#141495)
- Add in upcast_compute_type on creation of new tensors (loads, constants)
- Fixes index_expr - right now we are sort of inconsistent in dtype and dont always respect the dtype specified. would be nice to fix but not doing in this pr.
- bug fix in view dtype where we were always upcasting back to fp32 when input was in bf16/fp16. we should only be doing that if the output is also in bf16/fp16.
- for masked, avoid calling dtype propagation and just use output dtype.

Turns on the runtime dtype verification for opinfo tests. The separate test file is still useful because we can use it for testing turning off codegen_upcast_to_fp32.

Follow ups:

- We could consider requiring less explicit upcast_compute_types calls and do it automatically. That would potentially make things easier but be less flexible in the future. Maybe I should have done it this pr.
- Be more consistent on our index expr dtype printing.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141495
Approved by: https://github.com/blaine-rister, https://github.com/arui-meta, https://github.com/ezyang
ghstack dependencies: #139945, #140057
2024-11-28 11:39:38 +00:00
PyTorch MergeBot
b33f770574 Revert "[inductor] Fix 3d tiling (#141709)"
This reverts commit ca9bfa1a38.

Reverted https://github.com/pytorch/pytorch/pull/141709 on behalf of https://github.com/huydhn due to Sorry for reverting your change but there is one failed test showing up in trunk.  It was missed by target determination ([comment](https://github.com/pytorch/pytorch/pull/141709#issuecomment-2505213481))
2024-11-28 03:55:31 +00:00
Jason Ansel
ca9bfa1a38 [inductor] Fix 3d tiling (#141709)
Fixes #141121

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141709
Approved by: https://github.com/eellison
2024-11-28 01:34:28 +00:00
eellison
fd553b9817 Add remaining method and tests for dtype propagation (#140057)
Adds the remaining unimplemented ops as well as an assertion failure if someone adds a new op without a dtype rule.

We test all unique pointwise operators registered as lowerings which have an opinfo. There will be some follow ups for this to work well with both `codegen_upcast_to_fp32` as True and False.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140057
Approved by: https://github.com/arui-meta, https://github.com/blaine-rister, https://github.com/ezyang
ghstack dependencies: #139945
2024-11-27 17:06:44 +00:00
eellison
566ceb3e7e Refactor dtype propagation (#139945)
A couple changes.

- Tries to reuse dtype propagation rules that were already registered in inductor. These were present both with `pointwise_overrides_data` and the `boolean_ops` list. Additionally, the registration of pointwise ops already specified dtype propagation rules. Saves those registrations and reuses them later.

- Factors out `get_promoted_dtype` which uses functools.lru_cache to take in non - CSEVariable args because those will not work with the functools cache.

Tests get added later in the stack when everything is implemented.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139945
Approved by: https://github.com/blaine-rister, https://github.com/arui-meta, https://github.com/ezyang
2024-11-27 16:57:02 +00:00
Isuru Fernando
44186a0a4e Move Sympy printers to torch/utils/_sympy/printers.py (#140597)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140597
Approved by: https://github.com/ezyang, https://github.com/anijain2305
2024-11-26 18:11:00 +00:00
Jason Ansel
995e3079c9 [inductor] Fix for "Failed to find static RBLOCK" (#141434)
Summary: I expect this to fix https://fb.workplace.com/groups/1075192433118967/permalink/1547962839175255/

Test Plan: Ask poster to confirm fix

Differential Revision: D66413828

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141434
Approved by: https://github.com/ezyang
2024-11-23 22:08:56 +00:00
PyTorch MergeBot
f23621ec56 Revert "Move Sympy printers to torch/utils/_sympy/printers.py (#140597)"
This reverts commit c25b201583.

Reverted https://github.com/pytorch/pytorch/pull/140597 on behalf of https://github.com/huydhn due to Trunk is sad again after this lands, this looks like a landrace this time, so please do a rebase ([comment](https://github.com/pytorch/pytorch/pull/140597#issuecomment-2494052978))
2024-11-22 15:43:39 +00:00
Isuru Fernando
c25b201583 Move Sympy printers to torch/utils/_sympy/printers.py (#140597)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140597
Approved by: https://github.com/ezyang, https://github.com/anijain2305
2024-11-22 02:04:36 +00:00
PyTorch MergeBot
701e06b643 Revert "Move Sympy printers to torch/utils/_sympy/printers.py (#140597)"
This reverts commit aefcdb3c9f.

Reverted https://github.com/pytorch/pytorch/pull/140597 on behalf of https://github.com/huydhn due to Sorry for reverting your change but I think it fails inductor/test_padding in trunk. This is a target determination miss and that failed test was not run in your PR ([comment](https://github.com/pytorch/pytorch/pull/140597#issuecomment-2489641453))
2024-11-20 22:13:57 +00:00
Isuru Fernando
aefcdb3c9f Move Sympy printers to torch/utils/_sympy/printers.py (#140597)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140597
Approved by: https://github.com/ezyang, https://github.com/anijain2305
2024-11-20 20:26:49 +00:00
Jason Ansel
808f0f656d [inductor] Refactor MutableBox to make IRNode typing easier (#140895)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140895
Approved by: https://github.com/ezyang, https://github.com/Skylion007
2024-11-20 19:50:46 +00:00
eellison
eff22171d2 Add Current Mask Var To CSE Cache Key (#140838)
This torch.cat kernel has multiple subblocks which load from the same input. We were incorrectly reusing the mask vars from the first load for the second load.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140838
Approved by: https://github.com/jansel
ghstack dependencies: #140841
2024-11-20 00:55:56 +00:00
Jason Ansel
2c6bd9f6f6 [inductor] Support fixed triton configs defined at compile time (#140217)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140217
Approved by: https://github.com/shunting314
ghstack dependencies: #139585
2024-11-17 16:10:37 +00:00
Jason Ansel
318eaa2be7 [inductor] Refactor reduction type choices into V.choices (#139585)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139585
Approved by: https://github.com/shunting314
2024-11-17 16:10:37 +00:00
PyTorch MergeBot
069a71023b Revert "[inductor] Refactor reduction type choices into V.choices (#139585)"
This reverts commit 6438c8637a.

Reverted https://github.com/pytorch/pytorch/pull/139585 on behalf of https://github.com/kit1980 due to breaking internal builds, see D65800124 ([comment](https://github.com/pytorch/pytorch/pull/139585#issuecomment-2471392822))
2024-11-12 19:32:14 +00:00
PyTorch MergeBot
c0ddd10f6d Revert "[inductor] Support fixed triton configs defined at compile time (#140217)"
This reverts commit 29114e44fa.

Reverted https://github.com/pytorch/pytorch/pull/140217 on behalf of https://github.com/kit1980 due to breaking internal builds, see D65800124 ([comment](https://github.com/pytorch/pytorch/pull/139585#issuecomment-2471392822))
2024-11-12 19:32:14 +00:00
Jason Ansel
29114e44fa [inductor] Support fixed triton configs defined at compile time (#140217)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140217
Approved by: https://github.com/shunting314
ghstack dependencies: #139585
2024-11-12 00:56:02 +00:00
Jason Ansel
6438c8637a [inductor] Refactor reduction type choices into V.choices (#139585)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139585
Approved by: https://github.com/shunting314
2024-11-12 00:56:02 +00:00
Jason Ansel
ed30fa74ab [inductor] sympy.Integer([01]) -> sympy.S.(Zero|One) (#139523)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139523
Approved by: https://github.com/ezyang
ghstack dependencies: #139364, #139365, #139370, #139452
2024-11-04 04:28:40 +00:00
Jason Ansel
d189f92eb1 [inductor] Remove SIMDKernel.last_usage (#139364)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139364
Approved by: https://github.com/eellison, https://github.com/shunting314
2024-11-04 04:28:18 +00:00
PyTorch MergeBot
0863d6a08e Revert "[inductor] Remove SIMDKernel.last_usage (#139364)"
This reverts commit 286d3ce266.

Reverted https://github.com/pytorch/pytorch/pull/139364 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is failing lots of internal tests in D65345157 ([comment](https://github.com/pytorch/pytorch/pull/139364#issuecomment-2452897337))
2024-11-02 06:49:11 +00:00
PyTorch MergeBot
98e11b0021 Revert "[inductor] sympy.Integer([01]) -> sympy.S.(Zero|One) (#139523)"
This reverts commit c53beab377.

Reverted https://github.com/pytorch/pytorch/pull/139523 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is failing lots of internal tests in D65345157 ([comment](https://github.com/pytorch/pytorch/pull/139364#issuecomment-2452897337))
2024-11-02 06:49:10 +00:00
Jason Ansel
c53beab377 [inductor] sympy.Integer([01]) -> sympy.S.(Zero|One) (#139523)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139523
Approved by: https://github.com/ezyang
ghstack dependencies: #139364, #139365, #139370, #139452
2024-11-02 03:04:22 +00:00
Jason Ansel
286d3ce266 [inductor] Remove SIMDKernel.last_usage (#139364)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139364
Approved by: https://github.com/eellison, https://github.com/shunting314
2024-11-01 16:28:15 +00:00
Jason Ansel
f9ef880c0b [inductor] Refactor kernel args into SIMDKernelFeatures (#139327)
This is a refactor PR to move stuff around.  I'm planning to use the SIMDKernelFeatures class (in a future PR) to host new heuristics for selecting kernel types and block sizes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139327
Approved by: https://github.com/eellison, https://github.com/shunting314
2024-11-01 00:30:14 +00:00
drisspg
a884462bca Add workspace to TritonTemplates (#138050)
Here's a markdown summary for the PR:

# Add workspace buffer support for Triton templates

## Summary
Adds support for templates to allocate and use temporary workspace buffers

## Key Changes
- Add `WorkspaceArg` support in Triton template system
- Automatic workspace allocation/deallocation around kernel execution
- Zero-initialization support for workspace buffers
- Seamless integration with existing tensor management

## Example Usage
```python
def generate(self, ...):
    workspace_arg = WorkspaceArg(
        count=1024*1024,  # 1MB workspace
        zero_fill=True    # Zero-initialized
    )

    return TritonTemplateCaller(..., workspace_arg=workspace_arg)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138050
Approved by: https://github.com/Chillee, https://github.com/eellison
2024-10-29 18:17:54 +00:00
Jason Ansel
a762dc0357 [inductor] Multi-kernel + cooperative reductions (#138893)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138893
Approved by: https://github.com/shunting314
ghstack dependencies: #138533
2024-10-29 15:45:17 +00:00
Jason Ansel
77b0ae832d [inductor] Allow cooperative + persistent reductions (#138533)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138533
Approved by: https://github.com/shunting314, https://github.com/eellison
2024-10-29 15:45:17 +00:00
Jason Ansel
2b937e4e6d [inductor] Cooperative reductions (#137756)
Example generated code for `(x+y).sum()`:
```py
@triton.jit
def triton_unk_fused_add_sum_0(in_ptr0, in_ptr1, out_ptr0, ws_ptr, semaphores_ptr, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr, RSPLIT : tl.constexpr):
    xnumel = 1
    rnumel = 1048576
    rsplit_id = tl.program_id(0)
    num_rblocks = (rnumel + RBLOCK - 1) // RBLOCK
    rsplit_chunk = (num_rblocks + RSPLIT - 1) // RSPLIT * RBLOCK
    rsplit_start = rsplit_chunk * rsplit_id
    rsplit_end = rsplit_chunk * (rsplit_id + 1)
    xoffset = tl.program_id(1) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
    rbase = tl.arange(0, RBLOCK)[None, :]
    _tmp4 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
    for roffset in range(rsplit_start, rsplit_end, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r0 = rindex
        tmp0 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0)
        tmp1 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_first', other=0.0)
        tmp2 = tmp0 + tmp1
        tmp3 = tl.broadcast_to(tmp2, [XBLOCK, RBLOCK])
        tmp5 = _tmp4 + tmp3
        _tmp4 = tl.where(rmask, tmp5, _tmp4)
    tmp4 = tl.sum(_tmp4, 1)[:, None]
    if RSPLIT > 1:
        tmp4_ws = (ws_ptr + 0).to(tl.pointer_type(tl.float32))
        tl.store(tmp4_ws + (xindex * RSPLIT + rsplit_id), tmp4, None)
    if RSPLIT > 1:
        triton_helpers.gpu_barrier(semaphores_ptr + (2 * tl.program_id(1) + 0), RSPLIT, True)
    if RSPLIT > 1:
        tmp4_peers = tl.load(tmp4_ws + (xindex * RSPLIT + tl.arange(0, RSPLIT)[None,:]), None, eviction_policy='evict_first')
        tmp4 = tl.sum(tmp4_peers, 1)[:, None]
    if rsplit_id == (0 % RSPLIT):
        tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp4, None)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137756
Approved by: https://github.com/eellison
2024-10-29 00:45:53 +00:00
PyTorch MergeBot
60d1c7138d Revert "[inductor] Cooperative reductions (#137756)"
This reverts commit fed37dbfbc.

Reverted https://github.com/pytorch/pytorch/pull/137756 on behalf of https://github.com/jeanschmidt due to ROCM tests are timing out :( ([comment](https://github.com/pytorch/pytorch/pull/137756#issuecomment-2441579322))
2024-10-28 13:24:33 +00:00
Jason Ansel
fed37dbfbc [inductor] Cooperative reductions (#137756)
Example generated code for `(x+y).sum()`:
```py
@triton.jit
def triton_unk_fused_add_sum_0(in_ptr0, in_ptr1, out_ptr0, ws_ptr, semaphores_ptr, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr, RSPLIT : tl.constexpr):
    xnumel = 1
    rnumel = 1048576
    rsplit_id = tl.program_id(0)
    num_rblocks = (rnumel + RBLOCK - 1) // RBLOCK
    rsplit_chunk = (num_rblocks + RSPLIT - 1) // RSPLIT * RBLOCK
    rsplit_start = rsplit_chunk * rsplit_id
    rsplit_end = rsplit_chunk * (rsplit_id + 1)
    xoffset = tl.program_id(1) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
    rbase = tl.arange(0, RBLOCK)[None, :]
    _tmp4 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
    for roffset in range(rsplit_start, rsplit_end, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r0 = rindex
        tmp0 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0)
        tmp1 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_first', other=0.0)
        tmp2 = tmp0 + tmp1
        tmp3 = tl.broadcast_to(tmp2, [XBLOCK, RBLOCK])
        tmp5 = _tmp4 + tmp3
        _tmp4 = tl.where(rmask, tmp5, _tmp4)
    tmp4 = tl.sum(_tmp4, 1)[:, None]
    if RSPLIT > 1:
        tmp4_ws = (ws_ptr + 0).to(tl.pointer_type(tl.float32))
        tl.store(tmp4_ws + (xindex * RSPLIT + rsplit_id), tmp4, None)
    if RSPLIT > 1:
        triton_helpers.gpu_barrier(semaphores_ptr + (2 * tl.program_id(1) + 0), RSPLIT, True)
    if RSPLIT > 1:
        tmp4_peers = tl.load(tmp4_ws + (xindex * RSPLIT + tl.arange(0, RSPLIT)[None,:]), None, eviction_policy='evict_first')
        tmp4 = tl.sum(tmp4_peers, 1)[:, None]
    if rsplit_id == (0 % RSPLIT):
        tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp4, None)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137756
Approved by: https://github.com/eellison
ghstack dependencies: #138970
2024-10-27 16:31:38 +00:00
Xinran / Allan Rui
ba6526814a Add dtype attribute to CSEVariable (#136778)
Summary:
- This diff introduces `dtype` attribute to `TritonCSEVariable` and a dtype propagation helper function to infer dtype from input to output for each op.

- There will be a follow-up diff that uses this `dtype` information in `TritonCSEVariable` to perform dtype-aware codegen.

Test Plan: CI

Differential Revision: D61815079

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136778
Approved by: https://github.com/eellison, https://github.com/blaine-rister
2024-10-25 18:00:30 +00:00
Jason Ansel
4632594546 [inductor] Move V.graph.scheduler.current_device to V.graph.current_device (#138252)
There are some places where it would be nice to use this, but the scheduler hasn't yet been created.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138252
Approved by: https://github.com/eellison
ghstack dependencies: #138170
2024-10-18 23:05:54 +00:00