Commit Graph

260 Commits

Author SHA1 Message Date
Justin Chu
388368b699 [ONNX] Fix type annotations and enable type checking for all apis (#84091)
Enable runtime type checking for all torch.onnx public apis, symbolic functions and most helpers (minus two that does not have a checkable type: `_.JitType` does not exist) by adding the beartype decorator. Fix type annotations to makes unit tests green.

Profile:

export `torchvision.models.alexnet(pretrained=True)`

```
with runtime type checking: 21.314 / 10 passes
without runtime type checking: 20.797 / 10 passes

+ 2.48%
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84091
Approved by: https://github.com/BowenBao, https://github.com/thiagocrepaldi
2022-09-03 01:40:18 +00:00
titaiwang
ece0002c4b [ONNX] Disable autocast cache in exporter (#84219)
This PR provides a temporary fix on #84092 in exporter to avoid more cases falling into this bug.
A long-term fix will be provided later.

A simple repro with torch.onnx.export is still under investigation, as torch.jit.trace() is not the API we call inside torch.onnx.export, and it may introduce the difference. Therefore, a test case is provided here only.
A specific test one can use,
```python
import torch
import onnxruntime
from onnxruntime.training.ortmodule import DebugOptions, LogLevel
from onnxruntime.training.ortmodule import ORTModule

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.cv1 = torch.nn.Conv2d(3, 3, 5, 2, 1)

    def forward(self, x):
        x = self.cv1(x)
        return x

x = torch.randn(10, 3, 20, 20) * 2
m = MyModule().eval()
x = x.cuda()
m = m.cuda()

debug_options = DebugOptions(log_level=LogLevel.VERBOSE, save_onnx=True, onnx_prefix="ViT-B")
m = ORTModule(m, debug_options=debug_options)

with torch.cuda.amp.autocast(dtype=torch.float16, cache_enabled=True):
    loss = m(x)
```
AND make assertion fail in ORTModule
17ccd6fa02/orttraining/orttraining/python/training/ortmodule/_io.py (L578-L581)

Without the fix, the user will see the weight/bias of Conv node becomes constant.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84219
Approved by: https://github.com/BowenBao, https://github.com/thiagocrepaldi
2022-09-01 00:34:37 +00:00
titaiwang
18264432f7 [ONNX] replace all _C._flatten to torch.jit._flatten (#83598)
_C._flatten is exactly the same as torch.jit._flatten. Unifying them to reduce confusion.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83598
Approved by: https://github.com/justinchuby, https://github.com/BowenBao
2022-09-01 00:31:28 +00:00
BowenBao
806878518f [ONNX][Reland] Export node and value with scope name (#82040)
Introduce `_jit_pass_onnx_assign_node_and_value_names` to parse and assign
scoped name for nodes and values in exported onnx graph.
Module layer information is obtained from `ONNXScopeName` captured in `scope`
attribute in nodes. For nodes, the processed onnx node name are stored in
attribute `onnx_name`. For values, the processed onnx output name are stored
as `debugName`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82040
Approved by: https://github.com/AllenTiTaiWang, https://github.com/justinchuby, https://github.com/abock
2022-08-29 20:10:38 +00:00
PyTorch MergeBot
8e6207bcd8 Revert "[ONNX] Export node and value with scope name (#82040)"
This reverts commit 6a3666282d.

Reverted https://github.com/pytorch/pytorch/pull/82040 on behalf of https://github.com/weiwangmeta due to Diff reverted internally
2022-08-29 06:36:18 +00:00
PyTorch MergeBot
d8cc8368ab Revert "[ONNX] Fix type annotations and enable type checking for all apis (#84091)"
This reverts commit 6446da1730.

Reverted https://github.com/pytorch/pytorch/pull/84091 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally
2022-08-28 12:28:58 +00:00
Justin Chu
6446da1730 [ONNX] Fix type annotations and enable type checking for all apis (#84091)
Enable runtime type checking for all torch.onnx public apis, symbolic functions and most helpers (minus two that does not have a checkable type: `_.JitType` does not exist) by adding the beartype decorator. Fix type annotations to makes unit tests green.

Profile:

export `torchvision.models.alexnet(pretrained=True)`

```
with runtime type checking: 21.314 / 10 passes
without runtime type checking: 20.797 / 10 passes

+ 2.48%
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84091
Approved by: https://github.com/BowenBao
2022-08-27 04:40:41 +00:00
BowenBao
6a3666282d [ONNX] Export node and value with scope name (#82040)
Introduce `_jit_pass_onnx_assign_node_and_value_names` to parse and assign
scoped name for nodes and values in exported onnx graph.
Module layer information is obtained from `ONNXScopeName` captured in `scope`
attribute in nodes. For nodes, the processed onnx node name are stored in
attribute `onnx_name`. For values, the processed onnx output name are stored
as `debugName`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82040
Approved by: https://github.com/AllenTiTaiWang, https://github.com/justinchuby, https://github.com/abock
2022-08-26 20:59:12 +00:00
Justin Chu
bf25a140f9 [ONNX] Add runtime type checking to export (#83673)
This PR adds an internal wrapper on the [beartype](https://github.com/beartype/beartype) library to perform runtime type checking in `torch.onnx`. It uses beartype when it is found in the environment and is reduced to a no-op when beartype is not found.

Setting the env var `TORCH_ONNX_EXPERIMENTAL_RUNTIME_TYPE_CHECK=ERRORS` will turn on the feature. setting `TORCH_ONNX_EXPERIMENTAL_RUNTIME_TYPE_CHECK=DISABLED` will disable all checks. When not set and `beartype` is installed, a warning message is emitted.

Now when users call an api with invalid arguments e.g.

```python
torch.onnx.export(conv, y, path, export_params=True, training=False)

# traning should take TrainingModel, not bool
```

they get

```
Traceback (most recent call last):
  File "bisect_m1_error.py", line 63, in <module>
    main()
  File "bisect_m1_error.py", line 59, in main
    reveal_error()
  File "bisect_m1_error.py", line 32, in reveal_error
    torch.onnx.export(conv, y, cpu_model_path, export_params=True, training=False)
  File "<@beartype(torch.onnx.utils.export) at 0x1281f5a60>", line 136, in export
  File "pytorch/venv/lib/python3.9/site-packages/beartype/_decor/_error/errormain.py", line 301, in raise_pep_call_exception
    raise exception_cls(  # type: ignore[misc]
beartype.roar.BeartypeCallHintParamViolation: @beartyped export() parameter training=False violates type hint <class 'torch._C._onnx.TrainingMode'>, as False not instance of <protocol "torch._C._onnx.TrainingMode">.
```

when `TORCH_ONNX_EXPERIMENTAL_RUNTIME_TYPE_CHECK` is not set and `beartype` is installed, a warning message is emitted.

```
>>> torch.onnx.export("foo", "bar", "f")
<stdin>:1: CallHintViolationWarning: Traceback (most recent call last):
  File "/home/justinchu/dev/pytorch/torch/onnx/_internal/_beartype.py", line 54, in _coerce_beartype_exceptions_to_warnings
    return beartyped(*args, **kwargs)
  File "<@beartype(torch.onnx.utils.export) at 0x7f1d4ab35280>", line 39, in export
  File "/home/justinchu/anaconda3/envs/pytorch/lib/python3.9/site-packages/beartype/_decor/_error/errormain.py", line 301, in raise_pep_call_exception
    raise exception_cls(  # type: ignore[misc]
beartype.roar.BeartypeCallHintParamViolation: @beartyped export() parameter model='foo' violates type hint typing.Union[torch.nn.modules.module.Module, torch.jit._script.ScriptModule, torch.jit.ScriptFunction], as 'foo' not <protocol "torch.jit.ScriptFunction">, <protocol "torch.nn.modules.module.Module">, or <protocol "torch.jit._script.ScriptModule">.

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/justinchu/dev/pytorch/torch/onnx/_internal/_beartype.py", line 63, in _coerce_beartype_exceptions_to_warnings
    return func(*args, **kwargs)
  File "/home/justinchu/dev/pytorch/torch/onnx/utils.py", line 482, in export
    _export(
  File "/home/justinchu/dev/pytorch/torch/onnx/utils.py", line 1422, in _export
    with exporter_context(model, training, verbose):
  File "/home/justinchu/anaconda3/envs/pytorch/lib/python3.9/contextlib.py", line 119, in __enter__
    return next(self.gen)
  File "/home/justinchu/dev/pytorch/torch/onnx/utils.py", line 177, in exporter_context
    with select_model_mode_for_export(
  File "/home/justinchu/anaconda3/envs/pytorch/lib/python3.9/contextlib.py", line 119, in __enter__
    return next(self.gen)
  File "/home/justinchu/dev/pytorch/torch/onnx/utils.py", line 95, in select_model_mode_for_export
    originally_training = model.training
AttributeError: 'str' object has no attribute 'training'
```

We see the error is caught right when the type mismatch happens, improving from what otherwise would become `AttributeError: 'str' object has no attribute 'training'`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83673
Approved by: https://github.com/BowenBao
2022-08-25 21:24:37 +00:00
BowenBao
daca0ee5e2 [ONNX] Introduce ONNXScopeName (#82038)
Update `_setup_trace_module_map` to always record module/layer info
in `Scope` attribute for nodes.
Extend `Scope` name to not only record module typename, but also
module object variable name. Both names are formatted and stored
as `name` attribute in `Scope`.
Introduce `ONNXScopeName` class to manage the formatting and parsing.
Updated local function export code adjusting to this update.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82038
Approved by: https://github.com/AllenTiTaiWang, https://github.com/justinchuby, https://github.com/abock, https://github.com/malfet
2022-08-22 20:49:21 +00:00
Justin Chu
e4f74f0891 [ONNX] Update the default opset to version 14 (#83284)
Update the default opset by running the `update_default_opset_version.py` script. The update is done in a regularly to ensure we are in sync with the onnx updates. All changes are produced by the script.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83284
Approved by: https://github.com/AllenTiTaiWang, https://github.com/malfet, https://github.com/BowenBao
2022-08-18 19:13:38 +00:00
BowenBao
017ecb782d [ONNX] Update legacy code, initialize onnx_shape_inference=True by default (#82767)
Legacy code has onnx_shape_inference=False by default, which is misleading
as every other export api sets it to True unless otherwise overriden by caller.
There is only two tests that need updating according to this change.
* test_utility_funs.py::test_constant_fold_shape. The resulting number of nodes
  in graph is increased by 1, due to that previously the extra constant node was
  added as initializer.
* test_utility_funs.py::test_onnx_function_substitution_pass. Enabling onnx
  shape inference discovered discrepancy in test input shape and supplied dynamic
  axes arguments.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82767
Approved by: https://github.com/justinchuby, https://github.com/abock
2022-08-10 21:50:13 +00:00
Justin Chu
f5701a1f9a [ONNX] Remove unused patching methods (#83006)
### Description
<!-- What did you change and why was it needed? -->

Remove unused patching methods:

- `torch._C.Graph.constant`
- unpatch `torch._C.Node.__getitem__` and move the helper function to `symbolic_helper`

Add typing annotations

### Issue
<!-- Link to Issue ticket or RFP -->

#76254

### Testing
<!-- How did you test your change? -->

Unit tested
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83006
Approved by: https://github.com/BowenBao
2022-08-09 19:24:03 +00:00
qqaatw
9b4dc56c83 [ONNX] Fix quantization outputs' dtype (#79690)
Part of #79263

Previously, all quantized PyTorch tensors are all casted to the dtypes which comply with ONNX's definition, i.e. `scale` is casted to `double`, and `zero_point` is casted to `int64`. These casts lead to inconsistent dtypes when comparing PyTorch's outputs and ONNX runtime's outputs.

Now, `cast_onnx_accepted` argument is added to `unpack_quantized_tensor` function. When making example inputs for ONNX, we cast them to the ONNX compliant dtypes; otherwise, they are casted to PyTorch default types for quantization.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79690
Approved by: https://github.com/justinchuby, https://github.com/BowenBao
2022-08-09 18:32:03 +00:00
shubhambhokare1
95d873855e [ONNX] Inline prim::PythonOp for Autograd Function Export (#74765)
Add flag (inline_autograd) to enable inline export of model consisting of autograd functions. Currently, this flag should only be used in TrainingMode.EVAL and not for training.

An example:

If a model containing ``autograd.Function`` is as follows
```
                class AutogradFunc(torch.autograd.Function):
                  @staticmethod
                  def forward(ctx, i):
                      result = i.exp()
                      result = result.log()
                      ctx.save_for_backward(result)
                      return result
```
Then the model is exported as
```
                graph(%0 : Float):
                  %1 : Float = ^AutogradFunc(%0)
                  return (%1)
```
If inline_autograd is set to True, this will be exported as
```
                graph(%0 : Float):
                  %1 : Float = onnx::Exp(%0)
                  %2 : Float = onnx::Log(%1)
                  return (%2)
```

If one of the ops within the autograd module is not supported, that particular node is exported as is mirroring ONNX_FALLTHROUGH mode

Fixes: #61813
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74765
Approved by: https://github.com/BowenBao, https://github.com/malfet
2022-08-03 23:30:19 +00:00
Huy Do
6ea422dd0b Format torch/onnx with ufmt (#82137)
This is the last batch for the new ufmt (black + usort) linter. After this, black linter can finally be replaced. The previous PR to format ONNX tests was #81335
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82137
Approved by: https://github.com/kit1980, https://github.com/AllenTiTaiWang
2022-07-25 22:42:21 +00:00
Justin Chu
d1d2687d34 [ONNX] Fix potentially unbound variables (#79789)
Pylint alerts that some variables may be unbound. This PR fixes the errors.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79789
Approved by: https://github.com/garymm
2022-06-29 17:01:49 +00:00
Justin Chu
4b817f5816 [ONNX] Improve docstrings and typing annotations (#78231)
- Remove wrappers in `__init__` around utils and instead expose those functions directly. Move the docstrings from `__init__` to corresponding functions in utils
- Annotate `torch.onnx.export` types
- Improve docstrings
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78231
Approved by: https://github.com/BowenBao
2022-06-16 02:59:24 +00:00
Justin Chu
c8b9b6266b [ONNX] Fix arg type in _set_training_mode (#78583)
When `TrainingMode.PRESERVE` is set for export, the exporter used to change the model's training mode based on some logic. Now we respect the option and not touch the model's training state.

- Previously `_set_training_mode`'s behavior doesn't match what the global variable expects. This PR removes the deprecated `_set_training_mode` and makes the type correct.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78583
Approved by: https://github.com/BowenBao
2022-06-15 23:47:12 +00:00
Justin Chu
d3ef5c3fa3 [ONNX] Clean up __init__ in torch.onnx (#78446)
- Move definitions in `__init__` to internal classes and expose them by importing to init (prevent circular dependencies): https://github.com/pytorch/pytorch/wiki/torch.onnx-Namespacing
  - Context classes and enums are moved to `_exporter_states.py`
  - Exceptions are moved to `errors.py`
- Define `__all__` for torch.onnx. https://github.com/pytorch/pytorch/wiki/Public-API-definition-and-documentation
- Moved `utils.__IN_ONNX_EXPORT` to `GLOBALS.in_onnx_export`
- Deprecated `torch.onnx._export`

Precedes #78231

Using this as an aid for finding public functions:

```python
list(filter(lambda x: not x.startswith("_"), torch.onnx.utils.__dict__.keys()))
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78446
Approved by: https://github.com/BowenBao
2022-06-14 04:35:06 +00:00
BowenBao
530dcc2b94 [ONNX] Tool to verify exported model discrepancy between sets of inputs
A graph is exported for each set of inputs. The exported graphs are then compared
to each other, and discrepancies are reported. This function first checks the jit
graph, and then the onnx graph.

Unless otherwise specified, the jit/ONNX graph is expected to be the same, regardless
of the inputs it used for exporting. A discrepancy would imply the graph exported is
not accurate when running with other set of inputs, which will typically results in
runtime error or output mismatches.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78323

Approved by: https://github.com/justinchuby, https://github.com/garymm
2022-06-06 20:29:20 +00:00
Justin Chu
161e931156 [ONNX] Modernize python syntax (#77935)
Use pyupgrade(https://github.com/asottile/pyupgrade) and flynt to modernize python syntax

```sh
pyupgrade --py36-plus --keep-runtime-typing torch/onnx/**/*.py
pyupgrade --py36-plus --keep-runtime-typing test/onnx/**/*.py
flynt torch/onnx/ --line-length 120
```

- Use f-strings for string formatting
- Use the new `super()` syntax for class initialization
- Use dictionary / set comprehension
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77935
Approved by: https://github.com/BowenBao
2022-05-24 22:52:37 +00:00
Justin Chu
0d76299ff7 [ONNX] Clean up module imports (#77423)
Cleaning up onnx module imports to prepare for updating `__init__`.

- Simplify importing the `_C` and `_C._onnx` name spaces
- Remove alias of the symbolic_helper module in imports
- Remove any module level function imports. Import modules instead
    - Alias `symbilic_opsetx` as `opsetx`
- Fix some docstrings

Requires:
- https://github.com/pytorch/pytorch/pull/77448
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77423
Approved by: https://github.com/BowenBao
2022-05-20 01:56:24 +00:00
Justin Chu
563c2719bf [ONNX] Refactor to remove inline imports - attempt 2 (#77448)
Re-land
- #77142

(diff: https://github.com/pytorch/pytorch/compare/c08b8f0..justinchuby:justinchu/remove-patch2)

Fixed:
- Delay import symbolic_opsets in the registry.

Tested locally with torchvision
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77448
Approved by: https://github.com/garymm
2022-05-16 14:44:24 +00:00
PyTorch MergeBot
6b366dd3c1 Revert "[ONNX] Refactor to remove inline imports (#77142)"
This reverts commit c08b8f0967.

Reverted https://github.com/pytorch/pytorch/pull/77142 on behalf of https://github.com/malfet
2022-05-13 19:44:17 +00:00
Justin Chu
c08b8f0967 [ONNX] Refactor to remove inline imports (#77142)
Reduce circular dependencies

- Lift constants and flags from `symbolic_helper` to `_constants` and `_globals`
    - Standardized constant naming to make it consistant
- Make `utils` strictly dependent on `symbolic_helper`, removing inline imports from symbolic_helper
- Move side effects from `utils` to `_patch_torch`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77142
Approved by: https://github.com/garymm, https://github.com/BowenBao
2022-05-13 03:46:33 +00:00
Justin Chu
78d3798181 [ONNX] Fix type comparison in utils._need_symbolic_context (#77365)
In `_need_symbolic_context`, when the annotation is postponed evaluated, the annotation is a string and not a type. We need to use get_type_hints to get the real type.

For example,

```python
def g(a: int) -> int: return a

def f(a: "int") -> "int": return a
```

we will get the correct type `int` for both g and f with `typing.get_type_hints`. Otherwise, the type for `a` in `f` will be a string and is not comparable to the type `int` - `issubclass` will complain.

This is necessary as we will use postponed typing evaluation to break circular dependencies.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77365
Approved by: https://github.com/BowenBao
2022-05-12 23:49:04 +00:00
BowenBao
93953a48b7 [ONNX] Bug fix: opset_version checked before set (#76928)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76928
Approved by: https://github.com/justinchuby, https://github.com/garymm
2022-05-11 17:16:22 +00:00
Justin Chu
5dd1c67776 [ONNX] Format ONNX python with black
Format all onnx python code with black and isort with

```sh
isort torch/onnx/ test/onnx
black torch/onnx/ test/onnx
```

Updated lintrunner config to include these paths.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76754
Approved by: https://github.com/suo, https://github.com/BowenBao
2022-05-05 00:19:22 +00:00
BowenBao
679fc90cdb [ONNX] Support optional type (#68793) (#73284)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73284

Some important ops won't support optional type until opset 16,
so we can't fully test things end-to-end, but I believe this should
be all that's needed. Once ONNX Runtime supports opset 16,
we can do more testing and fix any remaining bugs.

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D34625646

Pulled By: malfet

fbshipit-source-id: 537fcbc1e9d87686cc61f5bd66a997e99cec287b

Co-authored-by: BowenBao <bowbao@microsoft.com>
Co-authored-by: neginraoof <neginmr@utexas.edu>
Co-authored-by: Nikita Shulga <nshulga@fb.com>
(cherry picked from commit 822e79f31ae54d73407f34f166b654f4ba115ea5)
2022-05-04 20:24:30 +00:00
Justin Chu
2e841e68b2 [ONNX] Documentation and typing annotations in registry
Updating the docstrings and type annotations as I walk through the code.

- Turned some comments into docstrings.
- Added type annotations for some functions in utils and the registry
- Removed direct function imports; importing functions makes name space collision easier to happen and refactoring/code analysis harder: https://google.github.io/styleguide/pyguide.html#22-imports
- Formatted touched files with black
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76255
Approved by: https://github.com/BowenBao
2022-04-28 18:24:24 +00:00
Thiago Crepaldi
90d31cb311 Emit ATen ops when symbolics raise + minor fixes
Currently `torch.onnx.export(.., operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK)` only issues ATen ops through explicit requests (e.g. `g.at()`) calls inside each op symbolic function. This is done based on specific conditions such as `operator_export_type==OperatorExportTypes.ONNX_ATEN_FALLBACK)` or `is_caffe2_aten_fallback()`

This PR extends the ATen fallback mechanism for scenarios when the symbolic function raises `RuntimeError` during export. The idea is that partial implementation of existing ONNX ops can fallback to ATen as a last resort. That is valuable because each operator can have many input combinations and not all are always implemented.

A minor fix was done to make sure the `overload_name` attribute is added to explicit ATen op fallback requests when a symbolic is not registered to a particular op.

ps: The behavior for builds with BUILD_CAFFE2=1 is not changed to ensure BC.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74759
Approved by: https://github.com/garymm, https://github.com/msaroufim
2022-04-23 21:24:25 +00:00
BowenBao
2c748b7573 [ONNX] Trace model if quantization is detected
Previously pre-tracing model is required for exporting quantized model.
e.g. calling `traced_m = torch.jit.trace(model, inputs)` and export `traced_m`.
The reason was quantized weights are stored in a unique `PackedParam` structure,
and they need to be handled by tracing to be exportable.
This PR enables export api to call tracing underneath if it detects quantization
in the model.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/75921

Approved by: https://github.com/garymm
2022-04-22 17:27:32 +00:00
Thiago Crepaldi
9bbe1d632e Fix ONNX ATen fallback for non-caffe2 engines
This PR introduces 3 BC changes:

First, this PR propagates `BUILD_CAFFE2` flag to `libtorch` and `libtorch_python`, which is necessary for non-caffe2 ONNX runtimes when using `ONNX_ATEN_FALLBACK` operator export type.

Second, as a complement of https://github.com/pytorch/pytorch/pull/68490, this PR refactors Caffe2's Aten ops symbolics to consider not only the `operator_export_type` (aka `ONNX_ATEN_FALLBACK`) to emit Caffe2 Aten ops, but also whether `BUILD_CAFFE2` (which is called `torch.onnx._CAFFE2_ATEN_FALLBACK` in python binding) is set.

Lastly, it renames `onnx::ATen` to `aten::ATen` for ONNX spec consistency in a BC fashion.
ONNX doesn't have `ATen` op on its spec, but PyTorch ONNX converter emits them. Non-Caffe2 backend engines would be mislead by such operator's name/domain. A non-ideal workaround would be to have Aten ops handled based on its name and ignore the (non-complaint) domain. Moreover, users could incorrectly file bugs to either ONNX or ONNX Runtime when they inspect the model and notice the presence of an unspecified ONNX operator.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73954
Approved by: https://github.com/BowenBao, https://github.com/malfet, https://github.com/garymm, https://github.com/jiafatom
2022-04-14 23:18:45 +00:00
BowenBao
144b7de9dd [ONNX] Adjust is_train flag for onnx pass deduplicate initializers
Previous logic didn't consider the case for TrainingMode.PRESERVE.
A more direct way is to check `model.training`, which is the accurate
training mode, set by `exporter_context(model, training)`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/74247
Approved by: https://github.com/garymm
2022-03-22 22:39:20 +00:00
BowenBao
54a6942f8d [ONNX] ONNX Exporter logging (#71342)
Summary:
Add ONNX exporter logging facility. Supporting both C++/Python logging api. Logging can be turned on/off. Logging output stream can be either set to `stdout` or `stderr`.

A few other changes:
* When exception is raised in passes, the current IR graph being processed will be logged.
* When exception is raised from `_jit_pass_onnx` (the pass that converts nodes from namespace `ATen` to `ONNX`), both ATen IR graph and ONNX IR graph under construction will be logged.
* Exception message for ConstantFolding is truncated to avoid being too verbose.
* Update the final printed IR graph with node name in ONNX ModelProto as node attribute. Torch IR Node does not have name. Adding this to printed IR graph helps debugging.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/71342

Reviewed By: msaroufim

Differential Revision: D34433473

Pulled By: malfet

fbshipit-source-id: 4b137dfd6a33eb681a5f2612f19aadf5dfe3d84a
(cherry picked from commit 67a8ebed5192c266f604bdcca931df6fe589699f)
2022-03-17 19:40:03 +00:00
BowenBao
9210e8f540 [ONNX] Adds overload_name to Aten op (#69378) (#73280)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73280

This PR adds a new attribute overload_name to the Aten node so that third party applications can implement calls to libtorch without using PyTorch source code.

This is necessary because torch's torch::jit::findOperatorFor(fullname) requires a full name, including operator and overload names.

ATen op was originally created for Caffe2, which leveraged the availability of the pytorch yaml files to create calls to the aten oeprators directly, not relying on torch::jit::findOperatorFor

The first part of the PR refactors all symbolics that create Aten ops, so that there is a single helper for this operator. Next all symbolics are updated to pass in the relevant overload name, if empty string is not applicable

Test Plan: Imported from OSS

Reviewed By: jbschlosser

Differential Revision: D34625645

Pulled By: malfet

fbshipit-source-id: 37d58cfb5231833768172c122efc42edf7d8609a
(cherry picked from commit e92f09117d3645b38bc3235b30aba4b4c7c71dfa)
2022-03-09 14:26:18 +00:00
BowenBao
b3cfc74f0f [ONNX] Capture annotated attributes for local function
Enables local function export to capture annotated attributes.
For example:
```python
class M(torch.nn.Module):
    num_layers: int

    def __init__(self, num_layers):
        super().__init__()
        self.num_layers = num_layers

    def forward(self, args):
        ...
```
`num_layers` will now be captured as attribute of local function `M`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72883
2022-02-28 18:56:18 +00:00
BowenBao
28bf2f80cf Don't call _jit_pass_onnx_function_extraction if export_modules_as_functions is False (#69742)
* fix clang-format violations

* Don't call _jit_pass_onnx_function_extraction if export_modules_as_functions is False

It's just wasteful.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/73100
2022-02-22 22:43:53 +00:00
BowenBao
2791725a84 Integrate full ONNX check into ONNX export API (#71125)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72988
2022-02-18 18:40:09 +00:00
BowenBao
32f6a1e2a2 [ONNX] First version of quantized model export: Support quantized.Linear (#69232)
Co-authored-by: David Fan <jiafamicrosoft.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/72986
2022-02-18 18:27:26 +00:00
BowenBao
cc792746d2 [ONNX] De-duplicate initializers (#68202) (#69547)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69547

ScriptModule export introduces duplicated ONNX initializers for shared weights, unnecessarily increases ONNX model size. This PR de-duplicates ONNX initializers for model exported in eval mode, by checking if the underlying tensors share the same `data_ptr`, `strides` and `sizes`.

Test Plan: Imported from OSS

Reviewed By: msaroufim

Differential Revision: D32994271

Pulled By: malfet

fbshipit-source-id: 10ac66638b6255890875272472aa9ed07a5b1d9a

Co-authored-by: BowenBao <bowbao@microsoft.com>
(cherry picked from commit d7cbde940c)
2022-02-11 22:05:15 +00:00
BowenBao
04c5d978b9 [ONNX] Refactor _run_symbolic_function (#67573) (#68491)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68491

* Allows implementing symbolic functions for domains other than `aten`, for example `prim`, in symbolic_opset#.py.
* Allows symbolic function to access extra context if needed, through `SymbolicFunctionState`.
  * Particularly, the `prim::PythonOp` special case can access node without the need of passing node through inputs. Updates will be made downstreams, and in a follow-up PR we will remove the previous workaround in exporter.
* `prim::Loop`, `prim::If`, etc are now moved outside of `_run_symbolic_function` from utils.py, and to symbolic_opset9.py.

Motivation for this change:
- Better maintainability and reducing complexity. Easier to add symbolic for operators, both simple and complex ones (that need additional context), without the former needing to know the existence of the latter.
- The design idea was long outdated. prim ops are no longer rare special cases, and they shouldn't all be handled inside `_run_symbolic_function`. As a result this function becomes too clumsy. There were also prim ops symbolic added in symbolic_opset#.py with signature `prim_[opname]`, creating separation and confusion.

Test Plan: Imported from OSS

Reviewed By: jansel

Differential Revision: D32483782

Pulled By: malfet

fbshipit-source-id: f9affc31b1570af30ffa6668da9375da111fd54a

Co-authored-by: BowenBao <bowbao@microsoft.com>
(cherry picked from commit 1e04ffd2fd)
2022-02-11 18:35:35 +00:00
BowenBao
eb4238fc26 Allow caffe2-specific graph transformations for OperatorExportTypes.ONNX_ATEN_FALLBACK when BUILD_CAFFE2 is ON (#67460) (#68490)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68490

The use of ATEN as a fallback operator during ONNX conversion is important for increasing operator coverage or even provide more efficient implementations over some ONNX ops.

Currently this feature is available through `OperatorExportTypes.ONNX_ATEN_FALLBACK`,
but it also performs changes to the graph that are runnable by Caffe2, only.

This PR introduces restricts caffe2-specific graph transformations for `ONNX_ATEN_FALLBACK`
operator export type for when pytorch is built with caffe2 support (aka BUILD_CAFFE2=1 during build)

The first version of this PR introduced a new operator export type `ONNX_ATEN__STRICT_FALLBACK`,
which essentially is the same as `ONNX_ATEN_FALLBACK` but without caffe2 transformations.
It was preferred to not introduce a new operator export type, but to refine the existing aten fallback one

## BC-breaking note
### The global constant `torch.onnx.PYTORCH_ONNX_CAFFE2_BUNDLE` is removed in favor of
a less visible `torch.onnx._CAFFE2_ATEN_FALLBACK`.
`PYTORCH_ONNX_CAFFE2_BUNDLE` is really a dead code flag always set to False.
One alternative would be fixing it, but #66658 disables Caffe2 build by default.
Making a Caffe2 feature a private one seems to make more sense for future deprecation.

### The method `torch.onnx.export` now defaults to ONNX when `operator_export_type` is not specified.
Previously `torch.onnx.export's operator_export_type` intended to default to `ONNX_ATEN_FALLBACK` when `PYTORCH_ONNX_CAFFE2_BUNDLE` was set, but it would never happen as `PYTORCH_ONNX_CAFFE2_BUNDLE` is always undefined

 Co-authored-by: Nikita Shulga <nshulga@fb.com>

Test Plan: Imported from OSS

Reviewed By: jansel

Differential Revision: D32483781

Pulled By: malfet

fbshipit-source-id: e9b447db9466b369e77d747188685495aec3f124
(cherry picked from commit 5fb1eb1b19)
2022-02-10 03:26:48 +00:00
BowenBao
cf70466970 [ONNX] Improve scope inference in function extraction
Cover more cases of scope inferencing where consecutive nodes don't have valid scope information. Usually these nodes are created in some pass where authors forgot to assign meaningful scope to them.
* One rule of `InferScope` is to check if the current node's outputs' users share the same scope. Recursively run `InferScope` on the user nodes if they are missing scope as well. Since the graph is SSA, the depth is finite.
* Fix one pass that missed scope information for a new node.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71897
2022-01-31 23:58:53 +00:00
BowenBao
804f13289e [ONNX] Update opset_version restriction for local function
Export should fail if export_modules_as_functions is set and opset_version<15.
This is because opeset_version < 15 implies IR version < 8, which means no local function support.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71619
2022-01-27 00:21:13 +00:00
Emilio Castillo
e2dc2aca93 Export ONNX models with readable input/output names (#68976)
Summary:
For some ONNX exported models, the inputs/outputs names have sometimes a numeric value and this makes pretty hard to inspect the generated graphs in the case of large models.

The solution in this PR was initially submitted to our internal utilities library by take-cheeze https://github.com/pfnet/pytorch-pfn-extras/pull/102

Now we would like to upstream this change by adding an extra kwarg when exporting the model to allow replacing these numeric names with actual debuggable ones.

As an example, the following code shows that the module output is `3`

```python
g, p, o = _model_to_graph(module, torch.ones(1, 10))
for n in g.nodes():
    for v in n.outputs():
        print(v.debugName())
```
output
```
3
```

With this PR

```
v3_Gemm
```

This allows identifying this out as a value from the associated Gemm layer.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/68976

Reviewed By: jansel

Differential Revision: D33662246

Pulled By: msaroufim

fbshipit-source-id: 45f56eef2a84d9a318db20c6a6de6c2743b9cd99
(cherry picked from commit 513c1d28f1)
2022-01-21 00:34:56 +00:00
BowenBao
ff78c73286 [ONNX] Remove f arg from export_to_pretty_string (#69045) (#69546)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69546

The arg is not used and was previously deprecated.

Also remove torch.onnx._export_to_pretty_string. It's redundant with the
public version.

Test Plan: Imported from OSS

Reviewed By: malfet

Differential Revision: D32994270

Pulled By: msaroufim

fbshipit-source-id: f8f3933b371a0d868d9247510bcd73c31a9d6fcc
2022-01-12 21:31:36 -08:00
Deyu Huang
d32efe8bc2 [ONNX] Remove the argument use_external_data_format of export() method entirely. (#67080) (#67811)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67811

* remove the argument use_external_data_format of export() method entirely

Test Plan: Imported from OSS

Reviewed By: msaroufim

Differential Revision: D32181302

Pulled By: malfet

fbshipit-source-id: 4bc1448b7487bb9dfdad4e36008ff5b227fd64a3

Co-authored-by: hwangdeyu <dejack953@outlook.com>
2021-11-15 17:20:04 -08:00
Thiago Crepaldi
9d25554d45 [ONNX] Allow registration of custom symbolics for aten namespace (#66481) (#67810)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67810

Test Plan: Imported from OSS

Reviewed By: msaroufim

Differential Revision: D32181303

Pulled By: malfet

fbshipit-source-id: af2a715dc554b958fa3f5a7a8ae96cb3f7d112bb
2021-11-15 17:18:39 -08:00