diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 687a7a8962a..3a57291b0bc 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -1079,6 +1079,14 @@ class OutputGraph(OutputGraphCommon): def register_static_attr_and_return_proxy( self, attr_prefix: str, attr_value: Any ) -> fx.Proxy: + # Check if the module already exists, if it does, return the already + # added proxy. This is important for executorch tests. + if isinstance(attr_value, torch.nn.Module): + for name, mod in self.nn_modules.items(): + if mod is attr_value: + proxy = self.create_proxy("get_attr", name, (), {}) + return proxy + attr_name = get_unique_name_wrt(attr_prefix, self.nn_modules) # TODO `nn_modules` has been historically overloaded to store a lot more # than just nn module objects, fix that.