Commit Graph

59 Commits

Author SHA1 Message Date
Thiago Crepaldi
a8f40b39ce Update all ONNX symbolics with new JitScalarType API (#87245)
Fixes https://github.com/pytorch/pytorch/issues/84365 and more

This PR addresses not only the issue above, but the entire family of issues related to `torch._C.Value.type()` parsing when `scalarType()` or `dtype()` is not available.

This issue exists before `JitScalarType` was introduced, but the new implementation refactored the bug in because the new api `from_name` and `from_dtype` requires parsing `torch._C.Value.type()` to get proper inputs, which is exactly the root cause for this family of bugs.

Therefore `from_name` and `from_dtype` must be called when the implementor knows the `name` and `dtype` without parsing a `torch._C.Value`. To handle the corner cases hidden within `torch._C.Value`, a new `from_value` API was introduced and it should be used in favor of the former ones for most cases. The new API is safer and doesn't require type parsing from user, triggering JIT asserts in the core of pytorch.

Although CI is passing for all tests, please review carefully all symbolics/helpers refactoring to make sure the meaning/intetion of the old call are not changed in the new call

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87245
Approved by: https://github.com/justinchuby, https://github.com/BowenBao
2022-11-03 03:01:33 +00:00
Justin Chu
85a79a7f50 [ONNX] Expand _cast_ symbolic functions (#87666)
The `_cast_` family of symbolic functions has been created from a template function. Even though it saved some lines, it very much obscured the intention of the code. Since the list doesn't really change and the `_cast_` family are IIRC deprecated, it is safe for us to expand the templates and make the code more readable.

This PR also removes any direct calls to `_cast_` functions to maintain a consistent pattern of directly creating `Cast` nodes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87666
Approved by: https://github.com/BowenBao
2022-10-26 00:39:59 +00:00
Justin Chu
5deeb09d4e [ONNX] Annotate all g as GraphContext (#85491)
- Use g.opset to test export opset version
- Annotate all `g` as GraphContext

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85491
Approved by: https://github.com/AllenTiTaiWang, https://github.com/BowenBao
2022-09-28 22:39:28 +00:00
Justin Chu
3d2316670f [ONNX] Create GraphContext and load g.op method to the class (#84728)
This PR create the `GraphContext` class and relays all graph methods to _C.Graph as well as implements the `g.op`  method. The GraphContext object is passed into the symbolic functions in place of _C.Graph for compatibility with existing symbolic functions.

This way (1) we can type annotate all `g` args because the method is defined and (2) we can use additional context information in symbolic functions. (3) no more monkey patching on `_C.Graph`

Also

- Fix return type of `_jit_pass_fixup_onnx_controlflow_node`
- Create `torchscript.py` to house torch.Graph related functions
- Change `GraphContext.op` to create nodes in the Block instead of the Graph
- Create `add_op_with_blocks` to handle scenarios where we need to directly manipulate sub-blocks. Update loop and if symbolic functions to use this function.

## Discussion

Should we put all the context inside `SymbolicContext` and make it an attribute in the `GraphContext` class? This way we only define two attributes `GraphContext.graph` and `GraphContext.context`. Currently all context attributes are directly defined in the class.

### Decision

Keep GraphContext flatand note that it will change in the future.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84728
Approved by: https://github.com/AllenTiTaiWang, https://github.com/BowenBao
2022-09-28 22:21:55 +00:00
Justin Chu
76d60778eb [ONNX] Use decorators for symbolic function registration (#84448)
This is the 4th PR in the series of #83787. It enables the use of `@onnx_symbolic` across `torch.onnx`.

- **Backward breaking**: Removed some symbolic functions from `__all__` because of the use of  `@onnx_symbolic` for registering the same function on multiple aten names.
- Decorate all symbolic functions with `@onnx_symbolic`
- Move Quantized and Prim ops out from classes to functions defined in the modules. Eliminate the need for `isfunction` checking, speeding up the registration process by 60%.
    - Remove the outdated unit test `test_symbolic_opset9.py`
- Symbolic function registration moved from the first call to `_run_symbolic_function` to init time.
- Registration is fast:
  ![image](https://user-images.githubusercontent.com/11205048/189164959-f3fca173-19bc-4682-b150-f13a586387bf.png)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84448
Approved by: https://github.com/AllenTiTaiWang, https://github.com/BowenBao
2022-09-22 06:25:24 +00:00
Justin Chu
3dfb8dfcf3 [ONNX] Use errors.SymbolicValueError for more context (#83332)
Replace runtime errors in torch.onnx with `errors.SymbolicValueError` for more context around jit values.

- Extend `_unimplemented`, `_onnx_unsupported`, `_onnx_opset_unsupported`, `_onnx_opset_unsupported_detailed` errors to include JIT value information
- Replace plain RuntimeError with `errors.SymbolicValueError`
- Clean up: Use `_is_bool` to replace string comparison on jit types
- Clean up: Remove the todo `Remove type ignore after #81112`

#77316
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83332
Approved by: https://github.com/AllenTiTaiWang, https://github.com/thiagocrepaldi, https://github.com/BowenBao
2022-08-23 05:39:17 +00:00
Justin Chu
f5701a1f9a [ONNX] Remove unused patching methods (#83006)
### Description
<!-- What did you change and why was it needed? -->

Remove unused patching methods:

- `torch._C.Graph.constant`
- unpatch `torch._C.Node.__getitem__` and move the helper function to `symbolic_helper`

Add typing annotations

### Issue
<!-- Link to Issue ticket or RFP -->

#76254

### Testing
<!-- How did you test your change? -->

Unit tested
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83006
Approved by: https://github.com/BowenBao
2022-08-09 19:24:03 +00:00
Justin Chu
c6cdca5c68 [ONNX] Reland #81953 Type utility for converting among JIT, torch and ONNX data types (#82995)
Re-land #81953

Add `_type_utils` for handling data type conversion among JIT, torch and ONNX.

- Replace dictionary / list indexing with methods in ScalarType
- Breaking: **Remove ScalarType from `symbolic_helper`** and move it to `_type_utils`
- Deprecated: "cast_pytorch_to_onnx", "pytorch_name_to_type", "scalar_name_to_pytorch", "scalar_type_to_onnx", "scalar_type_to_pytorch_type" in `symbolic_helper`
- Deprecate the type mappings and lists. Remove all internal references
- Move _cast_func_template to opset 9 and remove its reference elsewhere (clean up). Added documentation for easy discovery

Why: List / dictionary indexing and lookup are error-prone and convoluted.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82995
Approved by: https://github.com/kit1980
2022-08-08 23:43:43 +00:00
PyTorch MergeBot
b170a52a09 Revert "[ONNX] Type utility for converting among JIT, torch and ONNX data types (#81953)"
This reverts commit 6ddf4c6f58.

Reverted https://github.com/pytorch/pytorch/pull/81953 on behalf of https://github.com/kit1980 due to Broke internal builds by removing functions without deprecation
2022-08-07 20:15:28 +00:00
Justin Chu
6ddf4c6f58 [ONNX] Type utility for converting among JIT, torch and ONNX data types (#81953)
Add `_type_utils` for handling data type conversion among JIT, torch and ONNX.

- Replace dictionary / list indexing with methods in ScalarType
- Breaking: **Remove ScalarType from `symbolic_helper`** and move it to `_type_utils`
- Breaking: **Remove "cast_pytorch_to_onnx", "pytorch_name_to_type", "scalar_name_to_pytorch", "scalar_type_to_onnx", "scalar_type_to_pytorch_type"** from `symbolic_helper`
- Deprecate the type mappings and lists. Remove all internal references
- Move _cast_func_template to opset 9 and remove its reference elsewhere (clean up). Added documentation for easy discovery

Why: List / dictionary indexing and lookup are error-prone and convoluted.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81953
Approved by: https://github.com/AllenTiTaiWang, https://github.com/BowenBao
2022-08-05 22:24:45 +00:00
Huy Do
6ea422dd0b Format torch/onnx with ufmt (#82137)
This is the last batch for the new ufmt (black + usort) linter. After this, black linter can finally be replaced. The previous PR to format ONNX tests was #81335
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82137
Approved by: https://github.com/kit1980, https://github.com/AllenTiTaiWang
2022-07-25 22:42:21 +00:00
qqaatw
efbee797b8 [ONNX] Fix prelu output's shape (#79846)
Part of #79263

Before: The output has `[1]` shape when the input is a scalar.
After: The output has `[]` shape, matching PyTorch's behavior.

The original comment along the code states `torch allows scalar self, and ONNX is ambiguous about whether this is allowed`. The fact seems to be that ONNX never clearly indicates whether scalar inputs are allowed for all the ONNX operators. At least in this case, a scalar input seems to be allowed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79846
Approved by: https://github.com/BowenBao
2022-06-22 02:20:27 +00:00
Justin Chu
d3ef5c3fa3 [ONNX] Clean up __init__ in torch.onnx (#78446)
- Move definitions in `__init__` to internal classes and expose them by importing to init (prevent circular dependencies): https://github.com/pytorch/pytorch/wiki/torch.onnx-Namespacing
  - Context classes and enums are moved to `_exporter_states.py`
  - Exceptions are moved to `errors.py`
- Define `__all__` for torch.onnx. https://github.com/pytorch/pytorch/wiki/Public-API-definition-and-documentation
- Moved `utils.__IN_ONNX_EXPORT` to `GLOBALS.in_onnx_export`
- Deprecated `torch.onnx._export`

Precedes #78231

Using this as an aid for finding public functions:

```python
list(filter(lambda x: not x.startswith("_"), torch.onnx.utils.__dict__.keys()))
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78446
Approved by: https://github.com/BowenBao
2022-06-14 04:35:06 +00:00
Justin Chu
161e931156 [ONNX] Modernize python syntax (#77935)
Use pyupgrade(https://github.com/asottile/pyupgrade) and flynt to modernize python syntax

```sh
pyupgrade --py36-plus --keep-runtime-typing torch/onnx/**/*.py
pyupgrade --py36-plus --keep-runtime-typing test/onnx/**/*.py
flynt torch/onnx/ --line-length 120
```

- Use f-strings for string formatting
- Use the new `super()` syntax for class initialization
- Use dictionary / set comprehension
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77935
Approved by: https://github.com/BowenBao
2022-05-24 22:52:37 +00:00
Justin Chu
0d76299ff7 [ONNX] Clean up module imports (#77423)
Cleaning up onnx module imports to prepare for updating `__init__`.

- Simplify importing the `_C` and `_C._onnx` name spaces
- Remove alias of the symbolic_helper module in imports
- Remove any module level function imports. Import modules instead
    - Alias `symbilic_opsetx` as `opsetx`
- Fix some docstrings

Requires:
- https://github.com/pytorch/pytorch/pull/77448
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77423
Approved by: https://github.com/BowenBao
2022-05-20 01:56:24 +00:00
Justin Chu
563c2719bf [ONNX] Refactor to remove inline imports - attempt 2 (#77448)
Re-land
- #77142

(diff: https://github.com/pytorch/pytorch/compare/c08b8f0..justinchuby:justinchu/remove-patch2)

Fixed:
- Delay import symbolic_opsets in the registry.

Tested locally with torchvision
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77448
Approved by: https://github.com/garymm
2022-05-16 14:44:24 +00:00
PyTorch MergeBot
6b366dd3c1 Revert "[ONNX] Refactor to remove inline imports (#77142)"
This reverts commit c08b8f0967.

Reverted https://github.com/pytorch/pytorch/pull/77142 on behalf of https://github.com/malfet
2022-05-13 19:44:17 +00:00
Justin Chu
c08b8f0967 [ONNX] Refactor to remove inline imports (#77142)
Reduce circular dependencies

- Lift constants and flags from `symbolic_helper` to `_constants` and `_globals`
    - Standardized constant naming to make it consistant
- Make `utils` strictly dependent on `symbolic_helper`, removing inline imports from symbolic_helper
- Move side effects from `utils` to `_patch_torch`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77142
Approved by: https://github.com/garymm, https://github.com/BowenBao
2022-05-13 03:46:33 +00:00
Justin Chu
5dd1c67776 [ONNX] Format ONNX python with black
Format all onnx python code with black and isort with

```sh
isort torch/onnx/ test/onnx
black torch/onnx/ test/onnx
```

Updated lintrunner config to include these paths.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76754
Approved by: https://github.com/suo, https://github.com/BowenBao
2022-05-05 00:19:22 +00:00
vfdev-5
3da2e09c9b Added antialias flag to interpolate (CPU only, bilinear) (#65142)
Summary:
Description:
- Added antialias flag to interpolate (CPU only)
  - forward and backward for bilinear mode
  - added tests

### Benchmarks

<details>
<summary>
Forward pass, CPU. PTH interpolation vs PIL
</summary>

Cases:
- PTH RGB 3 Channels, float32 vs PIL RGB uint8 (apply vs pears)
- PTH 1 Channel, float32 vs PIL 1 Channel Float

Code: https://gist.github.com/vfdev-5/b173761a567f2283b3c649c3c0574112

```
# OMP_NUM_THREADS=1 python bench_interp_aa_vs_pillow.py

Torch config: PyTorch built with:
  - GCC 9.3
  - C++ Version: 201402
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - CPU capability usage: AVX2
  - CUDA Runtime 11.1
  - NVCC architecture flags: -gencode;arch=compute_75,code=sm_75
  - CuDNN 8.0.5
  - Build settings: BUILD_TYPE=Release, CUDA_VERSION=11.1, CUDNN_VERSION=8.0.5, CXX_COMPILER=/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_PYTORCH_QNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.10.0, USE_CUDA=1, USE_CUDNN=1, USE_EIGEN_FOR_BLAS=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=OFF, USE_MKLDNN=OFF, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=0, USE_OPENMP=ON,

Num threads: 1
[------------------------ Downsampling: torch.Size([1, 3, 906, 438]) -> (320, 196) ------------------------]
                                                  |  Reference, PIL 8.3.2, mode: RGB  |  1.10.0a0+git1e87d91
1 threads: -------------------------------------------------------------------------------------------------
      channels_first contiguous torch.float32     |                2.9                |          3.1
      channels_last non-contiguous torch.float32  |                2.6                |          3.6

Times are in milliseconds (ms).

[------------------------ Downsampling: torch.Size([1, 3, 906, 438]) -> (460, 220) ------------------------]
                                                  |  Reference, PIL 8.3.2, mode: RGB  |  1.10.0a0+git1e87d91
1 threads: -------------------------------------------------------------------------------------------------
      channels_first contiguous torch.float32     |                3.4                |          4.0
      channels_last non-contiguous torch.float32  |                3.4                |          4.8

Times are in milliseconds (ms).

[------------------------ Downsampling: torch.Size([1, 3, 906, 438]) -> (120, 96) -------------------------]
                                                  |  Reference, PIL 8.3.2, mode: RGB  |  1.10.0a0+git1e87d91
1 threads: -------------------------------------------------------------------------------------------------
      channels_first contiguous torch.float32     |                1.6                |          1.8
      channels_last non-contiguous torch.float32  |                1.6                |          1.9

Times are in milliseconds (ms).

[----------------------- Downsampling: torch.Size([1, 3, 906, 438]) -> (1200, 196) ------------------------]
                                                  |  Reference, PIL 8.3.2, mode: RGB  |  1.10.0a0+git1e87d91
1 threads: -------------------------------------------------------------------------------------------------
      channels_first contiguous torch.float32     |                9.0                |          11.3
      channels_last non-contiguous torch.float32  |                8.9                |          12.5

Times are in milliseconds (ms).

[----------------------- Downsampling: torch.Size([1, 3, 906, 438]) -> (120, 1200) ------------------------]
                                                  |  Reference, PIL 8.3.2, mode: RGB  |  1.10.0a0+git1e87d91
1 threads: -------------------------------------------------------------------------------------------------
      channels_first contiguous torch.float32     |                2.1                |          1.8
      channels_last non-contiguous torch.float32  |                2.1                |          3.4

Times are in milliseconds (ms).

[--------------- Downsampling: torch.Size([1, 1, 906, 438]) -> (320, 196) --------------]
                                 |  Reference, PIL 8.3.2, mode: F  |  1.10.0a0+git1e87d91
1 threads: ------------------------------------------------------------------------------
       contiguous torch.float32  |               1.2               |          1.0

Times are in milliseconds (ms).

[--------------- Downsampling: torch.Size([1, 1, 906, 438]) -> (460, 220) --------------]
                                 |  Reference, PIL 8.3.2, mode: F  |  1.10.0a0+git1e87d91
1 threads: ------------------------------------------------------------------------------
       contiguous torch.float32  |               1.4               |          1.3

Times are in milliseconds (ms).

[--------------- Downsampling: torch.Size([1, 1, 906, 438]) -> (120, 96) ---------------]
                                 |  Reference, PIL 8.3.2, mode: F  |  1.10.0a0+git1e87d91
1 threads: ------------------------------------------------------------------------------
       contiguous torch.float32  |              719.9              |         599.9

Times are in microseconds (us).

[-------------- Downsampling: torch.Size([1, 1, 906, 438]) -> (1200, 196) --------------]
                                 |  Reference, PIL 8.3.2, mode: F  |  1.10.0a0+git1e87d91
1 threads: ------------------------------------------------------------------------------
       contiguous torch.float32  |               3.7               |          3.5

Times are in milliseconds (ms).

[-------------- Downsampling: torch.Size([1, 1, 906, 438]) -> (120, 1200) --------------]
                                 |  Reference, PIL 8.3.2, mode: F  |  1.10.0a0+git1e87d91
1 threads: ------------------------------------------------------------------------------
       contiguous torch.float32  |              834.4              |         605.7

Times are in microseconds (us).

```

</details>

Code is moved from torchvision: https://github.com/pytorch/vision/pull/4208

Pull Request resolved: https://github.com/pytorch/pytorch/pull/65142

Reviewed By: mrshenli

Differential Revision: D32432405

Pulled By: jbschlosser

fbshipit-source-id: b66c548347f257c522c36105868532e8bc1d4c6d
2021-11-17 09:10:15 -08:00
Gary Miguel
eb22d06e5e [ONNX] Use human readable enum for dtype scalars (#66822) (#67807)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67807

Also make quoting of string literals consistent.

Test Plan: Imported from OSS

Reviewed By: msaroufim

Differential Revision: D32181309

Pulled By: malfet

fbshipit-source-id: e1053701e3589f0310d8b5ef920359c03c6713f0
2021-11-08 14:37:05 -08:00
BowenBao
d4ff344fae [ONNX] Fix remainder export (#64230) (#64578)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64578

* Fix remainder export for edge case when input is negative. New export relies on true_divide export.
* Simplified true_divide export. Cleaned up redundant code which is handled by scalar type analysis pass. Removed dependency on `onnx::Where`, thus supports opset 7 & 8.

Fixes #60179

Test Plan: Imported from OSS

Reviewed By: jansel

Differential Revision: D30919601

Pulled By: malfet

fbshipit-source-id: 0f78621c0ac3bdb6bf4225e049ba5f470dc8ab12

Co-authored-by: BowenBao <bowbao@microsoft.com>
2021-09-30 21:08:54 -07:00
BowenBao
0a6828a306 [ONNX] use consistent quoting for string literals (#57757) (#58695)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58695

As PEP8 says: "Pick a rule and stick to it." [1]

[1] https://www.python.org/dev/peps/pep-0008/#string-quotes

Test Plan: Imported from OSS

Reviewed By: driazati

Differential Revision: D28714811

Pulled By: SplitInfinity

fbshipit-source-id: c95103aceb1725c17c034dc6fc8216627f189548

Co-authored-by: Gary Miguel <garymiguel@microsoft.com>
2021-05-27 12:06:42 -07:00
Sam Estep
75024e228c Add lint for unqualified type: ignore (#56290)
Summary:
The other half of https://github.com/pytorch/pytorch/issues/56272.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/56290

Test Plan:
CI should pass on the tip of this PR, and we know that the lint works because the following CI runs (before this PR was finished) failed:

- https://github.com/pytorch/pytorch/runs/2384511062
- https://github.com/pytorch/pytorch/actions/runs/765036024

Reviewed By: seemethere

Differential Revision: D27867219

Pulled By: samestep

fbshipit-source-id: e648f07b6822867e70833e23ddafe7fb7eaca235
2021-04-21 08:07:23 -07:00
BowenBao
a572f70f2f [ONNX] Support torch.isinf, torch.any and torch.all export to ONNX (#53328) (#53529)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53529

Supported for ONNX export after opset 10.
This is not exportable to opsets < 10 due to
1. onnx::IsInf is introduced in opset 10
2. onnx::Equal does not accept float tensor prior to opset 11

Test Plan: Imported from OSS

Reviewed By: pbelevich, malfet

Differential Revision: D26922418

Pulled By: SplitInfinity

fbshipit-source-id: 69bcba50520fa3d69db4bd4c2b9f88c00146fca7
2021-03-12 02:49:41 -08:00
BowenBao
a6a811f23a [ONNX] Add repeat_interleave symbolic (#52855) (#53312)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53312

- Add support for aten::repeat_interleave
- NOTE: Also adds fix for cases with split op where input tensor sizes are not known but _outputs is provided

Test Plan: Imported from OSS

Reviewed By: pbelevich, malfet

Differential Revision: D26922422

Pulled By: SplitInfinity

fbshipit-source-id: 5362d0d8ccfdc14c15e1ae73fd70c4c113f823e6
2021-03-12 02:49:34 -08:00
Mike Ruberry
594a66d778 Warn about floor_divide performing incorrect rounding (#50281) (#50281)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/50281

Pull Request resolved: https://github.com/pytorch/pytorch/pull/51745

Test Plan: Imported from OSS

Reviewed By: ngimel

Pulled By: mruberry

Differential Revision: D26257855

fbshipit-source-id: e5d497cf07b0c746838ed081c5d0e82fb4cb701b
2021-02-10 03:13:34 -08:00
BowenBao
e5a98c5ab0 [ONNX] Remove usage of isCompleteTensor() in symbolic functions (#48162)
Summary:
`isCompleteTensor()` only returns true when both scalar type and shape is present. All dimensions in the shape must be static. This high requirement is unnecessary for many use cases such as when only rank or scalar type needs to be known.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/48162

Reviewed By: malfet

Differential Revision: D25340823

Pulled By: bzinodev

fbshipit-source-id: 1fef61f44918f4339dd6654fb725b18cd58d99cf
2020-12-09 11:37:19 -08:00
Guilherme Leobas
34cc77a811 Torch onnx (#48980)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/45215

This is a follow up PR of https://github.com/pytorch/pytorch/issues/45258 and https://github.com/pytorch/pytorch/issues/48782

Pull Request resolved: https://github.com/pytorch/pytorch/pull/48980

Reviewed By: zhangguanheng66

Differential Revision: D25399823

Pulled By: ezyang

fbshipit-source-id: 798055f4abbbffecdfab0325884193c81addecec
2020-12-08 19:41:44 -08:00
Edward Yang
88ebf6f894 Revert D25304229: [pytorch][PR] Add type annotations to torch.onnx.* modules
Test Plan: revert-hammer

Differential Revision:
D25304229 (8bc6023d7a)

Original commit changeset: b01b21ddbf86

fbshipit-source-id: bc3308176e2c70423f29f694e9db94828213e7d6
2020-12-07 11:58:03 -08:00
Guilherme Leobas
8bc6023d7a Add type annotations to torch.onnx.* modules (#48782)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/45215

This is a follow up PR of https://github.com/pytorch/pytorch/issues/45258

Pull Request resolved: https://github.com/pytorch/pytorch/pull/48782

Reviewed By: heitorschueroff

Differential Revision: D25304229

Pulled By: ezyang

fbshipit-source-id: b01b21ddbf86f908ca08173e68b81fb25851bc81
2020-12-07 08:23:02 -08:00
Mike Ruberry
6299c870ee Revert D25254920: [pytorch][PR] Add type annotations to torch.onnx.* modules
Test Plan: revert-hammer

Differential Revision:
D25254920 (40a2dd7e1e)

Original commit changeset: dc9dc036da43

fbshipit-source-id: c17cb282ebf90ecbae4023aa63ecbb443a87037d
2020-12-02 02:25:31 -08:00
Guilherme Leobas
40a2dd7e1e Add type annotations to torch.onnx.* modules (#45258)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/45215

Still need to resolve a few mypy issues before a review. In special, there is an error which I don't know how to solve, see:
```python
torch/onnx/utils.py:437: error: Name 'is_originally_training' is not defined  [name-defined]
        if training is None or training == TrainingMode.EVAL or (training == TrainingMode.PRESERVE and not is_originally_training):
```

`is_originally_training` is used but never defined/imported on [`torch/onnx/utils.py`](ab5cc97fb0/torch/onnx/utils.py (L437)),

Pull Request resolved: https://github.com/pytorch/pytorch/pull/45258

Reviewed By: zhangguanheng66

Differential Revision: D25254920

Pulled By: ezyang

fbshipit-source-id: dc9dc036da43dd56b23bd6141e3ab92e1a16e3b8
2020-12-01 20:41:39 -08:00
Negin Raoof
a77d633db1 [ONNX] Fix view for dynamic input shape (#43558)
Summary:
Export of view op with dynamic input shape is broken when using tensors with a 0-dim.
This fix removes symbolic use of static input size to fix this issue.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/43558

Reviewed By: ailzhang

Differential Revision: D23965090

Pulled By: bzinodev

fbshipit-source-id: 628e9d7ee5d53375f25052340ca6feabf7ba7c53
2020-09-28 14:46:51 -07:00
Xiang Gao
20ac736200 Remove py2 compatible future imports (#44735)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44735

Reviewed By: mruberry

Differential Revision: D23731306

Pulled By: ezyang

fbshipit-source-id: 0ba009a99e475ddbe22981be8ac636f8a1c8b02f
2020-09-16 12:55:57 -07:00
shubhambhokare1
f3bf6a41ca [ONNX] Update repeat op (#43430)
Summary:
Update repeat op so that the inputs to sizes argument can a mixture of dynamic and constant inputs

Pull Request resolved: https://github.com/pytorch/pytorch/pull/43430

Reviewed By: houseroad

Differential Revision: D23494257

Pulled By: bzinodev

fbshipit-source-id: 90c5e90e4f73e98f3a9d5c8772850e72cecdf0d4
2020-09-04 18:53:31 -07:00
Deepak Velmurugan
c9f125bf70 Black to Block for various files (#42913)
Summary:
Fixes  https://github.com/pytorch/pytorch/issues/41735 #41736 https://github.com/pytorch/pytorch/issues/41737 #41738 all areas where black is mentioned is replaced to block

Pull Request resolved: https://github.com/pytorch/pytorch/pull/42913

Reviewed By: houseroad

Differential Revision: D23112873

Pulled By: malfet

fbshipit-source-id: a515b56dc2ed20aa75741c577988d95f750b364c
2020-08-25 17:43:31 -07:00
BowenBao
c4f10e0fe7 Renaming scales parameter for interpolate (#31526)
Summary:
PR separated from https://github.com/pytorch/pytorch/pull/31274.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/31526

Reviewed By: zou3519

Differential Revision: D19221931

Pulled By: gchanan

fbshipit-source-id: 81958a9910867ac9d62f2b47abc49384526c4e51
2020-01-02 08:19:30 -08:00
Lara
97c1e90f46 ONNX Interpolate Add Scales Params (#28324)
Summary:
Fix for : https://github.com/pytorch/pytorch/issues/27176
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28324

Reviewed By: hl475

Differential Revision: D18309133

Pulled By: houseroad

fbshipit-source-id: 348bb41393442c6b107d88fc2cd3224e0afa3ccf
2019-12-11 20:09:15 -08:00
Vitaly Fedyunin
4bfe2f0900 Fix jit outplace tracing and reapply changes to *_like operators. (#28839)
Summary:
Reapply reverted and fix files `gen_variable_type.py` `test_jit.py`

https://github.com/pytorch/pytorch/issues/27891 Cleanup testing of _like operators
https://github.com/pytorch/pytorch/issues/27890 Add memory format support to randn_like operator
https://github.com/pytorch/pytorch/issues/27889 Add memory format support to randint_like operator
https://github.com/pytorch/pytorch/issues/27562 Add memory format support to zeros_like operator
https://github.com/pytorch/pytorch/issues/27561 Add memory format support to rand_like operator
https://github.com/pytorch/pytorch/issues/27270 Add memory format support to ones_like operator
https://github.com/pytorch/pytorch/issues/27262 Add memory format support to full_like operator
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28839

Test Plan:
Imported from GitHub, without a `Test Plan:` line.

buck test mode/dev //language_technology/neural_mt/os/pytorch_translate/test:test_onnx -- 'test_forced_decoder_export_vocab_reduction \(language_technology\.neural_mt\.os\.pytorch_translate\.test\.test_onnx\.TestONNX\)'

Differential Revision: D18203397

Pulled By: VitalyFedyunin

fbshipit-source-id: eea41cbd4c232cf5a54172b1e1b16b173798f298
2019-10-31 13:23:08 -07:00
Vitaly Fedyunin
87c98acf5d Back out "Add memory format support to full_like operator" (#28803)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28803

Original commit changeset: 1761a9939aa7
ghstack-source-id: 92748946

Test Plan: buck test language_technology/neural_mt/os/pytorch_translate/test:test_onnx

Reviewed By: ifedan

Differential Revision: D18175282

fbshipit-source-id: d3f537bbed50a4524797edd96b210b8455ef1bcc
2019-10-28 12:44:53 -07:00
Vitaly Fedyunin
d828fef8ac Back out "Add memory format support to ones_like operator" (#28802)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28802

Original commit changeset: 5da9530f6b23
ghstack-source-id: 92748794

Test Plan: buck test language_technology/neural_mt/os/pytorch_translate/test:test_onnx

Reviewed By: ifedan

Differential Revision: D18175303

fbshipit-source-id: ac36c7d345cba901bc2b64dc22661b8d0f179f13
2019-10-28 12:44:49 -07:00
Vitaly Fedyunin
7ff272c6da Back out D17980308-D17980313 (#28748)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28748

Found D17980313 to break unit tests, backed out descendants too to avoid conflicts.

Test Plan:
Failed on master:

   buck test mode/dev-nosan language_technology/neural_mt/fb/pytorch_translate/test:test_onnx

Passes with this diff.

Differential Revision: D18157588

fbshipit-source-id: e2b56eac8c5bfccf3ce9a3a2993f6332ab1471e7
2019-10-26 13:08:49 -07:00
Vitaly Fedyunin
c258cd039a Add memory format support to zeros_like operator (#27562)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27562

Adds memory_format keyword argument (positional for cpp).

'Preserve' behavior now follows next rules:
1) If tensor is non-overlapping and dense - output tensor will have the same strides as input tensor.
2) If not (1) and tensor is stored in the channels last format, output tensor going to have channels last format.
3) Output tensor is going to be contiguous in all other cases.

 ---
Dense tensor is the tensor that store values in a contiguous block of memory.
Non-overlapping tensor is the tensor in which elements occupy individual non-repetitive memory.

Test Plan: Imported from OSS

Differential Revision: D17980313

Pulled By: VitalyFedyunin

fbshipit-source-id: 9ca8453dc1a554ceea93c6949e01263cc576384b
2019-10-25 07:29:16 -07:00
Vitaly Fedyunin
2c339a24ec Add memory format support to ones_like operator (#27270)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27270

Adds memory_format keyword argument (positional for cpp).

'Preserve' behavior now follows next rules:
1) If tensor is non-overlapping and dense - output tensor will have the same strides as input tensor.
2) If not (1) and tensor is stored in the channels last format, output tensor going to have channels last format.
3) Output tensor is going to be contiguous in all other cases.

 ---
Dense tensor is the tensor that store values in a contiguous block of memory.
Non-overlapping tensor is the tensor in which elements occupy individual non-repetitive memory.

Test Plan: Imported from OSS

Differential Revision: D17980312

Pulled By: VitalyFedyunin

fbshipit-source-id: 5da9530f6b239306dbb66d1dfeefe88237f13bbd
2019-10-25 07:29:08 -07:00
Vitaly Fedyunin
85d5aee863 Add memory format support to full_like operator (#27262)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27262

Adds memory_format keyword argument (positional for cpp).

'Preserve' behavior now follows next rules:
1) If tensor is non-overlapping and dense - output tensor will have the same strides as input tensor.
2) If not (1) and tensor is stored in the channels last format, output tensor going to have channels last format.
3) Output tensor is going to be contiguous in all other cases.

 ---
Dense tensor is the tensor that store values in a contiguous block of memory.
Non-overlapping tensor is the tensor in which elements occupy individual non-repetitive memory.

Test Plan: Imported from OSS

Differential Revision: D17980309

Pulled By: VitalyFedyunin

fbshipit-source-id: 1761a9939aa7c5ab23e927b897e25e225089a8e7
2019-10-25 07:29:04 -07:00
Lara
735463f210 ONNX Export Scripted Interpolate Op (#27566)
Summary:
We currently support exporting traced interpolate ops to ONNX.

Scripting interpolate op invokes aten::__interpolate in the Torch IR (instead of aten::upsample_[mode][dim]d), which we do not support yet.
This PR implements the ONNX symbolic for __interpolate() to support exporting interpolate in scripting scenarios.

Related open issue: https://github.com/pytorch/pytorch/issues/25807
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27566

Reviewed By: hl475

Differential Revision: D17817731

Pulled By: houseroad

fbshipit-source-id: e091793df503e2497f24821cf2954ff157492c75
2019-10-16 11:22:22 -07:00
Lara Haidar
2093fac4ee ONNX Export ConstantOfShape with default dtype (#27577)
Summary:
Exporting a scripted module to ONNX, with ops like torch.zeros(), fails when the dtype is not specified.
This PR adds support to exporting scripted torch.zeros() ops (and similar ops) without specifying the dtype (dtype will default to float).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27577

Reviewed By: hl475

Differential Revision: D17822318

Pulled By: houseroad

fbshipit-source-id: b2d4300b869e782a9b72534fea1263eb83744953
2019-10-09 17:05:35 -07:00
BowenBao
638c4375de Export index_fill and index_copy, fix caffe2 scatter
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23052

Reviewed By: hl475

Differential Revision: D16428486

Pulled By: houseroad

fbshipit-source-id: 8c5905052763fd70197c67aba5f28eeff0790721
2019-09-26 16:23:32 -07:00
Lara
d396c7332a Update ONNX Export for Interpolate in Opset 11 (#26778)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26778

- Add support for linear and cubic interpolate in opset 11.
- Add support for 1d and 3d interpolate in nearest mode for opset 7 and 8.
- Add tests for all cases of interpolate in ORT tests (nearest/linear/cubic, 1d/2d/3d, upsample/downsample).
Original PR resolved: https://github.com/pytorch/pytorch/pull/24805

Reviewed By: hl475

Differential Revision: D17564911

Pulled By: houseroad

fbshipit-source-id: 591e1f5b361854ace322eca1590f8f84d29c1a5d
2019-09-25 05:43:20 -07:00