mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
adedf26e21
170 Commits
| Author | SHA1 | Message | Date | |
|---|---|---|---|---|
|
|
e6ba4d0725 |
Back out "Do not decompose in functionalization/proxy tensor if autograd wouldn't have decomposed (#164939)" (#165910)
Summary: Original commit changeset: d6d62d0c96dd Original Phabricator Diff: D84468451 and D84613184 D84468451 caused CUDA OutOfMemoryError in model. Test Plan: D84468451 was found through bisect. Also double checked on recent trunk 9866939225248c2adc307be7a804b26db0b9b555: f815887517 With this diff that backs out D84468451 and D84613184 : f816114560 Differential Revision: D85025378 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165910 Approved by: https://github.com/clee2000 |
||
|
|
51d0d8ee67 |
[ATen] Fix CUDA reduction warp shuffle order (#164790)
Typical warp shuffle reduction has the following pattern: <img width="1138" height="501" alt="image" src="https://github.com/user-attachments/assets/3bd176dc-0ad2-4df6-90c7-06e467337166" /> which is exhibited in Triton generated by torch.compile: <img width="663" height="403" alt="image" src="https://github.com/user-attachments/assets/7f9f36cd-b9eb-44c1-879e-b469668a2ea8" /> Switch the warp shuffle order to make bitwise equivalence between the 2 easier. PTX difference between old and new, we see a few extra instructions: https://www.diffchecker.com/h6ly3INC/ Comparing the performance on different reduction operations, we see minimal differences. New represents the changes in this PR, old represents the past warp shuffle order: ``` Tensor Shape Operation New all dims (ms) New dim=0 (ms) New dim=1 (ms) Old all dims (ms) Old dim=0 (ms) Old dim=1 (ms) ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 1024) mean 0.015817 0.016259 0.013642 0.015990 0.016258 0.013631 (1024, 1024) sum 0.015917 0.015906 0.013359 0.015707 0.016266 0.013226 (1024, 1024) min 0.016021 0.024625 0.015631 0.015761 0.024485 0.015317 (1024, 1024) max 0.016349 0.024971 0.015972 0.015771 0.025001 0.015314 (1024, 1024) argmin 0.018070 0.024448 0.015578 0.018135 0.025370 0.015322 (1024, 1024) argmax 0.018427 0.024859 0.015932 0.018164 0.024452 0.015639 (1024, 1024) var 0.020078 0.026413 0.020295 0.020199 0.026381 0.020214 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 2048) mean 0.023826 0.023726 0.022273 0.023236 0.023776 0.022248 (2048, 2048) sum 0.023840 0.023355 0.021974 0.023294 0.023354 0.021884 (2048, 2048) min 0.024519 0.041263 0.024620 0.023292 0.041491 0.024358 (2048, 2048) max 0.024509 0.041670 0.024277 0.023334 0.041231 0.024395 (2048, 2048) argmin 0.026125 0.041282 0.024567 0.026772 0.041773 0.024296 (2048, 2048) argmax 0.026117 0.041487 0.024572 0.026412 0.041477 0.024273 (2048, 2048) var 0.026603 0.048581 0.031308 0.027587 0.048603 0.030860 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 4096) mean 0.053927 0.057070 0.054073 0.053028 0.057544 0.053935 (4096, 4096) sum 0.053604 0.057410 0.054451 0.053076 0.057033 0.054266 (4096, 4096) min 0.054293 0.109122 0.058363 0.053821 0.108689 0.058382 (4096, 4096) max 0.054258 0.108035 0.058703 0.053492 0.110552 0.058376 (4096, 4096) argmin 0.056805 0.111167 0.058301 0.056836 0.112325 0.058292 (4096, 4096) argmax 0.056488 0.110958 0.058636 0.056844 0.111000 0.057928 (4096, 4096) var 0.058936 0.141755 0.068693 0.059735 0.141284 0.068500 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 8192) mean 0.145552 0.148082 0.138647 0.145364 0.147818 0.138207 (8192, 8192) sum 0.145985 0.147900 0.138714 0.145755 0.148031 0.138616 (8192, 8192) min 0.146566 0.205359 0.192739 0.145611 0.205237 0.182335 (8192, 8192) max 0.146526 0.204844 0.193050 0.146073 0.205457 0.182697 (8192, 8192) argmin 0.150190 0.206605 0.192543 0.150654 0.206847 0.182007 (8192, 8192) argmax 0.150481 0.206368 0.192535 0.150845 0.206430 0.182022 (8192, 8192) var 0.150884 0.184546 0.203900 0.151594 0.184172 0.197983 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1, 1024, 128) mean 0.014293 0.008119 0.014533 0.013861 0.008022 0.014449 (1, 1024, 128) sum 0.014039 0.007877 0.014111 0.014219 0.008227 0.014045 (1, 1024, 128) min 0.014159 0.011354 0.023493 0.014271 0.010862 0.023644 (1, 1024, 128) max 0.014154 0.011027 0.023368 0.014259 0.011234 0.023692 (1, 1024, 128) argmin 0.016403 0.005677 0.023328 0.016273 0.005683 0.024073 (1, 1024, 128) argmax 0.016734 0.005675 0.023437 0.016580 0.005318 0.023331 (1, 1024, 128) var 0.018338 0.009549 0.025538 0.018528 0.009391 0.024777 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (5, 1024, 128) mean 0.014873 0.010131 0.015546 0.015123 0.010131 0.015481 (5, 1024, 128) sum 0.015334 0.009673 0.015824 0.014736 0.009671 0.015438 (5, 1024, 128) min 0.015047 0.013252 0.024573 0.014803 0.013163 0.024551 (5, 1024, 128) max 0.015050 0.013339 0.024197 0.014810 0.013525 0.024230 (5, 1024, 128) argmin 0.017341 0.012737 0.024306 0.017471 0.012379 0.024991 (5, 1024, 128) argmax 0.017345 0.012411 0.024421 0.017422 0.012471 0.024237 (5, 1024, 128) var 0.019973 0.011453 0.026188 0.020050 0.011438 0.026282 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (10, 1024, 128) mean 0.016976 0.011575 0.016831 0.016722 0.011927 0.017173 (10, 1024, 128) sum 0.017039 0.011841 0.017159 0.016385 0.011860 0.016753 (10, 1024, 128) min 0.017036 0.015331 0.026770 0.016944 0.015205 0.027166 (10, 1024, 128) max 0.017369 0.015348 0.027077 0.016531 0.015716 0.026819 (10, 1024, 128) argmin 0.019203 0.014447 0.026813 0.018994 0.014497 0.027313 (10, 1024, 128) argmax 0.019563 0.014795 0.027140 0.019460 0.014912 0.026733 (10, 1024, 128) var 0.020529 0.014316 0.030405 0.020719 0.013960 0.029964 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (100, 1024, 128) mean 0.045046 0.039168 0.046082 0.044839 0.039217 0.045782 (100, 1024, 128) sum 0.045094 0.039150 0.045777 0.044496 0.039542 0.046083 (100, 1024, 128) min 0.045768 0.054466 0.076244 0.044915 0.053943 0.076599 (100, 1024, 128) max 0.045748 0.054459 0.076188 0.044931 0.053949 0.076856 (100, 1024, 128) argmin 0.048275 0.054046 0.076647 0.048694 0.054105 0.077004 (100, 1024, 128) argmax 0.048267 0.054395 0.077401 0.048691 0.054131 0.076751 (100, 1024, 128) var 0.049710 0.043254 0.083077 0.050971 0.043251 0.082378 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1000, 1000, 100) mean 0.202312 0.196723 0.197765 0.201774 0.196641 0.197459 (1000, 1000, 100) sum 0.202651 0.196682 0.197736 0.202175 0.196313 0.197523 (1000, 1000, 100) min 0.203022 0.264762 0.269200 0.202729 0.264129 0.268694 (1000, 1000, 100) max 0.202864 0.264396 0.269388 0.202486 0.263896 0.268720 (1000, 1000, 100) argmin 0.226727 0.263781 0.268651 0.226597 0.264676 0.268983 (1000, 1000, 100) argmax 0.226412 0.264469 0.269090 0.226570 0.264595 0.269178 (1000, 1000, 100) var 0.243223 0.204079 0.216096 0.241942 0.204079 0.215925 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (10000, 100) mean 0.016193 0.020277 0.014316 0.016152 0.020324 0.013712 (10000, 100) sum 0.016289 0.020237 0.014034 0.016168 0.020265 0.013708 (10000, 100) min 0.016046 0.030872 0.019609 0.016208 0.030867 0.018627 (10000, 100) max 0.016369 0.030835 0.019257 0.016218 0.030861 0.018209 (10000, 100) argmin 0.017957 0.031171 0.019517 0.018050 0.031556 0.018077 (10000, 100) argmax 0.017961 0.031658 0.019521 0.018060 0.031564 0.018087 (10000, 100) var 0.020393 0.035652 0.019339 0.020144 0.035987 0.019171 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (100000, 10) mean 0.015718 0.016576 0.016555 0.015999 0.016246 0.014869 (100000, 10) sum 0.015833 0.016247 0.016572 0.016007 0.016627 0.014872 (100000, 10) min 0.015888 0.020510 0.023920 0.015671 0.020821 0.021417 (100000, 10) max 0.015889 0.020479 0.023918 0.016077 0.020386 0.021421 (100000, 10) argmin 0.018233 0.020863 0.023647 0.017574 0.020864 0.021103 (100000, 10) argmax 0.017896 0.020527 0.023296 0.017569 0.020447 0.021098 (100000, 10) var 0.020005 0.024198 0.024372 0.020075 0.024167 0.022415 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1023, 1023, 1023) mean 1.874816 1.963506 1.903909 1.873279 1.963859 1.903230 (1023, 1023, 1023) sum 1.875030 1.965716 1.902458 1.873566 1.960730 1.901642 (1023, 1023, 1023) min 1.878563 2.473455 2.179092 1.875174 2.482086 2.183027 (1023, 1023, 1023) max 1.879128 2.474803 2.178895 1.874831 2.482253 2.183884 (1023, 1023, 1023) argmin 1.921800 2.476629 2.174831 1.923987 2.472641 2.170453 (1023, 1023, 1023) argmax 1.922605 2.476688 2.177927 1.923366 2.472808 2.172979 (1023, 1023, 1023) var 1.972606 3.088695 2.758797 1.978679 3.095658 2.762243 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1023, 1023, 255) mean 0.489984 0.500954 0.492957 0.489891 0.500654 0.491971 (1023, 1023, 255) sum 0.490228 0.500764 0.492289 0.489624 0.501089 0.492824 (1023, 1023, 255) min 0.491457 0.563560 0.553334 0.490355 0.564709 0.554754 (1023, 1023, 255) max 0.491396 0.563628 0.553345 0.490017 0.565004 0.554947 (1023, 1023, 255) argmin 0.503666 0.561512 0.551831 0.503845 0.560972 0.551017 (1023, 1023, 255) argmax 0.503602 0.561185 0.551407 0.504328 0.561267 0.551448 (1023, 1023, 255) var 0.510844 0.709452 0.701630 0.512693 0.710365 0.701965 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1023, 1023, 377) mean 0.707439 0.727646 0.712019 0.706769 0.727101 0.711632 (1023, 1023, 377) sum 0.707780 0.727453 0.711554 0.706807 0.726656 0.711729 (1023, 1023, 377) min 0.709423 0.819809 0.794379 0.707847 0.822086 0.796664 (1023, 1023, 377) max 0.709297 0.819780 0.794308 0.707566 0.821913 0.796690 (1023, 1023, 377) argmin 0.725028 0.817088 0.791695 0.726039 0.816445 0.790828 (1023, 1023, 377) argmax 0.725301 0.817011 0.791420 0.726040 0.816917 0.791143 (1023, 1023, 377) var 0.740859 1.034165 1.006712 0.743413 1.035506 1.007638 ``` Differential Revision: [D85022826](https://our.internmc.facebook.com/intern/diff/D85022826) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164790 Approved by: https://github.com/ngimel, https://github.com/eqy |
||
|
|
602ace5eb4 |
Revert "[ATen] Fix CUDA reduction warp shuffle order (#164790)"
This reverts commit
|
||
|
|
e925dfcc6b |
Enable all SIM rules except disabled ones (#164645)
`SIM` rules are useful for simplifying boolean expressions and enhances code readability. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164645 Approved by: https://github.com/ezyang, https://github.com/mlazos |
||
|
|
36371b8ec7 |
[ATen] Fix CUDA reduction warp shuffle order (#164790)
Typical warp shuffle reduction has the following pattern: <img width="1138" height="501" alt="image" src="https://github.com/user-attachments/assets/3bd176dc-0ad2-4df6-90c7-06e467337166" /> which is exhibited in Triton generated by torch.compile: <img width="663" height="403" alt="image" src="https://github.com/user-attachments/assets/7f9f36cd-b9eb-44c1-879e-b469668a2ea8" /> Switch the warp shuffle order to make bitwise equivalence between the 2 easier. PTX difference between old and new, we see a few extra instructions: https://www.diffchecker.com/h6ly3INC/ Comparing the performance on different reduction operations, we see minimal differences. New represents the changes in this PR, old represents the past warp shuffle order: ``` Tensor Shape Operation New all dims (ms) New dim=0 (ms) New dim=1 (ms) Old all dims (ms) Old dim=0 (ms) Old dim=1 (ms) ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 1024) mean 0.015817 0.016259 0.013642 0.015990 0.016258 0.013631 (1024, 1024) sum 0.015917 0.015906 0.013359 0.015707 0.016266 0.013226 (1024, 1024) min 0.016021 0.024625 0.015631 0.015761 0.024485 0.015317 (1024, 1024) max 0.016349 0.024971 0.015972 0.015771 0.025001 0.015314 (1024, 1024) argmin 0.018070 0.024448 0.015578 0.018135 0.025370 0.015322 (1024, 1024) argmax 0.018427 0.024859 0.015932 0.018164 0.024452 0.015639 (1024, 1024) var 0.020078 0.026413 0.020295 0.020199 0.026381 0.020214 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 2048) mean 0.023826 0.023726 0.022273 0.023236 0.023776 0.022248 (2048, 2048) sum 0.023840 0.023355 0.021974 0.023294 0.023354 0.021884 (2048, 2048) min 0.024519 0.041263 0.024620 0.023292 0.041491 0.024358 (2048, 2048) max 0.024509 0.041670 0.024277 0.023334 0.041231 0.024395 (2048, 2048) argmin 0.026125 0.041282 0.024567 0.026772 0.041773 0.024296 (2048, 2048) argmax 0.026117 0.041487 0.024572 0.026412 0.041477 0.024273 (2048, 2048) var 0.026603 0.048581 0.031308 0.027587 0.048603 0.030860 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 4096) mean 0.053927 0.057070 0.054073 0.053028 0.057544 0.053935 (4096, 4096) sum 0.053604 0.057410 0.054451 0.053076 0.057033 0.054266 (4096, 4096) min 0.054293 0.109122 0.058363 0.053821 0.108689 0.058382 (4096, 4096) max 0.054258 0.108035 0.058703 0.053492 0.110552 0.058376 (4096, 4096) argmin 0.056805 0.111167 0.058301 0.056836 0.112325 0.058292 (4096, 4096) argmax 0.056488 0.110958 0.058636 0.056844 0.111000 0.057928 (4096, 4096) var 0.058936 0.141755 0.068693 0.059735 0.141284 0.068500 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 8192) mean 0.145552 0.148082 0.138647 0.145364 0.147818 0.138207 (8192, 8192) sum 0.145985 0.147900 0.138714 0.145755 0.148031 0.138616 (8192, 8192) min 0.146566 0.205359 0.192739 0.145611 0.205237 0.182335 (8192, 8192) max 0.146526 0.204844 0.193050 0.146073 0.205457 0.182697 (8192, 8192) argmin 0.150190 0.206605 0.192543 0.150654 0.206847 0.182007 (8192, 8192) argmax 0.150481 0.206368 0.192535 0.150845 0.206430 0.182022 (8192, 8192) var 0.150884 0.184546 0.203900 0.151594 0.184172 0.197983 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1, 1024, 128) mean 0.014293 0.008119 0.014533 0.013861 0.008022 0.014449 (1, 1024, 128) sum 0.014039 0.007877 0.014111 0.014219 0.008227 0.014045 (1, 1024, 128) min 0.014159 0.011354 0.023493 0.014271 0.010862 0.023644 (1, 1024, 128) max 0.014154 0.011027 0.023368 0.014259 0.011234 0.023692 (1, 1024, 128) argmin 0.016403 0.005677 0.023328 0.016273 0.005683 0.024073 (1, 1024, 128) argmax 0.016734 0.005675 0.023437 0.016580 0.005318 0.023331 (1, 1024, 128) var 0.018338 0.009549 0.025538 0.018528 0.009391 0.024777 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (5, 1024, 128) mean 0.014873 0.010131 0.015546 0.015123 0.010131 0.015481 (5, 1024, 128) sum 0.015334 0.009673 0.015824 0.014736 0.009671 0.015438 (5, 1024, 128) min 0.015047 0.013252 0.024573 0.014803 0.013163 0.024551 (5, 1024, 128) max 0.015050 0.013339 0.024197 0.014810 0.013525 0.024230 (5, 1024, 128) argmin 0.017341 0.012737 0.024306 0.017471 0.012379 0.024991 (5, 1024, 128) argmax 0.017345 0.012411 0.024421 0.017422 0.012471 0.024237 (5, 1024, 128) var 0.019973 0.011453 0.026188 0.020050 0.011438 0.026282 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (10, 1024, 128) mean 0.016976 0.011575 0.016831 0.016722 0.011927 0.017173 (10, 1024, 128) sum 0.017039 0.011841 0.017159 0.016385 0.011860 0.016753 (10, 1024, 128) min 0.017036 0.015331 0.026770 0.016944 0.015205 0.027166 (10, 1024, 128) max 0.017369 0.015348 0.027077 0.016531 0.015716 0.026819 (10, 1024, 128) argmin 0.019203 0.014447 0.026813 0.018994 0.014497 0.027313 (10, 1024, 128) argmax 0.019563 0.014795 0.027140 0.019460 0.014912 0.026733 (10, 1024, 128) var 0.020529 0.014316 0.030405 0.020719 0.013960 0.029964 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (100, 1024, 128) mean 0.045046 0.039168 0.046082 0.044839 0.039217 0.045782 (100, 1024, 128) sum 0.045094 0.039150 0.045777 0.044496 0.039542 0.046083 (100, 1024, 128) min 0.045768 0.054466 0.076244 0.044915 0.053943 0.076599 (100, 1024, 128) max 0.045748 0.054459 0.076188 0.044931 0.053949 0.076856 (100, 1024, 128) argmin 0.048275 0.054046 0.076647 0.048694 0.054105 0.077004 (100, 1024, 128) argmax 0.048267 0.054395 0.077401 0.048691 0.054131 0.076751 (100, 1024, 128) var 0.049710 0.043254 0.083077 0.050971 0.043251 0.082378 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1000, 1000, 100) mean 0.202312 0.196723 0.197765 0.201774 0.196641 0.197459 (1000, 1000, 100) sum 0.202651 0.196682 0.197736 0.202175 0.196313 0.197523 (1000, 1000, 100) min 0.203022 0.264762 0.269200 0.202729 0.264129 0.268694 (1000, 1000, 100) max 0.202864 0.264396 0.269388 0.202486 0.263896 0.268720 (1000, 1000, 100) argmin 0.226727 0.263781 0.268651 0.226597 0.264676 0.268983 (1000, 1000, 100) argmax 0.226412 0.264469 0.269090 0.226570 0.264595 0.269178 (1000, 1000, 100) var 0.243223 0.204079 0.216096 0.241942 0.204079 0.215925 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (10000, 100) mean 0.016193 0.020277 0.014316 0.016152 0.020324 0.013712 (10000, 100) sum 0.016289 0.020237 0.014034 0.016168 0.020265 0.013708 (10000, 100) min 0.016046 0.030872 0.019609 0.016208 0.030867 0.018627 (10000, 100) max 0.016369 0.030835 0.019257 0.016218 0.030861 0.018209 (10000, 100) argmin 0.017957 0.031171 0.019517 0.018050 0.031556 0.018077 (10000, 100) argmax 0.017961 0.031658 0.019521 0.018060 0.031564 0.018087 (10000, 100) var 0.020393 0.035652 0.019339 0.020144 0.035987 0.019171 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (100000, 10) mean 0.015718 0.016576 0.016555 0.015999 0.016246 0.014869 (100000, 10) sum 0.015833 0.016247 0.016572 0.016007 0.016627 0.014872 (100000, 10) min 0.015888 0.020510 0.023920 0.015671 0.020821 0.021417 (100000, 10) max 0.015889 0.020479 0.023918 0.016077 0.020386 0.021421 (100000, 10) argmin 0.018233 0.020863 0.023647 0.017574 0.020864 0.021103 (100000, 10) argmax 0.017896 0.020527 0.023296 0.017569 0.020447 0.021098 (100000, 10) var 0.020005 0.024198 0.024372 0.020075 0.024167 0.022415 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1023, 1023, 1023) mean 1.874816 1.963506 1.903909 1.873279 1.963859 1.903230 (1023, 1023, 1023) sum 1.875030 1.965716 1.902458 1.873566 1.960730 1.901642 (1023, 1023, 1023) min 1.878563 2.473455 2.179092 1.875174 2.482086 2.183027 (1023, 1023, 1023) max 1.879128 2.474803 2.178895 1.874831 2.482253 2.183884 (1023, 1023, 1023) argmin 1.921800 2.476629 2.174831 1.923987 2.472641 2.170453 (1023, 1023, 1023) argmax 1.922605 2.476688 2.177927 1.923366 2.472808 2.172979 (1023, 1023, 1023) var 1.972606 3.088695 2.758797 1.978679 3.095658 2.762243 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1023, 1023, 255) mean 0.489984 0.500954 0.492957 0.489891 0.500654 0.491971 (1023, 1023, 255) sum 0.490228 0.500764 0.492289 0.489624 0.501089 0.492824 (1023, 1023, 255) min 0.491457 0.563560 0.553334 0.490355 0.564709 0.554754 (1023, 1023, 255) max 0.491396 0.563628 0.553345 0.490017 0.565004 0.554947 (1023, 1023, 255) argmin 0.503666 0.561512 0.551831 0.503845 0.560972 0.551017 (1023, 1023, 255) argmax 0.503602 0.561185 0.551407 0.504328 0.561267 0.551448 (1023, 1023, 255) var 0.510844 0.709452 0.701630 0.512693 0.710365 0.701965 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1023, 1023, 377) mean 0.707439 0.727646 0.712019 0.706769 0.727101 0.711632 (1023, 1023, 377) sum 0.707780 0.727453 0.711554 0.706807 0.726656 0.711729 (1023, 1023, 377) min 0.709423 0.819809 0.794379 0.707847 0.822086 0.796664 (1023, 1023, 377) max 0.709297 0.819780 0.794308 0.707566 0.821913 0.796690 (1023, 1023, 377) argmin 0.725028 0.817088 0.791695 0.726039 0.816445 0.790828 (1023, 1023, 377) argmax 0.725301 0.817011 0.791420 0.726040 0.816917 0.791143 (1023, 1023, 377) var 0.740859 1.034165 1.006712 0.743413 1.035506 1.007638 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/164790 Approved by: https://github.com/ngimel, https://github.com/eqy ghstack dependencies: #165494 |
||
|
|
8de85896e0 |
Enable ruff rule E721 (#165162)
`E721` checks for object type comparisons using == and other comparison operators. This is useful because it is recommended to use `is` for type comparisons. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165162 Approved by: https://github.com/Skylion007 |
||
|
|
816fb7f48d |
Revert "Enable ruff rule E721 (#165162)"
This reverts commit
|
||
|
|
9e7c19f72b |
Enable ruff rule E721 (#165162)
`E721` checks for object type comparisons using == and other comparison operators. This is useful because it is recommended to use `is` for type comparisons. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165162 Approved by: https://github.com/Skylion007 |
||
|
|
de8d81275a |
Do not decompose in functionalization/proxy tensor if autograd wouldn't have decomposed (#164939)
This fixes AOTAutograd rms_norm not being bitwise equivalent to eager, because it avoids a decomposition. You can force the decomposition by having the decomposition in the dispatch table, but if eager mode wouldn't have decomposed (because it went to the fused one), we now default to preserving the fused call by default. This largely reverts https://github.com/pytorch/pytorch/pull/103275/ for view ops. This means that in inference mode we could hit the wrong C++ kernel; if this occurs we should just SymInt'ify the C++ kernel. Another neat side effect of this change is that Inductor's generated kernels for rms_norm now have rms_norm in their name. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/164939 Approved by: https://github.com/bdhirsh |
||
|
|
5c3fe9fb30 |
Revert "Do not decompose in functionalization/proxy tensor if autograd wouldn't have decomposed (#164939)"
This reverts commit
|
||
|
|
a6fa4f9c28 |
Do not decompose in functionalization/proxy tensor if autograd wouldn't have decomposed (#164939)
This fixes AOTAutograd rms_norm not being bitwise equivalent to eager, because it avoids a decomposition. You can force the decomposition by having the decomposition in the dispatch table, but if eager mode wouldn't have decomposed (because it went to the fused one), we now default to preserving the fused call by default. This largely reverts https://github.com/pytorch/pytorch/pull/103275/ for view ops. This means that in inference mode we could hit the wrong C++ kernel; if this occurs we should just SymInt'ify the C++ kernel. Another neat side effect of this change is that Inductor's generated kernels for rms_norm now have rms_norm in their name. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/164939 Approved by: https://github.com/bdhirsh |
||
|
|
06d86e58d0 |
Revert "Do not decompose in functionalization/proxy tensor if autograd wouldn't have decomposed (#164939)"
This reverts commit
|
||
|
|
d40a9bfb8d |
Do not decompose in functionalization/proxy tensor if autograd wouldn't have decomposed (#164939)
This fixes AOTAutograd rms_norm not being bitwise equivalent to eager, because it avoids a decomposition. You can force the decomposition by having the decomposition in the dispatch table, but if eager mode wouldn't have decomposed (because it went to the fused one), we now default to preserving the fused call by default. This largely reverts https://github.com/pytorch/pytorch/pull/103275/ for view ops. This means that in inference mode we could hit the wrong C++ kernel; if this occurs we should just SymInt'ify the C++ kernel. Another neat side effect of this change is that Inductor's generated kernels for rms_norm now have rms_norm in their name. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/164939 Approved by: https://github.com/bdhirsh ghstack dependencies: #164573 |
||
|
|
5d7360bb03 |
Revert "Enable all SIM rules except disabled ones (#164645)"
This reverts commit
|
||
|
|
321e602692 |
Enable all SIM rules except disabled ones (#164645)
`SIM` rules are useful for simplifying boolean expressions and enhances code readability. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164645 Approved by: https://github.com/ezyang |
||
|
|
4660e38e5a |
write conv1d decomposition (#163080)
In Unified Runtime, we cannot have any fallback ops (for now). Not all conv1d ops can avoid fallbacks now, so we write a decomposition for it. it's not registered to the default decomposition table as currently only executorch/unified runtime needs it. But it might benefit inductor as well because conv2d can generate triton kernels while there's no triton codegen for conv1d. I don't know if the conv2d triton kernel will have better perf compared to aten::conv1d, so it's not registered by default yet. To register it, one just needs to do `import torch._decomp as decomp;decomp.register_decomposition(torch.ops.aten.conv1d.default, conv1d_to_conv2d)` Pull Request resolved: https://github.com/pytorch/pytorch/pull/163080 Approved by: https://github.com/angelayi |
||
|
|
fa0355c18d |
Fix full_like decomposition to preserve strides (#158898)
Summary: See original PR at: https://github.com/pytorch/pytorch/pull/144765, which landed internally but was reverted due to test failures. Addressing reviewer comments and trying again. Rollback Plan: Differential hack Revision: D78783627 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158898 Approved by: https://github.com/eellison |
||
|
|
04a393507b |
Fused RMSNorm implementation (#153666)
Relevant #72643 Benchmarked versus unfused torch implementation and torch.compile implementation. Around 9x speedup vs unfused implementation on cuda and slightly faster vs inductor compile on 5090. ```py import torch import torch.nn as nn class RMSNorm(nn.Module): def __init__(self, dim, eps=1e-5): super().__init__() self.eps = eps self.scale = nn.Parameter(torch.ones(dim)) def forward(self, x): norm_x = x.norm(2, dim=-1, keepdim=True) rms_x = norm_x * torch.rsqrt(torch.tensor(x.shape[-1], dtype=x.dtype)) x_normed = x / (rms_x + self.eps) return self.scale * x_normed def benchmark_rmsnorm_cuda(input_shape, normalized_dim, num_iterations=100, warmup_iterations=10, dtype=torch.float16): rms_norm_layer = torch.nn.RMSNorm(normalized_dim, device='cuda', dtype=dtype) input_data = torch.randn(input_shape, device='cuda', dtype=dtype) for _ in range(warmup_iterations): _ = rms_norm_layer(input_data) torch.cuda.synchronize() start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() for _ in range(num_iterations): _ = rms_norm_layer(input_data) end_event.record() torch.cuda.synchronize() elapsed_time_ms = start_event.elapsed_time(end_event) avg_time_ms = elapsed_time_ms / num_iterations print(f"--- RMSNorm CUDA Benchmark ---") print(f"Input Shape: {input_shape}") print(f"Normalized Dimension: {normalized_dim}") print(f"Benchmark Iterations: {num_iterations}") print(f"--- Fused Implementation ---") print(f"Average Time per Iteration: {avg_time_ms:.4f} ms") print(f"Total Time for {num_iterations} Iterations: {elapsed_time_ms:.3f} ms") compiled_rms_norm = torch.compile(RMSNorm(dim=normalized_dim)).cuda() for _ in range(warmup_iterations): _ = compiled_rms_norm(input_data) torch.cuda.synchronize() start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() for _ in range(num_iterations): _ = compiled_rms_norm(input_data) end_event.record() torch.cuda.synchronize() elapsed_time_ms = start_event.elapsed_time(end_event) avg_time_ms = elapsed_time_ms / num_iterations print(f"--- TorchCompile Implementation ---") print(f"Average Time per Iteration: {avg_time_ms:.4f} ms") print(f"Total Time for {num_iterations} Iterations: {elapsed_time_ms:.3f} ms") print("-" * 50) if __name__ == '__main__': parameter_sets = [ {'batch_size': 16, 'sequence_length': 256, 'hidden_features': 512, 'dtype': torch.float16}, {'batch_size': 32, 'sequence_length': 512, 'hidden_features': 768, 'dtype': torch.float16}, {'batch_size': 64, 'sequence_length': 1024, 'hidden_features': 1024, 'dtype': torch.float16}, {'batch_size': 32, 'sequence_length': 512, 'hidden_features': 768, 'dtype': torch.float32}, {'batch_size': 8, 'sequence_length': 2048, 'hidden_features': 2048, 'dtype': torch.float16}, ] num_benchmark_iterations = 200 num_warmup_iterations = 20 for params in parameter_sets: batch_size = params['batch_size'] sequence_length = params['sequence_length'] hidden_features = params['hidden_features'] data_type = params.get('dtype', torch.float16) shape = (batch_size, sequence_length, hidden_features) norm_dim_to_normalize = hidden_features print(f"Benchmarking with: BS={batch_size}, SeqLen={sequence_length}, Hidden={hidden_features}, DType={data_type}") benchmark_rmsnorm_cuda(input_shape=shape, normalized_dim=norm_dim_to_normalize, num_iterations=num_benchmark_iterations, warmup_iterations=num_warmup_iterations, dtype=data_type) ``` Here are the triton compile tests ran on a 5090 (comparing this branch vs main) ```py import torch import torch.nn as nn from torch._inductor.utils import run_and_get_code, run_fw_bw_and_get_code torch.manual_seed(0) device = torch.device("cuda") for batch in range(0, 9): for i in range(9, 16): normalized_shape_arg = (2**batch, 2**i) input_tensor = torch.randn(2**batch, 2**i, device=device, requires_grad=True) weight_tensor = torch.randn(2**batch, 2**i,device=device, requires_grad=True) model = torch.nn.functional.rms_norm compiled_model = torch.compile(model) loss = torch.randn_like(input_tensor) num_iter = 5 for j in range(num_iter): output = compiled_model(input_tensor, normalized_shape_arg, weight_tensor) output.backward(loss) start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() num_iter = 10 for j in range(num_iter): output = compiled_model(input_tensor, normalized_shape_arg, weight_tensor) output.backward(loss) end_event.record() torch.cuda.synchronize() elapsed_time_ms = start_event.elapsed_time(end_event) avg_time_ms = round(elapsed_time_ms / num_iter, 5) print(2**batch, 2**i, avg_time_ms) ``` main ``` 32 512 0.1812 32 1024 0.19021 32 2048 0.18871 32 4096 0.17019 32 8192 0.21944 32 16384 0.38871 32 32768 0.83282 64 512 0.14705 64 1024 0.13987 64 2048 0.14111 64 4096 0.21699 64 8192 0.43141 64 16384 0.90652 64 32768 2.18573 128 512 0.19361 128 1024 0.1963 128 2048 0.20122 128 4096 0.38888 128 8192 0.93795 128 16384 2.23437 128 32768 5.50079 256 512 0.16722 256 1024 0.22856 256 2048 0.39421 256 4096 0.96621 256 8192 2.48746 256 16384 5.53571 256 32768 11.97932 ``` current branch ``` 32 512 0.16328 32 1024 0.18104 32 2048 0.15508 32 4096 0.14356 32 8192 0.20111 32 16384 0.45974 32 32768 0.94799 64 512 0.16874 64 1024 0.18701 64 2048 0.16107 64 4096 0.20152 64 8192 0.46568 64 16384 0.96599 64 32768 2.21661 128 512 0.14982 128 1024 0.15565 128 2048 0.22241 128 4096 0.46128 128 8192 0.88883 128 16384 2.3097 128 32768 5.84448 256 512 0.14346 256 1024 0.2007 256 2048 0.45927 256 4096 0.87876 256 8192 2.10571 256 16384 5.73948 256 32768 12.98581 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/153666 Approved by: https://github.com/ngimel, https://github.com/albanD |
||
|
|
35f1b4ad9e |
Revert "Fused RMSNorm implementation (#153666)"
This reverts commit
|
||
|
|
15ef4f28df |
Fused RMSNorm implementation (#153666)
Relevant #72643 Benchmarked versus unfused torch implementation and torch.compile implementation. Around 9x speedup vs unfused implementation on cuda and slightly faster vs inductor compile on 5090. ```py import torch import torch.nn as nn class RMSNorm(nn.Module): def __init__(self, dim, eps=1e-5): super().__init__() self.eps = eps self.scale = nn.Parameter(torch.ones(dim)) def forward(self, x): norm_x = x.norm(2, dim=-1, keepdim=True) rms_x = norm_x * torch.rsqrt(torch.tensor(x.shape[-1], dtype=x.dtype)) x_normed = x / (rms_x + self.eps) return self.scale * x_normed def benchmark_rmsnorm_cuda(input_shape, normalized_dim, num_iterations=100, warmup_iterations=10, dtype=torch.float16): rms_norm_layer = torch.nn.RMSNorm(normalized_dim, device='cuda', dtype=dtype) input_data = torch.randn(input_shape, device='cuda', dtype=dtype) for _ in range(warmup_iterations): _ = rms_norm_layer(input_data) torch.cuda.synchronize() start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() for _ in range(num_iterations): _ = rms_norm_layer(input_data) end_event.record() torch.cuda.synchronize() elapsed_time_ms = start_event.elapsed_time(end_event) avg_time_ms = elapsed_time_ms / num_iterations print(f"--- RMSNorm CUDA Benchmark ---") print(f"Input Shape: {input_shape}") print(f"Normalized Dimension: {normalized_dim}") print(f"Benchmark Iterations: {num_iterations}") print(f"--- Fused Implementation ---") print(f"Average Time per Iteration: {avg_time_ms:.4f} ms") print(f"Total Time for {num_iterations} Iterations: {elapsed_time_ms:.3f} ms") compiled_rms_norm = torch.compile(RMSNorm(dim=normalized_dim)).cuda() for _ in range(warmup_iterations): _ = compiled_rms_norm(input_data) torch.cuda.synchronize() start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() for _ in range(num_iterations): _ = compiled_rms_norm(input_data) end_event.record() torch.cuda.synchronize() elapsed_time_ms = start_event.elapsed_time(end_event) avg_time_ms = elapsed_time_ms / num_iterations print(f"--- TorchCompile Implementation ---") print(f"Average Time per Iteration: {avg_time_ms:.4f} ms") print(f"Total Time for {num_iterations} Iterations: {elapsed_time_ms:.3f} ms") print("-" * 50) if __name__ == '__main__': parameter_sets = [ {'batch_size': 16, 'sequence_length': 256, 'hidden_features': 512, 'dtype': torch.float16}, {'batch_size': 32, 'sequence_length': 512, 'hidden_features': 768, 'dtype': torch.float16}, {'batch_size': 64, 'sequence_length': 1024, 'hidden_features': 1024, 'dtype': torch.float16}, {'batch_size': 32, 'sequence_length': 512, 'hidden_features': 768, 'dtype': torch.float32}, {'batch_size': 8, 'sequence_length': 2048, 'hidden_features': 2048, 'dtype': torch.float16}, ] num_benchmark_iterations = 200 num_warmup_iterations = 20 for params in parameter_sets: batch_size = params['batch_size'] sequence_length = params['sequence_length'] hidden_features = params['hidden_features'] data_type = params.get('dtype', torch.float16) shape = (batch_size, sequence_length, hidden_features) norm_dim_to_normalize = hidden_features print(f"Benchmarking with: BS={batch_size}, SeqLen={sequence_length}, Hidden={hidden_features}, DType={data_type}") benchmark_rmsnorm_cuda(input_shape=shape, normalized_dim=norm_dim_to_normalize, num_iterations=num_benchmark_iterations, warmup_iterations=num_warmup_iterations, dtype=data_type) ``` Here are the triton compile tests ran on a 5090 (comparing this branch vs main) ```py import torch import torch.nn as nn from torch._inductor.utils import run_and_get_code, run_fw_bw_and_get_code torch.manual_seed(0) device = torch.device("cuda") for batch in range(0, 9): for i in range(9, 16): normalized_shape_arg = (2**batch, 2**i) input_tensor = torch.randn(2**batch, 2**i, device=device, requires_grad=True) weight_tensor = torch.randn(2**batch, 2**i,device=device, requires_grad=True) model = torch.nn.functional.rms_norm compiled_model = torch.compile(model) loss = torch.randn_like(input_tensor) num_iter = 5 for j in range(num_iter): output = compiled_model(input_tensor, normalized_shape_arg, weight_tensor) output.backward(loss) start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() num_iter = 10 for j in range(num_iter): output = compiled_model(input_tensor, normalized_shape_arg, weight_tensor) output.backward(loss) end_event.record() torch.cuda.synchronize() elapsed_time_ms = start_event.elapsed_time(end_event) avg_time_ms = round(elapsed_time_ms / num_iter, 5) print(2**batch, 2**i, avg_time_ms) ``` main ``` 32 512 0.1812 32 1024 0.19021 32 2048 0.18871 32 4096 0.17019 32 8192 0.21944 32 16384 0.38871 32 32768 0.83282 64 512 0.14705 64 1024 0.13987 64 2048 0.14111 64 4096 0.21699 64 8192 0.43141 64 16384 0.90652 64 32768 2.18573 128 512 0.19361 128 1024 0.1963 128 2048 0.20122 128 4096 0.38888 128 8192 0.93795 128 16384 2.23437 128 32768 5.50079 256 512 0.16722 256 1024 0.22856 256 2048 0.39421 256 4096 0.96621 256 8192 2.48746 256 16384 5.53571 256 32768 11.97932 ``` current branch ``` 32 512 0.16328 32 1024 0.18104 32 2048 0.15508 32 4096 0.14356 32 8192 0.20111 32 16384 0.45974 32 32768 0.94799 64 512 0.16874 64 1024 0.18701 64 2048 0.16107 64 4096 0.20152 64 8192 0.46568 64 16384 0.96599 64 32768 2.21661 128 512 0.14982 128 1024 0.15565 128 2048 0.22241 128 4096 0.46128 128 8192 0.88883 128 16384 2.3097 128 32768 5.84448 256 512 0.14346 256 1024 0.2007 256 2048 0.45927 256 4096 0.87876 256 8192 2.10571 256 16384 5.73948 256 32768 12.98581 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/153666 Approved by: https://github.com/ngimel, https://github.com/eqy, https://github.com/albanD |
||
|
|
fc0376e8b1 |
[BE][2/6] fix typos in test/ (test/test_*.py) (#157636)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157636 Approved by: https://github.com/yewentao256, https://github.com/mlazos ghstack dependencies: #156311, #156609 |
||
|
|
c553c55be7 |
Revert "Fix full_like decomposition to preserve strides (#144765)"
This reverts commit
|
||
|
|
01b0f09931 |
Fix full_like decomposition to preserve strides (#144765)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144765 Approved by: https://github.com/amjames, https://github.com/jansel |
||
|
|
6401d1d53d |
Revert "Fused RMSNorm implementation (#153666)"
This reverts commit |
||
|
|
e1aee86646 |
Fused RMSNorm implementation (#153666)
Relevant #72643 Benchmarked versus unfused torch implementation and torch.compile implementation. Around 9x speedup vs unfused implementation on cuda and slightly faster vs inductor compile on 5090. ```py import torch import torch.nn as nn class RMSNorm(nn.Module): def __init__(self, dim, eps=1e-5): super().__init__() self.eps = eps self.scale = nn.Parameter(torch.ones(dim)) def forward(self, x): norm_x = x.norm(2, dim=-1, keepdim=True) rms_x = norm_x * torch.rsqrt(torch.tensor(x.shape[-1], dtype=x.dtype)) x_normed = x / (rms_x + self.eps) return self.scale * x_normed def benchmark_rmsnorm_cuda(input_shape, normalized_dim, num_iterations=100, warmup_iterations=10, dtype=torch.float16): rms_norm_layer = torch.nn.RMSNorm(normalized_dim, device='cuda', dtype=dtype) input_data = torch.randn(input_shape, device='cuda', dtype=dtype) for _ in range(warmup_iterations): _ = rms_norm_layer(input_data) torch.cuda.synchronize() start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() for _ in range(num_iterations): _ = rms_norm_layer(input_data) end_event.record() torch.cuda.synchronize() elapsed_time_ms = start_event.elapsed_time(end_event) avg_time_ms = elapsed_time_ms / num_iterations print(f"--- RMSNorm CUDA Benchmark ---") print(f"Input Shape: {input_shape}") print(f"Normalized Dimension: {normalized_dim}") print(f"Benchmark Iterations: {num_iterations}") print(f"--- Fused Implementation ---") print(f"Average Time per Iteration: {avg_time_ms:.4f} ms") print(f"Total Time for {num_iterations} Iterations: {elapsed_time_ms:.3f} ms") compiled_rms_norm = torch.compile(RMSNorm(dim=normalized_dim)).cuda() for _ in range(warmup_iterations): _ = compiled_rms_norm(input_data) torch.cuda.synchronize() start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() for _ in range(num_iterations): _ = compiled_rms_norm(input_data) end_event.record() torch.cuda.synchronize() elapsed_time_ms = start_event.elapsed_time(end_event) avg_time_ms = elapsed_time_ms / num_iterations print(f"--- TorchCompile Implementation ---") print(f"Average Time per Iteration: {avg_time_ms:.4f} ms") print(f"Total Time for {num_iterations} Iterations: {elapsed_time_ms:.3f} ms") print("-" * 50) if __name__ == '__main__': parameter_sets = [ {'batch_size': 16, 'sequence_length': 256, 'hidden_features': 512, 'dtype': torch.float16}, {'batch_size': 32, 'sequence_length': 512, 'hidden_features': 768, 'dtype': torch.float16}, {'batch_size': 64, 'sequence_length': 1024, 'hidden_features': 1024, 'dtype': torch.float16}, {'batch_size': 32, 'sequence_length': 512, 'hidden_features': 768, 'dtype': torch.float32}, {'batch_size': 8, 'sequence_length': 2048, 'hidden_features': 2048, 'dtype': torch.float16}, ] num_benchmark_iterations = 200 num_warmup_iterations = 20 for params in parameter_sets: batch_size = params['batch_size'] sequence_length = params['sequence_length'] hidden_features = params['hidden_features'] data_type = params.get('dtype', torch.float16) shape = (batch_size, sequence_length, hidden_features) norm_dim_to_normalize = hidden_features print(f"Benchmarking with: BS={batch_size}, SeqLen={sequence_length}, Hidden={hidden_features}, DType={data_type}") benchmark_rmsnorm_cuda(input_shape=shape, normalized_dim=norm_dim_to_normalize, num_iterations=num_benchmark_iterations, warmup_iterations=num_warmup_iterations, dtype=data_type) ``` Here are the triton compile tests ran on a 5090 (comparing this branch vs main) ```py import torch import torch.nn as nn from torch._inductor.utils import run_and_get_code, run_fw_bw_and_get_code torch.manual_seed(0) device = torch.device("cuda") for batch in range(0, 9): for i in range(9, 16): normalized_shape_arg = (2**batch, 2**i) input_tensor = torch.randn(2**batch, 2**i, device=device, requires_grad=True) weight_tensor = torch.randn(2**batch, 2**i,device=device, requires_grad=True) model = torch.nn.functional.rms_norm compiled_model = torch.compile(model) loss = torch.randn_like(input_tensor) num_iter = 5 for j in range(num_iter): output = compiled_model(input_tensor, normalized_shape_arg, weight_tensor) output.backward(loss) start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() num_iter = 10 for j in range(num_iter): output = compiled_model(input_tensor, normalized_shape_arg, weight_tensor) output.backward(loss) end_event.record() torch.cuda.synchronize() elapsed_time_ms = start_event.elapsed_time(end_event) avg_time_ms = round(elapsed_time_ms / num_iter, 5) print(2**batch, 2**i, avg_time_ms) ``` main ``` 32 512 0.1812 32 1024 0.19021 32 2048 0.18871 32 4096 0.17019 32 8192 0.21944 32 16384 0.38871 32 32768 0.83282 64 512 0.14705 64 1024 0.13987 64 2048 0.14111 64 4096 0.21699 64 8192 0.43141 64 16384 0.90652 64 32768 2.18573 128 512 0.19361 128 1024 0.1963 128 2048 0.20122 128 4096 0.38888 128 8192 0.93795 128 16384 2.23437 128 32768 5.50079 256 512 0.16722 256 1024 0.22856 256 2048 0.39421 256 4096 0.96621 256 8192 2.48746 256 16384 5.53571 256 32768 11.97932 ``` current branch ``` 32 512 0.16328 32 1024 0.18104 32 2048 0.15508 32 4096 0.14356 32 8192 0.20111 32 16384 0.45974 32 32768 0.94799 64 512 0.16874 64 1024 0.18701 64 2048 0.16107 64 4096 0.20152 64 8192 0.46568 64 16384 0.96599 64 32768 2.21661 128 512 0.14982 128 1024 0.15565 128 2048 0.22241 128 4096 0.46128 128 8192 0.88883 128 16384 2.3097 128 32768 5.84448 256 512 0.14346 256 1024 0.2007 256 2048 0.45927 256 4096 0.87876 256 8192 2.10571 256 16384 5.73948 256 32768 12.98581 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/153666 Approved by: https://github.com/ngimel |
||
|
|
086d146f6f |
Update ruff linter for PEP585 (#147540)
This turns on PEP585 enforcement in RUFF. - Updates the target python version - Stops ignoring UP006 warnings (PEP585) - Fixes a few issues which crept into the tree in the last day Pull Request resolved: https://github.com/pytorch/pytorch/pull/147540 Approved by: https://github.com/justinchuby, https://github.com/Skylion007 |
||
|
|
288aa87383 |
[Inductor][CPU] disable bernoulli_p decomposition (#143460)
Fix https://github.com/pytorch/pytorch/issues/142853 `fallback_random=True` should cause RNG to match between compile/eager (by having compile fall back to eager for RNG ops), but the `bernoulli_p` decompose function is not fully consistent with the eager CPU implementation. We remove the decomp and keep the version for` fallback_random=False`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/143460 Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5, https://github.com/jansel |
||
|
|
f85e238186 |
[aotd] capture rrelu_with_noise noise mutation in compile (#141867)
Rebase-copy of long standing already approved PR https://github.com/pytorch/pytorch/pull/138503 that was blocked on landing by xla build issues. Got a new PR with the same content (ghstack checkout was failing due to changed submodules) Corresponding xla PR: https://github.com/pytorch/xla/pull/8363 Pull Request resolved: https://github.com/pytorch/pytorch/pull/141867 Approved by: https://github.com/bdhirsh |
||
|
|
161425ff9f |
Added aten.bernoulli.p and aten.bernoulli.default decompositions (#139141)
Fixes #105519 Added aten.bernoulli.p decomposition and moved/rewrote aten.bernoulli.deafult to make them included in core aten decomposition. Tested the sample code in [105519](https://github.com/pytorch/pytorch/issues/105519), torch.bernoulli could be decomposed by the code snippet. Pull Request resolved: https://github.com/pytorch/pytorch/pull/139141 Approved by: https://github.com/eellison |
||
|
|
9ae19ffbed |
fix layer_norm decomp precision for cpu (#140557)
xref: https://fb.workplace.com/groups/1075192433118967/posts/1540519826586223/?comment_id=1543752356262970&reply_comment_id=1544425069529032 the issue is that our decomp needs to branch on device (it only upcasts for cpu), but the device shows up as "meta" because it is registered as a meta tensor rule. Pull Request resolved: https://github.com/pytorch/pytorch/pull/140557 Approved by: https://github.com/ezyang |
||
|
|
546318e559 |
[7/N] Don't skip ASAN on some tests (#139675)
Follows #139565 Pull Request resolved: https://github.com/pytorch/pytorch/pull/139675 Approved by: https://github.com/ezyang |
||
|
|
29297731bb |
[5/N] Don't skip ASAN on some tests (#139265)
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/139265 Approved by: https://github.com/ezyang |
||
|
|
1f32a1fb80 |
Replace torch.export default decomp table to be lazily populated (#137650)
In this PR, we implement lazy dictionary for export decomp behaviour for following reasons: 1. Custom op loading can happen after import time, as a result, the decomp table might not be able to pick up the decomp. Therefore we try to delay materialization as late as possible. I intentionally seperated out the core_aten_decomp to not have any custom CIA ops in this PR to mitigate the risk of getting reverted but in the future, core_aten_decomp under torch/_decomp will exist as an alias to official export table (torch.export.default_decompositions) Differential Revision: [D64140807](https://our.internmc.facebook.com/intern/diff/D64140807) Pull Request resolved: https://github.com/pytorch/pytorch/pull/137650 Approved by: https://github.com/justinchuby, https://github.com/bdhirsh |
||
|
|
382fad58b3 |
Deprecate _preserve_ops and consolidate with decomp_table (#135080)
In this PR, we deprecate _preserve_ops feature in run_decomposition API. We can't kill this API completely because Executorch team depends on it. As the syncing between two repos is non-trivial, I just leave this argument as deprecated for now. In the next PR, i will immediately remove it. After this PR, run_decompositions will only decompose what's inside the decomp table and preserve the rest by default. Note that this feature is only rolled out to OSS for now. Old code path is protected under IS_FBCODE flag. Differential Revision: [D62163161](https://our.internmc.facebook.com/intern/diff/D62163161/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/135080 Approved by: https://github.com/justinchuby, https://github.com/avikchaudhuri, https://github.com/bdhirsh |
||
|
|
5d964a5eb7 |
[Export] Fix SDPA decomposition (#135297)
Summary: Update SDPA decomposition to match updated stride from D62009189 which aligns strides with the `aten._scaled_dot_product_attention_math.default`, which makes `t.permute().continuous().permute()` no longer necessary. Test Plan: CI Differential Revision: D62278378 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135297 Approved by: https://github.com/drisspg |
||
|
|
1434e0b121 |
Add a private _safe_softmax (#131060)
# Summary Changes the stance of SDPA on what to do for fully masked out rows ## Current Behavior Several PyTorch users have expressed frustration over this issue: - https://github.com/pytorch/pytorch/issues/41508 - https://github.com/pytorch/pytorch/issues/103749 - https://github.com/pytorch/pytorch/issues/103963 These are significant issues with extensive discussion but no satisfactory resolution. The PyTorch team's consensus, as stated here: https://github.com/pytorch/pytorch/issues/24816#issuecomment-524415617 Can be paraphrased as follows: When passing in fully masked out rows, attention becomes ambiguous. We have two main options: 1. Uniformly attend to all values: ```python scores[masked_out_rows] = 1 / len(row) out[masked_out_rows] = 1 / len(row) * value ``` 2. Decide that attention between no queries (masked) and no keys (masked) is meaningless: ```python output[fully_masked_rows] = NaN ``` We went with option 2. Partially because it was easier to implement, but also people argued that users can slice the output to remove the NaNs: ``` Python >fill_value = -float("inf") >row0 = torch.randn(4) >row1 = torch.tensor([(fill_value for _ in range(4)]) >matrix = torch.stack([row0, row1]).requires_grad_(True) >out = torch.softmax(matrix, 1) >out = out[0] >print(out) tensor([0.5377, 0.2729, 0.0692, 0.1201]) ``` Cool, problem solved. But what happends when you call backwards.. ```Python >out.backward(torch.ones_like(out)) >print(matrix.grad) tensor([[3.0957e-08, 1.4157e-08, 7.7802e-10, 1.3713e-08], [ nan, nan, nan, nan]]) ``` Those pesky NaNs are back! ## Why do we see NaNs today? The core of the problem revolves around using softmax function in sdpa: ```python > row = torch.tensor([(-float("inf")) for _ in range(4)]) > torch.softmax(row, 0) tensor([nan, nan, nan, nan]) ``` ## Quick Aside: Masking in Attention Attention itself doesn't have a concept of masking. The `sdpa` function has an argument called `attn_mask`, which would be more accurately named `attn_bias`. This is because we don't actually "mask" entries when computing attention. Instead, due to implementation details([performance](https://github.com/pytorch/pytorch/issues/25110#issuecomment-524519087)), we add a value to the masked-out query/key pairs. We use a large negative number (typically -inf) to decrease the attention weight, as softmax assigns more weight to larger values. ## Alternative Approaches If we use a very large negative number instead of -inf: ```python > row = torch.tensor([(-1e6) for _ in range(4)]) > torch.softmax(row, 0) tensor([0.2500, 0.2500, 0.2500, 0.2500]) ``` However if users always remembered to "slice" out their outputs i.e.: ```Python >fill_value = -1e6 >... >out.backward(torch.ones_like(out)) >print(matrix.grad) tensor([[-0.0563, -0.0564, 0.1613, -0.0486], [ 0.0000, 0.0000, 0.0000, 0.0000]]) ``` This would bring us back into a better state. ## A Third Option We don't necessarily need to alter the behavior of softmax for -inf or very large negative numbers. The fundamental goal is to exclude certain query/key pairs from attention, regardless of the underlying implementation. This PR implements the new semantic for masking w/ attention in fully masked-out rows: ```python out[masked_out_rows] = 0 ``` **Important Note**: This idea isn't entirely new. The [MaskedTensor](https://pytorch.org/tutorials/prototype/maskedtensor_overview#safe-softmax) prototype, a tensor subclass, was designed to handle such cases. However, it remains a prototype feature and hasn't gained widespread adoption. ## Details This PR stack does 3 things: 1. Adds a PRIVATE _safe_softmax op 2. Updates semantic for flash_cpu fused kernel 3. Updates semantic for efficient_cuda fused kernel _safe_softmax is not supposed to be used generically and is only meant to be used within the context of SDPA. Due to this fact instead of decomposing softmax and checking for -inf rows we instead "cheat" and use nan_to_num. Why I think this is okay? (please find a counter point if avail) There are multiple ways NaNs can emerge. For the fully masked out rows case nan_to_num works. But what if there were other NaNs, wouldn't this silently remove them? The only case that this can happen is if the input itself had a NaN or an Inf For example: ```Python a = torch.ones([4], requires_grad=False, dtype=torch.float16) a[1] = torch.finfo(torch.float16).max print(a.softmax(-1)) ``` Will return `tensor([0., 1., 0., 0.], dtype=torch.float16)` Where ```Python a = torch.ones([4], requires_grad=False, dtype=torch.float16) a[1] = float("inf") a.softmax(-1) ``` returns: `tensor([nan, nan, nan, nan], dtype=torch.float16)` If we dont want to even allow for the possibility of "inf" or "NaN" attention scores to be converted to 0 then we can implemented it something like this ```Python max = torch.max(a, dim=-1, keepdim=True) exp = torch.exp(a - max.values) denom = torch.sum(exp, dim=-1, keepdim=True) softmax = exp / denom softmax = torch.where(max.values == float('-inf'), 0.0, softmax) ``` however we would be paying for this in math performance. ## Why Now I think one point that has substantially changed where PyTorch should lie on this argument is the fact that we have fused implementations for SDPA now. And these fused implementations allow us to easily and performantly support this new semantic. Pull Request resolved: https://github.com/pytorch/pytorch/pull/131060 Approved by: https://github.com/jbschlosser |
||
|
|
c7cfa51721 |
Always use high precision for SDPA math backend (#128922)
Summary: feikou observed the big numerical gaps when using math backend on AMD and NV GPUs. It's mainly because we are not using higher precision FP32 for the intermediate accumulated/materialized parts. Since math backend is expected to be slower anyways, and we expect math backend to generate the correct reference result, I think it should be worth to upcast FP16/BF16 input to FP32, and do FP32/TF32 computations, and then downcast FP32 output back to FP16/BF16. Differential Revision: D58710805 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128922 Approved by: https://github.com/xw285cornell, https://github.com/drisspg |
||
|
|
4226ed1585 |
[BE] Format uncategorized Python files with ruff format (#132576)
Remove patterns `**`, `test/**`, and `torch/**` in `tools/linter/adapters/pyfmt_linter.py` and run `lintrunner`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/132576 Approved by: https://github.com/ezyang, https://github.com/Skylion007 ghstack dependencies: #132574 |
||
|
|
59b73079a0 |
Revert "Always use high precision for SDPA math backend (#128922)"
This reverts commit
|
||
|
|
fbf3bc0a60 |
Always use high precision for SDPA math backend (#128922)
Summary: feikou observed the big numerical gaps when using math backend on AMD and NV GPUs. It's mainly because we are not using higher precision FP32 for the intermediate accumulated/materialized parts. Since math backend is expected to be slower anyways, and we expect math backend to generate the correct reference result, I think it should be worth to upcast FP16/BF16 input to FP32, and do FP32/TF32 computations, and then downcast FP32 output back to FP16/BF16. Differential Revision: D58710805 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128922 Approved by: https://github.com/xw285cornell, https://github.com/drisspg |
||
|
|
221350e3a4 |
Add None return type to init -- tests (#132352)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132352 Approved by: https://github.com/ezyang ghstack dependencies: #132335, #132351 |
||
|
|
baa4c9ca46 |
Optimize aten.cat calls of a repeated element (#132081)
This was a particular problem for a model I saw which would have a large number of repeats, making compilation slow. Pull Request resolved: https://github.com/pytorch/pytorch/pull/132081 Approved by: https://github.com/shunting314 |
||
|
|
43a6d20883 |
Add decomposition for reflection_pad{1,2,3}d_backward (#130299)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130299 Approved by: https://github.com/lezcano ghstack dependencies: #130130 |
||
|
|
ba48cf6535 |
[BE][Easy][6/19] enforce style for empty lines in import segments in test/ (#129757)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter. You can review these PRs via: ```bash git diff --ignore-all-space --ignore-blank-lines HEAD~1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/129757 Approved by: https://github.com/ezyang |
||
|
|
c8ab2e8b63 |
Set seed per sample for OpInfo tests + support for restricting to a single sample input (#128238)
This PR:
* Sets a random seed before generating each sample for an OpInfo test. It does this by intercepting the sample input iterator via `TrackedInputIter`, optionally setting the seed to a test name specific seed before each iterator call (default is to set the seed).
* Some quick and dirty benchmarking shows (hopefully) negligible overhead from setting the random seed before each sample input generation. For a trivial (single assert) test that uses `@ops`:
* Uncovered a bunch of test issues:
* Test breakdown (>100 total)
* A lot of tolerance issues (tweaked tolerance values to fix)
* 1 broken OpInfo (`sample_inputs_masked_fill` was generating a sample of the wrong dtype)
* 3 actually broken semantics (for masked tensor; added xfails)
* 4 Jacobian mismatches (added xfails)
* 2 nan results (skip for now, need fixing)
* 3 results too far from reference result (add xfails)
* Skips MPS tests for now (there are so many failures!). Those will default to the old behavior.
**before (no seed setting):**
```
real 0m21.306s
user 0m19.053s
sys 0m5.192s
```
**after (with seed setting):**
```
real 0m21.905s
user 0m19.578s
sys 0m5.390s
```
* Utilizing the above for reproducible sample input generation, adds support for restricting the iterator to a single sample input. This is done via an env var `PYTORCH_OPINFO_SAMPLE_INPUT_INDEX` and its usage is included in the repro command.
```
======================================================================
ERROR: test_bar_add_cuda_uint8 (__main__.TestFooCUDA.test_bar_add_cuda_uint8)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/home/jbschlosser/branches/testing_updates/torch/testing/_internal/common_device_type.py", line 971, in test_wrapper
return test(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/jbschlosser/branches/testing_updates/test/test_ops.py", line 2671, in test_bar
self.assertFalse(True)
AssertionError: True is not false
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/jbschlosser/branches/testing_updates/torch/testing/_internal/common_utils.py", line 2816, in wrapper
method(*args, **kwargs)
File "/home/jbschlosser/branches/testing_updates/torch/testing/_internal/common_utils.py", line 2816, in wrapper
method(*args, **kwargs)
File "/home/jbschlosser/branches/testing_updates/torch/testing/_internal/common_device_type.py", line 419, in instantiated_test
result = test(self, **param_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jbschlosser/branches/testing_updates/torch/testing/_internal/common_utils.py", line 1426, in wrapper
fn(*args, **kwargs)
File "/home/jbschlosser/branches/testing_updates/torch/testing/_internal/common_device_type.py", line 982, in test_wrapper
raise new_e from e
Exception: Caused by sample input at index 3: SampleInput(input=Tensor[size=(10, 5), device="cuda:0", dtype=torch.uint8], args=TensorList[Tensor[size=(), device="cuda:0", dtype=torch.uint8]], kwargs={}, broadcasts_input=False, name='')
To execute this test, run the following from the base repo dir:
PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=3 python test/test_ops.py -k TestFooCUDA.test_bar_add_cuda_uint8
This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
----------------------------------------------------------------------
Ran 1 test in 0.037s
FAILED (errors=1)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128238
Approved by: https://github.com/janeyx99, https://github.com/justinchuby
|
||
|
|
8cd9b10456 |
Fix exp decomp numerics (#129154)
Our previous implementation would sometimes generate `inf` because we did not do the same numerics tricks as in eager: See comment / [link](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/core/TransformationHelper.h#L123-L144) : ``` # curand_uniform has (0,1] bounds. log(1) is 0 and exponential excludes 0. # we need log to be not 0, and not underflow when converted to half # fast __logf approximation can underflow, so set log to -epsilon/2 for 1 or close to 1 args ``` Fix for https://github.com/pytorch/pytorch/issues/127749. Added a test for non-inf, but it would be great to have more robust decomp distribution tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/129154 Approved by: https://github.com/bdhirsh, https://github.com/zou3519 |
||
|
|
39de62845a |
[decomp] Fix default values missing from inplace rrelu decomposition (#126978)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126978 Approved by: https://github.com/lezcano |
||
|
|
c165a8e71d |
Enable UFMT on test_decomp.py, test_expanded_weights.py and some files (#125117)
Part of: #123062 Ran lintrunner on: - test/test_decomp.py - test/test_deploy.py - test/test_determination.py - test/test_dlpack.py - test/test_dynamic_shapes.py - test/test_expanded_weights.py Detail: ```bash $ lintrunner -a --take UFMT --all-files ok No lint issues. Successfully applied all patches. ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/125117 Approved by: https://github.com/jansel |
||
|
|
af67704dcc |
[privateuse1] _refs.masked_fill support privateuse1 when value.device.type is cpu (#124835)
_refs.masked_fill support privateuse1 when value.device.type is cpu. 1. maybe I should consider whether this modification meets the expectations of other privateuse1 devices, 2. add TestCase Fixes #124693 Co-authored-by: albanD <desmaison.alban@gmail.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/124835 Approved by: https://github.com/albanD |
||
|
|
97ccfad915 |
Fix test_decomp test for ops with py_impl(CompositeImplicitAutograd) (#116832)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116832 Approved by: https://github.com/lezcano |