pytorch/torch/onnx/_globals.py
Thiago Crepaldi 3834582327 [ONNX] Add autograd_inlining flag to torch.onnx.export (#104067)
Fixes #88286, Fixes #97160

Repro:

```python
import torch
import io
from torch.utils.checkpoint import checkpoint

class A(torch.nn.Module):
    # A supported module.
    def __init__(self):
        super(A, self).__init__()
        self.l1 = torch.nn.Linear(2, 2)

    def forward(self, x):
        return self.l1(x)

class B(torch.nn.Module):
    # This module is not exportable to ONNX because it
    # uses gradient-checkpointing. However, its two sub-module's
    # are exportable, so ORTModule should be used to compute them.
    def __init__(self):
        super(B, self).__init__()
        self.l1 = torch.nn.Linear(2, 2)
        self.a = A()

    def forward(self, x):
        def custom():
            def custom_forward(x_):
                return self.a(x_)

            return custom_forward

        z = self.l1(checkpoint(custom(), x))
        return z

torch.onnx.export(
    B(),
    (torch.randn(2, 2),),
    io.BytesIO(),
    autograd_inlining=True
)
```

`torch.onnx.export(autograd_inlining=True)` should repro the user error as this is the original execution path.
```bash
Traceback (most recent call last):
  File "repro88286.py", line 36, in <module>
    torch.onnx.export(
  File "<@beartype(torch.onnx.utils.export) at 0x7f0f011faee0>", line 385, in export
  File "/opt/pytorch/torch/onnx/utils.py", line 511, in export
    _export(
  File "/opt/pytorch/torch/onnx/utils.py", line 1576, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "<@beartype(torch.onnx.utils._model_to_graph) at 0x7f0f01187dc0>", line 11, in _model_to_graph
  File "/opt/pytorch/torch/onnx/utils.py", line 1130, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
  File "/opt/pytorch/torch/onnx/utils.py", line 1006, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "/opt/pytorch/torch/onnx/utils.py", line 910, in _trace_and_get_graph_from_model
    trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
  File "/opt/pytorch/torch/jit/_trace.py", line 1269, in _get_trace_graph
    outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
  File "/opt/pytorch/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/pytorch/torch/nn/modules/module.py", line 1511, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/pytorch/torch/jit/_trace.py", line 128, in forward
    graph, out = torch._C._create_graph_by_tracing(
  File "/opt/pytorch/torch/jit/_trace.py", line 119, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "/opt/pytorch/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/pytorch/torch/nn/modules/module.py", line 1511, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/pytorch/torch/nn/modules/module.py", line 1492, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "repro88286.py", line 32, in forward
    z = self.l1(checkpoint(custom(), x))
  File "/opt/pytorch/torch/utils/checkpoint.py", line 412, in checkpoint
    return CheckpointFunction.apply(function, preserve, *args)
  File "/opt/pytorch/torch/autograd/function.py", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
RuntimeError: _Map_base::at
```
By using `autograd_inlining=False`, the export still fail with a different error because autograd inlining is not enabled:

```bash
Traceback (most recent call last):
  File "repro88286.py", line 36, in <module>
    torch.onnx.export(
  File "<@beartype(torch.onnx.utils.export) at 0x7f6088b32ee0>", line 385, in export
  File "/opt/pytorch/torch/onnx/utils.py", line 511, in export
    _export(
  File "/opt/pytorch/torch/onnx/utils.py", line 1615, in _export
    ) = graph._export_onnx(  # type: ignore[attr-defined]
RuntimeError: ONNX export failed: Couldn't export Python operator CheckpointFunction
```
To allow `CheckpointFunction` into the onnx graph, `operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH` flag can be added to `torch.onnx.export`, which would lead to the following ONNX graph:

```bash
Exported graph: graph(%prim::PythonOp_0 : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu),
      %l1.weight : Float(2, 2, strides=[2, 1], requires_grad=1, device=cpu),
      %l1.bias : Float(2, strides=[1], requires_grad=1, device=cpu)):
  %/PythonOp_output_0 : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = ^CheckpointFunction[inplace=0, module="torch.utils.checkpoint", onnx_name="/PythonOp"](<function B.forward.<locals>.custom.<locals>.custom_forward at 0x7fdf9182f670>, True)(%prim::PythonOp_0), scope: __main__.B:: # /opt/pytorch/torch/autograd/function.py:506:0
  %6 : Float(2, 2, strides=[2, 1], requires_grad=1, device=cpu) = onnx::Gemm[alpha=1., beta=1., transB=1, onnx_name="/l1/Gemm"](%/PythonOp_output_0, %l1.weight, %l1.bias), scope: __main__.B::/torch.nn.modules.linear.Linear::l1 # /opt/pytorch/torch/nn/modules/linear.py:114:0
  return (%6)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104067
Approved by: https://github.com/BowenBao, https://github.com/kit1980
2023-07-05 15:27:36 +00:00

86 lines
2.9 KiB
Python

"""Globals used internally by the ONNX exporter.
Do not use this module outside of `torch.onnx` and its tests.
Be very judicious when adding any new global variables. Do not create new global
variables unless they are absolutely necessary.
"""
import torch._C._onnx as _C_onnx
# This module should only depend on _constants and nothing else in torch.onnx to keep
# dependency direction clean.
from torch.onnx import _constants
class _InternalGlobals:
"""Globals used internally by ONNX exporter.
NOTE: Be very judicious when adding any new variables. Do not create new
global variables unless they are absolutely necessary.
"""
def __init__(self):
self._export_onnx_opset_version = _constants.ONNX_DEFAULT_OPSET
self._training_mode: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL
self._in_onnx_export: bool = False
# Whether the user's model is training during export
self.export_training: bool = False
self.operator_export_type: _C_onnx.OperatorExportTypes = (
_C_onnx.OperatorExportTypes.ONNX
)
self.onnx_shape_inference: bool = True
self._autograd_inlining: bool = True
@property
def training_mode(self):
"""The training mode for the exporter."""
return self._training_mode
@training_mode.setter
def training_mode(self, training_mode: _C_onnx.TrainingMode):
if not isinstance(training_mode, _C_onnx.TrainingMode):
raise TypeError(
"training_mode must be of type 'torch.onnx.TrainingMode'. This is "
"likely a bug in torch.onnx."
)
self._training_mode = training_mode
@property
def export_onnx_opset_version(self) -> int:
"""Opset version used during export."""
return self._export_onnx_opset_version
@export_onnx_opset_version.setter
def export_onnx_opset_version(self, value: int):
supported_versions = range(
_constants.ONNX_MIN_OPSET, _constants.ONNX_MAX_OPSET + 1
)
if value not in supported_versions:
raise ValueError(f"Unsupported ONNX opset version: {value}")
self._export_onnx_opset_version = value
@property
def in_onnx_export(self) -> bool:
"""Whether it is in the middle of ONNX export."""
return self._in_onnx_export
@in_onnx_export.setter
def in_onnx_export(self, value: bool):
if type(value) is not bool:
raise TypeError("in_onnx_export must be a boolean")
self._in_onnx_export = value
@property
def autograd_inlining(self) -> bool:
"""Whether Autograd must be inlined."""
return self._autograd_inlining
@autograd_inlining.setter
def autograd_inlining(self, value: bool):
if type(value) is not bool:
raise TypeError("autograd_inlining must be a boolean")
self._autograd_inlining = value
GLOBALS = _InternalGlobals()