Commit Graph

157 Commits

Author SHA1 Message Date
lezcano
f7b9a46880 Deprecate torch.lu
**BC-breaking note**:

This PR deprecates `torch.lu` in favor of `torch.linalg.lu_factor`.
A upgrade guide is added to the documentation for `torch.lu`.

Note this PR DOES NOT remove `torch.lu`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77636

Approved by: https://github.com/malfet
2022-06-07 22:50:14 +00:00
Akshit Khurana
bb3e1f30a8 [Pytorch NNAPI] Add compilation_preference & relax_f32_to_f16 APIs (#78758)
Summary:
compilation_preference is one of:

ANEURALNETWORKS_PREFER_LOW_POWER = 0
ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER = 1
ANEURALNETWORKS_PREFER_SUSTAINED_SPEED = 2

relax_f32_to_f16 calls Model_relaxComputationFloat32toFloat16

Test Plan:
Tested on device with nnapi models

* Works with existing exported models
* Works with new exported models with options

Differential Revision: D36433236

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78758
Approved by: https://github.com/kimishpatel
2022-06-06 20:57:34 +00:00
Max Ren
93d5a722b1 [coreml] Introducing Quantization (#78108)
Summary: Adding Quantization mode to preprocess, which allows us to run through quantization for coreml models

Test Plan:
https://fburl.com/anp/r0ntsbq0

Notebook runnining through quantization workflow:

created a custom bentos kernel to run it through coreml

```bento_kernel(
    name = "coreml",
    deps = [
        "fbsource//third-party/pypi/coremltools:coremltools",
        "//caffe2:coreml_backend",
        "//caffe2:coreml_backend_cpp",
        "//caffe2:torch",
        "//caffe2/torch/fb/mobile/model_exporter:model_exporter",
    ],
)
```

Initial benchmarks on iPhone 11:

FP32 Core ML Model:
https://our.intern.facebook.com/intern/aibench/details/203998485252700

Quantized Core ML Model:
https://our.intern.facebook.com/intern/aibench/details/927584023592505

High End Quantized Model:
https://our.intern.facebook.com/intern/aibench/details/396271714697929

Summarized Results
| Backend | Quantization | p50 net latency | Model Size |
|---------|--------------|-----------------|------------|
| Core ML | No           | 1.2200          | 1.2mb      |
| Core ML | Yes          | 1.2135          | 385kb      |
| CPU     | Yes          | 3.1720          | 426kb      |

Reviewed By: SS-JIA

Differential Revision: D36559966

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78108
Approved by: https://github.com/jmdetloff
2022-06-01 17:10:17 +00:00
PyTorch MergeBot
b994ce359e Revert "[cuDNN V8 API] (reopen) Allow the number of kernels profiled under torch.backends.cudnn.benchmark = True to be limitedCudnnv8 benchmark limit (#77002)"
This reverts commit c274f2ad52.

Reverted https://github.com/pytorch/pytorch/pull/77002 on behalf of https://github.com/malfet due to please, as it breaks internal CI, but also no CUDA heads should be included from `torch/csrc/Module.cpp`, but rather should be implemented/registered in `torch/csrc/cuda/Module.cpp`
2022-05-24 21:52:35 +00:00
Nikita Shulga
6244daa6a9 [MPS] Fix torch.mps.is_available() (#78121)
By introducing `at:mps::is_available()` and changing `torch._C._is_mps_available` from property to memoizable callable

Also, if `_mtl_device` is released in MPSDevice destructor, shouldn't it be retained in the constructor

Looks like GitHubActions Mac runner does not have any Metal devices available, according to https://github.com/malfet/deleteme/runs/6560871657?check_suite_focus=true#step:3:15

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78121
Approved by: https://github.com/albanD
2022-05-24 05:10:38 +00:00
Eddie Yan
c274f2ad52 [cuDNN V8 API] (reopen) Allow the number of kernels profiled under torch.backends.cudnn.benchmark = True to be limitedCudnnv8 benchmark limit (#77002)
(reopening due to botched merge)
The cuDNN V8 API (main support merged in https://github.com/pytorch/pytorch/pull/60755) potentially exposes many more kernels with benchmark=True. While these additional kernels can improve performance, it is often unnecessary to run every kernel returned by the heuristic and doing so may degrade the user experience by causing the first model iteration to be very slow. To alleviate this issue, this PR introduces torch.backends.cudnn.benchmark_limit. benchmark_limit specifies the maximum number of working cuDNN kernels to try for a given workload, with the default being 10 (similar to what TensorFlow does). benchmark_limit = 0 yields the current behavior of trying every kernel returned by the heuristic.

CC @ptrblck @ngimel @xwang233
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77002
Approved by: https://github.com/ngimel
2022-05-24 00:11:47 +00:00
Kulin Seth
f348b1b2b5 Add the Runtime components for MPS backend. (#76725)
The PR adds the runtime components and few basic operations like copy, as_strided for MPS backend.

Current list of identified TODOs are:

-  https://github.com/pytorch/pytorch/issues/77176
- Unify the logic with CUDACachingAllocator and remove redundant code.
-  https://github.com/pytorch/pytorch/issues/77170
- Look into using C++ smart pointers where possible with ObjC code
- Use empty_strided_generic() to implement the `empty_strided_mps` code
- https://github.com/pytorch/pytorch/issues/77144
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76725
Approved by: https://github.com/albanD
2022-05-11 17:19:45 +00:00
PyTorch MergeBot
1467e0dd5d Revert "Deprecate torch.lu"
This reverts commit a5bbfd94fb.

Reverted https://github.com/pytorch/pytorch/pull/73804 on behalf of https://github.com/malfet
2022-05-09 19:06:44 +00:00
lezcano
a5bbfd94fb Deprecate torch.lu
**BC-breaking note**:

This PR deprecates `torch.lu` in favor of `torch.linalg.lu_factor`.
A upgrade guide is added to the documentation for `torch.lu`.

Note this PR DOES NOT remove `torch.lu`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/73804

Approved by: https://github.com/IvanYashchuk, https://github.com/mruberry
2022-05-05 19:17:11 +00:00
Kurt Mohler
5375b2e994 Resolve int[]? arguments to new OptionalIntArrayRef class
This PR uses the `OptionalArrayRef` template class that was drafted in #64084.

Fixes #44409
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70864
Approved by: https://github.com/ezyang
2022-03-26 01:45:50 +00:00
Tao Xu
06ff4f570c [Core ML] Support enumerated input shapes (#74441)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74441

For xirp based segmentation models, we want to support enumerated input shapes. This allows us to support both landscape and portrait mode images without sacrificing the performance. P488118264
ghstack-source-id: 151736964

Test Plan: `buck run coreml:xirp -- --model="/home/taox/xirp/xirp_20a.pt" --out="/home/taox/xirp/xirp_20a_coreml_enumerated.ptl"`

Reviewed By: mcr229

Differential Revision: D34803184

fbshipit-source-id: c462c0783846a1489ca7ce4d5a654aa6927c9c44
(cherry picked from commit 67d418c97531daaf3d03d1000ca4a4ff60de2a95)
2022-03-21 21:32:24 +00:00
Weiwen Xia
060f1b822a Add onednn quant backend (#74137)
Summary:
Resolve the conflicts in https://github.com/pytorch/pytorch/pull/69820
jerryzh168 Please review. Thanks.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/74137

Reviewed By: samdow

Differential Revision: D34840477

Pulled By: jerryzh168

fbshipit-source-id: 8aa60981ff7be211a1609644f273b16d18efd425
(cherry picked from commit de76bb808b315e9a2e45d8c5f1c1233a47d669c4)
2022-03-15 01:28:21 +00:00
Jerry Zhang
5a897536f3 Revert D33716039: [pytorch][PR] Add ONEDNN quantization backend
Test Plan: revert-hammer

Differential Revision:
D33716039 (989b24855e)

Original commit changeset: 6f7bb807e857

Original Phabricator Diff: D33716039 (989b24855e)

fbshipit-source-id: ed233c5b99d4edb7d5a9d6c600825c78555f16d0
(cherry picked from commit d3e1f825b06ef67adb13623ccb7cbf1b700c1dd5)
2022-03-11 22:06:25 +00:00
Xia Weiwen
989b24855e Add ONEDNN quantization backend (#69820)
Summary:
This PR adds a new quantization backend, ONEDNN, with quantized conv and linear kernels in the same code path as the FBGEMM backend

The ONEDNN backend is an alternative of FBGEMM and QNNPACK backends. It takes advantage of features of the latest Intel® CPU products. It supports VNNI on Cascade Lake and the AMX instruction set to be available on Sapphire Rapids which has 8X int8 peak TOPS over VNNI.

ONEDNN demonstrates better performance on conv kernels of popular CNN models than FBGEMM. It also supports more fused ops, such as convolution-add-ReLU, than FBGEMM and QNNPACK.
To use this backend, users only need to set the quantization backend to 'onednn' before any calculation without a single change to models.
```python
torch.backends.quantized.engine = 'onednn'
```

## Design docs
https://github.com/pytorch/pytorch/issues/21120#issuecomment-562371983
https://github.com/pytorch/pytorch/pull/67177#issuecomment-963787096

## File changes
**Add ONEDNN to qengine list**
- aten/src/ATen/Context.cpp
- c10/core/QEngine.h
- torch/ao/quantization/qconfig.py
- torch/backends/quantized/\_\_init\_\_.py

**Implement qconv & qlinear for ONEDNN backend**
- aten/src/ATen/native/quantized/cpu/conv_serialization.h
- aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp
- aten/src/ATen/native/quantized/cpu/onednn_utils.h
- aten/src/ATen/native/quantized/cpu/qconv.cpp
- aten/src/ATen/native/quantized/cpu/qconv_dynamic.cpp
- aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp
- aten/src/ATen/native/quantized/cpu/qconv_unpack.cpp
- aten/src/ATen/native/quantized/cpu/qlinear.cpp
- aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp
- aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp
- aten/src/ATen/native/quantized/cpu/qlinear_unpack.cpp

**Skip tests that are not supported by ONEDNN**
- test/ao/sparsity/test_kernels.py
- test/quantization/core/test_quantized_module.py
- test/quantization/core/test_quantized_op.py

## Validation results
This PR has passed `test_quantization.py` and `test_mkldnn.py`.
Below are performance data of int8 2d convolution and linear on the Cascade Lake Xeon® platform:
(Note: Tested with single instance on single core. Using the latest oneDNN library.)

**Table 1. Performance comparison of int8 2d convolution operator**
|No.|	Shape|	FBGEMM|	ONEDNN|	Gain|
|-|-|-|-|-|
|1|	IC=128, OC=128, kernel=3, stride=1, N=4, H=32, W=32, G=1, pad=0|	668.310us|	535.630us|	24.8%|
|2|	IC=128, OC=128, kernel=3, stride=2, N=4, H=32, W=32, G=1, pad=0|	290.630us|	281.810us|	3.1%|
|3|	IC=128, OC=256, kernel=3, stride=1, N=4, H=32, W=32, G=1, pad=0|	1.045ms|	893.010us|	17.0%|
|4|	IC=128, OC=256, kernel=3, stride=2, N=4, H=32, W=32, G=1, pad=0|	385.320us|	373.720us|	3.1%|
|5|	IC=256, OC=256, kernel=3, stride=1, N=4, H=32, W=32, G=1, pad=0|	1.876ms|	1.641ms|	14.3%|
|6|	IC=256, OC=256, kernel=3, stride=2, N=4, H=32, W=32, G=1, pad=0|	660.460us|	638.470us|	3.4%|

**Table 2. Performance comparison of int8 linear operator**
|No.|	Shape (m, n, k)|	FBGEMM|	ONEDNN|	Gap|
|-|-|-|-|-|
|1|	64, 800, 320|	80.550us|	96.770us|	20.10%|
|2|	64, 768, 512|	101.230us|	130.720us|	29.10%|
|3|	16, 256, 512|	30.230us|	51.450us|	70.20%|
|4|	128, 128, 128|	33.810us|	50.480us|	49.30%|
|5|	256, 512, 256|	154.490us|	195.050us|	26.30%|
|6|	1024, 1024, 1024|	3.134ms|	3.514ms|	12.10%|

ONEDNN showed advantages over FBGEMM for convolution. However, it has performance gap to FBGEMM for Linear ops. The gap is a known issue and further optimization is in progress in the oneDNN library. On the latest platforms, better performance of ONEDNN is achieved for both conv and linear.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/69820

Reviewed By: HDCharles

Differential Revision: D33716039

Pulled By: jerryzh168

fbshipit-source-id: 6f7bb807e85798142dfcffccfca8b8bd652fb3dd
(cherry picked from commit 91526b373560f42ba0ad307f9cccfc0eb5218b1f)
2022-03-11 20:31:49 +00:00
lkct
7d542a4f2b Fix type annotation for torch.backends.cudnn.allow_tf32 (#72757)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/72753

Pull Request resolved: https://github.com/pytorch/pytorch/pull/72757

Reviewed By: samdow

Differential Revision: D34204436

Pulled By: ngimel

fbshipit-source-id: 3528efd7bdf72c1d9338806555ecb643ab94ffeb
(cherry picked from commit 7036c2e6e6)
2022-02-14 17:26:37 +00:00
Akshit Khurana
a70297e7cb NNAPI: quant logistic fix (#70847)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70847

NNAPI needs a fixed zero point and scale for sigmoid (logistic)
ghstack-source-id: 146555935

Test Plan: LIBNEURALNETWORKS_PATH="/path/to/libneuralnetworks.so" pytest test/test_nnapi.py

Reviewed By: dreiss

Differential Revision: D33237918

fbshipit-source-id: 05ef3a81bf1589ad44b599a19bce4066531c432b
2022-01-07 13:36:33 -08:00
Akshit Khurana
44283c2766 NNAPI: Add qint16 support via int16 (#70621)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70621

Pytorch doesn't have support for qint16 yet. Add an option to handle qint16 via int16 & qint32 data types.

* For qint16 tensors in NNAPI, the user sends a qint32 tensor. We convert the qint32 to int16 for the converter and set the zero point and scale for nnapi
    * inputs to the model have to have fixed scale and zero point and are only supported for testing
* Added a flag use_int16_for_qint16 which will be used maintain backwards compatibility in the converter when true qint16 is supported in PyTorch
ghstack-source-id: 146507483

Test Plan: pytest test/test_nnapi.py

Reviewed By: dreiss

Differential Revision: D33285124

fbshipit-source-id: b6376fa1bb18a0b9f6a18c545f600222b650cb66
2022-01-04 23:12:38 -08:00
Akshit Khurana
1150046d29 NNAPI: Add runtime flexible shapes & return shapes (#70334)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70334

* Use 0 for load time flexible shapes
* -1 for runtime flexible shapes
* NNAPI needs return shapes for flexible outputs

Test Plan: Tested via upcoming ops

Reviewed By: dreiss

Differential Revision: D33237922

fbshipit-source-id: 50afdd8e3c6401dfb79b4bc09513c9882a09e5d5
2022-01-04 08:37:09 -08:00
Akshit Khurana
d9106116aa nnapi: Add int32 type torchscript expressions (#70197)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/70197

Test Plan:
* `pytest test/test_nnapi.py`
* Testing via ops following this commit

Reviewed By: anshuljain1, dreiss

Differential Revision: D33237917

fbshipit-source-id: f0493620f28a62ad9fe0b97b67d1e25059d50c24
2022-01-03 19:00:38 -08:00
Xiao Wang
bfe5ad28e6 [Linalg] Add a runtime switch to let pytorch prefer a backend impl in linalg functions on GPU (#67980)
Summary:
Per title.

This PR introduces a global flag that lets pytorch prefer one of the many backend implementations while calling linear algebra functions on GPU.

Usage:
```python
torch.backends.cuda.preferred_linalg_library('cusolver')
```

Available options (str): `'default'`, `'cusolver'`, `'magma'`.

Issue https://github.com/pytorch/pytorch/issues/63992 inspired me to write this PR. No heuristic is perfect on all devices, library versions, matrix shapes, workloads, etc. We can obtain better performance if we can conveniently switch linear algebra backends at runtime.

Performance of linear algebra operators after this PR should be no worse than before. The flag is set to **`'default'`** by default, which makes everything the same as before this PR.

The implementation of this PR is basically following that of https://github.com/pytorch/pytorch/pull/67790.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/67980

Reviewed By: mruberry

Differential Revision: D32849457

Pulled By: ngimel

fbshipit-source-id: 679fee7744a03af057995aef06316306073010a6
2021-12-03 19:06:30 -08:00
eqy
790763b0fe Add an option to disable reduced precision reductions for FP16 GEMM (#67946)
Summary:
https://github.com/pytorch/pytorch/issues/67578 disabled reduced precision reductions for FP16 GEMMs. After benchmarking, we've found that this has substantial performance impacts for common GEMM shapes (e.g., those found in popular instantiations of multiheaded-attention) on architectures such as Volta. As these performance regressions may come as a surprise to current users, this PR adds a toggle to disable reduced precision reductions
`torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = `
rather than making it the default behavior.

CC ngimel ptrblck
stas00 Note that the behavior after the previous PR can be replicated with
`torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/67946

Reviewed By: zou3519

Differential Revision: D32289896

Pulled By: ngimel

fbshipit-source-id: a1ea2918b77e27a7d9b391e030417802a0174abe
2021-11-09 17:27:20 -08:00
Akshit Khurana
1de8976e85 Add quantized::convtranspose2d (#63914)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63914

Test Plan: Imported from OSS

Reviewed By: dreiss

Differential Revision: D30531889

fbshipit-source-id: a65e389da2722efbc62e3fe1edf503732326350d
2021-09-24 17:07:29 -07:00
Akshit Khurana
ab5eb56983 add qmul (#63913)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63913

Test Plan: Imported from OSS

Reviewed By: dreiss

Differential Revision: D30531890

fbshipit-source-id: 29d88cc61bd1e328cc7ae7a91a2f8d4819803c8d
2021-09-24 17:06:17 -07:00
Tao Xu
7dc3858deb [CoreML][fbcode] Add the preprocess python APIs (#64521)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64521

Add the preprocess part for the coreml delegate. Check out the `example.py` for the usage.
ghstack-source-id: 138324214

Test Plan:
```
(base) [taox@devvm2780.vll0 ~/fbsource/fbcode/caffe2/fb]  buck run coreml:example -- --model="/home/taox/mobilenetv2/mobilenetv2.pt" --out="/home/taox/mobilenetv2/mobilenetv2_coreml.pt"
Parsing buck files: finished in 0.5 sec
Downloaded 0/1 artifacts, 0.00 bytes, 100.0% cache miss (for updated rules)
Building: finished in 10.6 sec (100%) 12611/57623 jobs, 1/57623 updated
  Total time: 11.1 sec
Converting Frontend ==> MIL Ops: 100%|██████████████████████████████████████████▉| 382/383 [00:00<00:00, 692.58 ops/s]
Running MIL optimization passes: 100%|███████████████████████████████████████████| 18/18 [00:00<00:00, 45.55 passes/s]
Translating MIL ==> MLModel Ops: 100%|███████████████████████████████████████████| 704/704 [00:01<00:00, 468.56 ops/s]
input {
  name: "input_0"
  type {
    multiArrayType {
      shape: 1
      shape: 3
      shape: 224
      shape: 224
      dataType: FLOAT32
    }
  }
}
output {
  name: "645"
  type {
    multiArrayType {
      dataType: FLOAT32
    }
  }
}
metadata {
  userDefined {
    key: "com.github.apple.coremltools.source"
    value: "torch==1.10.0a0+fb"
  }
  userDefined {
    key: "com.github.apple.coremltools.version"
    value: "4.1"
  }
}

{'inputs': '[["input_0", "0", "[1, 3, 224, 224]"]]', 'outputs': '[["645", "0", "[1, 1000]"]]', 'config': '{"spec_ver": "4", "backend": "cpu", "allow_low_precision": "True"}', 'metadata': '{"coremltool_ver": "4.1", "torch_ver": "torch==1.10.0a0+fb"}'}
WARNING: Logging before InitGoogleLogging() is written to STDERR
W0826 13:27:12.690302 2477051 backend_detail.cpp:376] Warning: Backend [coreml] is not available. Execution of this Module is still possible by saving and loading on a device where the backend is available. (function codegen_backend_module)
graph(%self.1 : torch.jit.LoweredModule.coreml.__torch__.torchvision.models.mobilenetv2.MobileNetV2,
      %x.1 : Tensor):
  %51 : str = prim::Constant[value="Exception: Backend is not available."]()
  %50 : str = prim::Constant[value="AssertionError: "]()
  %14 : str = prim::Constant[value="forward"]() # <string>:5:62
  %48 : Tensor = prim::Uninitialized()
  %44 : Tensor = prim::Uninitialized()
  %typed_inputs.1 : Any[] = prim::ListConstruct(%x.1)
  %__backend.3 : __torch__.torch.classes.__backends__.coreml = prim::GetAttr[name="__backend"](%self.1)
  %8 : bool = prim::CallMethod[name="is_available"](%__backend.3) # <string>:4:19
  %49 : Tensor = prim::If(%8) # <string>:4:16
    block0():
      %__backend : __torch__.torch.classes.__backends__.coreml = prim::GetAttr[name="__backend"](%self.1)
      %__handles : Dict(str, Any) = prim::GetAttr[name="__handles"](%self.1)
      %15 : Any = aten::__getitem__(%__handles, %14) # <string>:5:47
      %17 : Any[] = prim::CallMethod[name="execute"](%__backend, %15, %typed_inputs.1) # <string>:5:24
      %18 : Any = prim::ListUnpack(%17)
      %20 : bool = prim::isinstance[types=[Tensor]](%18)
      %39 : Tensor = prim::If(%20) # <string>:6:18
        block0():
          %22 : Tensor = prim::unchecked_cast(%18)
          -> (%22)
        block1():
           = prim::RaiseException(%50) # <string>:6:18
          -> (%44)
      -> (%39)
    block1():
       = prim::RaiseException(%51) # <string>:9:18
      -> (%48)
  return (%49)

```

Reviewed By: raziel

Differential Revision: D30585154

fbshipit-source-id: 66c7d2e931be6eaa3c43a0ee131ea8046452449d
2021-09-17 00:25:14 -07:00
Akshit Khurana
2d58f3f56d NNAPI: Support const values in binary ops
Summary:
NNAPI converter failed with 1 const value and one tensor earlier
Code suggestions from dreiss

Test Plan:
pytest test/test_nnapi.py::TestNNAPI::test_pointwise_binary

Imported from OSS

Reviewed By: anshuljain1

Differential Revision: D28893881

fbshipit-source-id: 59240373fb03c6fdafa4cb2fa4d8408dd20092f6
2021-08-20 21:10:26 -07:00
Amy He
73f1e2d1dc [8/N] Nnapi backend delegation preprocess: New refactored design (#62225)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62225

Rewrote the preprocess function for Android NNAPI delegate.
Previously, `preprocess()` called `convert_model_to_nnapi()` using Pybind and returned a NnapiModule that is serialized for mobile. Now, `preprocess()` calls a sub-function of `convert_model_to_nnapi()` and returns several preprocessed items (that were previously components of NnapiModule).

Dictionary returned contains:
   "shape_compute_module": torch::jit::Module,
   "ser_model": torch::Tensor,
   "weights": List[torch.Tensor],
   "inp_mem_fmts": List[int],
   "out_mem_fmts": List[int]

**Purpose and Future:**
The purpose of these changes are to move more implementation from bytecode and Torchscript to the delegate API, since bytecode is less efficient.
Now, only the shape computation uses bytecode. In the future, shape computation will be moved out of Torchscript as well.

**nnapi_backend_preprocess.cpp:** preprocess implementation
**prepare.py**: refactored a portion of `convert_model_to_nnapi()` to `process_for_nnapi()`, so preprocess can get components of NnapiModule

**Test:**
Ran `python test/test_jit.py TestNnapiBackend` and `python test/test_nnapi.py` on OSS successfully
ghstack-source-id: 134444190

Test Plan: Ran `python test/test_jit.py TestNnapiBackend` and `python test/test_nnapi.py` on OSS successfully

Reviewed By: raziel

Differential Revision: D29922279

fbshipit-source-id: cadcf8908d8a745dc7abbe286e97d6ead937d4ab
2021-07-27 18:52:48 -07:00
Akshit Khurana
8e71f48f0a Handle simple NNAPI flatten NHWC case (#61796)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61796

We can easily handle nnapi conversion for nhwc inputs
that have 1 channel or H & W are 1

Test Plan:
pytest test/test_nnapi.py::TestNNAPI::test_flatten

Imported from OSS

Reviewed By: saketh-are

Differential Revision: D29827735

fbshipit-source-id: 65dee4b42fceef1b032bf5dd1c4cc6e020d01e14
2021-07-26 10:59:04 -07:00
Akshit Khurana
a3670ba377 Add option to specify custom NNAPI serializer (#61025)
Summary:
To add serializer for custom ops we can subclass default serializer
and update ADDER_MAP

Pull Request resolved: https://github.com/pytorch/pytorch/pull/61025

Test Plan:
* pytest test/test_nnapi.py::TestNNAPI for current serializer
* Custom serializers to be tested with custom ops

Imported from OSS

Reviewed By: anshuljain1

Differential Revision: D29480745

fbshipit-source-id: 37e3f8de3c97f6c8a486f9879ce11430ea89af34
2021-07-09 15:27:10 -07:00
Akshit Khurana
ae65f63971 Make nnapi flatten converter accept flex inputs (#61024)
Summary:
As title

Pull Request resolved: https://github.com/pytorch/pytorch/pull/61024

Test Plan: pytest test/test_nnapi.py::TestNNAPI::test_flatten

Reviewed By: anshuljain1

Differential Revision: D29480748

fbshipit-source-id: c334b09600a64d3e552cec843d6da3de28e7d27c
2021-07-09 15:27:02 -07:00
Akshit Khurana
76c0f223d3 Make nnapi cat converter accept flex inputs
Summary: As title

Test Plan: pytest test/test_nnapi.py::TestNNAPI::test_cat

Reviewed By: anshuljain1

Differential Revision: D29480747

fbshipit-source-id: 161803054ff1a4c2c750fc30a5f0fc6d8a24b2c9
2021-07-09 14:27:53 -07:00
Akshit Khurana
9e81d3d869 Make NNAPI linear converter accept flex inputs (#61022)
Summary:
As title

Pull Request resolved: https://github.com/pytorch/pytorch/pull/61022

Test Plan: pytest test/test_nnapi.py::TestNNAPI::test_linear

Reviewed By: anshuljain1

Differential Revision: D29480749

fbshipit-source-id: 35975861740298c9e16f866c939e7ee3c2151710
2021-07-09 14:27:51 -07:00
Akshit Khurana
9e533a62f6 Make conv2d nnapi converter accept flexible batch (#61021)
Summary:
Same as title

Pull Request resolved: https://github.com/pytorch/pytorch/pull/61021

Test Plan: pytest test/test_nnapi.py::TestNNAPI

Reviewed By: anshuljain1

Differential Revision: D29480746

fbshipit-source-id: 7217c8f3a811db8c3c373f3e7ca31caf9502ef22
2021-07-09 10:28:10 -07:00
Akshit Khurana
8bd3e52e00 Add conv2d transpose NNAPI converter (#59529)
Summary:
* Conv2d transpose support
* Quantize WIP

Pull Request resolved: https://github.com/pytorch/pytorch/pull/59529

Test Plan: pytest test/test_nnapi.py::TestNNAPI::test_conv2d_transpose

Reviewed By: anshuljain1

Differential Revision: D28926335

fbshipit-source-id: 8f90182f96cee0a13c4f38331d421e1e8ac618de
2021-07-09 09:29:20 -07:00
Ivan Kobzarev
7b6ddb6793 [nnapi] add log_softmax (#61378)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/61378

Test Plan: Imported from OSS

Reviewed By: axitkhurana

Differential Revision: D29597355

Pulled By: IvanKobzarev

fbshipit-source-id: 55124749f8eeffa2b2713f7cffd5ccf965561de1
2021-07-07 18:28:39 -07:00
Akshit Khurana
baa518e2f6 Add Int32 support for NNAPI (#59365)
Summary:
Support Int32 tensors in NNAPI converter

Pull Request resolved: https://github.com/pytorch/pytorch/pull/59365

Test Plan: Local testing with FB prod models

Reviewed By: anshuljain1

Differential Revision: D28881040

fbshipit-source-id: 2dacceffd322a21d91bfefcf2fb2ea400d952d0d
2021-07-07 12:40:49 -07:00
Akshit Khurana
cf285d8eea Add aten::slice NNAPI converter (#59364)
Summary:
Add support for aten::slice op in the NNAPI model converter

* If start = 0; end = max -> identity
* Flexible shapes can be passed through
* Flexible shapes can't be sliced over

Pull Request resolved: https://github.com/pytorch/pytorch/pull/59364

Test Plan: pytest test/test_nnapi.py::TestNNAPI::test_slice

Reviewed By: anshuljain1

Differential Revision: D28881039

fbshipit-source-id: 3c1c630ff27b5bba6eda403d87570c61d43ae90e
2021-07-07 12:40:47 -07:00
Akshit Khurana
d26372794a Add aten::detach NNAPI converter (#58543)
Summary:
* Add support for aten::detach op in the NNAPI model converter as a no-op
* Also add flexible op support for add_pointwise_simple_unary_op

Pull Request resolved: https://github.com/pytorch/pytorch/pull/58543

Test Plan: pytest test/test_nnapi.py::TestNNAPI::test_detatch

Reviewed By: anshuljain1

Differential Revision: D28531942

fbshipit-source-id: 4387dbbbadd8ce6b690841f3a903e68a380b849d
2021-07-07 12:40:46 -07:00
Akshit Khurana
0be228dd5f Add aten::flatten NNAPI converter (#60885)
Summary:
Add support for aten::div op in the NNAPI model converter. Startup time
variable size support isn't supported as shapes go as inputs to NNAPI op

Runtime variable size support to supported soon

Pull Request resolved: https://github.com/pytorch/pytorch/pull/60885

Test Plan: pytest test/test_nnapi.py::TestNNAPI::test_flatten

Reviewed By: anshuljain1

Differential Revision: D29451725

fbshipit-source-id: 8902745f7758c8cc88ad4b4ce02b8301ff894bd4
2021-07-07 12:40:44 -07:00
Akshit Khurana
b297f65b66 Add aten::div NNAPI converter (#58541)
Summary:
Add support for aten::div op in the NNAPI model converter. Add variable
size input test as well.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/58541

Test Plan: pytest test/test_nnapi.py::TestNNAPI::test_div

Reviewed By: anshuljain1

Differential Revision: D28531943

fbshipit-source-id: e96342146f6de216f7b88443618edfc54963747c
2021-07-07 12:40:42 -07:00
Akshit Khurana
eab18a9a40 Add aten::to NNAPI converter (#58540)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58540

Add support for aten::to op in the NNAPI model converter for simple
cases like to("cpu"), to("gpu")

Test Plan: pytest test/test_nnapi.py::TestNNAPI::test_to

Reviewed By: anshuljain1

Differential Revision: D28531941

fbshipit-source-id: 0c934f7aceaff2669307c3426efe32046d8c44f3
2021-07-07 12:40:41 -07:00
Akshit Khurana
14d604a13e Add aten::softmax NNAPI converter (#58539)
Summary:
Add support for aten::softmax op in the NNAPI model converter with
flexible size

Pull Request resolved: https://github.com/pytorch/pytorch/pull/58539

Test Plan: pytest test/test_nnapi.py::TestNNAPI::test_softmax

Reviewed By: anshuljain1

Differential Revision: D28531946

fbshipit-source-id: 8633f3e3f7f52795f9866ff16ad0867ea36a19e8
2021-07-07 12:39:31 -07:00
Akshit Khurana
369802a504 Add aten::avgpool2d NNAPI converter (#58538)
Summary:
Add support for aten::avgpool2d op in the NNAPI model converter with var
size support

Pull Request resolved: https://github.com/pytorch/pytorch/pull/58538

Test Plan: pytest test/test_nnapi.py::TestNNAPI::test_avgpool2d

Reviewed By: anshuljain1

Differential Revision: D28531944

fbshipit-source-id: 43ff8c9389365698c282f204042b49c7ec84d824
2021-07-01 14:07:14 -07:00
Akshit Khurana
c4bb6a5781 NNAPI: flex size support for upsample_nearest2d op (#57563)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57563

Add flexible size support for upsample_nearest2d op in nnapi model conversion

Test Plan:
pytest test/test_nnapi.py

Imported from OSS

Reviewed By: dreiss

Differential Revision: D28200847

fbshipit-source-id: 901fe3f6e68e4c16ece730f3ffa68dc88c6ed6c3
2021-05-05 13:54:43 -07:00
Akshit Khurana
4c609a9782 NNAPI: Add qadd flexible size support (#57562)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57562

Add flexible size support for qadd op in nnapi model conversion

Test Plan:
pytest test/test_nnapi.py

Imported from OSS

Reviewed By: dreiss

Differential Revision: D28200849

fbshipit-source-id: d5b2ea8e9eb8ae405ff2c960f7549cef60bc0991
2021-05-05 13:54:41 -07:00
Akshit Khurana
28cd04ea64 NNAPI: add flexible size support for conv2d (#57561)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57561

Add flexible size support for conv2d op in nnapi model conversion

Test Plan:
pytest test/test_nnapi.py

Imported from OSS

Reviewed By: dreiss

Differential Revision: D28200848

fbshipit-source-id: d94ccf48a3d8453aa8e96c7cac02948c4cd870cc
2021-05-05 13:53:33 -07:00
Guilherme Leobas
e7c79cb158 Add type annotations to nnapi (#48142)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/48141

~Mypy is complaining about a missing arg in a function call.~
```bash
torch/backends/_nnapi/serializer.py:806: error: Too few arguments for "_do_add_binary"  [call-arg]
Found 1 error in 1 file (checked 1140 source files)
```

9392137dbe/torch/backends/_nnapi/serializer.py (L804-L806)

~dreiss, would you mind take a look when you have some cycles to spare and see what would be the appropriated value for `fuse_code` here? Thanks :)~

Edit: https://github.com/pytorch/pytorch/issues/48925 got merged a couple of days ago. The blocking part is now unblocked, and I just pushed the changes to make mypy happy again. This PR is ready for review.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/48142

Reviewed By: ezyang

Differential Revision: D28006249

Pulled By: walterddr

fbshipit-source-id: 5e43eeba7143512a549efaad31541f86718add7c
2021-04-26 19:08:07 -07:00
Sam Estep
75024e228c Add lint for unqualified type: ignore (#56290)
Summary:
The other half of https://github.com/pytorch/pytorch/issues/56272.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/56290

Test Plan:
CI should pass on the tip of this PR, and we know that the lint works because the following CI runs (before this PR was finished) failed:

- https://github.com/pytorch/pytorch/runs/2384511062
- https://github.com/pytorch/pytorch/actions/runs/765036024

Reviewed By: seemethere

Differential Revision: D27867219

Pulled By: samestep

fbshipit-source-id: e648f07b6822867e70833e23ddafe7fb7eaca235
2021-04-21 08:07:23 -07:00
David Reiss
da7a27b847 [NNAPI] Initial flexible size support (#54701)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54701

We need NNAPI models to support inputs (and, by extension, intermediate
values and outputs) whose shape is only determined at load time.  For
example, a vision models input shape might be dependent on the aspect
ratio of the device camera.  While NNAPI has full support for variable
shapes (by setting components of the operand shape to 0), the guidance
we have received is that vendor-provided drivers for real hardware are
not able to support this efficiently.  Therefore, we take a hybrid
approach where shapes are calculated at model load time to
semi-dynamically construct our NNAPI model.  While this doesn't let us
have truly dynamic input shapes, it does allow us to ensure that the
vendor driver only sees fixed shapes, so we get maximum performance.

In this initial commit, only PReLU supports dynamic shapes.  Additional
operators will be converted in separate diffs.

- In order to convert a flexible-shape model, the user supplies inputs
  with shapes containing dimensions of size 0 for the flexible
  dimensions.
- During conversion, we generate code to compute the shapes of all
  intermediates and outputs as a function of the input shapes.
- We no longer run the input model to produce the output templates.
  Instead, we generate code to return properly-sized templates, given
  the input shapes.
- All of this generated code goes into a "ShapeComputeModule" that is
  used by the NnapiModule during initialization.
- The ShapeComputeModule mutates the serialized model to fill in the
  computed sizes for each operand.  This requires us to change the dtype
  for the serialized model to int32, but this should be fine because
  everything in it is already 4-byte aligned.
- NnapiInitWrapper no longer exists.  Instead, initialization is
  performed on the first run, based on the real arguments.  We plan to
  provide an API for doing eager initialization.
- Unit test updated to allow separate arguments to be given for trace,
  conversion, and inference.  A flexible-shape test case was added for
  PReLU.

Test Plan: Unit test

Reviewed By: axitkhurana

Differential Revision: D27536796

Pulled By: dreiss

fbshipit-source-id: 105585f247987b1e6ec6946a6fe44401237cb0a0
2021-04-06 13:49:43 -07:00
David Reiss
1e3b3a4714 [NNAPI] Create get_next_operand_id (#54700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54700

This is an internal method just to make it more clear what
len(self.operands) is doing.

Test Plan: Unit test

Reviewed By: axitkhurana

Differential Revision: D27536794

Pulled By: dreiss

fbshipit-source-id: 678cee8a47df6757dd2e6feabf2560fd82d32e26
2021-04-06 13:49:41 -07:00
David Reiss
ca67c17e46 [NNAPI] Add fixed-size assertions (#54699)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54699

We'll soon be adding support for flexible-size tensors to the NNAPI
converter, but it won't be added to all ops at once.  Create
get_tensor_operand_by_jitval_fixed_size as a wrapper for
get_tensor_operand_by_jitval that verifies that the argument has a fixed
shape.  Update all call sites.  As flexible size support is added to
each op, the call sites can be converted back and proper size checks
added.

Test Plan: Unit test

Reviewed By: axitkhurana

Differential Revision: D27536791

Pulled By: dreiss

fbshipit-source-id: 6fb1fea814d767b6ff263fd8b88240a51be74777
2021-04-06 13:49:38 -07:00