pytorch/test/onnx/error_reproduction.py
titaiwangms 34f7dc9eba [ONNX] Support op consistency error reproduction (#119512)
Fixes #119472

Introduce the debugging tool in onnxscript: https://github.com/microsoft/onnxscript/blob/main/onnxscript/tests/function_libs/torch_lib/error_reproduction.py

This tool can help us quickly find the inputs leading to mismatched errors.

NOTE: this produces `error_reports` folder where there are different markdown reports for each mismatched test cases.

For example - CREATE_REPRODUCTION_REPORT=1 python -m pytest onnxscript/tests/function_libs/torch_lib/ops_test.py -k test_output_match_fft_fft_cpu_bool

### Summary

The output of ONNX Runtime does not match that of PyTorch when executing test
`test_fx_op_consistency.TestOnnxModelOutputConsistency_opset_version_18_model_type_TorchModelType.TORCH_NN_MODULECPU.test_output_match_fft_fft_cpu_bool`, `sample 3` in ONNX Script `TorchLib`.

To recreate this report, use

```bash
CREATE_REPRODUCTION_REPORT=1 python -m pytest onnxscript/tests/function_libs/torch_lib/ops_test.py -k test_output_match_fft_fft_cpu_bool
```

### ONNX Model

```
<
   ir_version: 8,
   opset_import: ["pkg.onnxscript.torch_lib" : 1, "" : 18, "pkg.onnxscript.torch_lib.common" : 1],
   producer_name: "pytorch",
   producer_version: "2.2.0"
>
main_graph (bool[31] l_args_0_) => (float[31,2] _fft_r2c)
   <bool[31] l_args_0_, float[31] _to_copy, float[31,2] _fft_r2c>
{
   _to_copy = Cast <to: int = 1> (l_args_0_)
   _val_2 = Constant <value: tensor = int64[1] {-1}> ()
   _val_3 = Unsqueeze (_to_copy, _val_2)
   _val_4 = Constant <value: tensor = int64[1] {0}> ()
   _val_5 = Unsqueeze (_val_3, _val_4)
   _val_6 = DFT <axis: int = 1, inverse: int = 0, onesided: int = 0> (_val_5)
   _val_7 = Constant <value: tensor = int64[1] {0}> ()
   _val_8 = Squeeze (_val_6, _val_7)
   _fft_r2c = pkg.onnxscript.torch_lib._fftn_onnx_normalization <dims: ints = [0], forward: int = 1, normalization: int = 0> (_val_3, _val_8)
}
<
  domain: "pkg.onnxscript.torch_lib",
  opset_import: ["" : 18]
>
_fftn_onnx_normalization <normalization,forward,dims>(self, transformed) => (result_15)
{
   self_shape = Shape (self)
   dims = Constant <value_ints: ints = @dims> ()
   self_shape_subscripted = Gather <axis: int = 0> (self_shape, dims)
   total_sample_count = ReduceProd <keepdims: int = 0> (self_shape_subscripted)
   total_sample_count_0 = CastLike (total_sample_count, transformed)
   normalization = Constant <value_int: int = @normalization> ()
   int64_1 = Constant <value: tensor = int64 int64_1 {1}> ()
   cond = Equal (normalization, int64_1)
   result_15 = If (cond) <then_branch: graph = thenGraph_21 () => ( result_3) {
      forward = Constant <value_int: int = @forward> ()
      forward_as_bool = Cast <to: int = 9> (forward)
      result_3 = If (forward_as_bool) <then_branch: graph = thenGraph_23 () => ( result) {
         tmp = Sqrt (total_sample_count_0)
         result = Div (transformed, tmp)
      }, else_branch: graph = elseGraph_23 () => ( result_2) {
         tmp_1 = Sqrt (total_sample_count_0)
         result_2 = Mul (transformed, tmp_1)
      }>
   }, else_branch: graph = elseGraph_21 () => ( result_14) {
      normalization_4 = Constant <value_int: int = @normalization> ()
      int64_2 = Constant <value: tensor = int64 int64_2 {2}> ()
      cond_5 = Equal (normalization_4, int64_2)
      result_14 = If (cond_5) <then_branch: graph = thenGraph_27 () => ( result_9) {
         forward_6 = Constant <value_int: int = @forward> ()
         forward_6_as_bool = Cast <to: int = 9> (forward_6)
         result_9 = If (forward_6_as_bool) <then_branch: graph = thenGraph_29 () => ( result_7) {
            result_7 = Div (transformed, total_sample_count_0)
         }, else_branch: graph = elseGraph_29 () => ( result_8) {
            result_8 = Identity (transformed)
         }>
      }, else_branch: graph = elseGraph_27 () => ( result_13) {
         forward_10 = Constant <value_int: int = @forward> ()
         forward_10_as_bool = Cast <to: int = 9> (forward_10)
         result_13 = If (forward_10_as_bool) <then_branch: graph = thenGraph_35 () => ( result_11) {
            result_11 = Identity (transformed)
         }, else_branch: graph = elseGraph_35 () => ( result_12) {
            result_12 = Mul (transformed, total_sample_count_0)
         }>
      }>
   }>
}
<
  domain: "pkg.onnxscript.torch_lib.common",
  opset_import: ["" : 18]
>
Rank (input) => (return_val)
{
   tmp = Shape (input)
   return_val = Size (tmp)
}
<
  domain: "pkg.onnxscript.torch_lib.common",
  opset_import: ["" : 18]
>
IsScalar (input) => (return_val)
{
   tmp = Shape (input)
   tmp_0 = Size (tmp)
   tmp_1 = Constant <value_int: int = 0> ()
   return_val = Equal (tmp_0, tmp_1)
}
```

### Inputs

Shapes: `['Tensor<torch.Size([31]), dtype=torch.bool>']`

<details><summary>Details</summary>
<p>

```python
kwargs = {}
inputs = (tensor([False, False,  True,  True, False,  True, False,  True, False, False,
         True, False, False, False, False, False,  True,  True,  True,  True,
         True,  True,  True,  True, False, False, False, False,  True,  True,
         True]),)
```

</p>
</details>

### Expected output

Shape: `torch.Size([31, 2])`

<details><summary>Details</summary>
<p>

```python
expected = tensor([[16.0000,  0.0000],
        [-0.2369,  2.6590],
        [ 0.7336, -4.9670],
        [ 2.2093,  2.9865],
        [-0.7166,  1.0928],
        [-3.0614,  3.0015],
        [-1.8945, -0.9677],
        [-2.1538,  0.2513],
        [-2.2432,  1.3978],
        [-0.3429,  1.9494],
        [-0.6495, -1.5423],
        [-0.6005,  2.2398],
        [ 2.2639,  2.6430],
        [ 1.7609,  0.2033],
        [-1.3829, -2.3365],
        [-1.6854, -0.0311],
        [-1.6854,  0.0311],
        [-1.3829,  2.3365],
        [ 1.7609, -0.2033],
        [ 2.2639, -2.6430],
        [-0.6005, -2.2398],
        [-0.6495,  1.5423],
        [-0.3429, -1.9494],
        [-2.2432, -1.3978],
        [-2.1538, -0.2513],
        [-1.8945,  0.9677],
        [-3.0614, -3.0015],
        [-0.7166, -1.0928],
        [ 2.2093, -2.9865],
        [ 0.7336,  4.9670],
        [-0.2369, -2.6590]])
```

</p>
</details>

### Actual output

Shape: `torch.Size([31, 2])`

<details><summary>Details</summary>
<p>

```python
actual = tensor([[ 1.6000e+01, -9.1791e-06],
        [-2.3695e-01,  2.6590e+00],
        [ 7.3355e-01, -4.9670e+00],
        [ 2.2093e+00,  2.9865e+00],
        [-7.1663e-01,  1.0928e+00],
        [-3.0614e+00,  3.0015e+00],
        [-1.8946e+00, -9.6773e-01],
        [-2.1538e+00,  2.5126e-01],
        [-2.2432e+00,  1.3978e+00],
        [-3.4294e-01,  1.9494e+00],
        [-6.4946e-01, -1.5423e+00],
        [-6.0044e-01,  2.2398e+00],
        [ 2.2639e+00,  2.6430e+00],
        [ 1.7609e+00,  2.0326e-01],
        [-1.3829e+00, -2.3365e+00],
        [-1.6854e+00, -3.1130e-02],
        [-1.6854e+00,  3.1161e-02],
        [-1.3829e+00,  2.3365e+00],
        [ 1.7609e+00, -2.0327e-01],
        [ 2.2639e+00, -2.6430e+00],
        [-6.0047e-01, -2.2398e+00],
        [-6.4945e-01,  1.5423e+00],
        [-3.4294e-01, -1.9494e+00],
        [-2.2432e+00, -1.3978e+00],
        [-2.1538e+00, -2.5129e-01],
        [-1.8945e+00,  9.6773e-01],
        [-3.0615e+00, -3.0015e+00],
        [-7.1663e-01, -1.0928e+00],
        [ 2.2093e+00, -2.9865e+00],
        [ 7.3354e-01,  4.9670e+00],
        [-2.3695e-01, -2.6589e+00]])
```

</p>
</details>

### Difference

<details><summary>Details</summary>
<p>

```diff
--- actual
+++ expected
@@ -1,31 +1,31 @@
-tensor([[ 1.6000e+01, -9.1791e-06],
-        [-2.3695e-01,  2.6590e+00],
-        [ 7.3355e-01, -4.9670e+00],
-        [ 2.2093e+00,  2.9865e+00],
-        [-7.1663e-01,  1.0928e+00],
-        [-3.0614e+00,  3.0015e+00],
-        [-1.8946e+00, -9.6773e-01],
-        [-2.1538e+00,  2.5126e-01],
-        [-2.2432e+00,  1.3978e+00],
-        [-3.4294e-01,  1.9494e+00],
-        [-6.4946e-01, -1.5423e+00],
-        [-6.0044e-01,  2.2398e+00],
-        [ 2.2639e+00,  2.6430e+00],
-        [ 1.7609e+00,  2.0326e-01],
-        [-1.3829e+00, -2.3365e+00],
-        [-1.6854e+00, -3.1130e-02],
-        [-1.6854e+00,  3.1161e-02],
-        [-1.3829e+00,  2.3365e+00],
-        [ 1.7609e+00, -2.0327e-01],
-        [ 2.2639e+00, -2.6430e+00],
-        [-6.0047e-01, -2.2398e+00],
-        [-6.4945e-01,  1.5423e+00],
-        [-3.4294e-01, -1.9494e+00],
-        [-2.2432e+00, -1.3978e+00],
-        [-2.1538e+00, -2.5129e-01],
-        [-1.8945e+00,  9.6773e-01],
-        [-3.0615e+00, -3.0015e+00],
-        [-7.1663e-01, -1.0928e+00],
-        [ 2.2093e+00, -2.9865e+00],
-        [ 7.3354e-01,  4.9670e+00],
-        [-2.3695e-01, -2.6589e+00]])
+tensor([[16.0000,  0.0000],
+        [-0.2369,  2.6590],
+        [ 0.7336, -4.9670],
+        [ 2.2093,  2.9865],
+        [-0.7166,  1.0928],
+        [-3.0614,  3.0015],
+        [-1.8945, -0.9677],
+        [-2.1538,  0.2513],
+        [-2.2432,  1.3978],
+        [-0.3429,  1.9494],
+        [-0.6495, -1.5423],
+        [-0.6005,  2.2398],
+        [ 2.2639,  2.6430],
+        [ 1.7609,  0.2033],
+        [-1.3829, -2.3365],
+        [-1.6854, -0.0311],
+        [-1.6854,  0.0311],
+        [-1.3829,  2.3365],
+        [ 1.7609, -0.2033],
+        [ 2.2639, -2.6430],
+        [-0.6005, -2.2398],
+        [-0.6495,  1.5423],
+        [-0.3429, -1.9494],
+        [-2.2432, -1.3978],
+        [-2.1538, -0.2513],
+        [-1.8945,  0.9677],
+        [-3.0614, -3.0015],
+        [-0.7166, -1.0928],
+        [ 2.2093, -2.9865],
+        [ 0.7336,  4.9670],
+        [-0.2369, -2.6590]])
```

</p>
</details>

### Full error stack

```
Tensor-likes are not close!

Mismatched elements: 21 / 62 (33.9%)
Greatest absolute difference: 3.719329833984375e-05 at index (26, 1) (up to 1e-05 allowed)
Greatest relative difference: 0.0005033136694692075 at index (15, 1) (up to 1.3e-06 allowed)
  File "/home/titaiwang/pytorch/test/onnx/test_fx_op_consistency.py", line 1763, in _compare_onnx_and_torch_exported_program
    torch.testing.assert_close(
  File "/home/titaiwang/pytorch/torch/testing/_comparison.py", line 1523, in assert_close
    raise error_metas[0].to_error(msg)

```

### Environment

```
OS: Linux-5.15.135.1-2.cm2-x86_64-with-glibc2.35
Python version: 3.11.5 (main, Sep 11 2023, 13:54:46) [GCC 11.2.0]
onnx==1.15.0
onnxruntime==1.17.0
onnxscript==0.1.0.dev20240207
numpy==1.26.0
torch==2.2.0a0+git684ce1b
```
Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119512
Approved by: https://github.com/justinchuby, https://github.com/thiagocrepaldi
2024-02-09 23:24:01 +00:00

177 lines
3.4 KiB
Python

"""Error reproduction utilities for op consistency tests."""
from __future__ import annotations
import difflib
import pathlib
import platform
import sys
import time
import traceback
import numpy as np
import onnx
import onnxruntime as ort
import onnxscript
import torch
_MISMATCH_MARKDOWN_TEMPLATE = """\
### Summary
The output of ONNX Runtime does not match that of PyTorch when executing test
`{test_name}`, `sample {sample_num}` in ONNX Script `TorchLib`.
To recreate this report, use
```bash
CREATE_REPRODUCTION_REPORT=1 python -m pytest onnxscript/tests/function_libs/torch_lib/ops_test.py -k {short_test_name}
```
### ONNX Model
```
{onnx_model_text}
```
### Inputs
Shapes: `{input_shapes}`
<details><summary>Details</summary>
<p>
```python
kwargs = {kwargs}
inputs = {inputs}
```
</p>
</details>
### Expected output
Shape: `{expected_shape}`
<details><summary>Details</summary>
<p>
```python
expected = {expected}
```
</p>
</details>
### Actual output
Shape: `{actual_shape}`
<details><summary>Details</summary>
<p>
```python
actual = {actual}
```
</p>
</details>
### Difference
<details><summary>Details</summary>
<p>
```diff
{diff}
```
</p>
</details>
### Full error stack
```
{error_stack}
```
### Environment
```
{sys_info}
```
"""
def create_mismatch_report(
test_name: str,
sample_num: int,
onnx_model: onnx.ModelProto,
inputs,
kwargs,
actual,
expected,
error: Exception,
) -> None:
torch.set_printoptions(threshold=sys.maxsize)
error_text = str(error)
error_stack = error_text + "\n" + "".join(traceback.format_tb(error.__traceback__))
short_test_name = test_name.split(".")[-1]
diff = difflib.unified_diff(
str(actual).splitlines(),
str(expected).splitlines(),
fromfile="actual",
tofile="expected",
lineterm="",
)
onnx_model_text = onnx.printer.to_text(onnx_model)
input_shapes = repr(
[
f"Tensor<{inp.shape}, dtype={inp.dtype}>"
if isinstance(inp, torch.Tensor)
else inp
for inp in inputs
]
)
sys_info = f"""\
OS: {platform.platform()}
Python version: {sys.version}
onnx=={onnx.__version__}
onnxruntime=={ort.__version__}
onnxscript=={onnxscript.__version__}
numpy=={np.__version__}
torch=={torch.__version__}"""
markdown = _MISMATCH_MARKDOWN_TEMPLATE.format(
test_name=test_name,
short_test_name=short_test_name,
sample_num=sample_num,
input_shapes=input_shapes,
inputs=inputs,
kwargs=kwargs,
expected=expected,
expected_shape=expected.shape if isinstance(expected, torch.Tensor) else None,
actual=actual,
actual_shape=actual.shape if isinstance(actual, torch.Tensor) else None,
diff="\n".join(diff),
error_stack=error_stack,
sys_info=sys_info,
onnx_model_text=onnx_model_text,
)
markdown_file_name = f'mismatch-{short_test_name.replace("/", "-").replace(":", "-")}-{str(time.time()).replace(".", "_")}.md'
markdown_file_path = save_error_report(markdown_file_name, markdown)
print(f"Created reproduction report at {markdown_file_path}")
def save_error_report(file_name: str, text: str):
reports_dir = pathlib.Path("error_reports")
reports_dir.mkdir(parents=True, exist_ok=True)
file_path = reports_dir / file_name
with open(file_path, "w", encoding="utf-8") as f:
f.write(text)
return file_path