Fixes#88286, Fixes#97160
Repro:
```python
import torch
import io
from torch.utils.checkpoint import checkpoint
class A(torch.nn.Module):
# A supported module.
def __init__(self):
super(A, self).__init__()
self.l1 = torch.nn.Linear(2, 2)
def forward(self, x):
return self.l1(x)
class B(torch.nn.Module):
# This module is not exportable to ONNX because it
# uses gradient-checkpointing. However, its two sub-module's
# are exportable, so ORTModule should be used to compute them.
def __init__(self):
super(B, self).__init__()
self.l1 = torch.nn.Linear(2, 2)
self.a = A()
def forward(self, x):
def custom():
def custom_forward(x_):
return self.a(x_)
return custom_forward
z = self.l1(checkpoint(custom(), x))
return z
torch.onnx.export(
B(),
(torch.randn(2, 2),),
io.BytesIO(),
autograd_inlining=True
)
```
`torch.onnx.export(autograd_inlining=True)` should repro the user error as this is the original execution path.
```bash
Traceback (most recent call last):
File "repro88286.py", line 36, in <module>
torch.onnx.export(
File "<@beartype(torch.onnx.utils.export) at 0x7f0f011faee0>", line 385, in export
File "/opt/pytorch/torch/onnx/utils.py", line 511, in export
_export(
File "/opt/pytorch/torch/onnx/utils.py", line 1576, in _export
graph, params_dict, torch_out = _model_to_graph(
File "<@beartype(torch.onnx.utils._model_to_graph) at 0x7f0f01187dc0>", line 11, in _model_to_graph
File "/opt/pytorch/torch/onnx/utils.py", line 1130, in _model_to_graph
graph, params, torch_out, module = _create_jit_graph(model, args)
File "/opt/pytorch/torch/onnx/utils.py", line 1006, in _create_jit_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args)
File "/opt/pytorch/torch/onnx/utils.py", line 910, in _trace_and_get_graph_from_model
trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
File "/opt/pytorch/torch/jit/_trace.py", line 1269, in _get_trace_graph
outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
File "/opt/pytorch/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/pytorch/torch/nn/modules/module.py", line 1511, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/pytorch/torch/jit/_trace.py", line 128, in forward
graph, out = torch._C._create_graph_by_tracing(
File "/opt/pytorch/torch/jit/_trace.py", line 119, in wrapper
outs.append(self.inner(*trace_inputs))
File "/opt/pytorch/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/pytorch/torch/nn/modules/module.py", line 1511, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/pytorch/torch/nn/modules/module.py", line 1492, in _slow_forward
result = self.forward(*input, **kwargs)
File "repro88286.py", line 32, in forward
z = self.l1(checkpoint(custom(), x))
File "/opt/pytorch/torch/utils/checkpoint.py", line 412, in checkpoint
return CheckpointFunction.apply(function, preserve, *args)
File "/opt/pytorch/torch/autograd/function.py", line 506, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
RuntimeError: _Map_base::at
```
By using `autograd_inlining=False`, the export still fail with a different error because autograd inlining is not enabled:
```bash
Traceback (most recent call last):
File "repro88286.py", line 36, in <module>
torch.onnx.export(
File "<@beartype(torch.onnx.utils.export) at 0x7f6088b32ee0>", line 385, in export
File "/opt/pytorch/torch/onnx/utils.py", line 511, in export
_export(
File "/opt/pytorch/torch/onnx/utils.py", line 1615, in _export
) = graph._export_onnx( # type: ignore[attr-defined]
RuntimeError: ONNX export failed: Couldn't export Python operator CheckpointFunction
```
To allow `CheckpointFunction` into the onnx graph, `operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH` flag can be added to `torch.onnx.export`, which would lead to the following ONNX graph:
```bash
Exported graph: graph(%prim::PythonOp_0 : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu),
%l1.weight : Float(2, 2, strides=[2, 1], requires_grad=1, device=cpu),
%l1.bias : Float(2, strides=[1], requires_grad=1, device=cpu)):
%/PythonOp_output_0 : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = ^CheckpointFunction[inplace=0, module="torch.utils.checkpoint", onnx_name="/PythonOp"](<function B.forward.<locals>.custom.<locals>.custom_forward at 0x7fdf9182f670>, True)(%prim::PythonOp_0), scope: __main__.B:: # /opt/pytorch/torch/autograd/function.py:506:0
%6 : Float(2, 2, strides=[2, 1], requires_grad=1, device=cpu) = onnx::Gemm[alpha=1., beta=1., transB=1, onnx_name="/l1/Gemm"](%/PythonOp_output_0, %l1.weight, %l1.bias), scope: __main__.B::/torch.nn.modules.linear.Linear::l1 # /opt/pytorch/torch/nn/modules/linear.py:114:0
return (%6)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104067
Approved by: https://github.com/BowenBao, https://github.com/kit1980
`_set_opset_version` and `_set_operator_export_type` are previously deprecated. This PR decorates them with the deprecation decorator, so warnings are emitted.
- Remove usage of `_set_opset_version` and `_set_operator_export_type` in favor of setting the globals vars directly in torch.onnx internal
- Update `GLOBALS.operator_export_type`'s default to not be None to tighten types
- Remove usage of `_set_onnx_shape_inference`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85165
Approved by: https://github.com/BowenBao, https://github.com/AllenTiTaiWang
This PR adds an internal wrapper on the [beartype](https://github.com/beartype/beartype) library to perform runtime type checking in `torch.onnx`. It uses beartype when it is found in the environment and is reduced to a no-op when beartype is not found.
Setting the env var `TORCH_ONNX_EXPERIMENTAL_RUNTIME_TYPE_CHECK=ERRORS` will turn on the feature. setting `TORCH_ONNX_EXPERIMENTAL_RUNTIME_TYPE_CHECK=DISABLED` will disable all checks. When not set and `beartype` is installed, a warning message is emitted.
Now when users call an api with invalid arguments e.g.
```python
torch.onnx.export(conv, y, path, export_params=True, training=False)
# traning should take TrainingModel, not bool
```
they get
```
Traceback (most recent call last):
File "bisect_m1_error.py", line 63, in <module>
main()
File "bisect_m1_error.py", line 59, in main
reveal_error()
File "bisect_m1_error.py", line 32, in reveal_error
torch.onnx.export(conv, y, cpu_model_path, export_params=True, training=False)
File "<@beartype(torch.onnx.utils.export) at 0x1281f5a60>", line 136, in export
File "pytorch/venv/lib/python3.9/site-packages/beartype/_decor/_error/errormain.py", line 301, in raise_pep_call_exception
raise exception_cls( # type: ignore[misc]
beartype.roar.BeartypeCallHintParamViolation: @beartyped export() parameter training=False violates type hint <class 'torch._C._onnx.TrainingMode'>, as False not instance of <protocol "torch._C._onnx.TrainingMode">.
```
when `TORCH_ONNX_EXPERIMENTAL_RUNTIME_TYPE_CHECK` is not set and `beartype` is installed, a warning message is emitted.
```
>>> torch.onnx.export("foo", "bar", "f")
<stdin>:1: CallHintViolationWarning: Traceback (most recent call last):
File "/home/justinchu/dev/pytorch/torch/onnx/_internal/_beartype.py", line 54, in _coerce_beartype_exceptions_to_warnings
return beartyped(*args, **kwargs)
File "<@beartype(torch.onnx.utils.export) at 0x7f1d4ab35280>", line 39, in export
File "/home/justinchu/anaconda3/envs/pytorch/lib/python3.9/site-packages/beartype/_decor/_error/errormain.py", line 301, in raise_pep_call_exception
raise exception_cls( # type: ignore[misc]
beartype.roar.BeartypeCallHintParamViolation: @beartyped export() parameter model='foo' violates type hint typing.Union[torch.nn.modules.module.Module, torch.jit._script.ScriptModule, torch.jit.ScriptFunction], as 'foo' not <protocol "torch.jit.ScriptFunction">, <protocol "torch.nn.modules.module.Module">, or <protocol "torch.jit._script.ScriptModule">.
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/justinchu/dev/pytorch/torch/onnx/_internal/_beartype.py", line 63, in _coerce_beartype_exceptions_to_warnings
return func(*args, **kwargs)
File "/home/justinchu/dev/pytorch/torch/onnx/utils.py", line 482, in export
_export(
File "/home/justinchu/dev/pytorch/torch/onnx/utils.py", line 1422, in _export
with exporter_context(model, training, verbose):
File "/home/justinchu/anaconda3/envs/pytorch/lib/python3.9/contextlib.py", line 119, in __enter__
return next(self.gen)
File "/home/justinchu/dev/pytorch/torch/onnx/utils.py", line 177, in exporter_context
with select_model_mode_for_export(
File "/home/justinchu/anaconda3/envs/pytorch/lib/python3.9/contextlib.py", line 119, in __enter__
return next(self.gen)
File "/home/justinchu/dev/pytorch/torch/onnx/utils.py", line 95, in select_model_mode_for_export
originally_training = model.training
AttributeError: 'str' object has no attribute 'training'
```
We see the error is caught right when the type mismatch happens, improving from what otherwise would become `AttributeError: 'str' object has no attribute 'training'`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83673
Approved by: https://github.com/BowenBao
Legacy code has onnx_shape_inference=False by default, which is misleading
as every other export api sets it to True unless otherwise overriden by caller.
There is only two tests that need updating according to this change.
* test_utility_funs.py::test_constant_fold_shape. The resulting number of nodes
in graph is increased by 1, due to that previously the extra constant node was
added as initializer.
* test_utility_funs.py::test_onnx_function_substitution_pass. Enabling onnx
shape inference discovered discrepancy in test input shape and supplied dynamic
axes arguments.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82767
Approved by: https://github.com/justinchuby, https://github.com/abock
When `TrainingMode.PRESERVE` is set for export, the exporter used to change the model's training mode based on some logic. Now we respect the option and not touch the model's training state.
- Previously `_set_training_mode`'s behavior doesn't match what the global variable expects. This PR removes the deprecated `_set_training_mode` and makes the type correct.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78583
Approved by: https://github.com/BowenBao
Cleaning up onnx module imports to prepare for updating `__init__`.
- Simplify importing the `_C` and `_C._onnx` name spaces
- Remove alias of the symbolic_helper module in imports
- Remove any module level function imports. Import modules instead
- Alias `symbilic_opsetx` as `opsetx`
- Fix some docstrings
Requires:
- https://github.com/pytorch/pytorch/pull/77448
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77423
Approved by: https://github.com/BowenBao
Reduce circular dependencies
- Lift constants and flags from `symbolic_helper` to `_constants` and `_globals`
- Standardized constant naming to make it consistant
- Make `utils` strictly dependent on `symbolic_helper`, removing inline imports from symbolic_helper
- Move side effects from `utils` to `_patch_torch`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77142
Approved by: https://github.com/garymm, https://github.com/BowenBao