Allow ONNX models without parameters (#121904)

Currently, if initializers are available, they are included in the ONNX model. If they are not available, the model is serialized without them.

However, there are times in which the initializers are avaialable, but the user prefers not to include them in the model, say for visualizing it on Netron or because the initialziers will be specified along with the inputs in the onnx runtime of choice.

This PR allow users to pass `include_initializers` to `ONNXProgram.save()` API.

Fixes #100996
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121904
Approved by: https://github.com/titaiwangms
This commit is contained in:
Thiago Crepaldi 2024-04-22 15:53:38 +00:00 committed by PyTorch MergeBot
parent 6ede882c0b
commit 9c2ac4476c
2 changed files with 122 additions and 15 deletions

View File

@ -762,6 +762,100 @@ class TestFxToOnnx(pytorch_test_common.ExportTestCase):
onnx_program.save(tmp_onnx_file.name)
onnx.checker.check_model(tmp_onnx_file.name, full_check=True)
@common_utils.parametrize(
"include_initializer",
[
common_utils.subtest(
True,
name="include_initializer",
),
common_utils.subtest(
False,
name="dont_include_initializer",
),
],
)
@common_utils.parametrize(
"use_fake_mode",
[
common_utils.subtest(
True,
name="use_fake_mode",
),
common_utils.subtest(
False,
name="no_fake_mode",
),
],
)
@common_utils.parametrize(
"use_exported_program",
[
common_utils.subtest(
True,
name="use_exported_program",
),
common_utils.subtest(
False,
name="no_exported_program",
),
],
)
def test_save_with_without_initializer(
self, include_initializer, use_fake_mode, use_exported_program
):
class MNISTModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1, bias=False)
self.conv2 = nn.Conv2d(32, 64, 3, 1, bias=False)
self.fc1 = nn.Linear(9216, 128, bias=False)
self.fc2 = nn.Linear(128, 10, bias=False)
def forward(self, tensor_x: torch.Tensor):
tensor_x = self.conv1(tensor_x)
tensor_x = F.sigmoid(tensor_x)
tensor_x = self.conv2(tensor_x)
tensor_x = F.sigmoid(tensor_x)
tensor_x = F.max_pool2d(tensor_x, 2)
tensor_x = torch.flatten(tensor_x, 1)
tensor_x = self.fc1(tensor_x)
tensor_x = F.sigmoid(tensor_x)
tensor_x = self.fc2(tensor_x)
output = F.log_softmax(tensor_x, dim=1)
return output
state_dict = MNISTModel().state_dict()
if use_fake_mode:
with torch.onnx.enable_fake_mode() as ctx:
model = MNISTModel()
tensor_x = torch.rand((64, 1, 28, 28), dtype=torch.float32)
if use_exported_program:
model = torch.export.export(model, args=(tensor_x,))
export_options = torch.onnx.ExportOptions(fake_context=ctx)
else:
model = MNISTModel()
tensor_x = torch.rand((64, 1, 28, 28), dtype=torch.float32)
if use_exported_program:
model = torch.export.export(model, args=(tensor_x,))
export_options = torch.onnx.ExportOptions()
onnx_program = torch.onnx.dynamo_export(
model, tensor_x, export_options=export_options
)
with tempfile.NamedTemporaryFile(suffix=".onnx") as tmp_onnx_file:
onnx_program.save(
tmp_onnx_file.name,
include_initializers=include_initializer,
model_state=state_dict if include_initializer else None,
)
onnx_model = onnx.load(tmp_onnx_file.name)
self.assertEqual(
(include_initializer and len(onnx_model.graph.initializer) > 0)
or (not include_initializer and len(onnx_model.graph.initializer) == 0),
True,
)
def test_export_with_print(self):
class PrintModule(torch.nn.Module):
def forward(self, x):

View File

@ -1010,6 +1010,7 @@ class ONNXProgram:
self,
destination: Union[str, io.BufferedIOBase],
*,
include_initializers: bool = True,
model_state: Optional[Union[Dict[str, Any], str]] = None,
serializer: Optional[ONNXProgramSerializer] = None,
) -> None:
@ -1021,12 +1022,18 @@ class ONNXProgram:
If `destination` is a string, besides saving the ONNX model into a file, model weights are also stored
in separate files in the same directory as the ONNX model. E.g. for `destination="/path/model.onnx"`,
the initializers are saved in "/path/" folder along with "onnx.model".
include_initializers: Whether to include initializers in the ONNX graph as external data.
Cannot be combined with `model_state_dict`.
model_state: The state_dict of the PyTorch model containing all weights on it.
It can be either a string with the path to a checkpoint or a dictionary with the actual model state.
The supported file formats are the same as those supported by `torch.load` and `safetensors.safe_open`.
Required when :func:`enable_fake_mode` is used but real initializers are needed on the ONNX graph.
serializer: The serializer to use. If not specified, the model will be serialized as Protobuf.
"""
assert (
include_initializers is True or model_state is None
), "Cannot specify both `include_initializers=False` and `model_state`."
if serializer is None:
if isinstance(destination, str):
serializer = LargeProtobufONNXProgramSerializer(destination)
@ -1035,21 +1042,27 @@ class ONNXProgram:
# Add initializers when symbolic tracing is enabled
_model_state_files: List[Union[str, io.BytesIO, Dict[str, Any]]] = []
if model_state is not None:
assert isinstance(
model_state, (dict, str)
), "model_state must be a path to the model's state_dict or the actual state_dict"
# NOTE: For dict, there can be performance penalty or high memory usage that might lead to OOM
# if the dict wasn't loaded with torch.load(..., mmap=True, map_location="cpu")
_model_state_files.append(model_state)
elif self._fake_context and self._fake_context.state_dict_paths:
# Load state from previous model.load_state_dict() call within enable_fake_mode() context
for path in self._fake_context.state_dict_paths:
if path in _model_state_files:
# ignore duplicate
continue
if os.path.exists(path): # type: ignore[arg-type]
_model_state_files.append(path)
if include_initializers:
if model_state is not None:
assert isinstance(
model_state, (dict, str)
), "model_state must be a path to the model's state_dict or the actual state_dict"
# NOTE: For dict, there can be performance penalty or high memory usage that might lead to OOM
# if the dict wasn't loaded with torch.load(..., mmap=True, map_location="cpu")
_model_state_files.append(model_state)
elif self._fake_context and self._fake_context.state_dict_paths:
# Load state from previous model.load_state_dict() call within enable_fake_mode() context
for path in self._fake_context.state_dict_paths:
if path in _model_state_files:
# ignore duplicate
continue
if os.path.exists(path): # type: ignore[arg-type]
_model_state_files.append(path)
else:
# self.model_proto.graph.initializer.clear() not available in older protobuf versions
initializer_count = len(self.model_proto.graph.initializer)
for _ in range(initializer_count):
del self.model_proto.graph.initializer[0]
if _model_state_files:
if not isinstance(destination, str):