mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[export] nested terms in nn_module_stack deserialization (#145901)
Summary: accounting for terms like "getattr(getattr(a[0], b), c)". Test Plan: test_serialize Differential Revision: D68784736 Pull Request resolved: https://github.com/pytorch/pytorch/pull/145901 Approved by: https://github.com/angelayi
This commit is contained in:
parent
1f1a9965d5
commit
7b07415aaa
|
|
@ -230,6 +230,47 @@ def forward(self, x):
|
|||
actual_out = loaded_ep.module()(*inp)
|
||||
self.assertEqual(exp_out, actual_out)
|
||||
|
||||
def test_nested_layer_split(self):
|
||||
class Bar(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.layers = torch.nn.Sequential(
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.SiLU(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out_start, out_rest = self.layers[0], self.layers[1:]
|
||||
h = out_start(x)
|
||||
h = out_rest(h) + 2
|
||||
return h
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.register_module("a[(1)]", Bar())
|
||||
self.register_module("b[(2)]", Bar())
|
||||
self.register_buffer("c:[22]", torch.randn(1))
|
||||
|
||||
def forward(self, x):
|
||||
out_a, out_b = getattr(self, "a[(1)]"), getattr(self, "b[(2)]")
|
||||
out_c = getattr(self, "c:[22]")
|
||||
h = out_a(x)
|
||||
h = out_b(h)
|
||||
return h + out_c
|
||||
|
||||
inp = (torch.ones(10),)
|
||||
ep = export_for_training(Foo(), inp, strict=True)
|
||||
buffer = io.BytesIO()
|
||||
save(ep, buffer)
|
||||
loaded_ep = load(buffer)
|
||||
|
||||
# Check that both modules run to confirm load was successful.
|
||||
exp_out = ep.module()(*inp)
|
||||
actual_out = loaded_ep.module()(*inp)
|
||||
self.assertEqual(exp_out, actual_out)
|
||||
|
||||
def test_serialize_constant_outputs(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ import keyword
|
|||
import logging
|
||||
import math
|
||||
import operator
|
||||
import re
|
||||
import traceback
|
||||
import typing
|
||||
|
||||
|
|
@ -2489,17 +2488,22 @@ class GraphModuleDeserializer(metaclass=Final):
|
|||
def import_nn_module_stack(key, path, ty):
|
||||
return key, (path, ty)
|
||||
|
||||
# Helper function that splits strings by commas except for those
|
||||
# encapsulated by parens, which are valid traces.
|
||||
# TODO: Currently this is needed due to indexing Sequential
|
||||
# layers introducing names in the form "layer.slice(1, None, None)".
|
||||
# If that naming is improved, this fancier splitting can probably be
|
||||
# reverted to a simple split by comma.
|
||||
# Helper function to split string by commas, accounting for nested parentheses/brackets
|
||||
def metadata_split(metadata):
|
||||
# Remove the parentheses and commas inside them
|
||||
metadata = re.sub(r'\(.*?\)', '', metadata)
|
||||
# Split the string by comma, except for those inside parentheses
|
||||
return re.split(r'(?<!\()\s*,\s*(?!\()', metadata)
|
||||
out = []
|
||||
start, n = 0, 0
|
||||
a, b = "[(", ")]"
|
||||
for end, c in enumerate(metadata):
|
||||
if c in a:
|
||||
n += 1
|
||||
elif c in b:
|
||||
n -= 1
|
||||
elif c == "," and n == 0:
|
||||
out.append(metadata[start : end])
|
||||
start = end + 1
|
||||
out.append(metadata[start:])
|
||||
assert len(out) == 3
|
||||
return out
|
||||
|
||||
nn_module_stack = dict(
|
||||
import_nn_module_stack(*metadata_split(item))
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user