Adds pretty print diagnostic logging. For example
```python
import io
import torch
from torch.onnx._internal import diagnostics
class CustomAdd(torch.autograd.Function):
@staticmethod
def forward(ctx, x, y):
return x + y
@staticmethod
def symbolic(g, x, y):
return g.op("custom::CustomAdd", x, y)
class M(torch.nn.Module):
def forward(self, x):
return CustomAdd.apply(x, x)
# trigger warning for missing shape inference.
# rule = diagnostics.rules.node_missing_onnx_shape_inference
torch.onnx.export(M(), torch.randn(3, 4), io.BytesIO())
```
By default, observe minimum summary of diagnostics
```
========= Diagnostic Run torch.onnx.export version 1.14.0a0+git90a69c5 =========
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 3 WARNING 0 ERROR ========================
3 WARNING were not printed due to the log level.
```
Adjusting the `verbose` and `level` argument.
```python
diagnostics.engine.pretty_print(verbose=True, level=diagnostics.levels.WARNING)
```
Prints full log.
```
=============================== 1 Diagnostic Run ===============================
========= Diagnostic Run torch.onnx.export version 1.14.0a0+git90a69c5 =========
verbose: True, log level: Level.WARNING
======================= 0 NONE 0 NOTE 3 WARNING 0 ERROR ========================
WARNING: node-missing-onnx-shape-inference
==========================================
The shape inference of custom::CustomAdd type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.
--------------------------- Stack: Python call stack ---------------------------
frame: diagnostic = ExportDiagnostic(rule, level, message, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_internal/diagnostics/_diagnostic.py:151
frame: n, utils._params_dict, GLOBALS.export_onnx_opset_version /home/bowbao/pytorch_dev/torch/onnx/_patch_torch.py:82
frame: <@beartype(torch.onnx._patch_torch._graph_op) at 0x7f62184b6710>:78
frame: return beartyped(*args, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_internal/_beartype.py:81
frame: return function(*args, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_deprecation.py:30
frame: return g.op("custom::CustomAdd", x, y) test_pretty_print.py:14
frame: return symbolic_fn(g, *args) /home/bowbao/pytorch_dev/torch/onnx/utils.py:1716
frame: return beartyped(*args, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_internal/_beartype.py:81
frame: graph = _C._jit_pass_onnx(graph, operator_export_type) /home/bowbao/pytorch_dev/torch/onnx/utils.py:663
frame: <@beartype(torch.onnx.utils._optimize_graph) at 0x7f62180e05f0>:85
frame: return beartyped(*args, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_internal/_beartype.py:81
frame: module=module, /home/bowbao/pytorch_dev/torch/onnx/utils.py:1123
frame: return beartyped(*args, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_internal/_beartype.py:81
frame: dynamic_axes=dynamic_axes, /home/bowbao/pytorch_dev/torch/onnx/utils.py:1539
frame: return beartyped(*args, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_internal/_beartype.py:81
frame: export_modules_as_functions=export_modules_as_functions, /home/bowbao/pytorch_dev/torch/onnx/utils.py:519
frame: <@beartype(torch.onnx.utils.export) at 0x7f62180e0170>:347
frame: return beartyped(*args, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_internal/_beartype.py:81
frame: torch.onnx.export(M(), torch.randn(3, 4), io.BytesIO()) test_pretty_print.py:22
---------------------------- Stack: C++ call stack -----------------------------
frame: (<unknown frame>)
frame: (<unknown function> + 0x88411b (0x7f625b36011b in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so))
frame: (torch::jit::UpdateReliable(torch::jit::Value*, std::pair<bool, bool> const&) + 0x7d3 (0x7f625b351743 in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so))
frame: (torch::jit::UpdateReliable(torch::jit::Node*) + 0x4f (0x7f625b35198f in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so))
frame: (torch::jit::ONNXShapeTypeInference(torch::jit::Node*, std::map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, c10::IValue, std::less<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, c10::IValue> > > const&, int) + 0xac9 (0x7f625b357179 in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so))
frame: (<unknown function> + 0xabd026 (0x7f625b599026 in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so))
frame: (<unknown function> + 0x3c0fda (0x7f625ae9cfda in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so))
frame: (<unknown frame>)
WARNING: node-missing-onnx-shape-inference
==========================================
The shape inference of custom::CustomAdd type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.
--------------------------- Stack: Python call stack ---------------------------
frame: diagnostic = ExportDiagnostic(rule, level, message, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_internal/diagnostics/_diagnostic.py:151
frame: graph, params_dict, GLOBALS.export_onnx_opset_version /home/bowbao/pytorch_dev/torch/onnx/utils.py:688
frame: <@beartype(torch.onnx.utils._optimize_graph) at 0x7f62180e05f0>:85
frame: return beartyped(*args, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_internal/_beartype.py:81
frame: module=module, /home/bowbao/pytorch_dev/torch/onnx/utils.py:1123
frame: return beartyped(*args, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_internal/_beartype.py:81
frame: dynamic_axes=dynamic_axes, /home/bowbao/pytorch_dev/torch/onnx/utils.py:1539
frame: return beartyped(*args, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_internal/_beartype.py:81
frame: export_modules_as_functions=export_modules_as_functions, /home/bowbao/pytorch_dev/torch/onnx/utils.py:519
frame: <@beartype(torch.onnx.utils.export) at 0x7f62180e0170>:347
frame: return beartyped(*args, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_internal/_beartype.py:81
frame: torch.onnx.export(M(), torch.randn(3, 4), io.BytesIO()) test_pretty_print.py:22
---------------------------- Stack: C++ call stack -----------------------------
frame: (<unknown frame>)
frame: (<unknown function> + 0x88411b (0x7f625b36011b in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so))
frame: (torch::jit::UpdateReliable(torch::jit::Value*, std::pair<bool, bool> const&) + 0x7d3 (0x7f625b351743 in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so))
frame: (torch::jit::UpdateReliable(torch::jit::Node*) + 0x4f (0x7f625b35198f in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so))
frame: (torch::jit::ONNXShapeTypeInference(torch::jit::Node*, std::map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, c10::IValue, std::less<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, c10::IValue> > > const&, int) + 0xac9 (0x7f625b357179 in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so))
frame: (<unknown function> + 0x87d6d1 (0x7f625b3596d1 in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so))
frame: (torch::jit::ONNXShapeTypeInference(std::shared_ptr<torch::jit::Graph>&, std::map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, c10::IValue, std::less<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, c10::IValue> > > const&, int) + 0x33 (0x7f625b359cf3 in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so))
frame: (<unknown function> + 0xabdbae (0x7f625b599bae in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so))
frame: (<unknown function> + 0x3c0fda (0x7f625ae9cfda in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so))
frame: (<unknown frame>)
WARNING: node-missing-onnx-shape-inference
==========================================
The shape inference of custom::CustomAdd type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.
--------------------------- Stack: Python call stack ---------------------------
frame: diagnostic = ExportDiagnostic(rule, level, message, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_internal/diagnostics/_diagnostic.py:151
frame: graph, params_dict, GLOBALS.export_onnx_opset_version /home/bowbao/pytorch_dev/torch/onnx/utils.py:1179
frame: return beartyped(*args, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_internal/_beartype.py:81
frame: dynamic_axes=dynamic_axes, /home/bowbao/pytorch_dev/torch/onnx/utils.py:1539
frame: return beartyped(*args, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_internal/_beartype.py:81
frame: export_modules_as_functions=export_modules_as_functions, /home/bowbao/pytorch_dev/torch/onnx/utils.py:519
frame: <@beartype(torch.onnx.utils.export) at 0x7f62180e0170>:347
frame: return beartyped(*args, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_internal/_beartype.py:81
frame: torch.onnx.export(M(), torch.randn(3, 4), io.BytesIO()) test_pretty_print.py:22
---------------------------- Stack: C++ call stack -----------------------------
frame: (<unknown frame>)
frame: (<unknown function> + 0x88411b (0x7f625b36011b in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so))
frame: (torch::jit::UpdateReliable(torch::jit::Value*, std::pair<bool, bool> const&) + 0x7d3 (0x7f625b351743 in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so))
frame: (torch::jit::UpdateReliable(torch::jit::Node*) + 0x4f (0x7f625b35198f in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so))
frame: (torch::jit::ONNXShapeTypeInference(torch::jit::Node*, std::map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, c10::IValue, std::less<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, c10::IValue> > > const&, int) + 0xac9 (0x7f625b357179 in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so))
frame: (<unknown function> + 0x87d6d1 (0x7f625b3596d1 in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so))
frame: (torch::jit::ONNXShapeTypeInference(std::shared_ptr<torch::jit::Graph>&, std::map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, c10::IValue, std::less<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, c10::IValue> > > const&, int) + 0x33 (0x7f625b359cf3 in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so))
frame: (<unknown function> + 0xabdbae (0x7f625b599bae in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so))
frame: (<unknown function> + 0x3c0fda (0x7f625ae9cfda in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so))
frame: (<unknown frame>)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88261
Approved by: https://github.com/abock, https://github.com/justinchuby
|
||
|---|---|---|
| .. | ||
| _internal | ||
| __init__.py | ||
| _constants.py | ||
| _deprecation.py | ||
| _experimental.py | ||
| _exporter_states.py | ||
| _globals.py | ||
| _onnx_supported_ops.py | ||
| _patch_torch.py | ||
| _type_utils.py | ||
| errors.py | ||
| operators.py | ||
| README.md | ||
| symbolic_caffe2.py | ||
| symbolic_helper.py | ||
| symbolic_opset7.py | ||
| symbolic_opset8.py | ||
| symbolic_opset9.py | ||
| symbolic_opset10.py | ||
| symbolic_opset11.py | ||
| symbolic_opset12.py | ||
| symbolic_opset13.py | ||
| symbolic_opset14.py | ||
| symbolic_opset15.py | ||
| symbolic_opset16.py | ||
| symbolic_opset17.py | ||
| utils.py | ||
| verification.py | ||
torch.onnx
Torch->ONNX converter / exporter.
- User-facing docs: https://pytorch.org/docs/master/onnx.html
- Developer docs: https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter
Read the following if you are contributing to
torch.onnx
Symbolic functions Opsets
Opset 9 is the base version. It is selected as the base version because
- It is the first opset version supported by PyTorch export.
- Opset 9 is more robust than previous opset versions. Opset versions like 7/8 have limitations that certain basic operators cannot be expressed in ONNX. Instead of basing on these limitations, we chose to handle them as special cases separately.
Backward support for opset versions beyond opset 7 is not in our roadmap.
For opset versions other than 9, by default they will inherit the symbolic functions defined in symbolic_opset9.py.
To extend support for updated operators in different opset versions on top of opset 9, simply add the updated symbolic functions in the respective symbolic_opset{version}.py file. Checkout topk in symbolic_opset10.py, and upsample_nearest2d in symbolic_opset8.py for example.
Editing Symbolic Files
- Use the internal
registration.onnx_symbolicdecorator to register a new symbolic function. Search fordef reshape(g, self, shape):to see an example. - Parameter names must exactly match the names in aten/src/ATen/native/native_functions.yaml, because dispatch is done with keyword arguments.
- Looking for inplace ops? They're detected by
_jit_pass_onnx_remove_inplace_ops_for_onnx, and transparently dispatched to their non inplace versions in "run_symbolic_function". See Note Export inplace - Required: Annotate new symbolic functions with type annotations and decorate
with
@_beartype.beartypeto enable runtime type checking.@_beartype.beartypeshould typically be the closest to the function to ensure proper type checking.
A note on Tensor types
In general, we should avoid depending on the type of Tensor Values contained within the trace graph. However, this is sometimes unavoidable (due to ONNX spec requirements, etc). The TensorType object has accessors for these properties that return the property if it is statically known and return nullopt otherwise.
In general, we should prefer to rely on the least specific information possible. For example, not relying on tensor properties at all is better than relying on the number of dimensions which is better than relying on concrete shapes. Doing so will make the export symbolics more robust to different graphs.
Extra context for symbolic functions
The first argument of a symbolic function is always a GraphContext object.
GraphContext contains all methods defined in a torch.Graph object and context
for the symbolic function.
In general, symbolic functions only require inputs and attributes to
the original node. An example of a symbolic function needing context is
prim::Loop. It needs access to the sub-block of the original node.
Export inplace
It would be better for us to export inplace annotations, than to not export them, since it is useful information that can help the target of an ONNX export export more efficiently. However, ONNX doesn't currently formalize inplace. Fortunately, it's sound to drop inplace annotations, but we are losing information this way.
Pointwise by scalar
What happens if you add a tensor with a constant (e.g., x + 2)? There are some moving parts to implementing the ONNX translation in this case:
-
By the time we get the scalar in a symbolic function here, it is no longer a Python long/float, but a PyTorch tensor with
numel == 1(eventually, we want it to be a zero dim tensor but this change has not happened yet.) However, the type of this scalar is exactly what the user wrote in Python, which may not match the tensor it is being added to. PyTorch will do implicit conversions on scalars; however, ONNX will not, so we must do the conversion ourselves. This is whatsymbolic_helper._if_scalar_type_as()and_jit_pass_onnx_scalar_type_analysisdoes. -
Dispatch to these functions takes advantage an outrageous coincidence between the tensor and scalar name. When we add two tensors together, you get the dispatch:
add(*[self, other], **{"alpha": alpha})
When you add a tensor and a scalar, you get the dispatch:
add(*[self], **{"other": other, "alpha": alpha})
By having the argument name line up with the name of the scalar attribute if it exists, we can write a single function for both overloads.