[FX] Rename Node._uses and refactor Node.all_input_nodes (#49415)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49415

Test Plan: Imported from OSS

Reviewed By: zdevito

Differential Revision: D25565341

Pulled By: jamesr66a

fbshipit-source-id: 2290ab62572632788809ba16319578bf0c0260ee
This commit is contained in:
James Reed 2020-12-15 17:12:01 -08:00 committed by Facebook GitHub Bot
parent 46debe7f23
commit e9d7d37ad0

View File

@ -59,7 +59,10 @@ class Node:
self.target = target # for method/module/function, the name of the method/module/function/attr self.target = target # for method/module/function, the name of the method/module/function/attr
# being invoked, e.g add, layer1, or torch.add # being invoked, e.g add, layer1, or torch.add
self._uses : Dict[Node, None] = {} # All `Node`-valued inputs. Key is the Node, value is don't-care.
# The public API for this is `all_input_nodes`, this private attribute
# should not be accessed directly.
self._input_nodes : Dict[Node, None] = {}
self._update_args_kwargs(map_arg(args, lambda x: x), map_arg(kwargs, lambda x: x)) # type: ignore self._update_args_kwargs(map_arg(args, lambda x: x), map_arg(kwargs, lambda x: x)) # type: ignore
# All of the nodes that use the value produced by this Node # All of the nodes that use the value produced by this Node
@ -191,10 +194,7 @@ class Node:
List of ``Nodes`` that appear in the ``args`` and ``kwargs`` of this List of ``Nodes`` that appear in the ``args`` and ``kwargs`` of this
``Node``, in that order. ``Node``, in that order.
""" """
all_nodes : List['Node'] = [] return list(self._input_nodes.keys())
map_arg(self.args, lambda n: all_nodes.append(n))
map_arg(self.kwargs, lambda n: all_nodes.append(n))
return all_nodes
def _update_args_kwargs(self, new_args : Tuple[Argument, ...], new_kwargs : Dict[str, Argument]): def _update_args_kwargs(self, new_args : Tuple[Argument, ...], new_kwargs : Dict[str, Argument]):
""" """
@ -203,14 +203,14 @@ class Node:
self._args = new_args self._args = new_args
self._kwargs = new_kwargs self._kwargs = new_kwargs
for old_use in self._uses.keys(): for old_use in self._input_nodes.keys():
old_use.users.pop(self) old_use.users.pop(self)
self._uses = {} self._input_nodes = {}
map_arg(self._args, lambda n: self._uses.setdefault(n)) map_arg(self._args, lambda n: self._input_nodes.setdefault(n))
map_arg(self._kwargs, lambda n: self._uses.setdefault(n)) map_arg(self._kwargs, lambda n: self._input_nodes.setdefault(n))
for new_use in self._uses.keys(): for new_use in self._input_nodes.keys():
new_use.users.setdefault(self) new_use.users.setdefault(self)
def __repr__(self) -> str: def __repr__(self) -> str: