mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Resending] [ONNX] Add eliminate_unused_items pass (#42743)
Summary: This PR: - Adds eliminate_unused_items pass that removes unused inputs and initializers. - Fixes run_embed_params function so it doesn't export unnecessary parameters. - Removes test_modifying_params in test_verify since it's no longer needed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/42743 Reviewed By: hl475 Differential Revision: D23058954 Pulled By: houseroad fbshipit-source-id: cd1e81463285a0bf4e60766c8c87fc9a350d9c7e
This commit is contained in:
parent
a846ed5ce7
commit
e845b0ab51
|
|
@ -41,12 +41,7 @@ def run_embed_params(proto, model, input, state_dict=None, use_gpu=True):
|
|||
# sure our order is consistent with the model's order.
|
||||
# TODO: Even better: keyword arguments!
|
||||
for k in model.state_dict():
|
||||
if k not in state_dict:
|
||||
# Once PyTorch Module adds unnecessary parameter, the old pre-trained model does not have it.
|
||||
# Just simply pass the new one.
|
||||
# TODO: Please don't export unnecessary parameter.
|
||||
parameters.append(model.state_dict()[k])
|
||||
else:
|
||||
if k in state_dict:
|
||||
parameters.append(state_dict[k])
|
||||
else:
|
||||
parameters = list(model.state_dict().values())
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ from torch.onnx import utils, OperatorExportTypes, TrainingMode
|
|||
from torch.onnx.symbolic_helper import _set_opset_version, _set_operator_export_type
|
||||
import torch.utils.cpp_extension
|
||||
from test_pytorch_common import skipIfUnsupportedMinOpsetVersion
|
||||
import caffe2.python.onnx.backend as backend
|
||||
from verify import verify
|
||||
|
||||
import torchvision
|
||||
|
||||
|
|
@ -698,6 +700,39 @@ class TestUtilityFuns(TestCase):
|
|||
|
||||
np.testing.assert_allclose(ratio_pytorch, ratio_ort, rtol=0.01, atol=0.01)
|
||||
|
||||
def test_unused_initializers(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.conv2 = torch.nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(1, 1))
|
||||
self.k_proj = torch.nn.Linear(5, 5, bias=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
x = torch.randn(20, 16, 50, 100)
|
||||
_set_opset_version(self.opset_version)
|
||||
_set_operator_export_type(OperatorExportTypes.ONNX)
|
||||
_, params_dict, __ = utils._model_to_graph(Model(), (x, ), do_constant_folding=False,
|
||||
operator_export_type=OperatorExportTypes.ONNX)
|
||||
|
||||
assert len(params_dict) == 2
|
||||
|
||||
def test_modifying_params(self):
|
||||
class MyModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModel, self).__init__()
|
||||
self.param = torch.nn.Parameter(torch.tensor([2.0]))
|
||||
|
||||
def forward(self, x):
|
||||
y = x * x
|
||||
self.param.data.add_(1.0)
|
||||
return y
|
||||
|
||||
x = torch.tensor([1, 2])
|
||||
verify(MyModel(), x, backend, do_constant_folding=False)
|
||||
|
||||
def test_fuse_conv_bn(self):
|
||||
class Fuse(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
|
|||
|
|
@ -64,21 +64,6 @@ class TestVerify(TestCase):
|
|||
with self.assertRaisesRegex(RuntimeError, "state_dict changed"):
|
||||
verify(MyModel(), x, backend)
|
||||
|
||||
def test_modifying_params(self):
|
||||
class MyModel(Module):
|
||||
def __init__(self):
|
||||
super(MyModel, self).__init__()
|
||||
self.param = Parameter(torch.tensor([2.0]))
|
||||
|
||||
def forward(self, x):
|
||||
y = x * x
|
||||
self.param.data.add_(1.0)
|
||||
return y
|
||||
|
||||
x = torch.tensor([1, 2])
|
||||
# To keep the unused model parameter, need to set constant folding to False
|
||||
self.assertVerifyExpectFail(MyModel(), x, backend, do_constant_folding=False)
|
||||
|
||||
def test_dynamic_model_structure(self):
|
||||
class MyModel(Module):
|
||||
def __init__(self):
|
||||
|
|
|
|||
|
|
@ -482,6 +482,7 @@ libtorch_python_core_sources = [
|
|||
"torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.cpp",
|
||||
"torch/csrc/jit/passes/onnx/eval_peephole.cpp",
|
||||
"torch/csrc/jit/passes/onnx/constant_fold.cpp",
|
||||
"torch/csrc/jit/passes/onnx/eliminate_unused_items.cpp",
|
||||
"torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp",
|
||||
"torch/csrc/jit/passes/onnx/function_substitution.cpp",
|
||||
"torch/csrc/jit/passes/onnx/helper.cpp",
|
||||
|
|
|
|||
|
|
@ -14,40 +14,35 @@ using namespace ::c10::onnx;
|
|||
|
||||
namespace {
|
||||
|
||||
enum OnnxType : int {
|
||||
ONNX_FLOAT = 1,
|
||||
ONNX_UINT8,
|
||||
ONNX_INT8,
|
||||
ONNX_UINT16,
|
||||
ONNX_INT16,
|
||||
ONNX_INT32,
|
||||
ONNX_INT64,
|
||||
ONNX_FLOAT16 = 10,
|
||||
ONNX_DOUBLE,
|
||||
ONNX_UINT32,
|
||||
};
|
||||
|
||||
std::unordered_map<int, at::ScalarType> onnxTypeToScalarTypeMap = {
|
||||
// Only conversion of ONNX numeric types is included here.
|
||||
// Unsigned ONNX types are mapped to the next higher signed
|
||||
// ScalarType type.
|
||||
{1, at::kFloat},
|
||||
{2, at::kByte},
|
||||
{3, at::kChar},
|
||||
{4, at::kInt},
|
||||
{5, at::kShort},
|
||||
{6, at::kInt},
|
||||
{7, at::kLong},
|
||||
{10, at::kFloat},
|
||||
{11, at::kDouble},
|
||||
{12, at::kLong},
|
||||
{ONNX_FLOAT, at::kFloat},
|
||||
{ONNX_UINT8, at::kByte},
|
||||
{ONNX_INT8, at::kChar},
|
||||
{ONNX_UINT16, at::kInt},
|
||||
{ONNX_INT16, at::kShort},
|
||||
{ONNX_INT32, at::kInt},
|
||||
{ONNX_INT64, at::kLong},
|
||||
{ONNX_FLOAT16, at::kFloat},
|
||||
{ONNX_DOUBLE, at::kDouble},
|
||||
{ONNX_UINT32, at::kLong},
|
||||
};
|
||||
|
||||
void buildParamsMapFromValueToParamsMap(
|
||||
const ValueToParamPairMap& valsToParamsMap,
|
||||
ParamMap& paramsDict) {
|
||||
paramsDict.clear();
|
||||
for (const auto& nameTensorParamPair : valsToParamsMap) {
|
||||
paramsDict.insert(nameTensorParamPair.second);
|
||||
}
|
||||
}
|
||||
|
||||
void eraseUnusedBlockInputs(Block* b) {
|
||||
for (size_t i_1 = b->inputs().size(); i_1 > 0; --i_1) {
|
||||
size_t i = i_1 - 1;
|
||||
if (!b->inputs().at(i)->hasUses()) {
|
||||
b->eraseInput(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void handleNegativeStartEndIndex(
|
||||
int64_t& start,
|
||||
int64_t& end,
|
||||
|
|
|
|||
23
torch/csrc/jit/passes/onnx/eliminate_unused_items.cpp
Normal file
23
torch/csrc/jit/passes/onnx/eliminate_unused_items.cpp
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
#include <torch/csrc/jit/passes/onnx/eliminate_unused_items.h>
|
||||
#include <torch/csrc/jit/passes/onnx/helper.h>
|
||||
|
||||
#include <c10/util/Optional.h>
|
||||
#include <algorithm>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
namespace onnx {
|
||||
using namespace ::c10::onnx;
|
||||
}
|
||||
|
||||
void EliminateUnusedItemsONNX(Block* b, ParamMap& paramsDict) {
|
||||
auto valsToParamsMap = buildValueToParamsMap(b, paramsDict);
|
||||
eraseUnusedValuesFromMap(valsToParamsMap);
|
||||
eraseUnusedBlockInputs(b);
|
||||
buildParamsMapFromValueToParamsMap(valsToParamsMap, paramsDict);
|
||||
return;
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
17
torch/csrc/jit/passes/onnx/eliminate_unused_items.h
Normal file
17
torch/csrc/jit/passes/onnx/eliminate_unused_items.h
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
// EliminateUnusedItemsONNX pass is removing unused
|
||||
// initializers and inputs, this is needed because
|
||||
// dce pass is only removing unused fork inputs
|
||||
void EliminateUnusedItemsONNX(
|
||||
Block* b,
|
||||
std::map<std::string, IValue>& paramDict);
|
||||
|
||||
} // namespace jit
|
||||
|
||||
} // namespace torch
|
||||
|
|
@ -135,15 +135,6 @@ static void fuseConvBatchNorm(Block* b, ValueToParamPairMap& valsToParamsMap) {
|
|||
}
|
||||
}
|
||||
|
||||
void buildParamsMapFromValueToParamsMap(
|
||||
const ValueToParamPairMap& valsToParamsMap,
|
||||
ParamMap& paramsDict) {
|
||||
paramsDict.clear();
|
||||
for (const auto& nameTensorParamPair : valsToParamsMap) {
|
||||
paramsDict.insert(nameTensorParamPair.second);
|
||||
}
|
||||
}
|
||||
|
||||
void EvalPeepholeONNX(Block* b, ParamMap& paramsDict) {
|
||||
auto valsToParamsMap = buildValueToParamsMap(b, paramsDict);
|
||||
fuseConvBatchNorm(b, valsToParamsMap);
|
||||
|
|
|
|||
|
|
@ -21,6 +21,15 @@ ValueToParamPairMap buildValueToParamsMap(
|
|||
return valsToParamsMap;
|
||||
}
|
||||
|
||||
void eraseUnusedBlockInputs(Block* b) {
|
||||
for (size_t i_1 = b->inputs().size(); i_1 > 0; --i_1) {
|
||||
size_t i = i_1 - 1;
|
||||
if (!b->inputs().at(i)->hasUses()) {
|
||||
b->eraseInput(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void eraseUnusedValuesFromMap(ValueToParamPairMap& valsToParamsMap) {
|
||||
auto it = valsToParamsMap.begin();
|
||||
while (it != valsToParamsMap.end()) {
|
||||
|
|
@ -31,5 +40,14 @@ void eraseUnusedValuesFromMap(ValueToParamPairMap& valsToParamsMap) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
void buildParamsMapFromValueToParamsMap(
|
||||
const ValueToParamPairMap& valsToParamsMap,
|
||||
ParamMap& paramsDict) {
|
||||
paramsDict.clear();
|
||||
for (const auto& nameTensorParamPair : valsToParamsMap) {
|
||||
paramsDict.insert(nameTensorParamPair.second);
|
||||
}
|
||||
}
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -22,6 +22,10 @@ using ParamMap = std::map<std::string, IValue>;
|
|||
|
||||
ValueToParamPairMap buildValueToParamsMap(Block* b, const ParamMap& paramsDict);
|
||||
void eraseUnusedValuesFromMap(ValueToParamPairMap& valsToParamsMap);
|
||||
void eraseUnusedBlockInputs(Block* b);
|
||||
void buildParamsMapFromValueToParamsMap(
|
||||
const ValueToParamPairMap& valsToParamsMap,
|
||||
ParamMap& paramsDict);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@
|
|||
#include <torch/csrc/jit/passes/onnx.h>
|
||||
#include <torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.h>
|
||||
#include <torch/csrc/jit/passes/onnx/constant_fold.h>
|
||||
#include <torch/csrc/jit/passes/onnx/eliminate_unused_items.h>
|
||||
#include <torch/csrc/jit/passes/onnx/eval_peephole.h>
|
||||
#include <torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.h>
|
||||
#include <torch/csrc/jit/passes/onnx/function_substitution.h>
|
||||
|
|
@ -168,6 +169,16 @@ void initJITBindings(PyObject* module) {
|
|||
return paramsDict;
|
||||
},
|
||||
pybind11::return_value_policy::move)
|
||||
.def(
|
||||
"_jit_pass_onnx_eliminate_unused_items",
|
||||
[](std::shared_ptr<Graph>& graph,
|
||||
std::map<std::string, IValue>& paramsDict) {
|
||||
EliminateUnusedItemsONNX(
|
||||
graph->block(),
|
||||
paramsDict); // overload resolution
|
||||
return paramsDict;
|
||||
},
|
||||
pybind11::return_value_policy::move)
|
||||
.def("_jit_pass_onnx_scalar_type_analysis", ScalarTypeAnalysisForONNX)
|
||||
.def(
|
||||
"_jit_pass_onnx_prepare_inplace_ops_for_onnx",
|
||||
|
|
|
|||
|
|
@ -424,6 +424,8 @@ def _model_to_graph(model, args, verbose=False,
|
|||
_export_onnx_opset_version)
|
||||
torch._C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
|
||||
|
||||
params_dict = torch._C._jit_pass_onnx_eliminate_unused_items(graph, params_dict)
|
||||
|
||||
# For ONNX opset < 9, constants only have three data types: float16, float, double.
|
||||
# In this pass transform constants of other data types to float/double + cast operator.
|
||||
if _export_onnx_opset_version < 9:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user