[export] nested terms in nn_module_stack deserialization (#145901)

Summary: accounting for terms like "getattr(getattr(a[0], b), c)".

Test Plan: test_serialize

Differential Revision: D68784736

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145901
Approved by: https://github.com/angelayi
This commit is contained in:
Pian Pawakapan 2025-01-31 10:00:13 +00:00 committed by PyTorch MergeBot
parent 1f1a9965d5
commit 7b07415aaa
2 changed files with 56 additions and 11 deletions

View File

@ -230,6 +230,47 @@ def forward(self, x):
actual_out = loaded_ep.module()(*inp)
self.assertEqual(exp_out, actual_out)
def test_nested_layer_split(self):
class Bar(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.layers = torch.nn.Sequential(
torch.nn.SiLU(),
torch.nn.SiLU(),
torch.nn.SiLU(),
)
def forward(self, x):
out_start, out_rest = self.layers[0], self.layers[1:]
h = out_start(x)
h = out_rest(h) + 2
return h
class Foo(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.register_module("a[(1)]", Bar())
self.register_module("b[(2)]", Bar())
self.register_buffer("c:[22]", torch.randn(1))
def forward(self, x):
out_a, out_b = getattr(self, "a[(1)]"), getattr(self, "b[(2)]")
out_c = getattr(self, "c:[22]")
h = out_a(x)
h = out_b(h)
return h + out_c
inp = (torch.ones(10),)
ep = export_for_training(Foo(), inp, strict=True)
buffer = io.BytesIO()
save(ep, buffer)
loaded_ep = load(buffer)
# Check that both modules run to confirm load was successful.
exp_out = ep.module()(*inp)
actual_out = loaded_ep.module()(*inp)
self.assertEqual(exp_out, actual_out)
def test_serialize_constant_outputs(self):
class MyModule(torch.nn.Module):
def __init__(self) -> None:

View File

@ -11,7 +11,6 @@ import keyword
import logging
import math
import operator
import re
import traceback
import typing
@ -2489,17 +2488,22 @@ class GraphModuleDeserializer(metaclass=Final):
def import_nn_module_stack(key, path, ty):
return key, (path, ty)
# Helper function that splits strings by commas except for those
# encapsulated by parens, which are valid traces.
# TODO: Currently this is needed due to indexing Sequential
# layers introducing names in the form "layer.slice(1, None, None)".
# If that naming is improved, this fancier splitting can probably be
# reverted to a simple split by comma.
# Helper function to split string by commas, accounting for nested parentheses/brackets
def metadata_split(metadata):
# Remove the parentheses and commas inside them
metadata = re.sub(r'\(.*?\)', '', metadata)
# Split the string by comma, except for those inside parentheses
return re.split(r'(?<!\()\s*,\s*(?!\()', metadata)
out = []
start, n = 0, 0
a, b = "[(", ")]"
for end, c in enumerate(metadata):
if c in a:
n += 1
elif c in b:
n -= 1
elif c == "," and n == 0:
out.append(metadata[start : end])
start = end + 1
out.append(metadata[start:])
assert len(out) == 3
return out
nn_module_stack = dict(
import_nn_module_stack(*metadata_split(item))