Commit Graph

157 Commits

Author SHA1 Message Date
CYuxian
35f4bb9a25 [ONNX] Return input itself for non-fp inputs and support decimals for aten::round op (#107920)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107920
Approved by: https://github.com/justinchuby
2023-08-26 05:54:52 +00:00
Thiago Crepaldi
4be6b6b673 Add quantization support to reshape and size for the ONNX exporter (#106629)
Fixes https://github.com/microsoft/onnx-converters-private/issues/175

Add quantization support for Reshape-14, Size-9 and Size-11
For Size operators, we don't requantize outputs because we want the original scalar in the graph
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106629
Approved by: https://github.com/BowenBao
2023-08-05 02:08:52 +00:00
AllenTiTaiWang
d5d6eb2d46 [ONNX] Refactor AvgPool to support dynamic shapes (#105683)
In #87892, to pick up the corner cases found in #71549, the PR falls back the implementation of AvgPool to the way opset 9 implementing. However, it introduces a regression on dynamic shape cases found in #101397. This PR refactors the AvgPool op with the same implementation we have in onnxscript: https://github.com/microsoft/onnxscript/pull/754.

However, the corner case with `count_include_pad` remains unsolved in onnxruntime: https://github.com/microsoft/onnxruntime/issues/16203. The calculuation on the last value of each dimension is different between ORT and PyTorch. But the fix can be proved in: https://github.com/microsoft/onnxruntime/pull/16752, and it supports AvgPool since opset19.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105683
Approved by: https://github.com/thiagocrepaldi
2023-07-21 20:22:08 +00:00
Justin Chu
242a7743f3 [BE] Enable ruff's UP rules and autoformat onnx/ (#105427)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105427
Approved by: https://github.com/malfet
2023-07-18 21:41:24 +00:00
Ilya Sherstyuk
8c0b9a2d69 [ONNX] Export dynamic step size for aten::slice() (#104385)
This commit improves the export of aten::slice() to ONNX in the following ways:

1. The step size can be an input tensor rather than a constant.
2. Fixes a bug where using a 1-D, 1-element torch tensor as an index created a broken ONNX model.

This commit also adds tests for the new functionality.

Fixes #104314

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104385
Approved by: https://github.com/thiagocrepaldi
2023-07-06 21:38:59 +00:00
AllenTiTaiWang
d3971f2d15 [ONNX] Support aten::hstack and aten::vstack (#102872)
https://github.com/microsoft/onnxscript/pull/762 is actually not used by FX graph.
Fixes https://github.com/microsoft/onnx-converters-private/issues/168
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102872
Approved by: https://github.com/justinchuby
2023-06-16 06:20:14 +00:00
AllenTiTaiWang
0411fc6ab6 [ONNX] Support aten::atleast_1d and aten::atleast_2d and aten::atleast_3d (#103061)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103061
Approved by: https://github.com/justinchuby
2023-06-16 06:07:00 +00:00
Matthew Hoffman
29da75cc55 Enable mypy allow redefinition (#102046)
Related #101528

I tried to enable this in another PR but it uncovered a bunch of type errors: https://github.com/pytorch/pytorch/actions/runs/4999748262/jobs/8956555243?pr=101528#step:10:1305

The goal of this PR is to fix these errors.

---

This PR enables [allow_redefinition = True](https://mypy.readthedocs.io/en/stable/config_file.html#confval-allow_redefinition) in `mypy.ini`, which allows for a common pattern:

> Allows variables to be redefined with an arbitrary type, as long as the redefinition is in the same block and nesting level as the original definition.

`allow_redefinition` allows mypy to be more flexible by allowing reassignment to an existing variable with a different type... for instance (from the linked PR):

4a1e9230ba/torch/nn/parallel/data_parallel.py (L213)

A `Sequence[Union[int, torch.device]]` is narrowed to `Sequence[int]` thru reassignment to the same variable.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102046
Approved by: https://github.com/ezyang
2023-05-24 07:05:30 +00:00
Bas Aarts
b3b0fbca11 [ONNX] Export Relu6 without using Relu (#99022)
The clamp operator used in the export of Relu6 already clamps the input value in between 0 and 6. There's no need to first perform a Relu on the input tensor.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99022
Approved by: https://github.com/BowenBao
2023-04-19 06:18:14 +00:00
Aaron Gokaslan
47dca20d80 [BE] Enable flake8-comprehension rule C417 (#97880)
Enables flake8-comprehension rule C417. Ruff autogenerated these fixes to the codebase.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97880
Approved by: https://github.com/ezyang, https://github.com/kit1980, https://github.com/albanD
2023-03-30 14:34:24 +00:00
BowenBao
88d0235b73 [ONNX] Update CI test environment; Add symbolic functions (#94564)
* CI Test environment to install onnx and onnx-script.
* Add symbolic function for `bitwise_or`, `convert_element_type` and `masked_fill_`.
* Update symbolic function for `slice` and `arange`.
* Update .pyi signature for `_jit_pass_onnx_graph_shape_type_inference`.

Co-authored-by: Wei-Sheng Chin <wschin@outlook.com>
Co-authored-by: Ti-Tai Wang <titaiwang@microsoft.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94564
Approved by: https://github.com/abock
2023-02-10 20:44:59 +00:00
BowenBao
24172eebac [ONNX] Export 'aten::index_put(self, mask, v)' when rank(mask) < rank(self) (#92862)
Fix #92540

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92862
Approved by: https://github.com/justinchuby
2023-01-27 02:00:56 +00:00
AllenTiTaiWang
4e9539e002 [ONNX] Support ListConstruct in quantized_args (#92009)
Fixes #91303

quantized_args didn't support ListConstruct leading to an error when user uses quantized op with list inputs, ex: aten::cat. After this PR, converter can successfully export the issued model and pass ONNX checker. However, ORT doesn't seem to support it with the very same error as https://github.com/microsoft/onnxruntime/issues/12131.

Update:
I find test_quantized_cat_when_concatinating_the_same_tensor is even similar to the new case we have in here. The only difference is whether the inputs are already quantized. ONNX graphs both seem to be valid.
[test_quantized_cat_when_concatinating_the_same_tensor.zip](https://github.com/pytorch/pytorch/files/10396798/test_quantized_cat_when_concatinating_the_same_tensor.zip)
[test_quantized_list_of_inputs_with_cat.zip](https://github.com/pytorch/pytorch/files/10396799/test_quantized_list_of_inputs_with_cat.zip)

issue raised https://github.com/microsoft/onnxruntime/issues/14245
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92009
Approved by: https://github.com/BowenBao
2023-01-23 20:55:08 +00:00
Thiago Crepaldi
a8f40b39ce Update all ONNX symbolics with new JitScalarType API (#87245)
Fixes https://github.com/pytorch/pytorch/issues/84365 and more

This PR addresses not only the issue above, but the entire family of issues related to `torch._C.Value.type()` parsing when `scalarType()` or `dtype()` is not available.

This issue exists before `JitScalarType` was introduced, but the new implementation refactored the bug in because the new api `from_name` and `from_dtype` requires parsing `torch._C.Value.type()` to get proper inputs, which is exactly the root cause for this family of bugs.

Therefore `from_name` and `from_dtype` must be called when the implementor knows the `name` and `dtype` without parsing a `torch._C.Value`. To handle the corner cases hidden within `torch._C.Value`, a new `from_value` API was introduced and it should be used in favor of the former ones for most cases. The new API is safer and doesn't require type parsing from user, triggering JIT asserts in the core of pytorch.

Although CI is passing for all tests, please review carefully all symbolics/helpers refactoring to make sure the meaning/intetion of the old call are not changed in the new call

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87245
Approved by: https://github.com/justinchuby, https://github.com/BowenBao
2022-11-03 03:01:33 +00:00
AllenTiTaiWang
f2ae459311 [ONNX] Disable ONNX ceil_mode and count_include_pad to aligntorch ceil_mode results in corner case (#87892)
ONNX and PyTorch has different equation on pooling and different strategy on ceil_mode, which leads to discrepancy on corner case (#71549 ).
Specifically, PyTorch avereage pooling is not following [the equation on documentation](https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html), it allows sliding window to go off-bound instead, if they start within the left padding or the input (in NOTE section). More details can be found in #57178.

This PR changes avgpool in opset 10 and 11 back the way as opset 9, which it stops using ceil_mode and count_include_pad  in onnx::AveragePool

A comprehensive test for all combinations of parameters can be found in the next PR. #87893
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87892
Approved by: https://github.com/BowenBao
2022-10-29 11:35:10 +00:00
AllenTiTaiWang
52ac8adc20 [ONNX] Fix pad Circular Mode (#86984)
In https://github.com/pytorch/pytorch/pull/73433, a ONNX test case is missed, and the result is incorrect when it is converted to ONNX.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86984
Approved by: https://github.com/BowenBao
2022-10-25 19:39:35 +00:00
Justin Chu
5deeb09d4e [ONNX] Annotate all g as GraphContext (#85491)
- Use g.opset to test export opset version
- Annotate all `g` as GraphContext

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85491
Approved by: https://github.com/AllenTiTaiWang, https://github.com/BowenBao
2022-09-28 22:39:28 +00:00
Justin Chu
3d2316670f [ONNX] Create GraphContext and load g.op method to the class (#84728)
This PR create the `GraphContext` class and relays all graph methods to _C.Graph as well as implements the `g.op`  method. The GraphContext object is passed into the symbolic functions in place of _C.Graph for compatibility with existing symbolic functions.

This way (1) we can type annotate all `g` args because the method is defined and (2) we can use additional context information in symbolic functions. (3) no more monkey patching on `_C.Graph`

Also

- Fix return type of `_jit_pass_fixup_onnx_controlflow_node`
- Create `torchscript.py` to house torch.Graph related functions
- Change `GraphContext.op` to create nodes in the Block instead of the Graph
- Create `add_op_with_blocks` to handle scenarios where we need to directly manipulate sub-blocks. Update loop and if symbolic functions to use this function.

## Discussion

Should we put all the context inside `SymbolicContext` and make it an attribute in the `GraphContext` class? This way we only define two attributes `GraphContext.graph` and `GraphContext.context`. Currently all context attributes are directly defined in the class.

### Decision

Keep GraphContext flatand note that it will change in the future.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84728
Approved by: https://github.com/AllenTiTaiWang, https://github.com/BowenBao
2022-09-28 22:21:55 +00:00
Justin Chu
2f50d2f685 [ONNX] Update docs on symbolic registration (#85290)
- Move inline instructions on editing symbolic functions to the README
- Add a line on using the symbolic function registration decorator.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85290
Approved by: https://github.com/BowenBao
2022-09-22 13:37:11 +00:00
Justin Chu
76d60778eb [ONNX] Use decorators for symbolic function registration (#84448)
This is the 4th PR in the series of #83787. It enables the use of `@onnx_symbolic` across `torch.onnx`.

- **Backward breaking**: Removed some symbolic functions from `__all__` because of the use of  `@onnx_symbolic` for registering the same function on multiple aten names.
- Decorate all symbolic functions with `@onnx_symbolic`
- Move Quantized and Prim ops out from classes to functions defined in the modules. Eliminate the need for `isfunction` checking, speeding up the registration process by 60%.
    - Remove the outdated unit test `test_symbolic_opset9.py`
- Symbolic function registration moved from the first call to `_run_symbolic_function` to init time.
- Registration is fast:
  ![image](https://user-images.githubusercontent.com/11205048/189164959-f3fca173-19bc-4682-b150-f13a586387bf.png)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84448
Approved by: https://github.com/AllenTiTaiWang, https://github.com/BowenBao
2022-09-22 06:25:24 +00:00
Justin Chu
388368b699 [ONNX] Fix type annotations and enable type checking for all apis (#84091)
Enable runtime type checking for all torch.onnx public apis, symbolic functions and most helpers (minus two that does not have a checkable type: `_.JitType` does not exist) by adding the beartype decorator. Fix type annotations to makes unit tests green.

Profile:

export `torchvision.models.alexnet(pretrained=True)`

```
with runtime type checking: 21.314 / 10 passes
without runtime type checking: 20.797 / 10 passes

+ 2.48%
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84091
Approved by: https://github.com/BowenBao, https://github.com/thiagocrepaldi
2022-09-03 01:40:18 +00:00
BowenBao
fd756caa36 [ONNX] Support nn.init.normal (#84149)
* Updated symbolic function for `aten::normal` to support additional generator arguments emitted from 5563248b58/torch/csrc/jit/passes/remove_mutation.cpp (L51)
* Added symbolic function for `aten::is_pinned` and `prim::layout`. Both are unused by ONNX later on.

Fixes #83647

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84149
Approved by: https://github.com/AllenTiTaiWang, https://github.com/abock
2022-09-01 18:29:41 +00:00
PyTorch MergeBot
d8cc8368ab Revert "[ONNX] Fix type annotations and enable type checking for all apis (#84091)"
This reverts commit 6446da1730.

Reverted https://github.com/pytorch/pytorch/pull/84091 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally
2022-08-28 12:28:58 +00:00
Justin Chu
6446da1730 [ONNX] Fix type annotations and enable type checking for all apis (#84091)
Enable runtime type checking for all torch.onnx public apis, symbolic functions and most helpers (minus two that does not have a checkable type: `_.JitType` does not exist) by adding the beartype decorator. Fix type annotations to makes unit tests green.

Profile:

export `torchvision.models.alexnet(pretrained=True)`

```
with runtime type checking: 21.314 / 10 passes
without runtime type checking: 20.797 / 10 passes

+ 2.48%
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84091
Approved by: https://github.com/BowenBao
2022-08-27 04:40:41 +00:00
Justin Chu
3dfb8dfcf3 [ONNX] Use errors.SymbolicValueError for more context (#83332)
Replace runtime errors in torch.onnx with `errors.SymbolicValueError` for more context around jit values.

- Extend `_unimplemented`, `_onnx_unsupported`, `_onnx_opset_unsupported`, `_onnx_opset_unsupported_detailed` errors to include JIT value information
- Replace plain RuntimeError with `errors.SymbolicValueError`
- Clean up: Use `_is_bool` to replace string comparison on jit types
- Clean up: Remove the todo `Remove type ignore after #81112`

#77316
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83332
Approved by: https://github.com/AllenTiTaiWang, https://github.com/thiagocrepaldi, https://github.com/BowenBao
2022-08-23 05:39:17 +00:00
Justin Chu
80cfafc385 [ONNX] Add quantization support to more single output ops (#83008)
#80039

- Implement quantization support for single output ops
  - quantized::sigmoid
  - quantized::instance_norm
  - aten::reshape
  - aten::reshape_as
  - aten::sum
  - aten::mean
  - aten::prod
  - aten::t
  - aten::numpy_T
  - aten::expand
  - aten::expand_as
  - aten::embedding
  - aten::embedding_bag
  - aten::view
  - aten::select
  - aten::eq
  - aten::ne
  - aten::gt
  - aten::lt
  - aten::le
  - aten::ge
  - quantized::layer_norm
  - aten::elu
  - aten::selu
  - aten::maximum
  - aten::minimum
  - aten::amax
  - aten::amin
  - aten::hardtanh
  - aten::hardswish
  - quantized::group_norm
  - aten::as_strided
  - quantized::leaky_relu
  - aten::transpose
- Avoid modifying functions in `quantized_args` and have the wrapper closed over `scale` and `zero_point` instead (for purity)
- Remove magic number and assign it to INT64_MAX
- implement `_unpack_quantized_tensor` for handling quantized tensor unpacking to separate the logic from tuple unpacking and for clearer error handling
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83008
Approved by: https://github.com/BowenBao
2022-08-23 00:39:24 +00:00
Justin Chu
c6cdca5c68 [ONNX] Reland #81953 Type utility for converting among JIT, torch and ONNX data types (#82995)
Re-land #81953

Add `_type_utils` for handling data type conversion among JIT, torch and ONNX.

- Replace dictionary / list indexing with methods in ScalarType
- Breaking: **Remove ScalarType from `symbolic_helper`** and move it to `_type_utils`
- Deprecated: "cast_pytorch_to_onnx", "pytorch_name_to_type", "scalar_name_to_pytorch", "scalar_type_to_onnx", "scalar_type_to_pytorch_type" in `symbolic_helper`
- Deprecate the type mappings and lists. Remove all internal references
- Move _cast_func_template to opset 9 and remove its reference elsewhere (clean up). Added documentation for easy discovery

Why: List / dictionary indexing and lookup are error-prone and convoluted.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82995
Approved by: https://github.com/kit1980
2022-08-08 23:43:43 +00:00
PyTorch MergeBot
b170a52a09 Revert "[ONNX] Type utility for converting among JIT, torch and ONNX data types (#81953)"
This reverts commit 6ddf4c6f58.

Reverted https://github.com/pytorch/pytorch/pull/81953 on behalf of https://github.com/kit1980 due to Broke internal builds by removing functions without deprecation
2022-08-07 20:15:28 +00:00
Justin Chu
6ddf4c6f58 [ONNX] Type utility for converting among JIT, torch and ONNX data types (#81953)
Add `_type_utils` for handling data type conversion among JIT, torch and ONNX.

- Replace dictionary / list indexing with methods in ScalarType
- Breaking: **Remove ScalarType from `symbolic_helper`** and move it to `_type_utils`
- Breaking: **Remove "cast_pytorch_to_onnx", "pytorch_name_to_type", "scalar_name_to_pytorch", "scalar_type_to_onnx", "scalar_type_to_pytorch_type"** from `symbolic_helper`
- Deprecate the type mappings and lists. Remove all internal references
- Move _cast_func_template to opset 9 and remove its reference elsewhere (clean up). Added documentation for easy discovery

Why: List / dictionary indexing and lookup are error-prone and convoluted.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81953
Approved by: https://github.com/AllenTiTaiWang, https://github.com/BowenBao
2022-08-05 22:24:45 +00:00
Huy Do
6ea422dd0b Format torch/onnx with ufmt (#82137)
This is the last batch for the new ufmt (black + usort) linter. After this, black linter can finally be replaced. The previous PR to format ONNX tests was #81335
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82137
Approved by: https://github.com/kit1980, https://github.com/AllenTiTaiWang
2022-07-25 22:42:21 +00:00
Justin Chu
ed1da2a9df [ONNX] Quantization support for quantized::cat (#79826)
- Add support for quantized `cat`
- Add type annotations for helper functions

Now we can export

```python
import torchvision.models.quantization as models
from torchvision import transforms

torch_model = models.inception_v3(pretrained=True, quantize=True)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79826
Approved by: https://github.com/AllenTiTaiWang, https://github.com/BowenBao
2022-07-12 15:44:23 +00:00
Justin Chu
06710ec1b9 [ONNX] Reland: Add quantization support to _avg_pool opset 9 and clean up (#81267)
- Add quantization support to _avg_pool opset 9
- Clean up reused / unused variables in avgpool helper
- Add types
- Sort `__all__`

Reland #79793 (Added `argsort` in `__all__`)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81267
Approved by: https://github.com/BowenBao
2022-07-11 23:58:28 +00:00
PyTorch MergeBot
b98b9eaae5 Revert "[ONNX] Add quantization support to _avg_pool opset 9 and clean up (#79793)"
This reverts commit 356341a3ec.

Reverted https://github.com/pytorch/pytorch/pull/79793 on behalf of https://github.com/malfet due to Broke trunk, see 356341a3ec
2022-07-09 02:40:28 +00:00
Justin Chu
356341a3ec [ONNX] Add quantization support to _avg_pool opset 9 and clean up (#79793)
- Add quantization support to _avg_pool opset 9
- Clean up reused / unused variables in avgpool helper
- Add types
- Sort `__all__`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79793
Approved by: https://github.com/BowenBao
2022-07-08 23:14:01 +00:00
shubhambhokare1
712b3e76ef [onnx] Add argsort support (#80234)
Fixes #79859 where exporting the operator '::argsort' to ONNX opset version 16 is not supported.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80234
Approved by: https://github.com/BowenBao
2022-07-07 22:06:29 +00:00
qqaatw
c47f78d25f [ONNX] Fix linalg norm output's shapes and dtypes (#79506)
Part of #79263

This PR fixes the following matters:

1. Before this fix, the reduced output has `[1]` shape when `dim = None` and `keepdim = False`. Now the output is reduced to `[]` shape, which matches Pytorch's behavior.
2. Before this fix, the output is always casted to `Long`. Now the output is casted to the input's dtype.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79506
Approved by: https://github.com/justinchuby, https://github.com/BowenBao
2022-06-22 02:12:34 +00:00
Justin Chu
c8b9b6266b [ONNX] Fix arg type in _set_training_mode (#78583)
When `TrainingMode.PRESERVE` is set for export, the exporter used to change the model's training mode based on some logic. Now we respect the option and not touch the model's training state.

- Previously `_set_training_mode`'s behavior doesn't match what the global variable expects. This PR removes the deprecated `_set_training_mode` and makes the type correct.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78583
Approved by: https://github.com/BowenBao
2022-06-15 23:47:12 +00:00
Justin Chu
d3ef5c3fa3 [ONNX] Clean up __init__ in torch.onnx (#78446)
- Move definitions in `__init__` to internal classes and expose them by importing to init (prevent circular dependencies): https://github.com/pytorch/pytorch/wiki/torch.onnx-Namespacing
  - Context classes and enums are moved to `_exporter_states.py`
  - Exceptions are moved to `errors.py`
- Define `__all__` for torch.onnx. https://github.com/pytorch/pytorch/wiki/Public-API-definition-and-documentation
- Moved `utils.__IN_ONNX_EXPORT` to `GLOBALS.in_onnx_export`
- Deprecated `torch.onnx._export`

Precedes #78231

Using this as an aid for finding public functions:

```python
list(filter(lambda x: not x.startswith("_"), torch.onnx.utils.__dict__.keys()))
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78446
Approved by: https://github.com/BowenBao
2022-06-14 04:35:06 +00:00
Justin Chu
def778527e [ONNX] Quantization support for five ops (#78103)
- Add quantization support for `interpolate`, `avgpool`, `sigmoid` and `add_relu`
- Return the inputs to ListUnpack if the previous node is ListConstruct so that `ListConstruct` and `ListUnpack` are canceled and removed in the jit passes. ONNX doesn't support them.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78103
Approved by: https://github.com/garymm
2022-06-03 20:22:07 +00:00
Justin Chu
0d76299ff7 [ONNX] Clean up module imports (#77423)
Cleaning up onnx module imports to prepare for updating `__init__`.

- Simplify importing the `_C` and `_C._onnx` name spaces
- Remove alias of the symbolic_helper module in imports
- Remove any module level function imports. Import modules instead
    - Alias `symbilic_opsetx` as `opsetx`
- Fix some docstrings

Requires:
- https://github.com/pytorch/pytorch/pull/77448
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77423
Approved by: https://github.com/BowenBao
2022-05-20 01:56:24 +00:00
Justin Chu
563c2719bf [ONNX] Refactor to remove inline imports - attempt 2 (#77448)
Re-land
- #77142

(diff: https://github.com/pytorch/pytorch/compare/c08b8f0..justinchuby:justinchu/remove-patch2)

Fixed:
- Delay import symbolic_opsets in the registry.

Tested locally with torchvision
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77448
Approved by: https://github.com/garymm
2022-05-16 14:44:24 +00:00
PyTorch MergeBot
6b366dd3c1 Revert "[ONNX] Refactor to remove inline imports (#77142)"
This reverts commit c08b8f0967.

Reverted https://github.com/pytorch/pytorch/pull/77142 on behalf of https://github.com/malfet
2022-05-13 19:44:17 +00:00
Justin Chu
c08b8f0967 [ONNX] Refactor to remove inline imports (#77142)
Reduce circular dependencies

- Lift constants and flags from `symbolic_helper` to `_constants` and `_globals`
    - Standardized constant naming to make it consistant
- Make `utils` strictly dependent on `symbolic_helper`, removing inline imports from symbolic_helper
- Move side effects from `utils` to `_patch_torch`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77142
Approved by: https://github.com/garymm, https://github.com/BowenBao
2022-05-13 03:46:33 +00:00
Justin Chu
5dd1c67776 [ONNX] Format ONNX python with black
Format all onnx python code with black and isort with

```sh
isort torch/onnx/ test/onnx
black torch/onnx/ test/onnx
```

Updated lintrunner config to include these paths.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76754
Approved by: https://github.com/suo, https://github.com/BowenBao
2022-05-05 00:19:22 +00:00
Peter Bell
cb37e7a080 Remove F.pad python implementation
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73433

Approved by: https://github.com/albanD, https://github.com/jbschlosser
2022-04-23 00:13:20 +00:00
Thiago Crepaldi
eab3f42883 Update symbolics policy to emit aten::ATen for Caffe2 build only
Currently ONNX exporter symbolics can emit ATen operators when `operator_export_type==ONNX_ATEN_FALLBACK`. However, this is a behavior specific to Caffe2 builds, as the intend use of `ONNX_ATEN_FALLBACK` is to emit ATen operators only when there is no ONNX equivalent.

The reason Caffe2 choses to emit ATen operators when ONNX counterpart exists is for performance on their particular engine implementation, which might not be true for other implementations. e.g. ONNX Runtime can optimize the generated ONNX graph into something more efficient

This PR must be merged only after https://github.com/pytorch/pytorch/pull/73954
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74680
Approved by: https://github.com/garymm, https://github.com/malfet
2022-04-19 15:57:54 +00:00
Thiago Crepaldi
9bbe1d632e Fix ONNX ATen fallback for non-caffe2 engines
This PR introduces 3 BC changes:

First, this PR propagates `BUILD_CAFFE2` flag to `libtorch` and `libtorch_python`, which is necessary for non-caffe2 ONNX runtimes when using `ONNX_ATEN_FALLBACK` operator export type.

Second, as a complement of https://github.com/pytorch/pytorch/pull/68490, this PR refactors Caffe2's Aten ops symbolics to consider not only the `operator_export_type` (aka `ONNX_ATEN_FALLBACK`) to emit Caffe2 Aten ops, but also whether `BUILD_CAFFE2` (which is called `torch.onnx._CAFFE2_ATEN_FALLBACK` in python binding) is set.

Lastly, it renames `onnx::ATen` to `aten::ATen` for ONNX spec consistency in a BC fashion.
ONNX doesn't have `ATen` op on its spec, but PyTorch ONNX converter emits them. Non-Caffe2 backend engines would be mislead by such operator's name/domain. A non-ideal workaround would be to have Aten ops handled based on its name and ignore the (non-complaint) domain. Moreover, users could incorrectly file bugs to either ONNX or ONNX Runtime when they inspect the model and notice the presence of an unspecified ONNX operator.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73954
Approved by: https://github.com/BowenBao, https://github.com/malfet, https://github.com/garymm, https://github.com/jiafatom
2022-04-14 23:18:45 +00:00
ganler
c1f0e6e763 [ONNX] Make Non-Float Op Exportation Compatible to Avoid Invalid ONNX Models
There are a few ONNX operators do not support non-float (e.g., integer) inputs at early versions. For example, Clip supports non-float types until [opset 12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#type-constraints-280), that said older versions like [opset 6](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#type-constraints-107) cannot deal with integer types.

I initially find such a bug in Clip (https://github.com/pytorch/pytorch/pull/70584), but later found more:
1. Clip < 12;
2. Min/Max < 12;
3. ReLU < 14;
4. Pad < 11;

In PyTorch, if we export Max-11 with integer inputs, actually the exportation will succeed; however, fail when imported by other frameworks like ONNXRuntime.

```python
import torch

class Net(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x: torch.Tensor):
        return torch.max(x, x + 1)

net = Net()
onnx_model = 'test.onnx'

torch.onnx.export(net, (torch.zeros((3, 3), dtype=torch.int32),),
                  onnx_model, verbose=True, opset_version=11)
```

This is an unexpected behavior as we want to ensure that every model exported by PyTorch is valid (https://github.com/pytorch/pytorch/pull/70584#issuecomment-1020636579). Theoretically, we can simply forbid such cases (e.g., `Clip<int>` < 12, `ReLU<int>` < 14). But actually we can enhance the compatibility and flexibility of PyTorch by simply casting inputs of those operators into float tensors, which allows the float operator functions, and then casting it back to original types.

This PR implements the second approach to achieve better compatibility in PyTorch.

@garymm  @thiagocrepaldi

Pull Request resolved: https://github.com/pytorch/pytorch/pull/72401
Approved by: https://github.com/garymm, https://github.com/thiagocrepaldi
2022-04-11 23:26:44 +00:00
BowenBao
2ba496bc23 [ONNX] Fix 1d case flatten export
Fixes #74142

Previous check `dim is not None and end_dim == dim - 2` didn't consider `end_dim` being negative. However the error only occurs when input tensor has rank 1, and the rank is known to symbolic function. So a better fix is to return `input` directly when rank is 1.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/74595
Approved by: https://github.com/garymm
2022-03-23 23:50:51 +00:00
BowenBao
9210e8f540 [ONNX] Adds overload_name to Aten op (#69378) (#73280)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73280

This PR adds a new attribute overload_name to the Aten node so that third party applications can implement calls to libtorch without using PyTorch source code.

This is necessary because torch's torch::jit::findOperatorFor(fullname) requires a full name, including operator and overload names.

ATen op was originally created for Caffe2, which leveraged the availability of the pytorch yaml files to create calls to the aten oeprators directly, not relying on torch::jit::findOperatorFor

The first part of the PR refactors all symbolics that create Aten ops, so that there is a single helper for this operator. Next all symbolics are updated to pass in the relevant overload name, if empty string is not applicable

Test Plan: Imported from OSS

Reviewed By: jbschlosser

Differential Revision: D34625645

Pulled By: malfet

fbshipit-source-id: 37d58cfb5231833768172c122efc42edf7d8609a
(cherry picked from commit e92f09117d3645b38bc3235b30aba4b4c7c71dfa)
2022-03-09 14:26:18 +00:00