pytorch/torch
Colin Peppler fe285b9560 [aoti] fix corner case in unbacked replacements for atomically_apply_size_hint (#153768)
## PR
There are a few cases that my previous PR (#153220) didn't cover.
1. The LHS/RHS matters. Today, if you do `torch._check(lhs == rhs)` then it will show up as a deferred runtime assert with `Eq(lhs, rhs)`.
2. There can be transitive replacements. For example, expr1 -> expr2 -> u0. `test_size_with_unbacked_add_expr_transitive` tests for this.
3. An unbacked symint expr may not have a replacement that's purely a symbol, for instance, it could be another expression. `test_size_with_unbacked_add_and_mul_expr` tests for this.

## Device assertion msg

```
/tmp/tmp07mu50tx/6y/c6ym2jzadwfigu3yexredb7qofviusz3p7ozcdjywvayhxgcqxkp.py:40: unknown: block: [8681,0,0], thread: [4,0,0] Assertion `index out of bounds: 0 <= tl.broadcast_to(tmp13, [XBLOCK]) < ks0` failed.
...
/tmp/tmp07mu50tx/6y/c6ym2jzadwfigu3yexredb7qofviusz3p7ozcdjywvayhxgcqxkp.py:40: unknown: block: [8681,0,0], thread: [6,0,0] Assertion `index out of bounds: 0 <= tl.broadcast_to(tmp13, [XBLOCK]) < ks0` failed.
```

## Autotuning code setup
This is the autotuning code for a concat kernel which takes input tensors (`in_buf`) and writes them to the (`out_buf`).

It's important to note the size of `in_buf0` is the same as `in_buf1` don't match along dim=0. This is bad because all concat inputs must share the same size for each dim except for the concat dim (here that's dim=1).
```
in_buf0 = generate_example_value(size=(u1 + s0, 256))   # concrete size is (17900, 256)
in_buf1 = generate_example_value(size=(u0, 10))         # concrete size is (8192, 10)
...
out_buf = generate_example_value(size=(u1 + s0, 266))   # concrete size is (17900, 256+10)
triton_poi_fused_cat_1.run(in_buf0, in_buf1, ..., out_buf, xnumel=(u1 + s0) * 266 ...)
```

If we look into the kernel code, you'll see that `tmp9` loads `in_buf1` (our incorrectly shaped input tensor). There is also a mask to prevent OOB loads.
- `tmp6`  makes sure we're only loading with the `xindex` from 256 to 264.
- `xmask` makes sure we're only loading with the `xindex` within `xnumel`.
- `tmp6 & xmask` together is essentially checking `0 ≤ x0 < u1 + s0` and `256 ≤ x1 < 264`.

The mask logic is correct, however, `in_buf1` has the shape `[8192, 10]` this means any load where `8192 ≤ x0 < u1 + s0` will be an OOB load.
```
def triton_poi_fused_cat_1(in_buf0, in_buf1, ... out_buf, xnumel, XBLOCK):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)
    xmask = xindex < xnumel
    x0 = (xindex % 264)
    x1 = xindex // 264
    ...
    tmp6 = x0 >= tl.full([1], value=256)
    tmp9 = tl.load(in_buf1 + (x1), tmp6 & xmask)
    # device assertion is thrown here
    tl.device_assert(((0 <= tl.broadcast_to(tmp13, [XBLOCK])) & (tl.broadcast_to(tmp13, [XBLOCK]) < ks0)) | ~(xmask & tmp6), "index out of bounds: 0 <= tl.broadcast_to(tmp13, [XBLOCK]) < ks0")
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153768
Approved by: https://github.com/jingsh
2025-05-22 02:05:37 +00:00
..
_awaits
_C [aoti] Add MPS runner and shim (#153964) 2025-05-21 21:55:59 +00:00
_C_flatbuffer
_custom_op
_decomp Fix torch.isin decomposition for scalar inputs (#153216) 2025-05-09 20:26:25 +00:00
_dispatch
_dynamo [dynamo] renamed _fn for more clarity and put a comment of user compiler user (#154026) 2025-05-21 21:12:51 +00:00
_export [aoti] Initial Metal support (#153959) 2025-05-21 21:55:59 +00:00
_functorch [aot] fix deepcopying of aot bwd containing real tensors (#153999) 2025-05-21 23:30:02 +00:00
_higher_order_ops [map] add inductor support by lowering to while_loop (#150971) 2025-05-21 22:19:47 +00:00
_inductor [aoti] fix corner case in unbacked replacements for atomically_apply_size_hint (#153768) 2025-05-22 02:05:37 +00:00
_lazy
_library Add torch._C.Tag.needs_contiguous_strides (#152859) 2025-05-08 04:49:59 +00:00
_logging Add flag _metrics_log_runtime to disable runtime metric logging by default (#153506) 2025-05-22 01:02:11 +00:00
_numpy Fix broken URLs (#152237) 2025-04-27 09:56:42 +00:00
_prims
_prims_common [Intel GPU][Inductor] Fallback embedding_dense_backward on XPU (#151637) 2025-05-19 02:19:37 +00:00
_refs Treat dim=[] same as dim=None (#153570) 2025-05-20 22:44:29 +00:00
_strobelight
_subclasses Revert "Fix fake tensor caching when output has unbacked (#153034)" 2025-05-20 06:02:38 +00:00
_vendor
accelerator Add torch.accelerator.device_index as accelerator's device switch context (#148864) 2025-04-25 09:45:25 +00:00
amp [Intel GPU] skip a cuda api call in amp to save some host overhead on xpu (#151111) 2025-04-13 06:37:07 +00:00
ao [BE]: Type previously untyped decorators (#153726) 2025-05-21 15:56:19 +00:00
autograd Add memory reporting for XPU to Memory Profiler (#152842) 2025-05-21 01:19:19 +00:00
backends Revert "refine fp32 precision api (#125888)" 2025-05-11 00:35:46 +00:00
compiler [MegaCache] Make MegaCache generic to allow external plugins registration (#152977) 2025-05-21 18:18:47 +00:00
contrib
cpu [device_mesh] improve device selection logic (#150897) 2025-05-14 06:29:16 +00:00
csrc [c10d] Turn off default non-blocking API mode to work around hang in NCCL 2.26 (#154055) 2025-05-21 23:46:52 +00:00
cuda make use_mem_pool threadlocal (#153356) 2025-05-13 00:16:07 +00:00
distributed Render Example: and not Example:: in docs (#153978) 2025-05-21 01:03:26 +00:00
distributions Fix support of MixtureSameFamily [bugfix]. (#151317) 2025-05-14 19:24:36 +00:00
export [export] Move PT2 constants to torch::_export (#153206) 2025-05-17 08:21:59 +00:00
fft
func
futures Render Example: and not Example:: in docs (#153978) 2025-05-21 01:03:26 +00:00
fx auto functionalize base_hop (#151067) 2025-05-21 18:55:46 +00:00
jit Fix broken URLs (#152237) 2025-04-27 09:56:42 +00:00
legacy
lib [1/N] Use internal linkage in torch/csrc C++ files. (#150930) 2025-04-11 02:19:31 +00:00
linalg Fix broken URLs (#152237) 2025-04-27 09:56:42 +00:00
masked [BE]: Update ruff to 0.11.8 (#153249) 2025-05-12 18:30:52 +00:00
monitor
mps
mtia [Kineto] Enable OOM observer (#152160) 2025-04-27 15:56:44 +00:00
multiprocessing
nativert [nativert] Move GraphSignature to pytorch core (#152969) 2025-05-20 21:49:56 +00:00
nested [Torch][NT] Fix NestedTensor contiguous check condition. (#153237) (#153529) 2025-05-14 17:15:48 +00:00
nn docs: fix "should not to be" typo in register_buffer docstring (#153817) 2025-05-21 22:46:50 +00:00
onnx [ONNX] Support float4 (#151069) 2025-05-18 03:19:35 +00:00
optim Add load_state_dict hint doc about invoke order work with lr_scheduler (#149942) 2025-05-15 01:07:36 +00:00
package [BE]: Enable ruff YTT linter for Python version checks (#153547) 2025-05-14 21:09:16 +00:00
profiler [profiler][retry] don't disable CUPTI_LAZY_REINIT for cuda >= 12.6 (#151124) 2025-04-15 16:11:49 +00:00
quantization
signal
sparse Revert "has_triton: Use the device interface for detecting Triton availability (#139171)" 2025-05-10 14:46:23 +00:00
special Render Example: and not Example:: in docs (#153978) 2025-05-21 01:03:26 +00:00
testing Treat dim=[] same as dim=None (#153570) 2025-05-20 22:44:29 +00:00
utils [ROCm] improve sparse addmm, enable complex (#153262) 2025-05-19 22:23:18 +00:00
xpu Correct torch.xpu.is_bf16_supported return False if no XPU detected (#152317) 2025-05-06 10:03:17 +00:00
__config__.py
__future__.py
__init__.py [BE] Improve the typing related to model input argument of torch.compile() (#153559) 2025-05-15 04:49:26 +00:00
_appdirs.py Fix broken URLs (#152237) 2025-04-27 09:56:42 +00:00
_classes.py
_compile.py
_custom_ops.py Render Example: and not Example:: in docs (#153978) 2025-05-21 01:03:26 +00:00
_deploy.py
_environment.py
_guards.py Revert "[BE]: Enable RUFF TRY400 rule - log.exception (#153473)" 2025-05-16 08:29:26 +00:00
_jit_internal.py [BE]: Type previously untyped decorators (#153726) 2025-05-21 15:56:19 +00:00
_linalg_utils.py
_lobpcg.py Fixed rerr computation in lobpcg (#152789) 2025-05-08 12:22:31 +00:00
_lowrank.py
_meta_registrations.py [Intel GPU][Inductor] Fallback embedding_dense_backward on XPU (#151637) 2025-05-19 02:19:37 +00:00
_namedtensor_internals.py
_ops.py Revert "Improve torch.ops typing (#153558)" 2025-05-19 23:32:36 +00:00
_python_dispatcher.py
_size_docs.py Render Example: and not Example:: in docs (#153978) 2025-05-21 01:03:26 +00:00
_sources.py
_storage_docs.py
_streambase.py
_tensor_docs.py Render Example: and not Example:: in docs (#153978) 2025-05-21 01:03:26 +00:00
_tensor_str.py
_tensor.py Avoid triggering ignored requires_grad warning in our code (#152686) 2025-05-05 23:56:40 +00:00
_thread_safe_fork.py
_torch_docs.py Render Example: and not Example:: in docs (#153978) 2025-05-21 01:03:26 +00:00
_utils_internal.py [reland] Add graph module runtime asserts to AOTI (#153182) 2025-05-09 22:56:19 +00:00
_utils.py
_VF.py
_vmap_internals.py Fix broken URLs (#152237) 2025-04-27 09:56:42 +00:00
_weights_only_unpickler.py
CMakeLists.txt Refactor torch/utils/data/datapipes/gen_pyi.py with torchgen (#150626) 2025-05-17 06:21:41 +00:00
custom_class_detail.h
custom_class.h
extension.h
functional.py Optimize cdist param description (#151178) 2025-04-14 13:53:10 +00:00
header_only_apis.txt Add torch/header_only_apis.txt and enforce they're tested (#153635) 2025-05-20 23:42:24 +00:00
hub.py
library.h Overload Library::def rather than templating it (#151626) 2025-04-18 22:51:16 +00:00
library.py Render Example: and not Example:: in docs (#153978) 2025-05-21 01:03:26 +00:00
overrides.py [CUDA][cuBLAS] Aten GEMM overload for FP32 output from FP16/BF16 inputs (#150812) 2025-04-18 01:53:26 +00:00
py.typed
quasirandom.py
random.py Update description for torch.random.fork_rng (#151881) 2025-04-23 16:59:29 +00:00
return_types.py
script.h
serialization.py Update serialization docs (#153631) 2025-05-19 20:22:07 +00:00
storage.py
torch_version.py
types.py
version.py.tpl