From 7b07415aaa64b624e20417b6aa4be37f479ee41d Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Fri, 31 Jan 2025 10:00:13 +0000 Subject: [PATCH] [export] nested terms in nn_module_stack deserialization (#145901) Summary: accounting for terms like "getattr(getattr(a[0], b), c)". Test Plan: test_serialize Differential Revision: D68784736 Pull Request resolved: https://github.com/pytorch/pytorch/pull/145901 Approved by: https://github.com/angelayi --- test/export/test_serialize.py | 41 ++++++++++++++++++++++++++++++++ torch/_export/serde/serialize.py | 26 +++++++++++--------- 2 files changed, 56 insertions(+), 11 deletions(-) diff --git a/test/export/test_serialize.py b/test/export/test_serialize.py index 1324c8035d3..66baf655b5b 100644 --- a/test/export/test_serialize.py +++ b/test/export/test_serialize.py @@ -230,6 +230,47 @@ def forward(self, x): actual_out = loaded_ep.module()(*inp) self.assertEqual(exp_out, actual_out) + def test_nested_layer_split(self): + class Bar(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.layers = torch.nn.Sequential( + torch.nn.SiLU(), + torch.nn.SiLU(), + torch.nn.SiLU(), + ) + + def forward(self, x): + out_start, out_rest = self.layers[0], self.layers[1:] + h = out_start(x) + h = out_rest(h) + 2 + return h + + class Foo(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.register_module("a[(1)]", Bar()) + self.register_module("b[(2)]", Bar()) + self.register_buffer("c:[22]", torch.randn(1)) + + def forward(self, x): + out_a, out_b = getattr(self, "a[(1)]"), getattr(self, "b[(2)]") + out_c = getattr(self, "c:[22]") + h = out_a(x) + h = out_b(h) + return h + out_c + + inp = (torch.ones(10),) + ep = export_for_training(Foo(), inp, strict=True) + buffer = io.BytesIO() + save(ep, buffer) + loaded_ep = load(buffer) + + # Check that both modules run to confirm load was successful. + exp_out = ep.module()(*inp) + actual_out = loaded_ep.module()(*inp) + self.assertEqual(exp_out, actual_out) + def test_serialize_constant_outputs(self): class MyModule(torch.nn.Module): def __init__(self) -> None: diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index d591b1d24ad..e7269fdc88a 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -11,7 +11,6 @@ import keyword import logging import math import operator -import re import traceback import typing @@ -2489,17 +2488,22 @@ class GraphModuleDeserializer(metaclass=Final): def import_nn_module_stack(key, path, ty): return key, (path, ty) - # Helper function that splits strings by commas except for those - # encapsulated by parens, which are valid traces. - # TODO: Currently this is needed due to indexing Sequential - # layers introducing names in the form "layer.slice(1, None, None)". - # If that naming is improved, this fancier splitting can probably be - # reverted to a simple split by comma. + # Helper function to split string by commas, accounting for nested parentheses/brackets def metadata_split(metadata): - # Remove the parentheses and commas inside them - metadata = re.sub(r'\(.*?\)', '', metadata) - # Split the string by comma, except for those inside parentheses - return re.split(r'(?