mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary:
Fixes https://github.com/pytorch/pytorch/issues/52719
- Changed the type(`scalar_t`) of intermediate results to `at::acc_type<scalar_t, true>`
This issue occurs by decimal precision of the half precision.
Follows test cases of upper issue, The value range of input tensors are [0, 1] because init by `rand`.
And when the kernel size 1, summations all target values and divide numel of kernel
34d9278c19/aten/src/ATen/native/cuda/AdaptiveAveragePooling3d.cu (L94-L95)
When adding [0, 1] values, if `sum` more than 2048 then not changed values. ( Even if the value is small, the mored exact value is added, but there are still precision issues.)
(https://en.wikipedia.org/wiki/Half-precision_floating-point_format)
Benchmarks
- In V100 32GB, Driver : 450.80, cuda 10.1
- faster than prev
<details><summary>Script</summary><p>
```import torch
from torch.utils.benchmark import Timer
torch.manual_seed(0)
kernel_sizes = [1, 3, 5, 7, 9, 11, 13]
shapes = [(12, 12, 12), (16, 16, 16), (16, 32, 32), (16, 56, 56), (16, 112, 112)]
def run(batch, channel):
print(f"Batch : {batch}, Channel : {channel} / (diff, diff / numel, time)")
head = "\t".join(f"{str(s):30s}" for s in ["k \ shape"] + shapes)
print(head)
for kernel_size in kernel_sizes:
kernel_size = (kernel_size, kernel_size, kernel_size)
pool = torch.nn.AdaptiveAvgPool3d(kernel_size)
print(f"{str(kernel_size):30s}", end="\t")
for shape in shapes:
x_half = torch.rand([batch, channel, *shape], dtype=torch.half, device="cuda")
x_float = x_half.float()
y_half = pool(x_half)
y_float = pool(x_float)
timer = Timer("pool(x_half)", globals={"pool": pool, "x_half": x_half})
measurement = timer.blocked_autorange(min_run_time=5)
diff = (y_float - y_half).abs().sum().item()
diff = f"{diff:.4f}, {diff / y_half.numel():.6f}, {measurement.median * 1e6 :3.2f}us"
print(f"{diff:30s}", end="\t")
print("")
run(1, 1)
run(1, 3)
run(1, 54)
run(1, 16)
run(8, 1)
run(8, 16)
run(8, 54)
import torch
m = torch.nn.AdaptiveAvgPool3d((1,1,1))
inputs = torch.rand([8,54,16,56,56])
inputs = inputs.cuda()
inputs_2 = inputs.half()
print("Float")
out = m(inputs).float()
print("half")
out2 = m(inputs_2).float()
print('Discepancies', torch.sum(torch.abs(out2- out)).item(), torch.sum(torch.abs(out2- out)).item() / out.numel() , out.numel())
print("Sum : ", torch.sum(inputs, dim=(2,3,4))[0, 0], torch.sum(inputs_2, dim=(2,3,4))[0, 0])
```
</p>
</details>
<details><summary>This commit</summary><p>
```
Batch : 1, Channel : 1 / (diff, diff / numel, time)
k \ shape (12, 12, 12) (16, 16, 16) (16, 32, 32) (16, 56, 56) (16, 112, 112)
(1, 1, 1) 0.0001, 0.000078, 55.73us 0.0001, 0.000079, 117.51us 0.0000, 0.000003, 379.60us 0.0000, 0.000046, 1046.21us 0.0001, 0.000139, 3897.17us
(3, 3, 3) 0.0021, 0.000076, 22.04us 0.0031, 0.000115, 21.47us 0.0022, 0.000080, 41.63us 0.0030, 0.000111, 100.59us 0.0025, 0.000091, 295.04us
(5, 5, 5) 0.0103, 0.000083, 21.65us 0.0097, 0.000078, 21.37us 0.0103, 0.000083, 21.60us 0.0114, 0.000091, 25.69us 0.0107, 0.000085, 97.06us
(7, 7, 7) 0.0312, 0.000091, 21.52us 0.0290, 0.000084, 21.61us 0.0311, 0.000091, 21.60us 0.0309, 0.000090, 21.44us 0.0334, 0.000097, 33.60us
(9, 9, 9) 0.0646, 0.000089, 21.57us 0.0672, 0.000092, 21.89us 0.0662, 0.000091, 21.89us 0.0684, 0.000094, 27.64us 0.0660, 0.000091, 54.85us
(11, 11, 11) 0.1251, 0.000094, 21.68us 0.1194, 0.000090, 21.70us 0.1202, 0.000090, 21.72us 0.1233, 0.000093, 22.25us 0.1229, 0.000092, 41.39us
(13, 13, 13) 0.2038, 0.000093, 21.57us 0.2047, 0.000093, 21.58us 0.1964, 0.000089, 21.54us 0.2021, 0.000092, 21.94us 0.1989, 0.000091, 40.01us
Batch : 1, Channel : 3 / (diff, diff / numel, time)
k \ shape (12, 12, 12) (16, 16, 16) (16, 32, 32) (16, 56, 56) (16, 112, 112)
(1, 1, 1) 0.0003, 0.000110, 55.74us 0.0003, 0.000093, 118.62us 0.0003, 0.000093, 382.12us 0.0001, 0.000040, 1052.33us 0.0003, 0.000114, 3917.90us
(3, 3, 3) 0.0073, 0.000090, 21.84us 0.0075, 0.000093, 22.25us 0.0072, 0.000089, 41.78us 0.0070, 0.000087, 100.27us 0.0069, 0.000086, 293.96us
(5, 5, 5) 0.0353, 0.000094, 22.57us 0.0325, 0.000087, 21.64us 0.0343, 0.000092, 22.63us 0.0338, 0.000090, 25.82us 0.0332, 0.000089, 97.16us
(7, 7, 7) 0.0937, 0.000091, 22.50us 0.0910, 0.000088, 21.92us 0.0933, 0.000091, 21.99us 0.0948, 0.000092, 21.56us 0.0928, 0.000090, 34.17us
(9, 9, 9) 0.1957, 0.000089, 21.68us 0.1984, 0.000091, 21.57us 0.2025, 0.000093, 22.10us 0.1986, 0.000091, 27.66us 0.2020, 0.000092, 55.32us
(11, 11, 11) 0.3585, 0.000090, 21.75us 0.3684, 0.000092, 22.70us 0.3706, 0.000093, 21.67us 0.3752, 0.000094, 21.86us 0.3663, 0.000092, 41.22us
(13, 13, 13) 0.5931, 0.000090, 21.67us 0.6056, 0.000092, 21.79us 0.6005, 0.000091, 21.79us 0.6112, 0.000093, 21.69us 0.6034, 0.000092, 40.02us
Batch : 1, Channel : 54 / (diff, diff / numel, time)
k \ shape (12, 12, 12) (16, 16, 16) (16, 32, 32) (16, 56, 56) (16, 112, 112)
(1, 1, 1) 0.0051, 0.000095, 55.76us 0.0060, 0.000112, 118.60us 0.0036, 0.000067, 381.50us 0.0054, 0.000100, 1054.03us 0.0048, 0.000089, 4888.68us
(3, 3, 3) 0.1332, 0.000091, 21.66us 0.1344, 0.000092, 22.62us 0.1354, 0.000093, 45.72us 0.1364, 0.000094, 106.63us 0.1324, 0.000091, 448.31us
(5, 5, 5) 0.6221, 0.000092, 22.48us 0.6220, 0.000092, 21.71us 0.6053, 0.000090, 27.65us 0.6137, 0.000091, 31.40us 0.6209, 0.000092, 172.78us
(7, 7, 7) 1.6859, 0.000091, 22.42us 1.6972, 0.000092, 21.96us 1.6849, 0.000091, 23.14us 1.7012, 0.000092, 26.25us 1.6920, 0.000091, 75.58us
(9, 9, 9) 3.5811, 0.000091, 21.73us 3.5746, 0.000091, 22.55us 3.6237, 0.000092, 27.66us 3.6046, 0.000092, 59.71us 3.6392, 0.000092, 168.15us
(11, 11, 11) 6.5582, 0.000091, 22.05us 6.5746, 0.000091, 21.74us 6.5955, 0.000092, 32.91us 6.5644, 0.000091, 45.57us 6.5697, 0.000091, 114.01us
(13, 13, 13) 10.6384, 0.000090, 21.81us 10.8608, 0.000092, 21.79us 10.8375, 0.000091, 37.01us 10.8662, 0.000092, 51.80us 10.8593, 0.000092, 123.19us
Batch : 1, Channel : 16 / (diff, diff / numel, time)
k \ shape (12, 12, 12) (16, 16, 16) (16, 32, 32) (16, 56, 56) (16, 112, 112)
(1, 1, 1) 0.0015, 0.000093, 55.75us 0.0012, 0.000075, 118.10us 0.0013, 0.000079, 379.25us 0.0012, 0.000075, 1047.21us 0.0013, 0.000079, 4451.57us
(3, 3, 3) 0.0407, 0.000094, 21.82us 0.0395, 0.000091, 21.69us 0.0385, 0.000089, 42.07us 0.0397, 0.000092, 100.33us 0.0384, 0.000089, 363.31us
(5, 5, 5) 0.1858, 0.000093, 21.76us 0.1799, 0.000090, 21.63us 0.1834, 0.000092, 21.76us 0.1890, 0.000095, 26.04us 0.1814, 0.000091, 135.32us
(7, 7, 7) 0.4937, 0.000090, 21.65us 0.5076, 0.000092, 21.69us 0.5001, 0.000091, 22.31us 0.4988, 0.000091, 21.59us 0.5123, 0.000093, 50.03us
(9, 9, 9) 1.0678, 0.000092, 21.73us 1.0752, 0.000092, 21.75us 1.0673, 0.000091, 21.75us 1.0649, 0.000091, 30.01us 1.0786, 0.000092, 70.92us
(11, 11, 11) 1.9591, 0.000092, 21.57us 1.9522, 0.000092, 21.60us 1.9566, 0.000092, 21.73us 1.9475, 0.000091, 23.46us 1.9323, 0.000091, 55.02us
(13, 13, 13) 3.1784, 0.000090, 22.02us 3.2165, 0.000092, 21.95us 3.1969, 0.000091, 21.92us 3.2061, 0.000091, 24.40us 3.2578, 0.000093, 56.00us
Batch : 8, Channel : 1 / (diff, diff / numel, time)
k \ shape (12, 12, 12) (16, 16, 16) (16, 32, 32) (16, 56, 56) (16, 112, 112)
(1, 1, 1) 0.0010, 0.000122, 55.74us 0.0009, 0.000114, 118.82us 0.0006, 0.000074, 379.80us 0.0009, 0.000107, 1047.31us 0.0008, 0.000102, 3900.36us
(3, 3, 3) 0.0219, 0.000101, 21.57us 0.0200, 0.000093, 21.61us 0.0194, 0.000090, 41.74us 0.0208, 0.000096, 99.91us 0.0212, 0.000098, 293.03us
(5, 5, 5) 0.0906, 0.000091, 21.46us 0.0911, 0.000091, 21.60us 0.0934, 0.000093, 21.93us 0.0927, 0.000093, 25.74us 0.0913, 0.000091, 96.85us
(7, 7, 7) 0.2530, 0.000092, 22.53us 0.2526, 0.000092, 22.46us 0.2558, 0.000093, 22.03us 0.2542, 0.000093, 22.29us 0.2475, 0.000090, 34.44us
(9, 9, 9) 0.5305, 0.000091, 22.34us 0.5368, 0.000092, 22.42us 0.5265, 0.000090, 21.74us 0.5370, 0.000092, 27.81us 0.5416, 0.000093, 55.65us
(11, 11, 11) 0.9887, 0.000093, 21.80us 0.9660, 0.000091, 21.61us 0.9793, 0.000092, 22.11us 0.9719, 0.000091, 21.80us 0.9650, 0.000091, 43.90us
(13, 13, 13) 1.6024, 0.000091, 21.87us 1.6198, 0.000092, 22.65us 1.6242, 0.000092, 21.73us 1.6236, 0.000092, 22.59us 1.6025, 0.000091, 42.77us
Batch : 8, Channel : 16 / (diff, diff / numel, time)
k \ shape (12, 12, 12) (16, 16, 16) (16, 32, 32) (16, 56, 56) (16, 112, 112)
(1, 1, 1) 0.0113, 0.000088, 56.66us 0.0117, 0.000091, 119.57us 0.0130, 0.000102, 389.57us 0.0110, 0.000086, 1433.78us 0.0119, 0.000093, 5217.61us
(3, 3, 3) 0.3209, 0.000093, 21.54us 0.3184, 0.000092, 22.87us 0.3115, 0.000090, 51.00us 0.3171, 0.000092, 164.17us 0.3182, 0.000092, 500.60us
(5, 5, 5) 1.4391, 0.000090, 22.39us 1.4577, 0.000091, 21.69us 1.4601, 0.000091, 53.87us 1.4626, 0.000091, 93.65us 1.4567, 0.000091, 370.11us
(7, 7, 7) 4.0501, 0.000092, 22.34us 4.0230, 0.000092, 31.45us 4.0381, 0.000092, 45.19us 4.0171, 0.000091, 65.35us 4.0108, 0.000091, 164.76us
(9, 9, 9) 8.5360, 0.000091, 22.80us 8.5456, 0.000092, 27.24us 8.5461, 0.000092, 50.23us 8.5677, 0.000092, 117.63us 8.5645, 0.000092, 270.46us
(11, 11, 11) 15.5521, 0.000091, 26.56us 15.5826, 0.000091, 32.81us 15.6014, 0.000092, 63.82us 15.5620, 0.000091, 96.87us 15.5722, 0.000091, 220.24us
(13, 13, 13) 25.4146, 0.000090, 32.91us 25.7898, 0.000092, 38.48us 25.6698, 0.000091, 72.02us 25.8193, 0.000092, 121.73us 25.7718, 0.000092, 249.71us
Batch : 8, Channel : 54 / (diff, diff / numel, time)
k \ shape (12, 12, 12) (16, 16, 16) (16, 32, 32) (16, 56, 56) (16, 112, 112)
(1, 1, 1) 0.0377, 0.000087, 109.07us 0.0405, 0.000094, 233.17us 0.0392, 0.000091, 998.97us 0.0393, 0.000091, 2960.68us 0.0408, 0.000094, 11879.53us
(3, 3, 3) 1.0660, 0.000091, 25.68us 1.0761, 0.000092, 64.12us 1.0725, 0.000092, 182.50us 1.0801, 0.000093, 505.82us 1.0736, 0.000092, 1650.21us
(5, 5, 5) 4.9587, 0.000092, 50.84us 4.9336, 0.000091, 47.38us 4.9696, 0.000092, 158.49us 4.9347, 0.000091, 237.39us 4.9303, 0.000091, 965.13us
(7, 7, 7) 13.5409, 0.000091, 45.60us 13.5736, 0.000092, 87.45us 13.5012, 0.000091, 141.63us 13.6111, 0.000092, 181.51us 13.5296, 0.000091, 469.77us
(9, 9, 9) 28.7817, 0.000091, 58.01us 28.7969, 0.000091, 77.61us 28.8761, 0.000092, 159.33us 28.8786, 0.000092, 334.47us 28.8093, 0.000091, 786.72us
(11, 11, 11) 52.4453, 0.000091, 78.19us 52.7265, 0.000092, 95.12us 52.7322, 0.000092, 200.38us 52.6342, 0.000092, 282.41us 52.6467, 0.000092, 652.54us
(13, 13, 13) 85.7411, 0.000090, 98.85us 86.7183, 0.000091, 115.28us 86.8545, 0.000092, 232.34us 86.9997, 0.000092, 367.32us 86.9083, 0.000092, 757.73us
Float
half
Discepancies 0.03963914513587952 9.175728040712852e-05 432
Sum : tensor(25110.1484, device='cuda:0') tensor(25104., device='cuda:0', dtype=torch.float16)
```
</p>
</details>
<details><summary>1.8.0</summary><p>
```
Batch : 1, Channel : 1 / (diff, diff / numel, time)
k \ shape (12, 12, 12) (16, 16, 16) (16, 32, 32) (16, 56, 56) (16, 112, 112)
(1, 1, 1) 0.0023, 0.002275, 74.35us 0.0040, 0.003985, 159.73us 0.3740, 0.374021, 546.59us 0.4587, 0.458663, 1543.16us 0.4906, 0.490637, 5945.97us
(3, 3, 3) 0.0100, 0.000370, 20.37us 0.0230, 0.000852, 22.12us 0.0309, 0.001143, 54.75us 0.0520, 0.001926, 129.78us 7.1219, 0.263775, 377.11us
(5, 5, 5) 0.0441, 0.000352, 20.06us 0.0394, 0.000316, 20.50us 0.0759, 0.000607, 26.43us 0.1499, 0.001199, 32.01us 0.2707, 0.002166, 128.15us
(7, 7, 7) 0.0791, 0.000231, 20.10us 0.1002, 0.000292, 20.56us 0.1812, 0.000528, 20.48us 0.2424, 0.000707, 20.83us 0.4994, 0.001456, 43.97us
(9, 9, 9) 0.1122, 0.000154, 20.55us 0.1778, 0.000244, 20.44us 0.2572, 0.000353, 20.15us 0.4149, 0.000569, 35.64us 0.7208, 0.000989, 68.46us
(11, 11, 11) 0.2044, 0.000154, 20.47us 0.2647, 0.000199, 20.62us 0.3867, 0.000291, 20.61us 0.6059, 0.000455, 23.54us 1.0902, 0.000819, 53.32us
(13, 13, 13) 0.3094, 0.000141, 20.53us 0.3843, 0.000175, 20.60us 0.5756, 0.000262, 20.80us 0.8598, 0.000391, 24.52us 1.4853, 0.000676, 47.70us
Batch : 1, Channel : 3 / (diff, diff / numel, time)
k \ shape (12, 12, 12) (16, 16, 16) (16, 32, 32) (16, 56, 56) (16, 112, 112)
(1, 1, 1) 0.0054, 0.001801, 74.36us 0.0108, 0.003614, 158.94us 1.1183, 0.372768, 547.67us 1.3782, 0.459387, 1545.27us 1.4685, 0.489505, 5949.17us
(3, 3, 3) 0.0308, 0.000380, 20.14us 0.0502, 0.000619, 22.11us 0.1210, 0.001493, 54.80us 0.1900, 0.002345, 130.47us 21.3483, 0.263560, 375.68us
(5, 5, 5) 0.1179, 0.000314, 20.68us 0.1326, 0.000354, 20.53us 0.2662, 0.000710, 26.51us 0.4116, 0.001098, 31.85us 0.8369, 0.002232, 128.19us
(7, 7, 7) 0.2335, 0.000227, 20.40us 0.3057, 0.000297, 20.43us 0.4954, 0.000481, 20.31us 0.7339, 0.000713, 20.74us 1.4208, 0.001381, 44.55us
(9, 9, 9) 0.3326, 0.000152, 20.63us 0.5353, 0.000245, 20.42us 0.8025, 0.000367, 20.13us 1.2693, 0.000580, 35.64us 2.2096, 0.001010, 68.88us
(11, 11, 11) 0.6121, 0.000153, 20.59us 0.8086, 0.000202, 20.42us 1.1700, 0.000293, 20.71us 1.8170, 0.000455, 23.54us 3.2117, 0.000804, 53.36us
(13, 13, 13) 0.9165, 0.000139, 20.51us 1.1395, 0.000173, 20.56us 1.7343, 0.000263, 20.80us 2.5868, 0.000392, 24.59us 4.5823, 0.000695, 47.77us
Batch : 1, Channel : 54 / (diff, diff / numel, time)
k \ shape (12, 12, 12) (16, 16, 16) (16, 32, 32) (16, 56, 56) (16, 112, 112)
(1, 1, 1) 0.1092, 0.002023, 75.45us 0.1709, 0.003165, 160.44us 20.2452, 0.374911, 548.61us 24.7990, 0.459240, 1550.34us 26.4494, 0.489804, 6957.79us
(3, 3, 3) 0.5352, 0.000367, 20.58us 1.0281, 0.000705, 24.14us 2.0150, 0.001382, 59.12us 3.3069, 0.002268, 138.23us 384.5216, 0.263732, 529.71us
(5, 5, 5) 2.0739, 0.000307, 20.60us 2.5199, 0.000373, 20.44us 4.6916, 0.000695, 33.89us 7.9482, 0.001178, 37.74us 14.2553, 0.002112, 200.54us
(7, 7, 7) 4.2236, 0.000228, 20.61us 5.5605, 0.000300, 20.97us 9.0440, 0.000488, 26.40us 12.7847, 0.000690, 30.64us 25.3050, 0.001366, 88.05us
(9, 9, 9) 6.0817, 0.000154, 20.63us 9.5416, 0.000242, 20.84us 14.2416, 0.000362, 32.47us 22.8452, 0.000580, 78.57us 40.3246, 0.001024, 194.50us
(11, 11, 11) 11.1144, 0.000155, 20.56us 14.5581, 0.000203, 20.91us 20.8263, 0.000290, 38.07us 33.0004, 0.000459, 52.74us 57.3275, 0.000798, 137.19us
(13, 13, 13) 16.5176, 0.000139, 21.26us 20.8089, 0.000175, 22.33us 31.3433, 0.000264, 42.93us 45.9733, 0.000388, 59.84us 82.8301, 0.000698, 138.42us
Batch : 1, Channel : 16 / (diff, diff / numel, time)
k \ shape (12, 12, 12) (16, 16, 16) (16, 32, 32) (16, 56, 56) (16, 112, 112)
(1, 1, 1) 0.0274, 0.001715, 74.99us 0.0485, 0.003034, 159.92us 5.9925, 0.374529, 546.35us 7.3389, 0.458679, 1544.53us 7.8354, 0.489714, 6677.00us
(3, 3, 3) 0.1560, 0.000361, 20.72us 0.3043, 0.000704, 22.37us 0.5838, 0.001352, 54.97us 1.0455, 0.002420, 130.57us 113.9739, 0.263828, 463.43us
(5, 5, 5) 0.6121, 0.000306, 20.12us 0.7247, 0.000362, 20.73us 1.3740, 0.000687, 26.59us 2.3794, 0.001190, 32.12us 4.1929, 0.002096, 165.81us
(7, 7, 7) 1.2389, 0.000226, 20.59us 1.6311, 0.000297, 20.53us 2.6732, 0.000487, 20.37us 3.7501, 0.000683, 20.71us 7.4575, 0.001359, 59.16us
(9, 9, 9) 1.7983, 0.000154, 20.64us 2.8075, 0.000241, 20.59us 4.2165, 0.000361, 20.38us 6.7153, 0.000576, 38.29us 12.0530, 0.001033, 86.33us
(11, 11, 11) 3.3326, 0.000156, 20.56us 4.3061, 0.000202, 20.67us 6.2235, 0.000292, 20.47us 9.8009, 0.000460, 27.41us 16.9994, 0.000798, 68.49us
(13, 13, 13) 4.9016, 0.000139, 20.63us 6.1261, 0.000174, 20.65us 9.2106, 0.000262, 20.93us 13.5843, 0.000386, 27.95us 24.6476, 0.000701, 64.88us
Batch : 8, Channel : 1 / (diff, diff / numel, time)
k \ shape (12, 12, 12) (16, 16, 16) (16, 32, 32) (16, 56, 56) (16, 112, 112)
(1, 1, 1) 0.0170, 0.002122, 74.99us 0.0316, 0.003946, 160.66us 3.0013, 0.375158, 546.94us 3.6780, 0.459753, 1544.58us 3.9197, 0.489966, 5948.43us
(3, 3, 3) 0.0821, 0.000380, 20.27us 0.1559, 0.000722, 22.29us 0.3133, 0.001450, 54.72us 0.5100, 0.002361, 130.12us 57.0481, 0.264111, 376.71us
(5, 5, 5) 0.3075, 0.000307, 20.57us 0.3680, 0.000368, 20.69us 0.6786, 0.000679, 26.61us 1.1744, 0.001174, 31.77us 2.0654, 0.002065, 128.31us
(7, 7, 7) 0.6512, 0.000237, 20.60us 0.8359, 0.000305, 20.50us 1.3712, 0.000500, 20.75us 1.9472, 0.000710, 20.92us 3.7586, 0.001370, 44.59us
(9, 9, 9) 0.9138, 0.000157, 20.43us 1.4198, 0.000243, 20.58us 2.1018, 0.000360, 20.52us 3.3691, 0.000578, 35.90us 5.9491, 0.001020, 69.16us
(11, 11, 11) 1.6606, 0.000156, 20.63us 2.1599, 0.000203, 20.57us 3.1240, 0.000293, 20.98us 4.8874, 0.000459, 24.65us 8.4780, 0.000796, 56.47us
(13, 13, 13) 2.4987, 0.000142, 20.71us 3.0667, 0.000174, 20.45us 4.6387, 0.000264, 20.76us 6.8187, 0.000388, 25.95us 12.2077, 0.000695, 50.46us
Batch : 8, Channel : 16 / (diff, diff / numel, time)
k \ shape (12, 12, 12) (16, 16, 16) (16, 32, 32) (16, 56, 56) (16, 112, 112)
(1, 1, 1) 0.2635, 0.002059, 75.66us 0.4030, 0.003149, 161.78us 48.0296, 0.375231, 550.46us 58.7787, 0.459209, 1902.41us 62.6966, 0.489817, 7817.48us
(3, 3, 3) 1.2271, 0.000355, 20.72us 2.4185, 0.000700, 26.44us 4.6933, 0.001358, 64.66us 7.7016, 0.002228, 192.69us 912.0736, 0.263910, 593.69us
(5, 5, 5) 4.8716, 0.000304, 24.75us 5.8624, 0.000366, 21.39us 11.0705, 0.000692, 66.94us 18.9280, 0.001183, 104.93us 34.0512, 0.002128, 441.81us
(7, 7, 7) 10.1713, 0.000232, 20.98us 13.2273, 0.000301, 36.26us 21.5426, 0.000491, 52.18us 30.1910, 0.000688, 72.94us 59.8381, 0.001363, 191.52us
(9, 9, 9) 14.4542, 0.000155, 23.85us 22.6579, 0.000243, 30.59us 33.8839, 0.000363, 57.40us 54.3563, 0.000583, 142.53us 95.8123, 0.001027, 309.24us
(11, 11, 11) 26.3348, 0.000155, 30.07us 34.3043, 0.000201, 37.01us 49.8093, 0.000292, 74.04us 78.3720, 0.000460, 110.53us 136.5404, 0.000801, 264.14us
(13, 13, 13) 39.3550, 0.000140, 37.38us 49.3207, 0.000175, 43.51us 74.1139, 0.000264, 83.70us 108.7627, 0.000387, 136.09us 196.5412, 0.000699, 280.16us
Batch : 8, Channel : 54 / (diff, diff / numel, time)
k \ shape (12, 12, 12) (16, 16, 16) (16, 32, 32) (16, 56, 56) (16, 112, 112)
(1, 1, 1) 0.8467, 0.001960, 147.36us 1.3993, 0.003239, 314.95us 162.0182, 0.375042, 1327.22us 198.3226, 0.459080, 3921.79us 211.6123, 0.489843, 15646.94us
(3, 3, 3) 4.3146, 0.000370, 29.23us 8.1125, 0.000696, 74.94us 15.8886, 0.001362, 223.69us 26.2404, 0.002250, 601.33us 3076.5354, 0.263763, 1974.06us
(5, 5, 5) 16.5032, 0.000306, 58.79us 19.6887, 0.000365, 53.79us 37.2731, 0.000690, 192.34us 63.3076, 0.001172, 270.01us 114.8880, 0.002128, 1148.56us
(7, 7, 7) 34.0802, 0.000230, 51.12us 44.4087, 0.000300, 100.93us 72.4613, 0.000489, 161.48us 101.9317, 0.000688, 202.91us 201.8955, 0.001363, 545.33us
(9, 9, 9) 48.8179, 0.000155, 65.78us 76.3465, 0.000242, 87.48us 114.0228, 0.000362, 179.11us 182.9805, 0.000581, 403.66us 322.7040, 0.001025, 894.86us
(11, 11, 11) 88.9993, 0.000155, 88.69us 116.4213, 0.000202, 107.55us 168.3363, 0.000293, 228.71us 264.2232, 0.000460, 322.84us 459.1324, 0.000799, 784.25us
(13, 13, 13) 132.7447, 0.000140, 112.91us 165.4525, 0.000174, 131.08us 249.7127, 0.000263, 266.43us 367.0824, 0.000387, 410.17us 663.1367, 0.000699, 847.87us
Float
half
Discepancies 198.37625122070312 0.4592042852331091 432
Sum : tensor(25110.1484, device='cuda:0') tensor(25104., device='cuda:0', dtype=torch.float16)
```
</p>
</details>
ngimel malfet anjali411
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53607
Reviewed By: mruberry
Differential Revision: D27652337
Pulled By: ngimel
fbshipit-source-id: 6439c0cafe6ca3f761a3f5d058050a55e9a0abd8
1557 lines
58 KiB
C++
1557 lines
58 KiB
C++
#include <torch/csrc/jit/runtime/symbolic_script.h>
|
|
|
|
#include <torch/csrc/jit/frontend/ir_emitter.h>
|
|
#include <torch/csrc/jit/runtime/operator.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace {
|
|
std::mutex lock;
|
|
const std::vector<std::string> functions = {
|
|
R"(
|
|
#### HELPER FUNCTIONS ###
|
|
#### PREFIX: AD_ ###
|
|
#### SCHEMA NOT SAVED IN CACHE ###
|
|
|
|
def AD_unsqueeze_multiple(t,
|
|
dims: List[int],
|
|
n_dims: int):
|
|
seen = [False] * n_dims
|
|
for i in range(len(dims)):
|
|
seen[dims[i]] = True
|
|
|
|
for d in range(n_dims):
|
|
if seen[d]:
|
|
t = t.unsqueeze(d)
|
|
return t
|
|
|
|
def AD_sum_backward(grad,
|
|
sizes: List[int],
|
|
dims: List[int],
|
|
keepdim: bool):
|
|
if not keepdim and len(sizes) > 0:
|
|
if len(dims) == 1:
|
|
return grad.unsqueeze(dims[0]).expand(sizes)
|
|
else:
|
|
res = AD_unsqueeze_multiple(grad, dims, len(sizes))
|
|
return res.expand(sizes)
|
|
else:
|
|
return grad.expand(sizes)
|
|
|
|
def AD_logsumexp_backward(grad, self, result,
|
|
dim: List[int],
|
|
keepdim: bool):
|
|
if not keepdim and self.dim() != 0:
|
|
n_dims = len(self.size())
|
|
grad = AD_unsqueeze_multiple(grad, dim, n_dims)
|
|
result = AD_unsqueeze_multiple(result, dim, n_dims)
|
|
return grad * (self - result).exp()
|
|
|
|
def mean_0(self, *, dtype: Optional[int]):
|
|
self_size = self.size()
|
|
self_numel = self.numel()
|
|
self_scalar_type = self.dtype
|
|
def backward(grad_output):
|
|
return grad_output.expand(self_size).to(self_scalar_type) / self_numel, None
|
|
|
|
return torch.mean(self, dtype=dtype), backward
|
|
|
|
def mean_1(self,
|
|
dim: List[int],
|
|
keepdim: bool,
|
|
*,
|
|
dtype: Optional[int]):
|
|
self_size = self.size()
|
|
self_scalar_type = self.dtype
|
|
def backward(grad_output):
|
|
grad_self = AD_sum_backward(grad_output, self_size, dim, keepdim).to(self_scalar_type) / AD_safe_size(self_size, dim)
|
|
return grad_self, None, None, None
|
|
|
|
return torch.mean(self, dim, keepdim, dtype=dtype), backward
|
|
|
|
def logsumexp(self,
|
|
dim: List[int],
|
|
keepdim: bool):
|
|
result = torch.logsumexp(self, dim, keepdim)
|
|
self_dim = self.dim()
|
|
def backward(grad_output):
|
|
grad_self = AD_logsumexp_backward(grad_output, self, result, dim, keepdim)
|
|
return grad_self, None, None
|
|
|
|
return result, backward
|
|
|
|
def AD_bool_to_int(b: bool):
|
|
# FIXME: torchscript: int - bool
|
|
if b:
|
|
i = 1
|
|
else:
|
|
i = 0
|
|
return i
|
|
|
|
def AD_var_backward_0(grad, self, unbiased: bool):
|
|
b = AD_bool_to_int(unbiased)
|
|
|
|
# FIXME: torchscript: div(float, float)
|
|
return grad * (self - self.mean()) * 2.0 / (self.numel() - b)
|
|
|
|
def AD_safe_size(sizes: List[int],
|
|
dims: List[int]):
|
|
if len(sizes) == 0:
|
|
return 1
|
|
|
|
size = 1
|
|
for i in range(len(dims)):
|
|
d = dims[i]
|
|
size *= sizes[d]
|
|
|
|
return size
|
|
|
|
def AD_var_backward_1(grad,
|
|
self,
|
|
dim: List[int],
|
|
unbiased: bool,
|
|
keepdim: bool):
|
|
if self.dim() == 0:
|
|
return AD_var_backward_0(grad, self, unbiased)
|
|
self_size = self.size()
|
|
b = AD_bool_to_int(unbiased)
|
|
if not keepdim and self.dim() > 1:
|
|
grad = AD_unsqueeze_multiple(grad, dim, len(self_size))
|
|
|
|
# FIXME: torchscript: div(float, float)
|
|
return grad * (self - self.mean(dim, True)) * 2.0 / (AD_safe_size(self_size, dim) - b)
|
|
|
|
def std_0(self,
|
|
unbiased: bool=True):
|
|
std_out = torch.std(self, unbiased)
|
|
def backward(grad_output):
|
|
grad_self = AD_var_backward_0(grad_output / (std_out * 2), self, unbiased)
|
|
return grad_self, None
|
|
|
|
return std_out, backward
|
|
|
|
def std_1(self,
|
|
dim: List[int],
|
|
unbiased: bool,
|
|
keepdim: bool):
|
|
std_out = torch.std(self, dim, unbiased, keepdim)
|
|
def backward(grad_output):
|
|
grad_self = AD_var_backward_1(grad_output / (std_out * 2), self, dim, unbiased, keepdim)
|
|
return grad_self, None, None, None
|
|
|
|
return std_out, backward
|
|
|
|
def var_0(self,
|
|
unbiased: bool=True):
|
|
def backward(grad_output):
|
|
grad_self = AD_var_backward_0(grad_output, self, unbiased)
|
|
return grad_self, None
|
|
|
|
return torch.var(self, unbiased), backward
|
|
|
|
def var_1(self,
|
|
dim: List[int],
|
|
unbiased: bool,
|
|
keepdim: bool):
|
|
def backward(grad_output):
|
|
grad_self = AD_var_backward_1(grad_output, self, dim, unbiased, keepdim)
|
|
return grad_self, None, None, None
|
|
|
|
return torch.var(self, dim, unbiased, keepdim), backward
|
|
|
|
def tanh(self):
|
|
output = torch.tanh(self)
|
|
def backward(grad_output):
|
|
return grad_output * (1 - output * output)
|
|
|
|
return output, backward
|
|
|
|
def AD_index_select_backward(grad,
|
|
dim: int,
|
|
indices,
|
|
sizes: List[int],
|
|
keepdim: bool):
|
|
if not keepdim and len(sizes) > 0:
|
|
grad = grad.unsqueeze(dim)
|
|
indices = indices.unsqueeze(dim)
|
|
|
|
# FIXME: torchscript: torch.zeros(sizes, grad.options())
|
|
return torch.zeros(sizes).to(grad).scatter_(dim, indices, grad)
|
|
|
|
# def topk(self,
|
|
# k: int,
|
|
# dim: int = -1,
|
|
# largest: bool = True,
|
|
# sorted: bool = True):
|
|
# result0, result1 = torch.topk(self, k, dim, largest, sorted)
|
|
# self_size = self.size()
|
|
# def backward(grad_output):
|
|
# grad_self = AD_index_select_backward(grad_output, dim, result1, self_size, True)
|
|
# return grad_self, None, None, None, None
|
|
|
|
# return result0, result1, backward
|
|
|
|
# def kthvalue(self,
|
|
# k: int,
|
|
# dim: int,
|
|
# keepdim: bool):
|
|
# result0, result1 = torch.kthvalue(self, k, dim, keepdim)
|
|
# self_size = self.size()
|
|
# def backward(grad_output):
|
|
# grad_self = AD_index_select_backward(grad_output, dim, result1, self_size, keepdim)
|
|
# return grad_self, None, None, None
|
|
|
|
# return result0, result1, backward
|
|
|
|
def AD_mm_backward_self(grad, mat2):
|
|
return grad.mm(mat2.t())
|
|
|
|
def AD_mm_backward_mat2(grad, self):
|
|
return self.t().mm(grad)
|
|
|
|
def mm(self, mat2):
|
|
def backward(grad_output):
|
|
grad_self = AD_mm_backward_self(grad_output, mat2)
|
|
grad_mat2 = AD_mm_backward_mat2(grad_output, self)
|
|
return grad_self, grad_mat2
|
|
|
|
return torch.mm(self, mat2), backward
|
|
|
|
def AD_permute_backward(grad,
|
|
fwd_dims: List[int]):
|
|
ndims = len(fwd_dims)
|
|
dims = [0] * ndims
|
|
|
|
for i in range(ndims):
|
|
dims[fwd_dims[i]] = i
|
|
|
|
return grad.permute(dims)
|
|
|
|
def permute(self,
|
|
dims: List[int]):
|
|
def backward(grad_output):
|
|
grad_self = AD_permute_backward(grad_output, dims)
|
|
return grad_self, None
|
|
|
|
return torch.permute(self, dims), backward
|
|
|
|
def AD_select_backward(grad,
|
|
input_sizes: List[int],
|
|
dim: int,
|
|
index: int):
|
|
# FIXME: torchscript: torch.zeros(sizes, grad.options())
|
|
grad_input = torch.zeros(input_sizes).to(grad)
|
|
grad_input.select(dim, index).copy_(grad)
|
|
return grad_input
|
|
|
|
# TODO: fix torch.zeros(sizes, grad.options()) before enabling select, topk, kthvalue
|
|
# def select(self,
|
|
# dim: int,
|
|
# index: int):
|
|
# self_size = self.size()
|
|
# def backward(grad_output):
|
|
# grad_self = AD_select_backward(grad_output, self_size, dim, index)
|
|
# return grad_self, None, None
|
|
|
|
# return torch.select(self, dim, index), backward
|
|
|
|
def AD_slice_backward(grad,
|
|
input_sizes: List[int],
|
|
dim: int,
|
|
start: int,
|
|
end: int,
|
|
step: int):
|
|
# FIXME: torchscript: torch.zeros(sizes, grad.options())
|
|
grad_input = torch.zeros(input_sizes).to(grad)
|
|
grad_input.slice(dim, start, end, step).copy_(grad)
|
|
return grad_input
|
|
|
|
# DON'T enable slice unless we can correctly handle view ops in graph executor.
|
|
# It triggers failure of TestJit.test_sample in test_distributions.py.
|
|
# def slice(self,
|
|
# dim: int=0,
|
|
# start: int=0,
|
|
# end: int=9223372036854775807,
|
|
# step: int=1):
|
|
# def backward(grad_output):
|
|
# grad_self = AD_slice_backward(grad_output, self.size(), dim, start, end, step)
|
|
# return grad_self, None, None, None, None
|
|
|
|
# return torch.slice(self, dim, start, end, step), backward
|
|
|
|
def AD_unsqueeze_to_0(self,
|
|
sizes: List[int]):
|
|
ndims = len(sizes)
|
|
for i in range(ndims):
|
|
if sizes[i] == 1:
|
|
self = self.unsqueeze(i)
|
|
|
|
return self
|
|
|
|
def AD_unsqueeze_to_1(self,
|
|
dim: int,
|
|
sizes: List[int]):
|
|
if len(sizes) > 0 and sizes[dim] == 1:
|
|
return self.unsqueeze(dim)
|
|
return self
|
|
|
|
def squeeze_0(self):
|
|
self_size = self.size()
|
|
def backward(grad_output):
|
|
grad_self = AD_unsqueeze_to_0(grad_output, self_size)
|
|
return grad_self
|
|
|
|
return torch.squeeze(self), backward
|
|
|
|
def squeeze_1(self,
|
|
dim: int):
|
|
self_size = self.size()
|
|
def backward(grad_output):
|
|
grad_self = AD_unsqueeze_to_1(grad_output, dim, self_size)
|
|
return grad_self, None
|
|
|
|
return torch.squeeze(self, dim), backward
|
|
|
|
def AD_infer_size(a: List[int],
|
|
b: List[int]):
|
|
dimsA = len(a)
|
|
dimsB = len(b)
|
|
|
|
ndim = dimsA if dimsA > dimsB else dimsB
|
|
expand_sizes = [0] * ndim
|
|
|
|
for i in range(ndim):
|
|
idx = - i + ndim - 1
|
|
sizeA = a[i] if dimsA + i >= 0 else 1
|
|
sizeB = b[i] if dimsB + i >= 0 else 1
|
|
|
|
# Assert sizeA == sizeB or sizeA == 1 or sizeB == 1
|
|
expand_sizes[i] = sizeB if sizeA == 1 else sizeA
|
|
|
|
return expand_sizes
|
|
|
|
def AD_bmm_backward_self(grad, mat2):
|
|
return grad.bmm(mat2.transpose(1, 2))
|
|
|
|
def AD_bmm_backward_mat2(grad, self):
|
|
return self.transpose(1, 2).bmm(grad)
|
|
|
|
def bmm(self, mat2):
|
|
def backward(grad_output):
|
|
grad_self = AD_bmm_backward_self(grad_output, mat2)
|
|
grad_mat2 = AD_bmm_backward_mat2(grad_output, self)
|
|
return grad_self, grad_mat2
|
|
return torch.bmm(self, mat2), backward
|
|
|
|
def AD_mat_transpose(mat):
|
|
dim = mat.dim()
|
|
if dim == 1:
|
|
out = mat
|
|
elif dim == 2:
|
|
out = mat.t()
|
|
else:
|
|
dims = rangelist(dim)
|
|
dims[-1] = dim - 2
|
|
dims[-2] = dim - 1
|
|
out = mat.permute(dims)
|
|
return out
|
|
|
|
# In matmul backward case of [b, m, n] * [b, n, p] => [m, p],
|
|
# instead of doing [b, m, p] and then reduce to [m, p]
|
|
# whice potentially uses large intermediate of size b*m*p,
|
|
# we do [m, bn] * [bn, p] to avoid having the large
|
|
# intermediate, thus reduces max memory usage.
|
|
def AD_matmul_bw_special_fold(mat1, mat2):
|
|
mat1_transpose = AD_mat_transpose(mat1)
|
|
mat1_fold = mat1_transpose.reshape(-1, mat1_transpose.size()[-1])
|
|
mat2_fold = mat2.reshape(-1, mat2.size()[-1])
|
|
return mat1_fold.t().mm(mat2_fold)
|
|
|
|
def AD_matmul_bw_size(mat1, mat2,
|
|
out_size: List[int]):
|
|
dim1 = mat1.dim()
|
|
dim2 = mat2.dim()
|
|
dim_out = len(out_size)
|
|
if dim1 == 0 or dim2 == 0:
|
|
out = mat1 * mat2
|
|
elif dim_out == 2 and dim1 == dim2 and dim1 >=3:
|
|
out = AD_matmul_bw_special_fold(mat1, mat2)
|
|
elif dim_out == 1 and dim1 - dim2 == 1 and dim1 >= 3:
|
|
mat2_unsqueeze = mat2.unsqueeze(-1)
|
|
out = AD_matmul_bw_special_fold(mat1, mat2_unsqueeze)
|
|
out = out.squeeze(-1)
|
|
elif dim1 + dim2 == dim_out:
|
|
if dim2 == 1:
|
|
target_dim2 = 0
|
|
else:
|
|
target_dim2 = -2
|
|
out = torch.matmul(mat1.unsqueeze(dim1), mat2.unsqueeze(target_dim2))
|
|
elif dim_out == dim1 - dim2:
|
|
out = torch.matmul(mat1, mat2.unsqueeze(dim2)).squeeze(-1)
|
|
elif dim_out == dim2 - dim1:
|
|
out = torch.matmul(mat1.unsqueeze(-2), mat2).squeeze(-2)
|
|
else:
|
|
out = torch.matmul(mat1, mat2)
|
|
return out
|
|
|
|
def matmul(self, other):
|
|
def backward(grad_output):
|
|
self_size = self.size()
|
|
other_size = other.size()
|
|
grad_self = AD_matmul_bw_size(grad_output, AD_mat_transpose(other), self_size)._grad_sum_to_size(self_size)
|
|
grad_other = AD_matmul_bw_size(AD_mat_transpose(self), grad_output, other_size)._grad_sum_to_size(other_size)
|
|
return grad_self, grad_other
|
|
|
|
return torch.matmul(self, other), backward
|
|
|
|
def linear(input : Tensor,
|
|
weight : Tensor,
|
|
bias : Optional[Tensor]):
|
|
result = torch.linear(input, weight, bias)
|
|
|
|
def backward(grad_output):
|
|
if bias is not None:
|
|
grad_bias = grad_output._grad_sum_to_size(bias.size())
|
|
else:
|
|
grad_bias = None
|
|
|
|
weight_size = weight.size()
|
|
grad_input = torch.matmul(grad_output, weight)
|
|
grad_weight = torch.matmul(grad_output.reshape(-1, weight_size[0]).t(), input.reshape(-1, weight_size[1]))
|
|
return grad_input, grad_weight, grad_bias
|
|
return result, backward
|
|
)",
|
|
R"(
|
|
def addcmul(self,
|
|
tensor1,
|
|
tensor2,
|
|
*,
|
|
value: number):
|
|
result = torch.addcmul(self, tensor1, tensor2, value=value)
|
|
self_size = torch._size_if_not_equal(self.size(), result.size())
|
|
tensor1_size = torch._size_if_not_equal(tensor1.size(), result.size())
|
|
tensor2_size = torch._size_if_not_equal(tensor2.size(), result.size())
|
|
def backward(grad_output):
|
|
grad = grad_output * value
|
|
grad_tensor1 = (grad * tensor2)._grad_sum_to_size(tensor1_size)
|
|
grad_tensor2 = (grad * tensor1)._grad_sum_to_size(tensor2_size)
|
|
return grad_output._grad_sum_to_size(self_size), grad_tensor1, grad_tensor2, None
|
|
return result, backward
|
|
|
|
def _dim_arange(like,
|
|
dim: int):
|
|
def backward(grad_output):
|
|
return None, None
|
|
|
|
return torch._dim_arange(like, dim), backward
|
|
|
|
def contiguous(self, *, memory_format: int=0):
|
|
def backward(grad_output):
|
|
return grad_output, None
|
|
|
|
return self.contiguous(memory_format=memory_format), backward
|
|
|
|
def dot(self, tensor):
|
|
def backward(grad_output):
|
|
return grad_output * tensor, grad_output * self
|
|
|
|
return torch.dot(self, tensor), backward
|
|
|
|
def erf(self):
|
|
def backward(grad_output):
|
|
# Precomputed constant C = 2.0 / math.sqrt(math.pi)
|
|
C = 1.1283791670955126
|
|
return C * torch.exp(- self * self) * grad_output
|
|
|
|
return torch.erf(self), backward
|
|
|
|
def expand(self,
|
|
size: List[int],
|
|
*,
|
|
implicit: bool=False):
|
|
result = torch.expand(self, size, implicit=implicit)
|
|
self_size = torch._size_if_not_equal(self.size(), result.size())
|
|
|
|
def backward(grad_output):
|
|
return grad_output._grad_sum_to_size(self_size), None, None
|
|
|
|
return result, backward
|
|
|
|
def expand_as(self, other):
|
|
result = torch.expand_as(self, other)
|
|
self_size = torch._size_if_not_equal(self.size(), result.size())
|
|
|
|
def backward(grad_output):
|
|
return grad_output._grad_sum_to_size(self_size), None
|
|
|
|
return result, backward
|
|
|
|
def full_like(self,
|
|
fill_value: float):
|
|
def backward(grad_output):
|
|
return None, None
|
|
|
|
return torch.full_like(self, fill_value, memory_format=1), backward
|
|
|
|
def lerp_0(self,
|
|
end,
|
|
weight: number):
|
|
result = torch.lerp(self, end, weight)
|
|
self_size = torch._size_if_not_equal(self.size(), result.size())
|
|
end_size = torch._size_if_not_equal(end.size(), result.size())
|
|
|
|
def backward(grad_output):
|
|
grad_self = (grad_output * (1 - float(weight)))._grad_sum_to_size(self_size)
|
|
grad_end = (grad_output * float(weight))._grad_sum_to_size(end_size)
|
|
return grad_self, grad_end, None
|
|
return result, backward
|
|
|
|
def lerp_1(self,
|
|
end,
|
|
weight):
|
|
result = torch.lerp(self, end, weight)
|
|
self_size = torch._size_if_not_equal(self.size(), result.size())
|
|
end_size = torch._size_if_not_equal(end.size(), result.size())
|
|
weight_size = torch._size_if_not_equal(weight.size(), result.size())
|
|
|
|
def backward(grad_output):
|
|
grad_self = (grad_output * (1 - weight))._grad_sum_to_size(self_size)
|
|
grad_end = (grad_output * weight)._grad_sum_to_size(end_size)
|
|
grad_weight = (grad_output * (end - self))._grad_sum_to_size(weight_size)
|
|
return grad_self, grad_end, grad_weight
|
|
|
|
return result, backward
|
|
|
|
def reshape(self,
|
|
shape: List[int]):
|
|
self_size = self.size()
|
|
|
|
def backward(grad_output):
|
|
return grad_output.reshape(self_size), None
|
|
|
|
return torch.reshape(self, shape), backward
|
|
|
|
def split(self,
|
|
split_size: int,
|
|
dim: int):
|
|
def backward(grad_outputs: List[Tensor]):
|
|
grad_self = torch.cat(grad_outputs, dim)
|
|
return grad_self, None, None
|
|
|
|
return torch.split(self, split_size, dim), backward
|
|
|
|
def split_with_sizes(self,
|
|
split_sizes: List[int],
|
|
dim: int):
|
|
def backward(grad_outputs: List[Tensor]):
|
|
size = len(grad_outputs)
|
|
grad_self = torch.cat(grad_outputs, dim)
|
|
return grad_self, None, None
|
|
|
|
return torch.split_with_sizes(self, split_sizes, dim), backward
|
|
|
|
def stack(tensors: List[Tensor],
|
|
dim: int=0):
|
|
def backward(grad_output):
|
|
grad_tensors = torch.unbind(grad_output, dim)
|
|
return grad_tensors, None
|
|
|
|
return torch.stack(tensors, dim), backward
|
|
|
|
def unbind(self,
|
|
dim: int):
|
|
def backward(grad_outputs: List[Tensor]):
|
|
grad_self = torch.stack(grad_outputs, dim)
|
|
return grad_self, None
|
|
|
|
return torch.unbind(self, dim), backward
|
|
|
|
def cat(tensors: List[Tensor],
|
|
dim: int):
|
|
size = len(tensors)
|
|
split_sizes = [0] * size
|
|
for i in range(size):
|
|
if tensors[i].numel() > 0:
|
|
split_sizes[i] = tensors[i].size()[dim]
|
|
|
|
def backward(grad_output):
|
|
grad_tensors = torch.split_with_sizes(grad_output, split_sizes, dim)
|
|
return grad_tensors, None
|
|
|
|
return torch.cat(tensors, dim), backward
|
|
|
|
def index(self,
|
|
indices: List[Tensor]):
|
|
def backward(grad_output):
|
|
grad_self = torch.zeros_like(self, memory_format=1).index_put_(indices, grad_output, True)
|
|
return grad_self, None
|
|
|
|
return torch.index(self, indices), backward
|
|
|
|
def meshgrid(tensors: List[Tensor]):
|
|
size = len(tensors)
|
|
sizes = [0] * size
|
|
for i in range(size):
|
|
if tensors[i].dim() != 0:
|
|
sizes[i] = tensors[i].size()[0]
|
|
def backward(grad_outputs: List[Tensor]):
|
|
grads_tensors = []
|
|
for i in range(size):
|
|
view_shape = [1] * size
|
|
if sizes[i] == 0:
|
|
view_shape[i] = 1
|
|
grads_tensors.append((grad_outputs[i]._grad_sum_to_size(view_shape)).reshape(()))
|
|
else:
|
|
view_shape[i] = sizes[i]
|
|
grads_tensors.append((grad_outputs[i]._grad_sum_to_size(view_shape)).reshape([sizes[i]]))
|
|
return grads_tensors
|
|
return torch.meshgrid(tensors), backward
|
|
|
|
def mv(self, vec):
|
|
def backward(grad_output):
|
|
return grad_output.ger(vec), self.t().mv(grad_output)
|
|
|
|
return torch.mv(self, vec), backward
|
|
|
|
def nonzero(self):
|
|
def backward(grad_output):
|
|
return None
|
|
|
|
return torch.nonzero(self), backward
|
|
|
|
def ones_like(self):
|
|
def backward(grad_output):
|
|
return None
|
|
|
|
return torch.ones_like(self, memory_format=1), backward
|
|
|
|
def pow_0(self,
|
|
exponent: number):
|
|
def backward(grad_output):
|
|
if float(exponent) == 0.0:
|
|
grad_self = torch.zeros_like(self, memory_format=1)
|
|
else:
|
|
grad_self = grad_output * exponent * torch.pow(self, float(exponent) - 1)
|
|
return grad_self, None
|
|
|
|
return torch.pow(self, exponent), backward
|
|
|
|
def pow_1(self, exponent):
|
|
result = torch.pow(self, exponent)
|
|
self_size = torch._size_if_not_equal(self.size(), result.size())
|
|
exponent_size = torch._size_if_not_equal(exponent.size(), result.size())
|
|
|
|
def backward(grad_output):
|
|
grad_self = torch.where(exponent == 0.0, torch.zeros_like(self, memory_format=1), grad_output * exponent * torch.pow(self, exponent - 1))._grad_sum_to_size(self_size)
|
|
grad_exponent = (grad_output * torch.pow(self, exponent) * torch.log(self))._grad_sum_to_size(exponent_size)
|
|
return grad_self, grad_exponent
|
|
|
|
return result, backward
|
|
|
|
def pow_2(self: number,
|
|
exponent):
|
|
def backward(grad_output):
|
|
grad_exponent = grad_output * torch.pow(self, exponent) * torch.log(float(self))
|
|
return None, grad_exponent
|
|
|
|
return torch.pow(self, exponent), backward
|
|
|
|
def rsub_0(self,
|
|
other,
|
|
alpha: number):
|
|
result = torch.rsub(self, other, alpha=alpha)
|
|
self_size = torch._size_if_not_equal(self.size(), result.size())
|
|
other_size = torch._size_if_not_equal(other.size(), result.size())
|
|
def backward(grad_output):
|
|
grad_self = (- grad_output * alpha)._grad_sum_to_size(self_size)
|
|
grad_other = (grad_output)._grad_sum_to_size(other_size)
|
|
return grad_self, grad_other, None
|
|
|
|
return result, backward
|
|
|
|
def rsub_1(self,
|
|
other: number,
|
|
alpha: number):
|
|
def backward(grad_output):
|
|
grad_self = (- grad_output * alpha)
|
|
return grad_self, None, None
|
|
|
|
return torch.rsub(self, other, alpha), backward
|
|
|
|
def sqrt(self):
|
|
result = torch.sqrt(self)
|
|
def backward(grad_output):
|
|
return grad_output / (2 * result)
|
|
|
|
return result, backward
|
|
|
|
def t(self):
|
|
def backward(grad_output):
|
|
return torch.t(grad_output)
|
|
|
|
return torch.t(self), backward
|
|
|
|
def to_0(self,
|
|
device: Optional[Device],
|
|
dtype: Optional[int],
|
|
non_blocking: bool,
|
|
copy: bool):
|
|
self_device = self.device
|
|
self_dtype = self.dtype
|
|
if device is not None:
|
|
result = self.to(device, dtype=dtype, non_blocking=non_blocking, copy=copy)
|
|
else:
|
|
result = self.to(dtype, non_blocking=non_blocking, copy=copy)
|
|
def backward(grad_output):
|
|
grad_self = grad_output.to(self_device, dtype=self_dtype, non_blocking=non_blocking, copy=copy)
|
|
return grad_self, None, None, None, None
|
|
|
|
return result, backward
|
|
|
|
|
|
def to_1(self,
|
|
dtype: int,
|
|
non_blocking: bool,
|
|
copy: bool):
|
|
self_dtype = self.dtype
|
|
def backward(grad_output):
|
|
grad_self = grad_output.to(self_dtype, non_blocking, copy)
|
|
return grad_self, None, None, None
|
|
|
|
return self.to(dtype=dtype, non_blocking=non_blocking, copy=copy), backward
|
|
|
|
def to_2(self,
|
|
other,
|
|
non_blocking: bool,
|
|
copy: bool):
|
|
def backward(grad_output):
|
|
grad_self = grad_output.to(self, non_blocking, copy)
|
|
return grad_self, None, None, None
|
|
|
|
return self.to(other, non_blocking=non_blocking, copy=copy), backward
|
|
|
|
def transpose(self,
|
|
dim0: int,
|
|
dim1: int):
|
|
def backward(grad_output):
|
|
return torch.transpose(grad_output, dim0, dim1), None, None
|
|
|
|
return torch.transpose(self, dim0, dim1), backward
|
|
|
|
def view(self,
|
|
size: List[int]):
|
|
self_size = self.size()
|
|
def backward(grad_output):
|
|
return grad_output.reshape(self_size), None
|
|
|
|
return torch.view(self, size), backward
|
|
)",
|
|
R"(
|
|
def AD_sizes_if_not_equal_multi_0(t1, t2, res):
|
|
return torch._size_if_not_equal(t1.size(), res.size()), torch._size_if_not_equal(t2.size(), res.size())
|
|
|
|
def mul_0(self, other):
|
|
result = self * other
|
|
self_size, other_size = AD_sizes_if_not_equal_multi_0(self, other, result)
|
|
|
|
def backward(grad_output):
|
|
grad_self = (grad_output * other)._grad_sum_to_size(self_size)
|
|
grad_other = (grad_output * self)._grad_sum_to_size(other_size)
|
|
return grad_self, grad_other
|
|
|
|
return result, backward
|
|
|
|
def mul_1(self, other: number):
|
|
def backward(grad_output):
|
|
return grad_output * other, None
|
|
return self * other, backward
|
|
|
|
def div_0(self, other):
|
|
result = self / other
|
|
self_size, other_size = AD_sizes_if_not_equal_multi_0(self, other, result)
|
|
|
|
def backward(grad_output):
|
|
grad_self = (grad_output / other)._grad_sum_to_size(self_size)
|
|
grad_other = (-grad_output * self / (other * other))._grad_sum_to_size(other_size)
|
|
return grad_self, grad_other
|
|
|
|
return result, backward
|
|
|
|
def div_1(self, other: number):
|
|
def backward(grad_output):
|
|
return grad_output / other, None
|
|
return self / other, backward
|
|
|
|
def div_2(self, other, *, rounding_mode: Optional[str]):
|
|
result = torch.div(self, other, rounding_mode=rounding_mode)
|
|
self_size, other_size = AD_sizes_if_not_equal_multi_0(self, other, result)
|
|
def backward(grad_output):
|
|
if rounding_mode is None:
|
|
grad_self = (grad_output / other)._grad_sum_to_size(self_size)
|
|
grad_other = (-grad_output * self / (other * other))._grad_sum_to_size(other_size)
|
|
else:
|
|
grad_self = torch.zeros_like(self)
|
|
grad_other = torch.zeros_like(other)
|
|
|
|
return grad_self, grad_other, None
|
|
|
|
return result, backward
|
|
|
|
def div_3(self, other: number, *, rounding_mode: Optional[str]):
|
|
result = torch.div(self, other, rounding_mode=rounding_mode)
|
|
def backward(grad_output):
|
|
if rounding_mode is None:
|
|
grad_self = (grad_output / other)
|
|
else:
|
|
grad_self = torch.zeros_like(self, memory_format=1)
|
|
return grad_self, None, None
|
|
return result, backward
|
|
|
|
def max(self, other):
|
|
result = torch.max(self, other)
|
|
self_size, other_size = AD_sizes_if_not_equal_multi_0(self, other, result)
|
|
|
|
def backward(grad_output):
|
|
grad_self = (grad_output * (self > other).type_as(grad_output))._grad_sum_to_size(self_size)
|
|
grad_other = (grad_output * (other > self).type_as(grad_output))._grad_sum_to_size(other_size)
|
|
return grad_self, grad_other
|
|
|
|
return result, backward
|
|
|
|
def min(self, other):
|
|
def backward(grad_output):
|
|
grad_self = (grad_output * (self < other).type_as(grad_output))._grad_sum_to_size(self.size())
|
|
grad_other = (grad_output * (other < self).type_as(grad_output))._grad_sum_to_size(other.size())
|
|
return grad_self, grad_other
|
|
|
|
return torch.min(self, other), backward
|
|
|
|
def sigmoid(self):
|
|
result = torch.sigmoid(self)
|
|
def backward(grad_output):
|
|
return (1 - result) * result * grad_output
|
|
|
|
return result, backward
|
|
|
|
# Share backward with threshold
|
|
def relu(self):
|
|
result = torch.relu(self)
|
|
def backward(grad_output):
|
|
return grad_output * (result > 0).type_as(result)
|
|
|
|
return result, backward
|
|
|
|
def erfc(self):
|
|
def backward(grad_output):
|
|
# Precomputed constant C = -2.0 / math.sqrt(math.pi)
|
|
C = -1.1283791670955126
|
|
return C * torch.exp(-self * self) * grad_output
|
|
|
|
return torch.erfc(self), backward
|
|
|
|
def exp(self):
|
|
result = torch.exp(self)
|
|
def backward(grad_output):
|
|
return grad_output * result
|
|
|
|
return result, backward
|
|
|
|
def neg(self):
|
|
def backward(grad_output):
|
|
return grad_output.neg()
|
|
|
|
return torch.neg(self), backward
|
|
|
|
def where(condition, self, other):
|
|
result = torch.where(condition, self, other)
|
|
self_size, other_size = AD_sizes_if_not_equal_multi_0(self, other, result)
|
|
def backward(grad_output):
|
|
grad_self = (grad_output * condition.type_as(grad_output))._grad_sum_to_size(self_size)
|
|
grad_other = (grad_output * (condition.bitwise_not()).type_as(grad_output))._grad_sum_to_size(other_size)
|
|
return None, grad_self, grad_other
|
|
|
|
return result, backward
|
|
|
|
def type_as(self, other):
|
|
def backward(grad_output):
|
|
return grad_output.type_as(self), None
|
|
|
|
return torch.type_as(self, other), backward
|
|
|
|
def unsqueeze(self, dim: int):
|
|
def backward(grad_output):
|
|
return grad_output.squeeze(dim), None
|
|
|
|
return torch.unsqueeze(self, dim), backward
|
|
|
|
def abs(self):
|
|
def backward(grad_output):
|
|
return grad_output * self.sign()
|
|
|
|
return torch.abs(self), backward
|
|
|
|
def acos(self):
|
|
def backward(grad_output):
|
|
return grad_output * -((-self * self + 1).rsqrt())
|
|
|
|
return torch.acos(self), backward
|
|
|
|
def asin(self):
|
|
def backward(grad_output):
|
|
return grad_output * (-self * self + 1).rsqrt()
|
|
|
|
return torch.asin(self), backward
|
|
|
|
def atan(self):
|
|
def backward(grad_output):
|
|
return grad_output / (self * self + 1)
|
|
|
|
return torch.atan(self), backward
|
|
|
|
def ceil(self):
|
|
def backward(grad_output):
|
|
return torch.zeros_like(grad_output, memory_format=1)
|
|
|
|
return torch.ceil(self), backward
|
|
|
|
def cos(self):
|
|
def backward(grad_output):
|
|
return grad_output * -self.sin()
|
|
|
|
return torch.cos(self), backward
|
|
|
|
def cosh(self):
|
|
def backward(grad_output):
|
|
return grad_output * self.sinh()
|
|
|
|
return torch.cosh(self), backward
|
|
|
|
def expm1(self):
|
|
result = torch.expm1(self)
|
|
def backward(grad_output):
|
|
return grad_output * (result + 1)
|
|
|
|
return result, backward
|
|
|
|
def floor(self):
|
|
def backward(grad_output):
|
|
return torch.zeros_like(grad_output, memory_format=1)
|
|
|
|
return torch.floor(self), backward
|
|
|
|
def frac(self):
|
|
def backward(grad_output):
|
|
return grad_output
|
|
|
|
return torch.frac(self), backward
|
|
|
|
def log(self):
|
|
def backward(grad_output):
|
|
return grad_output.div(self)
|
|
|
|
return torch.log(self), backward
|
|
|
|
def log10(self):
|
|
def backward(grad_output):
|
|
return grad_output / (self * 2.3025850929940456)
|
|
|
|
return torch.log10(self), backward
|
|
|
|
def log1p(self):
|
|
def backward(grad_output):
|
|
return grad_output / (self + 1)
|
|
|
|
return torch.log1p(self), backward
|
|
|
|
def log2(self):
|
|
def backward(grad_output):
|
|
return grad_output / (self * 0.6931471805599453)
|
|
|
|
return torch.log2(self), backward
|
|
|
|
def rand_like(self, *, memory_format: Optional[int]):
|
|
def backward(grad_output):
|
|
return None
|
|
|
|
return torch.rand_like(self, memory_format=memory_format), backward
|
|
|
|
def reciprocal(self):
|
|
result = torch.reciprocal(self)
|
|
def backward(grad_output):
|
|
return -grad_output * result * result
|
|
|
|
return result, backward
|
|
|
|
def round(self):
|
|
def backward(grad_output):
|
|
return torch.zeros_like(grad_output, memory_format=1)
|
|
|
|
return torch.round(self), backward
|
|
|
|
def rsqrt(self):
|
|
result = torch.rsqrt(self)
|
|
def backward(grad_output):
|
|
return -grad_output * result * result * result / 2
|
|
|
|
return result, backward
|
|
|
|
def sin(self):
|
|
def backward(grad_output):
|
|
return grad_output * self.cos()
|
|
|
|
return torch.sin(self), backward
|
|
|
|
def sinh(self):
|
|
def backward(grad_output):
|
|
return grad_output * self.cosh()
|
|
|
|
return torch.sinh(self), backward
|
|
|
|
def tan(self):
|
|
result = torch.tan(self)
|
|
def backward(grad_output):
|
|
return grad_output * (1. + result * result)
|
|
|
|
return result, backward
|
|
|
|
def trunc(self):
|
|
def backward(grad_output):
|
|
return torch.zeros_like(grad_output, memory_format=1)
|
|
|
|
return torch.trunc(self), backward
|
|
|
|
def _grad_sum_to_size(self,
|
|
size: Optional[List[int]]):
|
|
result = torch._grad_sum_to_size(self, size)
|
|
self_size = torch._size_if_not_equal(self.size(), result.size())
|
|
|
|
def backward(grad_output):
|
|
if self_size is None:
|
|
grad_input = grad_output
|
|
else:
|
|
grad_input = grad_output.expand(self_size)
|
|
return grad_input, None
|
|
|
|
return result, backward
|
|
)",
|
|
R"(
|
|
def batch_norm_disabled(input : Tensor,
|
|
weight : Optional[Tensor],
|
|
bias : Optional[Tensor],
|
|
running_mean : Optional[Tensor],
|
|
running_var : Optional[Tensor],
|
|
training : bool,
|
|
momentum : float,
|
|
eps : float,
|
|
cudnn_enabled : bool):
|
|
|
|
output, save1, save2, reserve, impl_idx = torch._batch_norm_impl_index(
|
|
input, weight, bias, running_mean, running_var, training,
|
|
momentum, eps, cudnn_enabled)
|
|
has_weight = weight is not None
|
|
has_bias = bias is not None
|
|
|
|
def backward(grad_output):
|
|
dinput, dweight, dbias = torch._batch_norm_impl_index_backward(
|
|
impl_idx, input, grad_output, weight, running_mean, running_var,
|
|
save1, save2, training, eps, [True, has_weight, has_bias], reserve)
|
|
return dinput, dweight, dbias, None, None, None, None, None, None
|
|
|
|
return output, backward
|
|
|
|
# disable the layernorm AD temporarily because of bug in https://github.com/pytorch/pytorch/issues/19769
|
|
def layer_norm_disabled(input : Tensor,
|
|
normalized_shape : List[int],
|
|
weight : Optional[Tensor],
|
|
bias : Optional[Tensor],
|
|
eps : float,
|
|
cudnn_enable : bool):
|
|
|
|
input_ndim = input.dim()
|
|
normalized_ndim = len(normalized_shape)
|
|
n = 1
|
|
for i in range(input_ndim - normalized_ndim):
|
|
n *= input.size(i)
|
|
|
|
input_reshape = input.contiguous().view(1, n, -1)
|
|
|
|
bn_out, save1, save2, reserve, impl_idx = torch._batch_norm_impl_index(
|
|
input_reshape, None, None, None, None, True,
|
|
0.0, eps, cudnn_enable)
|
|
|
|
bn_out = bn_out.view(input.size())
|
|
if weight is not None and bias is not None:
|
|
output = bias.addcmul(bn_out, weight, value=1)
|
|
elif weight is not None:
|
|
output = bn_out.mul(weight)
|
|
elif bias is not None:
|
|
output = bn_out.add(bias)
|
|
else:
|
|
output = bn_out
|
|
|
|
def backward(grad_output):
|
|
if weight is not None and bias is not None:
|
|
grad_bn_out = grad_output * weight
|
|
grad_weight = (grad_output * bn_out)._grad_sum_to_size(weight.size())
|
|
grad_bias = grad_output._grad_sum_to_size(bias.size())
|
|
elif weight is not None:
|
|
grad_bn_out = grad_output * weight
|
|
grad_weight = (grad_output * bn_out)._grad_sum_to_size(weight.size())
|
|
grad_bias = None
|
|
elif bias is not None:
|
|
grad_bn_out = grad_output
|
|
grad_weight= None
|
|
grad_bias = grad_output._grad_sum_to_size(bias.size())
|
|
else:
|
|
grad_bn_out = grad_output
|
|
grad_weight= None
|
|
grad_bias = None
|
|
|
|
|
|
grad_bn_out = grad_bn_out.contiguous().view(1, n, -1)
|
|
|
|
grad_input, _, _ = torch._batch_norm_impl_index_backward(
|
|
impl_idx, input_reshape, grad_bn_out, None, None, None,
|
|
save1, save2, True, eps, [True, False, False], reserve)
|
|
|
|
grad_input = grad_input.view(input.size())
|
|
return grad_input, None, grad_weight, grad_bias, None, None
|
|
|
|
return output, backward
|
|
|
|
def AD_fused_dropout_backward(grad,
|
|
mask,
|
|
p1m: float):
|
|
p1r = 1. / p1m
|
|
grad_input = grad * (mask.type_as(grad) * p1r)
|
|
return grad_input
|
|
|
|
def dropout(input,
|
|
p: float,
|
|
train: bool):
|
|
use_cuda = input.is_cuda
|
|
# lowering is specialized for cuda because cuda fuser can efficiently fuse those operations
|
|
# for cpu backend, where fusions are disabled, a different lowering that is more efficient
|
|
# in the absence of fusion is used
|
|
p1m = 1. - p
|
|
if train:
|
|
if use_cuda:
|
|
mask = torch.rand_like(input, memory_format=1) < p1m
|
|
res = mask.type_as(input) * input * (1./p1m)
|
|
else:
|
|
mask = torch.empty_like(input, memory_format=1)
|
|
mask.bernoulli_(p1m)
|
|
res = mask * input / p1m
|
|
else:
|
|
p1m = 1.
|
|
res = input
|
|
mask = torch.empty_like(input, memory_format=1)
|
|
|
|
def backward(grad_output):
|
|
use_cuda = grad_output.is_cuda
|
|
if use_cuda:
|
|
grad_input = AD_fused_dropout_backward(grad_output, mask, p1m)
|
|
else:
|
|
grad_input = grad_output * mask / p1m
|
|
return grad_input, None, None
|
|
return res, backward
|
|
|
|
def embedding(weight,
|
|
indices,
|
|
padding_idx: int,
|
|
scale_grad_by_freq: bool,
|
|
sparse: bool):
|
|
weight_size_0 = weight.size()[0]
|
|
def backward(grad_output):
|
|
grad_weight = torch.embedding_backward(grad_output, indices, weight_size_0, padding_idx, scale_grad_by_freq, sparse)
|
|
return grad_weight, None, None, None, None
|
|
|
|
return torch.embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse), backward
|
|
|
|
def log_softmax(self, dim: int, dtype: Optional[int]):
|
|
result = torch.log_softmax(self, dim, dtype)
|
|
def backward(grad_output):
|
|
grad_self = torch._log_softmax_backward_data(grad_output, result, dim, self)
|
|
return grad_self, None, None
|
|
|
|
return result, backward
|
|
|
|
def nll_loss(self, target, weight: Optional[Tensor], reduction: int, ignore_index: int):
|
|
result, total_weight = torch.nll_loss_forward(self, target, weight, reduction, ignore_index)
|
|
def backward(grad):
|
|
return torch.nll_loss_backward(grad, self, target, weight, reduction, ignore_index, total_weight), None, None, None, None
|
|
return result, backward
|
|
|
|
def softmax(self, dim: int, dtype: Optional[int]):
|
|
result = torch.softmax(self, dim, dtype)
|
|
def backward(grad_output):
|
|
grad_self = torch._softmax_backward_data(grad_output, result, dim, self)
|
|
return grad_self, None, None
|
|
|
|
return result, backward
|
|
)",
|
|
R"(
|
|
def AD_adaptive_avg_pool3d_backward(grad,
|
|
self,
|
|
output_size: List[int]):
|
|
if output_size[0] == 1 and output_size[1] == 1 and output_size[2] == 1:
|
|
self_size = self.size()
|
|
grad_self = grad.expand(self.size()) / (self_size[-1] * self_size[-2] * self_size[-3])
|
|
else:
|
|
grad_self = torch._adaptive_avg_pool3d_backward(grad, self)
|
|
|
|
return grad_self
|
|
|
|
def AD_adaptive_avg_pool2d_backward(grad,
|
|
self,
|
|
output_size: List[int]):
|
|
if output_size[0] == 1 and output_size[1] == 1:
|
|
self_size = self.size()
|
|
grad_self = grad.expand(self.size()) / (self_size[-1] * self_size[-2])
|
|
else:
|
|
grad_self = torch._adaptive_avg_pool2d_backward(grad, self)
|
|
|
|
return grad_self
|
|
|
|
def AD_adaptive_avg_pool1d_backward(grad,
|
|
input,
|
|
output_size: List[int]):
|
|
output_size_2d = [1, output_size[0]]
|
|
grad_input = AD_adaptive_avg_pool2d_backward(grad.unsqueeze(2), input.unsqueeze(2), output_size_2d).squeeze(2)
|
|
return grad_input
|
|
|
|
def adaptive_avg_pool1d(self,
|
|
output_size: List[int]):
|
|
def backward(grad_output):
|
|
grad_self = AD_adaptive_avg_pool1d_backward(grad_output, self, output_size)
|
|
return grad_self, None
|
|
|
|
return torch.adaptive_avg_pool1d(self, output_size), backward
|
|
|
|
def adaptive_avg_pool2d(self,
|
|
output_size: List[int]):
|
|
def backward(grad_output):
|
|
# self is used in backward, no need to pass in its size explicitly
|
|
grad_self = AD_adaptive_avg_pool2d_backward(grad_output, self, output_size)
|
|
return grad_self, None
|
|
return torch.adaptive_avg_pool2d(self, output_size), backward
|
|
|
|
def adaptive_avg_pool3d(self,
|
|
output_size: List[int]):
|
|
def backward(grad_output):
|
|
grad_self = AD_adaptive_avg_pool3d_backward(grad_output, self, output_size)
|
|
return grad_self, None
|
|
|
|
return torch.adaptive_avg_pool3d(self, output_size), backward
|
|
|
|
def avg_pool2d(self,
|
|
kernel_size: List[int],
|
|
stride: List[int],
|
|
padding: List[int],
|
|
ceil_mode: bool,
|
|
count_include_pad: bool,
|
|
divisor_override: Optional[int]):
|
|
def backward(grad_output):
|
|
grad_self = torch.avg_pool2d_backward(grad_output, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)
|
|
return grad_self, None, None, None, None, None, None
|
|
|
|
return torch.avg_pool2d(self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override), backward
|
|
|
|
def max_pool2d(self,
|
|
kernel_size: List[int],
|
|
stride: List[int],
|
|
padding: List[int],
|
|
dilation: List[int],
|
|
ceil_mode: bool):
|
|
output, indices = torch.max_pool2d_with_indices(self, kernel_size, stride, padding, dilation, ceil_mode)
|
|
def backward(grad_output):
|
|
grad_self = torch.max_pool2d_with_indices_backward(grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices)
|
|
return grad_self, None, None, None, None, None
|
|
return output, backward
|
|
|
|
def max_pool2d_with_indices(self,
|
|
kernel_size: List[int],
|
|
stride: List[int],
|
|
padding: List[int],
|
|
dilation: List[int],
|
|
ceil_mode: bool):
|
|
output, indices = torch.max_pool2d_with_indices(self, kernel_size, stride, padding, dilation, ceil_mode)
|
|
def backward(grad_output):
|
|
grad_self = torch.max_pool2d_with_indices_backward(grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices)
|
|
return grad_self, None, None, None, None, None
|
|
return output, indices, backward
|
|
)",
|
|
R"(
|
|
def AD_sizes_if_not_equal_multi_1(t1, t2, res):
|
|
return torch._size_if_not_equal(t1.size(), res.size()), torch._size_if_not_equal(t2.size(), res.size())
|
|
|
|
def add_0(self,
|
|
other,
|
|
*,
|
|
alpha: number):
|
|
result = torch.add(self, other, alpha=alpha)
|
|
self_size, other_size = AD_sizes_if_not_equal_multi_1(self, other, result)
|
|
def backward(grad_output):
|
|
grad_other = (grad_output * alpha)._grad_sum_to_size(other_size)
|
|
grad_self = (grad_output)._grad_sum_to_size(self_size)
|
|
return grad_self, grad_other, None
|
|
return result, backward
|
|
|
|
def add_1(self,
|
|
other: number,
|
|
alpha: number):
|
|
def backward(grad_output):
|
|
return grad_output, None, None
|
|
return torch.add(self, other, alpha=alpha), backward
|
|
|
|
def sub_0(self,
|
|
other,
|
|
*,
|
|
alpha: number):
|
|
result = torch.sub(self, other, alpha=alpha)
|
|
self_size, other_size = AD_sizes_if_not_equal_multi_1(self, other, result)
|
|
def backward(grad_output):
|
|
grad_other = (-grad_output * alpha)._grad_sum_to_size(other_size)
|
|
grad_self = (grad_output)._grad_sum_to_size(self_size)
|
|
return grad_self, grad_other, None
|
|
return result , backward
|
|
|
|
def sub_1(self,
|
|
other: number,
|
|
alpha: number):
|
|
def backward(grad_output):
|
|
return grad_output, None, None
|
|
return torch.sub(self, other, alpha=alpha), backward
|
|
|
|
def threshold(self,
|
|
threshold: number,
|
|
value: number):
|
|
def backward(grad_output):
|
|
mask = (self >= threshold).type_as(self)
|
|
return grad_output * mask, None, None
|
|
return torch.threshold(self, threshold, value), backward
|
|
|
|
def fmod(self,
|
|
other: number):
|
|
def backward(grad_output):
|
|
return grad_output, None
|
|
return torch.fmod(self, other), backward
|
|
|
|
def remainder(self,
|
|
other: number):
|
|
def backward(grad_output):
|
|
return grad_output, None
|
|
return torch.remainder(self, other), backward
|
|
|
|
def addmm(self,
|
|
mat1,
|
|
mat2,
|
|
*,
|
|
beta: number,
|
|
alpha: number):
|
|
result = torch.addmm(self, mat1, mat2, beta=beta, alpha=alpha)
|
|
self_size = torch._size_if_not_equal(self.size(), result.size())
|
|
def backward(grad_output):
|
|
self_grad = (grad_output * beta)._grad_sum_to_size(self_size)
|
|
mat1_grad = grad_output.mm(mat2.t()) * alpha
|
|
mat2_grad = mat1.t().mm(grad_output) * alpha
|
|
return self_grad, mat1_grad, mat2_grad, None, None
|
|
return result, backward
|
|
|
|
# Comparison operators
|
|
def lt(self, other: number):
|
|
def backward(grad_output):
|
|
return None, None
|
|
return torch.lt(self, other), backward
|
|
|
|
def le(self, other: number):
|
|
def backward(grad_output):
|
|
return None, None
|
|
return torch.le(self, other), backward
|
|
|
|
def gt(self, other: number):
|
|
def backward(grad_output):
|
|
return None, None
|
|
return torch.gt(self, other), backward
|
|
|
|
def ge(self, other: number):
|
|
def backward(grad_output):
|
|
return None, None
|
|
return torch.ge(self, other), backward
|
|
|
|
def eq(self, other: number):
|
|
def backward(grad_output):
|
|
return None, None
|
|
return torch.eq(self, other), backward
|
|
|
|
def ne(self, other: number):
|
|
def backward(grad_output):
|
|
return None, None
|
|
return torch.ne(self, other), backward
|
|
|
|
def clamp(self,
|
|
min: Optional[number],
|
|
max: Optional[number]):
|
|
def backward(grad_output):
|
|
if min is not None and max is not None:
|
|
mask = ((self >= float(min)) * (self <= float(max))).type_as(self)
|
|
return grad_output * mask, None, None
|
|
elif min is not None:
|
|
mask = (self >= float(min)).type_as(self)
|
|
return grad_output * mask, None, None
|
|
elif max is not None:
|
|
mask = (self <= float(max)).type_as(self)
|
|
return grad_output * mask, None, None
|
|
else: #min is None and max is None
|
|
return grad_output, None, None
|
|
return torch.clamp(self, min=min, max=max), backward
|
|
)"};
|
|
|
|
std::unordered_map<std::string, GradientPair> schema_to_graphs;
|
|
|
|
// This map is a workaround to cache compiled gradient_pairs. Ideally this graph
|
|
// should be compiled only once and saved in Operator structure.
|
|
// This should be done along with merging into native_functions.yaml.
|
|
std::unordered_map<const FunctionSchema*, GradientPair> cached_gradient_pairs;
|
|
|
|
// CompilationUnit that holds all these Functions and keeps them alive.
|
|
CompilationUnit compilation_unit;
|
|
} // anonymous namespace
|
|
|
|
std::pair<std::shared_ptr<Graph>, Value*> extractClosure(Value* closure) {
|
|
TORCH_CHECK(
|
|
closure->node()->kind() == prim::TupleConstruct,
|
|
"closure must be a literal tuple construct");
|
|
Value* fn = closure->node()->inputs().at(0);
|
|
Value* context = closure->node()->inputs().at(1);
|
|
|
|
TORCH_CHECK(
|
|
fn->node()->kind() == prim::Closure,
|
|
"closure tuple must contain a prim::Closure");
|
|
return std::make_pair(fn->node()->g(attr::Subgraph), context);
|
|
}
|
|
|
|
Argument originalReturnType(const TupleTypePtr& tup) {
|
|
TORCH_CHECK(tup->elements().size() > 1);
|
|
if (tup->elements().size() == 2)
|
|
return Argument("", tup->elements().at(0));
|
|
std::vector<TypePtr> types = tup->elements().vec();
|
|
types.pop_back();
|
|
return Argument("", TupleType::create(std::move(types)));
|
|
}
|
|
|
|
// In torchscript AD formulas, we define {func_0, func_1, ...} as
|
|
// overloaded functions of `func`.
|
|
// Remove the suffix before adding the schema string to map
|
|
// schema_to_graphs.
|
|
std::string overloadedSchemaString(const FunctionSchema& schema) {
|
|
const auto& schema_name = schema.name();
|
|
auto pos = schema_name.find_last_of('_');
|
|
auto schema_name_suffix = schema_name.substr(pos + 1);
|
|
std::string schema_string = canonicalSchemaString(schema);
|
|
if (!schema_name_suffix.empty() &&
|
|
schema_name_suffix.find_first_not_of("0123456789") == std::string::npos) {
|
|
schema_string.replace(
|
|
schema_string.find(schema_name),
|
|
schema_name.length(),
|
|
schema_name.substr(0, pos));
|
|
}
|
|
|
|
return schema_string;
|
|
}
|
|
|
|
bool isHelperFunction(const std::string& method_name) {
|
|
std::string helper_prefix = "AD_";
|
|
return method_name.compare(0, helper_prefix.length(), helper_prefix) == 0;
|
|
}
|
|
|
|
void loadModule(const CompilationUnit& module) {
|
|
for (const auto& method : module.get_functions()) {
|
|
if (isHelperFunction(method->name()))
|
|
continue;
|
|
|
|
GradientPair pair;
|
|
pair.forward = method->graph();
|
|
|
|
// lookup the backward function
|
|
Node* forward_tuple = pair.forward->outputs().at(0)->node();
|
|
|
|
if (forward_tuple->kind() != prim::TupleConstruct) {
|
|
throw ErrorReport(forward_tuple->sourceRange())
|
|
<< "gradient must return literal a tuple";
|
|
}
|
|
|
|
Value* context;
|
|
std::tie(pair.backward, context) =
|
|
extractClosure(forward_tuple->inputs().back());
|
|
|
|
// do surgery on the forward function to remove the closure tuple and
|
|
// replace it with the context variable:
|
|
// backward = (<lambda>, context_tuple)
|
|
// return original, backward
|
|
// -----
|
|
// return original, context_tuple
|
|
std::vector<Value*> new_inputs = forward_tuple->inputs().vec();
|
|
new_inputs.back() = context;
|
|
Value* new_tuple =
|
|
pair.forward->appendNode(pair.forward->createTuple(new_inputs))
|
|
->output();
|
|
pair.forward->eraseOutput(0);
|
|
pair.forward->registerOutput(new_tuple);
|
|
forward_tuple->destroy();
|
|
|
|
// derive schema from original function's schema:
|
|
const FunctionSchema& loaded_schema = method->getSchema();
|
|
FunctionSchema actual_schema(
|
|
Symbol::aten(loaded_schema.name()),
|
|
loaded_schema.overload_name(),
|
|
loaded_schema.arguments(),
|
|
{originalReturnType(new_tuple->type()->expect<TupleType>())});
|
|
|
|
// modify canonical string for function overloading
|
|
// prefer not to modify the schema name
|
|
auto schema_string = overloadedSchemaString(actual_schema);
|
|
|
|
schema_to_graphs[schema_string] = std::move(pair);
|
|
}
|
|
}
|
|
|
|
void loadFunctions() {
|
|
for (const std::string& str : functions) {
|
|
compilation_unit.define(c10::nullopt, str, nativeResolver(), nullptr);
|
|
}
|
|
loadModule(compilation_unit);
|
|
}
|
|
|
|
c10::optional<GradientPair> gradientInfoForSchema(
|
|
const FunctionSchema& schema) {
|
|
std::lock_guard<std::mutex> guard(lock);
|
|
if (schema_to_graphs.size() == 0) {
|
|
loadFunctions();
|
|
}
|
|
auto cache_it = cached_gradient_pairs.find(&schema);
|
|
if (cache_it != cached_gradient_pairs.end()) {
|
|
return cache_it->second;
|
|
} else {
|
|
auto schema_str = canonicalSchemaString(schema);
|
|
// For debugging AD change:
|
|
// std::cout << "Looking for " << schema_str << std::endl;
|
|
|
|
auto sym_script_it = schema_to_graphs.find(schema_str);
|
|
|
|
if (sym_script_it != schema_to_graphs.end()) {
|
|
cached_gradient_pairs.emplace_hint(
|
|
cache_it, &schema, sym_script_it->second);
|
|
return sym_script_it->second;
|
|
}
|
|
}
|
|
return c10::nullopt;
|
|
}
|
|
|
|
bool hasGradientInfoForSchema(const FunctionSchema& schema) {
|
|
return gradientInfoForSchema(schema).has_value();
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|