pytorch/torch/_inductor
Colin Peppler fe285b9560 [aoti] fix corner case in unbacked replacements for atomically_apply_size_hint (#153768)
## PR
There are a few cases that my previous PR (#153220) didn't cover.
1. The LHS/RHS matters. Today, if you do `torch._check(lhs == rhs)` then it will show up as a deferred runtime assert with `Eq(lhs, rhs)`.
2. There can be transitive replacements. For example, expr1 -> expr2 -> u0. `test_size_with_unbacked_add_expr_transitive` tests for this.
3. An unbacked symint expr may not have a replacement that's purely a symbol, for instance, it could be another expression. `test_size_with_unbacked_add_and_mul_expr` tests for this.

## Device assertion msg

```
/tmp/tmp07mu50tx/6y/c6ym2jzadwfigu3yexredb7qofviusz3p7ozcdjywvayhxgcqxkp.py:40: unknown: block: [8681,0,0], thread: [4,0,0] Assertion `index out of bounds: 0 <= tl.broadcast_to(tmp13, [XBLOCK]) < ks0` failed.
...
/tmp/tmp07mu50tx/6y/c6ym2jzadwfigu3yexredb7qofviusz3p7ozcdjywvayhxgcqxkp.py:40: unknown: block: [8681,0,0], thread: [6,0,0] Assertion `index out of bounds: 0 <= tl.broadcast_to(tmp13, [XBLOCK]) < ks0` failed.
```

## Autotuning code setup
This is the autotuning code for a concat kernel which takes input tensors (`in_buf`) and writes them to the (`out_buf`).

It's important to note the size of `in_buf0` is the same as `in_buf1` don't match along dim=0. This is bad because all concat inputs must share the same size for each dim except for the concat dim (here that's dim=1).
```
in_buf0 = generate_example_value(size=(u1 + s0, 256))   # concrete size is (17900, 256)
in_buf1 = generate_example_value(size=(u0, 10))         # concrete size is (8192, 10)
...
out_buf = generate_example_value(size=(u1 + s0, 266))   # concrete size is (17900, 256+10)
triton_poi_fused_cat_1.run(in_buf0, in_buf1, ..., out_buf, xnumel=(u1 + s0) * 266 ...)
```

If we look into the kernel code, you'll see that `tmp9` loads `in_buf1` (our incorrectly shaped input tensor). There is also a mask to prevent OOB loads.
- `tmp6`  makes sure we're only loading with the `xindex` from 256 to 264.
- `xmask` makes sure we're only loading with the `xindex` within `xnumel`.
- `tmp6 & xmask` together is essentially checking `0 ≤ x0 < u1 + s0` and `256 ≤ x1 < 264`.

The mask logic is correct, however, `in_buf1` has the shape `[8192, 10]` this means any load where `8192 ≤ x0 < u1 + s0` will be an OOB load.
```
def triton_poi_fused_cat_1(in_buf0, in_buf1, ... out_buf, xnumel, XBLOCK):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)
    xmask = xindex < xnumel
    x0 = (xindex % 264)
    x1 = xindex // 264
    ...
    tmp6 = x0 >= tl.full([1], value=256)
    tmp9 = tl.load(in_buf1 + (x1), tmp6 & xmask)
    # device assertion is thrown here
    tl.device_assert(((0 <= tl.broadcast_to(tmp13, [XBLOCK])) & (tl.broadcast_to(tmp13, [XBLOCK]) < ks0)) | ~(xmask & tmp6), "index out of bounds: 0 <= tl.broadcast_to(tmp13, [XBLOCK]) < ks0")
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153768
Approved by: https://github.com/jingsh
2025-05-22 02:05:37 +00:00
..
autoheuristic
codegen Revert "cpp_wrapper: build non-performance-sensitive code at O1 (#148773)" 2025-05-22 00:11:14 +00:00
compile_worker Allow to set custom PYTHONPATH for torch.inductor (#152832) 2025-05-15 06:35:41 +00:00
fx_passes [map] add inductor support by lowering to while_loop (#150971) 2025-05-21 22:19:47 +00:00
kernel [ROCm][Inductor][CK] Add ck-tile based universal gemm kernels to torch.mm autotune choices (#152341) 2025-05-21 23:59:16 +00:00
package [export] Move PT2 constants to torch::_export (#153206) 2025-05-17 08:21:59 +00:00
runtime [MegaCache] Make MegaCache generic to allow external plugins registration (#152977) 2025-05-21 18:18:47 +00:00
__autotune_main__.py Improve subproc autotuning implementation (#149700) 2025-03-28 01:06:39 +00:00
__init__.py Add optional device index to AOTIModelPackageLoader (#152093) 2025-05-04 11:40:12 +00:00
analyze_preserves_zero_mask.py Revert two recent prologue prs (#151013) 2025-04-10 23:48:41 +00:00
aoti_eager.py
async_compile.py Pass inductor config for static cuda launcher to workers (#153382) 2025-05-14 20:01:32 +00:00
autotune_process.py [inductor][cutlass backend] Add 2 stage autotuning aka prescreening (#153335) 2025-05-21 17:12:05 +00:00
bounds.py [inductor] Refactor op handlers part 5 (#146257) 2025-02-08 18:00:30 +00:00
choices.py Reland "Introduce new template heuristic for triton autotune configs" (#147452) 2025-03-26 15:47:06 +00:00
codecache.py Revert "cpp_wrapper: build non-performance-sensitive code at O1 (#148773)" 2025-05-22 00:11:14 +00:00
comm_analysis.py
comm_lowering.py Fix an issue where functional collectives don't force fx stride on inputs when compiled (#146467) 2025-02-10 19:15:49 +00:00
comms.py Make assertion about pass callable print the bad pass (#152654) 2025-05-05 18:07:43 +00:00
compile_fx_async.py Use correct boxed_forward_device_index when running CompiledFxGraph.post_compile (#148130) 2025-03-23 02:57:58 +00:00
compile_fx_ext.py [NFC] [inductor] [compile async] Warn exception if pickler failed (#152401) 2025-05-06 07:12:35 +00:00
compile_fx_subproc.py async fx compile (#146135) 2025-03-19 14:07:51 +00:00
compile_fx.py Add flag _metrics_log_runtime to disable runtime metric logging by default (#153506) 2025-05-22 01:02:11 +00:00
compiler_bisector.py Add a couple config options to compiler bisector (#148450) 2025-03-04 23:23:21 +00:00
config.py [ROCm][Inductor][CK] Add ck-tile based universal gemm kernels to torch.mm autotune choices (#152341) 2025-05-21 23:59:16 +00:00
constant_folding.py Fix constant folding cloning constants (#152273) 2025-05-01 17:34:39 +00:00
cpp_builder.py [AOTI][reland] Add an option to specify custom op C shim (#153968) 2025-05-21 15:57:57 +00:00
cpu_vec_isa.py Allow to set custom PYTHONPATH for torch.inductor (#152832) 2025-05-15 06:35:41 +00:00
cudagraph_trees.py [BE]: Update ruff to 0.11.8 (#153249) 2025-05-12 18:30:52 +00:00
cudagraph_utils.py [CUDAGraph] support meta tensor (#150478) 2025-04-02 07:21:50 +00:00
custom_graph_pass.py
debug.py Revert "[BE]: Enable RUFF TRY400 rule - log.exception (#153473)" 2025-05-16 08:29:26 +00:00
decomposition.py Revert "Improve torch.ops typing (#153558)" 2025-05-19 23:32:36 +00:00
dependencies.py [Graph Partition] Support symbol inputs (#149458) 2025-03-26 17:21:30 +00:00
dtype_propagation.py Remove libdevice ops in inductor (#151562) 2025-04-17 22:18:00 +00:00
exc.py
extern_node_serializer.py Back out "[AOTI] Always use oss schema for ExternKernelNodes serialization" (#151026) 2025-04-10 22:36:35 +00:00
freezing_utils.py PEP585: More UP006 fixes (#146392) 2025-02-20 06:18:13 +00:00
freezing.py [cudagraphs] Fix issue in collecting static_input_idxs (#152287) 2025-04-30 03:24:05 +00:00
fuzzer.py [AOTI][reland] Add an option to specify custom op C shim (#153968) 2025-05-21 15:57:57 +00:00
fx_utils.py Scheduler Flops refactor (#152708) 2025-05-09 19:01:43 +00:00
graph.py Revert "cpp_wrapper: build non-performance-sensitive code at O1 (#148773)" 2025-05-22 00:11:14 +00:00
hooks.py
index_propagation.py [BE][PYFMT] migrate PYFMT for torch._inductor to ruff format (#144550) 2025-02-28 13:33:19 +00:00
inductor_prims.py [inductor] Lowerings for max_pool3d (#148210) 2025-04-02 14:13:01 +00:00
ir.py [AOTI][reland] Add an option to specify custom op C shim (#153968) 2025-05-21 15:57:57 +00:00
jagged_lowerings.py
loop_body.py [ez] fix typo in comment (#151755) 2025-04-21 14:52:39 +00:00
lowering.py [Intel GPU][Inductor] Fallback embedding_dense_backward on XPU (#151637) 2025-05-19 02:19:37 +00:00
memory.py [Graph Partition] reorder for minimal number of partitions (#151968) 2025-04-29 17:17:16 +00:00
metrics.py [Inductor] Support parallel reduction for GroupNorm (#144020) 2025-03-01 17:11:50 +00:00
mkldnn_ir.py [Quant][PT2E][X86] enable qconv1d-relu fusion (#150751) 2025-04-09 14:42:02 +00:00
mkldnn_lowerings.py [BE]: Update ruff to 0.11.8 (#153249) 2025-05-12 18:30:52 +00:00
mock_cache.py
ops_handler.py Remove libdevice ops in inductor (#151562) 2025-04-17 22:18:00 +00:00
optimize_indexing.py
output_code.py codecache: Remove cpp_prefix.h duplication per build, then precompile it (#144293) 2025-05-16 17:41:36 +00:00
pattern_matcher.py Rename node.meta["arg_kwarg_vals"] to node.meta["eager_input_vals"] (#148092) 2025-04-02 13:18:04 +00:00
quantized_lowerings.py Add AOTI shim for _weight_int4pack_mm_cpu_tensor (#149031) 2025-03-18 01:33:13 +00:00
remote_cache.py [Indcutor Remote Cache] Raise an exception if redis module is required but not available (#151779) 2025-04-26 11:21:54 +00:00
scheduler.py [Easy][Inductor] Adds safety checks in get_estimated_runtime (#152821) 2025-05-14 21:46:59 +00:00
script.ld
select_algorithm.py Cache code generation during triton template expansion and enable it for mm_template. (#151773) 2025-05-21 18:55:41 +00:00
sizevars.py [aoti] fix corner case in unbacked replacements for atomically_apply_size_hint (#153768) 2025-05-22 02:05:37 +00:00
standalone_compile.py Add logging for guard miss failure (#153125) 2025-05-09 16:51:04 +00:00
subgraph_lowering.py [inductor] Refactor op handlers part 5 (#146257) 2025-02-08 18:00:30 +00:00
template_heuristics.py [Inductor] Add Additional Configs for persistent+TMA version of Triton mm and addmm (#150587) 2025-04-23 18:21:35 +00:00
test_case.py
test_operators.py [CI] Fix GPUTests.test_scheduler_vertical_fusion1 (#151166) 2025-04-13 00:41:51 +00:00
triton_bundler.py Keep raw cubin file around in case it gets deleted underneath us (#153064) 2025-05-08 14:29:19 +00:00
utils.py [aoti] Initial Metal support (#153959) 2025-05-21 21:55:59 +00:00
virtualized.py [inductor] Add a helper for convert index_dtype to torch dtype (#149531) 2025-03-20 21:33:29 +00:00
wrapper_benchmark.py [Inductor][NCU] Add kernel name filtering, and allow custom metrics (#150872) 2025-05-04 20:49:19 +00:00