Commit Graph

14 Commits

Author SHA1 Message Date
jiej
ce1380f9b5 fixing Optional[Tensor] type in autodiff (#55565)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/54783

We need to be extra careful with the pattern to legitimately use `unchecked_unwrap_optional` in autodiff.
This would at least allow us to start support `Optional[Tensor]` in autodiff, which is quite common in composite layers.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/55565

Reviewed By: ejguan

Differential Revision: D27825336

Pulled By: Krovatkin

fbshipit-source-id: a8562eb10ea741effff430d7417d313b1eb53dfe
2021-04-16 14:06:49 -07:00
zsef123
3498fde20e Add AccumulateType in AdaptiveAveragePooling3d.cu (#53607)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/52719

- Changed the type(`scalar_t`) of intermediate results to `at::acc_type<scalar_t, true>`

This issue occurs by decimal precision of the half precision.

Follows test cases of upper issue, The value range of input tensors are [0, 1] because init by `rand`.
And when the kernel size 1, summations all target values and divide numel of kernel
34d9278c19/aten/src/ATen/native/cuda/AdaptiveAveragePooling3d.cu (L94-L95)

When adding [0, 1] values, if `sum` more than 2048 then not changed values. ( Even if the value is small, the mored exact value is added, but there are still precision issues.)
(https://en.wikipedia.org/wiki/Half-precision_floating-point_format)

Benchmarks
- In V100 32GB, Driver : 450.80, cuda 10.1
- faster than prev

<details><summary>Script</summary><p>

```import torch
from torch.utils.benchmark import Timer

torch.manual_seed(0)

kernel_sizes = [1, 3, 5, 7, 9, 11, 13]
shapes = [(12, 12, 12), (16, 16, 16), (16, 32, 32), (16, 56, 56), (16, 112, 112)]

def run(batch, channel):
    print(f"Batch : {batch}, Channel : {channel} / (diff, diff / numel, time)")

    head = "\t".join(f"{str(s):30s}" for s in ["k \ shape"] + shapes)
    print(head)
    for kernel_size in kernel_sizes:
        kernel_size = (kernel_size, kernel_size, kernel_size)
        pool = torch.nn.AdaptiveAvgPool3d(kernel_size)

        print(f"{str(kernel_size):30s}", end="\t")
        for shape in shapes:
            x_half = torch.rand([batch, channel, *shape], dtype=torch.half, device="cuda")
            x_float = x_half.float()

            y_half = pool(x_half)
            y_float = pool(x_float)

            timer = Timer("pool(x_half)", globals={"pool": pool, "x_half": x_half})
            measurement = timer.blocked_autorange(min_run_time=5)

            diff = (y_float - y_half).abs().sum().item()
            diff = f"{diff:.4f}, {diff / y_half.numel():.6f}, {measurement.median * 1e6 :3.2f}us"
            print(f"{diff:30s}", end="\t")
        print("")

run(1, 1)
run(1, 3)
run(1, 54)
run(1, 16)

run(8, 1)
run(8, 16)
run(8, 54)

import torch
m = torch.nn.AdaptiveAvgPool3d((1,1,1))

inputs = torch.rand([8,54,16,56,56])
inputs = inputs.cuda()
inputs_2 = inputs.half()

print("Float")
out = m(inputs).float()
print("half")
out2 = m(inputs_2).float()

print('Discepancies', torch.sum(torch.abs(out2- out)).item(), torch.sum(torch.abs(out2- out)).item() / out.numel() , out.numel())

print("Sum : ", torch.sum(inputs, dim=(2,3,4))[0, 0], torch.sum(inputs_2, dim=(2,3,4))[0, 0])
```
</p>
</details>

<details><summary>This commit</summary><p>

```
Batch : 1, Channel : 1 / (diff, diff / numel, time)
k \ shape                       (12, 12, 12)                    (16, 16, 16)                         (16, 32, 32)                    (16, 56, 56)                    (16, 112, 112)
(1, 1, 1)                       0.0001, 0.000078, 55.73us       0.0001, 0.000079, 117.51us       0.0000, 0.000003, 379.60us      0.0000, 0.000046, 1046.21us      0.0001, 0.000139, 3897.17us
(3, 3, 3)                       0.0021, 0.000076, 22.04us       0.0031, 0.000115, 21.47us        0.0022, 0.000080, 41.63us       0.0030, 0.000111, 100.59us       0.0025, 0.000091, 295.04us
(5, 5, 5)                       0.0103, 0.000083, 21.65us       0.0097, 0.000078, 21.37us        0.0103, 0.000083, 21.60us       0.0114, 0.000091, 25.69us        0.0107, 0.000085, 97.06us
(7, 7, 7)                       0.0312, 0.000091, 21.52us       0.0290, 0.000084, 21.61us        0.0311, 0.000091, 21.60us       0.0309, 0.000090, 21.44us        0.0334, 0.000097, 33.60us
(9, 9, 9)                       0.0646, 0.000089, 21.57us       0.0672, 0.000092, 21.89us        0.0662, 0.000091, 21.89us       0.0684, 0.000094, 27.64us        0.0660, 0.000091, 54.85us
(11, 11, 11)                    0.1251, 0.000094, 21.68us       0.1194, 0.000090, 21.70us        0.1202, 0.000090, 21.72us       0.1233, 0.000093, 22.25us        0.1229, 0.000092, 41.39us
(13, 13, 13)                    0.2038, 0.000093, 21.57us       0.2047, 0.000093, 21.58us        0.1964, 0.000089, 21.54us       0.2021, 0.000092, 21.94us        0.1989, 0.000091, 40.01us
Batch : 1, Channel : 3 / (diff, diff / numel, time)
k \ shape                       (12, 12, 12)                    (16, 16, 16)                     (16, 32, 32)                    (16, 56, 56)                     (16, 112, 112)
(1, 1, 1)                       0.0003, 0.000110, 55.74us       0.0003, 0.000093, 118.62us       0.0003, 0.000093, 382.12us      0.0001, 0.000040, 1052.33us      0.0003, 0.000114, 3917.90us
(3, 3, 3)                       0.0073, 0.000090, 21.84us       0.0075, 0.000093, 22.25us        0.0072, 0.000089, 41.78us       0.0070, 0.000087, 100.27us       0.0069, 0.000086, 293.96us
(5, 5, 5)                       0.0353, 0.000094, 22.57us       0.0325, 0.000087, 21.64us        0.0343, 0.000092, 22.63us       0.0338, 0.000090, 25.82us        0.0332, 0.000089, 97.16us
(7, 7, 7)                       0.0937, 0.000091, 22.50us       0.0910, 0.000088, 21.92us        0.0933, 0.000091, 21.99us       0.0948, 0.000092, 21.56us        0.0928, 0.000090, 34.17us
(9, 9, 9)                       0.1957, 0.000089, 21.68us       0.1984, 0.000091, 21.57us        0.2025, 0.000093, 22.10us       0.1986, 0.000091, 27.66us        0.2020, 0.000092, 55.32us
(11, 11, 11)                    0.3585, 0.000090, 21.75us       0.3684, 0.000092, 22.70us        0.3706, 0.000093, 21.67us       0.3752, 0.000094, 21.86us        0.3663, 0.000092, 41.22us
(13, 13, 13)                    0.5931, 0.000090, 21.67us       0.6056, 0.000092, 21.79us        0.6005, 0.000091, 21.79us       0.6112, 0.000093, 21.69us        0.6034, 0.000092, 40.02us
Batch : 1, Channel : 54 / (diff, diff / numel, time)
k \ shape                       (12, 12, 12)                    (16, 16, 16)                     (16, 32, 32)                    (16, 56, 56)                     (16, 112, 112)
(1, 1, 1)                       0.0051, 0.000095, 55.76us       0.0060, 0.000112, 118.60us       0.0036, 0.000067, 381.50us      0.0054, 0.000100, 1054.03us      0.0048, 0.000089, 4888.68us
(3, 3, 3)                       0.1332, 0.000091, 21.66us       0.1344, 0.000092, 22.62us        0.1354, 0.000093, 45.72us       0.1364, 0.000094, 106.63us       0.1324, 0.000091, 448.31us
(5, 5, 5)                       0.6221, 0.000092, 22.48us       0.6220, 0.000092, 21.71us        0.6053, 0.000090, 27.65us       0.6137, 0.000091, 31.40us        0.6209, 0.000092, 172.78us
(7, 7, 7)                       1.6859, 0.000091, 22.42us       1.6972, 0.000092, 21.96us        1.6849, 0.000091, 23.14us       1.7012, 0.000092, 26.25us        1.6920, 0.000091, 75.58us
(9, 9, 9)                       3.5811, 0.000091, 21.73us       3.5746, 0.000091, 22.55us        3.6237, 0.000092, 27.66us       3.6046, 0.000092, 59.71us        3.6392, 0.000092, 168.15us
(11, 11, 11)                    6.5582, 0.000091, 22.05us       6.5746, 0.000091, 21.74us        6.5955, 0.000092, 32.91us       6.5644, 0.000091, 45.57us        6.5697, 0.000091, 114.01us
(13, 13, 13)                    10.6384, 0.000090, 21.81us      10.8608, 0.000092, 21.79us       10.8375, 0.000091, 37.01us      10.8662, 0.000092, 51.80us       10.8593, 0.000092, 123.19us
Batch : 1, Channel : 16 / (diff, diff / numel, time)
k \ shape                       (12, 12, 12)                    (16, 16, 16)                     (16, 32, 32)                    (16, 56, 56)                     (16, 112, 112)
(1, 1, 1)                       0.0015, 0.000093, 55.75us       0.0012, 0.000075, 118.10us           0.0013, 0.000079, 379.25us      0.0012, 0.000075, 1047.21us     0.0013, 0.000079, 4451.57us
(3, 3, 3)                       0.0407, 0.000094, 21.82us       0.0395, 0.000091, 21.69us            0.0385, 0.000089, 42.07us       0.0397, 0.000092, 100.33us      0.0384, 0.000089, 363.31us
(5, 5, 5)                       0.1858, 0.000093, 21.76us       0.1799, 0.000090, 21.63us            0.1834, 0.000092, 21.76us       0.1890, 0.000095, 26.04us       0.1814, 0.000091, 135.32us
(7, 7, 7)                       0.4937, 0.000090, 21.65us       0.5076, 0.000092, 21.69us            0.5001, 0.000091, 22.31us       0.4988, 0.000091, 21.59us       0.5123, 0.000093, 50.03us
(9, 9, 9)                       1.0678, 0.000092, 21.73us       1.0752, 0.000092, 21.75us            1.0673, 0.000091, 21.75us       1.0649, 0.000091, 30.01us       1.0786, 0.000092, 70.92us
(11, 11, 11)                    1.9591, 0.000092, 21.57us       1.9522, 0.000092, 21.60us            1.9566, 0.000092, 21.73us       1.9475, 0.000091, 23.46us       1.9323, 0.000091, 55.02us
(13, 13, 13)                    3.1784, 0.000090, 22.02us       3.2165, 0.000092, 21.95us            3.1969, 0.000091, 21.92us       3.2061, 0.000091, 24.40us       3.2578, 0.000093, 56.00us
Batch : 8, Channel : 1 / (diff, diff / numel, time)
k \ shape                       (12, 12, 12)                    (16, 16, 16)                         (16, 32, 32)                    (16, 56, 56)                    (16, 112, 112)
(1, 1, 1)                       0.0010, 0.000122, 55.74us       0.0009, 0.000114, 118.82us           0.0006, 0.000074, 379.80us      0.0009, 0.000107, 1047.31us     0.0008, 0.000102, 3900.36us
(3, 3, 3)                       0.0219, 0.000101, 21.57us       0.0200, 0.000093, 21.61us            0.0194, 0.000090, 41.74us       0.0208, 0.000096, 99.91us       0.0212, 0.000098, 293.03us
(5, 5, 5)                       0.0906, 0.000091, 21.46us       0.0911, 0.000091, 21.60us            0.0934, 0.000093, 21.93us       0.0927, 0.000093, 25.74us       0.0913, 0.000091, 96.85us
(7, 7, 7)                       0.2530, 0.000092, 22.53us       0.2526, 0.000092, 22.46us            0.2558, 0.000093, 22.03us       0.2542, 0.000093, 22.29us       0.2475, 0.000090, 34.44us
(9, 9, 9)                       0.5305, 0.000091, 22.34us       0.5368, 0.000092, 22.42us            0.5265, 0.000090, 21.74us       0.5370, 0.000092, 27.81us       0.5416, 0.000093, 55.65us
(11, 11, 11)                    0.9887, 0.000093, 21.80us       0.9660, 0.000091, 21.61us            0.9793, 0.000092, 22.11us       0.9719, 0.000091, 21.80us       0.9650, 0.000091, 43.90us
(13, 13, 13)                    1.6024, 0.000091, 21.87us       1.6198, 0.000092, 22.65us            1.6242, 0.000092, 21.73us       1.6236, 0.000092, 22.59us       1.6025, 0.000091, 42.77us
Batch : 8, Channel : 16 / (diff, diff / numel, time)
k \ shape                       (12, 12, 12)                    (16, 16, 16)                         (16, 32, 32)                    (16, 56, 56)                    (16, 112, 112)
(1, 1, 1)                       0.0113, 0.000088, 56.66us       0.0117, 0.000091, 119.57us           0.0130, 0.000102, 389.57us      0.0110, 0.000086, 1433.78us     0.0119, 0.000093, 5217.61us
(3, 3, 3)                       0.3209, 0.000093, 21.54us       0.3184, 0.000092, 22.87us            0.3115, 0.000090, 51.00us       0.3171, 0.000092, 164.17us      0.3182, 0.000092, 500.60us
(5, 5, 5)                       1.4391, 0.000090, 22.39us       1.4577, 0.000091, 21.69us            1.4601, 0.000091, 53.87us       1.4626, 0.000091, 93.65us       1.4567, 0.000091, 370.11us
(7, 7, 7)                       4.0501, 0.000092, 22.34us       4.0230, 0.000092, 31.45us            4.0381, 0.000092, 45.19us       4.0171, 0.000091, 65.35us       4.0108, 0.000091, 164.76us
(9, 9, 9)                       8.5360, 0.000091, 22.80us       8.5456, 0.000092, 27.24us            8.5461, 0.000092, 50.23us       8.5677, 0.000092, 117.63us      8.5645, 0.000092, 270.46us
(11, 11, 11)                    15.5521, 0.000091, 26.56us      15.5826, 0.000091, 32.81us           15.6014, 0.000092, 63.82us      15.5620, 0.000091, 96.87us      15.5722, 0.000091, 220.24us
(13, 13, 13)                    25.4146, 0.000090, 32.91us      25.7898, 0.000092, 38.48us           25.6698, 0.000091, 72.02us      25.8193, 0.000092, 121.73us     25.7718, 0.000092, 249.71us
Batch : 8, Channel : 54 / (diff, diff / numel, time)
k \ shape                       (12, 12, 12)                    (16, 16, 16)                         (16, 32, 32)                    (16, 56, 56)                    (16, 112, 112)
(1, 1, 1)                       0.0377, 0.000087, 109.07us      0.0405, 0.000094, 233.17us           0.0392, 0.000091, 998.97us      0.0393, 0.000091, 2960.68us     0.0408, 0.000094, 11879.53us
(3, 3, 3)                       1.0660, 0.000091, 25.68us       1.0761, 0.000092, 64.12us            1.0725, 0.000092, 182.50us      1.0801, 0.000093, 505.82us      1.0736, 0.000092, 1650.21us
(5, 5, 5)                       4.9587, 0.000092, 50.84us       4.9336, 0.000091, 47.38us            4.9696, 0.000092, 158.49us      4.9347, 0.000091, 237.39us      4.9303, 0.000091, 965.13us
(7, 7, 7)                       13.5409, 0.000091, 45.60us      13.5736, 0.000092, 87.45us           13.5012, 0.000091, 141.63us     13.6111, 0.000092, 181.51us     13.5296, 0.000091, 469.77us
(9, 9, 9)                       28.7817, 0.000091, 58.01us      28.7969, 0.000091, 77.61us           28.8761, 0.000092, 159.33us     28.8786, 0.000092, 334.47us     28.8093, 0.000091, 786.72us
(11, 11, 11)                    52.4453, 0.000091, 78.19us      52.7265, 0.000092, 95.12us           52.7322, 0.000092, 200.38us     52.6342, 0.000092, 282.41us     52.6467, 0.000092, 652.54us
(13, 13, 13)                    85.7411, 0.000090, 98.85us      86.7183, 0.000091, 115.28us          86.8545, 0.000092, 232.34us     86.9997, 0.000092, 367.32us     86.9083, 0.000092, 757.73us
Float
half
Discepancies 0.03963914513587952 9.175728040712852e-05 432
Sum :  tensor(25110.1484, device='cuda:0') tensor(25104., device='cuda:0', dtype=torch.float16)
```
</p>
</details>

<details><summary>1.8.0</summary><p>

```
Batch : 1, Channel : 1 / (diff, diff / numel, time)
k \ shape                       (12, 12, 12)                    (16, 16, 16)                  (16, 32, 32)                    (16, 56, 56)                    (16, 112, 112)
(1, 1, 1)                       0.0023, 0.002275, 74.35us       0.0040, 0.003985, 159.73us        0.3740, 0.374021, 546.59us      0.4587, 0.458663, 1543.16us       0.4906, 0.490637, 5945.97us
(3, 3, 3)                       0.0100, 0.000370, 20.37us       0.0230, 0.000852, 22.12us         0.0309, 0.001143, 54.75us       0.0520, 0.001926, 129.78us        7.1219, 0.263775, 377.11us
(5, 5, 5)                       0.0441, 0.000352, 20.06us       0.0394, 0.000316, 20.50us         0.0759, 0.000607, 26.43us       0.1499, 0.001199, 32.01us         0.2707, 0.002166, 128.15us
(7, 7, 7)                       0.0791, 0.000231, 20.10us       0.1002, 0.000292, 20.56us         0.1812, 0.000528, 20.48us       0.2424, 0.000707, 20.83us         0.4994, 0.001456, 43.97us
(9, 9, 9)                       0.1122, 0.000154, 20.55us       0.1778, 0.000244, 20.44us         0.2572, 0.000353, 20.15us       0.4149, 0.000569, 35.64us         0.7208, 0.000989, 68.46us
(11, 11, 11)                    0.2044, 0.000154, 20.47us       0.2647, 0.000199, 20.62us         0.3867, 0.000291, 20.61us       0.6059, 0.000455, 23.54us         1.0902, 0.000819, 53.32us
(13, 13, 13)                    0.3094, 0.000141, 20.53us       0.3843, 0.000175, 20.60us         0.5756, 0.000262, 20.80us       0.8598, 0.000391, 24.52us         1.4853, 0.000676, 47.70us
Batch : 1, Channel : 3 / (diff, diff / numel, time)
k \ shape                       (12, 12, 12)                    (16, 16, 16)                      (16, 32, 32)                    (16, 56, 56)                      (16, 112, 112)
(1, 1, 1)                       0.0054, 0.001801, 74.36us       0.0108, 0.003614, 158.94us        1.1183, 0.372768, 547.67us      1.3782, 0.459387, 1545.27us       1.4685, 0.489505, 5949.17us
(3, 3, 3)                       0.0308, 0.000380, 20.14us       0.0502, 0.000619, 22.11us         0.1210, 0.001493, 54.80us       0.1900, 0.002345, 130.47us        21.3483, 0.263560, 375.68us
(5, 5, 5)                       0.1179, 0.000314, 20.68us       0.1326, 0.000354, 20.53us         0.2662, 0.000710, 26.51us       0.4116, 0.001098, 31.85us         0.8369, 0.002232, 128.19us
(7, 7, 7)                       0.2335, 0.000227, 20.40us       0.3057, 0.000297, 20.43us         0.4954, 0.000481, 20.31us       0.7339, 0.000713, 20.74us         1.4208, 0.001381, 44.55us
(9, 9, 9)                       0.3326, 0.000152, 20.63us       0.5353, 0.000245, 20.42us         0.8025, 0.000367, 20.13us       1.2693, 0.000580, 35.64us         2.2096, 0.001010, 68.88us
(11, 11, 11)                    0.6121, 0.000153, 20.59us       0.8086, 0.000202, 20.42us         1.1700, 0.000293, 20.71us       1.8170, 0.000455, 23.54us         3.2117, 0.000804, 53.36us
(13, 13, 13)                    0.9165, 0.000139, 20.51us       1.1395, 0.000173, 20.56us         1.7343, 0.000263, 20.80us       2.5868, 0.000392, 24.59us         4.5823, 0.000695, 47.77us
Batch : 1, Channel : 54 / (diff, diff / numel, time)
k \ shape                       (12, 12, 12)                    (16, 16, 16)                      (16, 32, 32)                    (16, 56, 56)                      (16, 112, 112)
(1, 1, 1)                       0.1092, 0.002023, 75.45us       0.1709, 0.003165, 160.44us        20.2452, 0.374911, 548.61us     24.7990, 0.459240, 1550.34us      26.4494, 0.489804, 6957.79us
(3, 3, 3)                       0.5352, 0.000367, 20.58us       1.0281, 0.000705, 24.14us         2.0150, 0.001382, 59.12us       3.3069, 0.002268, 138.23us        384.5216, 0.263732, 529.71us
(5, 5, 5)                       2.0739, 0.000307, 20.60us       2.5199, 0.000373, 20.44us         4.6916, 0.000695, 33.89us       7.9482, 0.001178, 37.74us         14.2553, 0.002112, 200.54us
(7, 7, 7)                       4.2236, 0.000228, 20.61us       5.5605, 0.000300, 20.97us         9.0440, 0.000488, 26.40us       12.7847, 0.000690, 30.64us        25.3050, 0.001366, 88.05us
(9, 9, 9)                       6.0817, 0.000154, 20.63us       9.5416, 0.000242, 20.84us         14.2416, 0.000362, 32.47us      22.8452, 0.000580, 78.57us        40.3246, 0.001024, 194.50us
(11, 11, 11)                    11.1144, 0.000155, 20.56us      14.5581, 0.000203, 20.91us        20.8263, 0.000290, 38.07us      33.0004, 0.000459, 52.74us        57.3275, 0.000798, 137.19us
(13, 13, 13)                    16.5176, 0.000139, 21.26us      20.8089, 0.000175, 22.33us        31.3433, 0.000264, 42.93us      45.9733, 0.000388, 59.84us        82.8301, 0.000698, 138.42us
Batch : 1, Channel : 16 / (diff, diff / numel, time)
k \ shape                       (12, 12, 12)                    (16, 16, 16)                      (16, 32, 32)                    (16, 56, 56)                      (16, 112, 112)
(1, 1, 1)                       0.0274, 0.001715, 74.99us       0.0485, 0.003034, 159.92us    5.9925, 0.374529, 546.35us      7.3389, 0.458679, 1544.53us     7.8354, 0.489714, 6677.00us
(3, 3, 3)                       0.1560, 0.000361, 20.72us       0.3043, 0.000704, 22.37us     0.5838, 0.001352, 54.97us       1.0455, 0.002420, 130.57us      113.9739, 0.263828, 463.43us
(5, 5, 5)                       0.6121, 0.000306, 20.12us       0.7247, 0.000362, 20.73us     1.3740, 0.000687, 26.59us       2.3794, 0.001190, 32.12us       4.1929, 0.002096, 165.81us
(7, 7, 7)                       1.2389, 0.000226, 20.59us       1.6311, 0.000297, 20.53us     2.6732, 0.000487, 20.37us       3.7501, 0.000683, 20.71us       7.4575, 0.001359, 59.16us
(9, 9, 9)                       1.7983, 0.000154, 20.64us       2.8075, 0.000241, 20.59us     4.2165, 0.000361, 20.38us       6.7153, 0.000576, 38.29us       12.0530, 0.001033, 86.33us
(11, 11, 11)                    3.3326, 0.000156, 20.56us       4.3061, 0.000202, 20.67us     6.2235, 0.000292, 20.47us       9.8009, 0.000460, 27.41us       16.9994, 0.000798, 68.49us
(13, 13, 13)                    4.9016, 0.000139, 20.63us       6.1261, 0.000174, 20.65us     9.2106, 0.000262, 20.93us       13.5843, 0.000386, 27.95us      24.6476, 0.000701, 64.88us
Batch : 8, Channel : 1 / (diff, diff / numel, time)
k \ shape                       (12, 12, 12)                    (16, 16, 16)                  (16, 32, 32)                    (16, 56, 56)                    (16, 112, 112)
(1, 1, 1)                       0.0170, 0.002122, 74.99us       0.0316, 0.003946, 160.66us    3.0013, 0.375158, 546.94us      3.6780, 0.459753, 1544.58us     3.9197, 0.489966, 5948.43us
(3, 3, 3)                       0.0821, 0.000380, 20.27us       0.1559, 0.000722, 22.29us     0.3133, 0.001450, 54.72us       0.5100, 0.002361, 130.12us      57.0481, 0.264111, 376.71us
(5, 5, 5)                       0.3075, 0.000307, 20.57us       0.3680, 0.000368, 20.69us     0.6786, 0.000679, 26.61us       1.1744, 0.001174, 31.77us       2.0654, 0.002065, 128.31us
(7, 7, 7)                       0.6512, 0.000237, 20.60us       0.8359, 0.000305, 20.50us     1.3712, 0.000500, 20.75us       1.9472, 0.000710, 20.92us       3.7586, 0.001370, 44.59us
(9, 9, 9)                       0.9138, 0.000157, 20.43us       1.4198, 0.000243, 20.58us     2.1018, 0.000360, 20.52us       3.3691, 0.000578, 35.90us       5.9491, 0.001020, 69.16us
(11, 11, 11)                    1.6606, 0.000156, 20.63us       2.1599, 0.000203, 20.57us     3.1240, 0.000293, 20.98us       4.8874, 0.000459, 24.65us       8.4780, 0.000796, 56.47us
(13, 13, 13)                    2.4987, 0.000142, 20.71us       3.0667, 0.000174, 20.45us     4.6387, 0.000264, 20.76us       6.8187, 0.000388, 25.95us       12.2077, 0.000695, 50.46us
Batch : 8, Channel : 16 / (diff, diff / numel, time)
k \ shape                       (12, 12, 12)                    (16, 16, 16)                  (16, 32, 32)                    (16, 56, 56)                    (16, 112, 112)
(1, 1, 1)                       0.2635, 0.002059, 75.66us       0.4030, 0.003149, 161.78us    48.0296, 0.375231, 550.46us     58.7787, 0.459209, 1902.41us    62.6966, 0.489817, 7817.48us
(3, 3, 3)                       1.2271, 0.000355, 20.72us       2.4185, 0.000700, 26.44us     4.6933, 0.001358, 64.66us       7.7016, 0.002228, 192.69us      912.0736, 0.263910, 593.69us
(5, 5, 5)                       4.8716, 0.000304, 24.75us       5.8624, 0.000366, 21.39us     11.0705, 0.000692, 66.94us      18.9280, 0.001183, 104.93us     34.0512, 0.002128, 441.81us
(7, 7, 7)                       10.1713, 0.000232, 20.98us      13.2273, 0.000301, 36.26us    21.5426, 0.000491, 52.18us      30.1910, 0.000688, 72.94us      59.8381, 0.001363, 191.52us
(9, 9, 9)                       14.4542, 0.000155, 23.85us      22.6579, 0.000243, 30.59us    33.8839, 0.000363, 57.40us      54.3563, 0.000583, 142.53us     95.8123, 0.001027, 309.24us
(11, 11, 11)                    26.3348, 0.000155, 30.07us      34.3043, 0.000201, 37.01us    49.8093, 0.000292, 74.04us      78.3720, 0.000460, 110.53us     136.5404, 0.000801, 264.14us
(13, 13, 13)                    39.3550, 0.000140, 37.38us      49.3207, 0.000175, 43.51us    74.1139, 0.000264, 83.70us      108.7627, 0.000387, 136.09us    196.5412, 0.000699, 280.16us
Batch : 8, Channel : 54 / (diff, diff / numel, time)
k \ shape                       (12, 12, 12)                    (16, 16, 16)                  (16, 32, 32)                    (16, 56, 56)                    (16, 112, 112)
(1, 1, 1)                       0.8467, 0.001960, 147.36us      1.3993, 0.003239, 314.95us    162.0182, 0.375042, 1327.22us   198.3226, 0.459080, 3921.79us   211.6123, 0.489843, 15646.94us
(3, 3, 3)                       4.3146, 0.000370, 29.23us       8.1125, 0.000696, 74.94us     15.8886, 0.001362, 223.69us     26.2404, 0.002250, 601.33us     3076.5354, 0.263763, 1974.06us
(5, 5, 5)                       16.5032, 0.000306, 58.79us      19.6887, 0.000365, 53.79us    37.2731, 0.000690, 192.34us     63.3076, 0.001172, 270.01us     114.8880, 0.002128, 1148.56us
(7, 7, 7)                       34.0802, 0.000230, 51.12us      44.4087, 0.000300, 100.93us   72.4613, 0.000489, 161.48us     101.9317, 0.000688, 202.91us    201.8955, 0.001363, 545.33us
(9, 9, 9)                       48.8179, 0.000155, 65.78us      76.3465, 0.000242, 87.48us    114.0228, 0.000362, 179.11us    182.9805, 0.000581, 403.66us    322.7040, 0.001025, 894.86us
(11, 11, 11)                    88.9993, 0.000155, 88.69us      116.4213, 0.000202, 107.55us  168.3363, 0.000293, 228.71us    264.2232, 0.000460, 322.84us    459.1324, 0.000799, 784.25us
(13, 13, 13)                    132.7447, 0.000140, 112.91us    165.4525, 0.000174, 131.08us  249.7127, 0.000263, 266.43us    367.0824, 0.000387, 410.17us    663.1367, 0.000699, 847.87us
Float
half
Discepancies 198.37625122070312 0.4592042852331091 432
Sum :  tensor(25110.1484, device='cuda:0') tensor(25104., device='cuda:0', dtype=torch.float16)
```
</p>
</details>

ngimel malfet anjali411

Pull Request resolved: https://github.com/pytorch/pytorch/pull/53607

Reviewed By: mruberry

Differential Revision: D27652337

Pulled By: ngimel

fbshipit-source-id: 6439c0cafe6ca3f761a3f5d058050a55e9a0abd8
2021-04-08 15:48:08 -07:00
Peter Bell
2ee02b30b1 Replace rounding_mode="true" with rounding_mode=None (#51988)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/51988

* **#51988 Replace rounding_mode="true" with rounding_mode=None**

Test Plan: Imported from OSS

Reviewed By: ngimel

Differential Revision: D27561817

Pulled By: mruberry

fbshipit-source-id: 60d1d9c389570f60d599fc1876518717367fb368
2021-04-05 14:53:43 -07:00
Gregory Chanan
983347fa25 Allow broadcasting against lerp weights. (#52319)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/52319

Fixes: https://github.com/pytorch/pytorch/issues/52254

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D26488411

Pulled By: gchanan

fbshipit-source-id: 60eb471609986584c4235ba7f263581e988e7642
2021-02-18 09:53:25 -08:00
jiej
4d703d040b Linear autodiff revert revert (#51613)
Summary:
patch PR https://github.com/pytorch/pytorch/issues/50856 and rollbak the revert D26105797 (e488e3c443)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/51613

Reviewed By: mruberry

Differential Revision: D26253999

Pulled By: ngimel

fbshipit-source-id: a20b1591de06dd277e4cd95542e3291a2f5a252c
2021-02-04 16:32:05 -08:00
Peter Bell
b150f150ba Add division overload with rounding_mode selection (#51706)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/51706

Pull Request resolved: https://github.com/pytorch/pytorch/pull/50280

As mentioned in gh-43874, this adds a `rounding_mode={'true', 'trunc', 'floor'}`
argument so `torch.div` can be used as a replacement for `floor_divide` during
the transitional period.

I've included dedicated kernels for truncated and floor division which
aren't strictly necessary for float, but do perform significantly better (~2x) than
doing true division followed by a separate rounding kernel.

Note: I introduce new overloads for `aten::div` instead of just adding a default
`rounding_mode` because various JIT passes rely on the exact operator schema.

Test Plan: Imported from OSS

Reviewed By: ngimel

Differential Revision: D26123271

Pulled By: mruberry

fbshipit-source-id: 51a83717602114597ec9c4d946e35a392eb01d46
2021-02-04 13:08:36 -08:00
Natalia Gimelshein
26f9ac98e5 Revert D26105797: [pytorch][PR] Exposing linear layer to fuser
Test Plan: revert-hammer

Differential Revision:
D26105797 (e488e3c443)

Original commit changeset: 6f7cedb9f6e3

fbshipit-source-id: f0858cefed76d726e9dba61e51e1eaf2af4c99c5
2021-02-02 17:39:17 -08:00
jiej
e488e3c443 Exposing linear layer to fuser (#50856)
Summary:
1. enabling linear in autodiff;
2. remove control flow in python for linear;

Pull Request resolved: https://github.com/pytorch/pytorch/pull/50856

Reviewed By: pbelevich

Differential Revision: D26105797

Pulled By: eellison

fbshipit-source-id: 6f7cedb9f6e3e46daa24223d2a6080880498deb4
2021-02-02 15:39:01 -08:00
Andres Suarez
8530c65e25 [codemod][fbcode/caffe2] Apply clang-format update fixes
Test Plan: Sandcastle and visual inspection.

Reviewed By: igorsugak

Differential Revision: D25849205

fbshipit-source-id: ef664c1ad4b3ee92d5c020a5511b4ef9837a09a0
2021-01-09 14:37:36 -08:00
Michael Suo
dc8176356e Various cleanups to ir_emitter and friends (#46686)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46686

I was trying to page this code back in after a while and some things
stuck out as unnecessarily confusing.

1. Improve documentation of closures and fork stuff to be more accurate
to how we use them today.
2. Change `prim::LocalVariableScope` to `prim::ListComprehension`. It is
only ever used for a list comprehensions, and in general the nodes
emitted by `ir_emitter` should correspond to concrete operations or
language features rather than semantic constraints.
3. Change the somewhat mysterious "inputs" and "attributes" argument
names throughout the codebase to be the more obvious "args" and "kwargs"
that they generally represent (I think "inputs" and "attributes" come
from the AST naming).

Test Plan: Imported from OSS

Reviewed By: navahgar, jamesr66a

Differential Revision: D24464197

Pulled By: suo

fbshipit-source-id: 1f4b1475b58b5690a0b204e705caceff969533b4
2020-10-28 16:28:05 -07:00
Meghan Lele
6384c2d81b [JIT] clang-format JIT code (#35115)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35115

This commit runs the newly added tools/clang_format.py on the JIT
codebase and includes all of the formatting changes thus produced.

Testing:
Ran the script, CI.

Test Plan: Imported from OSS

Reviewed By: eellison

Differential Revision: D20568523

Pulled By: SplitInfinity

fbshipit-source-id: e09bdb982ccf090eecfb7c7b461b8d0681eef82b
2020-03-26 11:24:51 -07:00
Elias Ellison
514cba0661 [JIT] remove builtin interpolate functions (#34514)
Summary:
`torch.nn.functional.interpolate` was written as a builtin op when we scripted the standard library, because it has four possible overloads. As a result, whenever we make a change to `interpolate`, we need to make changes in two places, and it also makes it impossible to optimize the interpolate op. The builtin is tech debt.

I talked with ailzhang, and the symbolic script changes are good to remove (i guess that makes a third place we needed to re-implement interpolate).

I'm trying to get rid of unneccessary builtin operators because we're standardizing mobile bytecode soon, so we should try to get this landed as soon as possible.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34514

Differential Revision: D20391089

Pulled By: eellison

fbshipit-source-id: abc84cdecfac67332bcba6b308fca4db44303121
2020-03-12 09:21:33 -07:00
Michael Suo
c235be42dd [jit] kill script namespace (#34515)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34515

Once upon a time we thought this was necessary. In reality it is not, so
removing it.

For backcompat, our public interface (defined in `api/`) still has
typedefs to the old `script::` names.

There was only one collision: `Pass` as a `Stmt` and `Pass` as a graph
transform. I renamed one of them.

Test Plan: Imported from OSS

Differential Revision: D20353503

Pulled By: suo

fbshipit-source-id: 48bb911ce75120a8c9e0c6fb65262ef775dfba93
2020-03-11 23:32:48 -07:00
Michael Suo
dbe850af5b [jit] do the code reorg (#33851)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33851

Rationale and context described in #33828.

Script to reproduce the move:
https://gist.github.com/suo/16cbefaaeb67ca5a7c6caffd49b7f6e9
ghstack-source-id: 99079645

Test Plan: Make sure CI passes

Reviewed By: jamesr66a

Differential Revision: D20133869

fbshipit-source-id: 390e9241a9c85366d9005c492ac31f10aa96488e
2020-02-27 13:02:51 -08:00