Commit Graph

355 Commits

Author SHA1 Message Date
Isuru Fernando
edcd968b51 Add out wrappers to some decompositions (#115437)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115437
Approved by: https://github.com/lezcano
2024-04-23 06:26:11 +00:00
vfdev-5
6330acae76 Refactored implementation for upsample_nearest decompostions (#122783)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122783
Approved by: https://github.com/peterbell10
2024-04-17 23:05:40 +00:00
Edward Z. Yang
60d7fbe89a Register matmul out variant so it is used (#122979)
Fixes https://github.com/pytorch/pytorch/issues/122774

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122979
Approved by: https://github.com/Chillee, https://github.com/Skylion007
2024-04-09 22:21:37 +00:00
Andrew M. James
bde1a93bc4 Add lowering for resize, decomp for resize_as. (#122317)
This has been split off from #121354 as the inplace version of these
methods prove to be rather tricky.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122317
Approved by: https://github.com/peterbell10, https://github.com/lezcano
2024-04-03 17:47:29 +00:00
vfdev-5
38946bff51 Added DispatchKey.CompositeImplicitAutograd to all upsample_nearest*.default decompositions (#122782)
Related to https://github.com/pytorch/pytorch/pull/117632#issuecomment-2021321172
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122782
Approved by: https://github.com/ezyang
2024-03-29 13:55:25 +00:00
vfdev-5
b524a404e0 Fixed support for uint8 in upsample bicubic2d decomposition (#120411)
Superseeds https://github.com/pytorch/pytorch/pull/104248

Description:
- Fixed support for uint8 for upsample bicubic2d decomposition (on `main` results are wrong, so we can tolerate the slowdown)
- Added missing clamp(0, 1) for xscale and yscale
  - slowdown for f32 on cpu. PR on nodes fusion on CPU: https://github.com/pytorch/pytorch/pull/120077 can help for upsampling cases with align corners = true
  - the slowdown mainly due to the added clamp op and also partially reduced when using torch.stack in weights computation on cpu.
- Removed lowering implementation

Benchmarks:
```
[-------------------------------------------------------------------------------------------------------------------------------------------------------- Interpolate, cpu --------------------------------------------------------------------------------------------------------------------------------------------------------]
                                                                                                                                                   |  Eager (2.4.0a0+git0c61c20) PR  |  Compiled (2.4.0a0+git0c61c20) PR  |  Compiled (2.4.0a0+git069270d) Nightly  |  speed-up PR vs Nightly  |  Eager (2.4.0a0+git069270d) Nightly
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
      Input (1, 3, 500, 400), torch.uint8, torch.contiguous_format | mode: bicubic, align_corners: True, antialias: False, osize: (256, 256)       |        613.029 (+-1.590)        |         5477.608 (+-9.027)         |           3060.314 (+-12.368)           |     0.559 (+-0.000)      |          608.735 (+-6.336)
      Input (1, 3, 500, 400), torch.uint8, torch.contiguous_format | mode: bicubic, align_corners: False, antialias: False, osize: (256, 256)      |        610.176 (+-1.428)        |        5718.503 (+-11.203)         |           3424.022 (+-12.836)           |     0.599 (+-0.000)      |          604.781 (+-6.229)
      Input (1, 3, 500, 400), torch.uint8, torch.channels_last | mode: bicubic, align_corners: True, antialias: False, osize: (256, 256)           |        325.001 (+-0.840)        |        6183.029 (+-10.893)         |            3275.032 (+-7.625)           |     0.530 (+-0.000)      |          325.693 (+-1.067)
      Input (1, 3, 500, 400), torch.uint8, torch.channels_last | mode: bicubic, align_corners: False, antialias: False, osize: (256, 256)          |        325.855 (+-1.108)        |        6391.394 (+-11.552)         |            3533.410 (+-7.666)           |     0.553 (+-0.000)      |          325.838 (+-1.457)
      Input (1, 3, 1200, 1300), torch.uint8, torch.contiguous_format | mode: bicubic, align_corners: True, antialias: False, osize: (200, 300)     |       2521.533 (+-14.857)       |        5025.217 (+-13.415)         |            2814.304 (+-6.742)           |     0.560 (+-0.000)      |         2520.308 (+-10.796)
      Input (1, 3, 1200, 1300), torch.uint8, torch.contiguous_format | mode: bicubic, align_corners: False, antialias: False, osize: (200, 300)    |       2531.204 (+-12.534)       |        5294.925 (+-11.994)         |            3147.590 (+-6.808)           |     0.594 (+-0.000)      |         2521.228 (+-11.732)
      Input (1, 3, 1200, 1300), torch.uint8, torch.channels_last | mode: bicubic, align_corners: True, antialias: False, osize: (200, 300)         |        758.352 (+-10.362)       |        5639.912 (+-14.495)         |            3014.123 (+-8.799)           |     0.534 (+-0.000)      |          756.114 (+-4.792)
      Input (1, 3, 1200, 1300), torch.uint8, torch.channels_last | mode: bicubic, align_corners: False, antialias: False, osize: (200, 300)        |        758.712 (+-5.781)        |         5927.541 (+-9.982)         |            3249.555 (+-7.226)           |     0.548 (+-0.000)      |          757.719 (+-5.653)
      Input (1, 3, 300, 400), torch.uint8, torch.contiguous_format | mode: bicubic, align_corners: True, antialias: False, osize: (600, 700)       |       1524.469 (+-12.860)       |        34321.641 (+-80.310)        |           19373.714 (+-56.351)          |     0.564 (+-0.000)      |         1518.082 (+-49.653)
      Input (1, 3, 300, 400), torch.uint8, torch.contiguous_format | mode: bicubic, align_corners: False, antialias: False, osize: (600, 700)      |       1521.746 (+-13.780)       |        35949.711 (+-81.010)        |           21782.366 (+-68.938)          |     0.606 (+-0.000)      |         1467.911 (+-15.901)
      Input (1, 3, 300, 400), torch.uint8, torch.channels_last | mode: bicubic, align_corners: True, antialias: False, osize: (600, 700)           |        712.311 (+-5.361)        |        38826.510 (+-92.267)        |           20762.314 (+-59.303)          |     0.535 (+-0.000)      |          712.669 (+-4.673)
      Input (1, 3, 300, 400), torch.uint8, torch.channels_last | mode: bicubic, align_corners: False, antialias: False, osize: (600, 700)          |        715.060 (+-4.757)        |        40269.353 (+-92.543)        |           22402.114 (+-81.574)          |     0.556 (+-0.000)      |          716.001 (+-8.945)

      Input (4, 3, 500, 400), torch.uint8, torch.contiguous_format | mode: bicubic, align_corners: True, antialias: False, osize: (256, 256)       |       2331.889 (+-29.159)       |        21541.096 (+-72.346)        |           12181.194 (+-45.288)          |     0.565 (+-0.000)      |         2304.864 (+-21.351)
      Input (4, 3, 500, 400), torch.uint8, torch.contiguous_format | mode: bicubic, align_corners: False, antialias: False, osize: (256, 256)      |       2333.697 (+-10.066)       |        22514.154 (+-57.798)        |           21709.449 (+-98.307)          |     0.964 (+-0.000)      |         2302.141 (+-13.041)
      Input (4, 3, 500, 400), torch.uint8, torch.channels_last | mode: bicubic, align_corners: True, antialias: False, osize: (256, 256)           |        1198.768 (+-5.364)       |       37652.371 (+-101.644)        |           42740.413 (+-98.571)          |     1.135 (+-0.000)      |          1197.104 (+-7.225)
      Input (4, 3, 500, 400), torch.uint8, torch.channels_last | mode: bicubic, align_corners: False, antialias: False, osize: (256, 256)          |        1196.851 (+-5.118)       |       39678.341 (+-173.750)        |           46807.738 (+-92.744)          |     1.180 (+-0.000)      |          1189.322 (+-5.681)
      Input (4, 3, 1200, 1300), torch.uint8, torch.contiguous_format | mode: bicubic, align_corners: True, antialias: False, osize: (200, 300)     |       10020.978 (+-54.855)      |        19955.290 (+-71.891)        |           11420.521 (+-53.179)          |     0.572 (+-0.000)      |         9999.583 (+-61.230)
      Input (4, 3, 1200, 1300), torch.uint8, torch.contiguous_format | mode: bicubic, align_corners: False, antialias: False, osize: (200, 300)    |       10066.441 (+-62.700)      |       21058.334 (+-183.414)        |           19986.577 (+-65.304)          |     0.949 (+-0.000)      |         10018.672 (+-59.188)
      Input (4, 3, 1200, 1300), torch.uint8, torch.channels_last | mode: bicubic, align_corners: True, antialias: False, osize: (200, 300)         |       3171.135 (+-14.635)       |        19687.864 (+-54.320)        |           23313.699 (+-57.391)          |     1.184 (+-0.000)      |         3182.191 (+-17.686)
      Input (4, 3, 1200, 1300), torch.uint8, torch.channels_last | mode: bicubic, align_corners: False, antialias: False, osize: (200, 300)        |       3181.314 (+-13.784)       |        20224.468 (+-50.827)        |          30541.963 (+-381.385)          |     1.510 (+-0.000)      |         3183.578 (+-16.203)
      Input (4, 3, 300, 400), torch.uint8, torch.contiguous_format | mode: bicubic, align_corners: True, antialias: False, osize: (600, 700)       |       5879.450 (+-31.551)       |       136918.555 (+-480.320)       |          77723.568 (+-331.766)          |     0.568 (+-0.000)      |         5726.061 (+-87.517)
      Input (4, 3, 300, 400), torch.uint8, torch.contiguous_format | mode: bicubic, align_corners: False, antialias: False, osize: (600, 700)      |       5882.869 (+-30.325)       |       143378.094 (+-513.842)       |         137244.074 (+-4827.730)         |     0.957 (+-0.000)      |         5727.679 (+-22.164)
      Input (4, 3, 300, 400), torch.uint8, torch.channels_last | mode: bicubic, align_corners: True, antialias: False, osize: (600, 700)           |       2674.937 (+-45.003)       |      244829.360 (+-1930.579)       |         271283.073 (+-2243.245)         |     1.108 (+-0.000)      |         2676.054 (+-24.632)
      Input (4, 3, 300, 400), torch.uint8, torch.channels_last | mode: bicubic, align_corners: False, antialias: False, osize: (600, 700)          |       2676.217 (+-16.601)       |      248658.668 (+-2904.952)       |         296514.520 (+-2983.281)         |     1.192 (+-0.000)      |         2682.844 (+-19.886)

      Input (1, 3, 500, 400), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: True, antialias: False, osize: (256, 256)     |        1768.437 (+-6.294)       |        2934.013 (+-28.870)         |            2520.649 (+-6.797)           |     0.859 (+-0.000)      |          1759.292 (+-5.097)
      Input (1, 3, 500, 400), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: False, antialias: False, osize: (256, 256)    |        1748.660 (+-5.550)       |         3271.104 (+-7.557)         |            2891.306 (+-7.632)           |     0.884 (+-0.000)      |          1746.341 (+-5.845)
      Input (1, 3, 500, 400), torch.float32, torch.channels_last | mode: bicubic, align_corners: True, antialias: False, osize: (256, 256)         |        2813.150 (+-6.656)       |         3258.973 (+-7.543)         |            2766.286 (+-6.473)           |     0.849 (+-0.000)      |          2805.077 (+-7.611)
      Input (1, 3, 500, 400), torch.float32, torch.channels_last | mode: bicubic, align_corners: False, antialias: False, osize: (256, 256)        |        2812.102 (+-8.211)       |         3568.780 (+-9.018)         |            3125.870 (+-7.324)           |     0.876 (+-0.000)      |          2834.178 (+-9.034)
      Input (1, 3, 1200, 1300), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: True, antialias: False, osize: (200, 300)   |        1687.975 (+-9.527)       |         2752.085 (+-9.627)         |            2373.274 (+-7.888)           |     0.862 (+-0.000)      |          1698.782 (+-8.098)
      Input (1, 3, 1200, 1300), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: False, antialias: False, osize: (200, 300)  |        1696.606 (+-8.678)       |        3056.317 (+-13.303)         |           2699.160 (+-10.638)           |     0.883 (+-0.000)      |         1684.942 (+-10.519)
      Input (1, 3, 1200, 1300), torch.float32, torch.channels_last | mode: bicubic, align_corners: True, antialias: False, osize: (200, 300)       |        2613.491 (+-9.769)       |        3176.493 (+-13.366)         |            2730.193 (+-9.573)           |     0.859 (+-0.000)      |          2625.085 (+-9.943)
      Input (1, 3, 1200, 1300), torch.float32, torch.channels_last | mode: bicubic, align_corners: False, antialias: False, osize: (200, 300)      |       2614.946 (+-34.129)       |        3465.398 (+-11.165)         |           3044.396 (+-11.447)           |     0.879 (+-0.000)      |          2627.355 (+-9.608)
      Input (1, 3, 300, 400), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: True, antialias: False, osize: (600, 700)     |       10784.549 (+-58.181)      |        18292.452 (+-59.344)        |           15909.922 (+-49.864)          |     0.870 (+-0.000)      |         10837.656 (+-51.947)
      Input (1, 3, 300, 400), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: False, antialias: False, osize: (600, 700)    |       10786.513 (+-52.308)      |        20449.038 (+-56.204)        |           18295.997 (+-54.522)          |     0.895 (+-0.000)      |         10843.751 (+-44.781)
      Input (1, 3, 300, 400), torch.float32, torch.channels_last | mode: bicubic, align_corners: True, antialias: False, osize: (600, 700)         |       17532.699 (+-64.807)      |        20425.699 (+-80.271)        |           17517.040 (+-79.705)          |     0.858 (+-0.000)      |         17595.597 (+-61.870)
      Input (1, 3, 300, 400), torch.float32, torch.channels_last | mode: bicubic, align_corners: False, antialias: False, osize: (600, 700)        |       17530.816 (+-55.131)      |        22450.080 (+-92.899)        |           19827.828 (+-77.649)          |     0.883 (+-0.000)      |         17615.934 (+-71.716)

      Input (4, 3, 500, 400), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: True, antialias: False, osize: (256, 256)     |       6875.484 (+-40.543)       |        11569.509 (+-62.462)        |          10053.350 (+-208.136)          |     0.869 (+-0.000)      |         6864.501 (+-46.747)
      Input (4, 3, 500, 400), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: False, antialias: False, osize: (256, 256)    |       6843.126 (+-44.498)       |        12915.236 (+-60.654)        |          25335.058 (+-382.640)          |     1.962 (+-0.000)      |         6899.002 (+-46.861)
      Input (4, 3, 500, 400), torch.float32, torch.channels_last | mode: bicubic, align_corners: True, antialias: False, osize: (256, 256)         |       11103.418 (+-51.318)      |        28834.389 (+-78.395)        |          37405.463 (+-581.646)          |     1.297 (+-0.000)      |         11223.012 (+-60.709)
      Input (4, 3, 500, 400), torch.float32, torch.channels_last | mode: bicubic, align_corners: False, antialias: False, osize: (256, 256)        |       11092.994 (+-70.835)      |       36597.023 (+-118.988)        |           45761.267 (+-85.051)          |     1.250 (+-0.000)      |         11104.014 (+-61.288)
      Input (4, 3, 1200, 1300), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: True, antialias: False, osize: (200, 300)   |       7106.791 (+-63.666)       |        11191.071 (+-45.402)        |           9786.037 (+-75.781)           |     0.874 (+-0.000)      |         7129.419 (+-77.674)
      Input (4, 3, 1200, 1300), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: False, antialias: False, osize: (200, 300)  |       7146.519 (+-28.376)       |        12443.571 (+-39.425)        |           20147.067 (+-74.771)          |     1.619 (+-0.000)      |         7179.622 (+-64.847)
      Input (4, 3, 1200, 1300), torch.float32, torch.channels_last | mode: bicubic, align_corners: True, antialias: False, osize: (200, 300)       |       10533.849 (+-44.227)      |       34814.909 (+-138.127)        |          42803.001 (+-114.326)          |     1.229 (+-0.000)      |         10644.039 (+-59.681)
      Input (4, 3, 1200, 1300), torch.float32, torch.channels_last | mode: bicubic, align_corners: False, antialias: False, osize: (200, 300)      |       10548.910 (+-44.221)      |       42876.940 (+-146.959)        |          49711.443 (+-139.276)          |     1.159 (+-0.000)      |         10652.375 (+-44.174)
      Input (4, 3, 300, 400), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: True, antialias: False, osize: (600, 700)     |      42814.521 (+-103.198)      |       73100.489 (+-435.262)        |          63587.659 (+-134.266)          |     0.870 (+-0.000)      |        43208.921 (+-195.287)
      Input (4, 3, 300, 400), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: False, antialias: False, osize: (600, 700)    |      42812.373 (+-103.870)      |       81769.160 (+-373.369)        |         175159.813 (+-2028.558)         |     2.142 (+-0.000)      |         43007.691 (+-96.358)
      Input (4, 3, 300, 400), torch.float32, torch.channels_last | mode: bicubic, align_corners: True, antialias: False, osize: (600, 700)         |      69955.505 (+-373.373)      |      215248.616 (+-2040.775)       |         267511.246 (+-2094.161)         |     1.243 (+-0.000)      |        70382.679 (+-594.941)
      Input (4, 3, 300, 400), torch.float32, torch.channels_last | mode: bicubic, align_corners: False, antialias: False, osize: (600, 700)        |      69852.157 (+-490.076)      |      242841.484 (+-19645.513)      |         317931.678 (+-2016.498)         |     1.309 (+-0.000)      |        70074.819 (+-352.919)

Times are in microseconds (us).

[-------------------------------------------------------------------------------------------------------------------------------------------------------- Interpolate, cuda ---------------------------------------------------------------------------------------------------------------------------------------------------------]
                                                                                                                                                     |  Eager (2.4.0a0+git0c61c20) PR  |  Compiled (2.4.0a0+git0c61c20) PR  |  Compiled (2.4.0a0+git069270d) Nightly  |  speed-up PR vs Nightly  |  Eager (2.4.0a0+git069270d) Nightly
1 threads: ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
      Input (1, 3, 2345, 2456), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: True, antialias: False, osize: (1234, 1345)   |         97.727 (+-0.018)        |          97.765 (+-0.025)          |             97.773 (+-0.027)            |     1.000 (+-0.000)      |           97.905 (+-0.040)
      Input (1, 3, 2345, 2456), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: False, antialias: False, osize: (1234, 1345)  |         97.615 (+-0.066)        |          97.332 (+-0.032)          |             97.950 (+-0.026)            |     1.006 (+-0.000)      |           97.690 (+-0.062)
      Input (1, 3, 2345, 2456), torch.float32, torch.channels_last | mode: bicubic, align_corners: True, antialias: False, osize: (1234, 1345)       |        100.635 (+-0.033)        |         125.883 (+-0.020)          |            102.499 (+-0.116)            |     0.814 (+-0.000)      |          101.103 (+-0.027)
      Input (1, 3, 2345, 2456), torch.float32, torch.channels_last | mode: bicubic, align_corners: False, antialias: False, osize: (1234, 1345)      |        100.898 (+-0.036)        |         109.717 (+-0.336)          |            102.558 (+-0.120)            |     0.935 (+-0.000)      |          101.642 (+-0.105)
      Input (4, 3, 2345, 2456), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: True, antialias: False, osize: (1234, 1345)   |        462.853 (+-0.028)        |         382.475 (+-0.047)          |            382.472 (+-0.033)            |     1.000 (+-0.000)      |          462.188 (+-0.014)
      Input (4, 3, 2345, 2456), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: False, antialias: False, osize: (1234, 1345)  |        462.783 (+-0.021)        |         382.806 (+-0.037)          |            382.563 (+-0.043)            |     0.999 (+-0.000)      |          462.089 (+-0.028)
      Input (4, 3, 2345, 2456), torch.float32, torch.channels_last | mode: bicubic, align_corners: True, antialias: False, osize: (1234, 1345)       |        466.721 (+-0.022)        |         384.438 (+-0.027)          |            384.886 (+-0.037)            |     1.001 (+-0.000)      |          467.014 (+-0.025)
      Input (4, 3, 2345, 2456), torch.float32, torch.channels_last | mode: bicubic, align_corners: False, antialias: False, osize: (1234, 1345)      |        466.993 (+-0.032)        |         384.212 (+-0.009)          |            383.946 (+-0.029)            |     0.999 (+-0.000)      |          466.575 (+-0.020)
      Input (1, 3, 1234, 1345), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: True, antialias: False, osize: (2345, 2456)   |        190.070 (+-0.082)        |         209.353 (+-1.096)          |            202.870 (+-0.888)            |     0.969 (+-0.000)      |          189.371 (+-0.164)
      Input (1, 3, 1234, 1345), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: False, antialias: False, osize: (2345, 2456)  |        190.021 (+-0.018)        |         210.504 (+-0.456)          |            201.814 (+-0.770)            |     0.959 (+-0.000)      |          189.314 (+-0.036)
      Input (1, 3, 1234, 1345), torch.float32, torch.channels_last | mode: bicubic, align_corners: True, antialias: False, osize: (2345, 2456)       |        188.860 (+-0.207)        |         336.635 (+-0.023)          |            252.026 (+-0.510)            |     0.749 (+-0.000)      |          188.860 (+-0.170)
      Input (1, 3, 1234, 1345), torch.float32, torch.channels_last | mode: bicubic, align_corners: False, antialias: False, osize: (2345, 2456)      |        188.725 (+-0.214)        |         276.329 (+-0.563)          |            251.439 (+-0.524)            |     0.910 (+-0.000)      |          188.776 (+-0.189)
      Input (4, 3, 1234, 1345), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: True, antialias: False, osize: (2345, 2456)   |        781.879 (+-0.086)        |         836.389 (+-7.177)          |            816.483 (+-6.626)            |     0.976 (+-0.000)      |          781.362 (+-0.106)
      Input (4, 3, 1234, 1345), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: False, antialias: False, osize: (2345, 2456)  |        781.824 (+-0.099)        |         840.406 (+-7.111)          |            807.530 (+-6.514)            |     0.961 (+-0.000)      |          781.307 (+-0.129)
      Input (4, 3, 1234, 1345), torch.float32, torch.channels_last | mode: bicubic, align_corners: True, antialias: False, osize: (2345, 2456)       |        769.290 (+-0.309)        |         675.498 (+-1.537)          |            688.171 (+-4.326)            |     1.019 (+-0.000)      |          769.830 (+-0.222)
      Input (4, 3, 1234, 1345), torch.float32, torch.channels_last | mode: bicubic, align_corners: False, antialias: False, osize: (2345, 2456)      |        769.240 (+-0.179)        |         675.800 (+-1.113)          |            673.176 (+-1.740)            |     0.996 (+-0.000)      |          769.935 (+-0.171)

Times are in microseconds (us).

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120411
Approved by: https://github.com/lezcano
2024-03-29 13:15:25 +00:00
andrewor14
773ae817f7 Batch Norm Consolidation (#116092)
**Summary:**

This commit simplifies the existing decomposition hierarchy
of batch norm ops by adding a single, backend agnostic op:
`batch_norm_with_update`. The existing hierarchy looks like:

```
aten.batch_norm ->
aten._batch_norm_impl_index ->
[
  aten.native_batch_norm ->
  aten._native_batch_norm_legit (export only) ->
  _batch_norm_legit_cpu/cuda (kernels, export only) ->
  _batch_norm_cpu/cuda (kernels)
] OR
[ aten.cudnn_batch_norm ] OR
[ aten.miopen_batch_norm ]
```

Aside from complexity, an important problem with the
above decomposition hierarchy is cuda numerics in
export flows. We observed significantly worse convergence
when training a mobilenetv2-like model when using the
`_batch_norm_cuda` kernel instead of the `cudnn_batch_norm`
kernel. This means users who export their models on CPU
first then move the models to cuda later may silently
see worse accuracies even when cudnn is installed,
because they are using the worse kernel. This issue is
summarized in https://github.com/pytorch/pytorch/issues/111384.

Instead, the new hierarchy proposed by consolidating
existing batch norm ops will look like:

```
aten.batch_norm ->
aten.batch_norm_with_update ->
[ _batch_norm_cpu (kernel) ] OR
[ _batch_norm_cuda (kernel) ] OR
[ cudnn_batch_norm (kernel) ] OR
[ miopen_batch_norm (kernel) ]
```

The new op `batch_norm_with_update` hides backend
implementation details and automatically picks the right
kernel based on what is installed. This commit also adds
the following variants to this op:

```
batch_norm_with_update_functional
batch_norm_with_update.out
batch_norm_no_update
batch_norm_no_update.out
batch_norm_backward
```

Note that this commit only adds this op and its variants,
but does not actually change the decomps to produce these
ops in the graph. This will be done after the 2 week FC
window, and the ops used in the old stack is planned to
be removed after the 6 month BC window.

Test Plan: `OpInfo` tests for `batch_norm_with_update`.

Reviewers: albanD, bdhirsh

Subscribers: albanD, bdhirsh, supriyar

Tasks: https://github.com/pytorch/pytorch/issues/111384

Differential Revision: [D54805279](https://our.internmc.facebook.com/intern/diff/D54805279)
Co-authored-by: Tugsbayasgalan Manlaibaatar <tmanlaibaatar@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116092
Approved by: https://github.com/bdhirsh, https://github.com/albanD
2024-03-18 21:01:30 +00:00
PyTorch MergeBot
fd0dbcd891 Revert "Batch Norm Consolidation (#116092)"
This reverts commit 7b4f70eda5.

Reverted https://github.com/pytorch/pytorch/pull/116092 on behalf of https://github.com/osalpekar due to Causes build failure in //caffe2:aten-hip (AMD build) target. See [D54707318](https://www.internalfb.com/diff/D54707318) for more details, may require internal build system changes to resolve. ([comment](https://github.com/pytorch/pytorch/pull/116092#issuecomment-1989542965))
2024-03-11 22:22:41 +00:00
BowenBao
8c96b4367a Remove opmath cast for im2col decomp (#121363)
It is unclear why opmath cast is needed for im2col decomp, given that the decomposition is mainly performing padding, slicing, indexing and shape manipulation. There is no need for performing these operations in a higher precision, and in doing so it requires more memory and yields less performance.

Sample script to demonstrate inserted cast before this change

```python
import torch
from torch._decomp.decompositions import im2col

def func(x):
    return torch.nn.functional.unfold(
        x, kernel_size=[3, 1], padding=[2, 0], dilation=1, stride=1
    )

x = torch.rand(1, 1, 5, 5, dtype=torch.float16)

eo = torch._dynamo.export(
    func, aten_graph=True, decomposition_table={torch.ops.aten.im2col.default: im2col}
)(x)
eo.graph_module.print_readable()
```

```
class GraphModule(torch.nn.Module):
    def forward(self, x):
        arg0: "f16[1, 1, s0, s0]";

        arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
        arg0_1 = arg0

        _to_copy: "f32[1, 1, s0, s0]" = torch.ops.aten._to_copy.default(arg0_1, dtype = torch.float32)
        ...
        constant_pad_nd: "f32[1, 1, s0 + 4, s0]" = torch.ops.aten.constant_pad_nd.default(_to_copy, [0, 0, 2, 2], 0.0);  _to_copy = None
        ...
        slice_1: "f32[1, 1, s0 + 4, s0]" = torch.ops.aten.slice.Tensor(constant_pad_nd, 0, 0, 9223372036854775807);  constant_pad_nd = None
        slice_2: "f32[1, 1, s0 + 4, s0]" = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 9223372036854775807);  slice_1 = None
        index: "f32[1, 1, 3, s0 + 2, 1, s0]" = torch.ops.aten.index.Tensor(slice_2, [None, None, unsqueeze_5, add_3]);  slice_2 = unsqueeze_5 = add_3 = None
        permute: "f32[1, 1, 3, 1, s0 + 2, s0]" = torch.ops.aten.permute.default(index, [0, 1, 2, 4, 3, 5]);  index = None
        ...
        view: "f32[1, 3, s0**2 + 2*s0]" = torch.ops.aten.view.default(permute, [1, 3, mul]);  permute = mul = None
        _to_copy_1: "f16[1, 3, s0**2 + 2*s0]" = torch.ops.aten._to_copy.default(view, dtype = torch.float16);  view = None
        return pytree.tree_unflatten([_to_copy_1], self._out_spec)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121363
Approved by: https://github.com/lezcano
2024-03-09 15:37:27 +00:00
Boyuan Feng
35d3adb4b0 Add ATen Op _chunk_cat and _chunk_cat.out (#121081)
# Motivation

In backward of per-parameter sharding FSDP, each rank performs reduce scatter to sync gradients across ranks. A rank chunks each gradient tensor into `world_size` slices along the 0-th dimension and concatenate all slices along the 1-th dimension. Gradient tensors will be padded before concatenation when tensor.size(0) % world_size != 0.

### Example 1
Consider `world_size=3` and tensors A (2x4), B (3x3), C (1x2):

Input tensors:
```
AAAA   BBB   CC
AAAA   BBB
       BBB
```

Reduce-scatter-copy-in Output:
```
AAAABBBCC
AAAABBB00
0000BBB00
```

### Example 2
Consider `world_size=2` and tensors A (2x4), B (3x3), C(1x2), D(4x2):

Input tensors:
```
AAAA   BBB   CC   DD
AAAA   BBB   00   DD
       BBB        DD
       000        DD
```

Reduce-scatter-copy-in first pad:
```
AAAA   BBB   CC   DD
AAAA   BBB   00   DD
       BBB        DD
       000        DD
```

Then chunk and cat along dim as the output:
```
AAAABBBBBBCCDDDD
AAAABBB00000DDDD
```

The performance of reduce-scatter-copy-in is critical to per-parameter sharding FSDP. However, reduce-scatter-copy-in via composing existing ATen ops involves `cat` and irregular `pad`, leading redundant data copies and unsatisfactory performance.

# PR
We provide aten native support for reduce-scatter-copy-in, namely `_chunk_cat()`:

```
_chunk_cat(Tensor[] tensors, int dim, int num_chunks) -> Tensor
```

This PR includes the registration of `_chunk_cat` and `_chunk_cat.out`, OpInfo tests, and basic implementation composing existing ATen ops.
In the next PR, we will add the CUDA implementation. Comparing with baselines of composing existing ATen ops, `_chunk_cat()` CUDA implementation improves copy bandwidth from 498 GB/s to 966 GB/s on a production benchmark.

## Requirements on input

1. If input tensors have different ndims, dim should be non-negative and be less than the ndims of every input tensors. If all input tensors have the same ndims, we support both negative and non-negative dim.
2. For wrapped_dim, all tensors should have the same size for 0,...,wrapped_dim-1 dimensions. No requirements for (wrapped_dim, ...)-th dimension.
3. Expect positive num_chunks
4. Expect non-empty input tensor list and each input tensor should have at least 1 element

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121081
Approved by: https://github.com/albanD
2024-03-08 21:48:12 +00:00
andrewor14
7b4f70eda5 Batch Norm Consolidation (#116092)
**Summary:**

This commit simplifies the existing decomposition hierarchy
of batch norm ops by adding a single, backend agnostic op:
`batch_norm_with_update`. The existing hierarchy looks like:

```
aten.batch_norm ->
aten._batch_norm_impl_index ->
[
  aten.native_batch_norm ->
  aten._native_batch_norm_legit (export only) ->
  _batch_norm_legit_cpu/cuda (kernels, export only) ->
  _batch_norm_cpu/cuda (kernels)
] OR
[ aten.cudnn_batch_norm ] OR
[ aten.miopen_batch_norm ]
```

Aside from complexity, an important problem with the
above decomposition hierarchy is cuda numerics in
export flows. We observed significantly worse convergence
when training a mobilenetv2-like model when using the
`_batch_norm_cuda` kernel instead of the `cudnn_batch_norm`
kernel. This means users who export their models on CPU
first then move the models to cuda later may silently
see worse accuracies even when cudnn is installed,
because they are using the worse kernel. This issue is
summarized in https://github.com/pytorch/pytorch/issues/111384.

Instead, the new hierarchy proposed by consolidating
existing batch norm ops will look like:

```
aten.batch_norm ->
aten.batch_norm_with_update ->
[ _batch_norm_cpu (kernel) ] OR
[ _batch_norm_cuda (kernel) ] OR
[ cudnn_batch_norm (kernel) ] OR
[ miopen_batch_norm (kernel) ]
```

The new op `batch_norm_with_update` hides backend
implementation details and automatically picks the right
kernel based on what is installed. This commit also adds
the following variants to this op:

```
batch_norm_with_update_functional
batch_norm_with_update.out
batch_norm_no_update
batch_norm_no_update.out
batch_norm_backward
```

Note that this commit only adds this op and its variants,
but does not actually change the decomps to produce these
ops in the graph. This will be done after the 2 week FC
window, and the ops used in the old stack is planned to
be removed after the 6 month BC window.

Test Plan: `OpInfo` tests for `batch_norm_with_update`.

Reviewers: albanD, bdhirsh

Subscribers: albanD, bdhirsh, supriyar

Tasks: https://github.com/pytorch/pytorch/issues/111384

Co-authored-by: Tugsbayasgalan Manlaibaatar <tmanlaibaatar@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116092
Approved by: https://github.com/bdhirsh, https://github.com/albanD
2024-03-08 15:07:15 +00:00
PyTorch MergeBot
b529c19bdf Revert "Batch Norm Consolidation (#116092)"
This reverts commit 5680f565d5.

Reverted https://github.com/pytorch/pytorch/pull/116092 on behalf of https://github.com/jeffdaily due to broke ROCm, PR signal was clean but trunk was not, the merge should have been blocked but wasn't ([comment](https://github.com/pytorch/pytorch/pull/116092#issuecomment-1981373237))
2024-03-06 17:10:01 +00:00
Tugsbayasgalan Manlaibaatar
5680f565d5 Batch Norm Consolidation (#116092)
**Summary:**

This commit simplifies the existing decomposition hierarchy
of batch norm ops by adding a single, backend agnostic op:
`batch_norm_with_update`. The existing hierarchy looks like:

```
aten.batch_norm ->
aten._batch_norm_impl_index ->
[
  aten.native_batch_norm ->
  aten._native_batch_norm_legit (export only) ->
  _batch_norm_legit_cpu/cuda (kernels, export only) ->
  _batch_norm_cpu/cuda (kernels)
] OR
[ aten.cudnn_batch_norm ] OR
[ aten.miopen_batch_norm ]
```

Aside from complexity, an important problem with the
above decomposition hierarchy is cuda numerics in
export flows. We observed significantly worse convergence
when training a mobilenetv2-like model when using the
`_batch_norm_cuda` kernel instead of the `cudnn_batch_norm`
kernel. This means users who export their models on CPU
first then move the models to cuda later may silently
see worse accuracies even when cudnn is installed,
because they are using the worse kernel. This issue is
summarized in https://github.com/pytorch/pytorch/issues/111384.

Instead, the new hierarchy proposed by consolidating
existing batch norm ops will look like:

```
aten.batch_norm ->
aten.batch_norm_with_update ->
[ _batch_norm_cpu (kernel) ] OR
[ _batch_norm_cuda (kernel) ] OR
[ cudnn_batch_norm (kernel) ] OR
[ miopen_batch_norm (kernel) ]
```

The new op `batch_norm_with_update` hides backend
implementation details and automatically picks the right
kernel based on what is installed. This commit also adds
the following variants to this op:

```
batch_norm_with_update_functional
batch_norm_with_update.out
batch_norm_no_update
batch_norm_no_update.out
batch_norm_backward
```

Note that this commit only adds this op and its variants,
but does not actually change the decomps to produce these
ops in the graph. This will be done after the 2 week FC
window, and the ops used in the old stack is planned to
be removed after the 6 month BC window.

Test Plan: `OpInfo` tests for `batch_norm_with_update`.

Reviewers: albanD, bdhirsh

Subscribers: albanD, bdhirsh, supriyar

Tasks: https://github.com/pytorch/pytorch/issues/111384

Co-authored-by: Tugsbayasgalan Manlaibaatar <tmanlaibaatar@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116092
Approved by: https://github.com/bdhirsh, https://github.com/albanD
2024-03-06 04:50:46 +00:00
Jane Xu
da559c98e3 Fix isin decomp and add python meta registration (#120821)
Fixes #119792

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120821
Approved by: https://github.com/malfet, https://github.com/peterbell10
2024-02-29 22:08:50 +00:00
laith sakka
d21c6eb215 Do not wrap output with input device inside _to_copy (#119868)
Fixing https://github.com/pytorch/pytorch/issues/118790

This diff revert a small part of the code that was introduced in https://github.com/pytorch/pytorch/pull/104689

The PR above added a comment that "In case of dtype promotion, fake tensor converted into tensor"
but its not always the case that a conversion in dtype causes a fake tensor to be a tensor.

When such conversion does not happen we get the following error
```
Creating a new Tensor subclass FakeTensor but the raw Tensor object is already associated to
 a python object of type FakeTensor
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119868
Approved by: https://github.com/ezyang, https://github.com/thiagocrepaldi
2024-02-28 01:51:43 +00:00
Isuru Fernando
435063aa89 Decomposition for upsample_linear{1d, 3d} (#114774)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114774
Approved by: https://github.com/lezcano, https://github.com/vfdev-5, https://github.com/peterbell10
2024-02-27 11:57:45 +00:00
Aaron Meurer
5ce305270b Add a decomposition for isin() (#115390)
Co-authored-by: Peter Bell <peterbell10@live.co.uk>
Co-authored-by: Mario Lezcano Casado <3291265+lezcano@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115390
Approved by: https://github.com/peterbell10
2024-02-14 03:03:42 +00:00
Edward Z. Yang
52de407b6c Avoid performing replacements when it would unrefine ranges (#117356)
Fixes https://github.com/pytorch/pytorch/issues/117268; check this issue for background.

This PR does the following:

* Do not perform a replacement if the expression we're replacing the symbol with has a less refined value range than the original. There's a little bit of trickiness around the handling for values close to INT64_MAX; when checking if a range refines another, I *only* consider the range representable in 64-bit integers. This is enough to prevent us from doing a substitution like `i0 = 10 - i1`, but it appears to still let us do the other substitutions we like, such as `i0 = i1` or `i0 = 12 * i1`
* The test above is order dependent: if we assert an equality BEFORE we have refined a range, we might be willing to do the replacement because there isn't a meaningful range. This means that it's important to mark things as sizes, before you start doing other error checking. `split_with_sizes` is adjusted accordingly. It would be good to raise an error if you get the ordering wrong, but I leave this to future work.
* It turns out this is not enough to fix AOTAutograd, because we lose the size-ness of unbacked SymInts when AOTAutograd retraces the Dynamo graph. So update deferred runtime assert insertion to also insert size-ness and value ranges annotations. Note that, in principle, it shouldn't be necessary to explicitly do the latter; these should just show up as deferred runtime asserts. That's some extra refactoring for a later day.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117356
Approved by: https://github.com/lezcano
2024-02-13 15:56:59 +00:00
PyTorch MergeBot
472500e32a Revert "Avoid performing replacements when it would unrefine ranges (#117356)"
This reverts commit 0e6b314fc2.

Reverted https://github.com/pytorch/pytorch/pull/117356 on behalf of https://github.com/huydhn due to Sorry for reverting the change but it looks like the forward fix still needs more work https://github.com/pytorch/pytorch/pull/119712, so it would be cleaner to reland them ([comment](https://github.com/pytorch/pytorch/pull/117356#issuecomment-1940032407))
2024-02-13 01:16:58 +00:00
vfdev-5
ed20e9118b Fixed hash issue in fx_graph_cse (#119567)
Description:
- Fixed issue with hash collision for `hash((primals_2, 1.0)) == hash((primals_2, 1))`

Repro code:
```python
import torch
from torch._functorch.compile_utils import fx_graph_cse

def func(inpt, osize):
    size = inpt.shape[-1]
    s1 = size - 1
    s2 = size - 1.0
    scale = s2 / (osize - 1.0)
    inpt = torch.clamp(inpt, 0, s1)
    return scale * inpt

gms = []
def toy_backend(gm, _):
    gms.append(gm)
    return gm.forward

torch._dynamo.reset()
fn = torch.compile(backend=toy_backend, dynamic=True)(func)
t = torch.rand(3, 100)
out = fn(t, 50)
gm = gms[0]

print(gm.graph)
new_fx_g = fx_graph_cse(gm.graph)
print(str(new_fx_g))
```
Original graph
```
graph():
    %s0 : torch.SymInt [num_users=0] = placeholder[target=s0]
    %s1 : torch.SymInt [num_users=0] = placeholder[target=s1]
    %l_inpt_ : torch.Tensor [num_users=2] = placeholder[target=L_inpt_]
    %l_osize_ : torch.SymInt [num_users=1] = placeholder[target=L_osize_]
    %size : [num_users=1] = call_method[target=size](args = (%l_inpt_,), kwargs = {})
    %getitem_1 : [num_users=2] = call_function[target=operator.getitem](args = (%size, 1), kwargs = {})
    %sub : [num_users=1] = call_function[target=operator.sub](args = (%getitem_1, 1), kwargs = {})
    %sub_1 : [num_users=1] = call_function[target=operator.sub](args = (%getitem_1, 1.0), kwargs = {})
    %sub_2 : [num_users=1] = call_function[target=operator.sub](args = (%l_osize_, 1.0), kwargs = {})
    %truediv : [num_users=1] = call_function[target=operator.truediv](args = (%sub_1, %sub_2), kwargs = {})
    %inpt : [num_users=1] = call_function[target=torch.clamp](args = (%l_inpt_, 0, %sub), kwargs = {})
    %mul : [num_users=1] = call_function[target=operator.mul](args = (%truediv, %inpt), kwargs = {})
    return (mul,)
```
New wrong graph where `sub_2` is replaced incorrectly with `sub`:
```
graph():
    %s0 : torch.SymInt [num_users=0] = placeholder[target=s0]
    %s1 : torch.SymInt [num_users=0] = placeholder[target=s1]
    %l_inpt_ : torch.Tensor [num_users=2] = placeholder[target=L_inpt_]
    %l_osize_ : torch.SymInt [num_users=1] = placeholder[target=L_osize_]
    %size : [num_users=1] = call_method[target=size](args = (%l_inpt_,), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%size, 1), kwargs = {})
    %sub : [num_users=2] = call_function[target=operator.sub](args = (%getitem_1, 1), kwargs = {})
    %sub_2 : [num_users=1] = call_function[target=operator.sub](args = (%l_osize_, 1.0), kwargs = {})
    %truediv : [num_users=1] = call_function[target=operator.truediv](args = (%sub, %sub_2), kwargs = {})
    %inpt : [num_users=1] = call_function[target=torch.clamp](args = (%l_inpt_, 0, %sub), kwargs = {})
    %mul : [num_users=1] = call_function[target=operator.mul](args = (%truediv, %inpt), kwargs = {})
    return (mul,)
```
With this PR the new graph is the following:
```
graph():
    %s0 : torch.SymInt [num_users=0] = placeholder[target=s0]
    %s1 : torch.SymInt [num_users=0] = placeholder[target=s1]
    %l_inpt_ : torch.Tensor [num_users=2] = placeholder[target=L_inpt_]
    %l_osize_ : torch.SymInt [num_users=1] = placeholder[target=L_osize_]
    %size : [num_users=1] = call_method[target=size](args = (%l_inpt_,), kwargs = {})
    %getitem_1 : [num_users=2] = call_function[target=operator.getitem](args = (%size, 1), kwargs = {})
    %sub : [num_users=1] = call_function[target=operator.sub](args = (%getitem_1, 1), kwargs = {})
    %sub_1 : [num_users=1] = call_function[target=operator.sub](args = (%getitem_1, 1.0), kwargs = {})
    %sub_2 : [num_users=1] = call_function[target=operator.sub](args = (%l_osize_, 1.0), kwargs = {})
    %truediv : [num_users=1] = call_function[target=operator.truediv](args = (%sub_1, %sub_2), kwargs = {})
    %inpt : [num_users=1] = call_function[target=torch.clamp](args = (%l_inpt_, 0, %sub), kwargs = {})
    %mul : [num_users=1] = call_function[target=operator.mul](args = (%truediv, %inpt), kwargs = {})
    return (mul,)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119567
Approved by: https://github.com/eellison
2024-02-12 18:52:11 +00:00
Edward Z. Yang
0e6b314fc2 Avoid performing replacements when it would unrefine ranges (#117356)
Fixes https://github.com/pytorch/pytorch/issues/117268; check this issue for background.

This PR does the following:

* Do not perform a replacement if the expression we're replacing the symbol with has a less refined value range than the original. There's a little bit of trickiness around the handling for values close to INT64_MAX; when checking if a range refines another, I *only* consider the range representable in 64-bit integers. This is enough to prevent us from doing a substitution like `i0 = 10 - i1`, but it appears to still let us do the other substitutions we like, such as `i0 = i1` or `i0 = 12 * i1`
* The test above is order dependent: if we assert an equality BEFORE we have refined a range, we might be willing to do the replacement because there isn't a meaningful range. This means that it's important to mark things as sizes, before you start doing other error checking. `split_with_sizes` is adjusted accordingly. It would be good to raise an error if you get the ordering wrong, but I leave this to future work.
* It turns out this is not enough to fix AOTAutograd, because we lose the size-ness of unbacked SymInts when AOTAutograd retraces the Dynamo graph. So update deferred runtime assert insertion to also insert size-ness and value ranges annotations. Note that, in principle, it shouldn't be necessary to explicitly do the latter; these should just show up as deferred runtime asserts. That's some extra refactoring for a later day.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117356
Approved by: https://github.com/lezcano
2024-02-09 14:43:58 +00:00
CaoE
dfdbd73360 add Half support for flash attention (#119247)
Re-open for https://github.com/pytorch/pytorch/pull/118368.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119247
Approved by: https://github.com/drisspg, https://github.com/malfet
2024-02-07 05:57:41 +00:00
Edward Z. Yang
3f0fd36835 Introduce size oblivious guards (#118579)
Fixes https://github.com/pytorch/pytorch/issues/117361

The implementation here slightly diverges from what was proposed in the issue, so I will recap what this PR is doing here. Today, when doing computations involving size-like unbacked SymInts, we assume for all operations that the compile time range of the integer is `[2, inf]`, even though at runtime we also accept zero and one.

This PR removes the carte blanche assumption, and instead does the analysis in a much more limited and controlled fashion: only for guards which we have designated as "size oblivious" are we willing to do the analysis under the assumption that the range of all size-like unbacked SymInts is `[2, inf]`; otherwise, we will faithfully only do analysis with `[0, inf]` (or whatever the user provided) bounds.

The infra pieces of this PR are:

* Remove runtime_var_to_range from torch/fx/experimental/symbolic_shapes.py; modify `_constrain_range_for_size` to refine the range without clamping min to 2, and instead add the symbol to a `size_like` set in the ShapeEnv
* When evaluating an expression, if the expression is requested to be evaluated in a `size_oblivious` way, we attempt to statically compute the value of the expression with the assumption that all symbols in `size_like` are updated to assume that they are `>= 2`.
* Add Python and C++ APIs for guarding on a SymBool in a size-oblivious way. In C++, I also need to add some helpers for performing symbolic comparisons, since the stock comparisons immediately specialize in the "normal" way.

The rest of the changes of the PR are marking various spots in PyTorch framework code as size oblivious, based on what our current test suite exercises.

As you review the places where we have marked things as size oblivious, it may become clear why I ended up not opting for the "designate a branch as the default branch when it's not statically obvious which way to go": for some of the conditions, this answer is rather non-obvious. I think potentially there is another refinement on top of this PR, which is something like "I don't care if you can't figure it out with ValueRange analysis, go down this path anyway if there are unbacked sizes involved." But even if we add this API, I think we are obligated to attempt the ValueRange analysis first, since it can lead to better outcomes sometimes (e.g., we are able to figure out that something is contiguous no matter what the unbacked size is.)

When is it permissible to mark something as size oblivious? Heuristically, it is OK anywhere in framework code if it gets you past a guard on unbacked SymInt problem. It is somewhat difficult to provide a true semantic answer, however. In particular, these annotations don't have any observational equivalence guarantee; for example, if I have `torch.empty(u0, 1).squeeze()`, we will always produce a `[u0]` size tensor, even though if `u0 == 1` PyTorch will actually produce a `[]` size tensor. The argument that I gave to Lezcano is that we are in fact defining an alternate semantics for a "special" size = 0, 1, for which we have these alternate eager mode semantics. In particular, suppose that we have a constant `special1` which semantically denotes 1, but triggers alternate handling rules. We would define `torch.empty(special1, 1).squeeze()` to always produce a `[special1]` size tensor, making its semantics coincide with unbacked SymInt semantics. In this model, the decision to designate guards as size oblivious is simply a user API question: you put them where ever you need some handling for special1! As we conservatively error out whenever it is not obvious what `special1` semantics should be, it is always valid to expand these semantics to cover more cases (although you can always choose the wrong semantics!)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118579
Approved by: https://github.com/eellison, https://github.com/lezcano
2024-02-06 19:45:32 +00:00
Mengwei Liu
1e4b408b02 [decomp] Add tests for different dtypes to SDPA decomposition (#119239)
Summary: As titled. Skipping torch.bfloat16 because for some reason the
difference is 0.01.

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119239
Approved by: https://github.com/drisspg
2024-02-06 11:17:07 +00:00
Yifu Wang
a1280f0cc6 Add an OpInfo test for split_with_sizes_copy (#118512)
Adding an `OpInfo` test for `split_with_sizes_copy` so we can use it to test [CUDA fast path for split_with_sizes_copy.out](https://github.com/pytorch/pytorch/pull/117203). Since the `OpInfo` test doesn't exist yet and introducing it requires modifications to the `CompositeExplicitAutograd` impl, adding the `OpInfo` test in a separate PR to establish a healthy baseline.

Changes made:
- Registered a batching rule for `split_with_sizes_copy`.
- Registered a decomposition for `split_with_sizes_copy`.
- Registered a DTensor prop rule for `split_with_sizes_copy`.
- Added required dtype and device checks to the composite impl.
- Added output resize to the composite impl.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118512
Approved by: https://github.com/albanD
2024-02-01 07:09:27 +00:00
Elias Ellison
e87ac82c98 Fix missing default dim param in weight norm interface decomp (#118762)
Fix for https://github.com/pytorch/pytorch/issues/118742

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118762
Approved by: https://github.com/ezyang, https://github.com/shunting314
2024-01-31 22:10:10 +00:00
soulitzer
81b55f58ce Matmul decide should_fold using has_out instead of grad_mode (#118617)
Fixes https://github.com/pytorch/pytorch/issues/118548

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118617
Approved by: https://github.com/lezcano
2024-01-31 18:34:16 +00:00
Isuru Fernando
2f7839e6db register decomposition for rsub in torch._refs (#118288)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118288
Approved by: https://github.com/lezcano
ghstack dependencies: #118398
2024-01-30 22:18:15 +00:00
Catherine Lee
4f5785b6b3 Enable possibly-undefined error code (#118533)
Fixes https://github.com/pytorch/pytorch/issues/118129

Suppressions automatically added with

```
import re

with open("error_file.txt", "r") as f:
    errors = f.readlines()

error_lines = {}
for error in errors:
    match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
    if match:
        file_path, line_number, error_type = match.groups()
        if file_path not in error_lines:
            error_lines[file_path] = {}
        error_lines[file_path][int(line_number)] = error_type

for file_path, lines in error_lines.items():
    with open(file_path, "r") as f:
        code = f.readlines()
    for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
        code[line_number - 1] = code[line_number - 1].rstrip() + f"  # type: ignore[{error_type}]\n"
    with open(file_path, "w") as f:
        f.writelines(code)
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Co-authored-by: Catherine Lee <csl@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118533
Approved by: https://github.com/Skylion007, https://github.com/zou3519
2024-01-30 21:07:01 +00:00
PyTorch MergeBot
40ece2e579 Revert "Enable possibly-undefined error code (#118533)"
This reverts commit 4f13f69a45.

Reverted https://github.com/pytorch/pytorch/pull/118533 on behalf of https://github.com/clee2000 due to sorry i'm trying to figure out a codev merge conflict, if this works i'll be back to rebase and merge ([comment](https://github.com/pytorch/pytorch/pull/118533#issuecomment-1917695185))
2024-01-30 19:00:34 +00:00
Qingpeng Li
827949cef2 accelerate binary_cross_entropy_with_logits by using log_sigmoid operator (#115539)
When I was reimplementing BCEwithLogits, I found that `log_sigmoid` operator could accelerate the function.

Simple benchmark on AMD 3600 CPU Ubuntu 22.04:
|avg time (ms)|with `pos_weight`|no `pos_weight`|
|-|-|-|
|original|1986|1658|
|this PR|1295|995|

faster 35-40%. This is probably benefited by the `log_sigmoid` vectorization code.

CUDA benchmark was not obtained, but I believe CUDA can be also benefited by reduecing kernel launches as https://github.com/pytorch/pytorch/pull/11054#issuecomment-442233714 and https://github.com/pytorch/pytorch/pull/78267#issue-1248398454 mentioned.

The simple benchmark cpp file:
[demo.txt](https://github.com/pytorch/pytorch/files/13635355/demo.txt)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115539
Approved by: https://github.com/malfet
2024-01-30 13:24:13 +00:00
Edward Z. Yang
4f13f69a45 Enable possibly-undefined error code (#118533)
Fixes https://github.com/pytorch/pytorch/issues/118129

Suppressions automatically added with

```
import re

with open("error_file.txt", "r") as f:
    errors = f.readlines()

error_lines = {}
for error in errors:
    match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
    if match:
        file_path, line_number, error_type = match.groups()
        if file_path not in error_lines:
            error_lines[file_path] = {}
        error_lines[file_path][int(line_number)] = error_type

for file_path, lines in error_lines.items():
    with open(file_path, "r") as f:
        code = f.readlines()
    for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
        code[line_number - 1] = code[line_number - 1].rstrip() + f"  # type: ignore[{error_type}]\n"
    with open(file_path, "w") as f:
        f.writelines(code)
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118533
Approved by: https://github.com/Skylion007, https://github.com/zou3519
2024-01-30 05:08:10 +00:00
albanD
24133e44b1 Fix return type hint for list types (#118238)
All single element list types are `Tensor[]` so they will always be Tuple.
I don't know of any way to easily access the pyi type and compare that to a real run so no testing here :(
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118238
Approved by: https://github.com/ezyang
2024-01-25 23:35:20 +00:00
PyTorch MergeBot
8dc421a6b4 Revert "accelerate binary_cross_entropy_with_logits by using log_sigmoid operator (#115539)"
This reverts commit 03b12e56c7.

Reverted https://github.com/pytorch/pytorch/pull/115539 on behalf of https://github.com/DanilBaibak due to Break internal build ([comment](https://github.com/pytorch/pytorch/pull/115539#issuecomment-1904157729))
2024-01-22 14:48:35 +00:00
Qingpeng Li
03b12e56c7 accelerate binary_cross_entropy_with_logits by using log_sigmoid operator (#115539)
When I was reimplementing BCEwithLogits, I found that `log_sigmoid` operator could accelerate the function.

Simple benchmark on AMD 3600 CPU Ubuntu 22.04:
|avg time (ms)|with `pos_weight`|no `pos_weight`|
|-|-|-|
|original|1986|1658|
|this PR|1295|995|

faster 35-40%. This is probably benefited by the `log_sigmoid` vectorization code.

CUDA benchmark was not obtained, but I believe CUDA can be also benefited by reduecing kernel launches as https://github.com/pytorch/pytorch/pull/11054#issuecomment-442233714 and https://github.com/pytorch/pytorch/pull/78267#issue-1248398454 mentioned.

The simple benchmark cpp file:
[demo.txt](https://github.com/pytorch/pytorch/files/13635355/demo.txt)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115539
Approved by: https://github.com/lezcano
2024-01-19 14:56:43 +00:00
Digant Desai
e2830e6328 [PyTorch] SDPA decomp: actually use attn_mask (#117579)
Summary: Need to pass this along

Test Plan:
```
cd ~/fbsource/fbcode/executorch/backends/xnnpack/test
buck test fbcode//mode/dev-nosan :test_xnnpack_ops -- test_fp32_sdpa
buck run fbcode//mode/dev-nosan :test_xnnpack_models -- executorch.backends.xnnpack.test.models.llama2_et_example.TestLlama2ETExample.test_fp32
```

Reviewed By: larryliu0820

Differential Revision: D52812369

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117579
Approved by: https://github.com/larryliu0820
2024-01-17 10:26:43 +00:00
vfdev-5
f6767244cf Added meta function for _upsample_bicubic2d_aa (#117347)
This should fix remaining errors with Resize op in torchvision: https://github.com/pytorch/vision/actions/runs/7298953575?pr=8127
```
/opt/conda/envs/ci/lib/python3.8/site-packages/torch/nn/functional.py:4072: in interpolate
    return torch._C._nn._upsample_bicubic2d_aa(input, output_size, align_corners, scale_factors)
E   torch._dynamo.exc.TorchRuntimeError: Failed running call_function <function interpolate at 0x7f4443fe00d0>(*(FakeTensor(..., size=(1, s0, s1, s2)),), **{'size': [s4, floor(s3*s4/floor(s1*s3/s2))], 'mode': 'bicubic', 'align_corners': False, 'antialias': True}):
E   aten/src/ATen/RegisterCompositeImplicitAutograd.cpp:5567: SymIntArrayRef expected to contain only concrete integers
E
E   from user code:
E      File "/pytorch/vision/torchvision/transforms/v2/functional/_geometry.py", line 260, in resize_image
E       image = interpolate(
E
E   Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
E
E
E   You can suppress this exception and fall back to eager by setting:
E       import torch._dynamo
E       torch._dynamo.config.suppress_errors = True
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117347
Approved by: https://github.com/peterbell10
2024-01-16 23:33:55 +00:00
Aaron Orenstein
638f85fd67 Add default parameters to rrelu_with_noise() (#117141)
Summary:
rrelu_with_noise() was listed as having default parameters in the schema but the
actual code definition didn't have them.

The failing example was calling rrelu() which DOES have default parameters and
it passes those defaulted values to C++. Under the covers the C code was calling
the python version of rrelu_with_noise().

Although the C++ code was passing all the values to the python version of
rrelu_with_noise() the pytorch C++ -> Python dispatch code looks at the schema
and strips any parameters which match the schema's listed defaults so if the
schema shows defaults that aren't in the code it will be a problem.

Test Plan:
I added a unit test for this specific case. It would probably be better to write
a more general one to validate all the ops against their schemas - but I haven't
learned enough about the test harness to do that yet.

Fixes #115811

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117141
Approved by: https://github.com/yanboliang, https://github.com/oulgen
2024-01-12 05:32:13 +00:00
Elias Ellison
e3d4f4d14b [ProxyTensor] dedupe symbolic shapes in tracing (#116158)
Dedupes symbolic shapes in proxy tensor tracing. Reusing the existing sym shape avoids inserting spurious sym_size calls, which can interfere with pattern matching and graph passes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116158
Approved by: https://github.com/ezyang
2024-01-11 07:15:11 +00:00
Mengwei Liu
8783fe9cf3 [export] Modify SDPA decomposition to decompose _scaled_dot_product_flash_attention_for_cpu (#117097)
Summary: As titled. #115913 added
`_scaled_dot_product_flash_attention_for_cpu` and the export result of
`scaled_dot_product_attention` includes this op. Adding this
decomposition so that it's being decomposed the same way as
`_scaled_dot_product_attention_math`.

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117097
Approved by: https://github.com/lezcano
2024-01-10 23:46:14 +00:00
Elias Ellison
d6540038c0 Fix 0-dim Index in Index Copy decomp (#117065)
Fix for https://github.com/pytorch/pytorch/issues/115931

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117065
Approved by: https://github.com/jansel, https://github.com/shunting314
2024-01-10 22:13:43 +00:00
Zhengxu Chen
b3f7fdbf0a Add decomp for pad_sequence (#116285)
Summary: currently pad_sequence caused symbolic shape specialization in export which is unintended. Adding a decomp seems to work to avoid the c++ kernel which caused the specialization.

Test Plan: buck test mode/opt caffe2/test:test_export -- -r pad_sequence

Reviewed By: SherlockNoMad

Differential Revision: D52345667

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116285
Approved by: https://github.com/peterbell10, https://github.com/lezcano
2023-12-27 23:56:51 +00:00
Aaron Meurer
f08c4da86d Add a decomposition for take() (#114813)
Presumably this can close https://github.com/pytorch/pytorch/pull/109784

Also related to https://github.com/pytorch/pytorch/issues/93757 (though `take` is not listed there).

There's no bounds checking here (out of bounds indices cause a segfault or undefined behavior). Should that be added somehow?

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114813
Approved by: https://github.com/peterbell10, https://github.com/lezcano
2023-12-22 18:14:57 +00:00
vfdev-5
f727bed2e6 [inductor] Updated upsample_bilinear2d decomposition (#104182)
Description:
- Updated upsample_bilinear2d decomposition
  - added support for uint8 dtype support
  - code improvements
- Added uint8 dtype tests

Perf considerations:
- There is minor perf regression (speed-up ~0.7) on cases uint8, align_corners=True when output is smaller/equal (256, 256)
- For cases, when output is larger (256, 256) and input dtype uint8, nightly output is wrong, so IMO large perf regression (speed-up around ~0.2) should not be taken into account.

## Perfs benchmarks

```
[--------------------------------------------------------------------------------------------------------------------------------------------------------- Interpolate, cpu --------------------------------------------------------------------------------------------------------------------------------------------------------]
                                                                                                                                                    |  Eager (2.3.0a0+gitafcfdb1) PR  |  Compiled (2.3.0a0+gitafcfdb1) PR  |  Compiled (2.3.0a0+gitde89a53) Nightly  |  speed-up PR vs Nightly  |  Eager (2.3.0a0+gitde89a53) Nightly
1 threads: --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
      Input (1, 3, 500, 400), torch.uint8, torch.contiguous_format | mode: bilinear, align_corners: True, antialias: False, osize: (256, 256)       |        565.212 (+-3.548)        |        1384.210 (+-10.798)         |           1230.996 (+-32.930)           |     0.889 (+-0.000)      |          566.253 (+-1.526)
      Input (1, 3, 500, 400), torch.uint8, torch.contiguous_format | mode: bilinear, align_corners: False, antialias: False, osize: (256, 256)      |        565.404 (+-1.614)        |         1491.649 (+-7.763)         |            2974.959 (+-6.006)           |     1.994 (+-0.000)      |          566.476 (+-1.742)
      Input (1, 3, 500, 400), torch.uint8, torch.channels_last | mode: bilinear, align_corners: True, antialias: False, osize: (256, 256)           |        270.761 (+-0.861)        |         1557.777 (+-4.699)         |            1080.919 (+-4.243)           |     0.694 (+-0.000)      |          269.829 (+-0.986)
      Input (1, 3, 500, 400), torch.uint8, torch.channels_last | mode: bilinear, align_corners: False, antialias: False, osize: (256, 256)          |        270.960 (+-0.995)        |        1723.913 (+-12.433)         |            3191.938 (+-6.194)           |     1.852 (+-0.000)      |          269.962 (+-1.657)
      Input (1, 3, 500, 400), torch.float32, torch.contiguous_format | mode: bilinear, align_corners: True, antialias: False, osize: (256, 256)     |        1555.884 (+-5.169)       |         1178.753 (+-4.957)         |            1910.445 (+-5.988)           |     1.621 (+-0.000)      |          1560.804 (+-6.793)
      Input (1, 3, 500, 400), torch.float32, torch.contiguous_format | mode: bilinear, align_corners: False, antialias: False, osize: (256, 256)    |        1651.193 (+-6.952)       |         1323.466 (+-6.059)         |            3374.842 (+-8.168)           |     2.550 (+-0.000)      |          1653.497 (+-8.018)
      Input (1, 3, 500, 400), torch.float32, torch.channels_last | mode: bilinear, align_corners: True, antialias: False, osize: (256, 256)         |        978.482 (+-10.183)       |         1383.768 (+-4.341)         |            2147.841 (+-6.581)           |     1.552 (+-0.000)      |          979.983 (+-1.499)
      Input (1, 3, 500, 400), torch.float32, torch.channels_last | mode: bilinear, align_corners: False, antialias: False, osize: (256, 256)        |        1074.472 (+-5.031)       |         1414.912 (+-5.754)         |           3590.968 (+-10.042)           |     2.538 (+-0.000)      |          1074.589 (+-3.948)
      Input (4, 3, 500, 400), torch.uint8, torch.contiguous_format | mode: bilinear, align_corners: True, antialias: False, osize: (256, 256)       |        2168.703 (+-8.964)       |        5400.528 (+-26.628)         |           4777.299 (+-11.891)           |     0.885 (+-0.000)      |          2168.133 (+-7.667)
      Input (4, 3, 500, 400), torch.uint8, torch.contiguous_format | mode: bilinear, align_corners: False, antialias: False, osize: (256, 256)      |       2169.132 (+-12.618)       |        6583.866 (+-28.959)         |           11986.894 (+-45.838)          |     1.821 (+-0.000)      |         2174.488 (+-10.317)
      Input (4, 3, 500, 400), torch.uint8, torch.channels_last | mode: bilinear, align_corners: True, antialias: False, osize: (256, 256)           |        992.808 (+-6.086)        |         5985.028 (+-9.532)         |            4334.158 (+-9.423)           |     0.724 (+-0.000)      |          989.604 (+-5.499)
      Input (4, 3, 500, 400), torch.uint8, torch.channels_last | mode: bilinear, align_corners: False, antialias: False, osize: (256, 256)          |        987.618 (+-6.350)        |        6963.044 (+-28.885)         |           15441.096 (+-55.324)          |     2.218 (+-0.000)      |          985.573 (+-5.159)
      Input (4, 3, 500, 400), torch.float32, torch.contiguous_format | mode: bilinear, align_corners: True, antialias: False, osize: (256, 256)     |       6695.557 (+-35.067)       |        4657.603 (+-14.220)         |           8058.708 (+-41.684)           |     1.730 (+-0.000)      |         6714.996 (+-38.626)
      Input (4, 3, 500, 400), torch.float32, torch.contiguous_format | mode: bilinear, align_corners: False, antialias: False, osize: (256, 256)    |       7040.481 (+-39.486)       |        5445.704 (+-16.659)         |           13906.618 (+-53.298)          |     2.554 (+-0.000)      |         7034.453 (+-44.626)
      Input (4, 3, 500, 400), torch.float32, torch.channels_last | mode: bilinear, align_corners: True, antialias: False, osize: (256, 256)         |       3926.186 (+-10.660)       |        5741.433 (+-12.748)         |           9356.036 (+-40.848)           |     1.630 (+-0.000)      |         3930.598 (+-17.086)
      Input (4, 3, 500, 400), torch.float32, torch.channels_last | mode: bilinear, align_corners: False, antialias: False, osize: (256, 256)        |        4308.536 (+-9.607)       |        6122.755 (+-47.278)         |           15637.567 (+-54.392)          |     2.554 (+-0.000)      |         4307.463 (+-11.268)
      Input (1, 3, 1200, 1300), torch.uint8, torch.contiguous_format | mode: bilinear, align_corners: True, antialias: False, osize: (200, 300)     |       2512.740 (+-10.860)       |         1573.590 (+-5.061)         |            451.355 (+-1.210)            |     0.287 (+-0.000)      |         2511.727 (+-10.930)
      Input (1, 3, 1200, 1300), torch.uint8, torch.contiguous_format | mode: bilinear, align_corners: False, antialias: False, osize: (200, 300)    |       2489.926 (+-11.915)       |         1537.233 (+-4.212)         |            2501.470 (+-7.446)           |     1.627 (+-0.000)      |         2500.000 (+-12.155)
      Input (1, 3, 1200, 1300), torch.uint8, torch.channels_last | mode: bilinear, align_corners: True, antialias: False, osize: (200, 300)         |        632.032 (+-2.108)        |         1496.994 (+-4.194)         |            404.759 (+-1.064)            |     0.270 (+-0.000)      |          630.122 (+-4.086)
      Input (1, 3, 1200, 1300), torch.uint8, torch.channels_last | mode: bilinear, align_corners: False, antialias: False, osize: (200, 300)        |        629.174 (+-4.386)        |         1708.935 (+-8.817)         |            2643.296 (+-9.723)           |     1.547 (+-0.000)      |          628.388 (+-1.326)
      Input (1, 3, 1200, 1300), torch.float32, torch.contiguous_format | mode: bilinear, align_corners: True, antialias: False, osize: (200, 300)   |        4409.941 (+-8.016)       |         1160.133 (+-4.698)         |            1897.089 (+-9.392)           |     1.635 (+-0.000)      |         4450.959 (+-10.438)
      Input (1, 3, 1200, 1300), torch.float32, torch.contiguous_format | mode: bilinear, align_corners: False, antialias: False, osize: (200, 300)  |       4493.427 (+-11.703)       |         1329.226 (+-4.740)         |           2835.872 (+-12.241)           |     2.133 (+-0.000)      |          4506.973 (+-9.914)
      Input (1, 3, 1200, 1300), torch.float32, torch.channels_last | mode: bilinear, align_corners: True, antialias: False, osize: (200, 300)       |        901.712 (+-4.071)        |         1320.739 (+-5.197)         |            2207.605 (+-8.219)           |     1.671 (+-0.000)      |          904.757 (+-4.558)
      Input (1, 3, 1200, 1300), torch.float32, torch.channels_last | mode: bilinear, align_corners: False, antialias: False, osize: (200, 300)      |        990.080 (+-3.922)        |         1702.563 (+-7.909)         |           3074.196 (+-10.478)           |     1.806 (+-0.000)      |          990.482 (+-4.444)
      Input (4, 3, 1200, 1300), torch.uint8, torch.contiguous_format | mode: bilinear, align_corners: True, antialias: False, osize: (200, 300)     |       9785.550 (+-58.445)       |        6135.680 (+-33.569)         |           1628.572 (+-19.770)           |     0.265 (+-0.000)      |         9893.606 (+-62.377)
      Input (4, 3, 1200, 1300), torch.uint8, torch.contiguous_format | mode: bilinear, align_corners: False, antialias: False, osize: (200, 300)    |       9710.191 (+-57.597)       |        6066.824 (+-36.364)         |           10469.110 (+-42.775)          |     1.726 (+-0.000)      |         9919.022 (+-72.190)
      Input (4, 3, 1200, 1300), torch.uint8, torch.channels_last | mode: bilinear, align_corners: True, antialias: False, osize: (200, 300)         |       2790.356 (+-12.188)       |        6134.101 (+-28.694)         |            1576.832 (+-6.030)           |     0.257 (+-0.000)      |         2761.122 (+-11.503)
      Input (4, 3, 1200, 1300), torch.uint8, torch.channels_last | mode: bilinear, align_corners: False, antialias: False, osize: (200, 300)        |       2778.711 (+-13.603)       |        6608.528 (+-37.776)         |           10841.549 (+-49.429)          |     1.641 (+-0.000)      |         2753.037 (+-10.995)
      Input (4, 3, 1200, 1300), torch.float32, torch.contiguous_format | mode: bilinear, align_corners: True, antialias: False, osize: (200, 300)   |      45533.868 (+-102.618)      |         4962.994 (+-8.215)         |           9003.968 (+-38.179)           |     1.814 (+-0.000)      |        43531.261 (+-102.951)
      Input (4, 3, 1200, 1300), torch.float32, torch.contiguous_format | mode: bilinear, align_corners: False, antialias: False, osize: (200, 300)  |       45932.699 (+-81.207)      |        5595.682 (+-11.482)         |           12302.907 (+-50.254)          |     2.199 (+-0.000)      |         43916.455 (+-80.468)
      Input (4, 3, 1200, 1300), torch.float32, torch.channels_last | mode: bilinear, align_corners: True, antialias: False, osize: (200, 300)       |        3827.804 (+-8.057)       |        6311.580 (+-25.021)         |           11760.614 (+-51.531)          |     1.863 (+-0.000)      |         3849.959 (+-10.848)
      Input (4, 3, 1200, 1300), torch.float32, torch.channels_last | mode: bilinear, align_corners: False, antialias: False, osize: (200, 300)      |        4169.007 (+-8.452)       |        6820.716 (+-35.310)         |           15264.633 (+-49.982)          |     2.238 (+-0.000)      |         4183.875 (+-19.104)
      Input (1, 3, 300, 400), torch.uint8, torch.contiguous_format | mode: bilinear, align_corners: True, antialias: False, osize: (600, 700)       |        1306.914 (+-7.470)       |        10598.101 (+-38.410)        |           2678.031 (+-11.051)           |     0.253 (+-0.000)      |          1307.470 (+-8.519)
      Input (1, 3, 300, 400), torch.uint8, torch.contiguous_format | mode: bilinear, align_corners: False, antialias: False, osize: (600, 700)      |        1307.268 (+-8.197)       |        10161.123 (+-45.643)        |           17148.842 (+-55.402)          |     1.688 (+-0.000)      |          1308.077 (+-8.553)
      Input (1, 3, 300, 400), torch.uint8, torch.channels_last | mode: bilinear, align_corners: True, antialias: False, osize: (600, 700)           |        548.574 (+-2.157)        |        10072.806 (+-41.368)        |            2408.971 (+-6.997)           |     0.239 (+-0.000)      |          547.726 (+-1.721)
      Input (1, 3, 300, 400), torch.uint8, torch.channels_last | mode: bilinear, align_corners: False, antialias: False, osize: (600, 700)          |        546.664 (+-1.484)        |        11123.694 (+-43.636)        |           18058.070 (+-48.552)          |     1.623 (+-0.000)      |          547.151 (+-1.627)
      Input (1, 3, 300, 400), torch.float32, torch.contiguous_format | mode: bilinear, align_corners: True, antialias: False, osize: (600, 700)     |       7935.051 (+-71.022)       |        7654.533 (+-29.512)         |           12414.194 (+-87.450)          |     1.622 (+-0.000)      |         7900.056 (+-53.997)
      Input (1, 3, 300, 400), torch.float32, torch.contiguous_format | mode: bilinear, align_corners: False, antialias: False, osize: (600, 700)    |       8546.732 (+-53.118)       |        8583.572 (+-35.656)         |          19111.824 (+-166.978)          |     2.227 (+-0.000)      |         8515.433 (+-63.300)
      Input (1, 3, 300, 400), torch.float32, torch.channels_last | mode: bilinear, align_corners: True, antialias: False, osize: (600, 700)         |       6202.642 (+-34.355)       |        8915.622 (+-62.293)         |           14327.295 (+-52.188)          |     1.607 (+-0.000)      |         6213.329 (+-39.740)
      Input (1, 3, 300, 400), torch.float32, torch.channels_last | mode: bilinear, align_corners: False, antialias: False, osize: (600, 700)        |       6811.128 (+-33.747)       |        9647.316 (+-50.837)         |           20830.594 (+-62.979)          |     2.159 (+-0.000)      |         6822.512 (+-37.092)
      Input (4, 3, 300, 400), torch.uint8, torch.contiguous_format | mode: bilinear, align_corners: True, antialias: False, osize: (600, 700)       |       5079.586 (+-19.067)       |        42238.442 (+-87.643)        |           11282.141 (+-42.477)          |     0.267 (+-0.000)      |         5104.234 (+-17.706)
      Input (4, 3, 300, 400), torch.uint8, torch.contiguous_format | mode: bilinear, align_corners: False, antialias: False, osize: (600, 700)      |       5079.575 (+-16.306)       |        41512.995 (+-83.710)        |          68789.816 (+-440.001)          |     1.657 (+-0.000)      |         5097.446 (+-21.724)
      Input (4, 3, 300, 400), torch.uint8, torch.channels_last | mode: bilinear, align_corners: True, antialias: False, osize: (600, 700)           |        2039.974 (+-8.614)       |       42322.773 (+-111.866)        |           10399.237 (+-43.140)          |     0.246 (+-0.000)      |         2043.808 (+-10.707)
      Input (4, 3, 300, 400), torch.uint8, torch.channels_last | mode: bilinear, align_corners: False, antialias: False, osize: (600, 700)          |       2036.214 (+-10.083)       |        44353.281 (+-71.548)        |          73340.412 (+-324.780)          |     1.654 (+-0.000)      |          2039.000 (+-9.554)
      Input (4, 3, 300, 400), torch.float32, torch.contiguous_format | mode: bilinear, align_corners: True, antialias: False, osize: (600, 700)     |       33821.523 (+-96.639)      |        30552.094 (+-65.023)        |          49494.486 (+-872.916)          |     1.620 (+-0.000)      |         33844.404 (+-92.466)
      Input (4, 3, 300, 400), torch.float32, torch.contiguous_format | mode: bilinear, align_corners: False, antialias: False, osize: (600, 700)    |      36196.104 (+-128.169)      |        34038.432 (+-79.697)        |          75761.226 (+-905.194)          |     2.226 (+-0.000)      |         36260.473 (+-94.642)
      Input (4, 3, 300, 400), torch.float32, torch.channels_last | mode: bilinear, align_corners: True, antialias: False, osize: (600, 700)         |       24827.821 (+-77.335)      |        37006.218 (+-86.318)        |          61297.625 (+-898.192)          |     1.656 (+-0.000)      |         24823.275 (+-80.945)
      Input (4, 3, 300, 400), torch.float32, torch.channels_last | mode: bilinear, align_corners: False, antialias: False, osize: (600, 700)        |       27266.138 (+-70.262)      |        40109.475 (+-94.248)        |          92086.075 (+-404.922)          |     2.296 (+-0.000)      |         27287.992 (+-89.507)

Times are in microseconds (us).

[--------------------------------------------------------------------------------------------------------------------------------------------------------- Interpolate, cuda ---------------------------------------------------------------------------------------------------------------------------------------------------------]
                                                                                                                                                      |  Eager (2.3.0a0+gitafcfdb1) PR  |  Compiled (2.3.0a0+gitafcfdb1) PR  |  Compiled (2.3.0a0+gitde89a53) Nightly  |  speed-up PR vs Nightly  |  Eager (2.3.0a0+gitde89a53) Nightly
1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
      Input (1, 3, 2345, 2456), torch.float32, torch.contiguous_format | mode: bilinear, align_corners: True, antialias: False, osize: (1234, 1345)   |         98.259 (+-0.014)        |          97.156 (+-0.008)          |             97.443 (+-0.031)            |     1.003 (+-0.000)      |           98.248 (+-0.021)
      Input (1, 3, 2345, 2456), torch.float32, torch.contiguous_format | mode: bilinear, align_corners: False, antialias: False, osize: (1234, 1345)  |         97.048 (+-0.016)        |          97.480 (+-0.018)          |             96.819 (+-0.126)            |     0.993 (+-0.000)      |           97.045 (+-0.015)
      Input (1, 3, 2345, 2456), torch.float32, torch.channels_last | mode: bilinear, align_corners: True, antialias: False, osize: (1234, 1345)       |         97.944 (+-0.028)        |          91.686 (+-0.411)          |             93.894 (+-1.011)            |     1.024 (+-0.000)      |           97.933 (+-0.008)
      Input (1, 3, 2345, 2456), torch.float32, torch.channels_last | mode: bilinear, align_corners: False, antialias: False, osize: (1234, 1345)      |         98.008 (+-0.011)        |          91.205 (+-0.346)          |             96.854 (+-0.058)            |     1.062 (+-0.000)      |           97.203 (+-0.010)
      Input (4, 3, 2345, 2456), torch.float32, torch.contiguous_format | mode: bilinear, align_corners: True, antialias: False, osize: (1234, 1345)   |        384.318 (+-0.011)        |         382.793 (+-0.007)          |            382.472 (+-0.011)            |     0.999 (+-0.000)      |          384.701 (+-0.012)
      Input (4, 3, 2345, 2456), torch.float32, torch.contiguous_format | mode: bilinear, align_corners: False, antialias: False, osize: (1234, 1345)  |        384.266 (+-0.009)        |         385.333 (+-0.024)          |            382.554 (+-0.022)            |     0.993 (+-0.000)      |          384.386 (+-0.016)
      Input (4, 3, 2345, 2456), torch.float32, torch.channels_last | mode: bilinear, align_corners: True, antialias: False, osize: (1234, 1345)       |        383.924 (+-0.011)        |         570.071 (+-0.030)          |            545.615 (+-0.051)            |     0.957 (+-0.000)      |          384.044 (+-0.012)
      Input (4, 3, 2345, 2456), torch.float32, torch.channels_last | mode: bilinear, align_corners: False, antialias: False, osize: (1234, 1345)      |        384.184 (+-0.016)        |         560.857 (+-0.026)          |            552.447 (+-0.040)            |     0.985 (+-0.000)      |          384.063 (+-0.016)
      Input (1, 3, 1234, 1345), torch.float32, torch.contiguous_format | mode: bilinear, align_corners: True, antialias: False, osize: (2345, 2456)   |        122.188 (+-0.053)        |         116.744 (+-1.006)          |            163.762 (+-0.015)            |     1.403 (+-0.000)      |          121.874 (+-0.015)
      Input (1, 3, 1234, 1345), torch.float32, torch.contiguous_format | mode: bilinear, align_corners: False, antialias: False, osize: (2345, 2456)  |        122.156 (+-0.012)        |         182.692 (+-0.013)          |            161.653 (+-0.018)            |     0.885 (+-0.000)      |          121.926 (+-0.014)
      Input (1, 3, 1234, 1345), torch.float32, torch.channels_last | mode: bilinear, align_corners: True, antialias: False, osize: (2345, 2456)       |        105.852 (+-0.324)        |         119.545 (+-0.294)          |            190.527 (+-0.023)            |     1.594 (+-0.000)      |          105.999 (+-0.446)
      Input (1, 3, 1234, 1345), torch.float32, torch.channels_last | mode: bilinear, align_corners: False, antialias: False, osize: (2345, 2456)      |        106.507 (+-0.282)        |         120.060 (+-0.257)          |            162.330 (+-0.012)            |     1.352 (+-0.000)      |          106.567 (+-0.385)
      Input (4, 3, 1234, 1345), torch.float32, torch.contiguous_format | mode: bilinear, align_corners: True, antialias: False, osize: (2345, 2456)   |        447.907 (+-0.015)        |         463.863 (+-1.779)          |            650.492 (+-0.331)            |     1.402 (+-0.000)      |          446.596 (+-0.017)
      Input (4, 3, 1234, 1345), torch.float32, torch.contiguous_format | mode: bilinear, align_corners: False, antialias: False, osize: (2345, 2456)  |        447.750 (+-0.017)        |         723.832 (+-0.170)          |            641.539 (+-0.075)            |     0.886 (+-0.000)      |          446.467 (+-0.019)
      Input (4, 3, 1234, 1345), torch.float32, torch.channels_last | mode: bilinear, align_corners: True, antialias: False, osize: (2345, 2456)       |        439.549 (+-0.031)        |         507.772 (+-2.879)          |            758.795 (+-0.482)            |     1.494 (+-0.000)      |          440.372 (+-0.025)
      Input (4, 3, 1234, 1345), torch.float32, torch.channels_last | mode: bilinear, align_corners: False, antialias: False, osize: (2345, 2456)      |        439.538 (+-0.029)        |         509.260 (+-2.704)          |            654.195 (+-2.621)            |     1.285 (+-0.000)      |          440.362 (+-0.026)

Times are in microseconds (us).
```

[Source](f4751a3196/perf_interp_mode.py), [Output](899f34c024/output/20231213-214209-upsample-bilinear-pr_vs_nightly-speedup.md)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104182
Approved by: https://github.com/lezcano
2023-12-14 14:50:06 +00:00
angelayi
639060cb0b Use get_mkldnn_enabled for decompositions (#115448)
`torch._C.has_mkldnn` does not respect cases where users try to disable mkldnn using `torch._C._set_mkldnn_enabled()`. This is relevant to edge use cases, where they do not want decompositions to go to the ATen opset, and do not want the mkldnn operator to appear in the graph.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115448
Approved by: https://github.com/jgong5, https://github.com/ydwu4
2023-12-12 22:42:51 +00:00
Isuru Fernando
d40a7c6026 Add decompositions for replication_pad (#115113)
Fixes #115395

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115113
Approved by: https://github.com/peterbell10
2023-12-09 02:44:07 +00:00
Isuru Fernando
fb19947962 Add decompositions for reflection_pad{1, 2, 3}d (#115100)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115100
Approved by: https://github.com/peterbell10
2023-12-08 23:05:57 +00:00
Jason Ansel
7979ba7b43 [inductor] Add dropout type check to match eager (#115040)
Fixes #98970

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115040
Approved by: https://github.com/oulgen
2023-12-03 23:05:02 +00:00
Kurt Mohler
6f32eb7eef Add decomp for replication_pad2d and use for CUDA deterministic (#111590)
Fixes #95578

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111590
Approved by: https://github.com/peterbell10
2023-12-01 18:56:09 +00:00
PyTorch MergeBot
013675ff59 Revert "Add decomp for replication_pad2d and use for CUDA deterministic (#111590)"
This reverts commit f1286161a6.

Reverted https://github.com/pytorch/pytorch/pull/111590 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is failing XLA job.  The job is also failing on the PR, but the log classifier failed to find the failed test which lead to it being marked wrongly as flaky ([comment](https://github.com/pytorch/pytorch/pull/111590#issuecomment-1833004794))
2023-11-30 02:28:14 +00:00