mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Description:
- Updated upsample_bilinear2d decomposition
- added support for uint8 dtype support
- code improvements
- Added uint8 dtype tests
Perf considerations:
- There is minor perf regression (speed-up ~0.7) on cases uint8, align_corners=True when output is smaller/equal (256, 256)
- For cases, when output is larger (256, 256) and input dtype uint8, nightly output is wrong, so IMO large perf regression (speed-up around ~0.2) should not be taken into account.
## Perfs benchmarks
```
[--------------------------------------------------------------------------------------------------------------------------------------------------------- Interpolate, cpu --------------------------------------------------------------------------------------------------------------------------------------------------------]
| Eager (2.3.0a0+gitafcfdb1) PR | Compiled (2.3.0a0+gitafcfdb1) PR | Compiled (2.3.0a0+gitde89a53) Nightly | speed-up PR vs Nightly | Eager (2.3.0a0+gitde89a53) Nightly
1 threads: --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Input (1, 3, 500, 400), torch.uint8, torch.contiguous_format | mode: bilinear, align_corners: True, antialias: False, osize: (256, 256) | 565.212 (+-3.548) | 1384.210 (+-10.798) | 1230.996 (+-32.930) | 0.889 (+-0.000) | 566.253 (+-1.526)
Input (1, 3, 500, 400), torch.uint8, torch.contiguous_format | mode: bilinear, align_corners: False, antialias: False, osize: (256, 256) | 565.404 (+-1.614) | 1491.649 (+-7.763) | 2974.959 (+-6.006) | 1.994 (+-0.000) | 566.476 (+-1.742)
Input (1, 3, 500, 400), torch.uint8, torch.channels_last | mode: bilinear, align_corners: True, antialias: False, osize: (256, 256) | 270.761 (+-0.861) | 1557.777 (+-4.699) | 1080.919 (+-4.243) | 0.694 (+-0.000) | 269.829 (+-0.986)
Input (1, 3, 500, 400), torch.uint8, torch.channels_last | mode: bilinear, align_corners: False, antialias: False, osize: (256, 256) | 270.960 (+-0.995) | 1723.913 (+-12.433) | 3191.938 (+-6.194) | 1.852 (+-0.000) | 269.962 (+-1.657)
Input (1, 3, 500, 400), torch.float32, torch.contiguous_format | mode: bilinear, align_corners: True, antialias: False, osize: (256, 256) | 1555.884 (+-5.169) | 1178.753 (+-4.957) | 1910.445 (+-5.988) | 1.621 (+-0.000) | 1560.804 (+-6.793)
Input (1, 3, 500, 400), torch.float32, torch.contiguous_format | mode: bilinear, align_corners: False, antialias: False, osize: (256, 256) | 1651.193 (+-6.952) | 1323.466 (+-6.059) | 3374.842 (+-8.168) | 2.550 (+-0.000) | 1653.497 (+-8.018)
Input (1, 3, 500, 400), torch.float32, torch.channels_last | mode: bilinear, align_corners: True, antialias: False, osize: (256, 256) | 978.482 (+-10.183) | 1383.768 (+-4.341) | 2147.841 (+-6.581) | 1.552 (+-0.000) | 979.983 (+-1.499)
Input (1, 3, 500, 400), torch.float32, torch.channels_last | mode: bilinear, align_corners: False, antialias: False, osize: (256, 256) | 1074.472 (+-5.031) | 1414.912 (+-5.754) | 3590.968 (+-10.042) | 2.538 (+-0.000) | 1074.589 (+-3.948)
Input (4, 3, 500, 400), torch.uint8, torch.contiguous_format | mode: bilinear, align_corners: True, antialias: False, osize: (256, 256) | 2168.703 (+-8.964) | 5400.528 (+-26.628) | 4777.299 (+-11.891) | 0.885 (+-0.000) | 2168.133 (+-7.667)
Input (4, 3, 500, 400), torch.uint8, torch.contiguous_format | mode: bilinear, align_corners: False, antialias: False, osize: (256, 256) | 2169.132 (+-12.618) | 6583.866 (+-28.959) | 11986.894 (+-45.838) | 1.821 (+-0.000) | 2174.488 (+-10.317)
Input (4, 3, 500, 400), torch.uint8, torch.channels_last | mode: bilinear, align_corners: True, antialias: False, osize: (256, 256) | 992.808 (+-6.086) | 5985.028 (+-9.532) | 4334.158 (+-9.423) | 0.724 (+-0.000) | 989.604 (+-5.499)
Input (4, 3, 500, 400), torch.uint8, torch.channels_last | mode: bilinear, align_corners: False, antialias: False, osize: (256, 256) | 987.618 (+-6.350) | 6963.044 (+-28.885) | 15441.096 (+-55.324) | 2.218 (+-0.000) | 985.573 (+-5.159)
Input (4, 3, 500, 400), torch.float32, torch.contiguous_format | mode: bilinear, align_corners: True, antialias: False, osize: (256, 256) | 6695.557 (+-35.067) | 4657.603 (+-14.220) | 8058.708 (+-41.684) | 1.730 (+-0.000) | 6714.996 (+-38.626)
Input (4, 3, 500, 400), torch.float32, torch.contiguous_format | mode: bilinear, align_corners: False, antialias: False, osize: (256, 256) | 7040.481 (+-39.486) | 5445.704 (+-16.659) | 13906.618 (+-53.298) | 2.554 (+-0.000) | 7034.453 (+-44.626)
Input (4, 3, 500, 400), torch.float32, torch.channels_last | mode: bilinear, align_corners: True, antialias: False, osize: (256, 256) | 3926.186 (+-10.660) | 5741.433 (+-12.748) | 9356.036 (+-40.848) | 1.630 (+-0.000) | 3930.598 (+-17.086)
Input (4, 3, 500, 400), torch.float32, torch.channels_last | mode: bilinear, align_corners: False, antialias: False, osize: (256, 256) | 4308.536 (+-9.607) | 6122.755 (+-47.278) | 15637.567 (+-54.392) | 2.554 (+-0.000) | 4307.463 (+-11.268)
Input (1, 3, 1200, 1300), torch.uint8, torch.contiguous_format | mode: bilinear, align_corners: True, antialias: False, osize: (200, 300) | 2512.740 (+-10.860) | 1573.590 (+-5.061) | 451.355 (+-1.210) | 0.287 (+-0.000) | 2511.727 (+-10.930)
Input (1, 3, 1200, 1300), torch.uint8, torch.contiguous_format | mode: bilinear, align_corners: False, antialias: False, osize: (200, 300) | 2489.926 (+-11.915) | 1537.233 (+-4.212) | 2501.470 (+-7.446) | 1.627 (+-0.000) | 2500.000 (+-12.155)
Input (1, 3, 1200, 1300), torch.uint8, torch.channels_last | mode: bilinear, align_corners: True, antialias: False, osize: (200, 300) | 632.032 (+-2.108) | 1496.994 (+-4.194) | 404.759 (+-1.064) | 0.270 (+-0.000) | 630.122 (+-4.086)
Input (1, 3, 1200, 1300), torch.uint8, torch.channels_last | mode: bilinear, align_corners: False, antialias: False, osize: (200, 300) | 629.174 (+-4.386) | 1708.935 (+-8.817) | 2643.296 (+-9.723) | 1.547 (+-0.000) | 628.388 (+-1.326)
Input (1, 3, 1200, 1300), torch.float32, torch.contiguous_format | mode: bilinear, align_corners: True, antialias: False, osize: (200, 300) | 4409.941 (+-8.016) | 1160.133 (+-4.698) | 1897.089 (+-9.392) | 1.635 (+-0.000) | 4450.959 (+-10.438)
Input (1, 3, 1200, 1300), torch.float32, torch.contiguous_format | mode: bilinear, align_corners: False, antialias: False, osize: (200, 300) | 4493.427 (+-11.703) | 1329.226 (+-4.740) | 2835.872 (+-12.241) | 2.133 (+-0.000) | 4506.973 (+-9.914)
Input (1, 3, 1200, 1300), torch.float32, torch.channels_last | mode: bilinear, align_corners: True, antialias: False, osize: (200, 300) | 901.712 (+-4.071) | 1320.739 (+-5.197) | 2207.605 (+-8.219) | 1.671 (+-0.000) | 904.757 (+-4.558)
Input (1, 3, 1200, 1300), torch.float32, torch.channels_last | mode: bilinear, align_corners: False, antialias: False, osize: (200, 300) | 990.080 (+-3.922) | 1702.563 (+-7.909) | 3074.196 (+-10.478) | 1.806 (+-0.000) | 990.482 (+-4.444)
Input (4, 3, 1200, 1300), torch.uint8, torch.contiguous_format | mode: bilinear, align_corners: True, antialias: False, osize: (200, 300) | 9785.550 (+-58.445) | 6135.680 (+-33.569) | 1628.572 (+-19.770) | 0.265 (+-0.000) | 9893.606 (+-62.377)
Input (4, 3, 1200, 1300), torch.uint8, torch.contiguous_format | mode: bilinear, align_corners: False, antialias: False, osize: (200, 300) | 9710.191 (+-57.597) | 6066.824 (+-36.364) | 10469.110 (+-42.775) | 1.726 (+-0.000) | 9919.022 (+-72.190)
Input (4, 3, 1200, 1300), torch.uint8, torch.channels_last | mode: bilinear, align_corners: True, antialias: False, osize: (200, 300) | 2790.356 (+-12.188) | 6134.101 (+-28.694) | 1576.832 (+-6.030) | 0.257 (+-0.000) | 2761.122 (+-11.503)
Input (4, 3, 1200, 1300), torch.uint8, torch.channels_last | mode: bilinear, align_corners: False, antialias: False, osize: (200, 300) | 2778.711 (+-13.603) | 6608.528 (+-37.776) | 10841.549 (+-49.429) | 1.641 (+-0.000) | 2753.037 (+-10.995)
Input (4, 3, 1200, 1300), torch.float32, torch.contiguous_format | mode: bilinear, align_corners: True, antialias: False, osize: (200, 300) | 45533.868 (+-102.618) | 4962.994 (+-8.215) | 9003.968 (+-38.179) | 1.814 (+-0.000) | 43531.261 (+-102.951)
Input (4, 3, 1200, 1300), torch.float32, torch.contiguous_format | mode: bilinear, align_corners: False, antialias: False, osize: (200, 300) | 45932.699 (+-81.207) | 5595.682 (+-11.482) | 12302.907 (+-50.254) | 2.199 (+-0.000) | 43916.455 (+-80.468)
Input (4, 3, 1200, 1300), torch.float32, torch.channels_last | mode: bilinear, align_corners: True, antialias: False, osize: (200, 300) | 3827.804 (+-8.057) | 6311.580 (+-25.021) | 11760.614 (+-51.531) | 1.863 (+-0.000) | 3849.959 (+-10.848)
Input (4, 3, 1200, 1300), torch.float32, torch.channels_last | mode: bilinear, align_corners: False, antialias: False, osize: (200, 300) | 4169.007 (+-8.452) | 6820.716 (+-35.310) | 15264.633 (+-49.982) | 2.238 (+-0.000) | 4183.875 (+-19.104)
Input (1, 3, 300, 400), torch.uint8, torch.contiguous_format | mode: bilinear, align_corners: True, antialias: False, osize: (600, 700) | 1306.914 (+-7.470) | 10598.101 (+-38.410) | 2678.031 (+-11.051) | 0.253 (+-0.000) | 1307.470 (+-8.519)
Input (1, 3, 300, 400), torch.uint8, torch.contiguous_format | mode: bilinear, align_corners: False, antialias: False, osize: (600, 700) | 1307.268 (+-8.197) | 10161.123 (+-45.643) | 17148.842 (+-55.402) | 1.688 (+-0.000) | 1308.077 (+-8.553)
Input (1, 3, 300, 400), torch.uint8, torch.channels_last | mode: bilinear, align_corners: True, antialias: False, osize: (600, 700) | 548.574 (+-2.157) | 10072.806 (+-41.368) | 2408.971 (+-6.997) | 0.239 (+-0.000) | 547.726 (+-1.721)
Input (1, 3, 300, 400), torch.uint8, torch.channels_last | mode: bilinear, align_corners: False, antialias: False, osize: (600, 700) | 546.664 (+-1.484) | 11123.694 (+-43.636) | 18058.070 (+-48.552) | 1.623 (+-0.000) | 547.151 (+-1.627)
Input (1, 3, 300, 400), torch.float32, torch.contiguous_format | mode: bilinear, align_corners: True, antialias: False, osize: (600, 700) | 7935.051 (+-71.022) | 7654.533 (+-29.512) | 12414.194 (+-87.450) | 1.622 (+-0.000) | 7900.056 (+-53.997)
Input (1, 3, 300, 400), torch.float32, torch.contiguous_format | mode: bilinear, align_corners: False, antialias: False, osize: (600, 700) | 8546.732 (+-53.118) | 8583.572 (+-35.656) | 19111.824 (+-166.978) | 2.227 (+-0.000) | 8515.433 (+-63.300)
Input (1, 3, 300, 400), torch.float32, torch.channels_last | mode: bilinear, align_corners: True, antialias: False, osize: (600, 700) | 6202.642 (+-34.355) | 8915.622 (+-62.293) | 14327.295 (+-52.188) | 1.607 (+-0.000) | 6213.329 (+-39.740)
Input (1, 3, 300, 400), torch.float32, torch.channels_last | mode: bilinear, align_corners: False, antialias: False, osize: (600, 700) | 6811.128 (+-33.747) | 9647.316 (+-50.837) | 20830.594 (+-62.979) | 2.159 (+-0.000) | 6822.512 (+-37.092)
Input (4, 3, 300, 400), torch.uint8, torch.contiguous_format | mode: bilinear, align_corners: True, antialias: False, osize: (600, 700) | 5079.586 (+-19.067) | 42238.442 (+-87.643) | 11282.141 (+-42.477) | 0.267 (+-0.000) | 5104.234 (+-17.706)
Input (4, 3, 300, 400), torch.uint8, torch.contiguous_format | mode: bilinear, align_corners: False, antialias: False, osize: (600, 700) | 5079.575 (+-16.306) | 41512.995 (+-83.710) | 68789.816 (+-440.001) | 1.657 (+-0.000) | 5097.446 (+-21.724)
Input (4, 3, 300, 400), torch.uint8, torch.channels_last | mode: bilinear, align_corners: True, antialias: False, osize: (600, 700) | 2039.974 (+-8.614) | 42322.773 (+-111.866) | 10399.237 (+-43.140) | 0.246 (+-0.000) | 2043.808 (+-10.707)
Input (4, 3, 300, 400), torch.uint8, torch.channels_last | mode: bilinear, align_corners: False, antialias: False, osize: (600, 700) | 2036.214 (+-10.083) | 44353.281 (+-71.548) | 73340.412 (+-324.780) | 1.654 (+-0.000) | 2039.000 (+-9.554)
Input (4, 3, 300, 400), torch.float32, torch.contiguous_format | mode: bilinear, align_corners: True, antialias: False, osize: (600, 700) | 33821.523 (+-96.639) | 30552.094 (+-65.023) | 49494.486 (+-872.916) | 1.620 (+-0.000) | 33844.404 (+-92.466)
Input (4, 3, 300, 400), torch.float32, torch.contiguous_format | mode: bilinear, align_corners: False, antialias: False, osize: (600, 700) | 36196.104 (+-128.169) | 34038.432 (+-79.697) | 75761.226 (+-905.194) | 2.226 (+-0.000) | 36260.473 (+-94.642)
Input (4, 3, 300, 400), torch.float32, torch.channels_last | mode: bilinear, align_corners: True, antialias: False, osize: (600, 700) | 24827.821 (+-77.335) | 37006.218 (+-86.318) | 61297.625 (+-898.192) | 1.656 (+-0.000) | 24823.275 (+-80.945)
Input (4, 3, 300, 400), torch.float32, torch.channels_last | mode: bilinear, align_corners: False, antialias: False, osize: (600, 700) | 27266.138 (+-70.262) | 40109.475 (+-94.248) | 92086.075 (+-404.922) | 2.296 (+-0.000) | 27287.992 (+-89.507)
Times are in microseconds (us).
[--------------------------------------------------------------------------------------------------------------------------------------------------------- Interpolate, cuda ---------------------------------------------------------------------------------------------------------------------------------------------------------]
| Eager (2.3.0a0+gitafcfdb1) PR | Compiled (2.3.0a0+gitafcfdb1) PR | Compiled (2.3.0a0+gitde89a53) Nightly | speed-up PR vs Nightly | Eager (2.3.0a0+gitde89a53) Nightly
1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Input (1, 3, 2345, 2456), torch.float32, torch.contiguous_format | mode: bilinear, align_corners: True, antialias: False, osize: (1234, 1345) | 98.259 (+-0.014) | 97.156 (+-0.008) | 97.443 (+-0.031) | 1.003 (+-0.000) | 98.248 (+-0.021)
Input (1, 3, 2345, 2456), torch.float32, torch.contiguous_format | mode: bilinear, align_corners: False, antialias: False, osize: (1234, 1345) | 97.048 (+-0.016) | 97.480 (+-0.018) | 96.819 (+-0.126) | 0.993 (+-0.000) | 97.045 (+-0.015)
Input (1, 3, 2345, 2456), torch.float32, torch.channels_last | mode: bilinear, align_corners: True, antialias: False, osize: (1234, 1345) | 97.944 (+-0.028) | 91.686 (+-0.411) | 93.894 (+-1.011) | 1.024 (+-0.000) | 97.933 (+-0.008)
Input (1, 3, 2345, 2456), torch.float32, torch.channels_last | mode: bilinear, align_corners: False, antialias: False, osize: (1234, 1345) | 98.008 (+-0.011) | 91.205 (+-0.346) | 96.854 (+-0.058) | 1.062 (+-0.000) | 97.203 (+-0.010)
Input (4, 3, 2345, 2456), torch.float32, torch.contiguous_format | mode: bilinear, align_corners: True, antialias: False, osize: (1234, 1345) | 384.318 (+-0.011) | 382.793 (+-0.007) | 382.472 (+-0.011) | 0.999 (+-0.000) | 384.701 (+-0.012)
Input (4, 3, 2345, 2456), torch.float32, torch.contiguous_format | mode: bilinear, align_corners: False, antialias: False, osize: (1234, 1345) | 384.266 (+-0.009) | 385.333 (+-0.024) | 382.554 (+-0.022) | 0.993 (+-0.000) | 384.386 (+-0.016)
Input (4, 3, 2345, 2456), torch.float32, torch.channels_last | mode: bilinear, align_corners: True, antialias: False, osize: (1234, 1345) | 383.924 (+-0.011) | 570.071 (+-0.030) | 545.615 (+-0.051) | 0.957 (+-0.000) | 384.044 (+-0.012)
Input (4, 3, 2345, 2456), torch.float32, torch.channels_last | mode: bilinear, align_corners: False, antialias: False, osize: (1234, 1345) | 384.184 (+-0.016) | 560.857 (+-0.026) | 552.447 (+-0.040) | 0.985 (+-0.000) | 384.063 (+-0.016)
Input (1, 3, 1234, 1345), torch.float32, torch.contiguous_format | mode: bilinear, align_corners: True, antialias: False, osize: (2345, 2456) | 122.188 (+-0.053) | 116.744 (+-1.006) | 163.762 (+-0.015) | 1.403 (+-0.000) | 121.874 (+-0.015)
Input (1, 3, 1234, 1345), torch.float32, torch.contiguous_format | mode: bilinear, align_corners: False, antialias: False, osize: (2345, 2456) | 122.156 (+-0.012) | 182.692 (+-0.013) | 161.653 (+-0.018) | 0.885 (+-0.000) | 121.926 (+-0.014)
Input (1, 3, 1234, 1345), torch.float32, torch.channels_last | mode: bilinear, align_corners: True, antialias: False, osize: (2345, 2456) | 105.852 (+-0.324) | 119.545 (+-0.294) | 190.527 (+-0.023) | 1.594 (+-0.000) | 105.999 (+-0.446)
Input (1, 3, 1234, 1345), torch.float32, torch.channels_last | mode: bilinear, align_corners: False, antialias: False, osize: (2345, 2456) | 106.507 (+-0.282) | 120.060 (+-0.257) | 162.330 (+-0.012) | 1.352 (+-0.000) | 106.567 (+-0.385)
Input (4, 3, 1234, 1345), torch.float32, torch.contiguous_format | mode: bilinear, align_corners: True, antialias: False, osize: (2345, 2456) | 447.907 (+-0.015) | 463.863 (+-1.779) | 650.492 (+-0.331) | 1.402 (+-0.000) | 446.596 (+-0.017)
Input (4, 3, 1234, 1345), torch.float32, torch.contiguous_format | mode: bilinear, align_corners: False, antialias: False, osize: (2345, 2456) | 447.750 (+-0.017) | 723.832 (+-0.170) | 641.539 (+-0.075) | 0.886 (+-0.000) | 446.467 (+-0.019)
Input (4, 3, 1234, 1345), torch.float32, torch.channels_last | mode: bilinear, align_corners: True, antialias: False, osize: (2345, 2456) | 439.549 (+-0.031) | 507.772 (+-2.879) | 758.795 (+-0.482) | 1.494 (+-0.000) | 440.372 (+-0.025)
Input (4, 3, 1234, 1345), torch.float32, torch.channels_last | mode: bilinear, align_corners: False, antialias: False, osize: (2345, 2456) | 439.538 (+-0.029) | 509.260 (+-2.704) | 654.195 (+-2.621) | 1.285 (+-0.000) | 440.362 (+-0.026)
Times are in microseconds (us).
```
[Source](
|
||
|---|---|---|
| .. | ||
| __init__.py | ||
| decompositions_for_jvp.py | ||
| decompositions_for_rng.py | ||
| decompositions.py | ||