mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Resending] [ONNX] Add eliminate_unused_items pass (#42743)
Summary: This PR: - Adds eliminate_unused_items pass that removes unused inputs and initializers. - Fixes run_embed_params function so it doesn't export unnecessary parameters. - Removes test_modifying_params in test_verify since it's no longer needed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/42743 Reviewed By: hl475 Differential Revision: D23058954 Pulled By: houseroad fbshipit-source-id: cd1e81463285a0bf4e60766c8c87fc9a350d9c7e
This commit is contained in:
parent
a846ed5ce7
commit
e845b0ab51
|
|
@ -41,12 +41,7 @@ def run_embed_params(proto, model, input, state_dict=None, use_gpu=True):
|
||||||
# sure our order is consistent with the model's order.
|
# sure our order is consistent with the model's order.
|
||||||
# TODO: Even better: keyword arguments!
|
# TODO: Even better: keyword arguments!
|
||||||
for k in model.state_dict():
|
for k in model.state_dict():
|
||||||
if k not in state_dict:
|
if k in state_dict:
|
||||||
# Once PyTorch Module adds unnecessary parameter, the old pre-trained model does not have it.
|
|
||||||
# Just simply pass the new one.
|
|
||||||
# TODO: Please don't export unnecessary parameter.
|
|
||||||
parameters.append(model.state_dict()[k])
|
|
||||||
else:
|
|
||||||
parameters.append(state_dict[k])
|
parameters.append(state_dict[k])
|
||||||
else:
|
else:
|
||||||
parameters = list(model.state_dict().values())
|
parameters = list(model.state_dict().values())
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,8 @@ from torch.onnx import utils, OperatorExportTypes, TrainingMode
|
||||||
from torch.onnx.symbolic_helper import _set_opset_version, _set_operator_export_type
|
from torch.onnx.symbolic_helper import _set_opset_version, _set_operator_export_type
|
||||||
import torch.utils.cpp_extension
|
import torch.utils.cpp_extension
|
||||||
from test_pytorch_common import skipIfUnsupportedMinOpsetVersion
|
from test_pytorch_common import skipIfUnsupportedMinOpsetVersion
|
||||||
|
import caffe2.python.onnx.backend as backend
|
||||||
|
from verify import verify
|
||||||
|
|
||||||
import torchvision
|
import torchvision
|
||||||
|
|
||||||
|
|
@ -698,6 +700,39 @@ class TestUtilityFuns(TestCase):
|
||||||
|
|
||||||
np.testing.assert_allclose(ratio_pytorch, ratio_ort, rtol=0.01, atol=0.01)
|
np.testing.assert_allclose(ratio_pytorch, ratio_ort, rtol=0.01, atol=0.01)
|
||||||
|
|
||||||
|
def test_unused_initializers(self):
|
||||||
|
class Model(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(Model, self).__init__()
|
||||||
|
self.conv2 = torch.nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(1, 1))
|
||||||
|
self.k_proj = torch.nn.Linear(5, 5, bias=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
x = torch.randn(20, 16, 50, 100)
|
||||||
|
_set_opset_version(self.opset_version)
|
||||||
|
_set_operator_export_type(OperatorExportTypes.ONNX)
|
||||||
|
_, params_dict, __ = utils._model_to_graph(Model(), (x, ), do_constant_folding=False,
|
||||||
|
operator_export_type=OperatorExportTypes.ONNX)
|
||||||
|
|
||||||
|
assert len(params_dict) == 2
|
||||||
|
|
||||||
|
def test_modifying_params(self):
|
||||||
|
class MyModel(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(MyModel, self).__init__()
|
||||||
|
self.param = torch.nn.Parameter(torch.tensor([2.0]))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
y = x * x
|
||||||
|
self.param.data.add_(1.0)
|
||||||
|
return y
|
||||||
|
|
||||||
|
x = torch.tensor([1, 2])
|
||||||
|
verify(MyModel(), x, backend, do_constant_folding=False)
|
||||||
|
|
||||||
def test_fuse_conv_bn(self):
|
def test_fuse_conv_bn(self):
|
||||||
class Fuse(torch.nn.Module):
|
class Fuse(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
|
||||||
|
|
@ -64,21 +64,6 @@ class TestVerify(TestCase):
|
||||||
with self.assertRaisesRegex(RuntimeError, "state_dict changed"):
|
with self.assertRaisesRegex(RuntimeError, "state_dict changed"):
|
||||||
verify(MyModel(), x, backend)
|
verify(MyModel(), x, backend)
|
||||||
|
|
||||||
def test_modifying_params(self):
|
|
||||||
class MyModel(Module):
|
|
||||||
def __init__(self):
|
|
||||||
super(MyModel, self).__init__()
|
|
||||||
self.param = Parameter(torch.tensor([2.0]))
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
y = x * x
|
|
||||||
self.param.data.add_(1.0)
|
|
||||||
return y
|
|
||||||
|
|
||||||
x = torch.tensor([1, 2])
|
|
||||||
# To keep the unused model parameter, need to set constant folding to False
|
|
||||||
self.assertVerifyExpectFail(MyModel(), x, backend, do_constant_folding=False)
|
|
||||||
|
|
||||||
def test_dynamic_model_structure(self):
|
def test_dynamic_model_structure(self):
|
||||||
class MyModel(Module):
|
class MyModel(Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
|
||||||
|
|
@ -482,6 +482,7 @@ libtorch_python_core_sources = [
|
||||||
"torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.cpp",
|
"torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.cpp",
|
||||||
"torch/csrc/jit/passes/onnx/eval_peephole.cpp",
|
"torch/csrc/jit/passes/onnx/eval_peephole.cpp",
|
||||||
"torch/csrc/jit/passes/onnx/constant_fold.cpp",
|
"torch/csrc/jit/passes/onnx/constant_fold.cpp",
|
||||||
|
"torch/csrc/jit/passes/onnx/eliminate_unused_items.cpp",
|
||||||
"torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp",
|
"torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp",
|
||||||
"torch/csrc/jit/passes/onnx/function_substitution.cpp",
|
"torch/csrc/jit/passes/onnx/function_substitution.cpp",
|
||||||
"torch/csrc/jit/passes/onnx/helper.cpp",
|
"torch/csrc/jit/passes/onnx/helper.cpp",
|
||||||
|
|
|
||||||
|
|
@ -14,40 +14,35 @@ using namespace ::c10::onnx;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
enum OnnxType : int {
|
||||||
|
ONNX_FLOAT = 1,
|
||||||
|
ONNX_UINT8,
|
||||||
|
ONNX_INT8,
|
||||||
|
ONNX_UINT16,
|
||||||
|
ONNX_INT16,
|
||||||
|
ONNX_INT32,
|
||||||
|
ONNX_INT64,
|
||||||
|
ONNX_FLOAT16 = 10,
|
||||||
|
ONNX_DOUBLE,
|
||||||
|
ONNX_UINT32,
|
||||||
|
};
|
||||||
|
|
||||||
std::unordered_map<int, at::ScalarType> onnxTypeToScalarTypeMap = {
|
std::unordered_map<int, at::ScalarType> onnxTypeToScalarTypeMap = {
|
||||||
// Only conversion of ONNX numeric types is included here.
|
// Only conversion of ONNX numeric types is included here.
|
||||||
// Unsigned ONNX types are mapped to the next higher signed
|
// Unsigned ONNX types are mapped to the next higher signed
|
||||||
// ScalarType type.
|
// ScalarType type.
|
||||||
{1, at::kFloat},
|
{ONNX_FLOAT, at::kFloat},
|
||||||
{2, at::kByte},
|
{ONNX_UINT8, at::kByte},
|
||||||
{3, at::kChar},
|
{ONNX_INT8, at::kChar},
|
||||||
{4, at::kInt},
|
{ONNX_UINT16, at::kInt},
|
||||||
{5, at::kShort},
|
{ONNX_INT16, at::kShort},
|
||||||
{6, at::kInt},
|
{ONNX_INT32, at::kInt},
|
||||||
{7, at::kLong},
|
{ONNX_INT64, at::kLong},
|
||||||
{10, at::kFloat},
|
{ONNX_FLOAT16, at::kFloat},
|
||||||
{11, at::kDouble},
|
{ONNX_DOUBLE, at::kDouble},
|
||||||
{12, at::kLong},
|
{ONNX_UINT32, at::kLong},
|
||||||
};
|
};
|
||||||
|
|
||||||
void buildParamsMapFromValueToParamsMap(
|
|
||||||
const ValueToParamPairMap& valsToParamsMap,
|
|
||||||
ParamMap& paramsDict) {
|
|
||||||
paramsDict.clear();
|
|
||||||
for (const auto& nameTensorParamPair : valsToParamsMap) {
|
|
||||||
paramsDict.insert(nameTensorParamPair.second);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void eraseUnusedBlockInputs(Block* b) {
|
|
||||||
for (size_t i_1 = b->inputs().size(); i_1 > 0; --i_1) {
|
|
||||||
size_t i = i_1 - 1;
|
|
||||||
if (!b->inputs().at(i)->hasUses()) {
|
|
||||||
b->eraseInput(i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void handleNegativeStartEndIndex(
|
void handleNegativeStartEndIndex(
|
||||||
int64_t& start,
|
int64_t& start,
|
||||||
int64_t& end,
|
int64_t& end,
|
||||||
|
|
|
||||||
23
torch/csrc/jit/passes/onnx/eliminate_unused_items.cpp
Normal file
23
torch/csrc/jit/passes/onnx/eliminate_unused_items.cpp
Normal file
|
|
@ -0,0 +1,23 @@
|
||||||
|
#include <torch/csrc/jit/passes/onnx/eliminate_unused_items.h>
|
||||||
|
#include <torch/csrc/jit/passes/onnx/helper.h>
|
||||||
|
|
||||||
|
#include <c10/util/Optional.h>
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace jit {
|
||||||
|
|
||||||
|
namespace onnx {
|
||||||
|
using namespace ::c10::onnx;
|
||||||
|
}
|
||||||
|
|
||||||
|
void EliminateUnusedItemsONNX(Block* b, ParamMap& paramsDict) {
|
||||||
|
auto valsToParamsMap = buildValueToParamsMap(b, paramsDict);
|
||||||
|
eraseUnusedValuesFromMap(valsToParamsMap);
|
||||||
|
eraseUnusedBlockInputs(b);
|
||||||
|
buildParamsMapFromValueToParamsMap(valsToParamsMap, paramsDict);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace jit
|
||||||
|
} // namespace torch
|
||||||
17
torch/csrc/jit/passes/onnx/eliminate_unused_items.h
Normal file
17
torch/csrc/jit/passes/onnx/eliminate_unused_items.h
Normal file
|
|
@ -0,0 +1,17 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <torch/csrc/jit/ir/ir.h>
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace jit {
|
||||||
|
|
||||||
|
// EliminateUnusedItemsONNX pass is removing unused
|
||||||
|
// initializers and inputs, this is needed because
|
||||||
|
// dce pass is only removing unused fork inputs
|
||||||
|
void EliminateUnusedItemsONNX(
|
||||||
|
Block* b,
|
||||||
|
std::map<std::string, IValue>& paramDict);
|
||||||
|
|
||||||
|
} // namespace jit
|
||||||
|
|
||||||
|
} // namespace torch
|
||||||
|
|
@ -135,15 +135,6 @@ static void fuseConvBatchNorm(Block* b, ValueToParamPairMap& valsToParamsMap) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void buildParamsMapFromValueToParamsMap(
|
|
||||||
const ValueToParamPairMap& valsToParamsMap,
|
|
||||||
ParamMap& paramsDict) {
|
|
||||||
paramsDict.clear();
|
|
||||||
for (const auto& nameTensorParamPair : valsToParamsMap) {
|
|
||||||
paramsDict.insert(nameTensorParamPair.second);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void EvalPeepholeONNX(Block* b, ParamMap& paramsDict) {
|
void EvalPeepholeONNX(Block* b, ParamMap& paramsDict) {
|
||||||
auto valsToParamsMap = buildValueToParamsMap(b, paramsDict);
|
auto valsToParamsMap = buildValueToParamsMap(b, paramsDict);
|
||||||
fuseConvBatchNorm(b, valsToParamsMap);
|
fuseConvBatchNorm(b, valsToParamsMap);
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,15 @@ ValueToParamPairMap buildValueToParamsMap(
|
||||||
return valsToParamsMap;
|
return valsToParamsMap;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void eraseUnusedBlockInputs(Block* b) {
|
||||||
|
for (size_t i_1 = b->inputs().size(); i_1 > 0; --i_1) {
|
||||||
|
size_t i = i_1 - 1;
|
||||||
|
if (!b->inputs().at(i)->hasUses()) {
|
||||||
|
b->eraseInput(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void eraseUnusedValuesFromMap(ValueToParamPairMap& valsToParamsMap) {
|
void eraseUnusedValuesFromMap(ValueToParamPairMap& valsToParamsMap) {
|
||||||
auto it = valsToParamsMap.begin();
|
auto it = valsToParamsMap.begin();
|
||||||
while (it != valsToParamsMap.end()) {
|
while (it != valsToParamsMap.end()) {
|
||||||
|
|
@ -31,5 +40,14 @@ void eraseUnusedValuesFromMap(ValueToParamPairMap& valsToParamsMap) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void buildParamsMapFromValueToParamsMap(
|
||||||
|
const ValueToParamPairMap& valsToParamsMap,
|
||||||
|
ParamMap& paramsDict) {
|
||||||
|
paramsDict.clear();
|
||||||
|
for (const auto& nameTensorParamPair : valsToParamsMap) {
|
||||||
|
paramsDict.insert(nameTensorParamPair.second);
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,10 @@ using ParamMap = std::map<std::string, IValue>;
|
||||||
|
|
||||||
ValueToParamPairMap buildValueToParamsMap(Block* b, const ParamMap& paramsDict);
|
ValueToParamPairMap buildValueToParamsMap(Block* b, const ParamMap& paramsDict);
|
||||||
void eraseUnusedValuesFromMap(ValueToParamPairMap& valsToParamsMap);
|
void eraseUnusedValuesFromMap(ValueToParamPairMap& valsToParamsMap);
|
||||||
|
void eraseUnusedBlockInputs(Block* b);
|
||||||
|
void buildParamsMapFromValueToParamsMap(
|
||||||
|
const ValueToParamPairMap& valsToParamsMap,
|
||||||
|
ParamMap& paramsDict);
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|
|
||||||
|
|
@ -32,6 +32,7 @@
|
||||||
#include <torch/csrc/jit/passes/onnx.h>
|
#include <torch/csrc/jit/passes/onnx.h>
|
||||||
#include <torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.h>
|
#include <torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.h>
|
||||||
#include <torch/csrc/jit/passes/onnx/constant_fold.h>
|
#include <torch/csrc/jit/passes/onnx/constant_fold.h>
|
||||||
|
#include <torch/csrc/jit/passes/onnx/eliminate_unused_items.h>
|
||||||
#include <torch/csrc/jit/passes/onnx/eval_peephole.h>
|
#include <torch/csrc/jit/passes/onnx/eval_peephole.h>
|
||||||
#include <torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.h>
|
#include <torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.h>
|
||||||
#include <torch/csrc/jit/passes/onnx/function_substitution.h>
|
#include <torch/csrc/jit/passes/onnx/function_substitution.h>
|
||||||
|
|
@ -168,6 +169,16 @@ void initJITBindings(PyObject* module) {
|
||||||
return paramsDict;
|
return paramsDict;
|
||||||
},
|
},
|
||||||
pybind11::return_value_policy::move)
|
pybind11::return_value_policy::move)
|
||||||
|
.def(
|
||||||
|
"_jit_pass_onnx_eliminate_unused_items",
|
||||||
|
[](std::shared_ptr<Graph>& graph,
|
||||||
|
std::map<std::string, IValue>& paramsDict) {
|
||||||
|
EliminateUnusedItemsONNX(
|
||||||
|
graph->block(),
|
||||||
|
paramsDict); // overload resolution
|
||||||
|
return paramsDict;
|
||||||
|
},
|
||||||
|
pybind11::return_value_policy::move)
|
||||||
.def("_jit_pass_onnx_scalar_type_analysis", ScalarTypeAnalysisForONNX)
|
.def("_jit_pass_onnx_scalar_type_analysis", ScalarTypeAnalysisForONNX)
|
||||||
.def(
|
.def(
|
||||||
"_jit_pass_onnx_prepare_inplace_ops_for_onnx",
|
"_jit_pass_onnx_prepare_inplace_ops_for_onnx",
|
||||||
|
|
|
||||||
|
|
@ -424,6 +424,8 @@ def _model_to_graph(model, args, verbose=False,
|
||||||
_export_onnx_opset_version)
|
_export_onnx_opset_version)
|
||||||
torch._C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
|
torch._C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
|
||||||
|
|
||||||
|
params_dict = torch._C._jit_pass_onnx_eliminate_unused_items(graph, params_dict)
|
||||||
|
|
||||||
# For ONNX opset < 9, constants only have three data types: float16, float, double.
|
# For ONNX opset < 9, constants only have three data types: float16, float, double.
|
||||||
# In this pass transform constants of other data types to float/double + cast operator.
|
# In this pass transform constants of other data types to float/double + cast operator.
|
||||||
if _export_onnx_opset_version < 9:
|
if _export_onnx_opset_version < 9:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user