mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[FX] Rename Node._uses and refactor Node.all_input_nodes (#49415)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49415 Test Plan: Imported from OSS Reviewed By: zdevito Differential Revision: D25565341 Pulled By: jamesr66a fbshipit-source-id: 2290ab62572632788809ba16319578bf0c0260ee
This commit is contained in:
parent
46debe7f23
commit
e9d7d37ad0
|
|
@ -59,7 +59,10 @@ class Node:
|
||||||
self.target = target # for method/module/function, the name of the method/module/function/attr
|
self.target = target # for method/module/function, the name of the method/module/function/attr
|
||||||
# being invoked, e.g add, layer1, or torch.add
|
# being invoked, e.g add, layer1, or torch.add
|
||||||
|
|
||||||
self._uses : Dict[Node, None] = {}
|
# All `Node`-valued inputs. Key is the Node, value is don't-care.
|
||||||
|
# The public API for this is `all_input_nodes`, this private attribute
|
||||||
|
# should not be accessed directly.
|
||||||
|
self._input_nodes : Dict[Node, None] = {}
|
||||||
self._update_args_kwargs(map_arg(args, lambda x: x), map_arg(kwargs, lambda x: x)) # type: ignore
|
self._update_args_kwargs(map_arg(args, lambda x: x), map_arg(kwargs, lambda x: x)) # type: ignore
|
||||||
|
|
||||||
# All of the nodes that use the value produced by this Node
|
# All of the nodes that use the value produced by this Node
|
||||||
|
|
@ -191,10 +194,7 @@ class Node:
|
||||||
List of ``Nodes`` that appear in the ``args`` and ``kwargs`` of this
|
List of ``Nodes`` that appear in the ``args`` and ``kwargs`` of this
|
||||||
``Node``, in that order.
|
``Node``, in that order.
|
||||||
"""
|
"""
|
||||||
all_nodes : List['Node'] = []
|
return list(self._input_nodes.keys())
|
||||||
map_arg(self.args, lambda n: all_nodes.append(n))
|
|
||||||
map_arg(self.kwargs, lambda n: all_nodes.append(n))
|
|
||||||
return all_nodes
|
|
||||||
|
|
||||||
def _update_args_kwargs(self, new_args : Tuple[Argument, ...], new_kwargs : Dict[str, Argument]):
|
def _update_args_kwargs(self, new_args : Tuple[Argument, ...], new_kwargs : Dict[str, Argument]):
|
||||||
"""
|
"""
|
||||||
|
|
@ -203,14 +203,14 @@ class Node:
|
||||||
self._args = new_args
|
self._args = new_args
|
||||||
self._kwargs = new_kwargs
|
self._kwargs = new_kwargs
|
||||||
|
|
||||||
for old_use in self._uses.keys():
|
for old_use in self._input_nodes.keys():
|
||||||
old_use.users.pop(self)
|
old_use.users.pop(self)
|
||||||
|
|
||||||
self._uses = {}
|
self._input_nodes = {}
|
||||||
map_arg(self._args, lambda n: self._uses.setdefault(n))
|
map_arg(self._args, lambda n: self._input_nodes.setdefault(n))
|
||||||
map_arg(self._kwargs, lambda n: self._uses.setdefault(n))
|
map_arg(self._kwargs, lambda n: self._input_nodes.setdefault(n))
|
||||||
|
|
||||||
for new_use in self._uses.keys():
|
for new_use in self._input_nodes.keys():
|
||||||
new_use.users.setdefault(self)
|
new_use.users.setdefault(self)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user