mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
dbb55b448b
34 Commits
| Author | SHA1 | Message | Date | |
|---|---|---|---|---|
|
|
86792a5a8d |
[invoke_subgraph] User facing API to support arbitrary args and kwargs (#139162)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139162 Approved by: https://github.com/zou3519 |
||
|
|
75f3056c81 |
[hop-db] Import invoke_subgraph to avoid Dynamo error on mac (#140038)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140038 Approved by: https://github.com/ydwu4 |
||
|
|
c6bb9b53f4 |
[scan] better error handling and remove redundant tests (#137967)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137967 Approved by: https://github.com/zou3519 |
||
|
|
4dd4d38ca9 |
[hierarchical-compilation][hop] Introduce invoke_subgraph (#137538)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137538 Approved by: https://github.com/zou3519 |
||
|
|
840c6b7a68 |
[FlexAttention] Add Better error message for cpu tensors (#136673)
Partially address: #136525 Pull Request resolved: https://github.com/pytorch/pytorch/pull/136673 Approved by: https://github.com/Chillee |
||
|
|
1d231ff8ba |
[HOO] add hints_wrapper to support passing context hints (#132860)
Fixes #126393 The implementation code is based on feedback here (https://github.com/pytorch/pytorch/pull/121639#issuecomment-2223948842). Hints are passed as kwargs of hints_wrapper op. It also supports nested hints. Pull Request resolved: https://github.com/pytorch/pytorch/pull/132860 Approved by: https://github.com/ydwu4, https://github.com/zou3519 |
||
|
|
8ae1963a61 |
[Autograd] Cond Higher-Order Operation (#126911)
This is an updated PR to equip cond with the autograd feature and replaces the old [PR](https://github.com/pytorch/pytorch/pull/126007)
@ydwu4 I tried to incorporate your requests already.
Currently there are two problems that I struggle with solving:
1. There seems to be an import issue when trying to import cond in `torch/__init__.py`, see [here](
|
||
|
|
fb3674b1f4 |
Revert "[Autograd] Cond Higher-Order Operation (#126911)"
This reverts commit |
||
|
|
f7058b735e |
[Autograd] Cond Higher-Order Operation (#126911)
This is an updated PR to equip cond with the autograd feature and replaces the old [PR](https://github.com/pytorch/pytorch/pull/126007)
@ydwu4 I tried to incorporate your requests already.
Currently there are two problems that I struggle with solving:
1. There seems to be an import issue when trying to import cond in `torch/__init__.py`, see [here](
|
||
|
|
dd39dca034 |
Removing some cruff and updating signatures for consistency (#130871)
# Summary - This removes a bunch of example score mods that were primarily used for testing and places them directly in the test file. We should follow up with merging test_flex_decode and test_flash when the velocity slows down a little - Fixes a bug with indexing on block mask - Adds some doc strings to helper funcs and fixes some misc typing things - Forces functions passed to `create_block_mask` to mask_mods and updates tests files Pull Request resolved: https://github.com/pytorch/pytorch/pull/130871 Approved by: https://github.com/joydddd, https://github.com/Chillee |
||
|
|
2b43d339fe |
Make FlexAttention API public (#130755)
# Summary Makes the prototype API flex_attention public Pull Request resolved: https://github.com/pytorch/pytorch/pull/130755 Approved by: https://github.com/Chillee |
||
|
|
f9f85bfc0b |
[Inductor] FlexAttention supports partial masking (#130415) (#130626)
This is the new version of https://github.com/pytorch/pytorch/pull/130415 Updated test script: https://gist.github.com/yanboliang/7c34a82df611d4ea8869cb9e041bfbfc Updated perf numbers: ``` (pt) [ybliang@devgpu002.ash8 ~/local/debug]$ CUDA_VISIBLE_DEVICES=4 python debug7.py fwd speedup: 0.7166695598192317 bwd speedup: 0.7142133867805904 (pt) [ybliang@devgpu002.ash8 ~/local/debug]$ CUDA_VISIBLE_DEVICES=4 python debug7.py --partial-mask fwd speedup: 0.8428246087169973 bwd speedup: 0.8486261278030254 ``` Approved by: https://github.com/Chillee Pull Request resolved: https://github.com/pytorch/pytorch/pull/130626 Approved by: https://github.com/drisspg, https://github.com/yanboliang |
||
|
|
99c68f7bea |
Refactor TritonKernelVariable's logic so it can be shared (#130177)
TritonKernelVariable's logic tells us how to go from a user-defined triton kernel and a grid to a call to the triton_kernel_wrapper_mutation HOP. We want to re-use this in a setting without Dynamo; in the next PR up, we create a new decorator (capture_triton) that, when applied to a triton kernel, transforms a call to the triton kernel into a call to the triton_kernel_wrapper_mutation HOP. Test Plan: - existing tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/130177 Approved by: https://github.com/oulgen, https://github.com/ydwu4 |
||
|
|
da66e50e6e |
Added compile option to create_block_mask (#130106)
Compiling the `create_block_mask` function allows us to "materialize" extremely large masks. This would have been a 1 *trillion* element tensor if fully materialized. ``` print(do_bench(lambda: create_block_mask(causal_mask, 1, 1, 2**20, 2**20, _compiled=True))) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/130106 Approved by: https://github.com/yanboliang ghstack dependencies: #130160 |
||
|
|
ec47d4d9a8 |
[Inductor] FlexAttention supports block sparse mask (#129216)
Benchmark script (causal mask): https://gist.github.com/yanboliang/c2010a1fd081d4e8ca94fadec9eef286 Initial perf number: * fwd speedup: 0.44 -> 0.72 * bwd speedup: 0.38 -> 0.71 Pull Request resolved: https://github.com/pytorch/pytorch/pull/129216 Approved by: https://github.com/Chillee |
||
|
|
c43923a116 |
Revert "[Inductor] FlexAttention supports block sparse mask (#129216)"
This reverts commit |
||
|
|
b9d3cedd64 |
[Inductor] FlexAttention supports block sparse mask (#129216)
Benchmark script (causal mask): https://gist.github.com/yanboliang/c2010a1fd081d4e8ca94fadec9eef286 Initial perf number: * fwd speedup: 0.44 -> 0.72 * bwd speedup: 0.38 -> 0.71 Pull Request resolved: https://github.com/pytorch/pytorch/pull/129216 Approved by: https://github.com/Chillee |
||
|
|
5ceba6a3cb |
Revert "[Inductor] FlexAttention supports block sparse mask (#129216)"
This reverts commit |
||
|
|
4082759925 |
[Inductor] FlexAttention supports block sparse mask (#129216)
Benchmark script (causal mask): https://gist.github.com/yanboliang/c2010a1fd081d4e8ca94fadec9eef286 Initial perf number: * fwd speedup: 0.44 -> 0.72 * bwd speedup: 0.38 -> 0.71 Pull Request resolved: https://github.com/pytorch/pytorch/pull/129216 Approved by: https://github.com/Chillee |
||
|
|
853081a8e7 |
Replace torch.library.impl_abstract with torch.library.register_fake (#126606)
To remove the disrupting warning
```
warnings.warn("torch.library.impl_abstract was renamed to "
"torch.library.register_fake. Please use that instead; "
"we will remove torch.library.impl_abstract in a future "
"version of PyTorch.",
DeprecationWarning, stacklevel=2)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126606
Approved by: https://github.com/ezyang
|
||
|
|
762ce6f062 |
Add Lowering for FlexAttention Backwards (#125515)
# Summary #### What does this PR do? It enables Inductor to actually generate the fused flex attention kernel for the backwards I did some other things along the way: - Abstract out the 'build_subgraph_buffer' subroutine and make it reusable between flex attention and flex_attention backwards. In total we need too build 3 subgraphs for fwd + bwd. 1 for the fwd graph and then 2 in the bwd. The FAv2 algorithm recomputes the parts of the forward (more efficiently since we already have the row_max via logsumexp), therefore we need to inline both the fwd graph and the joint graph in the bwds kernel. - The version of the backwards kernel is from a somewhat older version of the triton tutorial implementation. I think that we should update in a follow up to a newer version. Notably the blocks need to be square for this to work as currently implemented. I am sure there are many opportunities for optimization. - I didnt correctly register the decomp table + IndexMode when I landed: https://github.com/pytorch/pytorch/pull/123902, this remedies that. - The rel_bias helper func was reversed in terms of causality. I updated and then add a test specific for "future causal" attention. - This PRs but the main point that I think still needs to be worked out is the store_output call. I have it hacked up to be 'fake' but I dont think we want to land that and likely want to just have a mutated 'dq' and a stored_output 'dk' - I also needed to update the `TritonTemplateKernel` to actually accept multiple subgraphs (modifications) - I updated the benchmark to also profile bwds performance ### Benchmark Numbers: _The current implementation is not parallelizing over ctx length in the bwd_ FWD Speedups | Type | Speedup | shape | score_mod | dtype | |---------|-----------|--------------------|-------------|----------------| | Average | 0.991 | | | | | Max | 1.182 | (16, 16, 4096, 64) | noop | torch.bfloat16 | | Min | 0.796 | (2, 16, 512, 256) | head_bias | torch.bfloat16 | BWD Speedups | Type | Speedup | shape | score_mod | dtype | |---------|-----------|--------------------|-------------|----------------| | Average | 0.291 | | | | | Max | 0.652 | (8, 16, 512, 64) | head_bias | torch.bfloat16 | | Min | 0.073 | (2, 16, 4096, 128) | head_bias | torch.bfloat16 | <details> <summary>Full Data</summary> | shape | score_mod | dtype | fwd_eager_time | fwd_compiled_time | bwd_eager_time | bwd_compiled_time | fwd_speedup | bwd_speedup | |---------------------|---------------|----------------|------------------|---------------------|------------------|---------------------|---------------|---------------| | (2, 16, 512, 64) | noop | torch.bfloat16 | 19.936 | 19.092 | 57.851 | 193.564 | 1.044 | 0.299 | | (2, 16, 512, 64) | causal_mask | torch.bfloat16 | 19.955 | 19.497 | 57.662 | 206.278 | 1.024 | 0.280 | | (2, 16, 512, 64) | relative_bias | torch.bfloat16 | 19.455 | 21.297 | 57.674 | 195.219 | 0.913 | 0.295 | | (2, 16, 512, 64) | head_bias | torch.bfloat16 | 19.958 | 21.289 | 57.674 | 193.859 | 0.938 | 0.298 | | (2, 16, 512, 128) | noop | torch.bfloat16 | 28.157 | 28.615 | 82.831 | 454.211 | 0.984 | 0.182 | | (2, 16, 512, 128) | causal_mask | torch.bfloat16 | 28.154 | 28.444 | 83.091 | 432.083 | 0.990 | 0.192 | | (2, 16, 512, 128) | relative_bias | torch.bfloat16 | 28.722 | 27.897 | 83.175 | 446.789 | 1.030 | 0.186 | | (2, 16, 512, 128) | head_bias | torch.bfloat16 | 28.299 | 27.673 | 83.052 | 459.179 | 1.023 | 0.181 | | (2, 16, 512, 256) | noop | torch.bfloat16 | 41.167 | 50.504 | 175.019 | 1083.545 | 0.815 | 0.162 | | (2, 16, 512, 256) | causal_mask | torch.bfloat16 | 41.656 | 51.933 | 175.078 | 1171.176 | 0.802 | 0.149 | | (2, 16, 512, 256) | relative_bias | torch.bfloat16 | 41.697 | 50.722 | 175.159 | 1097.312 | 0.822 | 0.160 | | (2, 16, 512, 256) | head_bias | torch.bfloat16 | 41.690 | 52.387 | 175.184 | 1097.336 | 0.796 | 0.160 | | (2, 16, 1024, 64) | noop | torch.bfloat16 | 39.232 | 37.454 | 127.847 | 612.430 | 1.047 | 0.209 | | (2, 16, 1024, 64) | causal_mask | torch.bfloat16 | 39.930 | 39.599 | 127.755 | 665.359 | 1.008 | 0.192 | | (2, 16, 1024, 64) | relative_bias | torch.bfloat16 | 39.417 | 41.304 | 127.902 | 614.990 | 0.954 | 0.208 | | (2, 16, 1024, 64) | head_bias | torch.bfloat16 | 39.965 | 42.034 | 127.953 | 613.273 | 0.951 | 0.209 | | (2, 16, 1024, 128) | noop | torch.bfloat16 | 63.964 | 71.024 | 226.510 | 1637.669 | 0.901 | 0.138 | | (2, 16, 1024, 128) | causal_mask | torch.bfloat16 | 63.843 | 72.451 | 226.750 | 1558.949 | 0.881 | 0.145 | | (2, 16, 1024, 128) | relative_bias | torch.bfloat16 | 64.301 | 70.487 | 226.651 | 1610.063 | 0.912 | 0.141 | | (2, 16, 1024, 128) | head_bias | torch.bfloat16 | 64.033 | 71.394 | 226.676 | 1668.511 | 0.897 | 0.136 | | (2, 16, 1024, 256) | noop | torch.bfloat16 | 129.348 | 141.390 | 507.337 | 4405.175 | 0.915 | 0.115 | | (2, 16, 1024, 256) | causal_mask | torch.bfloat16 | 129.538 | 145.680 | 507.178 | 4768.874 | 0.889 | 0.106 | | (2, 16, 1024, 256) | relative_bias | torch.bfloat16 | 129.438 | 142.782 | 507.004 | 4401.002 | 0.907 | 0.115 | | (2, 16, 1024, 256) | head_bias | torch.bfloat16 | 129.058 | 146.242 | 507.547 | 4434.251 | 0.883 | 0.114 | | (2, 16, 4096, 64) | noop | torch.bfloat16 | 481.606 | 409.120 | 1440.890 | 14147.269 | 1.177 | 0.102 | | (2, 16, 4096, 64) | causal_mask | torch.bfloat16 | 480.227 | 438.847 | 1434.419 | 14973.386 | 1.094 | 0.096 | | (2, 16, 4096, 64) | relative_bias | torch.bfloat16 | 480.831 | 458.104 | 1432.935 | 14193.253 | 1.050 | 0.101 | | (2, 16, 4096, 64) | head_bias | torch.bfloat16 | 480.749 | 452.497 | 1437.040 | 14084.869 | 1.062 | 0.102 | | (2, 16, 4096, 128) | noop | torch.bfloat16 | 872.534 | 848.275 | 2600.895 | 35156.849 | 1.029 | 0.074 | | (2, 16, 4096, 128) | causal_mask | torch.bfloat16 | 872.647 | 868.279 | 2587.581 | 31919.531 | 1.005 | 0.081 | | (2, 16, 4096, 128) | relative_bias | torch.bfloat16 | 871.484 | 827.644 | 2593.989 | 34805.634 | 1.053 | 0.075 | | (2, 16, 4096, 128) | head_bias | torch.bfloat16 | 871.422 | 856.437 | 2602.482 | 35708.591 | 1.017 | 0.073 | | (2, 16, 4096, 256) | noop | torch.bfloat16 | 1904.497 | 1758.183 | 6122.416 | 66754.593 | 1.083 | 0.092 | | (2, 16, 4096, 256) | causal_mask | torch.bfloat16 | 1911.174 | 1762.821 | 6113.207 | 72759.392 | 1.084 | 0.084 | | (2, 16, 4096, 256) | relative_bias | torch.bfloat16 | 1911.254 | 1727.108 | 6123.530 | 66577.988 | 1.107 | 0.092 | | (2, 16, 4096, 256) | head_bias | torch.bfloat16 | 1916.977 | 1801.804 | 6118.158 | 67359.680 | 1.064 | 0.091 | | (8, 16, 512, 64) | noop | torch.bfloat16 | 44.984 | 43.974 | 170.276 | 262.259 | 1.023 | 0.649 | | (8, 16, 512, 64) | causal_mask | torch.bfloat16 | 45.001 | 46.265 | 170.509 | 274.893 | 0.973 | 0.620 | | (8, 16, 512, 64) | relative_bias | torch.bfloat16 | 45.466 | 48.211 | 170.606 | 262.759 | 0.943 | 0.649 | | (8, 16, 512, 64) | head_bias | torch.bfloat16 | 45.481 | 48.435 | 170.267 | 261.265 | 0.939 | 0.652 | | (8, 16, 512, 128) | noop | torch.bfloat16 | 72.565 | 74.736 | 313.220 | 773.126 | 0.971 | 0.405 | | (8, 16, 512, 128) | causal_mask | torch.bfloat16 | 72.015 | 75.755 | 313.311 | 775.513 | 0.951 | 0.404 | | (8, 16, 512, 128) | relative_bias | torch.bfloat16 | 72.105 | 74.189 | 313.806 | 769.238 | 0.972 | 0.408 | | (8, 16, 512, 128) | head_bias | torch.bfloat16 | 72.005 | 74.364 | 313.509 | 775.237 | 0.968 | 0.404 | | (8, 16, 512, 256) | noop | torch.bfloat16 | 138.656 | 165.453 | 663.707 | 2672.067 | 0.838 | 0.248 | | (8, 16, 512, 256) | causal_mask | torch.bfloat16 | 139.096 | 172.613 | 663.593 | 2926.538 | 0.806 | 0.227 | | (8, 16, 512, 256) | relative_bias | torch.bfloat16 | 139.500 | 168.417 | 663.938 | 2658.629 | 0.828 | 0.250 | | (8, 16, 512, 256) | head_bias | torch.bfloat16 | 139.776 | 173.549 | 662.920 | 2667.266 | 0.805 | 0.249 | | (8, 16, 1024, 64) | noop | torch.bfloat16 | 134.883 | 125.004 | 484.706 | 1195.254 | 1.079 | 0.406 | | (8, 16, 1024, 64) | causal_mask | torch.bfloat16 | 134.297 | 132.875 | 485.420 | 1234.953 | 1.011 | 0.393 | | (8, 16, 1024, 64) | relative_bias | torch.bfloat16 | 134.839 | 139.231 | 485.470 | 1198.556 | 0.968 | 0.405 | | (8, 16, 1024, 64) | head_bias | torch.bfloat16 | 133.822 | 136.449 | 485.608 | 1189.198 | 0.981 | 0.408 | | (8, 16, 1024, 128) | noop | torch.bfloat16 | 235.470 | 234.765 | 886.094 | 2662.944 | 1.003 | 0.333 | | (8, 16, 1024, 128) | causal_mask | torch.bfloat16 | 236.305 | 241.382 | 886.293 | 2646.984 | 0.979 | 0.335 | | (8, 16, 1024, 128) | relative_bias | torch.bfloat16 | 236.414 | 233.980 | 885.250 | 2642.178 | 1.010 | 0.335 | | (8, 16, 1024, 128) | head_bias | torch.bfloat16 | 237.176 | 239.040 | 885.754 | 2665.242 | 0.992 | 0.332 | | (8, 16, 1024, 256) | noop | torch.bfloat16 | 504.445 | 517.855 | 1978.956 | 9592.906 | 0.974 | 0.206 | | (8, 16, 1024, 256) | causal_mask | torch.bfloat16 | 502.428 | 536.002 | 1978.611 | 10607.342 | 0.937 | 0.187 | | (8, 16, 1024, 256) | relative_bias | torch.bfloat16 | 503.396 | 523.960 | 1977.993 | 9539.284 | 0.961 | 0.207 | | (8, 16, 1024, 256) | head_bias | torch.bfloat16 | 503.818 | 536.014 | 1980.131 | 9576.262 | 0.940 | 0.207 | | (8, 16, 4096, 64) | noop | torch.bfloat16 | 1970.139 | 1674.930 | 5750.940 | 16724.134 | 1.176 | 0.344 | | (8, 16, 4096, 64) | causal_mask | torch.bfloat16 | 1959.036 | 1775.056 | 5780.512 | 17390.350 | 1.104 | 0.332 | | (8, 16, 4096, 64) | relative_bias | torch.bfloat16 | 1947.198 | 1773.869 | 5780.643 | 16779.699 | 1.098 | 0.345 | | (8, 16, 4096, 64) | head_bias | torch.bfloat16 | 1963.935 | 1829.502 | 5780.018 | 16703.259 | 1.073 | 0.346 | | (8, 16, 4096, 128) | noop | torch.bfloat16 | 3582.711 | 3362.623 | 10436.069 | 36415.565 | 1.065 | 0.287 | | (8, 16, 4096, 128) | causal_mask | torch.bfloat16 | 3581.504 | 3499.472 | 10346.869 | 36164.959 | 1.023 | 0.286 | | (8, 16, 4096, 128) | relative_bias | torch.bfloat16 | 3589.779 | 3337.849 | 10529.621 | 36261.696 | 1.075 | 0.290 | | (8, 16, 4096, 128) | head_bias | torch.bfloat16 | 3602.265 | 3436.444 | 10458.660 | 36507.790 | 1.048 | 0.286 | | (8, 16, 4096, 256) | noop | torch.bfloat16 | 7695.923 | 7126.275 | 24643.009 | 140949.081 | 1.080 | 0.175 | | (8, 16, 4096, 256) | causal_mask | torch.bfloat16 | 7679.939 | 7186.252 | 24538.105 | 157156.067 | 1.069 | 0.156 | | (8, 16, 4096, 256) | relative_bias | torch.bfloat16 | 7681.374 | 6994.832 | 24549.713 | 140077.179 | 1.098 | 0.175 | | (8, 16, 4096, 256) | head_bias | torch.bfloat16 | 7679.822 | 7212.278 | 24627.823 | 140675.003 | 1.065 | 0.175 | | (16, 16, 512, 64) | noop | torch.bfloat16 | 80.126 | 78.291 | 333.719 | 541.165 | 1.023 | 0.617 | | (16, 16, 512, 64) | causal_mask | torch.bfloat16 | 80.065 | 81.696 | 333.779 | 551.113 | 0.980 | 0.606 | | (16, 16, 512, 64) | relative_bias | torch.bfloat16 | 80.138 | 86.715 | 333.364 | 542.118 | 0.924 | 0.615 | | (16, 16, 512, 64) | head_bias | torch.bfloat16 | 80.415 | 85.204 | 333.294 | 536.840 | 0.944 | 0.621 | | (16, 16, 512, 128) | noop | torch.bfloat16 | 134.964 | 138.025 | 607.093 | 1333.102 | 0.978 | 0.455 | | (16, 16, 512, 128) | causal_mask | torch.bfloat16 | 134.192 | 141.523 | 606.269 | 1424.318 | 0.948 | 0.426 | | (16, 16, 512, 128) | relative_bias | torch.bfloat16 | 135.711 | 138.639 | 606.283 | 1327.974 | 0.979 | 0.457 | | (16, 16, 512, 128) | head_bias | torch.bfloat16 | 135.552 | 140.555 | 607.107 | 1347.370 | 0.964 | 0.451 | | (16, 16, 512, 256) | noop | torch.bfloat16 | 275.113 | 315.144 | 1301.583 | 5268.153 | 0.873 | 0.247 | | (16, 16, 512, 256) | causal_mask | torch.bfloat16 | 274.867 | 328.106 | 1302.513 | 5770.594 | 0.838 | 0.226 | | (16, 16, 512, 256) | relative_bias | torch.bfloat16 | 276.052 | 321.770 | 1302.904 | 5241.920 | 0.858 | 0.249 | | (16, 16, 512, 256) | head_bias | torch.bfloat16 | 271.409 | 328.839 | 1302.142 | 5266.037 | 0.825 | 0.247 | | (16, 16, 1024, 64) | noop | torch.bfloat16 | 260.489 | 237.463 | 955.884 | 1817.558 | 1.097 | 0.526 | | (16, 16, 1024, 64) | causal_mask | torch.bfloat16 | 262.378 | 254.350 | 955.280 | 1843.807 | 1.032 | 0.518 | | (16, 16, 1024, 64) | relative_bias | torch.bfloat16 | 261.338 | 268.253 | 956.038 | 1820.036 | 0.974 | 0.525 | | (16, 16, 1024, 64) | head_bias | torch.bfloat16 | 262.153 | 264.156 | 956.023 | 1810.076 | 0.992 | 0.528 | | (16, 16, 1024, 128) | noop | torch.bfloat16 | 476.475 | 461.413 | 1760.578 | 4306.521 | 1.033 | 0.409 | | (16, 16, 1024, 128) | causal_mask | torch.bfloat16 | 473.794 | 479.178 | 1761.277 | 4619.439 | 0.989 | 0.381 | | (16, 16, 1024, 128) | relative_bias | torch.bfloat16 | 473.839 | 463.282 | 1758.692 | 4290.562 | 1.023 | 0.410 | | (16, 16, 1024, 128) | head_bias | torch.bfloat16 | 472.979 | 472.896 | 1763.086 | 4367.931 | 1.000 | 0.404 | | (16, 16, 1024, 256) | noop | torch.bfloat16 | 1014.184 | 1026.764 | 3922.997 | 19104.147 | 0.988 | 0.205 | | (16, 16, 1024, 256) | causal_mask | torch.bfloat16 | 1013.217 | 1039.046 | 3928.382 | 21086.281 | 0.975 | 0.186 | | (16, 16, 1024, 256) | relative_bias | torch.bfloat16 | 1008.519 | 1015.278 | 3922.133 | 18980.652 | 0.993 | 0.207 | | (16, 16, 1024, 256) | head_bias | torch.bfloat16 | 1011.360 | 1047.542 | 3931.245 | 19069.172 | 0.965 | 0.206 | | (16, 16, 4096, 64) | noop | torch.bfloat16 | 3929.850 | 3325.667 | 11411.704 | 23344.280 | 1.182 | 0.489 | | (16, 16, 4096, 64) | causal_mask | torch.bfloat16 | 3885.262 | 3581.544 | 11390.515 | 23725.639 | 1.085 | 0.480 | | (16, 16, 4096, 64) | relative_bias | torch.bfloat16 | 3865.737 | 3537.308 | 11489.901 | 23406.330 | 1.093 | 0.491 | | (16, 16, 4096, 64) | head_bias | torch.bfloat16 | 3880.530 | 3665.249 | 11484.411 | 23299.496 | 1.059 | 0.493 | | (16, 16, 4096, 128) | noop | torch.bfloat16 | 7030.306 | 6745.715 | 20621.264 | 57464.096 | 1.042 | 0.359 | | (16, 16, 4096, 128) | causal_mask | torch.bfloat16 | 7095.414 | 7034.385 | 20410.656 | 61660.511 | 1.009 | 0.331 | | (16, 16, 4096, 128) | relative_bias | torch.bfloat16 | 7084.779 | 6686.497 | 20315.161 | 57243.969 | 1.060 | 0.355 | | (16, 16, 4096, 128) | head_bias | torch.bfloat16 | 7075.367 | 6863.305 | 20494.385 | 58481.953 | 1.031 | 0.350 | | (16, 16, 4096, 256) | noop | torch.bfloat16 | 15612.741 | 14297.482 | 55306.847 | 281161.865 | 1.092 | 0.197 | | (16, 16, 4096, 256) | causal_mask | torch.bfloat16 | 15326.592 | 14263.878 | 55227.806 | 313063.232 | 1.075 | 0.176 | | (16, 16, 4096, 256) | relative_bias | torch.bfloat16 | 15297.963 | 14007.379 | 54558.029 | 279529.175 | 1.092 | 0.195 | | (16, 16, 4096, 256) | head_bias | torch.bfloat16 | 15216.160 | 14276.027 | 55081.581 | 280996.826 | 1.066 | 0.196 | </details> Pull Request resolved: https://github.com/pytorch/pytorch/pull/125515 Approved by: https://github.com/Chillee |
||
|
|
0716f75cfb |
Revert "Add Lowering for FlexAttention Backwards (#125515)"
This reverts commit |
||
|
|
95b9e981c3 |
Add Lowering for FlexAttention Backwards (#125515)
# Summary #### What does this PR do? It enables Inductor to actually generate the fused flex attention kernel for the backwards I did some other things along the way: - Abstract out the 'build_subgraph_buffer' subroutine and make it reusable between flex attention and flex_attention backwards. In total we need too build 3 subgraphs for fwd + bwd. 1 for the fwd graph and then 2 in the bwd. The FAv2 algorithm recomputes the parts of the forward (more efficiently since we already have the row_max via logsumexp), therefore we need to inline both the fwd graph and the joint graph in the bwds kernel. - The version of the backwards kernel is from a somewhat older version of the triton tutorial implementation. I think that we should update in a follow up to a newer version. Notably the blocks need to be square for this to work as currently implemented. I am sure there are many opportunities for optimization. - I didnt correctly register the decomp table + IndexMode when I landed: https://github.com/pytorch/pytorch/pull/123902, this remedies that. - The rel_bias helper func was reversed in terms of causality. I updated and then add a test specific for "future causal" attention. - This PRs but the main point that I think still needs to be worked out is the store_output call. I have it hacked up to be 'fake' but I dont think we want to land that and likely want to just have a mutated 'dq' and a stored_output 'dk' - I also needed to update the `TritonTemplateKernel` to actually accept multiple subgraphs (modifications) - I updated the benchmark to also profile bwds performance ### Benchmark Numbers: _The current implementation is not parallelizing over ctx length in the bwd_ FWD Speedups | Type | Speedup | shape | score_mod | dtype | |---------|-----------|--------------------|-------------|----------------| | Average | 0.991 | | | | | Max | 1.182 | (16, 16, 4096, 64) | noop | torch.bfloat16 | | Min | 0.796 | (2, 16, 512, 256) | head_bias | torch.bfloat16 | BWD Speedups | Type | Speedup | shape | score_mod | dtype | |---------|-----------|--------------------|-------------|----------------| | Average | 0.291 | | | | | Max | 0.652 | (8, 16, 512, 64) | head_bias | torch.bfloat16 | | Min | 0.073 | (2, 16, 4096, 128) | head_bias | torch.bfloat16 | <details> <summary>Full Data</summary> | shape | score_mod | dtype | fwd_eager_time | fwd_compiled_time | bwd_eager_time | bwd_compiled_time | fwd_speedup | bwd_speedup | |---------------------|---------------|----------------|------------------|---------------------|------------------|---------------------|---------------|---------------| | (2, 16, 512, 64) | noop | torch.bfloat16 | 19.936 | 19.092 | 57.851 | 193.564 | 1.044 | 0.299 | | (2, 16, 512, 64) | causal_mask | torch.bfloat16 | 19.955 | 19.497 | 57.662 | 206.278 | 1.024 | 0.280 | | (2, 16, 512, 64) | relative_bias | torch.bfloat16 | 19.455 | 21.297 | 57.674 | 195.219 | 0.913 | 0.295 | | (2, 16, 512, 64) | head_bias | torch.bfloat16 | 19.958 | 21.289 | 57.674 | 193.859 | 0.938 | 0.298 | | (2, 16, 512, 128) | noop | torch.bfloat16 | 28.157 | 28.615 | 82.831 | 454.211 | 0.984 | 0.182 | | (2, 16, 512, 128) | causal_mask | torch.bfloat16 | 28.154 | 28.444 | 83.091 | 432.083 | 0.990 | 0.192 | | (2, 16, 512, 128) | relative_bias | torch.bfloat16 | 28.722 | 27.897 | 83.175 | 446.789 | 1.030 | 0.186 | | (2, 16, 512, 128) | head_bias | torch.bfloat16 | 28.299 | 27.673 | 83.052 | 459.179 | 1.023 | 0.181 | | (2, 16, 512, 256) | noop | torch.bfloat16 | 41.167 | 50.504 | 175.019 | 1083.545 | 0.815 | 0.162 | | (2, 16, 512, 256) | causal_mask | torch.bfloat16 | 41.656 | 51.933 | 175.078 | 1171.176 | 0.802 | 0.149 | | (2, 16, 512, 256) | relative_bias | torch.bfloat16 | 41.697 | 50.722 | 175.159 | 1097.312 | 0.822 | 0.160 | | (2, 16, 512, 256) | head_bias | torch.bfloat16 | 41.690 | 52.387 | 175.184 | 1097.336 | 0.796 | 0.160 | | (2, 16, 1024, 64) | noop | torch.bfloat16 | 39.232 | 37.454 | 127.847 | 612.430 | 1.047 | 0.209 | | (2, 16, 1024, 64) | causal_mask | torch.bfloat16 | 39.930 | 39.599 | 127.755 | 665.359 | 1.008 | 0.192 | | (2, 16, 1024, 64) | relative_bias | torch.bfloat16 | 39.417 | 41.304 | 127.902 | 614.990 | 0.954 | 0.208 | | (2, 16, 1024, 64) | head_bias | torch.bfloat16 | 39.965 | 42.034 | 127.953 | 613.273 | 0.951 | 0.209 | | (2, 16, 1024, 128) | noop | torch.bfloat16 | 63.964 | 71.024 | 226.510 | 1637.669 | 0.901 | 0.138 | | (2, 16, 1024, 128) | causal_mask | torch.bfloat16 | 63.843 | 72.451 | 226.750 | 1558.949 | 0.881 | 0.145 | | (2, 16, 1024, 128) | relative_bias | torch.bfloat16 | 64.301 | 70.487 | 226.651 | 1610.063 | 0.912 | 0.141 | | (2, 16, 1024, 128) | head_bias | torch.bfloat16 | 64.033 | 71.394 | 226.676 | 1668.511 | 0.897 | 0.136 | | (2, 16, 1024, 256) | noop | torch.bfloat16 | 129.348 | 141.390 | 507.337 | 4405.175 | 0.915 | 0.115 | | (2, 16, 1024, 256) | causal_mask | torch.bfloat16 | 129.538 | 145.680 | 507.178 | 4768.874 | 0.889 | 0.106 | | (2, 16, 1024, 256) | relative_bias | torch.bfloat16 | 129.438 | 142.782 | 507.004 | 4401.002 | 0.907 | 0.115 | | (2, 16, 1024, 256) | head_bias | torch.bfloat16 | 129.058 | 146.242 | 507.547 | 4434.251 | 0.883 | 0.114 | | (2, 16, 4096, 64) | noop | torch.bfloat16 | 481.606 | 409.120 | 1440.890 | 14147.269 | 1.177 | 0.102 | | (2, 16, 4096, 64) | causal_mask | torch.bfloat16 | 480.227 | 438.847 | 1434.419 | 14973.386 | 1.094 | 0.096 | | (2, 16, 4096, 64) | relative_bias | torch.bfloat16 | 480.831 | 458.104 | 1432.935 | 14193.253 | 1.050 | 0.101 | | (2, 16, 4096, 64) | head_bias | torch.bfloat16 | 480.749 | 452.497 | 1437.040 | 14084.869 | 1.062 | 0.102 | | (2, 16, 4096, 128) | noop | torch.bfloat16 | 872.534 | 848.275 | 2600.895 | 35156.849 | 1.029 | 0.074 | | (2, 16, 4096, 128) | causal_mask | torch.bfloat16 | 872.647 | 868.279 | 2587.581 | 31919.531 | 1.005 | 0.081 | | (2, 16, 4096, 128) | relative_bias | torch.bfloat16 | 871.484 | 827.644 | 2593.989 | 34805.634 | 1.053 | 0.075 | | (2, 16, 4096, 128) | head_bias | torch.bfloat16 | 871.422 | 856.437 | 2602.482 | 35708.591 | 1.017 | 0.073 | | (2, 16, 4096, 256) | noop | torch.bfloat16 | 1904.497 | 1758.183 | 6122.416 | 66754.593 | 1.083 | 0.092 | | (2, 16, 4096, 256) | causal_mask | torch.bfloat16 | 1911.174 | 1762.821 | 6113.207 | 72759.392 | 1.084 | 0.084 | | (2, 16, 4096, 256) | relative_bias | torch.bfloat16 | 1911.254 | 1727.108 | 6123.530 | 66577.988 | 1.107 | 0.092 | | (2, 16, 4096, 256) | head_bias | torch.bfloat16 | 1916.977 | 1801.804 | 6118.158 | 67359.680 | 1.064 | 0.091 | | (8, 16, 512, 64) | noop | torch.bfloat16 | 44.984 | 43.974 | 170.276 | 262.259 | 1.023 | 0.649 | | (8, 16, 512, 64) | causal_mask | torch.bfloat16 | 45.001 | 46.265 | 170.509 | 274.893 | 0.973 | 0.620 | | (8, 16, 512, 64) | relative_bias | torch.bfloat16 | 45.466 | 48.211 | 170.606 | 262.759 | 0.943 | 0.649 | | (8, 16, 512, 64) | head_bias | torch.bfloat16 | 45.481 | 48.435 | 170.267 | 261.265 | 0.939 | 0.652 | | (8, 16, 512, 128) | noop | torch.bfloat16 | 72.565 | 74.736 | 313.220 | 773.126 | 0.971 | 0.405 | | (8, 16, 512, 128) | causal_mask | torch.bfloat16 | 72.015 | 75.755 | 313.311 | 775.513 | 0.951 | 0.404 | | (8, 16, 512, 128) | relative_bias | torch.bfloat16 | 72.105 | 74.189 | 313.806 | 769.238 | 0.972 | 0.408 | | (8, 16, 512, 128) | head_bias | torch.bfloat16 | 72.005 | 74.364 | 313.509 | 775.237 | 0.968 | 0.404 | | (8, 16, 512, 256) | noop | torch.bfloat16 | 138.656 | 165.453 | 663.707 | 2672.067 | 0.838 | 0.248 | | (8, 16, 512, 256) | causal_mask | torch.bfloat16 | 139.096 | 172.613 | 663.593 | 2926.538 | 0.806 | 0.227 | | (8, 16, 512, 256) | relative_bias | torch.bfloat16 | 139.500 | 168.417 | 663.938 | 2658.629 | 0.828 | 0.250 | | (8, 16, 512, 256) | head_bias | torch.bfloat16 | 139.776 | 173.549 | 662.920 | 2667.266 | 0.805 | 0.249 | | (8, 16, 1024, 64) | noop | torch.bfloat16 | 134.883 | 125.004 | 484.706 | 1195.254 | 1.079 | 0.406 | | (8, 16, 1024, 64) | causal_mask | torch.bfloat16 | 134.297 | 132.875 | 485.420 | 1234.953 | 1.011 | 0.393 | | (8, 16, 1024, 64) | relative_bias | torch.bfloat16 | 134.839 | 139.231 | 485.470 | 1198.556 | 0.968 | 0.405 | | (8, 16, 1024, 64) | head_bias | torch.bfloat16 | 133.822 | 136.449 | 485.608 | 1189.198 | 0.981 | 0.408 | | (8, 16, 1024, 128) | noop | torch.bfloat16 | 235.470 | 234.765 | 886.094 | 2662.944 | 1.003 | 0.333 | | (8, 16, 1024, 128) | causal_mask | torch.bfloat16 | 236.305 | 241.382 | 886.293 | 2646.984 | 0.979 | 0.335 | | (8, 16, 1024, 128) | relative_bias | torch.bfloat16 | 236.414 | 233.980 | 885.250 | 2642.178 | 1.010 | 0.335 | | (8, 16, 1024, 128) | head_bias | torch.bfloat16 | 237.176 | 239.040 | 885.754 | 2665.242 | 0.992 | 0.332 | | (8, 16, 1024, 256) | noop | torch.bfloat16 | 504.445 | 517.855 | 1978.956 | 9592.906 | 0.974 | 0.206 | | (8, 16, 1024, 256) | causal_mask | torch.bfloat16 | 502.428 | 536.002 | 1978.611 | 10607.342 | 0.937 | 0.187 | | (8, 16, 1024, 256) | relative_bias | torch.bfloat16 | 503.396 | 523.960 | 1977.993 | 9539.284 | 0.961 | 0.207 | | (8, 16, 1024, 256) | head_bias | torch.bfloat16 | 503.818 | 536.014 | 1980.131 | 9576.262 | 0.940 | 0.207 | | (8, 16, 4096, 64) | noop | torch.bfloat16 | 1970.139 | 1674.930 | 5750.940 | 16724.134 | 1.176 | 0.344 | | (8, 16, 4096, 64) | causal_mask | torch.bfloat16 | 1959.036 | 1775.056 | 5780.512 | 17390.350 | 1.104 | 0.332 | | (8, 16, 4096, 64) | relative_bias | torch.bfloat16 | 1947.198 | 1773.869 | 5780.643 | 16779.699 | 1.098 | 0.345 | | (8, 16, 4096, 64) | head_bias | torch.bfloat16 | 1963.935 | 1829.502 | 5780.018 | 16703.259 | 1.073 | 0.346 | | (8, 16, 4096, 128) | noop | torch.bfloat16 | 3582.711 | 3362.623 | 10436.069 | 36415.565 | 1.065 | 0.287 | | (8, 16, 4096, 128) | causal_mask | torch.bfloat16 | 3581.504 | 3499.472 | 10346.869 | 36164.959 | 1.023 | 0.286 | | (8, 16, 4096, 128) | relative_bias | torch.bfloat16 | 3589.779 | 3337.849 | 10529.621 | 36261.696 | 1.075 | 0.290 | | (8, 16, 4096, 128) | head_bias | torch.bfloat16 | 3602.265 | 3436.444 | 10458.660 | 36507.790 | 1.048 | 0.286 | | (8, 16, 4096, 256) | noop | torch.bfloat16 | 7695.923 | 7126.275 | 24643.009 | 140949.081 | 1.080 | 0.175 | | (8, 16, 4096, 256) | causal_mask | torch.bfloat16 | 7679.939 | 7186.252 | 24538.105 | 157156.067 | 1.069 | 0.156 | | (8, 16, 4096, 256) | relative_bias | torch.bfloat16 | 7681.374 | 6994.832 | 24549.713 | 140077.179 | 1.098 | 0.175 | | (8, 16, 4096, 256) | head_bias | torch.bfloat16 | 7679.822 | 7212.278 | 24627.823 | 140675.003 | 1.065 | 0.175 | | (16, 16, 512, 64) | noop | torch.bfloat16 | 80.126 | 78.291 | 333.719 | 541.165 | 1.023 | 0.617 | | (16, 16, 512, 64) | causal_mask | torch.bfloat16 | 80.065 | 81.696 | 333.779 | 551.113 | 0.980 | 0.606 | | (16, 16, 512, 64) | relative_bias | torch.bfloat16 | 80.138 | 86.715 | 333.364 | 542.118 | 0.924 | 0.615 | | (16, 16, 512, 64) | head_bias | torch.bfloat16 | 80.415 | 85.204 | 333.294 | 536.840 | 0.944 | 0.621 | | (16, 16, 512, 128) | noop | torch.bfloat16 | 134.964 | 138.025 | 607.093 | 1333.102 | 0.978 | 0.455 | | (16, 16, 512, 128) | causal_mask | torch.bfloat16 | 134.192 | 141.523 | 606.269 | 1424.318 | 0.948 | 0.426 | | (16, 16, 512, 128) | relative_bias | torch.bfloat16 | 135.711 | 138.639 | 606.283 | 1327.974 | 0.979 | 0.457 | | (16, 16, 512, 128) | head_bias | torch.bfloat16 | 135.552 | 140.555 | 607.107 | 1347.370 | 0.964 | 0.451 | | (16, 16, 512, 256) | noop | torch.bfloat16 | 275.113 | 315.144 | 1301.583 | 5268.153 | 0.873 | 0.247 | | (16, 16, 512, 256) | causal_mask | torch.bfloat16 | 274.867 | 328.106 | 1302.513 | 5770.594 | 0.838 | 0.226 | | (16, 16, 512, 256) | relative_bias | torch.bfloat16 | 276.052 | 321.770 | 1302.904 | 5241.920 | 0.858 | 0.249 | | (16, 16, 512, 256) | head_bias | torch.bfloat16 | 271.409 | 328.839 | 1302.142 | 5266.037 | 0.825 | 0.247 | | (16, 16, 1024, 64) | noop | torch.bfloat16 | 260.489 | 237.463 | 955.884 | 1817.558 | 1.097 | 0.526 | | (16, 16, 1024, 64) | causal_mask | torch.bfloat16 | 262.378 | 254.350 | 955.280 | 1843.807 | 1.032 | 0.518 | | (16, 16, 1024, 64) | relative_bias | torch.bfloat16 | 261.338 | 268.253 | 956.038 | 1820.036 | 0.974 | 0.525 | | (16, 16, 1024, 64) | head_bias | torch.bfloat16 | 262.153 | 264.156 | 956.023 | 1810.076 | 0.992 | 0.528 | | (16, 16, 1024, 128) | noop | torch.bfloat16 | 476.475 | 461.413 | 1760.578 | 4306.521 | 1.033 | 0.409 | | (16, 16, 1024, 128) | causal_mask | torch.bfloat16 | 473.794 | 479.178 | 1761.277 | 4619.439 | 0.989 | 0.381 | | (16, 16, 1024, 128) | relative_bias | torch.bfloat16 | 473.839 | 463.282 | 1758.692 | 4290.562 | 1.023 | 0.410 | | (16, 16, 1024, 128) | head_bias | torch.bfloat16 | 472.979 | 472.896 | 1763.086 | 4367.931 | 1.000 | 0.404 | | (16, 16, 1024, 256) | noop | torch.bfloat16 | 1014.184 | 1026.764 | 3922.997 | 19104.147 | 0.988 | 0.205 | | (16, 16, 1024, 256) | causal_mask | torch.bfloat16 | 1013.217 | 1039.046 | 3928.382 | 21086.281 | 0.975 | 0.186 | | (16, 16, 1024, 256) | relative_bias | torch.bfloat16 | 1008.519 | 1015.278 | 3922.133 | 18980.652 | 0.993 | 0.207 | | (16, 16, 1024, 256) | head_bias | torch.bfloat16 | 1011.360 | 1047.542 | 3931.245 | 19069.172 | 0.965 | 0.206 | | (16, 16, 4096, 64) | noop | torch.bfloat16 | 3929.850 | 3325.667 | 11411.704 | 23344.280 | 1.182 | 0.489 | | (16, 16, 4096, 64) | causal_mask | torch.bfloat16 | 3885.262 | 3581.544 | 11390.515 | 23725.639 | 1.085 | 0.480 | | (16, 16, 4096, 64) | relative_bias | torch.bfloat16 | 3865.737 | 3537.308 | 11489.901 | 23406.330 | 1.093 | 0.491 | | (16, 16, 4096, 64) | head_bias | torch.bfloat16 | 3880.530 | 3665.249 | 11484.411 | 23299.496 | 1.059 | 0.493 | | (16, 16, 4096, 128) | noop | torch.bfloat16 | 7030.306 | 6745.715 | 20621.264 | 57464.096 | 1.042 | 0.359 | | (16, 16, 4096, 128) | causal_mask | torch.bfloat16 | 7095.414 | 7034.385 | 20410.656 | 61660.511 | 1.009 | 0.331 | | (16, 16, 4096, 128) | relative_bias | torch.bfloat16 | 7084.779 | 6686.497 | 20315.161 | 57243.969 | 1.060 | 0.355 | | (16, 16, 4096, 128) | head_bias | torch.bfloat16 | 7075.367 | 6863.305 | 20494.385 | 58481.953 | 1.031 | 0.350 | | (16, 16, 4096, 256) | noop | torch.bfloat16 | 15612.741 | 14297.482 | 55306.847 | 281161.865 | 1.092 | 0.197 | | (16, 16, 4096, 256) | causal_mask | torch.bfloat16 | 15326.592 | 14263.878 | 55227.806 | 313063.232 | 1.075 | 0.176 | | (16, 16, 4096, 256) | relative_bias | torch.bfloat16 | 15297.963 | 14007.379 | 54558.029 | 279529.175 | 1.092 | 0.195 | | (16, 16, 4096, 256) | head_bias | torch.bfloat16 | 15216.160 | 14276.027 | 55081.581 | 280996.826 | 1.066 | 0.196 | </details> Pull Request resolved: https://github.com/pytorch/pytorch/pull/125515 Approved by: https://github.com/Chillee |
||
|
|
461ffaaaf3 |
[dynamo] support torchbind object input (#124978)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124978 Approved by: https://github.com/jansel |
||
|
|
25691558d9 |
Change templated_attention -> flex_attention (#125251)
# Summary Change all the names Pull Request resolved: https://github.com/pytorch/pytorch/pull/125251 Approved by: https://github.com/Chillee, https://github.com/yanboliang |
||
|
|
8c219251c5 |
Add backwards support to FlexAttention (#123902)
# Summary This is part one of adding backwards support to FlexAttention. This PR focuses on the eager implementation and wiring up enough of the templated_attention_backward(name change soon 😉) to get through aot_eager. Notably this does not actually wire up the triton template just yet in order to make this PR easier to review. That will be the next follow up PR. #### Structure We pass both the forward and backward graph to the backwardsHOP since these are both needed to be inlined into the calculation for backwards: - the forward graph is needed in order to re-compute the scores - the joint graph is needed in order to construct the correct gradients post softmax_grad calc ### Attatched AOT Graph https://gist.github.com/drisspg/ce4c041f8df8a5a7983c5174705cf2b5 Pull Request resolved: https://github.com/pytorch/pytorch/pull/123902 Approved by: https://github.com/Chillee |
||
|
|
e979f45610 |
[while_loop] add a simiple op_info test (#123814)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123814 Approved by: https://github.com/tugsbayasgalan, https://github.com/zou3519 |
||
|
|
f4e2a226aa |
ScoreMod API (#121845)
# Summary This PR adds a new higher-order_op: `templated_attention`. This op is designed to extend the functionality of torch.nn.fucntional.scaled_dot_product_attention. PyTorch has efficient pre-written fused-attention kernels. However, users want to modify how scores are computed (a substep inside attention) -- this traditionally requires the user to write their own attention kernel. One such modification to attention scores that is not currently supported by the top level SDPA op is:[ Attention with Linear Biases (ALiBi](https://arxiv.org/abs/2108.12409)). This higher-order op will instead accept a callable( 'score_mod') function that is through torch.compile will be used to create an efficient attention kernel instantiation. ### Details This HOP utilizes the existing fx and HOP infra to capture and convert the User `score-mod` function and convert to an FX graph module. Inductor then consumes this HOP that has a `ir.Subgraph` input. It will inline this lowered subgraph into a triton kernel which performs fused attention with the modification to the scores matrix inlined. ### API The API for a score_mod function should be as follows: ```Python def score_mod(score: torch.Tensor, batch: torch.Tensor, head: torch.Tensor, token_1: torch.Tensor, token_kv: torch.Tensor) -> torch.Tensor ``` This function receives five parameters: - `score`: A scalar tensor representing the attention score, with the same data type and device as the query, key, and value tensors. - `batch`, `head`, `seq_len_q`, `seq_len_kv`: Scalar tensors indicating the batch index, head index, query index, and key/value index, respectively, with torch.int data type and located on the same device as the score tensor. Consider inputs query, key, value of shapes (2, 4, 16, 8), leading to an intermediate attention score matrix of shape (2, 4, 16, 16) The score_mod function will be vectorized over each element of this matrix. For instance, modifying the score at the position corresponding to the 0th batch, 2nd head, between the 8th query and the 9th key element, would be invoked as: ```Python score_mod(score[0,2,8,9], torch.tensor(0), torch.tensor(2), torch.tensor(8), torch.tensor(9)) ``` ### Examples ```Python import torch from torch.nn.attention.templated_attention import templated_attention torch.manual_seed(0) # Lets create some input tensors # The input tensor has shape (batch_size, num_heads, seq_len, head_dim) query = torch.randn(8, 8, 2048, 64, device="cuda", dtype=torch.float32) key = torch.randn(8, 8, 2048, 64, device="cuda", dtype=torch.float32) value = torch.randn(8, 8, 2048, 64, device="cuda", dtype=torch.float32) # Lets create a fun new score_modification! I will call this # Checkerboard. It will reduce the score for neighboring tokens (1 step apart) # in the sequence. And increase the score for tokens 2 steps apart. For everything # else, the score will remain the same. def checkerboard(score, batch, head, token_q, token_kv): score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score) score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score) return score # Lets call templated_attention with this new score modification output = templated_attention(query, key, value, score_mod=checkerboard) compiled_templated_attention = torch.compile(templated_attention) out_compiled = compiled_templated_attention(query, key, value, score_mod=checkerboard) torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2) ``` ### Future Work - This PR is currently only forward only. However the triton kernel for backwards where score_modifications to not rely on external buffers has been explored here: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/flash/flash_attention.py - Kernel Improvements; There are has been some larger updates to the fused attention implementation that Triton uses in its tutorials. The implementation of this kernel is based on a prior version and should be updated. - We may want to unify this API under the top level SDPA API and leave that as a follow up once this is more stable - Should we error on CPU? - There are some issues with dynamic shapes - Capturing of free variables and lifting to inputs to the subgraph is not working correctly today ### Performance Comparisons generated by this benchmark: | Type | Speedup | batch_size | num_heads | q_seq_len | k_seq_len | head_dim | score_mod | dtype | |---------|-----------|--------------|-------------|-------------|-------------|------------|---------------|----------------| | Average | 5.412 | | | | | | | | | Max | 8.882 | 16 | 16 | 4096 | 4096 | 64 | relative_bias | torch.bfloat16 | | Min | 3.645 | 8 | 16 | 512 | 512 | 64 | causal_mask | torch.bfloat16 | | Min | 0.345 | 1 | 16 | 1024 | 1024 | 64 | pathological | torch.bfloat16 | For reference | Configuration | Forward Time (µ seconds) | Backend | Speedup | |-----------------------------------------------|--------------------------|------------------|---------| | Fastest Config in Sweep (`8 16 4096 4096 64 relative_bias torch.bfloat16`) | 3608 | Templated Attention | 1.0 | | Compiled SDPA (No Mask) | 9928 | Math | 2.75x | | Compiled SDPA (With Mask) | 11898 | Math | 3.29x | | Compiled SDPA (With Mask) | 8704 | Memory Efficient Attention | 2.42x | | Compiled SDPA (No Mask) | 2548 | FlashAttention2 | 0.706x | The speedups are measuring compiled templated attention speed versus different calls to torch.nn.functional.sdpa <details> <summary> FULL PERFORMANCE SWEEP NUMBERS </summary> | batch_size | num_heads | q_seq_len | k_seq_len | head_dim | score_mod | dtype | eager_time | compiled_time | speedup | |--------------|-------------|-------------|-------------|------------|---------------|----------------|--------------|-----------------|-----------| | 1 | 16 | 512 | 512 | 64 | causal_mask | torch.bfloat16 | 331.444 | 67.221 | 4.931 | | 1 | 16 | 512 | 512 | 64 | relative_bias | torch.bfloat16 | 335.300 | 64.187 | 5.224 | | 1 | 16 | 512 | 512 | 64 | head_bias | torch.bfloat16 | 352.039 | 63.806 | 5.517 | | 1 | 16 | 512 | 512 | 64 | pathological | torch.bfloat16 | 371.699 | 711.349 | 0.523 | | 1 | 16 | 1024 | 1024 | 64 | causal_mask | torch.bfloat16 | 333.488 | 86.455 | 3.857 | | 1 | 16 | 1024 | 1024 | 64 | relative_bias | torch.bfloat16 | 322.363 | 82.469 | 3.909 | | 1 | 16 | 1024 | 1024 | 64 | head_bias | torch.bfloat16 | 349.967 | 82.233 | 4.256 | | 1 | 16 | 1024 | 1024 | 64 | pathological | torch.bfloat16 | 486.359 | 1412.453 | 0.344 | | 1 | 16 | 4096 | 4096 | 64 | causal_mask | torch.bfloat16 | 2794.597 | 551.188 | 5.070 | | 1 | 16 | 4096 | 4096 | 64 | relative_bias | torch.bfloat16 | 3965.150 | 513.101 | 7.728 | | 1 | 16 | 4096 | 4096 | 64 | head_bias | torch.bfloat16 | 2408.013 | 504.759 | 4.771 | | 1 | 16 | 4096 | 4096 | 64 | pathological | torch.bfloat16 | 6850.531 | 16733.675 | 0.409 | | 8 | 16 | 512 | 512 | 64 | causal_mask | torch.bfloat16 | 441.939 | 123.576 | 3.576 | | 8 | 16 | 512 | 512 | 64 | relative_bias | torch.bfloat16 | 560.379 | 116.710 | 4.801 | | 8 | 16 | 512 | 512 | 64 | head_bias | torch.bfloat16 | 421.172 | 115.825 | 3.636 | | 8 | 16 | 512 | 512 | 64 | pathological | torch.bfloat16 | 994.492 | 2132.806 | 0.466 | | 8 | 16 | 1024 | 1024 | 64 | causal_mask | torch.bfloat16 | 1436.430 | 309.495 | 4.641 | | 8 | 16 | 1024 | 1024 | 64 | relative_bias | torch.bfloat16 | 1892.216 | 290.186 | 6.521 | | 8 | 16 | 1024 | 1024 | 64 | head_bias | torch.bfloat16 | 1360.665 | 282.956 | 4.809 | | 8 | 16 | 1024 | 1024 | 64 | pathological | torch.bfloat16 | 3525.532 | 8359.702 | 0.422 | | 8 | 16 | 4096 | 4096 | 64 | causal_mask | torch.bfloat16 | 22026.839 | 3864.604 | 5.700 | | 8 | 16 | 4096 | 4096 | 64 | relative_bias | torch.bfloat16 | 31262.746 | 3609.551 | 8.661 | | 8 | 16 | 4096 | 4096 | 64 | head_bias | torch.bfloat16 | 20219.079 | 3480.402 | 5.809 | | 8 | 16 | 4096 | 4096 | 64 | pathological | torch.bfloat16 | 54654.647 | 116652.357 | 0.469 | | 16 | 16 | 512 | 512 | 64 | causal_mask | torch.bfloat16 | 820.606 | 188.683 | 4.349 | | 16 | 16 | 512 | 512 | 64 | relative_bias | torch.bfloat16 | 1058.362 | 179.295 | 5.903 | | 16 | 16 | 512 | 512 | 64 | head_bias | torch.bfloat16 | 784.372 | 175.714 | 4.464 | | 16 | 16 | 512 | 512 | 64 | pathological | torch.bfloat16 | 1890.792 | 4212.877 | 0.449 | | 16 | 16 | 1024 | 1024 | 64 | causal_mask | torch.bfloat16 | 2781.830 | 557.017 | 4.994 | | 16 | 16 | 1024 | 1024 | 64 | relative_bias | torch.bfloat16 | 3694.050 | 525.249 | 7.033 | | 16 | 16 | 1024 | 1024 | 64 | head_bias | torch.bfloat16 | 2634.164 | 507.613 | 5.189 | | 16 | 16 | 1024 | 1024 | 64 | pathological | torch.bfloat16 | 6959.917 | 15331.116 | 0.454 | | 16 | 16 | 4096 | 4096 | 64 | causal_mask | torch.bfloat16 | 43889.096 | 7582.018 | 5.789 | | 16 | 16 | 4096 | 4096 | 64 | relative_bias | torch.bfloat16 | 62784.293 | 7075.846 | 8.873 | | 16 | 16 | 4096 | 4096 | 64 | head_bias | torch.bfloat16 | 40308.606 | 6829.587 | 5.902 | | 16 | 16 | 4096 | 4096 | 64 | pathological | torch.bfloat16 | 108892.137 | 233090.953 | 0.467 | </details> Pull Request resolved: https://github.com/pytorch/pytorch/pull/121845 Approved by: https://github.com/Chillee, https://github.com/zou3519 |
||
|
|
8a0436014d |
Support map in pre-dispatch functionalization (#121444)
When we enter map_autograd, we try to trace through fwd/bwd of a map operator that is wrapped in ctx.functionalize wrapper. This forces us to go through PreDispatch functionalization again (only the python part). As a result, it revealed our previous bug where pre-dispatch mode handling doesn't actually manage the local dispatch key set. (If there is no active mode, we need to turn off PreDispatch key). This PR fixes that. Also I shuffled some APIs around so that there is less code duplication as the setting/unsetting logic is quite hard to get it right. Pull Request resolved: https://github.com/pytorch/pytorch/pull/121444 Approved by: https://github.com/bdhirsh |
||
|
|
25ad90adc0 |
Revert "Support map in pre-dispatch functionalization (#121444)"
This reverts commit
|
||
|
|
9288b27461 |
Support map in pre-dispatch functionalization (#121444)
When we enter map_autograd, we try to trace through fwd/bwd of a map operator that is wrapped in ctx.functionalize wrapper. This forces us to go through PreDispatch functionalization again (only the python part). As a result, it revealed our previous bug where pre-dispatch mode handling doesn't actually manage the local dispatch key set. (If there is no active mode, we need to turn off PreDispatch key). This PR fixes that. Also I shuffled some APIs around so that there is less code duplication as the setting/unsetting logic is quite hard to get it right. Pull Request resolved: https://github.com/pytorch/pytorch/pull/121444 Approved by: https://github.com/bdhirsh |
||
|
|
6b8205d3de |
Revert "Support map in pre-dispatch functionalization (#121444)"
This reverts commit |
||
|
|
079feea337 |
Support map in pre-dispatch functionalization (#121444)
When we enter map_autograd, we try to trace through fwd/bwd of a map operator that is wrapped in ctx.functionalize wrapper. This forces us to go through PreDispatch functionalization again (only the python part). As a result, it revealed our previous bug where pre-dispatch mode handling doesn't actually manage the local dispatch key set. (If there is no active mode, we need to turn off PreDispatch key). This PR fixes that. Also I shuffled some APIs around so that there is less code duplication as the setting/unsetting logic is quite hard to get it right. Pull Request resolved: https://github.com/pytorch/pytorch/pull/121444 Approved by: https://github.com/bdhirsh |
||
|
|
d9a08de9a4 |
Add Opinfo entries for HOP testing (#122265)
In this PR, we add a systematic way to test all HOPs to be exportable as export team has been running into various bugs related to newly added HOPs due to lack of tests. We do this by creating: - hop_db -> a list of HOP OpInfo tests which then used inside various flows including export functionalities: [aot-export, pre-dispatch export, retrace, and ser/der For now, we also create an allowlist so that people can bypass the failures for now. But we should discourage ppl to do that. Pull Request resolved: https://github.com/pytorch/pytorch/pull/122265 Approved by: https://github.com/ydwu4, https://github.com/zou3519 |