mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/59534 Update autograd export docs Test Plan: Imported from OSS Reviewed By: nikithamalgifb, ansley Differential Revision: D29046606 Pulled By: SplitInfinity fbshipit-source-id: 36057f6bdfd3e5c071dbca05d327de7952904120 Co-authored-by: neginraoof <neginmr@utexas.edu>
1253 lines
44 KiB
ReStructuredText
1253 lines
44 KiB
ReStructuredText
torch.onnx
|
||
============
|
||
|
||
.. contents:: :local:
|
||
|
||
.. automodule:: torch.onnx
|
||
|
||
Example: End-to-end AlexNet from PyTorch to ONNX
|
||
------------------------------------------------
|
||
|
||
Here is a simple script which exports a pretrained AlexNet as defined in
|
||
torchvision into ONNX. It runs a single round of inference and then
|
||
saves the resulting traced model to ``alexnet.onnx``::
|
||
|
||
import torch
|
||
import torchvision
|
||
|
||
dummy_input = torch.randn(10, 3, 224, 224, device='cuda')
|
||
model = torchvision.models.alexnet(pretrained=True).cuda()
|
||
|
||
# Providing input and output names sets the display names for values
|
||
# within the model's graph. Setting these does not change the semantics
|
||
# of the graph; it is only for readability.
|
||
#
|
||
# The inputs to the network consist of the flat list of inputs (i.e.
|
||
# the values you would pass to the forward() method) followed by the
|
||
# flat list of parameters. You can partially specify names, i.e. provide
|
||
# a list here shorter than the number of inputs to the model, and we will
|
||
# only set that subset of names, starting from the beginning.
|
||
input_names = [ "actual_input_1" ] + [ "learned_%d" % i for i in range(16) ]
|
||
output_names = [ "output1" ]
|
||
|
||
torch.onnx.export(model, dummy_input, "alexnet.onnx", verbose=True, input_names=input_names, output_names=output_names)
|
||
|
||
The resulting ``alexnet.onnx`` is a binary protobuf file which contains both
|
||
the network structure and parameters of the model you exported
|
||
(in this case, AlexNet). The keyword argument ``verbose=True`` causes the
|
||
exporter to print out a human-readable representation of the network::
|
||
|
||
# These are the inputs and parameters to the network, which have taken on
|
||
# the names we specified earlier.
|
||
graph(%actual_input_1 : Float(10, 3, 224, 224)
|
||
%learned_0 : Float(64, 3, 11, 11)
|
||
%learned_1 : Float(64)
|
||
%learned_2 : Float(192, 64, 5, 5)
|
||
%learned_3 : Float(192)
|
||
# ---- omitted for brevity ----
|
||
%learned_14 : Float(1000, 4096)
|
||
%learned_15 : Float(1000)) {
|
||
# Every statement consists of some output tensors (and their types),
|
||
# the operator to be run (with its attributes, e.g., kernels, strides,
|
||
# etc.), its input tensors (%actual_input_1, %learned_0, %learned_1)
|
||
%17 : Float(10, 64, 55, 55) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[11, 11], pads=[2, 2, 2, 2], strides=[4, 4]](%actual_input_1, %learned_0, %learned_1), scope: AlexNet/Sequential[features]/Conv2d[0]
|
||
%18 : Float(10, 64, 55, 55) = onnx::Relu(%17), scope: AlexNet/Sequential[features]/ReLU[1]
|
||
%19 : Float(10, 64, 27, 27) = onnx::MaxPool[kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[2, 2]](%18), scope: AlexNet/Sequential[features]/MaxPool2d[2]
|
||
# ---- omitted for brevity ----
|
||
%29 : Float(10, 256, 6, 6) = onnx::MaxPool[kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[2, 2]](%28), scope: AlexNet/Sequential[features]/MaxPool2d[12]
|
||
# Dynamic means that the shape is not known. This may be because of a
|
||
# limitation of our implementation (which we would like to fix in a
|
||
# future release) or shapes which are truly dynamic.
|
||
%30 : Dynamic = onnx::Shape(%29), scope: AlexNet
|
||
%31 : Dynamic = onnx::Slice[axes=[0], ends=[1], starts=[0]](%30), scope: AlexNet
|
||
%32 : Long() = onnx::Squeeze[axes=[0]](%31), scope: AlexNet
|
||
%33 : Long() = onnx::Constant[value={9216}](), scope: AlexNet
|
||
# ---- omitted for brevity ----
|
||
%output1 : Float(10, 1000) = onnx::Gemm[alpha=1, beta=1, broadcast=1, transB=1](%45, %learned_14, %learned_15), scope: AlexNet/Sequential[classifier]/Linear[6]
|
||
return (%output1);
|
||
}
|
||
|
||
You can also verify the protobuf using the `ONNX <https://github.com/onnx/onnx/>`_ library.
|
||
You can install ``ONNX`` with conda::
|
||
|
||
conda install -c conda-forge onnx
|
||
|
||
Then, you can run::
|
||
|
||
import onnx
|
||
|
||
# Load the ONNX model
|
||
model = onnx.load("alexnet.onnx")
|
||
|
||
# Check that the IR is well formed
|
||
onnx.checker.check_model(model)
|
||
|
||
# Print a human readable representation of the graph
|
||
onnx.helper.printable_graph(model.graph)
|
||
|
||
To run the exported script with `caffe2 <https://caffe2.ai/>`_, you will need to install `caffe2`: If you don't have one already, Please `follow the install instructions <https://caffe2.ai/docs/getting-started.html>`_.
|
||
|
||
Once these are installed, you can use the backend for Caffe2::
|
||
|
||
# ...continuing from above
|
||
import caffe2.python.onnx.backend as backend
|
||
import numpy as np
|
||
|
||
rep = backend.prepare(model, device="CUDA:0") # or "CPU"
|
||
# For the Caffe2 backend:
|
||
# rep.predict_net is the Caffe2 protobuf for the network
|
||
# rep.workspace is the Caffe2 workspace for the network
|
||
# (see the class caffe2.python.onnx.backend.Workspace)
|
||
outputs = rep.run(np.random.randn(10, 3, 224, 224).astype(np.float32))
|
||
# To run networks with more than one input, pass a tuple
|
||
# rather than a single numpy ndarray.
|
||
print(outputs[0])
|
||
|
||
You can also run the exported model with `ONNX Runtime <https://github.com/microsoft/onnxruntime>`_,
|
||
you will need to install `ONNX Runtime`: please `follow these instructions <https://github.com/microsoft/onnxruntime#installation>`_.
|
||
|
||
Once these are installed, you can use the backend for ONNX Runtime::
|
||
|
||
# ...continuing from above
|
||
import onnxruntime as ort
|
||
|
||
ort_session = ort.InferenceSession('alexnet.onnx')
|
||
|
||
outputs = ort_session.run(None, {'actual_input_1': np.random.randn(10, 3, 224, 224).astype(np.float32)})
|
||
|
||
print(outputs[0])
|
||
|
||
Here is another `tutorial of exporting the SuperResolution model to ONNX. <https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html>`_.
|
||
|
||
In the future, there will be backends for other frameworks as well.
|
||
|
||
Tracing vs Scripting
|
||
--------------------
|
||
|
||
The ONNX exporter can be both *trace-based* and *script-based* exporter.
|
||
|
||
* *trace-based* means that it operates by executing your model once, and exporting the operators which
|
||
were actually run during this run. This means that if your model is
|
||
dynamic, e.g., changes behavior depending on input data, the export
|
||
won't be accurate. Similarly, a trace is likely to be valid only
|
||
for a specific input size (which is one reason why we require explicit inputs
|
||
on tracing.) We recommend examining the model trace and making sure
|
||
the traced operators look reasonable. If your model contains control flows like
|
||
for loops and if conditions, *trace-based* exporter will unroll the loops and if conditions,
|
||
exporting a static graph that is exactly the same as this run. If you want
|
||
to export your model with dynamic control flows, you will need to use the *script-based* exporter.
|
||
|
||
* *script-based* means that the model you are trying to export is a `ScriptModule <jit.html>`_.
|
||
`ScriptModule` is the core data structure in `TorchScript`, and `TorchScript` is a subset of Python language,
|
||
that creates serializable and optimizable models from PyTorch code.
|
||
|
||
We allow mixing tracing and scripting. You can compose tracing and scripting to suit the particular requirements
|
||
of a part of a model. Checkout this example: ::
|
||
|
||
import torch
|
||
|
||
# Trace-based only
|
||
|
||
class LoopModel(torch.nn.Module):
|
||
def forward(self, x, y):
|
||
for i in range(y):
|
||
x = x + i
|
||
return x
|
||
|
||
model = LoopModel()
|
||
dummy_input = torch.ones(2, 3, dtype=torch.long)
|
||
loop_count = torch.tensor(5, dtype=torch.long)
|
||
|
||
torch.onnx.export(model, (dummy_input, loop_count), 'loop.onnx', verbose=True)
|
||
|
||
With *trace-based* exporter, we get the result ONNX graph which unrolls the for loop: ::
|
||
|
||
graph(%0 : Long(2, 3),
|
||
%1 : Long()):
|
||
%2 : Tensor = onnx::Constant[value={1}]()
|
||
%3 : Tensor = onnx::Add(%0, %2)
|
||
%4 : Tensor = onnx::Constant[value={2}]()
|
||
%5 : Tensor = onnx::Add(%3, %4)
|
||
%6 : Tensor = onnx::Constant[value={3}]()
|
||
%7 : Tensor = onnx::Add(%5, %6)
|
||
%8 : Tensor = onnx::Constant[value={4}]()
|
||
%9 : Tensor = onnx::Add(%7, %8)
|
||
return (%9)
|
||
|
||
To utilize *script-based* exporter for capturing the dynamic loop,
|
||
we can write the loop in script, and call it from the regular nn.Module: ::
|
||
|
||
# Mixing tracing and scripting
|
||
|
||
@torch.jit.script
|
||
def loop(x, y):
|
||
for i in range(int(y)):
|
||
x = x + i
|
||
return x
|
||
|
||
class LoopModel2(torch.nn.Module):
|
||
def forward(self, x, y):
|
||
return loop(x, y)
|
||
|
||
model = LoopModel2()
|
||
dummy_input = torch.ones(2, 3, dtype=torch.long)
|
||
loop_count = torch.tensor(5, dtype=torch.long)
|
||
torch.onnx.export(model, (dummy_input, loop_count), 'loop.onnx', verbose=True,
|
||
input_names=['input_data', 'loop_range'])
|
||
|
||
Now the exported ONNX graph becomes: ::
|
||
|
||
graph(%input_data : Long(2, 3),
|
||
%loop_range : Long()):
|
||
%2 : Long() = onnx::Constant[value={1}](), scope: LoopModel2/loop
|
||
%3 : Tensor = onnx::Cast[to=9](%2)
|
||
%4 : Long(2, 3) = onnx::Loop(%loop_range, %3, %input_data), scope: LoopModel2/loop
|
||
block0(%i.1 : Long(), %cond : bool, %x.6 : Long(2, 3)):
|
||
%8 : Long(2, 3) = onnx::Add(%x.6, %i.1), scope: LoopModel2/loop
|
||
%9 : Tensor = onnx::Cast[to=9](%2)
|
||
-> (%9, %8)
|
||
return (%4)
|
||
|
||
The dynamic control flow is captured correctly. We can verify in backends with different loop range. ::
|
||
|
||
import caffe2.python.onnx.backend as backend
|
||
import numpy as np
|
||
import onnx
|
||
model = onnx.load('loop.onnx')
|
||
|
||
rep = backend.prepare(model)
|
||
outputs = rep.run((dummy_input.numpy(), np.array(9).astype(np.int64)))
|
||
print(outputs[0])
|
||
#[[37 37 37]
|
||
# [37 37 37]]
|
||
|
||
|
||
import onnxruntime as ort
|
||
ort_sess = ort.InferenceSession('loop.onnx')
|
||
outputs = ort_sess.run(None, {'input_data': dummy_input.numpy(),
|
||
'loop_range': np.array(9).astype(np.int64)})
|
||
print(outputs)
|
||
#[array([[37, 37, 37],
|
||
# [37, 37, 37]], dtype=int64)]
|
||
|
||
|
||
To avoid exporting a variable scalar tensor as a fixed value constant as part of the ONNX model, please
|
||
avoid use of ``torch.Tensor.item()``. Torch supports implicit cast of single-element tensors to numbers.
|
||
E.g.: ::
|
||
|
||
class LoopModel(torch.nn.Module):
|
||
def forward(self, x, y):
|
||
res = []
|
||
arr = x.split(2, 0)
|
||
for i in range(int(y)):
|
||
res += [arr[i].sum(0, False)]
|
||
return torch.stack(res)
|
||
|
||
model = torch.jit.script(LoopModel())
|
||
inputs = (torch.randn(16), torch.tensor(8))
|
||
|
||
out = model(*inputs)
|
||
torch.onnx.export(model, inputs, 'loop_and_list.onnx', opset_version=11, example_outputs=out)
|
||
|
||
|
||
Type Annotations
|
||
--------------------------------
|
||
TorchScript only supports a subset of Python types. You can find more details about type annotation
|
||
`here <https://pytorch.org/docs/stable/jit_language_reference.html#id8>`_.
|
||
|
||
Due to optimization purposes, TorchScript only supports variables with single static types for script functions.
|
||
By default, each variable is assumed to be Tensor. If an argument to a ScriptModule function is not Tensor,
|
||
its type should be specified using MyPy-style annotations.
|
||
|
||
::
|
||
|
||
import torch
|
||
|
||
class Module(torch.nn.Module):
|
||
def forward(self, x, tup):
|
||
# type: (int, Tuple[Tensor, Tensor]) -> Tensor
|
||
t0, t1 = tup
|
||
return t0 + t1 + x
|
||
|
||
If the type annotation is not specified, TorchScript compiler fails with the runtime error below.
|
||
|
||
::
|
||
|
||
RuntimeError:
|
||
Tensor (inferred) cannot be used as a tuple:
|
||
File <filename>
|
||
def forward(self, x, tup):
|
||
t0, t1 = tup
|
||
~~~ <--- HERE
|
||
return t0 + t1 + x
|
||
|
||
|
||
Write PyTorch model in Torch way
|
||
--------------------------------
|
||
Avoid using numpy
|
||
~~~~~~~~~~~~~~~~~
|
||
|
||
PyTorch models can be written using numpy manipulations, but this is not proper when we convert to the ONNX model.
|
||
For the trace-based exporter, tracing treats the numpy values as the constant node,
|
||
therefore it calculates the wrong result if we change the input.
|
||
So the PyTorch model need implement using torch operators.
|
||
For example, do not use numpy operators on numpy tensors: ::
|
||
|
||
np.concatenate((x, y, z), axis=1)
|
||
|
||
do not convert to numpy types: ::
|
||
|
||
y = x.astype(np.int)
|
||
|
||
Always use torch tensors and torch operators: torch.concat, etc.
|
||
In addition, Dropout layer need defined in init function so that inferencing can handle it properly, i.e., ::
|
||
|
||
class MyModule(nn.Module):
|
||
def __init__(self):
|
||
self.dropout = nn.Dropout(0.5)
|
||
|
||
def forward(self, x):
|
||
x = self.dropout(x)
|
||
|
||
Avoid using .data field
|
||
~~~~~~~~~~~~~~~~~~~~~~~
|
||
|
||
The .data field is an old field that is kept for backward compatibility and should be avoided when writing models.
|
||
It's usage is dangerous and can make computations wrong, furthermore it can produce an incorrect trace graph and
|
||
therefore an incorrect ONNX graph. A safer alternative is to use .detach() instead.
|
||
|
||
Using dictionaries to handle Named Arguments as model inputs
|
||
------------------------------------------------------------
|
||
|
||
There are two ways to handle models which consist of named parameters or keyword arguments as inputs:
|
||
|
||
* The first method is to pass all the inputs in the same order as required by the model and pass None
|
||
values for the keyword arguments that do not require a value to be passed
|
||
|
||
* The second and more intuitive method is to represent the keyword arguments as key-value pairs where
|
||
the key represents the name of the argument in the model signature and the value represents the value
|
||
of the argument to be passed
|
||
|
||
For example, in the model: ::
|
||
|
||
class Model(torch.nn.Module):
|
||
def forward(self, x, y=None, z=None):
|
||
if y is not None:
|
||
return x + y
|
||
if z is not None:
|
||
return x + z
|
||
return x
|
||
m = Model()
|
||
x = torch.randn(2, 3)
|
||
z = torch.randn(2, 3)
|
||
|
||
There are two ways of exporting the model:
|
||
|
||
* Not using a dictionary for the keyword arguments and passing all the inputs in the same order
|
||
as required by the model ::
|
||
|
||
torch.onnx.export(model, (x, None, z), ‘test.onnx’)
|
||
|
||
* Using a dictionary to represent the keyword arguments. This dictionary is always passed in
|
||
addition to the non-keyword arguments and is always the last argument in the args tuple. ::
|
||
|
||
torch.onnx.export(model, (x, {'y': None, 'z': z}), ‘test.onnx’)
|
||
|
||
For cases in which there are no keyword arguments, models can be exported with either an
|
||
empty or no dictionary. For example, ::
|
||
|
||
torch.onnx.export(model, (x, {}), ‘test.onnx’)
|
||
or
|
||
torch.onnx.export(model, (x, ), ‘test.onnx’)
|
||
|
||
An exception to this rule are cases in which the last input is also of a dictionary type.
|
||
In these cases it is mandatory to have an empty dictionary as the last argument in the
|
||
args tuple. For example, ::
|
||
|
||
class Model(torch.nn.Module):
|
||
def forward(self, k, x):
|
||
...
|
||
return x
|
||
m = Model()
|
||
k = torch.randn(2, 3)
|
||
x = {torch.tensor(1.): torch.randn(2, 3)}
|
||
|
||
Without the presence of the empty dictionary, the export call assumes that the
|
||
‘x’ input is intended to represent the optional dictionary consisting of named arguments.
|
||
In order to prevent this from being an issue a constraint is placed to provide an empty
|
||
dictionary as the last input in the tuple args in such cases.
|
||
The new call would look like this. ::
|
||
|
||
torch.onnx.export(model, (k, x, {}), ‘test.onnx’)
|
||
|
||
|
||
Indexing
|
||
--------
|
||
|
||
Tensor indexing in PyTorch is very flexible and complicated.
|
||
There are two categories of indexing. Both are largely supported in exporting today.
|
||
If you are experiencing issues exporting indexing that belongs to the supported patterns below,
|
||
please double check that you are exporting with the latest opset (opset_version=12).
|
||
|
||
Getter
|
||
~~~~~~
|
||
|
||
This type of indexing occurs on the RHS. Export is supported for ONNX opset version >= 9. E.g.: ::
|
||
|
||
data = torch.randn(3, 4)
|
||
index = torch.tensor([1, 2])
|
||
|
||
# RHS indexing is supported in ONNX opset >= 11.
|
||
class RHSIndexing(torch.nn.Module):
|
||
def forward(self, data, index):
|
||
return data[index]
|
||
|
||
out = RHSIndexing()(data, index)
|
||
|
||
torch.onnx.export(RHSIndexing(), (data, index), 'indexing.onnx', opset_version=9)
|
||
|
||
# onnxruntime
|
||
import onnxruntime
|
||
sess = onnxruntime.InferenceSession('indexing.onnx')
|
||
out_ort = sess.run(None, {
|
||
sess.get_inputs()[0].name: data.numpy(),
|
||
sess.get_inputs()[1].name: index.numpy(),
|
||
})
|
||
|
||
assert torch.all(torch.eq(out, torch.tensor(out_ort)))
|
||
|
||
Below is the list of supported patterns for RHS indexing. ::
|
||
|
||
# Scalar indices
|
||
data[0, 1]
|
||
|
||
# Slice indices
|
||
data[:3]
|
||
|
||
# Tensor indices
|
||
data[torch.tensor([[1, 2], [2, 3]])]
|
||
data[torch.tensor([2, 3]), torch.tensor([1, 2])]
|
||
data[torch.tensor([[1, 2], [2, 3]]), torch.tensor([2, 3])]
|
||
data[torch.tensor([2, 3]), :, torch.tensor([1, 2])]
|
||
|
||
# Ellipsis followed by tensor indexing
|
||
# Not supported in scripting
|
||
# i.e. torch.jit.script(model) will fail if model contains this pattern.
|
||
# Export is supported under tracing
|
||
# i.e. torch.onnx.export(model)
|
||
data[..., torch.tensor([2, 1])]
|
||
|
||
# The combination of above
|
||
data[2, ..., torch.tensor([2, 1, 3]), 2:4, torch.tensor([[1], [2]])]
|
||
|
||
# Boolean mask (supported for ONNX opset version >= 11)
|
||
data[data != 1]
|
||
|
||
And below is the list of unsupported patterns for RHS indexing. ::
|
||
|
||
# Tensor indices that includes negative values.
|
||
data[torch.tensor([[1, 2], [2, -3]]), torch.tensor([-2, 3])]
|
||
|
||
Setter
|
||
~~~~~~
|
||
|
||
In code, this type of indexing occurs on the LHS.
|
||
Export is supported for ONNX opset version >= 11. E.g.: ::
|
||
|
||
data = torch.zeros(3, 4)
|
||
new_data = torch.arange(4).to(torch.float32)
|
||
|
||
# LHS indexing is supported in ONNX opset >= 11.
|
||
class LHSIndexing(torch.nn.Module):
|
||
def forward(self, data, new_data):
|
||
data[1] = new_data
|
||
return data
|
||
|
||
out = LHSIndexing()(data, new_data)
|
||
|
||
data = torch.zeros(3, 4)
|
||
new_data = torch.arange(4).to(torch.float32)
|
||
torch.onnx.export(LHSIndexing(), (data, new_data), 'inplace_assign.onnx', opset_version=11)
|
||
|
||
# onnxruntime
|
||
import onnxruntime
|
||
sess = onnxruntime.InferenceSession('inplace_assign.onnx')
|
||
out_ort = sess.run(None, {
|
||
sess.get_inputs()[0].name: torch.zeros(3, 4).numpy(),
|
||
sess.get_inputs()[1].name: new_data.numpy(),
|
||
})
|
||
|
||
assert torch.all(torch.eq(out, torch.tensor(out_ort)))
|
||
|
||
Below is the list of supported patterns for LHS indexing. ::
|
||
|
||
# Scalar indices
|
||
data[0, 1] = new_data
|
||
|
||
# Slice indices
|
||
data[:3] = new_data
|
||
|
||
# Tensor indices
|
||
# If more than one tensor are used as indices, only consecutive 1-d tensor indices are supported.
|
||
data[torch.tensor([[1, 2], [2, 3]])] = new_data
|
||
data[torch.tensor([2, 3]), torch.tensor([1, 2])] = new_data
|
||
|
||
# Ellipsis followed by tensor indexing
|
||
# Not supported to export in script modules
|
||
# i.e. torch.onnx.export(torch.jit.script(model)) will fail if model contains this pattern.
|
||
# Export is supported under tracing
|
||
# i.e. torch.onnx.export(model)
|
||
data[..., torch.tensor([2, 1])] = new_data
|
||
|
||
# The combination of above
|
||
data[2, ..., torch.tensor([2, 1, 3]), 2:4] += update
|
||
|
||
# Boolean mask
|
||
data[data != 1] = new_data
|
||
|
||
And below is the list of unsupported patterns for LHS indexing. ::
|
||
|
||
# Multiple tensor indices if any has rank >= 2
|
||
data[torch.tensor([[1, 2], [2, 3]]), torch.tensor([2, 3])] = new_data
|
||
|
||
# Multiple tensor indices that are not consecutive
|
||
data[torch.tensor([2, 3]), :, torch.tensor([1, 2])] = new_data
|
||
|
||
# Tensor indices that includes negative values.
|
||
data[torch.tensor([1, -2]), torch.tensor([-2, 3])] = new_data
|
||
|
||
If you are experiencing issues exporting indexing that belongs to the above supported patterns, please double check that
|
||
you are exporting with the latest opset (opset_version=12).
|
||
|
||
TorchVision support
|
||
-------------------
|
||
|
||
All TorchVision models, except for quantized versions, are exportable to ONNX.
|
||
More details can be found in `TorchVision <torchvision/models.html>`_.
|
||
|
||
|
||
Limitations
|
||
-----------
|
||
|
||
* Only tuples, lists and Variables are supported as JIT inputs/outputs. Dictionaries and strings are also accepted
|
||
but their usage is not recommended. Users need to verify their dict inputs carefully, and keep in mind that
|
||
dynamic lookups are not available.
|
||
|
||
* PyTorch and ONNX backends(Caffe2, ONNX Runtime, etc) often have implementations of operators with some
|
||
numeric differences. Depending on model structure, these differences
|
||
may be negligible, but they can also cause major divergences in behavior
|
||
(especially on untrained models.) We allow Caffe2 to call directly to Torch implementations of operators, to
|
||
help you smooth over these differences when precision is important,
|
||
and to also document these differences.
|
||
|
||
Supported operators
|
||
-------------------
|
||
|
||
The following operators are supported:
|
||
|
||
* BatchNorm
|
||
* ConstantPadNd
|
||
* Conv
|
||
* Dropout
|
||
* Embedding (no optional arguments supported)
|
||
* EmbeddingBag
|
||
* FeatureDropout (training mode not supported)
|
||
* Index
|
||
* MaxPool1d
|
||
* MaxPool2d
|
||
* MaxPool3d
|
||
* RNN
|
||
* abs
|
||
* absolute
|
||
* acos
|
||
* adaptive_avg_pool1d
|
||
* adaptive_avg_pool2d
|
||
* adaptive_avg_pool3d
|
||
* adaptive_max_pool1d
|
||
* adaptive_max_pool2d
|
||
* adaptive_max_pool3d
|
||
* add (nonzero alpha not supported)
|
||
* addmm
|
||
* and
|
||
* arange
|
||
* argmax
|
||
* argmin
|
||
* asin
|
||
* atan
|
||
* avg_pool1d
|
||
* avg_pool2d
|
||
* avg_pool2d
|
||
* avg_pool3d
|
||
* as_strided
|
||
* baddbmm
|
||
* bitshift
|
||
* cat
|
||
* ceil
|
||
* celu
|
||
* clamp
|
||
* clamp_max
|
||
* clamp_min
|
||
* concat
|
||
* copy
|
||
* cos
|
||
* cumsum
|
||
* det
|
||
* dim_arange
|
||
* div
|
||
* dropout
|
||
* einsum
|
||
* elu
|
||
* empty
|
||
* empty_like
|
||
* eq
|
||
* erf
|
||
* exp
|
||
* expand
|
||
* expand_as
|
||
* eye
|
||
* flatten
|
||
* floor
|
||
* floor_divide
|
||
* frobenius_norm
|
||
* full
|
||
* full_like
|
||
* gather
|
||
* ge
|
||
* gelu
|
||
* glu
|
||
* group_norm
|
||
* gt
|
||
* hardsigmoid
|
||
* hardswish
|
||
* hardtanh
|
||
* im2col
|
||
* index_copy
|
||
* index_fill
|
||
* index_put
|
||
* index_select
|
||
* instance_norm
|
||
* interpolate
|
||
* isnan
|
||
* KLDivLoss
|
||
* layer_norm
|
||
* le
|
||
* leaky_relu
|
||
* len
|
||
* log
|
||
* log1p
|
||
* log2
|
||
* log_sigmoid
|
||
* log_softmax
|
||
* logdet
|
||
* logsumexp
|
||
* lt
|
||
* masked_fill
|
||
* masked_scatter
|
||
* masked_select
|
||
* max
|
||
* mean
|
||
* min
|
||
* mm
|
||
* mul
|
||
* multinomial
|
||
* narrow
|
||
* ne
|
||
* neg
|
||
* new_empty
|
||
* new_full
|
||
* new_zeros
|
||
* nll_loss
|
||
* nonzero
|
||
* norm
|
||
* ones
|
||
* ones_like
|
||
* or
|
||
* permute
|
||
* pixel_shuffle
|
||
* pow
|
||
* prelu (single weight shared among input channels not supported)
|
||
* prod
|
||
* rand
|
||
* randn
|
||
* randn_like
|
||
* reciprocal
|
||
* reflection_pad
|
||
* relu
|
||
* repeat
|
||
* replication_pad
|
||
* reshape
|
||
* reshape_as
|
||
* round
|
||
* rrelu
|
||
* rsqrt
|
||
* rsub
|
||
* scalar_tensor
|
||
* scatter
|
||
* scatter_add
|
||
* select
|
||
* selu
|
||
* sigmoid
|
||
* sign
|
||
* sin
|
||
* size
|
||
* slice
|
||
* softmax
|
||
* softplus
|
||
* sort
|
||
* split
|
||
* sqrt
|
||
* squeeze
|
||
* stack
|
||
* std
|
||
* sub (nonzero alpha not supported)
|
||
* sum
|
||
* t
|
||
* tan
|
||
* tanh
|
||
* threshold (non-zero threshold/non-zero value not supported)
|
||
* to
|
||
* topk
|
||
* transpose
|
||
* true_divide
|
||
* type_as
|
||
* unbind
|
||
* unfold (experimental support with ATen-Caffe2 integration)
|
||
* unique
|
||
* unsqueeze
|
||
* upsample_nearest1d
|
||
* upsample_nearest2d
|
||
* upsample_nearest3d
|
||
* view
|
||
* weight_norm
|
||
* where
|
||
* zeros
|
||
* zeros_like
|
||
|
||
The operator set above is sufficient to export the following models:
|
||
|
||
* AlexNet
|
||
* DCGAN
|
||
* DenseNet
|
||
* Inception (warning: this model is highly sensitive to changes in operator
|
||
implementation)
|
||
* ResNet
|
||
* SuperResolution
|
||
* VGG
|
||
* `word_language_model <https://github.com/pytorch/examples/tree/master/word_language_model>`_
|
||
|
||
Adding support for operators
|
||
----------------------------
|
||
|
||
Adding export support for operators is an *advance usage*.
|
||
|
||
To achieve this, developers need to touch the source code of PyTorch.
|
||
Please follow the `instructions <https://github.com/pytorch/pytorch#from-source>`_
|
||
for installing PyTorch from source.
|
||
If the wanted operator is standardized in ONNX, it should be easy to add
|
||
support for exporting such operator (adding a symbolic function for the operator).
|
||
To confirm whether the operator is standardized or not, please check the
|
||
`ONNX operator list <https://github.com/onnx/onnx/blob/master/docs/Operators.md>`_.
|
||
|
||
ATen operators
|
||
~~~~~~~~~~~~~~
|
||
|
||
If the operator is an ATen operator, which means you can find the declaration
|
||
of the function in ``torch/csrc/autograd/generated/VariableType.h``
|
||
(available in generated code in PyTorch install dir), you should add the symbolic
|
||
function in ``torch/onnx/symbolic_opset<version>.py`` and follow the instructions listed as below:
|
||
|
||
* Define the symbolic function in ``torch/onnx/symbolic_opset<version>.py``, for example
|
||
`torch/onnx/symbolic_opset9.py <https://github.com/pytorch/pytorch/blob/master/torch/onnx/symbolic_opset9.py>`_.
|
||
Make sure the function has the same name as the ATen operator/function
|
||
defined in ``VariableType.h``.
|
||
* The first parameter is always the exported ONNX graph.
|
||
Parameter names must EXACTLY match the names in ``VariableType.h``,
|
||
because dispatch is done with keyword arguments.
|
||
* Parameter ordering does NOT necessarily match what is in ``VariableType.h``,
|
||
tensors (inputs) are always first, then non-tensor arguments.
|
||
* In the symbolic function, if the operator is already standardized in ONNX,
|
||
we only need to create a node to represent the ONNX operator in the graph.
|
||
* If the input argument is a tensor, but ONNX asks for a scalar, we have to
|
||
explicitly do the conversion. The helper function ``_scalar`` can convert a
|
||
scalar tensor into a python scalar, and ``_if_scalar_type_as`` can turn a
|
||
Python scalar into a PyTorch tensor.
|
||
|
||
Non-ATen operators
|
||
~~~~~~~~~~~~~~~~~~
|
||
|
||
If the operator is a non-ATen operator, the symbolic function has to be
|
||
added in the corresponding PyTorch Function class. Please read the following
|
||
instructions:
|
||
|
||
* Create a symbolic function named ``symbolic`` in the corresponding Function class.
|
||
* The first parameter is always the exported ONNX graph.
|
||
* Parameter names except the first must EXACTLY match the names in ``forward``.
|
||
* The output tuple size must match the outputs of ``forward``.
|
||
* In the symbolic function, if the operator is already standardized in ONNX,
|
||
we just need to create a node to represent the ONNX operator in the graph.
|
||
|
||
Symbolic functions should be implemented in Python. All of these functions interact
|
||
with Python methods which are implemented via C++-Python bindings,
|
||
but intuitively the interface they provide looks like this::
|
||
|
||
|
||
def operator/symbolic(g, *inputs):
|
||
"""
|
||
Modifies Graph (e.g., using "op"), adding the ONNX operations representing
|
||
this PyTorch function, and returning a Value or tuple of Values specifying the
|
||
ONNX outputs whose values correspond to the original PyTorch return values
|
||
of the autograd Function (or None if an output is not supported by ONNX).
|
||
|
||
Args:
|
||
g (Graph): graph to write the ONNX representation into
|
||
inputs (Value...): list of values representing the variables which contain
|
||
the inputs for this function
|
||
"""
|
||
|
||
class Value(object):
|
||
"""Represents an intermediate tensor value computed in ONNX."""
|
||
def type(self):
|
||
"""Returns the Type of the value."""
|
||
|
||
class Type(object):
|
||
def sizes(self):
|
||
"""Returns a tuple of ints representing the shape of a tensor this describes."""
|
||
|
||
class Graph(object):
|
||
def op(self, opname, *inputs, **attrs):
|
||
"""
|
||
Create an ONNX operator 'opname', taking 'args' as inputs
|
||
and attributes 'kwargs' and add it as a node to the current graph,
|
||
returning the value representing the single output of this
|
||
operator (see the `outputs` keyword argument for multi-return
|
||
nodes).
|
||
|
||
The set of operators and the inputs/attributes they take
|
||
is documented at https://github.com/onnx/onnx/blob/master/docs/Operators.md
|
||
|
||
Args:
|
||
opname (string): The ONNX operator name, e.g., `Abs` or `Add`.
|
||
args (Value...): The inputs to the operator; usually provided
|
||
as arguments to the `symbolic` definition.
|
||
kwargs: The attributes of the ONNX operator, with keys named
|
||
according to the following convention: `alpha_f` indicates
|
||
the `alpha` attribute with type `f`. The valid type specifiers are
|
||
`f` (float), `i` (int), `s` (string) or `t` (Tensor). An attribute
|
||
specified with type float accepts either a single float, or a
|
||
list of floats (e.g., you would say `dims_i` for a `dims` attribute
|
||
that takes a list of integers).
|
||
outputs (int, optional): The number of outputs this operator returns;
|
||
by default an operator is assumed to return a single output.
|
||
If `outputs` is greater than one, this functions returns a tuple
|
||
of output `Value`, representing each output of the ONNX operator
|
||
in positional.
|
||
"""
|
||
|
||
The ONNX graph C++ definition is in ``torch/csrc/jit/ir/ir.h``.
|
||
|
||
Here is an example of handling missing symbolic function for ``elu`` operator.
|
||
We try to export the model and see the error message as below::
|
||
|
||
UserWarning: ONNX export failed on elu because torch.onnx.symbolic_opset9.elu does not exist
|
||
RuntimeError: ONNX export failed: Couldn't export operator elu
|
||
|
||
The export fails because PyTorch does not support exporting ``elu`` operator.
|
||
We find ``virtual Tensor elu(const Tensor & input, Scalar alpha, bool inplace) const override;``
|
||
in ``VariableType.h``. This means ``elu`` is an ATen operator.
|
||
We check the `ONNX operator list <https://github.com/onnx/onnx/blob/master/docs/Operators.md>`_,
|
||
and confirm that ``Elu`` is standardized in ONNX.
|
||
We add the following lines to ``symbolic_opset9.py``::
|
||
|
||
def elu(g, input, alpha, inplace=False):
|
||
return g.op("Elu", input, alpha_f=_scalar(alpha))
|
||
|
||
Now PyTorch is able to export ``elu`` operator.
|
||
|
||
There are more examples in
|
||
`symbolic_opset9.py <https://github.com/pytorch/pytorch/blob/master/torch/onnx/symbolic_opset9.py>`_,
|
||
`symbolic_opset10.py <https://github.com/pytorch/pytorch/blob/master/torch/onnx/symbolic_opset10.py>`_.
|
||
|
||
|
||
The interface for specifying operator definitions is experimental;
|
||
adventurous users should note that the APIs will probably
|
||
change in a future interface.
|
||
|
||
Autograd Function
|
||
~~~~~~~~~~~~~~~~~
|
||
|
||
Autograd functions can be used to compute operation results and gradients and save the history.
|
||
More information on extending the torch.autograd engine can be found on this `page <https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function>`_.
|
||
There are two ways to export a subclass of the torch.autograd.function.
|
||
|
||
Symbolic Static Method
|
||
^^^^^^^^^^^^^^^^^^^^^^
|
||
|
||
You can add a symbolic static method to your function class. The symbolic method should contain a set
|
||
of ONNX operators that represent operator's behavior in ONNX. Here's an example for adding symbolic to
|
||
the your autograd function: ::
|
||
|
||
|
||
class MyRelu(torch.autograd.Function):
|
||
@staticmethod
|
||
def forward(ctx, input):
|
||
ctx.save_for_backward(input)
|
||
return input.clamp(min=0)
|
||
|
||
@staticmethod
|
||
def backward(ctx, grad_output):
|
||
input, = ctx.saved_tensors
|
||
grad_input = grad_output.clone()
|
||
grad_input[input < 0] = 0
|
||
return grad_input
|
||
|
||
@staticmethod
|
||
def symbolic(g, self):
|
||
return g.op("Clip", self, g.op('Constant', value_t=torch.tensor(0, dtype=torch.float)))
|
||
|
||
class MyModel(torch.nn.Module):
|
||
def forward(self, x):
|
||
my_relu = MyRelu.apply
|
||
return my_relu(x)
|
||
|
||
input_tensor = torch.randn(2, 3, 224, 224, requires_grad=True)
|
||
torch.onnx.export(MyModel(), input_tensor, 'test.onnx', opset_version=12)
|
||
|
||
Export as Custom Operator
|
||
^^^^^^^^^^^^^^^^^^^^^^^^^
|
||
|
||
Alternatively, you can register a custom symbolic function to export the autograd function as custom operator.
|
||
This is designed for more advanced usage, because it gives you access to more info on the export symbolic side.
|
||
Autograd functions are emitted in the IR graph as ``prim::PythonOp`` nodes. The attribute ``name`` identifies the
|
||
original module name. The ``prim::PythonOp`` node object can be accessed by argument ``n``. The example below shows
|
||
how you can access ``requires_grad`` info of the inputs and outputs of the autograd function. ::
|
||
|
||
class MyClip(torch.autograd.Function):
|
||
@staticmethod
|
||
def forward(ctx, input, scalar):
|
||
ctx.save_for_backward(input)
|
||
return input.clamp(min=scalar)
|
||
|
||
class MyRelu(torch.autograd.Function):
|
||
@staticmethod
|
||
def forward(ctx, input):
|
||
ctx.save_for_backward(input)
|
||
return input.clamp(min=0)
|
||
|
||
class MyModule(torch.nn.Module):
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.clip = MyClip.apply
|
||
self.relu = MyRelu.apply
|
||
|
||
def forward(self, x):
|
||
h = self.clip(x, 2)
|
||
h = self.relu(h)
|
||
return h
|
||
|
||
def symbolic_pythonop(g, n, *args, **kwargs):
|
||
# print information
|
||
print('original node: ', n)
|
||
for i, out in enumerate(n.outputs()):
|
||
print('original output {}: {}, requires grad: {}'.format(i, out, out.requiresGrad()))
|
||
import torch.onnx.symbolic_helper as sym_helper
|
||
for i, arg in enumerate(args):
|
||
print('arg {}: {}, requires grad: {}'.format(i, arg, arg.requiresGrad() if sym_helper._is_value(arg) else False))
|
||
|
||
name = kwargs['name']
|
||
if name == "MyClip":
|
||
return g.op("Clip", args[0], min_f=args[1])
|
||
elif name == "MyRelu":
|
||
return g.op("Relu", args[0])
|
||
else:
|
||
return _unimplemented("prim::PythonOp", "unknown node kind: " + name)
|
||
|
||
from torch.onnx import register_custom_op_symbolic
|
||
register_custom_op_symbolic('::prim_PythonOp', symbolic_pythonop, 1)
|
||
|
||
input_tensor = torch.randn(2, 3, 224, 224, requires_grad=True)
|
||
torch.onnx.export(MyModule(), input_tensor, 'test.onnx', opset_version=12)
|
||
|
||
Custom operators
|
||
~~~~~~~~~~~~~~~~
|
||
|
||
Following this tutorial `Extending TorchScript with Custom C++ Operators <https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html>`_,
|
||
you can create and register your own custom ops implementation in PyTorch. Here's how to export such model to ONNX.::
|
||
|
||
# Create custom symbolic function
|
||
from torch.onnx.symbolic_helper import parse_args
|
||
@parse_args('v', 'v', 'f', 'i')
|
||
def symbolic_foo_forward(g, input1, input2, attr1, attr2):
|
||
return g.op("Foo", input1, input2, attr1_f=attr1, attr2_i=attr2)
|
||
|
||
# Register custom symbolic function
|
||
from torch.onnx import register_custom_op_symbolic
|
||
register_custom_op_symbolic('custom_ops::foo_forward', symbolic_foo_forward, 9)
|
||
|
||
class FooModel(torch.nn.Module):
|
||
def __init__(self, attr1, attr2):
|
||
super(FooModule, self).__init__()
|
||
self.attr1 = attr1
|
||
self.attr2 = attr2
|
||
|
||
def forward(self, input1, input2):
|
||
# Calling custom op
|
||
return torch.ops.custom_ops.foo_forward(input1, input2, self.attr1, self.attr2)
|
||
|
||
model = FooModel(attr1, attr2)
|
||
torch.onnx.export(model, (dummy_input1, dummy_input2), 'model.onnx', custom_opsets={"custom_domain": 2})
|
||
|
||
Depending on the custom operator, you can export it as one or a combination of existing ONNX ops.
|
||
You can also export it as a custom op in ONNX as well. In that case, you can specify the custom domain
|
||
and version (custom opset) using the ``custom_opsets`` dictionary at export. If not
|
||
explicitly specified, the custom opset version is set to 1 by default.
|
||
Using custom ONNX ops, you will need to extend the backend of your choice
|
||
with matching custom ops implementation, e.g. `Caffe2 custom ops <https://caffe2.ai/docs/custom-operators.html>`_,
|
||
`ONNX Runtime custom ops <https://github.com/microsoft/onnxruntime/blob/master/docs/AddingCustomOp.md>`_.
|
||
|
||
Operator Export Type
|
||
------------------------------------------------
|
||
Exporting models with unsupported ONNX operators can be achieved using the ``operator_export_type`` flag in export API.
|
||
This flag is useful when users try to export ATen and non-ATen operators that are not registered and supported in ONNX.
|
||
|
||
ONNX
|
||
~~~~
|
||
This mode is used to export all operators as regular ONNX operators. This is the default ``operator_export_type`` mode. ::
|
||
|
||
Example torch ir graph:
|
||
|
||
graph(%0 : Float(2, 3, 4, strides=[12, 4, 1])):
|
||
%3 : Float(2, 3, 4, strides=[12, 4, 1]) = aten:exp(%0)
|
||
%4 : Float(2, 3, 4, strides=[12, 4, 1]) = aten:div(%0, %3)
|
||
return (%4)
|
||
|
||
Is exported as:
|
||
|
||
graph(%0 : Float(2, 3, 4, strides=[12, 4, 1])):
|
||
%1 : Float(2, 3, 4, strides=[12, 4, 1]) = onnx:Exp(%0)
|
||
%2 : Float(2, 3, 4, strides=[12, 4, 1]) = onnx:Div(%0, %1)
|
||
return (%2)
|
||
|
||
|
||
ONNX_ATEN
|
||
~~~~~~~~~
|
||
This mode is used to export all operators as ATen ops, and avoid conversion to ONNX. ::
|
||
|
||
Example torch ir graph:
|
||
|
||
graph(%0 : Float(2, 3, 4, strides=[12, 4, 1])):
|
||
%3 : Float(2, 3, 4, strides=[12, 4, 1]) = aten::exp(%0)
|
||
%4 : Float(2, 3, 4, strides=[12, 4, 1]) = aten::div(%0, %3)
|
||
return (%4)
|
||
|
||
Is exported as:
|
||
|
||
graph(%0 : Float(2, 3, 4, strides=[12, 4, 1])):
|
||
%1 : Float(2, 3, 4, strides=[12, 4, 1]) = aten::ATen[operator="exp"](%0)
|
||
%2 : Float(2, 3, 4, strides=[12, 4, 1]) = aten::ATen[operator="div"](%0, %1)
|
||
return (%2)
|
||
|
||
ONNX_ATEN_FALLBACK
|
||
~~~~~~~~~~~~~~~~~~
|
||
To fallback on unsupported ATen operators in ONNX. Supported operators are exported to ONNX regularly.
|
||
In the following example, aten::triu is not supported in ONNX. Exporter falls back on this operator. ::
|
||
|
||
Example torch ir graph:
|
||
|
||
graph(%0 : Float):
|
||
%3 : int = prim::Constant[value=0]()
|
||
%4 : Float = aten::triu(%0, %3) # unsupported op
|
||
%5 : Float = aten::mul(%4, %0) # registered op
|
||
return (%5)
|
||
|
||
is exported as:
|
||
|
||
graph(%0 : Float):
|
||
%1 : Long() = onnx::Constant[value={0}]()
|
||
%2 : Float = aten::ATen[operator="triu"](%0, %1) # unsupported op
|
||
%3 : Float = onnx::Mul(%2, %0) # registered op
|
||
return (%3)
|
||
|
||
RAW
|
||
~~~
|
||
To export a raw ir. ::
|
||
|
||
Example torch ir graph:
|
||
|
||
graph(%x.1 : Float(1, strides=[1])):
|
||
%1 : Tensor = aten::exp(%x.1)
|
||
%2 : Tensor = aten::div(%x.1, %1)
|
||
%y.1 : Tensor[] = prim::ListConstruct(%2)
|
||
return (%y.1)
|
||
|
||
is exported as:
|
||
|
||
graph(%x.1 : Float(1, strides=[1])):
|
||
%1 : Tensor = aten::exp(%x.1)
|
||
%2 : Tensor = aten::div(%x.1, %1)
|
||
%y.1 : Tensor[] = prim::ListConstruct(%2)
|
||
return (%y.1)
|
||
|
||
ONNX_FALLTHROUGH
|
||
~~~~~~~~~~~~~~~~
|
||
This mode can be used to export any operator (ATen or non-ATen) that is not registered and supported in ONNX.
|
||
Exported falls through and exports the operator as is, as custom op. Exporting custom operators
|
||
enables users to register and implement the operator as part of their runtime backend. ::
|
||
|
||
Example torch ir graph:
|
||
|
||
graph(%0 : Float(2, 3, 4, strides=[12, 4, 1]),
|
||
%1 : Float(2, 3, 4, strides=[12, 4, 1])):
|
||
%6 : Float(2, 3, 4, strides=[12, 4, 1]) = foo_namespace::bar(%0, %1) # custom op
|
||
%7 : Float(2, 3, 4, strides=[12, 4, 1]) = aten::div(%6, %0) # registered op
|
||
return (%7))
|
||
|
||
is exported as:
|
||
|
||
graph(%0 : Float(2, 3, 4, strides=[12, 4, 1]),
|
||
%1 : Float(2, 3, 4, strides=[12, 4, 1])):
|
||
%2 : Float(2, 3, 4, strides=[12, 4, 1]) = foo_namespace::bar(%0, %1) # custom op
|
||
%3 : Float(2, 3, 4, strides=[12, 4, 1]) = onnx::Div(%2, %0) # registered op
|
||
return (%3
|
||
|
||
|
||
Frequently Asked Questions
|
||
--------------------------
|
||
Q: I have exported my lstm model, but its input size seems to be fixed?
|
||
|
||
The tracer records the example inputs shape in the graph. In case the model should accept
|
||
inputs of dynamic shape, you can utilize the parameter `dynamic_axes` in export api. ::
|
||
|
||
layer_count = 4
|
||
|
||
model = nn.LSTM(10, 20, num_layers=layer_count, bidirectional=True)
|
||
model.eval()
|
||
|
||
with torch.no_grad():
|
||
input = torch.randn(5, 3, 10)
|
||
h0 = torch.randn(layer_count * 2, 3, 20)
|
||
c0 = torch.randn(layer_count * 2, 3, 20)
|
||
output, (hn, cn) = model(input, (h0, c0))
|
||
|
||
# default export
|
||
torch.onnx.export(model, (input, (h0, c0)), 'lstm.onnx')
|
||
onnx_model = onnx.load('lstm.onnx')
|
||
# input shape [5, 3, 10]
|
||
print(onnx_model.graph.input[0])
|
||
|
||
# export with `dynamic_axes`
|
||
torch.onnx.export(model, (input, (h0, c0)), 'lstm.onnx',
|
||
input_names=['input', 'h0', 'c0'],
|
||
output_names=['output', 'hn', 'cn'],
|
||
dynamic_axes={'input': {0: 'sequence'}, 'output': {0: 'sequence'}})
|
||
onnx_model = onnx.load('lstm.onnx')
|
||
# input shape ['sequence', 3, 10]
|
||
print(onnx_model.graph.input[0])
|
||
|
||
|
||
Q: How to export models with loops in it?
|
||
|
||
Please checkout `Tracing vs Scripting`_.
|
||
|
||
Q: How to export models with primitive type inputs?
|
||
|
||
Support for primitive type inputs will be added starting from PyTorch 1.9 release.
|
||
However, exporter does not support conversion of models with String inputs.
|
||
|
||
Q: Does ONNX support implicit scalar datatype casting?
|
||
|
||
No, but the exporter will try to handle that part. Scalars are converted to constant tensors in ONNX.
|
||
The exporter will try to figure out the right datatype for scalars. However for cases that it failed
|
||
to do so, you will need to manually provide the datatype information. This often happens with scripted models,
|
||
where the datatypes are not recorded. We are trying to improve the datatype
|
||
propagation in the exporter such that manual changes are not required in the future. ::
|
||
|
||
class ImplicitCastType(torch.jit.ScriptModule):
|
||
@torch.jit.script_method
|
||
def forward(self, x):
|
||
# Exporter knows x is float32, will export '2' as float32 as well.
|
||
y = x + 2
|
||
# Without type propagation, exporter doesn't know the datatype of y.
|
||
# Thus '3' is exported as int64 by default.
|
||
return y + 3
|
||
# The following will export correctly.
|
||
# return y + torch.tensor([3], dtype=torch.float32)
|
||
|
||
x = torch.tensor([1.0], dtype=torch.float32)
|
||
torch.onnx.export(ImplicitCastType(), x, 'models/implicit_cast.onnx',
|
||
example_outputs=ImplicitCastType()(x))
|
||
|
||
Q: Is tensor in-place indexed assignment like `data[index] = new_data` supported?
|
||
|
||
Yes, this is supported for ONNX opset version >= 11. Please checkout `Indexing`_.
|
||
|
||
Q: Is tensor list exportable to ONNX?
|
||
|
||
Yes, this is supported now for ONNX opset version >= 11. ONNX introduced the concept of Sequence in opset 11.
|
||
Similar to list, Sequence is a data type that contains arbitrary number of Tensors.
|
||
Associated operators are also introduced in ONNX, such as SequenceInsert, SequenceAt, etc.
|
||
However, in-place list append within loops is not exportable to ONNX. To implement this, please use inplace
|
||
add operator.
|
||
E.g.: ::
|
||
|
||
class ListLoopModel(torch.nn.Module):
|
||
def forward(self, x):
|
||
res = []
|
||
res1 = []
|
||
arr = x.split(2, 0)
|
||
res2 = torch.zeros(3, 4, dtype=torch.long)
|
||
for i in range(len(arr)):
|
||
res += [arr[i].sum(0, False)]
|
||
res1 += [arr[-1 - i].sum(0, False)]
|
||
res2 += 1
|
||
return torch.stack(res), torch.stack(res1), res2
|
||
|
||
model = torch.jit.script(ListLoopModel())
|
||
inputs = torch.randn(16)
|
||
|
||
out = model(inputs)
|
||
torch.onnx.export(model, (inputs, ), 'loop_and_list.onnx', opset_version=11, example_outputs=out)
|
||
|
||
# onnxruntime
|
||
import onnxruntime
|
||
sess = onnxruntime.InferenceSession('loop_and_list.onnx')
|
||
out_ort = sess.run(None, {
|
||
sess.get_inputs()[0].name: inputs.numpy(),
|
||
})
|
||
|
||
assert [torch.allclose(o, torch.tensor(o_ort)) for o, o_ort in zip(out, out_ort)]
|
||
|
||
Use external data format
|
||
------------------------
|
||
``use_external_data_format`` argument in export API enables export of models in ONNX external
|
||
data format. With this option enabled, the exporter stores some model parameters in external
|
||
binary files, rather than the ONNX file itself. These external binary files are stored in the
|
||
same location as the ONNX file. Argument 'f' must be a string specifying the location of the model. ::
|
||
|
||
model = torchvision.models.mobilenet_v2(pretrained=True)
|
||
input = torch.randn(2, 3, 224, 224, requires_grad=True)
|
||
torch.onnx.export(model, (input, ), './large_model.onnx', use_external_data_format=True)
|
||
|
||
|
||
This argument enables export of large models to ONNX. Models larger than 2GB cannot be exported
|
||
in one file because of the protobuf size limit. Users should set ``use_external_data_format`` to
|
||
``True`` to successfully export such models.
|
||
|
||
Training
|
||
--------
|
||
``Training`` argument in export API allows users to export models in a training-friendly mode.
|
||
``TrainingMode.TRAINING`` exports model in a training-friendly mode that avoids certain model
|
||
optimizations which might interfere with model parameter training. ``TrainingMode.PRESERVE``
|
||
exports the model in inference mode if ``model.training`` is ``False``. Otherwise, it exports
|
||
the model in a training-friendly mode.
|
||
The default mode for this argument is ``TrainingMode.EVAL`` which exports the model in
|
||
inference mode.
|
||
|
||
Functions
|
||
--------------------------
|
||
.. autofunction:: export
|
||
.. autofunction:: export_to_pretty_string
|
||
.. autofunction:: register_custom_op_symbolic
|
||
.. autofunction:: torch.onnx.operators.shape_as_tensor
|
||
.. autofunction:: select_model_mode_for_export
|
||
.. autofunction:: is_in_onnx_export
|