From 4dc579838be94c343cc8542c7a80b9a9a8c15b51 Mon Sep 17 00:00:00 2001 From: Yidi Wu Date: Wed, 19 Oct 2022 00:12:59 +0000 Subject: [PATCH] Allow fx.Graph.owning_module to be used as attribute. (#86822) Summary: The current behavior of owning_module setter is difficult to understand: it changes the owning_module to None if owners is not 0 but increments the owners count. If the owning_module is None, the owners count should be 0 as none of them is accessible. On the other hand, if the owners count increases, the owning_module should be a collection (e.g. a list). This diff changes owning_module to be a normal attribute. The semantic is that graph can have **at most one** owning module and can be assigned to new module. The alternative is to use a list to represent the owning_modules of a graph but it breaks backward compatibility and the exact use cases of having multiple owning_modules are not clear. Test Plan: Test with CI. Differential Revision: D40200624 Pull Request resolved: https://github.com/pytorch/pytorch/pull/86822 Approved by: https://github.com/tugsbayasgalan --- torch/fx/experimental/const_fold.py | 5 ----- torch/fx/graph.py | 10 +--------- 2 files changed, 1 insertion(+), 14 deletions(-) diff --git a/torch/fx/experimental/const_fold.py b/torch/fx/experimental/const_fold.py index 56ffbcacc84..a9698030297 100644 --- a/torch/fx/experimental/const_fold.py +++ b/torch/fx/experimental/const_fold.py @@ -24,11 +24,6 @@ class FoldedGraphModule(torch.fx.GraphModule): fx_const_folded_attrs_name: str = None, device_for_folded_attrs: str = "cuda", ): - # In init, we set graph's owning module to root which will make graph's - # owning module be None because graph already have a owning module. We - # need owning module to run DCE. To work around we set the number of - # graph's owners to 0. - graph._owners = 0 super().__init__(root, graph) self.const_subgraph_module = ( None diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 271c43e9857..9397050bc29 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -697,7 +697,6 @@ class Graph: self._insert = self._root.prepend self._len = 0 self._graph_namespace = _Namespace() - self._owners = 0 self._owning_module = owning_module self._tracer_cls = tracer_cls self._tracer_extras = tracer_extras @@ -705,18 +704,11 @@ class Graph: @property def owning_module(self): - """ - Return the module that owns this ``GraphModule``, if there is one, - ``None`` if there is no owning module or if there are multiple owning - modules. - """ return self._owning_module @owning_module.setter def owning_module(self, mod: Optional["GraphModule"]): - if mod: - self._owning_module = mod if not self._owners else None - self._owners += 1 + self._owning_module = mod @property def nodes(self) -> _node_list: