mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[FX] Fix uses not updating when erasing a node (#47720)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47720 Test Plan: Imported from OSS Reviewed By: zdevito Differential Revision: D24875880 Pulled By: jamesr66a fbshipit-source-id: aae9ffd10f8085b599e7923152287c6e6950ff49
This commit is contained in:
parent
d1351c66a8
commit
dbfee42a7d
|
|
@ -700,6 +700,19 @@ class TestFX(JitTestCase):
|
|||
ref = torch.sin(mod.linear(input) + mod.bias)
|
||||
self.assertEqual(r, ref)
|
||||
|
||||
def test_remove_uses(self):
|
||||
g : torch.fx.Graph = Graph()
|
||||
x : torch.fx.Node = g.placeholder('x')
|
||||
relu : torch.fx.Node = g.call_function(torch.relu, (x,))
|
||||
neg : torch.fx.Node = g.call_function(torch.neg, (relu,))
|
||||
g.output(neg)
|
||||
|
||||
neg.replace_all_uses_with(relu)
|
||||
g.erase_node(neg)
|
||||
|
||||
self.assertTrue(neg not in relu.users)
|
||||
|
||||
|
||||
def test_construct_root_dict(self):
|
||||
graph : torch.fx.Graph = torch.fx.Graph()
|
||||
a : torch.fx.Node = graph.create_node('placeholder', 'x')
|
||||
|
|
|
|||
|
|
@ -261,6 +261,15 @@ class Graph:
|
|||
to_erase._erased = True # iterators may retain handles to erased nodes
|
||||
self._len -= 1
|
||||
|
||||
# Null out this Node's argument nodes so that the Nodes referred to
|
||||
# can update their `users` accordingly
|
||||
new_args = map_arg(to_erase.args, lambda n: None)
|
||||
assert isinstance(new_args, tuple)
|
||||
to_erase.args = new_args
|
||||
new_kwargs = map_arg(to_erase.kwargs, lambda n: None)
|
||||
assert isinstance(new_kwargs, dict)
|
||||
to_erase.kwargs = new_kwargs
|
||||
|
||||
def inserting_before(self, n: Optional[Node] = None):
|
||||
"""Set the point at which create_node and companion methods will insert into the graph.
|
||||
When used within a 'with' statement, this will temporary set the insert point and
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user