Commit Graph

275 Commits

Author SHA1 Message Date
Nikita Shulga
8f1c3c68d3 [BE] Use nested namespaces in .cpp/.cu files (#92100)
As we live in C++17 world

This is a functional no-op, just
- `s/namespace at { namespace native {/namespace at::native {/`
- `s/namespace torch { namespace jit {/namespace torch::jit {/`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92100
Approved by: https://github.com/izaitsevfb
2023-01-13 16:32:34 +00:00
Aaron Gokaslan
3916d7a575 Apply modernize-use-emplace to aten, c10, torch (#91077)
Apply clang-tidy check modernize-use-emplace. This is slightly more efficient by using an inplace constructor and is the recommended style in parts of the codebase covered by clang-tidy. This just manually applies the check to rest of the codebase. Pinging @ezyang as this is related to my other PRs he reviewed like #89000

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91077
Approved by: https://github.com/ezyang
2022-12-19 07:49:56 +00:00
Aaron Gokaslan
7541c9f8be [Fix]: remove unnecessary copies in aten, c10, and torch bindings (#90629)
Applies various automated fixes that reduces the number of spurious copies in torch, aten, and c10. I also inlined any default dtors that would have made the type trivially destructible.

Follow up to #89000

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90629
Approved by: https://github.com/ezyang
2022-12-12 17:05:52 +00:00
Kazuaki Ishizaki
e0c194f10b Fix typos in messages under torch (#88961)
This PR fixes typos of messages and parms in c++ source and head files under `torch` directory.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88961
Approved by: https://github.com/albanD
2022-11-14 19:06:41 +00:00
Wang, Eikan
70c6a988d6 Fix the performance issue that the for-loop before ExternallCall could not be parallelized. (#85056)
Currently, NNC only parallelizes the loop statement of the graph outputs. The logic could bypass some loop statements that could be parallelized. Take an example as follows and suppose the output of `ExternallCall` is also the output of NNC fusion group. Current [parallel logic](https://github.com/pytorch/pytorch/pull/85056/files#diff-9a11174c26e4b57ab73e819520122bc314467c72962f3a5b79e7400ea3c4bbe5L781-L785) only tries to parallel the `ExternalCall` and bypass `stmt1` and `stmt2`.

```c++
stmt1: For:
stmt2:   For:
stmt3: ExternalCall
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85056
Approved by: https://github.com/frank-wei, https://github.com/bertmaher
2022-10-07 07:36:28 +00:00
Wu, Chunyuan
ebf45a0785 [NNC] support aten::_convolution when it is 2D conv (#84038)
## Motivation
Currently, only `aten::conv2d` has been supported in NNC. When using `torch.jit.trace`, the node on the graph will be `aten::_convolution`. This PR adds support of `aten::_convolution` node when it corresponds to a 2D convolution.

## Pitch
Support `aten::_convolution` in NNC when we can infer from the parameters that it is a 2D convolution to support models obtained from `torch.jit.trace`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84038
Approved by: https://github.com/huiguoo
2022-09-19 17:45:20 +00:00
chunyuan-w
693a8dd04c [NNC] enable fusion of conv with elementwise OP (#77157)
## Pitch
Enable Conv-Eltwise fusion in NNC.

## Description
This PR adds a `FuseConvWithEltwise` pass to fuse convolution with elementwise OP for TE subgraph. This pass will insert prepack and packed run ops for conv2d and enable fusion of conv2d with elementwise OPs. The fused packed run ops is implemented via external call in NNC.

## Code structure
Graph rewrite pass related code is placed in:
```
torch/csrc/jit/passes/mkldnn_rewrite.h
torch/csrc/jit/passes/mkldnn_rewrite.cpp
```

NNC integration of fused conv-eltwise OP via external call is located in:
```
torch/csrc/jit/tensorexpr/kernel.cpp

torch/csrc/jit/tensorexpr/operators/conv2d.h
torch/csrc/jit/tensorexpr/operators/conv2d.cpp

torch/csrc/jit/tensorexpr/lowerings.cpp
torch/csrc/jit/tensorexpr/external_functions.cpp
```

Fused prepack OP context is in:
```
aten/src/ATen/native/mkldnn/Common.h
aten/src/ATen/native/mkldnn/RegisterMkldnnOpContextClass.cpp
aten/src/ATen/native/mkldnn/OpContext.h
aten/src/ATen/native/mkldnn/OpContext.cpp
```

Fused OP implementation is done in:
```
aten/src/ATen/native/mkldnn/ConvPrepack.h
aten/src/ATen/native/mkldnn/ConvPrepack.cpp
```

## OP benchmark for conv-relu
The below performance is measured on top of these two PRs to support NHWC: https://github.com/pytorch/pytorch/pull/76948 and https://github.com/pytorch/pytorch/pull/78238.

- Measured on Cascade Lake 8280
- Jemalloc enabled
- batch_size = 1
- Channels Last format

### Single thread:
<html xmlns:v="urn:schemas-microsoft-com:vml"
xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:x="urn:schemas-microsoft-com:office:excel"
xmlns="http://www.w3.org/TR/REC-html40">

<head>

<meta name=ProgId content=Excel.Sheet>
<meta name=Generator content="Microsoft Excel 15">
<link id=Main-File rel=Main-File
href="file:///C:/Users/chunyuan/AppData/Local/Temp/msohtmlclip1/01/clip.htm">
<link rel=File-List
href="file:///C:/Users/chunyuan/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml">

</head>

<body link="#0563C1" vlink="#954F72">

shape | time (us)_no_fusion | time (us)_fusion | Gain
-- | -- | -- | --
kernel=3, N=1, iC=64, H=56, W=56, oC=64,   stride=1, pad=1, dilates=1, g=1 | 1706.22 | 1371.97 | 19.59%
kernel=1, N=1, iC=256, H=56, W=56,   oC=512, stride=2, pad=0, dilates=1, g=1 | 2499.28 | 1571.52 | 37.12%
kernel=3, N=1, iC=256, H=56, W=56,   oC=256, stride=1, pad=1, dilates=1, g=32 | 4169.52 | 2738.53 | 34.32%
kernel=3, N=1, iC=512, H=56, W=56,   oC=512, stride=2, pad=1, dilates=1, g=32 | 3998.77 | 3085.85 | 22.83%
kernel=1, N=1, iC=64, H=56, W=56, oC=64,   stride=1, pad=0, dilates=1, g=1 | 673.73 | 430.81 | 36.06%
kernel=1, N=1, iC=256, H=56, W=56, oC=64,   stride=1, pad=0, dilates=1, g=1 | 1101.87 | 801.07 | 27.30%
kernel=1, N=1, iC=256, H=56, W=56,   oC=256, stride=1, pad=0, dilates=1, g=1 | 4692.91 | 3116.13 | 33.60%
kernel=1, N=1, iC=512, H=28, W=28,   oC=512, stride=1, pad=0, dilates=1, g=1 | 3310.64 | 2503.39 | 24.38%

</body>

</html>

### 4 threads:
<html xmlns:v="urn:schemas-microsoft-com:vml"
xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:x="urn:schemas-microsoft-com:office:excel"
xmlns="http://www.w3.org/TR/REC-html40">

<head>

<meta name=ProgId content=Excel.Sheet>
<meta name=Generator content="Microsoft Excel 15">
<link id=Main-File rel=Main-File
href="file:///C:/Users/chunyuan/AppData/Local/Temp/msohtmlclip1/01/clip.htm">
<link rel=File-List
href="file:///C:/Users/chunyuan/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml">

</head>

<body link="#0563C1" vlink="#954F72">

shape | time (us)_no_fusion | time (us)_fusion | Gain
-- | -- | -- | --
kernel=3, N=1, iC=64, H=56, W=56, oC=64,   stride=1, pad=1, dilates=1, g=1 | 360.07 | 321.21 | 10.79%
kernel=1, N=1, iC=256, H=56, W=56,   oC=512, stride=2, pad=0, dilates=1, g=1 | 391.49 | 323.17 | 17.45%
kernel=3, N=1, iC=256, H=56, W=56,   oC=256, stride=1, pad=1, dilates=1, g=32 | 536.4 | 465.97 | 13.13%
kernel=3, N=1, iC=512, H=56, W=56,   oC=512, stride=2, pad=1, dilates=1, g=32 | 674.98 | 616.32 | 8.69%
kernel=1, N=1, iC=64, H=56, W=56, oC=64,   stride=1, pad=0, dilates=1, g=1 | 160.97 | 70.05 | 56.48%
kernel=1, N=1, iC=256, H=56, W=56, oC=64,   stride=1, pad=0, dilates=1, g=1 | 215.81 | 182.6 | 15.39%
kernel=1, N=1, iC=256, H=56, W=56,   oC=256, stride=1, pad=0, dilates=1, g=1 | 658.45 | 576.97 | 12.37%
kernel=1, N=1, iC=512, H=28, W=28,   oC=512, stride=1, pad=0, dilates=1, g=1 | 702.18 | 566.39 | 19.34%

</body>

</html>

### 1 socket (28 cores):
<html xmlns:v="urn:schemas-microsoft-com:vml"
xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:x="urn:schemas-microsoft-com:office:excel"
xmlns="http://www.w3.org/TR/REC-html40">

<head>

<meta name=ProgId content=Excel.Sheet>
<meta name=Generator content="Microsoft Excel 15">
<link id=Main-File rel=Main-File
href="file:///C:/Users/chunyuan/AppData/Local/Temp/msohtmlclip1/01/clip.htm">
<link rel=File-List
href="file:///C:/Users/chunyuan/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml">

</head>

<body link="#0563C1" vlink="#954F72">

shape | time (us)_no_fusion | time (us)_fusion | Gain
-- | -- | -- | --
kernel=3, N=1, iC=64, H=56, W=56, oC=64,   stride=1, pad=1, dilates=1, g=1 | 149.92 | 103.78 | 30.78%
kernel=1, N=1, iC=256, H=56, W=56,   oC=512, stride=2, pad=0, dilates=1, g=1 | 192.76 | 110.87 | 42.48%
kernel=3, N=1, iC=256, H=56, W=56,   oC=256, stride=1, pad=1, dilates=1, g=32 | 160.67 | 127.24 | 20.81%
kernel=3, N=1, iC=512, H=56, W=56,   oC=512, stride=2, pad=1, dilates=1, g=32 | 212.45 | 180.55 | 15.02%
kernel=1, N=1, iC=64, H=56, W=56, oC=64,   stride=1, pad=0, dilates=1, g=1 | 114.57 | 50.58 | 55.85%
kernel=1, N=1, iC=256, H=56, W=56, oC=64,   stride=1, pad=0, dilates=1, g=1 | 198.64 | 70.6 | 64.46%
kernel=1, N=1, iC=256, H=56, W=56,   oC=256, stride=1, pad=0, dilates=1, g=1 | 281.35 | 155.8 | 44.62%
kernel=1, N=1, iC=512, H=28, W=28,   oC=512, stride=1, pad=0, dilates=1, g=1 | 262.15 | 162.94 | 37.84%

</body>

</html>

## UT
```
test/test_mkldnn_fusion.py
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77157
Approved by: https://github.com/ZolotukhinM
2022-08-10 21:46:51 +00:00
Wang, Eikan
11b9a81e02 [NNC] channels last propagation within NNC fusion group (#76948)
Decide the memory layout propagation policy and propagate it within the NNC fusion group. The memory layout propagation policy could be `Contiguous` and `Channels-last contiguous`.
 - `Contiguous`: Convert the non-contiguous including channels-last contiguous input tensors to contiguous and generate the contiguous output `Buf` for lowering function.
 - `Channels-last contiguous`: Convert the input tensors to channels-last contiguous and generate the channels-last contiguous output `Buf` for lowering function.

Currently, the rule is simple. If all the input and out tensors of the NNC fusion group are channels-last contiguous, then the propagated memory layout is `Channels-last contiguous`. Otherwise, it is always `Contiguous` which is as same as current situation. It means that this PR provides a fast path to channels-last and the optimization is conservative since its trigger conditions are strict.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76948
Approved by: https://github.com/ZolotukhinM
2022-05-30 18:31:49 +00:00
Wang, Eikan
429a80dded [NNC] Lowering function generates the output buffer with the specified stride (#76529)
Summary:
Pass stride information to lowering function to generate the output bufer with proper memory layout.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76529

Reviewed By: ZolotukhinM

Differential Revision: D36116712

Pulled By: IvanKobzarev

fbshipit-source-id: d3901f756b3710ecce172d6db3ecb0b7c12fb929
(cherry picked from commit b6cd53c91c01db36ea0e99167dc0ce0ae1d3aa23)
2022-05-04 20:04:22 +00:00
zengk95
1d55518198 Revert "[nnc] Strides to Tensor (#72962)"
This reverts commit 939060925f.

Fixes https://github.com/pytorch/vision/issues/5873

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76332
Approved by: https://github.com/seemethere
2022-04-25 19:50:00 +00:00
Ivan Kobzarev
939060925f [nnc] Strides to Tensor (#72962)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72962

Test Plan: Imported from OSS

Reviewed By: ZolotukhinM, cpuhrsch

Differential Revision: D34589306

Pulled By: IvanKobzarev

fbshipit-source-id: ecee5249760ecc0c8b2edb1842b90218899bc944
(cherry picked from commit 9e310c4c67389da30da89126d838ffe3864aba6f)
2022-04-23 19:35:15 +00:00
Nikita Shulga
f6c275f55d Remove -Wno-unused-variable from utils.cmake (take 2) (#75538)
Summary:
[Comment](https://github.com/pytorch/pytorch/pull/62445/files#r680132022) claims, it got added for consistency with  top level CMakeLists.txt, but `-Wno-unused-variable` is not mentioned there.

Modify violations in 50+ files that were added in the interim by either removing unused variables, or decorating the code with `C10_UNUSED` if local variable is likely used to extend object lifetime until the end of the block.

Caused preventable revert in https://github.com/pytorch/pytorch/pull/72633#issuecomment-1092300787

Pull Request resolved: https://github.com/pytorch/pytorch/pull/75538

Reviewed By: anjali411

Differential Revision: D35747333

Pulled By: malfet

fbshipit-source-id: 3fc5828e44a4c05ba0e89e92613e6ebbdb260626
(cherry picked from commit c179fba21cfa2a0093fad50ccad5a22dd7cff52c)
2022-04-20 17:41:59 +00:00
PyTorch MergeBot
5c56b2286b Revert "Remove -Wno-unused-variable from utils.cmake"
This reverts commit 018cbe1f5c.

Reverted https://github.com/pytorch/pytorch/pull/75538 on behalf of https://github.com/seemethere
2022-04-19 17:19:09 +00:00
Nikita Shulga
018cbe1f5c Remove -Wno-unused-variable from utils.cmake
[Comment](https://github.com/pytorch/pytorch/pull/62445/files#r680132022) claims, it got added for consistency with  top level CMakeLists.txt, but `-Wno-unused-variable` is not mentioned there.

Modify violations in 50+ files that were added in the interim by either removing unused variables, or decorating the code with `C10_UNUSED` if local variable is likely used to extend object lifetime until the end of the block.

Caused preventable revert in https://github.com/pytorch/pytorch/pull/72633#issuecomment-1092300787

Pull Request resolved: https://github.com/pytorch/pytorch/pull/75538
Approved by: https://github.com/cpuhrsch
2022-04-19 15:26:55 +00:00
Raghavan Raman
d8ad1a579f [nnc] Fuse loops that have variable bounds
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74346

Approved by: https://github.com/ZolotukhinM
2022-04-14 20:24:03 +00:00
Raghavan Raman
1b99996119 [nnc] Make run methods in TensorExprKernel const (#73240)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/73240

Test Plan: Imported from OSS

Reviewed By: huiguoo

Differential Revision: D34399527

Pulled By: navahgar

fbshipit-source-id: 59501c6eb9a5166dbef21dbc36543862f136bfdc
(cherry picked from commit 7997f0eba269c22f64bb6b724bd5de8d4e41de8c)
2022-03-01 05:32:35 +00:00
Raghavan Raman
6d33852685 [NNC] TensorExprKernel state should not be modified on calls to run methods (#73028)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73028

A typical use case for `TensorExprKernel` is to create the kernel once and call it multiple times, possibly in parallel. For the parallel calls to work, we need to ensure that the run() method calls do not change any state in `TensorExprKernel`.

Before this change, the `run()` method was modifying the sizes and strides vectors when dynamic shapes were present. This manifested as a data race when running a model with Static Runtime.
ghstack-source-id: 149398820

Test Plan:
```
buck build mode/dev-asan //caffe2/test/cpp/tensorexpr:tensorexpr
./buck-out/dev/gen/caffe2/test/cpp/tensorexpr/tensorexpr --gtest_filter="DynamicShapes.MultiThreadedExecution"
```

Reviewed By: eellison

Differential Revision: D34287960

fbshipit-source-id: d311f3c5a66c5d5de4e1deaeaa01816b53e9906e
(cherry picked from commit 161568bfae)
2022-02-17 23:14:27 +00:00
Ryan Spring
4f8b986e28 Implement Tanh Gelu Approximation (#61439)
Summary:
1. Implements https://github.com/pytorch/pytorch/issues/39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - https://github.com/pytorch/xla/pull/3039

Pull Request resolved: https://github.com/pytorch/pytorch/pull/61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a9)
2022-02-14 03:40:32 +00:00
Mikhail Zolotukhin
1855b14922 [TensorExpr] Delet DimArg class. (#72390)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72390

This class didn't add much value and only caused more boilerplate code.
This change removes the class and updates all the use cases with
uses of `ExprHandle`.

A side effect of this change is different names in loop variables, which
caused massive mechanical changes in our tests.

Test Plan: Imported from OSS

Reviewed By: navahgar

Differential Revision: D34030296

Pulled By: ZolotukhinM

fbshipit-source-id: 2ba4e313506a43ab129a10d99e72b638b7d40108
(cherry picked from commit c2ec46a058)
2022-02-11 01:21:59 +00:00
Raghavan Raman
ff71429906 [nnc] Add stride args while running with allocated outputs (#72223)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72223

ghstack-source-id: 148494871

Test Plan:
```
buck test mode/opt //caffe2/test/cpp/tensorexpr:tensorexpr -- --exact 'caffe2/test/cpp/tensorexpr:tensorexpr - DynamicShapes.GraphWithSymbolicStrides'
```

Reviewed By: eellison

Differential Revision: D33960592

fbshipit-source-id: 6334978d5e3713889b4ad12bcd8ed8c69df39d58
(cherry picked from commit 95cc102bc2)
2022-02-07 19:24:56 +00:00
Elias Ellison
defde3bb04 [NNC] Use index for stride mapping in kernel.cpp (#72266)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72266

Within the kernel, we may manipulate `Value *` in `OptimizeCat`, which would invalidate the input `Value *` -> Stride mapping.

Fix for https://github.com/pytorch/pytorch/issues/72173

Test Plan: Imported from OSS

Reviewed By: dagitses, davidberard98

Differential Revision: D33986306

Pulled By: eellison

fbshipit-source-id: dc33cd2b545e49e90d1e46b9fcf1e6dbb4b829db
(cherry picked from commit 5e4555968a)
2022-02-04 00:12:38 +00:00
Elias Ellison
27a4d39756 NNC Dynamic Channels last fixes (#72032)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72032

This contains a few channels last changes from benchmarking:
- dont permute back to channels last on dynamic, cpu, perf is not good, and use cases for it are exotic atm
- remove the conditional one handling in permutting channels last symbolic tensor on cuda, it's not needed in the permutation case as tests show
- removing logic in torch/csrc/jit/tensorexpr/loopnest.cpp preventing inlining. the condition in checks is always valid given valid construction of ir

I can split up as needed.

Test Plan: Imported from OSS

Reviewed By: navahgar

Differential Revision: D33864652

Pulled By: eellison

fbshipit-source-id: f16674fb02dfff22670d8a2f856c5a317fd15717
(cherry picked from commit a9a0697839)
2022-02-01 19:07:02 +00:00
Mikhail Zolotukhin
bd6ec4efb4 [TensorExpr] Add lowerings for scalar binary ops (+,-,*,/,&,|,^,<<,>>,cmp). (#71298)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/71298

Differential Revision:
D33576534
D33576534

Test Plan: Imported from OSS

Reviewed By: anjali411

Pulled By: ZolotukhinM

fbshipit-source-id: 93787b6f11180fcbfbacbb55e1bfb79700320a0e
(cherry picked from commit b2a8e83f97)
2022-01-26 06:32:51 +00:00
Mikhail Zolotukhin
1dbcde2ade [TensorExpr] Support scalar intermediate and output values. (#71186)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71186

So far we've only supported scalar inputs, but couldn't handle scalar outputs
or intermediates. This PR adds it.

Scalar outputs are returned as 0-dim tensors. If the kernel is invoked on a
stack of IValues, we correctly convert the results to scalar IValues when
needed. If the kernel is invoked with a vector of void* pointers, everything
works out of the box without any conversions.

Lowerings for scalar operators are a bit tricky. Usual lowerings return a pair
<Buf, Stmt> (aka Tensor), but for scalar operators we also want to have the
corresponding Var that the lowering function supposedly creates (in theory we
could just use Loads and Stores, but I'm worried it can affect performance as
there is no guarantee this will be optimized by LLVM). So, what we do here to
work around this is we return a fake buf + stmt that sets the corresponding
var. Then outside of the lowering we create a real buffer and generate a Store
to it with the value from the variable we passed as the base handle of the fake
buf. This real buffer is then treated as usual by the rest of the system and we
can use it if we need to return this scalar value as a kernel output. If we do
not need to return it, then the Store will be deleted by the DCE pass.

Differential Revision:
D33539324
D33539324

Test Plan: Imported from OSS

Reviewed By: navahgar

Pulled By: ZolotukhinM

fbshipit-source-id: ab4524b9820ce204f106effcf6232ed33d4ee223
(cherry picked from commit 7faa0939f0)
2022-01-26 06:32:51 +00:00
CodemodService FBSourceClangFormatLinterBot
60632a00fe [AutoAccept][Codemod][FBSourceClangFormatLinter] Daily arc lint --take CLANGFORMAT
Reviewed By: zertosh

Differential Revision: D33561057

fbshipit-source-id: 79873717c45c8bbe6d0ae760e718770fd960185d
2022-01-13 03:27:06 -08:00
Elias Ellison
5480deb183 Add support for permutting dynamic fusion group outputs to channels last format (#70656)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/70656

Test Plan: Imported from OSS

Reviewed By: navahgar

Differential Revision: D33458650

Pulled By: eellison

fbshipit-source-id: f0c7d20743deac7a87f7c9176e60da8100aefe41
2022-01-12 09:11:34 -08:00
Elias Ellison
39be20f259 [JIT][NNC] Add handling of strides to dynamic shape support. (#70464)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70464

Add handling of strided input tensors to dynamic fusion. This is done with the same set of input striding specializations as https://github.com/pytorch/pytorch/pull/60684/:
```
  S_ONE, // STRIDE_ONE: packed
  S_CONT, // STRIDE_CONTIGUOUS: stride[i + 1] * sizes[i + 1]
  S_TRAN_CONT, // STRIDE_TRANSPOSED_CONTIGUOUS: stride[i-1] * sizes[i-1]
  S_AS_ARG, // STRIDE_AS_ARG: stride passed in as runtime value
```
and then two additional specializations for a) contiguous tensor and b) channels-last tensor. channels-last is a common case and we should optimize for it. additionally, tensors natively store whether they are contiguous/channels-last contiguous, which makes it faster to check if tensors follow this pattern.

Output striding will be done in a follow up.

The striding is stored on both the TensorGroup node and on the guard node. The striding descriptors are stored as a vector of strings on the node for debugability and to make use of storing ivalues as attributes on nodes.

As an example:

```

%8 : Double(10, 11, 12, 13, strides=[1716, 1, 143, 11], requires_grad=0, device=cpu) = prim::TensorExprGroup_0[symbolic_shape_inputs=[-37, -36, -35, -34], striding_inputs_desc=[["TENSOR_CONT_CHANNELS_LAST"]](%x, %24, %23, %22, %21)```
```

Test Plan: Imported from OSS

Reviewed By: navahgar

Differential Revision: D33458649

Pulled By: eellison

fbshipit-source-id: c42616d3c683d70f6258180d23d3841a31a6030d
2022-01-12 09:11:31 -08:00
Mikhail Zolotukhin
8223ef1cd8 [TensorExpr] Clean-up logic for copying input tensors and remove some dead code. (#70535)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70535

This also fixes handling of inputs that happen to be outputs (they
require copy).

Test Plan: Imported from OSS

Reviewed By: pbelevich

Differential Revision: D33399116

Pulled By: ZolotukhinM

fbshipit-source-id: 9845838eb653b82ae47b527631b51893990d5319
2022-01-07 01:03:56 -08:00
Animesh Jain
6896b2d734 [NNC Testing] Randomized loop nest infrastructure (#70410)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70410

Trying again after #70174 was reverted. Earlier the env
variable was read into a static var in C++ causing state to be retained
and causing test failures. Static type is removed in this PR.

Test Plan: Imported from OSS

Reviewed By: ZolotukhinM

Differential Revision: D33321435

fbshipit-source-id: 6d108eb00cac9150a142ccc3c9a65a1867dd7de4
2022-01-06 16:21:42 -08:00
Mikhail Zolotukhin
0ee663d2fa Revert D33234529: [NNC Testing] Randomized loop nest infrastructure
Test Plan: revert-hammer

Differential Revision:
D33234529 (1d094587ea)

Original commit changeset: 9019f1f1d4ca

Original Phabricator Diff: D33234529 (1d094587ea)

fbshipit-source-id: a79deca9f186299bf884587eb7d50af2464979fb
2021-12-23 23:11:23 -08:00
Animesh Jain
1d094587ea [NNC Testing] Randomized loop nest infrastructure (#70174)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/70174

Test Plan: Imported from OSS

Reviewed By: ZolotukhinM

Differential Revision: D33234529

fbshipit-source-id: 9019f1f1d4ca945c92bee401f7ec674b7d987de4
2021-12-22 22:07:39 -08:00
Raghavan Raman
4dec15e6d8 [nnc] Add a run method to TensorExprKernel that takes in output tensors (#69477)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69477

This diff adds a new run method to `TensorExprKernel` which takes in
output tensors as inputs and stores the output in those given tensors.
ghstack-source-id: 146107009

Test Plan: buck test mode/dev-nosan //caffe2/test/cpp/tensorexpr:tensorexpr -- --exact 'caffe2/test/cpp/tensorexpr:tensorexpr - Kernel.RunWithAllocatedOutputs'

Reviewed By: ZolotukhinM

Differential Revision: D32823890

fbshipit-source-id: edc1f4839785124048b034060feb71cb8c1be34f
2021-12-22 00:30:15 -08:00
Hui Guo
ac92f7cc75 [tensorexpr] Remove the optional argument in LoopNest::prepareForCodeGen (#67144)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67144

Test Plan: Imported from OSS

Reviewed By: ZolotukhinM

Differential Revision: D31881150

Pulled By: huiguoo

fbshipit-source-id: af99087722ec71d6deb9049b63b573ae7720c9ec
2021-12-17 01:37:59 -08:00
Hui Guo
531b045446 [tensorexpr] Fix the buf size of discontiguous tensors (#69657)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/69657

Test Plan: Imported from OSS

Reviewed By: ZolotukhinM

Differential Revision: D32974473

Pulled By: huiguoo

fbshipit-source-id: 52dcd13d0ad7f7e4f1beb69dcaabc8ceb386ffca
2021-12-10 01:26:37 -08:00
Mikhail Zolotukhin
1e9dcdd2a0 [TensorExpr] TensorExprKernel: support custom-class constants. (#68856)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/68856

Test Plan: Imported from OSS

Reviewed By: navahgar

Differential Revision: D32632907

Pulled By: ZolotukhinM

fbshipit-source-id: e4180f8d791ba0cdf82bcb3bd11b61405c2faadd
2021-12-02 14:34:15 -08:00
Mikhail Zolotukhin
ec94bb787a [TensorExpr] Add a way to define target triple/cpu/attrs for llvm codegen and turn on the AOT workflow. (#66527)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/66527

Differential Revision:
D31593869
D31593869

Test Plan: Imported from OSS

Reviewed By: navahgar

Pulled By: ZolotukhinM

fbshipit-source-id: e7534c11fbcf0dab5f49d01d6053caf77b833ef0
2021-11-13 00:52:20 -08:00
Mikhail Zolotukhin
e511a7a5b4 [TensorExpr] Remove non-determinism in iterating over unordered_set of intermediate buffers. (#68277)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/68277

Differential Revision:
D32400553
D32400553

Test Plan: Imported from OSS

Reviewed By: saketh-are, priyaramani

Pulled By: ZolotukhinM

fbshipit-source-id: a8fe820bbddaa19f95db432efaa6d3e36095a05e
2021-11-13 00:50:57 -08:00
Ivan Kobzarev
362c6069b9 [nnc] Lazy lowerings registration; custom classes network params (#67623)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67623

Test Plan: Imported from OSS

Reviewed By: ZolotukhinM

Differential Revision: D32065076

Pulled By: IvanKobzarev

fbshipit-source-id: 4945ac6483938d428c539ed1ce4fcd6988b34250
2021-11-11 09:00:23 -08:00
Raghavan Raman
e7a3bbce89 [nnc] Add support for dynamic shapes in TensorExprKernel (#67861)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67861

Previously submitted as https://github.com/pytorch/pytorch/pull/67197.
This got reverted because its failures were hidden by the failures of
another PR.

Test Plan: Imported from OSS

Reviewed By: ZolotukhinM

Differential Revision: D32178196

Pulled By: navahgar

fbshipit-source-id: cc8a5c68aed360d06289e69645461cfa773e1300
2021-11-05 11:18:19 -07:00
Natalia Gimelshein
ca445645f9 Revert D31902471: [nnc] Add support for dynamic shapes in TensorExprKernel
Test Plan: revert-hammer

Differential Revision:
D31902471 (15a3c374e2)

Original commit changeset: d2729a38ba1a

fbshipit-source-id: 4c05de82e626bbf744df84fd2b914b66fd165a19
2021-11-03 14:48:12 -07:00
Raghavan Raman
15a3c374e2 [nnc] Add support for dynamic shapes in TensorExprKernel (#67197)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67197

Test Plan: Imported from OSS

Reviewed By: eellison, ZolotukhinM

Differential Revision: D31902471

Pulled By: navahgar

fbshipit-source-id: d2729a38ba1ac607ff07f516ed56fbd9085715dc
2021-11-03 11:24:17 -07:00
Ivan Kobzarev
7fbcf79684 [tensorexpr][nnc] Support quantization (#66676)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/66676

Test Plan: Imported from OSS

Reviewed By: navahgar

Differential Revision: D31676329

Pulled By: IvanKobzarev

fbshipit-source-id: 288b41ff4ed603dfaacb465f296997f14bb23c22
2021-10-31 22:49:30 -07:00
Priya Ramani
fa70d72e95 Set kernel func name from AOT Compiler (#67229)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67229

Right now, assembly code generated for the a given method from the model is named wrapper or func by default. The function name is then replaced with a proper kernel_func_name after target specific assembly is generated.
This PR propagates a desired kernel_func_name right from aotCompiler API so that the generated function has the needed name that doesn't need to be replaced later.

Note: Most of this change was landed in https://github.com/pytorch/pytorch/pull/66337 which had to be reverted as it was breaking `test_profiler` in `test_jit_fuser_te` as it replaced the name generated for graph with the default kernel_func_name value. This PR fixes that as well.

```
(pytorch)  ~/local/pytorch kname
└─ $ python3 test/test_jit_fuser_te.py
CUDA not available, skipping tests
monkeytype is not installed. Skipping tests for Profile-Directed Typing
........................................<string>:3: UserWarning: torch.cholesky is deprecated in favor of torch.linalg.cholesky and will be removed in a future PyTorch release.
L = torch.cholesky(A)
should be replaced with
L = torch.linalg.cholesky(A)
and
.
.
.
......................<string>:3: UserWarning: torch.symeig is deprecated in favor of torch.linalg.eigh and will be removed in a future PyTorch release.
The default behavior has changed from using the upper triangular portion of the matrix by default to using the lower triangular portion.
L, _ = torch.symeig(A, upper=upper)
should be replaced with
L = torch.linalg.eigvalsh(A, UPLO='U' if upper else 'L')
and
L, V = torch.symeig(A, eigenvectors=True)
should be replaced with
L, V = torch.linalg.eigh(A, UPLO='U' if upper else 'L') (Triggered internally at  ../aten/src/ATen/native/BatchLinearAlgebra.cpp:2492.)
......[W pybind_utils.cpp:35] Warning: Using sparse tensors in TorchScript is experimental. Many optimization pathways have not been thoroughly tested with sparse tensors. Please include the fact that the network is running sparse tensors in any bug reports submitted. (function operator())
/data/users/priyaramani/pytorch/torch/testing/_internal/common_utils.py:403: UserWarning: Using sparse tensors in TorchScript is experimental. Many optimization pathways have not been thoroughly tested with sparse tensors. Please include the fact that the network is running sparse tensors in any bug reports submitted. (Triggered internally at  ../torch/csrc/jit/python/pybind_utils.h:691.)
  return callable(*args, **kwargs)
.....................................................................[W Resize.cpp:23] Warning: An output with one or more elements was resized since it had shape [1], which does not match the required output shape [].This behavior is deprecated, and in a future PyTorch release outputs will not be resized unless they have zero elements. You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0). (function resize_output_check)
[W Resize.cpp:23] Warning: An output with one or more elements was resized since it had shape [1, 5], which does not match the required output shape [5].This behavior is deprecated, and in a future PyTorch release outputs will not be resized unless they have zero elements. You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0). (function resize_output_check)
........................................................................s.......s...s.s....s......s..sss............................
----------------------------------------------------------------------
Ran 503 tests in 37.536s

OK (skipped=10)
```

Test Plan: Imported from OSS

Reviewed By: navahgar, pbelevich

Differential Revision: D31945713

Pulled By: priyaramani

fbshipit-source-id: f2246946f0fd51afba5cb6186d9743051e3b096b
2021-10-27 13:10:49 -07:00
Natalia Gimelshein
b6fa998892 Revert D31514095: Use kernel_func_name from aotCompiler
Test Plan: revert-hammer

Differential Revision:
D31514095 (7b55dc8340)

Original commit changeset: b70c8e2c7336

fbshipit-source-id: ad4d828f33506e612b51c276149fa0e12b0565d5
2021-10-23 17:17:53 -07:00
Priya Ramani
7b55dc8340 Use kernel_func_name from aotCompiler (#66337)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66337

Right now, assembly code generated for the a given method from the model is named wrapper or func by default. The function name is then replaced with a proper kernel_func_name after target specific assembly is generated.
This PR propagates a desired kernel_func_name right from aotCompiler API so that the generated function has the needed name that doesn't need to be replaced later.

Test Plan: Imported from OSS

Reviewed By: navahgar

Differential Revision: D31514095

Pulled By: priyaramani

fbshipit-source-id: b70c8e2c733600a435cd4e8b32092d37b7bf7de5
2021-10-23 02:20:45 -07:00
Mikhail Zolotukhin
60a2a295ce [TensorExpr] Use schema instead of op name in NNC lowerings. (#65843)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65843

Fixes #64963.

Test Plan: Imported from OSS

Reviewed By: navahgar

Differential Revision: D31282334

Pulled By: ZolotukhinM

fbshipit-source-id: ffd0e1b6433d9360fedd9081c01ef41b21684439
2021-10-12 01:26:32 -07:00
Mikhail Zolotukhin
24b9b304d9 [TensorExpr] Nuke TE shape inference. (#65554)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65554

We're relying on JIT based shape inference and not using the TE
implementation.

Question to the audience: we set `hasBroadcasts_` in that function, but
this function was almost never invoked. Do we behave correctly in the
presence of rand-calls and broadcasts?

Test Plan: Imported from OSS

Reviewed By: bertmaher

Differential Revision: D31148925

Pulled By: ZolotukhinM

fbshipit-source-id: 2898a57e389ea0950163122089d0fec3d92701c4
2021-10-12 01:25:14 -07:00
Scott Wolchok
2d885ab73d [jit] Reduce refcounting of Types (#65345)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65345

FooType::get() can return a const reference. Inconveniently, converting shared_ptr<FooType> to shared_ptr<Type> requires a copy & refcount bump, so to properly take advantage of this in unshapedType() we need to take a const Type& in isSubtypeOf(), which is good practice anyway -- don't require a shared_ptr if you don't need to take ownership.
ghstack-source-id: 140044165

Test Plan:
CI

perf says c10::unshapedType time decreased from 2.8% to 2.2% during static runtime startup, though I expect this to be generally beneficial.

Reviewed By: hlu1

Differential Revision: D31027361

fbshipit-source-id: 676feb81db9f74ad7b8651d8774f4ecb4cfa6ab8
2021-10-08 09:03:04 -07:00
Mikhail Zolotukhin
765b6a90f3 [TensorExpr] Move lowerings registration from kernel.cpp to lowerings.cpp. (#65553)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65553

Test Plan: Imported from OSS

Reviewed By: navahgar

Differential Revision: D31148921

Pulled By: ZolotukhinM

fbshipit-source-id: 772062155043d4be9e9a25f6259b8e4a6cb762f4
2021-09-30 22:56:22 -07:00
Mikhail Zolotukhin
015e0079e3 [TensorExpr] Move 'compute*' functions to operators/... (#65552)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65552

This PR is mostly a verbatim move of several functions to different
files. The goal is to have more consistency in what resides where.

With this PR:
* All `compute*` functions defining how a given operator needs to be
lowered to TE IR will reside in `operators/*.{cpp,h}`.
* Auxiliary functions for these functions will reside in
`operators/misc.cpp`. `compute*` functions for ops not belonging
anywhere else can also go to that file.
* `operators/unary.*` is renamed to `operators/pointwise.*` and now
includes functions like `computeTwoOperands`.
* `kernel.*` now contains *only JIT-related* logic and implementations of
`TensorExprKernel` methods.

Test Plan: Imported from OSS

Reviewed By: navahgar

Differential Revision: D31148923

Pulled By: ZolotukhinM

fbshipit-source-id: e36ad8e779b8d30a33b49ea4ebf6d6a7438989f4
2021-09-30 22:56:20 -07:00