mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Allow ONNX models without parameters (#121904)
Currently, if initializers are available, they are included in the ONNX model. If they are not available, the model is serialized without them. However, there are times in which the initializers are avaialable, but the user prefers not to include them in the model, say for visualizing it on Netron or because the initialziers will be specified along with the inputs in the onnx runtime of choice. This PR allow users to pass `include_initializers` to `ONNXProgram.save()` API. Fixes #100996 Pull Request resolved: https://github.com/pytorch/pytorch/pull/121904 Approved by: https://github.com/titaiwangms
This commit is contained in:
parent
6ede882c0b
commit
9c2ac4476c
|
|
@ -762,6 +762,100 @@ class TestFxToOnnx(pytorch_test_common.ExportTestCase):
|
|||
onnx_program.save(tmp_onnx_file.name)
|
||||
onnx.checker.check_model(tmp_onnx_file.name, full_check=True)
|
||||
|
||||
@common_utils.parametrize(
|
||||
"include_initializer",
|
||||
[
|
||||
common_utils.subtest(
|
||||
True,
|
||||
name="include_initializer",
|
||||
),
|
||||
common_utils.subtest(
|
||||
False,
|
||||
name="dont_include_initializer",
|
||||
),
|
||||
],
|
||||
)
|
||||
@common_utils.parametrize(
|
||||
"use_fake_mode",
|
||||
[
|
||||
common_utils.subtest(
|
||||
True,
|
||||
name="use_fake_mode",
|
||||
),
|
||||
common_utils.subtest(
|
||||
False,
|
||||
name="no_fake_mode",
|
||||
),
|
||||
],
|
||||
)
|
||||
@common_utils.parametrize(
|
||||
"use_exported_program",
|
||||
[
|
||||
common_utils.subtest(
|
||||
True,
|
||||
name="use_exported_program",
|
||||
),
|
||||
common_utils.subtest(
|
||||
False,
|
||||
name="no_exported_program",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_save_with_without_initializer(
|
||||
self, include_initializer, use_fake_mode, use_exported_program
|
||||
):
|
||||
class MNISTModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(1, 32, 3, 1, bias=False)
|
||||
self.conv2 = nn.Conv2d(32, 64, 3, 1, bias=False)
|
||||
self.fc1 = nn.Linear(9216, 128, bias=False)
|
||||
self.fc2 = nn.Linear(128, 10, bias=False)
|
||||
|
||||
def forward(self, tensor_x: torch.Tensor):
|
||||
tensor_x = self.conv1(tensor_x)
|
||||
tensor_x = F.sigmoid(tensor_x)
|
||||
tensor_x = self.conv2(tensor_x)
|
||||
tensor_x = F.sigmoid(tensor_x)
|
||||
tensor_x = F.max_pool2d(tensor_x, 2)
|
||||
tensor_x = torch.flatten(tensor_x, 1)
|
||||
tensor_x = self.fc1(tensor_x)
|
||||
tensor_x = F.sigmoid(tensor_x)
|
||||
tensor_x = self.fc2(tensor_x)
|
||||
output = F.log_softmax(tensor_x, dim=1)
|
||||
return output
|
||||
|
||||
state_dict = MNISTModel().state_dict()
|
||||
if use_fake_mode:
|
||||
with torch.onnx.enable_fake_mode() as ctx:
|
||||
model = MNISTModel()
|
||||
tensor_x = torch.rand((64, 1, 28, 28), dtype=torch.float32)
|
||||
if use_exported_program:
|
||||
model = torch.export.export(model, args=(tensor_x,))
|
||||
export_options = torch.onnx.ExportOptions(fake_context=ctx)
|
||||
else:
|
||||
model = MNISTModel()
|
||||
tensor_x = torch.rand((64, 1, 28, 28), dtype=torch.float32)
|
||||
if use_exported_program:
|
||||
model = torch.export.export(model, args=(tensor_x,))
|
||||
export_options = torch.onnx.ExportOptions()
|
||||
|
||||
onnx_program = torch.onnx.dynamo_export(
|
||||
model, tensor_x, export_options=export_options
|
||||
)
|
||||
with tempfile.NamedTemporaryFile(suffix=".onnx") as tmp_onnx_file:
|
||||
onnx_program.save(
|
||||
tmp_onnx_file.name,
|
||||
include_initializers=include_initializer,
|
||||
model_state=state_dict if include_initializer else None,
|
||||
)
|
||||
onnx_model = onnx.load(tmp_onnx_file.name)
|
||||
self.assertEqual(
|
||||
(include_initializer and len(onnx_model.graph.initializer) > 0)
|
||||
or (not include_initializer and len(onnx_model.graph.initializer) == 0),
|
||||
True,
|
||||
)
|
||||
|
||||
def test_export_with_print(self):
|
||||
class PrintModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
|
|
|
|||
|
|
@ -1010,6 +1010,7 @@ class ONNXProgram:
|
|||
self,
|
||||
destination: Union[str, io.BufferedIOBase],
|
||||
*,
|
||||
include_initializers: bool = True,
|
||||
model_state: Optional[Union[Dict[str, Any], str]] = None,
|
||||
serializer: Optional[ONNXProgramSerializer] = None,
|
||||
) -> None:
|
||||
|
|
@ -1021,12 +1022,18 @@ class ONNXProgram:
|
|||
If `destination` is a string, besides saving the ONNX model into a file, model weights are also stored
|
||||
in separate files in the same directory as the ONNX model. E.g. for `destination="/path/model.onnx"`,
|
||||
the initializers are saved in "/path/" folder along with "onnx.model".
|
||||
include_initializers: Whether to include initializers in the ONNX graph as external data.
|
||||
Cannot be combined with `model_state_dict`.
|
||||
model_state: The state_dict of the PyTorch model containing all weights on it.
|
||||
It can be either a string with the path to a checkpoint or a dictionary with the actual model state.
|
||||
The supported file formats are the same as those supported by `torch.load` and `safetensors.safe_open`.
|
||||
Required when :func:`enable_fake_mode` is used but real initializers are needed on the ONNX graph.
|
||||
serializer: The serializer to use. If not specified, the model will be serialized as Protobuf.
|
||||
"""
|
||||
|
||||
assert (
|
||||
include_initializers is True or model_state is None
|
||||
), "Cannot specify both `include_initializers=False` and `model_state`."
|
||||
if serializer is None:
|
||||
if isinstance(destination, str):
|
||||
serializer = LargeProtobufONNXProgramSerializer(destination)
|
||||
|
|
@ -1035,21 +1042,27 @@ class ONNXProgram:
|
|||
|
||||
# Add initializers when symbolic tracing is enabled
|
||||
_model_state_files: List[Union[str, io.BytesIO, Dict[str, Any]]] = []
|
||||
if model_state is not None:
|
||||
assert isinstance(
|
||||
model_state, (dict, str)
|
||||
), "model_state must be a path to the model's state_dict or the actual state_dict"
|
||||
# NOTE: For dict, there can be performance penalty or high memory usage that might lead to OOM
|
||||
# if the dict wasn't loaded with torch.load(..., mmap=True, map_location="cpu")
|
||||
_model_state_files.append(model_state)
|
||||
elif self._fake_context and self._fake_context.state_dict_paths:
|
||||
# Load state from previous model.load_state_dict() call within enable_fake_mode() context
|
||||
for path in self._fake_context.state_dict_paths:
|
||||
if path in _model_state_files:
|
||||
# ignore duplicate
|
||||
continue
|
||||
if os.path.exists(path): # type: ignore[arg-type]
|
||||
_model_state_files.append(path)
|
||||
if include_initializers:
|
||||
if model_state is not None:
|
||||
assert isinstance(
|
||||
model_state, (dict, str)
|
||||
), "model_state must be a path to the model's state_dict or the actual state_dict"
|
||||
# NOTE: For dict, there can be performance penalty or high memory usage that might lead to OOM
|
||||
# if the dict wasn't loaded with torch.load(..., mmap=True, map_location="cpu")
|
||||
_model_state_files.append(model_state)
|
||||
elif self._fake_context and self._fake_context.state_dict_paths:
|
||||
# Load state from previous model.load_state_dict() call within enable_fake_mode() context
|
||||
for path in self._fake_context.state_dict_paths:
|
||||
if path in _model_state_files:
|
||||
# ignore duplicate
|
||||
continue
|
||||
if os.path.exists(path): # type: ignore[arg-type]
|
||||
_model_state_files.append(path)
|
||||
else:
|
||||
# self.model_proto.graph.initializer.clear() not available in older protobuf versions
|
||||
initializer_count = len(self.model_proto.graph.initializer)
|
||||
for _ in range(initializer_count):
|
||||
del self.model_proto.graph.initializer[0]
|
||||
|
||||
if _model_state_files:
|
||||
if not isinstance(destination, str):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user