Preparatory refactor for https://github.com/pytorch/pytorch/pull/137243.
# Feature
This PR changes the `RINDEX` / `"r"` symbol type to `(R0_INDEX, R1_INDEX)` and `("r0_", "r1_")`, respectively. This allows the relevant code to support 2D (often ND) reductions. Unlike the parent PR, this one does not change the tiling algorithm, so `"r1_"` is never used. However, it prepares other parts of the system to handle `"r1_"` once we start using it. This should significantly reduce the chances of hitting merge conflicts, making the parent PR much easier to land.
The only change to the generated triton code is to rename `"rindex"` -> `"r0_index"`, `"RBLOCK"` -> `"R0_BLOCK"`, etc. To maintain compatibilty with existing codegen, this also generates aliases to the old reduction variables like `rindex = r0_index`. If we generated 2D reductions (which this PR will not do), the aliases would be more complicated and would collapse 2D multi-indices to linear indices. See some example kernels in the parent PR.
These aliases can be eliminated by the Triton compiler, and should not impact the final machine code running on the GPU. See the perf testing in the parent PR which confirms the aliases do not impact perf.
# Test plan
The existing CI provides good coverage. This PR modifies the expected code in a few places, renaming reduction variables from `r.*` to `r0_.*`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142020
Approved by: https://github.com/jansel
Co-authored-by: Jason Ansel <jansel@meta.com>
**Summary**
Inductor currently uses modulo and division to compute indices into certain multi-dimensional tensors, such as those arising from row padding. This PR matches on that indexing pattern, replacing it with an N-D block pointer. This should be more efficient than computing indices with division and modulo, and it can easily map to DMAs on non-GPU hardware targets.
Because the 1D block size needs to map to an integer block shape in ND, we need to know that the ND block size evenly divides the size of the iteration range. This PR only generates ND block pointers when it can guarantee that the iteration order and number of elements loaded are unchanged. This means that the number of elements in a slice of the iteration range must either be:
- Powers of 2. Since Triton block sizes are powers of 2, any integer power of 2 either divides the block size, or is greater than the block size. In the latter case, `CielDiv(x, y)` rounds up to 1.
- Multiples of the maximum block size. Since block sizes are powers of 2, the maximum block size is a multiple of every possible block size.
Note that a *slice* of the iteration range does not include the leading dimension. Thus we can support arbitrary leading dimensions like `(5,8)`.
Feature proposal and discussion: https://github.com/pytorch/pytorch/issues/125077
Example kernel:
```
triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 4096
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
tmp0 = tl.reshape(tl.load(tl.make_block_ptr(in_ptr0, shape=[32, 16, 8], strides=[1024, 32, 1], block_shape=[32 * (32 <= ((127 + XBLOCK) // 128)) + ((127 + XBLOCK) // 128) * (((127 + XBLOCK) // 128) < 32), 16 * (16 <= ((7 + XBLOCK) // 8)) + ((7 + XBLOCK) // 8) * (((7 + XBLOCK) // 8) < 16), 8 * (8 <= XBLOCK) + XBLOCK * (XBLOCK < 8)], order=[0, 1, 2], offsets=[(xoffset // 128), (xoffset // 8) % 16, xoffset % 8]), boundary_check=[0, 1, 2]), [XBLOCK])
tmp1 = tmp0 + tmp0
tl.store(tl.make_block_ptr(out_ptr0, shape=[4096], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), tl.broadcast_to(tmp1, [XBLOCK]).to(tl.float32))
''', device_str='cuda')
```
**Test Plan**
This PR adds a new CI test script to cover this feature. The tests can be grouped into a few main categories:
- Can we generate strided block pointers for the appropriate shapes?
- Powers of 2
- Non-power of 2, but multiple of the maximum block size
- Arbitrary leading dimensions, with power of 2 inner dimensions
- Weird strides and offsets
- Reductions
- Symbolic shapes that are multiples of the maximum block size (wasn't able to trace this through dynamo)
- Broadcasts (some variables are missing from the indexing expression)
- Do we still compile other cases correctly, even if we don't expect to be able to generate block pointers?
- Unsupported static shapes
- Unsupported symbolic shapes
- Mixing and matching these cases:
- Pointwise and reduction in the same kernel
- Sanity check the test harness
- Do we raise an exception if the expected number of block pointers and the actual number are different?
**Follow-ups**
There are a few important cases which this PR can't handle. I'm hoping these can be deferred to follow-up PRs:
- Handle non-divisible shapes
- Change the tiling algorithm to generate a 2D (X,Y) blocking, if doing so enables block pointers to be emitted.
- Pad unsupported loads up to the nearest divisible size, then mask/slice out the extra elements? This is probably the best solution, but I'm not yet sure how to go about it in triton.
- Take advantage of this analysis when `triton.use_block_ptr=False`. I'm guessing we can still avoid `%` and `/` without requiring block pointers. Maybe we could compute block indices with arange and broadcast instead?
Differential Revision: D56739375
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127342
Approved by: https://github.com/jansel, https://github.com/shunting314
The big idea is that floats are treated as Tensors on input/output to the FX graph, but on the inside, we immediately call item() on the synthetic Tensor and record regular float operations on it. Canonicalization to Tensor operations will happen in a standalone FX pass. This behavior is controlled by `specialize_float` config variable when set to False.
The generated graph looks like this for the test `test_unspec_float_output`:
```
def forward(self, L_x_: "f32[3]", L_y_: "f32[]"):
l_x_ = L_x_
l_y_ = L_y_
# File: /data/users/ezyang/a/pytorch/test/dynamo/test_unspec.py:511 in f, code: return x + 1, y * 2
add: "f32[3]" = l_x_ + 1; l_x_ = None
item: "Sym(zf0)" = l_y_.item(); l_y_ = None
mul: "Sym(2*zf0)" = item * 2; item = None
scalar_tensor: "f32[]" = torch.scalar_tensor(mul); mul = None
return (add, scalar_tensor)
```
The ingredients:
* **torch/_dynamo/variables/builder.py** When `specialize_float` is False, we wrap float literals with `wrap_symfloat`. This is an unholy mashup of `wrap_symint` and `wrap_unspecialized_primitive`. The overall strategy is that we first generate a tensor argument (because that's what we want to show up into the FX graph), but then immediately call item() on the tensor argument to get a SymNodeVariable, which we will do the rest of the tracing with. Importantly, this SymNodeVariable is backed with the source of the original float: this means we can guard on the resulting value (something we could NOT do with UnspecializedPythonVariable). This has to be done manually, because if you literally call item() on the tensor, you will end up with an unbacked float. There is a bit of copy paste from wrap_symint and wrap_unspecialized_primitive which we can try to factor out, but this really is its own thing and you should review every line of code in the function.
* **torch/fx/experimental/symbolic_shapes.py** We now can generate guards on float inputs, and these guards are handled inside of ShapeEnv. So we need to be able to allocate (backed!) float symbols, and produce guards for them. Fairly straightforward generalization.
* **torch/_dynamo/codegen.py** I also need to maintain the invariant that there are no float outputs to the FX graph. I chose to do this at codegen time. When we detect a SymNodeVariable on the return stack for a float, we on the fly convert it (via `as_tensor`) to a TensorVariable, which is the true output. We then special case the output bytecode to call item() on it again. The tensor conversion is memoized on SymNodeVariable since we typically run the code generation process twice.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125325
Approved by: https://github.com/lezcano, https://github.com/jansel
This provides utilities for creating and querying properties on
sympy.Symbol. I want to use this refactor to get a better handle on how
the 's' prefix is being used in Inductor. To start, I only do
symbolic_shapes code because that's what I'm familiar with.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125395
Approved by: https://github.com/Skylion007