Jason Ansel
fed37dbfbc
[inductor] Cooperative reductions ( #137756 )
...
Example generated code for `(x+y).sum()`:
```py
@triton.jit
def triton_unk_fused_add_sum_0(in_ptr0, in_ptr1, out_ptr0, ws_ptr, semaphores_ptr, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr, RSPLIT : tl.constexpr):
xnumel = 1
rnumel = 1048576
rsplit_id = tl.program_id(0)
num_rblocks = (rnumel + RBLOCK - 1) // RBLOCK
rsplit_chunk = (num_rblocks + RSPLIT - 1) // RSPLIT * RBLOCK
rsplit_start = rsplit_chunk * rsplit_id
rsplit_end = rsplit_chunk * (rsplit_id + 1)
xoffset = tl.program_id(1) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
rbase = tl.arange(0, RBLOCK)[None, :]
_tmp4 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(rsplit_start, rsplit_end, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp0 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0)
tmp1 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_first', other=0.0)
tmp2 = tmp0 + tmp1
tmp3 = tl.broadcast_to(tmp2, [XBLOCK, RBLOCK])
tmp5 = _tmp4 + tmp3
_tmp4 = tl.where(rmask, tmp5, _tmp4)
tmp4 = tl.sum(_tmp4, 1)[:, None]
if RSPLIT > 1:
tmp4_ws = (ws_ptr + 0).to(tl.pointer_type(tl.float32))
tl.store(tmp4_ws + (xindex * RSPLIT + rsplit_id), tmp4, None)
if RSPLIT > 1:
triton_helpers.gpu_barrier(semaphores_ptr + (2 * tl.program_id(1) + 0), RSPLIT, True)
if RSPLIT > 1:
tmp4_peers = tl.load(tmp4_ws + (xindex * RSPLIT + tl.arange(0, RSPLIT)[None,:]), None, eviction_policy='evict_first')
tmp4 = tl.sum(tmp4_peers, 1)[:, None]
if rsplit_id == (0 % RSPLIT):
tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp4, None)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137756
Approved by: https://github.com/eellison
ghstack dependencies: #138970
2024-10-27 16:31:38 +00:00
Alex Baden
487873f7ca
[Inductor]: Support updated Triton AttrsDescriptor ( #137757 )
...
The Triton `AttrsDescriptor` object was refactored in https://github.com/triton-lang/triton/pull/4734 . These changes add support for the new `AttrsDescriptor` while maintaining backwards compatibility with the existing version. The main changes are different names for the initialized of the descriptor parameters, and a creation via a static method instead of the class constructor.
Depends on #137458 which removes some unused logic around the old descriptor. Those changes make this PR cleaner, but if for some reason that old logic is still used I can make adjustments.
Use of the new `AttrsDescriptor` depends on https://github.com/triton-lang/triton/pull/4888
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137757
Approved by: https://github.com/jansel
2024-10-15 19:34:59 +00:00
Alex Baden
39d21ed803
[Inductor] Update AttrsDescriptor instantiation for Triton changes ( #137458 )
...
The `AttrsDescriptor` class has been present in Triton for almost a year now (introduced [here](72c9833927 )), so we should be able to rely on it existing. I am in the process of supporting the new `AttrsDescriptor` class and @jansel suggested I split changes to the existing class out separately to make sure nothing breaks removing the legacy attribute descriptor attributes.
Initially I attempted to remove the branching around detecting whether `AttrsDescriptor` exists but that breaks because PyTorch must build without Triton. So, I went back and updated for the naming introduced in the commit linked above, and also removed two unused attributes `divisible_by_8` and `ids_to_fold` which were removed in Feb 2024 (https://github.com/triton-lang/triton/pull/3122 and https://github.com/triton-lang/triton/pull/3080 respectively).
With these changes only the internal workings of the `AttrsDescriptor` class will differ between supported Triton versions, but the data stored will remain consistent.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137458
Approved by: https://github.com/jansel
2024-10-14 20:20:29 +00:00
xinan.lin
0a26851601
[Inductor] Handle device property warp_size is None but used on XPU. ( #136834 )
...
Fix #136820
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136834
Approved by: https://github.com/EikanWang , https://github.com/jansel
2024-09-30 02:08:45 +00:00
David Berard
9c2c61d2dd
[inductor] ELEMENTS_PER_WARP_32 -> ONE_ELEMENT_PER_THREAD ( #136472 )
...
AMD devices have 64 elements per thread; this PR makes the handling of the "ELEMENTS_PER_WARP_32" generic and uses DeviceProperties.warp_size to determine the warp size instead of hard-coding the warp size as 32. It also renames the enum value. Added a unit test for this.
Note: I left the old enum option (ELEMENTS_PER_WARP_32) as is instead of renaming it. I'm not sure whether we expect should caches to get invalidated here; if this concern is valid, then there's a risk that this would get updated, but some model could use the cached inductor code, which would reference "ELEMENTS_PER_WARP_32", which would no longer exist.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136472
Approved by: https://github.com/jansel
2024-09-25 18:21:09 +00:00
Jack Taylor
a15774563b
[ROCm] Enable ROCm support for inductor's dynamic_rblock_scaling ( #129663 )
...
As of ROCm 6.1 [hipDeviceProp_t::regsPerMultiprocessor](https://rocm.docs.amd.com/projects/HIP/en/latest/doxygen/html/structhip_device_prop__t.html#a7390d5b180d63978c81aa971060270b4 ) is now available allowing us to enable this attribute on ROCm.
```
>>> torch.cuda.get_device_properties(0)
_CudaDeviceProperties(name='AMD Instinct MI250X/MI250', major=9, minor=0, gcnArchName='gfx90a:sramecc+:xnack-', total_memory=65520MB, multi_processor_count=104)
>>> torch.cuda.get_device_properties(0).regs_per_multiprocessor
65536
```
With https://github.com/triton-lang/triton/pull/3962we can extract n_regs and n_spells from a triton binary with AMD backend allowing us to enable inductor's dynamic_rblock_scaling on ROCm initially implemented in https://github.com/pytorch/pytorch/pull/115094
Leaving this in draft until following PRs have landed:
- https://github.com/pytorch/pytorch/pull/129361 to bump the triton commit pin
- https://github.com/pytorch/pytorch/pull/128449 to allow us to grab warp_size from device properties instead of hard coding 64 on ROCm.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129663
Approved by: https://github.com/jansel , https://github.com/shunting314
2024-09-13 16:45:39 +00:00
Yichen Yan
c0d2f991b1
Increase TRITON_MAX_BLOCK['X'] ( #135181 )
...
Fixes #135028
As title, increase `TRITON_MAX_BLOCK['X']` to 4096 and fix an error, thanks to @Chillee: https://github.com/pytorch/pytorch/pull/133300/files#r1744706189
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135181
Approved by: https://github.com/jansel
2024-09-10 05:54:37 +00:00
PyTorch MergeBot
5f981388ec
Revert "[ROCm] Enable ROCm support for inductor's dynamic_rblock_scaling ( #129663 )"
...
This reverts commit d7a78ec8b9 .
Reverted https://github.com/pytorch/pytorch/pull/129663 on behalf of https://github.com/atalman due to Breaks internal builds ([comment](https://github.com/pytorch/pytorch/pull/129663#issuecomment-2240011143 ))
2024-07-19 19:46:26 +00:00
Jack Taylor
d7a78ec8b9
[ROCm] Enable ROCm support for inductor's dynamic_rblock_scaling ( #129663 )
...
As of ROCm 6.1 [hipDeviceProp_t::regsPerMultiprocessor](https://rocm.docs.amd.com/projects/HIP/en/latest/doxygen/html/structhip_device_prop__t.html#a7390d5b180d63978c81aa971060270b4 ) is now available allowing us to enable this attribute on ROCm.
```
>>> torch.cuda.get_device_properties(0)
_CudaDeviceProperties(name='AMD Instinct MI250X/MI250', major=9, minor=0, gcnArchName='gfx90a:sramecc+:xnack-', total_memory=65520MB, multi_processor_count=104)
>>> torch.cuda.get_device_properties(0).regs_per_multiprocessor
65536
```
With https://github.com/triton-lang/triton/pull/3962we can extract n_regs and n_spells from a triton binary with AMD backend allowing us to enable inductor's dynamic_rblock_scaling on ROCm initially implemented in https://github.com/pytorch/pytorch/pull/115094
Leaving this in draft until following PRs have landed:
- https://github.com/pytorch/pytorch/pull/129361 to bump the triton commit pin
- https://github.com/pytorch/pytorch/pull/128449 to allow us to grab warp_size from device properties instead of hard coding 64 on ROCm.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129663
Approved by: https://github.com/jansel , https://github.com/shunting314
2024-07-19 09:45:03 +00:00
Xuehai Pan
973037be6a
[BE][Easy] apply autofix for ruff rules unnecessary-collection-call (C408): list() / tuple() / dict() ( #130199 )
...
This PR changes the empty collection factory call to Python literals:
- `list()` -> `[]`
- `tuple()` -> `()`
- `dict()` -> `{}`
The Python literals are more performant and safer. For example, the bytecode for building an empty dictionary:
```bash
$ python3 -m dis - <<EOS
import collections
d1 = {}
d2 = dict()
dict = collections.OrderedDict
d3 = dict()
EOS
```
```text
0 0 RESUME 0
1 2 LOAD_CONST 0 (0)
4 LOAD_CONST 1 (None)
6 IMPORT_NAME 0 (collections)
8 STORE_NAME 0 (collections)
3 10 BUILD_MAP 0
12 STORE_NAME 1 (d1)
4 14 PUSH_NULL
16 LOAD_NAME 2 (dict)
18 CALL 0
26 STORE_NAME 3 (d2)
6 28 LOAD_NAME 0 (collections)
30 LOAD_ATTR 8 (OrderedDict)
50 STORE_NAME 2 (dict)
7 52 PUSH_NULL
54 LOAD_NAME 2 (dict)
56 CALL 0
64 STORE_NAME 5 (d3)
66 RETURN_CONST 1 (None)
```
The dict literal `{}` only has one bytecode `BUILD_MAP`, while the factory call `dict()` has three `PUSH_NULL + LOAD_NAME + CALL`. Also, the factory call is not safe if users override the `dict` name in `locals` or `globals` (see the example of replacing with `OrderedDict` above).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130199
Approved by: https://github.com/malfet
2024-07-11 17:30:28 +00:00
Jason Ansel
0abcca85b7
[halide-backend] Support manual schedules ( #129321 )
...
Currently using this for some by-hand hacking, but might need to implement our own scheduler later.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129321
Approved by: https://github.com/shunting314
2024-07-03 05:56:40 +00:00
PyTorch MergeBot
a83eaf1c3a
Revert "[halide-backend] Support manual schedules ( #129321 )"
...
This reverts commit 9ae78a578c .
Reverted https://github.com/pytorch/pytorch/pull/129321 on behalf of https://github.com/jeanschmidt due to Reverting, as it is required to do so in order to revert #129320 ([comment](https://github.com/pytorch/pytorch/pull/129321#issuecomment-2200345664 ))
2024-07-01 14:42:33 +00:00
Jason Ansel
9ae78a578c
[halide-backend] Support manual schedules ( #129321 )
...
Currently using this for some by-hand hacking, but might need to implement our own scheduler later.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129321
Approved by: https://github.com/shunting314
ghstack dependencies: #126417 , #129025 , #129026 , #127506 , #129036 , #129320
2024-06-29 14:06:28 +00:00
Jason Ansel
4cb8cb04a7
[halide-backend] Enable bfloat16 support ( #129036 )
...
Requires https://github.com/halide/Halide/pull/8255
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129036
Approved by: https://github.com/shunting314 , https://github.com/eellison
ghstack dependencies: #126417 , #129025 , #129026 , #127506
2024-06-29 14:06:25 +00:00
Jason Ansel
da5f37515e
[halide-backend] Generate standalone runtime ( #129025 )
...
This puts the halide runtime in a global shared object, rather than copying it to each kernel. Having many copies of the runtime causes many issues with cuda.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129025
Approved by: https://github.com/shunting314 , https://github.com/eellison
ghstack dependencies: #126417
2024-06-29 14:06:12 +00:00
Jason Ansel
e34b7e6af3
[halide-backend] Initial implementation of HalideKernel and HalideScheduling ( #126417 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126417
Approved by: https://github.com/shunting314 , https://github.com/eellison
2024-06-29 14:06:08 +00:00
PyTorch MergeBot
1a54bb0f96
Revert "[halide-backend] Initial implementation of HalideKernel and HalideScheduling ( #126417 )"
...
This reverts commit 4f9399bd0d .
Reverted https://github.com/pytorch/pytorch/pull/126417 on behalf of https://github.com/fbgheith due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/126417#issuecomment-2186999121 ))
2024-06-24 16:50:15 +00:00
PyTorch MergeBot
063facf352
Revert "[halide-backend] Generate standalone runtime ( #129025 )"
...
This reverts commit 10c64c3b49 .
Reverted https://github.com/pytorch/pytorch/pull/129025 on behalf of https://github.com/fbgheith due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/129025#issuecomment-2186995467 ))
2024-06-24 16:47:25 +00:00
Jason Ansel
10c64c3b49
[halide-backend] Generate standalone runtime ( #129025 )
...
This puts the halide runtime in a global shared object, rather than copying it to each kernel. Having many copies of the runtime causes many issues with cuda.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129025
Approved by: https://github.com/shunting314 , https://github.com/eellison
ghstack dependencies: #126417
2024-06-22 17:39:52 +00:00
Jason Ansel
4f9399bd0d
[halide-backend] Initial implementation of HalideKernel and HalideScheduling ( #126417 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126417
Approved by: https://github.com/shunting314 , https://github.com/eellison
2024-06-22 17:39:52 +00:00
Aaron Orenstein
afe15d2d2f
Flip default value for mypy disallow_untyped_defs [3/11] ( #127840 )
...
See #127836 for details.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127840
Approved by: https://github.com/oulgen
2024-06-08 18:28:01 +00:00
Jason Ansel
b516de8cac
[halide-backend] Add HalideCodeCache ( #126416 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126416
Approved by: https://github.com/shunting314
ghstack dependencies: #126631 , #126655
2024-05-22 06:52:50 +00:00
Sam Larsen
254128c16e
[inductor] Remove usage of device_interface from _inductor.runtime ( #124592 )
...
Differential Revision: [D56723770](https://our.internmc.facebook.com/intern/diff/D56723770 )
Co-authored-by: Sam Larsen <slarsen@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124592
Approved by: https://github.com/masnesral
2024-04-30 16:54:16 +00:00
PyTorch MergeBot
f6ce94dca5
Revert "[inductor] Remove usage of device_interface from _inductor.runtime ( #124592 )"
...
This reverts commit 5d45eb77f1 .
Reverted https://github.com/pytorch/pytorch/pull/124592 on behalf of https://github.com/jeanschmidt due to breaking internal tests, check D56522594 ([comment](https://github.com/pytorch/pytorch/pull/124592#issuecomment-2076957668 ))
2024-04-25 11:28:23 +00:00
Jason Ansel
5d45eb77f1
[inductor] Remove usage of device_interface from _inductor.runtime ( #124592 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124592
Approved by: https://github.com/masnesral
2024-04-23 17:51:25 +00:00
Jason Ansel
0093735ccd
[inductor] Use compile time config values in runtime ( #124561 )
...
This removes usage of torch._inductor.config from `torch._inductor.runtime`. Fixing two issues:
1) If configs change we should really use the compile time ones
2) In compile workers, we want to use the parent process config
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124561
Approved by: https://github.com/yanboliang
ghstack dependencies: #124552 , #124553 , #124557 , #124559 , #124560 , #124569
2024-04-22 18:46:40 +00:00
Jason Ansel
bb8815bc31
[inductor] Refactor runtime files into torch._inductor.runtime (part 2) ( #124553 )
...
I am planning to make the compile_worker process not import torch so it can start up much faster. This stack is prep for that.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124553
Approved by: https://github.com/yanboliang
ghstack dependencies: #124552
2024-04-22 18:46:20 +00:00
PyTorch MergeBot
56714cb497
Revert "[inductor] Refactor runtime files into torch._inductor.runtime (part 2) ( #124553 )"
...
This reverts commit f4d47f5bbb .
Reverted https://github.com/pytorch/pytorch/pull/124553 on behalf of https://github.com/jeanschmidt due to There are internal breakages, already discussed with author and he'll FF ([comment](https://github.com/pytorch/pytorch/pull/124552#issuecomment-2070548223 ))
2024-04-22 18:28:05 +00:00
PyTorch MergeBot
30dec1da84
Revert "[inductor] Use compile time config values in runtime ( #124561 )"
...
This reverts commit 3af12447f8 .
Reverted https://github.com/pytorch/pytorch/pull/124561 on behalf of https://github.com/jeanschmidt due to There are internal breakages, already discussed with author and he'll FF ([comment](https://github.com/pytorch/pytorch/pull/124561#issuecomment-2070537634 ))
2024-04-22 18:24:38 +00:00
Jason Ansel
3af12447f8
[inductor] Use compile time config values in runtime ( #124561 )
...
This removes usage of torch._inductor.config from `torch._inductor.runtime`. Fixing two issues:
1) If configs change we should really use the compile time ones
2) In compile workers, we want to use the parent process config
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124561
Approved by: https://github.com/yanboliang
ghstack dependencies: #124552 , #124553 , #124557 , #124559 , #124560 , #124569
2024-04-22 04:51:30 +00:00
Jason Ansel
f4d47f5bbb
[inductor] Refactor runtime files into torch._inductor.runtime (part 2) ( #124553 )
...
I am planning to make the compile_worker process not import torch so it can start up much faster. This stack is prep for that.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124553
Approved by: https://github.com/yanboliang
ghstack dependencies: #124552
2024-04-22 04:51:09 +00:00