mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Fixes #119472 Introduce the debugging tool in onnxscript: https://github.com/microsoft/onnxscript/blob/main/onnxscript/tests/function_libs/torch_lib/error_reproduction.py This tool can help us quickly find the inputs leading to mismatched errors. NOTE: this produces `error_reports` folder where there are different markdown reports for each mismatched test cases. For example - CREATE_REPRODUCTION_REPORT=1 python -m pytest onnxscript/tests/function_libs/torch_lib/ops_test.py -k test_output_match_fft_fft_cpu_bool ### Summary The output of ONNX Runtime does not match that of PyTorch when executing test `test_fx_op_consistency.TestOnnxModelOutputConsistency_opset_version_18_model_type_TorchModelType.TORCH_NN_MODULECPU.test_output_match_fft_fft_cpu_bool`, `sample 3` in ONNX Script `TorchLib`. To recreate this report, use ```bash CREATE_REPRODUCTION_REPORT=1 python -m pytest onnxscript/tests/function_libs/torch_lib/ops_test.py -k test_output_match_fft_fft_cpu_bool ``` ### ONNX Model ``` < ir_version: 8, opset_import: ["pkg.onnxscript.torch_lib" : 1, "" : 18, "pkg.onnxscript.torch_lib.common" : 1], producer_name: "pytorch", producer_version: "2.2.0" > main_graph (bool[31] l_args_0_) => (float[31,2] _fft_r2c) <bool[31] l_args_0_, float[31] _to_copy, float[31,2] _fft_r2c> { _to_copy = Cast <to: int = 1> (l_args_0_) _val_2 = Constant <value: tensor = int64[1] {-1}> () _val_3 = Unsqueeze (_to_copy, _val_2) _val_4 = Constant <value: tensor = int64[1] {0}> () _val_5 = Unsqueeze (_val_3, _val_4) _val_6 = DFT <axis: int = 1, inverse: int = 0, onesided: int = 0> (_val_5) _val_7 = Constant <value: tensor = int64[1] {0}> () _val_8 = Squeeze (_val_6, _val_7) _fft_r2c = pkg.onnxscript.torch_lib._fftn_onnx_normalization <dims: ints = [0], forward: int = 1, normalization: int = 0> (_val_3, _val_8) } < domain: "pkg.onnxscript.torch_lib", opset_import: ["" : 18] > _fftn_onnx_normalization <normalization,forward,dims>(self, transformed) => (result_15) { self_shape = Shape (self) dims = Constant <value_ints: ints = @dims> () self_shape_subscripted = Gather <axis: int = 0> (self_shape, dims) total_sample_count = ReduceProd <keepdims: int = 0> (self_shape_subscripted) total_sample_count_0 = CastLike (total_sample_count, transformed) normalization = Constant <value_int: int = @normalization> () int64_1 = Constant <value: tensor = int64 int64_1 {1}> () cond = Equal (normalization, int64_1) result_15 = If (cond) <then_branch: graph = thenGraph_21 () => ( result_3) { forward = Constant <value_int: int = @forward> () forward_as_bool = Cast <to: int = 9> (forward) result_3 = If (forward_as_bool) <then_branch: graph = thenGraph_23 () => ( result) { tmp = Sqrt (total_sample_count_0) result = Div (transformed, tmp) }, else_branch: graph = elseGraph_23 () => ( result_2) { tmp_1 = Sqrt (total_sample_count_0) result_2 = Mul (transformed, tmp_1) }> }, else_branch: graph = elseGraph_21 () => ( result_14) { normalization_4 = Constant <value_int: int = @normalization> () int64_2 = Constant <value: tensor = int64 int64_2 {2}> () cond_5 = Equal (normalization_4, int64_2) result_14 = If (cond_5) <then_branch: graph = thenGraph_27 () => ( result_9) { forward_6 = Constant <value_int: int = @forward> () forward_6_as_bool = Cast <to: int = 9> (forward_6) result_9 = If (forward_6_as_bool) <then_branch: graph = thenGraph_29 () => ( result_7) { result_7 = Div (transformed, total_sample_count_0) }, else_branch: graph = elseGraph_29 () => ( result_8) { result_8 = Identity (transformed) }> }, else_branch: graph = elseGraph_27 () => ( result_13) { forward_10 = Constant <value_int: int = @forward> () forward_10_as_bool = Cast <to: int = 9> (forward_10) result_13 = If (forward_10_as_bool) <then_branch: graph = thenGraph_35 () => ( result_11) { result_11 = Identity (transformed) }, else_branch: graph = elseGraph_35 () => ( result_12) { result_12 = Mul (transformed, total_sample_count_0) }> }> }> } < domain: "pkg.onnxscript.torch_lib.common", opset_import: ["" : 18] > Rank (input) => (return_val) { tmp = Shape (input) return_val = Size (tmp) } < domain: "pkg.onnxscript.torch_lib.common", opset_import: ["" : 18] > IsScalar (input) => (return_val) { tmp = Shape (input) tmp_0 = Size (tmp) tmp_1 = Constant <value_int: int = 0> () return_val = Equal (tmp_0, tmp_1) } ``` ### Inputs Shapes: `['Tensor<torch.Size([31]), dtype=torch.bool>']` <details><summary>Details</summary> <p> ```python kwargs = {} inputs = (tensor([False, False, True, True, False, True, False, True, False, False, True, False, False, False, False, False, True, True, True, True, True, True, True, True, False, False, False, False, True, True, True]),) ``` </p> </details> ### Expected output Shape: `torch.Size([31, 2])` <details><summary>Details</summary> <p> ```python expected = tensor([[16.0000, 0.0000], [-0.2369, 2.6590], [ 0.7336, -4.9670], [ 2.2093, 2.9865], [-0.7166, 1.0928], [-3.0614, 3.0015], [-1.8945, -0.9677], [-2.1538, 0.2513], [-2.2432, 1.3978], [-0.3429, 1.9494], [-0.6495, -1.5423], [-0.6005, 2.2398], [ 2.2639, 2.6430], [ 1.7609, 0.2033], [-1.3829, -2.3365], [-1.6854, -0.0311], [-1.6854, 0.0311], [-1.3829, 2.3365], [ 1.7609, -0.2033], [ 2.2639, -2.6430], [-0.6005, -2.2398], [-0.6495, 1.5423], [-0.3429, -1.9494], [-2.2432, -1.3978], [-2.1538, -0.2513], [-1.8945, 0.9677], [-3.0614, -3.0015], [-0.7166, -1.0928], [ 2.2093, -2.9865], [ 0.7336, 4.9670], [-0.2369, -2.6590]]) ``` </p> </details> ### Actual output Shape: `torch.Size([31, 2])` <details><summary>Details</summary> <p> ```python actual = tensor([[ 1.6000e+01, -9.1791e-06], [-2.3695e-01, 2.6590e+00], [ 7.3355e-01, -4.9670e+00], [ 2.2093e+00, 2.9865e+00], [-7.1663e-01, 1.0928e+00], [-3.0614e+00, 3.0015e+00], [-1.8946e+00, -9.6773e-01], [-2.1538e+00, 2.5126e-01], [-2.2432e+00, 1.3978e+00], [-3.4294e-01, 1.9494e+00], [-6.4946e-01, -1.5423e+00], [-6.0044e-01, 2.2398e+00], [ 2.2639e+00, 2.6430e+00], [ 1.7609e+00, 2.0326e-01], [-1.3829e+00, -2.3365e+00], [-1.6854e+00, -3.1130e-02], [-1.6854e+00, 3.1161e-02], [-1.3829e+00, 2.3365e+00], [ 1.7609e+00, -2.0327e-01], [ 2.2639e+00, -2.6430e+00], [-6.0047e-01, -2.2398e+00], [-6.4945e-01, 1.5423e+00], [-3.4294e-01, -1.9494e+00], [-2.2432e+00, -1.3978e+00], [-2.1538e+00, -2.5129e-01], [-1.8945e+00, 9.6773e-01], [-3.0615e+00, -3.0015e+00], [-7.1663e-01, -1.0928e+00], [ 2.2093e+00, -2.9865e+00], [ 7.3354e-01, 4.9670e+00], [-2.3695e-01, -2.6589e+00]]) ``` </p> </details> ### Difference <details><summary>Details</summary> <p> ```diff --- actual +++ expected @@ -1,31 +1,31 @@ -tensor([[ 1.6000e+01, -9.1791e-06], - [-2.3695e-01, 2.6590e+00], - [ 7.3355e-01, -4.9670e+00], - [ 2.2093e+00, 2.9865e+00], - [-7.1663e-01, 1.0928e+00], - [-3.0614e+00, 3.0015e+00], - [-1.8946e+00, -9.6773e-01], - [-2.1538e+00, 2.5126e-01], - [-2.2432e+00, 1.3978e+00], - [-3.4294e-01, 1.9494e+00], - [-6.4946e-01, -1.5423e+00], - [-6.0044e-01, 2.2398e+00], - [ 2.2639e+00, 2.6430e+00], - [ 1.7609e+00, 2.0326e-01], - [-1.3829e+00, -2.3365e+00], - [-1.6854e+00, -3.1130e-02], - [-1.6854e+00, 3.1161e-02], - [-1.3829e+00, 2.3365e+00], - [ 1.7609e+00, -2.0327e-01], - [ 2.2639e+00, -2.6430e+00], - [-6.0047e-01, -2.2398e+00], - [-6.4945e-01, 1.5423e+00], - [-3.4294e-01, -1.9494e+00], - [-2.2432e+00, -1.3978e+00], - [-2.1538e+00, -2.5129e-01], - [-1.8945e+00, 9.6773e-01], - [-3.0615e+00, -3.0015e+00], - [-7.1663e-01, -1.0928e+00], - [ 2.2093e+00, -2.9865e+00], - [ 7.3354e-01, 4.9670e+00], - [-2.3695e-01, -2.6589e+00]]) +tensor([[16.0000, 0.0000], + [-0.2369, 2.6590], + [ 0.7336, -4.9670], + [ 2.2093, 2.9865], + [-0.7166, 1.0928], + [-3.0614, 3.0015], + [-1.8945, -0.9677], + [-2.1538, 0.2513], + [-2.2432, 1.3978], + [-0.3429, 1.9494], + [-0.6495, -1.5423], + [-0.6005, 2.2398], + [ 2.2639, 2.6430], + [ 1.7609, 0.2033], + [-1.3829, -2.3365], + [-1.6854, -0.0311], + [-1.6854, 0.0311], + [-1.3829, 2.3365], + [ 1.7609, -0.2033], + [ 2.2639, -2.6430], + [-0.6005, -2.2398], + [-0.6495, 1.5423], + [-0.3429, -1.9494], + [-2.2432, -1.3978], + [-2.1538, -0.2513], + [-1.8945, 0.9677], + [-3.0614, -3.0015], + [-0.7166, -1.0928], + [ 2.2093, -2.9865], + [ 0.7336, 4.9670], + [-0.2369, -2.6590]]) ``` </p> </details> ### Full error stack ``` Tensor-likes are not close! Mismatched elements: 21 / 62 (33.9%) Greatest absolute difference: 3.719329833984375e-05 at index (26, 1) (up to 1e-05 allowed) Greatest relative difference: 0.0005033136694692075 at index (15, 1) (up to 1.3e-06 allowed) File "/home/titaiwang/pytorch/test/onnx/test_fx_op_consistency.py", line 1763, in _compare_onnx_and_torch_exported_program torch.testing.assert_close( File "/home/titaiwang/pytorch/torch/testing/_comparison.py", line 1523, in assert_close raise error_metas[0].to_error(msg) ``` ### Environment ``` OS: Linux-5.15.135.1-2.cm2-x86_64-with-glibc2.35 Python version: 3.11.5 (main, Sep 11 2023, 13:54:46) [GCC 11.2.0] onnx==1.15.0 onnxruntime==1.17.0 onnxscript==0.1.0.dev20240207 numpy==1.26.0 torch==2.2.0a0+git684ce1b ``` Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/119512 Approved by: https://github.com/justinchuby, https://github.com/thiagocrepaldi
177 lines
3.4 KiB
Python
177 lines
3.4 KiB
Python
"""Error reproduction utilities for op consistency tests."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import difflib
|
|
import pathlib
|
|
import platform
|
|
import sys
|
|
import time
|
|
import traceback
|
|
|
|
import numpy as np
|
|
|
|
import onnx
|
|
import onnxruntime as ort
|
|
import onnxscript
|
|
import torch
|
|
|
|
_MISMATCH_MARKDOWN_TEMPLATE = """\
|
|
### Summary
|
|
|
|
The output of ONNX Runtime does not match that of PyTorch when executing test
|
|
`{test_name}`, `sample {sample_num}` in ONNX Script `TorchLib`.
|
|
|
|
To recreate this report, use
|
|
|
|
```bash
|
|
CREATE_REPRODUCTION_REPORT=1 python -m pytest onnxscript/tests/function_libs/torch_lib/ops_test.py -k {short_test_name}
|
|
```
|
|
|
|
### ONNX Model
|
|
|
|
```
|
|
{onnx_model_text}
|
|
```
|
|
|
|
### Inputs
|
|
|
|
Shapes: `{input_shapes}`
|
|
|
|
<details><summary>Details</summary>
|
|
<p>
|
|
|
|
```python
|
|
kwargs = {kwargs}
|
|
inputs = {inputs}
|
|
```
|
|
|
|
</p>
|
|
</details>
|
|
|
|
### Expected output
|
|
|
|
Shape: `{expected_shape}`
|
|
|
|
<details><summary>Details</summary>
|
|
<p>
|
|
|
|
```python
|
|
expected = {expected}
|
|
```
|
|
|
|
</p>
|
|
</details>
|
|
|
|
### Actual output
|
|
|
|
Shape: `{actual_shape}`
|
|
|
|
<details><summary>Details</summary>
|
|
<p>
|
|
|
|
```python
|
|
actual = {actual}
|
|
```
|
|
|
|
</p>
|
|
</details>
|
|
|
|
### Difference
|
|
|
|
<details><summary>Details</summary>
|
|
<p>
|
|
|
|
```diff
|
|
{diff}
|
|
```
|
|
|
|
</p>
|
|
</details>
|
|
|
|
### Full error stack
|
|
|
|
```
|
|
{error_stack}
|
|
```
|
|
|
|
### Environment
|
|
|
|
```
|
|
{sys_info}
|
|
```
|
|
|
|
"""
|
|
|
|
|
|
def create_mismatch_report(
|
|
test_name: str,
|
|
sample_num: int,
|
|
onnx_model: onnx.ModelProto,
|
|
inputs,
|
|
kwargs,
|
|
actual,
|
|
expected,
|
|
error: Exception,
|
|
) -> None:
|
|
torch.set_printoptions(threshold=sys.maxsize)
|
|
|
|
error_text = str(error)
|
|
error_stack = error_text + "\n" + "".join(traceback.format_tb(error.__traceback__))
|
|
short_test_name = test_name.split(".")[-1]
|
|
diff = difflib.unified_diff(
|
|
str(actual).splitlines(),
|
|
str(expected).splitlines(),
|
|
fromfile="actual",
|
|
tofile="expected",
|
|
lineterm="",
|
|
)
|
|
onnx_model_text = onnx.printer.to_text(onnx_model)
|
|
input_shapes = repr(
|
|
[
|
|
f"Tensor<{inp.shape}, dtype={inp.dtype}>"
|
|
if isinstance(inp, torch.Tensor)
|
|
else inp
|
|
for inp in inputs
|
|
]
|
|
)
|
|
sys_info = f"""\
|
|
OS: {platform.platform()}
|
|
Python version: {sys.version}
|
|
onnx=={onnx.__version__}
|
|
onnxruntime=={ort.__version__}
|
|
onnxscript=={onnxscript.__version__}
|
|
numpy=={np.__version__}
|
|
torch=={torch.__version__}"""
|
|
|
|
markdown = _MISMATCH_MARKDOWN_TEMPLATE.format(
|
|
test_name=test_name,
|
|
short_test_name=short_test_name,
|
|
sample_num=sample_num,
|
|
input_shapes=input_shapes,
|
|
inputs=inputs,
|
|
kwargs=kwargs,
|
|
expected=expected,
|
|
expected_shape=expected.shape if isinstance(expected, torch.Tensor) else None,
|
|
actual=actual,
|
|
actual_shape=actual.shape if isinstance(actual, torch.Tensor) else None,
|
|
diff="\n".join(diff),
|
|
error_stack=error_stack,
|
|
sys_info=sys_info,
|
|
onnx_model_text=onnx_model_text,
|
|
)
|
|
|
|
markdown_file_name = f'mismatch-{short_test_name.replace("/", "-").replace(":", "-")}-{str(time.time()).replace(".", "_")}.md'
|
|
markdown_file_path = save_error_report(markdown_file_name, markdown)
|
|
print(f"Created reproduction report at {markdown_file_path}")
|
|
|
|
|
|
def save_error_report(file_name: str, text: str):
|
|
reports_dir = pathlib.Path("error_reports")
|
|
reports_dir.mkdir(parents=True, exist_ok=True)
|
|
file_path = reports_dir / file_name
|
|
with open(file_path, "w", encoding="utf-8") as f:
|
|
f.write(text)
|
|
|
|
return file_path
|