mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[Export] Improve metadata and output parsing during deserialization (#122793)
Summary: Deserialization of metadata could encounter a bug where commas are used in valid metadata names. This specifically occurs when a split of a `torch.nn.Sequential` stack is used, but may have other possible triggers. Because the deserialization relies on a comma based string split, such names trigger an error. This change uses a simple regular expression to ignore commas within parentheses to avoid the issue. I add a test that constructs one such problematic sequential stack and show that it can be properly round-tripped with the improved splitting. Similarly, deserialization could fail when outputs are not a tensor type. Although such outputs like None or constants are not very useful, they do show up in graphs and export should be able to support them. This change improves output node parsing and adds a corresponding test. Test Plan: buck test //caffe2/test:test_export -- TestSerialize Differential Revision: D55391674 Pull Request resolved: https://github.com/pytorch/pytorch/pull/122793 Approved by: https://github.com/zhxchen17
This commit is contained in:
parent
064a650b63
commit
0c8a165b43
|
|
@ -116,6 +116,61 @@ class TestSerialize(TestCase):
|
|||
loaded_out = loaded_model.module()(*loaded_args, **loaded_kwargs)
|
||||
self.assertEqual(orig_out, loaded_out)
|
||||
|
||||
def test_metadata_parsing_with_layer_split(self):
|
||||
# Tests that modules with more complicated layer patterns can be serialized
|
||||
# and deserialized correctly.
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.layers = torch.nn.Sequential(
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.SiLU(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# Splitting layers of a sequential stack introduces commas and parens
|
||||
# into metadata trace.
|
||||
out_start, out_rest = self.layers[0], self.layers[1:]
|
||||
h = out_start(x)
|
||||
h = out_rest(h)
|
||||
return h
|
||||
|
||||
inp = (torch.ones(10),)
|
||||
# Module will only be able to roundtrip if metadata
|
||||
# can be correctly parsed.
|
||||
ep = export(MyModule(), inp)
|
||||
buffer = io.BytesIO()
|
||||
save(ep, buffer)
|
||||
loaded_ep = load(buffer)
|
||||
|
||||
# Check that both modules run to confirm load was successful.
|
||||
exp_out = ep.module()(*inp)
|
||||
actual_out = loaded_ep.module()(*inp)
|
||||
self.assertEqual(exp_out, actual_out)
|
||||
|
||||
def test_serialize_constant_outputs(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
# Along with tensor output, return Nonetype
|
||||
# and constant. Although these outputs aren't
|
||||
# very useful, they do show up in graphs.
|
||||
return x + 1, None, 1024
|
||||
|
||||
# Check that module can be roundtripped, thereby confirming proper deserialization.
|
||||
inp = (torch.ones(10),)
|
||||
ep = export(MyModule(), inp)
|
||||
buffer = io.BytesIO()
|
||||
save(ep, buffer)
|
||||
loaded_ep = load(buffer)
|
||||
|
||||
exp_out = ep.module()(*inp)
|
||||
actual_out = loaded_ep.module()(*inp)
|
||||
self.assertEqual(exp_out, actual_out)
|
||||
|
||||
def test_serialize_multiple_returns_from_node(self) -> None:
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import math
|
|||
import operator
|
||||
import typing
|
||||
import copyreg
|
||||
import re
|
||||
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
|
|
@ -1440,13 +1441,17 @@ class GraphModuleDeserializer:
|
|||
class_fqn=script_obj_meta.class_fqn,
|
||||
)
|
||||
|
||||
def deserialize_graph_output(self, output) -> torch.fx.Node:
|
||||
def deserialize_graph_output(self, output) -> Optional[Union[torch.fx.Node, int]]:
|
||||
if output.type == "as_tensor":
|
||||
return self.serialized_name_to_node[output.as_tensor.name]
|
||||
elif output.type == "as_sym_int":
|
||||
return self.serialized_name_to_node[output.as_sym_int.as_name]
|
||||
elif output.type == "as_sym_bool":
|
||||
return self.serialized_name_to_node[output.as_sym_bool.as_name]
|
||||
elif output.type == "as_int":
|
||||
return output.as_int
|
||||
elif output.type == "as_none":
|
||||
return None
|
||||
else:
|
||||
raise SerializeError(f"Unable to deserialize output node {output}")
|
||||
|
||||
|
|
@ -1516,7 +1521,8 @@ class GraphModuleDeserializer:
|
|||
output_node.meta["val"] = output_node.args[0].meta["val"]
|
||||
else:
|
||||
output_node.meta["val"] = tuple(
|
||||
arg.meta["val"] for arg in output_node.args[0]
|
||||
arg.meta["val"] if isinstance(arg, torch.fx.Node) else arg
|
||||
for arg in output_node.args[0]
|
||||
)
|
||||
|
||||
return self.graph
|
||||
|
|
@ -1976,8 +1982,20 @@ class GraphModuleDeserializer:
|
|||
def import_nn_module_stack(key, path, ty):
|
||||
return key, (path, ty)
|
||||
|
||||
# Helper function that splits strings by commas except for those
|
||||
# encapsulated by parens, which are valid traces.
|
||||
# TODO: Currently this is needed due to indexing Sequential
|
||||
# layers introducing names in the form "layer.slice(1, None, None)".
|
||||
# If that naming is improved, this fancier splitting can probably be
|
||||
# reverted to a simple split by comma.
|
||||
def metadata_split(metadata):
|
||||
# Remove the parentheses and commas inside them
|
||||
metadata = re.sub(r'\(.*?\)', '', metadata)
|
||||
# Split the string by comma, except for those inside parentheses
|
||||
return re.split(r'(?<!\()\s*,\s*(?!\()', metadata)
|
||||
|
||||
nn_module_stack = dict(
|
||||
import_nn_module_stack(*item.split(","))
|
||||
import_nn_module_stack(*metadata_split(item))
|
||||
for item in nn_module_stack_str.split(ST_DELIMITER)
|
||||
)
|
||||
ret["nn_module_stack"] = nn_module_stack
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user