mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[FX] Rename Node._uses and refactor Node.all_input_nodes (#49415)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49415 Test Plan: Imported from OSS Reviewed By: zdevito Differential Revision: D25565341 Pulled By: jamesr66a fbshipit-source-id: 2290ab62572632788809ba16319578bf0c0260ee
This commit is contained in:
parent
46debe7f23
commit
e9d7d37ad0
|
|
@ -59,7 +59,10 @@ class Node:
|
|||
self.target = target # for method/module/function, the name of the method/module/function/attr
|
||||
# being invoked, e.g add, layer1, or torch.add
|
||||
|
||||
self._uses : Dict[Node, None] = {}
|
||||
# All `Node`-valued inputs. Key is the Node, value is don't-care.
|
||||
# The public API for this is `all_input_nodes`, this private attribute
|
||||
# should not be accessed directly.
|
||||
self._input_nodes : Dict[Node, None] = {}
|
||||
self._update_args_kwargs(map_arg(args, lambda x: x), map_arg(kwargs, lambda x: x)) # type: ignore
|
||||
|
||||
# All of the nodes that use the value produced by this Node
|
||||
|
|
@ -191,10 +194,7 @@ class Node:
|
|||
List of ``Nodes`` that appear in the ``args`` and ``kwargs`` of this
|
||||
``Node``, in that order.
|
||||
"""
|
||||
all_nodes : List['Node'] = []
|
||||
map_arg(self.args, lambda n: all_nodes.append(n))
|
||||
map_arg(self.kwargs, lambda n: all_nodes.append(n))
|
||||
return all_nodes
|
||||
return list(self._input_nodes.keys())
|
||||
|
||||
def _update_args_kwargs(self, new_args : Tuple[Argument, ...], new_kwargs : Dict[str, Argument]):
|
||||
"""
|
||||
|
|
@ -203,14 +203,14 @@ class Node:
|
|||
self._args = new_args
|
||||
self._kwargs = new_kwargs
|
||||
|
||||
for old_use in self._uses.keys():
|
||||
for old_use in self._input_nodes.keys():
|
||||
old_use.users.pop(self)
|
||||
|
||||
self._uses = {}
|
||||
map_arg(self._args, lambda n: self._uses.setdefault(n))
|
||||
map_arg(self._kwargs, lambda n: self._uses.setdefault(n))
|
||||
self._input_nodes = {}
|
||||
map_arg(self._args, lambda n: self._input_nodes.setdefault(n))
|
||||
map_arg(self._kwargs, lambda n: self._input_nodes.setdefault(n))
|
||||
|
||||
for new_use in self._uses.keys():
|
||||
for new_use in self._input_nodes.keys():
|
||||
new_use.users.setdefault(self)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user