mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
## PR There are a few cases that my previous PR (#153220) didn't cover. 1. The LHS/RHS matters. Today, if you do `torch._check(lhs == rhs)` then it will show up as a deferred runtime assert with `Eq(lhs, rhs)`. 2. There can be transitive replacements. For example, expr1 -> expr2 -> u0. `test_size_with_unbacked_add_expr_transitive` tests for this. 3. An unbacked symint expr may not have a replacement that's purely a symbol, for instance, it could be another expression. `test_size_with_unbacked_add_and_mul_expr` tests for this. ## Device assertion msg ``` /tmp/tmp07mu50tx/6y/c6ym2jzadwfigu3yexredb7qofviusz3p7ozcdjywvayhxgcqxkp.py:40: unknown: block: [8681,0,0], thread: [4,0,0] Assertion `index out of bounds: 0 <= tl.broadcast_to(tmp13, [XBLOCK]) < ks0` failed. ... /tmp/tmp07mu50tx/6y/c6ym2jzadwfigu3yexredb7qofviusz3p7ozcdjywvayhxgcqxkp.py:40: unknown: block: [8681,0,0], thread: [6,0,0] Assertion `index out of bounds: 0 <= tl.broadcast_to(tmp13, [XBLOCK]) < ks0` failed. ``` ## Autotuning code setup This is the autotuning code for a concat kernel which takes input tensors (`in_buf`) and writes them to the (`out_buf`). It's important to note the size of `in_buf0` is the same as `in_buf1` don't match along dim=0. This is bad because all concat inputs must share the same size for each dim except for the concat dim (here that's dim=1). ``` in_buf0 = generate_example_value(size=(u1 + s0, 256)) # concrete size is (17900, 256) in_buf1 = generate_example_value(size=(u0, 10)) # concrete size is (8192, 10) ... out_buf = generate_example_value(size=(u1 + s0, 266)) # concrete size is (17900, 256+10) triton_poi_fused_cat_1.run(in_buf0, in_buf1, ..., out_buf, xnumel=(u1 + s0) * 266 ...) ``` If we look into the kernel code, you'll see that `tmp9` loads `in_buf1` (our incorrectly shaped input tensor). There is also a mask to prevent OOB loads. - `tmp6` makes sure we're only loading with the `xindex` from 256 to 264. - `xmask` makes sure we're only loading with the `xindex` within `xnumel`. - `tmp6 & xmask` together is essentially checking `0 ≤ x0 < u1 + s0` and `256 ≤ x1 < 264`. The mask logic is correct, however, `in_buf1` has the shape `[8192, 10]` this means any load where `8192 ≤ x0 < u1 + s0` will be an OOB load. ``` def triton_poi_fused_cat_1(in_buf0, in_buf1, ... out_buf, xnumel, XBLOCK): xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK) xmask = xindex < xnumel x0 = (xindex % 264) x1 = xindex // 264 ... tmp6 = x0 >= tl.full([1], value=256) tmp9 = tl.load(in_buf1 + (x1), tmp6 & xmask) # device assertion is thrown here tl.device_assert(((0 <= tl.broadcast_to(tmp13, [XBLOCK])) & (tl.broadcast_to(tmp13, [XBLOCK]) < ks0)) | ~(xmask & tmp6), "index out of bounds: 0 <= tl.broadcast_to(tmp13, [XBLOCK]) < ks0") ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/153768 Approved by: https://github.com/jingsh |
||
|---|---|---|
| .. | ||
| _awaits | ||
| _C | ||
| _C_flatbuffer | ||
| _custom_op | ||
| _decomp | ||
| _dispatch | ||
| _dynamo | ||
| _export | ||
| _functorch | ||
| _higher_order_ops | ||
| _inductor | ||
| _lazy | ||
| _library | ||
| _logging | ||
| _numpy | ||
| _prims | ||
| _prims_common | ||
| _refs | ||
| _strobelight | ||
| _subclasses | ||
| _vendor | ||
| accelerator | ||
| amp | ||
| ao | ||
| autograd | ||
| backends | ||
| compiler | ||
| contrib | ||
| cpu | ||
| csrc | ||
| cuda | ||
| distributed | ||
| distributions | ||
| export | ||
| fft | ||
| func | ||
| futures | ||
| fx | ||
| jit | ||
| legacy | ||
| lib | ||
| linalg | ||
| masked | ||
| monitor | ||
| mps | ||
| mtia | ||
| multiprocessing | ||
| nativert | ||
| nested | ||
| nn | ||
| onnx | ||
| optim | ||
| package | ||
| profiler | ||
| quantization | ||
| signal | ||
| sparse | ||
| special | ||
| testing | ||
| utils | ||
| xpu | ||
| __config__.py | ||
| __future__.py | ||
| __init__.py | ||
| _appdirs.py | ||
| _classes.py | ||
| _compile.py | ||
| _custom_ops.py | ||
| _deploy.py | ||
| _environment.py | ||
| _guards.py | ||
| _jit_internal.py | ||
| _linalg_utils.py | ||
| _lobpcg.py | ||
| _lowrank.py | ||
| _meta_registrations.py | ||
| _namedtensor_internals.py | ||
| _ops.py | ||
| _python_dispatcher.py | ||
| _size_docs.py | ||
| _sources.py | ||
| _storage_docs.py | ||
| _streambase.py | ||
| _tensor_docs.py | ||
| _tensor_str.py | ||
| _tensor.py | ||
| _thread_safe_fork.py | ||
| _torch_docs.py | ||
| _utils_internal.py | ||
| _utils.py | ||
| _VF.py | ||
| _vmap_internals.py | ||
| _weights_only_unpickler.py | ||
| CMakeLists.txt | ||
| custom_class_detail.h | ||
| custom_class.h | ||
| extension.h | ||
| functional.py | ||
| header_only_apis.txt | ||
| hub.py | ||
| library.h | ||
| library.py | ||
| overrides.py | ||
| py.typed | ||
| quasirandom.py | ||
| random.py | ||
| return_types.py | ||
| script.h | ||
| serialization.py | ||
| storage.py | ||
| torch_version.py | ||
| types.py | ||
| version.py.tpl | ||