pytorch/torch/onnx/errors.py
Justin Chu bf25a140f9 [ONNX] Add runtime type checking to export (#83673)
This PR adds an internal wrapper on the [beartype](https://github.com/beartype/beartype) library to perform runtime type checking in `torch.onnx`. It uses beartype when it is found in the environment and is reduced to a no-op when beartype is not found.

Setting the env var `TORCH_ONNX_EXPERIMENTAL_RUNTIME_TYPE_CHECK=ERRORS` will turn on the feature. setting `TORCH_ONNX_EXPERIMENTAL_RUNTIME_TYPE_CHECK=DISABLED` will disable all checks. When not set and `beartype` is installed, a warning message is emitted.

Now when users call an api with invalid arguments e.g.

```python
torch.onnx.export(conv, y, path, export_params=True, training=False)

# traning should take TrainingModel, not bool
```

they get

```
Traceback (most recent call last):
  File "bisect_m1_error.py", line 63, in <module>
    main()
  File "bisect_m1_error.py", line 59, in main
    reveal_error()
  File "bisect_m1_error.py", line 32, in reveal_error
    torch.onnx.export(conv, y, cpu_model_path, export_params=True, training=False)
  File "<@beartype(torch.onnx.utils.export) at 0x1281f5a60>", line 136, in export
  File "pytorch/venv/lib/python3.9/site-packages/beartype/_decor/_error/errormain.py", line 301, in raise_pep_call_exception
    raise exception_cls(  # type: ignore[misc]
beartype.roar.BeartypeCallHintParamViolation: @beartyped export() parameter training=False violates type hint <class 'torch._C._onnx.TrainingMode'>, as False not instance of <protocol "torch._C._onnx.TrainingMode">.
```

when `TORCH_ONNX_EXPERIMENTAL_RUNTIME_TYPE_CHECK` is not set and `beartype` is installed, a warning message is emitted.

```
>>> torch.onnx.export("foo", "bar", "f")
<stdin>:1: CallHintViolationWarning: Traceback (most recent call last):
  File "/home/justinchu/dev/pytorch/torch/onnx/_internal/_beartype.py", line 54, in _coerce_beartype_exceptions_to_warnings
    return beartyped(*args, **kwargs)
  File "<@beartype(torch.onnx.utils.export) at 0x7f1d4ab35280>", line 39, in export
  File "/home/justinchu/anaconda3/envs/pytorch/lib/python3.9/site-packages/beartype/_decor/_error/errormain.py", line 301, in raise_pep_call_exception
    raise exception_cls(  # type: ignore[misc]
beartype.roar.BeartypeCallHintParamViolation: @beartyped export() parameter model='foo' violates type hint typing.Union[torch.nn.modules.module.Module, torch.jit._script.ScriptModule, torch.jit.ScriptFunction], as 'foo' not <protocol "torch.jit.ScriptFunction">, <protocol "torch.nn.modules.module.Module">, or <protocol "torch.jit._script.ScriptModule">.

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/justinchu/dev/pytorch/torch/onnx/_internal/_beartype.py", line 63, in _coerce_beartype_exceptions_to_warnings
    return func(*args, **kwargs)
  File "/home/justinchu/dev/pytorch/torch/onnx/utils.py", line 482, in export
    _export(
  File "/home/justinchu/dev/pytorch/torch/onnx/utils.py", line 1422, in _export
    with exporter_context(model, training, verbose):
  File "/home/justinchu/anaconda3/envs/pytorch/lib/python3.9/contextlib.py", line 119, in __enter__
    return next(self.gen)
  File "/home/justinchu/dev/pytorch/torch/onnx/utils.py", line 177, in exporter_context
    with select_model_mode_for_export(
  File "/home/justinchu/anaconda3/envs/pytorch/lib/python3.9/contextlib.py", line 119, in __enter__
    return next(self.gen)
  File "/home/justinchu/dev/pytorch/torch/onnx/utils.py", line 95, in select_model_mode_for_export
    originally_training = model.training
AttributeError: 'str' object has no attribute 'training'
```

We see the error is caught right when the type mismatch happens, improving from what otherwise would become `AttributeError: 'str' object has no attribute 'training'`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83673
Approved by: https://github.com/BowenBao
2022-08-25 21:24:37 +00:00

114 lines
3.6 KiB
Python

"""ONNX exporter exceptions."""
from __future__ import annotations
import textwrap
from typing import Optional
from torch import _C
from torch.onnx import _constants
__all__ = [
"OnnxExporterError",
"OnnxExporterWarning",
"CallHintViolationWarning",
"CheckerError",
"UnsupportedOperatorError",
"SymbolicValueError",
]
class OnnxExporterWarning(UserWarning):
"""Base class for all warnings in the ONNX exporter."""
pass
class CallHintViolationWarning(OnnxExporterWarning):
"""Warning raised when a type hint is violated during a function call."""
pass
class OnnxExporterError(RuntimeError):
"""Errors raised by the ONNX exporter."""
pass
class CheckerError(OnnxExporterError):
"""Raised when ONNX checker detects an invalid model."""
pass
class UnsupportedOperatorError(OnnxExporterError):
"""Raised when an operator is unsupported by the exporter."""
def __init__(
self, domain: str, op_name: str, version: int, supported_version: Optional[int]
):
if domain in {"", "aten", "prim", "quantized"}:
msg = f"Exporting the operator '{domain}::{op_name}' to ONNX opset version {version} is not supported. "
if supported_version is not None:
msg += (
f"Support for this operator was added in version {supported_version}, "
"try exporting with this version."
)
else:
msg += "Please feel free to request support or submit a pull request on PyTorch GitHub: "
msg += _constants.PYTORCH_GITHUB_ISSUES_URL
else:
msg = (
f"ONNX export failed on an operator with unrecognized namespace '{domain}::{op_name}'. "
"If you are trying to export a custom operator, make sure you registered "
"it with the right domain and version."
)
super().__init__(msg)
class SymbolicValueError(OnnxExporterError):
"""Errors around TorchScript values and nodes."""
def __init__(self, msg: str, value: _C.Value):
message = (
f"{msg} [Caused by the value '{value}' (type '{value.type()}') in the "
f"TorchScript graph. The containing node has kind '{value.node().kind()}'.] "
)
code_location = value.node().sourceRange()
if code_location:
message += f"\n (node defined in {code_location})"
try:
# Add its input and output to the message.
message += "\n\n"
message += textwrap.indent(
(
"Inputs:\n"
+ (
"\n".join(
f" #{i}: {input_} (type '{input_.type()}')"
for i, input_ in enumerate(value.node().inputs())
)
or " Empty"
)
+ "\n"
+ "Outputs:\n"
+ (
"\n".join(
f" #{i}: {output} (type '{output.type()}')"
for i, output in enumerate(value.node().outputs())
)
or " Empty"
)
),
" ",
)
except AttributeError:
message += (
" Failed to obtain its input and output for debugging. "
"Please refer to the TorchScript graph for debugging information."
)
super().__init__(message)