ONNX Update training ops and training amenable export API (#35567)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/35567

Reviewed By: hl475

Differential Revision: D20715339

Pulled By: houseroad

fbshipit-source-id: ad88097e76b169035ab5814b769dc1bed54c6008
This commit is contained in:
Lara Haidar 2020-03-29 23:12:32 -07:00 committed by Facebook GitHub Bot
parent 1f759936f0
commit 728c7dcea3
19 changed files with 718 additions and 146 deletions

View File

@ -160,6 +160,7 @@ jobs:
-g"-torch/csrc/jit/export.cpp" \
-g"-torch/csrc/jit/import.cpp" \
-g"-torch/csrc/jit/netdef_converter.cpp" \
-g"-torch/csrc/onnx/init.cpp" \
"$@" > ${GITHUB_WORKSPACE}/clang-tidy-output.txt
cat ${GITHUB_WORKSPACE}/clang-tidy-output.txt

View File

@ -54,6 +54,7 @@ pytest "${args[@]}" \
--ignore "$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py" \
--ignore "$top_dir/test/onnx/test_custom_ops.py" \
--ignore "$top_dir/test/onnx/test_models_onnxruntime.py" \
--ignore "$top_dir/test/onnx/test_utility_funs.py" \
"${test_paths[@]}"
# onnxruntime only support py3
@ -64,7 +65,8 @@ if [[ "$BUILD_ENVIRONMENT" == *ort1-py3.6* ]]; then
"$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset8" \
"$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime" \
"$top_dir/test/onnx/test_custom_ops.py" \
"$top_dir/test/onnx/test_models_onnxruntime.py"
"$top_dir/test/onnx/test_models_onnxruntime.py" \
"$top_dir/test/onnx/test_utility_funs.py"
fi
if [[ "$BUILD_ENVIRONMENT" == *ort2-py3.6* ]]; then
# Update the loop for new opsets

View File

@ -0,0 +1,167 @@
ir_version: 6
producer_name: "pytorch"
producer_version: "1.5"
graph {
node {
output: "6"
name: "Constant_0"
op_type: "Constant"
attribute {
name: "value"
t {
data_type: 9
raw_data: "\001"
}
type: TENSOR
}
}
node {
input: "input"
input: "weight"
input: "bias"
input: "running_mean"
input: "running_var"
input: "6"
output: "7"
output: "8"
output: "9"
output: "batch_norm_dead_output-14"
output: "batch_norm_dead_output-15"
name: "BatchNormalization_1"
op_type: "BatchNormalization"
attribute {
name: "epsilon"
f: 1e-05
type: FLOAT
}
attribute {
name: "momentum"
f: 0.9
type: FLOAT
}
}
name: "torch-jit-export"
initializer {
dims: 2
data_type: 1
name: "bias"
raw_data: "\000\000\000\000\000\000\000\000"
}
initializer {
dims: 2
data_type: 1
name: "running_mean"
raw_data: "\315\314\314=\315\314\314="
}
initializer {
dims: 2
data_type: 1
name: "running_var"
raw_data: "fff?fff?"
}
initializer {
dims: 2
data_type: 1
name: "weight"
raw_data: "\000\000\200?\000\000\200?"
}
input {
name: "input"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "weight"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
}
}
}
}
input {
name: "bias"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
}
}
}
}
input {
name: "running_mean"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
}
}
}
}
input {
name: "running_var"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
}
}
}
}
output {
name: "7"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
version: 12
}

View File

@ -0,0 +1,46 @@
ir_version: 6
producer_name: "pytorch"
producer_version: "1.5"
graph {
node {
input: "x"
output: "1"
name: "ReduceMax_0"
op_type: "ReduceMax"
attribute {
name: "keepdims"
i: 0
type: INT
}
}
name: "torch-jit-export"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 4
}
}
}
}
}
output {
name: "1"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
}
opset_import {
version: 9
}

View File

@ -0,0 +1,58 @@
ir_version: 6
producer_name: "pytorch"
producer_version: "1.5"
graph {
node {
input: "x"
output: "1"
output: "2"
name: "Dropout_0"
op_type: "Dropout"
attribute {
name: "ratio"
f: 0.5
type: FLOAT
}
}
node {
input: "1"
output: "3"
name: "ReduceMax_1"
op_type: "ReduceMax"
attribute {
name: "keepdims"
i: 0
type: INT
}
}
name: "torch-jit-export"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 4
}
}
}
}
}
output {
name: "3"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
}
opset_import {
version: 9
}

View File

@ -0,0 +1,67 @@
ir_version: 6
producer_name: "pytorch"
producer_version: "1.5"
graph {
node {
output: "1"
name: "Constant_0"
op_type: "Constant"
attribute {
name: "value"
t {
data_type: 1
raw_data: "\000\000\000?"
}
type: TENSOR
}
}
node {
input: "x"
input: "1"
output: "2"
output: "3"
name: "Dropout_1"
op_type: "Dropout"
}
node {
input: "2"
output: "4"
name: "ReduceMax_2"
op_type: "ReduceMax"
attribute {
name: "keepdims"
i: 0
type: INT
}
}
name: "torch-jit-export"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 4
}
}
}
}
}
output {
name: "4"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
}
opset_import {
version: 12
}

View File

@ -37,9 +37,10 @@ BATCH_SIZE = 2
class TestModels(TestCase):
def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7):
graph = torch.onnx.utils._trace(model, inputs, OperatorExportTypes.ONNX)
torch._C._jit_pass_lint(graph)
verify(model, inputs, backend, rtol=rtol, atol=atol)
with torch.onnx.select_model_mode_for_export(model, None):
graph = torch.onnx.utils._trace(model, inputs, OperatorExportTypes.ONNX)
torch._C._jit_pass_lint(graph)
verify(model, inputs, backend, rtol=rtol, atol=atol)
def test_ops(self):
x = Variable(

View File

@ -45,7 +45,7 @@ def check_onnx_opset_operator(model, ops, opset_version=_export_onnx_opset_versi
assert attributes[j][attribute_field] == getattr(graph.node[i].attribute[j], attribute_field)
def check_onnx_opsets_operator(module, x, ops, opset_versions, training=False, example_outputs=None):
def check_onnx_opsets_operator(module, x, ops, opset_versions, training=torch.onnx.TrainingMode.EVAL, example_outputs=None):
for opset_version in opset_versions:
f = io.BytesIO()
torch.onnx.export(module, x, f,
@ -238,12 +238,12 @@ class TestONNXOpset(TestCase):
# test training mode
ops = [{"op_name" : "Dropout", "attributes" : [{"name" : "ratio", "f" : 0.5, "type" : 1}]}]
ops = {9 : ops, 10 : ops}
check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10], training=True)
check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10], training=torch.onnx.TrainingMode.TRAINING)
# test eval mode
ops = []
ops = {9 : ops, 10 : ops}
check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10], training=False)
check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10], training=torch.onnx.TrainingMode.EVAL)
def test_full(self):
class MyModule(Module):

View File

@ -16,7 +16,6 @@ import os
import shutil
import torch.testing._internal.common_utils as common
'''Usage: python test/onnx/test_operators.py [--no-onnx] [--produce-onnx-test-data]
--no-onnx: no onnx python dependence
--produce-onnx-test-data: generate onnx test data
@ -255,7 +254,12 @@ class TestOperators(TestCase):
def test_batchnorm_training(self):
x = torch.ones(2, 2, 2, 2, requires_grad=True)
self.assertONNX(nn.BatchNorm2d(2), x, training=True, keep_initializers_as_inputs=True)
self.assertONNX(nn.BatchNorm2d(2), x, training=torch.onnx.TrainingMode.TRAINING, keep_initializers_as_inputs=True)
def test_batchnorm_training_opset12(self):
x = torch.ones(2, 2, 2, 2, requires_grad=True)
self.assertONNX(nn.BatchNorm2d(2), x, training=torch.onnx.TrainingMode.TRAINING,
keep_initializers_as_inputs=True, opset_version=12)
def test_conv(self):
x = torch.ones(20, 16, 50, 40, requires_grad=True)
@ -672,6 +676,18 @@ class TestOperators(TestCase):
x = torch.randn(3, 4, requires_grad=True)
self.assertONNX(lambda x: torch.max(functional.dropout(x, training=False)), x)
def test_dropout_default(self):
x = torch.randn(3, 4, requires_grad=True)
self.assertONNX(lambda x: torch.max(functional.dropout(x,)), x)
def test_dropout_training(self):
x = torch.randn(3, 4, requires_grad=True)
self.assertONNX(lambda x: torch.max(functional.dropout(x)), x, training=torch.onnx.TrainingMode.TRAINING)
def test_dropout_training_opset12(self):
x = torch.randn(3, 4, requires_grad=True)
self.assertONNX(lambda x: torch.max(functional.dropout(x)), x, opset_version=12, training=torch.onnx.TrainingMode.TRAINING)
def test_nonzero(self):
x = torch.tensor([[[2., 2.], [1., 0.]], [[0., 0.], [1., 1.]]], requires_grad=True)
self.assertONNX(lambda x: torch.nonzero(x), x)

View File

@ -5,8 +5,11 @@ import torch
import torch.onnx
from torch.onnx import utils, OperatorExportTypes
from torch.onnx.symbolic_helper import _set_opset_version, _set_operator_export_type
from test_pytorch_common import skipIfUnsupportedOpsetVersion
import onnx
import onnxruntime # noqa
import numpy as np
import io
import copy
@ -52,6 +55,8 @@ class TestUtilityFuns(TestCase):
assert "Provided key invalid_name2 for dynamic axes is not a valid input/output name" in messages
assert len(messages) == 2
# TODO : enable when constant folding is enabled for opset 12
@skipIfUnsupportedOpsetVersion([12])
def test_constant_fold_transpose(self):
class TransposeModule(torch.nn.Module):
def forward(self, x):
@ -72,6 +77,8 @@ class TestUtilityFuns(TestCase):
assert node.kind() != "onnx::Constant"
assert len(list(graph.nodes())) == 1
# TODO : enable when constant folding is enabled for opset 12
@skipIfUnsupportedOpsetVersion([12])
def test_constant_fold_slice(self):
class NarrowModule(torch.nn.Module):
def forward(self, x):
@ -92,6 +99,8 @@ class TestUtilityFuns(TestCase):
assert node.kind() != "onnx::Constant"
assert len(list(graph.nodes())) == 1
# TODO : enable when constant folding is enabled for opset 12
@skipIfUnsupportedOpsetVersion([12])
def test_constant_fold_slice_index_exceeds_dim(self):
class SliceIndexExceedsDimModule(torch.nn.Module):
def forward(self, x):
@ -113,6 +122,8 @@ class TestUtilityFuns(TestCase):
assert node.kind() != "onnx::Constant"
assert len(list(graph.nodes())) == 1
# TODO : enable when constant folding is enabled for opset 12
@skipIfUnsupportedOpsetVersion([12])
def test_constant_fold_slice_negative_index(self):
class SliceNegativeIndexModule(torch.nn.Module):
def forward(self, x):
@ -133,6 +144,8 @@ class TestUtilityFuns(TestCase):
assert node.kind() != "onnx::Constant"
assert len(list(graph.nodes())) == 1
# TODO : enable when constant folding is enabled for opset 12
@skipIfUnsupportedOpsetVersion([12])
def test_constant_fold_unsqueeze(self):
class UnsqueezeModule(torch.nn.Module):
def forward(self, x):
@ -153,6 +166,8 @@ class TestUtilityFuns(TestCase):
assert node.kind() != "onnx::Constant"
assert len(list(graph.nodes())) == 1
# TODO : enable when constant folding is enabled for opset 12
@skipIfUnsupportedOpsetVersion([12])
def test_constant_fold_concat(self):
class ConcatModule(torch.nn.Module):
def forward(self, x):
@ -190,6 +205,8 @@ class TestUtilityFuns(TestCase):
assert node.kind() != "onnx::Constant"
assert len(list(graph.nodes())) == 2
# TODO : enable when constant folding is enabled for opset 12
@skipIfUnsupportedOpsetVersion([12])
def test_constant_fold_lstm(self):
class GruNet(torch.nn.Module):
def __init__(self):
@ -212,6 +229,8 @@ class TestUtilityFuns(TestCase):
assert node.kind() != "onnx::Unsqueeze"
assert len(list(graph.nodes())) == 3
# TODO : enable when constant folding is enabled for opset 12
@skipIfUnsupportedOpsetVersion([12])
def test_constant_fold_transpose_matmul(self):
class MatMulNet(torch.nn.Module):
def __init__(self):
@ -233,6 +252,8 @@ class TestUtilityFuns(TestCase):
# TODO we need to figure out the root cause and fix the problem
@skip("causing segmentation fault")
# TODO : enable when constant folding is enabled for opset 12
@skipIfUnsupportedOpsetVersion([12])
def test_constant_fold_reshape(self):
class ReshapeModule(torch.nn.Module):
def __init__(self, ):
@ -252,6 +273,8 @@ class TestUtilityFuns(TestCase):
assert node.kind() != "onnx::Reshape"
assert len(list(graph.nodes())) == 1
# TODO : enable when constant folding is enabled for opset 12
@skipIfUnsupportedOpsetVersion([12])
def test_constant_fold_div(self):
class Module(torch.nn.Module):
def __init__(self, ):
@ -271,6 +294,8 @@ class TestUtilityFuns(TestCase):
assert node.kind() != "onnx::Div"
assert len(list(graph.nodes())) == 1
# TODO : enable when constant folding is enabled for opset 12
@skipIfUnsupportedOpsetVersion([12])
def test_constant_fold_mul(self):
class Module(torch.nn.Module):
def __init__(self, ):
@ -290,6 +315,8 @@ class TestUtilityFuns(TestCase):
assert node.kind() != "onnx::Mul"
assert len(list(graph.nodes())) == 1
# TODO : enable when constant folding is enabled for opset 12
@skipIfUnsupportedOpsetVersion([12])
def test_constant_fold_sqrt(self):
class Module(torch.nn.Module):
def __init__(self, ):
@ -342,6 +369,95 @@ class TestUtilityFuns(TestCase):
'unwrap model from torch.nn.DataParallel. Try '):
torch.onnx.export(model, x, f, opset_version=self.opset_version)
def test_export_mode(self):
class MyModule(torch.nn.Module):
def forward(self, x):
y = x + 1
return y
model = MyModule()
x = torch.randn(10, 3, 128, 128)
f = io.BytesIO()
# set mode to in inference mode and export in training mode
model.eval()
old_state = model.training
torch.onnx.export(model, (x,), f,
opset_version=self.opset_version, training=torch.onnx.TrainingMode.TRAINING)
# verify that the model state is preserved
assert model.training == old_state
# set mode to training mode and export in inference mode
model.train()
old_state = model.training
torch.onnx.export(model, (x,), f,
opset_version=self.opset_version, training=torch.onnx.TrainingMode.EVAL)
# verify that the model state is preserved
assert model.training == old_state
# TODO: Enable test when BatchNorm is implemented in ORT for opset 12.
@skipIfUnsupportedOpsetVersion([12])
def test_batchnorm_training(self):
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.bn = torch.nn.BatchNorm2d(3, affine=True)
def forward(self, x):
bn = self.bn(x)
return bn
model = MyModule()
x = torch.randn(10, 3, 128, 128)
model.train()
out = model(x)
# state after 1 train epoch
running_mean = model.bn.running_mean
running_var = model.bn.running_var
saved_mean = x.mean((0, 2, 3))
saved_var = x.var((0, 2, 3))
pytorch_out = [out.detach().numpy(),
running_mean.cpu().numpy(), running_var.cpu().numpy(),
saved_mean.cpu().numpy(), saved_var.cpu().numpy()]
model_export = MyModule()
f = io.BytesIO()
torch.onnx.export(model_export, (x,), f,
opset_version=self.opset_version, training=torch.onnx.TrainingMode.TRAINING)
ort_sess = onnxruntime.InferenceSession(f.getvalue())
ort_inputs = {ort_sess.get_inputs()[0].name : x.cpu().numpy()}
ort_outs = ort_sess.run(None, ort_inputs)
[np.testing.assert_allclose(p_out, ort_out, atol=10e-3, rtol=10e-3) for p_out, ort_out in zip(pytorch_out, ort_outs)]
# TODO: Enable test when Dropout is implemented in ORT for opset 12.
@skipIfUnsupportedOpsetVersion([12])
def test_dropout_training(self):
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.dropout = torch.nn.Dropout(0.4)
def forward(self, x):
dropout = self.dropout(x)
return dropout
model = MyModule()
x = torch.randn(10, 3, 128, 128)
model.train()
f = io.BytesIO()
torch.onnx.export(model, (x,), f,
opset_version=self.opset_version, training=torch.onnx.TrainingMode.TRAINING)
ort_sess = onnxruntime.InferenceSession(f.getvalue())
ort_inputs = {ort_sess.get_inputs()[0].name : x.cpu().numpy()}
ort_outs = ort_sess.run(None, ort_inputs)
assert x != ort_outs[0]
# opset 10 tests
TestUtilityFuns_opset10 = type(str("TestUtilityFuns_opset10"),
@ -354,6 +470,11 @@ TestUtilityFuns_opset11 = type(str("TestUtilityFuns_opset11"),
(TestCase,),
dict(TestUtilityFuns.__dict__, opset_version=11))
# opset 12 tests
TestUtilityFuns_opset12 = type(str("TestUtilityFuns_opset12"),
(TestCase,),
dict(TestUtilityFuns.__dict__, opset_version=12))
# opset 12tests
TestUtilityFuns_opset12 = type(str("TestUtilityFuns_opset12"),

View File

@ -8,7 +8,6 @@ import onnx.helper
import numpy as np
import difflib
import contextlib
import io
@ -226,24 +225,7 @@ class Errors(object):
if exc_type == self.exc_class:
raise RuntimeError("ShortCircuit was raised, but no errors were recorded")
@contextlib.contextmanager
def set_training(model, mode):
"""
A context manager to temporarily set the training mode of 'model'
to 'mode', resetting it when we exit the with-block.
"""
old_mode = model.training
if old_mode != mode:
model.train(mode)
try:
yield
finally:
if old_mode != mode:
model.train(old_mode)
def verify(model, args, backend, verbose=False, training=False, rtol=1e-3, atol=1e-7,
def verify(model, args, backend, verbose=False, training=torch.onnx.TrainingMode.EVAL, rtol=1e-3, atol=1e-7,
test_args=2, do_constant_folding=True, example_outputs=None, opset_version=None,
keep_initializers_as_inputs=True, add_node_names=False):
"""
@ -359,7 +341,7 @@ def verify(model, args, backend, verbose=False, training=False, rtol=1e-3, atol=
if isinstance(args, torch.Tensor):
args = (args,)
with set_training(model, training):
with torch.onnx.select_model_mode_for_export(model, training):
proto_bytes = io.BytesIO()
torch_out = torch.onnx._export(model, args, proto_bytes, verbose=verbose,
do_constant_folding=do_constant_folding,

View File

@ -30,6 +30,11 @@ void initONNXBindings(PyObject* module) {
.value("ONNX_ATEN_FALLBACK", OperatorExportTypes::ONNX_ATEN_FALLBACK)
.value("RAW", OperatorExportTypes::RAW);
py::enum_<TrainingMode>(onnx, "TrainingMode")
.value("EVAL", TrainingMode::EVAL)
.value("PRESERVE", TrainingMode::PRESERVE)
.value("TRAINING", TrainingMode::TRAINING);
onnx.attr("IR_VERSION") = IR_VERSION;
onnx.attr("PRODUCER_VERSION") = py::str(PRODUCER_VERSION);

View File

@ -9,6 +9,12 @@ enum class OperatorExportTypes {
RAW, // Raw export (no ONNX)
};
enum class TrainingMode {
EVAL, // Inference mode
PRESERVE, // Preserve model state (eval/training)
TRAINING, // Training mode
};
// we pin IR version to version 6 (12/11/2019) instead of using
// onnx::IR_VERSION. with this change, the test_operators.py will be more
// stable. only bump it when it's necessary

View File

@ -2,6 +2,7 @@ import torch._C as _C
TensorProtoDataType = _C._onnx.TensorProtoDataType
OperatorExportTypes = _C._onnx.OperatorExportTypes
TrainingMode = _C._onnx.TrainingMode
PYTORCH_ONNX_CAFFE2_BUNDLE = _C._onnx.PYTORCH_ONNX_CAFFE2_BUNDLE
ONNX_ARCHIVE_MODEL_PROTO_NAME = "__MODEL_PROTO"
@ -28,7 +29,7 @@ def _export(*args, **kwargs):
return result
def export(model, args, f, export_params=True, verbose=False, training=False,
def export(model, args, f, export_params=True, verbose=False, training=TrainingMode.EVAL,
input_names=None, output_names=None, aten=False, export_raw_ir=False,
operator_export_type=None, opset_version=None, _retain_param_name=True,
do_constant_folding=True, example_outputs=None, strip_doc_string=True,
@ -59,9 +60,11 @@ def export(model, args, f, export_params=True, verbose=False, training=False,
as arguments, the ordering as specified by ``model.state_dict().values()``
verbose (bool, default False): if specified, we will print out a debug
description of the trace being exported.
training (bool, default False): export the model in training mode. At
the moment, ONNX is oriented towards exporting models for inference
only, so you will generally not need to set this to True.
training (enum, default TrainingMode.EVAL):
TrainingMode.EVAL: export the model in inference mode.
TrainingMode.PRESERVE: export the model in inference mode if model.training is
False and to a training friendly mode if model.training is True.
TrainingMode.TRAINING: export the model in a training friendly mode.
input_names(list of strings, default empty list): names to assign to the
input nodes of the graph, in order
output_names(list of strings, default empty list): names to assign to the
@ -184,7 +187,7 @@ def _optimize_trace(graph, operator_export_type):
return utils._optimize_graph(graph, operator_export_type)
def set_training(model, mode):
def select_model_mode_for_export(model, mode):
r"""
A context manager to temporarily set the training mode of 'model'
to 'mode', resetting it when we exit the with-block. A no-op if
@ -192,7 +195,7 @@ def set_training(model, mode):
"""
from torch.onnx import utils
return utils.set_training(model, mode)
return utils.select_model_mode_for_export(model, mode)
def _run_symbolic_function(*args, **kwargs):

View File

@ -413,6 +413,20 @@ def _avgpool_helper(tuple_fn, padding, kernel_size, stride, divisor_override, na
padding = tuple(tuple_fn(padding))
return padding
def assert_training_mode(op_mode, op_name):
global _training_mode
op_mode = True if op_mode == 1 else False
if op_mode != _training_mode:
op_mode = "training " if op_mode else "inference"
training_mode = "training " if _training_mode else "inference"
# setting the model mode could result in op_mode != _training_mode
# if the model is a FuncModule. In this case we warn the user of
# the state and export depending on training_mode
warnings.warn("ONNX export mode is set to " + training_mode +
" mode, but operator " + op_name + " is set to " +
op_mode + " mode. The model will be exported in " +
training_mode + ", as specified by the export mode.")
# ---------------------------------------------------------------------
# ONNX operator version
# ---------------------------------------------------------------------
@ -461,6 +475,11 @@ def _set_operator_export_type(operator_export_type):
global _operator_export_type
_operator_export_type = operator_export_type
_training_mode = None
def _set_training_mode(training_mode):
global _training_mode
_training_mode = training_mode
# Metaprogram symbolics for each ATen native specialized cast operator.
# For e.g. we specify a function named `_cast_uint8_t` that instantiates an
# ONNX cast node with `to` attribute 'UINT8'

View File

@ -1,5 +1,6 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import torch
import torch.onnx.symbolic_helper as sym_help
from torch.onnx.symbolic_helper import parse_args
@ -9,11 +10,62 @@ from torch.onnx.symbolic_helper import parse_args
# This file exports ONNX ops for opset 12
black_listed_operators = [
"ArgMin", "ArgMax"
]
@parse_args('s', 'v')
def einsum(g, equation, tensor_list):
tensors = sym_help._unpack_list(tensor_list)
return g.op("Einsum", *tensors, equation_s=equation)
@parse_args('v', 'f', 'i')
def dropout(g, input, p, train):
sym_help.assert_training_mode(train, "dropout")
# in eval mode, dropout is non-op - if the node's train param is set to False, dropout is non-op
if not sym_help._training_mode:
return input
p = g.op("Constant", value_t=torch.tensor(p))
r, _ = g.op("Dropout", input, p, outputs=2)
return r
@parse_args('v', 'v', 'v', 'v', 'v', 'i', 'f', 'f', 'i')
def batch_norm(g, input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled):
sym_help.assert_training_mode(training, "batch_norm")
input_sizes = input.type().sizes()
if weight is None or sym_help._is_none(weight):
assert len(input_sizes) > 1
weight_value = torch.tensor([1.] * input_sizes[1]).type(
'torch.' + input.type().scalarType() + 'Tensor')
weight = g.op("Constant", value_t=weight_value)
if bias is None or sym_help._is_none(bias):
assert len(input_sizes) > 1
bias_value = torch.tensor([0.] * input_sizes[1]).type(
'torch.' + input.type().scalarType() + 'Tensor')
bias = g.op("Constant", value_t=bias_value)
if not sym_help._training_mode:
out = g.op("BatchNormalization", input, weight, bias, running_mean, running_var,
epsilon_f=eps,
momentum_f=1 - momentum,
outputs=1)
return out
else:
training_mode = g.op("Constant", value_t=torch.tensor(True))
res, new_running_mean, new_running_var, saved_mean, saved_var = g.op("BatchNormalization",
input,
weight, bias,
running_mean, running_var, training_mode,
epsilon_f=eps,
momentum_f=1 - momentum,
outputs=5)
new_running_mean.setType(running_mean.type())
new_running_var.setType(running_var.type())
saved_mean.setDebugName("batch_norm_dead_output-" + saved_mean.debugName())
saved_var.setDebugName("batch_norm_dead_output-" + saved_var.debugName())
return res
def nll_loss(g, self, target, weight, reduction, ignore_index):
# none reduction : onnx::Constant[value={0}]
# mean reduction : onnx::Constant[value={1}]

View File

@ -1049,6 +1049,7 @@ def conv_transpose3d(g, input, weight, bias, stride, padding, output_padding, gr
@parse_args('v', 'v', 'v', 'v', 'v', 'i', 'f', 'f', 'i')
def batch_norm(g, input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled):
sym_help.assert_training_mode(training, "dropout")
input_sizes = input.type().sizes()
if weight is None or sym_help._is_none(weight):
@ -1065,8 +1066,8 @@ def batch_norm(g, input, weight, bias, running_mean, running_var, training, mome
out = g.op("BatchNormalization", input, weight, bias, running_mean, running_var,
epsilon_f=eps,
momentum_f=1 - momentum,
outputs=1 if not training else 5)
if not training:
outputs=1 if not sym_help._training_mode else 5)
if not sym_help._training_mode:
return out
else:
res, new_running_mean, new_running_var, saved_mean, saved_var = out
@ -1298,7 +1299,9 @@ def exp(g, self):
@parse_args('v', 'f', 'i')
def dropout(g, input, p, train):
if not train: # in eval mode, dropout is non-op
sym_help.assert_training_mode(train, "dropout")
# in eval mode, dropout is non-op - if the node's train param is set to False, dropout is non-op
if not sym_help._training_mode:
return input
warnings.warn("Dropout is a training op and should not be exported in inference mode. "
"Make sure to call eval() on the model, and to export it with param training=False.")

View File

@ -17,7 +17,7 @@ import numbers
import warnings
from torch._six import string_classes
from torch.jit import _unique_state_dict
from torch.onnx import ONNX_ARCHIVE_MODEL_PROTO_NAME, ExportTypes, OperatorExportTypes
from torch.onnx import ONNX_ARCHIVE_MODEL_PROTO_NAME, ExportTypes, OperatorExportTypes, TrainingMode
from torch._C import ListType, _propagate_and_assign_input_shapes, _assign_output_shapes, _check_onnx_proto
@ -31,21 +31,44 @@ def is_in_onnx_export():
@contextlib.contextmanager
def set_training(model, mode):
if mode is None:
yield
return
old_mode = model.training
if old_mode != mode:
model.train(mode)
def select_model_mode_for_export(model, mode):
if not isinstance(model, torch.jit.ScriptFunction):
is_originally_training = model.training
if mode is None:
mode = TrainingMode.EVAL
# if the model is in training mode but the user did not specify
# to export the model in training mode, export the model in inference
# mode (default) and warn them
if is_originally_training:
warnings.warn("You are exporting the model to ONNX while in training mode with "
"'train' parameter not specified. The model will default to inference mode export. "
"If you wish to export a training amenable ONNX model, specify train=TrainingMode.TRAIN or "
"train=TrainingMode.PRESERVE (to preserve the original model state) in torch.onnx.export().")
# if mode == TrainingMode.EVAL or (mode == TrainingMode.PRESERVE and not is_originally_training) => is_training = False
is_export_training = False
# ONNX opset 12 has better support for training amenable models, with updated
# versions of the dropout and batch_norm operators
if mode == TrainingMode.TRAINING or (mode == TrainingMode.PRESERVE and is_originally_training):
from torch.onnx.symbolic_helper import _export_onnx_opset_version
if _export_onnx_opset_version < 12:
warnings.warn("You are exporting the model in training mode with onnx opset version {}. "
"Opset versions lower than opset 12 will not be able to export nodes such as"
"Dropout and BatchNorm correctly.".format(_export_onnx_opset_version))
is_export_training = True
from torch.onnx.symbolic_helper import _set_training_mode
_set_training_mode(is_export_training)
model.train(is_export_training)
try:
yield
finally:
if old_mode != mode:
model.train(old_mode)
if not isinstance(model, torch.jit.ScriptFunction):
model.train(is_originally_training)
def export(model, args, f, export_params=True, verbose=False, training=False,
def export(model, args, f, export_params=True, verbose=False, training=None,
input_names=None, output_names=None, aten=False, export_raw_ir=False,
operator_export_type=None, opset_version=None, _retain_param_name=True,
do_constant_folding=True, example_outputs=None, strip_doc_string=True,
@ -275,21 +298,15 @@ def _trace(func, args, operator_export_type, return_outs=False):
return trace_graph
def _trace_and_get_graph_from_model(model, args, training):
def _trace_and_get_graph_from_model(model, args):
# A basic sanity check: make sure the state_dict keys are the same
# before and after running the model. Fail fast!
orig_state_dict_keys = _unique_state_dict(model).keys()
# By default, training=False, which is good because running a model in
# training mode could result in internal buffers getting updated, dropout
# getting applied, etc. If you really know what you're doing, you
# can turn training=True (or None, to preserve whatever the original
# training mode was.)
with set_training(model, training):
trace_graph, torch_out, inputs_states = \
torch.jit._get_trace_graph(model, args, _force_outplace=False, _return_inputs_states=True)
warn_on_static_input_change(inputs_states)
trace_graph, torch_out, inputs_states = \
torch.jit._get_trace_graph(model, args, _force_outplace=False, _return_inputs_states=True)
warn_on_static_input_change(inputs_states)
if orig_state_dict_keys != _unique_state_dict(model).keys():
raise RuntimeError("state_dict changed after running the tracer; "
@ -298,7 +315,7 @@ def _trace_and_get_graph_from_model(model, args, training):
return trace_graph, torch_out
def _model_to_graph(model, args, verbose=False, training=False,
def _model_to_graph(model, args, verbose=False,
input_names=None, output_names=None,
operator_export_type=OperatorExportTypes.ONNX,
example_outputs=None, propagate=False,
@ -331,7 +348,7 @@ def _model_to_graph(model, args, verbose=False, training=False,
graph = _propagate_and_assign_input_shapes(
model.graph, tuple(in_vars), False, propagate)
else:
graph, torch_out = _trace_and_get_graph_from_model(model, args, training)
graph, torch_out = _trace_and_get_graph_from_model(model, args)
state_dict = _unique_state_dict(model)
params = list(state_dict.values())
if _retain_param_name:
@ -387,7 +404,7 @@ def _model_to_graph(model, args, verbose=False, training=False,
return graph, params_dict, torch_out
def export_to_pretty_string(model, args, f, export_params=True, verbose=False, training=False,
def export_to_pretty_string(model, args, f, export_params=True, verbose=False, training=None,
input_names=None, output_names=None, aten=False, export_raw_ir=False,
operator_export_type=None, export_type=ExportTypes.PROTOBUF_FILE,
example_outputs=None, propagate=False, google_printer=False,
@ -410,7 +427,7 @@ def export_to_pretty_string(model, args, f, export_params=True, verbose=False, t
custom_opsets=custom_opsets)
def _export_to_pretty_string(model, args, f, export_params=True, verbose=False, training=False,
def _export_to_pretty_string(model, args, f, export_params=True, verbose=False, training=None,
input_names=None, output_names=None, operator_export_type=OperatorExportTypes.ONNX,
export_type=ExportTypes.PROTOBUF_FILE, example_outputs=None, propagate=False,
google_printer=False, opset_version=None, _retain_param_name=False,
@ -424,27 +441,27 @@ def _export_to_pretty_string(model, args, f, export_params=True, verbose=False,
custom_opsets = {}
_set_opset_version(opset_version)
_set_operator_export_type(operator_export_type)
val_keep_init_as_ip = _decide_keep_init_as_input(keep_initializers_as_inputs,
operator_export_type,
opset_version)
val_add_node_names = _decide_add_node_names(add_node_names, operator_export_type)
val_do_constant_folding = _decide_constant_folding(do_constant_folding, operator_export_type)
graph, params_dict, torch_out = _model_to_graph(model, args, verbose,
training, input_names,
output_names, operator_export_type,
example_outputs, propagate, _retain_param_name,
val_do_constant_folding, fixed_batch_size=fixed_batch_size)
with select_model_mode_for_export(model, training):
val_keep_init_as_ip = _decide_keep_init_as_input(keep_initializers_as_inputs,
operator_export_type,
opset_version)
val_add_node_names = _decide_add_node_names(add_node_names, operator_export_type)
val_do_constant_folding = _decide_constant_folding(do_constant_folding, operator_export_type)
graph, params_dict, torch_out = _model_to_graph(model, args, verbose, input_names,
output_names, operator_export_type,
example_outputs, propagate, _retain_param_name,
val_do_constant_folding, fixed_batch_size=fixed_batch_size)
return graph._pretty_print_onnx(params_dict, opset_version, False,
operator_export_type, google_printer,
val_keep_init_as_ip, custom_opsets, val_add_node_names)
return graph._pretty_print_onnx(params_dict, opset_version, False,
operator_export_type, google_printer,
val_keep_init_as_ip, custom_opsets, val_add_node_names)
# NOTE: the output `torch_out` will contain the output tensors resulting from
# the trace of a Module. In the case that a torch.nn.ScriptModule is passed in,
# this output will be None, since we are not doing any tracing but rather
# directly extracting the graph.
def _export(model, args, f, export_params=True, verbose=False, training=False,
def _export(model, args, f, export_params=True, verbose=False, training=None,
input_names=None, output_names=None, operator_export_type=None,
export_type=ExportTypes.PROTOBUF_FILE, example_outputs=None, propagate=False,
opset_version=None, _retain_param_name=False, do_constant_folding=True,
@ -470,80 +487,86 @@ def _export(model, args, f, export_params=True, verbose=False, training=False,
else:
operator_export_type = OperatorExportTypes.ONNX
_set_opset_version(opset_version)
_set_operator_export_type(operator_export_type)
val_keep_init_as_ip = _decide_keep_init_as_input(keep_initializers_as_inputs,
operator_export_type,
opset_version)
val_add_node_names = _decide_add_node_names(add_node_names, operator_export_type)
val_do_constant_folding = _decide_constant_folding(do_constant_folding, operator_export_type)
val_use_external_data_format, model_file_location = _decide_external_data_format(use_external_data_format,
operator_export_type,
f)
graph, params_dict, torch_out = _model_to_graph(model, args, verbose,
training, input_names,
output_names, operator_export_type,
example_outputs, propagate,
_retain_param_name, val_do_constant_folding,
fixed_batch_size=fixed_batch_size)
# By default, training=None, (which defaults to TrainingMode.EVAL),
# which is good because running a model in training mode could result in
# internal buffers getting updated, dropout getting applied, etc.
# If you really know what you're doing, you can turn
# training=TrainingMode.TRAINING or training=TrainingMode.PRESERVE,
# (to preserve whatever the original training mode was.)
with select_model_mode_for_export(model, training):
_set_opset_version(opset_version)
_set_operator_export_type(operator_export_type)
val_keep_init_as_ip = _decide_keep_init_as_input(keep_initializers_as_inputs,
operator_export_type,
opset_version)
val_add_node_names = _decide_add_node_names(add_node_names, operator_export_type)
val_do_constant_folding = _decide_constant_folding(do_constant_folding, operator_export_type)
val_use_external_data_format, model_file_location = _decide_external_data_format(use_external_data_format,
operator_export_type,
f)
graph, params_dict, torch_out = _model_to_graph(model, args, verbose, input_names,
output_names, operator_export_type,
example_outputs, propagate,
_retain_param_name, val_do_constant_folding,
fixed_batch_size=fixed_batch_size)
# TODO: Don't allocate a in-memory string for the protobuf
defer_weight_export = export_type is not ExportTypes.PROTOBUF_FILE
if dynamic_axes is None:
dynamic_axes = {}
if custom_opsets is None:
custom_opsets = {}
# TODO: Don't allocate a in-memory string for the protobuf
defer_weight_export = export_type is not ExportTypes.PROTOBUF_FILE
if dynamic_axes is None:
dynamic_axes = {}
if custom_opsets is None:
custom_opsets = {}
_validate_dynamic_axes(dynamic_axes, model, input_names, output_names)
_validate_dynamic_axes(dynamic_axes, model, input_names, output_names)
if export_params:
proto, export_map = graph._export_onnx(
params_dict, opset_version, dynamic_axes, defer_weight_export,
operator_export_type, strip_doc_string, val_keep_init_as_ip, custom_opsets,
val_add_node_names, val_use_external_data_format, model_file_location)
else:
proto, export_map = graph._export_onnx(
{}, opset_version, dynamic_axes, False, operator_export_type,
strip_doc_string, val_keep_init_as_ip, custom_opsets, val_add_node_names,
val_use_external_data_format, model_file_location)
if enable_onnx_checker and \
operator_export_type is OperatorExportTypes.ONNX and \
not val_use_external_data_format:
# Only run checker if enabled and we are not using ATEN fallback and
# large model format export in not enabled.
_check_onnx_proto(proto)
if export_type == ExportTypes.PROTOBUF_FILE:
assert(len(export_map) == 0)
with torch.serialization._open_file_like(f, 'wb') as opened_file:
opened_file.write(proto)
elif export_type in [ExportTypes.ZIP_ARCHIVE, ExportTypes.COMPRESSED_ZIP_ARCHIVE]:
import zipfile
compression = zipfile.ZIP_DEFLATED \
if export_type == ExportTypes.COMPRESSED_ZIP_ARCHIVE \
else zipfile.ZIP_STORED
with zipfile.ZipFile(f, 'w', compression=compression) as z:
z.writestr(ONNX_ARCHIVE_MODEL_PROTO_NAME, proto)
for k, v in export_map.items():
z.writestr(k, v)
elif export_type == ExportTypes.DIRECTORY:
import os
if os.path.exists(f):
assert(os.path.isdir(f))
if export_params:
proto, export_map = graph._export_onnx(
params_dict, opset_version, dynamic_axes, defer_weight_export,
operator_export_type, strip_doc_string, val_keep_init_as_ip, custom_opsets,
val_add_node_names, val_use_external_data_format, model_file_location)
else:
os.makedirs(f)
proto, export_map = graph._export_onnx(
{}, opset_version, dynamic_axes, False, operator_export_type,
strip_doc_string, val_keep_init_as_ip, custom_opsets, val_add_node_names,
val_use_external_data_format, model_file_location)
model_proto_file = os.path.join(f, ONNX_ARCHIVE_MODEL_PROTO_NAME)
with torch.serialization._open_file_like(model_proto_file, 'wb') as opened_file:
opened_file.write(proto)
if enable_onnx_checker and \
operator_export_type is OperatorExportTypes.ONNX_ATEN_FALLBACK and \
not val_use_external_data_format:
# Only run checker if enabled and we are not using ATEN fallback and
# large model format export in not enabled.
_check_onnx_proto(proto)
for k, v in export_map.items():
weight_proto_file = os.path.join(f, k)
with torch.serialization._open_file_like(weight_proto_file, 'wb') as opened_file:
opened_file.write(v)
else:
raise RuntimeError('Unknown export type')
if export_type == ExportTypes.PROTOBUF_FILE:
assert(len(export_map) == 0)
with torch.serialization._open_file_like(f, 'wb') as opened_file:
opened_file.write(proto)
elif export_type in [ExportTypes.ZIP_ARCHIVE, ExportTypes.COMPRESSED_ZIP_ARCHIVE]:
import zipfile
compression = zipfile.ZIP_DEFLATED \
if export_type == ExportTypes.COMPRESSED_ZIP_ARCHIVE \
else zipfile.ZIP_STORED
with zipfile.ZipFile(f, 'w', compression=compression) as z:
z.writestr(ONNX_ARCHIVE_MODEL_PROTO_NAME, proto)
for k, v in export_map.items():
z.writestr(k, v)
elif export_type == ExportTypes.DIRECTORY:
import os
if os.path.exists(f):
assert(os.path.isdir(f))
else:
os.makedirs(f)
model_proto_file = os.path.join(f, ONNX_ARCHIVE_MODEL_PROTO_NAME)
with torch.serialization._open_file_like(model_proto_file, 'wb') as opened_file:
opened_file.write(proto)
for k, v in export_map.items():
weight_proto_file = os.path.join(f, k)
with torch.serialization._open_file_like(weight_proto_file, 'wb') as opened_file:
opened_file.write(v)
else:
raise RuntimeError('Unknown export type')
finally:
assert __IN_ONNX_EXPORT
__IN_ONNX_EXPORT = False

View File

@ -280,7 +280,7 @@ def graph(model, args, verbose=False):
verbose (bool): Whether to print out verbose information while
processing.
"""
with torch.onnx.set_training(model, False): # TODO: move outside of torch.onnx?
with torch.onnx.select_model_mode_for_export(model, torch.onnx.TrainingMode.EVAL): # TODO: move outside of torch.onnx?
try:
trace = torch.jit.trace(model, args)
graph = trace.graph