Here's a markdown summary for the PR:
# Add workspace buffer support for Triton templates
## Summary
Adds support for templates to allocate and use temporary workspace buffers
## Key Changes
- Add `WorkspaceArg` support in Triton template system
- Automatic workspace allocation/deallocation around kernel execution
- Zero-initialization support for workspace buffers
- Seamless integration with existing tensor management
## Example Usage
```python
def generate(self, ...):
workspace_arg = WorkspaceArg(
count=1024*1024, # 1MB workspace
zero_fill=True # Zero-initialized
)
return TritonTemplateCaller(..., workspace_arg=workspace_arg)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138050
Approved by: https://github.com/Chillee, https://github.com/eellison
Summary:
- This diff introduces `dtype` attribute to `TritonCSEVariable` and a dtype propagation helper function to infer dtype from input to output for each op.
- There will be a follow-up diff that uses this `dtype` information in `TritonCSEVariable` to perform dtype-aware codegen.
Test Plan: CI
Differential Revision: D61815079
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136778
Approved by: https://github.com/eellison, https://github.com/blaine-rister
The Triton `AttrsDescriptor` object was refactored in https://github.com/triton-lang/triton/pull/4734. These changes add support for the new `AttrsDescriptor` while maintaining backwards compatibility with the existing version. The main changes are different names for the initialized of the descriptor parameters, and a creation via a static method instead of the class constructor.
Depends on #137458 which removes some unused logic around the old descriptor. Those changes make this PR cleaner, but if for some reason that old logic is still used I can make adjustments.
Use of the new `AttrsDescriptor` depends on https://github.com/triton-lang/triton/pull/4888
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137757
Approved by: https://github.com/jansel
The `AttrsDescriptor` class has been present in Triton for almost a year now (introduced [here](72c9833927)), so we should be able to rely on it existing. I am in the process of supporting the new `AttrsDescriptor` class and @jansel suggested I split changes to the existing class out separately to make sure nothing breaks removing the legacy attribute descriptor attributes.
Initially I attempted to remove the branching around detecting whether `AttrsDescriptor` exists but that breaks because PyTorch must build without Triton. So, I went back and updated for the naming introduced in the commit linked above, and also removed two unused attributes `divisible_by_8` and `ids_to_fold` which were removed in Feb 2024 (https://github.com/triton-lang/triton/pull/3122 and https://github.com/triton-lang/triton/pull/3080 respectively).
With these changes only the internal workings of the `AttrsDescriptor` class will differ between supported Triton versions, but the data stored will remain consistent.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137458
Approved by: https://github.com/jansel
Previously, instances of `SchedulerNode` and `FusedSchedulerNode` would explicitly check whether the compilation target is Triton when codegen'ing debug strings. Generating debug triton code is instead implemented as a callback set on scheduler nodes by `TritonScheduling`. This makes the codegen more device-agnostic and allows schedulers to customise the codegen output as opposed to it being closely coupled to the debug string codegen
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135015
Approved by: https://github.com/jansel
Summary: Previously triton_reshape will generate code with `Min` keyword in it, which is incorrect. This diff updates the triton_reshape function to properly expand `Min` keyword to `<`.
Test Plan:
```
buck2 run @//mode/{opt,mtia,inplace} //glow/fb/fx/fba/tests:test_fba_inductor -- -r test_Min_keyword_in_block_shape
```
Differential Revision: D63850158
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137357
Approved by: https://github.com/blaine-rister, https://github.com/eellison
Adds lowering for `aten.searchsorted`. This entails:
1. Adding support for multi-dimensional bucket tensors to `ops.bucketize`.
2. Adding support for striding to `ops.bucketize`.
3. Adding support for sorting tensors to `ops.bucketize`.
4. Adding a lowering for `aten.searchsorted.Tensor`.
5. Adding a basic decomposition for `aten.searchsorted.Scalar` that calls into the lowering for tensors.
6. Updating the meta-function for `aten.searchsorted` to properly check some of the sizing conditions.
Closes#135873
Differential Revision: [D63766514](https://our.internmc.facebook.com/intern/diff/D63766514)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135701
Approved by: https://github.com/amjames, https://github.com/eellison, https://github.com/davidberard98
This is a retry of https://github.com/pytorch/pytorch/pull/136594, which is having trouble landing.
Summary: We have an internal report of a Triton compiler error `ValueError: Cannot broadcast, rank mismatch: [1], [1, 2048]` coming from a line like this:
`tmp25 = tl.broadcast_to(((tl.full([1], 1.00000000000000, tl.float64)) + ((ks0 // 3278).to(tl.float64))) / (((tl.full([1], 0.500000000000000, tl.float64))*(libdevice.sqrt((1 + ((ks0 // 3278)*(ks0 // 3278)) + ((-2)*(ks0 // 3278))).to(tl.float64).to(tl.float32)))) + ((tl.full([1], 0.500000000000000, tl.float64))*((1 + (ks0 // 3278)).to(tl.float64)))), [XBLOCK, RBLOCK])`
https://github.com/pytorch/pytorch/pull/135260 is the cause, presumably because we turn a constant into a 1-element tensor with: `(tl.full([1], const, tl.float64))`. It looks like changing the syntax to `(tl.full([], const, tl.float64))` gives us what we want?
Differential Revision: [D63540693](https://our.internmc.facebook.com/intern/diff/D63540693)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136858
Approved by: https://github.com/atalman
Related issue: #125077
### Feature
Inductor tries to remove dimensions with stride 0 from block pointers. Rather than loading with stride 0, it's more efficient to load a smaller block pointer, then use `tl.broadcast_to` to broadcast it up to the desired size. This already worked for simpler block pointers, but it was disabled for more complex block pointers which used `tl.reshape` to change the dimensionality after loading.
This PR generalizes the approach to work for all block pointers. The idea is to first reshape, adding singleton dimensions, then broadcast those singletons up to something larger, then reshape again to the final output shape. For readability, we emit this code only if it actually does something. Simpler loads will just have `tl.load`.
Here's an example of a complicated kernel that uses `reshape` -> `load` -> `reshape`. (The first reshape is actually the slice `[None,None,:]`).
```
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 64
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x2 = xindex
x1 = (xindex // 8)
tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[64], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), boundary_check=[0])
tmp1 = tl.reshape(tl.broadcast_to(tl.load(tl.make_block_ptr(in_ptr1, shape=[8], strides=[8], block_shape=[((7 + XBLOCK) // 8)], order=[0], offsets=[(xoffset // 8)]), boundary_check=[0], eviction_policy='evict_last')[:, None, None], [((7 + XBLOCK) // 8), ((1) * ((1) <= (((7 + XBLOCK) // 8))) + (((7 + XBLOCK) // 8)) * ((((7 + XBLOCK) // 8)) < (1))), ((8) * ((8) <= (XBLOCK)) + (XBLOCK) * ((XBLOCK) < (8)))]), [XBLOCK])
tmp2 = tmp0 + tmp1
tl.store(tl.make_block_ptr(out_ptr0, shape=[64], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), tmp2.to(tl.float32), boundary_check=[0])
''', device_str='cuda')
```
Before this PR, we would have stride-0 dimensions:
```
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 64
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x2 = xindex
x1 = (xindex // 8)
tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[64], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), boundary_check=[0])
tmp1 = tl.reshape(tl.load(tl.make_block_ptr(in_ptr1, shape=[8, 1, 8], strides=[8, 0, 0], block_shape=[((7 + XBLOCK) // 8), ((1) * ((1) <= (((7 + XBLOCK) // 8))) + (((7 + XBLOCK) // 8)) * ((((7 + XBLOCK) // 8)) < (1))), ((8) * ((8) <= (XBLOCK)) + (XBLOCK) * ((XBLOCK) < (8)))], order=[2, 1, 0], offsets=[(xoffset // 8), 0, xoffset % 8]), boundary_check=[0], eviction_policy='evict_last'), [XBLOCK])
tmp2 = tmp0 + tmp1
tl.store(tl.make_block_ptr(out_ptr0, shape=[64], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), tl.broadcast_to(tmp2, [XBLOCK]).to(tl.float32), boundary_check=[0])
''', device_str='cuda')
```
Here's a simpler example where we use 2D tiling. In this case we don't actually need the broadcast. The broadcast is implied via a slice adding a new singleton dimension. This code is not changed by this PR, but it's important to know that we don't accidentally insert unnecessary broadcasts.
```
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
ynumel = 8
xnumel = 8
yoffset = tl.program_id(1) * YBLOCK
yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
ymask = yindex < ynumel
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
x1 = xindex
y0 = yindex
tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[8, 8], strides=[1, 8], block_shape=[XBLOCK, YBLOCK], order=[1, 0], offsets=[xoffset, yoffset]), boundary_check=[0, 1])
tmp1 = tl.load(tl.make_block_ptr(in_ptr1, shape=[8], strides=[8], block_shape=[YBLOCK], order=[0], offsets=[yoffset]), boundary_check=[0], eviction_policy='evict_last')[None, :]
tmp2 = tmp0 + tmp1
tl.store(tl.make_block_ptr(out_ptr0, shape=[8, 8], strides=[1, 8], block_shape=[XBLOCK, YBLOCK], order=[1, 0], offsets=[xoffset, yoffset]), tmp2.to(tl.float32), boundary_check=[0, 1])
''', device_str='cuda')
```
### Test Plan
Added a new expecttest to check the emitted code for broadcast addition. Looking at the test, we can see that stride 0 dimensions are removed. (This test generated the example kernels in the previous section.)
This change also removed a stride-0 dimension in an existing block pointer test. I updated the expected code accordingly.
Bonus: I noticed that the test parametrization for `config.prefer_nd_tiling` wasn't working as intended. It ended up always setting this option to `True`. Fixed it so we get the intended test coverage.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135557
Approved by: https://github.com/shunting314, https://github.com/jansel
Co-authored-by: Yueming Hao <yhao@meta.com>
Summary: We have an internal report of a Triton compiler error `ValueError: Cannot broadcast, rank mismatch: [1], [1, 2048]` coming from a line like this:
`tmp25 = tl.broadcast_to(((tl.full([1], 1.00000000000000, tl.float64)) + ((ks0 // 3278).to(tl.float64))) / (((tl.full([1], 0.500000000000000, tl.float64))*(libdevice.sqrt((1 + ((ks0 // 3278)*(ks0 // 3278)) + ((-2)*(ks0 // 3278))).to(tl.float64).to(tl.float32)))) + ((tl.full([1], 0.500000000000000, tl.float64))*((1 + (ks0 // 3278)).to(tl.float64)))), [XBLOCK, RBLOCK])
`
https://github.com/pytorch/pytorch/pull/135260 is the cause, presumably because we turn a constant into a 1-element tensor with: `(tl.full([1], const, tl.float64))`. It looks like changing the syntax to `(tl.full([], const, tl.float64))` gives us what we want?
Differential Revision: [D63465169](https://our.internmc.facebook.com/intern/diff/D63465169)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136594
Approved by: https://github.com/mengluy0125, https://github.com/jansel
AMD devices have 64 elements per thread; this PR makes the handling of the "ELEMENTS_PER_WARP_32" generic and uses DeviceProperties.warp_size to determine the warp size instead of hard-coding the warp size as 32. It also renames the enum value. Added a unit test for this.
Note: I left the old enum option (ELEMENTS_PER_WARP_32) as is instead of renaming it. I'm not sure whether we expect should caches to get invalidated here; if this concern is valid, then there's a risk that this would get updated, but some model could use the cached inductor code, which would reference "ELEMENTS_PER_WARP_32", which would no longer exist.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136472
Approved by: https://github.com/jansel
Summary: Previous PR forgets to change two other places that also create `constants` and `signature`.
Test Plan:
Imported from GitHub, without a `Test Plan:` line.
{F1884584338}
Differential Revision: D63027728
Pulled By: Myrthan
Co-authored-by: Jokeren <robinho364@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136514
Approved by: https://github.com/jansel
Co-authored-by: Jokeren <robinho364@gmail.com>
Fixes#131337
- add `arg_type` for workspace_arg, the type is consistent with the type in `generate_workspace_allocation()`.
- do not generate example tensors for `workspace`, and use `generate_workspace_allocation()` instead.
- add workspace allocation generation code to `kernel_autotune_calls`. e.g.
```python
workspace = empty_strided_cuda((1280, ), (1, ), torch.uint8)
workspace.zero_()
.....
triton_spl_fused_add_cumprod_0.run(buf2, arg0_1, arg1_1, workspace, 1, 10000, grid=split_scan_grid(1, 10000), stream=stream0)
del buf2, arg0_1, arg1_1, workspace
```
- add `empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda` to the header of triton autotune code.
The generated cpp has lines like below, so we also implement a `zero_()` for ` AtenTensorHandle `.
```cpp
static constexpr int64_t int_array_0[] = {1280L, };
static constexpr int64_t int_array_1[] = {1L, };
AtenTensorHandle workspace_handle;
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(1, int_array_0, int_array_1, cached_torch_dtype_uint8, cached_torch_device_type_cuda, 0, &workspace_handle));
RAIIAtenTensorHandle workspace(workspace_handle);
workspace.zero_();
```
- Fix handle grid_fn for grid computation. Pass in "RBLOCK" to `split_scan_grid`
- Fix dynamic shapes:
Without the fix we generate code that looks like this `workspace = empty_strided_cuda((32*((255 + s0) // 256), ), (1, ), torch.uint8)` when doing triton autotune and `s0` is not defined.
The solution approach is to use `V.graph.sizevars.size_hint(nbytes)` to realize the workspace size for triton autotune. Note that we only realize it for triton autotune code, but not for the cpp cuda code.
- We also generate slightly different cpp code depending on if `abi_compatible` is turned on.
```cpp
RAIIAtenTensorHandle workspace(workspace_handle);
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_zero_(workspace.get()));
```
vs
```cpp
at::Tensor workspace = at::detail::empty_strided_cuda({8L*(c10::div_floor_integer(static_cast<int64_t>((255L + s0)), static_cast<int64_t>(256L))), }, {1L, }, at::kByte, c10::DeviceType::CUDA);
workspace.zero_();
```
Test Plan:
```
TORCHINDUCTOR_ABI_COMPATIBLE=1 TORCHINDUCTOR_CPP_WRAPPER=1 python test/inductor/test_torchinductor.py -k GPUTests.test_consecutive_split_cumprod_cuda
python test/inductor/test_cuda_cpp_wrapper.py TestCudaWrapper.test_consecutive_split_cumprod_cuda_cuda_wrapper
python test/inductor/test_cuda_cpp_wrapper.py DynamicShapesCudaWrapperCudaTests.test_consecutive_split_cumprod_cuda_dynamic_shapes_cuda_wrapper
TORCHINDUCTOR_ABI_COMPATIBLE=1 python test/inductor/test_cuda_cpp_wrapper.py TestCudaWrapper.test_consecutive_split_cumprod_cuda_cuda_wrapper
TORCHINDUCTOR_CPP_WRAPPER=1 python test/inductor/test_torchinductor.py -k GPUTests.test_consecutive_split_cumprod_cuda
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135552
Approved by: https://github.com/desertfire