From e9d7d37ad0a920a763bbabf13ea3ba7fbd3e939d Mon Sep 17 00:00:00 2001 From: James Reed Date: Tue, 15 Dec 2020 17:12:01 -0800 Subject: [PATCH] [FX] Rename Node._uses and refactor Node.all_input_nodes (#49415) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49415 Test Plan: Imported from OSS Reviewed By: zdevito Differential Revision: D25565341 Pulled By: jamesr66a fbshipit-source-id: 2290ab62572632788809ba16319578bf0c0260ee --- torch/fx/node.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/torch/fx/node.py b/torch/fx/node.py index d304a4c0a47..247896b5a92 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -59,7 +59,10 @@ class Node: self.target = target # for method/module/function, the name of the method/module/function/attr # being invoked, e.g add, layer1, or torch.add - self._uses : Dict[Node, None] = {} + # All `Node`-valued inputs. Key is the Node, value is don't-care. + # The public API for this is `all_input_nodes`, this private attribute + # should not be accessed directly. + self._input_nodes : Dict[Node, None] = {} self._update_args_kwargs(map_arg(args, lambda x: x), map_arg(kwargs, lambda x: x)) # type: ignore # All of the nodes that use the value produced by this Node @@ -191,10 +194,7 @@ class Node: List of ``Nodes`` that appear in the ``args`` and ``kwargs`` of this ``Node``, in that order. """ - all_nodes : List['Node'] = [] - map_arg(self.args, lambda n: all_nodes.append(n)) - map_arg(self.kwargs, lambda n: all_nodes.append(n)) - return all_nodes + return list(self._input_nodes.keys()) def _update_args_kwargs(self, new_args : Tuple[Argument, ...], new_kwargs : Dict[str, Argument]): """ @@ -203,14 +203,14 @@ class Node: self._args = new_args self._kwargs = new_kwargs - for old_use in self._uses.keys(): + for old_use in self._input_nodes.keys(): old_use.users.pop(self) - self._uses = {} - map_arg(self._args, lambda n: self._uses.setdefault(n)) - map_arg(self._kwargs, lambda n: self._uses.setdefault(n)) + self._input_nodes = {} + map_arg(self._args, lambda n: self._input_nodes.setdefault(n)) + map_arg(self._kwargs, lambda n: self._input_nodes.setdefault(n)) - for new_use in self._uses.keys(): + for new_use in self._input_nodes.keys(): new_use.users.setdefault(self) def __repr__(self) -> str: