Allow fx.Graph.owning_module to be used as attribute. (#86822)

Summary:
The current behavior of owning_module setter is difficult to understand: it changes the owning_module to None if owners is not 0 but increments the owners count. If the owning_module is None, the owners count should be 0 as none of them is accessible. On the other hand, if the owners count increases, the owning_module should be a collection (e.g. a list).

This diff changes owning_module to be a normal attribute. The semantic is that graph can have **at most one** owning module and can be assigned to new module.

The alternative is to use a list to represent the owning_modules of a graph but it breaks backward compatibility and the exact use cases of having multiple owning_modules are not clear.

Test Plan: Test with CI.

Differential Revision: D40200624

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86822
Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
Yidi Wu 2022-10-19 00:12:59 +00:00 committed by PyTorch MergeBot
parent 3eb7429385
commit 4dc579838b
2 changed files with 1 additions and 14 deletions

View File

@ -24,11 +24,6 @@ class FoldedGraphModule(torch.fx.GraphModule):
fx_const_folded_attrs_name: str = None,
device_for_folded_attrs: str = "cuda",
):
# In init, we set graph's owning module to root which will make graph's
# owning module be None because graph already have a owning module. We
# need owning module to run DCE. To work around we set the number of
# graph's owners to 0.
graph._owners = 0
super().__init__(root, graph)
self.const_subgraph_module = (
None

View File

@ -697,7 +697,6 @@ class Graph:
self._insert = self._root.prepend
self._len = 0
self._graph_namespace = _Namespace()
self._owners = 0
self._owning_module = owning_module
self._tracer_cls = tracer_cls
self._tracer_extras = tracer_extras
@ -705,18 +704,11 @@ class Graph:
@property
def owning_module(self):
"""
Return the module that owns this ``GraphModule``, if there is one,
``None`` if there is no owning module or if there are multiple owning
modules.
"""
return self._owning_module
@owning_module.setter
def owning_module(self, mod: Optional["GraphModule"]):
if mod:
self._owning_module = mod if not self._owners else None
self._owners += 1
self._owning_module = mod
@property
def nodes(self) -> _node_list: