Python's set is non deterministic. There is an internal failure which we recently ran into which did not consistently fail.
See, repro here: P1453035092.
Now, with these changes, it does consistently fail. In follow ups we could also consider adding a lintrule for uses of either set() or set literals.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130004
Approved by: https://github.com/oulgen
This fixes a few instances where we assumed indexing expressions were
non-negative. This is not valid when we have more complicated
expressions involving masking e.g. pointwise cat.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131761
Approved by: https://github.com/ezyang
Python's set is non deterministic. There is an internal failure which we recently ran into which did not consistently fail.
See, repro here: P1453035092.
Now, with these changes, it does consistently fail. In follow ups we could also consider adding a lintrule for uses of either set() or set literals.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130004
Approved by: https://github.com/oulgen
```
# Mode to emulate pytorch eager numerics for lower precision (fp16, bf16)
# Pytorch eager computes bf16/fp16 by upcasting inputs to fp32 and downcasting after
# For multiple, fused pointwise nodes, inductor will elide the intermediary upcasts and downcasts
# Typically this should be closer to fp64 ref numerics. However, it can be useful for debugging
# to emulate the eager numerics.
```
We add extra upcasts and downcasts for pointwise nodes that correspond to casts that existed in the original user program (excluding pointwise nodes that are emitted during decomposition). Since this is mostly for debugging, I added this information in the `meta` so that this mode does not have unintended side effects like changing pattern matching.
in theory there could also be some other casts with fused reduction -> reduction, although i havent seen this in practice as much. could be done as follow up. note: only works with cuda backend right now.
This mode was sufficient to eliminate compile differences from https://fb.workplace.com/groups/385893200869952/posts/464263173032954/?comment_id=465199259606012&reply_comment_id=465676792891592.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131595
Approved by: https://github.com/shunting314, https://github.com/bdhirsh, https://github.com/jansel
This fixes a few instances where we assumed indexing expressions were
non-negative. This is not valid when we have more complicated
expressions involving masking e.g. pointwise cat.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131761
Approved by: https://github.com/ezyang
Closes#129507
This makes two changes to the sort kernel:
1. Use int16 for the indices since we only operate on small dims anyway
2. Instead of passing an explicit mask, we pass the rnumel and imply the
mask from that which saves an additional reduction in the sort
kernel's inner loop.
In my benchmarks, this gives enough of a perf improvement to bump up the
max rblock to 512.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131719
Approved by: https://github.com/eellison
This fixes a couple errors that come up when multi-kernel is used with
split-scan.
1. The split-scan was being marked as a persistent kernel, which allowed
a multi-kernel to be created but this isn't supported. Fix is to
never mark split-scan as persistent.
2. Benchmark codegen was not handling WorkspaceArg, and would raise a
KeyError during codegen.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131044
Approved by: https://github.com/shunting314
Fixes#126338
## Issue Summary
When torchinductor compiles the combination `functional_collective -> view.dtype -> wait`, a memory leak occurs. This happens because `view.dtype` is compiled into an out-of-place Triton kernel that copies the input data to a new tensor, even if the data hasn't completed collection via the wait operation. The tensor used by `collective` is only freed when the `wait` operation triggers the garbage collector, see [~WorkRegistry](https://github.com/pytorch/pytorch/blob/main/torch/csrc/distributed/c10d/Functional.cpp#L41). However, since `wait` now waits for a new tensor, the previous one is never freed. The `view.dtype` should only check the metadata instead of creating a new tensor. The current lowering is against its semantics and causes memory leaks.
See more great discussions in the #126338
This kind of lowering also generates unnecessary triton kernels for `view.dtype` when it can't be fused with other operations.
## Fix
The function `aten.view.dtype` is a CPU operation that changes the metadata of its input. After discussions with @eellison and @bdhirsh, we decided to change the lowering of `aten.view.dtype` to ensure it fallback properly to the correct `aten.view.dtype` instead of generating a Triton kernel in some cases. This approach also preserves the same semantics of the view operation.
When the model calls `aten.view.dtype` with a data type whose bit width matches the input's original data type, we lower it to the newly added `DtypeView` in IR, acting like a `ReinterpretView`. When the operation can be fused, its `make_loader` is called to maintain the correct type conversion for each load instruction. When the operation can't be fused, it falls back to `aten.view.dtype` to avoid Triton kernel generation.
## Example
```python
@torch.compile
def fn(x, y):
x = x.view(torch.float16)
y = y.view(torch.float16) + 1
return x @ y
x = torch.randn((2, 2), device=self.device, dtype=torch.bfloat16)
y = torch.randn((2, 2), device=self.device, dtype=torch.bfloat16)
fn(x, y)
```
The output code generated before this fix is like the following.
```python
triton_poi_fused_add_view_0...
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 4
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask).to(tl.float32)
tmp1 = tmp0.to(tl.bfloat16).to(tl.float32, bitcast=True).to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, xmask)
triton_poi_fused_add_view_1...
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 4
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask).to(tl.float32)
tmp1 = tmp0.to(tl.bfloat16).to(tl.float32, bitcast=True).to(tl.float32)
tmp2 = 1.0
tmp3 = tmp1 + tmp2
tl.store(out_ptr0 + (x0), tmp3, xmask)
def call(args):
...
triton_poi_fused_view_0.run(arg0_1, buf0, 4, grid=grid(4), stream=stream0)
del arg0_1
buf1 = empty_strided_cuda((2, 2), (2, 1), torch.float16)
# Source Nodes: [view_1, y], Original ATen: [aten.add, aten.view]
triton_poi_fused_add_view_1.run(arg1_1, buf1, 4, grid=grid(4), stream=stream0)
del arg1_1
buf2 = empty_strided_cuda((2, 2), (2, 1), torch.float16)
# Source Nodes: [matmul, view_1, x, y], Original ATen: [aten.add, aten.mm, aten.view]
extern_kernels.mm(buf0, buf1, out=buf2)
```
As you can see, the two `view` operations are compiled to two kernels `triton_poi_fused_view_0` nad `triton_poi_fused_add_view_1`. Both of them has a line `tmp1 = tmp0.to(tl.bfloat16).to(tl.float32, bitcast=True).to(tl.float32)` which does the type conversion.
The main issue is that the first `view` operation didn't do anything to the actual data. But it generates a triton kernel with a new output tensor. Another small issue is that this triton kernel can't be compiled because `bitcast=True` only support type converstion with same bidwidth.
The following are output code generated after this PR.
```python
triton_poi_fused_add_0...
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 4
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask).to(tl.float32)
tmp1 = tmp0.to(tl.bfloat16).to(tl.float32)
tmp2 = 1.0
tmp3 = tmp1 + tmp2
tl.store(out_ptr0 + (x0), tmp3, xmask)
def call(args):
...
triton_poi_fused_add_0.run(arg1_1, buf0, 4, grid=grid(4), stream=stream0)
del arg1_1
buf1 = empty_strided_cuda((2, 2), (2, 1), torch.float16)
# Source Nodes: [matmul, y], Original ATen: [aten.add, aten.mm]
extern_kernels.mm(aten.view.dtype(arg0_1, torch.float16), buf0, out=buf1)
```
The first `view` operation has been replaced with the `aten.view.dtype` and it is directly passed as an argument. The second one is still there because it is fused with the following add operation. The invalid bitcast operation is removed too.
The following two code snippets is for the upcasts and downcasts. For dtype in `torch.float16, torch.bfloat16`, each load will be upcasted to float32, then downcast to its original dtype to ensure use values with the right precision.
7bda23ef84/torch/_inductor/codegen/triton.py (L1725-L1726)7bda23ef84/torch/_inductor/codegen/triton.py (L629-L642)
Huge thanks to @eellison, @bdhirsh, @shunting314, and @desertfire .
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128883
Approved by: https://github.com/eellison
**Summary**
Inductor currently uses modulo and division to compute indices into certain multi-dimensional tensors, such as those arising from row padding. This PR matches on that indexing pattern, replacing it with an N-D block pointer. This should be more efficient than computing indices with division and modulo, and it can easily map to DMAs on non-GPU hardware targets.
Because the 1D block size needs to map to an integer block shape in ND, we need to know that the ND block size evenly divides the size of the iteration range. This PR only generates ND block pointers when it can guarantee that the iteration order and number of elements loaded are unchanged. This means that the number of elements in a slice of the iteration range must either be:
- Powers of 2. Since Triton block sizes are powers of 2, any integer power of 2 either divides the block size, or is greater than the block size. In the latter case, `CielDiv(x, y)` rounds up to 1.
- Multiples of the maximum block size. Since block sizes are powers of 2, the maximum block size is a multiple of every possible block size.
Note that a *slice* of the iteration range does not include the leading dimension. Thus we can support arbitrary leading dimensions like `(5,8)`.
Feature proposal and discussion: https://github.com/pytorch/pytorch/issues/125077
Example kernel:
```
triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 4096
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
tmp0 = tl.reshape(tl.load(tl.make_block_ptr(in_ptr0, shape=[32, 16, 8], strides=[1024, 32, 1], block_shape=[32 * (32 <= ((127 + XBLOCK) // 128)) + ((127 + XBLOCK) // 128) * (((127 + XBLOCK) // 128) < 32), 16 * (16 <= ((7 + XBLOCK) // 8)) + ((7 + XBLOCK) // 8) * (((7 + XBLOCK) // 8) < 16), 8 * (8 <= XBLOCK) + XBLOCK * (XBLOCK < 8)], order=[0, 1, 2], offsets=[(xoffset // 128), (xoffset // 8) % 16, xoffset % 8]), boundary_check=[0, 1, 2]), [XBLOCK])
tmp1 = tmp0 + tmp0
tl.store(tl.make_block_ptr(out_ptr0, shape=[4096], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), tl.broadcast_to(tmp1, [XBLOCK]).to(tl.float32))
''', device_str='cuda')
```
**Test Plan**
This PR adds a new CI test script to cover this feature. The tests can be grouped into a few main categories:
- Can we generate strided block pointers for the appropriate shapes?
- Powers of 2
- Non-power of 2, but multiple of the maximum block size
- Arbitrary leading dimensions, with power of 2 inner dimensions
- Weird strides and offsets
- Reductions
- Symbolic shapes that are multiples of the maximum block size (wasn't able to trace this through dynamo)
- Broadcasts (some variables are missing from the indexing expression)
- Do we still compile other cases correctly, even if we don't expect to be able to generate block pointers?
- Unsupported static shapes
- Unsupported symbolic shapes
- Mixing and matching these cases:
- Pointwise and reduction in the same kernel
- Sanity check the test harness
- Do we raise an exception if the expected number of block pointers and the actual number are different?
**Follow-ups**
There are a few important cases which this PR can't handle. I'm hoping these can be deferred to follow-up PRs:
- Handle non-divisible shapes
- Change the tiling algorithm to generate a 2D (X,Y) blocking, if doing so enables block pointers to be emitted.
- Pad unsupported loads up to the nearest divisible size, then mask/slice out the extra elements? This is probably the best solution, but I'm not yet sure how to go about it in triton.
- Take advantage of this analysis when `triton.use_block_ptr=False`. I'm guessing we can still avoid `%` and `/` without requiring block pointers. Maybe we could compute block indices with arange and broadcast instead?
Differential Revision: D56739375
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127342
Approved by: https://github.com/jansel, https://github.com/shunting314
At a high level, the idea behind this PR is:
* Make it clearer what the promotion and int/float rules for various Sympy operations are. Operators that previously were polymorphic over int/float are now split into separate operators for clarity. We never do mixed int/float addition/multiplication etc in sympy, instead, we always promote to the appropriate operator. (However, equality is currently not done correctly.)
* Enforce strict typing on ValueRanges: if you have a ValueRange for a float, the lower and upper MUST be floats, and so forth for integers.
The story begins in **torch/utils/_sympy/functions.py**. Here, I make some changes to how we represent certain operations in sympy expressions:
* FloorDiv now only supports integer inputs; to do float floor division, do a truediv and then a trunc. Additionally, we remove the divide out addition by gcd optimization, because sympy gcd is over fields and is willing to generate rationals (but rationals are bad for ValueRange strict typing).
* ModularIndexing, LShift, RShift now assert they are given integer inputs.
* Mod only supports integer inputs; eventually we will support FloatMod (left for later work, when we build out Sympy support for floating operations). Unfortunately, I couldn't assert integer inputs here, because of a bad interaction with sympy's inequality solver that is used by the offline solver
* TrueDiv is split into FloatTrueDiv and IntTrueDiv. This allows for us to eventually generate accurate code for Python semantics IntTrueDiv, which is written in a special way to preserve precision when the inputs are >= 2**53 beyond what first coercing the integer to floats and then doing true division.
* Trunc is split to TruncToFloat and TruncToInt.
* Round is updated to return a float, not an int, making it consistent with the round op handler in Inductor. To get Python-style conversion to int, we call TruncToInt on the result.
* RoundDecimal updated to consistently only ever return a float
* Add ToFloat for explicit coercion to float (required so we can enforce strict ValueRanges typing)
In **torch/__init__.py**, we modify SymInt and SymFloat to appropriately call into new bindings that route to these refined sympy operations. Also, we modify `torch.sym_min` and `torch.sym_max` to have promotion semantics (if one argument is a float, the return result is always a float), making them inconsistent with builtins.min/max, but possible to do type analysis without runtime information.
We also need to introduce some new op handlers in **torch/_inductor/ops_handler.py**:
* `to_int` for truncation to int64, directly corresponding to TruncToInt; this can be implemented by trunc and dtype, but with a dedicated handler it is more convenient for roundtripping in Sympy
* `int_truediv` for Python-style integer true division, which has higher precision than casting to floats and then running `truediv`
These changes have consequences. First, we need to make some administrative changes:
* Actually wire up these Sympy functions from SymInt/SymFloat in **torch/fx/experimental/sym_node.py**, including the new promotion rules (promote2)
* Add support for new Sympy functions in **torch/utils/_sympy/interp.py**, **torch/utils/_sympy/reference.py**
* In particular, in torch.utils._sympy.reference, we have a strong preference to NOT do nontrivial compute, instead, everything in ops handler should map to a singular sympy function
* TODO: I chose to roundtrip mod back to our Mod function, but I think I'm going to have to deal with the C/Python inconsistency this to fix tests here
* Add printer support for the Sympy functions in **torch/_inductor/codegen/common.py**, **torch/_inductor/codegen/cpp_utils.py**, **torch/_inductor/codegen/triton.py**. `int_truediv` and mixed precision equality is currently not implemented soundly, so we will lose precision in codegen for large values. TODO: The additions here are not exhaustive yet
* Update ValueRanges logic to use new sympy functions in **torch/utils/_sympy/value_ranges.py**. In general, we prefer to use the new Sympy function rather than try to roll things by hand, which is what was done previously for many VR analysis functions.
In **torch/fx/experimental/symbolic_shapes.py** we need to make some symbolic reasoning adjustments:
* Avoid generation of rational subexpressions by removing simplification of `x // y` into `floor(x / y)`. This simplification then triggers an addition simplification rule `(x + y) / c --> x / c + y / c` which is bad because x / c is a rational number now
* `_assert_bound_is_rational` is no more, we no longer generate rational bounds
* Don't intersect non-int value ranges with the `int_range`
* Support more sympy Functions for guard SYMPY_INTERP
* Assert the type of value range is consistent with the variable type
The new asserts uncovered necessary bug fixes:
* **torch/_inductor/codegen/cpp.py**, **torch/_inductor/select_algorithm.py**, **torch/_inductor/sizevars.py** - Ensure Wild/Symbol manually allocated in Inductor is marked `is_integer` so it's accepted to build expressions
* **torch/_inductor/utils.py** - make sure you actually pass in sympy.Expr to these functions
* **torch/_inductor/ir.py** - make_contiguous_strides_for takes int/SymInt, not sympy.Expr!
* **torch/export/dynamic_shapes.py** - don't use infinity to represent int ranges, instead use sys.maxsize - 1
Because of the removal of some symbolic reasoning that produced rationals, some of our symbolic reasoning has gotten worse and we are unable to simplify some guards. Check the TODO at **test/test_proxy_tensor.py**
**Reland notes.** This requires this internal fbcode diff https://www.internalfb.com/phabricator/paste/view/P1403322587 but I cannot prepare the diff codev due to https://fb.workplace.com/groups/osssupport/posts/26343544518600814/
It also requires this Executorch PR https://github.com/pytorch/executorch/pull/3911 but the ET PR can be landed prior to this landing.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126905
Approved by: https://github.com/xadupre, https://github.com/lezcano
Summary:
1. Integrate NaN and INF checker with existing config, controllable by env var.
2. Move inject point of NaN & INF checker earlier, this could prevent buffer freeing before check.
3. Inject debugging code in Kernel level, which prevents us trying to read buffers that are fused inplace and into a single kernel.
Test Plan:
Debugging utility.
Test and check by existing tests with env var:
```
TORCHINDUCTOR_NAN_ASSERTS=1 TORCHINDUCTOR_MAX_AUTOTUNE=0 python test/inductor/test_aot_inductor.py -k AOTInductorTestNonABICompatibleCuda.test_seq_non_abi_compatible_cuda
```
Reviewed By: ColinPeppler
Differential Revision: D57989176
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127574
Approved by: https://github.com/chenyang78, https://github.com/desertfire
Automated fixes to put imports that are only used in type hints into TYPE_CHECKING imports. This also enables the RUFF TCH rules which will automatically apply autofixes to move imports in and out of TYPE_CHECKING blocks as needed in the future, this will make the initial PyTorch import faster and will reduce cyclic dependencies.
Co-authored-by: Xuehai Pan <XuehaiPan@pku.edu.cn>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127688
Approved by: https://github.com/XuehaiPan, https://github.com/ezyang, https://github.com/malfet
At a high level, the idea behind this PR is:
* Make it clearer what the promotion and int/float rules for various Sympy operations are. Operators that previously were polymorphic over int/float are now split into separate operators for clarity. We never do mixed int/float addition/multiplication etc in sympy, instead, we always promote to the appropriate operator. (However, equality is currently not done correctly.)
* Enforce strict typing on ValueRanges: if you have a ValueRange for a float, the lower and upper MUST be floats, and so forth for integers.
The story begins in **torch/utils/_sympy/functions.py**. Here, I make some changes to how we represent certain operations in sympy expressions:
* FloorDiv now only supports integer inputs; to do float floor division, do a truediv and then a trunc. Additionally, we remove the divide out addition by gcd optimization, because sympy gcd is over fields and is willing to generate rationals (but rationals are bad for ValueRange strict typing).
* ModularIndexing, LShift, RShift now assert they are given integer inputs.
* Mod only supports integer inputs; eventually we will support FloatMod (left for later work, when we build out Sympy support for floating operations). Unfortunately, I couldn't assert integer inputs here, because of a bad interaction with sympy's inequality solver that is used by the offline solver
* TrueDiv is split into FloatTrueDiv and IntTrueDiv. This allows for us to eventually generate accurate code for Python semantics IntTrueDiv, which is written in a special way to preserve precision when the inputs are >= 2**53 beyond what first coercing the integer to floats and then doing true division.
* Trunc is split to TruncToFloat and TruncToInt.
* Round is updated to return a float, not an int, making it consistent with the round op handler in Inductor. To get Python-style conversion to int, we call TruncToInt on the result.
* RoundDecimal updated to consistently only ever return a float
* Add ToFloat for explicit coercion to float (required so we can enforce strict ValueRanges typing)
In **torch/__init__.py**, we modify SymInt and SymFloat to appropriately call into new bindings that route to these refined sympy operations. Also, we modify `torch.sym_min` and `torch.sym_max` to have promotion semantics (if one argument is a float, the return result is always a float), making them inconsistent with builtins.min/max, but possible to do type analysis without runtime information.
We also need to introduce some new op handlers in **torch/_inductor/ops_handler.py**:
* `to_int` for truncation to int64, directly corresponding to TruncToInt; this can be implemented by trunc and dtype, but with a dedicated handler it is more convenient for roundtripping in Sympy
* `int_truediv` for Python-style integer true division, which has higher precision than casting to floats and then running `truediv`
These changes have consequences. First, we need to make some administrative changes:
* Actually wire up these Sympy functions from SymInt/SymFloat in **torch/fx/experimental/sym_node.py**, including the new promotion rules (promote2)
* Add support for new Sympy functions in **torch/utils/_sympy/interp.py**, **torch/utils/_sympy/reference.py**
* In particular, in torch.utils._sympy.reference, we have a strong preference to NOT do nontrivial compute, instead, everything in ops handler should map to a singular sympy function
* TODO: I chose to roundtrip mod back to our Mod function, but I think I'm going to have to deal with the C/Python inconsistency this to fix tests here
* Add printer support for the Sympy functions in **torch/_inductor/codegen/common.py**, **torch/_inductor/codegen/cpp_utils.py**, **torch/_inductor/codegen/triton.py**. `int_truediv` and mixed precision equality is currently not implemented soundly, so we will lose precision in codegen for large values. TODO: The additions here are not exhaustive yet
* Update ValueRanges logic to use new sympy functions in **torch/utils/_sympy/value_ranges.py**. In general, we prefer to use the new Sympy function rather than try to roll things by hand, which is what was done previously for many VR analysis functions.
In **torch/fx/experimental/symbolic_shapes.py** we need to make some symbolic reasoning adjustments:
* Avoid generation of rational subexpressions by removing simplification of `x // y` into `floor(x / y)`. This simplification then triggers an addition simplification rule `(x + y) / c --> x / c + y / c` which is bad because x / c is a rational number now
* `_assert_bound_is_rational` is no more, we no longer generate rational bounds
* Don't intersect non-int value ranges with the `int_range`
* Support more sympy Functions for guard SYMPY_INTERP
* Assert the type of value range is consistent with the variable type
The new asserts uncovered necessary bug fixes:
* **torch/_inductor/codegen/cpp.py**, **torch/_inductor/select_algorithm.py**, **torch/_inductor/sizevars.py** - Ensure Wild/Symbol manually allocated in Inductor is marked `is_integer` so it's accepted to build expressions
* **torch/_inductor/utils.py** - make sure you actually pass in sympy.Expr to these functions
* **torch/_inductor/ir.py** - make_contiguous_strides_for takes int/SymInt, not sympy.Expr!
* **torch/export/dynamic_shapes.py** - don't use infinity to represent int ranges, instead use sys.maxsize - 1
Because of the removal of some symbolic reasoning that produced rationals, some of our symbolic reasoning has gotten worse and we are unable to simplify some guards. Check the TODO at **test/test_proxy_tensor.py**
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126905
Approved by: https://github.com/xadupre, https://github.com/lezcano
At a high level, the idea behind this PR is:
* Make it clearer what the promotion and int/float rules for various Sympy operations are. Operators that previously were polymorphic over int/float are now split into separate operators for clarity. We never do mixed int/float addition/multiplication etc in sympy, instead, we always promote to the appropriate operator. (However, equality is currently not done correctly.)
* Enforce strict typing on ValueRanges: if you have a ValueRange for a float, the lower and upper MUST be floats, and so forth for integers.
The story begins in **torch/utils/_sympy/functions.py**. Here, I make some changes to how we represent certain operations in sympy expressions:
* FloorDiv now only supports integer inputs; to do float floor division, do a truediv and then a trunc. Additionally, we remove the divide out addition by gcd optimization, because sympy gcd is over fields and is willing to generate rationals (but rationals are bad for ValueRange strict typing).
* ModularIndexing, LShift, RShift now assert they are given integer inputs.
* Mod only supports integer inputs; eventually we will support FloatMod (left for later work, when we build out Sympy support for floating operations). Unfortunately, I couldn't assert integer inputs here, because of a bad interaction with sympy's inequality solver that is used by the offline solver
* TrueDiv is split into FloatTrueDiv and IntTrueDiv. This allows for us to eventually generate accurate code for Python semantics IntTrueDiv, which is written in a special way to preserve precision when the inputs are >= 2**53 beyond what first coercing the integer to floats and then doing true division.
* Trunc is split to TruncToFloat and TruncToInt.
* Round is updated to return a float, not an int, making it consistent with the round op handler in Inductor. To get Python-style conversion to int, we call TruncToInt on the result.
* RoundDecimal updated to consistently only ever return a float
* Add ToFloat for explicit coercion to float (required so we can enforce strict ValueRanges typing)
In **torch/__init__.py**, we modify SymInt and SymFloat to appropriately call into new bindings that route to these refined sympy operations. Also, we modify `torch.sym_min` and `torch.sym_max` to have promotion semantics (if one argument is a float, the return result is always a float), making them inconsistent with builtins.min/max, but possible to do type analysis without runtime information.
We also need to introduce some new op handlers in **torch/_inductor/ops_handler.py**:
* `to_int` for truncation to int64, directly corresponding to TruncToInt; this can be implemented by trunc and dtype, but with a dedicated handler it is more convenient for roundtripping in Sympy
* `int_truediv` for Python-style integer true division, which has higher precision than casting to floats and then running `truediv`
These changes have consequences. First, we need to make some administrative changes:
* Actually wire up these Sympy functions from SymInt/SymFloat in **torch/fx/experimental/sym_node.py**, including the new promotion rules (promote2)
* Add support for new Sympy functions in **torch/utils/_sympy/interp.py**, **torch/utils/_sympy/reference.py**
* In particular, in torch.utils._sympy.reference, we have a strong preference to NOT do nontrivial compute, instead, everything in ops handler should map to a singular sympy function
* TODO: I chose to roundtrip mod back to our Mod function, but I think I'm going to have to deal with the C/Python inconsistency this to fix tests here
* Add printer support for the Sympy functions in **torch/_inductor/codegen/common.py**, **torch/_inductor/codegen/cpp_utils.py**, **torch/_inductor/codegen/triton.py**. `int_truediv` and mixed precision equality is currently not implemented soundly, so we will lose precision in codegen for large values. TODO: The additions here are not exhaustive yet
* Update ValueRanges logic to use new sympy functions in **torch/utils/_sympy/value_ranges.py**. In general, we prefer to use the new Sympy function rather than try to roll things by hand, which is what was done previously for many VR analysis functions.
In **torch/fx/experimental/symbolic_shapes.py** we need to make some symbolic reasoning adjustments:
* Avoid generation of rational subexpressions by removing simplification of `x // y` into `floor(x / y)`. This simplification then triggers an addition simplification rule `(x + y) / c --> x / c + y / c` which is bad because x / c is a rational number now
* `_assert_bound_is_rational` is no more, we no longer generate rational bounds
* Don't intersect non-int value ranges with the `int_range`
* Support more sympy Functions for guard SYMPY_INTERP
* Assert the type of value range is consistent with the variable type
The new asserts uncovered necessary bug fixes:
* **torch/_inductor/codegen/cpp.py**, **torch/_inductor/select_algorithm.py**, **torch/_inductor/sizevars.py** - Ensure Wild/Symbol manually allocated in Inductor is marked `is_integer` so it's accepted to build expressions
* **torch/_inductor/utils.py** - make sure you actually pass in sympy.Expr to these functions
* **torch/_inductor/ir.py** - make_contiguous_strides_for takes int/SymInt, not sympy.Expr!
* **torch/export/dynamic_shapes.py** - don't use infinity to represent int ranges, instead use sys.maxsize - 1
Because of the removal of some symbolic reasoning that produced rationals, some of our symbolic reasoning has gotten worse and we are unable to simplify some guards. Check the TODO at **test/test_proxy_tensor.py**
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126905
Approved by: https://github.com/xadupre, https://github.com/lezcano
Fixes#1232102f3d3ddd70/torch/_inductor/runtime/triton_heuristics.py (L1733-L1753)
If a kernel's y_grid is larger than 65535, it will be split into multiple z grids. The above grad_fn does this split before the kernel launch; however, the computations for yoffset and the y_grid are incorrect. For example, if we have xy numel of `(1*XBLOCK, 65537*YBLOCK)`, this function will return an [xyz]_grid with (1, 32768, 2). XBLOCK and YBLOCK here are used for the following `get_grid_dim`. Let's use their default values (4, 1024).
2f3d3ddd70/torch/_inductor/runtime/triton_heuristics.py (L1734)
[xyz]_grid = (1, 32768, 2) means the workload are divided to two z grids. Because the triton kernel generation still follows xy dimension, one of the exampled generated kernel is shown below.
```python
@triton.jit
def triton_(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
ynumel = 65537*1024
xnumel = 1*4
yoffset = tl.program_id(1) * (tl.program_id(2) + 1) * YBLOCK
yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
ymask = yindex < ynumel
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
x2 = xindex
y0 = yindex % 128
y1 = (yindex // 128)
y3 = yindex
tmp0 = tl.load(in_ptr0 + (y0 + (128*x2) + (512*y1)), xmask, eviction_policy='evict_last')
tl.store(out_ptr0 + (x2 + (4*y3)), tmp0, xmask)
```
For a trition block with xyz index (0, 0, 1), its yoffset and xoffset are both 0s based on the compuation `yoffset = tl.program_id(1) * (tl.program_id(2) + 1) * YBLOCK` and `xoffset = tl.program_id(0) * XBLOCK`. So, this triton block will access the very first elements of the input. However, the correct yoffset should be `(y_index + z_index * y_grid ) * YBLOCK` which is the starting position of the 2nd z grid.
At the same time, because we used `y_grid = y_grid // div` to compute the maximum number of element in y dimension, the y_grid is 32768. The total y grids is 32768*2 = 65536, which is less than the actual y grids 65537. So, we should use `y_grid = ceildiv(y_grid, div)` to compute the y grid to save the remaining grids.
#123210 is not about AOTInductor, the root cause is the triton kernel generated by torchinductor.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127448
Approved by: https://github.com/eellison
For a masked `tl.load` operation, the Triton language specifies that values masked out (i.e. where the mask evaluates to false) are undefined in the output of the load. Triton provides an optional `other` parameter which, when included, provides an explicit value to use for masked out values from the load. If the output from a masked load without the `other` parameter is used in a conditional, unexpected behavior can occur.
Despite the language specification, all Triton backends currently in use by PyTorch Inductor (NVIDIA, AMD, and Intel) 0-initialize masked loads if `other` is not present (we recently changed the Intel backend behavior to match NVIDIA and AMD because that's what our users expect, even if we are not following the Triton spec to the tee). This PR attempts to "future-proof" Inductor for new backends (or perhaps changes in the current backends? - we did not see any performance change from 0-initializing in the Intel XPU backend but one could imagine compiler optimizations to remove paths that depend on undefined) to add an explicit `other` in instances where later conditionals depend on the `tl.load` output. I also removed an exception to `other` behavior for boolean loads, which was put in place for a Triton bug that should be fixed. I added `other` to the getting started documentation as a clue that masked load behavior requires explicit initialization if, even though I don't expect `undef` values to cause the example code to fail if the underlying output is not 0-initialized. Finally, I added other to the `make_load` function in `select_algorithm.py`, though I wasn't able to determine if that function was actually being called.
Fixes#126535
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127311
Approved by: https://github.com/jansel
Before, when calling ops.where, masks were not properly propagated. We
now restrict the optimisation to `ops.masked`, which I think it was what
the original code intended to do.
I'm not 100% sure that even in the masked case this code is not
introducing some bugs, but this is a strict improvement over the
previous state.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125574
Approved by: https://github.com/peterbell10
ghstack dependencies: #114471, #126783
This pass was broken in a number of ways, as we were not generating
asserts whenever we took it, even though we need to. While doing so,
we found that the analysis we were using for choosing
whether to generate asserts or not for dynamic shapes was completely
broken.
Eliminating indirect indexing in this way allows for a number of optimisations.
In particular, we can now fuse against these kernels (indirect indexing disallows fusions).
The new strategy is as follows:
- We always propagate sympy expressions if we can.
- If an expression was an indirect_indexing, we call `check_bounds`
- We also call `check_bounds` within `CSEProxy.indirect_indexing`
- The checks are issued in the buffer where they would go if the were used in a load
- This makes them always be codegen'd before the load and stores
- In the case of stores, they will be generated potentially much earlier than the stores themselves, which is fine.
We add quite a few asserts to preexisting tests to strengthen them. In particular, we make sure
that issuing an assert plays well with all kinds of C++ vectorisation.
For now, we rely on the logic within `_maybe_evaluate_static` to prove
these bounds. This logic is rather limited though. In the future, we might want
to rely on Z3 here to be able to prove bounds in a more general way.
Supersedes https://github.com/pytorch/pytorch/pull/113068
Fixes https://github.com/pytorch/pytorch/issues/121251
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114471
Approved by: https://github.com/peterbell10
Add `# mypy: disallow-untyped-defs` to scheduler.py and then fix the resulting fallout.
We probably should eventually add a new node between BaseSchedulerNode and all the non-FusedSchedulerNode types to indicate the split between nodes that have a valid `self.node` and ones that don't. That would cause a lot of the `assert self.node is not None` churn to go away - but was a bigger change because a lot of code makes assumptions about types that aren't reflected in the types themselves.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126656
Approved by: https://github.com/eellison