Commit Graph

424 Commits

Author SHA1 Message Date
liqunfu
bbe846f430 Add symbolic_opset19.py and symbolic_opset20.py to support opset 19/20, extend opset 18 support (#118828)
Start to fix https://github.com/pytorch/pytorch/issues/114801

Co-authored-by: Thiago Crepaldi <thiagofc@microsoft.com>
Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118828
Approved by: https://github.com/thiagocrepaldi
2024-03-22 18:01:33 +00:00
Catherine Lee
4f5785b6b3 Enable possibly-undefined error code (#118533)
Fixes https://github.com/pytorch/pytorch/issues/118129

Suppressions automatically added with

```
import re

with open("error_file.txt", "r") as f:
    errors = f.readlines()

error_lines = {}
for error in errors:
    match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
    if match:
        file_path, line_number, error_type = match.groups()
        if file_path not in error_lines:
            error_lines[file_path] = {}
        error_lines[file_path][int(line_number)] = error_type

for file_path, lines in error_lines.items():
    with open(file_path, "r") as f:
        code = f.readlines()
    for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
        code[line_number - 1] = code[line_number - 1].rstrip() + f"  # type: ignore[{error_type}]\n"
    with open(file_path, "w") as f:
        f.writelines(code)
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Co-authored-by: Catherine Lee <csl@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118533
Approved by: https://github.com/Skylion007, https://github.com/zou3519
2024-01-30 21:07:01 +00:00
PyTorch MergeBot
40ece2e579 Revert "Enable possibly-undefined error code (#118533)"
This reverts commit 4f13f69a45.

Reverted https://github.com/pytorch/pytorch/pull/118533 on behalf of https://github.com/clee2000 due to sorry i'm trying to figure out a codev merge conflict, if this works i'll be back to rebase and merge ([comment](https://github.com/pytorch/pytorch/pull/118533#issuecomment-1917695185))
2024-01-30 19:00:34 +00:00
Edward Z. Yang
4f13f69a45 Enable possibly-undefined error code (#118533)
Fixes https://github.com/pytorch/pytorch/issues/118129

Suppressions automatically added with

```
import re

with open("error_file.txt", "r") as f:
    errors = f.readlines()

error_lines = {}
for error in errors:
    match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
    if match:
        file_path, line_number, error_type = match.groups()
        if file_path not in error_lines:
            error_lines[file_path] = {}
        error_lines[file_path][int(line_number)] = error_type

for file_path, lines in error_lines.items():
    with open(file_path, "r") as f:
        code = f.readlines()
    for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
        code[line_number - 1] = code[line_number - 1].rstrip() + f"  # type: ignore[{error_type}]\n"
    with open(file_path, "w") as f:
        f.writelines(code)
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118533
Approved by: https://github.com/Skylion007, https://github.com/zou3519
2024-01-30 05:08:10 +00:00
CYuxian
f543093e06 [ONNX] Fix output mismatch issue of repeat_interleave when dim is None (#116689)
'input' is introduced but it's mixed with 'self' in repeat_interleave, which causes the mismatch issue.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116689
Approved by: https://github.com/thiagocrepaldi
2024-01-03 18:38:00 +00:00
Thiago Crepaldi
16f82198ca Export ReduleL1/ReduceL2 ONNX ops for aten::linalg_vector_norm(ord={1,2}) (#113173)
After #84624, aten::linalg_vector_norm started being used instead of aten::norm. In the ONNX exporter, the latter leveraged Reduce{L1,L2} when p={1,2}, which resulted in more optimized code in the ONNX Runtime

This PR extends aten::linal_vector_norm to also use Reduce{L1,L2} when ord={1,2}, producing an equivalent ONNX subgraph

This PR is a WIP. Pending work include checking argument equivalence between `aten::norm` and `aten::linalg_vector_norm` and maybe re-enable tests disabled by #84624
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113173
Approved by: https://github.com/justinchuby
2023-11-08 19:08:43 +00:00
Peter Bell
46e80ce58a [ATen] Support multi dim any and all reductions (#110310)
This adds a new overload to `all` and `any` with support for multiple reduction dims.
```
all.dims(Tensor self, int[1]? dim=None, bool keepdim=False) -> Tensor
any.dims(Tensor self, int[1]? dim=None, bool keepdim=False) -> Tensor
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110310
Approved by: https://github.com/lezcano, https://github.com/albanD, https://github.com/justinchuby
2023-10-24 21:33:53 +00:00
Thiago Crepaldi
9ab6ac5bc1 [ONNX] Fix aten::new_zeros due to TorchScript behavior change on Pytorch 2.1 Fix #110935 (#110956)
Fixes #110597

Summary:

* Generic code: The `torch._C.Value.node().mustBeNone()` is encapsulated into the high-level API `JitScalarType.from_value` ; `_is_none` was also extended to allow either `None` or `torch._C.Value.node.mustBeNone()`, so users don't manually call into TorchScript API when implementing operators
* Specific to `new_zeros` (and ops of ` *_like`  and `new_*`): When checking `dtype`, we always must use ` _is_none`, which will call  proposed by #110935
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110956
Approved by: https://github.com/justinchuby, https://github.com/BowenBao
2023-10-16 18:28:20 +00:00
Kazuaki Ishizaki
f7ce19d40a Fix typo under torch/onnx directory (#110697)
This PR fixes typo of comments in files under `torch/onnx` directory.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110697
Approved by: https://github.com/ezyang
2023-10-06 18:21:00 +00:00
PyTorch MergeBot
a5364b12bb Revert "[ONNX] Remove the depreacated function _export (#109763)"
This reverts commit d7c05bb2e8.

Reverted https://github.com/pytorch/pytorch/pull/109763 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/109763#issuecomment-1734201053))
2023-09-25 17:47:21 +00:00
wangxiyuan
d7c05bb2e8 [ONNX] Remove the depreacated function _export (#109763)
`_export` API was depreacated and should be removed after 2.0.

See: https://github.com/pytorch/pytorch/pull/107208

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109763
Approved by: https://github.com/thiagocrepaldi
2023-09-22 07:14:13 +00:00
PyTorch MergeBot
cd31c170c9 Revert "[ONNX] Remove deprecated functions (#107208)"
This reverts commit 263ca7d69b.

Reverted https://github.com/pytorch/pytorch/pull/107208 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/107208#issuecomment-1726183104))
2023-09-19 17:26:48 +00:00
CYuxian
504dceacb1 [ONNX] Fix indexing issue of meshgrid op (#109350)
Should unpack tensor_list before swapping the elements for indexing 'xy'.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109350
Approved by: https://github.com/thiagocrepaldi
2023-09-15 19:49:43 +00:00
wangxiyuan
263ca7d69b [ONNX] Remove deprecated functions (#107208)
The usage of some functions is deprecated. This PR drop them.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107208
Approved by: https://github.com/justinchuby, https://github.com/thiagocrepaldi
2023-09-14 19:09:56 +00:00
Thiago Crepaldi
4be6b6b673 Add quantization support to reshape and size for the ONNX exporter (#106629)
Fixes https://github.com/microsoft/onnx-converters-private/issues/175

Add quantization support for Reshape-14, Size-9 and Size-11
For Size operators, we don't requantize outputs because we want the original scalar in the graph
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106629
Approved by: https://github.com/BowenBao
2023-08-05 02:08:52 +00:00
BowenBao
bf40561ab4 [ONNX] Support 'aten::randint' in torchscript onnx exporter (#105089)
Export as 'ONNX::RandomUniform' which produces floating point result,
then round it to integer with 'ONNX::Cast'.

Fixes https://github.com/microsoft/onnx-converters-private/issues/173
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105089
Approved by: https://github.com/thiagocrepaldi
2023-07-13 01:50:03 +00:00
Ilya Sherstyuk
8c0b9a2d69 [ONNX] Export dynamic step size for aten::slice() (#104385)
This commit improves the export of aten::slice() to ONNX in the following ways:

1. The step size can be an input tensor rather than a constant.
2. Fixes a bug where using a 1-D, 1-element torch tensor as an index created a broken ONNX model.

This commit also adds tests for the new functionality.

Fixes #104314

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104385
Approved by: https://github.com/thiagocrepaldi
2023-07-06 21:38:59 +00:00
CYuxian
42b0bdd0c5 [onnx] Convert aten::flatten with 0d input to onnx Reshape and 1d to Identity (#104089)
Avoid empty tensor generated by Slice op if using _flatten_helper for aten::flatten with 0d/1d input.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104089
Approved by: https://github.com/thiagocrepaldi
2023-06-28 17:01:43 +00:00
Tung D. Le
b77f1b0f27 Wrong type when exporting {zeros, ones, full, empty, rand, randn}_like ops to onnx (#103048)
Fixes #99788

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103048
Approved by: https://github.com/thiagocrepaldi
2023-06-13 17:17:28 +00:00
AllenTiTaiWang
1ca2e993af [ONNX] Support aten::logit (#102377)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102377
Approved by: https://github.com/BowenBao
2023-06-02 03:39:35 +00:00
AllenTiTaiWang
0df691df4e [ONNX] Support aten::broadcast_to (#101833)
Support aten::broadcast as the way we support on aten::expand.

Fix #92678 #101768
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101833
Approved by: https://github.com/thiagocrepaldi
2023-05-19 16:59:54 +00:00
Peter Pham
3362c1d240 [ONNX] add cast operator after reduce to match desired dtype (#100700)
This PR conditionally inserts a cast operator after a reduction operation  to match the specified dtype in the exported ONNX model.  The code changes affect **opset9**, and **opset13**.

I understand there's an [automatic upcast to int64](c91a41fd68/torch/onnx/symbolic_opset9.py (L783)) before reduction most likely to prevent overflow so I left that alone and only conditionally add casting back to desired dtype.

## Test int32
```
import torch
import onnx
a = torch.tensor([10, 20, 30, 80], dtype=torch.int32)
def test():
    class SumInt32(torch.nn.Module):
        def forward(self, a):
            return torch.sum(a, dtype=torch.int32)

    sumi = SumInt32().eval()
    assert sumi(a).dtype == torch.int32
    print("Torch model output type matches input type")

    torch.onnx.export(sumi, (a), "/tmp/sumi_int32.onnx", opset_version=12)
    model = onnx.load("/tmp/sumi_int32.onnx")

    assert model.graph.output[0].type.tensor_type.elem_type == onnx.TensorProto.INT32
    print("ONNX model output type matches input type")
test()
```
![sumi_int32 onnx](https://user-images.githubusercontent.com/10516699/236499220-59b64821-5807-4f69-b0e2-90ae34280e03.png)

## Test int64

```
import onnx
import torch

a = torch.tensor([10, 20, 30, 80], dtype=torch.int64)

def test():
    class SumInt64(torch.nn.Module):
        def forward(self, a):
            return torch.sum(a, dtype=torch.int64)

    sumi = SumInt64().eval()
    assert sumi(a).dtype == torch.int64
    print("Torch model output type matches input type")
    torch.onnx.export(sumi, (a), "/tmp/sumi_int64.onnx", opset_version=12)
    model = onnx.load("/tmp/sumi_int64.onnx")
    assert model.graph.output[0].type.tensor_type.elem_type == onnx.TensorProto.INT64
    print("ONNX model output type matches input type")

test()

```
![sum_int64 onnx](https://user-images.githubusercontent.com/10516699/236422133-15f9cda3-242f-46da-9b23-c2e920f27078.png)

Fixes https://github.com/pytorch/pytorch/issues/100097

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100700
Approved by: https://github.com/thiagocrepaldi
2023-05-06 00:05:57 +00:00
Ilya Sherstyuk
40df6e1647 [ONNX] Simplify repeat_intereleave export for scalar-valued 'repeat' (#100575)
This PR simplifies the ONNX export of torch.repeat_interleave when 'repeat' is a scalar value (so each index in the input is repeated the same number of times). (Issue #100438)

Here is a before/after of a simple model export:
```python
# Model + export code
import torch

class RepeatInterleaveModel(torch.nn.Module):
    def forward(self, x):
        return x.repeat_interleave(2, dim=-1)

args = (torch.rand((2, 2, 16)),)
model = RepeatInterleaveModel()
torch.onnx.export(model, args, "repeat_interleave.onnx", opset_version=17)
```

**Before (static shapes)**
![repeat_interleave onnx(1)](https://user-images.githubusercontent.com/46343317/236014996-00726832-1e76-4fb4-950d-4b54cc5cc20c.png)

-----
**Before (dynamic shapes, second graph is Loop body)**
<p float="left">
  <img src="https://user-images.githubusercontent.com/46343317/236029895-20b0ae0a-240f-466d-bb01-e619ec5967ad.png" width="45%" />
  <img src="https://user-images.githubusercontent.com/46343317/236029915-e67b808a-029b-4997-bc05-1ce59eec409a.png" width="47%" />
</p>

-----
**After (for both static and dynamic shapes)**
<img src="https://user-images.githubusercontent.com/46343317/236015235-633811cb-09a2-435d-a293-1b2bcb7dea50.png" width="66%" />

-----

This PR also fixes a bug where the exporter throws an expection when the input has dynamic shapes and the 'dim' parameter is not specified to torch.repeat_interleave. Also adds a new testcase to cover this. (Issue #100429)

Fixes #100438 and #100429

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100575
Approved by: https://github.com/BowenBao
2023-05-05 17:00:42 +00:00
shubhambhokare1
0595ecf00c [ONNX] Add symbolic for _convolution_mode (#89107)
As per #68880
implement the operator _convolution_mode in the ONNX exporter. This will allow user to leverage the padding 'str' mode where it can be set to 'valid' or 'same'.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89107
Approved by: https://github.com/titaiwangms, https://github.com/BowenBao
2023-05-03 20:42:30 +00:00
AllenTiTaiWang
08c49eee5e [ONNX] Support aten::atan2 in torchscript exporter (#100040)
Fixes #51334

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100040
Approved by: https://github.com/BowenBao
2023-04-26 04:00:47 +00:00
Aaron Gokaslan
e2a3817dfd [BE] Enable C419 rule for any all shortcircuiting (#99890)
Apparently https://github.com/pytorch/pytorch/pull/78142 made torch.JIT allow for simple generator expressions which allows us to enable rules that replace unnecessary list comprehensions with generators in any/all. This was originally part of #99280 but I split it off into this PR so that it can be easily reverted should anything break.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99890
Approved by: https://github.com/justinchuby, https://github.com/kit1980, https://github.com/malfet
2023-04-25 15:02:13 +00:00
Bas Aarts
b3b0fbca11 [ONNX] Export Relu6 without using Relu (#99022)
The clamp operator used in the export of Relu6 already clamps the input value in between 0 and 6. There's no need to first perform a Relu on the input tensor.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99022
Approved by: https://github.com/BowenBao
2023-04-19 06:18:14 +00:00
BowenBao
60a68477a6 Bump black version to 23.1.0 (#96578)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96578
Approved by: https://github.com/ezyang
2023-03-15 06:27:59 +00:00
BowenBao
b0a580a21d [ONNX] Export logical_not (#96315)
Fixes https://github.com/pytorch/pytorch/issues/95154

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96315
Approved by: https://github.com/justinchuby
2023-03-10 02:25:08 +00:00
Ilya Sherstyuk
6154be1dd1 [ONNX] Fix circular padding to support dynamic axes (#95647)
This commit fixes a bug where the ONNX exporter for circular padding queried the input tensor shape in order to get the correct 'end' index for a slice node. This doesn't work when the axis in question is has dynamic size. The commit fixes this by setting the 'end' index to INT_MAX, which is the recommended way of slicing to the end of a dimension with unknown size per ONNX spec.

See https://onnx.ai/onnx/operators/onnx__Slice.html

Also adds a regression test.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95647
Approved by: https://github.com/BowenBao
2023-03-10 00:29:33 +00:00
guyang3532
79d49c60c1 [ONNX] Fix expand_as (#95962)
Fixes [#ISSUE_NUMBER](https://github.com/pytorch/pytorch/issues/95961)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95962
Approved by: https://github.com/BowenBao, https://github.com/justinchuby
2023-03-07 22:11:50 +00:00
BowenBao
2fbbc3362b [ONNX] Support 'dtype' argument for 'aten::norm' (#95637)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95637
Approved by: https://github.com/titaiwangms
2023-03-01 00:07:34 +00:00
Justin Chu
afadc3697a [ONNX] Fix assert in cat (#94870)
The assert statement blocks tensors with unknown ranks. This change unblocks those cases. Needed for https://github.com/pytorch/vision/pull/7056

Verified against https://github.com/pytorch/vision/pull/7056
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94870
Approved by: https://github.com/BowenBao
2023-02-15 04:09:59 +00:00
Justin Chu
5ed7c701a3 [ONNX] Remove the deprecated monkey patches to torch.Graph (#94747)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94747
Approved by: https://github.com/BowenBao, https://github.com/Skylion007
2023-02-14 00:08:09 +00:00
BowenBao
88d0235b73 [ONNX] Update CI test environment; Add symbolic functions (#94564)
* CI Test environment to install onnx and onnx-script.
* Add symbolic function for `bitwise_or`, `convert_element_type` and `masked_fill_`.
* Update symbolic function for `slice` and `arange`.
* Update .pyi signature for `_jit_pass_onnx_graph_shape_type_inference`.

Co-authored-by: Wei-Sheng Chin <wschin@outlook.com>
Co-authored-by: Ti-Tai Wang <titaiwang@microsoft.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94564
Approved by: https://github.com/abock
2023-02-10 20:44:59 +00:00
Thiago Crepaldi
4e1bd4abe7 Fix scalar type resolution for optional tensor (#94427)
When TorchScript Value has an optional tensor, `dtype()` or `scalarType()` is not available and raise (by design).

The symbolic `_op_with_optional_float_cast` must check whether the tensor is otpional or not before calling the scalar type resolution API. This PR fixes that
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94427
Approved by: https://github.com/abock, https://github.com/shubhambhokare1
2023-02-09 15:22:02 +00:00
shubhambhokare1
fcde6dbbac [onnx] Add mse_loss symbolic (#90717)
Adds support for mse_loss operator
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90717
Approved by: https://github.com/BowenBao, https://github.com/titaiwangms, https://github.com/abock
2023-01-18 00:04:59 +00:00
Peter Bell
6912f7c564 Update references to 1.14 to 2.0 (#91769)
There won't be a 1.14 release, so these should be updated.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91769
Approved by: https://github.com/Skylion007, https://github.com/svekars, https://github.com/lezcano
2023-01-10 23:42:07 +00:00
AllenTiTaiWang
e3ed55d483 [ONNX] Add aten::zero support (#91731)
Fixes #90268

When we use `tensor.zero_()` with inplace slice, it actually uses `aten::zero` instead.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91731
Approved by: https://github.com/BowenBao
2023-01-07 11:07:54 +00:00
PyTorch MergeBot
08a378a286 Revert "[ONNX] Add aten::zero support (#91731)"
This reverts commit ff23508c0d.

Reverted https://github.com/pytorch/pytorch/pull/91731 on behalf of https://github.com/clee2000 due to failing test_correct_module_names ff23508c0d https://github.com/pytorch/pytorch/actions/runs/3859079162/jobs/6578419644
2023-01-06 23:57:57 +00:00
AllenTiTaiWang
ff23508c0d [ONNX] Add aten::zero support (#91731)
Fixes #90268

When we use `tensor.zero_()` with inplace slice, it actually uses `aten::zero` instead.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91731
Approved by: https://github.com/BowenBao
2023-01-06 22:48:54 +00:00
BowenBao
66745831d7 [ONNX] Support constant 'aten::__contains__' (#91660)
#84624 introduces an update on `torch.norm` [dispatch logic](eaa43d9f25/torch/functional.py (L1489)) which now depends on `layout`. Resulting in regressions to export related operators from TorchScript.

This PR resolves the regression by partially supporting a subset use case of `prim::layout` (only `torch.strided`), `aten::__contains__` (only constants) operators. It requires much more effort to properly support other layouts, e.g. `torch.sparse_coo`. Extending JIT types, and supporting related family of ops like `aten::to_sparse`. This is out of the scope of this PR.

Fixes #83661
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91660
Approved by: https://github.com/justinchuby, https://github.com/kit1980
2023-01-06 01:39:32 +00:00
Justin Chu
634555d981 [ONNX] Auto test based on OpInfo (#86182)
This change introduces a mechanism to test onnx export based on sample inputs registered in OpInfo, similar to how MPS and other components of pytorch are tested. It provides test coverage on ops and dtypes previously unattainable with manually created test models. This is the best way for us to discover gaps in the exporter support, especially for ops with partial existing support.

This test is adapted from https://github.com/pytorch/pytorch/blob/master/test/test_mps.py

This PR also

- Update sqrt to support integer inputs to match pytorch behavior
- Add pytest-subtests for unittest subtests support in the new test file

I only enabled very few ops: `t`, `ceil` and `sqrt` because otherwise too many things will fail due to (1) unsupported dtypes in the exporter (2) unimplemented dtype support in onnxruntime (3) unexpected input to verification.verify.

Subsequent PRs should improve `verification.verify` first for it to accept any legal input to a pytorch model, then incrementally fix the symbolic functions to enable more test cases.

Fixes #85363
Design #88118
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86182
Approved by: https://github.com/BowenBao
2022-12-16 14:43:41 +00:00
titaiwang
06c98e673f [ONNX] Fix ignored small eps in layer normalization in fp16 (#89869)
Prior to this change, the symbolic_fn `layer_norm` (before ONNX version 17) always lose precision when eps is smaller than Float type, while PyTorch always take eps as Double. This PR adds `onnx::Cast` into eps related operations to prevent losing precision during the calculation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89869
Approved by: https://github.com/BowenBao
2022-12-08 06:13:09 +00:00
Thiago Crepaldi
6d794f6a4a [ONNX] Fix concat with empty tensors (#87620)
Fixes #54410

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87620
Approved by: https://github.com/BowenBao
2022-12-05 17:36:31 +00:00
PyTorch MergeBot
cba96366a2 Revert "remove torch.equal usages (#89527)"
This reverts commit 4095ef8b80.

Reverted https://github.com/pytorch/pytorch/pull/89527 on behalf of https://github.com/clee2000 due to broke periodic multigpu tests 4095ef8b80 https://github.com/pytorch/pytorch/actions/runs/3592806602/jobs/6049368502
2022-12-02 21:36:13 +00:00
Philip Meier
4095ef8b80 remove torch.equal usages (#89527)
Preparation for the next PR in this stack: #89559.

I replaced

- `self.assertTrue(torch.equal(...))` with `self.assertEqual(..., rtol=0, atol=0, exact_device=True)`,
- the same for `self.assertFalse(...)` with `self.assertNotEqual(...)`, and
- `assert torch.equal(...)` with `torch.testing.assert_close(..., rtol=0, atol=0)` (note that we don't need to set `check_device=True` here since that is the default).

There were a few instances where the result of `torch.equal` is used directly. In that cases I've replaced with `(... == ...).all().item()` while sometimes also dropping the `.item()` depending on the context.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89527
Approved by: https://github.com/mruberry
2022-12-01 11:22:52 +00:00
mindest
9fe36a0214 [ONNX] Extra support for bernoulli export (#88655)
* add opset 15 support for `bernoulli`.
* add extra export options for different `bernoulli` cases: `x.bernoulli(p)` where `p` is a tensor or float.

Fixes #88299

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88655
Approved by: https://github.com/BowenBao
2022-11-16 15:08:41 +00:00
Justin Chu
23a6e15321 [ONNX] Remove the INT64_MAX magic numbers (#88341)
Remove the magic numbers in symbolic opsets and use a INT64_MAX  global instead.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88341
Approved by: https://github.com/BowenBao
2022-11-03 20:18:36 +00:00
Thiago Crepaldi
a8f40b39ce Update all ONNX symbolics with new JitScalarType API (#87245)
Fixes https://github.com/pytorch/pytorch/issues/84365 and more

This PR addresses not only the issue above, but the entire family of issues related to `torch._C.Value.type()` parsing when `scalarType()` or `dtype()` is not available.

This issue exists before `JitScalarType` was introduced, but the new implementation refactored the bug in because the new api `from_name` and `from_dtype` requires parsing `torch._C.Value.type()` to get proper inputs, which is exactly the root cause for this family of bugs.

Therefore `from_name` and `from_dtype` must be called when the implementor knows the `name` and `dtype` without parsing a `torch._C.Value`. To handle the corner cases hidden within `torch._C.Value`, a new `from_value` API was introduced and it should be used in favor of the former ones for most cases. The new API is safer and doesn't require type parsing from user, triggering JIT asserts in the core of pytorch.

Although CI is passing for all tests, please review carefully all symbolics/helpers refactoring to make sure the meaning/intetion of the old call are not changed in the new call

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87245
Approved by: https://github.com/justinchuby, https://github.com/BowenBao
2022-11-03 03:01:33 +00:00