mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Fixes #88286, Fixes #97160 Repro: ```python import torch import io from torch.utils.checkpoint import checkpoint class A(torch.nn.Module): # A supported module. def __init__(self): super(A, self).__init__() self.l1 = torch.nn.Linear(2, 2) def forward(self, x): return self.l1(x) class B(torch.nn.Module): # This module is not exportable to ONNX because it # uses gradient-checkpointing. However, its two sub-module's # are exportable, so ORTModule should be used to compute them. def __init__(self): super(B, self).__init__() self.l1 = torch.nn.Linear(2, 2) self.a = A() def forward(self, x): def custom(): def custom_forward(x_): return self.a(x_) return custom_forward z = self.l1(checkpoint(custom(), x)) return z torch.onnx.export( B(), (torch.randn(2, 2),), io.BytesIO(), autograd_inlining=True ) ``` `torch.onnx.export(autograd_inlining=True)` should repro the user error as this is the original execution path. ```bash Traceback (most recent call last): File "repro88286.py", line 36, in <module> torch.onnx.export( File "<@beartype(torch.onnx.utils.export) at 0x7f0f011faee0>", line 385, in export File "/opt/pytorch/torch/onnx/utils.py", line 511, in export _export( File "/opt/pytorch/torch/onnx/utils.py", line 1576, in _export graph, params_dict, torch_out = _model_to_graph( File "<@beartype(torch.onnx.utils._model_to_graph) at 0x7f0f01187dc0>", line 11, in _model_to_graph File "/opt/pytorch/torch/onnx/utils.py", line 1130, in _model_to_graph graph, params, torch_out, module = _create_jit_graph(model, args) File "/opt/pytorch/torch/onnx/utils.py", line 1006, in _create_jit_graph graph, torch_out = _trace_and_get_graph_from_model(model, args) File "/opt/pytorch/torch/onnx/utils.py", line 910, in _trace_and_get_graph_from_model trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph( File "/opt/pytorch/torch/jit/_trace.py", line 1269, in _get_trace_graph outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs) File "/opt/pytorch/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/opt/pytorch/torch/nn/modules/module.py", line 1511, in _call_impl return forward_call(*args, **kwargs) File "/opt/pytorch/torch/jit/_trace.py", line 128, in forward graph, out = torch._C._create_graph_by_tracing( File "/opt/pytorch/torch/jit/_trace.py", line 119, in wrapper outs.append(self.inner(*trace_inputs)) File "/opt/pytorch/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/opt/pytorch/torch/nn/modules/module.py", line 1511, in _call_impl return forward_call(*args, **kwargs) File "/opt/pytorch/torch/nn/modules/module.py", line 1492, in _slow_forward result = self.forward(*input, **kwargs) File "repro88286.py", line 32, in forward z = self.l1(checkpoint(custom(), x)) File "/opt/pytorch/torch/utils/checkpoint.py", line 412, in checkpoint return CheckpointFunction.apply(function, preserve, *args) File "/opt/pytorch/torch/autograd/function.py", line 506, in apply return super().apply(*args, **kwargs) # type: ignore[misc] RuntimeError: _Map_base::at ``` By using `autograd_inlining=False`, the export still fail with a different error because autograd inlining is not enabled: ```bash Traceback (most recent call last): File "repro88286.py", line 36, in <module> torch.onnx.export( File "<@beartype(torch.onnx.utils.export) at 0x7f6088b32ee0>", line 385, in export File "/opt/pytorch/torch/onnx/utils.py", line 511, in export _export( File "/opt/pytorch/torch/onnx/utils.py", line 1615, in _export ) = graph._export_onnx( # type: ignore[attr-defined] RuntimeError: ONNX export failed: Couldn't export Python operator CheckpointFunction ``` To allow `CheckpointFunction` into the onnx graph, `operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH` flag can be added to `torch.onnx.export`, which would lead to the following ONNX graph: ```bash Exported graph: graph(%prim::PythonOp_0 : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu), %l1.weight : Float(2, 2, strides=[2, 1], requires_grad=1, device=cpu), %l1.bias : Float(2, strides=[1], requires_grad=1, device=cpu)): %/PythonOp_output_0 : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = ^CheckpointFunction[inplace=0, module="torch.utils.checkpoint", onnx_name="/PythonOp"](<function B.forward.<locals>.custom.<locals>.custom_forward at 0x7fdf9182f670>, True)(%prim::PythonOp_0), scope: __main__.B:: # /opt/pytorch/torch/autograd/function.py:506:0 %6 : Float(2, 2, strides=[2, 1], requires_grad=1, device=cpu) = onnx::Gemm[alpha=1., beta=1., transB=1, onnx_name="/l1/Gemm"](%/PythonOp_output_0, %l1.weight, %l1.bias), scope: __main__.B::/torch.nn.modules.linear.Linear::l1 # /opt/pytorch/torch/nn/modules/linear.py:114:0 return (%6) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/104067 Approved by: https://github.com/BowenBao, https://github.com/kit1980
86 lines
2.9 KiB
Python
86 lines
2.9 KiB
Python
"""Globals used internally by the ONNX exporter.
|
|
|
|
Do not use this module outside of `torch.onnx` and its tests.
|
|
|
|
Be very judicious when adding any new global variables. Do not create new global
|
|
variables unless they are absolutely necessary.
|
|
"""
|
|
import torch._C._onnx as _C_onnx
|
|
|
|
# This module should only depend on _constants and nothing else in torch.onnx to keep
|
|
# dependency direction clean.
|
|
from torch.onnx import _constants
|
|
|
|
|
|
class _InternalGlobals:
|
|
"""Globals used internally by ONNX exporter.
|
|
|
|
NOTE: Be very judicious when adding any new variables. Do not create new
|
|
global variables unless they are absolutely necessary.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._export_onnx_opset_version = _constants.ONNX_DEFAULT_OPSET
|
|
self._training_mode: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL
|
|
self._in_onnx_export: bool = False
|
|
# Whether the user's model is training during export
|
|
self.export_training: bool = False
|
|
self.operator_export_type: _C_onnx.OperatorExportTypes = (
|
|
_C_onnx.OperatorExportTypes.ONNX
|
|
)
|
|
self.onnx_shape_inference: bool = True
|
|
self._autograd_inlining: bool = True
|
|
|
|
@property
|
|
def training_mode(self):
|
|
"""The training mode for the exporter."""
|
|
return self._training_mode
|
|
|
|
@training_mode.setter
|
|
def training_mode(self, training_mode: _C_onnx.TrainingMode):
|
|
if not isinstance(training_mode, _C_onnx.TrainingMode):
|
|
raise TypeError(
|
|
"training_mode must be of type 'torch.onnx.TrainingMode'. This is "
|
|
"likely a bug in torch.onnx."
|
|
)
|
|
self._training_mode = training_mode
|
|
|
|
@property
|
|
def export_onnx_opset_version(self) -> int:
|
|
"""Opset version used during export."""
|
|
return self._export_onnx_opset_version
|
|
|
|
@export_onnx_opset_version.setter
|
|
def export_onnx_opset_version(self, value: int):
|
|
supported_versions = range(
|
|
_constants.ONNX_MIN_OPSET, _constants.ONNX_MAX_OPSET + 1
|
|
)
|
|
if value not in supported_versions:
|
|
raise ValueError(f"Unsupported ONNX opset version: {value}")
|
|
self._export_onnx_opset_version = value
|
|
|
|
@property
|
|
def in_onnx_export(self) -> bool:
|
|
"""Whether it is in the middle of ONNX export."""
|
|
return self._in_onnx_export
|
|
|
|
@in_onnx_export.setter
|
|
def in_onnx_export(self, value: bool):
|
|
if type(value) is not bool:
|
|
raise TypeError("in_onnx_export must be a boolean")
|
|
self._in_onnx_export = value
|
|
|
|
@property
|
|
def autograd_inlining(self) -> bool:
|
|
"""Whether Autograd must be inlined."""
|
|
return self._autograd_inlining
|
|
|
|
@autograd_inlining.setter
|
|
def autograd_inlining(self, value: bool):
|
|
if type(value) is not bool:
|
|
raise TypeError("autograd_inlining must be a boolean")
|
|
self._autograd_inlining = value
|
|
|
|
|
|
GLOBALS = _InternalGlobals()
|