pytorch/torch/_inductor/codegen
Blaine Burton Rister a2753e376b [Inductor] Support tiling reduction dimensions (#137243)
Fixes #134277 and https://github.com/pytorch/pytorch/issues/142317.

Sub-PRs containing refactors from this one:
 - https://github.com/pytorch/pytorch/pull/141733
 - https://github.com/pytorch/pytorch/pull/141738
 - https://github.com/pytorch/pytorch/pull/141751 (based off the former)
 - https://github.com/pytorch/pytorch/pull/142249
 - https://github.com/pytorch/pytorch/pull/142020
 - https://github.com/pytorch/pytorch/pull/143135

 These refactor PRs should land before the main one.

# Feature

*Note: to minimize risk, multi-dimensional reductions are gated by the flag `config.triton.tile_reductions`, which defaults to False.*

Instead of having a single reduction dimension called `"r"`, we can now support 2D reductions with `"r0_"` and `"r1_"` dimensions. 2D reductions generate two nested loops, with different block pointer advancements in each loop body. Most of the implementation is generic to ND reductions, but for now the tiling algorithm sets a hard limit at 2D.

Here's an example of a 2D persistent reduction kernel:
```
@triton.jit
def triton_per_fused_sum_0(in_ptr0, out_ptr0, xnumel, r0_numel, r1_numel, XBLOCK : tl.constexpr):
    xnumel = 1
    r0_numel = 15
    R0_BLOCK: tl.constexpr = 16
    r1_numel = 15
    R1_BLOCK: tl.constexpr = 16
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None, None]
    xmask = tl.full([XBLOCK, R0_BLOCK, R1_BLOCK], True, tl.int1)
    r0_index = tl.arange(0, R0_BLOCK)[None, :, None]
    r0_offset = 0
    r0_mask = r0_index < r0_numel
    r1_index = tl.arange(0, R1_BLOCK)[None, None, :]
    r1_offset = 0
    r1_mask = r1_index < r1_numel
    rnumel = r0_numel * r1_numel
    RBLOCK: tl.constexpr = R0_BLOCK*R1_BLOCK
    roffset = r1_offset + (r0_offset*r1_numel)
    rindex = r1_index + (r0_index*r1_numel)
    r0_0 = r0_index
    r1_1 = r1_index
    tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[15, 15], strides=[30, 1], block_shape=[R0_BLOCK, R1_BLOCK], order=[1, 0], offsets=[r0_offset, r1_offset]), boundary_check=[0, 1], padding_option='zero')[None, :, :]
    tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK, R1_BLOCK])
    tmp3 = tl.where(r0_mask & r1_mask, tmp1, 0)
    tmp4 = tl.reshape(tmp3, [XBLOCK, RBLOCK])
    tmp5 = tl.sum(tmp4, 1)[:, None, None]
    tl.store(out_ptr0 + (tl.full([XBLOCK, 1, 1], 0, tl.int32)), tmp5, None)
''', device_str='cuda')
```

There are a few main differences between this kernel and what Inductor would generate without this PR.
 - Instead of an `r`/`RBLOCK` dimension, we have two reduction dimensions: `r0_`/`R0_BLOCK` and `r1_`/`R1_BLOCK`.
 - There are special size and indexing variables for reductions, which don't directly correspond to any kernel dimension. (`rindex`, `rnumel`, `RBLOCK`, and `roffset`.) These collapse N-D reduction sizes and indices indices into 1D. This simplifies the codegen for reductions, which sometimes want to access linear indices instead of N-dimensional ones. Doing things this way allows us to generate N-D loads and stores, but access this data as if it were 1D, minimizing the blast radius of this PR. Although this makes the code more verbose, it shouldn't have a perf impact because the triton compiler eliminates dead code.
 - We generate the line `tmp4 = tl.reshape(tmp3, [XBLOCK, RBLOCK])` before performing the actual reduction. This reshapes N reduction dimensions into 1D. This allows us to reduce over all N dimensions at once, simplifying the codegen and allowing the Triton complier to decide the order of processing under the hood.

Here's an example of a looped reduction:
```
@triton.jit
def triton_red_fused_sum_0(in_ptr0, out_ptr0, xnumel, r0_numel, r1_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr, R1_BLOCK : tl.constexpr):
    xnumel = 3
    r0_numel = 43
    r1_numel = 129
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None, None]
    xmask = xindex < xnumel
    r0_base = tl.arange(0, R0_BLOCK)[None, :, None]
    r1_base = tl.arange(0, R1_BLOCK)[None, None, :]
    rnumel = r0_numel * r1_numel
    RBLOCK: tl.constexpr = R0_BLOCK*R1_BLOCK
    rbase = r1_base + (r0_base*r1_numel)
    x0 = xindex
    block_ptr0 = tl.make_block_ptr(in_ptr0, shape=[3, 43, 129], strides=[11094, 258, 1], block_shape=[XBLOCK, R0_BLOCK, R1_BLOCK], order=[2, 1, 0], offsets=[xoffset, 0, 0])
    _tmp2 = tl.full([XBLOCK, R0_BLOCK, R1_BLOCK], 0, tl.float32)
    for r0_offset in range(0, r0_numel, R0_BLOCK):
        r0_index = r0_offset + r0_base
        r0_mask = r0_index < r0_numel
        for r1_offset in range(0, r1_numel, R1_BLOCK):
            r1_index = r1_offset + r1_base
            r1_mask = r1_index < r1_numel
            roffset = r1_offset + (r0_offset*r1_numel)
            rindex = r1_index + (r0_index*r1_numel)
            r0_1 = r0_index
            r1_2 = r1_index
            tmp0 = tl.load(block_ptr0, boundary_check=[0, 1, 2], padding_option='zero', eviction_policy='evict_first')
            tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK, R1_BLOCK])
            tmp3 = _tmp2 + tmp1
            _tmp2 = tl.where(r0_mask & r1_mask & xmask, tmp3, _tmp2)
            block_ptr0 = tl.advance(block_ptr0, [0, 0, R1_BLOCK])
        block_ptr0 = tl.advance(block_ptr0, [0, R0_BLOCK, (-1)*R1_BLOCK*((128 + R1_BLOCK) // R1_BLOCK)])
    tmp4 = tl.reshape(_tmp2, [XBLOCK, RBLOCK])
    tmp2 = tl.sum(tmp4, 1)[:, None, None]
    tl.store(tl.make_block_ptr(out_ptr0, shape=[3], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), tl.reshape(tmp2, [XBLOCK]).to(tl.float32), boundary_check=[0])
''', device_str='cuda')
```

In addition to the aforementioned changes to the persistent reduction, multidimensional looped reductions have a few more lines of code:
 - They calculate indices inside the loop using `r0_base` and `r1_base`. For compatibility with existing codegen, these are collapsed to the 1D variant `rbase`.
 - Block pointer advancements are more nuanced for multidimensional loops. At the end of each loop body, we emit a `tl.advance` line which not only increments the pointer in its own dimension, but also undoes the cumulative increments of the previous loop level. This is equivalent to the usual practice in nested loops of starting with a fresh iteration variable at each level. Implementing this required refactoring the way we generate pointer advancements into a new `self.pointer_advancements` field of the kernel, which categorizes advancements by dimension.

The biggest difficulty in implementing this feature was that we represented tiling with a tuple like `(5,2)`. In the existing codebase, the compiler can infer that the reduction dimension of `(5,2)` is `2`, since reductions are always the last dimension. This became cumbersome now that we have to support multiple reduction dimensions, so I refactored tiling into a dict like `{"x": 5, "r0_": 2, "r1_": 4}`. This required quite a few code changes, but I don't think it makes the underlying logic much more complex. This will also make it easier to eventually support simultaneous pointwise and reduction tiling, like `{"x": 5, "y": 5, "r0_": 2, "r1_": 4}`. (This is not supported today, but we might want to do it eventually.)

The existing tiling algorithm generalized naturally to support reductions. For pointwise kernels, we tile the pointwise dimensions (`"x"`, `"y"`) as is. For reduction kernels, we never tile the `"x"` dimension, and only tile the reduction dimensions (`"r0_"`, `"r1_"`). Thus we only ever tile pointwise OR reduction dimensions, but not both. In principle it seems possible to support both, but it would likely require changes to the kernel fusion and autotuning logic. I thought it best to keep this PR as minimal as possible since it already touched a lot of different files.

Unfortunately, these changes weren't enough to get block pointers in some seemingly simple test cases. In some tests for `argmax` and `var_mean`, we already collapse reduction dimensions into 1D and generate modular indexing expressions, prior to tiling. So it's not trivial to figure out how to expand the collapsed reduction dimension back to a shape that would simplify the indexing.

To address these cases, this PR adds a new feature to the `config.prefer_nd_tiling` option, which analyzes reads and writes in the kernel, using the same mod-div pattern matching logic that generates block pointers later on. By matching this pattern, we can solve for the tiling splits which *would* simplify the indexing expression, and use then use that tiling to eliminate the modular indexing and emit a block pointer. This tiling mode is still off by default, but it's important for certain applications where we need to get as many block pointers as possible.

# Test plan

This touches pretty much anything that uses the Triton and Halide backends, so the existing CI provides good coverage. However, 2D reductions are gated behind a few feature flags like `config.prefer_nd_tiling` and `config.tile_reductions`, so this really only checks that the PR doesn't break 1D reductions.

In addition to existing CI tests, this PR also adds some new tests that specifically stress 2D reductions:

- `test_2d_reduction_odd_shapes`: test 2D reductions with a variety of ops and sizes. This covers the typical persistent and looped reductions.
-  `test_2d_reduce_no_x_dim`: test 2D reductions with no x dimension.
-  `test_2d_welford_reduction`: test 2D welford reductions with block pointers.
- `test_welford_non_block_pointer`: test a 2D welford reduction when block pointer analysis fails.
- `test_reduction_multiple_discontiguous_dims`: test reducing over more than one discontiguous dimension. We won't get a block pointer for this case, since that would require 3D tiling, but we're currently limited to 2D.
- `test_2d_reduction_multi_kernel`: test multi kernel autotuning on a 2D softmax kernel.
- `test_enable_tiled_reductions`: test that `config.triton.tile_reductions` enables/disables this feature.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137243
Approved by: https://github.com/jansel

Co-authored-by: Yueming Hao <yhao@meta.com>
Co-authored-by: Jason Ansel <jansel@meta.com>
2024-12-31 05:06:46 +00:00
..
aoti_runtime [AOTI][refactor] Separate header codegen (#138882) 2024-10-27 14:14:27 +00:00
cuda [CUTLASS] fix bugs: extra data_ptr() call, wrong size symbol name, bias symbol not added (#143528) 2024-12-27 23:38:18 +00:00
rocm remove allow-untyped-defs from _inductor/codegen/rocm/rocm_template_buffer.py (#143870) 2024-12-27 23:28:51 +00:00
xpu [AOTI XPU] Enable Cpp wraper for Intel GPU. (#135318) 2024-11-26 11:51:32 +00:00
__init__.py
aoti_hipify_utils.py remove allow-untyped-defs from _inductor/codegen/aoti_hipify_utils.py (#143916) 2024-12-27 23:25:37 +00:00
block_analysis.py [Inductor] move block pointer analysis to a new module (#141733) 2024-11-30 23:21:24 +00:00
common.py [Inductor] Implement primitive Metal compiler (#143893) 2024-12-28 06:58:32 +00:00
cpp_bmm_template.py [inductor][cpp] Add BMM kernel template for autotuning (#129772) 2024-12-06 04:54:00 +00:00
cpp_flex_attention_template.py [inductor] Replace set by OrderedSet (#138466) 2024-12-13 16:08:45 +00:00
cpp_gemm_template.py [inductor] Replace set by OrderedSet (#138466) 2024-12-13 16:08:45 +00:00
cpp_micro_gemm.py Simplify & rectify dequantized B buffer loading for AMX GEMM micro-kernel for WoQ int8 case (#140258) 2024-11-22 01:34:06 +00:00
cpp_prefix.h [16/N] Fix extra warnings brought by clang-tidy-17 (#143714) 2024-12-24 03:29:38 +00:00
cpp_template_kernel.py [inductor] Replace set by OrderedSet (#138466) 2024-12-13 16:08:45 +00:00
cpp_template.py [AOTI] Remove the non-ABI-compatible mode (part 1) (#138009) 2024-10-17 02:48:26 +00:00
cpp_utils.py [Inductor][CPP] Fix Data Type issue of frexp (#143746) 2024-12-28 06:00:13 +00:00
cpp_wrapper_cpu_array_ref.py [inductor] Replace set by OrderedSet (#138466) 2024-12-13 16:08:45 +00:00
cpp_wrapper_cpu.py cpp_wrapper: minimize pybind11 dependency (#143772) 2024-12-30 20:41:02 +00:00
cpp_wrapper_gpu.py [AOTI] Relax input alignment assertion (#143236) 2024-12-17 00:17:39 +00:00
cpp.py [Inductor][CPP] Fix Data Type issue of frexp (#143746) 2024-12-28 06:00:13 +00:00
cpu_device_op_overrides.py remove allow-untyped-defs from _inductor/codegen/cpu_device_op_overrides.py (#143881) 2024-12-27 04:10:47 +00:00
cuda_combined_scheduling.py Prologue Fusion (#134532) 2024-12-13 04:18:25 +00:00
debug_utils.py Rename convert_arrayref_tensor_to_tensor to copy_arrayref_tensor_to_tensor (#142182) 2024-12-09 22:23:21 +00:00
halide.py [inductor] Replace set by OrderedSet (#138466) 2024-12-13 16:08:45 +00:00
memory_planning.py [inductor] Replace set by OrderedSet (#138466) 2024-12-13 16:08:45 +00:00
mps_device_op_overrides.py [Inductor] Add MPS device op overrides (#143892) 2024-12-28 02:11:45 +00:00
mps.py [MPSInductor] Fix index generation for transpose (#143973) 2024-12-31 02:04:50 +00:00
multi_kernel.py Revert "Use absolute path path.resolve() -> path.absolute() (#129409)" 2024-12-26 17:26:06 +00:00
simd_kernel_features.py [inductor] Replace set by OrderedSet (#138466) 2024-12-13 16:08:45 +00:00
simd.py [Inductor] Support tiling reduction dimensions (#137243) 2024-12-31 05:06:46 +00:00
triton_combo_kernel.py [inductor] Replace set by OrderedSet (#138466) 2024-12-13 16:08:45 +00:00
triton_split_scan.py [inductor] Replace set by OrderedSet (#138466) 2024-12-13 16:08:45 +00:00
triton_utils.py [inductor] Move V.graph.scheduler.current_device to V.graph.current_device (#138252) 2024-10-18 23:05:54 +00:00
triton.py [Inductor] Support tiling reduction dimensions (#137243) 2024-12-31 05:06:46 +00:00
wrapper.py [Inductor] Relax size constraints for re-inplacing (#143884) 2024-12-31 03:52:47 +00:00