This is enough to get @XilunWu 's stack in a state where his flex_attention DTensor implementations worked E2E for me. It also required these changes on the DTensor side, to properly add a DTensor rule for flex backward: P1789852198
There are two problems:
(1) in the normal dispatcher, we have a precedence ordering between modes and subclasses. Modes are dispatched to first, but modes are allowed to return NotImplemented, giving subclasses a chance to run.
This normally happens automatically in `FakeTensorMode.__torch_dispatch__` and `FunctionalTensorMode.__torch_dispatch__`. However, since HOPs implement these two modes themselves, HOPs do not get this benefit. For now, I ended up hardcoding this `NotImplemented` logic directly into the functional/fake rules for flex attention.
Having to do this for every HOP seems a bit painful. If we could plumb every HOP through `Fake[|Functional]TensorMode.__torch_dispatch__` then we would get this support. Another option could be to just assume that most HOP <> mode implementations want the same treatment by default, and hardcode this `NotImplemented` logic into `torch/_ops.py`. I'm not sure if we'd need a way for the HOP to opt out of this though.
(2) We were hardcoding a call to flex attention's fake implementation in dynamo to run fake prop. This is technically wrong for subclasses, because it doesn't give subclasses the chance to interpose on the op and desugar it before fake prop runs. I tweaked dynamo's logic to call the op, and let the dispatcher handle invoking the fake implementation.
**Testing** Xilun is adding some DTensor tests in his PR that will end up testing this logic. If folks would prefer, though, I can try to add a test that uses another subclass instead that is maybe more basic.
This is the tlparse that his DTensor test gnerated for me: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/hirsheybar/0196c1d3-a9a2-46ea-a46d-aa21618aa060/custom/rank_0/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151719
Approved by: https://github.com/ydwu4
Co-authored-by: drisspg <drisspguessous@gmail.com>
https://github.com/pytorch/pytorch/pull/152708 expanded support of `get_estimated_runtime` to many more types of `SchedulerNodes`. This caused an increase in compile time because we're always calling `get_estimated_runtime` to populate the metrics table. This PR adds a flag for this logging, which reduces the instruction count by 8%. Long term, we should probably merge metrics.py with TORCH_LOGS/tlparse (suggestion from @xmfan).
Update: added support for TORCH_LOGS for the metrics logging.
Test Plan:
mm_loop.py and many existing tests cover.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153506
Approved by: https://github.com/eellison
Fixes#147336
## Context
NCU analysis of the fp8 flex attention perf issue in #147336 showed an unexpected increase in shared memory access bank conflicts when loading the V tensor from HBM to SRAM.
Bringing this to the attention of triton developer @davidberard98 he identified the memory layout of the tensor in HBM to be causing non-pipelined loads into SRAM, causing the slowdown.
To summarize:
In flex attention when performing the FP8 GEMM `softmax_scores @ V` the right operand V must be in column-major memory layout. However, the `tl.load` of V blocks from HBM to SRAM cannot be pipelined if the V tensor isn't column-major in HBM already, leading to substantial performance degradation.
This is because triton does not perform async copies with the `cp.async` PTX instruction if the number of contiguous bytes is less than 4 (see [here](81f93f2c8e/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp (L403))).
i.e., when loading 4 bytes of contiguous data from a tensor stored in row-major in HBM, we have to perform 4 separate non-contiguous writes to SRAM to place those bytes in their new location in the col-major layout in SRAM. Thus the load is not a candidate for pipelining w/ cp.async and just moves data to registers then performs a series of single byte stores.
## Fix summary
- To fix this, we should enforce memory layouts for Q, K, V in FlexAttention when fp8 is being used, to ensure they each exist in HBM in the necessary memory layout to facilitate pipelined loads into SRAM ahead of the FP8 GEMMs
## Benchmarks
Rerunning the repro we see fp8 runtime is reduced from 120% of bf16 to 76% of bf16 runtime.
Before fix:
```
(flex) [danvm@devgpu007.eag6 ~/ml-perf-tools/flex_attention (main)]$ rm -rf /tmp/torchinductor_${USER}; python profile_flex.py --bf16 --fp8
2025-05-11 19:07:33,402 - flex_bench - INFO - Running benchmark: bf16
2025-05-11 19:07:35,885 - flex_bench - INFO - bf16: 424.87228804347734 us
2025-05-11 19:07:35,893 - flex_bench - INFO - Running benchmark: fp8e4m3
2025-05-11 19:07:37,319 - flex_bench - INFO - fp8e4m3: 515.714000000001 us
```
After fix:
```
(flex) [danvm@devgpu007.eag6 ~/ml-perf-tools/flex_attention (main)]$ rm -rf /tmp/torchinductor_${USER}; python profile_flex.py --bf16 --fp8
2025-05-11 17:34:38,223 - flex_bench - INFO - Running benchmark: bf16
2025-05-11 17:34:41,157 - flex_bench - INFO - bf16: 423.4662032967036 us
2025-05-11 17:34:41,167 - flex_bench - INFO - Running benchmark: fp8e4m3
2025-05-11 17:34:42,917 - flex_bench - INFO - fp8e4m3: 326.3694803493453 us
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153357
Approved by: https://github.com/ngimel, https://github.com/davidberard98
Fixes#147336
## Context
NCU analysis of the fp8 flex attention perf issue in #147336 showed an unexpected increase in shared memory access bank conflicts when loading the V tensor from HBM to SRAM.
Bringing this to the attention of triton developer @davidberard98 he identified the memory layout of the tensor in HBM to be causing non-pipelined loads into SRAM, causing the slowdown.
To summarize:
In flex attention when performing the FP8 GEMM `softmax_scores @ V` the right operand V must be in column-major memory layout. However, the `tl.load` of V blocks from HBM to SRAM cannot be pipelined if the V tensor isn't column-major in HBM already, leading to substantial performance degradation.
This is because triton does not perform async copies with the `cp.async` PTX instruction if the number of contiguous bytes is less than 4 (see [here](81f93f2c8e/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp (L403))).
i.e., when loading 4 bytes of contiguous data from a tensor stored in row-major in HBM, we have to perform 4 separate non-contiguous writes to SRAM to place those bytes in their new location in the col-major layout in SRAM. Thus the load is not a candidate for pipelining w/ cp.async and just moves data to registers then performs a series of single byte stores.
## Fix summary
- To fix this, we should enforce memory layouts for Q, K, V in FlexAttention when fp8 is being used, to ensure they each exist in HBM in the necessary memory layout to facilitate pipelined loads into SRAM ahead of the FP8 GEMMs
## Benchmarks
Rerunning the repro we see fp8 runtime is reduced from 120% of bf16 to 76% of bf16 runtime.
Before fix:
```
(flex) [danvm@devgpu007.eag6 ~/ml-perf-tools/flex_attention (main)]$ rm -rf /tmp/torchinductor_${USER}; python profile_flex.py --bf16 --fp8
2025-05-11 19:07:33,402 - flex_bench - INFO - Running benchmark: bf16
2025-05-11 19:07:35,885 - flex_bench - INFO - bf16: 424.87228804347734 us
2025-05-11 19:07:35,893 - flex_bench - INFO - Running benchmark: fp8e4m3
2025-05-11 19:07:37,319 - flex_bench - INFO - fp8e4m3: 515.714000000001 us
```
After fix:
```
(flex) [danvm@devgpu007.eag6 ~/ml-perf-tools/flex_attention (main)]$ rm -rf /tmp/torchinductor_${USER}; python profile_flex.py --bf16 --fp8
2025-05-11 17:34:38,223 - flex_bench - INFO - Running benchmark: bf16
2025-05-11 17:34:41,157 - flex_bench - INFO - bf16: 423.4662032967036 us
2025-05-11 17:34:41,167 - flex_bench - INFO - Running benchmark: fp8e4m3
2025-05-11 17:34:42,917 - flex_bench - INFO - fp8e4m3: 326.3694803493453 us
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153357
Approved by: https://github.com/ngimel, https://github.com/davidberard98
Flex Attention may have symints in subgraph inputs and outputs. Existing code implicitly captures these symints but does not explicitly store it in TritonTemplateBuffer. This leads to error when analyzing symints used in Flex Attention as a TritonTemplateBuffer. This PR fixes the issue.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152878
Approved by: https://github.com/drisspg
**Summary**
Fixes [#151290](https://github.com/pytorch/pytorch/issues/151290) and [#151523](https://github.com/pytorch/pytorch/issues/151523), which are regressions introduced by [#144020](https://github.com/pytorch/pytorch/pull/144020). That PR enabled parallelization at the inner loop level.
However, a currently unsupported case arises when parallel reduction occurs under the vectorization loop level, specifically in patterns like:
```
for vec_loop_level:
do_parallel_reduction
```
In such cases, a temporary buffer `tmp_acc_array` is allocated for tail scalar kernels, and another temporary buffer `tmp_acc_array` is also defined for parallel reduction. This results in a conflict due to overlapping temporary buffers. This PR disables the problematic case to avoid the conflict until proper support is implemented.
**Test Plan**
```
python test/inductor/test_flex_attention.py -k test_make_block_mask_cpu
python test/inductor/test_cpu_repro.py -k test_parallel_reduction_vectorization
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151887
Approved by: https://github.com/jansel
Triton codegen currently [sorts vars by divisor](ae6f6b8efb/torch/_inductor/codegen/simd.py (L233-L237)). When there are two vars with the same divisor, the order is undecided.
```python
nodes.sort(
key=lambda x: V.graph.sizevars.size_hint(
x.divisor, fallback=config.unbacked_symint_fallback
)
)
```
The test case leads to the following nodes:
```
(Pdb) nodes[0]
IterationRangesEntry(x1, ((s37 + 127)//128), 2, (xindex//ps0), {x0: ((s37 + 127)//128), x1: 2, x2: ((s12 + 127)//128), x4: 2*(((s12 + 127)//128))*(((s37 + 127)//128)), x5: 0, x6: 2, x7: (((s12 + 127)//128))*(((s37 + 127)//128))})
(Pdb) nodes[1]
IterationRangesEntry(x0, 1, ((s37 + 127)//128), ModularIndexing(xindex, 1, ps0), {x0: ((s37 + 127)//128), x1: 2, x2: ((s12 + 127)//128), x4: 2*(((s12 + 127)//128))*(((s37 + 127)//128)), x5: 0, x6: 2, x7: (((s12 + 127)//128))*(((s37 + 127)//128))})
(Pdb) nodes[2]
IterationRangesEntry(x2, 2*(((s37 + 127)//128)), ((s12 + 127)//128), (xindex//(2*(((s37 + 127)//128)))), {x0: ((s37 + 127)//128), x1: 2, x2: ((s12 + 127)//128), x4: 2*(((s12 + 127)//128))*(((s37 + 127)//128)), x5: 0, x6: 2, x7: (((s12 + 127)//128))*(((s37 + 127)//128))})
(Pdb) V.graph.sizevars.statically_known_equals(nodes[0].length, 2)
True
(Pdb) V.graph.sizevars.statically_known_equals(nodes[1].length, 1)
True
(Pdb) V.graph.sizevars.statically_known_equals(nodes[2].length, 1)
True
(Pdb) V.graph.sizevars.statically_known_equals(nodes[0].divisor, 1)
True
(Pdb) V.graph.sizevars.statically_known_equals(nodes[1].divisor, 1)
True
(Pdb) V.graph.sizevars.statically_known_equals(nodes[2].divisor, 2)
True
```
Since x1 and x0 both have divisor 1, the relative order is random across runs.
In some runs, we have order [x1, x0, x2] with divisors as [1,1,2] and lengths as [2,1,1]. After x1, we have [divisor = divisor * node.length](ae6f6b8efb/torch/_inductor/codegen/simd.py (L246)) = 1 * 2 = 2. Then, when processing x0, we have node.divisor=1, divisor=2, and [FloorDiv(node.divisor, divisor)](ae6f6b8efb/torch/_inductor/codegen/simd.py (L251)) = 0, which indicates an iteration length of 0 and leads errors later.
The fix is to sort by both divisor and length_is_one. So for two nodes with the same divisor, we process the node with length=1 first.
Fixes#149789
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151634
Approved by: https://github.com/Skylion007, https://github.com/drisspg
Finishes up the work started in #121686 + adds test
Update: this was not as straightforward as I originally imagined. Context below.
**TL;DR:** `TestFoo{CPU, CUDA}` now actually derive from `TestFoo`! Also, `{CPU, CUDA}TestBase` setup / teardown logic is now always called (it is required to set the primary device), regardless of whether `super().setUpClass()` / `super().tearDownClass()` are called or not.
**Background:** The typical way to get device-specific tests is to write a generic `TestFoo` and call `instantiate_device_type_tests(TestFoo, locals())` to get `TestFooCPU`, `TestFooCUDA`, etc. After this, generic tests (e.g. `TestFoo.test_bar()`) become `TestFooCPU.test_bar_cpu()` / `TestFooCUDA.test_bar_cuda()`.
Behind the scenes, this was historically accomplished by creating a `TestFooCUDA` that derives from both a `CUDATestBase` and an *empty class* called `TestFoo_base`. This `TestFoo_base` has the same bases as `TestFoo`, but none of the test functions (e.g. `test_bar()`). The documented reason for this is to avoid things like a derived `TestFooCUDA.test_bar()` being discovered in addition to the real device-specific test `TestFooCUDA.test_bar_cuda()`.
(1) A reason this matters is because it should be possible to call e.g. `super().setUpClass()` from a custom setup / teardown classmethod. If the generated TestFooCUDA does not derive from TestFoo, but instead derives from the empty class described above, this syntax does not work; in fact there is no way to form a proper `super()` call that works across the device-specific test variants. Here's an example that breaks in the OpInfo tests:
070f389745/test/test_ops.py (L218-L221)
(2) Further, there is some precedent within a custom `setUpClass()` impl for storing things on the `cls` object to be accessed at test time. This must be the device-specific test class (`TestFooCUDA`) and not `TestFoo` for this to work. As an example, the open device registration tests load a module during setup and use it in the test logic:
070f389745/test/test_cpp_extensions_open_device_registration.py (L63-L77)070f389745/test/test_cpp_extensions_open_device_registration.py (L79-L80)
To accomplish both (1) and (2) at the same time, I decided to revisit the idea of utilizing a proper inheritance hierarchy for `TestFoo` -> `{TestFooCPU, TestFooCUDA}`. That is: have TestFooCPU / TestFooCUDA **actually** derive from `TestFoo`. This achieves both (1) and (2). The only thing left is to make sure the generic tests (e.g. `TestFoo.test_bar()`) are not discoverable, as was the stated reason for diverging from this in the first place. It turns out we can simply `delattr()` these generic tests from `TestFoo` once `TestFooCPU` / `TestFooCUDA` have been setup with the device-specific variants, and all works well. The `instantiate_device_type_tests(...)` logic already deletes `TestFoo` from scope, so I don't see a problem with deleting generic tests from this base class as well (CI will prove me right or wrong ofc).
**Side note:** I was encountering a weird race condition where sometimes the custom `setUpClass()` / `tearDownClass()` defined & swapped in [here](4a47dd9b3f/torch/testing/_internal/common_device_type.py (L940-L955)) would be used, and sometimes it wouldn't. This non-deterministic behavior was called out previously by @ngimel here:
4a47dd9b3f/test/inductor/test_torchinductor_dynamic_shapes.py (L128-L130)
To address this, I moved this block of logic to before the first call to `instantiate_test()`, as that method queries for the primary device, and the primary device identification logic may manually invoke `setUpClass()` (see [here](4a47dd9b3f/torch/testing/_internal/common_device_type.py (L381-L384))). Goal: define the `setUpClass()` / `tearDownClass()` we want for correctness before they're ever called. This seems to work and the behavior is deterministic now AFAICT.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151129
Approved by: https://github.com/janeyx99, https://github.com/masnesral, https://github.com/malfet
Summary:
Currently only `num_warps` and `num_stages` are supported as one of the kernel options for inductor auto-tuning using `TritonTemplate`.
In order to allow warp-specialization kernel options should allow specifying `num_consumer_groups` and `num_buffers_warp_spec` as well.
NOTE: Currently gating changes to FBCODE using HAS_WARP_SPEC which is only available on triton/release-3.3.x
Test Plan:
## Unit test
Added tests for `test_triton_template_warp_specialization` to verify generated kenrnel contains configs for `num_consumer_groups` and `num_buffers_warp_spec`.
## Functional Testing
Specific to flexattention.
```
import torch
from torch.nn.attention.flex_attention import flex_attention
from triton.testing import do_bench
make_tensor = lambda: torch.rand(8, 16, 8192, 128, device="cuda", dtype=torch.bfloat16)
q, k, v = make_tensor(), make_tensor(), make_tensor()
flex_compiled = torch.compile(flex_attention, fullgraph=True)
print(do_bench(lambda: flex_compiled(q, k, v, kernel_options={"num_warps": 4})))
```
triton do_bench results:
- default compile: 15.176783561706543
- with warp-spec: 9.452800750732422
## Extra notes
- generated triton kernel using `TORCH_LOGS=output_code`: P1740612877
- TTGIR for fused kernel: P1740614685
Differential Revision: D71982587
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150122
Approved by: https://github.com/eellison, https://github.com/zou3519, https://github.com/jansel
Summary:
Currently only `num_warps` and `num_stages` are supported as one of the kernel options for inductor auto-tuning using `TritonTemplate`. In order to allow warp-specialization kernel options should allow specifying `num_consumer_groups` and `num_buffers_warp_spec` as well.
Test Plan:
## Unit test
Added tests for `test_triton_template_warp_specialization` to verify generated kenrnel contains configs for `num_consumer_groups` and `num_buffers_warp_spec`.
## Functional Testing
Specific to flexattention.
```
import torch
from torch.nn.attention.flex_attention import flex_attention
from triton.testing import do_bench
make_tensor = lambda: torch.rand(8, 16, 8192, 128, device="cuda", dtype=torch.bfloat16)
q, k, v = make_tensor(), make_tensor(), make_tensor()
flex_compiled = torch.compile(flex_attention, fullgraph=True)
print(do_bench(lambda: flex_compiled(q, k, v, kernel_options={"num_warps": 4})))
```
triton do_bench results:
- default compile: 15.176783561706543
- with warp-spec: 9.452800750732422
## Extra notes
- generated triton kernel using `TORCH_LOGS=output_code`: P1740612877
- TTGIR for fused kernel: P1740614685
Differential Revision: D70212243
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148503
Approved by: https://github.com/eellison
python benchmarks/transformer/score_mod.py --dynamic --max-autotune
previously would crash with
```
"/home/bobren/local/a/pytorch/torch/_inductor/select_algorithm.py", line 2306, in key_of
node.get_device().type,
```
but with this change no longer does
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148991
Approved by: https://github.com/drisspg
# Summary
Fixes https://github.com/pytorch/pytorch/issues/146377
So what was the original problem: we were codegening a really weird epilogue:
```Python
# first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
# then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
xindex = index_k + 64*index_n + 64*off_hkv*ks2 + 128*off_zq*ks2
tl.store(out_ptr0 + (tl.broadcast_to(index_k + 64*index_n + off_hkv*ks1, dk.shape)), dk, mask)
x5 = (xindex % ks3)
tmp2 = tl.load(out_ptr0 + (x5 + ks1*off_hkv), mask, eviction_policy='evict_last')
tl.store(out_ptr1 + (tl.broadcast_to(xindex, dk.shape)), tmp2, mask)
```
This epilogue was writing and then reading from overlapping regions of memory causing a race condition.
### Why were we generating this epilgoue
During the lowering we created a buffer w/ a different size/stride from the expected return strides. I :think this added an implicit node (for doing the permutation of this wrongly strided output to the the expected one from the meta func. The scheduler for some reason thought it was okay to fuse this into the epilogue, tbh I dont know why.
This fixes the broken meta func and the original repro. I will add a test but it is hard to pop, better than nothing
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146563
Approved by: https://github.com/Chillee
# Fixes
https://github.com/pytorch/pytorch/issues/146624
### Updated
From offline discussion going w/ sizehint
However this does incur guards. I couldn't really think of a fancy way to do this. I was going to do `V.graph.sizevars.size_hint` w/ some default for num blocks, but we ultimately need some information about the input.
I am also not sure if size_hint is ALWAYS guaranteed to return the runtime value. I think it would be okay to not supported unbacked symints (maybe).
For instance, in the repro, we quickly hit the recompile limit.
```Shell
torch._dynamo hit config.recompile_limit (8)
function: 'flex_attention' (/home/drisspg/meta/pytorch/torch/nn/attention/flex_attention.py:1161)
last reason: 0/0: tensor 'L['key']' size mismatch at index 2. expected 1, actual 546
To log all recompilation reasons, use TORCH_LOGS="recompiles".
To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html.
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146657
Approved by: https://github.com/Chillee, https://github.com/yanboliang
# Summary
- Adds support for non-power of 2 headdim by launching blocks w/ head_dim rounded to the next valid power.
- Other option I considered was building up the final dot_products with smaller blocks (this would probably work but for sake of code complexity going with this option for now)
### Corollary
We had a bug in our backwards kernel where we were using index_k instead of index_v. This should have shown up for the qk_head_dim != v_head_dim cases..
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133495
Approved by: https://github.com/Chillee
Before #143552
```py
Traceback (most recent call last):
File "/home/jansel/pytorch/repro.py", line 51, in <module>
fp32_compiled = optimized_model(low_input)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/eval_frame.py", line 576, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/convert_frame.py", line 1381, in __call__
return self._torchdynamo_orig_callable(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/convert_frame.py", line 1165, in __call__
result = self._inner_convert(
^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/convert_frame.py", line 547, in __call__
return _compile(
^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/convert_frame.py", line 987, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/convert_frame.py", line 715, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_utils_internal.py", line 95, in wrapper_function
return function(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/convert_frame.py", line 750, in _compile_inner
out_code = transform_code_object(code, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
transformations(instructions, code_options)
File "/home/jansel/pytorch/torch/_dynamo/convert_frame.py", line 231, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/convert_frame.py", line 662, in transform
tracer.run()
File "/home/jansel/pytorch/torch/_dynamo/symbolic_convert.py", line 2870, in run
super().run()
File "/home/jansel/pytorch/torch/_dynamo/symbolic_convert.py", line 1053, in run
while self.step():
^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/symbolic_convert.py", line 963, in step
self.dispatch_table[inst.opcode](self, inst)
File "/home/jansel/pytorch/torch/_dynamo/symbolic_convert.py", line 3050, in RETURN_VALUE
self._return(inst)
File "/home/jansel/pytorch/torch/_dynamo/symbolic_convert.py", line 3035, in _return
self.output.compile_subgraph(
File "/home/jansel/pytorch/torch/_dynamo/output_graph.py", line 1101, in compile_subgraph
self.compile_and_call_fx_graph(
File "/home/jansel/pytorch/torch/_dynamo/output_graph.py", line 1382, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/output_graph.py", line 1432, in call_user_compiler
return self._call_user_compiler(gm)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/output_graph.py", line 1483, in _call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
File "/home/jansel/pytorch/torch/_dynamo/output_graph.py", line 1462, in _call_user_compiler
compiled_fn = compiler_fn(gm, self.example_inputs())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/repro/after_dynamo.py", line 130, in __call__
compiled_gm = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/__init__.py", line 2314, in __call__
return compile_fx(model_, inputs_, config_patches=self.config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/compile_fx.py", line 1880, in compile_fx
return aot_autograd(
^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/backends/common.py", line 83, in __call__
cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_functorch/aot_autograd.py", line 1145, in aot_module_simplified
compiled_fn = AOTAutogradCache.load(
^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_functorch/_aot_autograd/autograd_cache.py", line 754, in load
compiled_fn = dispatch_and_compile()
^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_functorch/aot_autograd.py", line 1131, in dispatch_and_compile
compiled_fn, _ = create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_functorch/aot_autograd.py", line 580, in create_aot_dispatcher_function
return _create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_functorch/aot_autograd.py", line 830, in _create_aot_dispatcher_function
compiled_fn, fw_metadata = compiler_fn(
^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 676, in aot_dispatch_autograd
compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_functorch/aot_autograd.py", line 489, in __call__
return self.compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/compile_fx.py", line 1758, in fw_compiler_base
return inner_compile(
^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/compile_fx.py", line 572, in compile_fx_inner
return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/repro/after_aot.py", line 102, in debug_wrapper
inner_compiled_fn = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/compile_fx.py", line 686, in _compile_fx_inner
mb_compiled_graph = fx_codegen_and_compile(
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/compile_fx.py", line 1129, in fx_codegen_and_compile
return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/compile_fx.py", line 1044, in codegen_and_compile
compiled_fn = graph.compile_to_module().call
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/graph.py", line 1975, in compile_to_module
return self._compile_to_module()
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/graph.py", line 1981, in _compile_to_module
self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/graph.py", line 1912, in codegen
self.scheduler = Scheduler(self.operations)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/scheduler.py", line 1880, in __init__
self._init(nodes)
File "/home/jansel/pytorch/torch/_inductor/scheduler.py", line 1955, in _init
self.nodes = self.fuse_nodes(self.nodes)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/scheduler.py", line 2461, in fuse_nodes
nodes = self.fuse_nodes_once(nodes)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/scheduler.py", line 2773, in fuse_nodes_once
assert False, "a fake error during fusion"
^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AssertionError: a fake error during fusion
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
```
Before this PR
```py
Traceback (most recent call last):
File "/home/jansel/pytorch/repro.py", line 51, in <module>
fp32_compiled = optimized_model(low_input)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/eval_frame.py", line 580, in _fn
raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/output_graph.py", line 1484, in _call_user_compiler
raise BackendCompilerFailed(
File "/home/jansel/pytorch/torch/_dynamo/output_graph.py", line 1463, in _call_user_compiler
compiled_fn = compiler_fn(gm, self.example_inputs())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/repro/after_dynamo.py", line 130, in __call__
compiled_gm = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/__init__.py", line 2314, in __call__
return compile_fx(model_, inputs_, config_patches=self.config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/compile_fx.py", line 1880, in compile_fx
return aot_autograd(
^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/backends/common.py", line 83, in __call__
cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_functorch/aot_autograd.py", line 1145, in aot_module_simplified
compiled_fn = AOTAutogradCache.load(
^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_functorch/_aot_autograd/autograd_cache.py", line 754, in load
compiled_fn = dispatch_and_compile()
^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_functorch/aot_autograd.py", line 1131, in dispatch_and_compile
compiled_fn, _ = create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_functorch/aot_autograd.py", line 580, in create_aot_dispatcher_function
return _create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_functorch/aot_autograd.py", line 830, in _create_aot_dispatcher_function
compiled_fn, fw_metadata = compiler_fn(
^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 676, in aot_dispatch_autograd
compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_functorch/aot_autograd.py", line 489, in __call__
return self.compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/compile_fx.py", line 1758, in fw_compiler_base
return inner_compile(
^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/compile_fx.py", line 572, in compile_fx_inner
return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/repro/after_aot.py", line 102, in debug_wrapper
inner_compiled_fn = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/compile_fx.py", line 686, in _compile_fx_inner
mb_compiled_graph = fx_codegen_and_compile(
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/compile_fx.py", line 1129, in fx_codegen_and_compile
return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/compile_fx.py", line 1044, in codegen_and_compile
compiled_fn = graph.compile_to_module().call
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/graph.py", line 1975, in compile_to_module
return self._compile_to_module()
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/graph.py", line 1981, in _compile_to_module
self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/graph.py", line 1912, in codegen
self.scheduler = Scheduler(self.operations)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/scheduler.py", line 1880, in __init__
self._init(nodes)
File "/home/jansel/pytorch/torch/_inductor/scheduler.py", line 1955, in _init
self.nodes = self.fuse_nodes(self.nodes)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/scheduler.py", line 2461, in fuse_nodes
nodes = self.fuse_nodes_once(nodes)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/scheduler.py", line 2773, in fuse_nodes_once
assert False, "a fake error during fusion"
^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AssertionError: a fake error during fusion
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
```
After this PR
```py
Traceback (most recent call last):
File "/home/jansel/pytorch/repro.py", line 51, in <module>
fp32_compiled = optimized_model(low_input)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_dynamo/eval_frame.py", line 580, in _fn
raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/compile_fx.py", line 704, in _compile_fx_inner
raise InductorError(e, currentframe()).with_traceback(
File "/home/jansel/pytorch/torch/_inductor/compile_fx.py", line 689, in _compile_fx_inner
mb_compiled_graph = fx_codegen_and_compile(
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/compile_fx.py", line 1138, in fx_codegen_and_compile
return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/compile_fx.py", line 1053, in codegen_and_compile
compiled_fn = graph.compile_to_module().call
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/graph.py", line 1975, in compile_to_module
return self._compile_to_module()
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/graph.py", line 1981, in _compile_to_module
self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/graph.py", line 1912, in codegen
self.scheduler = Scheduler(self.operations)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/scheduler.py", line 1880, in __init__
self._init(nodes)
File "/home/jansel/pytorch/torch/_inductor/scheduler.py", line 1955, in _init
self.nodes = self.fuse_nodes(self.nodes)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/scheduler.py", line 2461, in fuse_nodes
nodes = self.fuse_nodes_once(nodes)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/scheduler.py", line 2773, in fuse_nodes_once
assert False, "a fake error during fusion"
^^^^^
torch._inductor.exc.InductorError: AssertionError: a fake error during fusion
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
```
A large numer of frames are removed between:
```py
File "/home/jansel/pytorch/torch/_dynamo/eval_frame.py", line 580, in _fn
raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_inductor/compile_fx.py", line 704, in _compile_fx_inner
raise InductorError(e, currentframe()).with_traceback(
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143610
Approved by: https://github.com/eellison
ghstack dependencies: #143552