Commit Graph

40 Commits

Author SHA1 Message Date
jjsjann123
1ec732bc46 Add fp16/fp32 autocasting to JIT/TorchScript (#63939)
Summary:
Adds mixed precision autocasting support between fp32/fp16 to torchscript/JIT. More in depth descriptoin can be found at [torch/csrc/jit/JIT-AUTOCAST.md](https://github.com/pytorch/pytorch/pull/63939/files#diff-1f1772aaa508841c5bb58b74ab98f49a1e577612cd9ea5c386c8714a75db830b)

This PR implemented an autocast optimization pass that inserts casting ops per AMP rule (torch/csrc/jit/passes/autocast.cpp), that mimics the behavior of eager autocast. The pass also takes into consideration the context of `torch.cuda.amp.autocast` and only inserts casting ops within the enabled context manager, giving feature parity as with eager amp autocast.

We currently provide JIT AMP autocast as a prototyping feature, so it is default off and could be turned on via `torch._C._jit_set_autocast_mode(True)`

The JIT support for autocast is subject to different constraints compared to the eager mode implementation (mostly related to the fact that TorchScript is statically typed), restriction on the user facing python code is described in doc torch/csrc/jit/JIT-AUTOCAST.md

This is a prototype, there are also implementation limitation that's necessary to keep this PR small and get something functioning quickly on upstream, so we can iterate on designs.

Few limitation/challenge that is not properly resolved in this PR:
1. Autocast inserts cast operation, which would have impact on scalar type of output tensor feeding downstream operations. We are not currently propagating the updated scalar types, this would give issues/wrong results on operations in promotion rules.

2. Backward for autodiff in JIT misses the casting of dgrad to input scalar type, as what autograd does in eager. This forces us to explicitly mark the casting operation for certain operations (e.g. binary ops), otherwise, we might be feeding dgrad with mismatch scalar type to input. This could potentially break gradient function consuming dgrad. (e.g. gemm backwards, which assumes grad_output to be of same scalar type as input')

3. `torch.autocast` api has an optional argument `dtype` which is not currently supported in the JIT autocast and we require a static value.

Credit goes mostly to:
tlemo
kevinstephano

Pull Request resolved: https://github.com/pytorch/pytorch/pull/63939

Reviewed By: navahgar

Differential Revision: D31093381

Pulled By: eellison

fbshipit-source-id: da6e26c668c38b01e296f304507048d6c1794314
2021-10-27 12:11:36 -07:00
Zhengxu Chen
b55a2500d2 [jit] Remove graph() call from abstract Function interface. (#65967)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65967

Graph is an implementation detail. If user wants to get access to the
underlying graph, they should be able to explicitly dynamic cast instead.
ghstack-source-id: 141659819

Test Plan: no behavior change.

Reviewed By: gmagogsfm

Differential Revision: D31326153

fbshipit-source-id: a0e984f57c6013494b92a7095bf5bb660035eb84
2021-10-27 11:54:26 -07:00
Natalia Gimelshein
7d9bbd3596 Revert D31580382: [pytorch][PR] dropout update in autodiff
Test Plan: revert-hammer

Differential Revision:
D31580382 (eb8138d886)

Original commit changeset: 41d15da99bf4

fbshipit-source-id: 59f751ee59602a5fd09c17f8c7565dca5e2beb50
2021-10-13 19:52:05 -07:00
jiej
eb8138d886 dropout update in autodiff (#66273)
Summary:
1. Unifies dropout op in autodiff
2. Removes dropout inference support in autodiff

Pull Request resolved: https://github.com/pytorch/pytorch/pull/66273

Reviewed By: jbschlosser, gmagogsfm

Differential Revision: D31580382

Pulled By: eellison

fbshipit-source-id: 41d15da99bf4ce6c47cc335a4156c4a1c9705a70
2021-10-13 16:23:40 -07:00
jjsjann123
d85948896c Add softplus support to autodiff (#63942)
Summary:
Add softplus definition to autodiff.

cc gmagogsfm

Pull Request resolved: https://github.com/pytorch/pytorch/pull/63942

Reviewed By: ngimel

Differential Revision: D31397158

Pulled By: eellison

fbshipit-source-id: f7db547370f82e5e282505c3c8415fb4fbd86d54
2021-10-13 08:08:09 -07:00
soulitzer
4cdfceddd2 [Reland] Avoid saving self for softmax and log_softmax (#66018)
Summary:
Reland of https://github.com/pytorch/pytorch/pull/65242

The last attempt of the reland automatically rebased onto stable, which did not yet have the revert commit

Pull Request resolved: https://github.com/pytorch/pytorch/pull/66018

Reviewed By: albanD

Differential Revision: D31348822

Pulled By: soulitzer

fbshipit-source-id: 881d701b404530c1352ac9245bd67264e1652b8a
2021-10-03 21:35:01 -07:00
Michael Suo
9ae63bd87c Revert D31238123: [pytorch][PR] Avoid saving self forsoftmax and log_softmax
Test Plan: revert-hammer

Differential Revision:
D31238123 (fb412bdd80)

Original commit changeset: afd319d3676d

fbshipit-source-id: b7980d653a4b8322a225f1dd08c2857ecbe5bc94
2021-09-30 11:34:14 -07:00
soulitzer
fb412bdd80 Avoid saving self forsoftmax and log_softmax (#65242)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/64000
 - updates double backward formula to compute grad wrt output instead of self
 - ~~In some of the error messages, we still refer to the dtype of the input, even though we are now checking the dtype of the output~~

Pull Request resolved: https://github.com/pytorch/pytorch/pull/65242

Reviewed By: albanD

Differential Revision: D31238123

Pulled By: soulitzer

fbshipit-source-id: afd319d3676d9ef8d81607e0e8c2a3e6d09f68e4
2021-09-29 18:16:12 -07:00
Natalia Gimelshein
09eb3e661c don't check 0 elements for cat symbolic diff (#65751)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65751

Fixes symbolic script grad formula for cat to correctly handle empty tensors

Test Plan: Existing tests

Reviewed By: eellison

Differential Revision: D31208364

fbshipit-source-id: d676d9abcc033b56076fa946f58f3db50034502d
2021-09-29 09:34:03 -07:00
Xiaodong Wang
6d58c83007 Turn off layer norm in jit symbolic differentiation (#63816)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63816

Test Plan:
Confirmed this can rescue the NE:

https://www.internalfb.com/mast/job/torchx_xdwang-SparseNNApplication_72cf593d

Reviewed By: ngimel

Differential Revision: D30498746

fbshipit-source-id: 4a387f32ee2f70685de6104459c7f21bfbddc187
2021-08-24 15:47:13 -07:00
jiej
e926f75b0b BatchNorm autodiff re-enabled (#57321)
Summary:
Turns on BN in autodiff:

1. outputs an empty tensor for running stats to by pass autodiff issue on None;
2. fixing BN inference backward in cudnn & miopen, where backward falls back to native batchnorm kernel instead;

Pull Request resolved: https://github.com/pytorch/pytorch/pull/57321

Reviewed By: albanD, ngimel

Differential Revision: D30250419

Pulled By: jansel

fbshipit-source-id: a62553789c20fb50a820003a056f40d9d642dfaa
2021-08-21 09:07:31 -07:00
jiej
ed0b8a3e83 LayerNorm Support in autodiff: (#50467)
Summary:
1. extend autodiff by adding entry for layer_norm in symbolic script, we now use native_layer_norm_backward
2. added backward function `layernorm_double_backward` for `native_layer_norm_backward`, preserves double backward support for LayerNorm in autodiff/ScriptModule
3. added python test to verify autodiff on layer_norm with various configuration of optional tensors; (verify the fix in https://github.com/pytorch/pytorch/issues/49430)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/50467

Reviewed By: eellison

Differential Revision: D30232864

Pulled By: jansel

fbshipit-source-id: b9c33075386aff96afff7415df9f94388bfb474a

Co-authored-by: Ryan Spring <rspring@nvidia.com>
Co-authored-by: Jie <jiej@nvidia.com>
2021-08-12 11:05:53 -07:00
Nikita Shulga
a9b0a921d5 Disable avoid-non-const-global-variables lint check (#62008)
Summary:
As GoogleTest `TEST` macro is non-compliant with it as well as `DEFINE_DISPATCH`

All changes but the ones to `.clang-tidy` are generated using following script:
```
for i in `find . -type f -iname "*.c*" -or -iname "*.h"|xargs grep cppcoreguidelines-avoid-non-const-global-variables|cut -f1 -d:|sort|uniq`;  do sed -i "/\/\/ NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)/d" $i; done
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/62008

Reviewed By: driazati, r-barnes

Differential Revision: D29838584

Pulled By: malfet

fbshipit-source-id: 1b2f8602c945bd4ce50a9bfdd204755556e31d13
2021-07-22 18:04:40 -07:00
Bert Maher
c3bf42e0d8 Fix symbolic derivative of hardswish (#59405)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/59405

Test Plan: Imported from OSS

Reviewed By: ngimel

Differential Revision: D28879698

Pulled By: bertmaher

fbshipit-source-id: 2f2d9836bf592b18ed9a19aab4f5967e653b5898
2021-06-03 23:12:18 -07:00
Bert Maher
9ac954789d [nnc] Add hardsigmoid (#59069)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/59069

Test Plan: Imported from OSS

Reviewed By: mrshenli

Differential Revision: D28738166

Pulled By: bertmaher

fbshipit-source-id: d9f5b87ef1f2323a3631add79c2670ce794f911e
2021-06-03 23:10:36 -07:00
Bin Bao
7e4e648c2a Enable NNC fusion for relu6 (#58773)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/58773

Test Plan:
```
python test/test_ops.py -k relu6
python test/test_jit_fuser_te.py
```

Reviewed By: bertmaher

Differential Revision: D28721791

Pulled By: desertfire

fbshipit-source-id: a94f711977afd080faae052f66eb8dded3cdc79e
2021-05-27 10:54:02 -07:00
Edvard Ghazaryan
ad97fd8031 Support symbolic diff for leaky_relu (#58337)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58337

supports  symbolic differentiation for leaky_relu

Test Plan:
test/test_jit.py
test/test_ops.py

Reviewed By: Krovatkin

Differential Revision: D28458898

fbshipit-source-id: bdde74d689d2c2ea1f59507456c2efa4e38de1cc
2021-05-18 14:13:40 -07:00
Nikolay Korovaiko
3072c97017 Gelu Backward, Contribution from Kevin Stephano (#58249)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/58249

Test Plan: Imported from OSS

Reviewed By: ejguan

Differential Revision: D28425629

Pulled By: Krovatkin

fbshipit-source-id: 494ab165d548aa76f036344ab1c19c5fd64bae82
2021-05-13 19:39:39 -07:00
Nick Korovaiko
f3ead05d77 hardtanh (#57750)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/57750

Test Plan: Imported from OSS

Reviewed By: huiguoo

Differential Revision: D28425975

fbshipit-source-id: a5e3dfbd6c77c595528c052e0b4325ef452983eb
2021-05-13 19:39:37 -07:00
Nick Korovaiko
c524448dd1 init hardshrink (#57749)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57749

add to a fx test

Test Plan: Imported from OSS

Reviewed By: huiguoo

Differential Revision: D28425974

fbshipit-source-id: 195c7a1944decb7a2a99c2831cab38485f32be17
2021-05-13 19:38:05 -07:00
Peter Bell
2043093217 Add correction parameter to std/var (#50903)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/50903

First part of #50010. Also fixes #51127.

Test Plan: Imported from OSS

Reviewed By: ngimel

Differential Revision: D27911345

Pulled By: mruberry

fbshipit-source-id: 7138fddc935802918ab9ff19f4bc1b9f4d745d41
2021-05-07 14:40:28 -07:00
Elias Ellison
7627dd568a hardswish reland (#57652)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/57652

Test Plan: Imported from OSS

Reviewed By: Krovatkin

Differential Revision: D28226724

Pulled By: eellison

fbshipit-source-id: 585a91ffab7a855b5600e79130a37be25ef9b354
2021-05-05 17:21:43 -07:00
Shen Li
887d0e5657 Revert D28197820: [JIT][NNC] add hardswish symbolic gradient and NNC lowering
Test Plan: revert-hammer

Differential Revision:
D28197820 (0142fd0b57)

Original commit changeset: 05305d85c5bb

fbshipit-source-id: 2e1d9699515982ba2a9be06e83a2ce043ec857ee
2021-05-05 07:53:30 -07:00
eellison
0142fd0b57 [JIT][NNC] add hardswish symbolic gradient and NNC lowering (#57383)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57383

Notes: I picked up an activation from https://github.com/pytorch/pytorch/issues/56969. You can look at the [activations.cpp](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cpu/Activation.cpp#L429) file which has both forward and backward kernel code to help you write the NNC lowering and the symbolic gradient.

I added a test in test_jit_fuser_te for the fusion, and I added an OpInfo and asserted that we expect to see autodiffable nodes to test the symbolic gradient.

Test Plan: Imported from OSS

Reviewed By: mrshenli

Differential Revision: D28197820

Pulled By: eellison

fbshipit-source-id: 05305d85c5bb0847c8f911b95ba47b137dca7e90
2021-05-04 23:39:59 -07:00
Peter Bell
33eea146ee torch.clamp with tensor min and max (#52695)
Summary:
Fixes gh-2793

Pull Request resolved: https://github.com/pytorch/pytorch/pull/52695

Reviewed By: mruberry

Differential Revision: D27395977

Pulled By: ezyang

fbshipit-source-id: f86aa240feb034d42e4c45447e72218f6a773c24
2021-05-03 12:56:16 -07:00
Nikita Shulga
4cb534f92e Make PyTorch code-base clang-tidy compliant (#56892)
Summary:
This is an automatic change generated by the following script:
```
#!/usr/bin/env python3
from subprocess import check_output, check_call
import os

def get_compiled_files_list():
    import json
    with open("build/compile_commands.json") as f:
        data = json.load(f)
    files = [os.path.relpath(node['file']) for node in data]
    for idx, fname in enumerate(files):
        if fname.startswith('build/') and fname.endswith('.DEFAULT.cpp'):
            files[idx] = fname[len('build/'):-len('.DEFAULT.cpp')]
    return files

def run_clang_tidy(fname):
    check_call(["python3", "tools/clang_tidy.py", "-c", "build", "-x", fname,"-s"])
    changes = check_output(["git", "ls-files", "-m"])
    if len(changes) == 0:
        return
    check_call(["git", "commit","--all", "-m", f"NOLINT stubs for {fname}"])

def main():
    git_files = check_output(["git", "ls-files"]).decode("ascii").split("\n")
    compiled_files = get_compiled_files_list()
    for idx, fname in enumerate(git_files):
        if fname not in compiled_files:
            continue
        if fname.startswith("caffe2/contrib/aten/"):
            continue
        print(f"[{idx}/{len(git_files)}] Processing {fname}")
        run_clang_tidy(fname)

if __name__ == "__main__":
    main()
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/56892

Reviewed By: H-Huang

Differential Revision: D27991944

Pulled By: malfet

fbshipit-source-id: 5415e1eb2c1b34319a4f03024bfaa087007d7179
2021-04-28 14:10:25 -07:00
jiej
ce1380f9b5 fixing Optional[Tensor] type in autodiff (#55565)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/54783

We need to be extra careful with the pattern to legitimately use `unchecked_unwrap_optional` in autodiff.
This would at least allow us to start support `Optional[Tensor]` in autodiff, which is quite common in composite layers.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/55565

Reviewed By: ejguan

Differential Revision: D27825336

Pulled By: Krovatkin

fbshipit-source-id: a8562eb10ea741effff430d7417d313b1eb53dfe
2021-04-16 14:06:49 -07:00
zsef123
3498fde20e Add AccumulateType in AdaptiveAveragePooling3d.cu (#53607)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/52719

- Changed the type(`scalar_t`) of intermediate results to `at::acc_type<scalar_t, true>`

This issue occurs by decimal precision of the half precision.

Follows test cases of upper issue, The value range of input tensors are [0, 1] because init by `rand`.
And when the kernel size 1, summations all target values and divide numel of kernel
34d9278c19/aten/src/ATen/native/cuda/AdaptiveAveragePooling3d.cu (L94-L95)

When adding [0, 1] values, if `sum` more than 2048 then not changed values. ( Even if the value is small, the mored exact value is added, but there are still precision issues.)
(https://en.wikipedia.org/wiki/Half-precision_floating-point_format)

Benchmarks
- In V100 32GB, Driver : 450.80, cuda 10.1
- faster than prev

<details><summary>Script</summary><p>

```import torch
from torch.utils.benchmark import Timer

torch.manual_seed(0)

kernel_sizes = [1, 3, 5, 7, 9, 11, 13]
shapes = [(12, 12, 12), (16, 16, 16), (16, 32, 32), (16, 56, 56), (16, 112, 112)]

def run(batch, channel):
    print(f"Batch : {batch}, Channel : {channel} / (diff, diff / numel, time)")

    head = "\t".join(f"{str(s):30s}" for s in ["k \ shape"] + shapes)
    print(head)
    for kernel_size in kernel_sizes:
        kernel_size = (kernel_size, kernel_size, kernel_size)
        pool = torch.nn.AdaptiveAvgPool3d(kernel_size)

        print(f"{str(kernel_size):30s}", end="\t")
        for shape in shapes:
            x_half = torch.rand([batch, channel, *shape], dtype=torch.half, device="cuda")
            x_float = x_half.float()

            y_half = pool(x_half)
            y_float = pool(x_float)

            timer = Timer("pool(x_half)", globals={"pool": pool, "x_half": x_half})
            measurement = timer.blocked_autorange(min_run_time=5)

            diff = (y_float - y_half).abs().sum().item()
            diff = f"{diff:.4f}, {diff / y_half.numel():.6f}, {measurement.median * 1e6 :3.2f}us"
            print(f"{diff:30s}", end="\t")
        print("")

run(1, 1)
run(1, 3)
run(1, 54)
run(1, 16)

run(8, 1)
run(8, 16)
run(8, 54)

import torch
m = torch.nn.AdaptiveAvgPool3d((1,1,1))

inputs = torch.rand([8,54,16,56,56])
inputs = inputs.cuda()
inputs_2 = inputs.half()

print("Float")
out = m(inputs).float()
print("half")
out2 = m(inputs_2).float()

print('Discepancies', torch.sum(torch.abs(out2- out)).item(), torch.sum(torch.abs(out2- out)).item() / out.numel() , out.numel())

print("Sum : ", torch.sum(inputs, dim=(2,3,4))[0, 0], torch.sum(inputs_2, dim=(2,3,4))[0, 0])
```
</p>
</details>

<details><summary>This commit</summary><p>

```
Batch : 1, Channel : 1 / (diff, diff / numel, time)
k \ shape                       (12, 12, 12)                    (16, 16, 16)                         (16, 32, 32)                    (16, 56, 56)                    (16, 112, 112)
(1, 1, 1)                       0.0001, 0.000078, 55.73us       0.0001, 0.000079, 117.51us       0.0000, 0.000003, 379.60us      0.0000, 0.000046, 1046.21us      0.0001, 0.000139, 3897.17us
(3, 3, 3)                       0.0021, 0.000076, 22.04us       0.0031, 0.000115, 21.47us        0.0022, 0.000080, 41.63us       0.0030, 0.000111, 100.59us       0.0025, 0.000091, 295.04us
(5, 5, 5)                       0.0103, 0.000083, 21.65us       0.0097, 0.000078, 21.37us        0.0103, 0.000083, 21.60us       0.0114, 0.000091, 25.69us        0.0107, 0.000085, 97.06us
(7, 7, 7)                       0.0312, 0.000091, 21.52us       0.0290, 0.000084, 21.61us        0.0311, 0.000091, 21.60us       0.0309, 0.000090, 21.44us        0.0334, 0.000097, 33.60us
(9, 9, 9)                       0.0646, 0.000089, 21.57us       0.0672, 0.000092, 21.89us        0.0662, 0.000091, 21.89us       0.0684, 0.000094, 27.64us        0.0660, 0.000091, 54.85us
(11, 11, 11)                    0.1251, 0.000094, 21.68us       0.1194, 0.000090, 21.70us        0.1202, 0.000090, 21.72us       0.1233, 0.000093, 22.25us        0.1229, 0.000092, 41.39us
(13, 13, 13)                    0.2038, 0.000093, 21.57us       0.2047, 0.000093, 21.58us        0.1964, 0.000089, 21.54us       0.2021, 0.000092, 21.94us        0.1989, 0.000091, 40.01us
Batch : 1, Channel : 3 / (diff, diff / numel, time)
k \ shape                       (12, 12, 12)                    (16, 16, 16)                     (16, 32, 32)                    (16, 56, 56)                     (16, 112, 112)
(1, 1, 1)                       0.0003, 0.000110, 55.74us       0.0003, 0.000093, 118.62us       0.0003, 0.000093, 382.12us      0.0001, 0.000040, 1052.33us      0.0003, 0.000114, 3917.90us
(3, 3, 3)                       0.0073, 0.000090, 21.84us       0.0075, 0.000093, 22.25us        0.0072, 0.000089, 41.78us       0.0070, 0.000087, 100.27us       0.0069, 0.000086, 293.96us
(5, 5, 5)                       0.0353, 0.000094, 22.57us       0.0325, 0.000087, 21.64us        0.0343, 0.000092, 22.63us       0.0338, 0.000090, 25.82us        0.0332, 0.000089, 97.16us
(7, 7, 7)                       0.0937, 0.000091, 22.50us       0.0910, 0.000088, 21.92us        0.0933, 0.000091, 21.99us       0.0948, 0.000092, 21.56us        0.0928, 0.000090, 34.17us
(9, 9, 9)                       0.1957, 0.000089, 21.68us       0.1984, 0.000091, 21.57us        0.2025, 0.000093, 22.10us       0.1986, 0.000091, 27.66us        0.2020, 0.000092, 55.32us
(11, 11, 11)                    0.3585, 0.000090, 21.75us       0.3684, 0.000092, 22.70us        0.3706, 0.000093, 21.67us       0.3752, 0.000094, 21.86us        0.3663, 0.000092, 41.22us
(13, 13, 13)                    0.5931, 0.000090, 21.67us       0.6056, 0.000092, 21.79us        0.6005, 0.000091, 21.79us       0.6112, 0.000093, 21.69us        0.6034, 0.000092, 40.02us
Batch : 1, Channel : 54 / (diff, diff / numel, time)
k \ shape                       (12, 12, 12)                    (16, 16, 16)                     (16, 32, 32)                    (16, 56, 56)                     (16, 112, 112)
(1, 1, 1)                       0.0051, 0.000095, 55.76us       0.0060, 0.000112, 118.60us       0.0036, 0.000067, 381.50us      0.0054, 0.000100, 1054.03us      0.0048, 0.000089, 4888.68us
(3, 3, 3)                       0.1332, 0.000091, 21.66us       0.1344, 0.000092, 22.62us        0.1354, 0.000093, 45.72us       0.1364, 0.000094, 106.63us       0.1324, 0.000091, 448.31us
(5, 5, 5)                       0.6221, 0.000092, 22.48us       0.6220, 0.000092, 21.71us        0.6053, 0.000090, 27.65us       0.6137, 0.000091, 31.40us        0.6209, 0.000092, 172.78us
(7, 7, 7)                       1.6859, 0.000091, 22.42us       1.6972, 0.000092, 21.96us        1.6849, 0.000091, 23.14us       1.7012, 0.000092, 26.25us        1.6920, 0.000091, 75.58us
(9, 9, 9)                       3.5811, 0.000091, 21.73us       3.5746, 0.000091, 22.55us        3.6237, 0.000092, 27.66us       3.6046, 0.000092, 59.71us        3.6392, 0.000092, 168.15us
(11, 11, 11)                    6.5582, 0.000091, 22.05us       6.5746, 0.000091, 21.74us        6.5955, 0.000092, 32.91us       6.5644, 0.000091, 45.57us        6.5697, 0.000091, 114.01us
(13, 13, 13)                    10.6384, 0.000090, 21.81us      10.8608, 0.000092, 21.79us       10.8375, 0.000091, 37.01us      10.8662, 0.000092, 51.80us       10.8593, 0.000092, 123.19us
Batch : 1, Channel : 16 / (diff, diff / numel, time)
k \ shape                       (12, 12, 12)                    (16, 16, 16)                     (16, 32, 32)                    (16, 56, 56)                     (16, 112, 112)
(1, 1, 1)                       0.0015, 0.000093, 55.75us       0.0012, 0.000075, 118.10us           0.0013, 0.000079, 379.25us      0.0012, 0.000075, 1047.21us     0.0013, 0.000079, 4451.57us
(3, 3, 3)                       0.0407, 0.000094, 21.82us       0.0395, 0.000091, 21.69us            0.0385, 0.000089, 42.07us       0.0397, 0.000092, 100.33us      0.0384, 0.000089, 363.31us
(5, 5, 5)                       0.1858, 0.000093, 21.76us       0.1799, 0.000090, 21.63us            0.1834, 0.000092, 21.76us       0.1890, 0.000095, 26.04us       0.1814, 0.000091, 135.32us
(7, 7, 7)                       0.4937, 0.000090, 21.65us       0.5076, 0.000092, 21.69us            0.5001, 0.000091, 22.31us       0.4988, 0.000091, 21.59us       0.5123, 0.000093, 50.03us
(9, 9, 9)                       1.0678, 0.000092, 21.73us       1.0752, 0.000092, 21.75us            1.0673, 0.000091, 21.75us       1.0649, 0.000091, 30.01us       1.0786, 0.000092, 70.92us
(11, 11, 11)                    1.9591, 0.000092, 21.57us       1.9522, 0.000092, 21.60us            1.9566, 0.000092, 21.73us       1.9475, 0.000091, 23.46us       1.9323, 0.000091, 55.02us
(13, 13, 13)                    3.1784, 0.000090, 22.02us       3.2165, 0.000092, 21.95us            3.1969, 0.000091, 21.92us       3.2061, 0.000091, 24.40us       3.2578, 0.000093, 56.00us
Batch : 8, Channel : 1 / (diff, diff / numel, time)
k \ shape                       (12, 12, 12)                    (16, 16, 16)                         (16, 32, 32)                    (16, 56, 56)                    (16, 112, 112)
(1, 1, 1)                       0.0010, 0.000122, 55.74us       0.0009, 0.000114, 118.82us           0.0006, 0.000074, 379.80us      0.0009, 0.000107, 1047.31us     0.0008, 0.000102, 3900.36us
(3, 3, 3)                       0.0219, 0.000101, 21.57us       0.0200, 0.000093, 21.61us            0.0194, 0.000090, 41.74us       0.0208, 0.000096, 99.91us       0.0212, 0.000098, 293.03us
(5, 5, 5)                       0.0906, 0.000091, 21.46us       0.0911, 0.000091, 21.60us            0.0934, 0.000093, 21.93us       0.0927, 0.000093, 25.74us       0.0913, 0.000091, 96.85us
(7, 7, 7)                       0.2530, 0.000092, 22.53us       0.2526, 0.000092, 22.46us            0.2558, 0.000093, 22.03us       0.2542, 0.000093, 22.29us       0.2475, 0.000090, 34.44us
(9, 9, 9)                       0.5305, 0.000091, 22.34us       0.5368, 0.000092, 22.42us            0.5265, 0.000090, 21.74us       0.5370, 0.000092, 27.81us       0.5416, 0.000093, 55.65us
(11, 11, 11)                    0.9887, 0.000093, 21.80us       0.9660, 0.000091, 21.61us            0.9793, 0.000092, 22.11us       0.9719, 0.000091, 21.80us       0.9650, 0.000091, 43.90us
(13, 13, 13)                    1.6024, 0.000091, 21.87us       1.6198, 0.000092, 22.65us            1.6242, 0.000092, 21.73us       1.6236, 0.000092, 22.59us       1.6025, 0.000091, 42.77us
Batch : 8, Channel : 16 / (diff, diff / numel, time)
k \ shape                       (12, 12, 12)                    (16, 16, 16)                         (16, 32, 32)                    (16, 56, 56)                    (16, 112, 112)
(1, 1, 1)                       0.0113, 0.000088, 56.66us       0.0117, 0.000091, 119.57us           0.0130, 0.000102, 389.57us      0.0110, 0.000086, 1433.78us     0.0119, 0.000093, 5217.61us
(3, 3, 3)                       0.3209, 0.000093, 21.54us       0.3184, 0.000092, 22.87us            0.3115, 0.000090, 51.00us       0.3171, 0.000092, 164.17us      0.3182, 0.000092, 500.60us
(5, 5, 5)                       1.4391, 0.000090, 22.39us       1.4577, 0.000091, 21.69us            1.4601, 0.000091, 53.87us       1.4626, 0.000091, 93.65us       1.4567, 0.000091, 370.11us
(7, 7, 7)                       4.0501, 0.000092, 22.34us       4.0230, 0.000092, 31.45us            4.0381, 0.000092, 45.19us       4.0171, 0.000091, 65.35us       4.0108, 0.000091, 164.76us
(9, 9, 9)                       8.5360, 0.000091, 22.80us       8.5456, 0.000092, 27.24us            8.5461, 0.000092, 50.23us       8.5677, 0.000092, 117.63us      8.5645, 0.000092, 270.46us
(11, 11, 11)                    15.5521, 0.000091, 26.56us      15.5826, 0.000091, 32.81us           15.6014, 0.000092, 63.82us      15.5620, 0.000091, 96.87us      15.5722, 0.000091, 220.24us
(13, 13, 13)                    25.4146, 0.000090, 32.91us      25.7898, 0.000092, 38.48us           25.6698, 0.000091, 72.02us      25.8193, 0.000092, 121.73us     25.7718, 0.000092, 249.71us
Batch : 8, Channel : 54 / (diff, diff / numel, time)
k \ shape                       (12, 12, 12)                    (16, 16, 16)                         (16, 32, 32)                    (16, 56, 56)                    (16, 112, 112)
(1, 1, 1)                       0.0377, 0.000087, 109.07us      0.0405, 0.000094, 233.17us           0.0392, 0.000091, 998.97us      0.0393, 0.000091, 2960.68us     0.0408, 0.000094, 11879.53us
(3, 3, 3)                       1.0660, 0.000091, 25.68us       1.0761, 0.000092, 64.12us            1.0725, 0.000092, 182.50us      1.0801, 0.000093, 505.82us      1.0736, 0.000092, 1650.21us
(5, 5, 5)                       4.9587, 0.000092, 50.84us       4.9336, 0.000091, 47.38us            4.9696, 0.000092, 158.49us      4.9347, 0.000091, 237.39us      4.9303, 0.000091, 965.13us
(7, 7, 7)                       13.5409, 0.000091, 45.60us      13.5736, 0.000092, 87.45us           13.5012, 0.000091, 141.63us     13.6111, 0.000092, 181.51us     13.5296, 0.000091, 469.77us
(9, 9, 9)                       28.7817, 0.000091, 58.01us      28.7969, 0.000091, 77.61us           28.8761, 0.000092, 159.33us     28.8786, 0.000092, 334.47us     28.8093, 0.000091, 786.72us
(11, 11, 11)                    52.4453, 0.000091, 78.19us      52.7265, 0.000092, 95.12us           52.7322, 0.000092, 200.38us     52.6342, 0.000092, 282.41us     52.6467, 0.000092, 652.54us
(13, 13, 13)                    85.7411, 0.000090, 98.85us      86.7183, 0.000091, 115.28us          86.8545, 0.000092, 232.34us     86.9997, 0.000092, 367.32us     86.9083, 0.000092, 757.73us
Float
half
Discepancies 0.03963914513587952 9.175728040712852e-05 432
Sum :  tensor(25110.1484, device='cuda:0') tensor(25104., device='cuda:0', dtype=torch.float16)
```
</p>
</details>

<details><summary>1.8.0</summary><p>

```
Batch : 1, Channel : 1 / (diff, diff / numel, time)
k \ shape                       (12, 12, 12)                    (16, 16, 16)                  (16, 32, 32)                    (16, 56, 56)                    (16, 112, 112)
(1, 1, 1)                       0.0023, 0.002275, 74.35us       0.0040, 0.003985, 159.73us        0.3740, 0.374021, 546.59us      0.4587, 0.458663, 1543.16us       0.4906, 0.490637, 5945.97us
(3, 3, 3)                       0.0100, 0.000370, 20.37us       0.0230, 0.000852, 22.12us         0.0309, 0.001143, 54.75us       0.0520, 0.001926, 129.78us        7.1219, 0.263775, 377.11us
(5, 5, 5)                       0.0441, 0.000352, 20.06us       0.0394, 0.000316, 20.50us         0.0759, 0.000607, 26.43us       0.1499, 0.001199, 32.01us         0.2707, 0.002166, 128.15us
(7, 7, 7)                       0.0791, 0.000231, 20.10us       0.1002, 0.000292, 20.56us         0.1812, 0.000528, 20.48us       0.2424, 0.000707, 20.83us         0.4994, 0.001456, 43.97us
(9, 9, 9)                       0.1122, 0.000154, 20.55us       0.1778, 0.000244, 20.44us         0.2572, 0.000353, 20.15us       0.4149, 0.000569, 35.64us         0.7208, 0.000989, 68.46us
(11, 11, 11)                    0.2044, 0.000154, 20.47us       0.2647, 0.000199, 20.62us         0.3867, 0.000291, 20.61us       0.6059, 0.000455, 23.54us         1.0902, 0.000819, 53.32us
(13, 13, 13)                    0.3094, 0.000141, 20.53us       0.3843, 0.000175, 20.60us         0.5756, 0.000262, 20.80us       0.8598, 0.000391, 24.52us         1.4853, 0.000676, 47.70us
Batch : 1, Channel : 3 / (diff, diff / numel, time)
k \ shape                       (12, 12, 12)                    (16, 16, 16)                      (16, 32, 32)                    (16, 56, 56)                      (16, 112, 112)
(1, 1, 1)                       0.0054, 0.001801, 74.36us       0.0108, 0.003614, 158.94us        1.1183, 0.372768, 547.67us      1.3782, 0.459387, 1545.27us       1.4685, 0.489505, 5949.17us
(3, 3, 3)                       0.0308, 0.000380, 20.14us       0.0502, 0.000619, 22.11us         0.1210, 0.001493, 54.80us       0.1900, 0.002345, 130.47us        21.3483, 0.263560, 375.68us
(5, 5, 5)                       0.1179, 0.000314, 20.68us       0.1326, 0.000354, 20.53us         0.2662, 0.000710, 26.51us       0.4116, 0.001098, 31.85us         0.8369, 0.002232, 128.19us
(7, 7, 7)                       0.2335, 0.000227, 20.40us       0.3057, 0.000297, 20.43us         0.4954, 0.000481, 20.31us       0.7339, 0.000713, 20.74us         1.4208, 0.001381, 44.55us
(9, 9, 9)                       0.3326, 0.000152, 20.63us       0.5353, 0.000245, 20.42us         0.8025, 0.000367, 20.13us       1.2693, 0.000580, 35.64us         2.2096, 0.001010, 68.88us
(11, 11, 11)                    0.6121, 0.000153, 20.59us       0.8086, 0.000202, 20.42us         1.1700, 0.000293, 20.71us       1.8170, 0.000455, 23.54us         3.2117, 0.000804, 53.36us
(13, 13, 13)                    0.9165, 0.000139, 20.51us       1.1395, 0.000173, 20.56us         1.7343, 0.000263, 20.80us       2.5868, 0.000392, 24.59us         4.5823, 0.000695, 47.77us
Batch : 1, Channel : 54 / (diff, diff / numel, time)
k \ shape                       (12, 12, 12)                    (16, 16, 16)                      (16, 32, 32)                    (16, 56, 56)                      (16, 112, 112)
(1, 1, 1)                       0.1092, 0.002023, 75.45us       0.1709, 0.003165, 160.44us        20.2452, 0.374911, 548.61us     24.7990, 0.459240, 1550.34us      26.4494, 0.489804, 6957.79us
(3, 3, 3)                       0.5352, 0.000367, 20.58us       1.0281, 0.000705, 24.14us         2.0150, 0.001382, 59.12us       3.3069, 0.002268, 138.23us        384.5216, 0.263732, 529.71us
(5, 5, 5)                       2.0739, 0.000307, 20.60us       2.5199, 0.000373, 20.44us         4.6916, 0.000695, 33.89us       7.9482, 0.001178, 37.74us         14.2553, 0.002112, 200.54us
(7, 7, 7)                       4.2236, 0.000228, 20.61us       5.5605, 0.000300, 20.97us         9.0440, 0.000488, 26.40us       12.7847, 0.000690, 30.64us        25.3050, 0.001366, 88.05us
(9, 9, 9)                       6.0817, 0.000154, 20.63us       9.5416, 0.000242, 20.84us         14.2416, 0.000362, 32.47us      22.8452, 0.000580, 78.57us        40.3246, 0.001024, 194.50us
(11, 11, 11)                    11.1144, 0.000155, 20.56us      14.5581, 0.000203, 20.91us        20.8263, 0.000290, 38.07us      33.0004, 0.000459, 52.74us        57.3275, 0.000798, 137.19us
(13, 13, 13)                    16.5176, 0.000139, 21.26us      20.8089, 0.000175, 22.33us        31.3433, 0.000264, 42.93us      45.9733, 0.000388, 59.84us        82.8301, 0.000698, 138.42us
Batch : 1, Channel : 16 / (diff, diff / numel, time)
k \ shape                       (12, 12, 12)                    (16, 16, 16)                      (16, 32, 32)                    (16, 56, 56)                      (16, 112, 112)
(1, 1, 1)                       0.0274, 0.001715, 74.99us       0.0485, 0.003034, 159.92us    5.9925, 0.374529, 546.35us      7.3389, 0.458679, 1544.53us     7.8354, 0.489714, 6677.00us
(3, 3, 3)                       0.1560, 0.000361, 20.72us       0.3043, 0.000704, 22.37us     0.5838, 0.001352, 54.97us       1.0455, 0.002420, 130.57us      113.9739, 0.263828, 463.43us
(5, 5, 5)                       0.6121, 0.000306, 20.12us       0.7247, 0.000362, 20.73us     1.3740, 0.000687, 26.59us       2.3794, 0.001190, 32.12us       4.1929, 0.002096, 165.81us
(7, 7, 7)                       1.2389, 0.000226, 20.59us       1.6311, 0.000297, 20.53us     2.6732, 0.000487, 20.37us       3.7501, 0.000683, 20.71us       7.4575, 0.001359, 59.16us
(9, 9, 9)                       1.7983, 0.000154, 20.64us       2.8075, 0.000241, 20.59us     4.2165, 0.000361, 20.38us       6.7153, 0.000576, 38.29us       12.0530, 0.001033, 86.33us
(11, 11, 11)                    3.3326, 0.000156, 20.56us       4.3061, 0.000202, 20.67us     6.2235, 0.000292, 20.47us       9.8009, 0.000460, 27.41us       16.9994, 0.000798, 68.49us
(13, 13, 13)                    4.9016, 0.000139, 20.63us       6.1261, 0.000174, 20.65us     9.2106, 0.000262, 20.93us       13.5843, 0.000386, 27.95us      24.6476, 0.000701, 64.88us
Batch : 8, Channel : 1 / (diff, diff / numel, time)
k \ shape                       (12, 12, 12)                    (16, 16, 16)                  (16, 32, 32)                    (16, 56, 56)                    (16, 112, 112)
(1, 1, 1)                       0.0170, 0.002122, 74.99us       0.0316, 0.003946, 160.66us    3.0013, 0.375158, 546.94us      3.6780, 0.459753, 1544.58us     3.9197, 0.489966, 5948.43us
(3, 3, 3)                       0.0821, 0.000380, 20.27us       0.1559, 0.000722, 22.29us     0.3133, 0.001450, 54.72us       0.5100, 0.002361, 130.12us      57.0481, 0.264111, 376.71us
(5, 5, 5)                       0.3075, 0.000307, 20.57us       0.3680, 0.000368, 20.69us     0.6786, 0.000679, 26.61us       1.1744, 0.001174, 31.77us       2.0654, 0.002065, 128.31us
(7, 7, 7)                       0.6512, 0.000237, 20.60us       0.8359, 0.000305, 20.50us     1.3712, 0.000500, 20.75us       1.9472, 0.000710, 20.92us       3.7586, 0.001370, 44.59us
(9, 9, 9)                       0.9138, 0.000157, 20.43us       1.4198, 0.000243, 20.58us     2.1018, 0.000360, 20.52us       3.3691, 0.000578, 35.90us       5.9491, 0.001020, 69.16us
(11, 11, 11)                    1.6606, 0.000156, 20.63us       2.1599, 0.000203, 20.57us     3.1240, 0.000293, 20.98us       4.8874, 0.000459, 24.65us       8.4780, 0.000796, 56.47us
(13, 13, 13)                    2.4987, 0.000142, 20.71us       3.0667, 0.000174, 20.45us     4.6387, 0.000264, 20.76us       6.8187, 0.000388, 25.95us       12.2077, 0.000695, 50.46us
Batch : 8, Channel : 16 / (diff, diff / numel, time)
k \ shape                       (12, 12, 12)                    (16, 16, 16)                  (16, 32, 32)                    (16, 56, 56)                    (16, 112, 112)
(1, 1, 1)                       0.2635, 0.002059, 75.66us       0.4030, 0.003149, 161.78us    48.0296, 0.375231, 550.46us     58.7787, 0.459209, 1902.41us    62.6966, 0.489817, 7817.48us
(3, 3, 3)                       1.2271, 0.000355, 20.72us       2.4185, 0.000700, 26.44us     4.6933, 0.001358, 64.66us       7.7016, 0.002228, 192.69us      912.0736, 0.263910, 593.69us
(5, 5, 5)                       4.8716, 0.000304, 24.75us       5.8624, 0.000366, 21.39us     11.0705, 0.000692, 66.94us      18.9280, 0.001183, 104.93us     34.0512, 0.002128, 441.81us
(7, 7, 7)                       10.1713, 0.000232, 20.98us      13.2273, 0.000301, 36.26us    21.5426, 0.000491, 52.18us      30.1910, 0.000688, 72.94us      59.8381, 0.001363, 191.52us
(9, 9, 9)                       14.4542, 0.000155, 23.85us      22.6579, 0.000243, 30.59us    33.8839, 0.000363, 57.40us      54.3563, 0.000583, 142.53us     95.8123, 0.001027, 309.24us
(11, 11, 11)                    26.3348, 0.000155, 30.07us      34.3043, 0.000201, 37.01us    49.8093, 0.000292, 74.04us      78.3720, 0.000460, 110.53us     136.5404, 0.000801, 264.14us
(13, 13, 13)                    39.3550, 0.000140, 37.38us      49.3207, 0.000175, 43.51us    74.1139, 0.000264, 83.70us      108.7627, 0.000387, 136.09us    196.5412, 0.000699, 280.16us
Batch : 8, Channel : 54 / (diff, diff / numel, time)
k \ shape                       (12, 12, 12)                    (16, 16, 16)                  (16, 32, 32)                    (16, 56, 56)                    (16, 112, 112)
(1, 1, 1)                       0.8467, 0.001960, 147.36us      1.3993, 0.003239, 314.95us    162.0182, 0.375042, 1327.22us   198.3226, 0.459080, 3921.79us   211.6123, 0.489843, 15646.94us
(3, 3, 3)                       4.3146, 0.000370, 29.23us       8.1125, 0.000696, 74.94us     15.8886, 0.001362, 223.69us     26.2404, 0.002250, 601.33us     3076.5354, 0.263763, 1974.06us
(5, 5, 5)                       16.5032, 0.000306, 58.79us      19.6887, 0.000365, 53.79us    37.2731, 0.000690, 192.34us     63.3076, 0.001172, 270.01us     114.8880, 0.002128, 1148.56us
(7, 7, 7)                       34.0802, 0.000230, 51.12us      44.4087, 0.000300, 100.93us   72.4613, 0.000489, 161.48us     101.9317, 0.000688, 202.91us    201.8955, 0.001363, 545.33us
(9, 9, 9)                       48.8179, 0.000155, 65.78us      76.3465, 0.000242, 87.48us    114.0228, 0.000362, 179.11us    182.9805, 0.000581, 403.66us    322.7040, 0.001025, 894.86us
(11, 11, 11)                    88.9993, 0.000155, 88.69us      116.4213, 0.000202, 107.55us  168.3363, 0.000293, 228.71us    264.2232, 0.000460, 322.84us    459.1324, 0.000799, 784.25us
(13, 13, 13)                    132.7447, 0.000140, 112.91us    165.4525, 0.000174, 131.08us  249.7127, 0.000263, 266.43us    367.0824, 0.000387, 410.17us    663.1367, 0.000699, 847.87us
Float
half
Discepancies 198.37625122070312 0.4592042852331091 432
Sum :  tensor(25110.1484, device='cuda:0') tensor(25104., device='cuda:0', dtype=torch.float16)
```
</p>
</details>

ngimel malfet anjali411

Pull Request resolved: https://github.com/pytorch/pytorch/pull/53607

Reviewed By: mruberry

Differential Revision: D27652337

Pulled By: ngimel

fbshipit-source-id: 6439c0cafe6ca3f761a3f5d058050a55e9a0abd8
2021-04-08 15:48:08 -07:00
Peter Bell
2ee02b30b1 Replace rounding_mode="true" with rounding_mode=None (#51988)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/51988

* **#51988 Replace rounding_mode="true" with rounding_mode=None**

Test Plan: Imported from OSS

Reviewed By: ngimel

Differential Revision: D27561817

Pulled By: mruberry

fbshipit-source-id: 60d1d9c389570f60d599fc1876518717367fb368
2021-04-05 14:53:43 -07:00
Gregory Chanan
983347fa25 Allow broadcasting against lerp weights. (#52319)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/52319

Fixes: https://github.com/pytorch/pytorch/issues/52254

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D26488411

Pulled By: gchanan

fbshipit-source-id: 60eb471609986584c4235ba7f263581e988e7642
2021-02-18 09:53:25 -08:00
jiej
4d703d040b Linear autodiff revert revert (#51613)
Summary:
patch PR https://github.com/pytorch/pytorch/issues/50856 and rollbak the revert D26105797 (e488e3c443)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/51613

Reviewed By: mruberry

Differential Revision: D26253999

Pulled By: ngimel

fbshipit-source-id: a20b1591de06dd277e4cd95542e3291a2f5a252c
2021-02-04 16:32:05 -08:00
Peter Bell
b150f150ba Add division overload with rounding_mode selection (#51706)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/51706

Pull Request resolved: https://github.com/pytorch/pytorch/pull/50280

As mentioned in gh-43874, this adds a `rounding_mode={'true', 'trunc', 'floor'}`
argument so `torch.div` can be used as a replacement for `floor_divide` during
the transitional period.

I've included dedicated kernels for truncated and floor division which
aren't strictly necessary for float, but do perform significantly better (~2x) than
doing true division followed by a separate rounding kernel.

Note: I introduce new overloads for `aten::div` instead of just adding a default
`rounding_mode` because various JIT passes rely on the exact operator schema.

Test Plan: Imported from OSS

Reviewed By: ngimel

Differential Revision: D26123271

Pulled By: mruberry

fbshipit-source-id: 51a83717602114597ec9c4d946e35a392eb01d46
2021-02-04 13:08:36 -08:00
Natalia Gimelshein
26f9ac98e5 Revert D26105797: [pytorch][PR] Exposing linear layer to fuser
Test Plan: revert-hammer

Differential Revision:
D26105797 (e488e3c443)

Original commit changeset: 6f7cedb9f6e3

fbshipit-source-id: f0858cefed76d726e9dba61e51e1eaf2af4c99c5
2021-02-02 17:39:17 -08:00
jiej
e488e3c443 Exposing linear layer to fuser (#50856)
Summary:
1. enabling linear in autodiff;
2. remove control flow in python for linear;

Pull Request resolved: https://github.com/pytorch/pytorch/pull/50856

Reviewed By: pbelevich

Differential Revision: D26105797

Pulled By: eellison

fbshipit-source-id: 6f7cedb9f6e3e46daa24223d2a6080880498deb4
2021-02-02 15:39:01 -08:00
Andres Suarez
8530c65e25 [codemod][fbcode/caffe2] Apply clang-format update fixes
Test Plan: Sandcastle and visual inspection.

Reviewed By: igorsugak

Differential Revision: D25849205

fbshipit-source-id: ef664c1ad4b3ee92d5c020a5511b4ef9837a09a0
2021-01-09 14:37:36 -08:00
Michael Suo
dc8176356e Various cleanups to ir_emitter and friends (#46686)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46686

I was trying to page this code back in after a while and some things
stuck out as unnecessarily confusing.

1. Improve documentation of closures and fork stuff to be more accurate
to how we use them today.
2. Change `prim::LocalVariableScope` to `prim::ListComprehension`. It is
only ever used for a list comprehensions, and in general the nodes
emitted by `ir_emitter` should correspond to concrete operations or
language features rather than semantic constraints.
3. Change the somewhat mysterious "inputs" and "attributes" argument
names throughout the codebase to be the more obvious "args" and "kwargs"
that they generally represent (I think "inputs" and "attributes" come
from the AST naming).

Test Plan: Imported from OSS

Reviewed By: navahgar, jamesr66a

Differential Revision: D24464197

Pulled By: suo

fbshipit-source-id: 1f4b1475b58b5690a0b204e705caceff969533b4
2020-10-28 16:28:05 -07:00
Meghan Lele
6384c2d81b [JIT] clang-format JIT code (#35115)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35115

This commit runs the newly added tools/clang_format.py on the JIT
codebase and includes all of the formatting changes thus produced.

Testing:
Ran the script, CI.

Test Plan: Imported from OSS

Reviewed By: eellison

Differential Revision: D20568523

Pulled By: SplitInfinity

fbshipit-source-id: e09bdb982ccf090eecfb7c7b461b8d0681eef82b
2020-03-26 11:24:51 -07:00
Elias Ellison
514cba0661 [JIT] remove builtin interpolate functions (#34514)
Summary:
`torch.nn.functional.interpolate` was written as a builtin op when we scripted the standard library, because it has four possible overloads. As a result, whenever we make a change to `interpolate`, we need to make changes in two places, and it also makes it impossible to optimize the interpolate op. The builtin is tech debt.

I talked with ailzhang, and the symbolic script changes are good to remove (i guess that makes a third place we needed to re-implement interpolate).

I'm trying to get rid of unneccessary builtin operators because we're standardizing mobile bytecode soon, so we should try to get this landed as soon as possible.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34514

Differential Revision: D20391089

Pulled By: eellison

fbshipit-source-id: abc84cdecfac67332bcba6b308fca4db44303121
2020-03-12 09:21:33 -07:00
Michael Suo
c235be42dd [jit] kill script namespace (#34515)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34515

Once upon a time we thought this was necessary. In reality it is not, so
removing it.

For backcompat, our public interface (defined in `api/`) still has
typedefs to the old `script::` names.

There was only one collision: `Pass` as a `Stmt` and `Pass` as a graph
transform. I renamed one of them.

Test Plan: Imported from OSS

Differential Revision: D20353503

Pulled By: suo

fbshipit-source-id: 48bb911ce75120a8c9e0c6fb65262ef775dfba93
2020-03-11 23:32:48 -07:00
Michael Suo
dbe850af5b [jit] do the code reorg (#33851)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33851

Rationale and context described in #33828.

Script to reproduce the move:
https://gist.github.com/suo/16cbefaaeb67ca5a7c6caffd49b7f6e9
ghstack-source-id: 99079645

Test Plan: Make sure CI passes

Reviewed By: jamesr66a

Differential Revision: D20133869

fbshipit-source-id: 390e9241a9c85366d9005c492ac31f10aa96488e
2020-02-27 13:02:51 -08:00