[Resending] [ONNX] Add eliminate_unused_items pass (#42743)

Summary:
This PR:

- Adds eliminate_unused_items pass that removes unused inputs and initializers.
- Fixes run_embed_params function so it doesn't export unnecessary parameters.
- Removes test_modifying_params in test_verify since it's no longer needed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/42743

Reviewed By: hl475

Differential Revision: D23058954

Pulled By: houseroad

fbshipit-source-id: cd1e81463285a0bf4e60766c8c87fc9a350d9c7e
This commit is contained in:
Ksenija Stanojevic 2020-08-11 20:29:12 -07:00 committed by Facebook GitHub Bot
parent a846ed5ce7
commit e845b0ab51
12 changed files with 135 additions and 58 deletions

View File

@ -41,12 +41,7 @@ def run_embed_params(proto, model, input, state_dict=None, use_gpu=True):
# sure our order is consistent with the model's order. # sure our order is consistent with the model's order.
# TODO: Even better: keyword arguments! # TODO: Even better: keyword arguments!
for k in model.state_dict(): for k in model.state_dict():
if k not in state_dict: if k in state_dict:
# Once PyTorch Module adds unnecessary parameter, the old pre-trained model does not have it.
# Just simply pass the new one.
# TODO: Please don't export unnecessary parameter.
parameters.append(model.state_dict()[k])
else:
parameters.append(state_dict[k]) parameters.append(state_dict[k])
else: else:
parameters = list(model.state_dict().values()) parameters = list(model.state_dict().values())

View File

@ -7,6 +7,8 @@ from torch.onnx import utils, OperatorExportTypes, TrainingMode
from torch.onnx.symbolic_helper import _set_opset_version, _set_operator_export_type from torch.onnx.symbolic_helper import _set_opset_version, _set_operator_export_type
import torch.utils.cpp_extension import torch.utils.cpp_extension
from test_pytorch_common import skipIfUnsupportedMinOpsetVersion from test_pytorch_common import skipIfUnsupportedMinOpsetVersion
import caffe2.python.onnx.backend as backend
from verify import verify
import torchvision import torchvision
@ -698,6 +700,39 @@ class TestUtilityFuns(TestCase):
np.testing.assert_allclose(ratio_pytorch, ratio_ort, rtol=0.01, atol=0.01) np.testing.assert_allclose(ratio_pytorch, ratio_ort, rtol=0.01, atol=0.01)
def test_unused_initializers(self):
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv2 = torch.nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(1, 1))
self.k_proj = torch.nn.Linear(5, 5, bias=True)
def forward(self, x):
x = self.conv2(x)
return x
x = torch.randn(20, 16, 50, 100)
_set_opset_version(self.opset_version)
_set_operator_export_type(OperatorExportTypes.ONNX)
_, params_dict, __ = utils._model_to_graph(Model(), (x, ), do_constant_folding=False,
operator_export_type=OperatorExportTypes.ONNX)
assert len(params_dict) == 2
def test_modifying_params(self):
class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.param = torch.nn.Parameter(torch.tensor([2.0]))
def forward(self, x):
y = x * x
self.param.data.add_(1.0)
return y
x = torch.tensor([1, 2])
verify(MyModel(), x, backend, do_constant_folding=False)
def test_fuse_conv_bn(self): def test_fuse_conv_bn(self):
class Fuse(torch.nn.Module): class Fuse(torch.nn.Module):
def __init__(self): def __init__(self):

View File

@ -64,21 +64,6 @@ class TestVerify(TestCase):
with self.assertRaisesRegex(RuntimeError, "state_dict changed"): with self.assertRaisesRegex(RuntimeError, "state_dict changed"):
verify(MyModel(), x, backend) verify(MyModel(), x, backend)
def test_modifying_params(self):
class MyModel(Module):
def __init__(self):
super(MyModel, self).__init__()
self.param = Parameter(torch.tensor([2.0]))
def forward(self, x):
y = x * x
self.param.data.add_(1.0)
return y
x = torch.tensor([1, 2])
# To keep the unused model parameter, need to set constant folding to False
self.assertVerifyExpectFail(MyModel(), x, backend, do_constant_folding=False)
def test_dynamic_model_structure(self): def test_dynamic_model_structure(self):
class MyModel(Module): class MyModel(Module):
def __init__(self): def __init__(self):

View File

@ -482,6 +482,7 @@ libtorch_python_core_sources = [
"torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.cpp", "torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.cpp",
"torch/csrc/jit/passes/onnx/eval_peephole.cpp", "torch/csrc/jit/passes/onnx/eval_peephole.cpp",
"torch/csrc/jit/passes/onnx/constant_fold.cpp", "torch/csrc/jit/passes/onnx/constant_fold.cpp",
"torch/csrc/jit/passes/onnx/eliminate_unused_items.cpp",
"torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp", "torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp",
"torch/csrc/jit/passes/onnx/function_substitution.cpp", "torch/csrc/jit/passes/onnx/function_substitution.cpp",
"torch/csrc/jit/passes/onnx/helper.cpp", "torch/csrc/jit/passes/onnx/helper.cpp",

View File

@ -14,40 +14,35 @@ using namespace ::c10::onnx;
namespace { namespace {
enum OnnxType : int {
ONNX_FLOAT = 1,
ONNX_UINT8,
ONNX_INT8,
ONNX_UINT16,
ONNX_INT16,
ONNX_INT32,
ONNX_INT64,
ONNX_FLOAT16 = 10,
ONNX_DOUBLE,
ONNX_UINT32,
};
std::unordered_map<int, at::ScalarType> onnxTypeToScalarTypeMap = { std::unordered_map<int, at::ScalarType> onnxTypeToScalarTypeMap = {
// Only conversion of ONNX numeric types is included here. // Only conversion of ONNX numeric types is included here.
// Unsigned ONNX types are mapped to the next higher signed // Unsigned ONNX types are mapped to the next higher signed
// ScalarType type. // ScalarType type.
{1, at::kFloat}, {ONNX_FLOAT, at::kFloat},
{2, at::kByte}, {ONNX_UINT8, at::kByte},
{3, at::kChar}, {ONNX_INT8, at::kChar},
{4, at::kInt}, {ONNX_UINT16, at::kInt},
{5, at::kShort}, {ONNX_INT16, at::kShort},
{6, at::kInt}, {ONNX_INT32, at::kInt},
{7, at::kLong}, {ONNX_INT64, at::kLong},
{10, at::kFloat}, {ONNX_FLOAT16, at::kFloat},
{11, at::kDouble}, {ONNX_DOUBLE, at::kDouble},
{12, at::kLong}, {ONNX_UINT32, at::kLong},
}; };
void buildParamsMapFromValueToParamsMap(
const ValueToParamPairMap& valsToParamsMap,
ParamMap& paramsDict) {
paramsDict.clear();
for (const auto& nameTensorParamPair : valsToParamsMap) {
paramsDict.insert(nameTensorParamPair.second);
}
}
void eraseUnusedBlockInputs(Block* b) {
for (size_t i_1 = b->inputs().size(); i_1 > 0; --i_1) {
size_t i = i_1 - 1;
if (!b->inputs().at(i)->hasUses()) {
b->eraseInput(i);
}
}
}
void handleNegativeStartEndIndex( void handleNegativeStartEndIndex(
int64_t& start, int64_t& start,
int64_t& end, int64_t& end,

View File

@ -0,0 +1,23 @@
#include <torch/csrc/jit/passes/onnx/eliminate_unused_items.h>
#include <torch/csrc/jit/passes/onnx/helper.h>
#include <c10/util/Optional.h>
#include <algorithm>
namespace torch {
namespace jit {
namespace onnx {
using namespace ::c10::onnx;
}
void EliminateUnusedItemsONNX(Block* b, ParamMap& paramsDict) {
auto valsToParamsMap = buildValueToParamsMap(b, paramsDict);
eraseUnusedValuesFromMap(valsToParamsMap);
eraseUnusedBlockInputs(b);
buildParamsMapFromValueToParamsMap(valsToParamsMap, paramsDict);
return;
}
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,17 @@
#pragma once
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
// EliminateUnusedItemsONNX pass is removing unused
// initializers and inputs, this is needed because
// dce pass is only removing unused fork inputs
void EliminateUnusedItemsONNX(
Block* b,
std::map<std::string, IValue>& paramDict);
} // namespace jit
} // namespace torch

View File

@ -135,15 +135,6 @@ static void fuseConvBatchNorm(Block* b, ValueToParamPairMap& valsToParamsMap) {
} }
} }
void buildParamsMapFromValueToParamsMap(
const ValueToParamPairMap& valsToParamsMap,
ParamMap& paramsDict) {
paramsDict.clear();
for (const auto& nameTensorParamPair : valsToParamsMap) {
paramsDict.insert(nameTensorParamPair.second);
}
}
void EvalPeepholeONNX(Block* b, ParamMap& paramsDict) { void EvalPeepholeONNX(Block* b, ParamMap& paramsDict) {
auto valsToParamsMap = buildValueToParamsMap(b, paramsDict); auto valsToParamsMap = buildValueToParamsMap(b, paramsDict);
fuseConvBatchNorm(b, valsToParamsMap); fuseConvBatchNorm(b, valsToParamsMap);

View File

@ -21,6 +21,15 @@ ValueToParamPairMap buildValueToParamsMap(
return valsToParamsMap; return valsToParamsMap;
} }
void eraseUnusedBlockInputs(Block* b) {
for (size_t i_1 = b->inputs().size(); i_1 > 0; --i_1) {
size_t i = i_1 - 1;
if (!b->inputs().at(i)->hasUses()) {
b->eraseInput(i);
}
}
}
void eraseUnusedValuesFromMap(ValueToParamPairMap& valsToParamsMap) { void eraseUnusedValuesFromMap(ValueToParamPairMap& valsToParamsMap) {
auto it = valsToParamsMap.begin(); auto it = valsToParamsMap.begin();
while (it != valsToParamsMap.end()) { while (it != valsToParamsMap.end()) {
@ -31,5 +40,14 @@ void eraseUnusedValuesFromMap(ValueToParamPairMap& valsToParamsMap) {
} }
} }
} }
void buildParamsMapFromValueToParamsMap(
const ValueToParamPairMap& valsToParamsMap,
ParamMap& paramsDict) {
paramsDict.clear();
for (const auto& nameTensorParamPair : valsToParamsMap) {
paramsDict.insert(nameTensorParamPair.second);
}
}
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -22,6 +22,10 @@ using ParamMap = std::map<std::string, IValue>;
ValueToParamPairMap buildValueToParamsMap(Block* b, const ParamMap& paramsDict); ValueToParamPairMap buildValueToParamsMap(Block* b, const ParamMap& paramsDict);
void eraseUnusedValuesFromMap(ValueToParamPairMap& valsToParamsMap); void eraseUnusedValuesFromMap(ValueToParamPairMap& valsToParamsMap);
void eraseUnusedBlockInputs(Block* b);
void buildParamsMapFromValueToParamsMap(
const ValueToParamPairMap& valsToParamsMap,
ParamMap& paramsDict);
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -32,6 +32,7 @@
#include <torch/csrc/jit/passes/onnx.h> #include <torch/csrc/jit/passes/onnx.h>
#include <torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.h> #include <torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.h>
#include <torch/csrc/jit/passes/onnx/constant_fold.h> #include <torch/csrc/jit/passes/onnx/constant_fold.h>
#include <torch/csrc/jit/passes/onnx/eliminate_unused_items.h>
#include <torch/csrc/jit/passes/onnx/eval_peephole.h> #include <torch/csrc/jit/passes/onnx/eval_peephole.h>
#include <torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.h> #include <torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.h>
#include <torch/csrc/jit/passes/onnx/function_substitution.h> #include <torch/csrc/jit/passes/onnx/function_substitution.h>
@ -168,6 +169,16 @@ void initJITBindings(PyObject* module) {
return paramsDict; return paramsDict;
}, },
pybind11::return_value_policy::move) pybind11::return_value_policy::move)
.def(
"_jit_pass_onnx_eliminate_unused_items",
[](std::shared_ptr<Graph>& graph,
std::map<std::string, IValue>& paramsDict) {
EliminateUnusedItemsONNX(
graph->block(),
paramsDict); // overload resolution
return paramsDict;
},
pybind11::return_value_policy::move)
.def("_jit_pass_onnx_scalar_type_analysis", ScalarTypeAnalysisForONNX) .def("_jit_pass_onnx_scalar_type_analysis", ScalarTypeAnalysisForONNX)
.def( .def(
"_jit_pass_onnx_prepare_inplace_ops_for_onnx", "_jit_pass_onnx_prepare_inplace_ops_for_onnx",

View File

@ -424,6 +424,8 @@ def _model_to_graph(model, args, verbose=False,
_export_onnx_opset_version) _export_onnx_opset_version)
torch._C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) torch._C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
params_dict = torch._C._jit_pass_onnx_eliminate_unused_items(graph, params_dict)
# For ONNX opset < 9, constants only have three data types: float16, float, double. # For ONNX opset < 9, constants only have three data types: float16, float, double.
# In this pass transform constants of other data types to float/double + cast operator. # In this pass transform constants of other data types to float/double + cast operator.
if _export_onnx_opset_version < 9: if _export_onnx_opset_version < 9: